Files
Aether/src/services/rate_limit/rpm_limiter.py

136 lines
4.5 KiB
Python
Raw Normal View History

2025-12-10 20:52:44 +08:00
"""
RPM (Requests Per Minute) 限流服务
"""
import time
from datetime import datetime, timedelta, timezone
2025-12-10 20:52:44 +08:00
from typing import Dict, Tuple
from sqlalchemy.orm import Session
from src.core.batch_committer import get_batch_committer
from src.core.logger import logger
from src.models.database import Provider
from src.models.database_extensions import ProviderUsageTracking
class RPMLimiter:
"""RPM限流器"""
def __init__(self, db: Session):
self.db = db
# 内存中的RPM计数器 {provider_id: (count, window_start)}
self._rpm_counters: Dict[str, Tuple[int, float]] = {}
def check_and_increment(self, provider_id: str) -> bool:
"""
检查并递增RPM计数
Returns:
True if allowed, False if rate limited
"""
provider = self.db.query(Provider).filter(Provider.id == provider_id).first()
if not provider:
return True
rpm_limit = provider.rpm_limit
if rpm_limit is None:
# 未设置限制
return True
if rpm_limit == 0:
logger.warning(f"Provider {provider.name} is fully restricted by RPM limit=0")
return False
current_time = time.time()
# 检查是否需要重置
if provider.rpm_reset_at and provider.rpm_reset_at < datetime.now(timezone.utc):
provider.rpm_used = 0
provider.rpm_reset_at = datetime.fromtimestamp(current_time + 60, tz=timezone.utc)
self.db.commit() # 立即提交事务,释放数据库锁
# 检查是否超限
if provider.rpm_used >= rpm_limit:
logger.warning(f"Provider {provider.name} RPM limit exceeded")
return False
# 递增计数
provider.rpm_used += 1
if not provider.rpm_reset_at:
provider.rpm_reset_at = datetime.fromtimestamp(current_time + 60, tz=timezone.utc)
self.db.commit() # 立即提交事务,释放数据库锁
return True
def record_usage(
self, provider_id: str, success: bool, response_time_ms: float, cost_usd: float
):
"""记录使用情况到追踪表"""
# 获取当前分钟窗口
now = datetime.now(timezone.utc)
window_start = now.replace(second=0, microsecond=0)
window_end = window_start + timedelta(minutes=1)
2025-12-10 20:52:44 +08:00
# 查找或创建追踪记录
tracking = (
self.db.query(ProviderUsageTracking)
.filter(
ProviderUsageTracking.provider_id == provider_id,
ProviderUsageTracking.window_start == window_start,
)
.first()
)
if not tracking:
tracking = ProviderUsageTracking(
provider_id=provider_id, window_start=window_start, window_end=window_end
)
self.db.add(tracking)
# 更新统计
tracking.total_requests += 1
if success:
tracking.successful_requests += 1
else:
tracking.failed_requests += 1
tracking.total_response_time_ms += response_time_ms
tracking.avg_response_time_ms = tracking.total_response_time_ms / tracking.total_requests
tracking.total_cost_usd += cost_usd
self.db.flush() # 只 flush不立即 commit
# RPM 使用统计是非关键数据,使用批量提交
get_batch_committer().mark_dirty(self.db)
logger.debug(f"Recorded usage for provider {provider_id}")
def get_rpm_status(self, provider_id: str) -> Dict:
"""获取提供商的RPM状态"""
provider = self.db.query(Provider).filter(Provider.id == provider_id).first()
if not provider:
return {"error": "Provider not found"}
return {
"provider_id": provider_id,
"provider_name": provider.name,
"rpm_limit": provider.rpm_limit,
"rpm_used": provider.rpm_used,
"rpm_reset_at": provider.rpm_reset_at.isoformat() if provider.rpm_reset_at else None,
"available": (
provider.rpm_limit - provider.rpm_used if provider.rpm_limit is not None else None
),
}
def reset_rpm_counter(self, provider_id: str):
"""手动重置RPM计数器"""
provider = self.db.query(Provider).filter(Provider.id == provider_id).first()
if provider:
provider.rpm_used = 0
provider.rpm_reset_at = None
self.db.commit() # 立即提交事务,释放数据库锁
logger.info(f"Reset RPM counter for provider {provider.name}")