Files
Aether/src/plugins/rate_limit/token_bucket.py
fawney19 dec681fea0 fix: 统一时区处理,确保所有 datetime 带时区信息
- token_bucket.py: get_reset_time 和 Redis 后端使用 timezone.utc
- sliding_window.py: get_reset_time 和 retry_after 计算使用 timezone.utc
- provider_strategy.py: dateutil.parser 解析后确保有时区信息
2026-01-05 02:23:24 +08:00

432 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""令牌桶速率限制策略,支持 Redis 分布式后端"""
import asyncio
import os
import time
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional, Tuple
from ...clients.redis_client import get_redis_client_sync
from src.core.logger import logger
from .base import RateLimitResult, RateLimitStrategy
class TokenBucket:
"""令牌桶实现"""
def __init__(self, capacity: int, refill_rate: float):
"""
初始化令牌桶
Args:
capacity: 桶容量(最大令牌数)
refill_rate: 令牌补充速率(每秒)
"""
self.capacity = capacity
self.refill_rate = refill_rate
self.tokens = capacity
self.last_refill = time.time()
def _refill(self):
"""补充令牌"""
now = time.time()
time_passed = now - self.last_refill
tokens_to_add = time_passed * self.refill_rate
if tokens_to_add > 0:
self.tokens = min(self.capacity, self.tokens + tokens_to_add)
self.last_refill = now
def consume(self, amount: int = 1) -> bool:
"""
消费令牌
Args:
amount: 要消费的令牌数
Returns:
是否成功消费
"""
self._refill()
if self.tokens >= amount:
self.tokens -= amount
return True
return False
def get_remaining(self) -> int:
"""获取剩余令牌数"""
self._refill()
return int(self.tokens)
def get_reset_time(self) -> datetime:
"""获取下次完全恢复的时间"""
if self.tokens >= self.capacity:
return datetime.now(timezone.utc)
tokens_needed = self.capacity - self.tokens
seconds_to_full = tokens_needed / self.refill_rate
return datetime.now(timezone.utc) + timedelta(seconds=seconds_to_full)
class TokenBucketStrategy(RateLimitStrategy):
"""
令牌桶算法速率限制策略
特点:
- 允许突发流量
- 平均速率受限
- 适合处理不均匀的流量模式
"""
def __init__(self):
super().__init__("token_bucket")
self.buckets: Dict[str, TokenBucket] = {}
self._lock = asyncio.Lock()
# 默认配置
self.default_capacity = 100 # 默认桶容量
self.default_refill_rate = 10 # 默认每秒补充10个令牌
# 可选的 Redis 后端
self._redis_backend: Optional[RedisTokenBucketBackend] = None
self._redis_checked = False
self._backend_mode = os.getenv("RATE_LIMIT_BACKEND", "auto").lower()
def _get_bucket(self, key: str, rate_limit: Optional[int] = None) -> TokenBucket:
"""
获取或创建令牌桶
Args:
key: 限制键
rate_limit: 每分钟请求限制(来自数据库配置),如果提供则使用此值
Returns:
令牌桶实例
"""
if key not in self.buckets:
# 如果提供了rate_limit参数来自数据库优先使用
if rate_limit is not None:
# rate_limit 是每分钟请求数,转换为令牌桶参数
capacity = rate_limit # 桶容量等于每分钟限制
refill_rate = rate_limit / 60.0 # 每秒补充的令牌数
# 否则根据key的不同前缀使用不同的配置
elif key.startswith("api_key:"):
capacity = self.config.get("api_key_capacity", self.default_capacity)
refill_rate = self.config.get("api_key_refill_rate", self.default_refill_rate)
elif key.startswith("user:"):
capacity = self.config.get("user_capacity", self.default_capacity * 2)
refill_rate = self.config.get("user_refill_rate", self.default_refill_rate * 2)
else:
capacity = self.default_capacity
refill_rate = self.default_refill_rate
self.buckets[key] = TokenBucket(capacity, refill_rate)
return self.buckets[key]
def _want_redis_backend(self) -> bool:
return self._backend_mode in {"auto", "redis"}
async def _ensure_backend(self):
if self._redis_checked:
return
self._redis_checked = True
if not self._want_redis_backend():
return
redis_client = get_redis_client_sync()
if redis_client:
self._redis_backend = RedisTokenBucketBackend(redis_client)
logger.info("速率限制改用 Redis 令牌桶后端")
elif self._backend_mode == "redis":
logger.warning("RATE_LIMIT_BACKEND=redis 但 Redis 客户端不可用,回退到内存桶")
async def check_limit(self, key: str, **kwargs) -> RateLimitResult:
"""
检查速率限制
Args:
key: 限制键
**kwargs: 额外参数,包括 rate_limit (从数据库配置)
Returns:
速率限制检查结果
"""
await self._ensure_backend()
rate_limit = kwargs.get("rate_limit")
amount = kwargs.get("amount", 1)
if self._redis_backend:
return await self._redis_backend.peek(
key=key,
capacity=self._resolve_capacity(key, rate_limit),
refill_rate=self._resolve_refill_rate(key, rate_limit),
amount=amount,
)
async with self._lock:
bucket = self._get_bucket(key, rate_limit)
remaining = bucket.get_remaining()
reset_at = bucket.get_reset_time()
allowed = remaining >= amount
retry_after = None
if not allowed:
tokens_needed = amount - remaining
retry_after = int(tokens_needed / bucket.refill_rate) + 1
return RateLimitResult(
allowed=allowed,
remaining=remaining,
reset_at=reset_at,
retry_after=retry_after,
message=(
None
if allowed
else f"Rate limit exceeded. Please retry after {retry_after} seconds."
),
)
async def consume(self, key: str, amount: int = 1, **kwargs) -> bool:
"""
消费令牌
Args:
key: 限制键
amount: 消费数量
Returns:
是否成功消费
"""
await self._ensure_backend()
if self._redis_backend:
success, remaining = await self._redis_backend.consume(
key=key,
capacity=self._resolve_capacity(key, kwargs.get("rate_limit")),
refill_rate=self._resolve_refill_rate(key, kwargs.get("rate_limit")),
amount=amount,
)
if success:
logger.debug("Redis 令牌消费成功")
else:
logger.warning("Redis 令牌消费失败")
return success
async with self._lock:
bucket = self._get_bucket(key)
success = bucket.consume(amount)
if success:
logger.debug(f"令牌消费成功")
else:
logger.warning(f"令牌消费失败:超出速率限制")
return success
async def reset(self, key: str):
"""
重置令牌桶
Args:
key: 限制键
"""
await self._ensure_backend()
if self._redis_backend:
await self._redis_backend.reset(key)
return
async with self._lock:
if key in self.buckets:
bucket = self.buckets[key]
bucket.tokens = bucket.capacity
bucket.last_refill = time.time()
logger.info(f"令牌桶已重置")
async def get_stats(self, key: str) -> Dict[str, Any]:
"""
获取统计信息
Args:
key: 限制键
Returns:
统计信息
"""
await self._ensure_backend()
if self._redis_backend:
return await self._redis_backend.get_stats(
key,
capacity=self._resolve_capacity(key),
refill_rate=self._resolve_refill_rate(key),
)
async with self._lock:
bucket = self._get_bucket(key)
return {
"strategy": "token_bucket",
"key": key,
"capacity": bucket.capacity,
"remaining": bucket.get_remaining(),
"refill_rate": bucket.refill_rate,
"reset_at": bucket.get_reset_time().isoformat(),
}
def configure(self, config: Dict[str, Any]):
"""
配置策略
支持的配置项:
- api_key_capacity: API Key的桶容量
- api_key_refill_rate: API Key的令牌补充速率
- user_capacity: 用户的桶容量
- user_refill_rate: 用户的令牌补充速率
"""
super().configure(config)
self.default_capacity = config.get("default_capacity", self.default_capacity)
self.default_refill_rate = config.get("default_refill_rate", self.default_refill_rate)
def _resolve_capacity(self, key: str, rate_limit: Optional[int] = None) -> int:
if rate_limit is not None:
return rate_limit
if key.startswith("api_key:"):
return self.config.get("api_key_capacity", self.default_capacity)
if key.startswith("user:"):
return self.config.get("user_capacity", self.default_capacity * 2)
return self.default_capacity
def _resolve_refill_rate(self, key: str, rate_limit: Optional[int] = None) -> float:
if rate_limit is not None:
return rate_limit / 60.0
if key.startswith("api_key:"):
return self.config.get("api_key_refill_rate", self.default_refill_rate)
if key.startswith("user:"):
return self.config.get("user_refill_rate", self.default_refill_rate * 2)
return self.default_refill_rate
class RedisTokenBucketBackend:
"""使用 Redis 存储令牌桶状态,支持多实例共享"""
_SCRIPT = """
local key = KEYS[1]
local now = tonumber(ARGV[1])
local capacity = tonumber(ARGV[2])
local refill_rate = tonumber(ARGV[3])
local amount = tonumber(ARGV[4])
local data = redis.call('HMGET', key, 'tokens', 'timestamp')
local tokens = tonumber(data[1])
local last_refill = tonumber(data[2])
if tokens == nil then
tokens = capacity
last_refill = now
end
local delta = math.max(0, now - last_refill)
local refill = delta * refill_rate
tokens = math.min(capacity, tokens + refill)
local allowed = 0
local retry_after = 0
if tokens >= amount then
tokens = tokens - amount
allowed = 1
else
retry_after = math.ceil((amount - tokens) / refill_rate)
end
redis.call('HMSET', key, 'tokens', tokens, 'timestamp', now)
local ttl = math.max(1, math.ceil(capacity / refill_rate))
redis.call('EXPIRE', key, ttl)
return {allowed, tokens, retry_after}
"""
def __init__(self, redis_client):
self.redis = redis_client
self._consume_script = self.redis.register_script(self._SCRIPT)
def _redis_key(self, key: str) -> str:
return f"rate_limit:bucket:{key}"
async def peek(
self,
key: str,
capacity: int,
refill_rate: float,
amount: int,
) -> RateLimitResult:
bucket_key = self._redis_key(key)
data = await self.redis.hmget(bucket_key, "tokens", "timestamp")
tokens = data[0]
last_refill = data[1]
if tokens is None or last_refill is None:
remaining = capacity
reset_at = datetime.now(timezone.utc) + timedelta(seconds=capacity / refill_rate)
else:
tokens_value = float(tokens)
last_refill_value = float(last_refill)
delta = max(0.0, time.time() - last_refill_value)
tokens_value = min(capacity, tokens_value + delta * refill_rate)
remaining = int(tokens_value)
reset_after = 0 if tokens_value >= capacity else (capacity - tokens_value) / refill_rate
reset_at = datetime.now(timezone.utc) + timedelta(seconds=reset_after)
allowed = remaining >= amount
retry_after = None
if not allowed:
needed = max(0, amount - remaining)
retry_after = int(needed / refill_rate) + 1
return RateLimitResult(
allowed=allowed,
remaining=int(remaining),
reset_at=reset_at,
retry_after=retry_after,
message=(
None
if allowed
else f"Rate limit exceeded. Please retry after {retry_after} seconds."
),
)
async def consume(
self,
key: str,
capacity: int,
refill_rate: float,
amount: int,
) -> Tuple[bool, int]:
result = await self._consume_script(
keys=[self._redis_key(key)],
args=[time.time(), capacity, refill_rate, amount],
)
allowed = bool(result[0])
remaining = int(float(result[1]))
return allowed, remaining
async def reset(self, key: str):
await self.redis.delete(self._redis_key(key))
async def get_stats(self, key: str, capacity: int, refill_rate: float) -> Dict[str, Any]:
data = await self.redis.hmget(self._redis_key(key), "tokens", "timestamp")
tokens = data[0]
timestamp = data[1]
return {
"strategy": "token_bucket",
"key": key,
"capacity": capacity,
"remaining": float(tokens) if tokens else capacity,
"refill_rate": refill_rate,
"last_refill": float(timestamp) if timestamp else time.time(),
"backend": "redis",
}