2025-12-10 20:52:44 +08:00
|
|
|
|
"""
|
|
|
|
|
|
RPM (Requests Per Minute) 限流服务
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import time
|
2025-12-18 00:35:46 +08:00
|
|
|
|
from datetime import datetime, timedelta, timezone
|
2025-12-10 20:52:44 +08:00
|
|
|
|
from typing import Dict, Tuple
|
|
|
|
|
|
|
|
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
|
|
|
|
|
|
from src.core.batch_committer import get_batch_committer
|
|
|
|
|
|
from src.core.logger import logger
|
|
|
|
|
|
from src.models.database import Provider
|
|
|
|
|
|
from src.models.database_extensions import ProviderUsageTracking
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RPMLimiter:
|
|
|
|
|
|
"""RPM限流器"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, db: Session):
|
|
|
|
|
|
self.db = db
|
|
|
|
|
|
# 内存中的RPM计数器 {provider_id: (count, window_start)}
|
|
|
|
|
|
self._rpm_counters: Dict[str, Tuple[int, float]] = {}
|
|
|
|
|
|
|
|
|
|
|
|
def check_and_increment(self, provider_id: str) -> bool:
|
|
|
|
|
|
"""
|
|
|
|
|
|
检查并递增RPM计数
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
True if allowed, False if rate limited
|
|
|
|
|
|
"""
|
|
|
|
|
|
provider = self.db.query(Provider).filter(Provider.id == provider_id).first()
|
|
|
|
|
|
if not provider:
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
rpm_limit = provider.rpm_limit
|
|
|
|
|
|
if rpm_limit is None:
|
|
|
|
|
|
# 未设置限制
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
if rpm_limit == 0:
|
|
|
|
|
|
logger.warning(f"Provider {provider.name} is fully restricted by RPM limit=0")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
current_time = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
# 检查是否需要重置
|
|
|
|
|
|
if provider.rpm_reset_at and provider.rpm_reset_at < datetime.now(timezone.utc):
|
|
|
|
|
|
provider.rpm_used = 0
|
|
|
|
|
|
provider.rpm_reset_at = datetime.fromtimestamp(current_time + 60, tz=timezone.utc)
|
|
|
|
|
|
self.db.commit() # 立即提交事务,释放数据库锁
|
|
|
|
|
|
|
|
|
|
|
|
# 检查是否超限
|
|
|
|
|
|
if provider.rpm_used >= rpm_limit:
|
|
|
|
|
|
logger.warning(f"Provider {provider.name} RPM limit exceeded")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
# 递增计数
|
|
|
|
|
|
provider.rpm_used += 1
|
|
|
|
|
|
if not provider.rpm_reset_at:
|
|
|
|
|
|
provider.rpm_reset_at = datetime.fromtimestamp(current_time + 60, tz=timezone.utc)
|
|
|
|
|
|
|
|
|
|
|
|
self.db.commit() # 立即提交事务,释放数据库锁
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
def record_usage(
|
|
|
|
|
|
self, provider_id: str, success: bool, response_time_ms: float, cost_usd: float
|
|
|
|
|
|
):
|
|
|
|
|
|
"""记录使用情况到追踪表"""
|
|
|
|
|
|
|
|
|
|
|
|
# 获取当前分钟窗口
|
|
|
|
|
|
now = datetime.now(timezone.utc)
|
|
|
|
|
|
window_start = now.replace(second=0, microsecond=0)
|
2025-12-18 00:35:46 +08:00
|
|
|
|
window_end = window_start + timedelta(minutes=1)
|
2025-12-10 20:52:44 +08:00
|
|
|
|
|
|
|
|
|
|
# 查找或创建追踪记录
|
|
|
|
|
|
tracking = (
|
|
|
|
|
|
self.db.query(ProviderUsageTracking)
|
|
|
|
|
|
.filter(
|
|
|
|
|
|
ProviderUsageTracking.provider_id == provider_id,
|
|
|
|
|
|
ProviderUsageTracking.window_start == window_start,
|
|
|
|
|
|
)
|
|
|
|
|
|
.first()
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if not tracking:
|
|
|
|
|
|
tracking = ProviderUsageTracking(
|
|
|
|
|
|
provider_id=provider_id, window_start=window_start, window_end=window_end
|
|
|
|
|
|
)
|
|
|
|
|
|
self.db.add(tracking)
|
|
|
|
|
|
|
|
|
|
|
|
# 更新统计
|
|
|
|
|
|
tracking.total_requests += 1
|
|
|
|
|
|
if success:
|
|
|
|
|
|
tracking.successful_requests += 1
|
|
|
|
|
|
else:
|
|
|
|
|
|
tracking.failed_requests += 1
|
|
|
|
|
|
|
|
|
|
|
|
tracking.total_response_time_ms += response_time_ms
|
|
|
|
|
|
tracking.avg_response_time_ms = tracking.total_response_time_ms / tracking.total_requests
|
|
|
|
|
|
tracking.total_cost_usd += cost_usd
|
|
|
|
|
|
|
|
|
|
|
|
self.db.flush() # 只 flush,不立即 commit
|
|
|
|
|
|
# RPM 使用统计是非关键数据,使用批量提交
|
|
|
|
|
|
get_batch_committer().mark_dirty(self.db)
|
|
|
|
|
|
|
|
|
|
|
|
logger.debug(f"Recorded usage for provider {provider_id}")
|
|
|
|
|
|
|
|
|
|
|
|
def get_rpm_status(self, provider_id: str) -> Dict:
|
|
|
|
|
|
"""获取提供商的RPM状态"""
|
|
|
|
|
|
provider = self.db.query(Provider).filter(Provider.id == provider_id).first()
|
|
|
|
|
|
if not provider:
|
|
|
|
|
|
return {"error": "Provider not found"}
|
|
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
"provider_id": provider_id,
|
|
|
|
|
|
"provider_name": provider.name,
|
|
|
|
|
|
"rpm_limit": provider.rpm_limit,
|
|
|
|
|
|
"rpm_used": provider.rpm_used,
|
|
|
|
|
|
"rpm_reset_at": provider.rpm_reset_at.isoformat() if provider.rpm_reset_at else None,
|
|
|
|
|
|
"available": (
|
|
|
|
|
|
provider.rpm_limit - provider.rpm_used if provider.rpm_limit is not None else None
|
|
|
|
|
|
),
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def reset_rpm_counter(self, provider_id: str):
|
|
|
|
|
|
"""手动重置RPM计数器"""
|
|
|
|
|
|
provider = self.db.query(Provider).filter(Provider.id == provider_id).first()
|
|
|
|
|
|
if provider:
|
|
|
|
|
|
provider.rpm_used = 0
|
|
|
|
|
|
provider.rpm_reset_at = None
|
|
|
|
|
|
self.db.commit() # 立即提交事务,释放数据库锁
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"Reset RPM counter for provider {provider.name}")
|