Initial commit

This commit is contained in:
fawney19
2025-12-10 20:52:44 +08:00
commit f784106826
485 changed files with 110993 additions and 0 deletions

View File

@@ -0,0 +1,19 @@
"""
限流服务模块
包含自适应并发控制、RPM限流、IP限流等功能。
"""
from src.services.rate_limit.adaptive_concurrency import AdaptiveConcurrencyManager
from src.services.rate_limit.concurrency_manager import ConcurrencyManager
from src.services.rate_limit.detector import RateLimitDetector
from src.services.rate_limit.ip_limiter import IPRateLimiter
from src.services.rate_limit.rpm_limiter import RPMLimiter
__all__ = [
"AdaptiveConcurrencyManager",
"ConcurrencyManager",
"IPRateLimiter",
"RPMLimiter",
"RateLimitDetector",
]

View File

@@ -0,0 +1,558 @@
"""
自适应并发调整器 - 基于滑动窗口利用率的并发限制调整
核心改进(相对于旧版基于"持续高利用率"的方案):
- 使用滑动窗口采样,容忍并发波动
- 基于窗口内高利用率采样比例决策,而非要求连续高利用率
- 增加探测性扩容机制,长时间稳定时主动尝试扩容
"""
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, cast
from sqlalchemy.orm import Session
from src.config.constants import ConcurrencyDefaults
from src.core.batch_committer import get_batch_committer
from src.core.logger import logger
from src.models.database import ProviderAPIKey
from src.services.rate_limit.detector import RateLimitInfo, RateLimitType
class AdaptiveStrategy:
"""自适应策略类型"""
AIMD = "aimd" # 加性增-乘性减 (Additive Increase Multiplicative Decrease)
CONSERVATIVE = "conservative" # 保守策略(只减不增)
AGGRESSIVE = "aggressive" # 激进策略(快速探测)
class AdaptiveConcurrencyManager:
"""
自适应并发管理器
核心算法:基于滑动窗口利用率的 AIMD
- 滑动窗口记录最近 N 次请求的利用率
- 当窗口内高利用率采样比例 >= 60% 时触发扩容
- 遇到 429 错误时乘性减少 (*0.7)
- 长时间无 429 且有流量时触发探测性扩容
扩容条件(满足任一即可):
1. 滑动窗口扩容:窗口内 >= 60% 的采样利用率 >= 70%,且不在冷却期
2. 探测性扩容:距上次 429 超过 30 分钟,且期间有足够请求量
关键特性:
1. 滑动窗口容忍并发波动,不会因单次低利用率重置
2. 区分并发限制和 RPM 限制
3. 探测性扩容避免长期卡在低限制
4. 记录调整历史
"""
# 默认配置 - 使用统一常量
DEFAULT_INITIAL_LIMIT = ConcurrencyDefaults.INITIAL_LIMIT
MIN_CONCURRENT_LIMIT = ConcurrencyDefaults.MIN_CONCURRENT_LIMIT
MAX_CONCURRENT_LIMIT = ConcurrencyDefaults.MAX_CONCURRENT_LIMIT
# AIMD 参数
INCREASE_STEP = ConcurrencyDefaults.INCREASE_STEP
DECREASE_MULTIPLIER = ConcurrencyDefaults.DECREASE_MULTIPLIER
# 滑动窗口参数
UTILIZATION_WINDOW_SIZE = ConcurrencyDefaults.UTILIZATION_WINDOW_SIZE
UTILIZATION_WINDOW_SECONDS = ConcurrencyDefaults.UTILIZATION_WINDOW_SECONDS
UTILIZATION_THRESHOLD = ConcurrencyDefaults.UTILIZATION_THRESHOLD
HIGH_UTILIZATION_RATIO = ConcurrencyDefaults.HIGH_UTILIZATION_RATIO
MIN_SAMPLES_FOR_DECISION = ConcurrencyDefaults.MIN_SAMPLES_FOR_DECISION
# 探测性扩容参数
PROBE_INCREASE_INTERVAL_MINUTES = ConcurrencyDefaults.PROBE_INCREASE_INTERVAL_MINUTES
PROBE_INCREASE_MIN_REQUESTS = ConcurrencyDefaults.PROBE_INCREASE_MIN_REQUESTS
# 记录历史数量
MAX_HISTORY_RECORDS = 20
def __init__(self, strategy: str = AdaptiveStrategy.AIMD):
"""
初始化自适应并发管理器
Args:
strategy: 调整策略
"""
self.strategy = strategy
def handle_429_error(
self,
db: Session,
key: ProviderAPIKey,
rate_limit_info: RateLimitInfo,
current_concurrent: Optional[int] = None,
) -> int:
"""
处理429错误调整并发限制
Args:
db: 数据库会话
key: API Key对象
rate_limit_info: 速率限制信息
current_concurrent: 当前并发数
Returns:
调整后的并发限制
"""
# max_concurrent=NULL 表示启用自适应max_concurrent=数字 表示固定限制
is_adaptive = key.max_concurrent is None
if not is_adaptive:
logger.debug(
f"Key {key.id} 设置了固定并发限制 ({key.max_concurrent}),跳过自适应调整"
)
return int(key.max_concurrent) # type: ignore[arg-type]
# 更新429统计
key.last_429_at = datetime.now(timezone.utc) # type: ignore[assignment]
key.last_429_type = rate_limit_info.limit_type # type: ignore[assignment]
key.last_concurrent_peak = current_concurrent # type: ignore[assignment]
# 遇到 429 错误,清空利用率采样窗口(重新开始收集)
key.utilization_samples = [] # type: ignore[assignment]
if rate_limit_info.limit_type == RateLimitType.CONCURRENT:
# 并发限制:减少并发数
key.concurrent_429_count = int(key.concurrent_429_count or 0) + 1 # type: ignore[assignment]
# 获取当前有效限制(自适应模式使用 learned_max_concurrent
old_limit = int(key.learned_max_concurrent or self.DEFAULT_INITIAL_LIMIT)
new_limit = self._decrease_limit(old_limit, current_concurrent)
logger.warning(
f"[CONCURRENT] 并发限制触发: Key {key.id[:8]}... | "
f"当前并发: {current_concurrent} | "
f"调整: {old_limit} -> {new_limit}"
)
# 记录调整历史
self._record_adjustment(
key,
old_limit=old_limit,
new_limit=new_limit,
reason="concurrent_429",
current_concurrent=current_concurrent,
)
# 更新学习到的并发限制
key.learned_max_concurrent = new_limit # type: ignore[assignment]
elif rate_limit_info.limit_type == RateLimitType.RPM:
# RPM限制不调整并发只记录
key.rpm_429_count = int(key.rpm_429_count or 0) + 1 # type: ignore[assignment]
logger.info(
f"[RPM] RPM限制触发: Key {key.id[:8]}... | "
f"retry_after: {rate_limit_info.retry_after}s | "
f"不调整并发限制"
)
else:
# 未知类型:保守处理,轻微减少
logger.warning(
f"[UNKNOWN] 未知429类型: Key {key.id[:8]}... | "
f"当前并发: {current_concurrent} | "
f"保守减少并发"
)
old_limit = int(key.learned_max_concurrent or self.DEFAULT_INITIAL_LIMIT)
new_limit = max(int(old_limit * 0.9), self.MIN_CONCURRENT_LIMIT) # 减少10%
self._record_adjustment(
key,
old_limit=old_limit,
new_limit=new_limit,
reason="unknown_429",
current_concurrent=current_concurrent,
)
key.learned_max_concurrent = new_limit # type: ignore[assignment]
db.flush()
get_batch_committer().mark_dirty(db)
return int(key.learned_max_concurrent or self.DEFAULT_INITIAL_LIMIT)
def handle_success(
self,
db: Session,
key: ProviderAPIKey,
current_concurrent: int,
) -> Optional[int]:
"""
处理成功请求,基于滑动窗口利用率考虑增加并发限制
Args:
db: 数据库会话
key: API Key对象
current_concurrent: 当前并发数(必需,用于计算利用率)
Returns:
调整后的并发限制(如果有调整),否则返回 None
"""
# max_concurrent=NULL 表示启用自适应
is_adaptive = key.max_concurrent is None
if not is_adaptive:
return None
current_limit = int(key.learned_max_concurrent or self.DEFAULT_INITIAL_LIMIT)
# 计算当前利用率
utilization = float(current_concurrent / current_limit) if current_limit > 0 else 0.0
now = datetime.now(timezone.utc)
now_ts = now.timestamp()
# 更新滑动窗口
samples = self._update_utilization_window(key, now_ts, utilization)
# 检查是否满足扩容条件
increase_reason = self._check_increase_conditions(key, samples, now)
if increase_reason and current_limit < self.MAX_CONCURRENT_LIMIT:
old_limit = current_limit
new_limit = self._increase_limit(current_limit)
# 计算窗口统计用于日志
avg_util = sum(s["util"] for s in samples) / len(samples) if samples else 0
high_util_count = sum(1 for s in samples if s["util"] >= self.UTILIZATION_THRESHOLD)
high_util_ratio = high_util_count / len(samples) if samples else 0
logger.info(
f"[INCREASE] {increase_reason}: Key {key.id[:8]}... | "
f"窗口采样: {len(samples)} | "
f"平均利用率: {avg_util:.1%} | "
f"高利用率比例: {high_util_ratio:.1%} | "
f"调整: {old_limit} -> {new_limit}"
)
# 记录调整历史
self._record_adjustment(
key,
old_limit=old_limit,
new_limit=new_limit,
reason=increase_reason,
avg_utilization=round(avg_util, 2),
high_util_ratio=round(high_util_ratio, 2),
sample_count=len(samples),
current_concurrent=current_concurrent,
)
# 更新限制
key.learned_max_concurrent = new_limit # type: ignore[assignment]
# 如果是探测性扩容,更新探测时间
if increase_reason == "probe_increase":
key.last_probe_increase_at = now # type: ignore[assignment]
# 扩容后清空采样窗口,重新开始收集
key.utilization_samples = [] # type: ignore[assignment]
db.flush()
get_batch_committer().mark_dirty(db)
return new_limit
# 定期持久化采样数据每5个采样保存一次
if len(samples) % 5 == 0:
db.flush()
get_batch_committer().mark_dirty(db)
return None
def _update_utilization_window(
self, key: ProviderAPIKey, now_ts: float, utilization: float
) -> List[Dict[str, Any]]:
"""
更新利用率滑动窗口
Args:
key: API Key对象
now_ts: 当前时间戳
utilization: 当前利用率
Returns:
更新后的采样列表
"""
samples: List[Dict[str, Any]] = list(key.utilization_samples or [])
# 添加新采样
samples.append({"ts": now_ts, "util": round(utilization, 3)})
# 移除过期采样(超过时间窗口)
cutoff_ts = now_ts - self.UTILIZATION_WINDOW_SECONDS
samples = [s for s in samples if s["ts"] > cutoff_ts]
# 限制采样数量
if len(samples) > self.UTILIZATION_WINDOW_SIZE:
samples = samples[-self.UTILIZATION_WINDOW_SIZE:]
# 更新到 key 对象
key.utilization_samples = samples # type: ignore[assignment]
return samples
def _check_increase_conditions(
self, key: ProviderAPIKey, samples: List[Dict[str, Any]], now: datetime
) -> Optional[str]:
"""
检查是否满足扩容条件
Args:
key: API Key对象
samples: 利用率采样列表
now: 当前时间
Returns:
扩容原因(如果满足条件),否则返回 None
"""
# 检查是否在冷却期
if self._is_in_cooldown(key):
return None
# 条件1滑动窗口扩容
if len(samples) >= self.MIN_SAMPLES_FOR_DECISION:
high_util_count = sum(1 for s in samples if s["util"] >= self.UTILIZATION_THRESHOLD)
high_util_ratio = high_util_count / len(samples)
if high_util_ratio >= self.HIGH_UTILIZATION_RATIO:
return "high_utilization"
# 条件2探测性扩容长时间无 429 且有流量)
if self._should_probe_increase(key, samples, now):
return "probe_increase"
return None
def _should_probe_increase(
self, key: ProviderAPIKey, samples: List[Dict[str, Any]], now: datetime
) -> bool:
"""
检查是否应该进行探测性扩容
条件:
1. 距上次 429 超过 PROBE_INCREASE_INTERVAL_MINUTES 分钟
2. 距上次探测性扩容超过 PROBE_INCREASE_INTERVAL_MINUTES 分钟
3. 期间有足够的请求量(采样数 >= PROBE_INCREASE_MIN_REQUESTS
4. 平均利用率 > 30%(说明确实有使用需求)
Args:
key: API Key对象
samples: 利用率采样列表
now: 当前时间
Returns:
是否应该探测性扩容
"""
probe_interval_seconds = self.PROBE_INCREASE_INTERVAL_MINUTES * 60
# 检查距上次 429 的时间
if key.last_429_at:
last_429_at = cast(datetime, key.last_429_at)
time_since_429 = (now - last_429_at).total_seconds()
if time_since_429 < probe_interval_seconds:
return False
# 检查距上次探测性扩容的时间
if key.last_probe_increase_at:
last_probe = cast(datetime, key.last_probe_increase_at)
time_since_probe = (now - last_probe).total_seconds()
if time_since_probe < probe_interval_seconds:
return False
# 检查请求量
if len(samples) < self.PROBE_INCREASE_MIN_REQUESTS:
return False
# 检查平均利用率(确保确实有使用需求)
avg_util = sum(s["util"] for s in samples) / len(samples)
if avg_util < 0.3: # 至少 30% 利用率
return False
return True
def _is_in_cooldown(self, key: ProviderAPIKey) -> bool:
"""
检查是否在 429 错误后的冷却期内
Args:
key: API Key对象
Returns:
True 如果在冷却期内,否则 False
"""
if key.last_429_at is None:
return False
last_429_at = cast(datetime, key.last_429_at)
time_since_429 = (datetime.now(timezone.utc) - last_429_at).total_seconds()
cooldown_seconds = ConcurrencyDefaults.COOLDOWN_AFTER_429_MINUTES * 60
return bool(time_since_429 < cooldown_seconds)
def _decrease_limit(
self,
current_limit: int,
current_concurrent: Optional[int] = None,
) -> int:
"""
减少并发限制
策略:
- 如果知道当前并发数设置为当前并发的70%
- 否则,使用乘性减少
"""
if current_concurrent:
# 基于当前并发数减少
new_limit = max(
int(current_concurrent * self.DECREASE_MULTIPLIER), self.MIN_CONCURRENT_LIMIT
)
else:
# 乘性减少
new_limit = max(
int(current_limit * self.DECREASE_MULTIPLIER), self.MIN_CONCURRENT_LIMIT
)
return new_limit
def _increase_limit(self, current_limit: int) -> int:
"""
增加并发限制
策略:加性增加 (+1)
"""
new_limit = min(current_limit + self.INCREASE_STEP, self.MAX_CONCURRENT_LIMIT)
return new_limit
def _record_adjustment(
self,
key: ProviderAPIKey,
old_limit: int,
new_limit: int,
reason: str,
**extra_data: Any,
) -> None:
"""
记录并发调整历史
Args:
key: API Key对象
old_limit: 原限制
new_limit: 新限制
reason: 调整原因
**extra_data: 额外数据
"""
history: List[Dict[str, Any]] = list(key.adjustment_history or [])
record = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"old_limit": old_limit,
"new_limit": new_limit,
"reason": reason,
**extra_data,
}
history.append(record)
# 保留最近N条记录
if len(history) > self.MAX_HISTORY_RECORDS:
history = history[-self.MAX_HISTORY_RECORDS:]
key.adjustment_history = history # type: ignore[assignment]
def get_adjustment_stats(self, key: ProviderAPIKey) -> Dict[str, Any]:
"""
获取调整统计信息
Args:
key: API Key对象
Returns:
统计信息
"""
history: List[Dict[str, Any]] = list(key.adjustment_history or [])
samples: List[Dict[str, Any]] = list(key.utilization_samples or [])
# max_concurrent=NULL 表示自适应,否则为固定限制
is_adaptive = key.max_concurrent is None
current_limit = int(key.learned_max_concurrent or self.DEFAULT_INITIAL_LIMIT)
effective_limit = current_limit if is_adaptive else int(key.max_concurrent) # type: ignore
# 计算窗口统计
avg_utilization: Optional[float] = None
high_util_ratio: Optional[float] = None
if samples:
avg_utilization = sum(s["util"] for s in samples) / len(samples)
high_util_count = sum(1 for s in samples if s["util"] >= self.UTILIZATION_THRESHOLD)
high_util_ratio = high_util_count / len(samples)
last_429_at_str: Optional[str] = None
if key.last_429_at:
last_429_at_str = cast(datetime, key.last_429_at).isoformat()
last_probe_at_str: Optional[str] = None
if key.last_probe_increase_at:
last_probe_at_str = cast(datetime, key.last_probe_increase_at).isoformat()
return {
"adaptive_mode": is_adaptive,
"max_concurrent": key.max_concurrent, # NULL=自适应,数字=固定限制
"effective_limit": effective_limit, # 当前有效限制
"learned_limit": key.learned_max_concurrent, # 学习到的限制
"concurrent_429_count": int(key.concurrent_429_count or 0),
"rpm_429_count": int(key.rpm_429_count or 0),
"last_429_at": last_429_at_str,
"last_429_type": key.last_429_type,
"adjustment_count": len(history),
"recent_adjustments": history[-5:] if history else [],
# 滑动窗口相关
"window_sample_count": len(samples),
"window_avg_utilization": round(avg_utilization, 3) if avg_utilization else None,
"window_high_util_ratio": round(high_util_ratio, 3) if high_util_ratio else None,
"utilization_threshold": self.UTILIZATION_THRESHOLD,
"high_util_ratio_threshold": self.HIGH_UTILIZATION_RATIO,
"min_samples_for_decision": self.MIN_SAMPLES_FOR_DECISION,
# 探测性扩容相关
"last_probe_increase_at": last_probe_at_str,
"probe_increase_interval_minutes": self.PROBE_INCREASE_INTERVAL_MINUTES,
}
def reset_learning(self, db: Session, key: ProviderAPIKey) -> None:
"""
重置学习状态(管理员功能)
Args:
db: 数据库会话
key: API Key对象
"""
logger.info(f"[RESET] 重置学习状态: Key {key.id[:8]}...")
key.learned_max_concurrent = None # type: ignore[assignment]
key.concurrent_429_count = 0 # type: ignore[assignment]
key.rpm_429_count = 0 # type: ignore[assignment]
key.last_429_at = None # type: ignore[assignment]
key.last_429_type = None # type: ignore[assignment]
key.last_concurrent_peak = None # type: ignore[assignment]
key.adjustment_history = [] # type: ignore[assignment]
key.utilization_samples = [] # type: ignore[assignment]
key.last_probe_increase_at = None # type: ignore[assignment]
db.flush()
get_batch_committer().mark_dirty(db)
# 全局单例
_adaptive_manager: Optional[AdaptiveConcurrencyManager] = None
def get_adaptive_manager() -> AdaptiveConcurrencyManager:
"""获取全局自适应管理器单例"""
global _adaptive_manager
if _adaptive_manager is None:
_adaptive_manager = AdaptiveConcurrencyManager()
return _adaptive_manager

