Files
Aether/src/api/admin/provider_strategy.py
fawney19 dec681fea0 fix: 统一时区处理,确保所有 datetime 带时区信息
- token_bucket.py: get_reset_time 和 Redis 后端使用 timezone.utc
- sliding_window.py: get_reset_time 和 retry_after 计算使用 timezone.utc
- provider_strategy.py: dateutil.parser 解析后确保有时区信息
2026-01-05 02:23:24 +08:00

280 lines
11 KiB
Python

"""
提供商策略管理 API 端点
"""
from datetime import datetime, timedelta, timezone
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.enums import ProviderBillingType
from src.core.exceptions import InvalidRequestException, translate_pydantic_error
from src.core.logger import logger
from src.database import get_db
from src.models.database import Provider
from src.models.database_extensions import ProviderUsageTracking
router = APIRouter(prefix="/api/admin/provider-strategy", tags=["Provider Strategy"])
pipeline = ApiRequestPipeline()
class ProviderBillingUpdate(BaseModel):
billing_type: ProviderBillingType
monthly_quota_usd: Optional[float] = None
quota_reset_day: int = Field(default=30, ge=1, le=365) # 重置周期(天数)
quota_last_reset_at: Optional[str] = None # 当前周期开始时间
quota_expires_at: Optional[str] = None
rpm_limit: Optional[int] = Field(default=None, ge=0)
provider_priority: int = Field(default=100, ge=0, le=200)
@router.put("/providers/{provider_id}/billing")
async def update_provider_billing(
provider_id: str,
request: Request,
db: Session = Depends(get_db),
):
adapter = AdminProviderBillingAdapter(provider_id=provider_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/providers/{provider_id}/stats")
async def get_provider_stats(
provider_id: str,
request: Request,
hours: int = 24,
db: Session = Depends(get_db),
):
adapter = AdminProviderStatsAdapter(provider_id=provider_id, hours=hours)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/providers/{provider_id}/quota")
async def reset_provider_quota(
provider_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""Reset provider quota usage to zero"""
adapter = AdminProviderResetQuotaAdapter(provider_id=provider_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/strategies")
async def list_available_strategies(request: Request, db: Session = Depends(get_db)):
adapter = AdminListStrategiesAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
class AdminProviderBillingAdapter(AdminApiAdapter):
def __init__(self, provider_id: str):
self.provider_id = provider_id
async def handle(self, context):
db = context.db
payload = context.ensure_json_body()
try:
config = ProviderBillingUpdate.model_validate(payload)
except ValidationError as e:
errors = e.errors()
if errors:
raise InvalidRequestException(translate_pydantic_error(errors[0]))
raise InvalidRequestException("请求数据验证失败")
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
provider.billing_type = config.billing_type
provider.monthly_quota_usd = config.monthly_quota_usd
provider.quota_reset_day = config.quota_reset_day
provider.rpm_limit = config.rpm_limit
provider.provider_priority = config.provider_priority
from dateutil import parser
from sqlalchemy import func
from src.models.database import Usage
if config.quota_last_reset_at:
new_reset_at = parser.parse(config.quota_last_reset_at)
# 确保有时区信息,如果没有则假设为 UTC
if new_reset_at.tzinfo is None:
new_reset_at = new_reset_at.replace(tzinfo=timezone.utc)
provider.quota_last_reset_at = new_reset_at
# 自动同步该周期内的历史使用量
period_usage = (
db.query(func.coalesce(func.sum(Usage.total_cost_usd), 0))
.filter(
Usage.provider_id == self.provider_id,
Usage.created_at >= new_reset_at,
)
.scalar()
)
provider.monthly_used_usd = float(period_usage or 0)
logger.info(f"Synced usage for provider {provider.name}: ${period_usage:.4f} since {new_reset_at}")
if config.quota_expires_at:
expires_at = parser.parse(config.quota_expires_at)
# 确保有时区信息,如果没有则假设为 UTC
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
provider.quota_expires_at = expires_at
db.commit()
db.refresh(provider)
logger.info(f"Updated billing config for provider {provider.name}")
return JSONResponse(
{
"message": "Provider billing config updated successfully",
"provider": {
"id": provider.id,
"name": provider.name,
"billing_type": provider.billing_type.value,
"provider_priority": provider.provider_priority,
},
}
)
class AdminProviderStatsAdapter(AdminApiAdapter):
def __init__(self, provider_id: str, hours: int):
self.provider_id = provider_id
self.hours = hours
async def handle(self, context):
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
since = datetime.now(timezone.utc) - timedelta(hours=self.hours)
stats = (
db.query(ProviderUsageTracking)
.filter(
ProviderUsageTracking.provider_id == self.provider_id,
ProviderUsageTracking.window_start >= since,
)
.all()
)
total_requests = sum(s.total_requests for s in stats)
total_success = sum(s.successful_requests for s in stats)
total_failures = sum(s.failed_requests for s in stats)
avg_response_time = sum(s.avg_response_time_ms for s in stats) / len(stats) if stats else 0
total_cost = sum(s.total_cost_usd for s in stats)
return JSONResponse(
{
"provider_id": self.provider_id,
"provider_name": provider.name,
"period_hours": self.hours,
"billing_info": {
"billing_type": provider.billing_type.value,
"monthly_quota_usd": provider.monthly_quota_usd,
"monthly_used_usd": provider.monthly_used_usd,
"quota_remaining_usd": (
provider.monthly_quota_usd - provider.monthly_used_usd
if provider.monthly_quota_usd is not None
else None
),
"quota_expires_at": (
provider.quota_expires_at.isoformat() if provider.quota_expires_at else None
),
},
"rpm_info": {
"rpm_limit": provider.rpm_limit,
"rpm_used": provider.rpm_used,
"rpm_reset_at": (
provider.rpm_reset_at.isoformat() if provider.rpm_reset_at else None
),
},
"usage_stats": {
"total_requests": total_requests,
"successful_requests": total_success,
"failed_requests": total_failures,
"success_rate": total_success / total_requests if total_requests > 0 else 0,
"avg_response_time_ms": round(avg_response_time, 2),
"total_cost_usd": round(total_cost, 4),
},
}
)
class AdminProviderResetQuotaAdapter(AdminApiAdapter):
def __init__(self, provider_id: str):
self.provider_id = provider_id
async def handle(self, context):
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
if provider.billing_type != ProviderBillingType.MONTHLY_QUOTA:
raise HTTPException(status_code=400, detail="Only monthly quota providers can be reset")
old_used = provider.monthly_used_usd
provider.monthly_used_usd = 0.0
provider.rpm_used = 0
provider.rpm_reset_at = None
db.commit()
logger.info(f"Manually reset quota for provider {provider.name}")
return JSONResponse(
{
"message": "Provider quota reset successfully",
"provider_name": provider.name,
"previous_used": old_used,
"current_used": 0.0,
}
)
class AdminListStrategiesAdapter(AdminApiAdapter):
async def handle(self, context):
from src.plugins.manager import get_plugin_manager
plugin_manager = get_plugin_manager()
lb_plugins = plugin_manager.plugins.get("load_balancer", {})
strategies = []
for name, plugin in lb_plugins.items():
try:
strategies.append(
{
"name": getattr(plugin, "name", name),
"priority": getattr(plugin, "priority", 0),
"version": (
getattr(plugin.metadata, "version", "1.0.0")
if hasattr(plugin, "metadata")
else "1.0.0"
),
"description": (
getattr(plugin.metadata, "description", "")
if hasattr(plugin, "metadata")
else ""
),
"author": (
getattr(plugin.metadata, "author", "Unknown")
if hasattr(plugin, "metadata")
else "Unknown"
),
}
)
except Exception as exc: # pragma: no cover
logger.error(f"Error accessing plugin {name}: {exc}")
continue
strategies.sort(key=lambda x: x["priority"], reverse=True)
return JSONResponse({"strategies": strategies, "total": len(strategies)})