Initial commit

This commit is contained in:
fawney19
2025-12-10 20:52:44 +08:00
commit f784106826
485 changed files with 110993 additions and 0 deletions

View File

@@ -0,0 +1,9 @@
"""
速率限制插件模块
"""
from .base import RateLimitResult, RateLimitStrategy
from .sliding_window import SlidingWindowStrategy
from .token_bucket import TokenBucketStrategy
__all__ = ["RateLimitStrategy", "RateLimitResult", "TokenBucketStrategy", "SlidingWindowStrategy"]

View File

@@ -0,0 +1,132 @@
"""
速率限制策略基类
定义速率限制策略的接口
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, Optional
from ..common import BasePlugin, HealthStatus, PluginMetadata
@dataclass
class RateLimitResult:
"""
速率限制检查结果
"""
allowed: bool
remaining: int
reset_at: Optional[datetime] = None
retry_after: Optional[int] = None
message: Optional[str] = None
headers: Optional[Dict[str, str]] = None
def __post_init__(self):
if self.headers is None:
self.headers = {}
if self.remaining is not None:
self.headers["X-RateLimit-Remaining"] = str(self.remaining)
if self.reset_at:
self.headers["X-RateLimit-Reset"] = str(int(self.reset_at.timestamp()))
if self.retry_after:
self.headers["Retry-After"] = str(self.retry_after)
class RateLimitStrategy(BasePlugin):
"""
速率限制策略基类
所有速率限制策略必须继承此类
"""
def __init__(
self,
name: str,
priority: int = 0,
version: str = "1.0.0",
author: str = "Unknown",
description: str = "",
api_version: str = "1.0",
dependencies: List[str] = None,
provides: List[str] = None,
config: Dict[str, Any] = None,
):
"""
初始化速率限制策略
Args:
name: 策略名称
priority: 优先级(数字越大优先级越高)
version: 插件版本
author: 插件作者
description: 插件描述
api_version: API版本
dependencies: 依赖的其他插件
provides: 提供的服务
config: 配置字典
"""
super().__init__(
name=name,
priority=priority,
version=version,
author=author,
description=description,
api_version=api_version,
dependencies=dependencies,
provides=provides,
config=config,
)
@abstractmethod
async def check_limit(self, key: str, **kwargs) -> RateLimitResult:
"""
检查速率限制
Args:
key: 限制键如用户ID、API Key ID等
**kwargs: 额外参数
Returns:
速率限制检查结果
"""
pass
@abstractmethod
async def consume(self, key: str, amount: int = 1, **kwargs) -> bool:
"""
消费配额
Args:
key: 限制键
amount: 消费数量
**kwargs: 额外参数
Returns:
是否成功消费
"""
pass
@abstractmethod
async def reset(self, key: str):
"""
重置限制
Args:
key: 限制键
"""
pass
@abstractmethod
async def get_stats(self, key: str) -> Dict[str, Any]:
"""
获取统计信息
Args:
key: 限制键
Returns:
统计信息字典
"""
pass

View File

@@ -0,0 +1,363 @@
"""
滑动窗口算法速率限制策略
精确的速率限制,不允许突发
WARNING: 多进程环境注意事项
=============================
此插件的窗口状态存储在进程内存中。如果使用 Gunicorn/uvicorn 多 worker 模式,
每个 worker 进程有独立的限流状态,可能导致:
- 实际允许的请求数 = 配置限制 * worker数量
- 限流效果大打折扣
解决方案:
1. 单 worker 模式:适用于低流量场景
2. Redis 共享状态:使用 Redis 实现分布式滑动窗口
3. 使用 token_bucket.py令牌桶策略可以更容易迁移到 Redis
目前项目已有 Redis 依赖src/clients/redis_client.py
建议在生产环境使用 Redis 实现分布式限流。
"""
import asyncio
import time
from collections import deque
from datetime import datetime
from typing import Any, Deque, Dict
from src.core.logger import logger
from .base import RateLimitResult, RateLimitStrategy
class SlidingWindow:
"""滑动窗口实现"""
def __init__(self, window_size: int, max_requests: int):
"""
初始化滑动窗口
Args:
window_size: 窗口大小(秒)
max_requests: 窗口内最大请求数
"""
self.window_size = window_size
self.max_requests = max_requests
self.requests: Deque[float] = deque()
self.last_access_time: float = time.time()
def _cleanup(self):
"""清理过期的请求记录"""
current_time = time.time()
self.last_access_time = current_time # 更新最后访问时间
cutoff_time = current_time - self.window_size
# 移除窗口外的请求
while self.requests and self.requests[0] < cutoff_time:
self.requests.popleft()
def can_accept(self, amount: int = 1) -> bool:
"""
检查是否可以接受新请求
Args:
amount: 请求数量
Returns:
是否可以接受
"""
self._cleanup()
return len(self.requests) + amount <= self.max_requests
def add_request(self, amount: int = 1) -> bool:
"""
添加请求记录
Args:
amount: 请求数量
Returns:
是否成功添加
"""
if not self.can_accept(amount):
return False
current_time = time.time()
for _ in range(amount):
self.requests.append(current_time)
return True
def get_remaining(self) -> int:
"""获取剩余配额"""
self._cleanup()
return max(0, self.max_requests - len(self.requests))
def get_reset_time(self) -> datetime:
"""获取最早的重置时间"""
self._cleanup()
if not self.requests:
return datetime.now()
# 最早的请求将在window_size秒后过期
oldest_request = self.requests[0]
reset_time = oldest_request + self.window_size
return datetime.fromtimestamp(reset_time)
class SlidingWindowStrategy(RateLimitStrategy):
"""
滑动窗口算法速率限制策略
特点:
- 精确的速率限制
- 不允许突发流量
- 适合需要严格速率控制的场景
- 自动清理长时间不活跃的窗口,防止内存泄漏
"""
# 默认最大缓存窗口数量
DEFAULT_MAX_WINDOWS = 10000
# 默认窗口过期时间(秒)- 超过此时间未访问的窗口将被清理
DEFAULT_WINDOW_EXPIRY = 3600 # 1小时
def __init__(self):
super().__init__("sliding_window")
self.windows: Dict[str, SlidingWindow] = {}
self._lock = asyncio.Lock()
# 默认配置
self.default_window_size = 60 # 默认60秒窗口
self.default_max_requests = 100 # 默认100个请求
# 内存管理配置
self.max_windows = self.DEFAULT_MAX_WINDOWS
self.window_expiry = self.DEFAULT_WINDOW_EXPIRY
self._last_cleanup_time: float = time.time()
self._cleanup_interval = 300 # 每5分钟检查一次是否需要清理
def _cleanup_expired_windows(self) -> int:
"""
清理过期的窗口,防止内存泄漏
Returns:
清理的窗口数量
"""
current_time = time.time()
expired_keys = []
for key, window in self.windows.items():
# 检查窗口是否过期(长时间未访问)
if current_time - window.last_access_time > self.window_expiry:
expired_keys.append(key)
# 删除过期窗口
for key in expired_keys:
del self.windows[key]
if expired_keys:
logger.info(f"清理了 {len(expired_keys)} 个过期的滑动窗口")
return len(expired_keys)
def _evict_lru_windows(self, count: int) -> int:
"""
使用 LRU 策略淘汰最久未使用的窗口
Args:
count: 需要淘汰的数量
Returns:
实际淘汰的数量
"""
if not self.windows or count <= 0:
return 0
# 按最后访问时间排序,淘汰最久未访问的
sorted_keys = sorted(self.windows.keys(), key=lambda k: self.windows[k].last_access_time)
evicted = 0
for key in sorted_keys[:count]:
del self.windows[key]
evicted += 1
if evicted:
logger.warning(f"LRU 淘汰了 {evicted} 个滑动窗口(达到容量上限)")
return evicted
async def _maybe_cleanup(self):
"""检查是否需要执行清理操作"""
current_time = time.time()
# 定期清理过期窗口
if current_time - self._last_cleanup_time > self._cleanup_interval:
self._cleanup_expired_windows()
self._last_cleanup_time = current_time
# 如果超过容量上限,执行 LRU 淘汰
if len(self.windows) >= self.max_windows:
# 淘汰 10% 的窗口
evict_count = max(1, self.max_windows // 10)
self._evict_lru_windows(evict_count)
def _get_window(self, key: str) -> SlidingWindow:
"""
获取或创建滑动窗口
Args:
key: 限制键
Returns:
滑动窗口实例
"""
if key not in self.windows:
# 根据key的不同前缀使用不同的配置
if key.startswith("api_key:"):
window_size = self.config.get("api_key_window_size", self.default_window_size)
max_requests = self.config.get("api_key_max_requests", self.default_max_requests)
elif key.startswith("user:"):
window_size = self.config.get("user_window_size", self.default_window_size)
max_requests = self.config.get("user_max_requests", self.default_max_requests * 2)
else:
window_size = self.default_window_size
max_requests = self.default_max_requests
self.windows[key] = SlidingWindow(window_size, max_requests)
return self.windows[key]
async def check_limit(self, key: str, **kwargs) -> RateLimitResult:
"""
检查速率限制
Args:
key: 限制键
Returns:
速率限制检查结果
"""
async with self._lock:
# 检查是否需要清理过期窗口
await self._maybe_cleanup()
window = self._get_window(key)
amount = kwargs.get("amount", 1)
# 检查是否可以接受请求
allowed = window.can_accept(amount)
remaining = window.get_remaining()
reset_at = window.get_reset_time()
retry_after = None
if not allowed:
# 计算需要等待的时间(最早请求过期的时间)
retry_after = int((reset_at - datetime.now()).total_seconds()) + 1
return RateLimitResult(
allowed=allowed,
remaining=remaining,
reset_at=reset_at,
retry_after=retry_after,
message=(
None
if allowed
else f"Rate limit exceeded. Please retry after {retry_after} seconds."
),
)
async def consume(self, key: str, amount: int = 1, **kwargs) -> bool:
"""
消费配额
Args:
key: 限制键
amount: 消费数量
Returns:
是否成功消费
"""
async with self._lock:
window = self._get_window(key)
success = window.add_request(amount)
if success:
logger.debug(f"滑动窗口请求记录成功")
else:
logger.warning(f"滑动窗口请求被拒绝:超出速率限制")
return success
async def reset(self, key: str):
"""
重置滑动窗口
Args:
key: 限制键
"""
async with self._lock:
if key in self.windows:
window = self.windows[key]
window.requests.clear()
logger.info(f"滑动窗口已重置")
async def get_stats(self, key: str) -> Dict[str, Any]:
"""
获取统计信息
Args:
key: 限制键
Returns:
统计信息
"""
async with self._lock:
window = self._get_window(key)
window._cleanup() # 先清理过期请求
return {
"strategy": "sliding_window",
"key": key,
"window_size": window.window_size,
"max_requests": window.max_requests,
"current_requests": len(window.requests),
"remaining": window.get_remaining(),
"reset_at": window.get_reset_time().isoformat(),
}
def configure(self, config: Dict[str, Any]):
"""
配置策略
支持的配置项:
- api_key_window_size: API Key的窗口大小
- api_key_max_requests: API Key的最大请求数
- user_window_size: 用户的窗口大小(秒)
- user_max_requests: 用户的最大请求数
- max_windows: 最大缓存窗口数量(防止内存泄漏)
- window_expiry: 窗口过期时间(秒)
- cleanup_interval: 清理检查间隔(秒)
"""
super().configure(config)
self.default_window_size = config.get("default_window_size", self.default_window_size)
self.default_max_requests = config.get("default_max_requests", self.default_max_requests)
self.max_windows = config.get("max_windows", self.max_windows)
self.window_expiry = config.get("window_expiry", self.window_expiry)
self._cleanup_interval = config.get("cleanup_interval", self._cleanup_interval)
def get_memory_stats(self) -> Dict[str, Any]:
"""
获取内存使用统计信息
Returns:
内存使用统计
"""
return {
"total_windows": len(self.windows),
"max_windows": self.max_windows,
"window_expiry": self.window_expiry,
"cleanup_interval": self._cleanup_interval,
"last_cleanup_time": self._last_cleanup_time,
"usage_percent": (
(len(self.windows) / self.max_windows * 100) if self.max_windows > 0 else 0
),
}

View File

@@ -0,0 +1,431 @@
"""令牌桶速率限制策略,支持 Redis 分布式后端"""
import asyncio
import os
import time
from datetime import datetime, timedelta
from typing import Any, Dict, Optional, Tuple
from ...clients.redis_client import get_redis_client_sync
from src.core.logger import logger
from .base import RateLimitResult, RateLimitStrategy
class TokenBucket:
"""令牌桶实现"""
def __init__(self, capacity: int, refill_rate: float):
"""
初始化令牌桶
Args:
capacity: 桶容量(最大令牌数)
refill_rate: 令牌补充速率(每秒)
"""
self.capacity = capacity
self.refill_rate = refill_rate
self.tokens = capacity
self.last_refill = time.time()
def _refill(self):
"""补充令牌"""
now = time.time()
time_passed = now - self.last_refill
tokens_to_add = time_passed * self.refill_rate
if tokens_to_add > 0:
self.tokens = min(self.capacity, self.tokens + tokens_to_add)
self.last_refill = now
def consume(self, amount: int = 1) -> bool:
"""
消费令牌
Args:
amount: 要消费的令牌数
Returns:
是否成功消费
"""
self._refill()
if self.tokens >= amount:
self.tokens -= amount
return True
return False
def get_remaining(self) -> int:
"""获取剩余令牌数"""
self._refill()
return int(self.tokens)
def get_reset_time(self) -> datetime:
"""获取下次完全恢复的时间"""
if self.tokens >= self.capacity:
return datetime.now()
tokens_needed = self.capacity - self.tokens
seconds_to_full = tokens_needed / self.refill_rate
return datetime.now() + timedelta(seconds=seconds_to_full)
class TokenBucketStrategy(RateLimitStrategy):
"""
令牌桶算法速率限制策略
特点:
- 允许突发流量
- 平均速率受限
- 适合处理不均匀的流量模式
"""
def __init__(self):
super().__init__("token_bucket")
self.buckets: Dict[str, TokenBucket] = {}
self._lock = asyncio.Lock()
# 默认配置
self.default_capacity = 100 # 默认桶容量
self.default_refill_rate = 10 # 默认每秒补充10个令牌
# 可选的 Redis 后端
self._redis_backend: Optional[RedisTokenBucketBackend] = None
self._redis_checked = False
self._backend_mode = os.getenv("RATE_LIMIT_BACKEND", "auto").lower()
def _get_bucket(self, key: str, rate_limit: Optional[int] = None) -> TokenBucket:
"""
获取或创建令牌桶
Args:
key: 限制键
rate_limit: 每分钟请求限制(来自数据库配置),如果提供则使用此值
Returns:
令牌桶实例
"""
if key not in self.buckets:
# 如果提供了rate_limit参数来自数据库优先使用
if rate_limit is not None:
# rate_limit 是每分钟请求数,转换为令牌桶参数
capacity = rate_limit # 桶容量等于每分钟限制
refill_rate = rate_limit / 60.0 # 每秒补充的令牌数
# 否则根据key的不同前缀使用不同的配置
elif key.startswith("api_key:"):
capacity = self.config.get("api_key_capacity", self.default_capacity)
refill_rate = self.config.get("api_key_refill_rate", self.default_refill_rate)
elif key.startswith("user:"):
capacity = self.config.get("user_capacity", self.default_capacity * 2)
refill_rate = self.config.get("user_refill_rate", self.default_refill_rate * 2)
else:
capacity = self.default_capacity
refill_rate = self.default_refill_rate
self.buckets[key] = TokenBucket(capacity, refill_rate)
return self.buckets[key]
def _want_redis_backend(self) -> bool:
return self._backend_mode in {"auto", "redis"}
async def _ensure_backend(self):
if self._redis_checked:
return
self._redis_checked = True
if not self._want_redis_backend():
return
redis_client = get_redis_client_sync()
if redis_client:
self._redis_backend = RedisTokenBucketBackend(redis_client)
logger.info("速率限制改用 Redis 令牌桶后端")
elif self._backend_mode == "redis":
logger.warning("RATE_LIMIT_BACKEND=redis 但 Redis 客户端不可用,回退到内存桶")
async def check_limit(self, key: str, **kwargs) -> RateLimitResult:
"""
检查速率限制
Args:
key: 限制键
**kwargs: 额外参数,包括 rate_limit (从数据库配置)
Returns:
速率限制检查结果
"""
await self._ensure_backend()
rate_limit = kwargs.get("rate_limit")
amount = kwargs.get("amount", 1)
if self._redis_backend:
return await self._redis_backend.peek(
key=key,
capacity=self._resolve_capacity(key, rate_limit),
refill_rate=self._resolve_refill_rate(key, rate_limit),
amount=amount,
)
async with self._lock:
bucket = self._get_bucket(key, rate_limit)
remaining = bucket.get_remaining()
reset_at = bucket.get_reset_time()
allowed = remaining >= amount
retry_after = None
if not allowed:
tokens_needed = amount - remaining
retry_after = int(tokens_needed / bucket.refill_rate) + 1
return RateLimitResult(
allowed=allowed,
remaining=remaining,
reset_at=reset_at,
retry_after=retry_after,
message=(
None
if allowed
else f"Rate limit exceeded. Please retry after {retry_after} seconds."
),
)
async def consume(self, key: str, amount: int = 1, **kwargs) -> bool:
"""
消费令牌
Args:
key: 限制键
amount: 消费数量
Returns:
是否成功消费
"""
await self._ensure_backend()
if self._redis_backend:
success, remaining = await self._redis_backend.consume(
key=key,
capacity=self._resolve_capacity(key, kwargs.get("rate_limit")),
refill_rate=self._resolve_refill_rate(key, kwargs.get("rate_limit")),
amount=amount,
)
if success:
logger.debug("Redis 令牌消费成功")
else:
logger.warning("Redis 令牌消费失败")
return success
async with self._lock:
bucket = self._get_bucket(key)
success = bucket.consume(amount)
if success:
logger.debug(f"令牌消费成功")
else:
logger.warning(f"令牌消费失败:超出速率限制")
return success
async def reset(self, key: str):
"""
重置令牌桶
Args:
key: 限制键
"""
await self._ensure_backend()
if self._redis_backend:
await self._redis_backend.reset(key)
return
async with self._lock:
if key in self.buckets:
bucket = self.buckets[key]
bucket.tokens = bucket.capacity
bucket.last_refill = time.time()
logger.info(f"令牌桶已重置")
async def get_stats(self, key: str) -> Dict[str, Any]:
"""
获取统计信息
Args:
key: 限制键
Returns:
统计信息
"""
await self._ensure_backend()
if self._redis_backend:
return await self._redis_backend.get_stats(
key,
capacity=self._resolve_capacity(key),
refill_rate=self._resolve_refill_rate(key),
)
async with self._lock:
bucket = self._get_bucket(key)
return {
"strategy": "token_bucket",
"key": key,
"capacity": bucket.capacity,
"remaining": bucket.get_remaining(),
"refill_rate": bucket.refill_rate,
"reset_at": bucket.get_reset_time().isoformat(),
}
def configure(self, config: Dict[str, Any]):
"""
配置策略
支持的配置项:
- api_key_capacity: API Key的桶容量
- api_key_refill_rate: API Key的令牌补充速率
- user_capacity: 用户的桶容量
- user_refill_rate: 用户的令牌补充速率
"""
super().configure(config)
self.default_capacity = config.get("default_capacity", self.default_capacity)
self.default_refill_rate = config.get("default_refill_rate", self.default_refill_rate)
def _resolve_capacity(self, key: str, rate_limit: Optional[int] = None) -> int:
if rate_limit is not None:
return rate_limit
if key.startswith("api_key:"):
return self.config.get("api_key_capacity", self.default_capacity)
if key.startswith("user:"):
return self.config.get("user_capacity", self.default_capacity * 2)
return self.default_capacity
def _resolve_refill_rate(self, key: str, rate_limit: Optional[int] = None) -> float:
if rate_limit is not None:
return rate_limit / 60.0
if key.startswith("api_key:"):
return self.config.get("api_key_refill_rate", self.default_refill_rate)
if key.startswith("user:"):
return self.config.get("user_refill_rate", self.default_refill_rate * 2)
return self.default_refill_rate
class RedisTokenBucketBackend:
"""使用 Redis 存储令牌桶状态,支持多实例共享"""
_SCRIPT = """
local key = KEYS[1]
local now = tonumber(ARGV[1])
local capacity = tonumber(ARGV[2])
local refill_rate = tonumber(ARGV[3])
local amount = tonumber(ARGV[4])
local data = redis.call('HMGET', key, 'tokens', 'timestamp')
local tokens = tonumber(data[1])
local last_refill = tonumber(data[2])
if tokens == nil then
tokens = capacity
last_refill = now
end
local delta = math.max(0, now - last_refill)
local refill = delta * refill_rate
tokens = math.min(capacity, tokens + refill)
local allowed = 0
local retry_after = 0
if tokens >= amount then
tokens = tokens - amount
allowed = 1
else
retry_after = math.ceil((amount - tokens) / refill_rate)
end
redis.call('HMSET', key, 'tokens', tokens, 'timestamp', now)
local ttl = math.max(1, math.ceil(capacity / refill_rate))
redis.call('EXPIRE', key, ttl)
return {allowed, tokens, retry_after}
"""
def __init__(self, redis_client):
self.redis = redis_client
self._consume_script = self.redis.register_script(self._SCRIPT)
def _redis_key(self, key: str) -> str:
return f"rate_limit:bucket:{key}"
async def peek(
self,
key: str,
capacity: int,
refill_rate: float,
amount: int,
) -> RateLimitResult:
bucket_key = self._redis_key(key)
data = await self.redis.hmget(bucket_key, "tokens", "timestamp")
tokens = data[0]
last_refill = data[1]
if tokens is None or last_refill is None:
remaining = capacity
reset_at = datetime.now() + timedelta(seconds=capacity / refill_rate)
else:
tokens_value = float(tokens)
last_refill_value = float(last_refill)
delta = max(0.0, time.time() - last_refill_value)
tokens_value = min(capacity, tokens_value + delta * refill_rate)
remaining = int(tokens_value)
reset_after = 0 if tokens_value >= capacity else (capacity - tokens_value) / refill_rate
reset_at = datetime.now() + timedelta(seconds=reset_after)
allowed = remaining >= amount
retry_after = None
if not allowed:
needed = max(0, amount - remaining)
retry_after = int(needed / refill_rate) + 1
return RateLimitResult(
allowed=allowed,
remaining=int(remaining),
reset_at=reset_at,
retry_after=retry_after,
message=(
None
if allowed
else f"Rate limit exceeded. Please retry after {retry_after} seconds."
),
)
async def consume(
self,
key: str,
capacity: int,
refill_rate: float,
amount: int,
) -> Tuple[bool, int]:
result = await self._consume_script(
keys=[self._redis_key(key)],
args=[time.time(), capacity, refill_rate, amount],
)
allowed = bool(result[0])
remaining = int(float(result[1]))
return allowed, remaining
async def reset(self, key: str):
await self.redis.delete(self._redis_key(key))
async def get_stats(self, key: str, capacity: int, refill_rate: float) -> Dict[str, Any]:
data = await self.redis.hmget(self._redis_key(key), "tokens", "timestamp")
tokens = data[0]
timestamp = data[1]
return {
"strategy": "token_bucket",
"key": key,
"capacity": capacity,
"remaining": float(tokens) if tokens else capacity,
"refill_rate": refill_rate,
"last_refill": float(timestamp) if timestamp else time.time(),
"backend": "redis",
}