""" 提供商策略管理 API 端点 """ from datetime import datetime, timedelta, timezone from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import JSONResponse from pydantic import BaseModel, Field, ValidationError from sqlalchemy.orm import Session from src.api.base.admin_adapter import AdminApiAdapter from src.api.base.pipeline import ApiRequestPipeline from src.core.enums import ProviderBillingType from src.core.exceptions import InvalidRequestException, translate_pydantic_error from src.core.logger import logger from src.database import get_db from src.models.database import Provider from src.models.database_extensions import ProviderUsageTracking router = APIRouter(prefix="/api/admin/provider-strategy", tags=["Provider Strategy"]) pipeline = ApiRequestPipeline() class ProviderBillingUpdate(BaseModel): billing_type: ProviderBillingType monthly_quota_usd: Optional[float] = None quota_reset_day: int = Field(default=30, ge=1, le=365) # 重置周期(天数) quota_last_reset_at: Optional[str] = None # 当前周期开始时间 quota_expires_at: Optional[str] = None rpm_limit: Optional[int] = Field(default=None, ge=0) provider_priority: int = Field(default=100, ge=0, le=200) @router.put("/providers/{provider_id}/billing") async def update_provider_billing( provider_id: str, request: Request, db: Session = Depends(get_db), ): adapter = AdminProviderBillingAdapter(provider_id=provider_id) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) @router.get("/providers/{provider_id}/stats") async def get_provider_stats( provider_id: str, request: Request, hours: int = 24, db: Session = Depends(get_db), ): adapter = AdminProviderStatsAdapter(provider_id=provider_id, hours=hours) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) @router.delete("/providers/{provider_id}/quota") async def reset_provider_quota( provider_id: str, request: Request, db: Session = Depends(get_db), ): """Reset provider quota usage to zero""" adapter = AdminProviderResetQuotaAdapter(provider_id=provider_id) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) @router.get("/strategies") async def list_available_strategies(request: Request, db: Session = Depends(get_db)): adapter = AdminListStrategiesAdapter() return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) class AdminProviderBillingAdapter(AdminApiAdapter): def __init__(self, provider_id: str): self.provider_id = provider_id async def handle(self, context): db = context.db payload = context.ensure_json_body() try: config = ProviderBillingUpdate.model_validate(payload) except ValidationError as e: errors = e.errors() if errors: raise InvalidRequestException(translate_pydantic_error(errors[0])) raise InvalidRequestException("请求数据验证失败") provider = db.query(Provider).filter(Provider.id == self.provider_id).first() if not provider: raise HTTPException(status_code=404, detail="Provider not found") provider.billing_type = config.billing_type provider.monthly_quota_usd = config.monthly_quota_usd provider.quota_reset_day = config.quota_reset_day provider.rpm_limit = config.rpm_limit provider.provider_priority = config.provider_priority from dateutil import parser from sqlalchemy import func from src.models.database import Usage if config.quota_last_reset_at: new_reset_at = parser.parse(config.quota_last_reset_at) # 确保有时区信息,如果没有则假设为 UTC if new_reset_at.tzinfo is None: new_reset_at = new_reset_at.replace(tzinfo=timezone.utc) provider.quota_last_reset_at = new_reset_at # 自动同步该周期内的历史使用量 period_usage = ( db.query(func.coalesce(func.sum(Usage.total_cost_usd), 0)) .filter( Usage.provider_id == self.provider_id, Usage.created_at >= new_reset_at, ) .scalar() ) provider.monthly_used_usd = float(period_usage or 0) logger.info(f"Synced usage for provider {provider.name}: ${period_usage:.4f} since {new_reset_at}") if config.quota_expires_at: expires_at = parser.parse(config.quota_expires_at) # 确保有时区信息,如果没有则假设为 UTC if expires_at.tzinfo is None: expires_at = expires_at.replace(tzinfo=timezone.utc) provider.quota_expires_at = expires_at db.commit() db.refresh(provider) logger.info(f"Updated billing config for provider {provider.name}") return JSONResponse( { "message": "Provider billing config updated successfully", "provider": { "id": provider.id, "name": provider.name, "billing_type": provider.billing_type.value, "provider_priority": provider.provider_priority, }, } ) class AdminProviderStatsAdapter(AdminApiAdapter): def __init__(self, provider_id: str, hours: int): self.provider_id = provider_id self.hours = hours async def handle(self, context): db = context.db provider = db.query(Provider).filter(Provider.id == self.provider_id).first() if not provider: raise HTTPException(status_code=404, detail="Provider not found") since = datetime.now(timezone.utc) - timedelta(hours=self.hours) stats = ( db.query(ProviderUsageTracking) .filter( ProviderUsageTracking.provider_id == self.provider_id, ProviderUsageTracking.window_start >= since, ) .all() ) total_requests = sum(s.total_requests for s in stats) total_success = sum(s.successful_requests for s in stats) total_failures = sum(s.failed_requests for s in stats) avg_response_time = sum(s.avg_response_time_ms for s in stats) / len(stats) if stats else 0 total_cost = sum(s.total_cost_usd for s in stats) return JSONResponse( { "provider_id": self.provider_id, "provider_name": provider.name, "period_hours": self.hours, "billing_info": { "billing_type": provider.billing_type.value, "monthly_quota_usd": provider.monthly_quota_usd, "monthly_used_usd": provider.monthly_used_usd, "quota_remaining_usd": ( provider.monthly_quota_usd - provider.monthly_used_usd if provider.monthly_quota_usd is not None else None ), "quota_expires_at": ( provider.quota_expires_at.isoformat() if provider.quota_expires_at else None ), }, "rpm_info": { "rpm_limit": provider.rpm_limit, "rpm_used": provider.rpm_used, "rpm_reset_at": ( provider.rpm_reset_at.isoformat() if provider.rpm_reset_at else None ), }, "usage_stats": { "total_requests": total_requests, "successful_requests": total_success, "failed_requests": total_failures, "success_rate": total_success / total_requests if total_requests > 0 else 0, "avg_response_time_ms": round(avg_response_time, 2), "total_cost_usd": round(total_cost, 4), }, } ) class AdminProviderResetQuotaAdapter(AdminApiAdapter): def __init__(self, provider_id: str): self.provider_id = provider_id async def handle(self, context): db = context.db provider = db.query(Provider).filter(Provider.id == self.provider_id).first() if not provider: raise HTTPException(status_code=404, detail="Provider not found") if provider.billing_type != ProviderBillingType.MONTHLY_QUOTA: raise HTTPException(status_code=400, detail="Only monthly quota providers can be reset") old_used = provider.monthly_used_usd provider.monthly_used_usd = 0.0 provider.rpm_used = 0 provider.rpm_reset_at = None db.commit() logger.info(f"Manually reset quota for provider {provider.name}") return JSONResponse( { "message": "Provider quota reset successfully", "provider_name": provider.name, "previous_used": old_used, "current_used": 0.0, } ) class AdminListStrategiesAdapter(AdminApiAdapter): async def handle(self, context): from src.plugins.manager import get_plugin_manager plugin_manager = get_plugin_manager() lb_plugins = plugin_manager.plugins.get("load_balancer", {}) strategies = [] for name, plugin in lb_plugins.items(): try: strategies.append( { "name": getattr(plugin, "name", name), "priority": getattr(plugin, "priority", 0), "version": ( getattr(plugin.metadata, "version", "1.0.0") if hasattr(plugin, "metadata") else "1.0.0" ), "description": ( getattr(plugin.metadata, "description", "") if hasattr(plugin, "metadata") else "" ), "author": ( getattr(plugin.metadata, "author", "Unknown") if hasattr(plugin, "metadata") else "Unknown" ), } ) except Exception as exc: # pragma: no cover logger.error(f"Error accessing plugin {name}: {exc}") continue strategies.sort(key=lambda x: x["priority"], reverse=True) return JSONResponse({"strategies": strategies, "total": len(strategies)})