mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-13 04:58:28 +08:00
Initial commit
This commit is contained in:
9
src/plugins/rate_limit/__init__.py
Normal file
9
src/plugins/rate_limit/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
速率限制插件模块
|
||||
"""
|
||||
|
||||
from .base import RateLimitResult, RateLimitStrategy
|
||||
from .sliding_window import SlidingWindowStrategy
|
||||
from .token_bucket import TokenBucketStrategy
|
||||
|
||||
__all__ = ["RateLimitStrategy", "RateLimitResult", "TokenBucketStrategy", "SlidingWindowStrategy"]
|
||||
132
src/plugins/rate_limit/base.py
Normal file
132
src/plugins/rate_limit/base.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""
|
||||
速率限制策略基类
|
||||
定义速率限制策略的接口
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ..common import BasePlugin, HealthStatus, PluginMetadata
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitResult:
|
||||
"""
|
||||
速率限制检查结果
|
||||
"""
|
||||
|
||||
allowed: bool
|
||||
remaining: int
|
||||
reset_at: Optional[datetime] = None
|
||||
retry_after: Optional[int] = None
|
||||
message: Optional[str] = None
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.headers is None:
|
||||
self.headers = {}
|
||||
if self.remaining is not None:
|
||||
self.headers["X-RateLimit-Remaining"] = str(self.remaining)
|
||||
if self.reset_at:
|
||||
self.headers["X-RateLimit-Reset"] = str(int(self.reset_at.timestamp()))
|
||||
if self.retry_after:
|
||||
self.headers["Retry-After"] = str(self.retry_after)
|
||||
|
||||
|
||||
class RateLimitStrategy(BasePlugin):
|
||||
"""
|
||||
速率限制策略基类
|
||||
所有速率限制策略必须继承此类
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
priority: int = 0,
|
||||
version: str = "1.0.0",
|
||||
author: str = "Unknown",
|
||||
description: str = "",
|
||||
api_version: str = "1.0",
|
||||
dependencies: List[str] = None,
|
||||
provides: List[str] = None,
|
||||
config: Dict[str, Any] = None,
|
||||
):
|
||||
"""
|
||||
初始化速率限制策略
|
||||
|
||||
Args:
|
||||
name: 策略名称
|
||||
priority: 优先级(数字越大优先级越高)
|
||||
version: 插件版本
|
||||
author: 插件作者
|
||||
description: 插件描述
|
||||
api_version: API版本
|
||||
dependencies: 依赖的其他插件
|
||||
provides: 提供的服务
|
||||
config: 配置字典
|
||||
"""
|
||||
super().__init__(
|
||||
name=name,
|
||||
priority=priority,
|
||||
version=version,
|
||||
author=author,
|
||||
description=description,
|
||||
api_version=api_version,
|
||||
dependencies=dependencies,
|
||||
provides=provides,
|
||||
config=config,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def check_limit(self, key: str, **kwargs) -> RateLimitResult:
|
||||
"""
|
||||
检查速率限制
|
||||
|
||||
Args:
|
||||
key: 限制键(如用户ID、API Key ID等)
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
速率限制检查结果
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def consume(self, key: str, amount: int = 1, **kwargs) -> bool:
|
||||
"""
|
||||
消费配额
|
||||
|
||||
Args:
|
||||
key: 限制键
|
||||
amount: 消费数量
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
是否成功消费
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def reset(self, key: str):
|
||||
"""
|
||||
重置限制
|
||||
|
||||
Args:
|
||||
key: 限制键
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_stats(self, key: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取统计信息
|
||||
|
||||
Args:
|
||||
key: 限制键
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
pass
|
||||
363
src/plugins/rate_limit/sliding_window.py
Normal file
363
src/plugins/rate_limit/sliding_window.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""
|
||||
滑动窗口算法速率限制策略
|
||||
精确的速率限制,不允许突发
|
||||
|
||||
WARNING: 多进程环境注意事项
|
||||
=============================
|
||||
此插件的窗口状态存储在进程内存中。如果使用 Gunicorn/uvicorn 多 worker 模式,
|
||||
每个 worker 进程有独立的限流状态,可能导致:
|
||||
- 实际允许的请求数 = 配置限制 * worker数量
|
||||
- 限流效果大打折扣
|
||||
|
||||
解决方案:
|
||||
1. 单 worker 模式:适用于低流量场景
|
||||
2. Redis 共享状态:使用 Redis 实现分布式滑动窗口
|
||||
3. 使用 token_bucket.py:令牌桶策略可以更容易迁移到 Redis
|
||||
|
||||
目前项目已有 Redis 依赖(src/clients/redis_client.py),
|
||||
建议在生产环境使用 Redis 实现分布式限流。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from typing import Any, Deque, Dict
|
||||
|
||||
from src.core.logger import logger
|
||||
from .base import RateLimitResult, RateLimitStrategy
|
||||
|
||||
|
||||
|
||||
class SlidingWindow:
|
||||
"""滑动窗口实现"""
|
||||
|
||||
def __init__(self, window_size: int, max_requests: int):
|
||||
"""
|
||||
初始化滑动窗口
|
||||
|
||||
Args:
|
||||
window_size: 窗口大小(秒)
|
||||
max_requests: 窗口内最大请求数
|
||||
"""
|
||||
self.window_size = window_size
|
||||
self.max_requests = max_requests
|
||||
self.requests: Deque[float] = deque()
|
||||
self.last_access_time: float = time.time()
|
||||
|
||||
def _cleanup(self):
|
||||
"""清理过期的请求记录"""
|
||||
current_time = time.time()
|
||||
self.last_access_time = current_time # 更新最后访问时间
|
||||
cutoff_time = current_time - self.window_size
|
||||
|
||||
# 移除窗口外的请求
|
||||
while self.requests and self.requests[0] < cutoff_time:
|
||||
self.requests.popleft()
|
||||
|
||||
def can_accept(self, amount: int = 1) -> bool:
|
||||
"""
|
||||
检查是否可以接受新请求
|
||||
|
||||
Args:
|
||||
amount: 请求数量
|
||||
|
||||
Returns:
|
||||
是否可以接受
|
||||
"""
|
||||
self._cleanup()
|
||||
return len(self.requests) + amount <= self.max_requests
|
||||
|
||||
def add_request(self, amount: int = 1) -> bool:
|
||||
"""
|
||||
添加请求记录
|
||||
|
||||
Args:
|
||||
amount: 请求数量
|
||||
|
||||
Returns:
|
||||
是否成功添加
|
||||
"""
|
||||
if not self.can_accept(amount):
|
||||
return False
|
||||
|
||||
current_time = time.time()
|
||||
for _ in range(amount):
|
||||
self.requests.append(current_time)
|
||||
return True
|
||||
|
||||
def get_remaining(self) -> int:
|
||||
"""获取剩余配额"""
|
||||
self._cleanup()
|
||||
return max(0, self.max_requests - len(self.requests))
|
||||
|
||||
def get_reset_time(self) -> datetime:
|
||||
"""获取最早的重置时间"""
|
||||
self._cleanup()
|
||||
if not self.requests:
|
||||
return datetime.now()
|
||||
|
||||
# 最早的请求将在window_size秒后过期
|
||||
oldest_request = self.requests[0]
|
||||
reset_time = oldest_request + self.window_size
|
||||
return datetime.fromtimestamp(reset_time)
|
||||
|
||||
|
||||
class SlidingWindowStrategy(RateLimitStrategy):
|
||||
"""
|
||||
滑动窗口算法速率限制策略
|
||||
|
||||
特点:
|
||||
- 精确的速率限制
|
||||
- 不允许突发流量
|
||||
- 适合需要严格速率控制的场景
|
||||
- 自动清理长时间不活跃的窗口,防止内存泄漏
|
||||
"""
|
||||
|
||||
# 默认最大缓存窗口数量
|
||||
DEFAULT_MAX_WINDOWS = 10000
|
||||
# 默认窗口过期时间(秒)- 超过此时间未访问的窗口将被清理
|
||||
DEFAULT_WINDOW_EXPIRY = 3600 # 1小时
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("sliding_window")
|
||||
self.windows: Dict[str, SlidingWindow] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# 默认配置
|
||||
self.default_window_size = 60 # 默认60秒窗口
|
||||
self.default_max_requests = 100 # 默认100个请求
|
||||
|
||||
# 内存管理配置
|
||||
self.max_windows = self.DEFAULT_MAX_WINDOWS
|
||||
self.window_expiry = self.DEFAULT_WINDOW_EXPIRY
|
||||
self._last_cleanup_time: float = time.time()
|
||||
self._cleanup_interval = 300 # 每5分钟检查一次是否需要清理
|
||||
|
||||
def _cleanup_expired_windows(self) -> int:
|
||||
"""
|
||||
清理过期的窗口,防止内存泄漏
|
||||
|
||||
Returns:
|
||||
清理的窗口数量
|
||||
"""
|
||||
current_time = time.time()
|
||||
expired_keys = []
|
||||
|
||||
for key, window in self.windows.items():
|
||||
# 检查窗口是否过期(长时间未访问)
|
||||
if current_time - window.last_access_time > self.window_expiry:
|
||||
expired_keys.append(key)
|
||||
|
||||
# 删除过期窗口
|
||||
for key in expired_keys:
|
||||
del self.windows[key]
|
||||
|
||||
if expired_keys:
|
||||
logger.info(f"清理了 {len(expired_keys)} 个过期的滑动窗口")
|
||||
|
||||
return len(expired_keys)
|
||||
|
||||
def _evict_lru_windows(self, count: int) -> int:
|
||||
"""
|
||||
使用 LRU 策略淘汰最久未使用的窗口
|
||||
|
||||
Args:
|
||||
count: 需要淘汰的数量
|
||||
|
||||
Returns:
|
||||
实际淘汰的数量
|
||||
"""
|
||||
if not self.windows or count <= 0:
|
||||
return 0
|
||||
|
||||
# 按最后访问时间排序,淘汰最久未访问的
|
||||
sorted_keys = sorted(self.windows.keys(), key=lambda k: self.windows[k].last_access_time)
|
||||
|
||||
evicted = 0
|
||||
for key in sorted_keys[:count]:
|
||||
del self.windows[key]
|
||||
evicted += 1
|
||||
|
||||
if evicted:
|
||||
logger.warning(f"LRU 淘汰了 {evicted} 个滑动窗口(达到容量上限)")
|
||||
|
||||
return evicted
|
||||
|
||||
async def _maybe_cleanup(self):
|
||||
"""检查是否需要执行清理操作"""
|
||||
current_time = time.time()
|
||||
|
||||
# 定期清理过期窗口
|
||||
if current_time - self._last_cleanup_time > self._cleanup_interval:
|
||||
self._cleanup_expired_windows()
|
||||
self._last_cleanup_time = current_time
|
||||
|
||||
# 如果超过容量上限,执行 LRU 淘汰
|
||||
if len(self.windows) >= self.max_windows:
|
||||
# 淘汰 10% 的窗口
|
||||
evict_count = max(1, self.max_windows // 10)
|
||||
self._evict_lru_windows(evict_count)
|
||||
|
||||
def _get_window(self, key: str) -> SlidingWindow:
|
||||
"""
|
||||
获取或创建滑动窗口
|
||||
|
||||
Args:
|
||||
key: 限制键
|
||||
|
||||
Returns:
|
||||
滑动窗口实例
|
||||
"""
|
||||
if key not in self.windows:
|
||||
# 根据key的不同前缀使用不同的配置
|
||||
if key.startswith("api_key:"):
|
||||
window_size = self.config.get("api_key_window_size", self.default_window_size)
|
||||
max_requests = self.config.get("api_key_max_requests", self.default_max_requests)
|
||||
elif key.startswith("user:"):
|
||||
window_size = self.config.get("user_window_size", self.default_window_size)
|
||||
max_requests = self.config.get("user_max_requests", self.default_max_requests * 2)
|
||||
else:
|
||||
window_size = self.default_window_size
|
||||
max_requests = self.default_max_requests
|
||||
|
||||
self.windows[key] = SlidingWindow(window_size, max_requests)
|
||||
|
||||
return self.windows[key]
|
||||
|
||||
async def check_limit(self, key: str, **kwargs) -> RateLimitResult:
|
||||
"""
|
||||
检查速率限制
|
||||
|
||||
Args:
|
||||
key: 限制键
|
||||
|
||||
Returns:
|
||||
速率限制检查结果
|
||||
"""
|
||||
async with self._lock:
|
||||
# 检查是否需要清理过期窗口
|
||||
await self._maybe_cleanup()
|
||||
|
||||
window = self._get_window(key)
|
||||
amount = kwargs.get("amount", 1)
|
||||
|
||||
# 检查是否可以接受请求
|
||||
allowed = window.can_accept(amount)
|
||||
remaining = window.get_remaining()
|
||||
reset_at = window.get_reset_time()
|
||||
|
||||
retry_after = None
|
||||
if not allowed:
|
||||
# 计算需要等待的时间(最早请求过期的时间)
|
||||
retry_after = int((reset_at - datetime.now()).total_seconds()) + 1
|
||||
|
||||
return RateLimitResult(
|
||||
allowed=allowed,
|
||||
remaining=remaining,
|
||||
reset_at=reset_at,
|
||||
retry_after=retry_after,
|
||||
message=(
|
||||
None
|
||||
if allowed
|
||||
else f"Rate limit exceeded. Please retry after {retry_after} seconds."
|
||||
),
|
||||
)
|
||||
|
||||
async def consume(self, key: str, amount: int = 1, **kwargs) -> bool:
|
||||
"""
|
||||
消费配额
|
||||
|
||||
Args:
|
||||
key: 限制键
|
||||
amount: 消费数量
|
||||
|
||||
Returns:
|
||||
是否成功消费
|
||||
"""
|
||||
async with self._lock:
|
||||
window = self._get_window(key)
|
||||
success = window.add_request(amount)
|
||||
|
||||
if success:
|
||||
logger.debug(f"滑动窗口请求记录成功")
|
||||
else:
|
||||
logger.warning(f"滑动窗口请求被拒绝:超出速率限制")
|
||||
|
||||
return success
|
||||
|
||||
async def reset(self, key: str):
|
||||
"""
|
||||
重置滑动窗口
|
||||
|
||||
Args:
|
||||
key: 限制键
|
||||
"""
|
||||
async with self._lock:
|
||||
if key in self.windows:
|
||||
window = self.windows[key]
|
||||
window.requests.clear()
|
||||
|
||||
logger.info(f"滑动窗口已重置")
|
||||
|
||||
async def get_stats(self, key: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取统计信息
|
||||
|
||||
Args:
|
||||
key: 限制键
|
||||
|
||||
Returns:
|
||||
统计信息
|
||||
"""
|
||||
async with self._lock:
|
||||
window = self._get_window(key)
|
||||
window._cleanup() # 先清理过期请求
|
||||
|
||||
return {
|
||||
"strategy": "sliding_window",
|
||||
"key": key,
|
||||
"window_size": window.window_size,
|
||||
"max_requests": window.max_requests,
|
||||
"current_requests": len(window.requests),
|
||||
"remaining": window.get_remaining(),
|
||||
"reset_at": window.get_reset_time().isoformat(),
|
||||
}
|
||||
|
||||
def configure(self, config: Dict[str, Any]):
|
||||
"""
|
||||
配置策略
|
||||
|
||||
支持的配置项:
|
||||
- api_key_window_size: API Key的窗口大小(秒)
|
||||
- api_key_max_requests: API Key的最大请求数
|
||||
- user_window_size: 用户的窗口大小(秒)
|
||||
- user_max_requests: 用户的最大请求数
|
||||
- max_windows: 最大缓存窗口数量(防止内存泄漏)
|
||||
- window_expiry: 窗口过期时间(秒)
|
||||
- cleanup_interval: 清理检查间隔(秒)
|
||||
"""
|
||||
super().configure(config)
|
||||
self.default_window_size = config.get("default_window_size", self.default_window_size)
|
||||
self.default_max_requests = config.get("default_max_requests", self.default_max_requests)
|
||||
self.max_windows = config.get("max_windows", self.max_windows)
|
||||
self.window_expiry = config.get("window_expiry", self.window_expiry)
|
||||
self._cleanup_interval = config.get("cleanup_interval", self._cleanup_interval)
|
||||
|
||||
def get_memory_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取内存使用统计信息
|
||||
|
||||
Returns:
|
||||
内存使用统计
|
||||
"""
|
||||
return {
|
||||
"total_windows": len(self.windows),
|
||||
"max_windows": self.max_windows,
|
||||
"window_expiry": self.window_expiry,
|
||||
"cleanup_interval": self._cleanup_interval,
|
||||
"last_cleanup_time": self._last_cleanup_time,
|
||||
"usage_percent": (
|
||||
(len(self.windows) / self.max_windows * 100) if self.max_windows > 0 else 0
|
||||
),
|
||||
}
|
||||
431
src/plugins/rate_limit/token_bucket.py
Normal file
431
src/plugins/rate_limit/token_bucket.py
Normal file
@@ -0,0 +1,431 @@
|
||||
"""令牌桶速率限制策略,支持 Redis 分布式后端"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from ...clients.redis_client import get_redis_client_sync
|
||||
from src.core.logger import logger
|
||||
from .base import RateLimitResult, RateLimitStrategy
|
||||
|
||||
|
||||
|
||||
class TokenBucket:
|
||||
"""令牌桶实现"""
|
||||
|
||||
def __init__(self, capacity: int, refill_rate: float):
|
||||
"""
|
||||
初始化令牌桶
|
||||
|
||||
Args:
|
||||
capacity: 桶容量(最大令牌数)
|
||||
refill_rate: 令牌补充速率(每秒)
|
||||
"""
|
||||
self.capacity = capacity
|
||||
self.refill_rate = refill_rate
|
||||
self.tokens = capacity
|
||||
self.last_refill = time.time()
|
||||
|
||||
def _refill(self):
|
||||
"""补充令牌"""
|
||||
now = time.time()
|
||||
time_passed = now - self.last_refill
|
||||
tokens_to_add = time_passed * self.refill_rate
|
||||
|
||||
if tokens_to_add > 0:
|
||||
self.tokens = min(self.capacity, self.tokens + tokens_to_add)
|
||||
self.last_refill = now
|
||||
|
||||
def consume(self, amount: int = 1) -> bool:
|
||||
"""
|
||||
消费令牌
|
||||
|
||||
Args:
|
||||
amount: 要消费的令牌数
|
||||
|
||||
Returns:
|
||||
是否成功消费
|
||||
"""
|
||||
self._refill()
|
||||
|
||||
if self.tokens >= amount:
|
||||
self.tokens -= amount
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_remaining(self) -> int:
|
||||
"""获取剩余令牌数"""
|
||||
self._refill()
|
||||
return int(self.tokens)
|
||||
|
||||
def get_reset_time(self) -> datetime:
|
||||
"""获取下次完全恢复的时间"""
|
||||
if self.tokens >= self.capacity:
|
||||
return datetime.now()
|
||||
|
||||
tokens_needed = self.capacity - self.tokens
|
||||
seconds_to_full = tokens_needed / self.refill_rate
|
||||
return datetime.now() + timedelta(seconds=seconds_to_full)
|
||||
|
||||
|
||||
class TokenBucketStrategy(RateLimitStrategy):
|
||||
"""
|
||||
令牌桶算法速率限制策略
|
||||
|
||||
特点:
|
||||
- 允许突发流量
|
||||
- 平均速率受限
|
||||
- 适合处理不均匀的流量模式
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("token_bucket")
|
||||
self.buckets: Dict[str, TokenBucket] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# 默认配置
|
||||
self.default_capacity = 100 # 默认桶容量
|
||||
self.default_refill_rate = 10 # 默认每秒补充10个令牌
|
||||
|
||||
# 可选的 Redis 后端
|
||||
self._redis_backend: Optional[RedisTokenBucketBackend] = None
|
||||
self._redis_checked = False
|
||||
self._backend_mode = os.getenv("RATE_LIMIT_BACKEND", "auto").lower()
|
||||
|
||||
def _get_bucket(self, key: str, rate_limit: Optional[int] = None) -> TokenBucket:
|
||||
"""
|
||||
获取或创建令牌桶
|
||||
|
||||
Args:
|
||||
key: 限制键
|
||||
rate_limit: 每分钟请求限制(来自数据库配置),如果提供则使用此值
|
||||
|
||||
Returns:
|
||||
令牌桶实例
|
||||
"""
|
||||
if key not in self.buckets:
|
||||
# 如果提供了rate_limit参数(来自数据库),优先使用
|
||||
if rate_limit is not None:
|
||||
# rate_limit 是每分钟请求数,转换为令牌桶参数
|
||||
capacity = rate_limit # 桶容量等于每分钟限制
|
||||
refill_rate = rate_limit / 60.0 # 每秒补充的令牌数
|
||||
# 否则根据key的不同前缀使用不同的配置
|
||||
elif key.startswith("api_key:"):
|
||||
capacity = self.config.get("api_key_capacity", self.default_capacity)
|
||||
refill_rate = self.config.get("api_key_refill_rate", self.default_refill_rate)
|
||||
elif key.startswith("user:"):
|
||||
capacity = self.config.get("user_capacity", self.default_capacity * 2)
|
||||
refill_rate = self.config.get("user_refill_rate", self.default_refill_rate * 2)
|
||||
else:
|
||||
capacity = self.default_capacity
|
||||
refill_rate = self.default_refill_rate
|
||||
|
||||
self.buckets[key] = TokenBucket(capacity, refill_rate)
|
||||
|
||||
return self.buckets[key]
|
||||
|
||||
def _want_redis_backend(self) -> bool:
|
||||
return self._backend_mode in {"auto", "redis"}
|
||||
|
||||
async def _ensure_backend(self):
|
||||
if self._redis_checked:
|
||||
return
|
||||
self._redis_checked = True
|
||||
if not self._want_redis_backend():
|
||||
return
|
||||
redis_client = get_redis_client_sync()
|
||||
if redis_client:
|
||||
self._redis_backend = RedisTokenBucketBackend(redis_client)
|
||||
logger.info("速率限制改用 Redis 令牌桶后端")
|
||||
elif self._backend_mode == "redis":
|
||||
logger.warning("RATE_LIMIT_BACKEND=redis 但 Redis 客户端不可用,回退到内存桶")
|
||||
|
||||
async def check_limit(self, key: str, **kwargs) -> RateLimitResult:
|
||||
"""
|
||||
检查速率限制
|
||||
|
||||
Args:
|
||||
key: 限制键
|
||||
**kwargs: 额外参数,包括 rate_limit (从数据库配置)
|
||||
|
||||
Returns:
|
||||
速率限制检查结果
|
||||
"""
|
||||
await self._ensure_backend()
|
||||
|
||||
rate_limit = kwargs.get("rate_limit")
|
||||
amount = kwargs.get("amount", 1)
|
||||
|
||||
if self._redis_backend:
|
||||
return await self._redis_backend.peek(
|
||||
key=key,
|
||||
capacity=self._resolve_capacity(key, rate_limit),
|
||||
refill_rate=self._resolve_refill_rate(key, rate_limit),
|
||||
amount=amount,
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
bucket = self._get_bucket(key, rate_limit)
|
||||
remaining = bucket.get_remaining()
|
||||
reset_at = bucket.get_reset_time()
|
||||
|
||||
allowed = remaining >= amount
|
||||
|
||||
retry_after = None
|
||||
if not allowed:
|
||||
tokens_needed = amount - remaining
|
||||
retry_after = int(tokens_needed / bucket.refill_rate) + 1
|
||||
|
||||
return RateLimitResult(
|
||||
allowed=allowed,
|
||||
remaining=remaining,
|
||||
reset_at=reset_at,
|
||||
retry_after=retry_after,
|
||||
message=(
|
||||
None
|
||||
if allowed
|
||||
else f"Rate limit exceeded. Please retry after {retry_after} seconds."
|
||||
),
|
||||
)
|
||||
|
||||
async def consume(self, key: str, amount: int = 1, **kwargs) -> bool:
|
||||
"""
|
||||
消费令牌
|
||||
|
||||
Args:
|
||||
key: 限制键
|
||||
amount: 消费数量
|
||||
|
||||
Returns:
|
||||
是否成功消费
|
||||
"""
|
||||
await self._ensure_backend()
|
||||
|
||||
if self._redis_backend:
|
||||
success, remaining = await self._redis_backend.consume(
|
||||
key=key,
|
||||
capacity=self._resolve_capacity(key, kwargs.get("rate_limit")),
|
||||
refill_rate=self._resolve_refill_rate(key, kwargs.get("rate_limit")),
|
||||
amount=amount,
|
||||
)
|
||||
if success:
|
||||
logger.debug("Redis 令牌消费成功")
|
||||
else:
|
||||
logger.warning("Redis 令牌消费失败")
|
||||
return success
|
||||
|
||||
async with self._lock:
|
||||
bucket = self._get_bucket(key)
|
||||
success = bucket.consume(amount)
|
||||
|
||||
if success:
|
||||
logger.debug(f"令牌消费成功")
|
||||
else:
|
||||
logger.warning(f"令牌消费失败:超出速率限制")
|
||||
|
||||
return success
|
||||
|
||||
async def reset(self, key: str):
|
||||
"""
|
||||
重置令牌桶
|
||||
|
||||
Args:
|
||||
key: 限制键
|
||||
"""
|
||||
await self._ensure_backend()
|
||||
|
||||
if self._redis_backend:
|
||||
await self._redis_backend.reset(key)
|
||||
return
|
||||
|
||||
async with self._lock:
|
||||
if key in self.buckets:
|
||||
bucket = self.buckets[key]
|
||||
bucket.tokens = bucket.capacity
|
||||
bucket.last_refill = time.time()
|
||||
|
||||
logger.info(f"令牌桶已重置")
|
||||
|
||||
async def get_stats(self, key: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取统计信息
|
||||
|
||||
Args:
|
||||
key: 限制键
|
||||
|
||||
Returns:
|
||||
统计信息
|
||||
"""
|
||||
await self._ensure_backend()
|
||||
|
||||
if self._redis_backend:
|
||||
return await self._redis_backend.get_stats(
|
||||
key,
|
||||
capacity=self._resolve_capacity(key),
|
||||
refill_rate=self._resolve_refill_rate(key),
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
bucket = self._get_bucket(key)
|
||||
return {
|
||||
"strategy": "token_bucket",
|
||||
"key": key,
|
||||
"capacity": bucket.capacity,
|
||||
"remaining": bucket.get_remaining(),
|
||||
"refill_rate": bucket.refill_rate,
|
||||
"reset_at": bucket.get_reset_time().isoformat(),
|
||||
}
|
||||
|
||||
def configure(self, config: Dict[str, Any]):
|
||||
"""
|
||||
配置策略
|
||||
|
||||
支持的配置项:
|
||||
- api_key_capacity: API Key的桶容量
|
||||
- api_key_refill_rate: API Key的令牌补充速率
|
||||
- user_capacity: 用户的桶容量
|
||||
- user_refill_rate: 用户的令牌补充速率
|
||||
"""
|
||||
super().configure(config)
|
||||
self.default_capacity = config.get("default_capacity", self.default_capacity)
|
||||
self.default_refill_rate = config.get("default_refill_rate", self.default_refill_rate)
|
||||
|
||||
def _resolve_capacity(self, key: str, rate_limit: Optional[int] = None) -> int:
|
||||
if rate_limit is not None:
|
||||
return rate_limit
|
||||
if key.startswith("api_key:"):
|
||||
return self.config.get("api_key_capacity", self.default_capacity)
|
||||
if key.startswith("user:"):
|
||||
return self.config.get("user_capacity", self.default_capacity * 2)
|
||||
return self.default_capacity
|
||||
|
||||
def _resolve_refill_rate(self, key: str, rate_limit: Optional[int] = None) -> float:
|
||||
if rate_limit is not None:
|
||||
return rate_limit / 60.0
|
||||
if key.startswith("api_key:"):
|
||||
return self.config.get("api_key_refill_rate", self.default_refill_rate)
|
||||
if key.startswith("user:"):
|
||||
return self.config.get("user_refill_rate", self.default_refill_rate * 2)
|
||||
return self.default_refill_rate
|
||||
|
||||
|
||||
class RedisTokenBucketBackend:
|
||||
"""使用 Redis 存储令牌桶状态,支持多实例共享"""
|
||||
|
||||
_SCRIPT = """
|
||||
local key = KEYS[1]
|
||||
local now = tonumber(ARGV[1])
|
||||
local capacity = tonumber(ARGV[2])
|
||||
local refill_rate = tonumber(ARGV[3])
|
||||
local amount = tonumber(ARGV[4])
|
||||
|
||||
local data = redis.call('HMGET', key, 'tokens', 'timestamp')
|
||||
local tokens = tonumber(data[1])
|
||||
local last_refill = tonumber(data[2])
|
||||
|
||||
if tokens == nil then
|
||||
tokens = capacity
|
||||
last_refill = now
|
||||
end
|
||||
|
||||
local delta = math.max(0, now - last_refill)
|
||||
local refill = delta * refill_rate
|
||||
tokens = math.min(capacity, tokens + refill)
|
||||
|
||||
local allowed = 0
|
||||
local retry_after = 0
|
||||
if tokens >= amount then
|
||||
tokens = tokens - amount
|
||||
allowed = 1
|
||||
else
|
||||
retry_after = math.ceil((amount - tokens) / refill_rate)
|
||||
end
|
||||
|
||||
redis.call('HMSET', key, 'tokens', tokens, 'timestamp', now)
|
||||
local ttl = math.max(1, math.ceil(capacity / refill_rate))
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
return {allowed, tokens, retry_after}
|
||||
"""
|
||||
|
||||
def __init__(self, redis_client):
|
||||
self.redis = redis_client
|
||||
self._consume_script = self.redis.register_script(self._SCRIPT)
|
||||
|
||||
def _redis_key(self, key: str) -> str:
|
||||
return f"rate_limit:bucket:{key}"
|
||||
|
||||
async def peek(
|
||||
self,
|
||||
key: str,
|
||||
capacity: int,
|
||||
refill_rate: float,
|
||||
amount: int,
|
||||
) -> RateLimitResult:
|
||||
bucket_key = self._redis_key(key)
|
||||
data = await self.redis.hmget(bucket_key, "tokens", "timestamp")
|
||||
tokens = data[0]
|
||||
last_refill = data[1]
|
||||
|
||||
if tokens is None or last_refill is None:
|
||||
remaining = capacity
|
||||
reset_at = datetime.now() + timedelta(seconds=capacity / refill_rate)
|
||||
else:
|
||||
tokens_value = float(tokens)
|
||||
last_refill_value = float(last_refill)
|
||||
delta = max(0.0, time.time() - last_refill_value)
|
||||
tokens_value = min(capacity, tokens_value + delta * refill_rate)
|
||||
remaining = int(tokens_value)
|
||||
reset_after = 0 if tokens_value >= capacity else (capacity - tokens_value) / refill_rate
|
||||
reset_at = datetime.now() + timedelta(seconds=reset_after)
|
||||
|
||||
allowed = remaining >= amount
|
||||
retry_after = None
|
||||
if not allowed:
|
||||
needed = max(0, amount - remaining)
|
||||
retry_after = int(needed / refill_rate) + 1
|
||||
|
||||
return RateLimitResult(
|
||||
allowed=allowed,
|
||||
remaining=int(remaining),
|
||||
reset_at=reset_at,
|
||||
retry_after=retry_after,
|
||||
message=(
|
||||
None
|
||||
if allowed
|
||||
else f"Rate limit exceeded. Please retry after {retry_after} seconds."
|
||||
),
|
||||
)
|
||||
|
||||
async def consume(
|
||||
self,
|
||||
key: str,
|
||||
capacity: int,
|
||||
refill_rate: float,
|
||||
amount: int,
|
||||
) -> Tuple[bool, int]:
|
||||
result = await self._consume_script(
|
||||
keys=[self._redis_key(key)],
|
||||
args=[time.time(), capacity, refill_rate, amount],
|
||||
)
|
||||
allowed = bool(result[0])
|
||||
remaining = int(float(result[1]))
|
||||
return allowed, remaining
|
||||
|
||||
async def reset(self, key: str):
|
||||
await self.redis.delete(self._redis_key(key))
|
||||
|
||||
async def get_stats(self, key: str, capacity: int, refill_rate: float) -> Dict[str, Any]:
|
||||
data = await self.redis.hmget(self._redis_key(key), "tokens", "timestamp")
|
||||
tokens = data[0]
|
||||
timestamp = data[1]
|
||||
return {
|
||||
"strategy": "token_bucket",
|
||||
"key": key,
|
||||
"capacity": capacity,
|
||||
"remaining": float(tokens) if tokens else capacity,
|
||||
"refill_rate": refill_rate,
|
||||
"last_refill": float(timestamp) if timestamp else time.time(),
|
||||
"backend": "redis",
|
||||
}
|
||||
Reference in New Issue
Block a user