Initial commit

This commit is contained in:
fawney19
2025-12-10 20:52:44 +08:00
commit f784106826
485 changed files with 110993 additions and 0 deletions

30
src/api/admin/__init__.py Normal file
View File

@@ -0,0 +1,30 @@
"""Admin API routers."""
from fastapi import APIRouter
from .adaptive import router as adaptive_router
from .api_keys import router as api_keys_router
from .endpoints import router as endpoints_router
from .models import router as models_router
from .monitoring import router as monitoring_router
from .provider_strategy import router as provider_strategy_router
from .providers import router as providers_router
from .security import router as security_router
from .system import router as system_router
from .usage import router as usage_router
from .users import router as users_router
router = APIRouter()
router.include_router(system_router)
router.include_router(users_router)
router.include_router(providers_router)
router.include_router(api_keys_router)
router.include_router(usage_router)
router.include_router(monitoring_router)
router.include_router(endpoints_router)
router.include_router(provider_strategy_router)
router.include_router(adaptive_router)
router.include_router(models_router)
router.include_router(security_router)
__all__ = ["router"]

377
src/api/admin/adaptive.py Normal file
View File

@@ -0,0 +1,377 @@
"""
自适应并发管理 API 端点
设计原则:
- 自适应模式由 max_concurrent 字段决定:
- max_concurrent = NULL启用自适应模式系统自动学习并调整并发限制
- max_concurrent = 数字:固定限制模式,使用用户指定的并发限制
- learned_max_concurrent自适应模式下学习到的并发限制值
- adaptive_mode 是计算字段,基于 max_concurrent 是否为 NULL
"""
from dataclasses import dataclass
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import InvalidRequestException, translate_pydantic_error
from src.database import get_db
from src.models.database import ProviderAPIKey
from src.services.rate_limit.adaptive_concurrency import get_adaptive_manager
router = APIRouter(prefix="/api/admin/adaptive", tags=["Adaptive Concurrency"])
pipeline = ApiRequestPipeline()
# ==================== Pydantic Models ====================
class EnableAdaptiveRequest(BaseModel):
"""启用自适应模式请求"""
enabled: bool = Field(..., description="是否启用自适应模式true=自适应false=固定限制)")
fixed_limit: Optional[int] = Field(
None, ge=1, le=100, description="固定并发限制(仅当 enabled=false 时生效)"
)
class AdaptiveStatsResponse(BaseModel):
"""自适应统计响应"""
adaptive_mode: bool = Field(..., description="是否为自适应模式max_concurrent=NULL")
max_concurrent: Optional[int] = Field(None, description="用户配置的固定限制NULL=自适应)")
effective_limit: Optional[int] = Field(
None, description="当前有效限制(自适应使用学习值,固定使用配置值)"
)
learned_limit: Optional[int] = Field(None, description="学习到的并发限制")
concurrent_429_count: int
rpm_429_count: int
last_429_at: Optional[str]
last_429_type: Optional[str]
adjustment_count: int
recent_adjustments: List[dict]
class KeyListItem(BaseModel):
"""Key 列表项"""
id: str
name: Optional[str]
endpoint_id: str
is_adaptive: bool = Field(..., description="是否为自适应模式max_concurrent=NULL")
max_concurrent: Optional[int] = Field(None, description="固定并发限制NULL=自适应)")
effective_limit: Optional[int] = Field(None, description="当前有效限制")
learned_max_concurrent: Optional[int] = Field(None, description="学习到的并发限制")
concurrent_429_count: int
rpm_429_count: int
# ==================== API Endpoints ====================
@router.get(
"/keys",
response_model=List[KeyListItem],
summary="获取所有启用自适应模式的Key",
)
async def list_adaptive_keys(
request: Request,
endpoint_id: Optional[str] = Query(None, description="按 Endpoint 过滤"),
db: Session = Depends(get_db),
):
"""
获取所有启用自适应模式的Key列表
可选参数:
- endpoint_id: 按 Endpoint 过滤
"""
adapter = ListAdaptiveKeysAdapter(endpoint_id=endpoint_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.patch(
"/keys/{key_id}/mode",
summary="Toggle key's concurrency control mode",
)
async def toggle_adaptive_mode(
key_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""
Toggle the concurrency control mode for a specific key
Parameters:
- enabled: true=adaptive mode (max_concurrent=NULL), false=fixed limit mode
- fixed_limit: fixed limit value (required when enabled=false)
"""
adapter = ToggleAdaptiveModeAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get(
"/keys/{key_id}/stats",
response_model=AdaptiveStatsResponse,
summary="获取Key的自适应统计",
)
async def get_adaptive_stats(
key_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""
获取指定Key的自适应并发统计信息
包括:
- 当前配置
- 学习到的限制
- 429错误统计
- 调整历史
"""
adapter = GetAdaptiveStatsAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete(
"/keys/{key_id}/learning",
summary="Reset key's learning state",
)
async def reset_adaptive_learning(
key_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""
Reset the adaptive learning state for a specific key
Clears:
- Learned concurrency limit (learned_max_concurrent)
- 429 error counts
- Adjustment history
Does not change:
- max_concurrent config (determines adaptive mode)
"""
adapter = ResetAdaptiveLearningAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.patch(
"/keys/{key_id}/limit",
summary="Set key to fixed concurrency limit mode",
)
async def set_concurrent_limit(
key_id: str,
request: Request,
limit: int = Query(..., ge=1, le=100, description="Concurrency limit value"),
db: Session = Depends(get_db),
):
"""
Set key to fixed concurrency limit mode
Note:
- After setting this value, key switches to fixed limit mode and won't auto-adjust
- To restore adaptive mode, use PATCH /keys/{key_id}/mode
"""
adapter = SetConcurrentLimitAdapter(key_id=key_id, limit=limit)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get(
"/summary",
summary="获取自适应并发的全局统计",
)
async def get_adaptive_summary(
request: Request,
db: Session = Depends(get_db),
):
"""
获取自适应并发的全局统计摘要
包括:
- 启用自适应模式的Key数量
- 总429错误数
- 并发限制调整次数
"""
adapter = AdaptiveSummaryAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# ==================== Pipeline 适配器 ====================
@dataclass
class ListAdaptiveKeysAdapter(AdminApiAdapter):
endpoint_id: Optional[str] = None
async def handle(self, context): # type: ignore[override]
# 自适应模式max_concurrent = NULL
query = context.db.query(ProviderAPIKey).filter(ProviderAPIKey.max_concurrent.is_(None))
if self.endpoint_id:
query = query.filter(ProviderAPIKey.endpoint_id == self.endpoint_id)
keys = query.all()
return [
KeyListItem(
id=key.id,
name=key.name,
endpoint_id=key.endpoint_id,
is_adaptive=key.max_concurrent is None,
max_concurrent=key.max_concurrent,
effective_limit=(
key.learned_max_concurrent if key.max_concurrent is None else key.max_concurrent
),
learned_max_concurrent=key.learned_max_concurrent,
concurrent_429_count=key.concurrent_429_count or 0,
rpm_429_count=key.rpm_429_count or 0,
)
for key in keys
]
@dataclass
class ToggleAdaptiveModeAdapter(AdminApiAdapter):
key_id: str
async def handle(self, context): # type: ignore[override]
key = context.db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
if not key:
raise HTTPException(status_code=404, detail="Key not found")
payload = context.ensure_json_body()
try:
body = EnableAdaptiveRequest.model_validate(payload)
except ValidationError as e:
errors = e.errors()
if errors:
raise InvalidRequestException(translate_pydantic_error(errors[0]))
raise InvalidRequestException("请求数据验证失败")
if body.enabled:
# 启用自适应模式:将 max_concurrent 设为 NULL
key.max_concurrent = None
message = "已切换为自适应模式,系统将自动学习并调整并发限制"
else:
# 禁用自适应模式:设置固定限制
if body.fixed_limit is None:
raise HTTPException(
status_code=400, detail="禁用自适应模式时必须提供 fixed_limit 参数"
)
key.max_concurrent = body.fixed_limit
message = f"已切换为固定限制模式,并发限制设为 {body.fixed_limit}"
context.db.commit()
context.db.refresh(key)
is_adaptive = key.max_concurrent is None
return {
"message": message,
"key_id": key.id,
"is_adaptive": is_adaptive,
"max_concurrent": key.max_concurrent,
"effective_limit": key.learned_max_concurrent if is_adaptive else key.max_concurrent,
}
@dataclass
class GetAdaptiveStatsAdapter(AdminApiAdapter):
key_id: str
async def handle(self, context): # type: ignore[override]
key = context.db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
if not key:
raise HTTPException(status_code=404, detail="Key not found")
adaptive_manager = get_adaptive_manager()
stats = adaptive_manager.get_adjustment_stats(key)
# 转换字段名以匹配响应模型
return AdaptiveStatsResponse(
adaptive_mode=stats["adaptive_mode"],
max_concurrent=stats["max_concurrent"],
effective_limit=stats["effective_limit"],
learned_limit=stats["learned_limit"],
concurrent_429_count=stats["concurrent_429_count"],
rpm_429_count=stats["rpm_429_count"],
last_429_at=stats["last_429_at"],
last_429_type=stats["last_429_type"],
adjustment_count=stats["adjustment_count"],
recent_adjustments=stats["recent_adjustments"],
)
@dataclass
class ResetAdaptiveLearningAdapter(AdminApiAdapter):
key_id: str
async def handle(self, context): # type: ignore[override]
key = context.db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
if not key:
raise HTTPException(status_code=404, detail="Key not found")
adaptive_manager = get_adaptive_manager()
adaptive_manager.reset_learning(context.db, key)
return {"message": "学习状态已重置", "key_id": key.id}
@dataclass
class SetConcurrentLimitAdapter(AdminApiAdapter):
key_id: str
limit: int
async def handle(self, context): # type: ignore[override]
key = context.db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
if not key:
raise HTTPException(status_code=404, detail="Key not found")
was_adaptive = key.max_concurrent is None
key.max_concurrent = self.limit
context.db.commit()
context.db.refresh(key)
return {
"message": f"已设置为固定限制模式,并发限制为 {self.limit}",
"key_id": key.id,
"is_adaptive": False,
"max_concurrent": key.max_concurrent,
"previous_mode": "adaptive" if was_adaptive else "fixed",
}
class AdaptiveSummaryAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
# 自适应模式max_concurrent = NULL
adaptive_keys = (
context.db.query(ProviderAPIKey).filter(ProviderAPIKey.max_concurrent.is_(None)).all()
)
total_keys = len(adaptive_keys)
total_concurrent_429 = sum(key.concurrent_429_count or 0 for key in adaptive_keys)
total_rpm_429 = sum(key.rpm_429_count or 0 for key in adaptive_keys)
total_adjustments = sum(len(key.adjustment_history or []) for key in adaptive_keys)
recent_adjustments = []
for key in adaptive_keys:
if key.adjustment_history:
for adj in key.adjustment_history[-3:]:
recent_adjustments.append(
{
"key_id": key.id,
"key_name": key.name,
**adj,
}
)
recent_adjustments.sort(key=lambda item: item.get("timestamp", ""), reverse=True)
return {
"total_adaptive_keys": total_keys,
"total_concurrent_429_errors": total_concurrent_429,
"total_rpm_429_errors": total_rpm_429,
"total_adjustments": total_adjustments,
"recent_adjustments": recent_adjustments[:10],
}

View File

@@ -0,0 +1,5 @@
"""API key admin routes export."""
from .routes import router
__all__ = ["router"]

View File

@@ -0,0 +1,497 @@
"""管理员独立余额 API Key 管理路由。
独立余额Key不关联用户配额有独立余额限制用于给非注册用户使用。
"""
from datetime import datetime, timezone
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import NotFoundException
from src.core.logger import logger
from src.database import get_db
from src.models.api import CreateApiKeyRequest
from src.models.database import ApiKey, User
from src.services.user.apikey import ApiKeyService
router = APIRouter(prefix="/api/admin/api-keys", tags=["Admin - API Keys (Standalone)"])
pipeline = ApiRequestPipeline()
@router.get("")
async def list_standalone_api_keys(
request: Request,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=500),
is_active: Optional[bool] = None,
db: Session = Depends(get_db),
):
"""列出所有独立余额API Keys"""
adapter = AdminListStandaloneKeysAdapter(skip=skip, limit=limit, is_active=is_active)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("")
async def create_standalone_api_key(
request: Request,
key_data: CreateApiKeyRequest,
db: Session = Depends(get_db),
):
"""创建独立余额API Key必须设置余额限制"""
adapter = AdminCreateStandaloneKeyAdapter(key_data=key_data)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.put("/{key_id}")
async def update_api_key(
key_id: str, request: Request, key_data: CreateApiKeyRequest, db: Session = Depends(get_db)
):
"""更新独立余额Key可修改名称、过期时间、余额限制等"""
adapter = AdminUpdateApiKeyAdapter(key_id=key_id, key_data=key_data)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.patch("/{key_id}")
async def toggle_api_key(key_id: str, request: Request, db: Session = Depends(get_db)):
"""Toggle API key active status (PATCH with is_active in body)"""
adapter = AdminToggleApiKeyAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/{key_id}")
async def delete_api_key(key_id: str, request: Request, db: Session = Depends(get_db)):
adapter = AdminDeleteApiKeyAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.patch("/{key_id}/balance")
async def add_balance_to_key(
key_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""Adjust balance for standalone API key (positive to add, negative to deduct)"""
# 从请求体获取调整金额
body = await request.json()
amount_usd = body.get("amount_usd")
# 参数校验
if amount_usd is None:
raise HTTPException(status_code=400, detail="缺少必需参数: amount_usd")
if amount_usd == 0:
raise HTTPException(status_code=400, detail="调整金额不能为 0")
# 类型校验
try:
amount_usd = float(amount_usd)
except (ValueError, TypeError):
raise HTTPException(status_code=400, detail="调整金额必须是有效数字")
# 如果是扣除操作,检查Key是否存在以及余额是否充足
if amount_usd < 0:
api_key = db.query(ApiKey).filter(ApiKey.id == key_id).first()
if not api_key:
raise HTTPException(status_code=404, detail="API密钥不存在")
if not api_key.is_standalone:
raise HTTPException(status_code=400, detail="只能为独立余额Key调整余额")
if api_key.current_balance_usd is not None:
if abs(amount_usd) > api_key.current_balance_usd:
raise HTTPException(
status_code=400,
detail=f"扣除金额 ${abs(amount_usd):.2f} 超过当前余额 ${api_key.current_balance_usd:.2f}",
)
adapter = AdminAddBalanceAdapter(key_id=key_id, amount_usd=amount_usd)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/{key_id}")
async def get_api_key_detail(
key_id: str,
request: Request,
include_key: bool = Query(False, description="Include full decrypted key in response"),
db: Session = Depends(get_db),
):
"""Get API key detail, optionally include full key"""
if include_key:
adapter = AdminGetFullKeyAdapter(key_id=key_id)
else:
# Return basic key info without full key
adapter = AdminGetKeyDetailAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
class AdminListStandaloneKeysAdapter(AdminApiAdapter):
"""列出独立余额Keys"""
def __init__(
self,
skip: int,
limit: int,
is_active: Optional[bool],
):
self.skip = skip
self.limit = limit
self.is_active = is_active
async def handle(self, context): # type: ignore[override]
db = context.db
# 只查询独立余额Keys
query = db.query(ApiKey).filter(ApiKey.is_standalone == True)
if self.is_active is not None:
query = query.filter(ApiKey.is_active == self.is_active)
total = query.count()
api_keys = (
query.order_by(ApiKey.created_at.desc()).offset(self.skip).limit(self.limit).all()
)
context.add_audit_metadata(
action="list_standalone_api_keys",
filter_is_active=self.is_active,
limit=self.limit,
skip=self.skip,
total=total,
)
return {
"api_keys": [
{
"id": api_key.id,
"user_id": api_key.user_id, # 创建者ID
"name": api_key.name,
"key_display": api_key.get_display_key(),
"is_active": api_key.is_active,
"is_standalone": api_key.is_standalone,
"current_balance_usd": api_key.current_balance_usd,
"balance_used_usd": float(api_key.balance_used_usd or 0),
"total_requests": api_key.total_requests,
"total_cost_usd": float(api_key.total_cost_usd or 0),
"rate_limit": api_key.rate_limit,
"allowed_providers": api_key.allowed_providers,
"allowed_api_formats": api_key.allowed_api_formats,
"allowed_models": api_key.allowed_models,
"last_used_at": (
api_key.last_used_at.isoformat() if api_key.last_used_at else None
),
"expires_at": api_key.expires_at.isoformat() if api_key.expires_at else None,
"created_at": api_key.created_at.isoformat(),
"updated_at": api_key.updated_at.isoformat() if api_key.updated_at else None,
"auto_delete_on_expiry": api_key.auto_delete_on_expiry,
}
for api_key in api_keys
],
"total": total,
"limit": self.limit,
"skip": self.skip,
}
class AdminCreateStandaloneKeyAdapter(AdminApiAdapter):
"""创建独立余额Key"""
def __init__(self, key_data: CreateApiKeyRequest):
self.key_data = key_data
async def handle(self, context): # type: ignore[override]
db = context.db
# 独立Key必须设置初始余额
if not self.key_data.initial_balance_usd or self.key_data.initial_balance_usd <= 0:
raise HTTPException(
status_code=400,
detail="创建独立余额Key必须设置有效的初始余额initial_balance_usd > 0",
)
# 独立Key需要关联到管理员用户从context获取
admin_user_id = context.user.id
# 创建独立Key
api_key, plain_key = ApiKeyService.create_api_key(
db=db,
user_id=admin_user_id, # 关联到创建者
name=self.key_data.name,
allowed_providers=self.key_data.allowed_providers,
allowed_api_formats=self.key_data.allowed_api_formats,
allowed_models=self.key_data.allowed_models,
rate_limit=self.key_data.rate_limit or 100,
expire_days=self.key_data.expire_days,
initial_balance_usd=self.key_data.initial_balance_usd,
is_standalone=True, # 标记为独立Key
auto_delete_on_expiry=self.key_data.auto_delete_on_expiry,
)
logger.info(f"管理员创建独立余额Key: ID {api_key.id}, 初始余额 ${self.key_data.initial_balance_usd}")
context.add_audit_metadata(
action="create_standalone_api_key",
key_id=api_key.id,
initial_balance_usd=self.key_data.initial_balance_usd,
)
return {
"id": api_key.id,
"key": plain_key, # 只在创建时返回完整密钥
"name": api_key.name,
"key_display": api_key.get_display_key(),
"is_standalone": True,
"current_balance_usd": api_key.current_balance_usd,
"balance_used_usd": 0.0,
"rate_limit": api_key.rate_limit,
"expires_at": api_key.expires_at.isoformat() if api_key.expires_at else None,
"created_at": api_key.created_at.isoformat(),
"message": "独立余额Key创建成功请妥善保存完整密钥后续将无法查看",
}
class AdminUpdateApiKeyAdapter(AdminApiAdapter):
"""更新独立余额Key"""
def __init__(self, key_id: str, key_data: CreateApiKeyRequest):
self.key_id = key_id
self.key_data = key_data
async def handle(self, context): # type: ignore[override]
db = context.db
api_key = db.query(ApiKey).filter(ApiKey.id == self.key_id).first()
if not api_key:
raise NotFoundException("API密钥不存在", "api_key")
# 构建更新数据
update_data = {}
if self.key_data.name is not None:
update_data["name"] = self.key_data.name
if self.key_data.rate_limit is not None:
update_data["rate_limit"] = self.key_data.rate_limit
if (
hasattr(self.key_data, "auto_delete_on_expiry")
and self.key_data.auto_delete_on_expiry is not None
):
update_data["auto_delete_on_expiry"] = self.key_data.auto_delete_on_expiry
# 访问限制配置(允许设置为空数组来清除限制)
if hasattr(self.key_data, "allowed_providers"):
update_data["allowed_providers"] = self.key_data.allowed_providers
if hasattr(self.key_data, "allowed_api_formats"):
update_data["allowed_api_formats"] = self.key_data.allowed_api_formats
if hasattr(self.key_data, "allowed_models"):
update_data["allowed_models"] = self.key_data.allowed_models
# 处理过期时间
if self.key_data.expire_days is not None:
if self.key_data.expire_days > 0:
from datetime import timedelta
update_data["expires_at"] = datetime.now(timezone.utc) + timedelta(
days=self.key_data.expire_days
)
else:
# expire_days = 0 或负数表示永不过期
update_data["expires_at"] = None
elif hasattr(self.key_data, "expire_days") and self.key_data.expire_days is None:
# 明确传递 None设为永不过期
update_data["expires_at"] = None
# 使用 ApiKeyService 更新
updated_key = ApiKeyService.update_api_key(db, self.key_id, **update_data)
if not updated_key:
raise NotFoundException("更新失败", "api_key")
logger.info(f"管理员更新独立余额Key: ID {self.key_id}, 更新字段 {list(update_data.keys())}")
context.add_audit_metadata(
action="update_standalone_api_key",
key_id=self.key_id,
updated_fields=list(update_data.keys()),
)
return {
"id": updated_key.id,
"name": updated_key.name,
"key_display": updated_key.get_display_key(),
"is_active": updated_key.is_active,
"current_balance_usd": updated_key.current_balance_usd,
"balance_used_usd": float(updated_key.balance_used_usd or 0),
"rate_limit": updated_key.rate_limit,
"expires_at": updated_key.expires_at.isoformat() if updated_key.expires_at else None,
"updated_at": updated_key.updated_at.isoformat() if updated_key.updated_at else None,
"message": "API密钥已更新",
}
class AdminToggleApiKeyAdapter(AdminApiAdapter):
def __init__(self, key_id: str):
self.key_id = key_id
async def handle(self, context): # type: ignore[override]
db = context.db
api_key = db.query(ApiKey).filter(ApiKey.id == self.key_id).first()
if not api_key:
raise NotFoundException("API密钥不存在", "api_key")
api_key.is_active = not api_key.is_active
api_key.updated_at = datetime.now(timezone.utc)
db.commit()
db.refresh(api_key)
logger.info(f"管理员切换API密钥状态: Key ID {self.key_id}, 新状态 {'启用' if api_key.is_active else '禁用'}")
context.add_audit_metadata(
action="toggle_api_key",
target_key_id=api_key.id,
user_id=api_key.user_id,
new_status="enabled" if api_key.is_active else "disabled",
)
return {
"id": api_key.id,
"is_active": api_key.is_active,
"message": f"API密钥已{'启用' if api_key.is_active else '禁用'}",
}
class AdminDeleteApiKeyAdapter(AdminApiAdapter):
def __init__(self, key_id: str):
self.key_id = key_id
async def handle(self, context): # type: ignore[override]
db = context.db
api_key = db.query(ApiKey).filter(ApiKey.id == self.key_id).first()
if not api_key:
raise HTTPException(status_code=404, detail="API密钥不存在")
user = api_key.user
db.delete(api_key)
db.commit()
logger.info(f"管理员删除API密钥: Key ID {self.key_id}, 用户 {user.email if user else '未知'}")
context.add_audit_metadata(
action="delete_api_key",
target_key_id=self.key_id,
user_id=user.id if user else None,
user_email=user.email if user else None,
)
return {"message": "API密钥已删除"}
class AdminAddBalanceAdapter(AdminApiAdapter):
"""为独立余额Key增加余额"""
def __init__(self, key_id: str, amount_usd: float):
self.key_id = key_id
self.amount_usd = amount_usd
async def handle(self, context): # type: ignore[override]
db = context.db
# 使用 ApiKeyService 增加余额
updated_key = ApiKeyService.add_balance(db, self.key_id, self.amount_usd)
if not updated_key:
raise NotFoundException("余额充值失败Key不存在或不是独立余额Key", "api_key")
logger.info(f"管理员为独立余额Key充值: ID {self.key_id}, 充值 ${self.amount_usd:.4f}")
context.add_audit_metadata(
action="add_balance_to_key",
key_id=self.key_id,
amount_usd=self.amount_usd,
new_current_balance=updated_key.current_balance_usd,
)
return {
"id": updated_key.id,
"name": updated_key.name,
"current_balance_usd": updated_key.current_balance_usd,
"balance_used_usd": float(updated_key.balance_used_usd or 0),
"message": f"余额充值成功,充值 ${self.amount_usd:.2f},当前余额 ${updated_key.current_balance_usd:.2f}",
}
class AdminGetFullKeyAdapter(AdminApiAdapter):
"""获取完整的API密钥"""
def __init__(self, key_id: str):
self.key_id = key_id
async def handle(self, context): # type: ignore[override]
from src.core.crypto import crypto_service
db = context.db
# 查找API密钥
api_key = db.query(ApiKey).filter(ApiKey.id == self.key_id).first()
if not api_key:
raise NotFoundException("API密钥不存在", "api_key")
# 解密完整密钥
if not api_key.key_encrypted:
raise HTTPException(status_code=400, detail="该密钥没有存储完整密钥信息")
try:
full_key = crypto_service.decrypt(api_key.key_encrypted)
except Exception as e:
logger.error(f"解密API密钥失败: Key ID {self.key_id}, 错误: {e}")
raise HTTPException(status_code=500, detail="解密密钥失败")
logger.info(f"管理员查看完整API密钥: Key ID {self.key_id}")
context.add_audit_metadata(
action="view_full_api_key",
key_id=self.key_id,
key_name=api_key.name,
)
return {
"key": full_key,
}
class AdminGetKeyDetailAdapter(AdminApiAdapter):
"""Get API key detail without full key"""
def __init__(self, key_id: str):
self.key_id = key_id
async def handle(self, context): # type: ignore[override]
db = context.db
api_key = db.query(ApiKey).filter(ApiKey.id == self.key_id).first()
if not api_key:
raise NotFoundException("API密钥不存在", "api_key")
context.add_audit_metadata(
action="get_api_key_detail",
key_id=self.key_id,
)
return {
"id": api_key.id,
"user_id": api_key.user_id,
"name": api_key.name,
"key_display": api_key.get_display_key(),
"is_active": api_key.is_active,
"is_standalone": api_key.is_standalone,
"current_balance_usd": api_key.current_balance_usd,
"balance_used_usd": float(api_key.balance_used_usd or 0),
"total_requests": api_key.total_requests,
"total_cost_usd": float(api_key.total_cost_usd or 0),
"rate_limit": api_key.rate_limit,
"allowed_providers": api_key.allowed_providers,
"allowed_api_formats": api_key.allowed_api_formats,
"allowed_models": api_key.allowed_models,
"last_used_at": api_key.last_used_at.isoformat() if api_key.last_used_at else None,
"expires_at": api_key.expires_at.isoformat() if api_key.expires_at else None,
"created_at": api_key.created_at.isoformat(),
"updated_at": api_key.updated_at.isoformat() if api_key.updated_at else None,
}

View File

@@ -0,0 +1,24 @@
"""Endpoint management API routers."""
from fastapi import APIRouter
from .concurrency import router as concurrency_router
from .health import router as health_router
from .keys import router as keys_router
from .routes import router as routes_router
router = APIRouter(prefix="/api/admin/endpoints", tags=["Endpoint Management"])
# Endpoint CRUD
router.include_router(routes_router)
# Endpoint Keys management
router.include_router(keys_router)
# Health monitoring
router.include_router(health_router)
# Concurrency control
router.include_router(concurrency_router)
__all__ = ["router"]

View File

@@ -0,0 +1,116 @@
"""
Endpoint 并发控制管理 API
"""
from dataclasses import dataclass
from typing import Optional
from fastapi import APIRouter, Depends, Request
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import NotFoundException
from src.database import get_db
from src.models.database import ProviderAPIKey, ProviderEndpoint
from src.models.endpoint_models import (
ConcurrencyStatusResponse,
ResetConcurrencyRequest,
)
from src.services.rate_limit.concurrency_manager import get_concurrency_manager
router = APIRouter(tags=["Concurrency Control"])
pipeline = ApiRequestPipeline()
@router.get("/concurrency/endpoint/{endpoint_id}", response_model=ConcurrencyStatusResponse)
async def get_endpoint_concurrency(
endpoint_id: str,
request: Request,
db: Session = Depends(get_db),
) -> ConcurrencyStatusResponse:
"""获取 Endpoint 当前并发状态"""
adapter = AdminEndpointConcurrencyAdapter(endpoint_id=endpoint_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/concurrency/key/{key_id}", response_model=ConcurrencyStatusResponse)
async def get_key_concurrency(
key_id: str,
request: Request,
db: Session = Depends(get_db),
) -> ConcurrencyStatusResponse:
"""获取 Key 当前并发状态"""
adapter = AdminKeyConcurrencyAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/concurrency")
async def reset_concurrency(
request: ResetConcurrencyRequest,
http_request: Request,
db: Session = Depends(get_db),
) -> dict:
"""Reset concurrency counters (admin function, use with caution)"""
adapter = AdminResetConcurrencyAdapter(endpoint_id=request.endpoint_id, key_id=request.key_id)
return await pipeline.run(adapter=adapter, http_request=http_request, db=db, mode=adapter.mode)
# -------- Adapters --------
@dataclass
class AdminEndpointConcurrencyAdapter(AdminApiAdapter):
endpoint_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
concurrency_manager = await get_concurrency_manager()
endpoint_count, _ = await concurrency_manager.get_current_concurrency(
endpoint_id=self.endpoint_id
)
return ConcurrencyStatusResponse(
endpoint_id=self.endpoint_id,
endpoint_current_concurrency=endpoint_count,
endpoint_max_concurrent=endpoint.max_concurrent,
)
@dataclass
class AdminKeyConcurrencyAdapter(AdminApiAdapter):
key_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
if not key:
raise NotFoundException(f"Key {self.key_id} 不存在")
concurrency_manager = await get_concurrency_manager()
_, key_count = await concurrency_manager.get_current_concurrency(key_id=self.key_id)
return ConcurrencyStatusResponse(
key_id=self.key_id,
key_current_concurrency=key_count,
key_max_concurrent=key.max_concurrent,
)
@dataclass
class AdminResetConcurrencyAdapter(AdminApiAdapter):
endpoint_id: Optional[str]
key_id: Optional[str]
async def handle(self, context): # type: ignore[override]
concurrency_manager = await get_concurrency_manager()
await concurrency_manager.reset_concurrency(
endpoint_id=self.endpoint_id, key_id=self.key_id
)
return {"message": "并发计数已重置"}

View File

@@ -0,0 +1,476 @@
"""
Endpoint 健康监控 API
"""
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Dict, List
from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy import func
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import NotFoundException
from src.core.logger import logger
from src.database import get_db
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint, RequestCandidate
from src.models.endpoint_models import (
ApiFormatHealthMonitor,
ApiFormatHealthMonitorResponse,
EndpointHealthEvent,
HealthStatusResponse,
HealthSummaryResponse,
)
from src.services.health.endpoint import EndpointHealthService
from src.services.health.monitor import health_monitor
router = APIRouter(tags=["Endpoint Health"])
pipeline = ApiRequestPipeline()
@router.get("/health/summary", response_model=HealthSummaryResponse)
async def get_health_summary(
request: Request,
db: Session = Depends(get_db),
) -> HealthSummaryResponse:
"""获取健康状态摘要"""
adapter = AdminHealthSummaryAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/health/status")
async def get_endpoint_health_status(
request: Request,
lookback_hours: int = Query(6, ge=1, le=72, description="回溯的小时数"),
db: Session = Depends(get_db),
):
"""
获取端点健康状态(简化视图,与用户端点统一)
与 /health/api-formats 的区别:
- /health/status: 返回聚合的时间线状态50个时间段基于 Usage 表
- /health/api-formats: 返回详细的事件列表,基于 RequestCandidate 表
"""
adapter = AdminEndpointHealthStatusAdapter(lookback_hours=lookback_hours)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/health/api-formats", response_model=ApiFormatHealthMonitorResponse)
async def get_api_format_health_monitor(
request: Request,
lookback_hours: int = Query(6, ge=1, le=72, description="回溯的小时数"),
per_format_limit: int = Query(60, ge=10, le=200, description="每个 API 格式的事件数量"),
db: Session = Depends(get_db),
) -> ApiFormatHealthMonitorResponse:
"""获取按 API 格式聚合的健康监控时间线(详细事件列表)"""
adapter = AdminApiFormatHealthMonitorAdapter(
lookback_hours=lookback_hours,
per_format_limit=per_format_limit,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/health/key/{key_id}", response_model=HealthStatusResponse)
async def get_key_health(
key_id: str,
request: Request,
db: Session = Depends(get_db),
) -> HealthStatusResponse:
"""获取 Key 健康状态"""
adapter = AdminKeyHealthAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.patch("/health/keys/{key_id}")
async def recover_key_health(
key_id: str,
request: Request,
db: Session = Depends(get_db),
) -> dict:
"""
Recover key health status
Resets health_score to 1.0, closes circuit breaker,
cancels auto-disable, and resets all failure counts.
Parameters:
- key_id: Key ID (path parameter)
"""
adapter = AdminRecoverKeyHealthAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.patch("/health/keys")
async def recover_all_keys_health(
request: Request,
db: Session = Depends(get_db),
) -> dict:
"""
Batch recover all circuit-broken keys
Finds all keys with circuit_breaker_open=True and:
1. Resets health_score to 1.0
2. Closes circuit breaker
3. Resets failure counts
"""
adapter = AdminRecoverAllKeysHealthAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# -------- Adapters --------
class AdminHealthSummaryAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
summary = health_monitor.get_all_health_status(context.db)
return HealthSummaryResponse(**summary)
@dataclass
class AdminEndpointHealthStatusAdapter(AdminApiAdapter):
"""管理员端点健康状态适配器(与用户端点统一,但包含管理员字段)"""
lookback_hours: int
async def handle(self, context): # type: ignore[override]
from src.services.health.endpoint import EndpointHealthService
db = context.db
# 使用共享服务获取健康状态(管理员视图)
result = EndpointHealthService.get_endpoint_health_by_format(
db=db,
lookback_hours=self.lookback_hours,
include_admin_fields=True, # 包含管理员字段
use_cache=False, # 管理员不使用缓存,确保实时性
)
context.add_audit_metadata(
action="endpoint_health_status",
format_count=len(result),
lookback_hours=self.lookback_hours,
)
return result
@dataclass
class AdminApiFormatHealthMonitorAdapter(AdminApiAdapter):
lookback_hours: int
per_format_limit: int
async def handle(self, context): # type: ignore[override]
db = context.db
now = datetime.now(timezone.utc)
since = now - timedelta(hours=self.lookback_hours)
# 1. 获取所有活跃的 API 格式及其 Provider 数量
active_formats = (
db.query(
ProviderEndpoint.api_format,
func.count(func.distinct(ProviderEndpoint.provider_id)).label("provider_count"),
)
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
.filter(
ProviderEndpoint.is_active.is_(True),
Provider.is_active.is_(True),
)
.group_by(ProviderEndpoint.api_format)
.all()
)
# 构建所有格式的 provider_count 映射
all_formats: Dict[str, int] = {}
for api_format_enum, provider_count in active_formats:
api_format = (
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
)
all_formats[api_format] = provider_count
# 1.1 获取所有活跃的 API 格式及其 API Key 数量
active_keys = (
db.query(
ProviderEndpoint.api_format,
func.count(ProviderAPIKey.id).label("key_count"),
)
.join(ProviderAPIKey, ProviderEndpoint.id == ProviderAPIKey.endpoint_id)
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
.filter(
ProviderEndpoint.is_active.is_(True),
Provider.is_active.is_(True),
ProviderAPIKey.is_active.is_(True),
)
.group_by(ProviderEndpoint.api_format)
.all()
)
# 构建所有格式的 key_count 映射
key_counts: Dict[str, int] = {}
for api_format_enum, key_count in active_keys:
api_format = (
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
)
key_counts[api_format] = key_count
# 1.2 建立每个 API 格式对应的 Endpoint ID 列表,供 Usage 时间线生成使用
endpoint_rows = (
db.query(ProviderEndpoint.api_format, ProviderEndpoint.id)
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
.filter(
ProviderEndpoint.is_active.is_(True),
Provider.is_active.is_(True),
)
.all()
)
endpoint_map: Dict[str, List[str]] = defaultdict(list)
for api_format_enum, endpoint_id in endpoint_rows:
api_format = (
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
)
endpoint_map[api_format].append(endpoint_id)
# 2. 统计窗口内每个 API 格式的请求状态分布(真实统计)
# 只统计最终状态success, failed, skipped
final_statuses = ["success", "failed", "skipped"]
status_counts_query = (
db.query(
ProviderEndpoint.api_format,
RequestCandidate.status,
func.count(RequestCandidate.id).label("count"),
)
.join(RequestCandidate, ProviderEndpoint.id == RequestCandidate.endpoint_id)
.filter(
RequestCandidate.created_at >= since,
RequestCandidate.status.in_(final_statuses),
)
.group_by(ProviderEndpoint.api_format, RequestCandidate.status)
.all()
)
# 构建每个格式的状态统计
status_counts: Dict[str, Dict[str, int]] = {}
for api_format_enum, status, count in status_counts_query:
api_format = (
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
)
if api_format not in status_counts:
status_counts[api_format] = {"success": 0, "failed": 0, "skipped": 0}
status_counts[api_format][status] = count
# 3. 获取最近一段时间的 RequestCandidate限制数量
# 使用上面定义的 final_statuses排除中间状态
limit_rows = max(500, self.per_format_limit * 10)
rows = (
db.query(
RequestCandidate,
ProviderEndpoint.api_format,
ProviderEndpoint.provider_id,
)
.join(ProviderEndpoint, RequestCandidate.endpoint_id == ProviderEndpoint.id)
.filter(
RequestCandidate.created_at >= since,
RequestCandidate.status.in_(final_statuses),
)
.order_by(RequestCandidate.created_at.desc())
.limit(limit_rows)
.all()
)
grouped_attempts: Dict[str, List[RequestCandidate]] = {}
for attempt, api_format_enum, provider_id in rows:
api_format = (
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
)
if api_format not in grouped_attempts:
grouped_attempts[api_format] = []
# 只保留每个 API 格式最近 per_format_limit 条记录
if len(grouped_attempts[api_format]) < self.per_format_limit:
grouped_attempts[api_format].append(attempt)
# 4. 为所有活跃格式生成监控数据(包括没有请求记录的)
monitors: List[ApiFormatHealthMonitor] = []
for api_format in all_formats:
attempts = grouped_attempts.get(api_format, [])
# 获取窗口内的真实统计数据
# 只统计最终状态success, failed, skipped
# 中间状态available, pending, used, started不计入统计
format_stats = status_counts.get(api_format, {"success": 0, "failed": 0, "skipped": 0})
real_success_count = format_stats.get("success", 0)
real_failed_count = format_stats.get("failed", 0)
real_skipped_count = format_stats.get("skipped", 0)
# total_attempts 只包含最终状态的请求数
total_attempts = real_success_count + real_failed_count + real_skipped_count
# 时间线按时间正序
attempts_sorted = list(reversed(attempts))
events: List[EndpointHealthEvent] = []
for attempt in attempts_sorted:
event_timestamp = attempt.finished_at or attempt.started_at or attempt.created_at
events.append(
EndpointHealthEvent(
timestamp=event_timestamp,
status=attempt.status,
status_code=attempt.status_code,
latency_ms=attempt.latency_ms,
error_type=attempt.error_type,
error_message=attempt.error_message,
)
)
# 成功率 = success / (success + failed)
# skipped 不算失败,不计入成功率分母
# 无实际完成请求时成功率为 1.0(灰色状态)
actual_completed = real_success_count + real_failed_count
success_rate = real_success_count / actual_completed if actual_completed > 0 else 1.0
last_event_at = events[-1].timestamp if events else None
# 生成 Usage 基于时间窗口的健康时间线
timeline_data = EndpointHealthService._generate_timeline_from_usage(
db=db,
endpoint_ids=endpoint_map.get(api_format, []),
now=now,
lookback_hours=self.lookback_hours,
)
monitors.append(
ApiFormatHealthMonitor(
api_format=api_format,
total_attempts=total_attempts, # 真实总请求数
success_count=real_success_count, # 真实成功数
failed_count=real_failed_count, # 真实失败数
skipped_count=real_skipped_count, # 真实跳过数
success_rate=success_rate, # 基于真实统计的成功率
provider_count=all_formats[api_format],
key_count=key_counts.get(api_format, 0),
last_event_at=last_event_at,
events=events, # 限制为 per_format_limit 条(用于时间线显示)
timeline=timeline_data.get("timeline", []),
time_range_start=timeline_data.get("time_range_start"),
time_range_end=timeline_data.get("time_range_end"),
)
)
response = ApiFormatHealthMonitorResponse(
generated_at=now,
formats=monitors,
)
context.add_audit_metadata(
action="api_format_health_monitor",
format_count=len(monitors),
lookback_hours=self.lookback_hours,
per_format_limit=self.per_format_limit,
)
return response
@dataclass
class AdminKeyHealthAdapter(AdminApiAdapter):
key_id: str
async def handle(self, context): # type: ignore[override]
health_data = health_monitor.get_key_health(context.db, self.key_id)
if not health_data:
raise NotFoundException(f"Key {self.key_id} 不存在")
return HealthStatusResponse(
key_id=health_data["key_id"],
key_health_score=health_data["health_score"],
key_consecutive_failures=health_data["consecutive_failures"],
key_last_failure_at=health_data["last_failure_at"],
key_is_active=health_data["is_active"],
key_statistics=health_data["statistics"],
circuit_breaker_open=health_data["circuit_breaker_open"],
circuit_breaker_open_at=health_data["circuit_breaker_open_at"],
next_probe_at=health_data["next_probe_at"],
)
@dataclass
class AdminRecoverKeyHealthAdapter(AdminApiAdapter):
key_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
if not key:
raise NotFoundException(f"Key {self.key_id} 不存在")
key.health_score = 1.0
key.consecutive_failures = 0
key.last_failure_at = None
key.circuit_breaker_open = False
key.circuit_breaker_open_at = None
key.next_probe_at = None
if not key.is_active:
key.is_active = True
db.commit()
admin_name = context.user.username if context.user else "admin"
logger.info(f"管理员恢复Key健康状态: {self.key_id} (health_score: 1.0, circuit_breaker: closed)")
return {
"message": "Key已完全恢复",
"details": {
"health_score": 1.0,
"circuit_breaker_open": False,
"is_active": True,
},
}
class AdminRecoverAllKeysHealthAdapter(AdminApiAdapter):
"""批量恢复所有熔断 Key 的健康状态"""
async def handle(self, context): # type: ignore[override]
db = context.db
# 查找所有熔断的 Key
circuit_open_keys = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.circuit_breaker_open == True).all()
)
if not circuit_open_keys:
return {
"message": "没有需要恢复的 Key",
"recovered_count": 0,
"recovered_keys": [],
}
recovered_keys = []
for key in circuit_open_keys:
key.health_score = 1.0
key.consecutive_failures = 0
key.last_failure_at = None
key.circuit_breaker_open = False
key.circuit_breaker_open_at = None
key.next_probe_at = None
recovered_keys.append(
{
"key_id": key.id,
"key_name": key.name,
"endpoint_id": key.endpoint_id,
}
)
db.commit()
# 重置健康监控器的计数
from src.services.health.monitor import HealthMonitor, health_open_circuits
HealthMonitor._open_circuit_keys = 0
health_open_circuits.set(0)
admin_name = context.user.username if context.user else "admin"
logger.info(f"管理员批量恢复 {len(recovered_keys)} 个 Key 的健康状态")
return {
"message": f"已恢复 {len(recovered_keys)} 个 Key",
"recovered_count": len(recovered_keys),
"recovered_keys": recovered_keys,
}

View File

@@ -0,0 +1,425 @@
"""
Endpoint API Keys 管理
"""
import uuid
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Dict, List
from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.crypto import crypto_service
from src.core.exceptions import InvalidRequestException, NotFoundException
from src.core.key_capabilities import get_capability
from src.core.logger import logger
from src.database import get_db
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
from src.models.endpoint_models import (
BatchUpdateKeyPriorityRequest,
EndpointAPIKeyCreate,
EndpointAPIKeyResponse,
EndpointAPIKeyUpdate,
)
router = APIRouter(tags=["Endpoint Keys"])
pipeline = ApiRequestPipeline()
@router.get("/{endpoint_id}/keys", response_model=List[EndpointAPIKeyResponse])
async def list_endpoint_keys(
endpoint_id: str,
request: Request,
skip: int = Query(0, ge=0, description="跳过的记录数"),
limit: int = Query(100, ge=1, le=1000, description="返回的最大记录数"),
db: Session = Depends(get_db),
) -> List[EndpointAPIKeyResponse]:
"""获取 Endpoint 的所有 Keys"""
adapter = AdminListEndpointKeysAdapter(
endpoint_id=endpoint_id,
skip=skip,
limit=limit,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/{endpoint_id}/keys", response_model=EndpointAPIKeyResponse)
async def add_endpoint_key(
endpoint_id: str,
key_data: EndpointAPIKeyCreate,
request: Request,
db: Session = Depends(get_db),
) -> EndpointAPIKeyResponse:
"""为 Endpoint 添加 Key"""
adapter = AdminCreateEndpointKeyAdapter(endpoint_id=endpoint_id, key_data=key_data)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.put("/keys/{key_id}", response_model=EndpointAPIKeyResponse)
async def update_endpoint_key(
key_id: str,
key_data: EndpointAPIKeyUpdate,
request: Request,
db: Session = Depends(get_db),
) -> EndpointAPIKeyResponse:
"""更新 Endpoint Key"""
adapter = AdminUpdateEndpointKeyAdapter(key_id=key_id, key_data=key_data)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/keys/grouped-by-format")
async def get_keys_grouped_by_format(
request: Request,
db: Session = Depends(get_db),
) -> dict:
"""获取按 API 格式分组的所有 Keys用于全局优先级管理"""
adapter = AdminGetKeysGroupedByFormatAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/keys/{key_id}")
async def delete_endpoint_key(
key_id: str,
request: Request,
db: Session = Depends(get_db),
) -> dict:
"""删除 Endpoint Key"""
adapter = AdminDeleteEndpointKeyAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.put("/{endpoint_id}/keys/batch-priority")
async def batch_update_key_priority(
endpoint_id: str,
request: Request,
priority_data: BatchUpdateKeyPriorityRequest,
db: Session = Depends(get_db),
) -> dict:
"""批量更新 Endpoint 下 Keys 的优先级(用于拖动排序)"""
adapter = AdminBatchUpdateKeyPriorityAdapter(endpoint_id=endpoint_id, priority_data=priority_data)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# -------- Adapters --------
@dataclass
class AdminListEndpointKeysAdapter(AdminApiAdapter):
endpoint_id: str
skip: int
limit: int
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
keys = (
db.query(ProviderAPIKey)
.filter(ProviderAPIKey.endpoint_id == self.endpoint_id)
.order_by(ProviderAPIKey.internal_priority.asc(), ProviderAPIKey.created_at.asc())
.offset(self.skip)
.limit(self.limit)
.all()
)
result: List[EndpointAPIKeyResponse] = []
for key in keys:
try:
decrypted_key = crypto_service.decrypt(key.api_key)
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
except Exception:
masked_key = "***ERROR***"
success_rate = key.success_count / key.request_count if key.request_count > 0 else 0.0
avg_response_time_ms = (
key.total_response_time_ms / key.success_count if key.success_count > 0 else 0.0
)
is_adaptive = key.max_concurrent is None
key_dict = key.__dict__.copy()
key_dict.pop("_sa_instance_state", None)
key_dict.update(
{
"api_key_masked": masked_key,
"api_key_plain": None,
"success_rate": success_rate,
"avg_response_time_ms": round(avg_response_time_ms, 2),
"is_adaptive": is_adaptive,
"effective_limit": (
key.learned_max_concurrent if is_adaptive else key.max_concurrent
),
}
)
result.append(EndpointAPIKeyResponse(**key_dict))
return result
@dataclass
class AdminCreateEndpointKeyAdapter(AdminApiAdapter):
endpoint_id: str
key_data: EndpointAPIKeyCreate
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
if self.key_data.endpoint_id != self.endpoint_id:
raise InvalidRequestException("endpoint_id 不匹配")
encrypted_key = crypto_service.encrypt(self.key_data.api_key)
now = datetime.now(timezone.utc)
# max_concurrent=NULL 表示自适应模式,数字表示固定限制
new_key = ProviderAPIKey(
id=str(uuid.uuid4()),
endpoint_id=self.endpoint_id,
api_key=encrypted_key,
name=self.key_data.name,
note=self.key_data.note,
rate_multiplier=self.key_data.rate_multiplier,
internal_priority=self.key_data.internal_priority,
max_concurrent=self.key_data.max_concurrent, # NULL=自适应模式
rate_limit=self.key_data.rate_limit,
daily_limit=self.key_data.daily_limit,
monthly_limit=self.key_data.monthly_limit,
allowed_models=self.key_data.allowed_models if self.key_data.allowed_models else None,
capabilities=self.key_data.capabilities if self.key_data.capabilities else None,
request_count=0,
success_count=0,
error_count=0,
total_response_time_ms=0,
is_active=True,
last_used_at=None,
created_at=now,
updated_at=now,
)
db.add(new_key)
db.commit()
db.refresh(new_key)
logger.info(f"[OK] 添加 Key: Endpoint={self.endpoint_id}, Key=***{self.key_data.api_key[-4:]}, ID={new_key.id}")
masked_key = f"{self.key_data.api_key[:8]}***{self.key_data.api_key[-4:]}"
is_adaptive = new_key.max_concurrent is None
response_dict = new_key.__dict__.copy()
response_dict.pop("_sa_instance_state", None)
response_dict.update(
{
"api_key_masked": masked_key,
"api_key_plain": self.key_data.api_key,
"success_rate": 0.0,
"avg_response_time_ms": 0.0,
"is_adaptive": is_adaptive,
"effective_limit": (
new_key.learned_max_concurrent if is_adaptive else new_key.max_concurrent
),
}
)
return EndpointAPIKeyResponse(**response_dict)
@dataclass
class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
key_id: str
key_data: EndpointAPIKeyUpdate
async def handle(self, context): # type: ignore[override]
db = context.db
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
if not key:
raise NotFoundException(f"Key {self.key_id} 不存在")
update_data = self.key_data.model_dump(exclude_unset=True)
if "api_key" in update_data:
update_data["api_key"] = crypto_service.encrypt(update_data["api_key"])
for field, value in update_data.items():
setattr(key, field, value)
key.updated_at = datetime.now(timezone.utc)
db.commit()
db.refresh(key)
logger.info(f"[OK] 更新 Key: ID={self.key_id}, Updates={list(update_data.keys())}")
try:
decrypted_key = crypto_service.decrypt(key.api_key)
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
except Exception:
masked_key = "***ERROR***"
success_rate = key.success_count / key.request_count if key.request_count > 0 else 0.0
avg_response_time_ms = (
key.total_response_time_ms / key.success_count if key.success_count > 0 else 0.0
)
is_adaptive = key.max_concurrent is None
response_dict = key.__dict__.copy()
response_dict.pop("_sa_instance_state", None)
response_dict.update(
{
"api_key_masked": masked_key,
"api_key_plain": None,
"success_rate": success_rate,
"avg_response_time_ms": round(avg_response_time_ms, 2),
"is_adaptive": is_adaptive,
"effective_limit": (
key.learned_max_concurrent if is_adaptive else key.max_concurrent
),
}
)
return EndpointAPIKeyResponse(**response_dict)
@dataclass
class AdminDeleteEndpointKeyAdapter(AdminApiAdapter):
key_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
if not key:
raise NotFoundException(f"Key {self.key_id} 不存在")
endpoint_id = key.endpoint_id
try:
db.delete(key)
db.commit()
except Exception as exc:
db.rollback()
logger.error(f"删除 Key 失败: ID={self.key_id}, Error={exc}")
raise
logger.warning(f"[DELETE] 删除 Key: ID={self.key_id}, Endpoint={endpoint_id}")
return {"message": f"Key {self.key_id} 已删除"}
class AdminGetKeysGroupedByFormatAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
keys = (
db.query(ProviderAPIKey, ProviderEndpoint, Provider)
.join(ProviderEndpoint, ProviderAPIKey.endpoint_id == ProviderEndpoint.id)
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
.filter(
ProviderAPIKey.is_active.is_(True),
ProviderEndpoint.is_active.is_(True),
Provider.is_active.is_(True),
)
.order_by(
ProviderAPIKey.global_priority.asc().nullslast(), ProviderAPIKey.internal_priority.asc()
)
.all()
)
grouped: Dict[str, List[dict]] = {}
for key, endpoint, provider in keys:
api_format = endpoint.api_format
if api_format not in grouped:
grouped[api_format] = []
try:
decrypted_key = crypto_service.decrypt(key.api_key)
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
except Exception:
masked_key = "***ERROR***"
# 计算健康度指标
success_rate = key.success_count / key.request_count if key.request_count > 0 else None
avg_response_time_ms = (
round(key.total_response_time_ms / key.success_count, 2)
if key.success_count > 0
else None
)
# 将 capabilities dict 转换为启用的能力简短名称列表
caps_list = []
if key.capabilities:
for cap_name, enabled in key.capabilities.items():
if enabled:
cap_def = get_capability(cap_name)
caps_list.append(cap_def.short_name if cap_def else cap_name)
grouped[api_format].append(
{
"id": key.id,
"name": key.name,
"api_key_masked": masked_key,
"internal_priority": key.internal_priority,
"global_priority": key.global_priority,
"rate_multiplier": key.rate_multiplier,
"is_active": key.is_active,
"circuit_breaker_open": key.circuit_breaker_open,
"provider_name": provider.display_name or provider.name,
"endpoint_base_url": endpoint.base_url,
"api_format": api_format,
"capabilities": caps_list,
"success_rate": success_rate,
"avg_response_time_ms": avg_response_time_ms,
"request_count": key.request_count,
}
)
# 直接返回分组对象,供前端使用
return grouped
@dataclass
class AdminBatchUpdateKeyPriorityAdapter(AdminApiAdapter):
endpoint_id: str
priority_data: BatchUpdateKeyPriorityRequest
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
# 获取所有需要更新的 Key ID
key_ids = [item.key_id for item in self.priority_data.priorities]
# 验证所有 Key 都属于该 Endpoint
keys = (
db.query(ProviderAPIKey)
.filter(
ProviderAPIKey.id.in_(key_ids),
ProviderAPIKey.endpoint_id == self.endpoint_id,
)
.all()
)
if len(keys) != len(key_ids):
found_ids = {k.id for k in keys}
missing_ids = set(key_ids) - found_ids
raise InvalidRequestException(f"Keys 不属于该 Endpoint 或不存在: {missing_ids}")
# 批量更新优先级
key_map = {k.id: k for k in keys}
updated_count = 0
for item in self.priority_data.priorities:
key = key_map.get(item.key_id)
if key and key.internal_priority != item.internal_priority:
key.internal_priority = item.internal_priority
key.updated_at = datetime.now(timezone.utc)
updated_count += 1
db.commit()
logger.info(f"[OK] 批量更新 Key 优先级: Endpoint={self.endpoint_id}, Updated={updated_count}/{len(key_ids)}")
return {"message": f"已更新 {updated_count} 个 Key 的优先级", "updated_count": updated_count}

View File

@@ -0,0 +1,345 @@
"""
ProviderEndpoint CRUD 管理 API
"""
import uuid
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import List
from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy import and_, func
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import InvalidRequestException, NotFoundException
from src.core.logger import logger
from src.database import get_db
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
from src.models.endpoint_models import (
ProviderEndpointCreate,
ProviderEndpointResponse,
ProviderEndpointUpdate,
)
router = APIRouter(tags=["Endpoint Management"])
pipeline = ApiRequestPipeline()
@router.get("/providers/{provider_id}/endpoints", response_model=List[ProviderEndpointResponse])
async def list_provider_endpoints(
provider_id: str,
request: Request,
skip: int = Query(0, ge=0, description="跳过的记录数"),
limit: int = Query(100, ge=1, le=1000, description="返回的最大记录数"),
db: Session = Depends(get_db),
) -> List[ProviderEndpointResponse]:
"""获取指定 Provider 的所有 Endpoints"""
adapter = AdminListProviderEndpointsAdapter(
provider_id=provider_id,
skip=skip,
limit=limit,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/providers/{provider_id}/endpoints", response_model=ProviderEndpointResponse)
async def create_provider_endpoint(
provider_id: str,
endpoint_data: ProviderEndpointCreate,
request: Request,
db: Session = Depends(get_db),
) -> ProviderEndpointResponse:
"""为 Provider 创建新的 Endpoint"""
adapter = AdminCreateProviderEndpointAdapter(
provider_id=provider_id,
endpoint_data=endpoint_data,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/{endpoint_id}", response_model=ProviderEndpointResponse)
async def get_endpoint(
endpoint_id: str,
request: Request,
db: Session = Depends(get_db),
) -> ProviderEndpointResponse:
"""获取 Endpoint 详情"""
adapter = AdminGetProviderEndpointAdapter(endpoint_id=endpoint_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.put("/{endpoint_id}", response_model=ProviderEndpointResponse)
async def update_endpoint(
endpoint_id: str,
endpoint_data: ProviderEndpointUpdate,
request: Request,
db: Session = Depends(get_db),
) -> ProviderEndpointResponse:
"""更新 Endpoint"""
adapter = AdminUpdateProviderEndpointAdapter(
endpoint_id=endpoint_id,
endpoint_data=endpoint_data,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/{endpoint_id}")
async def delete_endpoint(
endpoint_id: str,
request: Request,
db: Session = Depends(get_db),
) -> dict:
"""删除 Endpoint级联删除所有关联的 Keys"""
adapter = AdminDeleteProviderEndpointAdapter(endpoint_id=endpoint_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# -------- Adapters --------
@dataclass
class AdminListProviderEndpointsAdapter(AdminApiAdapter):
provider_id: str
skip: int
limit: int
async def handle(self, context): # type: ignore[override]
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException(f"Provider {self.provider_id} 不存在")
endpoints = (
db.query(ProviderEndpoint)
.filter(ProviderEndpoint.provider_id == self.provider_id)
.order_by(ProviderEndpoint.created_at.desc())
.offset(self.skip)
.limit(self.limit)
.all()
)
endpoint_ids = [ep.id for ep in endpoints]
total_keys_map = {}
active_keys_map = {}
if endpoint_ids:
total_rows = (
db.query(ProviderAPIKey.endpoint_id, func.count(ProviderAPIKey.id).label("total"))
.filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids))
.group_by(ProviderAPIKey.endpoint_id)
.all()
)
total_keys_map = {row.endpoint_id: row.total for row in total_rows}
active_rows = (
db.query(ProviderAPIKey.endpoint_id, func.count(ProviderAPIKey.id).label("active"))
.filter(
and_(
ProviderAPIKey.endpoint_id.in_(endpoint_ids),
ProviderAPIKey.is_active.is_(True),
)
)
.group_by(ProviderAPIKey.endpoint_id)
.all()
)
active_keys_map = {row.endpoint_id: row.active for row in active_rows}
result: List[ProviderEndpointResponse] = []
for endpoint in endpoints:
endpoint_dict = {
**endpoint.__dict__,
"provider_name": provider.name,
"api_format": endpoint.api_format,
"total_keys": total_keys_map.get(endpoint.id, 0),
"active_keys": active_keys_map.get(endpoint.id, 0),
}
endpoint_dict.pop("_sa_instance_state", None)
result.append(ProviderEndpointResponse(**endpoint_dict))
return result
@dataclass
class AdminCreateProviderEndpointAdapter(AdminApiAdapter):
provider_id: str
endpoint_data: ProviderEndpointCreate
async def handle(self, context): # type: ignore[override]
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException(f"Provider {self.provider_id} 不存在")
if self.endpoint_data.provider_id != self.provider_id:
raise InvalidRequestException("provider_id 不匹配")
existing = (
db.query(ProviderEndpoint)
.filter(
and_(
ProviderEndpoint.provider_id == self.provider_id,
ProviderEndpoint.api_format == self.endpoint_data.api_format,
)
)
.first()
)
if existing:
raise InvalidRequestException(
f"Provider {provider.name} 已存在 {self.endpoint_data.api_format} 格式的 Endpoint"
)
now = datetime.now(timezone.utc)
new_endpoint = ProviderEndpoint(
id=str(uuid.uuid4()),
provider_id=self.provider_id,
api_format=self.endpoint_data.api_format,
base_url=self.endpoint_data.base_url,
headers=self.endpoint_data.headers,
timeout=self.endpoint_data.timeout,
max_retries=self.endpoint_data.max_retries,
max_concurrent=self.endpoint_data.max_concurrent,
rate_limit=self.endpoint_data.rate_limit,
is_active=True,
config=self.endpoint_data.config,
created_at=now,
updated_at=now,
)
db.add(new_endpoint)
db.commit()
db.refresh(new_endpoint)
logger.info(f"[OK] 创建 Endpoint: Provider={provider.name}, Format={self.endpoint_data.api_format}, ID={new_endpoint.id}")
endpoint_dict = {
k: v
for k, v in new_endpoint.__dict__.items()
if k not in {"api_format", "_sa_instance_state"}
}
return ProviderEndpointResponse(
**endpoint_dict,
provider_name=provider.name,
api_format=new_endpoint.api_format,
total_keys=0,
active_keys=0,
)
@dataclass
class AdminGetProviderEndpointAdapter(AdminApiAdapter):
endpoint_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint, Provider)
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
.filter(ProviderEndpoint.id == self.endpoint_id)
.first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
endpoint_obj, provider = endpoint
total_keys = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
)
active_keys = (
db.query(ProviderAPIKey)
.filter(
and_(
ProviderAPIKey.endpoint_id == self.endpoint_id,
ProviderAPIKey.is_active.is_(True),
)
)
.count()
)
endpoint_dict = {
k: v
for k, v in endpoint_obj.__dict__.items()
if k not in {"api_format", "_sa_instance_state"}
}
return ProviderEndpointResponse(
**endpoint_dict,
provider_name=provider.name,
api_format=endpoint_obj.api_format,
total_keys=total_keys,
active_keys=active_keys,
)
@dataclass
class AdminUpdateProviderEndpointAdapter(AdminApiAdapter):
endpoint_id: str
endpoint_data: ProviderEndpointUpdate
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
update_data = self.endpoint_data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(endpoint, field, value)
endpoint.updated_at = datetime.now(timezone.utc)
db.commit()
db.refresh(endpoint)
provider = db.query(Provider).filter(Provider.id == endpoint.provider_id).first()
logger.info(f"[OK] 更新 Endpoint: ID={self.endpoint_id}, Updates={list(update_data.keys())}")
total_keys = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
)
active_keys = (
db.query(ProviderAPIKey)
.filter(
and_(
ProviderAPIKey.endpoint_id == self.endpoint_id,
ProviderAPIKey.is_active.is_(True),
)
)
.count()
)
endpoint_dict = {
k: v
for k, v in endpoint.__dict__.items()
if k not in {"api_format", "_sa_instance_state"}
}
return ProviderEndpointResponse(
**endpoint_dict,
provider_name=provider.name if provider else "Unknown",
api_format=endpoint.api_format,
total_keys=total_keys,
active_keys=active_keys,
)
@dataclass
class AdminDeleteProviderEndpointAdapter(AdminApiAdapter):
endpoint_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
keys_count = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
)
db.delete(endpoint)
db.commit()
logger.warning(f"[DELETE] 删除 Endpoint: ID={self.endpoint_id}, 同时删除了 {keys_count} 个 Keys")
return {"message": f"Endpoint {self.endpoint_id} 已删除", "deleted_keys_count": keys_count}

View File

@@ -0,0 +1,16 @@
"""
模型管理相关 Admin API
"""
from fastapi import APIRouter
from .catalog import router as catalog_router
from .global_models import router as global_models_router
from .mappings import router as mappings_router
router = APIRouter(prefix="/api/admin/models", tags=["Admin - Model Management"])
# 挂载子路由
router.include_router(catalog_router)
router.include_router(global_models_router)
router.include_router(mappings_router)

View File

@@ -0,0 +1,432 @@
"""
统一模型目录 Admin API
阶段一:基于 ModelMapping 和 Model 的聚合视图
"""
from dataclasses import dataclass
from typing import Dict, List, Optional, Set
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy import func, or_
from sqlalchemy.orm import Session, joinedload
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.logger import logger
from src.database import get_db
from src.models.database import GlobalModel, Model, ModelMapping, Provider
from src.models.pydantic_models import (
BatchAssignError,
BatchAssignModelMappingRequest,
BatchAssignModelMappingResponse,
BatchAssignProviderResult,
DeleteModelMappingResponse,
ModelCapabilities,
ModelCatalogItem,
ModelCatalogProviderDetail,
ModelCatalogResponse,
ModelPriceRange,
OrphanedModel,
UpdateModelMappingRequest,
UpdateModelMappingResponse,
)
from src.services.cache.invalidation import get_cache_invalidation_service
from src.services.model.service import ModelService
router = APIRouter(prefix="/catalog", tags=["Admin - Model Catalog"])
pipeline = ApiRequestPipeline()
@router.get("", response_model=ModelCatalogResponse)
async def get_model_catalog(
request: Request,
db: Session = Depends(get_db),
) -> ModelCatalogResponse:
adapter = AdminGetModelCatalogAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/batch-assign", response_model=BatchAssignModelMappingResponse)
async def batch_assign_model_mappings(
request: Request,
payload: BatchAssignModelMappingRequest,
db: Session = Depends(get_db),
) -> BatchAssignModelMappingResponse:
adapter = AdminBatchAssignModelMappingsAdapter(payload=payload)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@dataclass
class AdminGetModelCatalogAdapter(AdminApiAdapter):
"""管理员查询统一模型目录
新架构说明:
1. 以 GlobalModel 为中心聚合数据
2. ModelMapping 表提供别名信息provider_id=NULL 表示全局)
3. Model 表提供关联提供商和价格
"""
async def handle(self, context): # type: ignore[override]
db: Session = context.db
# 1. 获取所有活跃的 GlobalModel
global_models: List[GlobalModel] = (
db.query(GlobalModel).filter(GlobalModel.is_active == True).all()
)
# 2. 获取所有活跃的别名(含全局和 Provider 特定)
aliases_rows: List[ModelMapping] = (
db.query(ModelMapping)
.options(joinedload(ModelMapping.target_global_model))
.filter(
ModelMapping.is_active == True,
ModelMapping.provider_id.is_(None),
)
.all()
)
# 按 GlobalModel ID 组织别名
aliases_by_global_model: Dict[str, List[str]] = {}
for alias_row in aliases_rows:
if not alias_row.target_global_model_id:
continue
gm_id = alias_row.target_global_model_id
if gm_id not in aliases_by_global_model:
aliases_by_global_model[gm_id] = []
if alias_row.source_model not in aliases_by_global_model[gm_id]:
aliases_by_global_model[gm_id].append(alias_row.source_model)
# 3. 获取所有活跃的 Model 实现(包含 global_model 以便计算有效价格)
models: List[Model] = (
db.query(Model)
.options(joinedload(Model.provider), joinedload(Model.global_model))
.filter(Model.is_active == True)
.all()
)
# 按 GlobalModel ID 组织关联提供商
models_by_global_model: Dict[str, List[Model]] = {}
for model in models:
if model.global_model_id:
models_by_global_model.setdefault(model.global_model_id, []).append(model)
# 4. 为每个 GlobalModel 构建 catalog item
catalog_items: List[ModelCatalogItem] = []
for gm in global_models:
gm_id = gm.id
provider_entries: List[ModelCatalogProviderDetail] = []
capability_flags = {
"supports_vision": gm.default_supports_vision or False,
"supports_function_calling": gm.default_supports_function_calling or False,
"supports_streaming": gm.default_supports_streaming or False,
}
# 遍历该 GlobalModel 的所有关联提供商
for model in models_by_global_model.get(gm_id, []):
provider = model.provider
if not provider:
continue
# 使用有效价格(考虑 GlobalModel 默认值)
effective_input = model.get_effective_input_price()
effective_output = model.get_effective_output_price()
effective_tiered = model.get_effective_tiered_pricing()
tier_count = len(effective_tiered.get("tiers", [])) if effective_tiered else 1
# 使用有效能力值
capability_flags["supports_vision"] = (
capability_flags["supports_vision"] or model.get_effective_supports_vision()
)
capability_flags["supports_function_calling"] = (
capability_flags["supports_function_calling"]
or model.get_effective_supports_function_calling()
)
capability_flags["supports_streaming"] = (
capability_flags["supports_streaming"]
or model.get_effective_supports_streaming()
)
provider_entries.append(
ModelCatalogProviderDetail(
provider_id=provider.id,
provider_name=provider.name,
provider_display_name=provider.display_name,
model_id=model.id,
target_model=model.provider_model_name,
# 显示有效价格
input_price_per_1m=effective_input,
output_price_per_1m=effective_output,
cache_creation_price_per_1m=model.get_effective_cache_creation_price(),
cache_read_price_per_1m=model.get_effective_cache_read_price(),
cache_1h_creation_price_per_1m=model.get_effective_1h_cache_creation_price(),
price_per_request=model.get_effective_price_per_request(),
effective_tiered_pricing=effective_tiered,
tier_count=tier_count,
supports_vision=model.get_effective_supports_vision(),
supports_function_calling=model.get_effective_supports_function_calling(),
supports_streaming=model.get_effective_supports_streaming(),
is_active=bool(model.is_active),
mapping_id=None, # 新架构中不再有 mapping_id
)
)
# 模型目录显示 GlobalModel 的第一个阶梯价格(不是 Provider 聚合价格)
tiered = gm.default_tiered_pricing or {}
first_tier = tiered.get("tiers", [{}])[0] if tiered.get("tiers") else {}
price_range = ModelPriceRange(
min_input=first_tier.get("input_price_per_1m", 0),
max_input=first_tier.get("input_price_per_1m", 0),
min_output=first_tier.get("output_price_per_1m", 0),
max_output=first_tier.get("output_price_per_1m", 0),
)
catalog_items.append(
ModelCatalogItem(
global_model_name=gm.name,
display_name=gm.display_name,
description=gm.description,
aliases=aliases_by_global_model.get(gm_id, []),
providers=provider_entries,
price_range=price_range,
total_providers=len(provider_entries),
capabilities=ModelCapabilities(**capability_flags),
)
)
# 5. 查找孤立的别名(别名指向的 GlobalModel 不存在或不活跃)
orphaned_rows = (
db.query(ModelMapping.source_model, GlobalModel.name, func.count(ModelMapping.id))
.outerjoin(GlobalModel, ModelMapping.target_global_model_id == GlobalModel.id)
.filter(
ModelMapping.is_active == True,
ModelMapping.provider_id.is_(None),
or_(GlobalModel.id == None, GlobalModel.is_active == False),
)
.group_by(ModelMapping.source_model, GlobalModel.name)
.all()
)
orphaned_models = [
OrphanedModel(alias=row[0], global_model_name=row[1], mapping_count=row[2])
for row in orphaned_rows
if row[0]
]
return ModelCatalogResponse(
models=catalog_items,
total=len(catalog_items),
orphaned_models=orphaned_models,
)
@dataclass
class AdminBatchAssignModelMappingsAdapter(AdminApiAdapter):
payload: BatchAssignModelMappingRequest
async def handle(self, context): # type: ignore[override]
db: Session = context.db
created: List[BatchAssignProviderResult] = []
errors: List[BatchAssignError] = []
for provider_config in self.payload.providers:
provider_id = provider_config.provider_id
try:
provider: Provider = db.query(Provider).filter(Provider.id == provider_id).first()
if not provider:
errors.append(
BatchAssignError(provider_id=provider_id, error="Provider 不存在")
)
continue
model_id: Optional[str] = None
created_model = False
if provider_config.create_model:
model_data = provider_config.model_data
if not model_data:
errors.append(
BatchAssignError(provider_id=provider_id, error="缺少 model_data 配置")
)
continue
existing_model = ModelService.get_model_by_name(
db, provider_id, model_data.provider_model_name
)
if existing_model:
model_id = existing_model.id
logger.info("模型 %s 已存在于 Provider %s,复用现有模型",
model_data.provider_model_name,
provider.name,
)
else:
model = ModelService.create_model(db, provider_id, model_data)
model_id = model.id
created_model = True
else:
model_id = provider_config.model_id
if not model_id:
errors.append(
BatchAssignError(provider_id=provider_id, error="缺少 model_id")
)
continue
model = (
db.query(Model)
.filter(Model.id == model_id, Model.provider_id == provider_id)
.first()
)
if not model:
errors.append(
BatchAssignError(
provider_id=provider_id, error="模型不存在或不属于当前 Provider")
)
continue
# 批量分配功能需要适配 GlobalModel 架构
# 参见 docs/optimization-backlog.md 中的待办项
errors.append(
BatchAssignError(
provider_id=provider_id,
error="批量分配功能暂时不可用,需要适配新的 GlobalModel 架构",
)
)
continue
except Exception as exc:
db.rollback()
logger.error("批量添加模型映射失败(需要适配新架构)")
errors.append(BatchAssignError(provider_id=provider_id, error=str(exc)))
return BatchAssignModelMappingResponse(
success=len(created) > 0,
created_mappings=created,
errors=errors,
)
@router.put("/mappings/{mapping_id}", response_model=UpdateModelMappingResponse)
async def update_model_mapping(
request: Request,
mapping_id: str,
payload: UpdateModelMappingRequest,
db: Session = Depends(get_db),
) -> UpdateModelMappingResponse:
"""更新模型映射"""
adapter = AdminUpdateModelMappingAdapter(mapping_id=mapping_id, payload=payload)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/mappings/{mapping_id}", response_model=DeleteModelMappingResponse)
async def delete_model_mapping(
request: Request,
mapping_id: str,
db: Session = Depends(get_db),
) -> DeleteModelMappingResponse:
"""删除模型映射"""
adapter = AdminDeleteModelMappingAdapter(mapping_id=mapping_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@dataclass
class AdminUpdateModelMappingAdapter(AdminApiAdapter):
"""更新模型映射"""
mapping_id: str
payload: UpdateModelMappingRequest
async def handle(self, context): # type: ignore[override]
db: Session = context.db
mapping: Optional[ModelMapping] = (
db.query(ModelMapping).filter(ModelMapping.id == self.mapping_id).first()
)
if not mapping:
raise HTTPException(status_code=404, detail="映射不存在")
update_data = self.payload.model_dump(exclude_unset=True)
if "provider_id" in update_data:
new_provider_id = update_data["provider_id"]
if new_provider_id:
provider = db.query(Provider).filter(Provider.id == new_provider_id).first()
if not provider:
raise HTTPException(status_code=404, detail="Provider 不存在")
mapping.provider_id = new_provider_id
if "target_global_model_id" in update_data:
target_model = (
db.query(GlobalModel)
.filter(
GlobalModel.id == update_data["target_global_model_id"],
GlobalModel.is_active == True,
)
.first()
)
if not target_model:
raise HTTPException(status_code=404, detail="目标 GlobalModel 不存在或未激活")
mapping.target_global_model_id = update_data["target_global_model_id"]
if "source_model" in update_data:
new_source = update_data["source_model"].strip()
if not new_source:
raise HTTPException(status_code=400, detail="source_model 不能为空")
mapping.source_model = new_source
if "is_active" in update_data:
mapping.is_active = update_data["is_active"]
duplicate = (
db.query(ModelMapping)
.filter(
ModelMapping.source_model == mapping.source_model,
ModelMapping.provider_id == mapping.provider_id,
ModelMapping.id != mapping.id,
)
.first()
)
if duplicate:
raise HTTPException(status_code=400, detail="映射已存在")
db.commit()
db.refresh(mapping)
cache_service = get_cache_invalidation_service()
cache_service.on_model_mapping_changed(mapping.source_model, mapping.provider_id)
return UpdateModelMappingResponse(
success=True,
mapping_id=mapping.id,
message="映射更新成功",
)
@dataclass
class AdminDeleteModelMappingAdapter(AdminApiAdapter):
"""删除模型映射"""
mapping_id: str
async def handle(self, context): # type: ignore[override]
db: Session = context.db
mapping: Optional[ModelMapping] = (
db.query(ModelMapping).filter(ModelMapping.id == self.mapping_id).first()
)
if not mapping:
raise HTTPException(status_code=404, detail="映射不存在")
source_model = mapping.source_model
provider_id = mapping.provider_id
db.delete(mapping)
db.commit()
cache_service = get_cache_invalidation_service()
cache_service.on_model_mapping_changed(source_model, provider_id)
return DeleteModelMappingResponse(
success=True,
message=f"映射 {self.mapping_id} 已删除",
)

View File

@@ -0,0 +1,292 @@
"""
GlobalModel Admin API
提供 GlobalModel 的 CRUD 操作接口
"""
from dataclasses import dataclass
from typing import List, Optional
from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.logger import logger
from src.database import get_db
from src.models.pydantic_models import (
BatchAssignToProvidersRequest,
BatchAssignToProvidersResponse,
GlobalModelCreate,
GlobalModelListResponse,
GlobalModelResponse,
GlobalModelUpdate,
GlobalModelWithStats,
)
from src.services.model.global_model import GlobalModelService
router = APIRouter(prefix="/global", tags=["Admin - Global Models"])
pipeline = ApiRequestPipeline()
@router.get("", response_model=GlobalModelListResponse)
async def list_global_models(
request: Request,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000),
is_active: Optional[bool] = Query(None),
search: Optional[str] = Query(None),
db: Session = Depends(get_db),
) -> GlobalModelListResponse:
"""获取 GlobalModel 列表"""
adapter = AdminListGlobalModelsAdapter(
skip=skip,
limit=limit,
is_active=is_active,
search=search,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/{global_model_id}", response_model=GlobalModelWithStats)
async def get_global_model(
request: Request,
global_model_id: str,
db: Session = Depends(get_db),
) -> GlobalModelWithStats:
"""获取单个 GlobalModel 详情(含统计信息)"""
adapter = AdminGetGlobalModelAdapter(global_model_id=global_model_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("", response_model=GlobalModelResponse, status_code=201)
async def create_global_model(
request: Request,
payload: GlobalModelCreate,
db: Session = Depends(get_db),
) -> GlobalModelResponse:
"""创建 GlobalModel"""
adapter = AdminCreateGlobalModelAdapter(payload=payload)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.patch("/{global_model_id}", response_model=GlobalModelResponse)
async def update_global_model(
request: Request,
global_model_id: str,
payload: GlobalModelUpdate,
db: Session = Depends(get_db),
) -> GlobalModelResponse:
"""更新 GlobalModel"""
adapter = AdminUpdateGlobalModelAdapter(global_model_id=global_model_id, payload=payload)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/{global_model_id}", status_code=204)
async def delete_global_model(
request: Request,
global_model_id: str,
db: Session = Depends(get_db),
):
"""删除 GlobalModel级联删除所有关联的 Provider 模型实现)"""
adapter = AdminDeleteGlobalModelAdapter(global_model_id=global_model_id)
await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
return None
@router.post(
"/{global_model_id}/assign-to-providers", response_model=BatchAssignToProvidersResponse
)
async def batch_assign_to_providers(
request: Request,
global_model_id: str,
payload: BatchAssignToProvidersRequest,
db: Session = Depends(get_db),
) -> BatchAssignToProvidersResponse:
"""批量为多个 Provider 添加 GlobalModel 实现"""
adapter = AdminBatchAssignToProvidersAdapter(global_model_id=global_model_id, payload=payload)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# ========== Adapters ==========
@dataclass
class AdminListGlobalModelsAdapter(AdminApiAdapter):
"""列出 GlobalModel"""
skip: int
limit: int
is_active: Optional[bool]
search: Optional[str]
async def handle(self, context): # type: ignore[override]
from sqlalchemy import func
from src.models.database import Model, ModelMapping
models = GlobalModelService.list_global_models(
db=context.db,
skip=self.skip,
limit=self.limit,
is_active=self.is_active,
search=self.search,
)
# 为每个 GlobalModel 添加统计数据
model_responses = []
for gm in models:
# 统计关联的 Model 数量(去重 Provider
provider_count = (
context.db.query(func.count(func.distinct(Model.provider_id)))
.filter(Model.global_model_id == gm.id)
.scalar()
or 0
)
# 统计别名数量
alias_count = (
context.db.query(func.count(ModelMapping.id))
.filter(ModelMapping.target_global_model_id == gm.id)
.scalar()
or 0
)
response = GlobalModelResponse.model_validate(gm)
response.provider_count = provider_count
response.alias_count = alias_count
# usage_count 直接从 GlobalModel 表读取,已在 model_validate 中自动映射
model_responses.append(response)
return GlobalModelListResponse(
models=model_responses,
total=len(models),
)
@dataclass
class AdminGetGlobalModelAdapter(AdminApiAdapter):
"""获取单个 GlobalModel"""
global_model_id: str
async def handle(self, context): # type: ignore[override]
global_model = GlobalModelService.get_global_model(context.db, self.global_model_id)
stats = GlobalModelService.get_global_model_stats(context.db, self.global_model_id)
return GlobalModelWithStats(
**GlobalModelResponse.model_validate(global_model).model_dump(),
total_models=stats["total_models"],
total_providers=stats["total_providers"],
price_range=stats["price_range"],
)
@dataclass
class AdminCreateGlobalModelAdapter(AdminApiAdapter):
"""创建 GlobalModel"""
payload: GlobalModelCreate
async def handle(self, context): # type: ignore[override]
# 将 TieredPricingConfig 转换为 dict
tiered_pricing_dict = self.payload.default_tiered_pricing.model_dump()
global_model = GlobalModelService.create_global_model(
db=context.db,
name=self.payload.name,
display_name=self.payload.display_name,
description=self.payload.description,
official_url=self.payload.official_url,
icon_url=self.payload.icon_url,
is_active=self.payload.is_active,
# 按次计费配置
default_price_per_request=self.payload.default_price_per_request,
# 阶梯计费配置
default_tiered_pricing=tiered_pricing_dict,
# 默认能力配置
default_supports_vision=self.payload.default_supports_vision,
default_supports_function_calling=self.payload.default_supports_function_calling,
default_supports_streaming=self.payload.default_supports_streaming,
default_supports_extended_thinking=self.payload.default_supports_extended_thinking,
# Key 能力配置
supported_capabilities=self.payload.supported_capabilities,
)
logger.info(f"GlobalModel 已创建: id={global_model.id} name={global_model.name}")
return GlobalModelResponse.model_validate(global_model)
@dataclass
class AdminUpdateGlobalModelAdapter(AdminApiAdapter):
"""更新 GlobalModel"""
global_model_id: str
payload: GlobalModelUpdate
async def handle(self, context): # type: ignore[override]
global_model = GlobalModelService.update_global_model(
db=context.db,
global_model_id=self.global_model_id,
update_data=self.payload,
)
logger.info(f"GlobalModel 已更新: id={global_model.id} name={global_model.name}")
# 失效相关缓存
from src.services.cache.invalidation import get_cache_invalidation_service
cache_service = get_cache_invalidation_service()
cache_service.on_global_model_changed(global_model.name)
return GlobalModelResponse.model_validate(global_model)
@dataclass
class AdminDeleteGlobalModelAdapter(AdminApiAdapter):
"""删除 GlobalModel级联删除所有关联的 Provider 模型实现)"""
global_model_id: str
async def handle(self, context): # type: ignore[override]
# 先获取 GlobalModel 信息(用于失效缓存)
from src.models.database import GlobalModel
global_model = (
context.db.query(GlobalModel).filter(GlobalModel.id == self.global_model_id).first()
)
model_name = global_model.name if global_model else None
GlobalModelService.delete_global_model(context.db, self.global_model_id)
logger.info(f"GlobalModel 已删除: id={self.global_model_id}")
# 失效相关缓存
if model_name:
from src.services.cache.invalidation import get_cache_invalidation_service
cache_service = get_cache_invalidation_service()
cache_service.on_global_model_changed(model_name)
return None
@dataclass
class AdminBatchAssignToProvidersAdapter(AdminApiAdapter):
"""批量为 Provider 添加 GlobalModel 实现"""
global_model_id: str
payload: BatchAssignToProvidersRequest
async def handle(self, context): # type: ignore[override]
result = GlobalModelService.batch_assign_to_providers(
db=context.db,
global_model_id=self.global_model_id,
provider_ids=self.payload.provider_ids,
create_models=self.payload.create_models,
)
logger.info(f"批量为 Provider 添加 GlobalModel: global_model_id={self.global_model_id} success={len(result['success'])} errors={len(result['errors'])}")
return BatchAssignToProvidersResponse(**result)

View File

@@ -0,0 +1,303 @@
"""模型映射管理 API
提供模型映射的 CRUD 操作。
模型映射Mapping用于将源模型映射到目标模型例如
- 请求 gpt-5.1 → Provider A 映射到 gpt-4
- 用于处理 Provider 不支持请求模型的情况
映射必须关联到特定的 Provider。
"""
import uuid
from datetime import datetime, timezone
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session, joinedload
from src.core.logger import logger
from src.database import get_db
from src.models.api import (
ModelMappingCreate,
ModelMappingResponse,
ModelMappingUpdate,
)
from src.models.database import GlobalModel, ModelMapping, Provider, User
from src.services.cache.invalidation import get_cache_invalidation_service
router = APIRouter(prefix="/mappings", tags=["Model Mappings"])
def _serialize_mapping(mapping: ModelMapping) -> ModelMappingResponse:
target = mapping.target_global_model
provider = mapping.provider
scope = "provider" if mapping.provider_id else "global"
return ModelMappingResponse(
id=mapping.id,
source_model=mapping.source_model,
target_global_model_id=mapping.target_global_model_id,
target_global_model_name=target.name if target else None,
target_global_model_display_name=target.display_name if target else None,
provider_id=mapping.provider_id,
provider_name=provider.name if provider else None,
scope=scope,
mapping_type=mapping.mapping_type,
is_active=mapping.is_active,
created_at=mapping.created_at,
updated_at=mapping.updated_at,
)
@router.get("", response_model=List[ModelMappingResponse])
async def list_mappings(
provider_id: Optional[str] = Query(None, description="按 Provider 筛选"),
source_model: Optional[str] = Query(None, description="按源模型名筛选"),
target_global_model_id: Optional[str] = Query(None, description="按目标模型筛选"),
scope: Optional[str] = Query(None, description="global 或 provider"),
mapping_type: Optional[str] = Query(None, description="映射类型: alias 或 mapping"),
is_active: Optional[bool] = Query(None, description="按状态筛选"),
skip: int = Query(0, ge=0, description="跳过记录数"),
limit: int = Query(100, ge=1, le=1000, description="返回记录数"),
db: Session = Depends(get_db),
):
"""获取模型映射列表"""
query = db.query(ModelMapping).options(
joinedload(ModelMapping.target_global_model),
joinedload(ModelMapping.provider),
)
if provider_id is not None:
query = query.filter(ModelMapping.provider_id == provider_id)
if scope == "global":
query = query.filter(ModelMapping.provider_id.is_(None))
elif scope == "provider":
query = query.filter(ModelMapping.provider_id.isnot(None))
if mapping_type is not None:
query = query.filter(ModelMapping.mapping_type == mapping_type)
if source_model:
query = query.filter(ModelMapping.source_model.ilike(f"%{source_model}%"))
if target_global_model_id is not None:
query = query.filter(ModelMapping.target_global_model_id == target_global_model_id)
if is_active is not None:
query = query.filter(ModelMapping.is_active == is_active)
mappings = query.offset(skip).limit(limit).all()
return [_serialize_mapping(mapping) for mapping in mappings]
@router.get("/{mapping_id}", response_model=ModelMappingResponse)
async def get_mapping(
mapping_id: str,
db: Session = Depends(get_db),
):
"""获取单个模型映射"""
mapping = (
db.query(ModelMapping)
.options(
joinedload(ModelMapping.target_global_model),
joinedload(ModelMapping.provider),
)
.filter(ModelMapping.id == mapping_id)
.first()
)
if not mapping:
raise HTTPException(status_code=404, detail=f"映射 {mapping_id} 不存在")
return _serialize_mapping(mapping)
@router.post("", response_model=ModelMappingResponse, status_code=201)
async def create_mapping(
data: ModelMappingCreate,
db: Session = Depends(get_db),
):
"""创建模型映射"""
source_model = data.source_model.strip()
if not source_model:
raise HTTPException(status_code=400, detail="source_model 不能为空")
# 验证 mapping_type
if data.mapping_type not in ("alias", "mapping"):
raise HTTPException(status_code=400, detail="mapping_type 必须是 'alias''mapping'")
# 验证目标 GlobalModel 存在
target_model = (
db.query(GlobalModel)
.filter(GlobalModel.id == data.target_global_model_id, GlobalModel.is_active == True)
.first()
)
if not target_model:
raise HTTPException(
status_code=404, detail=f"目标模型 {data.target_global_model_id} 不存在或未激活"
)
# 验证 Provider 存在
provider = None
provider_id = data.provider_id
if provider_id:
provider = db.query(Provider).filter(Provider.id == provider_id).first()
if not provider:
raise HTTPException(status_code=404, detail=f"Provider {provider_id} 不存在")
# 检查映射是否已存在(全局或同一 Provider 下不可重复)
existing = (
db.query(ModelMapping)
.filter(
ModelMapping.source_model == source_model,
ModelMapping.provider_id == provider_id,
)
.first()
)
if existing:
raise HTTPException(status_code=400, detail="映射已存在")
# 创建映射
mapping = ModelMapping(
id=str(uuid.uuid4()),
source_model=source_model,
target_global_model_id=data.target_global_model_id,
provider_id=provider_id,
mapping_type=data.mapping_type,
is_active=data.is_active,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db.add(mapping)
db.commit()
mapping = (
db.query(ModelMapping)
.options(
joinedload(ModelMapping.target_global_model),
joinedload(ModelMapping.provider),
)
.filter(ModelMapping.id == mapping.id)
.first()
)
logger.info(f"创建模型映射: {source_model} -> {target_model.name} "
f"(Provider: {provider.name if provider else 'global'}, ID: {mapping.id})")
cache_service = get_cache_invalidation_service()
cache_service.on_model_mapping_changed(source_model, provider_id)
return _serialize_mapping(mapping)
@router.patch("/{mapping_id}", response_model=ModelMappingResponse)
async def update_mapping(
mapping_id: str,
data: ModelMappingUpdate,
db: Session = Depends(get_db),
):
"""更新模型映射"""
mapping = db.query(ModelMapping).filter(ModelMapping.id == mapping_id).first()
if not mapping:
raise HTTPException(status_code=404, detail=f"映射 {mapping_id} 不存在")
update_data = data.model_dump(exclude_unset=True)
# 更新 Provider
if "provider_id" in update_data:
new_provider_id = update_data["provider_id"]
if new_provider_id:
provider = db.query(Provider).filter(Provider.id == new_provider_id).first()
if not provider:
raise HTTPException(status_code=404, detail=f"Provider {new_provider_id} 不存在")
mapping.provider_id = new_provider_id
# 更新目标模型
if "target_global_model_id" in update_data:
target_model = (
db.query(GlobalModel)
.filter(
GlobalModel.id == update_data["target_global_model_id"],
GlobalModel.is_active == True,
)
.first()
)
if not target_model:
raise HTTPException(
status_code=404,
detail=f"目标模型 {update_data['target_global_model_id']} 不存在或未激活",
)
mapping.target_global_model_id = update_data["target_global_model_id"]
# 更新源模型名
if "source_model" in update_data:
new_source = update_data["source_model"].strip()
if not new_source:
raise HTTPException(status_code=400, detail="source_model 不能为空")
mapping.source_model = new_source
# 检查唯一约束
duplicate = (
db.query(ModelMapping)
.filter(
ModelMapping.source_model == mapping.source_model,
ModelMapping.provider_id == mapping.provider_id,
ModelMapping.id != mapping_id,
)
.first()
)
if duplicate:
raise HTTPException(status_code=400, detail="映射已存在")
# 更新映射类型
if "mapping_type" in update_data:
if update_data["mapping_type"] not in ("alias", "mapping"):
raise HTTPException(status_code=400, detail="mapping_type 必须是 'alias''mapping'")
mapping.mapping_type = update_data["mapping_type"]
# 更新状态
if "is_active" in update_data:
mapping.is_active = update_data["is_active"]
mapping.updated_at = datetime.now(timezone.utc)
db.commit()
db.refresh(mapping)
logger.info(f"更新模型映射 (ID: {mapping.id})")
mapping = (
db.query(ModelMapping)
.options(
joinedload(ModelMapping.target_global_model),
joinedload(ModelMapping.provider),
)
.filter(ModelMapping.id == mapping.id)
.first()
)
cache_service = get_cache_invalidation_service()
cache_service.on_model_mapping_changed(mapping.source_model, mapping.provider_id)
return _serialize_mapping(mapping)
@router.delete("/{mapping_id}", status_code=204)
async def delete_mapping(
mapping_id: str,
db: Session = Depends(get_db),
):
"""删除模型映射"""
mapping = db.query(ModelMapping).filter(ModelMapping.id == mapping_id).first()
if not mapping:
raise HTTPException(status_code=404, detail=f"映射 {mapping_id} 不存在")
source_model = mapping.source_model
provider_id = mapping.provider_id
logger.info(f"删除模型映射: {source_model} -> {mapping.target_global_model_id} (ID: {mapping.id})")
db.delete(mapping)
db.commit()
cache_service = get_cache_invalidation_service()
cache_service.on_model_mapping_changed(source_model, provider_id)
return None

View File

@@ -0,0 +1,14 @@
"""Admin monitoring router合集。"""
from fastapi import APIRouter
from .audit import router as audit_router
from .cache import router as cache_router
from .trace import router as trace_router
router = APIRouter()
router.include_router(audit_router)
router.include_router(cache_router)
router.include_router(trace_router)
__all__ = ["router"]

View File

@@ -0,0 +1,399 @@
"""管理员监控与审计端点。"""
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from sqlalchemy import func
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pagination import PaginationMeta, build_pagination_payload, paginate_query
from src.api.base.pipeline import ApiRequestPipeline
from src.core.logger import logger
from src.database import get_db
from src.models.database import (
ApiKey,
AuditEventType,
AuditLog,
Provider,
Usage,
)
from src.models.database import User as DBUser
from src.services.health.monitor import HealthMonitor
from src.services.system.audit import audit_service
router = APIRouter(prefix="/api/admin/monitoring", tags=["Admin - Monitoring"])
pipeline = ApiRequestPipeline()
@router.get("/audit-logs")
async def get_audit_logs(
request: Request,
user_id: Optional[str] = Query(None, description="用户ID筛选 (支持UUID)"),
event_type: Optional[str] = Query(None, description="事件类型筛选"),
days: int = Query(7, description="查询天数"),
limit: int = Query(100, description="返回数量限制"),
offset: int = Query(0, description="偏移量"),
db: Session = Depends(get_db),
):
adapter = AdminGetAuditLogsAdapter(
user_id=user_id,
event_type=event_type,
days=days,
limit=limit,
offset=offset,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/system-status")
async def get_system_status(request: Request, db: Session = Depends(get_db)):
adapter = AdminSystemStatusAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/suspicious-activities")
async def get_suspicious_activities(
request: Request,
hours: int = Query(24, description="时间范围(小时)"),
db: Session = Depends(get_db),
):
adapter = AdminSuspiciousActivitiesAdapter(hours=hours)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/user-behavior/{user_id}")
async def analyze_user_behavior(
user_id: str,
request: Request,
days: int = Query(30, description="分析天数"),
db: Session = Depends(get_db),
):
adapter = AdminUserBehaviorAdapter(user_id=user_id, days=days)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/resilience-status")
async def get_resilience_status(request: Request, db: Session = Depends(get_db)):
adapter = AdminResilienceStatusAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/resilience/error-stats")
async def reset_error_stats(request: Request, db: Session = Depends(get_db)):
"""Reset resilience error statistics"""
adapter = AdminResetErrorStatsAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/resilience/circuit-history")
async def get_circuit_history(
request: Request,
limit: int = Query(50, ge=1, le=200),
db: Session = Depends(get_db),
):
adapter = AdminCircuitHistoryAdapter(limit=limit)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@dataclass
class AdminGetAuditLogsAdapter(AdminApiAdapter):
user_id: Optional[str]
event_type: Optional[str]
days: int
limit: int
offset: int
async def handle(self, context): # type: ignore[override]
db = context.db
cutoff_time = datetime.now(timezone.utc) - timedelta(days=self.days)
base_query = (
db.query(AuditLog, DBUser)
.outerjoin(DBUser, AuditLog.user_id == DBUser.id)
.filter(AuditLog.created_at >= cutoff_time)
)
if self.user_id:
base_query = base_query.filter(AuditLog.user_id == self.user_id)
if self.event_type:
base_query = base_query.filter(AuditLog.event_type == self.event_type)
ordered_query = base_query.order_by(AuditLog.created_at.desc())
total, logs_with_users = paginate_query(ordered_query, self.limit, self.offset)
items = [
{
"id": log.id,
"event_type": log.event_type,
"user_id": log.user_id,
"user_email": user.email if user else None,
"user_username": user.username if user else None,
"description": log.description,
"ip_address": log.ip_address,
"status_code": log.status_code,
"error_message": log.error_message,
"metadata": log.event_metadata,
"created_at": log.created_at.isoformat() if log.created_at else None,
}
for log, user in logs_with_users
]
meta = PaginationMeta(
total=total,
limit=self.limit,
offset=self.offset,
count=len(items),
)
payload = build_pagination_payload(
items,
meta,
filters={
"user_id": self.user_id,
"event_type": self.event_type,
"days": self.days,
},
)
context.add_audit_metadata(
action="monitor_audit_logs",
filter_user_id=self.user_id,
filter_event_type=self.event_type,
days=self.days,
limit=self.limit,
offset=self.offset,
total=total,
result_count=meta.count,
)
return payload
class AdminSystemStatusAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
total_users = db.query(func.count(DBUser.id)).scalar()
active_users = db.query(func.count(DBUser.id)).filter(DBUser.is_active.is_(True)).scalar()
total_providers = db.query(func.count(Provider.id)).scalar()
active_providers = (
db.query(func.count(Provider.id)).filter(Provider.is_active.is_(True)).scalar()
)
total_api_keys = db.query(func.count(ApiKey.id)).scalar()
active_api_keys = (
db.query(func.count(ApiKey.id)).filter(ApiKey.is_active.is_(True)).scalar()
)
today_start = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
today_requests = (
db.query(func.count(Usage.id)).filter(Usage.created_at >= today_start).scalar()
)
today_tokens = (
db.query(func.sum(Usage.total_tokens)).filter(Usage.created_at >= today_start).scalar()
or 0
)
today_cost = (
db.query(func.sum(Usage.total_cost_usd))
.filter(Usage.created_at >= today_start)
.scalar()
or 0
)
recent_errors = (
db.query(AuditLog)
.filter(
AuditLog.event_type.in_(
[
AuditEventType.REQUEST_FAILED.value,
AuditEventType.SUSPICIOUS_ACTIVITY.value,
]
),
AuditLog.created_at >= datetime.now(timezone.utc) - timedelta(hours=1),
)
.count()
)
context.add_audit_metadata(
action="system_status_snapshot",
total_users=int(total_users or 0),
active_users=int(active_users or 0),
total_providers=int(total_providers or 0),
active_providers=int(active_providers or 0),
total_api_keys=int(total_api_keys or 0),
active_api_keys=int(active_api_keys or 0),
today_requests=int(today_requests or 0),
today_tokens=int(today_tokens or 0),
today_cost=float(today_cost or 0.0),
recent_errors=int(recent_errors or 0),
)
return {
"timestamp": datetime.now(timezone.utc).isoformat(),
"users": {"total": total_users, "active": active_users},
"providers": {"total": total_providers, "active": active_providers},
"api_keys": {"total": total_api_keys, "active": active_api_keys},
"today_stats": {
"requests": today_requests,
"tokens": today_tokens,
"cost_usd": f"${today_cost:.4f}",
},
"recent_errors": recent_errors,
}
@dataclass
class AdminSuspiciousActivitiesAdapter(AdminApiAdapter):
hours: int
async def handle(self, context): # type: ignore[override]
db = context.db
activities = audit_service.get_suspicious_activities(db=db, hours=self.hours, limit=100)
response = {
"activities": [
{
"id": activity.id,
"event_type": activity.event_type,
"user_id": activity.user_id,
"description": activity.description,
"ip_address": activity.ip_address,
"metadata": activity.event_metadata,
"created_at": activity.created_at.isoformat() if activity.created_at else None,
}
for activity in activities
],
"count": len(activities),
"time_range_hours": self.hours,
}
context.add_audit_metadata(
action="monitor_suspicious_activity",
hours=self.hours,
result_count=len(activities),
)
return response
@dataclass
class AdminUserBehaviorAdapter(AdminApiAdapter):
user_id: str
days: int
async def handle(self, context): # type: ignore[override]
result = audit_service.analyze_user_behavior(
db=context.db,
user_id=self.user_id,
days=self.days,
)
context.add_audit_metadata(
action="monitor_user_behavior",
target_user_id=self.user_id,
days=self.days,
contains_summary=bool(result),
)
return result
class AdminResilienceStatusAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
try:
from src.core.resilience import resilience_manager
except ImportError as exc:
raise HTTPException(status_code=503, detail="韧性管理系统未启用") from exc
error_stats = resilience_manager.get_error_stats()
recent_errors = [
{
"error_id": info["error_id"],
"error_type": info["error_type"],
"operation": info["operation"],
"timestamp": info["timestamp"].isoformat(),
"context": info.get("context", {}),
}
for info in resilience_manager.last_errors[-10:]
]
total_errors = error_stats.get("total_errors", 0)
circuit_breakers = error_stats.get("circuit_breakers", {})
circuit_breakers_open = sum(
1 for status in circuit_breakers.values() if status.get("state") == "open"
)
health_score = max(0, 100 - (total_errors * 2) - (circuit_breakers_open * 20))
response = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"health_score": health_score,
"status": (
"healthy" if health_score > 80 else "degraded" if health_score > 50 else "critical"
),
"error_statistics": error_stats,
"recent_errors": recent_errors,
"recommendations": _get_health_recommendations(error_stats, health_score),
}
context.add_audit_metadata(
action="resilience_status",
health_score=health_score,
error_total=error_stats.get("total_errors") if isinstance(error_stats, dict) else None,
open_circuit_breakers=circuit_breakers_open,
)
return response
class AdminResetErrorStatsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
try:
from src.core.resilience import resilience_manager
except ImportError as exc:
raise HTTPException(status_code=503, detail="韧性管理系统未启用") from exc
old_stats = resilience_manager.get_error_stats()
resilience_manager.error_stats.clear()
resilience_manager.last_errors.clear()
logger.info(f"管理员 {context.user.email if context.user else 'unknown'} 重置了错误统计")
context.add_audit_metadata(
action="reset_error_stats",
previous_total_errors=(
old_stats.get("total_errors") if isinstance(old_stats, dict) else None
),
)
return {
"message": "错误统计已重置",
"previous_stats": old_stats,
"reset_by": context.user.email if context.user else None,
"reset_at": datetime.now(timezone.utc).isoformat(),
}
class AdminCircuitHistoryAdapter(AdminApiAdapter):
def __init__(self, limit: int = 50):
super().__init__()
self.limit = limit
async def handle(self, context): # type: ignore[override]
history = HealthMonitor.get_circuit_history(self.limit)
context.add_audit_metadata(
action="circuit_history",
limit=self.limit,
result_count=len(history),
)
return {"items": history, "count": len(history)}
def _get_health_recommendations(error_stats: dict, health_score: int) -> List[str]:
recommendations: List[str] = []
if health_score < 50:
recommendations.append("系统健康状况严重,请立即检查错误日志")
if error_stats.get("total_errors", 0) > 100:
recommendations.append("错误频率过高,建议检查系统配置和外部依赖")
circuit_breakers = error_stats.get("circuit_breakers", {})
open_breakers = [k for k, v in circuit_breakers.items() if v.get("state") == "open"]
if open_breakers:
recommendations.append(f"以下服务熔断器已打开:{', '.join(open_breakers)}")
if health_score > 90:
recommendations.append("系统运行良好")
return recommendations

View File

@@ -0,0 +1,871 @@
"""
缓存监控端点
提供缓存亲和性统计、管理和监控功能
"""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.responses import PlainTextResponse
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pagination import PaginationMeta, build_pagination_payload, paginate_sequence
from src.api.base.pipeline import ApiRequestPipeline
from src.clients.redis_client import get_redis_client_sync
from src.core.crypto import crypto_service
from src.core.logger import logger
from src.database import get_db
from src.models.database import ApiKey, User
from src.services.cache.affinity_manager import get_affinity_manager
from src.services.cache.aware_scheduler import get_cache_aware_scheduler
router = APIRouter(prefix="/api/admin/monitoring/cache", tags=["Admin - Monitoring: Cache"])
pipeline = ApiRequestPipeline()
def mask_api_key(api_key: Optional[str], prefix_len: int = 8, suffix_len: int = 4) -> Optional[str]:
"""
脱敏 API Key显示前缀 + 星号 + 后缀
例如: sk-jhiId-xxxxxxxxxxxAABB -> sk-jhiId-********AABB
Args:
api_key: 原始 API Key
prefix_len: 显示的前缀长度,默认 8
suffix_len: 显示的后缀长度,默认 4
"""
if not api_key:
return None
total_visible = prefix_len + suffix_len
if len(api_key) <= total_visible:
# Key 太短,直接返回部分内容 + 星号
return api_key[:prefix_len] + "********"
return f"{api_key[:prefix_len]}********{api_key[-suffix_len:]}"
def decrypt_and_mask(encrypted_key: Optional[str], prefix_len: int = 8) -> Optional[str]:
"""
解密 API Key 后脱敏显示
Args:
encrypted_key: 加密后的 API Key
prefix_len: 显示的前缀长度
"""
if not encrypted_key:
return None
try:
decrypted = crypto_service.decrypt(encrypted_key)
return mask_api_key(decrypted, prefix_len)
except Exception:
# 解密失败时返回 None
return None
def resolve_user_identifier(db: Session, identifier: str) -> Optional[str]:
"""
将用户标识符username/email/user_id/api_key_id解析为 user_id
支持的输入格式:
1. User UUID (36位带横杠)
2. Username (用户名)
3. Email (邮箱)
4. API Key ID (36位UUID)
返回:
- user_id (UUID字符串) 或 None
"""
identifier = identifier.strip()
# 1. 先尝试作为 User UUID 查询
user = db.query(User).filter(User.id == identifier).first()
if user:
logger.debug(f"通过User ID解析: {identifier[:8]}... -> {user.username}")
return user.id
# 2. 尝试作为 Username 查询
user = db.query(User).filter(User.username == identifier).first()
if user:
logger.debug(f"通过Username解析: {identifier} -> {user.id[:8]}...")
return user.id
# 3. 尝试作为 Email 查询
user = db.query(User).filter(User.email == identifier).first()
if user:
logger.debug(f"通过Email解析: {identifier} -> {user.id[:8]}...")
return user.id
# 4. 尝试作为 API Key ID 查询
api_key = db.query(ApiKey).filter(ApiKey.id == identifier).first()
if api_key:
logger.debug(f"通过API Key ID解析: {identifier[:8]}... -> User ID: {api_key.user_id[:8]}...")
return api_key.user_id
# 无法识别
logger.debug(f"无法识别的用户标识符: {identifier}")
return None
@router.get("/stats")
async def get_cache_stats(
request: Request,
db: Session = Depends(get_db),
):
"""
获取缓存亲和性统计信息
返回:
- 缓存命中率
- 缓存用户数
- Provider切换次数
- Key切换次数
- 缓存预留配置
"""
adapter = AdminCacheStatsAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/affinity/{user_identifier}")
async def get_user_affinity(
user_identifier: str,
request: Request,
db: Session = Depends(get_db),
):
"""
查询指定用户的所有缓存亲和性
参数:
- user_identifier: 用户标识符,支持以下格式:
* 用户名 (username),如: yuanhonghu
* 邮箱 (email),如: user@example.com
* 用户UUID (user_id),如: 550e8400-e29b-41d4-a716-446655440000
* API Key ID如: 660e8400-e29b-41d4-a716-446655440000
返回:
- 用户信息
- 所有端点的缓存亲和性列表(每个端点一条记录)
"""
adapter = AdminGetUserAffinityAdapter(user_identifier=user_identifier)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/affinities")
async def list_affinities(
request: Request,
keyword: Optional[str] = None,
limit: int = Query(100, ge=1, le=1000, description="返回数量限制"),
offset: int = Query(0, ge=0, description="偏移量"),
db: Session = Depends(get_db),
):
"""
获取所有缓存亲和性列表,可选按关键词过滤
参数:
- keyword: 可选,支持用户名/邮箱/User ID/API Key ID 或模糊匹配
"""
adapter = AdminListAffinitiesAdapter(keyword=keyword, limit=limit, offset=offset)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/users/{user_identifier}")
async def clear_user_cache(
user_identifier: str,
request: Request,
db: Session = Depends(get_db),
):
"""
Clear cache affinity for a specific user
Parameters:
- user_identifier: User identifier (username, email, user_id, or API Key ID)
"""
adapter = AdminClearUserCacheAdapter(user_identifier=user_identifier)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("")
async def clear_all_cache(
request: Request,
db: Session = Depends(get_db),
):
"""
Clear all cache affinities
Warning: This affects all users, use with caution
"""
adapter = AdminClearAllCacheAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/providers/{provider_id}")
async def clear_provider_cache(
provider_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""
Clear cache affinities for a specific provider
Parameters:
- provider_id: Provider ID
"""
adapter = AdminClearProviderCacheAdapter(provider_id=provider_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/config")
async def get_cache_config(
request: Request,
db: Session = Depends(get_db),
):
"""
获取缓存相关配置
返回:
- 缓存TTL
- 缓存预留比例
"""
adapter = AdminCacheConfigAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/metrics", response_class=PlainTextResponse)
async def get_cache_metrics(
request: Request,
db: Session = Depends(get_db),
):
"""
以 Prometheus 文本格式暴露缓存调度指标,方便接入 Grafana。
"""
adapter = AdminCacheMetricsAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# -------- 缓存监控适配器 --------
class AdminCacheStatsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
try:
redis_client = get_redis_client_sync()
scheduler = await get_cache_aware_scheduler(redis_client)
stats = await scheduler.get_stats()
logger.info("缓存统计信息查询成功")
context.add_audit_metadata(
action="cache_stats",
scheduler=stats.get("scheduler"),
total_affinities=stats.get("total_affinities"),
cache_hit_rate=stats.get("cache_hit_rate"),
provider_switches=stats.get("provider_switches"),
)
return {"status": "ok", "data": stats}
except Exception as exc:
logger.exception(f"获取缓存统计信息失败: {exc}")
raise HTTPException(status_code=500, detail=f"获取缓存统计失败: {exc}")
class AdminCacheMetricsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
try:
redis_client = get_redis_client_sync()
scheduler = await get_cache_aware_scheduler(redis_client)
stats = await scheduler.get_stats()
payload = self._format_prometheus(stats)
context.add_audit_metadata(
action="cache_metrics_export",
scheduler=stats.get("scheduler"),
metrics_lines=payload.count("\n"),
)
return PlainTextResponse(payload)
except Exception as exc:
logger.exception(f"导出缓存指标失败: {exc}")
raise HTTPException(status_code=500, detail=f"导出缓存指标失败: {exc}")
def _format_prometheus(self, stats: Dict[str, Any]) -> str:
"""
将 scheduler/affinity 指标转换为 Prometheus 文本格式。
"""
scheduler_metrics = stats.get("scheduler_metrics", {})
affinity_stats = stats.get("affinity_stats", {})
metric_map: List[Tuple[str, str, float]] = [
(
"cache_scheduler_total_batches",
"Total batches pulled from provider list",
float(scheduler_metrics.get("total_batches", 0)),
),
(
"cache_scheduler_last_batch_size",
"Size of the latest candidate batch",
float(scheduler_metrics.get("last_batch_size", 0)),
),
(
"cache_scheduler_total_candidates",
"Total candidates enumerated by scheduler",
float(scheduler_metrics.get("total_candidates", 0)),
),
(
"cache_scheduler_last_candidate_count",
"Number of candidates in the most recent batch",
float(scheduler_metrics.get("last_candidate_count", 0)),
),
(
"cache_scheduler_cache_hits",
"Cache hits counted during scheduling",
float(scheduler_metrics.get("cache_hits", 0)),
),
(
"cache_scheduler_cache_misses",
"Cache misses counted during scheduling",
float(scheduler_metrics.get("cache_misses", 0)),
),
(
"cache_scheduler_cache_hit_rate",
"Cache hit rate during scheduling",
float(scheduler_metrics.get("cache_hit_rate", 0.0)),
),
(
"cache_scheduler_concurrency_denied",
"Times candidate rejected due to concurrency limits",
float(scheduler_metrics.get("concurrency_denied", 0)),
),
(
"cache_scheduler_avg_candidates_per_batch",
"Average candidates per batch",
float(scheduler_metrics.get("avg_candidates_per_batch", 0.0)),
),
]
affinity_map: List[Tuple[str, str, float]] = [
(
"cache_affinity_total",
"Total cache affinities stored",
float(affinity_stats.get("total_affinities", 0)),
),
(
"cache_affinity_hits",
"Affinity cache hits",
float(affinity_stats.get("cache_hits", 0)),
),
(
"cache_affinity_misses",
"Affinity cache misses",
float(affinity_stats.get("cache_misses", 0)),
),
(
"cache_affinity_hit_rate",
"Affinity cache hit rate",
float(affinity_stats.get("cache_hit_rate", 0.0)),
),
(
"cache_affinity_invalidations",
"Affinity invalidations",
float(affinity_stats.get("cache_invalidations", 0)),
),
(
"cache_affinity_provider_switches",
"Affinity provider switches",
float(affinity_stats.get("provider_switches", 0)),
),
(
"cache_affinity_key_switches",
"Affinity key switches",
float(affinity_stats.get("key_switches", 0)),
),
]
lines = []
for name, help_text, value in metric_map + affinity_map:
lines.append(f"# HELP {name} {help_text}")
lines.append(f"# TYPE {name} gauge")
lines.append(f"{name} {value}")
scheduler_name = stats.get("scheduler", "cache_aware")
lines.append(f'cache_scheduler_info{{scheduler="{scheduler_name}"}} 1')
return "\n".join(lines) + "\n"
@dataclass
class AdminGetUserAffinityAdapter(AdminApiAdapter):
user_identifier: str
async def handle(self, context): # type: ignore[override]
db = context.db
try:
user_id = resolve_user_identifier(db, self.user_identifier)
if not user_id:
raise HTTPException(
status_code=404,
detail=f"无法识别的用户标识符: {self.user_identifier}。支持用户名、邮箱、User ID或API Key ID",
)
user = db.query(User).filter(User.id == user_id).first()
redis_client = get_redis_client_sync()
affinity_mgr = await get_affinity_manager(redis_client)
# 获取该用户的所有缓存亲和性
all_affinities = await affinity_mgr.list_affinities()
user_affinities = [aff for aff in all_affinities if aff.get("user_id") == user_id]
if not user_affinities:
response = {
"status": "not_found",
"message": f"用户 {user.username} ({user.email}) 没有缓存亲和性",
"user_info": {
"user_id": user_id,
"username": user.username,
"email": user.email,
},
"affinities": [],
}
context.add_audit_metadata(
action="cache_user_affinity",
user_identifier=self.user_identifier,
resolved_user_id=user_id,
affinity_count=0,
status="not_found",
)
return response
response = {
"status": "ok",
"user_info": {
"user_id": user_id,
"username": user.username,
"email": user.email,
},
"affinities": [
{
"provider_id": aff["provider_id"],
"endpoint_id": aff["endpoint_id"],
"key_id": aff["key_id"],
"api_format": aff.get("api_format"),
"model_name": aff.get("model_name"),
"created_at": aff["created_at"],
"expire_at": aff["expire_at"],
"request_count": aff["request_count"],
}
for aff in user_affinities
],
"total_endpoints": len(user_affinities),
}
context.add_audit_metadata(
action="cache_user_affinity",
user_identifier=self.user_identifier,
resolved_user_id=user_id,
affinity_count=len(user_affinities),
status="ok",
)
return response
except HTTPException:
raise
except Exception as exc:
logger.exception(f"查询用户缓存亲和性失败: {exc}")
raise HTTPException(status_code=500, detail=f"查询失败: {exc}")
@dataclass
class AdminListAffinitiesAdapter(AdminApiAdapter):
keyword: Optional[str]
limit: int
offset: int
async def handle(self, context): # type: ignore[override]
db = context.db
redis_client = get_redis_client_sync()
if not redis_client:
raise HTTPException(status_code=503, detail="Redis未初始化无法获取缓存亲和性")
affinity_mgr = await get_affinity_manager(redis_client)
matched_user_id = None
matched_api_key_id = None
raw_affinities: List[Dict[str, Any]] = []
if self.keyword:
# 首先检查是否是 API Key IDaffinity_key
api_key = db.query(ApiKey).filter(ApiKey.id == self.keyword).first()
if api_key:
# 直接通过 affinity_key 过滤
matched_api_key_id = str(api_key.id)
matched_user_id = str(api_key.user_id)
all_affinities = await affinity_mgr.list_affinities()
raw_affinities = [
aff for aff in all_affinities if aff.get("affinity_key") == matched_api_key_id
]
else:
# 尝试解析为用户标识
user_id = resolve_user_identifier(db, self.keyword)
if user_id:
matched_user_id = user_id
# 获取该用户所有的 API Key ID
user_api_keys = db.query(ApiKey).filter(ApiKey.user_id == user_id).all()
user_api_key_ids = {str(k.id) for k in user_api_keys}
# 过滤出该用户所有 API Key 的亲和性
all_affinities = await affinity_mgr.list_affinities()
raw_affinities = [
aff for aff in all_affinities if aff.get("affinity_key") in user_api_key_ids
]
else:
# 关键词不是有效标识,返回所有亲和性(后续会进行模糊匹配)
raw_affinities = await affinity_mgr.list_affinities()
else:
raw_affinities = await affinity_mgr.list_affinities()
# 收集所有 affinity_key (API Key ID)
affinity_keys = {
item.get("affinity_key") for item in raw_affinities if item.get("affinity_key")
}
# 批量查询用户 API Key 信息
user_api_key_map: Dict[str, ApiKey] = {}
if affinity_keys:
user_api_keys = db.query(ApiKey).filter(ApiKey.id.in_(list(affinity_keys))).all()
user_api_key_map = {str(k.id): k for k in user_api_keys}
# 收集所有 user_id
user_ids = {str(k.user_id) for k in user_api_key_map.values()}
user_map: Dict[str, User] = {}
if user_ids:
users = db.query(User).filter(User.id.in_(list(user_ids))).all()
user_map = {str(user.id): user for user in users}
# 收集所有provider_id、endpoint_id、key_id
provider_ids = {
item.get("provider_id") for item in raw_affinities if item.get("provider_id")
}
endpoint_ids = {
item.get("endpoint_id") for item in raw_affinities if item.get("endpoint_id")
}
key_ids = {item.get("key_id") for item in raw_affinities if item.get("key_id")}
# 批量查询Provider、Endpoint、Key信息
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
provider_map = {}
if provider_ids:
providers = db.query(Provider).filter(Provider.id.in_(list(provider_ids))).all()
provider_map = {p.id: p for p in providers}
endpoint_map = {}
if endpoint_ids:
endpoints = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id.in_(list(endpoint_ids))).all()
)
endpoint_map = {e.id: e for e in endpoints}
key_map = {}
if key_ids:
keys = db.query(ProviderAPIKey).filter(ProviderAPIKey.id.in_(list(key_ids))).all()
key_map = {k.id: k for k in keys}
# 收集所有 model_name实际存储的是 global_model_id并批量查询 GlobalModel
from src.models.database import GlobalModel
global_model_ids = {
item.get("model_name") for item in raw_affinities if item.get("model_name")
}
global_model_map: Dict[str, GlobalModel] = {}
if global_model_ids:
# model_name 可能是 UUID 格式的 global_model_id也可能是原始模型名称
global_models = db.query(GlobalModel).filter(
GlobalModel.id.in_(list(global_model_ids))
).all()
global_model_map = {str(gm.id): gm for gm in global_models}
keyword_lower = self.keyword.lower() if self.keyword else None
items = []
for affinity in raw_affinities:
affinity_key = affinity.get("affinity_key")
if not affinity_key:
continue
# 通过 affinity_keyAPI Key ID找到用户 API Key 和用户
user_api_key = user_api_key_map.get(affinity_key)
user = user_map.get(str(user_api_key.user_id)) if user_api_key else None
user_id = str(user_api_key.user_id) if user_api_key else None
provider_id = affinity.get("provider_id")
endpoint_id = affinity.get("endpoint_id")
key_id = affinity.get("key_id")
provider = provider_map.get(provider_id)
endpoint = endpoint_map.get(endpoint_id)
key = key_map.get(key_id)
# 用户 API Key 脱敏显示(解密 key_encrypted 后脱敏)
user_api_key_masked = None
if user_api_key and user_api_key.key_encrypted:
user_api_key_masked = decrypt_and_mask(user_api_key.key_encrypted)
# Provider Key 脱敏显示(解密 api_key 后脱敏)
provider_key_masked = None
if key and key.api_key:
provider_key_masked = decrypt_and_mask(key.api_key)
item = {
"affinity_key": affinity_key,
"user_api_key_name": user_api_key.name if user_api_key else None,
"user_api_key_prefix": user_api_key_masked,
"is_standalone": user_api_key.is_standalone if user_api_key else False,
"user_id": user_id,
"username": user.username if user else None,
"email": user.email if user else None,
"provider_id": provider_id,
"provider_name": provider.display_name if provider else None,
"endpoint_id": endpoint_id,
"endpoint_api_format": (
endpoint.api_format if endpoint and endpoint.api_format else None
),
"endpoint_url": endpoint.base_url if endpoint else None,
"key_id": key_id,
"key_name": key.name if key else None,
"key_prefix": provider_key_masked,
"rate_multiplier": key.rate_multiplier if key else 1.0,
"model_name": (
global_model_map.get(affinity.get("model_name")).name
if affinity.get("model_name") and global_model_map.get(affinity.get("model_name"))
else affinity.get("model_name") # 如果找不到 GlobalModel显示原始值
),
"model_display_name": (
global_model_map.get(affinity.get("model_name")).display_name
if affinity.get("model_name") and global_model_map.get(affinity.get("model_name"))
else None
),
"api_format": affinity.get("api_format"),
"created_at": affinity.get("created_at"),
"expire_at": affinity.get("expire_at"),
"request_count": affinity.get("request_count", 0),
}
if keyword_lower and not matched_user_id and not matched_api_key_id:
searchable = [
item["affinity_key"],
item["user_api_key_name"] or "",
item["user_id"] or "",
item["username"] or "",
item["email"] or "",
item["provider_id"] or "",
item["key_id"] or "",
]
if not any(keyword_lower in str(value).lower() for value in searchable if value):
continue
items.append(item)
items.sort(key=lambda x: x.get("expire_at") or 0, reverse=True)
paged_items, meta = paginate_sequence(items, self.limit, self.offset)
payload = build_pagination_payload(
paged_items,
meta,
matched_user_id=matched_user_id,
)
response = {
"status": "ok",
"data": payload,
}
result_count = meta.count if hasattr(meta, "count") else len(paged_items)
context.add_audit_metadata(
action="cache_affinity_list",
keyword=self.keyword,
matched_user_id=matched_user_id,
matched_api_key_id=matched_api_key_id,
limit=self.limit,
offset=self.offset,
result_count=result_count,
)
return response
@dataclass
class AdminClearUserCacheAdapter(AdminApiAdapter):
user_identifier: str
async def handle(self, context): # type: ignore[override]
db = context.db
try:
redis_client = get_redis_client_sync()
affinity_mgr = await get_affinity_manager(redis_client)
# 首先检查是否直接是 API Key ID (affinity_key)
api_key = db.query(ApiKey).filter(ApiKey.id == self.user_identifier).first()
if api_key:
# 直接按 affinity_key 清除
affinity_key = str(api_key.id)
user = db.query(User).filter(User.id == api_key.user_id).first()
all_affinities = await affinity_mgr.list_affinities()
target_affinities = [
aff for aff in all_affinities if aff.get("affinity_key") == affinity_key
]
count = 0
for aff in target_affinities:
api_format = aff.get("api_format")
model_name = aff.get("model_name")
endpoint_id = aff.get("endpoint_id")
if api_format and model_name:
await affinity_mgr.invalidate_affinity(
affinity_key, api_format, model_name, endpoint_id=endpoint_id
)
count += 1
logger.info(f"已清除API Key缓存亲和性: api_key_name={api_key.name}, affinity_key={affinity_key[:8]}..., 清除数量={count}")
response = {
"status": "ok",
"message": f"已清除 API Key {api_key.name} 的缓存亲和性",
"user_info": {
"user_id": str(api_key.user_id),
"username": user.username if user else None,
"email": user.email if user else None,
"api_key_id": affinity_key,
"api_key_name": api_key.name,
},
}
context.add_audit_metadata(
action="cache_clear_api_key",
user_identifier=self.user_identifier,
resolved_api_key_id=affinity_key,
cleared_count=count,
)
return response
# 如果不是 API Key ID尝试解析为用户标识
user_id = resolve_user_identifier(db, self.user_identifier)
if not user_id:
raise HTTPException(
status_code=404,
detail=f"无法识别的标识符: {self.user_identifier}。支持用户名、邮箱、User ID或API Key ID",
)
user = db.query(User).filter(User.id == user_id).first()
# 获取该用户所有的 API Key
user_api_keys = db.query(ApiKey).filter(ApiKey.user_id == user_id).all()
user_api_key_ids = {str(k.id) for k in user_api_keys}
# 获取该用户所有 API Key 的缓存亲和性并逐个失效
all_affinities = await affinity_mgr.list_affinities()
user_affinities = [
aff for aff in all_affinities if aff.get("affinity_key") in user_api_key_ids
]
count = 0
for aff in user_affinities:
affinity_key = aff.get("affinity_key")
api_format = aff.get("api_format")
model_name = aff.get("model_name")
endpoint_id = aff.get("endpoint_id")
if affinity_key and api_format and model_name:
await affinity_mgr.invalidate_affinity(
affinity_key, api_format, model_name, endpoint_id=endpoint_id
)
count += 1
logger.info(f"已清除用户缓存亲和性: username={user.username}, user_id={user_id[:8]}..., 清除数量={count}")
response = {
"status": "ok",
"message": f"已清除用户 {user.username} 的所有缓存亲和性",
"user_info": {"user_id": user_id, "username": user.username, "email": user.email},
}
context.add_audit_metadata(
action="cache_clear_user",
user_identifier=self.user_identifier,
resolved_user_id=user_id,
cleared_count=count,
)
return response
except HTTPException:
raise
except Exception as exc:
logger.exception(f"清除用户缓存亲和性失败: {exc}")
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
class AdminClearAllCacheAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
try:
redis_client = get_redis_client_sync()
affinity_mgr = await get_affinity_manager(redis_client)
count = await affinity_mgr.clear_all()
logger.warning(f"已清除所有缓存亲和性(管理员操作): {count}")
context.add_audit_metadata(
action="cache_clear_all",
cleared_count=count,
)
return {"status": "ok", "message": "已清除所有缓存亲和性", "count": count}
except Exception as exc:
logger.exception(f"清除所有缓存亲和性失败: {exc}")
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
@dataclass
class AdminClearProviderCacheAdapter(AdminApiAdapter):
provider_id: str
async def handle(self, context): # type: ignore[override]
try:
redis_client = get_redis_client_sync()
affinity_mgr = await get_affinity_manager(redis_client)
count = await affinity_mgr.invalidate_all_for_provider(self.provider_id)
logger.info(f"已清除Provider缓存亲和性: provider_id={self.provider_id[:8]}..., count={count}")
context.add_audit_metadata(
action="cache_clear_provider",
provider_id=self.provider_id,
cleared_count=count,
)
return {
"status": "ok",
"message": "已清除Provider的缓存亲和性",
"provider_id": self.provider_id,
"count": count,
}
except Exception as exc:
logger.exception(f"清除Provider缓存亲和性失败: {exc}")
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
class AdminCacheConfigAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
from src.services.cache.affinity_manager import CacheAffinityManager
from src.services.cache.aware_scheduler import CacheAwareScheduler
from src.services.rate_limit.adaptive_reservation import get_adaptive_reservation_manager
# 获取动态预留管理器的配置
reservation_manager = get_adaptive_reservation_manager()
reservation_stats = reservation_manager.get_stats()
response = {
"status": "ok",
"data": {
"cache_ttl_seconds": CacheAffinityManager.DEFAULT_CACHE_TTL,
"cache_reservation_ratio": CacheAwareScheduler.CACHE_RESERVATION_RATIO,
"dynamic_reservation": {
"enabled": True,
"config": reservation_stats["config"],
"description": {
"probe_phase_requests": "探测阶段请求数阈值",
"probe_reservation": "探测阶段预留比例",
"stable_min_reservation": "稳定阶段最小预留比例",
"stable_max_reservation": "稳定阶段最大预留比例",
"low_load_threshold": "低负载阈值(低于此值使用最小预留)",
"high_load_threshold": "高负载阈值(高于此值根据置信度使用较高预留)",
},
},
"description": {
"cache_ttl": "缓存亲和性有效期(秒)",
"cache_reservation_ratio": "静态预留比例(已被动态预留替代)",
"dynamic_reservation": "动态预留机制配置",
},
},
}
context.add_audit_metadata(
action="cache_config",
cache_ttl_seconds=CacheAffinityManager.DEFAULT_CACHE_TTL,
cache_reservation_ratio=CacheAwareScheduler.CACHE_RESERVATION_RATIO,
dynamic_reservation_enabled=True,
)
return response

View File

@@ -0,0 +1,280 @@
"""
请求链路追踪 API 端点
"""
from dataclasses import dataclass
from datetime import datetime
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from pydantic import BaseModel
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.database import get_db
from src.models.database import Provider, ProviderEndpoint, ProviderAPIKey
from src.services.request.candidate import RequestCandidateService
router = APIRouter(prefix="/api/admin/monitoring/trace", tags=["Admin - Monitoring: Trace"])
pipeline = ApiRequestPipeline()
class CandidateResponse(BaseModel):
"""候选记录响应"""
id: str
request_id: str
candidate_index: int
retry_index: int = 0 # 重试序号从0开始
provider_id: Optional[str] = None
provider_name: Optional[str] = None
provider_website: Optional[str] = None # Provider 官网
endpoint_id: Optional[str] = None
endpoint_name: Optional[str] = None # 端点显示名称api_format
key_id: Optional[str] = None
key_name: Optional[str] = None # 密钥名称
key_preview: Optional[str] = None # 密钥脱敏预览(如 sk-***abc
key_capabilities: Optional[dict] = None # Key 支持的能力
required_capabilities: Optional[dict] = None # 请求实际需要的能力标签
status: str # 'pending', 'success', 'failed', 'skipped'
skip_reason: Optional[str] = None
is_cached: bool = False
# 执行结果字段
status_code: Optional[int] = None
error_type: Optional[str] = None
error_message: Optional[str] = None
latency_ms: Optional[int] = None
concurrent_requests: Optional[int] = None
extra_data: Optional[dict] = None
created_at: datetime
started_at: Optional[datetime] = None
finished_at: Optional[datetime] = None
class Config:
from_attributes = True
class RequestTraceResponse(BaseModel):
"""请求追踪完整响应"""
request_id: str
total_candidates: int
final_status: str # 'success', 'failed', 'streaming', 'pending'
total_latency_ms: int
candidates: List[CandidateResponse]
@router.get("/{request_id}", response_model=RequestTraceResponse)
async def get_request_trace(
request_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""获取特定请求的完整追踪信息"""
adapter = AdminGetRequestTraceAdapter(request_id=request_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/stats/provider/{provider_id}")
async def get_provider_failure_rate(
provider_id: str,
request: Request,
limit: int = Query(100, ge=1, le=1000, description="统计最近的尝试数量"),
db: Session = Depends(get_db),
):
"""
获取某个 Provider 的失败率统计
需要管理员权限
"""
adapter = AdminProviderFailureRateAdapter(provider_id=provider_id, limit=limit)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# -------- 请求追踪适配器 --------
@dataclass
class AdminGetRequestTraceAdapter(AdminApiAdapter):
request_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
# 只查询 candidates
candidates = RequestCandidateService.get_candidates_by_request_id(db, self.request_id)
# 如果没有数据,返回 404
if not candidates:
raise HTTPException(status_code=404, detail="Request not found")
# 计算总延迟只统计已完成的候选success 或 failed
# 使用显式的 is not None 检查,避免过滤掉 0ms 的快速响应
total_latency = sum(
c.latency_ms
for c in candidates
if c.status in ("success", "failed") and c.latency_ms is not None
)
# 判断最终状态:
# 1. status="success" 即视为成功(无论 status_code 是什么)
# - 流式请求即使客户端断开499只要 Provider 成功返回数据,也算成功
# 2. 同时检查 status_code 在 200-299 范围,作为额外的成功判断条件
# - 用于兼容非流式请求或未正确设置 status 的旧数据
# 3. status="streaming" 表示流式请求正在进行中
# 4. status="pending" 表示请求尚未开始执行
has_success = any(
c.status == "success"
or (c.status_code is not None and 200 <= c.status_code < 300)
for c in candidates
)
has_streaming = any(c.status == "streaming" for c in candidates)
has_pending = any(c.status == "pending" for c in candidates)
if has_success:
final_status = "success"
elif has_streaming:
# 有候选正在流式传输中
final_status = "streaming"
elif has_pending:
# 有候选正在等待执行
final_status = "pending"
else:
final_status = "failed"
# 批量加载 provider 信息,避免 N+1 查询
provider_ids = {c.provider_id for c in candidates if c.provider_id}
provider_map = {}
provider_website_map = {}
if provider_ids:
providers = db.query(Provider).filter(Provider.id.in_(provider_ids)).all()
for p in providers:
provider_map[p.id] = p.name
provider_website_map[p.id] = p.website
# 批量加载 endpoint 信息
endpoint_ids = {c.endpoint_id for c in candidates if c.endpoint_id}
endpoint_map = {}
if endpoint_ids:
endpoints = db.query(ProviderEndpoint).filter(ProviderEndpoint.id.in_(endpoint_ids)).all()
endpoint_map = {e.id: e.api_format for e in endpoints}
# 批量加载 key 信息
key_ids = {c.key_id for c in candidates if c.key_id}
key_map = {}
key_preview_map = {}
key_capabilities_map = {}
if key_ids:
keys = db.query(ProviderAPIKey).filter(ProviderAPIKey.id.in_(key_ids)).all()
for k in keys:
key_map[k.id] = k.name
key_capabilities_map[k.id] = k.capabilities
# 生成脱敏预览保留前缀和最后4位
api_key = k.api_key or ""
if len(api_key) > 8:
# 检测常见前缀模式
prefix_end = 0
for prefix in ["sk-", "key-", "api-", "ak-"]:
if api_key.lower().startswith(prefix):
prefix_end = len(prefix)
break
if prefix_end > 0:
key_preview_map[k.id] = f"{api_key[:prefix_end]}***{api_key[-4:]}"
else:
key_preview_map[k.id] = f"{api_key[:3]}***{api_key[-4:]}"
elif len(api_key) > 4:
key_preview_map[k.id] = f"***{api_key[-4:]}"
else:
key_preview_map[k.id] = "***"
# 构建 candidate 响应列表
candidate_responses: List[CandidateResponse] = []
for candidate in candidates:
provider_name = (
provider_map.get(candidate.provider_id) if candidate.provider_id else None
)
provider_website = (
provider_website_map.get(candidate.provider_id) if candidate.provider_id else None
)
endpoint_name = (
endpoint_map.get(candidate.endpoint_id) if candidate.endpoint_id else None
)
key_name = (
key_map.get(candidate.key_id) if candidate.key_id else None
)
key_preview = (
key_preview_map.get(candidate.key_id) if candidate.key_id else None
)
key_capabilities = (
key_capabilities_map.get(candidate.key_id) if candidate.key_id else None
)
candidate_responses.append(
CandidateResponse(
id=candidate.id,
request_id=candidate.request_id,
candidate_index=candidate.candidate_index,
retry_index=candidate.retry_index,
provider_id=candidate.provider_id,
provider_name=provider_name,
provider_website=provider_website,
endpoint_id=candidate.endpoint_id,
endpoint_name=endpoint_name,
key_id=candidate.key_id,
key_name=key_name,
key_preview=key_preview,
key_capabilities=key_capabilities,
required_capabilities=candidate.required_capabilities,
status=candidate.status,
skip_reason=candidate.skip_reason,
is_cached=candidate.is_cached,
status_code=candidate.status_code,
error_type=candidate.error_type,
error_message=candidate.error_message,
latency_ms=candidate.latency_ms,
concurrent_requests=candidate.concurrent_requests,
extra_data=candidate.extra_data,
created_at=candidate.created_at,
started_at=candidate.started_at,
finished_at=candidate.finished_at,
)
)
response = RequestTraceResponse(
request_id=self.request_id,
total_candidates=len(candidates),
final_status=final_status,
total_latency_ms=total_latency,
candidates=candidate_responses,
)
context.add_audit_metadata(
action="trace_request_detail",
request_id=self.request_id,
total_candidates=len(candidates),
final_status=final_status,
total_latency_ms=total_latency,
)
return response
@dataclass
class AdminProviderFailureRateAdapter(AdminApiAdapter):
provider_id: str
limit: int
async def handle(self, context): # type: ignore[override]
result = RequestCandidateService.get_candidate_stats_by_provider(
db=context.db,
provider_id=self.provider_id,
limit=self.limit,
)
context.add_audit_metadata(
action="trace_provider_failure_rate",
provider_id=self.provider_id,
limit=self.limit,
total_attempts=result.get("total_attempts"),
failure_rate=result.get("failure_rate"),
)
return result

View File

@@ -0,0 +1,410 @@
"""
Provider Query API 端点
用于查询提供商的余额、使用记录等信息
"""
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from src.core.logger import logger
from src.database.database import get_db
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint, User
# 初始化适配器注册
from src.plugins.provider_query import init # noqa
from src.plugins.provider_query import get_query_registry
from src.plugins.provider_query.base import QueryCapability
from src.utils.auth_utils import get_current_user
router = APIRouter(prefix="/provider-query", tags=["Provider Query"])
# ============ Request/Response Models ============
class BalanceQueryRequest(BaseModel):
"""余额查询请求"""
provider_id: str
api_key_id: Optional[str] = None # 如果不指定,使用提供商的第一个可用 API Key
class UsageSummaryQueryRequest(BaseModel):
"""使用汇总查询请求"""
provider_id: str
api_key_id: Optional[str] = None
period: str = "month" # day, week, month, year
class ModelsQueryRequest(BaseModel):
"""模型列表查询请求"""
provider_id: str
api_key_id: Optional[str] = None
# ============ API Endpoints ============
@router.get("/adapters")
async def list_adapters(
current_user: User = Depends(get_current_user),
):
"""
获取所有可用的查询适配器
Returns:
适配器列表
"""
registry = get_query_registry()
adapters = registry.list_adapters()
return {"success": True, "data": adapters}
@router.get("/capabilities/{provider_id}")
async def get_provider_capabilities(
provider_id: str,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
获取提供商支持的查询能力
Args:
provider_id: 提供商 ID
Returns:
支持的查询能力列表
"""
# 获取提供商
from sqlalchemy import select
result = await db.execute(select(Provider).where(Provider.id == provider_id))
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
registry = get_query_registry()
capabilities = registry.get_capabilities_for_provider(provider.name)
if capabilities is None:
return {
"success": True,
"data": {
"provider_id": provider_id,
"provider_name": provider.name,
"capabilities": [],
"has_adapter": False,
"message": "No query adapter available for this provider",
},
}
return {
"success": True,
"data": {
"provider_id": provider_id,
"provider_name": provider.name,
"capabilities": [c.name for c in capabilities],
"has_adapter": True,
},
}
@router.post("/balance")
async def query_balance(
request: BalanceQueryRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
查询提供商余额
Args:
request: 查询请求
Returns:
余额信息
"""
from sqlalchemy import select
from sqlalchemy.orm import selectinload
# 获取提供商及其端点
result = await db.execute(
select(Provider)
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
.where(Provider.id == request.provider_id)
)
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
# 获取 API Key
api_key_value = None
endpoint_config = None
if request.api_key_id:
# 查找指定的 API Key
for endpoint in provider.endpoints:
for api_key in endpoint.api_keys:
if api_key.id == request.api_key_id:
api_key_value = api_key.api_key
endpoint_config = {
"base_url": endpoint.base_url,
"api_format": endpoint.api_format if endpoint.api_format else None,
}
break
if api_key_value:
break
if not api_key_value:
raise HTTPException(status_code=404, detail="API Key not found")
else:
# 使用第一个可用的 API Key
for endpoint in provider.endpoints:
if endpoint.is_active and endpoint.api_keys:
for api_key in endpoint.api_keys:
if api_key.is_active:
api_key_value = api_key.api_key
endpoint_config = {
"base_url": endpoint.base_url,
"api_format": endpoint.api_format if endpoint.api_format else None,
}
break
if api_key_value:
break
if not api_key_value:
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
# 查询余额
registry = get_query_registry()
query_result = await registry.query_provider_balance(
provider_type=provider.name, api_key=api_key_value, endpoint_config=endpoint_config
)
if not query_result.success:
logger.warning(f"Balance query failed for provider {provider.name}: {query_result.error}")
return {
"success": query_result.success,
"data": query_result.to_dict(),
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
}
@router.post("/usage-summary")
async def query_usage_summary(
request: UsageSummaryQueryRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
查询提供商使用汇总
Args:
request: 查询请求
Returns:
使用汇总信息
"""
from sqlalchemy import select
from sqlalchemy.orm import selectinload
# 获取提供商及其端点
result = await db.execute(
select(Provider)
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
.where(Provider.id == request.provider_id)
)
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
# 获取 API Key逻辑同上
api_key_value = None
endpoint_config = None
if request.api_key_id:
for endpoint in provider.endpoints:
for api_key in endpoint.api_keys:
if api_key.id == request.api_key_id:
api_key_value = api_key.api_key
endpoint_config = {"base_url": endpoint.base_url}
break
if api_key_value:
break
if not api_key_value:
raise HTTPException(status_code=404, detail="API Key not found")
else:
for endpoint in provider.endpoints:
if endpoint.is_active and endpoint.api_keys:
for api_key in endpoint.api_keys:
if api_key.is_active:
api_key_value = api_key.api_key
endpoint_config = {"base_url": endpoint.base_url}
break
if api_key_value:
break
if not api_key_value:
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
# 查询使用汇总
registry = get_query_registry()
query_result = await registry.query_provider_usage(
provider_type=provider.name,
api_key=api_key_value,
period=request.period,
endpoint_config=endpoint_config,
)
return {
"success": query_result.success,
"data": query_result.to_dict(),
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
}
@router.post("/models")
async def query_available_models(
request: ModelsQueryRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
查询提供商可用模型
Args:
request: 查询请求
Returns:
模型列表
"""
from sqlalchemy import select
from sqlalchemy.orm import selectinload
# 获取提供商及其端点
result = await db.execute(
select(Provider)
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
.where(Provider.id == request.provider_id)
)
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
# 获取 API Key
api_key_value = None
endpoint_config = None
if request.api_key_id:
for endpoint in provider.endpoints:
for api_key in endpoint.api_keys:
if api_key.id == request.api_key_id:
api_key_value = api_key.api_key
endpoint_config = {"base_url": endpoint.base_url}
break
if api_key_value:
break
if not api_key_value:
raise HTTPException(status_code=404, detail="API Key not found")
else:
for endpoint in provider.endpoints:
if endpoint.is_active and endpoint.api_keys:
for api_key in endpoint.api_keys:
if api_key.is_active:
api_key_value = api_key.api_key
endpoint_config = {"base_url": endpoint.base_url}
break
if api_key_value:
break
if not api_key_value:
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
# 查询模型
registry = get_query_registry()
adapter = registry.get_adapter_for_provider(provider.name)
if not adapter:
raise HTTPException(
status_code=400, detail=f"No query adapter available for provider: {provider.name}"
)
query_result = await adapter.query_available_models(
api_key=api_key_value, endpoint_config=endpoint_config
)
return {
"success": query_result.success,
"data": query_result.to_dict(),
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
}
@router.delete("/cache/{provider_id}")
async def clear_query_cache(
provider_id: str,
api_key_id: Optional[str] = None,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
清除查询缓存
Args:
provider_id: 提供商 ID
api_key_id: 可选,指定清除某个 API Key 的缓存
Returns:
清除结果
"""
from sqlalchemy import select
# 获取提供商
result = await db.execute(select(Provider).where(Provider.id == provider_id))
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
registry = get_query_registry()
adapter = registry.get_adapter_for_provider(provider.name)
if adapter:
if api_key_id:
# 获取 API Key 值来清除缓存
from sqlalchemy.orm import selectinload
result = await db.execute(select(ProviderAPIKey).where(ProviderAPIKey.id == api_key_id))
api_key = result.scalar_one_or_none()
if api_key:
adapter.clear_cache(api_key.api_key)
else:
adapter.clear_cache()
return {"success": True, "message": "Cache cleared successfully"}

View File

@@ -0,0 +1,272 @@
"""
提供商策略管理 API 端点
"""
from datetime import datetime, timedelta
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.enums import ProviderBillingType
from src.core.exceptions import InvalidRequestException, translate_pydantic_error
from src.core.logger import logger
from src.database import get_db
from src.models.database import Provider
from src.models.database_extensions import ProviderUsageTracking
router = APIRouter(prefix="/api/admin/provider-strategy", tags=["Provider Strategy"])
pipeline = ApiRequestPipeline()
class ProviderBillingUpdate(BaseModel):
billing_type: ProviderBillingType
monthly_quota_usd: Optional[float] = None
quota_reset_day: int = Field(default=30, ge=1, le=365) # 重置周期(天数)
quota_last_reset_at: Optional[str] = None # 当前周期开始时间
quota_expires_at: Optional[str] = None
rpm_limit: Optional[int] = Field(default=None, ge=0)
provider_priority: int = Field(default=100, ge=0, le=200)
@router.put("/providers/{provider_id}/billing")
async def update_provider_billing(
provider_id: str,
request: Request,
db: Session = Depends(get_db),
):
adapter = AdminProviderBillingAdapter(provider_id=provider_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/providers/{provider_id}/stats")
async def get_provider_stats(
provider_id: str,
request: Request,
hours: int = 24,
db: Session = Depends(get_db),
):
adapter = AdminProviderStatsAdapter(provider_id=provider_id, hours=hours)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/providers/{provider_id}/quota")
async def reset_provider_quota(
provider_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""Reset provider quota usage to zero"""
adapter = AdminProviderResetQuotaAdapter(provider_id=provider_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/strategies")
async def list_available_strategies(request: Request, db: Session = Depends(get_db)):
adapter = AdminListStrategiesAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
class AdminProviderBillingAdapter(AdminApiAdapter):
def __init__(self, provider_id: str):
self.provider_id = provider_id
async def handle(self, context):
db = context.db
payload = context.ensure_json_body()
try:
config = ProviderBillingUpdate.model_validate(payload)
except ValidationError as e:
errors = e.errors()
if errors:
raise InvalidRequestException(translate_pydantic_error(errors[0]))
raise InvalidRequestException("请求数据验证失败")
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
provider.billing_type = config.billing_type
provider.monthly_quota_usd = config.monthly_quota_usd
provider.quota_reset_day = config.quota_reset_day
provider.rpm_limit = config.rpm_limit
provider.provider_priority = config.provider_priority
from dateutil import parser
from sqlalchemy import func
from src.models.database import Usage
if config.quota_last_reset_at:
new_reset_at = parser.parse(config.quota_last_reset_at)
provider.quota_last_reset_at = new_reset_at
# 自动同步该周期内的历史使用量
period_usage = (
db.query(func.coalesce(func.sum(Usage.total_cost_usd), 0))
.filter(
Usage.provider_id == self.provider_id,
Usage.created_at >= new_reset_at,
)
.scalar()
)
provider.monthly_used_usd = float(period_usage or 0)
logger.info(f"Synced usage for provider {provider.name}: ${period_usage:.4f} since {new_reset_at}")
if config.quota_expires_at:
provider.quota_expires_at = parser.parse(config.quota_expires_at)
db.commit()
db.refresh(provider)
logger.info(f"Updated billing config for provider {provider.name}")
return JSONResponse(
{
"message": "Provider billing config updated successfully",
"provider": {
"id": provider.id,
"name": provider.name,
"billing_type": provider.billing_type.value,
"provider_priority": provider.provider_priority,
},
}
)
class AdminProviderStatsAdapter(AdminApiAdapter):
def __init__(self, provider_id: str, hours: int):
self.provider_id = provider_id
self.hours = hours
async def handle(self, context):
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
since = datetime.now() - timedelta(hours=self.hours)
stats = (
db.query(ProviderUsageTracking)
.filter(
ProviderUsageTracking.provider_id == self.provider_id,
ProviderUsageTracking.window_start >= since,
)
.all()
)
total_requests = sum(s.total_requests for s in stats)
total_success = sum(s.successful_requests for s in stats)
total_failures = sum(s.failed_requests for s in stats)
avg_response_time = sum(s.avg_response_time_ms for s in stats) / len(stats) if stats else 0
total_cost = sum(s.total_cost_usd for s in stats)
return JSONResponse(
{
"provider_id": self.provider_id,
"provider_name": provider.name,
"period_hours": self.hours,
"billing_info": {
"billing_type": provider.billing_type.value,
"monthly_quota_usd": provider.monthly_quota_usd,
"monthly_used_usd": provider.monthly_used_usd,
"quota_remaining_usd": (
provider.monthly_quota_usd - provider.monthly_used_usd
if provider.monthly_quota_usd is not None
else None
),
"quota_expires_at": (
provider.quota_expires_at.isoformat() if provider.quota_expires_at else None
),
},
"rpm_info": {
"rpm_limit": provider.rpm_limit,
"rpm_used": provider.rpm_used,
"rpm_reset_at": (
provider.rpm_reset_at.isoformat() if provider.rpm_reset_at else None
),
},
"usage_stats": {
"total_requests": total_requests,
"successful_requests": total_success,
"failed_requests": total_failures,
"success_rate": total_success / total_requests if total_requests > 0 else 0,
"avg_response_time_ms": round(avg_response_time, 2),
"total_cost_usd": round(total_cost, 4),
},
}
)
class AdminProviderResetQuotaAdapter(AdminApiAdapter):
def __init__(self, provider_id: str):
self.provider_id = provider_id
async def handle(self, context):
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
if provider.billing_type != ProviderBillingType.MONTHLY_QUOTA:
raise HTTPException(status_code=400, detail="Only monthly quota providers can be reset")
old_used = provider.monthly_used_usd
provider.monthly_used_usd = 0.0
provider.rpm_used = 0
provider.rpm_reset_at = None
db.commit()
logger.info(f"Manually reset quota for provider {provider.name}")
return JSONResponse(
{
"message": "Provider quota reset successfully",
"provider_name": provider.name,
"previous_used": old_used,
"current_used": 0.0,
}
)
class AdminListStrategiesAdapter(AdminApiAdapter):
async def handle(self, context):
from src.plugins.manager import get_plugin_manager
plugin_manager = get_plugin_manager()
lb_plugins = plugin_manager.plugins.get("load_balancer", {})
strategies = []
for name, plugin in lb_plugins.items():
try:
strategies.append(
{
"name": getattr(plugin, "name", name),
"priority": getattr(plugin, "priority", 0),
"version": (
getattr(plugin.metadata, "version", "1.0.0")
if hasattr(plugin, "metadata")
else "1.0.0"
),
"description": (
getattr(plugin.metadata, "description", "")
if hasattr(plugin, "metadata")
else ""
),
"author": (
getattr(plugin.metadata, "author", "Unknown")
if hasattr(plugin, "metadata")
else "Unknown"
),
}
)
except Exception as exc: # pragma: no cover
logger.error(f"Error accessing plugin {name}: {exc}")
continue
strategies.sort(key=lambda x: x["priority"], reverse=True)
return JSONResponse({"strategies": strategies, "total": len(strategies)})

View File

@@ -0,0 +1,20 @@
"""Provider admin routes export."""
from fastapi import APIRouter
from .models import router as models_router
from .routes import router as routes_router
from .summary import router as summary_router
router = APIRouter(prefix="/api/admin/providers", tags=["Admin - Providers"])
# Provider CRUD
router.include_router(routes_router)
# Provider summary & health monitor
router.include_router(summary_router)
# Provider models management
router.include_router(models_router)
__all__ = ["router"]

View File

@@ -0,0 +1,443 @@
"""
Provider 模型管理 API
"""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, Request
from sqlalchemy import or_
from sqlalchemy.orm import Session, joinedload
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import InvalidRequestException, NotFoundException
from src.core.logger import logger
from src.database import get_db
from src.models.api import (
ModelCreate,
ModelResponse,
ModelUpdate,
)
from src.models.pydantic_models import (
BatchAssignModelsToProviderRequest,
BatchAssignModelsToProviderResponse,
)
from src.models.database import (
GlobalModel,
Model,
ModelMapping,
Provider,
)
from src.models.pydantic_models import (
ProviderAvailableSourceModel,
ProviderAvailableSourceModelsResponse,
)
from src.services.model.service import ModelService
router = APIRouter(tags=["Model Management"])
pipeline = ApiRequestPipeline()
@router.get("/{provider_id}/models", response_model=List[ModelResponse])
async def list_provider_models(
provider_id: str,
request: Request,
is_active: Optional[bool] = None,
skip: int = 0,
limit: int = 100,
db: Session = Depends(get_db),
) -> List[ModelResponse]:
"""获取提供商的所有模型(管理员)"""
adapter = AdminListProviderModelsAdapter(
provider_id=provider_id,
is_active=is_active,
skip=skip,
limit=limit,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/{provider_id}/models", response_model=ModelResponse)
async def create_provider_model(
provider_id: str,
model_data: ModelCreate,
request: Request,
db: Session = Depends(get_db),
) -> ModelResponse:
"""创建模型(管理员)"""
adapter = AdminCreateProviderModelAdapter(provider_id=provider_id, model_data=model_data)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/{provider_id}/models/{model_id}", response_model=ModelResponse)
async def get_provider_model(
provider_id: str,
model_id: str,
request: Request,
db: Session = Depends(get_db),
) -> ModelResponse:
"""获取模型详情(管理员)"""
adapter = AdminGetProviderModelAdapter(provider_id=provider_id, model_id=model_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.patch("/{provider_id}/models/{model_id}", response_model=ModelResponse)
async def update_provider_model(
provider_id: str,
model_id: str,
model_data: ModelUpdate,
request: Request,
db: Session = Depends(get_db),
) -> ModelResponse:
"""更新模型(管理员)"""
adapter = AdminUpdateProviderModelAdapter(
provider_id=provider_id,
model_id=model_id,
model_data=model_data,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/{provider_id}/models/{model_id}")
async def delete_provider_model(
provider_id: str,
model_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""删除模型(管理员)"""
adapter = AdminDeleteProviderModelAdapter(provider_id=provider_id, model_id=model_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/{provider_id}/models/batch", response_model=List[ModelResponse])
async def batch_create_provider_models(
provider_id: str,
models_data: List[ModelCreate],
request: Request,
db: Session = Depends(get_db),
) -> List[ModelResponse]:
"""批量创建模型(管理员)"""
adapter = AdminBatchCreateModelsAdapter(provider_id=provider_id, models_data=models_data)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get(
"/{provider_id}/available-source-models",
response_model=ProviderAvailableSourceModelsResponse,
)
async def get_provider_available_source_models(
provider_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""
获取该 Provider 支持的所有统一模型名source_model
包括:
1. 通过 ModelMapping 映射的模型
2. 直连模型Model.provider_model_name 直接作为统一模型名)
"""
adapter = AdminGetProviderAvailableSourceModelsAdapter(provider_id=provider_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post(
"/{provider_id}/assign-global-models",
response_model=BatchAssignModelsToProviderResponse,
)
async def batch_assign_global_models_to_provider(
provider_id: str,
payload: BatchAssignModelsToProviderRequest,
request: Request,
db: Session = Depends(get_db),
) -> BatchAssignModelsToProviderResponse:
"""批量为 Provider 关联 GlobalModels自动继承价格和能力配置"""
adapter = AdminBatchAssignModelsToProviderAdapter(
provider_id=provider_id, payload=payload
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# -------- Adapters --------
@dataclass
class AdminListProviderModelsAdapter(AdminApiAdapter):
provider_id: str
is_active: Optional[bool]
skip: int
limit: int
async def handle(self, context): # type: ignore[override]
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException("Provider not found", "provider")
models = ModelService.get_models_by_provider(
db, self.provider_id, self.skip, self.limit, self.is_active
)
return [ModelService.convert_to_response(model) for model in models]
@dataclass
class AdminCreateProviderModelAdapter(AdminApiAdapter):
provider_id: str
model_data: ModelCreate
async def handle(self, context): # type: ignore[override]
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException("Provider not found", "provider")
try:
model = ModelService.create_model(db, self.provider_id, self.model_data)
logger.info(f"Model created: {model.provider_model_name} for provider {provider.name} by {context.user.username}")
return ModelService.convert_to_response(model)
except Exception as exc:
raise InvalidRequestException(str(exc))
@dataclass
class AdminGetProviderModelAdapter(AdminApiAdapter):
provider_id: str
model_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
model = (
db.query(Model)
.filter(Model.id == self.model_id, Model.provider_id == self.provider_id)
.first()
)
if not model:
raise NotFoundException("Model not found", "model")
return ModelService.convert_to_response(model)
@dataclass
class AdminUpdateProviderModelAdapter(AdminApiAdapter):
provider_id: str
model_id: str
model_data: ModelUpdate
async def handle(self, context): # type: ignore[override]
db = context.db
model = (
db.query(Model)
.filter(Model.id == self.model_id, Model.provider_id == self.provider_id)
.first()
)
if not model:
raise NotFoundException("Model not found", "model")
try:
updated_model = ModelService.update_model(db, self.model_id, self.model_data)
logger.info(f"Model updated: {updated_model.provider_model_name} by {context.user.username}")
return ModelService.convert_to_response(updated_model)
except Exception as exc:
raise InvalidRequestException(str(exc))
@dataclass
class AdminDeleteProviderModelAdapter(AdminApiAdapter):
provider_id: str
model_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
model = (
db.query(Model)
.filter(Model.id == self.model_id, Model.provider_id == self.provider_id)
.first()
)
if not model:
raise NotFoundException("Model not found", "model")
model_name = model.provider_model_name
try:
ModelService.delete_model(db, self.model_id)
logger.info(f"Model deleted: {model_name} by {context.user.username}")
return {"message": f"Model '{model_name}' deleted successfully"}
except Exception as exc:
raise InvalidRequestException(str(exc))
@dataclass
class AdminBatchCreateModelsAdapter(AdminApiAdapter):
provider_id: str
models_data: List[ModelCreate]
async def handle(self, context): # type: ignore[override]
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException("Provider not found", "provider")
try:
models = ModelService.batch_create_models(db, self.provider_id, self.models_data)
logger.info(f"Batch created {len(models)} models for provider {provider.name} by {context.user.username}")
return [ModelService.convert_to_response(model) for model in models]
except Exception as exc:
raise InvalidRequestException(str(exc))
@dataclass
class AdminGetProviderAvailableSourceModelsAdapter(AdminApiAdapter):
provider_id: str
async def handle(self, context): # type: ignore[override]
"""
返回 Provider 支持的所有 GlobalModel
方案 A 逻辑:
1. 查询该 Provider 的所有 Model
2. 通过 Model.global_model_id 获取 GlobalModel
3. 查询所有指向该 GlobalModel 的别名ModelMapping.alias
"""
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException("Provider not found", "provider")
# 1. 查询该 Provider 的所有活跃 Model预加载 GlobalModel
models = (
db.query(Model)
.options(joinedload(Model.global_model))
.filter(Model.provider_id == self.provider_id, Model.is_active == True)
.all()
)
# 2. 构建以 GlobalModel 为主键的字典
global_models_dict: Dict[str, Dict[str, Any]] = {}
for model in models:
global_model = model.global_model
if not global_model or not global_model.is_active:
continue
global_model_name = global_model.name
# 如果该 GlobalModel 还未处理,初始化
if global_model_name not in global_models_dict:
# 查询指向该 GlobalModel 的所有别名/映射
alias_rows = (
db.query(ModelMapping.source_model)
.filter(
ModelMapping.target_global_model_id == global_model.id,
ModelMapping.is_active == True,
or_(
ModelMapping.provider_id == self.provider_id,
ModelMapping.provider_id.is_(None),
),
)
.all()
)
alias_list = [alias[0] for alias in alias_rows]
global_models_dict[global_model_name] = {
"global_model_name": global_model_name,
"display_name": global_model.display_name,
"provider_model_name": model.provider_model_name,
"has_alias": len(alias_list) > 0,
"aliases": alias_list,
"model_id": model.id,
"price": {
"input_price_per_1m": model.get_effective_input_price(),
"output_price_per_1m": model.get_effective_output_price(),
"cache_creation_price_per_1m": model.get_effective_cache_creation_price(),
"cache_read_price_per_1m": model.get_effective_cache_read_price(),
"price_per_request": model.get_effective_price_per_request(),
},
"capabilities": {
"supports_vision": bool(model.supports_vision),
"supports_function_calling": bool(model.supports_function_calling),
"supports_streaming": bool(model.supports_streaming),
},
"is_active": bool(model.is_active),
}
models_list = [
ProviderAvailableSourceModel(**global_models_dict[name])
for name in sorted(global_models_dict.keys())
]
return ProviderAvailableSourceModelsResponse(models=models_list, total=len(models_list))
@dataclass
class AdminBatchAssignModelsToProviderAdapter(AdminApiAdapter):
"""批量为 Provider 关联 GlobalModels"""
provider_id: str
payload: BatchAssignModelsToProviderRequest
async def handle(self, context): # type: ignore[override]
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException("Provider not found", "provider")
success = []
errors = []
for global_model_id in self.payload.global_model_ids:
try:
global_model = (
db.query(GlobalModel).filter(GlobalModel.id == global_model_id).first()
)
if not global_model:
errors.append(
{"global_model_id": global_model_id, "error": "GlobalModel not found"}
)
continue
# 检查是否已存在关联
existing = (
db.query(Model)
.filter(
Model.provider_id == self.provider_id,
Model.global_model_id == global_model_id,
)
.first()
)
if existing:
errors.append(
{
"global_model_id": global_model_id,
"global_model_name": global_model.name,
"error": "Already associated",
}
)
continue
# 创建新的 Model 记录,继承 GlobalModel 的配置
new_model = Model(
provider_id=self.provider_id,
global_model_id=global_model_id,
provider_model_name=global_model.name,
is_active=True,
)
db.add(new_model)
db.flush()
success.append(
{
"global_model_id": global_model_id,
"global_model_name": global_model.name,
"model_id": new_model.id,
}
)
except Exception as e:
errors.append({"global_model_id": global_model_id, "error": str(e)})
db.commit()
logger.info(
f"Batch assigned {len(success)} GlobalModels to provider {provider.name} by {context.user.username}"
)
return BatchAssignModelsToProviderResponse(success=success, errors=errors)

View File

@@ -0,0 +1,249 @@
"""管理员 Provider 管理路由。"""
from datetime import datetime, timezone
from typing import Optional
from fastapi import APIRouter, Depends, Query, Request
from pydantic import ValidationError
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.enums import ProviderBillingType
from src.core.exceptions import InvalidRequestException, NotFoundException
from src.database import get_db
from src.models.admin_requests import CreateProviderRequest, UpdateProviderRequest
from src.models.database import Provider
router = APIRouter(tags=["Provider CRUD"])
pipeline = ApiRequestPipeline()
@router.get("/")
async def list_providers(
request: Request,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=500),
is_active: Optional[bool] = None,
db: Session = Depends(get_db),
):
adapter = AdminListProvidersAdapter(skip=skip, limit=limit, is_active=is_active)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/")
async def create_provider(request: Request, db: Session = Depends(get_db)):
adapter = AdminCreateProviderAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.put("/{provider_id}")
async def update_provider(provider_id: str, request: Request, db: Session = Depends(get_db)):
adapter = AdminUpdateProviderAdapter(provider_id=provider_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/{provider_id}")
async def delete_provider(provider_id: str, request: Request, db: Session = Depends(get_db)):
adapter = AdminDeleteProviderAdapter(provider_id=provider_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
class AdminListProvidersAdapter(AdminApiAdapter):
def __init__(self, skip: int, limit: int, is_active: Optional[bool]):
self.skip = skip
self.limit = limit
self.is_active = is_active
async def handle(self, context): # type: ignore[override]
db = context.db
query = db.query(Provider)
if self.is_active is not None:
query = query.filter(Provider.is_active == self.is_active)
providers = query.offset(self.skip).limit(self.limit).all()
data = []
for provider in providers:
api_format = getattr(provider, "api_format", None)
base_url = getattr(provider, "base_url", None)
api_key = getattr(provider, "api_key", None)
priority = getattr(provider, "priority", provider.provider_priority)
data.append(
{
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
"api_format": api_format.value if api_format else None,
"base_url": base_url,
"api_key": "***" if api_key else None,
"priority": priority,
"is_active": provider.is_active,
"created_at": provider.created_at.isoformat(),
"updated_at": provider.updated_at.isoformat() if provider.updated_at else None,
}
)
context.add_audit_metadata(
action="list_providers",
filter_is_active=self.is_active,
limit=self.limit,
skip=self.skip,
result_count=len(data),
)
return data
class AdminCreateProviderAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
payload = context.ensure_json_body()
try:
# 使用 Pydantic 模型进行验证(自动进行 SQL 注入、XSS、SSRF 检测)
validated_data = CreateProviderRequest.model_validate(payload)
except ValidationError as exc:
# 将 Pydantic 验证错误转换为友好的错误信息
errors = []
for error in exc.errors():
field = " -> ".join(str(x) for x in error["loc"])
errors.append(f"{field}: {error['msg']}")
raise InvalidRequestException("输入验证失败: " + "; ".join(errors))
try:
# 检查名称是否已存在
existing = db.query(Provider).filter(Provider.name == validated_data.name).first()
if existing:
raise InvalidRequestException(f"提供商名称 '{validated_data.name}' 已存在")
# 将验证后的数据转换为枚举类型
billing_type = (
ProviderBillingType(validated_data.billing_type)
if validated_data.billing_type
else ProviderBillingType.PAY_AS_YOU_GO
)
# 创建 Provider 对象
provider = Provider(
name=validated_data.name,
display_name=validated_data.display_name,
description=validated_data.description,
website=validated_data.website,
billing_type=billing_type,
monthly_quota_usd=validated_data.monthly_quota_usd,
quota_reset_day=validated_data.quota_reset_day,
quota_last_reset_at=validated_data.quota_last_reset_at,
quota_expires_at=validated_data.quota_expires_at,
rpm_limit=validated_data.rpm_limit,
provider_priority=validated_data.provider_priority,
is_active=validated_data.is_active,
rate_limit=validated_data.rate_limit,
concurrent_limit=validated_data.concurrent_limit,
config=validated_data.config,
)
db.add(provider)
db.commit()
db.refresh(provider)
context.add_audit_metadata(
action="create_provider",
provider_id=provider.id,
provider_name=provider.name,
billing_type=provider.billing_type.value if provider.billing_type else None,
is_active=provider.is_active,
provider_priority=provider.provider_priority,
)
return {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
"message": "提供商创建成功",
}
except InvalidRequestException:
db.rollback()
raise
except Exception:
db.rollback()
raise
class AdminUpdateProviderAdapter(AdminApiAdapter):
def __init__(self, provider_id: str):
self.provider_id = provider_id
async def handle(self, context): # type: ignore[override]
db = context.db
payload = context.ensure_json_body()
# 查找 Provider
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException("提供商不存在", "provider")
try:
# 使用 Pydantic 模型进行验证(自动进行 SQL 注入、XSS、SSRF 检测)
validated_data = UpdateProviderRequest.model_validate(payload)
except ValidationError as exc:
# 将 Pydantic 验证错误转换为友好的错误信息
errors = []
for error in exc.errors():
field = " -> ".join(str(x) for x in error["loc"])
errors.append(f"{field}: {error['msg']}")
raise InvalidRequestException("输入验证失败: " + "; ".join(errors))
try:
# 更新字段(只更新非 None 的字段)
update_data = validated_data.model_dump(exclude_unset=True)
for field, value in update_data.items():
if field == "billing_type" and value is not None:
# billing_type 需要转换为枚举
setattr(provider, field, ProviderBillingType(value))
else:
setattr(provider, field, value)
provider.updated_at = datetime.now(timezone.utc)
db.commit()
db.refresh(provider)
context.add_audit_metadata(
action="update_provider",
provider_id=provider.id,
changed_fields=list(update_data.keys()),
is_active=provider.is_active,
provider_priority=provider.provider_priority,
)
return {
"id": provider.id,
"name": provider.name,
"is_active": provider.is_active,
"message": "提供商更新成功",
}
except InvalidRequestException:
db.rollback()
raise
except Exception:
db.rollback()
raise
class AdminDeleteProviderAdapter(AdminApiAdapter):
def __init__(self, provider_id: str):
self.provider_id = provider_id
async def handle(self, context): # type: ignore[override]
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException("提供商不存在", "provider")
context.add_audit_metadata(
action="delete_provider",
provider_id=provider.id,
provider_name=provider.name,
)
db.delete(provider)
db.commit()
return {"message": "提供商已删除"}

View File

@@ -0,0 +1,348 @@
"""
Provider 摘要与健康监控 API
"""
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Dict, List
from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy import case, func
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.enums import ProviderBillingType
from src.core.exceptions import NotFoundException
from src.core.logger import logger
from src.database import get_db
from src.models.database import (
Model,
Provider,
ProviderAPIKey,
ProviderEndpoint,
RequestCandidate,
)
from src.models.endpoint_models import (
EndpointHealthEvent,
EndpointHealthMonitor,
ProviderEndpointHealthMonitorResponse,
ProviderUpdateRequest,
ProviderWithEndpointsSummary,
)
router = APIRouter(tags=["Provider Summary"])
pipeline = ApiRequestPipeline()
@router.get("/summary", response_model=List[ProviderWithEndpointsSummary])
async def get_providers_summary(
request: Request,
db: Session = Depends(get_db),
) -> List[ProviderWithEndpointsSummary]:
"""获取所有 Providers 的摘要信息(包含 Endpoints 和 Keys 统计)"""
adapter = AdminProviderSummaryAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/{provider_id}/summary", response_model=ProviderWithEndpointsSummary)
async def get_provider_summary(
provider_id: str,
request: Request,
db: Session = Depends(get_db),
) -> ProviderWithEndpointsSummary:
"""获取单个 Provider 的摘要信息(包含 Endpoints 和 Keys 统计)"""
provider = db.query(Provider).filter(Provider.id == provider_id).first()
if not provider:
raise NotFoundException(f"Provider {provider_id} not found")
return _build_provider_summary(db, provider)
@router.get("/{provider_id}/health-monitor", response_model=ProviderEndpointHealthMonitorResponse)
async def get_provider_health_monitor(
provider_id: str,
request: Request,
lookback_hours: int = Query(6, ge=1, le=72, description="回溯的小时数"),
per_endpoint_limit: int = Query(48, ge=10, le=200, description="每个端点的事件数量"),
db: Session = Depends(get_db),
) -> ProviderEndpointHealthMonitorResponse:
"""获取 Provider 下所有端点的健康监控时间线"""
adapter = AdminProviderHealthMonitorAdapter(
provider_id=provider_id,
lookback_hours=lookback_hours,
per_endpoint_limit=per_endpoint_limit,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.patch("/{provider_id}", response_model=ProviderWithEndpointsSummary)
async def update_provider_settings(
provider_id: str,
update_data: ProviderUpdateRequest,
request: Request,
db: Session = Depends(get_db),
) -> ProviderWithEndpointsSummary:
"""更新 Provider 基础配置display_name, description, priority, weight 等)"""
adapter = AdminUpdateProviderSettingsAdapter(provider_id=provider_id, update_data=update_data)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
def _build_provider_summary(db: Session, provider: Provider) -> ProviderWithEndpointsSummary:
endpoints = db.query(ProviderEndpoint).filter(ProviderEndpoint.provider_id == provider.id).all()
total_endpoints = len(endpoints)
active_endpoints = sum(1 for e in endpoints if e.is_active)
endpoint_ids = [e.id for e in endpoints]
# Key 统计(合并为单个查询)
total_keys = 0
active_keys = 0
if endpoint_ids:
key_stats = db.query(
func.count(ProviderAPIKey.id).label("total"),
func.sum(case((ProviderAPIKey.is_active == True, 1), else_=0)).label("active"),
).filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids)).first()
total_keys = key_stats.total or 0
active_keys = int(key_stats.active or 0)
# Model 统计(合并为单个查询)
model_stats = db.query(
func.count(Model.id).label("total"),
func.sum(case((Model.is_active == True, 1), else_=0)).label("active"),
).filter(Model.provider_id == provider.id).first()
total_models = model_stats.total or 0
active_models = int(model_stats.active or 0)
api_formats = [e.api_format for e in endpoints]
# 优化: 一次性加载所有 endpoint 的 keys避免 N+1 查询
all_keys = []
if endpoint_ids:
all_keys = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids)).all()
)
# 按 endpoint_id 分组 keys
keys_by_endpoint: dict[str, list[ProviderAPIKey]] = {}
for key in all_keys:
if key.endpoint_id not in keys_by_endpoint:
keys_by_endpoint[key.endpoint_id] = []
keys_by_endpoint[key.endpoint_id].append(key)
endpoint_health_map: dict[str, float] = {}
for endpoint in endpoints:
keys = keys_by_endpoint.get(endpoint.id, [])
if keys:
health_scores = [k.health_score for k in keys if k.health_score is not None]
avg_health = sum(health_scores) / len(health_scores) if health_scores else 1.0
endpoint_health_map[endpoint.id] = avg_health
else:
endpoint_health_map[endpoint.id] = 1.0
all_health_scores = list(endpoint_health_map.values())
avg_health_score = sum(all_health_scores) / len(all_health_scores) if all_health_scores else 1.0
unhealthy_endpoints = sum(1 for score in all_health_scores if score < 0.5)
# 计算每个端点的活跃密钥数量
active_keys_by_endpoint: dict[str, int] = {}
for endpoint_id, keys in keys_by_endpoint.items():
active_keys_by_endpoint[endpoint_id] = sum(1 for k in keys if k.is_active)
endpoint_health_details = [
{
"api_format": e.api_format,
"health_score": endpoint_health_map.get(e.id, 1.0),
"is_active": e.is_active,
"active_keys": active_keys_by_endpoint.get(e.id, 0),
}
for e in endpoints
]
return ProviderWithEndpointsSummary(
id=provider.id,
name=provider.name,
display_name=provider.display_name,
description=provider.description,
website=provider.website,
provider_priority=provider.provider_priority,
is_active=provider.is_active,
billing_type=provider.billing_type.value if provider.billing_type else None,
monthly_quota_usd=provider.monthly_quota_usd,
monthly_used_usd=provider.monthly_used_usd,
quota_reset_day=provider.quota_reset_day,
quota_last_reset_at=provider.quota_last_reset_at,
quota_expires_at=provider.quota_expires_at,
rpm_limit=provider.rpm_limit,
rpm_used=provider.rpm_used,
rpm_reset_at=provider.rpm_reset_at,
total_endpoints=total_endpoints,
active_endpoints=active_endpoints,
total_keys=total_keys,
active_keys=active_keys,
total_models=total_models,
active_models=active_models,
avg_health_score=avg_health_score,
unhealthy_endpoints=unhealthy_endpoints,
api_formats=api_formats,
endpoint_health_details=endpoint_health_details,
created_at=provider.created_at,
updated_at=provider.updated_at,
)
# -------- Adapters --------
@dataclass
class AdminProviderHealthMonitorAdapter(AdminApiAdapter):
provider_id: str
lookback_hours: int
per_endpoint_limit: int
async def handle(self, context): # type: ignore[override]
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException(f"Provider {self.provider_id} 不存在")
endpoints = (
db.query(ProviderEndpoint)
.filter(ProviderEndpoint.provider_id == self.provider_id)
.all()
)
now = datetime.now(timezone.utc)
since = now - timedelta(hours=self.lookback_hours)
endpoint_ids = [endpoint.id for endpoint in endpoints]
if not endpoint_ids:
response = ProviderEndpointHealthMonitorResponse(
provider_id=provider.id,
provider_name=provider.display_name or provider.name,
generated_at=now,
endpoints=[],
)
context.add_audit_metadata(
action="provider_health_monitor",
provider_id=self.provider_id,
endpoint_count=0,
lookback_hours=self.lookback_hours,
)
return response
limit_rows = max(200, self.per_endpoint_limit * max(1, len(endpoint_ids)) * 2)
attempts_query = (
db.query(RequestCandidate)
.filter(
RequestCandidate.endpoint_id.in_(endpoint_ids),
RequestCandidate.created_at >= since,
)
.order_by(RequestCandidate.created_at.desc())
)
attempts = attempts_query.limit(limit_rows).all()
buffered_attempts: Dict[str, List[RequestCandidate]] = {eid: [] for eid in endpoint_ids}
counters: Dict[str, int] = {eid: 0 for eid in endpoint_ids}
for attempt in attempts:
if not attempt.endpoint_id or attempt.endpoint_id not in buffered_attempts:
continue
if counters[attempt.endpoint_id] >= self.per_endpoint_limit:
continue
buffered_attempts[attempt.endpoint_id].append(attempt)
counters[attempt.endpoint_id] += 1
endpoint_monitors: List[EndpointHealthMonitor] = []
for endpoint in endpoints:
attempt_list = list(reversed(buffered_attempts.get(endpoint.id, [])))
events: List[EndpointHealthEvent] = []
for attempt in attempt_list:
event_timestamp = attempt.finished_at or attempt.started_at or attempt.created_at
events.append(
EndpointHealthEvent(
timestamp=event_timestamp,
status=attempt.status,
status_code=attempt.status_code,
latency_ms=attempt.latency_ms,
error_type=attempt.error_type,
error_message=attempt.error_message,
)
)
success_count = sum(1 for event in events if event.status == "success")
failed_count = sum(1 for event in events if event.status == "failed")
skipped_count = sum(1 for event in events if event.status == "skipped")
total_attempts = len(events)
success_rate = success_count / total_attempts if total_attempts else 1.0
last_event_at = events[-1].timestamp if events else None
endpoint_monitors.append(
EndpointHealthMonitor(
endpoint_id=endpoint.id,
api_format=endpoint.api_format,
is_active=endpoint.is_active,
total_attempts=total_attempts,
success_count=success_count,
failed_count=failed_count,
skipped_count=skipped_count,
success_rate=success_rate,
last_event_at=last_event_at,
events=events,
)
)
response = ProviderEndpointHealthMonitorResponse(
provider_id=provider.id,
provider_name=provider.display_name or provider.name,
generated_at=now,
endpoints=endpoint_monitors,
)
context.add_audit_metadata(
action="provider_health_monitor",
provider_id=self.provider_id,
endpoint_count=len(endpoint_monitors),
lookback_hours=self.lookback_hours,
per_endpoint_limit=self.per_endpoint_limit,
)
return response
class AdminProviderSummaryAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
providers = (
db.query(Provider)
.order_by(Provider.provider_priority.asc(), Provider.created_at.asc())
.all()
)
return [_build_provider_summary(db, provider) for provider in providers]
@dataclass
class AdminUpdateProviderSettingsAdapter(AdminApiAdapter):
provider_id: str
update_data: ProviderUpdateRequest
async def handle(self, context): # type: ignore[override]
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException("Provider not found", "provider")
update_dict = self.update_data.model_dump(exclude_unset=True)
if "billing_type" in update_dict and update_dict["billing_type"] is not None:
update_dict["billing_type"] = ProviderBillingType(update_dict["billing_type"])
for key, value in update_dict.items():
setattr(provider, key, value)
db.commit()
db.refresh(provider)
admin_name = context.user.username if context.user else "admin"
logger.info(f"Provider {provider.name} updated by {admin_name}: {update_dict}")
return _build_provider_summary(db, provider)

View File

@@ -0,0 +1,14 @@
"""
安全管理 API
提供 IP 黑白名单管理等安全功能
"""
from fastapi import APIRouter
from .ip_management import router as ip_router
router = APIRouter()
router.include_router(ip_router)
__all__ = ["router"]

View File

@@ -0,0 +1,202 @@
"""
IP 安全管理接口
提供 IP 黑白名单管理和速率限制统计
"""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy.orm import Session
from src.api.base.adapter import ApiMode
from src.api.base.authenticated_adapter import AuthenticatedApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import InvalidRequestException, translate_pydantic_error
from src.core.logger import logger
from src.database import get_db
from src.services.rate_limit.ip_limiter import IPRateLimiter
router = APIRouter(prefix="/api/admin/security/ip", tags=["IP Security"])
pipeline = ApiRequestPipeline()
# ========== Pydantic 模型 ==========
class AddIPToBlacklistRequest(BaseModel):
"""添加 IP 到黑名单请求"""
ip_address: str = Field(..., description="IP 地址")
reason: str = Field(..., min_length=1, max_length=200, description="加入黑名单的原因")
ttl: Optional[int] = Field(None, gt=0, description="过期时间None 表示永久")
class RemoveIPFromBlacklistRequest(BaseModel):
"""从黑名单移除 IP 请求"""
ip_address: str = Field(..., description="IP 地址")
class AddIPToWhitelistRequest(BaseModel):
"""添加 IP 到白名单请求"""
ip_address: str = Field(..., description="IP 地址或 CIDR 格式(如 192.168.1.0/24")
class RemoveIPFromWhitelistRequest(BaseModel):
"""从白名单移除 IP 请求"""
ip_address: str = Field(..., description="IP 地址")
# ========== API 端点 ==========
@router.post("/blacklist")
async def add_to_blacklist(request: Request, db: Session = Depends(get_db)):
"""Add IP to blacklist"""
adapter = AddToBlacklistAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.ADMIN)
@router.delete("/blacklist/{ip_address}")
async def remove_from_blacklist(ip_address: str, request: Request, db: Session = Depends(get_db)):
"""Remove IP from blacklist"""
adapter = RemoveFromBlacklistAdapter(ip_address=ip_address)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.ADMIN)
@router.get("/blacklist/stats")
async def get_blacklist_stats(request: Request, db: Session = Depends(get_db)):
"""Get blacklist statistics"""
adapter = GetBlacklistStatsAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.ADMIN)
@router.post("/whitelist")
async def add_to_whitelist(request: Request, db: Session = Depends(get_db)):
"""Add IP to whitelist"""
adapter = AddToWhitelistAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.ADMIN)
@router.delete("/whitelist/{ip_address}")
async def remove_from_whitelist(ip_address: str, request: Request, db: Session = Depends(get_db)):
"""Remove IP from whitelist"""
adapter = RemoveFromWhitelistAdapter(ip_address=ip_address)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.ADMIN)
@router.get("/whitelist")
async def get_whitelist(request: Request, db: Session = Depends(get_db)):
"""Get whitelist"""
adapter = GetWhitelistAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.ADMIN)
# ========== 适配器实现 ==========
class AddToBlacklistAdapter(AuthenticatedApiAdapter):
"""添加 IP 到黑名单适配器"""
async def handle(self, context): # type: ignore[override]
payload = context.ensure_json_body()
try:
req = AddIPToBlacklistRequest.model_validate(payload)
except ValidationError as e:
errors = e.errors()
if errors:
raise InvalidRequestException(translate_pydantic_error(errors[0]))
raise InvalidRequestException("请求数据验证失败")
success = await IPRateLimiter.add_to_blacklist(req.ip_address, req.reason, req.ttl)
if not success:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="添加 IP 到黑名单失败Redis 不可用)",
)
return {
"success": True,
"message": f"IP {req.ip_address} 已加入黑名单",
"reason": req.reason,
"ttl": req.ttl or "永久",
}
class RemoveFromBlacklistAdapter(AuthenticatedApiAdapter):
"""从黑名单移除 IP 适配器"""
def __init__(self, ip_address: str):
self.ip_address = ip_address
async def handle(self, context): # type: ignore[override]
success = await IPRateLimiter.remove_from_blacklist(self.ip_address)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=f"IP {self.ip_address} 不在黑名单中"
)
return {"success": True, "message": f"IP {self.ip_address} 已从黑名单移除"}
class GetBlacklistStatsAdapter(AuthenticatedApiAdapter):
"""获取黑名单统计适配器"""
async def handle(self, context): # type: ignore[override]
stats = await IPRateLimiter.get_blacklist_stats()
return stats
class AddToWhitelistAdapter(AuthenticatedApiAdapter):
"""添加 IP 到白名单适配器"""
async def handle(self, context): # type: ignore[override]
payload = context.ensure_json_body()
try:
req = AddIPToWhitelistRequest.model_validate(payload)
except ValidationError as e:
errors = e.errors()
if errors:
raise InvalidRequestException(translate_pydantic_error(errors[0]))
raise InvalidRequestException("请求数据验证失败")
success = await IPRateLimiter.add_to_whitelist(req.ip_address)
if not success:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"添加 IP 到白名单失败(无效的 IP 格式或 Redis 不可用)",
)
return {"success": True, "message": f"IP {req.ip_address} 已加入白名单"}
class RemoveFromWhitelistAdapter(AuthenticatedApiAdapter):
"""从白名单移除 IP 适配器"""
def __init__(self, ip_address: str):
self.ip_address = ip_address
async def handle(self, context): # type: ignore[override]
success = await IPRateLimiter.remove_from_whitelist(self.ip_address)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=f"IP {self.ip_address} 不在白名单中"
)
return {"success": True, "message": f"IP {self.ip_address} 已从白名单移除"}
class GetWhitelistAdapter(AuthenticatedApiAdapter):
"""获取白名单适配器"""
async def handle(self, context): # type: ignore[override]
whitelist = await IPRateLimiter.get_whitelist()
return {"whitelist": list(whitelist), "total": len(whitelist)}

312
src/api/admin/system.py Normal file
View File

@@ -0,0 +1,312 @@
"""系统设置API端点。"""
from dataclasses import dataclass
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import ValidationError
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import InvalidRequestException, NotFoundException, translate_pydantic_error
from src.database import get_db
from src.models.api import SystemSettingsRequest, SystemSettingsResponse
from src.models.database import ApiKey, Provider, Usage, User
from src.services.system.config import SystemConfigService
router = APIRouter(prefix="/api/admin/system", tags=["Admin - System"])
pipeline = ApiRequestPipeline()
@router.get("/settings")
async def get_system_settings(request: Request, db: Session = Depends(get_db)):
"""获取系统设置(管理员)"""
adapter = AdminGetSystemSettingsAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.put("/settings")
async def update_system_settings(http_request: Request, db: Session = Depends(get_db)):
"""更新系统设置(管理员)"""
adapter = AdminUpdateSystemSettingsAdapter()
return await pipeline.run(adapter=adapter, http_request=http_request, db=db, mode=adapter.mode)
@router.get("/configs")
async def get_all_system_configs(request: Request, db: Session = Depends(get_db)):
"""获取所有系统配置(管理员)"""
adapter = AdminGetAllConfigsAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/configs/{key}")
async def get_system_config(key: str, request: Request, db: Session = Depends(get_db)):
"""获取特定系统配置(管理员)"""
adapter = AdminGetSystemConfigAdapter(key=key)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.put("/configs/{key}")
async def set_system_config(
key: str,
request: Request,
db: Session = Depends(get_db),
):
"""设置系统配置(管理员)"""
adapter = AdminSetSystemConfigAdapter(key=key)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/configs/{key}")
async def delete_system_config(key: str, request: Request, db: Session = Depends(get_db)):
"""删除系统配置(管理员)"""
adapter = AdminDeleteSystemConfigAdapter(key=key)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/stats")
async def get_system_stats(request: Request, db: Session = Depends(get_db)):
adapter = AdminSystemStatsAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/cleanup")
async def trigger_cleanup(request: Request, db: Session = Depends(get_db)):
"""Manually trigger usage record cleanup task"""
adapter = AdminTriggerCleanupAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/api-formats")
async def get_api_formats(request: Request, db: Session = Depends(get_db)):
"""获取所有可用的API格式列表"""
adapter = AdminGetApiFormatsAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# -------- 系统设置适配器 --------
class AdminGetSystemSettingsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
default_provider = SystemConfigService.get_default_provider(db)
default_model = SystemConfigService.get_config(db, "default_model")
enable_usage_tracking = (
SystemConfigService.get_config(db, "enable_usage_tracking", "true") == "true"
)
return SystemSettingsResponse(
default_provider=default_provider,
default_model=default_model,
enable_usage_tracking=enable_usage_tracking,
)
class AdminUpdateSystemSettingsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
payload = context.ensure_json_body()
try:
settings_request = SystemSettingsRequest.model_validate(payload)
except ValidationError as e:
errors = e.errors()
if errors:
raise InvalidRequestException(translate_pydantic_error(errors[0]))
raise InvalidRequestException("请求数据验证失败")
if settings_request.default_provider is not None:
provider = (
db.query(Provider)
.filter(
Provider.name == settings_request.default_provider,
Provider.is_active.is_(True),
)
.first()
)
if not provider and settings_request.default_provider != "":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"提供商 '{settings_request.default_provider}' 不存在或未启用",
)
if settings_request.default_provider:
SystemConfigService.set_default_provider(db, settings_request.default_provider)
else:
SystemConfigService.delete_config(db, "default_provider")
if settings_request.default_model is not None:
if settings_request.default_model:
SystemConfigService.set_config(db, "default_model", settings_request.default_model)
else:
SystemConfigService.delete_config(db, "default_model")
if settings_request.enable_usage_tracking is not None:
SystemConfigService.set_config(
db,
"enable_usage_tracking",
str(settings_request.enable_usage_tracking).lower(),
)
return {"message": "系统设置更新成功"}
class AdminGetAllConfigsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
return SystemConfigService.get_all_configs(context.db)
@dataclass
class AdminGetSystemConfigAdapter(AdminApiAdapter):
key: str
async def handle(self, context): # type: ignore[override]
value = SystemConfigService.get_config(context.db, self.key)
if value is None:
raise NotFoundException(f"配置项 '{self.key}' 不存在")
return {"key": self.key, "value": value}
@dataclass
class AdminSetSystemConfigAdapter(AdminApiAdapter):
key: str
async def handle(self, context): # type: ignore[override]
payload = context.ensure_json_body()
config = SystemConfigService.set_config(
context.db,
self.key,
payload.get("value"),
payload.get("description"),
)
return {
"key": config.key,
"value": config.value,
"description": config.description,
"updated_at": config.updated_at.isoformat(),
}
@dataclass
class AdminDeleteSystemConfigAdapter(AdminApiAdapter):
key: str
async def handle(self, context): # type: ignore[override]
deleted = SystemConfigService.delete_config(context.db, self.key)
if not deleted:
raise NotFoundException(f"配置项 '{self.key}' 不存在")
return {"message": f"配置项 '{self.key}' 已删除"}
class AdminSystemStatsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
total_users = db.query(User).count()
active_users = db.query(User).filter(User.is_active.is_(True)).count()
total_providers = db.query(Provider).count()
active_providers = db.query(Provider).filter(Provider.is_active.is_(True)).count()
total_api_keys = db.query(ApiKey).count()
total_requests = db.query(Usage).count()
return {
"users": {"total": total_users, "active": active_users},
"providers": {"total": total_providers, "active": active_providers},
"api_keys": total_api_keys,
"requests": total_requests,
}
class AdminTriggerCleanupAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
"""手动触发清理任务"""
from datetime import datetime, timedelta, timezone
from sqlalchemy import func
from src.services.system.cleanup_scheduler import get_cleanup_scheduler
db = context.db
# 获取清理前的统计信息
total_before = db.query(Usage).count()
with_body_before = (
db.query(Usage)
.filter((Usage.request_body.isnot(None)) | (Usage.response_body.isnot(None)))
.count()
)
with_headers_before = (
db.query(Usage)
.filter((Usage.request_headers.isnot(None)) | (Usage.response_headers.isnot(None)))
.count()
)
# 触发清理
cleanup_scheduler = get_cleanup_scheduler()
await cleanup_scheduler._perform_cleanup()
# 获取清理后的统计信息
total_after = db.query(Usage).count()
with_body_after = (
db.query(Usage)
.filter((Usage.request_body.isnot(None)) | (Usage.response_body.isnot(None)))
.count()
)
with_headers_after = (
db.query(Usage)
.filter((Usage.request_headers.isnot(None)) | (Usage.response_headers.isnot(None)))
.count()
)
return {
"message": "清理任务执行完成",
"stats": {
"total_records": {
"before": total_before,
"after": total_after,
"deleted": total_before - total_after,
},
"body_fields": {
"before": with_body_before,
"after": with_body_after,
"cleaned": with_body_before - with_body_after,
},
"header_fields": {
"before": with_headers_before,
"after": with_headers_after,
"cleaned": with_headers_before - with_headers_after,
},
},
"timestamp": datetime.now(timezone.utc).isoformat(),
}
class AdminGetApiFormatsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
"""获取所有可用的API格式"""
from src.core.api_format_metadata import API_FORMAT_DEFINITIONS
from src.core.enums import APIFormat
_ = context # 参数保留以符合接口规范
formats = []
for api_format in APIFormat:
definition = API_FORMAT_DEFINITIONS.get(api_format)
formats.append(
{
"value": api_format.value,
"label": api_format.value,
"default_path": definition.default_path if definition else "/",
"aliases": list(definition.aliases) if definition else [],
}
)
return {"formats": formats}

View File

@@ -0,0 +1,5 @@
"""Usage admin routes export."""
from .routes import router
__all__ = ["router"]

View File

@@ -0,0 +1,818 @@
"""管理员使用情况统计路由。"""
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from sqlalchemy import func
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.database import get_db
from src.models.database import (
ApiKey,
Provider,
ProviderAPIKey,
ProviderEndpoint,
RequestCandidate,
Usage,
User,
)
from src.services.usage.service import UsageService
router = APIRouter(prefix="/api/admin/usage", tags=["Admin - Usage"])
pipeline = ApiRequestPipeline()
# ==================== RESTful Routes ====================
@router.get("/aggregation/stats")
async def get_usage_aggregation(
request: Request,
group_by: str = Query(..., description="Aggregation dimension: model, user, provider, or api_format"),
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: int = Query(20, ge=1, le=100),
db: Session = Depends(get_db),
):
"""
Get usage aggregation by specified dimension.
- group_by=model: Aggregate by model
- group_by=user: Aggregate by user
- group_by=provider: Aggregate by provider
- group_by=api_format: Aggregate by API format
"""
if group_by == "model":
adapter = AdminUsageByModelAdapter(start_date=start_date, end_date=end_date, limit=limit)
elif group_by == "user":
adapter = AdminUsageByUserAdapter(start_date=start_date, end_date=end_date, limit=limit)
elif group_by == "provider":
adapter = AdminUsageByProviderAdapter(start_date=start_date, end_date=end_date, limit=limit)
elif group_by == "api_format":
adapter = AdminUsageByApiFormatAdapter(start_date=start_date, end_date=end_date, limit=limit)
else:
raise HTTPException(
status_code=400,
detail=f"Invalid group_by value: {group_by}. Must be one of: model, user, provider, api_format"
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/stats")
async def get_usage_stats(
request: Request,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
db: Session = Depends(get_db),
):
adapter = AdminUsageStatsAdapter(start_date=start_date, end_date=end_date)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/records")
async def get_usage_records(
request: Request,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
user_id: Optional[str] = None,
username: Optional[str] = None,
model: Optional[str] = None,
provider: Optional[str] = None,
status: Optional[str] = None, # stream, standard, error
limit: int = Query(100, ge=1, le=500),
offset: int = Query(0, ge=0),
db: Session = Depends(get_db),
):
adapter = AdminUsageRecordsAdapter(
start_date=start_date,
end_date=end_date,
user_id=user_id,
username=username,
model=model,
provider=provider,
status=status,
limit=limit,
offset=offset,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/active")
async def get_active_requests(
request: Request,
ids: Optional[str] = Query(None, description="逗号分隔的请求 ID 列表,用于查询特定请求的状态"),
db: Session = Depends(get_db),
):
"""
获取活跃请求的状态(轻量级接口,用于前端轮询)
- 如果提供 ids 参数,只返回这些 ID 对应请求的最新状态
- 如果不提供 ids返回所有 pending/streaming 状态的请求
"""
adapter = AdminActiveRequestsAdapter(ids=ids)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# NOTE: This route must be defined AFTER all other routes to avoid matching
# routes like /stats, /records, /active, etc.
@router.get("/{usage_id}")
async def get_usage_detail(
usage_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""
Get detailed information of a specific usage record.
Includes request/response headers and body.
"""
adapter = AdminUsageDetailAdapter(usage_id=usage_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
class AdminUsageStatsAdapter(AdminApiAdapter):
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime]):
self.start_date = start_date
self.end_date = end_date
async def handle(self, context): # type: ignore[override]
db = context.db
query = db.query(Usage)
if self.start_date:
query = query.filter(Usage.created_at >= self.start_date)
if self.end_date:
query = query.filter(Usage.created_at <= self.end_date)
total_stats = query.with_entities(
func.count(Usage.id).label("total_requests"),
func.sum(Usage.total_tokens).label("total_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost"),
func.sum(Usage.actual_total_cost_usd).label("total_actual_cost"),
func.avg(Usage.response_time_ms).label("avg_response_time_ms"),
).first()
# 缓存统计
cache_stats = query.with_entities(
func.sum(Usage.cache_creation_input_tokens).label("cache_creation_tokens"),
func.sum(Usage.cache_read_input_tokens).label("cache_read_tokens"),
func.sum(Usage.cache_creation_cost_usd).label("cache_creation_cost"),
func.sum(Usage.cache_read_cost_usd).label("cache_read_cost"),
).first()
# 错误统计
error_count = query.filter(
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
).count()
activity_heatmap = UsageService.get_daily_activity(
db=db,
window_days=365,
include_actual_cost=True,
)
context.add_audit_metadata(
action="usage_stats",
start_date=self.start_date.isoformat() if self.start_date else None,
end_date=self.end_date.isoformat() if self.end_date else None,
)
total_requests = total_stats.total_requests if total_stats else 0
avg_response_time_ms = float(total_stats.avg_response_time_ms or 0) if total_stats else 0
avg_response_time = avg_response_time_ms / 1000.0
return {
"total_requests": total_requests,
"total_tokens": int(total_stats.total_tokens or 0),
"total_cost": float(total_stats.total_cost or 0),
"total_actual_cost": float(total_stats.total_actual_cost or 0),
"avg_response_time": round(avg_response_time, 2),
"error_count": error_count,
"error_rate": (
round((error_count / total_requests) * 100, 2) if total_requests > 0 else 0
),
"cache_stats": {
"cache_creation_tokens": (
int(cache_stats.cache_creation_tokens or 0) if cache_stats else 0
),
"cache_read_tokens": int(cache_stats.cache_read_tokens or 0) if cache_stats else 0,
"cache_creation_cost": (
float(cache_stats.cache_creation_cost or 0) if cache_stats else 0
),
"cache_read_cost": float(cache_stats.cache_read_cost or 0) if cache_stats else 0,
},
"activity_heatmap": activity_heatmap,
}
class AdminUsageByModelAdapter(AdminApiAdapter):
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int):
self.start_date = start_date
self.end_date = end_date
self.limit = limit
async def handle(self, context): # type: ignore[override]
db = context.db
query = db.query(
Usage.model,
func.count(Usage.id).label("request_count"),
func.sum(Usage.total_tokens).label("total_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost"),
func.sum(Usage.actual_total_cost_usd).label("actual_cost"),
)
# 过滤掉 pending/streaming 状态的请求(尚未完成的请求不应计入统计)
query = query.filter(Usage.status.notin_(["pending", "streaming"]))
# 过滤掉 unknown/pending provider请求未到达任何提供商
query = query.filter(Usage.provider.notin_(["unknown", "pending"]))
if self.start_date:
query = query.filter(Usage.created_at >= self.start_date)
if self.end_date:
query = query.filter(Usage.created_at <= self.end_date)
query = query.group_by(Usage.model).order_by(func.count(Usage.id).desc()).limit(self.limit)
stats = query.all()
context.add_audit_metadata(
action="usage_by_model",
start_date=self.start_date.isoformat() if self.start_date else None,
end_date=self.end_date.isoformat() if self.end_date else None,
limit=self.limit,
result_count=len(stats),
)
return [
{
"model": model,
"request_count": count,
"total_tokens": int(tokens or 0),
"total_cost": float(cost or 0),
"actual_cost": float(actual_cost or 0),
}
for model, count, tokens, cost, actual_cost in stats
]
class AdminUsageByUserAdapter(AdminApiAdapter):
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int):
self.start_date = start_date
self.end_date = end_date
self.limit = limit
async def handle(self, context): # type: ignore[override]
db = context.db
query = (
db.query(
User.id,
User.email,
User.username,
func.count(Usage.id).label("request_count"),
func.sum(Usage.total_tokens).label("total_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost"),
)
.join(Usage, Usage.user_id == User.id)
.group_by(User.id, User.email, User.username)
)
if self.start_date:
query = query.filter(Usage.created_at >= self.start_date)
if self.end_date:
query = query.filter(Usage.created_at <= self.end_date)
query = query.order_by(func.count(Usage.id).desc()).limit(self.limit)
stats = query.all()
context.add_audit_metadata(
action="usage_by_user",
start_date=self.start_date.isoformat() if self.start_date else None,
end_date=self.end_date.isoformat() if self.end_date else None,
limit=self.limit,
result_count=len(stats),
)
return [
{
"user_id": user_id,
"email": email,
"username": username,
"request_count": count,
"total_tokens": int(tokens or 0),
"total_cost": float(cost or 0),
}
for user_id, email, username, count, tokens, cost in stats
]
class AdminUsageByProviderAdapter(AdminApiAdapter):
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int):
self.start_date = start_date
self.end_date = end_date
self.limit = limit
async def handle(self, context): # type: ignore[override]
db = context.db
# 从 request_candidates 表统计每个 Provider 的尝试次数和成功率
# 这样可以正确统计 Fallback 场景(一个请求可能尝试多个 Provider
from sqlalchemy import case, Integer
attempt_query = db.query(
RequestCandidate.provider_id,
func.count(RequestCandidate.id).label("attempt_count"),
func.sum(
case((RequestCandidate.status == "success", 1), else_=0)
).label("success_count"),
func.sum(
case((RequestCandidate.status == "failed", 1), else_=0)
).label("failed_count"),
func.avg(RequestCandidate.latency_ms).label("avg_latency_ms"),
).filter(
RequestCandidate.provider_id.isnot(None),
# 只统计实际执行的尝试(排除 available/skipped 状态)
RequestCandidate.status.in_(["success", "failed"]),
)
if self.start_date:
attempt_query = attempt_query.filter(RequestCandidate.created_at >= self.start_date)
if self.end_date:
attempt_query = attempt_query.filter(RequestCandidate.created_at <= self.end_date)
attempt_stats = (
attempt_query.group_by(RequestCandidate.provider_id)
.order_by(func.count(RequestCandidate.id).desc())
.limit(self.limit)
.all()
)
# 从 Usage 表获取 token 和费用统计(基于成功的请求)
usage_query = db.query(
Usage.provider_id,
func.count(Usage.id).label("request_count"),
func.sum(Usage.total_tokens).label("total_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost"),
func.sum(Usage.actual_total_cost_usd).label("actual_cost"),
func.avg(Usage.response_time_ms).label("avg_response_time_ms"),
).filter(
Usage.provider_id.isnot(None),
# 过滤掉 pending/streaming 状态的请求
Usage.status.notin_(["pending", "streaming"]),
)
if self.start_date:
usage_query = usage_query.filter(Usage.created_at >= self.start_date)
if self.end_date:
usage_query = usage_query.filter(Usage.created_at <= self.end_date)
usage_stats = usage_query.group_by(Usage.provider_id).all()
usage_map = {str(u.provider_id): u for u in usage_stats}
# 获取所有相关的 Provider ID
provider_ids = set()
for stat in attempt_stats:
if stat.provider_id:
provider_ids.add(stat.provider_id)
for stat in usage_stats:
if stat.provider_id:
provider_ids.add(stat.provider_id)
# 获取 Provider 名称映射
provider_map = {}
if provider_ids:
providers_data = (
db.query(Provider.id, Provider.name).filter(Provider.id.in_(provider_ids)).all()
)
provider_map = {str(p.id): p.name for p in providers_data}
context.add_audit_metadata(
action="usage_by_provider",
start_date=self.start_date.isoformat() if self.start_date else None,
end_date=self.end_date.isoformat() if self.end_date else None,
limit=self.limit,
result_count=len(attempt_stats),
)
result = []
for stat in attempt_stats:
provider_id_str = str(stat.provider_id) if stat.provider_id else None
attempt_count = stat.attempt_count or 0
success_count = int(stat.success_count or 0)
failed_count = int(stat.failed_count or 0)
success_rate = (success_count / attempt_count * 100) if attempt_count > 0 else 0
# 从 usage_map 获取 token 和费用信息
usage_stat = usage_map.get(provider_id_str)
result.append({
"provider_id": provider_id_str,
"provider": provider_map.get(provider_id_str, "Unknown"),
"request_count": attempt_count, # 尝试次数
"total_tokens": int(usage_stat.total_tokens or 0) if usage_stat else 0,
"total_cost": float(usage_stat.total_cost or 0) if usage_stat else 0,
"actual_cost": float(usage_stat.actual_cost or 0) if usage_stat else 0,
"avg_response_time_ms": float(stat.avg_latency_ms or 0),
"success_rate": round(success_rate, 2),
"error_count": failed_count,
})
return result
class AdminUsageByApiFormatAdapter(AdminApiAdapter):
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int):
self.start_date = start_date
self.end_date = end_date
self.limit = limit
async def handle(self, context): # type: ignore[override]
db = context.db
query = db.query(
Usage.api_format,
func.count(Usage.id).label("request_count"),
func.sum(Usage.total_tokens).label("total_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost"),
func.sum(Usage.actual_total_cost_usd).label("actual_cost"),
func.avg(Usage.response_time_ms).label("avg_response_time_ms"),
)
# 过滤掉 pending/streaming 状态的请求
query = query.filter(Usage.status.notin_(["pending", "streaming"]))
# 过滤掉 unknown/pending provider
query = query.filter(Usage.provider.notin_(["unknown", "pending"]))
# 只统计有 api_format 的记录
query = query.filter(Usage.api_format.isnot(None))
if self.start_date:
query = query.filter(Usage.created_at >= self.start_date)
if self.end_date:
query = query.filter(Usage.created_at <= self.end_date)
query = (
query.group_by(Usage.api_format)
.order_by(func.count(Usage.id).desc())
.limit(self.limit)
)
stats = query.all()
context.add_audit_metadata(
action="usage_by_api_format",
start_date=self.start_date.isoformat() if self.start_date else None,
end_date=self.end_date.isoformat() if self.end_date else None,
limit=self.limit,
result_count=len(stats),
)
return [
{
"api_format": api_format or "unknown",
"request_count": count,
"total_tokens": int(tokens or 0),
"total_cost": float(cost or 0),
"actual_cost": float(actual_cost or 0),
"avg_response_time_ms": float(avg_response_time or 0),
}
for api_format, count, tokens, cost, actual_cost, avg_response_time in stats
]
class AdminUsageRecordsAdapter(AdminApiAdapter):
def __init__(
self,
start_date: Optional[datetime],
end_date: Optional[datetime],
user_id: Optional[str],
username: Optional[str],
model: Optional[str],
provider: Optional[str],
status: Optional[str],
limit: int,
offset: int,
):
self.start_date = start_date
self.end_date = end_date
self.user_id = user_id
self.username = username
self.model = model
self.provider = provider
self.status = status
self.limit = limit
self.offset = offset
async def handle(self, context): # type: ignore[override]
db = context.db
query = (
db.query(Usage, User, ProviderEndpoint, ProviderAPIKey)
.outerjoin(User, Usage.user_id == User.id)
.outerjoin(ProviderEndpoint, Usage.provider_endpoint_id == ProviderEndpoint.id)
.outerjoin(ProviderAPIKey, Usage.provider_api_key_id == ProviderAPIKey.id)
)
if self.user_id:
query = query.filter(Usage.user_id == self.user_id)
if self.username:
# 支持用户名模糊搜索
query = query.filter(User.username.ilike(f"%{self.username}%"))
if self.model:
# 支持模型名模糊搜索
query = query.filter(Usage.model.ilike(f"%{self.model}%"))
if self.provider:
# 支持提供商名称搜索(通过 Provider 表)
query = query.join(Provider, Usage.provider_id == Provider.id, isouter=True)
query = query.filter(Provider.name.ilike(f"%{self.provider}%"))
if self.status:
# 状态筛选
# 旧的筛选值(基于 is_stream 和 status_codestream, standard, error
# 新的筛选值(基于 status 字段pending, streaming, completed, failed, active
if self.status == "stream":
query = query.filter(Usage.is_stream == True) # noqa: E712
elif self.status == "standard":
query = query.filter(Usage.is_stream == False) # noqa: E712
elif self.status == "error":
query = query.filter(
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
)
elif self.status in ("pending", "streaming", "completed", "failed"):
# 新的状态筛选:直接按 status 字段过滤
query = query.filter(Usage.status == self.status)
elif self.status == "active":
# 活跃请求pending 或 streaming 状态
query = query.filter(Usage.status.in_(["pending", "streaming"]))
if self.start_date:
query = query.filter(Usage.created_at >= self.start_date)
if self.end_date:
query = query.filter(Usage.created_at <= self.end_date)
total = query.count()
records = (
query.order_by(Usage.created_at.desc()).offset(self.offset).limit(self.limit).all()
)
request_ids = [usage.request_id for usage, _, _, _ in records if usage.request_id]
fallback_map = {}
if request_ids:
# 只统计实际执行的候选success 或 failed不包括 skipped/pending/available
executed_counts = (
db.query(RequestCandidate.request_id, func.count(RequestCandidate.id))
.filter(
RequestCandidate.request_id.in_(request_ids),
RequestCandidate.status.in_(["success", "failed"]),
)
.group_by(RequestCandidate.request_id)
.all()
)
# 如果实际执行的候选数 > 1说明发生了 Provider 切换
fallback_map = {req_id: count > 1 for req_id, count in executed_counts}
context.add_audit_metadata(
action="usage_records",
start_date=self.start_date.isoformat() if self.start_date else None,
end_date=self.end_date.isoformat() if self.end_date else None,
user_id=self.user_id,
username=self.username,
model=self.model,
provider=self.provider,
status=self.status,
limit=self.limit,
offset=self.offset,
total=total,
)
# 构建 provider_id -> Provider 名称的映射,避免 N+1 查询
provider_ids = [usage.provider_id for usage, _, _, _ in records if usage.provider_id]
provider_map = {}
if provider_ids:
providers_data = (
db.query(Provider.id, Provider.name).filter(Provider.id.in_(provider_ids)).all()
)
provider_map = {str(p.id): p.name for p in providers_data}
data = []
for usage, user, endpoint, api_key in records:
actual_cost = (
float(usage.actual_total_cost_usd)
if usage.actual_total_cost_usd is not None
else 0.0
)
rate_multiplier = (
float(usage.rate_multiplier) if usage.rate_multiplier is not None else 1.0
)
# 提供商名称优先级:关联的 Provider 表 > usage.provider 字段
provider_name = usage.provider
if usage.provider_id and str(usage.provider_id) in provider_map:
provider_name = provider_map[str(usage.provider_id)]
data.append(
{
"id": usage.id,
"user_id": user.id if user else None,
"user_email": user.email if user else "已删除用户",
"username": user.username if user else "已删除用户",
"provider": provider_name,
"model": usage.model,
"target_model": usage.target_model, # 映射后的目标模型名
"input_tokens": usage.input_tokens,
"output_tokens": usage.output_tokens,
"cache_creation_input_tokens": usage.cache_creation_input_tokens,
"cache_read_input_tokens": usage.cache_read_input_tokens,
"total_tokens": usage.total_tokens,
"cost": float(usage.total_cost_usd),
"actual_cost": actual_cost,
"rate_multiplier": rate_multiplier,
"response_time_ms": usage.response_time_ms,
"created_at": usage.created_at.isoformat(),
"is_stream": usage.is_stream,
"input_price_per_1m": usage.input_price_per_1m,
"output_price_per_1m": usage.output_price_per_1m,
"cache_creation_price_per_1m": usage.cache_creation_price_per_1m,
"cache_read_price_per_1m": usage.cache_read_price_per_1m,
"status_code": usage.status_code,
"error_message": usage.error_message,
"status": usage.status, # 请求状态: pending, streaming, completed, failed
"has_fallback": fallback_map.get(usage.request_id, False),
"api_format": usage.api_format
or (endpoint.api_format if endpoint and endpoint.api_format else None),
"api_key_name": api_key.name if api_key else None,
"request_metadata": usage.request_metadata, # Provider 响应元数据
}
)
return {
"records": data,
"total": total,
"limit": self.limit,
"offset": self.offset,
}
class AdminActiveRequestsAdapter(AdminApiAdapter):
"""轻量级活跃请求状态查询适配器"""
def __init__(self, ids: Optional[str]):
self.ids = ids
async def handle(self, context): # type: ignore[override]
db = context.db
if self.ids:
# 查询指定 ID 的请求状态
id_list = [id.strip() for id in self.ids.split(",") if id.strip()]
if not id_list:
return {"requests": []}
records = (
db.query(Usage.id, Usage.status, Usage.input_tokens, Usage.output_tokens, Usage.total_cost_usd, Usage.response_time_ms)
.filter(Usage.id.in_(id_list))
.all()
)
else:
# 查询所有活跃请求pending 或 streaming
records = (
db.query(Usage.id, Usage.status, Usage.input_tokens, Usage.output_tokens, Usage.total_cost_usd, Usage.response_time_ms)
.filter(Usage.status.in_(["pending", "streaming"]))
.order_by(Usage.created_at.desc())
.limit(50)
.all()
)
return {
"requests": [
{
"id": r.id,
"status": r.status,
"input_tokens": r.input_tokens,
"output_tokens": r.output_tokens,
"cost": float(r.total_cost_usd) if r.total_cost_usd else 0,
"response_time_ms": r.response_time_ms,
}
for r in records
]
}
@dataclass
class AdminUsageDetailAdapter(AdminApiAdapter):
"""Get detailed usage record with request/response body"""
usage_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
usage_record = db.query(Usage).filter(Usage.id == self.usage_id).first()
if not usage_record:
raise HTTPException(status_code=404, detail="Usage record not found")
user = db.query(User).filter(User.id == usage_record.user_id).first()
api_key = db.query(ApiKey).filter(ApiKey.id == usage_record.api_key_id).first()
# 获取阶梯计费信息
tiered_pricing_info = await self._get_tiered_pricing_info(db, usage_record)
context.add_audit_metadata(
action="usage_detail",
usage_id=self.usage_id,
)
return {
"id": usage_record.id,
"request_id": usage_record.request_id,
"user": {
"id": user.id if user else None,
"username": user.username if user else "Unknown",
"email": user.email if user else None,
},
"api_key": {
"id": api_key.id if api_key else None,
"name": api_key.name if api_key else None,
"display": api_key.get_display_key() if api_key else None,
},
"provider": usage_record.provider,
"api_format": usage_record.api_format,
"model": usage_record.model,
"target_model": usage_record.target_model,
"tokens": {
"input": usage_record.input_tokens,
"output": usage_record.output_tokens,
"total": usage_record.total_tokens,
},
"cost": {
"input": usage_record.input_cost_usd,
"output": usage_record.output_cost_usd,
"total": usage_record.total_cost_usd,
},
"cache_creation_input_tokens": usage_record.cache_creation_input_tokens,
"cache_read_input_tokens": usage_record.cache_read_input_tokens,
"cache_creation_cost": getattr(usage_record, "cache_creation_cost_usd", 0.0),
"cache_read_cost": getattr(usage_record, "cache_read_cost_usd", 0.0),
"request_cost": getattr(usage_record, "request_cost_usd", 0.0),
"input_price_per_1m": usage_record.input_price_per_1m,
"output_price_per_1m": usage_record.output_price_per_1m,
"cache_creation_price_per_1m": usage_record.cache_creation_price_per_1m,
"cache_read_price_per_1m": usage_record.cache_read_price_per_1m,
"price_per_request": usage_record.price_per_request,
"request_type": usage_record.request_type,
"is_stream": usage_record.is_stream,
"status_code": usage_record.status_code,
"error_message": usage_record.error_message,
"response_time_ms": usage_record.response_time_ms,
"created_at": usage_record.created_at.isoformat() if usage_record.created_at else None,
"request_headers": usage_record.request_headers,
"request_body": usage_record.get_request_body(),
"provider_request_headers": usage_record.provider_request_headers,
"response_headers": usage_record.response_headers,
"response_body": usage_record.get_response_body(),
"metadata": usage_record.request_metadata,
"tiered_pricing": tiered_pricing_info,
}
async def _get_tiered_pricing_info(self, db, usage_record) -> dict | None:
"""获取阶梯计费信息"""
from src.services.model.cost import ModelCostService
# 计算总输入上下文(用于阶梯判定):输入 + 缓存创建 + 缓存读取
input_tokens = usage_record.input_tokens or 0
cache_creation_tokens = usage_record.cache_creation_input_tokens or 0
cache_read_tokens = usage_record.cache_read_input_tokens or 0
total_input_context = input_tokens + cache_creation_tokens + cache_read_tokens
# 尝试获取模型的阶梯配置(带来源信息)
cost_service = ModelCostService(db)
pricing_result = await cost_service.get_tiered_pricing_with_source_async(
usage_record.provider, usage_record.model
)
if not pricing_result:
return None
tiered_pricing = pricing_result.get("pricing")
pricing_source = pricing_result.get("source") # 'provider' 或 'global'
if not tiered_pricing or not tiered_pricing.get("tiers"):
return None
tiers = tiered_pricing.get("tiers", [])
if not tiers:
return None
# 找到命中的阶梯
tier_index = None
matched_tier = None
for i, tier in enumerate(tiers):
up_to = tier.get("up_to")
if up_to is None or total_input_context <= up_to:
tier_index = i
matched_tier = tier
break
# 如果都没匹配,使用最后一个阶梯
if tier_index is None and tiers:
tier_index = len(tiers) - 1
matched_tier = tiers[-1]
return {
"total_input_context": total_input_context,
"tier_index": tier_index,
"tier_count": len(tiers),
"current_tier": matched_tier,
"tiers": tiers,
"source": pricing_source, # 定价来源: 'provider' 或 'global'
}

View File

@@ -0,0 +1,5 @@
"""User admin routes export."""
from .routes import router
__all__ = ["router"]

View File

@@ -0,0 +1,488 @@
"""用户管理 API 端点。"""
from datetime import datetime, timezone
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from pydantic import ValidationError
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import InvalidRequestException, NotFoundException, translate_pydantic_error
from src.core.logger import logger
from src.database import get_db
from src.models.admin_requests import UpdateUserRequest
from src.models.api import CreateApiKeyRequest, CreateUserRequest
from src.models.database import ApiKey, User, UserRole
from src.services.user.apikey import ApiKeyService
from src.services.user.service import UserService
router = APIRouter(prefix="/api/admin/users", tags=["Admin - Users"])
pipeline = ApiRequestPipeline()
# 管理员端点
@router.post("")
async def create_user_endpoint(request: Request, db: Session = Depends(get_db)):
adapter = AdminCreateUserAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("")
async def list_users(
request: Request,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000),
role: Optional[str] = None,
is_active: Optional[bool] = None,
db: Session = Depends(get_db),
):
adapter = AdminListUsersAdapter(skip=skip, limit=limit, role=role, is_active=is_active)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/{user_id}")
async def get_user(user_id: str, request: Request, db: Session = Depends(get_db)): # UUID
adapter = AdminGetUserAdapter(user_id=user_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.put("/{user_id}")
async def update_user(
user_id: str,
request: Request,
db: Session = Depends(get_db),
):
adapter = AdminUpdateUserAdapter(user_id=user_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/{user_id}")
async def delete_user(user_id: str, request: Request, db: Session = Depends(get_db)): # UUID
adapter = AdminDeleteUserAdapter(user_id=user_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.patch("/{user_id}/quota")
async def reset_user_quota(user_id: str, request: Request, db: Session = Depends(get_db)):
"""Reset user quota (set used_usd to 0)"""
adapter = AdminResetUserQuotaAdapter(user_id=user_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/{user_id}/api-keys")
async def get_user_api_keys(
user_id: str,
request: Request,
is_active: Optional[bool] = None,
db: Session = Depends(get_db),
):
"""获取用户的所有API Keys不包括独立Keys"""
adapter = AdminGetUserKeysAdapter(user_id=user_id, is_active=is_active)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/{user_id}/api-keys")
async def create_user_api_key(
user_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""为用户创建API Key"""
adapter = AdminCreateUserKeyAdapter(user_id=user_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/{user_id}/api-keys/{key_id}")
async def delete_user_api_key(
user_id: str,
key_id: str,
request: Request,
db: Session = Depends(get_db),
):
"""删除用户的API Key"""
adapter = AdminDeleteUserKeyAdapter(user_id=user_id, key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# ============== 管理员适配器实现 ==============
class AdminCreateUserAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
payload = context.ensure_json_body()
try:
request = CreateUserRequest.model_validate(payload)
except ValidationError as e:
errors = e.errors()
if errors:
raise InvalidRequestException(translate_pydantic_error(errors[0]))
raise InvalidRequestException("请求数据验证失败")
try:
role = (
request.role if hasattr(request.role, "value") else UserRole[request.role.upper()]
)
except (KeyError, AttributeError):
raise InvalidRequestException("角色参数不合法")
try:
user = UserService.create_user(
db=db,
email=request.email,
username=request.username,
password=request.password,
role=role,
quota_usd=request.quota_usd,
)
except ValueError as exc:
raise InvalidRequestException(str(exc))
context.add_audit_metadata(
action="create_user",
target_user_id=user.id,
target_email=user.email,
target_username=user.username,
target_role=user.role.value,
quota_usd=user.quota_usd,
is_active=user.is_active,
)
return {
"id": user.id,
"email": user.email,
"username": user.username,
"role": user.role.value,
"allowed_providers": user.allowed_providers,
"allowed_endpoints": user.allowed_endpoints,
"allowed_models": user.allowed_models,
"quota_usd": user.quota_usd,
"used_usd": user.used_usd,
"total_usd": getattr(user, "total_usd", 0),
"is_active": user.is_active,
"created_at": user.created_at.isoformat(),
}
class AdminListUsersAdapter(AdminApiAdapter):
def __init__(self, skip: int, limit: int, role: Optional[str], is_active: Optional[bool]):
self.skip = skip
self.limit = limit
self.role = role
self.is_active = is_active
async def handle(self, context): # type: ignore[override]
db = context.db
role_enum = UserRole[self.role.upper()] if self.role else None
users = UserService.list_users(db, self.skip, self.limit, role_enum, self.is_active)
return [
{
"id": u.id,
"email": u.email,
"username": u.username,
"role": u.role.value,
"quota_usd": u.quota_usd,
"used_usd": u.used_usd,
"total_usd": getattr(u, "total_usd", 0),
"is_active": u.is_active,
"created_at": u.created_at.isoformat(),
}
for u in users
]
class AdminGetUserAdapter(AdminApiAdapter):
def __init__(self, user_id: str):
self.user_id = user_id
async def handle(self, context): # type: ignore[override]
db = context.db
user = UserService.get_user(db, self.user_id)
if not user:
raise NotFoundException("用户不存在", "user")
context.add_audit_metadata(
action="get_user_detail",
target_user_id=user.id,
target_role=user.role.value,
include_history=bool(user.last_login_at),
)
return {
"id": user.id,
"email": user.email,
"username": user.username,
"role": user.role.value,
"allowed_providers": user.allowed_providers,
"allowed_endpoints": user.allowed_endpoints,
"allowed_models": user.allowed_models,
"quota_usd": user.quota_usd,
"used_usd": user.used_usd,
"total_usd": getattr(user, "total_usd", 0),
"is_active": user.is_active,
"created_at": user.created_at.isoformat(),
"updated_at": user.updated_at.isoformat() if user.updated_at else None,
"last_login_at": user.last_login_at.isoformat() if user.last_login_at else None,
}
class AdminUpdateUserAdapter(AdminApiAdapter):
def __init__(self, user_id: str):
self.user_id = user_id
async def handle(self, context): # type: ignore[override]
db = context.db
existing_user = UserService.get_user(db, self.user_id)
if not existing_user:
raise NotFoundException("用户不存在", "user")
payload = context.ensure_json_body()
try:
request = UpdateUserRequest.model_validate(payload)
except ValidationError as e:
errors = e.errors()
if errors:
raise InvalidRequestException(translate_pydantic_error(errors[0]))
raise InvalidRequestException("请求数据验证失败")
update_data = request.model_dump(exclude_unset=True)
if "role" in update_data and update_data["role"]:
if hasattr(update_data["role"], "value"):
update_data["role"] = update_data["role"]
else:
update_data["role"] = UserRole[update_data["role"].upper()]
user = UserService.update_user(db, self.user_id, **update_data)
if not user:
raise NotFoundException("用户不存在", "user")
changed_fields = list(update_data.keys())
context.add_audit_metadata(
action="update_user",
target_user_id=user.id,
updated_fields=changed_fields,
role_before=existing_user.role.value if existing_user.role else None,
role_after=user.role.value,
quota_usd=user.quota_usd,
is_active=user.is_active,
)
return {
"id": user.id,
"email": user.email,
"username": user.username,
"role": user.role.value,
"allowed_providers": user.allowed_providers,
"allowed_endpoints": user.allowed_endpoints,
"allowed_models": user.allowed_models,
"quota_usd": user.quota_usd,
"used_usd": user.used_usd,
"total_usd": getattr(user, "total_usd", 0),
"is_active": user.is_active,
"created_at": user.created_at.isoformat(),
"updated_at": user.updated_at.isoformat() if user.updated_at else None,
}
class AdminDeleteUserAdapter(AdminApiAdapter):
def __init__(self, user_id: str):
self.user_id = user_id
async def handle(self, context): # type: ignore[override]
db = context.db
user = UserService.get_user(db, self.user_id)
if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
if user.role == UserRole.ADMIN:
admin_count = db.query(User).filter(User.role == UserRole.ADMIN).count()
if admin_count <= 1:
raise InvalidRequestException("不能删除最后一个管理员账户")
success = UserService.delete_user(db, self.user_id)
if not success:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
context.add_audit_metadata(
action="delete_user",
target_user_id=user.id,
target_email=user.email,
target_role=user.role.value,
)
return {"message": "用户删除成功"}
class AdminResetUserQuotaAdapter(AdminApiAdapter):
def __init__(self, user_id: str):
self.user_id = user_id
async def handle(self, context): # type: ignore[override]
db = context.db
user = UserService.get_user(db, self.user_id)
if not user:
raise NotFoundException("用户不存在", "user")
user.used_usd = 0.0
user.total_usd = getattr(user, "total_usd", 0)
user.updated_at = datetime.now(timezone.utc)
db.commit()
context.add_audit_metadata(
action="reset_user_quota",
target_user_id=user.id,
quota_usd=user.quota_usd,
used_usd=user.used_usd,
total_usd=user.total_usd,
)
return {
"message": "配额已重置",
"user_id": user.id,
"quota_usd": user.quota_usd,
"used_usd": user.used_usd,
"total_usd": user.total_usd,
}
class AdminGetUserKeysAdapter(AdminApiAdapter):
"""获取用户的API Keys"""
def __init__(self, user_id: str, is_active: Optional[bool]):
self.user_id = user_id
self.is_active = is_active
async def handle(self, context): # type: ignore[override]
db = context.db
# 验证用户存在
user = db.query(User).filter(User.id == self.user_id).first()
if not user:
raise NotFoundException("用户不存在", "user")
# 获取用户的Keys不包括独立Keys
api_keys = ApiKeyService.list_user_api_keys(
db=db, user_id=self.user_id, is_active=self.is_active
)
context.add_audit_metadata(
action="list_user_api_keys",
target_user_id=self.user_id,
total=len(api_keys),
)
return {
"api_keys": [
{
"id": key.id,
"name": key.name,
"key_display": key.get_display_key(),
"is_active": key.is_active,
"total_requests": key.total_requests,
"total_cost_usd": float(key.total_cost_usd or 0),
"rate_limit": key.rate_limit,
"expires_at": key.expires_at.isoformat() if key.expires_at else None,
"last_used_at": key.last_used_at.isoformat() if key.last_used_at else None,
"created_at": key.created_at.isoformat(),
}
for key in api_keys
],
"total": len(api_keys),
"user_email": user.email,
"username": user.username,
}
class AdminCreateUserKeyAdapter(AdminApiAdapter):
"""为用户创建API Key"""
def __init__(self, user_id: str):
self.user_id = user_id
async def handle(self, context): # type: ignore[override]
db = context.db
payload = context.ensure_json_body()
try:
key_data = CreateApiKeyRequest.model_validate(payload)
except ValidationError as e:
errors = e.errors()
if errors:
raise InvalidRequestException(translate_pydantic_error(errors[0]))
raise InvalidRequestException("请求数据验证失败")
# 验证用户存在
user = db.query(User).filter(User.id == self.user_id).first()
if not user:
raise NotFoundException("用户不存在", "user")
# 为用户创建Key不是独立Key
api_key, plain_key = ApiKeyService.create_api_key(
db=db,
user_id=self.user_id,
name=key_data.name,
allowed_providers=key_data.allowed_providers,
allowed_models=key_data.allowed_models,
rate_limit=key_data.rate_limit or 100,
expire_days=key_data.expire_days,
initial_balance_usd=None, # 普通Key不设置余额限制
is_standalone=False, # 不是独立Key
)
logger.info(f"管理员为用户创建API Key: 用户 {user.email}, Key ID {api_key.id}")
context.add_audit_metadata(
action="create_user_api_key",
target_user_id=self.user_id,
key_id=api_key.id,
)
return {
"id": api_key.id,
"key": plain_key, # 只在创建时返回
"name": api_key.name,
"key_display": api_key.get_display_key(),
"rate_limit": api_key.rate_limit,
"expires_at": api_key.expires_at.isoformat() if api_key.expires_at else None,
"created_at": api_key.created_at.isoformat(),
"message": "API Key创建成功请妥善保存完整密钥",
}
class AdminDeleteUserKeyAdapter(AdminApiAdapter):
"""删除用户的API Key"""
def __init__(self, user_id: str, key_id: str):
self.user_id = user_id
self.key_id = key_id
async def handle(self, context): # type: ignore[override]
db = context.db
# 验证Key存在且属于该用户
api_key = (
db.query(ApiKey)
.filter(
ApiKey.id == self.key_id,
ApiKey.user_id == self.user_id,
ApiKey.is_standalone == False, # 只能删除普通Key
)
.first()
)
if not api_key:
raise NotFoundException("API Key不存在或不属于该用户", "api_key")
db.delete(api_key)
db.commit()
logger.info(f"管理员删除用户API Key: 用户ID {self.user_id}, Key ID {self.key_id}")
context.add_audit_metadata(
action="delete_user_api_key",
target_user_id=self.user_id,
key_id=self.key_id,
)
return {"message": "API Key已删除"}