View File

@@ -0,0 +1,340 @@
"""
自适应预留比例管理器
根据学习置信度和当前负载动态计算缓存用户预留比例,
解决固定 30% 预留在学习初期和负载变化时的不适应问题。
核心思路:
1. 探测阶段:使用低预留,让系统快速学习真实并发限制
2. 稳定阶段:根据置信度和负载动态调整预留比例
3. 置信度计算综合考虑连续成功次数、429冷却时间、调整历史稳定性
"""
import statistics
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Dict, Optional
from src.core.logger import logger
from src.config.constants import AdaptiveReservationDefaults
if TYPE_CHECKING:
from src.models.database import ProviderAPIKey
@dataclass
class ReservationConfig:
"""预留比例配置(使用统一常量作为默认值)"""
# 探测阶段配置
probe_phase_requests: int = field(
default_factory=lambda: AdaptiveReservationDefaults.PROBE_PHASE_REQUESTS
)
probe_reservation: float = field(
default_factory=lambda: AdaptiveReservationDefaults.PROBE_RESERVATION
)
# 稳定阶段配置
stable_min_reservation: float = field(
default_factory=lambda: AdaptiveReservationDefaults.STABLE_MIN_RESERVATION
)
stable_max_reservation: float = field(
default_factory=lambda: AdaptiveReservationDefaults.STABLE_MAX_RESERVATION
)
# 置信度计算参数
success_count_for_full_confidence: int = field(
default_factory=lambda: AdaptiveReservationDefaults.SUCCESS_COUNT_FOR_FULL_CONFIDENCE
)
cooldown_hours_for_full_confidence: int = field(
default_factory=lambda: AdaptiveReservationDefaults.COOLDOWN_HOURS_FOR_FULL_CONFIDENCE
)
# 负载阈值
low_load_threshold: float = field(
default_factory=lambda: AdaptiveReservationDefaults.LOW_LOAD_THRESHOLD
)
high_load_threshold: float = field(
default_factory=lambda: AdaptiveReservationDefaults.HIGH_LOAD_THRESHOLD
)
@dataclass
class ReservationResult:
"""预留比例计算结果"""
ratio: float # 最终预留比例
phase: str # 当前阶段: "probe" | "stable"
confidence: float # 置信度 (0-1)
load_factor: float # 负载因子 (0-1)
details: Dict[str, Any] # 详细信息
class AdaptiveReservationManager:
"""
自适应预留比例管理器
工作原理:
1. 探测阶段(请求数 < 阈值):
- 使用低预留比例10%),不浪费资源
- 让系统快速探测真实并发限制
2. 稳定阶段(请求数 >= 阈值):
- 根据置信度和负载动态计算预留比例
- 置信度高 + 负载高 = 高预留(保护缓存用户)
- 置信度低或负载低 = 低预留(避免浪费)
置信度因素:
- 连续成功次数:越多说明当前限制越准确
- 429冷却时间距离上次429越久越稳定
- 调整历史稳定性:最近调整的方差越小越稳定
"""
def __init__(self, config: Optional[ReservationConfig] = None):
self.config = config or ReservationConfig()
self._cache: Dict[str, ReservationResult] = {} # 简单的内存缓存
def calculate_reservation(
self,
key: "ProviderAPIKey",
current_concurrent: int = 0,
effective_limit: Optional[int] = None,
) -> ReservationResult:
"""
计算当前应使用的预留比例
Args:
key: ProviderAPIKey 对象
current_concurrent: 当前并发数
effective_limit: 有效并发限制(学习值或配置值)
Returns:
ReservationResult 包含预留比例和详细信息
"""
# 计算总请求数(用于判断阶段)
total_requests = self._get_total_requests(key)
# 计算负载率
load_ratio = self._calculate_load_ratio(current_concurrent, effective_limit)
# 阶段1: 探测阶段
if total_requests < self.config.probe_phase_requests:
return ReservationResult(
ratio=self.config.probe_reservation,
phase="probe",
confidence=0.0,
load_factor=load_ratio,
details={
"total_requests": total_requests,
"probe_threshold": self.config.probe_phase_requests,
"reason": "探测阶段,使用低预留让系统学习真实限制",
},
)
# 阶段2: 稳定阶段
confidence = self._calculate_confidence(key)
ratio = self._calculate_stable_ratio(confidence, load_ratio)
return ReservationResult(
ratio=ratio,
phase="stable",
confidence=confidence,
load_factor=load_ratio,
details={
"total_requests": total_requests,
"confidence_factors": self._get_confidence_breakdown(key),
"reason": self._get_ratio_reason(confidence, load_ratio),
},
)
def _get_total_requests(self, key: "ProviderAPIKey") -> int:
"""获取总请求数(用于判断是否过了探测阶段)"""
# 使用总请求计数作为基准
request_count = key.request_count or 0
# 如果 request_count 为 0使用 429 计数 + 成功计数作为近似值
if request_count == 0:
concurrent_429 = key.concurrent_429_count or 0
rpm_429 = key.rpm_429_count or 0
success_count = key.success_count or 0
# 调整历史中的记录数也可以参考
history_count = len(key.adjustment_history or []) * 10
return concurrent_429 + rpm_429 + success_count + history_count
return request_count
def _calculate_load_ratio(
self, current_concurrent: int, effective_limit: Optional[int]
) -> float:
"""计算当前负载率"""
if not effective_limit or effective_limit <= 0:
return 0.0
return min(current_concurrent / effective_limit, 1.0)
def _calculate_confidence(self, key: "ProviderAPIKey") -> float:
"""
计算学习值的置信度 (0-1)
三个因素各占一定权重:
- 成功率40%(基于总成功数/总请求数)
- 429冷却时间30%
- 调整历史稳定性30%
"""
scores = self._get_confidence_breakdown(key)
return min(
scores["success_score"] + scores["cooldown_score"] + scores["stability_score"], 1.0
)
def _get_confidence_breakdown(self, key: "ProviderAPIKey") -> Dict[str, float]:
"""获取置信度各因素的详细分数"""
# 因素1: 成功率(权重 40%
# 使用成功率而非连续成功次数,更准确反映 Key 的稳定性
request_count = key.request_count or 0
success_count = key.success_count or 0
if request_count >= self.config.success_count_for_full_confidence:
# 请求数足够时,根据成功率计算
success_rate = success_count / request_count if request_count > 0 else 0
success_score = success_rate * 0.4
elif request_count > 0:
# 请求数不足时,按比例折算
progress_ratio = request_count / self.config.success_count_for_full_confidence
success_rate = success_count / request_count
success_score = success_rate * progress_ratio * 0.4
else:
success_score = 0.0
# 因素2: 429冷却时间权重 30%
if key.last_429_at:
now = datetime.now(timezone.utc)
# 确保 last_429_at 有时区信息
last_429 = key.last_429_at
if last_429.tzinfo is None:
last_429 = last_429.replace(tzinfo=timezone.utc)
hours_since_429 = (now - last_429).total_seconds() / 3600
cooldown_ratio = min(
hours_since_429 / self.config.cooldown_hours_for_full_confidence, 1.0
)
cooldown_score = cooldown_ratio * 0.3
else:
# 从未触发 429给满分
cooldown_score = 0.3
# 因素3: 调整历史稳定性(权重 30%
history = key.adjustment_history or []
if len(history) >= 3:
# 取最近的调整记录
recent = history[-5:] if len(history) >= 5 else history
limits = [h.get("new_limit", 0) for h in recent if h.get("new_limit")]
if len(limits) >= 2:
try:
variance = statistics.variance(limits)
# 方差越小越稳定方差为10时分数接近0
stability_ratio = max(0, 1 - variance / 10)
stability_score = stability_ratio * 0.3
except statistics.StatisticsError:
stability_score = 0.15
else:
stability_score = 0.15
else:
# 历史数据不足,给一半分
stability_score = 0.15
# 计算成功率用于返回
success_rate_pct = (success_count / request_count * 100) if request_count > 0 else None
return {
"success_score": round(success_score, 3),
"cooldown_score": round(cooldown_score, 3),
"stability_score": round(stability_score, 3),
"request_count": request_count,
"success_count": success_count,
"success_rate": round(success_rate_pct, 1) if success_rate_pct is not None else None,
"hours_since_429": (
round(
(
datetime.now(timezone.utc) - key.last_429_at.replace(tzinfo=timezone.utc)
).total_seconds()
/ 3600,
1,
)
if key.last_429_at
else None
),
"history_count": len(history),
}
def _calculate_stable_ratio(self, confidence: float, load_ratio: float) -> float:
"""
计算稳定阶段的预留比例
策略:
- 低负载(<50%):使用最小预留,槽位充足无需过多预留
- 中等负载50-80%):根据置信度线性增加预留
- 高负载(>80%):根据置信度使用较高预留保护缓存用户
"""
min_r = self.config.stable_min_reservation
max_r = self.config.stable_max_reservation
if load_ratio < self.config.low_load_threshold:
# 低负载:使用最小预留
return min_r
if load_ratio < self.config.high_load_threshold:
# 中等负载:根据置信度和负载线性插值
# 负载越高、置信度越高,预留越多
load_factor = (load_ratio - self.config.low_load_threshold) / (
self.config.high_load_threshold - self.config.low_load_threshold
)
return min_r + confidence * load_factor * (max_r - min_r)
# 高负载:根据置信度决定预留比例
# 置信度高 → 接近最大预留
# 置信度低 → 保守预留(避免基于不准确的学习值过度预留)
return min_r + confidence * (max_r - min_r)
def _get_ratio_reason(self, confidence: float, load_ratio: float) -> str:
"""生成预留比例的解释"""
if load_ratio < self.config.low_load_threshold:
return f"低负载({load_ratio:.0%}),使用最小预留"
if confidence < 0.3:
return f"置信度低({confidence:.0%}),保守预留避免浪费"
if confidence > 0.7 and load_ratio > self.config.high_load_threshold:
return f"高置信度({confidence:.0%})+高负载({load_ratio:.0%}),使用较高预留保护缓存用户"
return f"置信度{confidence:.0%},负载{load_ratio:.0%},动态计算预留"
def get_stats(self) -> Dict[str, Any]:
"""获取管理器统计信息"""
return {
"config": {
"probe_phase_requests": self.config.probe_phase_requests,
"probe_reservation": self.config.probe_reservation,
"stable_min_reservation": self.config.stable_min_reservation,
"stable_max_reservation": self.config.stable_max_reservation,
"low_load_threshold": self.config.low_load_threshold,
"high_load_threshold": self.config.high_load_threshold,
},
}
# 全局单例
_reservation_manager: Optional[AdaptiveReservationManager] = None
def get_adaptive_reservation_manager() -> AdaptiveReservationManager:
"""获取全局自适应预留管理器单例"""
global _reservation_manager
if _reservation_manager is None:
_reservation_manager = AdaptiveReservationManager()
return _reservation_manager
def reset_adaptive_reservation_manager():
"""重置全局单例(用于测试)"""
global _reservation_manager
_reservation_manager = None

