Files
Aether/src/services/rate_limit/rpm_limiter.py
fawney19 b579420690 refactor: optimize database session lifecycle and middleware architecture
- Improve database pool capacity logging with detailed configuration parameters
- Optimize database session dependency injection with middleware-managed lifecycle
- Simplify plugin middleware by delegating session creation to FastAPI dependencies
- Fix import path in auth routes (relative to absolute)
- Add safety checks for database session management across middleware exception handlers
- Ensure session cleanup only when not managed by middleware (avoid premature cleanup)
2025-12-18 00:35:46 +08:00

136 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
RPM (Requests Per Minute) 限流服务
"""
import time
from datetime import datetime, timedelta, timezone
from typing import Dict, Tuple
from sqlalchemy.orm import Session
from src.core.batch_committer import get_batch_committer
from src.core.logger import logger
from src.models.database import Provider
from src.models.database_extensions import ProviderUsageTracking
class RPMLimiter:
"""RPM限流器"""
def __init__(self, db: Session):
self.db = db
# 内存中的RPM计数器 {provider_id: (count, window_start)}
self._rpm_counters: Dict[str, Tuple[int, float]] = {}
def check_and_increment(self, provider_id: str) -> bool:
"""
检查并递增RPM计数
Returns:
True if allowed, False if rate limited
"""
provider = self.db.query(Provider).filter(Provider.id == provider_id).first()
if not provider:
return True
rpm_limit = provider.rpm_limit
if rpm_limit is None:
# 未设置限制
return True
if rpm_limit == 0:
logger.warning(f"Provider {provider.name} is fully restricted by RPM limit=0")
return False
current_time = time.time()
# 检查是否需要重置
if provider.rpm_reset_at and provider.rpm_reset_at < datetime.now(timezone.utc):
provider.rpm_used = 0
provider.rpm_reset_at = datetime.fromtimestamp(current_time + 60, tz=timezone.utc)
self.db.commit() # 立即提交事务,释放数据库锁
# 检查是否超限
if provider.rpm_used >= rpm_limit:
logger.warning(f"Provider {provider.name} RPM limit exceeded")
return False
# 递增计数
provider.rpm_used += 1
if not provider.rpm_reset_at:
provider.rpm_reset_at = datetime.fromtimestamp(current_time + 60, tz=timezone.utc)
self.db.commit() # 立即提交事务,释放数据库锁
return True
def record_usage(
self, provider_id: str, success: bool, response_time_ms: float, cost_usd: float
):
"""记录使用情况到追踪表"""
# 获取当前分钟窗口
now = datetime.now(timezone.utc)
window_start = now.replace(second=0, microsecond=0)
window_end = window_start + timedelta(minutes=1)
# 查找或创建追踪记录
tracking = (
self.db.query(ProviderUsageTracking)
.filter(
ProviderUsageTracking.provider_id == provider_id,
ProviderUsageTracking.window_start == window_start,
)
.first()
)
if not tracking:
tracking = ProviderUsageTracking(
provider_id=provider_id, window_start=window_start, window_end=window_end
)
self.db.add(tracking)
# 更新统计
tracking.total_requests += 1
if success:
tracking.successful_requests += 1
else:
tracking.failed_requests += 1
tracking.total_response_time_ms += response_time_ms
tracking.avg_response_time_ms = tracking.total_response_time_ms / tracking.total_requests
tracking.total_cost_usd += cost_usd
self.db.flush() # 只 flush不立即 commit
# RPM 使用统计是非关键数据,使用批量提交
get_batch_committer().mark_dirty(self.db)
logger.debug(f"Recorded usage for provider {provider_id}")
def get_rpm_status(self, provider_id: str) -> Dict:
"""获取提供商的RPM状态"""
provider = self.db.query(Provider).filter(Provider.id == provider_id).first()
if not provider:
return {"error": "Provider not found"}
return {
"provider_id": provider_id,
"provider_name": provider.name,
"rpm_limit": provider.rpm_limit,
"rpm_used": provider.rpm_used,
"rpm_reset_at": provider.rpm_reset_at.isoformat() if provider.rpm_reset_at else None,
"available": (
provider.rpm_limit - provider.rpm_used if provider.rpm_limit is not None else None
),
}
def reset_rpm_counter(self, provider_id: str):
"""手动重置RPM计数器"""
provider = self.db.query(Provider).filter(Provider.id == provider_id).first()
if provider:
provider.rpm_used = 0
provider.rpm_reset_at = None
self.db.commit() # 立即提交事务,释放数据库锁
logger.info(f"Reset RPM counter for provider {provider.name}")