2025-12-10 20:52:44 +08:00
|
|
|
"""
|
|
|
|
|
提供商策略管理 API 端点
|
|
|
|
|
"""
|
|
|
|
|
|
2026-01-05 02:23:24 +08:00
|
|
|
from datetime import datetime, timedelta, timezone
|
2025-12-10 20:52:44 +08:00
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
|
|
|
|
from fastapi.responses import JSONResponse
|
|
|
|
|
from pydantic import BaseModel, Field, ValidationError
|
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
|
|
|
|
from src.api.base.admin_adapter import AdminApiAdapter
|
|
|
|
|
from src.api.base.pipeline import ApiRequestPipeline
|
|
|
|
|
from src.core.enums import ProviderBillingType
|
|
|
|
|
from src.core.exceptions import InvalidRequestException, translate_pydantic_error
|
|
|
|
|
from src.core.logger import logger
|
|
|
|
|
from src.database import get_db
|
|
|
|
|
from src.models.database import Provider
|
|
|
|
|
from src.models.database_extensions import ProviderUsageTracking
|
|
|
|
|
|
|
|
|
|
router = APIRouter(prefix="/api/admin/provider-strategy", tags=["Provider Strategy"])
|
|
|
|
|
pipeline = ApiRequestPipeline()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ProviderBillingUpdate(BaseModel):
|
|
|
|
|
billing_type: ProviderBillingType
|
|
|
|
|
monthly_quota_usd: Optional[float] = None
|
|
|
|
|
quota_reset_day: int = Field(default=30, ge=1, le=365) # 重置周期(天数)
|
|
|
|
|
quota_last_reset_at: Optional[str] = None # 当前周期开始时间
|
|
|
|
|
quota_expires_at: Optional[str] = None
|
|
|
|
|
rpm_limit: Optional[int] = Field(default=None, ge=0)
|
|
|
|
|
provider_priority: int = Field(default=100, ge=0, le=200)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.put("/providers/{provider_id}/billing")
|
|
|
|
|
async def update_provider_billing(
|
|
|
|
|
provider_id: str,
|
|
|
|
|
request: Request,
|
|
|
|
|
db: Session = Depends(get_db),
|
|
|
|
|
):
|
|
|
|
|
adapter = AdminProviderBillingAdapter(provider_id=provider_id)
|
|
|
|
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/providers/{provider_id}/stats")
|
|
|
|
|
async def get_provider_stats(
|
|
|
|
|
provider_id: str,
|
|
|
|
|
request: Request,
|
|
|
|
|
hours: int = 24,
|
|
|
|
|
db: Session = Depends(get_db),
|
|
|
|
|
):
|
|
|
|
|
adapter = AdminProviderStatsAdapter(provider_id=provider_id, hours=hours)
|
|
|
|
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.delete("/providers/{provider_id}/quota")
|
|
|
|
|
async def reset_provider_quota(
|
|
|
|
|
provider_id: str,
|
|
|
|
|
request: Request,
|
|
|
|
|
db: Session = Depends(get_db),
|
|
|
|
|
):
|
|
|
|
|
"""Reset provider quota usage to zero"""
|
|
|
|
|
adapter = AdminProviderResetQuotaAdapter(provider_id=provider_id)
|
|
|
|
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/strategies")
|
|
|
|
|
async def list_available_strategies(request: Request, db: Session = Depends(get_db)):
|
|
|
|
|
adapter = AdminListStrategiesAdapter()
|
|
|
|
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AdminProviderBillingAdapter(AdminApiAdapter):
|
|
|
|
|
def __init__(self, provider_id: str):
|
|
|
|
|
self.provider_id = provider_id
|
|
|
|
|
|
|
|
|
|
async def handle(self, context):
|
|
|
|
|
db = context.db
|
|
|
|
|
payload = context.ensure_json_body()
|
|
|
|
|
try:
|
|
|
|
|
config = ProviderBillingUpdate.model_validate(payload)
|
|
|
|
|
except ValidationError as e:
|
|
|
|
|
errors = e.errors()
|
|
|
|
|
if errors:
|
|
|
|
|
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
|
|
|
|
raise InvalidRequestException("请求数据验证失败")
|
|
|
|
|
|
|
|
|
|
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
|
|
|
|
if not provider:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Provider not found")
|
|
|
|
|
|
|
|
|
|
provider.billing_type = config.billing_type
|
|
|
|
|
provider.monthly_quota_usd = config.monthly_quota_usd
|
|
|
|
|
provider.quota_reset_day = config.quota_reset_day
|
|
|
|
|
provider.rpm_limit = config.rpm_limit
|
|
|
|
|
provider.provider_priority = config.provider_priority
|
|
|
|
|
|
|
|
|
|
from dateutil import parser
|
|
|
|
|
from sqlalchemy import func
|
|
|
|
|
|
|
|
|
|
from src.models.database import Usage
|
|
|
|
|
|
|
|
|
|
if config.quota_last_reset_at:
|
|
|
|
|
new_reset_at = parser.parse(config.quota_last_reset_at)
|
2026-01-05 02:23:24 +08:00
|
|
|
# 确保有时区信息,如果没有则假设为 UTC
|
|
|
|
|
if new_reset_at.tzinfo is None:
|
|
|
|
|
new_reset_at = new_reset_at.replace(tzinfo=timezone.utc)
|
2025-12-10 20:52:44 +08:00
|
|
|
provider.quota_last_reset_at = new_reset_at
|
|
|
|
|
|
|
|
|
|
# 自动同步该周期内的历史使用量
|
|
|
|
|
period_usage = (
|
|
|
|
|
db.query(func.coalesce(func.sum(Usage.total_cost_usd), 0))
|
|
|
|
|
.filter(
|
|
|
|
|
Usage.provider_id == self.provider_id,
|
|
|
|
|
Usage.created_at >= new_reset_at,
|
|
|
|
|
)
|
|
|
|
|
.scalar()
|
|
|
|
|
)
|
|
|
|
|
provider.monthly_used_usd = float(period_usage or 0)
|
|
|
|
|
logger.info(f"Synced usage for provider {provider.name}: ${period_usage:.4f} since {new_reset_at}")
|
|
|
|
|
|
|
|
|
|
if config.quota_expires_at:
|
2026-01-05 02:23:24 +08:00
|
|
|
expires_at = parser.parse(config.quota_expires_at)
|
|
|
|
|
# 确保有时区信息,如果没有则假设为 UTC
|
|
|
|
|
if expires_at.tzinfo is None:
|
|
|
|
|
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
|
|
|
|
provider.quota_expires_at = expires_at
|
2025-12-10 20:52:44 +08:00
|
|
|
|
|
|
|
|
db.commit()
|
|
|
|
|
db.refresh(provider)
|
|
|
|
|
|
|
|
|
|
logger.info(f"Updated billing config for provider {provider.name}")
|
|
|
|
|
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
{
|
|
|
|
|
"message": "Provider billing config updated successfully",
|
|
|
|
|
"provider": {
|
|
|
|
|
"id": provider.id,
|
|
|
|
|
"name": provider.name,
|
|
|
|
|
"billing_type": provider.billing_type.value,
|
|
|
|
|
"provider_priority": provider.provider_priority,
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AdminProviderStatsAdapter(AdminApiAdapter):
|
|
|
|
|
def __init__(self, provider_id: str, hours: int):
|
|
|
|
|
self.provider_id = provider_id
|
|
|
|
|
self.hours = hours
|
|
|
|
|
|
|
|
|
|
async def handle(self, context):
|
|
|
|
|
db = context.db
|
|
|
|
|
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
|
|
|
|
if not provider:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Provider not found")
|
|
|
|
|
|
2026-01-05 02:23:24 +08:00
|
|
|
since = datetime.now(timezone.utc) - timedelta(hours=self.hours)
|
2025-12-10 20:52:44 +08:00
|
|
|
stats = (
|
|
|
|
|
db.query(ProviderUsageTracking)
|
|
|
|
|
.filter(
|
|
|
|
|
ProviderUsageTracking.provider_id == self.provider_id,
|
|
|
|
|
ProviderUsageTracking.window_start >= since,
|
|
|
|
|
)
|
|
|
|
|
.all()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
total_requests = sum(s.total_requests for s in stats)
|
|
|
|
|
total_success = sum(s.successful_requests for s in stats)
|
|
|
|
|
total_failures = sum(s.failed_requests for s in stats)
|
|
|
|
|
avg_response_time = sum(s.avg_response_time_ms for s in stats) / len(stats) if stats else 0
|
|
|
|
|
total_cost = sum(s.total_cost_usd for s in stats)
|
|
|
|
|
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
{
|
|
|
|
|
"provider_id": self.provider_id,
|
|
|
|
|
"provider_name": provider.name,
|
|
|
|
|
"period_hours": self.hours,
|
|
|
|
|
"billing_info": {
|
|
|
|
|
"billing_type": provider.billing_type.value,
|
|
|
|
|
"monthly_quota_usd": provider.monthly_quota_usd,
|
|
|
|
|
"monthly_used_usd": provider.monthly_used_usd,
|
|
|
|
|
"quota_remaining_usd": (
|
|
|
|
|
provider.monthly_quota_usd - provider.monthly_used_usd
|
|
|
|
|
if provider.monthly_quota_usd is not None
|
|
|
|
|
else None
|
|
|
|
|
),
|
|
|
|
|
"quota_expires_at": (
|
|
|
|
|
provider.quota_expires_at.isoformat() if provider.quota_expires_at else None
|
|
|
|
|
),
|
|
|
|
|
},
|
|
|
|
|
"rpm_info": {
|
|
|
|
|
"rpm_limit": provider.rpm_limit,
|
|
|
|
|
"rpm_used": provider.rpm_used,
|
|
|
|
|
"rpm_reset_at": (
|
|
|
|
|
provider.rpm_reset_at.isoformat() if provider.rpm_reset_at else None
|
|
|
|
|
),
|
|
|
|
|
},
|
|
|
|
|
"usage_stats": {
|
|
|
|
|
"total_requests": total_requests,
|
|
|
|
|
"successful_requests": total_success,
|
|
|
|
|
"failed_requests": total_failures,
|
|
|
|
|
"success_rate": total_success / total_requests if total_requests > 0 else 0,
|
|
|
|
|
"avg_response_time_ms": round(avg_response_time, 2),
|
|
|
|
|
"total_cost_usd": round(total_cost, 4),
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AdminProviderResetQuotaAdapter(AdminApiAdapter):
|
|
|
|
|
def __init__(self, provider_id: str):
|
|
|
|
|
self.provider_id = provider_id
|
|
|
|
|
|
|
|
|
|
async def handle(self, context):
|
|
|
|
|
db = context.db
|
|
|
|
|
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
|
|
|
|
if not provider:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Provider not found")
|
|
|
|
|
|
|
|
|
|
if provider.billing_type != ProviderBillingType.MONTHLY_QUOTA:
|
|
|
|
|
raise HTTPException(status_code=400, detail="Only monthly quota providers can be reset")
|
|
|
|
|
|
|
|
|
|
old_used = provider.monthly_used_usd
|
|
|
|
|
provider.monthly_used_usd = 0.0
|
|
|
|
|
provider.rpm_used = 0
|
|
|
|
|
provider.rpm_reset_at = None
|
|
|
|
|
db.commit()
|
|
|
|
|
|
|
|
|
|
logger.info(f"Manually reset quota for provider {provider.name}")
|
|
|
|
|
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
{
|
|
|
|
|
"message": "Provider quota reset successfully",
|
|
|
|
|
"provider_name": provider.name,
|
|
|
|
|
"previous_used": old_used,
|
|
|
|
|
"current_used": 0.0,
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AdminListStrategiesAdapter(AdminApiAdapter):
|
|
|
|
|
async def handle(self, context):
|
|
|
|
|
from src.plugins.manager import get_plugin_manager
|
|
|
|
|
|
|
|
|
|
plugin_manager = get_plugin_manager()
|
|
|
|
|
lb_plugins = plugin_manager.plugins.get("load_balancer", {})
|
|
|
|
|
|
|
|
|
|
strategies = []
|
|
|
|
|
for name, plugin in lb_plugins.items():
|
|
|
|
|
try:
|
|
|
|
|
strategies.append(
|
|
|
|
|
{
|
|
|
|
|
"name": getattr(plugin, "name", name),
|
|
|
|
|
"priority": getattr(plugin, "priority", 0),
|
|
|
|
|
"version": (
|
|
|
|
|
getattr(plugin.metadata, "version", "1.0.0")
|
|
|
|
|
if hasattr(plugin, "metadata")
|
|
|
|
|
else "1.0.0"
|
|
|
|
|
),
|
|
|
|
|
"description": (
|
|
|
|
|
getattr(plugin.metadata, "description", "")
|
|
|
|
|
if hasattr(plugin, "metadata")
|
|
|
|
|
else ""
|
|
|
|
|
),
|
|
|
|
|
"author": (
|
|
|
|
|
getattr(plugin.metadata, "author", "Unknown")
|
|
|
|
|
if hasattr(plugin, "metadata")
|
|
|
|
|
else "Unknown"
|
|
|
|
|
),
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
except Exception as exc: # pragma: no cover
|
|
|
|
|
logger.error(f"Error accessing plugin {name}: {exc}")
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
strategies.sort(key=lambda x: x["priority"], reverse=True)
|
|
|
|
|
return JSONResponse({"strategies": strategies, "total": len(strategies)})
|