Files
Aether/src/api/admin/provider_strategy.py

280 lines
11 KiB
Python
Raw Normal View History

2025-12-10 20:52:44 +08:00
"""
提供商策略管理 API 端点
"""
from datetime import datetime, timedelta, timezone
2025-12-10 20:52:44 +08:00
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.enums import ProviderBillingType
from src.core.exceptions import InvalidRequestException, translate_pydantic_error
from src.core.logger import logger
from src.database import get_db
from src.models.database import Provider
from src.models.database_extensions import ProviderUsageTracking
router = APIRouter(prefix="/api/admin/provider-strategy", tags=["Provider Strategy"])
pipeline = ApiRequestPipeline()
class ProviderBillingUpdate(BaseModel):
billing_type: ProviderBillingType
monthly_quota_usd: Optional[float] = None
quota_reset_day: int = Field(default=30, ge=1, le=365) # 重置周期(天数)
quota_last_reset_at: Optional[str] = None # 当前周期开始时间
quota_expires_at: Optional[str] = None
rpm_limit: Optional[int] = Field(default=None, ge=0)
provider_priority: int = Field(default=100, ge=0, le=200)
@router.put("/providers/{provider_id}/billing")
async def update_provider_billing(
provider_id: str,
request: Request,
db: Session = Depends(get_db),
):
adapter = AdminProviderBillingAdapter(provider_id=provider_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/providers/{provider_id}/stats")
async def get_provider_stats(
provider_id: str,
request: Request,
hours: int = 24,
db: Session = Depends(get_db),
):
adapter = AdminProviderStatsAdapter(provider_id=provider_id, hours=hours)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/providers/{provider_id}/quota")
async def reset_provider_quota(
provider_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""Reset provider quota usage to zero"""
adapter = AdminProviderResetQuotaAdapter(provider_id=provider_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/strategies")
async def list_available_strategies(request: Request, db: Session = Depends(get_db)):
adapter = AdminListStrategiesAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
class AdminProviderBillingAdapter(AdminApiAdapter):
def __init__(self, provider_id: str):
self.provider_id = provider_id
async def handle(self, context):
db = context.db
payload = context.ensure_json_body()
try:
config = ProviderBillingUpdate.model_validate(payload)
except ValidationError as e:
errors = e.errors()
if errors:
raise InvalidRequestException(translate_pydantic_error(errors[0]))
raise InvalidRequestException("请求数据验证失败")
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
provider.billing_type = config.billing_type
provider.monthly_quota_usd = config.monthly_quota_usd
provider.quota_reset_day = config.quota_reset_day
provider.rpm_limit = config.rpm_limit
provider.provider_priority = config.provider_priority
from dateutil import parser
from sqlalchemy import func
from src.models.database import Usage
if config.quota_last_reset_at:
new_reset_at = parser.parse(config.quota_last_reset_at)
# 确保有时区信息,如果没有则假设为 UTC
if new_reset_at.tzinfo is None:
new_reset_at = new_reset_at.replace(tzinfo=timezone.utc)
2025-12-10 20:52:44 +08:00
provider.quota_last_reset_at = new_reset_at
# 自动同步该周期内的历史使用量
period_usage = (
db.query(func.coalesce(func.sum(Usage.total_cost_usd), 0))
.filter(
Usage.provider_id == self.provider_id,
Usage.created_at >= new_reset_at,
)
.scalar()
)
provider.monthly_used_usd = float(period_usage or 0)
logger.info(f"Synced usage for provider {provider.name}: ${period_usage:.4f} since {new_reset_at}")
if config.quota_expires_at:
expires_at = parser.parse(config.quota_expires_at)
# 确保有时区信息,如果没有则假设为 UTC
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
provider.quota_expires_at = expires_at
2025-12-10 20:52:44 +08:00
db.commit()
db.refresh(provider)
logger.info(f"Updated billing config for provider {provider.name}")
return JSONResponse(
{
"message": "Provider billing config updated successfully",
"provider": {
"id": provider.id,
"name": provider.name,
"billing_type": provider.billing_type.value,
"provider_priority": provider.provider_priority,
},
}
)
class AdminProviderStatsAdapter(AdminApiAdapter):
def __init__(self, provider_id: str, hours: int):
self.provider_id = provider_id
self.hours = hours
async def handle(self, context):
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
since = datetime.now(timezone.utc) - timedelta(hours=self.hours)
2025-12-10 20:52:44 +08:00
stats = (
db.query(ProviderUsageTracking)
.filter(
ProviderUsageTracking.provider_id == self.provider_id,
ProviderUsageTracking.window_start >= since,
)
.all()
)
total_requests = sum(s.total_requests for s in stats)
total_success = sum(s.successful_requests for s in stats)
total_failures = sum(s.failed_requests for s in stats)
avg_response_time = sum(s.avg_response_time_ms for s in stats) / len(stats) if stats else 0
total_cost = sum(s.total_cost_usd for s in stats)
return JSONResponse(
{
"provider_id": self.provider_id,
"provider_name": provider.name,
"period_hours": self.hours,
"billing_info": {
"billing_type": provider.billing_type.value,
"monthly_quota_usd": provider.monthly_quota_usd,
"monthly_used_usd": provider.monthly_used_usd,
"quota_remaining_usd": (
provider.monthly_quota_usd - provider.monthly_used_usd
if provider.monthly_quota_usd is not None
else None
),
"quota_expires_at": (
provider.quota_expires_at.isoformat() if provider.quota_expires_at else None
),
},
"rpm_info": {
"rpm_limit": provider.rpm_limit,
"rpm_used": provider.rpm_used,
"rpm_reset_at": (
provider.rpm_reset_at.isoformat() if provider.rpm_reset_at else None
),
},
"usage_stats": {
"total_requests": total_requests,
"successful_requests": total_success,
"failed_requests": total_failures,
"success_rate": total_success / total_requests if total_requests > 0 else 0,
"avg_response_time_ms": round(avg_response_time, 2),
"total_cost_usd": round(total_cost, 4),
},
}
)
class AdminProviderResetQuotaAdapter(AdminApiAdapter):
def __init__(self, provider_id: str):
self.provider_id = provider_id
async def handle(self, context):
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
if provider.billing_type != ProviderBillingType.MONTHLY_QUOTA:
raise HTTPException(status_code=400, detail="Only monthly quota providers can be reset")
old_used = provider.monthly_used_usd
provider.monthly_used_usd = 0.0
provider.rpm_used = 0
provider.rpm_reset_at = None
db.commit()
logger.info(f"Manually reset quota for provider {provider.name}")
return JSONResponse(
{
"message": "Provider quota reset successfully",
"provider_name": provider.name,
"previous_used": old_used,
"current_used": 0.0,
}
)
class AdminListStrategiesAdapter(AdminApiAdapter):
async def handle(self, context):
from src.plugins.manager import get_plugin_manager
plugin_manager = get_plugin_manager()
lb_plugins = plugin_manager.plugins.get("load_balancer", {})
strategies = []
for name, plugin in lb_plugins.items():
try:
strategies.append(
{
"name": getattr(plugin, "name", name),
"priority": getattr(plugin, "priority", 0),
"version": (
getattr(plugin.metadata, "version", "1.0.0")
if hasattr(plugin, "metadata")
else "1.0.0"
),
"description": (
getattr(plugin.metadata, "description", "")
if hasattr(plugin, "metadata")
else ""
),
"author": (
getattr(plugin.metadata, "author", "Unknown")
if hasattr(plugin, "metadata")
else "Unknown"
),
}
)
except Exception as exc: # pragma: no cover
logger.error(f"Error accessing plugin {name}: {exc}")
continue
strategies.sort(key=lambda x: x["priority"], reverse=True)
return JSONResponse({"strategies": strategies, "total": len(strategies)})