"""令牌桶速率限制策略,支持 Redis 分布式后端""" import asyncio import os import time from datetime import datetime, timedelta from typing import Any, Dict, Optional, Tuple from ...clients.redis_client import get_redis_client_sync from src.core.logger import logger from .base import RateLimitResult, RateLimitStrategy class TokenBucket: """令牌桶实现""" def __init__(self, capacity: int, refill_rate: float): """ 初始化令牌桶 Args: capacity: 桶容量(最大令牌数) refill_rate: 令牌补充速率(每秒) """ self.capacity = capacity self.refill_rate = refill_rate self.tokens = capacity self.last_refill = time.time() def _refill(self): """补充令牌""" now = time.time() time_passed = now - self.last_refill tokens_to_add = time_passed * self.refill_rate if tokens_to_add > 0: self.tokens = min(self.capacity, self.tokens + tokens_to_add) self.last_refill = now def consume(self, amount: int = 1) -> bool: """ 消费令牌 Args: amount: 要消费的令牌数 Returns: 是否成功消费 """ self._refill() if self.tokens >= amount: self.tokens -= amount return True return False def get_remaining(self) -> int: """获取剩余令牌数""" self._refill() return int(self.tokens) def get_reset_time(self) -> datetime: """获取下次完全恢复的时间""" if self.tokens >= self.capacity: return datetime.now() tokens_needed = self.capacity - self.tokens seconds_to_full = tokens_needed / self.refill_rate return datetime.now() + timedelta(seconds=seconds_to_full) class TokenBucketStrategy(RateLimitStrategy): """ 令牌桶算法速率限制策略 特点: - 允许突发流量 - 平均速率受限 - 适合处理不均匀的流量模式 """ def __init__(self): super().__init__("token_bucket") self.buckets: Dict[str, TokenBucket] = {} self._lock = asyncio.Lock() # 默认配置 self.default_capacity = 100 # 默认桶容量 self.default_refill_rate = 10 # 默认每秒补充10个令牌 # 可选的 Redis 后端 self._redis_backend: Optional[RedisTokenBucketBackend] = None self._redis_checked = False self._backend_mode = os.getenv("RATE_LIMIT_BACKEND", "auto").lower() def _get_bucket(self, key: str, rate_limit: Optional[int] = None) -> TokenBucket: """ 获取或创建令牌桶 Args: key: 限制键 rate_limit: 每分钟请求限制(来自数据库配置),如果提供则使用此值 Returns: 令牌桶实例 """ if key not in self.buckets: # 如果提供了rate_limit参数(来自数据库),优先使用 if rate_limit is not None: # rate_limit 是每分钟请求数,转换为令牌桶参数 capacity = rate_limit # 桶容量等于每分钟限制 refill_rate = rate_limit / 60.0 # 每秒补充的令牌数 # 否则根据key的不同前缀使用不同的配置 elif key.startswith("api_key:"): capacity = self.config.get("api_key_capacity", self.default_capacity) refill_rate = self.config.get("api_key_refill_rate", self.default_refill_rate) elif key.startswith("user:"): capacity = self.config.get("user_capacity", self.default_capacity * 2) refill_rate = self.config.get("user_refill_rate", self.default_refill_rate * 2) else: capacity = self.default_capacity refill_rate = self.default_refill_rate self.buckets[key] = TokenBucket(capacity, refill_rate) return self.buckets[key] def _want_redis_backend(self) -> bool: return self._backend_mode in {"auto", "redis"} async def _ensure_backend(self): if self._redis_checked: return self._redis_checked = True if not self._want_redis_backend(): return redis_client = get_redis_client_sync() if redis_client: self._redis_backend = RedisTokenBucketBackend(redis_client) logger.info("速率限制改用 Redis 令牌桶后端") elif self._backend_mode == "redis": logger.warning("RATE_LIMIT_BACKEND=redis 但 Redis 客户端不可用,回退到内存桶") async def check_limit(self, key: str, **kwargs) -> RateLimitResult: """ 检查速率限制 Args: key: 限制键 **kwargs: 额外参数,包括 rate_limit (从数据库配置) Returns: 速率限制检查结果 """ await self._ensure_backend() rate_limit = kwargs.get("rate_limit") amount = kwargs.get("amount", 1) if self._redis_backend: return await self._redis_backend.peek( key=key, capacity=self._resolve_capacity(key, rate_limit), refill_rate=self._resolve_refill_rate(key, rate_limit), amount=amount, ) async with self._lock: bucket = self._get_bucket(key, rate_limit) remaining = bucket.get_remaining() reset_at = bucket.get_reset_time() allowed = remaining >= amount retry_after = None if not allowed: tokens_needed = amount - remaining retry_after = int(tokens_needed / bucket.refill_rate) + 1 return RateLimitResult( allowed=allowed, remaining=remaining, reset_at=reset_at, retry_after=retry_after, message=( None if allowed else f"Rate limit exceeded. Please retry after {retry_after} seconds." ), ) async def consume(self, key: str, amount: int = 1, **kwargs) -> bool: """ 消费令牌 Args: key: 限制键 amount: 消费数量 Returns: 是否成功消费 """ await self._ensure_backend() if self._redis_backend: success, remaining = await self._redis_backend.consume( key=key, capacity=self._resolve_capacity(key, kwargs.get("rate_limit")), refill_rate=self._resolve_refill_rate(key, kwargs.get("rate_limit")), amount=amount, ) if success: logger.debug("Redis 令牌消费成功") else: logger.warning("Redis 令牌消费失败") return success async with self._lock: bucket = self._get_bucket(key) success = bucket.consume(amount) if success: logger.debug(f"令牌消费成功") else: logger.warning(f"令牌消费失败:超出速率限制") return success async def reset(self, key: str): """ 重置令牌桶 Args: key: 限制键 """ await self._ensure_backend() if self._redis_backend: await self._redis_backend.reset(key) return async with self._lock: if key in self.buckets: bucket = self.buckets[key] bucket.tokens = bucket.capacity bucket.last_refill = time.time() logger.info(f"令牌桶已重置") async def get_stats(self, key: str) -> Dict[str, Any]: """ 获取统计信息 Args: key: 限制键 Returns: 统计信息 """ await self._ensure_backend() if self._redis_backend: return await self._redis_backend.get_stats( key, capacity=self._resolve_capacity(key), refill_rate=self._resolve_refill_rate(key), ) async with self._lock: bucket = self._get_bucket(key) return { "strategy": "token_bucket", "key": key, "capacity": bucket.capacity, "remaining": bucket.get_remaining(), "refill_rate": bucket.refill_rate, "reset_at": bucket.get_reset_time().isoformat(), } def configure(self, config: Dict[str, Any]): """ 配置策略 支持的配置项: - api_key_capacity: API Key的桶容量 - api_key_refill_rate: API Key的令牌补充速率 - user_capacity: 用户的桶容量 - user_refill_rate: 用户的令牌补充速率 """ super().configure(config) self.default_capacity = config.get("default_capacity", self.default_capacity) self.default_refill_rate = config.get("default_refill_rate", self.default_refill_rate) def _resolve_capacity(self, key: str, rate_limit: Optional[int] = None) -> int: if rate_limit is not None: return rate_limit if key.startswith("api_key:"): return self.config.get("api_key_capacity", self.default_capacity) if key.startswith("user:"): return self.config.get("user_capacity", self.default_capacity * 2) return self.default_capacity def _resolve_refill_rate(self, key: str, rate_limit: Optional[int] = None) -> float: if rate_limit is not None: return rate_limit / 60.0 if key.startswith("api_key:"): return self.config.get("api_key_refill_rate", self.default_refill_rate) if key.startswith("user:"): return self.config.get("user_refill_rate", self.default_refill_rate * 2) return self.default_refill_rate class RedisTokenBucketBackend: """使用 Redis 存储令牌桶状态,支持多实例共享""" _SCRIPT = """ local key = KEYS[1] local now = tonumber(ARGV[1]) local capacity = tonumber(ARGV[2]) local refill_rate = tonumber(ARGV[3]) local amount = tonumber(ARGV[4]) local data = redis.call('HMGET', key, 'tokens', 'timestamp') local tokens = tonumber(data[1]) local last_refill = tonumber(data[2]) if tokens == nil then tokens = capacity last_refill = now end local delta = math.max(0, now - last_refill) local refill = delta * refill_rate tokens = math.min(capacity, tokens + refill) local allowed = 0 local retry_after = 0 if tokens >= amount then tokens = tokens - amount allowed = 1 else retry_after = math.ceil((amount - tokens) / refill_rate) end redis.call('HMSET', key, 'tokens', tokens, 'timestamp', now) local ttl = math.max(1, math.ceil(capacity / refill_rate)) redis.call('EXPIRE', key, ttl) return {allowed, tokens, retry_after} """ def __init__(self, redis_client): self.redis = redis_client self._consume_script = self.redis.register_script(self._SCRIPT) def _redis_key(self, key: str) -> str: return f"rate_limit:bucket:{key}" async def peek( self, key: str, capacity: int, refill_rate: float, amount: int, ) -> RateLimitResult: bucket_key = self._redis_key(key) data = await self.redis.hmget(bucket_key, "tokens", "timestamp") tokens = data[0] last_refill = data[1] if tokens is None or last_refill is None: remaining = capacity reset_at = datetime.now() + timedelta(seconds=capacity / refill_rate) else: tokens_value = float(tokens) last_refill_value = float(last_refill) delta = max(0.0, time.time() - last_refill_value) tokens_value = min(capacity, tokens_value + delta * refill_rate) remaining = int(tokens_value) reset_after = 0 if tokens_value >= capacity else (capacity - tokens_value) / refill_rate reset_at = datetime.now() + timedelta(seconds=reset_after) allowed = remaining >= amount retry_after = None if not allowed: needed = max(0, amount - remaining) retry_after = int(needed / refill_rate) + 1 return RateLimitResult( allowed=allowed, remaining=int(remaining), reset_at=reset_at, retry_after=retry_after, message=( None if allowed else f"Rate limit exceeded. Please retry after {retry_after} seconds." ), ) async def consume( self, key: str, capacity: int, refill_rate: float, amount: int, ) -> Tuple[bool, int]: result = await self._consume_script( keys=[self._redis_key(key)], args=[time.time(), capacity, refill_rate, amount], ) allowed = bool(result[0]) remaining = int(float(result[1])) return allowed, remaining async def reset(self, key: str): await self.redis.delete(self._redis_key(key)) async def get_stats(self, key: str, capacity: int, refill_rate: float) -> Dict[str, Any]: data = await self.redis.hmget(self._redis_key(key), "tokens", "timestamp") tokens = data[0] timestamp = data[1] return { "strategy": "token_bucket", "key": key, "capacity": capacity, "remaining": float(tokens) if tokens else capacity, "refill_rate": refill_rate, "last_refill": float(timestamp) if timestamp else time.time(), "backend": "redis", }