Initial commit

This commit is contained in:
fawney19
2025-12-10 20:52:44 +08:00
commit f784106826
485 changed files with 110993 additions and 0 deletions

View File

@@ -0,0 +1,431 @@
"""令牌桶速率限制策略,支持 Redis 分布式后端"""
import asyncio
import os
import time
from datetime import datetime, timedelta
from typing import Any, Dict, Optional, Tuple
from ...clients.redis_client import get_redis_client_sync
from src.core.logger import logger
from .base import RateLimitResult, RateLimitStrategy
class TokenBucket:
"""令牌桶实现"""
def __init__(self, capacity: int, refill_rate: float):
"""
初始化令牌桶
Args:
capacity: 桶容量(最大令牌数)
refill_rate: 令牌补充速率(每秒)
"""
self.capacity = capacity
self.refill_rate = refill_rate
self.tokens = capacity
self.last_refill = time.time()
def _refill(self):
"""补充令牌"""
now = time.time()
time_passed = now - self.last_refill
tokens_to_add = time_passed * self.refill_rate
if tokens_to_add > 0:
self.tokens = min(self.capacity, self.tokens + tokens_to_add)
self.last_refill = now
def consume(self, amount: int = 1) -> bool:
"""
消费令牌
Args:
amount: 要消费的令牌数
Returns:
是否成功消费
"""
self._refill()
if self.tokens >= amount:
self.tokens -= amount
return True
return False
def get_remaining(self) -> int:
"""获取剩余令牌数"""
self._refill()
return int(self.tokens)
def get_reset_time(self) -> datetime:
"""获取下次完全恢复的时间"""
if self.tokens >= self.capacity:
return datetime.now()
tokens_needed = self.capacity - self.tokens
seconds_to_full = tokens_needed / self.refill_rate
return datetime.now() + timedelta(seconds=seconds_to_full)
class TokenBucketStrategy(RateLimitStrategy):
"""
令牌桶算法速率限制策略
特点:
- 允许突发流量
- 平均速率受限
- 适合处理不均匀的流量模式
"""
def __init__(self):
super().__init__("token_bucket")
self.buckets: Dict[str, TokenBucket] = {}
self._lock = asyncio.Lock()
# 默认配置
self.default_capacity = 100 # 默认桶容量
self.default_refill_rate = 10 # 默认每秒补充10个令牌
# 可选的 Redis 后端
self._redis_backend: Optional[RedisTokenBucketBackend] = None
self._redis_checked = False
self._backend_mode = os.getenv("RATE_LIMIT_BACKEND", "auto").lower()
def _get_bucket(self, key: str, rate_limit: Optional[int] = None) -> TokenBucket:
"""
获取或创建令牌桶
Args:
key: 限制键
rate_limit: 每分钟请求限制(来自数据库配置),如果提供则使用此值
Returns:
令牌桶实例
"""
if key not in self.buckets:
# 如果提供了rate_limit参数来自数据库优先使用
if rate_limit is not None:
# rate_limit 是每分钟请求数,转换为令牌桶参数
capacity = rate_limit # 桶容量等于每分钟限制
refill_rate = rate_limit / 60.0 # 每秒补充的令牌数
# 否则根据key的不同前缀使用不同的配置
elif key.startswith("api_key:"):
capacity = self.config.get("api_key_capacity", self.default_capacity)
refill_rate = self.config.get("api_key_refill_rate", self.default_refill_rate)
elif key.startswith("user:"):
capacity = self.config.get("user_capacity", self.default_capacity * 2)
refill_rate = self.config.get("user_refill_rate", self.default_refill_rate * 2)
else:
capacity = self.default_capacity
refill_rate = self.default_refill_rate
self.buckets[key] = TokenBucket(capacity, refill_rate)
return self.buckets[key]
def _want_redis_backend(self) -> bool:
return self._backend_mode in {"auto", "redis"}
async def _ensure_backend(self):
if self._redis_checked:
return
self._redis_checked = True
if not self._want_redis_backend():
return
redis_client = get_redis_client_sync()
if redis_client:
self._redis_backend = RedisTokenBucketBackend(redis_client)
logger.info("速率限制改用 Redis 令牌桶后端")
elif self._backend_mode == "redis":
logger.warning("RATE_LIMIT_BACKEND=redis 但 Redis 客户端不可用,回退到内存桶")
async def check_limit(self, key: str, **kwargs) -> RateLimitResult:
"""
检查速率限制
Args:
key: 限制键
**kwargs: 额外参数,包括 rate_limit (从数据库配置)
Returns:
速率限制检查结果
"""
await self._ensure_backend()
rate_limit = kwargs.get("rate_limit")
amount = kwargs.get("amount", 1)
if self._redis_backend:
return await self._redis_backend.peek(
key=key,
capacity=self._resolve_capacity(key, rate_limit),
refill_rate=self._resolve_refill_rate(key, rate_limit),
amount=amount,
)
async with self._lock:
bucket = self._get_bucket(key, rate_limit)
remaining = bucket.get_remaining()
reset_at = bucket.get_reset_time()
allowed = remaining >= amount
retry_after = None
if not allowed:
tokens_needed = amount - remaining
retry_after = int(tokens_needed / bucket.refill_rate) + 1
return RateLimitResult(
allowed=allowed,
remaining=remaining,
reset_at=reset_at,
retry_after=retry_after,
message=(
None
if allowed
else f"Rate limit exceeded. Please retry after {retry_after} seconds."
),
)
async def consume(self, key: str, amount: int = 1, **kwargs) -> bool:
"""
消费令牌
Args:
key: 限制键
amount: 消费数量
Returns:
是否成功消费
"""
await self._ensure_backend()
if self._redis_backend:
success, remaining = await self._redis_backend.consume(
key=key,
capacity=self._resolve_capacity(key, kwargs.get("rate_limit")),
refill_rate=self._resolve_refill_rate(key, kwargs.get("rate_limit")),
amount=amount,
)
if success:
logger.debug("Redis 令牌消费成功")
else:
logger.warning("Redis 令牌消费失败")
return success
async with self._lock:
bucket = self._get_bucket(key)
success = bucket.consume(amount)
if success:
logger.debug(f"令牌消费成功")
else:
logger.warning(f"令牌消费失败:超出速率限制")
return success
async def reset(self, key: str):
"""
重置令牌桶
Args:
key: 限制键
"""
await self._ensure_backend()
if self._redis_backend:
await self._redis_backend.reset(key)
return
async with self._lock:
if key in self.buckets:
bucket = self.buckets[key]
bucket.tokens = bucket.capacity
bucket.last_refill = time.time()
logger.info(f"令牌桶已重置")
async def get_stats(self, key: str) -> Dict[str, Any]:
"""
获取统计信息
Args:
key: 限制键
Returns:
统计信息
"""
await self._ensure_backend()
if self._redis_backend:
return await self._redis_backend.get_stats(
key,
capacity=self._resolve_capacity(key),
refill_rate=self._resolve_refill_rate(key),
)
async with self._lock:
bucket = self._get_bucket(key)
return {
"strategy": "token_bucket",
"key": key,
"capacity": bucket.capacity,
"remaining": bucket.get_remaining(),
"refill_rate": bucket.refill_rate,
"reset_at": bucket.get_reset_time().isoformat(),
}
def configure(self, config: Dict[str, Any]):
"""
配置策略
支持的配置项:
- api_key_capacity: API Key的桶容量
- api_key_refill_rate: API Key的令牌补充速率
- user_capacity: 用户的桶容量
- user_refill_rate: 用户的令牌补充速率
"""
super().configure(config)
self.default_capacity = config.get("default_capacity", self.default_capacity)
self.default_refill_rate = config.get("default_refill_rate", self.default_refill_rate)
def _resolve_capacity(self, key: str, rate_limit: Optional[int] = None) -> int:
if rate_limit is not None:
return rate_limit
if key.startswith("api_key:"):
return self.config.get("api_key_capacity", self.default_capacity)
if key.startswith("user:"):
return self.config.get("user_capacity", self.default_capacity * 2)
return self.default_capacity
def _resolve_refill_rate(self, key: str, rate_limit: Optional[int] = None) -> float:
if rate_limit is not None:
return rate_limit / 60.0
if key.startswith("api_key:"):
return self.config.get("api_key_refill_rate", self.default_refill_rate)
if key.startswith("user:"):
return self.config.get("user_refill_rate", self.default_refill_rate * 2)
return self.default_refill_rate
class RedisTokenBucketBackend:
"""使用 Redis 存储令牌桶状态,支持多实例共享"""
_SCRIPT = """
local key = KEYS[1]
local now = tonumber(ARGV[1])
local capacity = tonumber(ARGV[2])
local refill_rate = tonumber(ARGV[3])
local amount = tonumber(ARGV[4])
local data = redis.call('HMGET', key, 'tokens', 'timestamp')
local tokens = tonumber(data[1])
local last_refill = tonumber(data[2])
if tokens == nil then
tokens = capacity
last_refill = now
end
local delta = math.max(0, now - last_refill)
local refill = delta * refill_rate
tokens = math.min(capacity, tokens + refill)
local allowed = 0
local retry_after = 0
if tokens >= amount then
tokens = tokens - amount
allowed = 1
else
retry_after = math.ceil((amount - tokens) / refill_rate)
end
redis.call('HMSET', key, 'tokens', tokens, 'timestamp', now)
local ttl = math.max(1, math.ceil(capacity / refill_rate))
redis.call('EXPIRE', key, ttl)
return {allowed, tokens, retry_after}
"""
def __init__(self, redis_client):
self.redis = redis_client
self._consume_script = self.redis.register_script(self._SCRIPT)
def _redis_key(self, key: str) -> str:
return f"rate_limit:bucket:{key}"
async def peek(
self,
key: str,
capacity: int,
refill_rate: float,
amount: int,
) -> RateLimitResult:
bucket_key = self._redis_key(key)
data = await self.redis.hmget(bucket_key, "tokens", "timestamp")
tokens = data[0]
last_refill = data[1]
if tokens is None or last_refill is None:
remaining = capacity
reset_at = datetime.now() + timedelta(seconds=capacity / refill_rate)
else:
tokens_value = float(tokens)
last_refill_value = float(last_refill)
delta = max(0.0, time.time() - last_refill_value)
tokens_value = min(capacity, tokens_value + delta * refill_rate)
remaining = int(tokens_value)
reset_after = 0 if tokens_value >= capacity else (capacity - tokens_value) / refill_rate
reset_at = datetime.now() + timedelta(seconds=reset_after)
allowed = remaining >= amount
retry_after = None
if not allowed:
needed = max(0, amount - remaining)
retry_after = int(needed / refill_rate) + 1
return RateLimitResult(
allowed=allowed,
remaining=int(remaining),
reset_at=reset_at,
retry_after=retry_after,
message=(
None
if allowed
else f"Rate limit exceeded. Please retry after {retry_after} seconds."
),
)
async def consume(
self,
key: str,
capacity: int,
refill_rate: float,
amount: int,
) -> Tuple[bool, int]:
result = await self._consume_script(
keys=[self._redis_key(key)],
args=[time.time(), capacity, refill_rate, amount],
)
allowed = bool(result[0])
remaining = int(float(result[1]))
return allowed, remaining
async def reset(self, key: str):
await self.redis.delete(self._redis_key(key))
async def get_stats(self, key: str, capacity: int, refill_rate: float) -> Dict[str, Any]:
data = await self.redis.hmget(self._redis_key(key), "tokens", "timestamp")
tokens = data[0]
timestamp = data[1]
return {
"strategy": "token_bucket",
"key": key,
"capacity": capacity,
"remaining": float(tokens) if tokens else capacity,
"refill_rate": refill_rate,
"last_refill": float(timestamp) if timestamp else time.time(),
"backend": "redis",
}