View File

@@ -0,0 +1,582 @@
"""
并发管理器 - 支持 Redis 或内存的并发控制
功能:
1. Endpoint 级别的并发限制
2. ProviderAPIKey 级别的并发限制
3. 分布式环境下优先使用 Redis多实例共享
4. 在开发/单实例场景下自动降级为内存计数
5. 自动释放和异常处理Redis 提供 TTL内存模式请确保手动释放
"""
import asyncio
import math
import os
from contextlib import asynccontextmanager
from datetime import timedelta
from typing import Optional, Tuple
import redis.asyncio as aioredis
from src.core.logger import logger
class ConcurrencyManager:
"""分布式并发管理器"""
_instance: Optional["ConcurrencyManager"] = None
_redis: Optional[aioredis.Redis] = None
def __new__(cls):
"""单例模式"""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
"""初始化内存后端结构(只执行一次)"""
if hasattr(self, "_memory_initialized"):
return
self._memory_lock: asyncio.Lock = asyncio.Lock()
self._memory_endpoint_counts: dict[str, int] = {}
self._memory_key_counts: dict[str, int] = {}
self._memory_initialized = True
async def initialize(self) -> None:
"""初始化 Redis 连接"""
if self._redis is not None:
return
# 优先使用 REDIS_URL如果没有则根据密码构建 URL
redis_url = os.getenv("REDIS_URL")
if not redis_url:
# 本地开发模式:从 REDIS_PASSWORD 构建 URL
redis_password = os.getenv("REDIS_PASSWORD")
if redis_password:
redis_url = f"redis://:{redis_password}@localhost:6379/0"
else:
redis_url = "redis://localhost:6379/0"
try:
self._redis = await aioredis.from_url(
redis_url,
encoding="utf-8",
decode_responses=True,
socket_timeout=5.0,
socket_connect_timeout=5.0,
)
# 测试连接
await self._redis.ping()
# 脱敏显示(隐藏密码)
safe_url = redis_url.split("@")[-1] if "@" in redis_url else redis_url
logger.info(f"[OK] Redis 连接成功: {safe_url}")
except Exception as e:
logger.error(f"[ERROR] Redis 连接失败: {e}")
logger.warning("[WARN] 并发控制将被禁用(仅在单实例环境下安全)")
self._redis = None
async def close(self) -> None:
"""关闭 Redis 连接"""
if self._redis:
await self._redis.close()
self._redis = None
logger.info("Redis 连接已关闭")
def _get_endpoint_key(self, endpoint_id: str) -> str:
"""获取 Endpoint 并发计数的 Redis Key"""
return f"concurrency:endpoint:{endpoint_id}"
def _get_key_key(self, key_id: str) -> str:
"""获取 ProviderAPIKey 并发计数的 Redis Key"""
return f"concurrency:key:{key_id}"
async def get_current_concurrency(
self, endpoint_id: Optional[str] = None, key_id: Optional[str] = None
) -> Tuple[int, int]:
"""
获取当前并发数
Args:
endpoint_id: Endpoint ID可选
key_id: ProviderAPIKey ID可选
Returns:
(endpoint_concurrency, key_concurrency)
"""
if self._redis is None:
async with self._memory_lock:
endpoint_count = (
self._memory_endpoint_counts.get(endpoint_id, 0) if endpoint_id else 0
)
key_count = self._memory_key_counts.get(key_id, 0) if key_id else 0
return endpoint_count, key_count
endpoint_count = 0
key_count = 0
try:
if endpoint_id:
endpoint_key = self._get_endpoint_key(endpoint_id)
result = await self._redis.get(endpoint_key)
endpoint_count = int(result) if result else 0
if key_id:
key_key = self._get_key_key(key_id)
result = await self._redis.get(key_key)
key_count = int(result) if result else 0
except Exception as e:
logger.error(f"获取并发数失败: {e}")
return endpoint_count, key_count
async def check_available(
self,
endpoint_id: str,
endpoint_max_concurrent: Optional[int],
key_id: str,
key_max_concurrent: Optional[int],
) -> bool:
"""
检查是否可以获取并发槽位(不实际获取)
Args:
endpoint_id: Endpoint ID
endpoint_max_concurrent: Endpoint 最大并发数None 表示不限制)
key_id: ProviderAPIKey ID
key_max_concurrent: Key 最大并发数None 表示不限制)
Returns:
是否可用True/False
"""
if self._redis is None:
async with self._memory_lock:
endpoint_count = self._memory_endpoint_counts.get(endpoint_id, 0)
key_count = self._memory_key_counts.get(key_id, 0)
if (
endpoint_max_concurrent is not None
and endpoint_count >= endpoint_max_concurrent
):
return False
if key_max_concurrent is not None and key_count >= key_max_concurrent:
return False
return True
endpoint_count, key_count = await self.get_current_concurrency(endpoint_id, key_id)
# 检查 Endpoint 级别限制
if endpoint_max_concurrent is not None and endpoint_count >= endpoint_max_concurrent:
return False
# 检查 Key 级别限制
if key_max_concurrent is not None and key_count >= key_max_concurrent:
return False
return True
async def acquire_slot(
self,
endpoint_id: str,
endpoint_max_concurrent: Optional[int],
key_id: str,
key_max_concurrent: Optional[int],
is_cached_user: bool = False, # 新增:是否是缓存用户
cache_reservation_ratio: float = 0.3, # 新增:缓存预留比例
ttl_seconds: int = 600, # 10分钟 TTL防止死锁
) -> bool:
"""
尝试获取并发槽位(支持缓存用户优先级)
Args:
endpoint_id: Endpoint ID
endpoint_max_concurrent: Endpoint 最大并发数None 表示不限制)
key_id: ProviderAPIKey ID
key_max_concurrent: Key 最大并发数None 表示不限制)
is_cached_user: 是否是缓存用户(缓存用户可使用全部槽位)
cache_reservation_ratio: 缓存预留比例默认30%,只对新用户生效)
ttl_seconds: TTL 秒数,防止异常情况下的死锁
Returns:
是否成功获取True/False
缓存预留机制说明:
- 假设 key_max_concurrent = 10, cache_reservation_ratio = 0.3
- 新用户最多使用: 7个槽位 (10 * (1 - 0.3))
- 缓存用户最多使用: 10个槽位全部
- 预留的3个槽位专门给缓存用户保证他们的请求优先
"""
if self._redis is None:
async with self._memory_lock:
endpoint_count = self._memory_endpoint_counts.get(endpoint_id, 0)
key_count = self._memory_key_counts.get(key_id, 0)
# Endpoint 限制
if (
endpoint_max_concurrent is not None
and endpoint_count >= endpoint_max_concurrent
):
return False
# Key 限制,包含缓存预留
if key_max_concurrent is not None:
if is_cached_user:
if key_count >= key_max_concurrent:
return False
else:
available_for_new = max(
1, math.ceil(key_max_concurrent * (1 - cache_reservation_ratio))
)
if key_count >= available_for_new:
return False
# 通过限制,更新计数
self._memory_endpoint_counts[endpoint_id] = endpoint_count + 1
self._memory_key_counts[key_id] = key_count + 1
return True
endpoint_key = self._get_endpoint_key(endpoint_id)
key_key = self._get_key_key(key_id)
try:
# 使用 Lua 脚本保证原子性(新增缓存预留逻辑)
lua_script = """
local endpoint_key = KEYS[1]
local key_key = KEYS[2]
local endpoint_max = tonumber(ARGV[1])
local key_max = tonumber(ARGV[2])
local ttl = tonumber(ARGV[3])
local is_cached = tonumber(ARGV[4]) -- 0=新用户, 1=缓存用户
local cache_ratio = tonumber(ARGV[5]) -- 缓存预留比例
-- 获取当前值
local endpoint_count = tonumber(redis.call('GET', endpoint_key) or '0')
local key_count = tonumber(redis.call('GET', key_key) or '0')
-- 检查 endpoint 限制(-1 表示不限制)
if endpoint_max >= 0 and endpoint_count >= endpoint_max then
return 0 -- 失败endpoint 已满
end
-- 检查 key 限制(支持缓存预留)
if key_max >= 0 then
if is_cached == 0 then
-- 新用户:只能使用 (1 - cache_ratio) 的槽位
local available_for_new = math.floor(key_max * (1 - cache_ratio))
if key_count >= available_for_new then
return 0 -- 失败:新用户配额已满
end
else
-- 缓存用户:可以使用全部槽位
if key_count >= key_max then
return 0 -- 失败:总配额已满
end
end
end
-- 增加计数
redis.call('INCR', endpoint_key)
redis.call('EXPIRE', endpoint_key, ttl)
redis.call('INCR', key_key)
redis.call('EXPIRE', key_key, ttl)
return 1 -- 成功
"""
# 执行脚本
result = await self._redis.eval(
lua_script,
2, # 2 个 KEYS
endpoint_key,
key_key,
endpoint_max_concurrent if endpoint_max_concurrent is not None else -1,
key_max_concurrent if key_max_concurrent is not None else -1,
ttl_seconds,
1 if is_cached_user else 0, # 缓存用户标志
cache_reservation_ratio, # 预留比例
)
success = result == 1
if success:
user_type = "缓存用户" if is_cached_user else "新用户"
logger.debug(
f"[OK] 获取并发槽位成功: endpoint={endpoint_id}, key={key_id}, "
f"类型={user_type}"
)
else:
endpoint_count, key_count = await self.get_current_concurrency(endpoint_id, key_id)
# 计算新用户可用槽位
if key_max_concurrent and not is_cached_user:
available_for_new = int(key_max_concurrent * (1 - cache_reservation_ratio))
user_info = f"新用户配额={available_for_new}, 当前={key_count}"
else:
user_info = f"缓存用户, 当前={key_count}/{key_max_concurrent}"
logger.warning(
f"[WARN] 并发槽位已满: endpoint={endpoint_id}({endpoint_count}/{endpoint_max_concurrent}), "
f"key={key_id}({user_info})"
)
return success
except Exception as e:
logger.error(f"获取并发槽位失败,降级到内存模式: {e}")
# Redis 异常时降级到内存模式进行保守限流
# 使用较低的限制值(原限制的 50%)避免上游 API 被打爆
async with self._memory_lock:
endpoint_count = self._memory_endpoint_counts.get(endpoint_id, 0)
key_count = self._memory_key_counts.get(key_id, 0)
# 降级模式下使用更保守的限制50%
fallback_endpoint_limit = (
max(1, endpoint_max_concurrent // 2)
if endpoint_max_concurrent is not None
else None
)
fallback_key_limit = (
max(1, key_max_concurrent // 2) if key_max_concurrent is not None else None
)
if (
fallback_endpoint_limit is not None
and endpoint_count >= fallback_endpoint_limit
):
logger.warning(
f"[FALLBACK] Endpoint 并发达到降级限制: {endpoint_count}/{fallback_endpoint_limit}"
)
return False
if fallback_key_limit is not None and key_count >= fallback_key_limit:
logger.warning(
f"[FALLBACK] Key 并发达到降级限制: {key_count}/{fallback_key_limit}"
)
return False
# 更新内存计数
self._memory_endpoint_counts[endpoint_id] = endpoint_count + 1
self._memory_key_counts[key_id] = key_count + 1
logger.debug(
f"[FALLBACK] 使用内存模式获取槽位: endpoint={endpoint_id}, key={key_id}"
)
return True
async def release_slot(self, endpoint_id: str, key_id: str) -> None:
"""
释放并发槽位
Args:
endpoint_id: Endpoint ID
key_id: ProviderAPIKey ID
"""
if self._redis is None:
async with self._memory_lock:
if endpoint_id in self._memory_endpoint_counts:
self._memory_endpoint_counts[endpoint_id] = max(
0, self._memory_endpoint_counts[endpoint_id] - 1
)
if self._memory_endpoint_counts[endpoint_id] == 0:
self._memory_endpoint_counts.pop(endpoint_id, None)
if key_id in self._memory_key_counts:
self._memory_key_counts[key_id] = max(0, self._memory_key_counts[key_id] - 1)
if self._memory_key_counts[key_id] == 0:
self._memory_key_counts.pop(key_id, None)
return
endpoint_key = self._get_endpoint_key(endpoint_id)
key_key = self._get_key_key(key_id)
try:
# 使用 Lua 脚本保证原子性(不会减到负数)
lua_script = """
local endpoint_key = KEYS[1]
local key_key = KEYS[2]
local endpoint_count = tonumber(redis.call('GET', endpoint_key) or '0')
local key_count = tonumber(redis.call('GET', key_key) or '0')
if endpoint_count > 0 then
redis.call('DECR', endpoint_key)
end
if key_count > 0 then
redis.call('DECR', key_key)
end
return 1
"""
await self._redis.eval(lua_script, 2, endpoint_key, key_key)
logger.debug(f"[OK] 释放并发槽位: endpoint={endpoint_id}, key={key_id}")
except Exception as e:
logger.error(f"释放并发槽位失败: {e}")
@asynccontextmanager
async def concurrency_guard(
self,
endpoint_id: str,
endpoint_max_concurrent: Optional[int],
key_id: str,
key_max_concurrent: Optional[int],
is_cached_user: bool = False, # 新增:是否是缓存用户
cache_reservation_ratio: float = 0.3, # 新增:缓存预留比例
):
"""
并发控制上下文管理器(支持缓存用户优先级)
用法:
async with manager.concurrency_guard(
endpoint_id, endpoint_max, key_id, key_max,
is_cached_user=True # 缓存用户
):
# 执行请求
response = await send_request(...)
如果获取失败,会抛出 ConcurrencyLimitError 异常
"""
# 尝试获取槽位(传递缓存用户参数)
acquired = await self.acquire_slot(
endpoint_id,
endpoint_max_concurrent,
key_id,
key_max_concurrent,
is_cached_user,
cache_reservation_ratio,
)
if not acquired:
from src.core.exceptions import ConcurrencyLimitError
user_type = "缓存用户" if is_cached_user else "新用户"
raise ConcurrencyLimitError(
f"并发限制已达上限: endpoint={endpoint_id}, key={key_id}, 类型={user_type}"
)
# 记录开始时间和状态
import time
slot_acquired_at = time.time()
exception_occurred = False
try:
yield # 执行请求
except Exception as e:
# 记录异常
exception_occurred = True
raise
finally:
# 计算槽位占用时长
slot_duration = time.time() - slot_acquired_at
# 记录 Prometheus 指标
try:
from src.core.metrics import (
concurrency_slot_duration_seconds,
concurrency_slot_release_total,
)
# 记录槽位占用时长分布
concurrency_slot_duration_seconds.labels(
key_id=key_id[:8] if key_id else "unknown", # 只记录前8位
exception=str(exception_occurred),
).observe(slot_duration)
# 记录槽位释放计数
concurrency_slot_release_total.labels(
key_id=key_id[:8] if key_id else "unknown", exception=str(exception_occurred)
).inc()
# 告警:槽位占用时间过长(超过 60 秒)
if slot_duration > 60:
logger.warning(
f"[WARN] 并发槽位占用时间过长: "
f"key_id={key_id[:8] if key_id else 'unknown'}..., "
f"duration={slot_duration:.1f}s, "
f"exception={exception_occurred}"
)
except Exception as metric_error:
# 指标记录失败不应影响业务逻辑
logger.debug(f"记录并发指标失败: {metric_error}")
# 自动释放槽位(即使发生异常)
await self.release_slot(endpoint_id, key_id)
async def reset_concurrency(
self, endpoint_id: Optional[str] = None, key_id: Optional[str] = None
) -> None:
"""
重置并发计数(管理功能,慎用)
Args:
endpoint_id: Endpoint ID可选None 表示重置所有 endpoint
key_id: ProviderAPIKey ID可选None 表示重置所有 key
"""
if self._redis is None:
async with self._memory_lock:
if endpoint_id:
self._memory_endpoint_counts.pop(endpoint_id, None)
logger.info(f"[RESET] 重置 Endpoint 并发计数(内存): {endpoint_id}")
else:
count = len(self._memory_endpoint_counts)
self._memory_endpoint_counts.clear()
if count:
logger.info(f"[RESET] 重置所有 Endpoint 并发计数(内存): {count}")
if key_id:
self._memory_key_counts.pop(key_id, None)
logger.info(f"[RESET] 重置 Key 并发计数(内存): {key_id}")
else:
count = len(self._memory_key_counts)
self._memory_key_counts.clear()
if count:
logger.info(f"[RESET] 重置所有 Key 并发计数(内存): {count}")
return
try:
if endpoint_id:
endpoint_key = self._get_endpoint_key(endpoint_id)
await self._redis.delete(endpoint_key)
logger.info(f"[RESET] 重置 Endpoint 并发计数: {endpoint_id}")
else:
# 重置所有 endpoint
keys = await self._redis.keys("concurrency:endpoint:*")
if keys:
await self._redis.delete(*keys)
logger.info(f"[RESET] 重置所有 Endpoint 并发计数: {len(keys)}")
if key_id:
key_key = self._get_key_key(key_id)
await self._redis.delete(key_key)
logger.info(f"[RESET] 重置 Key 并发计数: {key_id}")
else:
# 重置所有 key
keys = await self._redis.keys("concurrency:key:*")
if keys:
await self._redis.delete(*keys)
logger.info(f"[RESET] 重置所有 Key 并发计数: {len(keys)}")
except Exception as e:
logger.error(f"重置并发计数失败: {e}")
# 全局单例
_concurrency_manager: Optional[ConcurrencyManager] = None
async def get_concurrency_manager() -> ConcurrencyManager:
"""获取全局 ConcurrencyManager 实例"""
global _concurrency_manager
if _concurrency_manager is None:
_concurrency_manager = ConcurrencyManager()
await _concurrency_manager.initialize()
return _concurrency_manager

View File

@@ -0,0 +1,333 @@
"""
速率限制检测器 - 解析429响应头区分并发限制和RPM限制
"""
from datetime import datetime, timezone
from typing import Any, Dict, Optional, Tuple
from src.core.logger import logger
class RateLimitType:
"""速率限制类型"""
CONCURRENT = "concurrent" # 并发限制
RPM = "rpm" # 每分钟请求数限制
DAILY = "daily" # 每日限制
MONTHLY = "monthly" # 每月限制
UNKNOWN = "unknown" # 未知类型
class RateLimitInfo:
"""速率限制信息"""
def __init__(
self,
limit_type: str,
retry_after: Optional[int] = None,
limit_value: Optional[int] = None,
remaining: Optional[int] = None,
reset_at: Optional[datetime] = None,
current_usage: Optional[int] = None,
raw_headers: Optional[Dict[str, str]] = None,
):
self.limit_type = limit_type
self.retry_after = retry_after # 需要等待的秒数
self.limit_value = limit_value # 限制值
self.remaining = remaining # 剩余配额
self.reset_at = reset_at # 重置时间
self.current_usage = current_usage # 当前使用量
self.raw_headers = raw_headers or {}
def __repr__(self) -> str:
return (
f"RateLimitInfo(type={self.limit_type}, "
f"retry_after={self.retry_after}, "
f"limit={self.limit_value}, "
f"remaining={self.remaining})"
)
class RateLimitDetector:
"""
速率限制检测器
支持的提供商:
- Anthropic Claude API
- OpenAI API
- 通用 HTTP 标准头
"""
@staticmethod
def detect_from_headers(
headers: Dict[str, str],
provider_name: str = "unknown",
current_concurrent: Optional[int] = None,
) -> RateLimitInfo:
"""
从响应头中检测速率限制类型
Args:
headers: 429响应的HTTP头
provider_name: 提供商名称(用于选择解析策略)
current_concurrent: 当前并发数(用于判断是否为并发限制)
Returns:
RateLimitInfo对象
"""
# 标准化header key (转小写)
headers_lower = {k.lower(): v for k, v in headers.items()}
# 根据提供商选择解析策略
if "anthropic" in provider_name.lower() or "claude" in provider_name.lower():
return RateLimitDetector._parse_anthropic_headers(headers_lower, current_concurrent)
elif "openai" in provider_name.lower():
return RateLimitDetector._parse_openai_headers(headers_lower, current_concurrent)
else:
return RateLimitDetector._parse_generic_headers(headers_lower, current_concurrent)
@staticmethod
def _parse_anthropic_headers(
headers: Dict[str, str],
current_concurrent: Optional[int] = None,
) -> RateLimitInfo:
"""
解析 Anthropic Claude API 的速率限制头
常见头部:
- anthropic-ratelimit-requests-limit: 50
- anthropic-ratelimit-requests-remaining: 0
- anthropic-ratelimit-requests-reset: 2024-01-01T00:00:00Z
- anthropic-ratelimit-tokens-limit: 100000
- anthropic-ratelimit-tokens-remaining: 50000
- retry-after: 60
"""
retry_after = RateLimitDetector._parse_retry_after(headers)
# 获取请求限制信息
requests_limit = RateLimitDetector._parse_int(
headers.get("anthropic-ratelimit-requests-limit")
)
requests_remaining = RateLimitDetector._parse_int(
headers.get("anthropic-ratelimit-requests-remaining")
)
requests_reset = RateLimitDetector._parse_datetime(
headers.get("anthropic-ratelimit-requests-reset")
)
# 判断限制类型
# 1. 明确的 RPM 限制:请求数剩余为 0
if requests_remaining is not None and requests_remaining == 0:
return RateLimitInfo(
limit_type=RateLimitType.RPM,
retry_after=retry_after,
limit_value=requests_limit,
remaining=requests_remaining,
reset_at=requests_reset,
raw_headers=headers,
)
# 2. 可能的并发限制判断(多条件综合)
# 条件:当前并发数存在,且 remaining > 0说明不是 RPM 耗尽)
# 同时 retry_after 较短(并发限制通常 retry_after 较短,如 1-10 秒)
is_likely_concurrent = (
current_concurrent is not None
and current_concurrent >= 2 # 至少有 2 个并发
and (requests_remaining is None or requests_remaining > 0) # RPM 未耗尽
and (retry_after is None or retry_after <= 30) # 短暂等待
)
if is_likely_concurrent:
logger.info(
f"检测到可能的并发限制: current_concurrent={current_concurrent}, "
f"remaining={requests_remaining}, retry_after={retry_after}"
)
return RateLimitInfo(
limit_type=RateLimitType.CONCURRENT,
retry_after=retry_after,
current_usage=current_concurrent,
raw_headers=headers,
)
# 3. 未知类型
return RateLimitInfo(
limit_type=RateLimitType.UNKNOWN,
retry_after=retry_after,
raw_headers=headers,
)
@staticmethod
def _parse_openai_headers(
headers: Dict[str, str],
current_concurrent: Optional[int] = None,
) -> RateLimitInfo:
"""
解析 OpenAI API 的速率限制头
常见头部:
- x-ratelimit-limit-requests: 3500
- x-ratelimit-remaining-requests: 0
- x-ratelimit-reset-requests: 2024-01-01T00:00:00Z
- x-ratelimit-limit-tokens: 90000
- x-ratelimit-remaining-tokens: 50000
- retry-after: 60
"""
retry_after = RateLimitDetector._parse_retry_after(headers)
# 获取请求限制信息
requests_limit = RateLimitDetector._parse_int(headers.get("x-ratelimit-limit-requests"))
requests_remaining = RateLimitDetector._parse_int(
headers.get("x-ratelimit-remaining-requests")
)
requests_reset = RateLimitDetector._parse_datetime(
headers.get("x-ratelimit-reset-requests")
)
# 判断限制类型
# 1. 明确的 RPM 限制
if requests_remaining is not None and requests_remaining == 0:
return RateLimitInfo(
limit_type=RateLimitType.RPM,
retry_after=retry_after,
limit_value=requests_limit,
remaining=requests_remaining,
reset_at=requests_reset,
raw_headers=headers,
)
# 2. 可能的并发限制(多条件综合判断)
is_likely_concurrent = (
current_concurrent is not None
and current_concurrent >= 2
and (requests_remaining is None or requests_remaining > 0)
and (retry_after is None or retry_after <= 30)
)
if is_likely_concurrent:
return RateLimitInfo(
limit_type=RateLimitType.CONCURRENT,
retry_after=retry_after,
current_usage=current_concurrent,
raw_headers=headers,
)
# 3. 未知类型
return RateLimitInfo(
limit_type=RateLimitType.UNKNOWN,
retry_after=retry_after,
raw_headers=headers,
)
@staticmethod
def _parse_generic_headers(
headers: Dict[str, str],
current_concurrent: Optional[int] = None,
) -> RateLimitInfo:
"""
解析通用的速率限制头
标准头部:
- retry-after: 60
- x-ratelimit-limit: 100
- x-ratelimit-remaining: 0
- x-ratelimit-reset: 1609459200
"""
retry_after = RateLimitDetector._parse_retry_after(headers)
limit_value = RateLimitDetector._parse_int(headers.get("x-ratelimit-limit"))
remaining = RateLimitDetector._parse_int(headers.get("x-ratelimit-remaining"))
# 1. 明确的 RPM 限制
if remaining is not None and remaining == 0:
return RateLimitInfo(
limit_type=RateLimitType.RPM,
retry_after=retry_after,
limit_value=limit_value,
remaining=remaining,
raw_headers=headers,
)
# 2. 可能的并发限制
is_likely_concurrent = (
current_concurrent is not None
and current_concurrent >= 2
and (remaining is None or remaining > 0)
and (retry_after is None or retry_after <= 30)
)
if is_likely_concurrent:
return RateLimitInfo(
limit_type=RateLimitType.CONCURRENT,
retry_after=retry_after,
current_usage=current_concurrent,
raw_headers=headers,
)
# 3. 未知类型
return RateLimitInfo(
limit_type=RateLimitType.UNKNOWN,
retry_after=retry_after,
raw_headers=headers,
)
@staticmethod
def _parse_retry_after(headers: Dict[str, str]) -> Optional[int]:
"""解析 Retry-After 头"""
retry_after_str = headers.get("retry-after")
if not retry_after_str:
return None
try:
# 尝试解析为整数(秒数)
return int(retry_after_str)
except ValueError:
# 尝试解析为HTTP日期格式
try:
retry_date = datetime.strptime(retry_after_str, "%a, %d %b %Y %H:%M:%S %Z")
delta = retry_date - datetime.now(timezone.utc)
return max(int(delta.total_seconds()), 0)
except Exception:
return None
@staticmethod
def _parse_int(value: Optional[str]) -> Optional[int]:
"""安全解析整数"""
if not value:
return None
try:
return int(value)
except (ValueError, TypeError):
return None
@staticmethod
def _parse_datetime(value: Optional[str]) -> Optional[datetime]:
"""安全解析ISO 8601日期时间"""
if not value:
return None
try:
# 尝试解析 ISO 8601 格式
if value.endswith("Z"):
value = value[:-1] + "+00:00"
return datetime.fromisoformat(value)
except (ValueError, TypeError):
return None
# 便捷函数
def detect_rate_limit_type(
headers: Dict[str, str],
provider_name: str = "unknown",
current_concurrent: Optional[int] = None,
) -> RateLimitInfo:
"""
检测速率限制类型(便捷函数)
Args:
headers: 429响应头
provider_name: 提供商名称
current_concurrent: 当前并发数
Returns:
RateLimitInfo对象
"""
return RateLimitDetector.detect_from_headers(headers, provider_name, current_concurrent)

View File

@@ -0,0 +1,351 @@
"""
IP 级别的速率限制服务
提供基于 IP 地址的速率限制,防止暴力破解和 DDoS 攻击
"""
import ipaddress
from datetime import datetime, timezone
from typing import Dict, Optional, Set
from src.clients.redis_client import get_redis_client
from src.core.logger import logger
class IPRateLimiter:
"""IP 速率限制服务"""
# Redis key 前缀
RATE_LIMIT_PREFIX = "ip:rate_limit:"
BLACKLIST_PREFIX = "ip:blacklist:"
WHITELIST_KEY = "ip:whitelist"
# 默认限制配置(每分钟)
DEFAULT_LIMITS = {
"default": 100, # 默认限制
"login": 5, # 登录接口
"register": 3, # 注册接口
"api": 60, # API 接口
"public": 60, # 公共接口
}
@staticmethod
async def check_limit(
ip_address: str, endpoint_type: str = "default", limit: Optional[int] = None
) -> tuple[bool, int, int]:
"""
检查 IP 是否超过速率限制
Args:
ip_address: IP 地址
endpoint_type: 端点类型default, login, register, api, public
limit: 自定义限制值None 则使用默认值
Returns:
(是否允许, 剩余次数, 重置时间秒数)
"""
# 检查白名单
if await IPRateLimiter.is_whitelisted(ip_address):
return True, 999999, 60
# 检查黑名单
if await IPRateLimiter.is_blacklisted(ip_address):
logger.warning(f"黑名单 IP 尝试访问: {ip_address}, 类型: {endpoint_type}")
return False, 0, 0
redis_client = await get_redis_client(require_redis=False)
if redis_client is None:
# Redis 不可用时降级:允许访问但记录警告
logger.warning("Redis 不可用,跳过 IP 速率限制(降级模式)")
return True, 0, 60
# 确定限制值
rate_limit = (
limit if limit is not None else IPRateLimiter.DEFAULT_LIMITS.get(endpoint_type, 100)
)
try:
# Redis key: ip:rate_limit:{type}:{ip}
redis_key = f"{IPRateLimiter.RATE_LIMIT_PREFIX}{endpoint_type}:{ip_address}"
# 使用 Redis 的滑动窗口计数器
# INCR 并设置过期时间
count = await redis_client.incr(redis_key)
# 第一次访问时设置过期时间
if count == 1:
await redis_client.expire(redis_key, 60) # 60秒窗口
# 获取 TTL剩余过期时间
ttl = await redis_client.ttl(redis_key)
if ttl < 0:
# 如果没有过期时间,重新设置
await redis_client.expire(redis_key, 60)
ttl = 60
remaining = max(0, rate_limit - count)
allowed = count <= rate_limit
if not allowed:
logger.warning(f"IP 速率限制触发: {ip_address}, 类型: {endpoint_type}, 计数: {count}/{rate_limit}")
return allowed, remaining, ttl
except Exception as e:
logger.error(f"检查 IP 速率限制失败: {e}")
# 发生错误时允许访问,避免误杀
return True, 0, 60
@staticmethod
async def add_to_blacklist(
ip_address: str, reason: str = "manual", ttl: Optional[int] = None
) -> bool:
"""
将 IP 加入黑名单
Args:
ip_address: IP 地址
reason: 加入黑名单的原因
ttl: 过期时间None 表示永久
Returns:
是否成功
"""
redis_client = await get_redis_client(require_redis=False)
if redis_client is None:
logger.warning("Redis 不可用,无法将 IP 加入黑名单")
return False
try:
redis_key = f"{IPRateLimiter.BLACKLIST_PREFIX}{ip_address}"
if ttl is not None:
await redis_client.setex(redis_key, ttl, reason)
else:
await redis_client.set(redis_key, reason)
logger.warning(f"IP 已加入黑名单: {ip_address}, 原因: {reason}, TTL: {ttl or '永久'}")
return True
except Exception as e:
logger.error(f"添加 IP 到黑名单失败: {e}")
return False
@staticmethod
async def remove_from_blacklist(ip_address: str) -> bool:
"""
从黑名单移除 IP
Args:
ip_address: IP 地址
Returns:
是否成功
"""
redis_client = await get_redis_client(require_redis=False)
if redis_client is None:
logger.warning("Redis 不可用,无法从黑名单移除 IP")
return False
try:
redis_key = f"{IPRateLimiter.BLACKLIST_PREFIX}{ip_address}"
deleted = await redis_client.delete(redis_key)
if deleted:
logger.info(f"IP 已从黑名单移除: {ip_address}")
return bool(deleted)
except Exception as e:
logger.error(f"从黑名单移除 IP 失败: {e}")
return False
@staticmethod
async def is_blacklisted(ip_address: str) -> bool:
"""
检查 IP 是否在黑名单中
Args:
ip_address: IP 地址
Returns:
是否在黑名单中
"""
redis_client = await get_redis_client(require_redis=False)
if redis_client is None:
return False
try:
redis_key = f"{IPRateLimiter.BLACKLIST_PREFIX}{ip_address}"
exists = await redis_client.exists(redis_key)
return bool(exists)
except Exception as e:
logger.error(f"检查 IP 黑名单状态失败: {e}")
return False
@staticmethod
async def add_to_whitelist(ip_address: str) -> bool:
"""
将 IP 加入白名单
Args:
ip_address: IP 地址或 CIDR 格式(如 192.168.1.0/24
Returns:
是否成功
"""
redis_client = await get_redis_client(require_redis=False)
if redis_client is None:
logger.warning("Redis 不可用,无法将 IP 加入白名单")
return False
try:
# 验证 IP 格式
try:
ipaddress.ip_network(ip_address, strict=False)
except ValueError as e:
logger.error(f"无效的 IP 地址格式: {ip_address}, 错误: {e}")
return False
# 使用 Redis Set 存储白名单
await redis_client.sadd(IPRateLimiter.WHITELIST_KEY, ip_address)
logger.info(f"IP 已加入白名单: {ip_address}")
return True
except Exception as e:
logger.error(f"添加 IP 到白名单失败: {e}")
return False
@staticmethod
async def remove_from_whitelist(ip_address: str) -> bool:
"""
从白名单移除 IP
Args:
ip_address: IP 地址
Returns:
是否成功
"""
redis_client = await get_redis_client(require_redis=False)
if redis_client is None:
logger.warning("Redis 不可用,无法从白名单移除 IP")
return False
try:
removed = await redis_client.srem(IPRateLimiter.WHITELIST_KEY, ip_address)
if removed:
logger.info(f"IP 已从白名单移除: {ip_address}")
return bool(removed)
except Exception as e:
logger.error(f"从白名单移除 IP 失败: {e}")
return False
@staticmethod
async def is_whitelisted(ip_address: str) -> bool:
"""
检查 IP 是否在白名单中(支持 CIDR 匹配)
Args:
ip_address: IP 地址
Returns:
是否在白名单中
"""
redis_client = await get_redis_client(require_redis=False)
if redis_client is None:
return False
try:
# 获取所有白名单条目
whitelist = await redis_client.smembers(IPRateLimiter.WHITELIST_KEY)
if not whitelist:
return False
# 将 IP 地址转换为 ip_address 对象
try:
ip_obj = ipaddress.ip_address(ip_address)
except ValueError:
return False
# 检查是否匹配白名单中的任何条目
for entry in whitelist:
try:
network = ipaddress.ip_network(entry, strict=False)
if ip_obj in network:
return True
except ValueError:
# 如果条目格式无效,跳过
continue
return False
except Exception as e:
logger.error(f"检查 IP 白名单状态失败: {e}")
return False
@staticmethod
async def get_blacklist_stats() -> Dict:
"""
获取黑名单统计信息
Returns:
统计信息字典
"""
redis_client = await get_redis_client(require_redis=False)
if redis_client is None:
return {"available": False, "total": 0, "error": "Redis 不可用"}
try:
pattern = f"{IPRateLimiter.BLACKLIST_PREFIX}*"
cursor = 0
total = 0
while True:
cursor, keys = await redis_client.scan(cursor=cursor, match=pattern, count=100)
total += len(keys)
if cursor == 0:
break
return {"available": True, "total": total}
except Exception as e:
logger.error(f"获取黑名单统计失败: {e}")
return {"available": False, "total": 0, "error": str(e)}
@staticmethod
async def get_whitelist() -> Set[str]:
"""
获取白名单列表
Returns:
白名单 IP 集合
"""
redis_client = await get_redis_client(require_redis=False)
if redis_client is None:
return set()
try:
whitelist = await redis_client.smembers(IPRateLimiter.WHITELIST_KEY)
return whitelist if whitelist else set()
except Exception as e:
logger.error(f"获取白名单失败: {e}")
return set()

View File

@@ -0,0 +1,139 @@
"""
RPM (Requests Per Minute) 限流服务
"""
import time
from datetime import datetime, timezone
from typing import Dict, Tuple
from sqlalchemy.orm import Session
from src.core.batch_committer import get_batch_committer
from src.core.logger import logger
from src.models.database import Provider
from src.models.database_extensions import ProviderUsageTracking
class RPMLimiter:
"""RPM限流器"""
def __init__(self, db: Session):
self.db = db
# 内存中的RPM计数器 {provider_id: (count, window_start)}
self._rpm_counters: Dict[str, Tuple[int, float]] = {}
def check_and_increment(self, provider_id: str) -> bool:
"""
检查并递增RPM计数
Returns:
True if allowed, False if rate limited
"""
provider = self.db.query(Provider).filter(Provider.id == provider_id).first()
if not provider:
return True
rpm_limit = provider.rpm_limit
if rpm_limit is None:
# 未设置限制
return True
if rpm_limit == 0:
logger.warning(f"Provider {provider.name} is fully restricted by RPM limit=0")
return False
current_time = time.time()
# 检查是否需要重置
if provider.rpm_reset_at and provider.rpm_reset_at < datetime.now(timezone.utc):
provider.rpm_used = 0
provider.rpm_reset_at = datetime.fromtimestamp(current_time + 60, tz=timezone.utc)
self.db.commit() # 立即提交事务,释放数据库锁
# 检查是否超限
if provider.rpm_used >= rpm_limit:
logger.warning(f"Provider {provider.name} RPM limit exceeded")
return False
# 递增计数
provider.rpm_used += 1
if not provider.rpm_reset_at:
provider.rpm_reset_at = datetime.fromtimestamp(current_time + 60, tz=timezone.utc)
self.db.commit() # 立即提交事务,释放数据库锁
return True
def record_usage(
self, provider_id: str, success: bool, response_time_ms: float, cost_usd: float
):
"""记录使用情况到追踪表"""
# 获取当前分钟窗口
now = datetime.now(timezone.utc)
window_start = now.replace(second=0, microsecond=0)
window_end = (
window_start.replace(minute=window_start.minute + 1)
if window_start.minute < 59
else window_start.replace(hour=window_start.hour + 1, minute=0)
)
# 查找或创建追踪记录
tracking = (
self.db.query(ProviderUsageTracking)
.filter(
ProviderUsageTracking.provider_id == provider_id,
ProviderUsageTracking.window_start == window_start,
)
.first()
)
if not tracking:
tracking = ProviderUsageTracking(
provider_id=provider_id, window_start=window_start, window_end=window_end
)
self.db.add(tracking)
# 更新统计
tracking.total_requests += 1
if success:
tracking.successful_requests += 1
else:
tracking.failed_requests += 1
tracking.total_response_time_ms += response_time_ms
tracking.avg_response_time_ms = tracking.total_response_time_ms / tracking.total_requests
tracking.total_cost_usd += cost_usd
self.db.flush() # 只 flush不立即 commit
# RPM 使用统计是非关键数据,使用批量提交
get_batch_committer().mark_dirty(self.db)
logger.debug(f"Recorded usage for provider {provider_id}")
def get_rpm_status(self, provider_id: str) -> Dict:
"""获取提供商的RPM状态"""
provider = self.db.query(Provider).filter(Provider.id == provider_id).first()
if not provider:
return {"error": "Provider not found"}
return {
"provider_id": provider_id,
"provider_name": provider.name,
"rpm_limit": provider.rpm_limit,
"rpm_used": provider.rpm_used,
"rpm_reset_at": provider.rpm_reset_at.isoformat() if provider.rpm_reset_at else None,
"available": (
provider.rpm_limit - provider.rpm_used if provider.rpm_limit is not None else None
),
}
def reset_rpm_counter(self, provider_id: str):
"""手动重置RPM计数器"""
provider = self.db.query(Provider).filter(Provider.id == provider_id).first()
if provider:
provider.rpm_used = 0
provider.rpm_reset_at = None
self.db.commit() # 立即提交事务,释放数据库锁
logger.info(f"Reset RPM counter for provider {provider.name}")