mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-09 19:22:26 +08:00
Initial commit
This commit is contained in:
19
src/services/rate_limit/__init__.py
Normal file
19
src/services/rate_limit/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
限流服务模块
|
||||
|
||||
包含自适应并发控制、RPM限流、IP限流等功能。
|
||||
"""
|
||||
|
||||
from src.services.rate_limit.adaptive_concurrency import AdaptiveConcurrencyManager
|
||||
from src.services.rate_limit.concurrency_manager import ConcurrencyManager
|
||||
from src.services.rate_limit.detector import RateLimitDetector
|
||||
from src.services.rate_limit.ip_limiter import IPRateLimiter
|
||||
from src.services.rate_limit.rpm_limiter import RPMLimiter
|
||||
|
||||
__all__ = [
|
||||
"AdaptiveConcurrencyManager",
|
||||
"ConcurrencyManager",
|
||||
"IPRateLimiter",
|
||||
"RPMLimiter",
|
||||
"RateLimitDetector",
|
||||
]
|
||||
558
src/services/rate_limit/adaptive_concurrency.py
Normal file
558
src/services/rate_limit/adaptive_concurrency.py
Normal file
@@ -0,0 +1,558 @@
|
||||
"""
|
||||
自适应并发调整器 - 基于滑动窗口利用率的并发限制调整
|
||||
|
||||
核心改进(相对于旧版基于"持续高利用率"的方案):
|
||||
- 使用滑动窗口采样,容忍并发波动
|
||||
- 基于窗口内高利用率采样比例决策,而非要求连续高利用率
|
||||
- 增加探测性扩容机制,长时间稳定时主动尝试扩容
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.config.constants import ConcurrencyDefaults
|
||||
from src.core.batch_committer import get_batch_committer
|
||||
from src.core.logger import logger
|
||||
from src.models.database import ProviderAPIKey
|
||||
from src.services.rate_limit.detector import RateLimitInfo, RateLimitType
|
||||
|
||||
|
||||
class AdaptiveStrategy:
|
||||
"""自适应策略类型"""
|
||||
|
||||
AIMD = "aimd" # 加性增-乘性减 (Additive Increase Multiplicative Decrease)
|
||||
CONSERVATIVE = "conservative" # 保守策略(只减不增)
|
||||
AGGRESSIVE = "aggressive" # 激进策略(快速探测)
|
||||
|
||||
|
||||
class AdaptiveConcurrencyManager:
|
||||
"""
|
||||
自适应并发管理器
|
||||
|
||||
核心算法:基于滑动窗口利用率的 AIMD
|
||||
- 滑动窗口记录最近 N 次请求的利用率
|
||||
- 当窗口内高利用率采样比例 >= 60% 时触发扩容
|
||||
- 遇到 429 错误时乘性减少 (*0.7)
|
||||
- 长时间无 429 且有流量时触发探测性扩容
|
||||
|
||||
扩容条件(满足任一即可):
|
||||
1. 滑动窗口扩容:窗口内 >= 60% 的采样利用率 >= 70%,且不在冷却期
|
||||
2. 探测性扩容:距上次 429 超过 30 分钟,且期间有足够请求量
|
||||
|
||||
关键特性:
|
||||
1. 滑动窗口容忍并发波动,不会因单次低利用率重置
|
||||
2. 区分并发限制和 RPM 限制
|
||||
3. 探测性扩容避免长期卡在低限制
|
||||
4. 记录调整历史
|
||||
"""
|
||||
|
||||
# 默认配置 - 使用统一常量
|
||||
DEFAULT_INITIAL_LIMIT = ConcurrencyDefaults.INITIAL_LIMIT
|
||||
MIN_CONCURRENT_LIMIT = ConcurrencyDefaults.MIN_CONCURRENT_LIMIT
|
||||
MAX_CONCURRENT_LIMIT = ConcurrencyDefaults.MAX_CONCURRENT_LIMIT
|
||||
|
||||
# AIMD 参数
|
||||
INCREASE_STEP = ConcurrencyDefaults.INCREASE_STEP
|
||||
DECREASE_MULTIPLIER = ConcurrencyDefaults.DECREASE_MULTIPLIER
|
||||
|
||||
# 滑动窗口参数
|
||||
UTILIZATION_WINDOW_SIZE = ConcurrencyDefaults.UTILIZATION_WINDOW_SIZE
|
||||
UTILIZATION_WINDOW_SECONDS = ConcurrencyDefaults.UTILIZATION_WINDOW_SECONDS
|
||||
UTILIZATION_THRESHOLD = ConcurrencyDefaults.UTILIZATION_THRESHOLD
|
||||
HIGH_UTILIZATION_RATIO = ConcurrencyDefaults.HIGH_UTILIZATION_RATIO
|
||||
MIN_SAMPLES_FOR_DECISION = ConcurrencyDefaults.MIN_SAMPLES_FOR_DECISION
|
||||
|
||||
# 探测性扩容参数
|
||||
PROBE_INCREASE_INTERVAL_MINUTES = ConcurrencyDefaults.PROBE_INCREASE_INTERVAL_MINUTES
|
||||
PROBE_INCREASE_MIN_REQUESTS = ConcurrencyDefaults.PROBE_INCREASE_MIN_REQUESTS
|
||||
|
||||
# 记录历史数量
|
||||
MAX_HISTORY_RECORDS = 20
|
||||
|
||||
def __init__(self, strategy: str = AdaptiveStrategy.AIMD):
|
||||
"""
|
||||
初始化自适应并发管理器
|
||||
|
||||
Args:
|
||||
strategy: 调整策略
|
||||
"""
|
||||
self.strategy = strategy
|
||||
|
||||
def handle_429_error(
|
||||
self,
|
||||
db: Session,
|
||||
key: ProviderAPIKey,
|
||||
rate_limit_info: RateLimitInfo,
|
||||
current_concurrent: Optional[int] = None,
|
||||
) -> int:
|
||||
"""
|
||||
处理429错误,调整并发限制
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
key: API Key对象
|
||||
rate_limit_info: 速率限制信息
|
||||
current_concurrent: 当前并发数
|
||||
|
||||
Returns:
|
||||
调整后的并发限制
|
||||
"""
|
||||
# max_concurrent=NULL 表示启用自适应,max_concurrent=数字 表示固定限制
|
||||
is_adaptive = key.max_concurrent is None
|
||||
|
||||
if not is_adaptive:
|
||||
logger.debug(
|
||||
f"Key {key.id} 设置了固定并发限制 ({key.max_concurrent}),跳过自适应调整"
|
||||
)
|
||||
return int(key.max_concurrent) # type: ignore[arg-type]
|
||||
|
||||
# 更新429统计
|
||||
key.last_429_at = datetime.now(timezone.utc) # type: ignore[assignment]
|
||||
key.last_429_type = rate_limit_info.limit_type # type: ignore[assignment]
|
||||
key.last_concurrent_peak = current_concurrent # type: ignore[assignment]
|
||||
|
||||
# 遇到 429 错误,清空利用率采样窗口(重新开始收集)
|
||||
key.utilization_samples = [] # type: ignore[assignment]
|
||||
|
||||
if rate_limit_info.limit_type == RateLimitType.CONCURRENT:
|
||||
# 并发限制:减少并发数
|
||||
key.concurrent_429_count = int(key.concurrent_429_count or 0) + 1 # type: ignore[assignment]
|
||||
|
||||
# 获取当前有效限制(自适应模式使用 learned_max_concurrent)
|
||||
old_limit = int(key.learned_max_concurrent or self.DEFAULT_INITIAL_LIMIT)
|
||||
new_limit = self._decrease_limit(old_limit, current_concurrent)
|
||||
|
||||
logger.warning(
|
||||
f"[CONCURRENT] 并发限制触发: Key {key.id[:8]}... | "
|
||||
f"当前并发: {current_concurrent} | "
|
||||
f"调整: {old_limit} -> {new_limit}"
|
||||
)
|
||||
|
||||
# 记录调整历史
|
||||
self._record_adjustment(
|
||||
key,
|
||||
old_limit=old_limit,
|
||||
new_limit=new_limit,
|
||||
reason="concurrent_429",
|
||||
current_concurrent=current_concurrent,
|
||||
)
|
||||
|
||||
# 更新学习到的并发限制
|
||||
key.learned_max_concurrent = new_limit # type: ignore[assignment]
|
||||
|
||||
elif rate_limit_info.limit_type == RateLimitType.RPM:
|
||||
# RPM限制:不调整并发,只记录
|
||||
key.rpm_429_count = int(key.rpm_429_count or 0) + 1 # type: ignore[assignment]
|
||||
|
||||
logger.info(
|
||||
f"[RPM] RPM限制触发: Key {key.id[:8]}... | "
|
||||
f"retry_after: {rate_limit_info.retry_after}s | "
|
||||
f"不调整并发限制"
|
||||
)
|
||||
|
||||
else:
|
||||
# 未知类型:保守处理,轻微减少
|
||||
logger.warning(
|
||||
f"[UNKNOWN] 未知429类型: Key {key.id[:8]}... | "
|
||||
f"当前并发: {current_concurrent} | "
|
||||
f"保守减少并发"
|
||||
)
|
||||
|
||||
old_limit = int(key.learned_max_concurrent or self.DEFAULT_INITIAL_LIMIT)
|
||||
new_limit = max(int(old_limit * 0.9), self.MIN_CONCURRENT_LIMIT) # 减少10%
|
||||
|
||||
self._record_adjustment(
|
||||
key,
|
||||
old_limit=old_limit,
|
||||
new_limit=new_limit,
|
||||
reason="unknown_429",
|
||||
current_concurrent=current_concurrent,
|
||||
)
|
||||
|
||||
key.learned_max_concurrent = new_limit # type: ignore[assignment]
|
||||
|
||||
db.flush()
|
||||
get_batch_committer().mark_dirty(db)
|
||||
|
||||
return int(key.learned_max_concurrent or self.DEFAULT_INITIAL_LIMIT)
|
||||
|
||||
def handle_success(
|
||||
self,
|
||||
db: Session,
|
||||
key: ProviderAPIKey,
|
||||
current_concurrent: int,
|
||||
) -> Optional[int]:
|
||||
"""
|
||||
处理成功请求,基于滑动窗口利用率考虑增加并发限制
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
key: API Key对象
|
||||
current_concurrent: 当前并发数(必需,用于计算利用率)
|
||||
|
||||
Returns:
|
||||
调整后的并发限制(如果有调整),否则返回 None
|
||||
"""
|
||||
# max_concurrent=NULL 表示启用自适应
|
||||
is_adaptive = key.max_concurrent is None
|
||||
|
||||
if not is_adaptive:
|
||||
return None
|
||||
|
||||
current_limit = int(key.learned_max_concurrent or self.DEFAULT_INITIAL_LIMIT)
|
||||
|
||||
# 计算当前利用率
|
||||
utilization = float(current_concurrent / current_limit) if current_limit > 0 else 0.0
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
now_ts = now.timestamp()
|
||||
|
||||
# 更新滑动窗口
|
||||
samples = self._update_utilization_window(key, now_ts, utilization)
|
||||
|
||||
# 检查是否满足扩容条件
|
||||
increase_reason = self._check_increase_conditions(key, samples, now)
|
||||
|
||||
if increase_reason and current_limit < self.MAX_CONCURRENT_LIMIT:
|
||||
old_limit = current_limit
|
||||
new_limit = self._increase_limit(current_limit)
|
||||
|
||||
# 计算窗口统计用于日志
|
||||
avg_util = sum(s["util"] for s in samples) / len(samples) if samples else 0
|
||||
high_util_count = sum(1 for s in samples if s["util"] >= self.UTILIZATION_THRESHOLD)
|
||||
high_util_ratio = high_util_count / len(samples) if samples else 0
|
||||
|
||||
logger.info(
|
||||
f"[INCREASE] {increase_reason}: Key {key.id[:8]}... | "
|
||||
f"窗口采样: {len(samples)} | "
|
||||
f"平均利用率: {avg_util:.1%} | "
|
||||
f"高利用率比例: {high_util_ratio:.1%} | "
|
||||
f"调整: {old_limit} -> {new_limit}"
|
||||
)
|
||||
|
||||
# 记录调整历史
|
||||
self._record_adjustment(
|
||||
key,
|
||||
old_limit=old_limit,
|
||||
new_limit=new_limit,
|
||||
reason=increase_reason,
|
||||
avg_utilization=round(avg_util, 2),
|
||||
high_util_ratio=round(high_util_ratio, 2),
|
||||
sample_count=len(samples),
|
||||
current_concurrent=current_concurrent,
|
||||
)
|
||||
|
||||
# 更新限制
|
||||
key.learned_max_concurrent = new_limit # type: ignore[assignment]
|
||||
|
||||
# 如果是探测性扩容,更新探测时间
|
||||
if increase_reason == "probe_increase":
|
||||
key.last_probe_increase_at = now # type: ignore[assignment]
|
||||
|
||||
# 扩容后清空采样窗口,重新开始收集
|
||||
key.utilization_samples = [] # type: ignore[assignment]
|
||||
|
||||
db.flush()
|
||||
get_batch_committer().mark_dirty(db)
|
||||
|
||||
return new_limit
|
||||
|
||||
# 定期持久化采样数据(每5个采样保存一次)
|
||||
if len(samples) % 5 == 0:
|
||||
db.flush()
|
||||
get_batch_committer().mark_dirty(db)
|
||||
|
||||
return None
|
||||
|
||||
def _update_utilization_window(
|
||||
self, key: ProviderAPIKey, now_ts: float, utilization: float
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
更新利用率滑动窗口
|
||||
|
||||
Args:
|
||||
key: API Key对象
|
||||
now_ts: 当前时间戳
|
||||
utilization: 当前利用率
|
||||
|
||||
Returns:
|
||||
更新后的采样列表
|
||||
"""
|
||||
samples: List[Dict[str, Any]] = list(key.utilization_samples or [])
|
||||
|
||||
# 添加新采样
|
||||
samples.append({"ts": now_ts, "util": round(utilization, 3)})
|
||||
|
||||
# 移除过期采样(超过时间窗口)
|
||||
cutoff_ts = now_ts - self.UTILIZATION_WINDOW_SECONDS
|
||||
samples = [s for s in samples if s["ts"] > cutoff_ts]
|
||||
|
||||
# 限制采样数量
|
||||
if len(samples) > self.UTILIZATION_WINDOW_SIZE:
|
||||
samples = samples[-self.UTILIZATION_WINDOW_SIZE:]
|
||||
|
||||
# 更新到 key 对象
|
||||
key.utilization_samples = samples # type: ignore[assignment]
|
||||
|
||||
return samples
|
||||
|
||||
def _check_increase_conditions(
|
||||
self, key: ProviderAPIKey, samples: List[Dict[str, Any]], now: datetime
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
检查是否满足扩容条件
|
||||
|
||||
Args:
|
||||
key: API Key对象
|
||||
samples: 利用率采样列表
|
||||
now: 当前时间
|
||||
|
||||
Returns:
|
||||
扩容原因(如果满足条件),否则返回 None
|
||||
"""
|
||||
# 检查是否在冷却期
|
||||
if self._is_in_cooldown(key):
|
||||
return None
|
||||
|
||||
# 条件1:滑动窗口扩容
|
||||
if len(samples) >= self.MIN_SAMPLES_FOR_DECISION:
|
||||
high_util_count = sum(1 for s in samples if s["util"] >= self.UTILIZATION_THRESHOLD)
|
||||
high_util_ratio = high_util_count / len(samples)
|
||||
|
||||
if high_util_ratio >= self.HIGH_UTILIZATION_RATIO:
|
||||
return "high_utilization"
|
||||
|
||||
# 条件2:探测性扩容(长时间无 429 且有流量)
|
||||
if self._should_probe_increase(key, samples, now):
|
||||
return "probe_increase"
|
||||
|
||||
return None
|
||||
|
||||
def _should_probe_increase(
|
||||
self, key: ProviderAPIKey, samples: List[Dict[str, Any]], now: datetime
|
||||
) -> bool:
|
||||
"""
|
||||
检查是否应该进行探测性扩容
|
||||
|
||||
条件:
|
||||
1. 距上次 429 超过 PROBE_INCREASE_INTERVAL_MINUTES 分钟
|
||||
2. 距上次探测性扩容超过 PROBE_INCREASE_INTERVAL_MINUTES 分钟
|
||||
3. 期间有足够的请求量(采样数 >= PROBE_INCREASE_MIN_REQUESTS)
|
||||
4. 平均利用率 > 30%(说明确实有使用需求)
|
||||
|
||||
Args:
|
||||
key: API Key对象
|
||||
samples: 利用率采样列表
|
||||
now: 当前时间
|
||||
|
||||
Returns:
|
||||
是否应该探测性扩容
|
||||
"""
|
||||
probe_interval_seconds = self.PROBE_INCREASE_INTERVAL_MINUTES * 60
|
||||
|
||||
# 检查距上次 429 的时间
|
||||
if key.last_429_at:
|
||||
last_429_at = cast(datetime, key.last_429_at)
|
||||
time_since_429 = (now - last_429_at).total_seconds()
|
||||
if time_since_429 < probe_interval_seconds:
|
||||
return False
|
||||
|
||||
# 检查距上次探测性扩容的时间
|
||||
if key.last_probe_increase_at:
|
||||
last_probe = cast(datetime, key.last_probe_increase_at)
|
||||
time_since_probe = (now - last_probe).total_seconds()
|
||||
if time_since_probe < probe_interval_seconds:
|
||||
return False
|
||||
|
||||
# 检查请求量
|
||||
if len(samples) < self.PROBE_INCREASE_MIN_REQUESTS:
|
||||
return False
|
||||
|
||||
# 检查平均利用率(确保确实有使用需求)
|
||||
avg_util = sum(s["util"] for s in samples) / len(samples)
|
||||
if avg_util < 0.3: # 至少 30% 利用率
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _is_in_cooldown(self, key: ProviderAPIKey) -> bool:
|
||||
"""
|
||||
检查是否在 429 错误后的冷却期内
|
||||
|
||||
Args:
|
||||
key: API Key对象
|
||||
|
||||
Returns:
|
||||
True 如果在冷却期内,否则 False
|
||||
"""
|
||||
if key.last_429_at is None:
|
||||
return False
|
||||
|
||||
last_429_at = cast(datetime, key.last_429_at)
|
||||
time_since_429 = (datetime.now(timezone.utc) - last_429_at).total_seconds()
|
||||
cooldown_seconds = ConcurrencyDefaults.COOLDOWN_AFTER_429_MINUTES * 60
|
||||
|
||||
return bool(time_since_429 < cooldown_seconds)
|
||||
|
||||
def _decrease_limit(
|
||||
self,
|
||||
current_limit: int,
|
||||
current_concurrent: Optional[int] = None,
|
||||
) -> int:
|
||||
"""
|
||||
减少并发限制
|
||||
|
||||
策略:
|
||||
- 如果知道当前并发数,设置为当前并发的70%
|
||||
- 否则,使用乘性减少
|
||||
"""
|
||||
if current_concurrent:
|
||||
# 基于当前并发数减少
|
||||
new_limit = max(
|
||||
int(current_concurrent * self.DECREASE_MULTIPLIER), self.MIN_CONCURRENT_LIMIT
|
||||
)
|
||||
else:
|
||||
# 乘性减少
|
||||
new_limit = max(
|
||||
int(current_limit * self.DECREASE_MULTIPLIER), self.MIN_CONCURRENT_LIMIT
|
||||
)
|
||||
|
||||
return new_limit
|
||||
|
||||
def _increase_limit(self, current_limit: int) -> int:
|
||||
"""
|
||||
增加并发限制
|
||||
|
||||
策略:加性增加 (+1)
|
||||
"""
|
||||
new_limit = min(current_limit + self.INCREASE_STEP, self.MAX_CONCURRENT_LIMIT)
|
||||
return new_limit
|
||||
|
||||
def _record_adjustment(
|
||||
self,
|
||||
key: ProviderAPIKey,
|
||||
old_limit: int,
|
||||
new_limit: int,
|
||||
reason: str,
|
||||
**extra_data: Any,
|
||||
) -> None:
|
||||
"""
|
||||
记录并发调整历史
|
||||
|
||||
Args:
|
||||
key: API Key对象
|
||||
old_limit: 原限制
|
||||
new_limit: 新限制
|
||||
reason: 调整原因
|
||||
**extra_data: 额外数据
|
||||
"""
|
||||
history: List[Dict[str, Any]] = list(key.adjustment_history or [])
|
||||
|
||||
record = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"old_limit": old_limit,
|
||||
"new_limit": new_limit,
|
||||
"reason": reason,
|
||||
**extra_data,
|
||||
}
|
||||
history.append(record)
|
||||
|
||||
# 保留最近N条记录
|
||||
if len(history) > self.MAX_HISTORY_RECORDS:
|
||||
history = history[-self.MAX_HISTORY_RECORDS:]
|
||||
|
||||
key.adjustment_history = history # type: ignore[assignment]
|
||||
|
||||
def get_adjustment_stats(self, key: ProviderAPIKey) -> Dict[str, Any]:
|
||||
"""
|
||||
获取调整统计信息
|
||||
|
||||
Args:
|
||||
key: API Key对象
|
||||
|
||||
Returns:
|
||||
统计信息
|
||||
"""
|
||||
history: List[Dict[str, Any]] = list(key.adjustment_history or [])
|
||||
samples: List[Dict[str, Any]] = list(key.utilization_samples or [])
|
||||
|
||||
# max_concurrent=NULL 表示自适应,否则为固定限制
|
||||
is_adaptive = key.max_concurrent is None
|
||||
current_limit = int(key.learned_max_concurrent or self.DEFAULT_INITIAL_LIMIT)
|
||||
effective_limit = current_limit if is_adaptive else int(key.max_concurrent) # type: ignore
|
||||
|
||||
# 计算窗口统计
|
||||
avg_utilization: Optional[float] = None
|
||||
high_util_ratio: Optional[float] = None
|
||||
if samples:
|
||||
avg_utilization = sum(s["util"] for s in samples) / len(samples)
|
||||
high_util_count = sum(1 for s in samples if s["util"] >= self.UTILIZATION_THRESHOLD)
|
||||
high_util_ratio = high_util_count / len(samples)
|
||||
|
||||
last_429_at_str: Optional[str] = None
|
||||
if key.last_429_at:
|
||||
last_429_at_str = cast(datetime, key.last_429_at).isoformat()
|
||||
|
||||
last_probe_at_str: Optional[str] = None
|
||||
if key.last_probe_increase_at:
|
||||
last_probe_at_str = cast(datetime, key.last_probe_increase_at).isoformat()
|
||||
|
||||
return {
|
||||
"adaptive_mode": is_adaptive,
|
||||
"max_concurrent": key.max_concurrent, # NULL=自适应,数字=固定限制
|
||||
"effective_limit": effective_limit, # 当前有效限制
|
||||
"learned_limit": key.learned_max_concurrent, # 学习到的限制
|
||||
"concurrent_429_count": int(key.concurrent_429_count or 0),
|
||||
"rpm_429_count": int(key.rpm_429_count or 0),
|
||||
"last_429_at": last_429_at_str,
|
||||
"last_429_type": key.last_429_type,
|
||||
"adjustment_count": len(history),
|
||||
"recent_adjustments": history[-5:] if history else [],
|
||||
# 滑动窗口相关
|
||||
"window_sample_count": len(samples),
|
||||
"window_avg_utilization": round(avg_utilization, 3) if avg_utilization else None,
|
||||
"window_high_util_ratio": round(high_util_ratio, 3) if high_util_ratio else None,
|
||||
"utilization_threshold": self.UTILIZATION_THRESHOLD,
|
||||
"high_util_ratio_threshold": self.HIGH_UTILIZATION_RATIO,
|
||||
"min_samples_for_decision": self.MIN_SAMPLES_FOR_DECISION,
|
||||
# 探测性扩容相关
|
||||
"last_probe_increase_at": last_probe_at_str,
|
||||
"probe_increase_interval_minutes": self.PROBE_INCREASE_INTERVAL_MINUTES,
|
||||
}
|
||||
|
||||
def reset_learning(self, db: Session, key: ProviderAPIKey) -> None:
|
||||
"""
|
||||
重置学习状态(管理员功能)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
key: API Key对象
|
||||
"""
|
||||
logger.info(f"[RESET] 重置学习状态: Key {key.id[:8]}...")
|
||||
|
||||
key.learned_max_concurrent = None # type: ignore[assignment]
|
||||
key.concurrent_429_count = 0 # type: ignore[assignment]
|
||||
key.rpm_429_count = 0 # type: ignore[assignment]
|
||||
key.last_429_at = None # type: ignore[assignment]
|
||||
key.last_429_type = None # type: ignore[assignment]
|
||||
key.last_concurrent_peak = None # type: ignore[assignment]
|
||||
key.adjustment_history = [] # type: ignore[assignment]
|
||||
key.utilization_samples = [] # type: ignore[assignment]
|
||||
key.last_probe_increase_at = None # type: ignore[assignment]
|
||||
|
||||
db.flush()
|
||||
get_batch_committer().mark_dirty(db)
|
||||
|
||||
|
||||
# 全局单例
|
||||
_adaptive_manager: Optional[AdaptiveConcurrencyManager] = None
|
||||
|
||||
|
||||
def get_adaptive_manager() -> AdaptiveConcurrencyManager:
|
||||
"""获取全局自适应管理器单例"""
|
||||
global _adaptive_manager
|
||||
if _adaptive_manager is None:
|
||||
_adaptive_manager = AdaptiveConcurrencyManager()
|
||||
return _adaptive_manager
|
||||
340
src/services/rate_limit/adaptive_reservation.py
Normal file
340
src/services/rate_limit/adaptive_reservation.py
Normal file
@@ -0,0 +1,340 @@
|
||||
"""
|
||||
自适应预留比例管理器
|
||||
|
||||
根据学习置信度和当前负载动态计算缓存用户预留比例,
|
||||
解决固定 30% 预留在学习初期和负载变化时的不适应问题。
|
||||
|
||||
核心思路:
|
||||
1. 探测阶段:使用低预留,让系统快速学习真实并发限制
|
||||
2. 稳定阶段:根据置信度和负载动态调整预留比例
|
||||
3. 置信度计算:综合考虑连续成功次数、429冷却时间、调整历史稳定性
|
||||
"""
|
||||
|
||||
import statistics
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from src.core.logger import logger
|
||||
|
||||
from src.config.constants import AdaptiveReservationDefaults
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.models.database import ProviderAPIKey
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReservationConfig:
|
||||
"""预留比例配置(使用统一常量作为默认值)"""
|
||||
|
||||
# 探测阶段配置
|
||||
probe_phase_requests: int = field(
|
||||
default_factory=lambda: AdaptiveReservationDefaults.PROBE_PHASE_REQUESTS
|
||||
)
|
||||
probe_reservation: float = field(
|
||||
default_factory=lambda: AdaptiveReservationDefaults.PROBE_RESERVATION
|
||||
)
|
||||
|
||||
# 稳定阶段配置
|
||||
stable_min_reservation: float = field(
|
||||
default_factory=lambda: AdaptiveReservationDefaults.STABLE_MIN_RESERVATION
|
||||
)
|
||||
stable_max_reservation: float = field(
|
||||
default_factory=lambda: AdaptiveReservationDefaults.STABLE_MAX_RESERVATION
|
||||
)
|
||||
|
||||
# 置信度计算参数
|
||||
success_count_for_full_confidence: int = field(
|
||||
default_factory=lambda: AdaptiveReservationDefaults.SUCCESS_COUNT_FOR_FULL_CONFIDENCE
|
||||
)
|
||||
cooldown_hours_for_full_confidence: int = field(
|
||||
default_factory=lambda: AdaptiveReservationDefaults.COOLDOWN_HOURS_FOR_FULL_CONFIDENCE
|
||||
)
|
||||
|
||||
# 负载阈值
|
||||
low_load_threshold: float = field(
|
||||
default_factory=lambda: AdaptiveReservationDefaults.LOW_LOAD_THRESHOLD
|
||||
)
|
||||
high_load_threshold: float = field(
|
||||
default_factory=lambda: AdaptiveReservationDefaults.HIGH_LOAD_THRESHOLD
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReservationResult:
|
||||
"""预留比例计算结果"""
|
||||
|
||||
ratio: float # 最终预留比例
|
||||
phase: str # 当前阶段: "probe" | "stable"
|
||||
confidence: float # 置信度 (0-1)
|
||||
load_factor: float # 负载因子 (0-1)
|
||||
details: Dict[str, Any] # 详细信息
|
||||
|
||||
|
||||
class AdaptiveReservationManager:
|
||||
"""
|
||||
自适应预留比例管理器
|
||||
|
||||
工作原理:
|
||||
1. 探测阶段(请求数 < 阈值):
|
||||
- 使用低预留比例(10%),不浪费资源
|
||||
- 让系统快速探测真实并发限制
|
||||
|
||||
2. 稳定阶段(请求数 >= 阈值):
|
||||
- 根据置信度和负载动态计算预留比例
|
||||
- 置信度高 + 负载高 = 高预留(保护缓存用户)
|
||||
- 置信度低或负载低 = 低预留(避免浪费)
|
||||
|
||||
置信度因素:
|
||||
- 连续成功次数:越多说明当前限制越准确
|
||||
- 429冷却时间:距离上次429越久越稳定
|
||||
- 调整历史稳定性:最近调整的方差越小越稳定
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[ReservationConfig] = None):
|
||||
self.config = config or ReservationConfig()
|
||||
self._cache: Dict[str, ReservationResult] = {} # 简单的内存缓存
|
||||
|
||||
def calculate_reservation(
|
||||
self,
|
||||
key: "ProviderAPIKey",
|
||||
current_concurrent: int = 0,
|
||||
effective_limit: Optional[int] = None,
|
||||
) -> ReservationResult:
|
||||
"""
|
||||
计算当前应使用的预留比例
|
||||
|
||||
Args:
|
||||
key: ProviderAPIKey 对象
|
||||
current_concurrent: 当前并发数
|
||||
effective_limit: 有效并发限制(学习值或配置值)
|
||||
|
||||
Returns:
|
||||
ReservationResult 包含预留比例和详细信息
|
||||
"""
|
||||
# 计算总请求数(用于判断阶段)
|
||||
total_requests = self._get_total_requests(key)
|
||||
|
||||
# 计算负载率
|
||||
load_ratio = self._calculate_load_ratio(current_concurrent, effective_limit)
|
||||
|
||||
# 阶段1: 探测阶段
|
||||
if total_requests < self.config.probe_phase_requests:
|
||||
return ReservationResult(
|
||||
ratio=self.config.probe_reservation,
|
||||
phase="probe",
|
||||
confidence=0.0,
|
||||
load_factor=load_ratio,
|
||||
details={
|
||||
"total_requests": total_requests,
|
||||
"probe_threshold": self.config.probe_phase_requests,
|
||||
"reason": "探测阶段,使用低预留让系统学习真实限制",
|
||||
},
|
||||
)
|
||||
|
||||
# 阶段2: 稳定阶段
|
||||
confidence = self._calculate_confidence(key)
|
||||
ratio = self._calculate_stable_ratio(confidence, load_ratio)
|
||||
|
||||
return ReservationResult(
|
||||
ratio=ratio,
|
||||
phase="stable",
|
||||
confidence=confidence,
|
||||
load_factor=load_ratio,
|
||||
details={
|
||||
"total_requests": total_requests,
|
||||
"confidence_factors": self._get_confidence_breakdown(key),
|
||||
"reason": self._get_ratio_reason(confidence, load_ratio),
|
||||
},
|
||||
)
|
||||
|
||||
def _get_total_requests(self, key: "ProviderAPIKey") -> int:
|
||||
"""获取总请求数(用于判断是否过了探测阶段)"""
|
||||
# 使用总请求计数作为基准
|
||||
request_count = key.request_count or 0
|
||||
|
||||
# 如果 request_count 为 0,使用 429 计数 + 成功计数作为近似值
|
||||
if request_count == 0:
|
||||
concurrent_429 = key.concurrent_429_count or 0
|
||||
rpm_429 = key.rpm_429_count or 0
|
||||
success_count = key.success_count or 0
|
||||
# 调整历史中的记录数也可以参考
|
||||
history_count = len(key.adjustment_history or []) * 10
|
||||
return concurrent_429 + rpm_429 + success_count + history_count
|
||||
|
||||
return request_count
|
||||
|
||||
def _calculate_load_ratio(
|
||||
self, current_concurrent: int, effective_limit: Optional[int]
|
||||
) -> float:
|
||||
"""计算当前负载率"""
|
||||
if not effective_limit or effective_limit <= 0:
|
||||
return 0.0
|
||||
return min(current_concurrent / effective_limit, 1.0)
|
||||
|
||||
def _calculate_confidence(self, key: "ProviderAPIKey") -> float:
|
||||
"""
|
||||
计算学习值的置信度 (0-1)
|
||||
|
||||
三个因素各占一定权重:
|
||||
- 成功率:40%(基于总成功数/总请求数)
|
||||
- 429冷却时间:30%
|
||||
- 调整历史稳定性:30%
|
||||
"""
|
||||
scores = self._get_confidence_breakdown(key)
|
||||
return min(
|
||||
scores["success_score"] + scores["cooldown_score"] + scores["stability_score"], 1.0
|
||||
)
|
||||
|
||||
def _get_confidence_breakdown(self, key: "ProviderAPIKey") -> Dict[str, float]:
|
||||
"""获取置信度各因素的详细分数"""
|
||||
# 因素1: 成功率(权重 40%)
|
||||
# 使用成功率而非连续成功次数,更准确反映 Key 的稳定性
|
||||
request_count = key.request_count or 0
|
||||
success_count = key.success_count or 0
|
||||
|
||||
if request_count >= self.config.success_count_for_full_confidence:
|
||||
# 请求数足够时,根据成功率计算
|
||||
success_rate = success_count / request_count if request_count > 0 else 0
|
||||
success_score = success_rate * 0.4
|
||||
elif request_count > 0:
|
||||
# 请求数不足时,按比例折算
|
||||
progress_ratio = request_count / self.config.success_count_for_full_confidence
|
||||
success_rate = success_count / request_count
|
||||
success_score = success_rate * progress_ratio * 0.4
|
||||
else:
|
||||
success_score = 0.0
|
||||
|
||||
# 因素2: 429冷却时间(权重 30%)
|
||||
if key.last_429_at:
|
||||
now = datetime.now(timezone.utc)
|
||||
# 确保 last_429_at 有时区信息
|
||||
last_429 = key.last_429_at
|
||||
if last_429.tzinfo is None:
|
||||
last_429 = last_429.replace(tzinfo=timezone.utc)
|
||||
hours_since_429 = (now - last_429).total_seconds() / 3600
|
||||
cooldown_ratio = min(
|
||||
hours_since_429 / self.config.cooldown_hours_for_full_confidence, 1.0
|
||||
)
|
||||
cooldown_score = cooldown_ratio * 0.3
|
||||
else:
|
||||
# 从未触发 429,给满分
|
||||
cooldown_score = 0.3
|
||||
|
||||
# 因素3: 调整历史稳定性(权重 30%)
|
||||
history = key.adjustment_history or []
|
||||
if len(history) >= 3:
|
||||
# 取最近的调整记录
|
||||
recent = history[-5:] if len(history) >= 5 else history
|
||||
limits = [h.get("new_limit", 0) for h in recent if h.get("new_limit")]
|
||||
|
||||
if len(limits) >= 2:
|
||||
try:
|
||||
variance = statistics.variance(limits)
|
||||
# 方差越小越稳定,方差为10时分数接近0
|
||||
stability_ratio = max(0, 1 - variance / 10)
|
||||
stability_score = stability_ratio * 0.3
|
||||
except statistics.StatisticsError:
|
||||
stability_score = 0.15
|
||||
else:
|
||||
stability_score = 0.15
|
||||
else:
|
||||
# 历史数据不足,给一半分
|
||||
stability_score = 0.15
|
||||
|
||||
# 计算成功率用于返回
|
||||
success_rate_pct = (success_count / request_count * 100) if request_count > 0 else None
|
||||
|
||||
return {
|
||||
"success_score": round(success_score, 3),
|
||||
"cooldown_score": round(cooldown_score, 3),
|
||||
"stability_score": round(stability_score, 3),
|
||||
"request_count": request_count,
|
||||
"success_count": success_count,
|
||||
"success_rate": round(success_rate_pct, 1) if success_rate_pct is not None else None,
|
||||
"hours_since_429": (
|
||||
round(
|
||||
(
|
||||
datetime.now(timezone.utc) - key.last_429_at.replace(tzinfo=timezone.utc)
|
||||
).total_seconds()
|
||||
/ 3600,
|
||||
1,
|
||||
)
|
||||
if key.last_429_at
|
||||
else None
|
||||
),
|
||||
"history_count": len(history),
|
||||
}
|
||||
|
||||
def _calculate_stable_ratio(self, confidence: float, load_ratio: float) -> float:
|
||||
"""
|
||||
计算稳定阶段的预留比例
|
||||
|
||||
策略:
|
||||
- 低负载(<50%):使用最小预留,槽位充足无需过多预留
|
||||
- 中等负载(50-80%):根据置信度线性增加预留
|
||||
- 高负载(>80%):根据置信度使用较高预留保护缓存用户
|
||||
"""
|
||||
min_r = self.config.stable_min_reservation
|
||||
max_r = self.config.stable_max_reservation
|
||||
|
||||
if load_ratio < self.config.low_load_threshold:
|
||||
# 低负载:使用最小预留
|
||||
return min_r
|
||||
|
||||
if load_ratio < self.config.high_load_threshold:
|
||||
# 中等负载:根据置信度和负载线性插值
|
||||
# 负载越高、置信度越高,预留越多
|
||||
load_factor = (load_ratio - self.config.low_load_threshold) / (
|
||||
self.config.high_load_threshold - self.config.low_load_threshold
|
||||
)
|
||||
return min_r + confidence * load_factor * (max_r - min_r)
|
||||
|
||||
# 高负载:根据置信度决定预留比例
|
||||
# 置信度高 → 接近最大预留
|
||||
# 置信度低 → 保守预留(避免基于不准确的学习值过度预留)
|
||||
return min_r + confidence * (max_r - min_r)
|
||||
|
||||
def _get_ratio_reason(self, confidence: float, load_ratio: float) -> str:
|
||||
"""生成预留比例的解释"""
|
||||
if load_ratio < self.config.low_load_threshold:
|
||||
return f"低负载({load_ratio:.0%}),使用最小预留"
|
||||
|
||||
if confidence < 0.3:
|
||||
return f"置信度低({confidence:.0%}),保守预留避免浪费"
|
||||
|
||||
if confidence > 0.7 and load_ratio > self.config.high_load_threshold:
|
||||
return f"高置信度({confidence:.0%})+高负载({load_ratio:.0%}),使用较高预留保护缓存用户"
|
||||
|
||||
return f"置信度{confidence:.0%},负载{load_ratio:.0%},动态计算预留"
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取管理器统计信息"""
|
||||
return {
|
||||
"config": {
|
||||
"probe_phase_requests": self.config.probe_phase_requests,
|
||||
"probe_reservation": self.config.probe_reservation,
|
||||
"stable_min_reservation": self.config.stable_min_reservation,
|
||||
"stable_max_reservation": self.config.stable_max_reservation,
|
||||
"low_load_threshold": self.config.low_load_threshold,
|
||||
"high_load_threshold": self.config.high_load_threshold,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# 全局单例
|
||||
_reservation_manager: Optional[AdaptiveReservationManager] = None
|
||||
|
||||
|
||||
def get_adaptive_reservation_manager() -> AdaptiveReservationManager:
|
||||
"""获取全局自适应预留管理器单例"""
|
||||
global _reservation_manager
|
||||
if _reservation_manager is None:
|
||||
_reservation_manager = AdaptiveReservationManager()
|
||||
return _reservation_manager
|
||||
|
||||
|
||||
def reset_adaptive_reservation_manager():
|
||||
"""重置全局单例(用于测试)"""
|
||||
global _reservation_manager
|
||||
_reservation_manager = None
|
||||
582
src/services/rate_limit/concurrency_manager.py
Normal file
582
src/services/rate_limit/concurrency_manager.py
Normal file
@@ -0,0 +1,582 @@
|
||||
"""
|
||||
并发管理器 - 支持 Redis 或内存的并发控制
|
||||
|
||||
功能:
|
||||
1. Endpoint 级别的并发限制
|
||||
2. ProviderAPIKey 级别的并发限制
|
||||
3. 分布式环境下优先使用 Redis,多实例共享
|
||||
4. 在开发/单实例场景下自动降级为内存计数
|
||||
5. 自动释放和异常处理(Redis 提供 TTL,内存模式请确保手动释放)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import math
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import timedelta
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
from src.core.logger import logger
|
||||
|
||||
|
||||
class ConcurrencyManager:
|
||||
"""分布式并发管理器"""
|
||||
|
||||
_instance: Optional["ConcurrencyManager"] = None
|
||||
_redis: Optional[aioredis.Redis] = None
|
||||
|
||||
def __new__(cls):
|
||||
"""单例模式"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
"""初始化内存后端结构(只执行一次)"""
|
||||
if hasattr(self, "_memory_initialized"):
|
||||
return
|
||||
|
||||
self._memory_lock: asyncio.Lock = asyncio.Lock()
|
||||
self._memory_endpoint_counts: dict[str, int] = {}
|
||||
self._memory_key_counts: dict[str, int] = {}
|
||||
self._memory_initialized = True
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""初始化 Redis 连接"""
|
||||
if self._redis is not None:
|
||||
return
|
||||
|
||||
# 优先使用 REDIS_URL,如果没有则根据密码构建 URL
|
||||
redis_url = os.getenv("REDIS_URL")
|
||||
|
||||
if not redis_url:
|
||||
# 本地开发模式:从 REDIS_PASSWORD 构建 URL
|
||||
redis_password = os.getenv("REDIS_PASSWORD")
|
||||
if redis_password:
|
||||
redis_url = f"redis://:{redis_password}@localhost:6379/0"
|
||||
else:
|
||||
redis_url = "redis://localhost:6379/0"
|
||||
|
||||
try:
|
||||
self._redis = await aioredis.from_url(
|
||||
redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
socket_timeout=5.0,
|
||||
socket_connect_timeout=5.0,
|
||||
)
|
||||
# 测试连接
|
||||
await self._redis.ping()
|
||||
# 脱敏显示(隐藏密码)
|
||||
safe_url = redis_url.split("@")[-1] if "@" in redis_url else redis_url
|
||||
logger.info(f"[OK] Redis 连接成功: {safe_url}")
|
||||
except Exception as e:
|
||||
logger.error(f"[ERROR] Redis 连接失败: {e}")
|
||||
logger.warning("[WARN] 并发控制将被禁用(仅在单实例环境下安全)")
|
||||
self._redis = None
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭 Redis 连接"""
|
||||
if self._redis:
|
||||
await self._redis.close()
|
||||
self._redis = None
|
||||
logger.info("Redis 连接已关闭")
|
||||
|
||||
def _get_endpoint_key(self, endpoint_id: str) -> str:
|
||||
"""获取 Endpoint 并发计数的 Redis Key"""
|
||||
return f"concurrency:endpoint:{endpoint_id}"
|
||||
|
||||
def _get_key_key(self, key_id: str) -> str:
|
||||
"""获取 ProviderAPIKey 并发计数的 Redis Key"""
|
||||
return f"concurrency:key:{key_id}"
|
||||
|
||||
async def get_current_concurrency(
|
||||
self, endpoint_id: Optional[str] = None, key_id: Optional[str] = None
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
获取当前并发数
|
||||
|
||||
Args:
|
||||
endpoint_id: Endpoint ID(可选)
|
||||
key_id: ProviderAPIKey ID(可选)
|
||||
|
||||
Returns:
|
||||
(endpoint_concurrency, key_concurrency)
|
||||
"""
|
||||
if self._redis is None:
|
||||
async with self._memory_lock:
|
||||
endpoint_count = (
|
||||
self._memory_endpoint_counts.get(endpoint_id, 0) if endpoint_id else 0
|
||||
)
|
||||
key_count = self._memory_key_counts.get(key_id, 0) if key_id else 0
|
||||
return endpoint_count, key_count
|
||||
|
||||
endpoint_count = 0
|
||||
key_count = 0
|
||||
|
||||
try:
|
||||
if endpoint_id:
|
||||
endpoint_key = self._get_endpoint_key(endpoint_id)
|
||||
result = await self._redis.get(endpoint_key)
|
||||
endpoint_count = int(result) if result else 0
|
||||
|
||||
if key_id:
|
||||
key_key = self._get_key_key(key_id)
|
||||
result = await self._redis.get(key_key)
|
||||
key_count = int(result) if result else 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取并发数失败: {e}")
|
||||
|
||||
return endpoint_count, key_count
|
||||
|
||||
async def check_available(
|
||||
self,
|
||||
endpoint_id: str,
|
||||
endpoint_max_concurrent: Optional[int],
|
||||
key_id: str,
|
||||
key_max_concurrent: Optional[int],
|
||||
) -> bool:
|
||||
"""
|
||||
检查是否可以获取并发槽位(不实际获取)
|
||||
|
||||
Args:
|
||||
endpoint_id: Endpoint ID
|
||||
endpoint_max_concurrent: Endpoint 最大并发数(None 表示不限制)
|
||||
key_id: ProviderAPIKey ID
|
||||
key_max_concurrent: Key 最大并发数(None 表示不限制)
|
||||
|
||||
Returns:
|
||||
是否可用(True/False)
|
||||
"""
|
||||
if self._redis is None:
|
||||
async with self._memory_lock:
|
||||
endpoint_count = self._memory_endpoint_counts.get(endpoint_id, 0)
|
||||
key_count = self._memory_key_counts.get(key_id, 0)
|
||||
|
||||
if (
|
||||
endpoint_max_concurrent is not None
|
||||
and endpoint_count >= endpoint_max_concurrent
|
||||
):
|
||||
return False
|
||||
|
||||
if key_max_concurrent is not None and key_count >= key_max_concurrent:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
endpoint_count, key_count = await self.get_current_concurrency(endpoint_id, key_id)
|
||||
|
||||
# 检查 Endpoint 级别限制
|
||||
if endpoint_max_concurrent is not None and endpoint_count >= endpoint_max_concurrent:
|
||||
return False
|
||||
|
||||
# 检查 Key 级别限制
|
||||
if key_max_concurrent is not None and key_count >= key_max_concurrent:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def acquire_slot(
|
||||
self,
|
||||
endpoint_id: str,
|
||||
endpoint_max_concurrent: Optional[int],
|
||||
key_id: str,
|
||||
key_max_concurrent: Optional[int],
|
||||
is_cached_user: bool = False, # 新增:是否是缓存用户
|
||||
cache_reservation_ratio: float = 0.3, # 新增:缓存预留比例
|
||||
ttl_seconds: int = 600, # 10分钟 TTL,防止死锁
|
||||
) -> bool:
|
||||
"""
|
||||
尝试获取并发槽位(支持缓存用户优先级)
|
||||
|
||||
Args:
|
||||
endpoint_id: Endpoint ID
|
||||
endpoint_max_concurrent: Endpoint 最大并发数(None 表示不限制)
|
||||
key_id: ProviderAPIKey ID
|
||||
key_max_concurrent: Key 最大并发数(None 表示不限制)
|
||||
is_cached_user: 是否是缓存用户(缓存用户可使用全部槽位)
|
||||
cache_reservation_ratio: 缓存预留比例(默认30%,只对新用户生效)
|
||||
ttl_seconds: TTL 秒数,防止异常情况下的死锁
|
||||
|
||||
Returns:
|
||||
是否成功获取(True/False)
|
||||
|
||||
缓存预留机制说明:
|
||||
- 假设 key_max_concurrent = 10, cache_reservation_ratio = 0.3
|
||||
- 新用户最多使用: 7个槽位 (10 * (1 - 0.3))
|
||||
- 缓存用户最多使用: 10个槽位(全部)
|
||||
- 预留的3个槽位专门给缓存用户,保证他们的请求优先
|
||||
"""
|
||||
if self._redis is None:
|
||||
async with self._memory_lock:
|
||||
endpoint_count = self._memory_endpoint_counts.get(endpoint_id, 0)
|
||||
key_count = self._memory_key_counts.get(key_id, 0)
|
||||
|
||||
# Endpoint 限制
|
||||
if (
|
||||
endpoint_max_concurrent is not None
|
||||
and endpoint_count >= endpoint_max_concurrent
|
||||
):
|
||||
return False
|
||||
|
||||
# Key 限制,包含缓存预留
|
||||
if key_max_concurrent is not None:
|
||||
if is_cached_user:
|
||||
if key_count >= key_max_concurrent:
|
||||
return False
|
||||
else:
|
||||
available_for_new = max(
|
||||
1, math.ceil(key_max_concurrent * (1 - cache_reservation_ratio))
|
||||
)
|
||||
if key_count >= available_for_new:
|
||||
return False
|
||||
|
||||
# 通过限制,更新计数
|
||||
self._memory_endpoint_counts[endpoint_id] = endpoint_count + 1
|
||||
self._memory_key_counts[key_id] = key_count + 1
|
||||
return True
|
||||
|
||||
endpoint_key = self._get_endpoint_key(endpoint_id)
|
||||
key_key = self._get_key_key(key_id)
|
||||
|
||||
try:
|
||||
# 使用 Lua 脚本保证原子性(新增缓存预留逻辑)
|
||||
lua_script = """
|
||||
local endpoint_key = KEYS[1]
|
||||
local key_key = KEYS[2]
|
||||
local endpoint_max = tonumber(ARGV[1])
|
||||
local key_max = tonumber(ARGV[2])
|
||||
local ttl = tonumber(ARGV[3])
|
||||
local is_cached = tonumber(ARGV[4]) -- 0=新用户, 1=缓存用户
|
||||
local cache_ratio = tonumber(ARGV[5]) -- 缓存预留比例
|
||||
|
||||
-- 获取当前值
|
||||
local endpoint_count = tonumber(redis.call('GET', endpoint_key) or '0')
|
||||
local key_count = tonumber(redis.call('GET', key_key) or '0')
|
||||
|
||||
-- 检查 endpoint 限制(-1 表示不限制)
|
||||
if endpoint_max >= 0 and endpoint_count >= endpoint_max then
|
||||
return 0 -- 失败:endpoint 已满
|
||||
end
|
||||
|
||||
-- 检查 key 限制(支持缓存预留)
|
||||
if key_max >= 0 then
|
||||
if is_cached == 0 then
|
||||
-- 新用户:只能使用 (1 - cache_ratio) 的槽位
|
||||
local available_for_new = math.floor(key_max * (1 - cache_ratio))
|
||||
if key_count >= available_for_new then
|
||||
return 0 -- 失败:新用户配额已满
|
||||
end
|
||||
else
|
||||
-- 缓存用户:可以使用全部槽位
|
||||
if key_count >= key_max then
|
||||
return 0 -- 失败:总配额已满
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- 增加计数
|
||||
redis.call('INCR', endpoint_key)
|
||||
redis.call('EXPIRE', endpoint_key, ttl)
|
||||
redis.call('INCR', key_key)
|
||||
redis.call('EXPIRE', key_key, ttl)
|
||||
|
||||
return 1 -- 成功
|
||||
"""
|
||||
|
||||
# 执行脚本
|
||||
result = await self._redis.eval(
|
||||
lua_script,
|
||||
2, # 2 个 KEYS
|
||||
endpoint_key,
|
||||
key_key,
|
||||
endpoint_max_concurrent if endpoint_max_concurrent is not None else -1,
|
||||
key_max_concurrent if key_max_concurrent is not None else -1,
|
||||
ttl_seconds,
|
||||
1 if is_cached_user else 0, # 缓存用户标志
|
||||
cache_reservation_ratio, # 预留比例
|
||||
)
|
||||
|
||||
success = result == 1
|
||||
|
||||
if success:
|
||||
user_type = "缓存用户" if is_cached_user else "新用户"
|
||||
logger.debug(
|
||||
f"[OK] 获取并发槽位成功: endpoint={endpoint_id}, key={key_id}, "
|
||||
f"类型={user_type}"
|
||||
)
|
||||
else:
|
||||
endpoint_count, key_count = await self.get_current_concurrency(endpoint_id, key_id)
|
||||
|
||||
# 计算新用户可用槽位
|
||||
if key_max_concurrent and not is_cached_user:
|
||||
available_for_new = int(key_max_concurrent * (1 - cache_reservation_ratio))
|
||||
user_info = f"新用户配额={available_for_new}, 当前={key_count}"
|
||||
else:
|
||||
user_info = f"缓存用户, 当前={key_count}/{key_max_concurrent}"
|
||||
|
||||
logger.warning(
|
||||
f"[WARN] 并发槽位已满: endpoint={endpoint_id}({endpoint_count}/{endpoint_max_concurrent}), "
|
||||
f"key={key_id}({user_info})"
|
||||
)
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取并发槽位失败,降级到内存模式: {e}")
|
||||
# Redis 异常时降级到内存模式进行保守限流
|
||||
# 使用较低的限制值(原限制的 50%)避免上游 API 被打爆
|
||||
async with self._memory_lock:
|
||||
endpoint_count = self._memory_endpoint_counts.get(endpoint_id, 0)
|
||||
key_count = self._memory_key_counts.get(key_id, 0)
|
||||
|
||||
# 降级模式下使用更保守的限制(50%)
|
||||
fallback_endpoint_limit = (
|
||||
max(1, endpoint_max_concurrent // 2)
|
||||
if endpoint_max_concurrent is not None
|
||||
else None
|
||||
)
|
||||
fallback_key_limit = (
|
||||
max(1, key_max_concurrent // 2) if key_max_concurrent is not None else None
|
||||
)
|
||||
|
||||
if (
|
||||
fallback_endpoint_limit is not None
|
||||
and endpoint_count >= fallback_endpoint_limit
|
||||
):
|
||||
logger.warning(
|
||||
f"[FALLBACK] Endpoint 并发达到降级限制: {endpoint_count}/{fallback_endpoint_limit}"
|
||||
)
|
||||
return False
|
||||
|
||||
if fallback_key_limit is not None and key_count >= fallback_key_limit:
|
||||
logger.warning(
|
||||
f"[FALLBACK] Key 并发达到降级限制: {key_count}/{fallback_key_limit}"
|
||||
)
|
||||
return False
|
||||
|
||||
# 更新内存计数
|
||||
self._memory_endpoint_counts[endpoint_id] = endpoint_count + 1
|
||||
self._memory_key_counts[key_id] = key_count + 1
|
||||
logger.debug(
|
||||
f"[FALLBACK] 使用内存模式获取槽位: endpoint={endpoint_id}, key={key_id}"
|
||||
)
|
||||
return True
|
||||
|
||||
async def release_slot(self, endpoint_id: str, key_id: str) -> None:
|
||||
"""
|
||||
释放并发槽位
|
||||
|
||||
Args:
|
||||
endpoint_id: Endpoint ID
|
||||
key_id: ProviderAPIKey ID
|
||||
"""
|
||||
if self._redis is None:
|
||||
async with self._memory_lock:
|
||||
if endpoint_id in self._memory_endpoint_counts:
|
||||
self._memory_endpoint_counts[endpoint_id] = max(
|
||||
0, self._memory_endpoint_counts[endpoint_id] - 1
|
||||
)
|
||||
if self._memory_endpoint_counts[endpoint_id] == 0:
|
||||
self._memory_endpoint_counts.pop(endpoint_id, None)
|
||||
|
||||
if key_id in self._memory_key_counts:
|
||||
self._memory_key_counts[key_id] = max(0, self._memory_key_counts[key_id] - 1)
|
||||
if self._memory_key_counts[key_id] == 0:
|
||||
self._memory_key_counts.pop(key_id, None)
|
||||
return
|
||||
|
||||
endpoint_key = self._get_endpoint_key(endpoint_id)
|
||||
key_key = self._get_key_key(key_id)
|
||||
|
||||
try:
|
||||
# 使用 Lua 脚本保证原子性(不会减到负数)
|
||||
lua_script = """
|
||||
local endpoint_key = KEYS[1]
|
||||
local key_key = KEYS[2]
|
||||
|
||||
local endpoint_count = tonumber(redis.call('GET', endpoint_key) or '0')
|
||||
local key_count = tonumber(redis.call('GET', key_key) or '0')
|
||||
|
||||
if endpoint_count > 0 then
|
||||
redis.call('DECR', endpoint_key)
|
||||
end
|
||||
|
||||
if key_count > 0 then
|
||||
redis.call('DECR', key_key)
|
||||
end
|
||||
|
||||
return 1
|
||||
"""
|
||||
|
||||
await self._redis.eval(lua_script, 2, endpoint_key, key_key)
|
||||
|
||||
logger.debug(f"[OK] 释放并发槽位: endpoint={endpoint_id}, key={key_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"释放并发槽位失败: {e}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def concurrency_guard(
|
||||
self,
|
||||
endpoint_id: str,
|
||||
endpoint_max_concurrent: Optional[int],
|
||||
key_id: str,
|
||||
key_max_concurrent: Optional[int],
|
||||
is_cached_user: bool = False, # 新增:是否是缓存用户
|
||||
cache_reservation_ratio: float = 0.3, # 新增:缓存预留比例
|
||||
):
|
||||
"""
|
||||
并发控制上下文管理器(支持缓存用户优先级)
|
||||
|
||||
用法:
|
||||
async with manager.concurrency_guard(
|
||||
endpoint_id, endpoint_max, key_id, key_max,
|
||||
is_cached_user=True # 缓存用户
|
||||
):
|
||||
# 执行请求
|
||||
response = await send_request(...)
|
||||
|
||||
如果获取失败,会抛出 ConcurrencyLimitError 异常
|
||||
"""
|
||||
# 尝试获取槽位(传递缓存用户参数)
|
||||
acquired = await self.acquire_slot(
|
||||
endpoint_id,
|
||||
endpoint_max_concurrent,
|
||||
key_id,
|
||||
key_max_concurrent,
|
||||
is_cached_user,
|
||||
cache_reservation_ratio,
|
||||
)
|
||||
|
||||
if not acquired:
|
||||
from src.core.exceptions import ConcurrencyLimitError
|
||||
|
||||
user_type = "缓存用户" if is_cached_user else "新用户"
|
||||
raise ConcurrencyLimitError(
|
||||
f"并发限制已达上限: endpoint={endpoint_id}, key={key_id}, 类型={user_type}"
|
||||
)
|
||||
|
||||
# 记录开始时间和状态
|
||||
import time
|
||||
|
||||
slot_acquired_at = time.time()
|
||||
exception_occurred = False
|
||||
|
||||
try:
|
||||
yield # 执行请求
|
||||
except Exception as e:
|
||||
# 记录异常
|
||||
exception_occurred = True
|
||||
raise
|
||||
finally:
|
||||
# 计算槽位占用时长
|
||||
slot_duration = time.time() - slot_acquired_at
|
||||
|
||||
# 记录 Prometheus 指标
|
||||
try:
|
||||
from src.core.metrics import (
|
||||
concurrency_slot_duration_seconds,
|
||||
concurrency_slot_release_total,
|
||||
)
|
||||
|
||||
# 记录槽位占用时长分布
|
||||
concurrency_slot_duration_seconds.labels(
|
||||
key_id=key_id[:8] if key_id else "unknown", # 只记录前8位
|
||||
exception=str(exception_occurred),
|
||||
).observe(slot_duration)
|
||||
|
||||
# 记录槽位释放计数
|
||||
concurrency_slot_release_total.labels(
|
||||
key_id=key_id[:8] if key_id else "unknown", exception=str(exception_occurred)
|
||||
).inc()
|
||||
|
||||
# 告警:槽位占用时间过长(超过 60 秒)
|
||||
if slot_duration > 60:
|
||||
logger.warning(
|
||||
f"[WARN] 并发槽位占用时间过长: "
|
||||
f"key_id={key_id[:8] if key_id else 'unknown'}..., "
|
||||
f"duration={slot_duration:.1f}s, "
|
||||
f"exception={exception_occurred}"
|
||||
)
|
||||
|
||||
except Exception as metric_error:
|
||||
# 指标记录失败不应影响业务逻辑
|
||||
logger.debug(f"记录并发指标失败: {metric_error}")
|
||||
|
||||
# 自动释放槽位(即使发生异常)
|
||||
await self.release_slot(endpoint_id, key_id)
|
||||
|
||||
async def reset_concurrency(
|
||||
self, endpoint_id: Optional[str] = None, key_id: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
重置并发计数(管理功能,慎用)
|
||||
|
||||
Args:
|
||||
endpoint_id: Endpoint ID(可选,None 表示重置所有 endpoint)
|
||||
key_id: ProviderAPIKey ID(可选,None 表示重置所有 key)
|
||||
"""
|
||||
if self._redis is None:
|
||||
async with self._memory_lock:
|
||||
if endpoint_id:
|
||||
self._memory_endpoint_counts.pop(endpoint_id, None)
|
||||
logger.info(f"[RESET] 重置 Endpoint 并发计数(内存): {endpoint_id}")
|
||||
else:
|
||||
count = len(self._memory_endpoint_counts)
|
||||
self._memory_endpoint_counts.clear()
|
||||
if count:
|
||||
logger.info(f"[RESET] 重置所有 Endpoint 并发计数(内存): {count} 个")
|
||||
|
||||
if key_id:
|
||||
self._memory_key_counts.pop(key_id, None)
|
||||
logger.info(f"[RESET] 重置 Key 并发计数(内存): {key_id}")
|
||||
else:
|
||||
count = len(self._memory_key_counts)
|
||||
self._memory_key_counts.clear()
|
||||
if count:
|
||||
logger.info(f"[RESET] 重置所有 Key 并发计数(内存): {count} 个")
|
||||
return
|
||||
|
||||
try:
|
||||
if endpoint_id:
|
||||
endpoint_key = self._get_endpoint_key(endpoint_id)
|
||||
await self._redis.delete(endpoint_key)
|
||||
logger.info(f"[RESET] 重置 Endpoint 并发计数: {endpoint_id}")
|
||||
else:
|
||||
# 重置所有 endpoint
|
||||
keys = await self._redis.keys("concurrency:endpoint:*")
|
||||
if keys:
|
||||
await self._redis.delete(*keys)
|
||||
logger.info(f"[RESET] 重置所有 Endpoint 并发计数: {len(keys)} 个")
|
||||
|
||||
if key_id:
|
||||
key_key = self._get_key_key(key_id)
|
||||
await self._redis.delete(key_key)
|
||||
logger.info(f"[RESET] 重置 Key 并发计数: {key_id}")
|
||||
else:
|
||||
# 重置所有 key
|
||||
keys = await self._redis.keys("concurrency:key:*")
|
||||
if keys:
|
||||
await self._redis.delete(*keys)
|
||||
logger.info(f"[RESET] 重置所有 Key 并发计数: {len(keys)} 个")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"重置并发计数失败: {e}")
|
||||
|
||||
|
||||
# 全局单例
|
||||
_concurrency_manager: Optional[ConcurrencyManager] = None
|
||||
|
||||
|
||||
async def get_concurrency_manager() -> ConcurrencyManager:
|
||||
"""获取全局 ConcurrencyManager 实例"""
|
||||
global _concurrency_manager
|
||||
|
||||
if _concurrency_manager is None:
|
||||
_concurrency_manager = ConcurrencyManager()
|
||||
await _concurrency_manager.initialize()
|
||||
|
||||
return _concurrency_manager
|
||||
333
src/services/rate_limit/detector.py
Normal file
333
src/services/rate_limit/detector.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""
|
||||
速率限制检测器 - 解析429响应头,区分并发限制和RPM限制
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from src.core.logger import logger
|
||||
|
||||
|
||||
class RateLimitType:
|
||||
"""速率限制类型"""
|
||||
|
||||
CONCURRENT = "concurrent" # 并发限制
|
||||
RPM = "rpm" # 每分钟请求数限制
|
||||
DAILY = "daily" # 每日限制
|
||||
MONTHLY = "monthly" # 每月限制
|
||||
UNKNOWN = "unknown" # 未知类型
|
||||
|
||||
|
||||
class RateLimitInfo:
|
||||
"""速率限制信息"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
limit_type: str,
|
||||
retry_after: Optional[int] = None,
|
||||
limit_value: Optional[int] = None,
|
||||
remaining: Optional[int] = None,
|
||||
reset_at: Optional[datetime] = None,
|
||||
current_usage: Optional[int] = None,
|
||||
raw_headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
self.limit_type = limit_type
|
||||
self.retry_after = retry_after # 需要等待的秒数
|
||||
self.limit_value = limit_value # 限制值
|
||||
self.remaining = remaining # 剩余配额
|
||||
self.reset_at = reset_at # 重置时间
|
||||
self.current_usage = current_usage # 当前使用量
|
||||
self.raw_headers = raw_headers or {}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"RateLimitInfo(type={self.limit_type}, "
|
||||
f"retry_after={self.retry_after}, "
|
||||
f"limit={self.limit_value}, "
|
||||
f"remaining={self.remaining})"
|
||||
)
|
||||
|
||||
|
||||
class RateLimitDetector:
|
||||
"""
|
||||
速率限制检测器
|
||||
|
||||
支持的提供商:
|
||||
- Anthropic Claude API
|
||||
- OpenAI API
|
||||
- 通用 HTTP 标准头
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def detect_from_headers(
|
||||
headers: Dict[str, str],
|
||||
provider_name: str = "unknown",
|
||||
current_concurrent: Optional[int] = None,
|
||||
) -> RateLimitInfo:
|
||||
"""
|
||||
从响应头中检测速率限制类型
|
||||
|
||||
Args:
|
||||
headers: 429响应的HTTP头
|
||||
provider_name: 提供商名称(用于选择解析策略)
|
||||
current_concurrent: 当前并发数(用于判断是否为并发限制)
|
||||
|
||||
Returns:
|
||||
RateLimitInfo对象
|
||||
"""
|
||||
# 标准化header key (转小写)
|
||||
headers_lower = {k.lower(): v for k, v in headers.items()}
|
||||
|
||||
# 根据提供商选择解析策略
|
||||
if "anthropic" in provider_name.lower() or "claude" in provider_name.lower():
|
||||
return RateLimitDetector._parse_anthropic_headers(headers_lower, current_concurrent)
|
||||
elif "openai" in provider_name.lower():
|
||||
return RateLimitDetector._parse_openai_headers(headers_lower, current_concurrent)
|
||||
else:
|
||||
return RateLimitDetector._parse_generic_headers(headers_lower, current_concurrent)
|
||||
|
||||
@staticmethod
|
||||
def _parse_anthropic_headers(
|
||||
headers: Dict[str, str],
|
||||
current_concurrent: Optional[int] = None,
|
||||
) -> RateLimitInfo:
|
||||
"""
|
||||
解析 Anthropic Claude API 的速率限制头
|
||||
|
||||
常见头部:
|
||||
- anthropic-ratelimit-requests-limit: 50
|
||||
- anthropic-ratelimit-requests-remaining: 0
|
||||
- anthropic-ratelimit-requests-reset: 2024-01-01T00:00:00Z
|
||||
- anthropic-ratelimit-tokens-limit: 100000
|
||||
- anthropic-ratelimit-tokens-remaining: 50000
|
||||
- retry-after: 60
|
||||
"""
|
||||
retry_after = RateLimitDetector._parse_retry_after(headers)
|
||||
|
||||
# 获取请求限制信息
|
||||
requests_limit = RateLimitDetector._parse_int(
|
||||
headers.get("anthropic-ratelimit-requests-limit")
|
||||
)
|
||||
requests_remaining = RateLimitDetector._parse_int(
|
||||
headers.get("anthropic-ratelimit-requests-remaining")
|
||||
)
|
||||
requests_reset = RateLimitDetector._parse_datetime(
|
||||
headers.get("anthropic-ratelimit-requests-reset")
|
||||
)
|
||||
|
||||
# 判断限制类型
|
||||
# 1. 明确的 RPM 限制:请求数剩余为 0
|
||||
if requests_remaining is not None and requests_remaining == 0:
|
||||
return RateLimitInfo(
|
||||
limit_type=RateLimitType.RPM,
|
||||
retry_after=retry_after,
|
||||
limit_value=requests_limit,
|
||||
remaining=requests_remaining,
|
||||
reset_at=requests_reset,
|
||||
raw_headers=headers,
|
||||
)
|
||||
|
||||
# 2. 可能的并发限制判断(多条件综合)
|
||||
# 条件:当前并发数存在,且 remaining > 0(说明不是 RPM 耗尽)
|
||||
# 同时 retry_after 较短(并发限制通常 retry_after 较短,如 1-10 秒)
|
||||
is_likely_concurrent = (
|
||||
current_concurrent is not None
|
||||
and current_concurrent >= 2 # 至少有 2 个并发
|
||||
and (requests_remaining is None or requests_remaining > 0) # RPM 未耗尽
|
||||
and (retry_after is None or retry_after <= 30) # 短暂等待
|
||||
)
|
||||
|
||||
if is_likely_concurrent:
|
||||
logger.info(
|
||||
f"检测到可能的并发限制: current_concurrent={current_concurrent}, "
|
||||
f"remaining={requests_remaining}, retry_after={retry_after}"
|
||||
)
|
||||
return RateLimitInfo(
|
||||
limit_type=RateLimitType.CONCURRENT,
|
||||
retry_after=retry_after,
|
||||
current_usage=current_concurrent,
|
||||
raw_headers=headers,
|
||||
)
|
||||
|
||||
# 3. 未知类型
|
||||
return RateLimitInfo(
|
||||
limit_type=RateLimitType.UNKNOWN,
|
||||
retry_after=retry_after,
|
||||
raw_headers=headers,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_openai_headers(
|
||||
headers: Dict[str, str],
|
||||
current_concurrent: Optional[int] = None,
|
||||
) -> RateLimitInfo:
|
||||
"""
|
||||
解析 OpenAI API 的速率限制头
|
||||
|
||||
常见头部:
|
||||
- x-ratelimit-limit-requests: 3500
|
||||
- x-ratelimit-remaining-requests: 0
|
||||
- x-ratelimit-reset-requests: 2024-01-01T00:00:00Z
|
||||
- x-ratelimit-limit-tokens: 90000
|
||||
- x-ratelimit-remaining-tokens: 50000
|
||||
- retry-after: 60
|
||||
"""
|
||||
retry_after = RateLimitDetector._parse_retry_after(headers)
|
||||
|
||||
# 获取请求限制信息
|
||||
requests_limit = RateLimitDetector._parse_int(headers.get("x-ratelimit-limit-requests"))
|
||||
requests_remaining = RateLimitDetector._parse_int(
|
||||
headers.get("x-ratelimit-remaining-requests")
|
||||
)
|
||||
requests_reset = RateLimitDetector._parse_datetime(
|
||||
headers.get("x-ratelimit-reset-requests")
|
||||
)
|
||||
|
||||
# 判断限制类型
|
||||
# 1. 明确的 RPM 限制
|
||||
if requests_remaining is not None and requests_remaining == 0:
|
||||
return RateLimitInfo(
|
||||
limit_type=RateLimitType.RPM,
|
||||
retry_after=retry_after,
|
||||
limit_value=requests_limit,
|
||||
remaining=requests_remaining,
|
||||
reset_at=requests_reset,
|
||||
raw_headers=headers,
|
||||
)
|
||||
|
||||
# 2. 可能的并发限制(多条件综合判断)
|
||||
is_likely_concurrent = (
|
||||
current_concurrent is not None
|
||||
and current_concurrent >= 2
|
||||
and (requests_remaining is None or requests_remaining > 0)
|
||||
and (retry_after is None or retry_after <= 30)
|
||||
)
|
||||
|
||||
if is_likely_concurrent:
|
||||
return RateLimitInfo(
|
||||
limit_type=RateLimitType.CONCURRENT,
|
||||
retry_after=retry_after,
|
||||
current_usage=current_concurrent,
|
||||
raw_headers=headers,
|
||||
)
|
||||
|
||||
# 3. 未知类型
|
||||
return RateLimitInfo(
|
||||
limit_type=RateLimitType.UNKNOWN,
|
||||
retry_after=retry_after,
|
||||
raw_headers=headers,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_generic_headers(
|
||||
headers: Dict[str, str],
|
||||
current_concurrent: Optional[int] = None,
|
||||
) -> RateLimitInfo:
|
||||
"""
|
||||
解析通用的速率限制头
|
||||
|
||||
标准头部:
|
||||
- retry-after: 60
|
||||
- x-ratelimit-limit: 100
|
||||
- x-ratelimit-remaining: 0
|
||||
- x-ratelimit-reset: 1609459200
|
||||
"""
|
||||
retry_after = RateLimitDetector._parse_retry_after(headers)
|
||||
|
||||
limit_value = RateLimitDetector._parse_int(headers.get("x-ratelimit-limit"))
|
||||
remaining = RateLimitDetector._parse_int(headers.get("x-ratelimit-remaining"))
|
||||
|
||||
# 1. 明确的 RPM 限制
|
||||
if remaining is not None and remaining == 0:
|
||||
return RateLimitInfo(
|
||||
limit_type=RateLimitType.RPM,
|
||||
retry_after=retry_after,
|
||||
limit_value=limit_value,
|
||||
remaining=remaining,
|
||||
raw_headers=headers,
|
||||
)
|
||||
|
||||
# 2. 可能的并发限制
|
||||
is_likely_concurrent = (
|
||||
current_concurrent is not None
|
||||
and current_concurrent >= 2
|
||||
and (remaining is None or remaining > 0)
|
||||
and (retry_after is None or retry_after <= 30)
|
||||
)
|
||||
|
||||
if is_likely_concurrent:
|
||||
return RateLimitInfo(
|
||||
limit_type=RateLimitType.CONCURRENT,
|
||||
retry_after=retry_after,
|
||||
current_usage=current_concurrent,
|
||||
raw_headers=headers,
|
||||
)
|
||||
|
||||
# 3. 未知类型
|
||||
return RateLimitInfo(
|
||||
limit_type=RateLimitType.UNKNOWN,
|
||||
retry_after=retry_after,
|
||||
raw_headers=headers,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_retry_after(headers: Dict[str, str]) -> Optional[int]:
|
||||
"""解析 Retry-After 头"""
|
||||
retry_after_str = headers.get("retry-after")
|
||||
if not retry_after_str:
|
||||
return None
|
||||
|
||||
try:
|
||||
# 尝试解析为整数(秒数)
|
||||
return int(retry_after_str)
|
||||
except ValueError:
|
||||
# 尝试解析为HTTP日期格式
|
||||
try:
|
||||
retry_date = datetime.strptime(retry_after_str, "%a, %d %b %Y %H:%M:%S %Z")
|
||||
delta = retry_date - datetime.now(timezone.utc)
|
||||
return max(int(delta.total_seconds()), 0)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _parse_int(value: Optional[str]) -> Optional[int]:
|
||||
"""安全解析整数"""
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _parse_datetime(value: Optional[str]) -> Optional[datetime]:
|
||||
"""安全解析ISO 8601日期时间"""
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
# 尝试解析 ISO 8601 格式
|
||||
if value.endswith("Z"):
|
||||
value = value[:-1] + "+00:00"
|
||||
return datetime.fromisoformat(value)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def detect_rate_limit_type(
|
||||
headers: Dict[str, str],
|
||||
provider_name: str = "unknown",
|
||||
current_concurrent: Optional[int] = None,
|
||||
) -> RateLimitInfo:
|
||||
"""
|
||||
检测速率限制类型(便捷函数)
|
||||
|
||||
Args:
|
||||
headers: 429响应头
|
||||
provider_name: 提供商名称
|
||||
current_concurrent: 当前并发数
|
||||
|
||||
Returns:
|
||||
RateLimitInfo对象
|
||||
"""
|
||||
return RateLimitDetector.detect_from_headers(headers, provider_name, current_concurrent)
|
||||
351
src/services/rate_limit/ip_limiter.py
Normal file
351
src/services/rate_limit/ip_limiter.py
Normal file
@@ -0,0 +1,351 @@
|
||||
"""
|
||||
IP 级别的速率限制服务
|
||||
|
||||
提供基于 IP 地址的速率限制,防止暴力破解和 DDoS 攻击
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Optional, Set
|
||||
|
||||
from src.clients.redis_client import get_redis_client
|
||||
from src.core.logger import logger
|
||||
|
||||
|
||||
|
||||
class IPRateLimiter:
|
||||
"""IP 速率限制服务"""
|
||||
|
||||
# Redis key 前缀
|
||||
RATE_LIMIT_PREFIX = "ip:rate_limit:"
|
||||
BLACKLIST_PREFIX = "ip:blacklist:"
|
||||
WHITELIST_KEY = "ip:whitelist"
|
||||
|
||||
# 默认限制配置(每分钟)
|
||||
DEFAULT_LIMITS = {
|
||||
"default": 100, # 默认限制
|
||||
"login": 5, # 登录接口
|
||||
"register": 3, # 注册接口
|
||||
"api": 60, # API 接口
|
||||
"public": 60, # 公共接口
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def check_limit(
|
||||
ip_address: str, endpoint_type: str = "default", limit: Optional[int] = None
|
||||
) -> tuple[bool, int, int]:
|
||||
"""
|
||||
检查 IP 是否超过速率限制
|
||||
|
||||
Args:
|
||||
ip_address: IP 地址
|
||||
endpoint_type: 端点类型(default, login, register, api, public)
|
||||
limit: 自定义限制值,None 则使用默认值
|
||||
|
||||
Returns:
|
||||
(是否允许, 剩余次数, 重置时间秒数)
|
||||
"""
|
||||
# 检查白名单
|
||||
if await IPRateLimiter.is_whitelisted(ip_address):
|
||||
return True, 999999, 60
|
||||
|
||||
# 检查黑名单
|
||||
if await IPRateLimiter.is_blacklisted(ip_address):
|
||||
logger.warning(f"黑名单 IP 尝试访问: {ip_address}, 类型: {endpoint_type}")
|
||||
return False, 0, 0
|
||||
|
||||
redis_client = await get_redis_client(require_redis=False)
|
||||
|
||||
if redis_client is None:
|
||||
# Redis 不可用时降级:允许访问但记录警告
|
||||
logger.warning("Redis 不可用,跳过 IP 速率限制(降级模式)")
|
||||
return True, 0, 60
|
||||
|
||||
# 确定限制值
|
||||
rate_limit = (
|
||||
limit if limit is not None else IPRateLimiter.DEFAULT_LIMITS.get(endpoint_type, 100)
|
||||
)
|
||||
|
||||
try:
|
||||
# Redis key: ip:rate_limit:{type}:{ip}
|
||||
redis_key = f"{IPRateLimiter.RATE_LIMIT_PREFIX}{endpoint_type}:{ip_address}"
|
||||
|
||||
# 使用 Redis 的滑动窗口计数器
|
||||
# INCR 并设置过期时间
|
||||
count = await redis_client.incr(redis_key)
|
||||
|
||||
# 第一次访问时设置过期时间
|
||||
if count == 1:
|
||||
await redis_client.expire(redis_key, 60) # 60秒窗口
|
||||
|
||||
# 获取 TTL(剩余过期时间)
|
||||
ttl = await redis_client.ttl(redis_key)
|
||||
if ttl < 0:
|
||||
# 如果没有过期时间,重新设置
|
||||
await redis_client.expire(redis_key, 60)
|
||||
ttl = 60
|
||||
|
||||
remaining = max(0, rate_limit - count)
|
||||
allowed = count <= rate_limit
|
||||
|
||||
if not allowed:
|
||||
logger.warning(f"IP 速率限制触发: {ip_address}, 类型: {endpoint_type}, 计数: {count}/{rate_limit}")
|
||||
|
||||
return allowed, remaining, ttl
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检查 IP 速率限制失败: {e}")
|
||||
# 发生错误时允许访问,避免误杀
|
||||
return True, 0, 60
|
||||
|
||||
@staticmethod
|
||||
async def add_to_blacklist(
|
||||
ip_address: str, reason: str = "manual", ttl: Optional[int] = None
|
||||
) -> bool:
|
||||
"""
|
||||
将 IP 加入黑名单
|
||||
|
||||
Args:
|
||||
ip_address: IP 地址
|
||||
reason: 加入黑名单的原因
|
||||
ttl: 过期时间(秒),None 表示永久
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
redis_client = await get_redis_client(require_redis=False)
|
||||
|
||||
if redis_client is None:
|
||||
logger.warning("Redis 不可用,无法将 IP 加入黑名单")
|
||||
return False
|
||||
|
||||
try:
|
||||
redis_key = f"{IPRateLimiter.BLACKLIST_PREFIX}{ip_address}"
|
||||
|
||||
if ttl is not None:
|
||||
await redis_client.setex(redis_key, ttl, reason)
|
||||
else:
|
||||
await redis_client.set(redis_key, reason)
|
||||
|
||||
logger.warning(f"IP 已加入黑名单: {ip_address}, 原因: {reason}, TTL: {ttl or '永久'}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"添加 IP 到黑名单失败: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def remove_from_blacklist(ip_address: str) -> bool:
|
||||
"""
|
||||
从黑名单移除 IP
|
||||
|
||||
Args:
|
||||
ip_address: IP 地址
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
redis_client = await get_redis_client(require_redis=False)
|
||||
|
||||
if redis_client is None:
|
||||
logger.warning("Redis 不可用,无法从黑名单移除 IP")
|
||||
return False
|
||||
|
||||
try:
|
||||
redis_key = f"{IPRateLimiter.BLACKLIST_PREFIX}{ip_address}"
|
||||
deleted = await redis_client.delete(redis_key)
|
||||
|
||||
if deleted:
|
||||
logger.info(f"IP 已从黑名单移除: {ip_address}")
|
||||
|
||||
return bool(deleted)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从黑名单移除 IP 失败: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def is_blacklisted(ip_address: str) -> bool:
|
||||
"""
|
||||
检查 IP 是否在黑名单中
|
||||
|
||||
Args:
|
||||
ip_address: IP 地址
|
||||
|
||||
Returns:
|
||||
是否在黑名单中
|
||||
"""
|
||||
redis_client = await get_redis_client(require_redis=False)
|
||||
|
||||
if redis_client is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
redis_key = f"{IPRateLimiter.BLACKLIST_PREFIX}{ip_address}"
|
||||
exists = await redis_client.exists(redis_key)
|
||||
return bool(exists)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检查 IP 黑名单状态失败: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def add_to_whitelist(ip_address: str) -> bool:
|
||||
"""
|
||||
将 IP 加入白名单
|
||||
|
||||
Args:
|
||||
ip_address: IP 地址或 CIDR 格式(如 192.168.1.0/24)
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
redis_client = await get_redis_client(require_redis=False)
|
||||
|
||||
if redis_client is None:
|
||||
logger.warning("Redis 不可用,无法将 IP 加入白名单")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 验证 IP 格式
|
||||
try:
|
||||
ipaddress.ip_network(ip_address, strict=False)
|
||||
except ValueError as e:
|
||||
logger.error(f"无效的 IP 地址格式: {ip_address}, 错误: {e}")
|
||||
return False
|
||||
|
||||
# 使用 Redis Set 存储白名单
|
||||
await redis_client.sadd(IPRateLimiter.WHITELIST_KEY, ip_address)
|
||||
|
||||
logger.info(f"IP 已加入白名单: {ip_address}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"添加 IP 到白名单失败: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def remove_from_whitelist(ip_address: str) -> bool:
|
||||
"""
|
||||
从白名单移除 IP
|
||||
|
||||
Args:
|
||||
ip_address: IP 地址
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
redis_client = await get_redis_client(require_redis=False)
|
||||
|
||||
if redis_client is None:
|
||||
logger.warning("Redis 不可用,无法从白名单移除 IP")
|
||||
return False
|
||||
|
||||
try:
|
||||
removed = await redis_client.srem(IPRateLimiter.WHITELIST_KEY, ip_address)
|
||||
|
||||
if removed:
|
||||
logger.info(f"IP 已从白名单移除: {ip_address}")
|
||||
|
||||
return bool(removed)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从白名单移除 IP 失败: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def is_whitelisted(ip_address: str) -> bool:
|
||||
"""
|
||||
检查 IP 是否在白名单中(支持 CIDR 匹配)
|
||||
|
||||
Args:
|
||||
ip_address: IP 地址
|
||||
|
||||
Returns:
|
||||
是否在白名单中
|
||||
"""
|
||||
redis_client = await get_redis_client(require_redis=False)
|
||||
|
||||
if redis_client is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
# 获取所有白名单条目
|
||||
whitelist = await redis_client.smembers(IPRateLimiter.WHITELIST_KEY)
|
||||
|
||||
if not whitelist:
|
||||
return False
|
||||
|
||||
# 将 IP 地址转换为 ip_address 对象
|
||||
try:
|
||||
ip_obj = ipaddress.ip_address(ip_address)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
# 检查是否匹配白名单中的任何条目
|
||||
for entry in whitelist:
|
||||
try:
|
||||
network = ipaddress.ip_network(entry, strict=False)
|
||||
if ip_obj in network:
|
||||
return True
|
||||
except ValueError:
|
||||
# 如果条目格式无效,跳过
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检查 IP 白名单状态失败: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def get_blacklist_stats() -> Dict:
|
||||
"""
|
||||
获取黑名单统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
redis_client = await get_redis_client(require_redis=False)
|
||||
|
||||
if redis_client is None:
|
||||
return {"available": False, "total": 0, "error": "Redis 不可用"}
|
||||
|
||||
try:
|
||||
pattern = f"{IPRateLimiter.BLACKLIST_PREFIX}*"
|
||||
cursor = 0
|
||||
total = 0
|
||||
|
||||
while True:
|
||||
cursor, keys = await redis_client.scan(cursor=cursor, match=pattern, count=100)
|
||||
total += len(keys)
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
return {"available": True, "total": total}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取黑名单统计失败: {e}")
|
||||
return {"available": False, "total": 0, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def get_whitelist() -> Set[str]:
|
||||
"""
|
||||
获取白名单列表
|
||||
|
||||
Returns:
|
||||
白名单 IP 集合
|
||||
"""
|
||||
redis_client = await get_redis_client(require_redis=False)
|
||||
|
||||
if redis_client is None:
|
||||
return set()
|
||||
|
||||
try:
|
||||
whitelist = await redis_client.smembers(IPRateLimiter.WHITELIST_KEY)
|
||||
return whitelist if whitelist else set()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取白名单失败: {e}")
|
||||
return set()
|
||||
139
src/services/rate_limit/rpm_limiter.py
Normal file
139
src/services/rate_limit/rpm_limiter.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""
|
||||
RPM (Requests Per Minute) 限流服务
|
||||
"""
|
||||
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.batch_committer import get_batch_committer
|
||||
from src.core.logger import logger
|
||||
from src.models.database import Provider
|
||||
from src.models.database_extensions import ProviderUsageTracking
|
||||
|
||||
|
||||
|
||||
class RPMLimiter:
|
||||
"""RPM限流器"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
# 内存中的RPM计数器 {provider_id: (count, window_start)}
|
||||
self._rpm_counters: Dict[str, Tuple[int, float]] = {}
|
||||
|
||||
def check_and_increment(self, provider_id: str) -> bool:
|
||||
"""
|
||||
检查并递增RPM计数
|
||||
|
||||
Returns:
|
||||
True if allowed, False if rate limited
|
||||
"""
|
||||
provider = self.db.query(Provider).filter(Provider.id == provider_id).first()
|
||||
if not provider:
|
||||
return True
|
||||
|
||||
rpm_limit = provider.rpm_limit
|
||||
if rpm_limit is None:
|
||||
# 未设置限制
|
||||
return True
|
||||
|
||||
if rpm_limit == 0:
|
||||
logger.warning(f"Provider {provider.name} is fully restricted by RPM limit=0")
|
||||
return False
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 检查是否需要重置
|
||||
if provider.rpm_reset_at and provider.rpm_reset_at < datetime.now(timezone.utc):
|
||||
provider.rpm_used = 0
|
||||
provider.rpm_reset_at = datetime.fromtimestamp(current_time + 60, tz=timezone.utc)
|
||||
self.db.commit() # 立即提交事务,释放数据库锁
|
||||
|
||||
# 检查是否超限
|
||||
if provider.rpm_used >= rpm_limit:
|
||||
logger.warning(f"Provider {provider.name} RPM limit exceeded")
|
||||
return False
|
||||
|
||||
# 递增计数
|
||||
provider.rpm_used += 1
|
||||
if not provider.rpm_reset_at:
|
||||
provider.rpm_reset_at = datetime.fromtimestamp(current_time + 60, tz=timezone.utc)
|
||||
|
||||
self.db.commit() # 立即提交事务,释放数据库锁
|
||||
return True
|
||||
|
||||
def record_usage(
|
||||
self, provider_id: str, success: bool, response_time_ms: float, cost_usd: float
|
||||
):
|
||||
"""记录使用情况到追踪表"""
|
||||
|
||||
# 获取当前分钟窗口
|
||||
now = datetime.now(timezone.utc)
|
||||
window_start = now.replace(second=0, microsecond=0)
|
||||
window_end = (
|
||||
window_start.replace(minute=window_start.minute + 1)
|
||||
if window_start.minute < 59
|
||||
else window_start.replace(hour=window_start.hour + 1, minute=0)
|
||||
)
|
||||
|
||||
# 查找或创建追踪记录
|
||||
tracking = (
|
||||
self.db.query(ProviderUsageTracking)
|
||||
.filter(
|
||||
ProviderUsageTracking.provider_id == provider_id,
|
||||
ProviderUsageTracking.window_start == window_start,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tracking:
|
||||
tracking = ProviderUsageTracking(
|
||||
provider_id=provider_id, window_start=window_start, window_end=window_end
|
||||
)
|
||||
self.db.add(tracking)
|
||||
|
||||
# 更新统计
|
||||
tracking.total_requests += 1
|
||||
if success:
|
||||
tracking.successful_requests += 1
|
||||
else:
|
||||
tracking.failed_requests += 1
|
||||
|
||||
tracking.total_response_time_ms += response_time_ms
|
||||
tracking.avg_response_time_ms = tracking.total_response_time_ms / tracking.total_requests
|
||||
tracking.total_cost_usd += cost_usd
|
||||
|
||||
self.db.flush() # 只 flush,不立即 commit
|
||||
# RPM 使用统计是非关键数据,使用批量提交
|
||||
get_batch_committer().mark_dirty(self.db)
|
||||
|
||||
logger.debug(f"Recorded usage for provider {provider_id}")
|
||||
|
||||
def get_rpm_status(self, provider_id: str) -> Dict:
|
||||
"""获取提供商的RPM状态"""
|
||||
provider = self.db.query(Provider).filter(Provider.id == provider_id).first()
|
||||
if not provider:
|
||||
return {"error": "Provider not found"}
|
||||
|
||||
return {
|
||||
"provider_id": provider_id,
|
||||
"provider_name": provider.name,
|
||||
"rpm_limit": provider.rpm_limit,
|
||||
"rpm_used": provider.rpm_used,
|
||||
"rpm_reset_at": provider.rpm_reset_at.isoformat() if provider.rpm_reset_at else None,
|
||||
"available": (
|
||||
provider.rpm_limit - provider.rpm_used if provider.rpm_limit is not None else None
|
||||
),
|
||||
}
|
||||
|
||||
def reset_rpm_counter(self, provider_id: str):
|
||||
"""手动重置RPM计数器"""
|
||||
provider = self.db.query(Provider).filter(Provider.id == provider_id).first()
|
||||
if provider:
|
||||
provider.rpm_used = 0
|
||||
provider.rpm_reset_at = None
|
||||
self.db.commit() # 立即提交事务,释放数据库锁
|
||||
|
||||
logger.info(f"Reset RPM counter for provider {provider.name}")
|
||||
Reference in New Issue
Block a user