mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-08 10:42:29 +08:00
Initial commit
This commit is contained in:
30
src/api/admin/__init__.py
Normal file
30
src/api/admin/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Admin API routers."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .adaptive import router as adaptive_router
|
||||
from .api_keys import router as api_keys_router
|
||||
from .endpoints import router as endpoints_router
|
||||
from .models import router as models_router
|
||||
from .monitoring import router as monitoring_router
|
||||
from .provider_strategy import router as provider_strategy_router
|
||||
from .providers import router as providers_router
|
||||
from .security import router as security_router
|
||||
from .system import router as system_router
|
||||
from .usage import router as usage_router
|
||||
from .users import router as users_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(system_router)
|
||||
router.include_router(users_router)
|
||||
router.include_router(providers_router)
|
||||
router.include_router(api_keys_router)
|
||||
router.include_router(usage_router)
|
||||
router.include_router(monitoring_router)
|
||||
router.include_router(endpoints_router)
|
||||
router.include_router(provider_strategy_router)
|
||||
router.include_router(adaptive_router)
|
||||
router.include_router(models_router)
|
||||
router.include_router(security_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
377
src/api/admin/adaptive.py
Normal file
377
src/api/admin/adaptive.py
Normal file
@@ -0,0 +1,377 @@
|
||||
"""
|
||||
自适应并发管理 API 端点
|
||||
|
||||
设计原则:
|
||||
- 自适应模式由 max_concurrent 字段决定:
|
||||
- max_concurrent = NULL:启用自适应模式,系统自动学习并调整并发限制
|
||||
- max_concurrent = 数字:固定限制模式,使用用户指定的并发限制
|
||||
- learned_max_concurrent:自适应模式下学习到的并发限制值
|
||||
- adaptive_mode 是计算字段,基于 max_concurrent 是否为 NULL
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.exceptions import InvalidRequestException, translate_pydantic_error
|
||||
from src.database import get_db
|
||||
from src.models.database import ProviderAPIKey
|
||||
from src.services.rate_limit.adaptive_concurrency import get_adaptive_manager
|
||||
|
||||
router = APIRouter(prefix="/api/admin/adaptive", tags=["Adaptive Concurrency"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
# ==================== Pydantic Models ====================
|
||||
|
||||
|
||||
class EnableAdaptiveRequest(BaseModel):
|
||||
"""启用自适应模式请求"""
|
||||
|
||||
enabled: bool = Field(..., description="是否启用自适应模式(true=自适应,false=固定限制)")
|
||||
fixed_limit: Optional[int] = Field(
|
||||
None, ge=1, le=100, description="固定并发限制(仅当 enabled=false 时生效)"
|
||||
)
|
||||
|
||||
|
||||
class AdaptiveStatsResponse(BaseModel):
|
||||
"""自适应统计响应"""
|
||||
|
||||
adaptive_mode: bool = Field(..., description="是否为自适应模式(max_concurrent=NULL)")
|
||||
max_concurrent: Optional[int] = Field(None, description="用户配置的固定限制(NULL=自适应)")
|
||||
effective_limit: Optional[int] = Field(
|
||||
None, description="当前有效限制(自适应使用学习值,固定使用配置值)"
|
||||
)
|
||||
learned_limit: Optional[int] = Field(None, description="学习到的并发限制")
|
||||
concurrent_429_count: int
|
||||
rpm_429_count: int
|
||||
last_429_at: Optional[str]
|
||||
last_429_type: Optional[str]
|
||||
adjustment_count: int
|
||||
recent_adjustments: List[dict]
|
||||
|
||||
|
||||
class KeyListItem(BaseModel):
|
||||
"""Key 列表项"""
|
||||
|
||||
id: str
|
||||
name: Optional[str]
|
||||
endpoint_id: str
|
||||
is_adaptive: bool = Field(..., description="是否为自适应模式(max_concurrent=NULL)")
|
||||
max_concurrent: Optional[int] = Field(None, description="固定并发限制(NULL=自适应)")
|
||||
effective_limit: Optional[int] = Field(None, description="当前有效限制")
|
||||
learned_max_concurrent: Optional[int] = Field(None, description="学习到的并发限制")
|
||||
concurrent_429_count: int
|
||||
rpm_429_count: int
|
||||
|
||||
|
||||
# ==================== API Endpoints ====================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/keys",
|
||||
response_model=List[KeyListItem],
|
||||
summary="获取所有启用自适应模式的Key",
|
||||
)
|
||||
async def list_adaptive_keys(
|
||||
request: Request,
|
||||
endpoint_id: Optional[str] = Query(None, description="按 Endpoint 过滤"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取所有启用自适应模式的Key列表
|
||||
|
||||
可选参数:
|
||||
- endpoint_id: 按 Endpoint 过滤
|
||||
"""
|
||||
adapter = ListAdaptiveKeysAdapter(endpoint_id=endpoint_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/keys/{key_id}/mode",
|
||||
summary="Toggle key's concurrency control mode",
|
||||
)
|
||||
async def toggle_adaptive_mode(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Toggle the concurrency control mode for a specific key
|
||||
|
||||
Parameters:
|
||||
- enabled: true=adaptive mode (max_concurrent=NULL), false=fixed limit mode
|
||||
- fixed_limit: fixed limit value (required when enabled=false)
|
||||
"""
|
||||
adapter = ToggleAdaptiveModeAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/keys/{key_id}/stats",
|
||||
response_model=AdaptiveStatsResponse,
|
||||
summary="获取Key的自适应统计",
|
||||
)
|
||||
async def get_adaptive_stats(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取指定Key的自适应并发统计信息
|
||||
|
||||
包括:
|
||||
- 当前配置
|
||||
- 学习到的限制
|
||||
- 429错误统计
|
||||
- 调整历史
|
||||
"""
|
||||
adapter = GetAdaptiveStatsAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/keys/{key_id}/learning",
|
||||
summary="Reset key's learning state",
|
||||
)
|
||||
async def reset_adaptive_learning(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Reset the adaptive learning state for a specific key
|
||||
|
||||
Clears:
|
||||
- Learned concurrency limit (learned_max_concurrent)
|
||||
- 429 error counts
|
||||
- Adjustment history
|
||||
|
||||
Does not change:
|
||||
- max_concurrent config (determines adaptive mode)
|
||||
"""
|
||||
adapter = ResetAdaptiveLearningAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/keys/{key_id}/limit",
|
||||
summary="Set key to fixed concurrency limit mode",
|
||||
)
|
||||
async def set_concurrent_limit(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
limit: int = Query(..., ge=1, le=100, description="Concurrency limit value"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Set key to fixed concurrency limit mode
|
||||
|
||||
Note:
|
||||
- After setting this value, key switches to fixed limit mode and won't auto-adjust
|
||||
- To restore adaptive mode, use PATCH /keys/{key_id}/mode
|
||||
"""
|
||||
adapter = SetConcurrentLimitAdapter(key_id=key_id, limit=limit)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/summary",
|
||||
summary="获取自适应并发的全局统计",
|
||||
)
|
||||
async def get_adaptive_summary(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取自适应并发的全局统计摘要
|
||||
|
||||
包括:
|
||||
- 启用自适应模式的Key数量
|
||||
- 总429错误数
|
||||
- 并发限制调整次数
|
||||
"""
|
||||
adapter = AdaptiveSummaryAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# ==================== Pipeline 适配器 ====================
|
||||
|
||||
|
||||
@dataclass
|
||||
class ListAdaptiveKeysAdapter(AdminApiAdapter):
|
||||
endpoint_id: Optional[str] = None
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
# 自适应模式:max_concurrent = NULL
|
||||
query = context.db.query(ProviderAPIKey).filter(ProviderAPIKey.max_concurrent.is_(None))
|
||||
if self.endpoint_id:
|
||||
query = query.filter(ProviderAPIKey.endpoint_id == self.endpoint_id)
|
||||
|
||||
keys = query.all()
|
||||
return [
|
||||
KeyListItem(
|
||||
id=key.id,
|
||||
name=key.name,
|
||||
endpoint_id=key.endpoint_id,
|
||||
is_adaptive=key.max_concurrent is None,
|
||||
max_concurrent=key.max_concurrent,
|
||||
effective_limit=(
|
||||
key.learned_max_concurrent if key.max_concurrent is None else key.max_concurrent
|
||||
),
|
||||
learned_max_concurrent=key.learned_max_concurrent,
|
||||
concurrent_429_count=key.concurrent_429_count or 0,
|
||||
rpm_429_count=key.rpm_429_count or 0,
|
||||
)
|
||||
for key in keys
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToggleAdaptiveModeAdapter(AdminApiAdapter):
|
||||
key_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
key = context.db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
|
||||
if not key:
|
||||
raise HTTPException(status_code=404, detail="Key not found")
|
||||
|
||||
payload = context.ensure_json_body()
|
||||
try:
|
||||
body = EnableAdaptiveRequest.model_validate(payload)
|
||||
except ValidationError as e:
|
||||
errors = e.errors()
|
||||
if errors:
|
||||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||||
raise InvalidRequestException("请求数据验证失败")
|
||||
|
||||
if body.enabled:
|
||||
# 启用自适应模式:将 max_concurrent 设为 NULL
|
||||
key.max_concurrent = None
|
||||
message = "已切换为自适应模式,系统将自动学习并调整并发限制"
|
||||
else:
|
||||
# 禁用自适应模式:设置固定限制
|
||||
if body.fixed_limit is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="禁用自适应模式时必须提供 fixed_limit 参数"
|
||||
)
|
||||
key.max_concurrent = body.fixed_limit
|
||||
message = f"已切换为固定限制模式,并发限制设为 {body.fixed_limit}"
|
||||
|
||||
context.db.commit()
|
||||
context.db.refresh(key)
|
||||
|
||||
is_adaptive = key.max_concurrent is None
|
||||
return {
|
||||
"message": message,
|
||||
"key_id": key.id,
|
||||
"is_adaptive": is_adaptive,
|
||||
"max_concurrent": key.max_concurrent,
|
||||
"effective_limit": key.learned_max_concurrent if is_adaptive else key.max_concurrent,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetAdaptiveStatsAdapter(AdminApiAdapter):
|
||||
key_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
key = context.db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
|
||||
if not key:
|
||||
raise HTTPException(status_code=404, detail="Key not found")
|
||||
|
||||
adaptive_manager = get_adaptive_manager()
|
||||
stats = adaptive_manager.get_adjustment_stats(key)
|
||||
|
||||
# 转换字段名以匹配响应模型
|
||||
return AdaptiveStatsResponse(
|
||||
adaptive_mode=stats["adaptive_mode"],
|
||||
max_concurrent=stats["max_concurrent"],
|
||||
effective_limit=stats["effective_limit"],
|
||||
learned_limit=stats["learned_limit"],
|
||||
concurrent_429_count=stats["concurrent_429_count"],
|
||||
rpm_429_count=stats["rpm_429_count"],
|
||||
last_429_at=stats["last_429_at"],
|
||||
last_429_type=stats["last_429_type"],
|
||||
adjustment_count=stats["adjustment_count"],
|
||||
recent_adjustments=stats["recent_adjustments"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResetAdaptiveLearningAdapter(AdminApiAdapter):
|
||||
key_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
key = context.db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
|
||||
if not key:
|
||||
raise HTTPException(status_code=404, detail="Key not found")
|
||||
|
||||
adaptive_manager = get_adaptive_manager()
|
||||
adaptive_manager.reset_learning(context.db, key)
|
||||
return {"message": "学习状态已重置", "key_id": key.id}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SetConcurrentLimitAdapter(AdminApiAdapter):
|
||||
key_id: str
|
||||
limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
key = context.db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
|
||||
if not key:
|
||||
raise HTTPException(status_code=404, detail="Key not found")
|
||||
|
||||
was_adaptive = key.max_concurrent is None
|
||||
key.max_concurrent = self.limit
|
||||
context.db.commit()
|
||||
context.db.refresh(key)
|
||||
|
||||
return {
|
||||
"message": f"已设置为固定限制模式,并发限制为 {self.limit}",
|
||||
"key_id": key.id,
|
||||
"is_adaptive": False,
|
||||
"max_concurrent": key.max_concurrent,
|
||||
"previous_mode": "adaptive" if was_adaptive else "fixed",
|
||||
}
|
||||
|
||||
|
||||
class AdaptiveSummaryAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
# 自适应模式:max_concurrent = NULL
|
||||
adaptive_keys = (
|
||||
context.db.query(ProviderAPIKey).filter(ProviderAPIKey.max_concurrent.is_(None)).all()
|
||||
)
|
||||
|
||||
total_keys = len(adaptive_keys)
|
||||
total_concurrent_429 = sum(key.concurrent_429_count or 0 for key in adaptive_keys)
|
||||
total_rpm_429 = sum(key.rpm_429_count or 0 for key in adaptive_keys)
|
||||
total_adjustments = sum(len(key.adjustment_history or []) for key in adaptive_keys)
|
||||
|
||||
recent_adjustments = []
|
||||
for key in adaptive_keys:
|
||||
if key.adjustment_history:
|
||||
for adj in key.adjustment_history[-3:]:
|
||||
recent_adjustments.append(
|
||||
{
|
||||
"key_id": key.id,
|
||||
"key_name": key.name,
|
||||
**adj,
|
||||
}
|
||||
)
|
||||
|
||||
recent_adjustments.sort(key=lambda item: item.get("timestamp", ""), reverse=True)
|
||||
|
||||
return {
|
||||
"total_adaptive_keys": total_keys,
|
||||
"total_concurrent_429_errors": total_concurrent_429,
|
||||
"total_rpm_429_errors": total_rpm_429,
|
||||
"total_adjustments": total_adjustments,
|
||||
"recent_adjustments": recent_adjustments[:10],
|
||||
}
|
||||
5
src/api/admin/api_keys/__init__.py
Normal file
5
src/api/admin/api_keys/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""API key admin routes export."""
|
||||
|
||||
from .routes import router
|
||||
|
||||
__all__ = ["router"]
|
||||
497
src/api/admin/api_keys/routes.py
Normal file
497
src/api/admin/api_keys/routes.py
Normal file
@@ -0,0 +1,497 @@
|
||||
"""管理员独立余额 API Key 管理路由。
|
||||
|
||||
独立余额Key:不关联用户配额,有独立余额限制,用于给非注册用户使用。
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.exceptions import NotFoundException
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.api import CreateApiKeyRequest
|
||||
from src.models.database import ApiKey, User
|
||||
from src.services.user.apikey import ApiKeyService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/admin/api-keys", tags=["Admin - API Keys (Standalone)"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_standalone_api_keys(
|
||||
request: Request,
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=500),
|
||||
is_active: Optional[bool] = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""列出所有独立余额API Keys"""
|
||||
adapter = AdminListStandaloneKeysAdapter(skip=skip, limit=limit, is_active=is_active)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def create_standalone_api_key(
|
||||
request: Request,
|
||||
key_data: CreateApiKeyRequest,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""创建独立余额API Key(必须设置余额限制)"""
|
||||
adapter = AdminCreateStandaloneKeyAdapter(key_data=key_data)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("/{key_id}")
|
||||
async def update_api_key(
|
||||
key_id: str, request: Request, key_data: CreateApiKeyRequest, db: Session = Depends(get_db)
|
||||
):
|
||||
"""更新独立余额Key(可修改名称、过期时间、余额限制等)"""
|
||||
adapter = AdminUpdateApiKeyAdapter(key_id=key_id, key_data=key_data)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch("/{key_id}")
|
||||
async def toggle_api_key(key_id: str, request: Request, db: Session = Depends(get_db)):
|
||||
"""Toggle API key active status (PATCH with is_active in body)"""
|
||||
adapter = AdminToggleApiKeyAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/{key_id}")
|
||||
async def delete_api_key(key_id: str, request: Request, db: Session = Depends(get_db)):
|
||||
adapter = AdminDeleteApiKeyAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch("/{key_id}/balance")
|
||||
async def add_balance_to_key(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Adjust balance for standalone API key (positive to add, negative to deduct)"""
|
||||
# 从请求体获取调整金额
|
||||
body = await request.json()
|
||||
amount_usd = body.get("amount_usd")
|
||||
|
||||
# 参数校验
|
||||
if amount_usd is None:
|
||||
raise HTTPException(status_code=400, detail="缺少必需参数: amount_usd")
|
||||
|
||||
if amount_usd == 0:
|
||||
raise HTTPException(status_code=400, detail="调整金额不能为 0")
|
||||
|
||||
# 类型校验
|
||||
try:
|
||||
amount_usd = float(amount_usd)
|
||||
except (ValueError, TypeError):
|
||||
raise HTTPException(status_code=400, detail="调整金额必须是有效数字")
|
||||
|
||||
# 如果是扣除操作,检查Key是否存在以及余额是否充足
|
||||
if amount_usd < 0:
|
||||
api_key = db.query(ApiKey).filter(ApiKey.id == key_id).first()
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=404, detail="API密钥不存在")
|
||||
|
||||
if not api_key.is_standalone:
|
||||
raise HTTPException(status_code=400, detail="只能为独立余额Key调整余额")
|
||||
|
||||
if api_key.current_balance_usd is not None:
|
||||
if abs(amount_usd) > api_key.current_balance_usd:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"扣除金额 ${abs(amount_usd):.2f} 超过当前余额 ${api_key.current_balance_usd:.2f}",
|
||||
)
|
||||
|
||||
adapter = AdminAddBalanceAdapter(key_id=key_id, amount_usd=amount_usd)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/{key_id}")
|
||||
async def get_api_key_detail(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
include_key: bool = Query(False, description="Include full decrypted key in response"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get API key detail, optionally include full key"""
|
||||
if include_key:
|
||||
adapter = AdminGetFullKeyAdapter(key_id=key_id)
|
||||
else:
|
||||
# Return basic key info without full key
|
||||
adapter = AdminGetKeyDetailAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
class AdminListStandaloneKeysAdapter(AdminApiAdapter):
|
||||
"""列出独立余额Keys"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
skip: int,
|
||||
limit: int,
|
||||
is_active: Optional[bool],
|
||||
):
|
||||
self.skip = skip
|
||||
self.limit = limit
|
||||
self.is_active = is_active
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
# 只查询独立余额Keys
|
||||
query = db.query(ApiKey).filter(ApiKey.is_standalone == True)
|
||||
|
||||
if self.is_active is not None:
|
||||
query = query.filter(ApiKey.is_active == self.is_active)
|
||||
|
||||
total = query.count()
|
||||
api_keys = (
|
||||
query.order_by(ApiKey.created_at.desc()).offset(self.skip).limit(self.limit).all()
|
||||
)
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="list_standalone_api_keys",
|
||||
filter_is_active=self.is_active,
|
||||
limit=self.limit,
|
||||
skip=self.skip,
|
||||
total=total,
|
||||
)
|
||||
|
||||
return {
|
||||
"api_keys": [
|
||||
{
|
||||
"id": api_key.id,
|
||||
"user_id": api_key.user_id, # 创建者ID
|
||||
"name": api_key.name,
|
||||
"key_display": api_key.get_display_key(),
|
||||
"is_active": api_key.is_active,
|
||||
"is_standalone": api_key.is_standalone,
|
||||
"current_balance_usd": api_key.current_balance_usd,
|
||||
"balance_used_usd": float(api_key.balance_used_usd or 0),
|
||||
"total_requests": api_key.total_requests,
|
||||
"total_cost_usd": float(api_key.total_cost_usd or 0),
|
||||
"rate_limit": api_key.rate_limit,
|
||||
"allowed_providers": api_key.allowed_providers,
|
||||
"allowed_api_formats": api_key.allowed_api_formats,
|
||||
"allowed_models": api_key.allowed_models,
|
||||
"last_used_at": (
|
||||
api_key.last_used_at.isoformat() if api_key.last_used_at else None
|
||||
),
|
||||
"expires_at": api_key.expires_at.isoformat() if api_key.expires_at else None,
|
||||
"created_at": api_key.created_at.isoformat(),
|
||||
"updated_at": api_key.updated_at.isoformat() if api_key.updated_at else None,
|
||||
"auto_delete_on_expiry": api_key.auto_delete_on_expiry,
|
||||
}
|
||||
for api_key in api_keys
|
||||
],
|
||||
"total": total,
|
||||
"limit": self.limit,
|
||||
"skip": self.skip,
|
||||
}
|
||||
|
||||
|
||||
class AdminCreateStandaloneKeyAdapter(AdminApiAdapter):
|
||||
"""创建独立余额Key"""
|
||||
|
||||
def __init__(self, key_data: CreateApiKeyRequest):
|
||||
self.key_data = key_data
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
|
||||
# 独立Key必须设置初始余额
|
||||
if not self.key_data.initial_balance_usd or self.key_data.initial_balance_usd <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="创建独立余额Key必须设置有效的初始余额(initial_balance_usd > 0)",
|
||||
)
|
||||
|
||||
# 独立Key需要关联到管理员用户(从context获取)
|
||||
admin_user_id = context.user.id
|
||||
|
||||
# 创建独立Key
|
||||
api_key, plain_key = ApiKeyService.create_api_key(
|
||||
db=db,
|
||||
user_id=admin_user_id, # 关联到创建者
|
||||
name=self.key_data.name,
|
||||
allowed_providers=self.key_data.allowed_providers,
|
||||
allowed_api_formats=self.key_data.allowed_api_formats,
|
||||
allowed_models=self.key_data.allowed_models,
|
||||
rate_limit=self.key_data.rate_limit or 100,
|
||||
expire_days=self.key_data.expire_days,
|
||||
initial_balance_usd=self.key_data.initial_balance_usd,
|
||||
is_standalone=True, # 标记为独立Key
|
||||
auto_delete_on_expiry=self.key_data.auto_delete_on_expiry,
|
||||
)
|
||||
|
||||
logger.info(f"管理员创建独立余额Key: ID {api_key.id}, 初始余额 ${self.key_data.initial_balance_usd}")
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="create_standalone_api_key",
|
||||
key_id=api_key.id,
|
||||
initial_balance_usd=self.key_data.initial_balance_usd,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": api_key.id,
|
||||
"key": plain_key, # 只在创建时返回完整密钥
|
||||
"name": api_key.name,
|
||||
"key_display": api_key.get_display_key(),
|
||||
"is_standalone": True,
|
||||
"current_balance_usd": api_key.current_balance_usd,
|
||||
"balance_used_usd": 0.0,
|
||||
"rate_limit": api_key.rate_limit,
|
||||
"expires_at": api_key.expires_at.isoformat() if api_key.expires_at else None,
|
||||
"created_at": api_key.created_at.isoformat(),
|
||||
"message": "独立余额Key创建成功,请妥善保存完整密钥,后续将无法查看",
|
||||
}
|
||||
|
||||
|
||||
class AdminUpdateApiKeyAdapter(AdminApiAdapter):
|
||||
"""更新独立余额Key"""
|
||||
|
||||
def __init__(self, key_id: str, key_data: CreateApiKeyRequest):
|
||||
self.key_id = key_id
|
||||
self.key_data = key_data
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
api_key = db.query(ApiKey).filter(ApiKey.id == self.key_id).first()
|
||||
if not api_key:
|
||||
raise NotFoundException("API密钥不存在", "api_key")
|
||||
|
||||
# 构建更新数据
|
||||
update_data = {}
|
||||
if self.key_data.name is not None:
|
||||
update_data["name"] = self.key_data.name
|
||||
if self.key_data.rate_limit is not None:
|
||||
update_data["rate_limit"] = self.key_data.rate_limit
|
||||
if (
|
||||
hasattr(self.key_data, "auto_delete_on_expiry")
|
||||
and self.key_data.auto_delete_on_expiry is not None
|
||||
):
|
||||
update_data["auto_delete_on_expiry"] = self.key_data.auto_delete_on_expiry
|
||||
|
||||
# 访问限制配置(允许设置为空数组来清除限制)
|
||||
if hasattr(self.key_data, "allowed_providers"):
|
||||
update_data["allowed_providers"] = self.key_data.allowed_providers
|
||||
if hasattr(self.key_data, "allowed_api_formats"):
|
||||
update_data["allowed_api_formats"] = self.key_data.allowed_api_formats
|
||||
if hasattr(self.key_data, "allowed_models"):
|
||||
update_data["allowed_models"] = self.key_data.allowed_models
|
||||
|
||||
# 处理过期时间
|
||||
if self.key_data.expire_days is not None:
|
||||
if self.key_data.expire_days > 0:
|
||||
from datetime import timedelta
|
||||
|
||||
update_data["expires_at"] = datetime.now(timezone.utc) + timedelta(
|
||||
days=self.key_data.expire_days
|
||||
)
|
||||
else:
|
||||
# expire_days = 0 或负数表示永不过期
|
||||
update_data["expires_at"] = None
|
||||
elif hasattr(self.key_data, "expire_days") and self.key_data.expire_days is None:
|
||||
# 明确传递 None,设为永不过期
|
||||
update_data["expires_at"] = None
|
||||
|
||||
# 使用 ApiKeyService 更新
|
||||
updated_key = ApiKeyService.update_api_key(db, self.key_id, **update_data)
|
||||
if not updated_key:
|
||||
raise NotFoundException("更新失败", "api_key")
|
||||
|
||||
logger.info(f"管理员更新独立余额Key: ID {self.key_id}, 更新字段 {list(update_data.keys())}")
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="update_standalone_api_key",
|
||||
key_id=self.key_id,
|
||||
updated_fields=list(update_data.keys()),
|
||||
)
|
||||
|
||||
return {
|
||||
"id": updated_key.id,
|
||||
"name": updated_key.name,
|
||||
"key_display": updated_key.get_display_key(),
|
||||
"is_active": updated_key.is_active,
|
||||
"current_balance_usd": updated_key.current_balance_usd,
|
||||
"balance_used_usd": float(updated_key.balance_used_usd or 0),
|
||||
"rate_limit": updated_key.rate_limit,
|
||||
"expires_at": updated_key.expires_at.isoformat() if updated_key.expires_at else None,
|
||||
"updated_at": updated_key.updated_at.isoformat() if updated_key.updated_at else None,
|
||||
"message": "API密钥已更新",
|
||||
}
|
||||
|
||||
|
||||
class AdminToggleApiKeyAdapter(AdminApiAdapter):
|
||||
def __init__(self, key_id: str):
|
||||
self.key_id = key_id
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
api_key = db.query(ApiKey).filter(ApiKey.id == self.key_id).first()
|
||||
if not api_key:
|
||||
raise NotFoundException("API密钥不存在", "api_key")
|
||||
|
||||
api_key.is_active = not api_key.is_active
|
||||
api_key.updated_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
db.refresh(api_key)
|
||||
|
||||
logger.info(f"管理员切换API密钥状态: Key ID {self.key_id}, 新状态 {'启用' if api_key.is_active else '禁用'}")
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="toggle_api_key",
|
||||
target_key_id=api_key.id,
|
||||
user_id=api_key.user_id,
|
||||
new_status="enabled" if api_key.is_active else "disabled",
|
||||
)
|
||||
|
||||
return {
|
||||
"id": api_key.id,
|
||||
"is_active": api_key.is_active,
|
||||
"message": f"API密钥已{'启用' if api_key.is_active else '禁用'}",
|
||||
}
|
||||
|
||||
|
||||
class AdminDeleteApiKeyAdapter(AdminApiAdapter):
|
||||
def __init__(self, key_id: str):
|
||||
self.key_id = key_id
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
api_key = db.query(ApiKey).filter(ApiKey.id == self.key_id).first()
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=404, detail="API密钥不存在")
|
||||
|
||||
user = api_key.user
|
||||
db.delete(api_key)
|
||||
db.commit()
|
||||
|
||||
logger.info(f"管理员删除API密钥: Key ID {self.key_id}, 用户 {user.email if user else '未知'}")
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="delete_api_key",
|
||||
target_key_id=self.key_id,
|
||||
user_id=user.id if user else None,
|
||||
user_email=user.email if user else None,
|
||||
)
|
||||
return {"message": "API密钥已删除"}
|
||||
|
||||
|
||||
class AdminAddBalanceAdapter(AdminApiAdapter):
|
||||
"""为独立余额Key增加余额"""
|
||||
|
||||
def __init__(self, key_id: str, amount_usd: float):
|
||||
self.key_id = key_id
|
||||
self.amount_usd = amount_usd
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
|
||||
# 使用 ApiKeyService 增加余额
|
||||
updated_key = ApiKeyService.add_balance(db, self.key_id, self.amount_usd)
|
||||
|
||||
if not updated_key:
|
||||
raise NotFoundException("余额充值失败:Key不存在或不是独立余额Key", "api_key")
|
||||
|
||||
logger.info(f"管理员为独立余额Key充值: ID {self.key_id}, 充值 ${self.amount_usd:.4f}")
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="add_balance_to_key",
|
||||
key_id=self.key_id,
|
||||
amount_usd=self.amount_usd,
|
||||
new_current_balance=updated_key.current_balance_usd,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": updated_key.id,
|
||||
"name": updated_key.name,
|
||||
"current_balance_usd": updated_key.current_balance_usd,
|
||||
"balance_used_usd": float(updated_key.balance_used_usd or 0),
|
||||
"message": f"余额充值成功,充值 ${self.amount_usd:.2f},当前余额 ${updated_key.current_balance_usd:.2f}",
|
||||
}
|
||||
|
||||
|
||||
class AdminGetFullKeyAdapter(AdminApiAdapter):
|
||||
"""获取完整的API密钥"""
|
||||
|
||||
def __init__(self, key_id: str):
|
||||
self.key_id = key_id
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from src.core.crypto import crypto_service
|
||||
|
||||
db = context.db
|
||||
|
||||
# 查找API密钥
|
||||
api_key = db.query(ApiKey).filter(ApiKey.id == self.key_id).first()
|
||||
if not api_key:
|
||||
raise NotFoundException("API密钥不存在", "api_key")
|
||||
|
||||
# 解密完整密钥
|
||||
if not api_key.key_encrypted:
|
||||
raise HTTPException(status_code=400, detail="该密钥没有存储完整密钥信息")
|
||||
|
||||
try:
|
||||
full_key = crypto_service.decrypt(api_key.key_encrypted)
|
||||
except Exception as e:
|
||||
logger.error(f"解密API密钥失败: Key ID {self.key_id}, 错误: {e}")
|
||||
raise HTTPException(status_code=500, detail="解密密钥失败")
|
||||
|
||||
logger.info(f"管理员查看完整API密钥: Key ID {self.key_id}")
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="view_full_api_key",
|
||||
key_id=self.key_id,
|
||||
key_name=api_key.name,
|
||||
)
|
||||
|
||||
return {
|
||||
"key": full_key,
|
||||
}
|
||||
|
||||
|
||||
class AdminGetKeyDetailAdapter(AdminApiAdapter):
|
||||
"""Get API key detail without full key"""
|
||||
|
||||
def __init__(self, key_id: str):
|
||||
self.key_id = key_id
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
|
||||
api_key = db.query(ApiKey).filter(ApiKey.id == self.key_id).first()
|
||||
if not api_key:
|
||||
raise NotFoundException("API密钥不存在", "api_key")
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="get_api_key_detail",
|
||||
key_id=self.key_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": api_key.id,
|
||||
"user_id": api_key.user_id,
|
||||
"name": api_key.name,
|
||||
"key_display": api_key.get_display_key(),
|
||||
"is_active": api_key.is_active,
|
||||
"is_standalone": api_key.is_standalone,
|
||||
"current_balance_usd": api_key.current_balance_usd,
|
||||
"balance_used_usd": float(api_key.balance_used_usd or 0),
|
||||
"total_requests": api_key.total_requests,
|
||||
"total_cost_usd": float(api_key.total_cost_usd or 0),
|
||||
"rate_limit": api_key.rate_limit,
|
||||
"allowed_providers": api_key.allowed_providers,
|
||||
"allowed_api_formats": api_key.allowed_api_formats,
|
||||
"allowed_models": api_key.allowed_models,
|
||||
"last_used_at": api_key.last_used_at.isoformat() if api_key.last_used_at else None,
|
||||
"expires_at": api_key.expires_at.isoformat() if api_key.expires_at else None,
|
||||
"created_at": api_key.created_at.isoformat(),
|
||||
"updated_at": api_key.updated_at.isoformat() if api_key.updated_at else None,
|
||||
}
|
||||
24
src/api/admin/endpoints/__init__.py
Normal file
24
src/api/admin/endpoints/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Endpoint management API routers."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .concurrency import router as concurrency_router
|
||||
from .health import router as health_router
|
||||
from .keys import router as keys_router
|
||||
from .routes import router as routes_router
|
||||
|
||||
router = APIRouter(prefix="/api/admin/endpoints", tags=["Endpoint Management"])
|
||||
|
||||
# Endpoint CRUD
|
||||
router.include_router(routes_router)
|
||||
|
||||
# Endpoint Keys management
|
||||
router.include_router(keys_router)
|
||||
|
||||
# Health monitoring
|
||||
router.include_router(health_router)
|
||||
|
||||
# Concurrency control
|
||||
router.include_router(concurrency_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
116
src/api/admin/endpoints/concurrency.py
Normal file
116
src/api/admin/endpoints/concurrency.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
Endpoint 并发控制管理 API
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.exceptions import NotFoundException
|
||||
from src.database import get_db
|
||||
from src.models.database import ProviderAPIKey, ProviderEndpoint
|
||||
from src.models.endpoint_models import (
|
||||
ConcurrencyStatusResponse,
|
||||
ResetConcurrencyRequest,
|
||||
)
|
||||
from src.services.rate_limit.concurrency_manager import get_concurrency_manager
|
||||
|
||||
router = APIRouter(tags=["Concurrency Control"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("/concurrency/endpoint/{endpoint_id}", response_model=ConcurrencyStatusResponse)
|
||||
async def get_endpoint_concurrency(
|
||||
endpoint_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ConcurrencyStatusResponse:
|
||||
"""获取 Endpoint 当前并发状态"""
|
||||
adapter = AdminEndpointConcurrencyAdapter(endpoint_id=endpoint_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/concurrency/key/{key_id}", response_model=ConcurrencyStatusResponse)
|
||||
async def get_key_concurrency(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ConcurrencyStatusResponse:
|
||||
"""获取 Key 当前并发状态"""
|
||||
adapter = AdminKeyConcurrencyAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/concurrency")
|
||||
async def reset_concurrency(
|
||||
request: ResetConcurrencyRequest,
|
||||
http_request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Reset concurrency counters (admin function, use with caution)"""
|
||||
adapter = AdminResetConcurrencyAdapter(endpoint_id=request.endpoint_id, key_id=request.key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=http_request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# -------- Adapters --------
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminEndpointConcurrencyAdapter(AdminApiAdapter):
|
||||
endpoint_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
endpoint = (
|
||||
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
||||
)
|
||||
if not endpoint:
|
||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||
|
||||
concurrency_manager = await get_concurrency_manager()
|
||||
endpoint_count, _ = await concurrency_manager.get_current_concurrency(
|
||||
endpoint_id=self.endpoint_id
|
||||
)
|
||||
|
||||
return ConcurrencyStatusResponse(
|
||||
endpoint_id=self.endpoint_id,
|
||||
endpoint_current_concurrency=endpoint_count,
|
||||
endpoint_max_concurrent=endpoint.max_concurrent,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminKeyConcurrencyAdapter(AdminApiAdapter):
|
||||
key_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
|
||||
if not key:
|
||||
raise NotFoundException(f"Key {self.key_id} 不存在")
|
||||
|
||||
concurrency_manager = await get_concurrency_manager()
|
||||
_, key_count = await concurrency_manager.get_current_concurrency(key_id=self.key_id)
|
||||
|
||||
return ConcurrencyStatusResponse(
|
||||
key_id=self.key_id,
|
||||
key_current_concurrency=key_count,
|
||||
key_max_concurrent=key.max_concurrent,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminResetConcurrencyAdapter(AdminApiAdapter):
|
||||
endpoint_id: Optional[str]
|
||||
key_id: Optional[str]
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
concurrency_manager = await get_concurrency_manager()
|
||||
await concurrency_manager.reset_concurrency(
|
||||
endpoint_id=self.endpoint_id, key_id=self.key_id
|
||||
)
|
||||
return {"message": "并发计数已重置"}
|
||||
476
src/api/admin/endpoints/health.py
Normal file
476
src/api/admin/endpoints/health.py
Normal file
@@ -0,0 +1,476 @@
|
||||
"""
|
||||
Endpoint 健康监控 API
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Dict, List
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.exceptions import NotFoundException
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint, RequestCandidate
|
||||
from src.models.endpoint_models import (
|
||||
ApiFormatHealthMonitor,
|
||||
ApiFormatHealthMonitorResponse,
|
||||
EndpointHealthEvent,
|
||||
HealthStatusResponse,
|
||||
HealthSummaryResponse,
|
||||
)
|
||||
from src.services.health.endpoint import EndpointHealthService
|
||||
from src.services.health.monitor import health_monitor
|
||||
|
||||
router = APIRouter(tags=["Endpoint Health"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("/health/summary", response_model=HealthSummaryResponse)
|
||||
async def get_health_summary(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> HealthSummaryResponse:
|
||||
"""获取健康状态摘要"""
|
||||
adapter = AdminHealthSummaryAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/health/status")
|
||||
async def get_endpoint_health_status(
|
||||
request: Request,
|
||||
lookback_hours: int = Query(6, ge=1, le=72, description="回溯的小时数"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取端点健康状态(简化视图,与用户端点统一)
|
||||
|
||||
与 /health/api-formats 的区别:
|
||||
- /health/status: 返回聚合的时间线状态(50个时间段),基于 Usage 表
|
||||
- /health/api-formats: 返回详细的事件列表,基于 RequestCandidate 表
|
||||
"""
|
||||
adapter = AdminEndpointHealthStatusAdapter(lookback_hours=lookback_hours)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/health/api-formats", response_model=ApiFormatHealthMonitorResponse)
|
||||
async def get_api_format_health_monitor(
|
||||
request: Request,
|
||||
lookback_hours: int = Query(6, ge=1, le=72, description="回溯的小时数"),
|
||||
per_format_limit: int = Query(60, ge=10, le=200, description="每个 API 格式的事件数量"),
|
||||
db: Session = Depends(get_db),
|
||||
) -> ApiFormatHealthMonitorResponse:
|
||||
"""获取按 API 格式聚合的健康监控时间线(详细事件列表)"""
|
||||
adapter = AdminApiFormatHealthMonitorAdapter(
|
||||
lookback_hours=lookback_hours,
|
||||
per_format_limit=per_format_limit,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/health/key/{key_id}", response_model=HealthStatusResponse)
|
||||
async def get_key_health(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> HealthStatusResponse:
|
||||
"""获取 Key 健康状态"""
|
||||
adapter = AdminKeyHealthAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch("/health/keys/{key_id}")
|
||||
async def recover_key_health(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
Recover key health status
|
||||
|
||||
Resets health_score to 1.0, closes circuit breaker,
|
||||
cancels auto-disable, and resets all failure counts.
|
||||
|
||||
Parameters:
|
||||
- key_id: Key ID (path parameter)
|
||||
"""
|
||||
adapter = AdminRecoverKeyHealthAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch("/health/keys")
|
||||
async def recover_all_keys_health(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
Batch recover all circuit-broken keys
|
||||
|
||||
Finds all keys with circuit_breaker_open=True and:
|
||||
1. Resets health_score to 1.0
|
||||
2. Closes circuit breaker
|
||||
3. Resets failure counts
|
||||
"""
|
||||
adapter = AdminRecoverAllKeysHealthAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# -------- Adapters --------
|
||||
|
||||
|
||||
class AdminHealthSummaryAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
summary = health_monitor.get_all_health_status(context.db)
|
||||
return HealthSummaryResponse(**summary)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminEndpointHealthStatusAdapter(AdminApiAdapter):
|
||||
"""管理员端点健康状态适配器(与用户端点统一,但包含管理员字段)"""
|
||||
|
||||
lookback_hours: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from src.services.health.endpoint import EndpointHealthService
|
||||
|
||||
db = context.db
|
||||
|
||||
# 使用共享服务获取健康状态(管理员视图)
|
||||
result = EndpointHealthService.get_endpoint_health_by_format(
|
||||
db=db,
|
||||
lookback_hours=self.lookback_hours,
|
||||
include_admin_fields=True, # 包含管理员字段
|
||||
use_cache=False, # 管理员不使用缓存,确保实时性
|
||||
)
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="endpoint_health_status",
|
||||
format_count=len(result),
|
||||
lookback_hours=self.lookback_hours,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminApiFormatHealthMonitorAdapter(AdminApiAdapter):
|
||||
lookback_hours: int
|
||||
per_format_limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
now = datetime.now(timezone.utc)
|
||||
since = now - timedelta(hours=self.lookback_hours)
|
||||
|
||||
# 1. 获取所有活跃的 API 格式及其 Provider 数量
|
||||
active_formats = (
|
||||
db.query(
|
||||
ProviderEndpoint.api_format,
|
||||
func.count(func.distinct(ProviderEndpoint.provider_id)).label("provider_count"),
|
||||
)
|
||||
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
|
||||
.filter(
|
||||
ProviderEndpoint.is_active.is_(True),
|
||||
Provider.is_active.is_(True),
|
||||
)
|
||||
.group_by(ProviderEndpoint.api_format)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 构建所有格式的 provider_count 映射
|
||||
all_formats: Dict[str, int] = {}
|
||||
for api_format_enum, provider_count in active_formats:
|
||||
api_format = (
|
||||
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
|
||||
)
|
||||
all_formats[api_format] = provider_count
|
||||
|
||||
# 1.1 获取所有活跃的 API 格式及其 API Key 数量
|
||||
active_keys = (
|
||||
db.query(
|
||||
ProviderEndpoint.api_format,
|
||||
func.count(ProviderAPIKey.id).label("key_count"),
|
||||
)
|
||||
.join(ProviderAPIKey, ProviderEndpoint.id == ProviderAPIKey.endpoint_id)
|
||||
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
|
||||
.filter(
|
||||
ProviderEndpoint.is_active.is_(True),
|
||||
Provider.is_active.is_(True),
|
||||
ProviderAPIKey.is_active.is_(True),
|
||||
)
|
||||
.group_by(ProviderEndpoint.api_format)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 构建所有格式的 key_count 映射
|
||||
key_counts: Dict[str, int] = {}
|
||||
for api_format_enum, key_count in active_keys:
|
||||
api_format = (
|
||||
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
|
||||
)
|
||||
key_counts[api_format] = key_count
|
||||
|
||||
# 1.2 建立每个 API 格式对应的 Endpoint ID 列表,供 Usage 时间线生成使用
|
||||
endpoint_rows = (
|
||||
db.query(ProviderEndpoint.api_format, ProviderEndpoint.id)
|
||||
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
|
||||
.filter(
|
||||
ProviderEndpoint.is_active.is_(True),
|
||||
Provider.is_active.is_(True),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
endpoint_map: Dict[str, List[str]] = defaultdict(list)
|
||||
for api_format_enum, endpoint_id in endpoint_rows:
|
||||
api_format = (
|
||||
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
|
||||
)
|
||||
endpoint_map[api_format].append(endpoint_id)
|
||||
|
||||
# 2. 统计窗口内每个 API 格式的请求状态分布(真实统计)
|
||||
# 只统计最终状态:success, failed, skipped
|
||||
final_statuses = ["success", "failed", "skipped"]
|
||||
status_counts_query = (
|
||||
db.query(
|
||||
ProviderEndpoint.api_format,
|
||||
RequestCandidate.status,
|
||||
func.count(RequestCandidate.id).label("count"),
|
||||
)
|
||||
.join(RequestCandidate, ProviderEndpoint.id == RequestCandidate.endpoint_id)
|
||||
.filter(
|
||||
RequestCandidate.created_at >= since,
|
||||
RequestCandidate.status.in_(final_statuses),
|
||||
)
|
||||
.group_by(ProviderEndpoint.api_format, RequestCandidate.status)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 构建每个格式的状态统计
|
||||
status_counts: Dict[str, Dict[str, int]] = {}
|
||||
for api_format_enum, status, count in status_counts_query:
|
||||
api_format = (
|
||||
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
|
||||
)
|
||||
if api_format not in status_counts:
|
||||
status_counts[api_format] = {"success": 0, "failed": 0, "skipped": 0}
|
||||
status_counts[api_format][status] = count
|
||||
|
||||
# 3. 获取最近一段时间的 RequestCandidate(限制数量)
|
||||
# 使用上面定义的 final_statuses,排除中间状态
|
||||
limit_rows = max(500, self.per_format_limit * 10)
|
||||
rows = (
|
||||
db.query(
|
||||
RequestCandidate,
|
||||
ProviderEndpoint.api_format,
|
||||
ProviderEndpoint.provider_id,
|
||||
)
|
||||
.join(ProviderEndpoint, RequestCandidate.endpoint_id == ProviderEndpoint.id)
|
||||
.filter(
|
||||
RequestCandidate.created_at >= since,
|
||||
RequestCandidate.status.in_(final_statuses),
|
||||
)
|
||||
.order_by(RequestCandidate.created_at.desc())
|
||||
.limit(limit_rows)
|
||||
.all()
|
||||
)
|
||||
|
||||
grouped_attempts: Dict[str, List[RequestCandidate]] = {}
|
||||
|
||||
for attempt, api_format_enum, provider_id in rows:
|
||||
api_format = (
|
||||
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
|
||||
)
|
||||
if api_format not in grouped_attempts:
|
||||
grouped_attempts[api_format] = []
|
||||
|
||||
# 只保留每个 API 格式最近 per_format_limit 条记录
|
||||
if len(grouped_attempts[api_format]) < self.per_format_limit:
|
||||
grouped_attempts[api_format].append(attempt)
|
||||
|
||||
# 4. 为所有活跃格式生成监控数据(包括没有请求记录的)
|
||||
monitors: List[ApiFormatHealthMonitor] = []
|
||||
for api_format in all_formats:
|
||||
attempts = grouped_attempts.get(api_format, [])
|
||||
# 获取窗口内的真实统计数据
|
||||
# 只统计最终状态:success, failed, skipped
|
||||
# 中间状态(available, pending, used, started)不计入统计
|
||||
format_stats = status_counts.get(api_format, {"success": 0, "failed": 0, "skipped": 0})
|
||||
real_success_count = format_stats.get("success", 0)
|
||||
real_failed_count = format_stats.get("failed", 0)
|
||||
real_skipped_count = format_stats.get("skipped", 0)
|
||||
# total_attempts 只包含最终状态的请求数
|
||||
total_attempts = real_success_count + real_failed_count + real_skipped_count
|
||||
|
||||
# 时间线按时间正序
|
||||
attempts_sorted = list(reversed(attempts))
|
||||
events: List[EndpointHealthEvent] = []
|
||||
for attempt in attempts_sorted:
|
||||
event_timestamp = attempt.finished_at or attempt.started_at or attempt.created_at
|
||||
events.append(
|
||||
EndpointHealthEvent(
|
||||
timestamp=event_timestamp,
|
||||
status=attempt.status,
|
||||
status_code=attempt.status_code,
|
||||
latency_ms=attempt.latency_ms,
|
||||
error_type=attempt.error_type,
|
||||
error_message=attempt.error_message,
|
||||
)
|
||||
)
|
||||
|
||||
# 成功率 = success / (success + failed)
|
||||
# skipped 不算失败,不计入成功率分母
|
||||
# 无实际完成请求时成功率为 1.0(灰色状态)
|
||||
actual_completed = real_success_count + real_failed_count
|
||||
success_rate = real_success_count / actual_completed if actual_completed > 0 else 1.0
|
||||
last_event_at = events[-1].timestamp if events else None
|
||||
|
||||
# 生成 Usage 基于时间窗口的健康时间线
|
||||
timeline_data = EndpointHealthService._generate_timeline_from_usage(
|
||||
db=db,
|
||||
endpoint_ids=endpoint_map.get(api_format, []),
|
||||
now=now,
|
||||
lookback_hours=self.lookback_hours,
|
||||
)
|
||||
|
||||
monitors.append(
|
||||
ApiFormatHealthMonitor(
|
||||
api_format=api_format,
|
||||
total_attempts=total_attempts, # 真实总请求数
|
||||
success_count=real_success_count, # 真实成功数
|
||||
failed_count=real_failed_count, # 真实失败数
|
||||
skipped_count=real_skipped_count, # 真实跳过数
|
||||
success_rate=success_rate, # 基于真实统计的成功率
|
||||
provider_count=all_formats[api_format],
|
||||
key_count=key_counts.get(api_format, 0),
|
||||
last_event_at=last_event_at,
|
||||
events=events, # 限制为 per_format_limit 条(用于时间线显示)
|
||||
timeline=timeline_data.get("timeline", []),
|
||||
time_range_start=timeline_data.get("time_range_start"),
|
||||
time_range_end=timeline_data.get("time_range_end"),
|
||||
)
|
||||
)
|
||||
|
||||
response = ApiFormatHealthMonitorResponse(
|
||||
generated_at=now,
|
||||
formats=monitors,
|
||||
)
|
||||
context.add_audit_metadata(
|
||||
action="api_format_health_monitor",
|
||||
format_count=len(monitors),
|
||||
lookback_hours=self.lookback_hours,
|
||||
per_format_limit=self.per_format_limit,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminKeyHealthAdapter(AdminApiAdapter):
|
||||
key_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
health_data = health_monitor.get_key_health(context.db, self.key_id)
|
||||
if not health_data:
|
||||
raise NotFoundException(f"Key {self.key_id} 不存在")
|
||||
|
||||
return HealthStatusResponse(
|
||||
key_id=health_data["key_id"],
|
||||
key_health_score=health_data["health_score"],
|
||||
key_consecutive_failures=health_data["consecutive_failures"],
|
||||
key_last_failure_at=health_data["last_failure_at"],
|
||||
key_is_active=health_data["is_active"],
|
||||
key_statistics=health_data["statistics"],
|
||||
circuit_breaker_open=health_data["circuit_breaker_open"],
|
||||
circuit_breaker_open_at=health_data["circuit_breaker_open_at"],
|
||||
next_probe_at=health_data["next_probe_at"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminRecoverKeyHealthAdapter(AdminApiAdapter):
|
||||
key_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
|
||||
if not key:
|
||||
raise NotFoundException(f"Key {self.key_id} 不存在")
|
||||
|
||||
key.health_score = 1.0
|
||||
key.consecutive_failures = 0
|
||||
key.last_failure_at = None
|
||||
key.circuit_breaker_open = False
|
||||
key.circuit_breaker_open_at = None
|
||||
key.next_probe_at = None
|
||||
if not key.is_active:
|
||||
key.is_active = True
|
||||
|
||||
db.commit()
|
||||
|
||||
admin_name = context.user.username if context.user else "admin"
|
||||
logger.info(f"管理员恢复Key健康状态: {self.key_id} (health_score: 1.0, circuit_breaker: closed)")
|
||||
|
||||
return {
|
||||
"message": "Key已完全恢复",
|
||||
"details": {
|
||||
"health_score": 1.0,
|
||||
"circuit_breaker_open": False,
|
||||
"is_active": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class AdminRecoverAllKeysHealthAdapter(AdminApiAdapter):
|
||||
"""批量恢复所有熔断 Key 的健康状态"""
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
|
||||
# 查找所有熔断的 Key
|
||||
circuit_open_keys = (
|
||||
db.query(ProviderAPIKey).filter(ProviderAPIKey.circuit_breaker_open == True).all()
|
||||
)
|
||||
|
||||
if not circuit_open_keys:
|
||||
return {
|
||||
"message": "没有需要恢复的 Key",
|
||||
"recovered_count": 0,
|
||||
"recovered_keys": [],
|
||||
}
|
||||
|
||||
recovered_keys = []
|
||||
for key in circuit_open_keys:
|
||||
key.health_score = 1.0
|
||||
key.consecutive_failures = 0
|
||||
key.last_failure_at = None
|
||||
key.circuit_breaker_open = False
|
||||
key.circuit_breaker_open_at = None
|
||||
key.next_probe_at = None
|
||||
recovered_keys.append(
|
||||
{
|
||||
"key_id": key.id,
|
||||
"key_name": key.name,
|
||||
"endpoint_id": key.endpoint_id,
|
||||
}
|
||||
)
|
||||
|
||||
db.commit()
|
||||
|
||||
# 重置健康监控器的计数
|
||||
from src.services.health.monitor import HealthMonitor, health_open_circuits
|
||||
|
||||
HealthMonitor._open_circuit_keys = 0
|
||||
health_open_circuits.set(0)
|
||||
|
||||
admin_name = context.user.username if context.user else "admin"
|
||||
logger.info(f"管理员批量恢复 {len(recovered_keys)} 个 Key 的健康状态")
|
||||
|
||||
return {
|
||||
"message": f"已恢复 {len(recovered_keys)} 个 Key",
|
||||
"recovered_count": len(recovered_keys),
|
||||
"recovered_keys": recovered_keys,
|
||||
}
|
||||
425
src/api/admin/endpoints/keys.py
Normal file
425
src/api/admin/endpoints/keys.py
Normal file
@@ -0,0 +1,425 @@
|
||||
"""
|
||||
Endpoint API Keys 管理
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, List
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.crypto import crypto_service
|
||||
from src.core.exceptions import InvalidRequestException, NotFoundException
|
||||
from src.core.key_capabilities import get_capability
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
|
||||
from src.models.endpoint_models import (
|
||||
BatchUpdateKeyPriorityRequest,
|
||||
EndpointAPIKeyCreate,
|
||||
EndpointAPIKeyResponse,
|
||||
EndpointAPIKeyUpdate,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["Endpoint Keys"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("/{endpoint_id}/keys", response_model=List[EndpointAPIKeyResponse])
|
||||
async def list_endpoint_keys(
|
||||
endpoint_id: str,
|
||||
request: Request,
|
||||
skip: int = Query(0, ge=0, description="跳过的记录数"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回的最大记录数"),
|
||||
db: Session = Depends(get_db),
|
||||
) -> List[EndpointAPIKeyResponse]:
|
||||
"""获取 Endpoint 的所有 Keys"""
|
||||
adapter = AdminListEndpointKeysAdapter(
|
||||
endpoint_id=endpoint_id,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/{endpoint_id}/keys", response_model=EndpointAPIKeyResponse)
|
||||
async def add_endpoint_key(
|
||||
endpoint_id: str,
|
||||
key_data: EndpointAPIKeyCreate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> EndpointAPIKeyResponse:
|
||||
"""为 Endpoint 添加 Key"""
|
||||
adapter = AdminCreateEndpointKeyAdapter(endpoint_id=endpoint_id, key_data=key_data)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("/keys/{key_id}", response_model=EndpointAPIKeyResponse)
|
||||
async def update_endpoint_key(
|
||||
key_id: str,
|
||||
key_data: EndpointAPIKeyUpdate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> EndpointAPIKeyResponse:
|
||||
"""更新 Endpoint Key"""
|
||||
adapter = AdminUpdateEndpointKeyAdapter(key_id=key_id, key_data=key_data)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/keys/grouped-by-format")
|
||||
async def get_keys_grouped_by_format(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""获取按 API 格式分组的所有 Keys(用于全局优先级管理)"""
|
||||
adapter = AdminGetKeysGroupedByFormatAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/keys/{key_id}")
|
||||
async def delete_endpoint_key(
|
||||
key_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""删除 Endpoint Key"""
|
||||
adapter = AdminDeleteEndpointKeyAdapter(key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("/{endpoint_id}/keys/batch-priority")
|
||||
async def batch_update_key_priority(
|
||||
endpoint_id: str,
|
||||
request: Request,
|
||||
priority_data: BatchUpdateKeyPriorityRequest,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""批量更新 Endpoint 下 Keys 的优先级(用于拖动排序)"""
|
||||
adapter = AdminBatchUpdateKeyPriorityAdapter(endpoint_id=endpoint_id, priority_data=priority_data)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# -------- Adapters --------
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminListEndpointKeysAdapter(AdminApiAdapter):
|
||||
endpoint_id: str
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
endpoint = (
|
||||
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
||||
)
|
||||
if not endpoint:
|
||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||
|
||||
keys = (
|
||||
db.query(ProviderAPIKey)
|
||||
.filter(ProviderAPIKey.endpoint_id == self.endpoint_id)
|
||||
.order_by(ProviderAPIKey.internal_priority.asc(), ProviderAPIKey.created_at.asc())
|
||||
.offset(self.skip)
|
||||
.limit(self.limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
result: List[EndpointAPIKeyResponse] = []
|
||||
for key in keys:
|
||||
try:
|
||||
decrypted_key = crypto_service.decrypt(key.api_key)
|
||||
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
|
||||
except Exception:
|
||||
masked_key = "***ERROR***"
|
||||
|
||||
success_rate = key.success_count / key.request_count if key.request_count > 0 else 0.0
|
||||
avg_response_time_ms = (
|
||||
key.total_response_time_ms / key.success_count if key.success_count > 0 else 0.0
|
||||
)
|
||||
|
||||
is_adaptive = key.max_concurrent is None
|
||||
key_dict = key.__dict__.copy()
|
||||
key_dict.pop("_sa_instance_state", None)
|
||||
key_dict.update(
|
||||
{
|
||||
"api_key_masked": masked_key,
|
||||
"api_key_plain": None,
|
||||
"success_rate": success_rate,
|
||||
"avg_response_time_ms": round(avg_response_time_ms, 2),
|
||||
"is_adaptive": is_adaptive,
|
||||
"effective_limit": (
|
||||
key.learned_max_concurrent if is_adaptive else key.max_concurrent
|
||||
),
|
||||
}
|
||||
)
|
||||
result.append(EndpointAPIKeyResponse(**key_dict))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminCreateEndpointKeyAdapter(AdminApiAdapter):
|
||||
endpoint_id: str
|
||||
key_data: EndpointAPIKeyCreate
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
endpoint = (
|
||||
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
||||
)
|
||||
if not endpoint:
|
||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||
|
||||
if self.key_data.endpoint_id != self.endpoint_id:
|
||||
raise InvalidRequestException("endpoint_id 不匹配")
|
||||
|
||||
encrypted_key = crypto_service.encrypt(self.key_data.api_key)
|
||||
now = datetime.now(timezone.utc)
|
||||
# max_concurrent=NULL 表示自适应模式,数字表示固定限制
|
||||
new_key = ProviderAPIKey(
|
||||
id=str(uuid.uuid4()),
|
||||
endpoint_id=self.endpoint_id,
|
||||
api_key=encrypted_key,
|
||||
name=self.key_data.name,
|
||||
note=self.key_data.note,
|
||||
rate_multiplier=self.key_data.rate_multiplier,
|
||||
internal_priority=self.key_data.internal_priority,
|
||||
max_concurrent=self.key_data.max_concurrent, # NULL=自适应模式
|
||||
rate_limit=self.key_data.rate_limit,
|
||||
daily_limit=self.key_data.daily_limit,
|
||||
monthly_limit=self.key_data.monthly_limit,
|
||||
allowed_models=self.key_data.allowed_models if self.key_data.allowed_models else None,
|
||||
capabilities=self.key_data.capabilities if self.key_data.capabilities else None,
|
||||
request_count=0,
|
||||
success_count=0,
|
||||
error_count=0,
|
||||
total_response_time_ms=0,
|
||||
is_active=True,
|
||||
last_used_at=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
db.add(new_key)
|
||||
db.commit()
|
||||
db.refresh(new_key)
|
||||
|
||||
logger.info(f"[OK] 添加 Key: Endpoint={self.endpoint_id}, Key=***{self.key_data.api_key[-4:]}, ID={new_key.id}")
|
||||
|
||||
masked_key = f"{self.key_data.api_key[:8]}***{self.key_data.api_key[-4:]}"
|
||||
is_adaptive = new_key.max_concurrent is None
|
||||
response_dict = new_key.__dict__.copy()
|
||||
response_dict.pop("_sa_instance_state", None)
|
||||
response_dict.update(
|
||||
{
|
||||
"api_key_masked": masked_key,
|
||||
"api_key_plain": self.key_data.api_key,
|
||||
"success_rate": 0.0,
|
||||
"avg_response_time_ms": 0.0,
|
||||
"is_adaptive": is_adaptive,
|
||||
"effective_limit": (
|
||||
new_key.learned_max_concurrent if is_adaptive else new_key.max_concurrent
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
return EndpointAPIKeyResponse(**response_dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
|
||||
key_id: str
|
||||
key_data: EndpointAPIKeyUpdate
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
|
||||
if not key:
|
||||
raise NotFoundException(f"Key {self.key_id} 不存在")
|
||||
|
||||
update_data = self.key_data.model_dump(exclude_unset=True)
|
||||
if "api_key" in update_data:
|
||||
update_data["api_key"] = crypto_service.encrypt(update_data["api_key"])
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(key, field, value)
|
||||
key.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
db.commit()
|
||||
db.refresh(key)
|
||||
|
||||
logger.info(f"[OK] 更新 Key: ID={self.key_id}, Updates={list(update_data.keys())}")
|
||||
|
||||
try:
|
||||
decrypted_key = crypto_service.decrypt(key.api_key)
|
||||
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
|
||||
except Exception:
|
||||
masked_key = "***ERROR***"
|
||||
|
||||
success_rate = key.success_count / key.request_count if key.request_count > 0 else 0.0
|
||||
avg_response_time_ms = (
|
||||
key.total_response_time_ms / key.success_count if key.success_count > 0 else 0.0
|
||||
)
|
||||
|
||||
is_adaptive = key.max_concurrent is None
|
||||
response_dict = key.__dict__.copy()
|
||||
response_dict.pop("_sa_instance_state", None)
|
||||
response_dict.update(
|
||||
{
|
||||
"api_key_masked": masked_key,
|
||||
"api_key_plain": None,
|
||||
"success_rate": success_rate,
|
||||
"avg_response_time_ms": round(avg_response_time_ms, 2),
|
||||
"is_adaptive": is_adaptive,
|
||||
"effective_limit": (
|
||||
key.learned_max_concurrent if is_adaptive else key.max_concurrent
|
||||
),
|
||||
}
|
||||
)
|
||||
return EndpointAPIKeyResponse(**response_dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminDeleteEndpointKeyAdapter(AdminApiAdapter):
|
||||
key_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
|
||||
if not key:
|
||||
raise NotFoundException(f"Key {self.key_id} 不存在")
|
||||
|
||||
endpoint_id = key.endpoint_id
|
||||
try:
|
||||
db.delete(key)
|
||||
db.commit()
|
||||
except Exception as exc:
|
||||
db.rollback()
|
||||
logger.error(f"删除 Key 失败: ID={self.key_id}, Error={exc}")
|
||||
raise
|
||||
|
||||
logger.warning(f"[DELETE] 删除 Key: ID={self.key_id}, Endpoint={endpoint_id}")
|
||||
return {"message": f"Key {self.key_id} 已删除"}
|
||||
|
||||
|
||||
class AdminGetKeysGroupedByFormatAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
|
||||
keys = (
|
||||
db.query(ProviderAPIKey, ProviderEndpoint, Provider)
|
||||
.join(ProviderEndpoint, ProviderAPIKey.endpoint_id == ProviderEndpoint.id)
|
||||
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
|
||||
.filter(
|
||||
ProviderAPIKey.is_active.is_(True),
|
||||
ProviderEndpoint.is_active.is_(True),
|
||||
Provider.is_active.is_(True),
|
||||
)
|
||||
.order_by(
|
||||
ProviderAPIKey.global_priority.asc().nullslast(), ProviderAPIKey.internal_priority.asc()
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
grouped: Dict[str, List[dict]] = {}
|
||||
for key, endpoint, provider in keys:
|
||||
api_format = endpoint.api_format
|
||||
if api_format not in grouped:
|
||||
grouped[api_format] = []
|
||||
|
||||
try:
|
||||
decrypted_key = crypto_service.decrypt(key.api_key)
|
||||
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
|
||||
except Exception:
|
||||
masked_key = "***ERROR***"
|
||||
|
||||
# 计算健康度指标
|
||||
success_rate = key.success_count / key.request_count if key.request_count > 0 else None
|
||||
avg_response_time_ms = (
|
||||
round(key.total_response_time_ms / key.success_count, 2)
|
||||
if key.success_count > 0
|
||||
else None
|
||||
)
|
||||
|
||||
# 将 capabilities dict 转换为启用的能力简短名称列表
|
||||
caps_list = []
|
||||
if key.capabilities:
|
||||
for cap_name, enabled in key.capabilities.items():
|
||||
if enabled:
|
||||
cap_def = get_capability(cap_name)
|
||||
caps_list.append(cap_def.short_name if cap_def else cap_name)
|
||||
|
||||
grouped[api_format].append(
|
||||
{
|
||||
"id": key.id,
|
||||
"name": key.name,
|
||||
"api_key_masked": masked_key,
|
||||
"internal_priority": key.internal_priority,
|
||||
"global_priority": key.global_priority,
|
||||
"rate_multiplier": key.rate_multiplier,
|
||||
"is_active": key.is_active,
|
||||
"circuit_breaker_open": key.circuit_breaker_open,
|
||||
"provider_name": provider.display_name or provider.name,
|
||||
"endpoint_base_url": endpoint.base_url,
|
||||
"api_format": api_format,
|
||||
"capabilities": caps_list,
|
||||
"success_rate": success_rate,
|
||||
"avg_response_time_ms": avg_response_time_ms,
|
||||
"request_count": key.request_count,
|
||||
}
|
||||
)
|
||||
|
||||
# 直接返回分组对象,供前端使用
|
||||
return grouped
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminBatchUpdateKeyPriorityAdapter(AdminApiAdapter):
|
||||
endpoint_id: str
|
||||
priority_data: BatchUpdateKeyPriorityRequest
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
endpoint = (
|
||||
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
||||
)
|
||||
if not endpoint:
|
||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||
|
||||
# 获取所有需要更新的 Key ID
|
||||
key_ids = [item.key_id for item in self.priority_data.priorities]
|
||||
|
||||
# 验证所有 Key 都属于该 Endpoint
|
||||
keys = (
|
||||
db.query(ProviderAPIKey)
|
||||
.filter(
|
||||
ProviderAPIKey.id.in_(key_ids),
|
||||
ProviderAPIKey.endpoint_id == self.endpoint_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if len(keys) != len(key_ids):
|
||||
found_ids = {k.id for k in keys}
|
||||
missing_ids = set(key_ids) - found_ids
|
||||
raise InvalidRequestException(f"Keys 不属于该 Endpoint 或不存在: {missing_ids}")
|
||||
|
||||
# 批量更新优先级
|
||||
key_map = {k.id: k for k in keys}
|
||||
updated_count = 0
|
||||
for item in self.priority_data.priorities:
|
||||
key = key_map.get(item.key_id)
|
||||
if key and key.internal_priority != item.internal_priority:
|
||||
key.internal_priority = item.internal_priority
|
||||
key.updated_at = datetime.now(timezone.utc)
|
||||
updated_count += 1
|
||||
|
||||
db.commit()
|
||||
|
||||
logger.info(f"[OK] 批量更新 Key 优先级: Endpoint={self.endpoint_id}, Updated={updated_count}/{len(key_ids)}")
|
||||
return {"message": f"已更新 {updated_count} 个 Key 的优先级", "updated_count": updated_count}
|
||||
345
src/api/admin/endpoints/routes.py
Normal file
345
src/api/admin/endpoints/routes.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""
|
||||
ProviderEndpoint CRUD 管理 API
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from sqlalchemy import and_, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.exceptions import InvalidRequestException, NotFoundException
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
|
||||
from src.models.endpoint_models import (
|
||||
ProviderEndpointCreate,
|
||||
ProviderEndpointResponse,
|
||||
ProviderEndpointUpdate,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["Endpoint Management"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("/providers/{provider_id}/endpoints", response_model=List[ProviderEndpointResponse])
|
||||
async def list_provider_endpoints(
|
||||
provider_id: str,
|
||||
request: Request,
|
||||
skip: int = Query(0, ge=0, description="跳过的记录数"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回的最大记录数"),
|
||||
db: Session = Depends(get_db),
|
||||
) -> List[ProviderEndpointResponse]:
|
||||
"""获取指定 Provider 的所有 Endpoints"""
|
||||
adapter = AdminListProviderEndpointsAdapter(
|
||||
provider_id=provider_id,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/providers/{provider_id}/endpoints", response_model=ProviderEndpointResponse)
|
||||
async def create_provider_endpoint(
|
||||
provider_id: str,
|
||||
endpoint_data: ProviderEndpointCreate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ProviderEndpointResponse:
|
||||
"""为 Provider 创建新的 Endpoint"""
|
||||
adapter = AdminCreateProviderEndpointAdapter(
|
||||
provider_id=provider_id,
|
||||
endpoint_data=endpoint_data,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/{endpoint_id}", response_model=ProviderEndpointResponse)
|
||||
async def get_endpoint(
|
||||
endpoint_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ProviderEndpointResponse:
|
||||
"""获取 Endpoint 详情"""
|
||||
adapter = AdminGetProviderEndpointAdapter(endpoint_id=endpoint_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("/{endpoint_id}", response_model=ProviderEndpointResponse)
|
||||
async def update_endpoint(
|
||||
endpoint_id: str,
|
||||
endpoint_data: ProviderEndpointUpdate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ProviderEndpointResponse:
|
||||
"""更新 Endpoint"""
|
||||
adapter = AdminUpdateProviderEndpointAdapter(
|
||||
endpoint_id=endpoint_id,
|
||||
endpoint_data=endpoint_data,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/{endpoint_id}")
|
||||
async def delete_endpoint(
|
||||
endpoint_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""删除 Endpoint(级联删除所有关联的 Keys)"""
|
||||
adapter = AdminDeleteProviderEndpointAdapter(endpoint_id=endpoint_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# -------- Adapters --------
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminListProviderEndpointsAdapter(AdminApiAdapter):
|
||||
provider_id: str
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException(f"Provider {self.provider_id} 不存在")
|
||||
|
||||
endpoints = (
|
||||
db.query(ProviderEndpoint)
|
||||
.filter(ProviderEndpoint.provider_id == self.provider_id)
|
||||
.order_by(ProviderEndpoint.created_at.desc())
|
||||
.offset(self.skip)
|
||||
.limit(self.limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
endpoint_ids = [ep.id for ep in endpoints]
|
||||
total_keys_map = {}
|
||||
active_keys_map = {}
|
||||
if endpoint_ids:
|
||||
total_rows = (
|
||||
db.query(ProviderAPIKey.endpoint_id, func.count(ProviderAPIKey.id).label("total"))
|
||||
.filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids))
|
||||
.group_by(ProviderAPIKey.endpoint_id)
|
||||
.all()
|
||||
)
|
||||
total_keys_map = {row.endpoint_id: row.total for row in total_rows}
|
||||
|
||||
active_rows = (
|
||||
db.query(ProviderAPIKey.endpoint_id, func.count(ProviderAPIKey.id).label("active"))
|
||||
.filter(
|
||||
and_(
|
||||
ProviderAPIKey.endpoint_id.in_(endpoint_ids),
|
||||
ProviderAPIKey.is_active.is_(True),
|
||||
)
|
||||
)
|
||||
.group_by(ProviderAPIKey.endpoint_id)
|
||||
.all()
|
||||
)
|
||||
active_keys_map = {row.endpoint_id: row.active for row in active_rows}
|
||||
|
||||
result: List[ProviderEndpointResponse] = []
|
||||
for endpoint in endpoints:
|
||||
endpoint_dict = {
|
||||
**endpoint.__dict__,
|
||||
"provider_name": provider.name,
|
||||
"api_format": endpoint.api_format,
|
||||
"total_keys": total_keys_map.get(endpoint.id, 0),
|
||||
"active_keys": active_keys_map.get(endpoint.id, 0),
|
||||
}
|
||||
endpoint_dict.pop("_sa_instance_state", None)
|
||||
result.append(ProviderEndpointResponse(**endpoint_dict))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminCreateProviderEndpointAdapter(AdminApiAdapter):
|
||||
provider_id: str
|
||||
endpoint_data: ProviderEndpointCreate
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException(f"Provider {self.provider_id} 不存在")
|
||||
|
||||
if self.endpoint_data.provider_id != self.provider_id:
|
||||
raise InvalidRequestException("provider_id 不匹配")
|
||||
|
||||
existing = (
|
||||
db.query(ProviderEndpoint)
|
||||
.filter(
|
||||
and_(
|
||||
ProviderEndpoint.provider_id == self.provider_id,
|
||||
ProviderEndpoint.api_format == self.endpoint_data.api_format,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise InvalidRequestException(
|
||||
f"Provider {provider.name} 已存在 {self.endpoint_data.api_format} 格式的 Endpoint"
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
new_endpoint = ProviderEndpoint(
|
||||
id=str(uuid.uuid4()),
|
||||
provider_id=self.provider_id,
|
||||
api_format=self.endpoint_data.api_format,
|
||||
base_url=self.endpoint_data.base_url,
|
||||
headers=self.endpoint_data.headers,
|
||||
timeout=self.endpoint_data.timeout,
|
||||
max_retries=self.endpoint_data.max_retries,
|
||||
max_concurrent=self.endpoint_data.max_concurrent,
|
||||
rate_limit=self.endpoint_data.rate_limit,
|
||||
is_active=True,
|
||||
config=self.endpoint_data.config,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
db.add(new_endpoint)
|
||||
db.commit()
|
||||
db.refresh(new_endpoint)
|
||||
|
||||
logger.info(f"[OK] 创建 Endpoint: Provider={provider.name}, Format={self.endpoint_data.api_format}, ID={new_endpoint.id}")
|
||||
|
||||
endpoint_dict = {
|
||||
k: v
|
||||
for k, v in new_endpoint.__dict__.items()
|
||||
if k not in {"api_format", "_sa_instance_state"}
|
||||
}
|
||||
return ProviderEndpointResponse(
|
||||
**endpoint_dict,
|
||||
provider_name=provider.name,
|
||||
api_format=new_endpoint.api_format,
|
||||
total_keys=0,
|
||||
active_keys=0,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminGetProviderEndpointAdapter(AdminApiAdapter):
|
||||
endpoint_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
endpoint = (
|
||||
db.query(ProviderEndpoint, Provider)
|
||||
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
|
||||
.filter(ProviderEndpoint.id == self.endpoint_id)
|
||||
.first()
|
||||
)
|
||||
if not endpoint:
|
||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||
|
||||
endpoint_obj, provider = endpoint
|
||||
total_keys = (
|
||||
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
|
||||
)
|
||||
active_keys = (
|
||||
db.query(ProviderAPIKey)
|
||||
.filter(
|
||||
and_(
|
||||
ProviderAPIKey.endpoint_id == self.endpoint_id,
|
||||
ProviderAPIKey.is_active.is_(True),
|
||||
)
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
endpoint_dict = {
|
||||
k: v
|
||||
for k, v in endpoint_obj.__dict__.items()
|
||||
if k not in {"api_format", "_sa_instance_state"}
|
||||
}
|
||||
return ProviderEndpointResponse(
|
||||
**endpoint_dict,
|
||||
provider_name=provider.name,
|
||||
api_format=endpoint_obj.api_format,
|
||||
total_keys=total_keys,
|
||||
active_keys=active_keys,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminUpdateProviderEndpointAdapter(AdminApiAdapter):
|
||||
endpoint_id: str
|
||||
endpoint_data: ProviderEndpointUpdate
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
endpoint = (
|
||||
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
||||
)
|
||||
if not endpoint:
|
||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||
|
||||
update_data = self.endpoint_data.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(endpoint, field, value)
|
||||
endpoint.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
db.commit()
|
||||
db.refresh(endpoint)
|
||||
|
||||
provider = db.query(Provider).filter(Provider.id == endpoint.provider_id).first()
|
||||
logger.info(f"[OK] 更新 Endpoint: ID={self.endpoint_id}, Updates={list(update_data.keys())}")
|
||||
|
||||
total_keys = (
|
||||
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
|
||||
)
|
||||
active_keys = (
|
||||
db.query(ProviderAPIKey)
|
||||
.filter(
|
||||
and_(
|
||||
ProviderAPIKey.endpoint_id == self.endpoint_id,
|
||||
ProviderAPIKey.is_active.is_(True),
|
||||
)
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
endpoint_dict = {
|
||||
k: v
|
||||
for k, v in endpoint.__dict__.items()
|
||||
if k not in {"api_format", "_sa_instance_state"}
|
||||
}
|
||||
return ProviderEndpointResponse(
|
||||
**endpoint_dict,
|
||||
provider_name=provider.name if provider else "Unknown",
|
||||
api_format=endpoint.api_format,
|
||||
total_keys=total_keys,
|
||||
active_keys=active_keys,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminDeleteProviderEndpointAdapter(AdminApiAdapter):
|
||||
endpoint_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
endpoint = (
|
||||
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
|
||||
)
|
||||
if not endpoint:
|
||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||
|
||||
keys_count = (
|
||||
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
|
||||
)
|
||||
db.delete(endpoint)
|
||||
db.commit()
|
||||
|
||||
logger.warning(f"[DELETE] 删除 Endpoint: ID={self.endpoint_id}, 同时删除了 {keys_count} 个 Keys")
|
||||
|
||||
return {"message": f"Endpoint {self.endpoint_id} 已删除", "deleted_keys_count": keys_count}
|
||||
16
src/api/admin/models/__init__.py
Normal file
16
src/api/admin/models/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
模型管理相关 Admin API
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .catalog import router as catalog_router
|
||||
from .global_models import router as global_models_router
|
||||
from .mappings import router as mappings_router
|
||||
|
||||
router = APIRouter(prefix="/api/admin/models", tags=["Admin - Model Management"])
|
||||
|
||||
# 挂载子路由
|
||||
router.include_router(catalog_router)
|
||||
router.include_router(global_models_router)
|
||||
router.include_router(mappings_router)
|
||||
432
src/api/admin/models/catalog.py
Normal file
432
src/api/admin/models/catalog.py
Normal file
@@ -0,0 +1,432 @@
|
||||
"""
|
||||
统一模型目录 Admin API
|
||||
|
||||
阶段一:基于 ModelMapping 和 Model 的聚合视图
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy import func, or_
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import GlobalModel, Model, ModelMapping, Provider
|
||||
from src.models.pydantic_models import (
|
||||
BatchAssignError,
|
||||
BatchAssignModelMappingRequest,
|
||||
BatchAssignModelMappingResponse,
|
||||
BatchAssignProviderResult,
|
||||
DeleteModelMappingResponse,
|
||||
ModelCapabilities,
|
||||
ModelCatalogItem,
|
||||
ModelCatalogProviderDetail,
|
||||
ModelCatalogResponse,
|
||||
ModelPriceRange,
|
||||
OrphanedModel,
|
||||
UpdateModelMappingRequest,
|
||||
UpdateModelMappingResponse,
|
||||
)
|
||||
from src.services.cache.invalidation import get_cache_invalidation_service
|
||||
from src.services.model.service import ModelService
|
||||
|
||||
router = APIRouter(prefix="/catalog", tags=["Admin - Model Catalog"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("", response_model=ModelCatalogResponse)
|
||||
async def get_model_catalog(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ModelCatalogResponse:
|
||||
adapter = AdminGetModelCatalogAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/batch-assign", response_model=BatchAssignModelMappingResponse)
|
||||
async def batch_assign_model_mappings(
|
||||
request: Request,
|
||||
payload: BatchAssignModelMappingRequest,
|
||||
db: Session = Depends(get_db),
|
||||
) -> BatchAssignModelMappingResponse:
|
||||
adapter = AdminBatchAssignModelMappingsAdapter(payload=payload)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminGetModelCatalogAdapter(AdminApiAdapter):
|
||||
"""管理员查询统一模型目录
|
||||
|
||||
新架构说明:
|
||||
1. 以 GlobalModel 为中心聚合数据
|
||||
2. ModelMapping 表提供别名信息(provider_id=NULL 表示全局)
|
||||
3. Model 表提供关联提供商和价格
|
||||
"""
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db: Session = context.db
|
||||
|
||||
# 1. 获取所有活跃的 GlobalModel
|
||||
global_models: List[GlobalModel] = (
|
||||
db.query(GlobalModel).filter(GlobalModel.is_active == True).all()
|
||||
)
|
||||
|
||||
# 2. 获取所有活跃的别名(含全局和 Provider 特定)
|
||||
aliases_rows: List[ModelMapping] = (
|
||||
db.query(ModelMapping)
|
||||
.options(joinedload(ModelMapping.target_global_model))
|
||||
.filter(
|
||||
ModelMapping.is_active == True,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 按 GlobalModel ID 组织别名
|
||||
aliases_by_global_model: Dict[str, List[str]] = {}
|
||||
for alias_row in aliases_rows:
|
||||
if not alias_row.target_global_model_id:
|
||||
continue
|
||||
gm_id = alias_row.target_global_model_id
|
||||
if gm_id not in aliases_by_global_model:
|
||||
aliases_by_global_model[gm_id] = []
|
||||
if alias_row.source_model not in aliases_by_global_model[gm_id]:
|
||||
aliases_by_global_model[gm_id].append(alias_row.source_model)
|
||||
|
||||
# 3. 获取所有活跃的 Model 实现(包含 global_model 以便计算有效价格)
|
||||
models: List[Model] = (
|
||||
db.query(Model)
|
||||
.options(joinedload(Model.provider), joinedload(Model.global_model))
|
||||
.filter(Model.is_active == True)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 按 GlobalModel ID 组织关联提供商
|
||||
models_by_global_model: Dict[str, List[Model]] = {}
|
||||
for model in models:
|
||||
if model.global_model_id:
|
||||
models_by_global_model.setdefault(model.global_model_id, []).append(model)
|
||||
|
||||
# 4. 为每个 GlobalModel 构建 catalog item
|
||||
catalog_items: List[ModelCatalogItem] = []
|
||||
|
||||
for gm in global_models:
|
||||
gm_id = gm.id
|
||||
provider_entries: List[ModelCatalogProviderDetail] = []
|
||||
capability_flags = {
|
||||
"supports_vision": gm.default_supports_vision or False,
|
||||
"supports_function_calling": gm.default_supports_function_calling or False,
|
||||
"supports_streaming": gm.default_supports_streaming or False,
|
||||
}
|
||||
|
||||
# 遍历该 GlobalModel 的所有关联提供商
|
||||
for model in models_by_global_model.get(gm_id, []):
|
||||
provider = model.provider
|
||||
if not provider:
|
||||
continue
|
||||
|
||||
# 使用有效价格(考虑 GlobalModel 默认值)
|
||||
effective_input = model.get_effective_input_price()
|
||||
effective_output = model.get_effective_output_price()
|
||||
effective_tiered = model.get_effective_tiered_pricing()
|
||||
tier_count = len(effective_tiered.get("tiers", [])) if effective_tiered else 1
|
||||
|
||||
# 使用有效能力值
|
||||
capability_flags["supports_vision"] = (
|
||||
capability_flags["supports_vision"] or model.get_effective_supports_vision()
|
||||
)
|
||||
capability_flags["supports_function_calling"] = (
|
||||
capability_flags["supports_function_calling"]
|
||||
or model.get_effective_supports_function_calling()
|
||||
)
|
||||
capability_flags["supports_streaming"] = (
|
||||
capability_flags["supports_streaming"]
|
||||
or model.get_effective_supports_streaming()
|
||||
)
|
||||
|
||||
provider_entries.append(
|
||||
ModelCatalogProviderDetail(
|
||||
provider_id=provider.id,
|
||||
provider_name=provider.name,
|
||||
provider_display_name=provider.display_name,
|
||||
model_id=model.id,
|
||||
target_model=model.provider_model_name,
|
||||
# 显示有效价格
|
||||
input_price_per_1m=effective_input,
|
||||
output_price_per_1m=effective_output,
|
||||
cache_creation_price_per_1m=model.get_effective_cache_creation_price(),
|
||||
cache_read_price_per_1m=model.get_effective_cache_read_price(),
|
||||
cache_1h_creation_price_per_1m=model.get_effective_1h_cache_creation_price(),
|
||||
price_per_request=model.get_effective_price_per_request(),
|
||||
effective_tiered_pricing=effective_tiered,
|
||||
tier_count=tier_count,
|
||||
supports_vision=model.get_effective_supports_vision(),
|
||||
supports_function_calling=model.get_effective_supports_function_calling(),
|
||||
supports_streaming=model.get_effective_supports_streaming(),
|
||||
is_active=bool(model.is_active),
|
||||
mapping_id=None, # 新架构中不再有 mapping_id
|
||||
)
|
||||
)
|
||||
|
||||
# 模型目录显示 GlobalModel 的第一个阶梯价格(不是 Provider 聚合价格)
|
||||
tiered = gm.default_tiered_pricing or {}
|
||||
first_tier = tiered.get("tiers", [{}])[0] if tiered.get("tiers") else {}
|
||||
price_range = ModelPriceRange(
|
||||
min_input=first_tier.get("input_price_per_1m", 0),
|
||||
max_input=first_tier.get("input_price_per_1m", 0),
|
||||
min_output=first_tier.get("output_price_per_1m", 0),
|
||||
max_output=first_tier.get("output_price_per_1m", 0),
|
||||
)
|
||||
|
||||
catalog_items.append(
|
||||
ModelCatalogItem(
|
||||
global_model_name=gm.name,
|
||||
display_name=gm.display_name,
|
||||
description=gm.description,
|
||||
aliases=aliases_by_global_model.get(gm_id, []),
|
||||
providers=provider_entries,
|
||||
price_range=price_range,
|
||||
total_providers=len(provider_entries),
|
||||
capabilities=ModelCapabilities(**capability_flags),
|
||||
)
|
||||
)
|
||||
|
||||
# 5. 查找孤立的别名(别名指向的 GlobalModel 不存在或不活跃)
|
||||
orphaned_rows = (
|
||||
db.query(ModelMapping.source_model, GlobalModel.name, func.count(ModelMapping.id))
|
||||
.outerjoin(GlobalModel, ModelMapping.target_global_model_id == GlobalModel.id)
|
||||
.filter(
|
||||
ModelMapping.is_active == True,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
or_(GlobalModel.id == None, GlobalModel.is_active == False),
|
||||
)
|
||||
.group_by(ModelMapping.source_model, GlobalModel.name)
|
||||
.all()
|
||||
)
|
||||
orphaned_models = [
|
||||
OrphanedModel(alias=row[0], global_model_name=row[1], mapping_count=row[2])
|
||||
for row in orphaned_rows
|
||||
if row[0]
|
||||
]
|
||||
|
||||
return ModelCatalogResponse(
|
||||
models=catalog_items,
|
||||
total=len(catalog_items),
|
||||
orphaned_models=orphaned_models,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminBatchAssignModelMappingsAdapter(AdminApiAdapter):
|
||||
payload: BatchAssignModelMappingRequest
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db: Session = context.db
|
||||
created: List[BatchAssignProviderResult] = []
|
||||
errors: List[BatchAssignError] = []
|
||||
|
||||
for provider_config in self.payload.providers:
|
||||
provider_id = provider_config.provider_id
|
||||
try:
|
||||
provider: Provider = db.query(Provider).filter(Provider.id == provider_id).first()
|
||||
if not provider:
|
||||
errors.append(
|
||||
BatchAssignError(provider_id=provider_id, error="Provider 不存在")
|
||||
)
|
||||
continue
|
||||
|
||||
model_id: Optional[str] = None
|
||||
created_model = False
|
||||
|
||||
if provider_config.create_model:
|
||||
model_data = provider_config.model_data
|
||||
if not model_data:
|
||||
errors.append(
|
||||
BatchAssignError(provider_id=provider_id, error="缺少 model_data 配置")
|
||||
)
|
||||
continue
|
||||
|
||||
existing_model = ModelService.get_model_by_name(
|
||||
db, provider_id, model_data.provider_model_name
|
||||
)
|
||||
if existing_model:
|
||||
model_id = existing_model.id
|
||||
logger.info("模型 %s 已存在于 Provider %s,复用现有模型",
|
||||
model_data.provider_model_name,
|
||||
provider.name,
|
||||
)
|
||||
else:
|
||||
model = ModelService.create_model(db, provider_id, model_data)
|
||||
model_id = model.id
|
||||
created_model = True
|
||||
else:
|
||||
model_id = provider_config.model_id
|
||||
if not model_id:
|
||||
errors.append(
|
||||
BatchAssignError(provider_id=provider_id, error="缺少 model_id")
|
||||
)
|
||||
continue
|
||||
model = (
|
||||
db.query(Model)
|
||||
.filter(Model.id == model_id, Model.provider_id == provider_id)
|
||||
.first()
|
||||
)
|
||||
if not model:
|
||||
errors.append(
|
||||
BatchAssignError(
|
||||
provider_id=provider_id, error="模型不存在或不属于当前 Provider")
|
||||
)
|
||||
continue
|
||||
|
||||
# 批量分配功能需要适配 GlobalModel 架构
|
||||
# 参见 docs/optimization-backlog.md 中的待办项
|
||||
errors.append(
|
||||
BatchAssignError(
|
||||
provider_id=provider_id,
|
||||
error="批量分配功能暂时不可用,需要适配新的 GlobalModel 架构",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
except Exception as exc:
|
||||
db.rollback()
|
||||
logger.error("批量添加模型映射失败(需要适配新架构)")
|
||||
errors.append(BatchAssignError(provider_id=provider_id, error=str(exc)))
|
||||
|
||||
return BatchAssignModelMappingResponse(
|
||||
success=len(created) > 0,
|
||||
created_mappings=created,
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/mappings/{mapping_id}", response_model=UpdateModelMappingResponse)
|
||||
async def update_model_mapping(
|
||||
request: Request,
|
||||
mapping_id: str,
|
||||
payload: UpdateModelMappingRequest,
|
||||
db: Session = Depends(get_db),
|
||||
) -> UpdateModelMappingResponse:
|
||||
"""更新模型映射"""
|
||||
adapter = AdminUpdateModelMappingAdapter(mapping_id=mapping_id, payload=payload)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/mappings/{mapping_id}", response_model=DeleteModelMappingResponse)
|
||||
async def delete_model_mapping(
|
||||
request: Request,
|
||||
mapping_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
) -> DeleteModelMappingResponse:
|
||||
"""删除模型映射"""
|
||||
adapter = AdminDeleteModelMappingAdapter(mapping_id=mapping_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminUpdateModelMappingAdapter(AdminApiAdapter):
|
||||
"""更新模型映射"""
|
||||
|
||||
mapping_id: str
|
||||
payload: UpdateModelMappingRequest
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db: Session = context.db
|
||||
|
||||
mapping: Optional[ModelMapping] = (
|
||||
db.query(ModelMapping).filter(ModelMapping.id == self.mapping_id).first()
|
||||
)
|
||||
|
||||
if not mapping:
|
||||
raise HTTPException(status_code=404, detail="映射不存在")
|
||||
|
||||
update_data = self.payload.model_dump(exclude_unset=True)
|
||||
|
||||
if "provider_id" in update_data:
|
||||
new_provider_id = update_data["provider_id"]
|
||||
if new_provider_id:
|
||||
provider = db.query(Provider).filter(Provider.id == new_provider_id).first()
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider 不存在")
|
||||
mapping.provider_id = new_provider_id
|
||||
|
||||
if "target_global_model_id" in update_data:
|
||||
target_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.id == update_data["target_global_model_id"],
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not target_model:
|
||||
raise HTTPException(status_code=404, detail="目标 GlobalModel 不存在或未激活")
|
||||
mapping.target_global_model_id = update_data["target_global_model_id"]
|
||||
|
||||
if "source_model" in update_data:
|
||||
new_source = update_data["source_model"].strip()
|
||||
if not new_source:
|
||||
raise HTTPException(status_code=400, detail="source_model 不能为空")
|
||||
mapping.source_model = new_source
|
||||
|
||||
if "is_active" in update_data:
|
||||
mapping.is_active = update_data["is_active"]
|
||||
|
||||
duplicate = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == mapping.source_model,
|
||||
ModelMapping.provider_id == mapping.provider_id,
|
||||
ModelMapping.id != mapping.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if duplicate:
|
||||
raise HTTPException(status_code=400, detail="映射已存在")
|
||||
|
||||
db.commit()
|
||||
db.refresh(mapping)
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_mapping_changed(mapping.source_model, mapping.provider_id)
|
||||
|
||||
return UpdateModelMappingResponse(
|
||||
success=True,
|
||||
mapping_id=mapping.id,
|
||||
message="映射更新成功",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminDeleteModelMappingAdapter(AdminApiAdapter):
|
||||
"""删除模型映射"""
|
||||
|
||||
mapping_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db: Session = context.db
|
||||
|
||||
mapping: Optional[ModelMapping] = (
|
||||
db.query(ModelMapping).filter(ModelMapping.id == self.mapping_id).first()
|
||||
)
|
||||
|
||||
if not mapping:
|
||||
raise HTTPException(status_code=404, detail="映射不存在")
|
||||
|
||||
source_model = mapping.source_model
|
||||
provider_id = mapping.provider_id
|
||||
|
||||
db.delete(mapping)
|
||||
db.commit()
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_mapping_changed(source_model, provider_id)
|
||||
|
||||
return DeleteModelMappingResponse(
|
||||
success=True,
|
||||
message=f"映射 {self.mapping_id} 已删除",
|
||||
)
|
||||
292
src/api/admin/models/global_models.py
Normal file
292
src/api/admin/models/global_models.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""
|
||||
GlobalModel Admin API
|
||||
|
||||
提供 GlobalModel 的 CRUD 操作接口
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.pydantic_models import (
|
||||
BatchAssignToProvidersRequest,
|
||||
BatchAssignToProvidersResponse,
|
||||
GlobalModelCreate,
|
||||
GlobalModelListResponse,
|
||||
GlobalModelResponse,
|
||||
GlobalModelUpdate,
|
||||
GlobalModelWithStats,
|
||||
)
|
||||
from src.services.model.global_model import GlobalModelService
|
||||
|
||||
router = APIRouter(prefix="/global", tags=["Admin - Global Models"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("", response_model=GlobalModelListResponse)
|
||||
async def list_global_models(
|
||||
request: Request,
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
is_active: Optional[bool] = Query(None),
|
||||
search: Optional[str] = Query(None),
|
||||
db: Session = Depends(get_db),
|
||||
) -> GlobalModelListResponse:
|
||||
"""获取 GlobalModel 列表"""
|
||||
adapter = AdminListGlobalModelsAdapter(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
is_active=is_active,
|
||||
search=search,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/{global_model_id}", response_model=GlobalModelWithStats)
|
||||
async def get_global_model(
|
||||
request: Request,
|
||||
global_model_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
) -> GlobalModelWithStats:
|
||||
"""获取单个 GlobalModel 详情(含统计信息)"""
|
||||
adapter = AdminGetGlobalModelAdapter(global_model_id=global_model_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("", response_model=GlobalModelResponse, status_code=201)
|
||||
async def create_global_model(
|
||||
request: Request,
|
||||
payload: GlobalModelCreate,
|
||||
db: Session = Depends(get_db),
|
||||
) -> GlobalModelResponse:
|
||||
"""创建 GlobalModel"""
|
||||
adapter = AdminCreateGlobalModelAdapter(payload=payload)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch("/{global_model_id}", response_model=GlobalModelResponse)
|
||||
async def update_global_model(
|
||||
request: Request,
|
||||
global_model_id: str,
|
||||
payload: GlobalModelUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
) -> GlobalModelResponse:
|
||||
"""更新 GlobalModel"""
|
||||
adapter = AdminUpdateGlobalModelAdapter(global_model_id=global_model_id, payload=payload)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/{global_model_id}", status_code=204)
|
||||
async def delete_global_model(
|
||||
request: Request,
|
||||
global_model_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除 GlobalModel(级联删除所有关联的 Provider 模型实现)"""
|
||||
adapter = AdminDeleteGlobalModelAdapter(global_model_id=global_model_id)
|
||||
await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
return None
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{global_model_id}/assign-to-providers", response_model=BatchAssignToProvidersResponse
|
||||
)
|
||||
async def batch_assign_to_providers(
|
||||
request: Request,
|
||||
global_model_id: str,
|
||||
payload: BatchAssignToProvidersRequest,
|
||||
db: Session = Depends(get_db),
|
||||
) -> BatchAssignToProvidersResponse:
|
||||
"""批量为多个 Provider 添加 GlobalModel 实现"""
|
||||
adapter = AdminBatchAssignToProvidersAdapter(global_model_id=global_model_id, payload=payload)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# ========== Adapters ==========
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminListGlobalModelsAdapter(AdminApiAdapter):
|
||||
"""列出 GlobalModel"""
|
||||
|
||||
skip: int
|
||||
limit: int
|
||||
is_active: Optional[bool]
|
||||
search: Optional[str]
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from sqlalchemy import func
|
||||
|
||||
from src.models.database import Model, ModelMapping
|
||||
|
||||
models = GlobalModelService.list_global_models(
|
||||
db=context.db,
|
||||
skip=self.skip,
|
||||
limit=self.limit,
|
||||
is_active=self.is_active,
|
||||
search=self.search,
|
||||
)
|
||||
|
||||
# 为每个 GlobalModel 添加统计数据
|
||||
model_responses = []
|
||||
for gm in models:
|
||||
# 统计关联的 Model 数量(去重 Provider)
|
||||
provider_count = (
|
||||
context.db.query(func.count(func.distinct(Model.provider_id)))
|
||||
.filter(Model.global_model_id == gm.id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
# 统计别名数量
|
||||
alias_count = (
|
||||
context.db.query(func.count(ModelMapping.id))
|
||||
.filter(ModelMapping.target_global_model_id == gm.id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
response = GlobalModelResponse.model_validate(gm)
|
||||
response.provider_count = provider_count
|
||||
response.alias_count = alias_count
|
||||
# usage_count 直接从 GlobalModel 表读取,已在 model_validate 中自动映射
|
||||
model_responses.append(response)
|
||||
|
||||
return GlobalModelListResponse(
|
||||
models=model_responses,
|
||||
total=len(models),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminGetGlobalModelAdapter(AdminApiAdapter):
|
||||
"""获取单个 GlobalModel"""
|
||||
|
||||
global_model_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
global_model = GlobalModelService.get_global_model(context.db, self.global_model_id)
|
||||
stats = GlobalModelService.get_global_model_stats(context.db, self.global_model_id)
|
||||
|
||||
return GlobalModelWithStats(
|
||||
**GlobalModelResponse.model_validate(global_model).model_dump(),
|
||||
total_models=stats["total_models"],
|
||||
total_providers=stats["total_providers"],
|
||||
price_range=stats["price_range"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminCreateGlobalModelAdapter(AdminApiAdapter):
|
||||
"""创建 GlobalModel"""
|
||||
|
||||
payload: GlobalModelCreate
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
# 将 TieredPricingConfig 转换为 dict
|
||||
tiered_pricing_dict = self.payload.default_tiered_pricing.model_dump()
|
||||
|
||||
global_model = GlobalModelService.create_global_model(
|
||||
db=context.db,
|
||||
name=self.payload.name,
|
||||
display_name=self.payload.display_name,
|
||||
description=self.payload.description,
|
||||
official_url=self.payload.official_url,
|
||||
icon_url=self.payload.icon_url,
|
||||
is_active=self.payload.is_active,
|
||||
# 按次计费配置
|
||||
default_price_per_request=self.payload.default_price_per_request,
|
||||
# 阶梯计费配置
|
||||
default_tiered_pricing=tiered_pricing_dict,
|
||||
# 默认能力配置
|
||||
default_supports_vision=self.payload.default_supports_vision,
|
||||
default_supports_function_calling=self.payload.default_supports_function_calling,
|
||||
default_supports_streaming=self.payload.default_supports_streaming,
|
||||
default_supports_extended_thinking=self.payload.default_supports_extended_thinking,
|
||||
# Key 能力配置
|
||||
supported_capabilities=self.payload.supported_capabilities,
|
||||
)
|
||||
|
||||
logger.info(f"GlobalModel 已创建: id={global_model.id} name={global_model.name}")
|
||||
|
||||
return GlobalModelResponse.model_validate(global_model)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminUpdateGlobalModelAdapter(AdminApiAdapter):
|
||||
"""更新 GlobalModel"""
|
||||
|
||||
global_model_id: str
|
||||
payload: GlobalModelUpdate
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
global_model = GlobalModelService.update_global_model(
|
||||
db=context.db,
|
||||
global_model_id=self.global_model_id,
|
||||
update_data=self.payload,
|
||||
)
|
||||
|
||||
logger.info(f"GlobalModel 已更新: id={global_model.id} name={global_model.name}")
|
||||
|
||||
# 失效相关缓存
|
||||
from src.services.cache.invalidation import get_cache_invalidation_service
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_global_model_changed(global_model.name)
|
||||
|
||||
return GlobalModelResponse.model_validate(global_model)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminDeleteGlobalModelAdapter(AdminApiAdapter):
|
||||
"""删除 GlobalModel(级联删除所有关联的 Provider 模型实现)"""
|
||||
|
||||
global_model_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
# 先获取 GlobalModel 信息(用于失效缓存)
|
||||
from src.models.database import GlobalModel
|
||||
|
||||
global_model = (
|
||||
context.db.query(GlobalModel).filter(GlobalModel.id == self.global_model_id).first()
|
||||
)
|
||||
model_name = global_model.name if global_model else None
|
||||
|
||||
GlobalModelService.delete_global_model(context.db, self.global_model_id)
|
||||
|
||||
logger.info(f"GlobalModel 已删除: id={self.global_model_id}")
|
||||
|
||||
# 失效相关缓存
|
||||
if model_name:
|
||||
from src.services.cache.invalidation import get_cache_invalidation_service
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_global_model_changed(model_name)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminBatchAssignToProvidersAdapter(AdminApiAdapter):
|
||||
"""批量为 Provider 添加 GlobalModel 实现"""
|
||||
|
||||
global_model_id: str
|
||||
payload: BatchAssignToProvidersRequest
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
result = GlobalModelService.batch_assign_to_providers(
|
||||
db=context.db,
|
||||
global_model_id=self.global_model_id,
|
||||
provider_ids=self.payload.provider_ids,
|
||||
create_models=self.payload.create_models,
|
||||
)
|
||||
|
||||
logger.info(f"批量为 Provider 添加 GlobalModel: global_model_id={self.global_model_id} success={len(result['success'])} errors={len(result['errors'])}")
|
||||
|
||||
return BatchAssignToProvidersResponse(**result)
|
||||
303
src/api/admin/models/mappings.py
Normal file
303
src/api/admin/models/mappings.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""模型映射管理 API
|
||||
|
||||
提供模型映射的 CRUD 操作。
|
||||
|
||||
模型映射(Mapping)用于将源模型映射到目标模型,例如:
|
||||
- 请求 gpt-5.1 → Provider A 映射到 gpt-4
|
||||
- 用于处理 Provider 不支持请求模型的情况
|
||||
|
||||
映射必须关联到特定的 Provider。
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.api import (
|
||||
ModelMappingCreate,
|
||||
ModelMappingResponse,
|
||||
ModelMappingUpdate,
|
||||
)
|
||||
from src.models.database import GlobalModel, ModelMapping, Provider, User
|
||||
from src.services.cache.invalidation import get_cache_invalidation_service
|
||||
|
||||
|
||||
router = APIRouter(prefix="/mappings", tags=["Model Mappings"])
|
||||
|
||||
|
||||
def _serialize_mapping(mapping: ModelMapping) -> ModelMappingResponse:
|
||||
target = mapping.target_global_model
|
||||
provider = mapping.provider
|
||||
scope = "provider" if mapping.provider_id else "global"
|
||||
return ModelMappingResponse(
|
||||
id=mapping.id,
|
||||
source_model=mapping.source_model,
|
||||
target_global_model_id=mapping.target_global_model_id,
|
||||
target_global_model_name=target.name if target else None,
|
||||
target_global_model_display_name=target.display_name if target else None,
|
||||
provider_id=mapping.provider_id,
|
||||
provider_name=provider.name if provider else None,
|
||||
scope=scope,
|
||||
mapping_type=mapping.mapping_type,
|
||||
is_active=mapping.is_active,
|
||||
created_at=mapping.created_at,
|
||||
updated_at=mapping.updated_at,
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=List[ModelMappingResponse])
|
||||
async def list_mappings(
|
||||
provider_id: Optional[str] = Query(None, description="按 Provider 筛选"),
|
||||
source_model: Optional[str] = Query(None, description="按源模型名筛选"),
|
||||
target_global_model_id: Optional[str] = Query(None, description="按目标模型筛选"),
|
||||
scope: Optional[str] = Query(None, description="global 或 provider"),
|
||||
mapping_type: Optional[str] = Query(None, description="映射类型: alias 或 mapping"),
|
||||
is_active: Optional[bool] = Query(None, description="按状态筛选"),
|
||||
skip: int = Query(0, ge=0, description="跳过记录数"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回记录数"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取模型映射列表"""
|
||||
query = db.query(ModelMapping).options(
|
||||
joinedload(ModelMapping.target_global_model),
|
||||
joinedload(ModelMapping.provider),
|
||||
)
|
||||
|
||||
if provider_id is not None:
|
||||
query = query.filter(ModelMapping.provider_id == provider_id)
|
||||
if scope == "global":
|
||||
query = query.filter(ModelMapping.provider_id.is_(None))
|
||||
elif scope == "provider":
|
||||
query = query.filter(ModelMapping.provider_id.isnot(None))
|
||||
if mapping_type is not None:
|
||||
query = query.filter(ModelMapping.mapping_type == mapping_type)
|
||||
if source_model:
|
||||
query = query.filter(ModelMapping.source_model.ilike(f"%{source_model}%"))
|
||||
if target_global_model_id is not None:
|
||||
query = query.filter(ModelMapping.target_global_model_id == target_global_model_id)
|
||||
if is_active is not None:
|
||||
query = query.filter(ModelMapping.is_active == is_active)
|
||||
|
||||
mappings = query.offset(skip).limit(limit).all()
|
||||
return [_serialize_mapping(mapping) for mapping in mappings]
|
||||
|
||||
|
||||
@router.get("/{mapping_id}", response_model=ModelMappingResponse)
|
||||
async def get_mapping(
|
||||
mapping_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取单个模型映射"""
|
||||
mapping = (
|
||||
db.query(ModelMapping)
|
||||
.options(
|
||||
joinedload(ModelMapping.target_global_model),
|
||||
joinedload(ModelMapping.provider),
|
||||
)
|
||||
.filter(ModelMapping.id == mapping_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not mapping:
|
||||
raise HTTPException(status_code=404, detail=f"映射 {mapping_id} 不存在")
|
||||
|
||||
return _serialize_mapping(mapping)
|
||||
|
||||
|
||||
@router.post("", response_model=ModelMappingResponse, status_code=201)
|
||||
async def create_mapping(
|
||||
data: ModelMappingCreate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""创建模型映射"""
|
||||
source_model = data.source_model.strip()
|
||||
if not source_model:
|
||||
raise HTTPException(status_code=400, detail="source_model 不能为空")
|
||||
|
||||
# 验证 mapping_type
|
||||
if data.mapping_type not in ("alias", "mapping"):
|
||||
raise HTTPException(status_code=400, detail="mapping_type 必须是 'alias' 或 'mapping'")
|
||||
|
||||
# 验证目标 GlobalModel 存在
|
||||
target_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(GlobalModel.id == data.target_global_model_id, GlobalModel.is_active == True)
|
||||
.first()
|
||||
)
|
||||
if not target_model:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"目标模型 {data.target_global_model_id} 不存在或未激活"
|
||||
)
|
||||
|
||||
# 验证 Provider 存在
|
||||
provider = None
|
||||
provider_id = data.provider_id
|
||||
if provider_id:
|
||||
provider = db.query(Provider).filter(Provider.id == provider_id).first()
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail=f"Provider {provider_id} 不存在")
|
||||
|
||||
# 检查映射是否已存在(全局或同一 Provider 下不可重复)
|
||||
existing = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == source_model,
|
||||
ModelMapping.provider_id == provider_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="映射已存在")
|
||||
|
||||
# 创建映射
|
||||
mapping = ModelMapping(
|
||||
id=str(uuid.uuid4()),
|
||||
source_model=source_model,
|
||||
target_global_model_id=data.target_global_model_id,
|
||||
provider_id=provider_id,
|
||||
mapping_type=data.mapping_type,
|
||||
is_active=data.is_active,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
db.add(mapping)
|
||||
db.commit()
|
||||
mapping = (
|
||||
db.query(ModelMapping)
|
||||
.options(
|
||||
joinedload(ModelMapping.target_global_model),
|
||||
joinedload(ModelMapping.provider),
|
||||
)
|
||||
.filter(ModelMapping.id == mapping.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
logger.info(f"创建模型映射: {source_model} -> {target_model.name} "
|
||||
f"(Provider: {provider.name if provider else 'global'}, ID: {mapping.id})")
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_mapping_changed(source_model, provider_id)
|
||||
|
||||
return _serialize_mapping(mapping)
|
||||
|
||||
|
||||
@router.patch("/{mapping_id}", response_model=ModelMappingResponse)
|
||||
async def update_mapping(
|
||||
mapping_id: str,
|
||||
data: ModelMappingUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""更新模型映射"""
|
||||
mapping = db.query(ModelMapping).filter(ModelMapping.id == mapping_id).first()
|
||||
if not mapping:
|
||||
raise HTTPException(status_code=404, detail=f"映射 {mapping_id} 不存在")
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
|
||||
# 更新 Provider
|
||||
if "provider_id" in update_data:
|
||||
new_provider_id = update_data["provider_id"]
|
||||
if new_provider_id:
|
||||
provider = db.query(Provider).filter(Provider.id == new_provider_id).first()
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail=f"Provider {new_provider_id} 不存在")
|
||||
mapping.provider_id = new_provider_id
|
||||
|
||||
# 更新目标模型
|
||||
if "target_global_model_id" in update_data:
|
||||
target_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.id == update_data["target_global_model_id"],
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not target_model:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"目标模型 {update_data['target_global_model_id']} 不存在或未激活",
|
||||
)
|
||||
mapping.target_global_model_id = update_data["target_global_model_id"]
|
||||
|
||||
# 更新源模型名
|
||||
if "source_model" in update_data:
|
||||
new_source = update_data["source_model"].strip()
|
||||
if not new_source:
|
||||
raise HTTPException(status_code=400, detail="source_model 不能为空")
|
||||
mapping.source_model = new_source
|
||||
|
||||
# 检查唯一约束
|
||||
duplicate = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == mapping.source_model,
|
||||
ModelMapping.provider_id == mapping.provider_id,
|
||||
ModelMapping.id != mapping_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if duplicate:
|
||||
raise HTTPException(status_code=400, detail="映射已存在")
|
||||
|
||||
# 更新映射类型
|
||||
if "mapping_type" in update_data:
|
||||
if update_data["mapping_type"] not in ("alias", "mapping"):
|
||||
raise HTTPException(status_code=400, detail="mapping_type 必须是 'alias' 或 'mapping'")
|
||||
mapping.mapping_type = update_data["mapping_type"]
|
||||
|
||||
# 更新状态
|
||||
if "is_active" in update_data:
|
||||
mapping.is_active = update_data["is_active"]
|
||||
|
||||
mapping.updated_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
db.refresh(mapping)
|
||||
|
||||
logger.info(f"更新模型映射 (ID: {mapping.id})")
|
||||
|
||||
mapping = (
|
||||
db.query(ModelMapping)
|
||||
.options(
|
||||
joinedload(ModelMapping.target_global_model),
|
||||
joinedload(ModelMapping.provider),
|
||||
)
|
||||
.filter(ModelMapping.id == mapping.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_mapping_changed(mapping.source_model, mapping.provider_id)
|
||||
|
||||
return _serialize_mapping(mapping)
|
||||
|
||||
|
||||
@router.delete("/{mapping_id}", status_code=204)
|
||||
async def delete_mapping(
|
||||
mapping_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除模型映射"""
|
||||
mapping = db.query(ModelMapping).filter(ModelMapping.id == mapping_id).first()
|
||||
|
||||
if not mapping:
|
||||
raise HTTPException(status_code=404, detail=f"映射 {mapping_id} 不存在")
|
||||
|
||||
source_model = mapping.source_model
|
||||
provider_id = mapping.provider_id
|
||||
|
||||
logger.info(f"删除模型映射: {source_model} -> {mapping.target_global_model_id} (ID: {mapping.id})")
|
||||
|
||||
db.delete(mapping)
|
||||
db.commit()
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_mapping_changed(source_model, provider_id)
|
||||
|
||||
return None
|
||||
14
src/api/admin/monitoring/__init__.py
Normal file
14
src/api/admin/monitoring/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Admin monitoring router合集。"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .audit import router as audit_router
|
||||
from .cache import router as cache_router
|
||||
from .trace import router as trace_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(audit_router)
|
||||
router.include_router(cache_router)
|
||||
router.include_router(trace_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
399
src/api/admin/monitoring/audit.py
Normal file
399
src/api/admin/monitoring/audit.py
Normal file
@@ -0,0 +1,399 @@
|
||||
"""管理员监控与审计端点。"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pagination import PaginationMeta, build_pagination_payload, paginate_query
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import (
|
||||
ApiKey,
|
||||
AuditEventType,
|
||||
AuditLog,
|
||||
Provider,
|
||||
Usage,
|
||||
)
|
||||
from src.models.database import User as DBUser
|
||||
from src.services.health.monitor import HealthMonitor
|
||||
from src.services.system.audit import audit_service
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/admin/monitoring", tags=["Admin - Monitoring"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("/audit-logs")
|
||||
async def get_audit_logs(
|
||||
request: Request,
|
||||
user_id: Optional[str] = Query(None, description="用户ID筛选 (支持UUID)"),
|
||||
event_type: Optional[str] = Query(None, description="事件类型筛选"),
|
||||
days: int = Query(7, description="查询天数"),
|
||||
limit: int = Query(100, description="返回数量限制"),
|
||||
offset: int = Query(0, description="偏移量"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = AdminGetAuditLogsAdapter(
|
||||
user_id=user_id,
|
||||
event_type=event_type,
|
||||
days=days,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/system-status")
|
||||
async def get_system_status(request: Request, db: Session = Depends(get_db)):
|
||||
adapter = AdminSystemStatusAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/suspicious-activities")
|
||||
async def get_suspicious_activities(
|
||||
request: Request,
|
||||
hours: int = Query(24, description="时间范围(小时)"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = AdminSuspiciousActivitiesAdapter(hours=hours)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/user-behavior/{user_id}")
|
||||
async def analyze_user_behavior(
|
||||
user_id: str,
|
||||
request: Request,
|
||||
days: int = Query(30, description="分析天数"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = AdminUserBehaviorAdapter(user_id=user_id, days=days)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/resilience-status")
|
||||
async def get_resilience_status(request: Request, db: Session = Depends(get_db)):
|
||||
adapter = AdminResilienceStatusAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/resilience/error-stats")
|
||||
async def reset_error_stats(request: Request, db: Session = Depends(get_db)):
|
||||
"""Reset resilience error statistics"""
|
||||
adapter = AdminResetErrorStatsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/resilience/circuit-history")
|
||||
async def get_circuit_history(
|
||||
request: Request,
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = AdminCircuitHistoryAdapter(limit=limit)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminGetAuditLogsAdapter(AdminApiAdapter):
|
||||
user_id: Optional[str]
|
||||
event_type: Optional[str]
|
||||
days: int
|
||||
limit: int
|
||||
offset: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
cutoff_time = datetime.now(timezone.utc) - timedelta(days=self.days)
|
||||
|
||||
base_query = (
|
||||
db.query(AuditLog, DBUser)
|
||||
.outerjoin(DBUser, AuditLog.user_id == DBUser.id)
|
||||
.filter(AuditLog.created_at >= cutoff_time)
|
||||
)
|
||||
if self.user_id:
|
||||
base_query = base_query.filter(AuditLog.user_id == self.user_id)
|
||||
if self.event_type:
|
||||
base_query = base_query.filter(AuditLog.event_type == self.event_type)
|
||||
|
||||
ordered_query = base_query.order_by(AuditLog.created_at.desc())
|
||||
total, logs_with_users = paginate_query(ordered_query, self.limit, self.offset)
|
||||
|
||||
items = [
|
||||
{
|
||||
"id": log.id,
|
||||
"event_type": log.event_type,
|
||||
"user_id": log.user_id,
|
||||
"user_email": user.email if user else None,
|
||||
"user_username": user.username if user else None,
|
||||
"description": log.description,
|
||||
"ip_address": log.ip_address,
|
||||
"status_code": log.status_code,
|
||||
"error_message": log.error_message,
|
||||
"metadata": log.event_metadata,
|
||||
"created_at": log.created_at.isoformat() if log.created_at else None,
|
||||
}
|
||||
for log, user in logs_with_users
|
||||
]
|
||||
meta = PaginationMeta(
|
||||
total=total,
|
||||
limit=self.limit,
|
||||
offset=self.offset,
|
||||
count=len(items),
|
||||
)
|
||||
|
||||
payload = build_pagination_payload(
|
||||
items,
|
||||
meta,
|
||||
filters={
|
||||
"user_id": self.user_id,
|
||||
"event_type": self.event_type,
|
||||
"days": self.days,
|
||||
},
|
||||
)
|
||||
context.add_audit_metadata(
|
||||
action="monitor_audit_logs",
|
||||
filter_user_id=self.user_id,
|
||||
filter_event_type=self.event_type,
|
||||
days=self.days,
|
||||
limit=self.limit,
|
||||
offset=self.offset,
|
||||
total=total,
|
||||
result_count=meta.count,
|
||||
)
|
||||
return payload
|
||||
|
||||
|
||||
class AdminSystemStatusAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
|
||||
total_users = db.query(func.count(DBUser.id)).scalar()
|
||||
active_users = db.query(func.count(DBUser.id)).filter(DBUser.is_active.is_(True)).scalar()
|
||||
|
||||
total_providers = db.query(func.count(Provider.id)).scalar()
|
||||
active_providers = (
|
||||
db.query(func.count(Provider.id)).filter(Provider.is_active.is_(True)).scalar()
|
||||
)
|
||||
|
||||
total_api_keys = db.query(func.count(ApiKey.id)).scalar()
|
||||
active_api_keys = (
|
||||
db.query(func.count(ApiKey.id)).filter(ApiKey.is_active.is_(True)).scalar()
|
||||
)
|
||||
|
||||
today_start = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
today_requests = (
|
||||
db.query(func.count(Usage.id)).filter(Usage.created_at >= today_start).scalar()
|
||||
)
|
||||
today_tokens = (
|
||||
db.query(func.sum(Usage.total_tokens)).filter(Usage.created_at >= today_start).scalar()
|
||||
or 0
|
||||
)
|
||||
today_cost = (
|
||||
db.query(func.sum(Usage.total_cost_usd))
|
||||
.filter(Usage.created_at >= today_start)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
recent_errors = (
|
||||
db.query(AuditLog)
|
||||
.filter(
|
||||
AuditLog.event_type.in_(
|
||||
[
|
||||
AuditEventType.REQUEST_FAILED.value,
|
||||
AuditEventType.SUSPICIOUS_ACTIVITY.value,
|
||||
]
|
||||
),
|
||||
AuditLog.created_at >= datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="system_status_snapshot",
|
||||
total_users=int(total_users or 0),
|
||||
active_users=int(active_users or 0),
|
||||
total_providers=int(total_providers or 0),
|
||||
active_providers=int(active_providers or 0),
|
||||
total_api_keys=int(total_api_keys or 0),
|
||||
active_api_keys=int(active_api_keys or 0),
|
||||
today_requests=int(today_requests or 0),
|
||||
today_tokens=int(today_tokens or 0),
|
||||
today_cost=float(today_cost or 0.0),
|
||||
recent_errors=int(recent_errors or 0),
|
||||
)
|
||||
|
||||
return {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"users": {"total": total_users, "active": active_users},
|
||||
"providers": {"total": total_providers, "active": active_providers},
|
||||
"api_keys": {"total": total_api_keys, "active": active_api_keys},
|
||||
"today_stats": {
|
||||
"requests": today_requests,
|
||||
"tokens": today_tokens,
|
||||
"cost_usd": f"${today_cost:.4f}",
|
||||
},
|
||||
"recent_errors": recent_errors,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminSuspiciousActivitiesAdapter(AdminApiAdapter):
|
||||
hours: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
activities = audit_service.get_suspicious_activities(db=db, hours=self.hours, limit=100)
|
||||
response = {
|
||||
"activities": [
|
||||
{
|
||||
"id": activity.id,
|
||||
"event_type": activity.event_type,
|
||||
"user_id": activity.user_id,
|
||||
"description": activity.description,
|
||||
"ip_address": activity.ip_address,
|
||||
"metadata": activity.event_metadata,
|
||||
"created_at": activity.created_at.isoformat() if activity.created_at else None,
|
||||
}
|
||||
for activity in activities
|
||||
],
|
||||
"count": len(activities),
|
||||
"time_range_hours": self.hours,
|
||||
}
|
||||
context.add_audit_metadata(
|
||||
action="monitor_suspicious_activity",
|
||||
hours=self.hours,
|
||||
result_count=len(activities),
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminUserBehaviorAdapter(AdminApiAdapter):
|
||||
user_id: str
|
||||
days: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
result = audit_service.analyze_user_behavior(
|
||||
db=context.db,
|
||||
user_id=self.user_id,
|
||||
days=self.days,
|
||||
)
|
||||
context.add_audit_metadata(
|
||||
action="monitor_user_behavior",
|
||||
target_user_id=self.user_id,
|
||||
days=self.days,
|
||||
contains_summary=bool(result),
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class AdminResilienceStatusAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
try:
|
||||
from src.core.resilience import resilience_manager
|
||||
except ImportError as exc:
|
||||
raise HTTPException(status_code=503, detail="韧性管理系统未启用") from exc
|
||||
|
||||
error_stats = resilience_manager.get_error_stats()
|
||||
recent_errors = [
|
||||
{
|
||||
"error_id": info["error_id"],
|
||||
"error_type": info["error_type"],
|
||||
"operation": info["operation"],
|
||||
"timestamp": info["timestamp"].isoformat(),
|
||||
"context": info.get("context", {}),
|
||||
}
|
||||
for info in resilience_manager.last_errors[-10:]
|
||||
]
|
||||
|
||||
total_errors = error_stats.get("total_errors", 0)
|
||||
circuit_breakers = error_stats.get("circuit_breakers", {})
|
||||
circuit_breakers_open = sum(
|
||||
1 for status in circuit_breakers.values() if status.get("state") == "open"
|
||||
)
|
||||
health_score = max(0, 100 - (total_errors * 2) - (circuit_breakers_open * 20))
|
||||
|
||||
response = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"health_score": health_score,
|
||||
"status": (
|
||||
"healthy" if health_score > 80 else "degraded" if health_score > 50 else "critical"
|
||||
),
|
||||
"error_statistics": error_stats,
|
||||
"recent_errors": recent_errors,
|
||||
"recommendations": _get_health_recommendations(error_stats, health_score),
|
||||
}
|
||||
context.add_audit_metadata(
|
||||
action="resilience_status",
|
||||
health_score=health_score,
|
||||
error_total=error_stats.get("total_errors") if isinstance(error_stats, dict) else None,
|
||||
open_circuit_breakers=circuit_breakers_open,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
class AdminResetErrorStatsAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
try:
|
||||
from src.core.resilience import resilience_manager
|
||||
except ImportError as exc:
|
||||
raise HTTPException(status_code=503, detail="韧性管理系统未启用") from exc
|
||||
|
||||
old_stats = resilience_manager.get_error_stats()
|
||||
resilience_manager.error_stats.clear()
|
||||
resilience_manager.last_errors.clear()
|
||||
|
||||
logger.info(f"管理员 {context.user.email if context.user else 'unknown'} 重置了错误统计")
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="reset_error_stats",
|
||||
previous_total_errors=(
|
||||
old_stats.get("total_errors") if isinstance(old_stats, dict) else None
|
||||
),
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "错误统计已重置",
|
||||
"previous_stats": old_stats,
|
||||
"reset_by": context.user.email if context.user else None,
|
||||
"reset_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
|
||||
class AdminCircuitHistoryAdapter(AdminApiAdapter):
|
||||
def __init__(self, limit: int = 50):
|
||||
super().__init__()
|
||||
self.limit = limit
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
history = HealthMonitor.get_circuit_history(self.limit)
|
||||
context.add_audit_metadata(
|
||||
action="circuit_history",
|
||||
limit=self.limit,
|
||||
result_count=len(history),
|
||||
)
|
||||
return {"items": history, "count": len(history)}
|
||||
|
||||
|
||||
def _get_health_recommendations(error_stats: dict, health_score: int) -> List[str]:
|
||||
recommendations: List[str] = []
|
||||
if health_score < 50:
|
||||
recommendations.append("系统健康状况严重,请立即检查错误日志")
|
||||
if error_stats.get("total_errors", 0) > 100:
|
||||
recommendations.append("错误频率过高,建议检查系统配置和外部依赖")
|
||||
|
||||
circuit_breakers = error_stats.get("circuit_breakers", {})
|
||||
open_breakers = [k for k, v in circuit_breakers.items() if v.get("state") == "open"]
|
||||
if open_breakers:
|
||||
recommendations.append(f"以下服务熔断器已打开:{', '.join(open_breakers)}")
|
||||
|
||||
if health_score > 90:
|
||||
recommendations.append("系统运行良好")
|
||||
return recommendations
|
||||
871
src/api/admin/monitoring/cache.py
Normal file
871
src/api/admin/monitoring/cache.py
Normal file
@@ -0,0 +1,871 @@
|
||||
"""
|
||||
缓存监控端点
|
||||
|
||||
提供缓存亲和性统计、管理和监控功能
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi.responses import PlainTextResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pagination import PaginationMeta, build_pagination_payload, paginate_sequence
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.clients.redis_client import get_redis_client_sync
|
||||
from src.core.crypto import crypto_service
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import ApiKey, User
|
||||
from src.services.cache.affinity_manager import get_affinity_manager
|
||||
from src.services.cache.aware_scheduler import get_cache_aware_scheduler
|
||||
|
||||
router = APIRouter(prefix="/api/admin/monitoring/cache", tags=["Admin - Monitoring: Cache"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
def mask_api_key(api_key: Optional[str], prefix_len: int = 8, suffix_len: int = 4) -> Optional[str]:
|
||||
"""
|
||||
脱敏 API Key,显示前缀 + 星号 + 后缀
|
||||
例如: sk-jhiId-xxxxxxxxxxxAABB -> sk-jhiId-********AABB
|
||||
|
||||
Args:
|
||||
api_key: 原始 API Key
|
||||
prefix_len: 显示的前缀长度,默认 8
|
||||
suffix_len: 显示的后缀长度,默认 4
|
||||
"""
|
||||
if not api_key:
|
||||
return None
|
||||
total_visible = prefix_len + suffix_len
|
||||
if len(api_key) <= total_visible:
|
||||
# Key 太短,直接返回部分内容 + 星号
|
||||
return api_key[:prefix_len] + "********"
|
||||
return f"{api_key[:prefix_len]}********{api_key[-suffix_len:]}"
|
||||
|
||||
|
||||
def decrypt_and_mask(encrypted_key: Optional[str], prefix_len: int = 8) -> Optional[str]:
|
||||
"""
|
||||
解密 API Key 后脱敏显示
|
||||
|
||||
Args:
|
||||
encrypted_key: 加密后的 API Key
|
||||
prefix_len: 显示的前缀长度
|
||||
"""
|
||||
if not encrypted_key:
|
||||
return None
|
||||
try:
|
||||
decrypted = crypto_service.decrypt(encrypted_key)
|
||||
return mask_api_key(decrypted, prefix_len)
|
||||
except Exception:
|
||||
# 解密失败时返回 None
|
||||
return None
|
||||
|
||||
|
||||
def resolve_user_identifier(db: Session, identifier: str) -> Optional[str]:
|
||||
"""
|
||||
将用户标识符(username/email/user_id/api_key_id)解析为 user_id
|
||||
|
||||
支持的输入格式:
|
||||
1. User UUID (36位,带横杠)
|
||||
2. Username (用户名)
|
||||
3. Email (邮箱)
|
||||
4. API Key ID (36位UUID)
|
||||
|
||||
返回:
|
||||
- user_id (UUID字符串) 或 None
|
||||
"""
|
||||
identifier = identifier.strip()
|
||||
|
||||
# 1. 先尝试作为 User UUID 查询
|
||||
user = db.query(User).filter(User.id == identifier).first()
|
||||
if user:
|
||||
logger.debug(f"通过User ID解析: {identifier[:8]}... -> {user.username}")
|
||||
return user.id
|
||||
|
||||
# 2. 尝试作为 Username 查询
|
||||
user = db.query(User).filter(User.username == identifier).first()
|
||||
if user:
|
||||
logger.debug(f"通过Username解析: {identifier} -> {user.id[:8]}...")
|
||||
return user.id
|
||||
|
||||
# 3. 尝试作为 Email 查询
|
||||
user = db.query(User).filter(User.email == identifier).first()
|
||||
if user:
|
||||
logger.debug(f"通过Email解析: {identifier} -> {user.id[:8]}...")
|
||||
return user.id
|
||||
|
||||
# 4. 尝试作为 API Key ID 查询
|
||||
api_key = db.query(ApiKey).filter(ApiKey.id == identifier).first()
|
||||
if api_key:
|
||||
logger.debug(f"通过API Key ID解析: {identifier[:8]}... -> User ID: {api_key.user_id[:8]}...")
|
||||
return api_key.user_id
|
||||
|
||||
# 无法识别
|
||||
logger.debug(f"无法识别的用户标识符: {identifier}")
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_cache_stats(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取缓存亲和性统计信息
|
||||
|
||||
返回:
|
||||
- 缓存命中率
|
||||
- 缓存用户数
|
||||
- Provider切换次数
|
||||
- Key切换次数
|
||||
- 缓存预留配置
|
||||
"""
|
||||
adapter = AdminCacheStatsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/affinity/{user_identifier}")
|
||||
async def get_user_affinity(
|
||||
user_identifier: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
查询指定用户的所有缓存亲和性
|
||||
|
||||
参数:
|
||||
- user_identifier: 用户标识符,支持以下格式:
|
||||
* 用户名 (username),如: yuanhonghu
|
||||
* 邮箱 (email),如: user@example.com
|
||||
* 用户UUID (user_id),如: 550e8400-e29b-41d4-a716-446655440000
|
||||
* API Key ID,如: 660e8400-e29b-41d4-a716-446655440000
|
||||
|
||||
返回:
|
||||
- 用户信息
|
||||
- 所有端点的缓存亲和性列表(每个端点一条记录)
|
||||
"""
|
||||
adapter = AdminGetUserAffinityAdapter(user_identifier=user_identifier)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/affinities")
|
||||
async def list_affinities(
|
||||
request: Request,
|
||||
keyword: Optional[str] = None,
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回数量限制"),
|
||||
offset: int = Query(0, ge=0, description="偏移量"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取所有缓存亲和性列表,可选按关键词过滤
|
||||
|
||||
参数:
|
||||
- keyword: 可选,支持用户名/邮箱/User ID/API Key ID 或模糊匹配
|
||||
"""
|
||||
adapter = AdminListAffinitiesAdapter(keyword=keyword, limit=limit, offset=offset)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/users/{user_identifier}")
|
||||
async def clear_user_cache(
|
||||
user_identifier: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Clear cache affinity for a specific user
|
||||
|
||||
Parameters:
|
||||
- user_identifier: User identifier (username, email, user_id, or API Key ID)
|
||||
"""
|
||||
adapter = AdminClearUserCacheAdapter(user_identifier=user_identifier)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("")
|
||||
async def clear_all_cache(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Clear all cache affinities
|
||||
|
||||
Warning: This affects all users, use with caution
|
||||
"""
|
||||
adapter = AdminClearAllCacheAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/providers/{provider_id}")
|
||||
async def clear_provider_cache(
|
||||
provider_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Clear cache affinities for a specific provider
|
||||
|
||||
Parameters:
|
||||
- provider_id: Provider ID
|
||||
"""
|
||||
adapter = AdminClearProviderCacheAdapter(provider_id=provider_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
async def get_cache_config(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取缓存相关配置
|
||||
|
||||
返回:
|
||||
- 缓存TTL
|
||||
- 缓存预留比例
|
||||
"""
|
||||
adapter = AdminCacheConfigAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/metrics", response_class=PlainTextResponse)
|
||||
async def get_cache_metrics(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
以 Prometheus 文本格式暴露缓存调度指标,方便接入 Grafana。
|
||||
"""
|
||||
adapter = AdminCacheMetricsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# -------- 缓存监控适配器 --------
|
||||
|
||||
|
||||
class AdminCacheStatsAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
try:
|
||||
redis_client = get_redis_client_sync()
|
||||
scheduler = await get_cache_aware_scheduler(redis_client)
|
||||
stats = await scheduler.get_stats()
|
||||
logger.info("缓存统计信息查询成功")
|
||||
context.add_audit_metadata(
|
||||
action="cache_stats",
|
||||
scheduler=stats.get("scheduler"),
|
||||
total_affinities=stats.get("total_affinities"),
|
||||
cache_hit_rate=stats.get("cache_hit_rate"),
|
||||
provider_switches=stats.get("provider_switches"),
|
||||
)
|
||||
return {"status": "ok", "data": stats}
|
||||
except Exception as exc:
|
||||
logger.exception(f"获取缓存统计信息失败: {exc}")
|
||||
raise HTTPException(status_code=500, detail=f"获取缓存统计失败: {exc}")
|
||||
|
||||
|
||||
class AdminCacheMetricsAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
try:
|
||||
redis_client = get_redis_client_sync()
|
||||
scheduler = await get_cache_aware_scheduler(redis_client)
|
||||
stats = await scheduler.get_stats()
|
||||
payload = self._format_prometheus(stats)
|
||||
context.add_audit_metadata(
|
||||
action="cache_metrics_export",
|
||||
scheduler=stats.get("scheduler"),
|
||||
metrics_lines=payload.count("\n"),
|
||||
)
|
||||
return PlainTextResponse(payload)
|
||||
except Exception as exc:
|
||||
logger.exception(f"导出缓存指标失败: {exc}")
|
||||
raise HTTPException(status_code=500, detail=f"导出缓存指标失败: {exc}")
|
||||
|
||||
def _format_prometheus(self, stats: Dict[str, Any]) -> str:
|
||||
"""
|
||||
将 scheduler/affinity 指标转换为 Prometheus 文本格式。
|
||||
"""
|
||||
scheduler_metrics = stats.get("scheduler_metrics", {})
|
||||
affinity_stats = stats.get("affinity_stats", {})
|
||||
|
||||
metric_map: List[Tuple[str, str, float]] = [
|
||||
(
|
||||
"cache_scheduler_total_batches",
|
||||
"Total batches pulled from provider list",
|
||||
float(scheduler_metrics.get("total_batches", 0)),
|
||||
),
|
||||
(
|
||||
"cache_scheduler_last_batch_size",
|
||||
"Size of the latest candidate batch",
|
||||
float(scheduler_metrics.get("last_batch_size", 0)),
|
||||
),
|
||||
(
|
||||
"cache_scheduler_total_candidates",
|
||||
"Total candidates enumerated by scheduler",
|
||||
float(scheduler_metrics.get("total_candidates", 0)),
|
||||
),
|
||||
(
|
||||
"cache_scheduler_last_candidate_count",
|
||||
"Number of candidates in the most recent batch",
|
||||
float(scheduler_metrics.get("last_candidate_count", 0)),
|
||||
),
|
||||
(
|
||||
"cache_scheduler_cache_hits",
|
||||
"Cache hits counted during scheduling",
|
||||
float(scheduler_metrics.get("cache_hits", 0)),
|
||||
),
|
||||
(
|
||||
"cache_scheduler_cache_misses",
|
||||
"Cache misses counted during scheduling",
|
||||
float(scheduler_metrics.get("cache_misses", 0)),
|
||||
),
|
||||
(
|
||||
"cache_scheduler_cache_hit_rate",
|
||||
"Cache hit rate during scheduling",
|
||||
float(scheduler_metrics.get("cache_hit_rate", 0.0)),
|
||||
),
|
||||
(
|
||||
"cache_scheduler_concurrency_denied",
|
||||
"Times candidate rejected due to concurrency limits",
|
||||
float(scheduler_metrics.get("concurrency_denied", 0)),
|
||||
),
|
||||
(
|
||||
"cache_scheduler_avg_candidates_per_batch",
|
||||
"Average candidates per batch",
|
||||
float(scheduler_metrics.get("avg_candidates_per_batch", 0.0)),
|
||||
),
|
||||
]
|
||||
|
||||
affinity_map: List[Tuple[str, str, float]] = [
|
||||
(
|
||||
"cache_affinity_total",
|
||||
"Total cache affinities stored",
|
||||
float(affinity_stats.get("total_affinities", 0)),
|
||||
),
|
||||
(
|
||||
"cache_affinity_hits",
|
||||
"Affinity cache hits",
|
||||
float(affinity_stats.get("cache_hits", 0)),
|
||||
),
|
||||
(
|
||||
"cache_affinity_misses",
|
||||
"Affinity cache misses",
|
||||
float(affinity_stats.get("cache_misses", 0)),
|
||||
),
|
||||
(
|
||||
"cache_affinity_hit_rate",
|
||||
"Affinity cache hit rate",
|
||||
float(affinity_stats.get("cache_hit_rate", 0.0)),
|
||||
),
|
||||
(
|
||||
"cache_affinity_invalidations",
|
||||
"Affinity invalidations",
|
||||
float(affinity_stats.get("cache_invalidations", 0)),
|
||||
),
|
||||
(
|
||||
"cache_affinity_provider_switches",
|
||||
"Affinity provider switches",
|
||||
float(affinity_stats.get("provider_switches", 0)),
|
||||
),
|
||||
(
|
||||
"cache_affinity_key_switches",
|
||||
"Affinity key switches",
|
||||
float(affinity_stats.get("key_switches", 0)),
|
||||
),
|
||||
]
|
||||
|
||||
lines = []
|
||||
for name, help_text, value in metric_map + affinity_map:
|
||||
lines.append(f"# HELP {name} {help_text}")
|
||||
lines.append(f"# TYPE {name} gauge")
|
||||
lines.append(f"{name} {value}")
|
||||
|
||||
scheduler_name = stats.get("scheduler", "cache_aware")
|
||||
lines.append(f'cache_scheduler_info{{scheduler="{scheduler_name}"}} 1')
|
||||
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminGetUserAffinityAdapter(AdminApiAdapter):
|
||||
user_identifier: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
try:
|
||||
user_id = resolve_user_identifier(db, self.user_identifier)
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"无法识别的用户标识符: {self.user_identifier}。支持用户名、邮箱、User ID或API Key ID",
|
||||
)
|
||||
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
redis_client = get_redis_client_sync()
|
||||
affinity_mgr = await get_affinity_manager(redis_client)
|
||||
|
||||
# 获取该用户的所有缓存亲和性
|
||||
all_affinities = await affinity_mgr.list_affinities()
|
||||
user_affinities = [aff for aff in all_affinities if aff.get("user_id") == user_id]
|
||||
|
||||
if not user_affinities:
|
||||
response = {
|
||||
"status": "not_found",
|
||||
"message": f"用户 {user.username} ({user.email}) 没有缓存亲和性",
|
||||
"user_info": {
|
||||
"user_id": user_id,
|
||||
"username": user.username,
|
||||
"email": user.email,
|
||||
},
|
||||
"affinities": [],
|
||||
}
|
||||
context.add_audit_metadata(
|
||||
action="cache_user_affinity",
|
||||
user_identifier=self.user_identifier,
|
||||
resolved_user_id=user_id,
|
||||
affinity_count=0,
|
||||
status="not_found",
|
||||
)
|
||||
return response
|
||||
|
||||
response = {
|
||||
"status": "ok",
|
||||
"user_info": {
|
||||
"user_id": user_id,
|
||||
"username": user.username,
|
||||
"email": user.email,
|
||||
},
|
||||
"affinities": [
|
||||
{
|
||||
"provider_id": aff["provider_id"],
|
||||
"endpoint_id": aff["endpoint_id"],
|
||||
"key_id": aff["key_id"],
|
||||
"api_format": aff.get("api_format"),
|
||||
"model_name": aff.get("model_name"),
|
||||
"created_at": aff["created_at"],
|
||||
"expire_at": aff["expire_at"],
|
||||
"request_count": aff["request_count"],
|
||||
}
|
||||
for aff in user_affinities
|
||||
],
|
||||
"total_endpoints": len(user_affinities),
|
||||
}
|
||||
context.add_audit_metadata(
|
||||
action="cache_user_affinity",
|
||||
user_identifier=self.user_identifier,
|
||||
resolved_user_id=user_id,
|
||||
affinity_count=len(user_affinities),
|
||||
status="ok",
|
||||
)
|
||||
return response
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception(f"查询用户缓存亲和性失败: {exc}")
|
||||
raise HTTPException(status_code=500, detail=f"查询失败: {exc}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminListAffinitiesAdapter(AdminApiAdapter):
|
||||
keyword: Optional[str]
|
||||
limit: int
|
||||
offset: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
redis_client = get_redis_client_sync()
|
||||
if not redis_client:
|
||||
raise HTTPException(status_code=503, detail="Redis未初始化,无法获取缓存亲和性")
|
||||
|
||||
affinity_mgr = await get_affinity_manager(redis_client)
|
||||
matched_user_id = None
|
||||
matched_api_key_id = None
|
||||
raw_affinities: List[Dict[str, Any]] = []
|
||||
|
||||
if self.keyword:
|
||||
# 首先检查是否是 API Key ID(affinity_key)
|
||||
api_key = db.query(ApiKey).filter(ApiKey.id == self.keyword).first()
|
||||
if api_key:
|
||||
# 直接通过 affinity_key 过滤
|
||||
matched_api_key_id = str(api_key.id)
|
||||
matched_user_id = str(api_key.user_id)
|
||||
all_affinities = await affinity_mgr.list_affinities()
|
||||
raw_affinities = [
|
||||
aff for aff in all_affinities if aff.get("affinity_key") == matched_api_key_id
|
||||
]
|
||||
else:
|
||||
# 尝试解析为用户标识
|
||||
user_id = resolve_user_identifier(db, self.keyword)
|
||||
if user_id:
|
||||
matched_user_id = user_id
|
||||
# 获取该用户所有的 API Key ID
|
||||
user_api_keys = db.query(ApiKey).filter(ApiKey.user_id == user_id).all()
|
||||
user_api_key_ids = {str(k.id) for k in user_api_keys}
|
||||
# 过滤出该用户所有 API Key 的亲和性
|
||||
all_affinities = await affinity_mgr.list_affinities()
|
||||
raw_affinities = [
|
||||
aff for aff in all_affinities if aff.get("affinity_key") in user_api_key_ids
|
||||
]
|
||||
else:
|
||||
# 关键词不是有效标识,返回所有亲和性(后续会进行模糊匹配)
|
||||
raw_affinities = await affinity_mgr.list_affinities()
|
||||
else:
|
||||
raw_affinities = await affinity_mgr.list_affinities()
|
||||
|
||||
# 收集所有 affinity_key (API Key ID)
|
||||
affinity_keys = {
|
||||
item.get("affinity_key") for item in raw_affinities if item.get("affinity_key")
|
||||
}
|
||||
|
||||
# 批量查询用户 API Key 信息
|
||||
user_api_key_map: Dict[str, ApiKey] = {}
|
||||
if affinity_keys:
|
||||
user_api_keys = db.query(ApiKey).filter(ApiKey.id.in_(list(affinity_keys))).all()
|
||||
user_api_key_map = {str(k.id): k for k in user_api_keys}
|
||||
|
||||
# 收集所有 user_id
|
||||
user_ids = {str(k.user_id) for k in user_api_key_map.values()}
|
||||
user_map: Dict[str, User] = {}
|
||||
if user_ids:
|
||||
users = db.query(User).filter(User.id.in_(list(user_ids))).all()
|
||||
user_map = {str(user.id): user for user in users}
|
||||
|
||||
# 收集所有provider_id、endpoint_id、key_id
|
||||
provider_ids = {
|
||||
item.get("provider_id") for item in raw_affinities if item.get("provider_id")
|
||||
}
|
||||
endpoint_ids = {
|
||||
item.get("endpoint_id") for item in raw_affinities if item.get("endpoint_id")
|
||||
}
|
||||
key_ids = {item.get("key_id") for item in raw_affinities if item.get("key_id")}
|
||||
|
||||
# 批量查询Provider、Endpoint、Key信息
|
||||
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
|
||||
|
||||
provider_map = {}
|
||||
if provider_ids:
|
||||
providers = db.query(Provider).filter(Provider.id.in_(list(provider_ids))).all()
|
||||
provider_map = {p.id: p for p in providers}
|
||||
|
||||
endpoint_map = {}
|
||||
if endpoint_ids:
|
||||
endpoints = (
|
||||
db.query(ProviderEndpoint).filter(ProviderEndpoint.id.in_(list(endpoint_ids))).all()
|
||||
)
|
||||
endpoint_map = {e.id: e for e in endpoints}
|
||||
|
||||
key_map = {}
|
||||
if key_ids:
|
||||
keys = db.query(ProviderAPIKey).filter(ProviderAPIKey.id.in_(list(key_ids))).all()
|
||||
key_map = {k.id: k for k in keys}
|
||||
|
||||
# 收集所有 model_name(实际存储的是 global_model_id)并批量查询 GlobalModel
|
||||
from src.models.database import GlobalModel
|
||||
|
||||
global_model_ids = {
|
||||
item.get("model_name") for item in raw_affinities if item.get("model_name")
|
||||
}
|
||||
global_model_map: Dict[str, GlobalModel] = {}
|
||||
if global_model_ids:
|
||||
# model_name 可能是 UUID 格式的 global_model_id,也可能是原始模型名称
|
||||
global_models = db.query(GlobalModel).filter(
|
||||
GlobalModel.id.in_(list(global_model_ids))
|
||||
).all()
|
||||
global_model_map = {str(gm.id): gm for gm in global_models}
|
||||
|
||||
keyword_lower = self.keyword.lower() if self.keyword else None
|
||||
items = []
|
||||
for affinity in raw_affinities:
|
||||
affinity_key = affinity.get("affinity_key")
|
||||
if not affinity_key:
|
||||
continue
|
||||
|
||||
# 通过 affinity_key(API Key ID)找到用户 API Key 和用户
|
||||
user_api_key = user_api_key_map.get(affinity_key)
|
||||
user = user_map.get(str(user_api_key.user_id)) if user_api_key else None
|
||||
user_id = str(user_api_key.user_id) if user_api_key else None
|
||||
|
||||
provider_id = affinity.get("provider_id")
|
||||
endpoint_id = affinity.get("endpoint_id")
|
||||
key_id = affinity.get("key_id")
|
||||
|
||||
provider = provider_map.get(provider_id)
|
||||
endpoint = endpoint_map.get(endpoint_id)
|
||||
key = key_map.get(key_id)
|
||||
|
||||
# 用户 API Key 脱敏显示(解密 key_encrypted 后脱敏)
|
||||
user_api_key_masked = None
|
||||
if user_api_key and user_api_key.key_encrypted:
|
||||
user_api_key_masked = decrypt_and_mask(user_api_key.key_encrypted)
|
||||
|
||||
# Provider Key 脱敏显示(解密 api_key 后脱敏)
|
||||
provider_key_masked = None
|
||||
if key and key.api_key:
|
||||
provider_key_masked = decrypt_and_mask(key.api_key)
|
||||
|
||||
item = {
|
||||
"affinity_key": affinity_key,
|
||||
"user_api_key_name": user_api_key.name if user_api_key else None,
|
||||
"user_api_key_prefix": user_api_key_masked,
|
||||
"is_standalone": user_api_key.is_standalone if user_api_key else False,
|
||||
"user_id": user_id,
|
||||
"username": user.username if user else None,
|
||||
"email": user.email if user else None,
|
||||
"provider_id": provider_id,
|
||||
"provider_name": provider.display_name if provider else None,
|
||||
"endpoint_id": endpoint_id,
|
||||
"endpoint_api_format": (
|
||||
endpoint.api_format if endpoint and endpoint.api_format else None
|
||||
),
|
||||
"endpoint_url": endpoint.base_url if endpoint else None,
|
||||
"key_id": key_id,
|
||||
"key_name": key.name if key else None,
|
||||
"key_prefix": provider_key_masked,
|
||||
"rate_multiplier": key.rate_multiplier if key else 1.0,
|
||||
"model_name": (
|
||||
global_model_map.get(affinity.get("model_name")).name
|
||||
if affinity.get("model_name") and global_model_map.get(affinity.get("model_name"))
|
||||
else affinity.get("model_name") # 如果找不到 GlobalModel,显示原始值
|
||||
),
|
||||
"model_display_name": (
|
||||
global_model_map.get(affinity.get("model_name")).display_name
|
||||
if affinity.get("model_name") and global_model_map.get(affinity.get("model_name"))
|
||||
else None
|
||||
),
|
||||
"api_format": affinity.get("api_format"),
|
||||
"created_at": affinity.get("created_at"),
|
||||
"expire_at": affinity.get("expire_at"),
|
||||
"request_count": affinity.get("request_count", 0),
|
||||
}
|
||||
|
||||
if keyword_lower and not matched_user_id and not matched_api_key_id:
|
||||
searchable = [
|
||||
item["affinity_key"],
|
||||
item["user_api_key_name"] or "",
|
||||
item["user_id"] or "",
|
||||
item["username"] or "",
|
||||
item["email"] or "",
|
||||
item["provider_id"] or "",
|
||||
item["key_id"] or "",
|
||||
]
|
||||
if not any(keyword_lower in str(value).lower() for value in searchable if value):
|
||||
continue
|
||||
|
||||
items.append(item)
|
||||
|
||||
items.sort(key=lambda x: x.get("expire_at") or 0, reverse=True)
|
||||
paged_items, meta = paginate_sequence(items, self.limit, self.offset)
|
||||
payload = build_pagination_payload(
|
||||
paged_items,
|
||||
meta,
|
||||
matched_user_id=matched_user_id,
|
||||
)
|
||||
response = {
|
||||
"status": "ok",
|
||||
"data": payload,
|
||||
}
|
||||
result_count = meta.count if hasattr(meta, "count") else len(paged_items)
|
||||
context.add_audit_metadata(
|
||||
action="cache_affinity_list",
|
||||
keyword=self.keyword,
|
||||
matched_user_id=matched_user_id,
|
||||
matched_api_key_id=matched_api_key_id,
|
||||
limit=self.limit,
|
||||
offset=self.offset,
|
||||
result_count=result_count,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminClearUserCacheAdapter(AdminApiAdapter):
|
||||
user_identifier: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
try:
|
||||
redis_client = get_redis_client_sync()
|
||||
affinity_mgr = await get_affinity_manager(redis_client)
|
||||
|
||||
# 首先检查是否直接是 API Key ID (affinity_key)
|
||||
api_key = db.query(ApiKey).filter(ApiKey.id == self.user_identifier).first()
|
||||
if api_key:
|
||||
# 直接按 affinity_key 清除
|
||||
affinity_key = str(api_key.id)
|
||||
user = db.query(User).filter(User.id == api_key.user_id).first()
|
||||
|
||||
all_affinities = await affinity_mgr.list_affinities()
|
||||
target_affinities = [
|
||||
aff for aff in all_affinities if aff.get("affinity_key") == affinity_key
|
||||
]
|
||||
|
||||
count = 0
|
||||
for aff in target_affinities:
|
||||
api_format = aff.get("api_format")
|
||||
model_name = aff.get("model_name")
|
||||
endpoint_id = aff.get("endpoint_id")
|
||||
if api_format and model_name:
|
||||
await affinity_mgr.invalidate_affinity(
|
||||
affinity_key, api_format, model_name, endpoint_id=endpoint_id
|
||||
)
|
||||
count += 1
|
||||
|
||||
logger.info(f"已清除API Key缓存亲和性: api_key_name={api_key.name}, affinity_key={affinity_key[:8]}..., 清除数量={count}")
|
||||
|
||||
response = {
|
||||
"status": "ok",
|
||||
"message": f"已清除 API Key {api_key.name} 的缓存亲和性",
|
||||
"user_info": {
|
||||
"user_id": str(api_key.user_id),
|
||||
"username": user.username if user else None,
|
||||
"email": user.email if user else None,
|
||||
"api_key_id": affinity_key,
|
||||
"api_key_name": api_key.name,
|
||||
},
|
||||
}
|
||||
context.add_audit_metadata(
|
||||
action="cache_clear_api_key",
|
||||
user_identifier=self.user_identifier,
|
||||
resolved_api_key_id=affinity_key,
|
||||
cleared_count=count,
|
||||
)
|
||||
return response
|
||||
|
||||
# 如果不是 API Key ID,尝试解析为用户标识
|
||||
user_id = resolve_user_identifier(db, self.user_identifier)
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"无法识别的标识符: {self.user_identifier}。支持用户名、邮箱、User ID或API Key ID",
|
||||
)
|
||||
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
|
||||
# 获取该用户所有的 API Key
|
||||
user_api_keys = db.query(ApiKey).filter(ApiKey.user_id == user_id).all()
|
||||
user_api_key_ids = {str(k.id) for k in user_api_keys}
|
||||
|
||||
# 获取该用户所有 API Key 的缓存亲和性并逐个失效
|
||||
all_affinities = await affinity_mgr.list_affinities()
|
||||
user_affinities = [
|
||||
aff for aff in all_affinities if aff.get("affinity_key") in user_api_key_ids
|
||||
]
|
||||
|
||||
count = 0
|
||||
for aff in user_affinities:
|
||||
affinity_key = aff.get("affinity_key")
|
||||
api_format = aff.get("api_format")
|
||||
model_name = aff.get("model_name")
|
||||
endpoint_id = aff.get("endpoint_id")
|
||||
if affinity_key and api_format and model_name:
|
||||
await affinity_mgr.invalidate_affinity(
|
||||
affinity_key, api_format, model_name, endpoint_id=endpoint_id
|
||||
)
|
||||
count += 1
|
||||
|
||||
logger.info(f"已清除用户缓存亲和性: username={user.username}, user_id={user_id[:8]}..., 清除数量={count}")
|
||||
|
||||
response = {
|
||||
"status": "ok",
|
||||
"message": f"已清除用户 {user.username} 的所有缓存亲和性",
|
||||
"user_info": {"user_id": user_id, "username": user.username, "email": user.email},
|
||||
}
|
||||
context.add_audit_metadata(
|
||||
action="cache_clear_user",
|
||||
user_identifier=self.user_identifier,
|
||||
resolved_user_id=user_id,
|
||||
cleared_count=count,
|
||||
)
|
||||
return response
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception(f"清除用户缓存亲和性失败: {exc}")
|
||||
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
|
||||
|
||||
|
||||
class AdminClearAllCacheAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
try:
|
||||
redis_client = get_redis_client_sync()
|
||||
affinity_mgr = await get_affinity_manager(redis_client)
|
||||
count = await affinity_mgr.clear_all()
|
||||
logger.warning(f"已清除所有缓存亲和性(管理员操作): {count} 个")
|
||||
context.add_audit_metadata(
|
||||
action="cache_clear_all",
|
||||
cleared_count=count,
|
||||
)
|
||||
return {"status": "ok", "message": "已清除所有缓存亲和性", "count": count}
|
||||
except Exception as exc:
|
||||
logger.exception(f"清除所有缓存亲和性失败: {exc}")
|
||||
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminClearProviderCacheAdapter(AdminApiAdapter):
|
||||
provider_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
try:
|
||||
redis_client = get_redis_client_sync()
|
||||
affinity_mgr = await get_affinity_manager(redis_client)
|
||||
count = await affinity_mgr.invalidate_all_for_provider(self.provider_id)
|
||||
logger.info(f"已清除Provider缓存亲和性: provider_id={self.provider_id[:8]}..., count={count}")
|
||||
context.add_audit_metadata(
|
||||
action="cache_clear_provider",
|
||||
provider_id=self.provider_id,
|
||||
cleared_count=count,
|
||||
)
|
||||
return {
|
||||
"status": "ok",
|
||||
"message": "已清除Provider的缓存亲和性",
|
||||
"provider_id": self.provider_id,
|
||||
"count": count,
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.exception(f"清除Provider缓存亲和性失败: {exc}")
|
||||
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
|
||||
|
||||
|
||||
class AdminCacheConfigAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from src.services.cache.affinity_manager import CacheAffinityManager
|
||||
from src.services.cache.aware_scheduler import CacheAwareScheduler
|
||||
from src.services.rate_limit.adaptive_reservation import get_adaptive_reservation_manager
|
||||
|
||||
# 获取动态预留管理器的配置
|
||||
reservation_manager = get_adaptive_reservation_manager()
|
||||
reservation_stats = reservation_manager.get_stats()
|
||||
|
||||
response = {
|
||||
"status": "ok",
|
||||
"data": {
|
||||
"cache_ttl_seconds": CacheAffinityManager.DEFAULT_CACHE_TTL,
|
||||
"cache_reservation_ratio": CacheAwareScheduler.CACHE_RESERVATION_RATIO,
|
||||
"dynamic_reservation": {
|
||||
"enabled": True,
|
||||
"config": reservation_stats["config"],
|
||||
"description": {
|
||||
"probe_phase_requests": "探测阶段请求数阈值",
|
||||
"probe_reservation": "探测阶段预留比例",
|
||||
"stable_min_reservation": "稳定阶段最小预留比例",
|
||||
"stable_max_reservation": "稳定阶段最大预留比例",
|
||||
"low_load_threshold": "低负载阈值(低于此值使用最小预留)",
|
||||
"high_load_threshold": "高负载阈值(高于此值根据置信度使用较高预留)",
|
||||
},
|
||||
},
|
||||
"description": {
|
||||
"cache_ttl": "缓存亲和性有效期(秒)",
|
||||
"cache_reservation_ratio": "静态预留比例(已被动态预留替代)",
|
||||
"dynamic_reservation": "动态预留机制配置",
|
||||
},
|
||||
},
|
||||
}
|
||||
context.add_audit_metadata(
|
||||
action="cache_config",
|
||||
cache_ttl_seconds=CacheAffinityManager.DEFAULT_CACHE_TTL,
|
||||
cache_reservation_ratio=CacheAwareScheduler.CACHE_RESERVATION_RATIO,
|
||||
dynamic_reservation_enabled=True,
|
||||
)
|
||||
return response
|
||||
280
src/api/admin/monitoring/trace.py
Normal file
280
src/api/admin/monitoring/trace.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""
|
||||
请求链路追踪 API 端点
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.database import get_db
|
||||
from src.models.database import Provider, ProviderEndpoint, ProviderAPIKey
|
||||
from src.services.request.candidate import RequestCandidateService
|
||||
|
||||
router = APIRouter(prefix="/api/admin/monitoring/trace", tags=["Admin - Monitoring: Trace"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
class CandidateResponse(BaseModel):
|
||||
"""候选记录响应"""
|
||||
|
||||
id: str
|
||||
request_id: str
|
||||
candidate_index: int
|
||||
retry_index: int = 0 # 重试序号(从0开始)
|
||||
provider_id: Optional[str] = None
|
||||
provider_name: Optional[str] = None
|
||||
provider_website: Optional[str] = None # Provider 官网
|
||||
endpoint_id: Optional[str] = None
|
||||
endpoint_name: Optional[str] = None # 端点显示名称(api_format)
|
||||
key_id: Optional[str] = None
|
||||
key_name: Optional[str] = None # 密钥名称
|
||||
key_preview: Optional[str] = None # 密钥脱敏预览(如 sk-***abc)
|
||||
key_capabilities: Optional[dict] = None # Key 支持的能力
|
||||
required_capabilities: Optional[dict] = None # 请求实际需要的能力标签
|
||||
status: str # 'pending', 'success', 'failed', 'skipped'
|
||||
skip_reason: Optional[str] = None
|
||||
is_cached: bool = False
|
||||
# 执行结果字段
|
||||
status_code: Optional[int] = None
|
||||
error_type: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
latency_ms: Optional[int] = None
|
||||
concurrent_requests: Optional[int] = None
|
||||
extra_data: Optional[dict] = None
|
||||
created_at: datetime
|
||||
started_at: Optional[datetime] = None
|
||||
finished_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class RequestTraceResponse(BaseModel):
|
||||
"""请求追踪完整响应"""
|
||||
|
||||
request_id: str
|
||||
total_candidates: int
|
||||
final_status: str # 'success', 'failed', 'streaming', 'pending'
|
||||
total_latency_ms: int
|
||||
candidates: List[CandidateResponse]
|
||||
|
||||
|
||||
@router.get("/{request_id}", response_model=RequestTraceResponse)
|
||||
async def get_request_trace(
|
||||
request_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取特定请求的完整追踪信息"""
|
||||
|
||||
adapter = AdminGetRequestTraceAdapter(request_id=request_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/stats/provider/{provider_id}")
|
||||
async def get_provider_failure_rate(
|
||||
provider_id: str,
|
||||
request: Request,
|
||||
limit: int = Query(100, ge=1, le=1000, description="统计最近的尝试数量"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取某个 Provider 的失败率统计
|
||||
|
||||
需要管理员权限
|
||||
"""
|
||||
adapter = AdminProviderFailureRateAdapter(provider_id=provider_id, limit=limit)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# -------- 请求追踪适配器 --------
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminGetRequestTraceAdapter(AdminApiAdapter):
|
||||
request_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
|
||||
# 只查询 candidates
|
||||
candidates = RequestCandidateService.get_candidates_by_request_id(db, self.request_id)
|
||||
|
||||
# 如果没有数据,返回 404
|
||||
if not candidates:
|
||||
raise HTTPException(status_code=404, detail="Request not found")
|
||||
|
||||
# 计算总延迟(只统计已完成的候选:success 或 failed)
|
||||
# 使用显式的 is not None 检查,避免过滤掉 0ms 的快速响应
|
||||
total_latency = sum(
|
||||
c.latency_ms
|
||||
for c in candidates
|
||||
if c.status in ("success", "failed") and c.latency_ms is not None
|
||||
)
|
||||
|
||||
# 判断最终状态:
|
||||
# 1. status="success" 即视为成功(无论 status_code 是什么)
|
||||
# - 流式请求即使客户端断开(499),只要 Provider 成功返回数据,也算成功
|
||||
# 2. 同时检查 status_code 在 200-299 范围,作为额外的成功判断条件
|
||||
# - 用于兼容非流式请求或未正确设置 status 的旧数据
|
||||
# 3. status="streaming" 表示流式请求正在进行中
|
||||
# 4. status="pending" 表示请求尚未开始执行
|
||||
has_success = any(
|
||||
c.status == "success"
|
||||
or (c.status_code is not None and 200 <= c.status_code < 300)
|
||||
for c in candidates
|
||||
)
|
||||
has_streaming = any(c.status == "streaming" for c in candidates)
|
||||
has_pending = any(c.status == "pending" for c in candidates)
|
||||
|
||||
if has_success:
|
||||
final_status = "success"
|
||||
elif has_streaming:
|
||||
# 有候选正在流式传输中
|
||||
final_status = "streaming"
|
||||
elif has_pending:
|
||||
# 有候选正在等待执行
|
||||
final_status = "pending"
|
||||
else:
|
||||
final_status = "failed"
|
||||
|
||||
# 批量加载 provider 信息,避免 N+1 查询
|
||||
provider_ids = {c.provider_id for c in candidates if c.provider_id}
|
||||
provider_map = {}
|
||||
provider_website_map = {}
|
||||
if provider_ids:
|
||||
providers = db.query(Provider).filter(Provider.id.in_(provider_ids)).all()
|
||||
for p in providers:
|
||||
provider_map[p.id] = p.name
|
||||
provider_website_map[p.id] = p.website
|
||||
|
||||
# 批量加载 endpoint 信息
|
||||
endpoint_ids = {c.endpoint_id for c in candidates if c.endpoint_id}
|
||||
endpoint_map = {}
|
||||
if endpoint_ids:
|
||||
endpoints = db.query(ProviderEndpoint).filter(ProviderEndpoint.id.in_(endpoint_ids)).all()
|
||||
endpoint_map = {e.id: e.api_format for e in endpoints}
|
||||
|
||||
# 批量加载 key 信息
|
||||
key_ids = {c.key_id for c in candidates if c.key_id}
|
||||
key_map = {}
|
||||
key_preview_map = {}
|
||||
key_capabilities_map = {}
|
||||
if key_ids:
|
||||
keys = db.query(ProviderAPIKey).filter(ProviderAPIKey.id.in_(key_ids)).all()
|
||||
for k in keys:
|
||||
key_map[k.id] = k.name
|
||||
key_capabilities_map[k.id] = k.capabilities
|
||||
# 生成脱敏预览:保留前缀和最后4位
|
||||
api_key = k.api_key or ""
|
||||
if len(api_key) > 8:
|
||||
# 检测常见前缀模式
|
||||
prefix_end = 0
|
||||
for prefix in ["sk-", "key-", "api-", "ak-"]:
|
||||
if api_key.lower().startswith(prefix):
|
||||
prefix_end = len(prefix)
|
||||
break
|
||||
if prefix_end > 0:
|
||||
key_preview_map[k.id] = f"{api_key[:prefix_end]}***{api_key[-4:]}"
|
||||
else:
|
||||
key_preview_map[k.id] = f"{api_key[:3]}***{api_key[-4:]}"
|
||||
elif len(api_key) > 4:
|
||||
key_preview_map[k.id] = f"***{api_key[-4:]}"
|
||||
else:
|
||||
key_preview_map[k.id] = "***"
|
||||
|
||||
# 构建 candidate 响应列表
|
||||
candidate_responses: List[CandidateResponse] = []
|
||||
for candidate in candidates:
|
||||
provider_name = (
|
||||
provider_map.get(candidate.provider_id) if candidate.provider_id else None
|
||||
)
|
||||
provider_website = (
|
||||
provider_website_map.get(candidate.provider_id) if candidate.provider_id else None
|
||||
)
|
||||
endpoint_name = (
|
||||
endpoint_map.get(candidate.endpoint_id) if candidate.endpoint_id else None
|
||||
)
|
||||
key_name = (
|
||||
key_map.get(candidate.key_id) if candidate.key_id else None
|
||||
)
|
||||
key_preview = (
|
||||
key_preview_map.get(candidate.key_id) if candidate.key_id else None
|
||||
)
|
||||
key_capabilities = (
|
||||
key_capabilities_map.get(candidate.key_id) if candidate.key_id else None
|
||||
)
|
||||
|
||||
candidate_responses.append(
|
||||
CandidateResponse(
|
||||
id=candidate.id,
|
||||
request_id=candidate.request_id,
|
||||
candidate_index=candidate.candidate_index,
|
||||
retry_index=candidate.retry_index,
|
||||
provider_id=candidate.provider_id,
|
||||
provider_name=provider_name,
|
||||
provider_website=provider_website,
|
||||
endpoint_id=candidate.endpoint_id,
|
||||
endpoint_name=endpoint_name,
|
||||
key_id=candidate.key_id,
|
||||
key_name=key_name,
|
||||
key_preview=key_preview,
|
||||
key_capabilities=key_capabilities,
|
||||
required_capabilities=candidate.required_capabilities,
|
||||
status=candidate.status,
|
||||
skip_reason=candidate.skip_reason,
|
||||
is_cached=candidate.is_cached,
|
||||
status_code=candidate.status_code,
|
||||
error_type=candidate.error_type,
|
||||
error_message=candidate.error_message,
|
||||
latency_ms=candidate.latency_ms,
|
||||
concurrent_requests=candidate.concurrent_requests,
|
||||
extra_data=candidate.extra_data,
|
||||
created_at=candidate.created_at,
|
||||
started_at=candidate.started_at,
|
||||
finished_at=candidate.finished_at,
|
||||
)
|
||||
)
|
||||
|
||||
response = RequestTraceResponse(
|
||||
request_id=self.request_id,
|
||||
total_candidates=len(candidates),
|
||||
final_status=final_status,
|
||||
total_latency_ms=total_latency,
|
||||
candidates=candidate_responses,
|
||||
)
|
||||
context.add_audit_metadata(
|
||||
action="trace_request_detail",
|
||||
request_id=self.request_id,
|
||||
total_candidates=len(candidates),
|
||||
final_status=final_status,
|
||||
total_latency_ms=total_latency,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminProviderFailureRateAdapter(AdminApiAdapter):
|
||||
provider_id: str
|
||||
limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
result = RequestCandidateService.get_candidate_stats_by_provider(
|
||||
db=context.db,
|
||||
provider_id=self.provider_id,
|
||||
limit=self.limit,
|
||||
)
|
||||
context.add_audit_metadata(
|
||||
action="trace_provider_failure_rate",
|
||||
provider_id=self.provider_id,
|
||||
limit=self.limit,
|
||||
total_attempts=result.get("total_attempts"),
|
||||
failure_rate=result.get("failure_rate"),
|
||||
)
|
||||
return result
|
||||
410
src/api/admin/provider_query.py
Normal file
410
src/api/admin/provider_query.py
Normal file
@@ -0,0 +1,410 @@
|
||||
"""
|
||||
Provider Query API 端点
|
||||
用于查询提供商的余额、使用记录等信息
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.database.database import get_db
|
||||
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint, User
|
||||
|
||||
# 初始化适配器注册
|
||||
from src.plugins.provider_query import init # noqa
|
||||
from src.plugins.provider_query import get_query_registry
|
||||
from src.plugins.provider_query.base import QueryCapability
|
||||
from src.utils.auth_utils import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/provider-query", tags=["Provider Query"])
|
||||
|
||||
|
||||
# ============ Request/Response Models ============
|
||||
|
||||
|
||||
class BalanceQueryRequest(BaseModel):
|
||||
"""余额查询请求"""
|
||||
|
||||
provider_id: str
|
||||
api_key_id: Optional[str] = None # 如果不指定,使用提供商的第一个可用 API Key
|
||||
|
||||
|
||||
class UsageSummaryQueryRequest(BaseModel):
|
||||
"""使用汇总查询请求"""
|
||||
|
||||
provider_id: str
|
||||
api_key_id: Optional[str] = None
|
||||
period: str = "month" # day, week, month, year
|
||||
|
||||
|
||||
class ModelsQueryRequest(BaseModel):
|
||||
"""模型列表查询请求"""
|
||||
|
||||
provider_id: str
|
||||
api_key_id: Optional[str] = None
|
||||
|
||||
|
||||
# ============ API Endpoints ============
|
||||
|
||||
|
||||
@router.get("/adapters")
|
||||
async def list_adapters(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取所有可用的查询适配器
|
||||
|
||||
Returns:
|
||||
适配器列表
|
||||
"""
|
||||
registry = get_query_registry()
|
||||
adapters = registry.list_adapters()
|
||||
|
||||
return {"success": True, "data": adapters}
|
||||
|
||||
|
||||
@router.get("/capabilities/{provider_id}")
|
||||
async def get_provider_capabilities(
|
||||
provider_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取提供商支持的查询能力
|
||||
|
||||
Args:
|
||||
provider_id: 提供商 ID
|
||||
|
||||
Returns:
|
||||
支持的查询能力列表
|
||||
"""
|
||||
# 获取提供商
|
||||
from sqlalchemy import select
|
||||
|
||||
result = await db.execute(select(Provider).where(Provider.id == provider_id))
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
registry = get_query_registry()
|
||||
capabilities = registry.get_capabilities_for_provider(provider.name)
|
||||
|
||||
if capabilities is None:
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"provider_id": provider_id,
|
||||
"provider_name": provider.name,
|
||||
"capabilities": [],
|
||||
"has_adapter": False,
|
||||
"message": "No query adapter available for this provider",
|
||||
},
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"provider_id": provider_id,
|
||||
"provider_name": provider.name,
|
||||
"capabilities": [c.name for c in capabilities],
|
||||
"has_adapter": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.post("/balance")
|
||||
async def query_balance(
|
||||
request: BalanceQueryRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
查询提供商余额
|
||||
|
||||
Args:
|
||||
request: 查询请求
|
||||
|
||||
Returns:
|
||||
余额信息
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
# 获取提供商及其端点
|
||||
result = await db.execute(
|
||||
select(Provider)
|
||||
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
|
||||
.where(Provider.id == request.provider_id)
|
||||
)
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
# 获取 API Key
|
||||
api_key_value = None
|
||||
endpoint_config = None
|
||||
|
||||
if request.api_key_id:
|
||||
# 查找指定的 API Key
|
||||
for endpoint in provider.endpoints:
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.id == request.api_key_id:
|
||||
api_key_value = api_key.api_key
|
||||
endpoint_config = {
|
||||
"base_url": endpoint.base_url,
|
||||
"api_format": endpoint.api_format if endpoint.api_format else None,
|
||||
}
|
||||
break
|
||||
if api_key_value:
|
||||
break
|
||||
|
||||
if not api_key_value:
|
||||
raise HTTPException(status_code=404, detail="API Key not found")
|
||||
else:
|
||||
# 使用第一个可用的 API Key
|
||||
for endpoint in provider.endpoints:
|
||||
if endpoint.is_active and endpoint.api_keys:
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.is_active:
|
||||
api_key_value = api_key.api_key
|
||||
endpoint_config = {
|
||||
"base_url": endpoint.base_url,
|
||||
"api_format": endpoint.api_format if endpoint.api_format else None,
|
||||
}
|
||||
break
|
||||
if api_key_value:
|
||||
break
|
||||
|
||||
if not api_key_value:
|
||||
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
|
||||
|
||||
# 查询余额
|
||||
registry = get_query_registry()
|
||||
query_result = await registry.query_provider_balance(
|
||||
provider_type=provider.name, api_key=api_key_value, endpoint_config=endpoint_config
|
||||
)
|
||||
|
||||
if not query_result.success:
|
||||
logger.warning(f"Balance query failed for provider {provider.name}: {query_result.error}")
|
||||
|
||||
return {
|
||||
"success": query_result.success,
|
||||
"data": query_result.to_dict(),
|
||||
"provider": {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.post("/usage-summary")
|
||||
async def query_usage_summary(
|
||||
request: UsageSummaryQueryRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
查询提供商使用汇总
|
||||
|
||||
Args:
|
||||
request: 查询请求
|
||||
|
||||
Returns:
|
||||
使用汇总信息
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
# 获取提供商及其端点
|
||||
result = await db.execute(
|
||||
select(Provider)
|
||||
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
|
||||
.where(Provider.id == request.provider_id)
|
||||
)
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
# 获取 API Key(逻辑同上)
|
||||
api_key_value = None
|
||||
endpoint_config = None
|
||||
|
||||
if request.api_key_id:
|
||||
for endpoint in provider.endpoints:
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.id == request.api_key_id:
|
||||
api_key_value = api_key.api_key
|
||||
endpoint_config = {"base_url": endpoint.base_url}
|
||||
break
|
||||
if api_key_value:
|
||||
break
|
||||
|
||||
if not api_key_value:
|
||||
raise HTTPException(status_code=404, detail="API Key not found")
|
||||
else:
|
||||
for endpoint in provider.endpoints:
|
||||
if endpoint.is_active and endpoint.api_keys:
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.is_active:
|
||||
api_key_value = api_key.api_key
|
||||
endpoint_config = {"base_url": endpoint.base_url}
|
||||
break
|
||||
if api_key_value:
|
||||
break
|
||||
|
||||
if not api_key_value:
|
||||
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
|
||||
|
||||
# 查询使用汇总
|
||||
registry = get_query_registry()
|
||||
query_result = await registry.query_provider_usage(
|
||||
provider_type=provider.name,
|
||||
api_key=api_key_value,
|
||||
period=request.period,
|
||||
endpoint_config=endpoint_config,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": query_result.success,
|
||||
"data": query_result.to_dict(),
|
||||
"provider": {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.post("/models")
|
||||
async def query_available_models(
|
||||
request: ModelsQueryRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
查询提供商可用模型
|
||||
|
||||
Args:
|
||||
request: 查询请求
|
||||
|
||||
Returns:
|
||||
模型列表
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
# 获取提供商及其端点
|
||||
result = await db.execute(
|
||||
select(Provider)
|
||||
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
|
||||
.where(Provider.id == request.provider_id)
|
||||
)
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
# 获取 API Key
|
||||
api_key_value = None
|
||||
endpoint_config = None
|
||||
|
||||
if request.api_key_id:
|
||||
for endpoint in provider.endpoints:
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.id == request.api_key_id:
|
||||
api_key_value = api_key.api_key
|
||||
endpoint_config = {"base_url": endpoint.base_url}
|
||||
break
|
||||
if api_key_value:
|
||||
break
|
||||
|
||||
if not api_key_value:
|
||||
raise HTTPException(status_code=404, detail="API Key not found")
|
||||
else:
|
||||
for endpoint in provider.endpoints:
|
||||
if endpoint.is_active and endpoint.api_keys:
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.is_active:
|
||||
api_key_value = api_key.api_key
|
||||
endpoint_config = {"base_url": endpoint.base_url}
|
||||
break
|
||||
if api_key_value:
|
||||
break
|
||||
|
||||
if not api_key_value:
|
||||
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
|
||||
|
||||
# 查询模型
|
||||
registry = get_query_registry()
|
||||
adapter = registry.get_adapter_for_provider(provider.name)
|
||||
|
||||
if not adapter:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"No query adapter available for provider: {provider.name}"
|
||||
)
|
||||
|
||||
query_result = await adapter.query_available_models(
|
||||
api_key=api_key_value, endpoint_config=endpoint_config
|
||||
)
|
||||
|
||||
return {
|
||||
"success": query_result.success,
|
||||
"data": query_result.to_dict(),
|
||||
"provider": {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/cache/{provider_id}")
|
||||
async def clear_query_cache(
|
||||
provider_id: str,
|
||||
api_key_id: Optional[str] = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
清除查询缓存
|
||||
|
||||
Args:
|
||||
provider_id: 提供商 ID
|
||||
api_key_id: 可选,指定清除某个 API Key 的缓存
|
||||
|
||||
Returns:
|
||||
清除结果
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
|
||||
# 获取提供商
|
||||
result = await db.execute(select(Provider).where(Provider.id == provider_id))
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
registry = get_query_registry()
|
||||
adapter = registry.get_adapter_for_provider(provider.name)
|
||||
|
||||
if adapter:
|
||||
if api_key_id:
|
||||
# 获取 API Key 值来清除缓存
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
result = await db.execute(select(ProviderAPIKey).where(ProviderAPIKey.id == api_key_id))
|
||||
api_key = result.scalar_one_or_none()
|
||||
if api_key:
|
||||
adapter.clear_cache(api_key.api_key)
|
||||
else:
|
||||
adapter.clear_cache()
|
||||
|
||||
return {"success": True, "message": "Cache cleared successfully"}
|
||||
272
src/api/admin/provider_strategy.py
Normal file
272
src/api/admin/provider_strategy.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
提供商策略管理 API 端点
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.enums import ProviderBillingType
|
||||
from src.core.exceptions import InvalidRequestException, translate_pydantic_error
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import Provider
|
||||
from src.models.database_extensions import ProviderUsageTracking
|
||||
|
||||
router = APIRouter(prefix="/api/admin/provider-strategy", tags=["Provider Strategy"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
class ProviderBillingUpdate(BaseModel):
|
||||
billing_type: ProviderBillingType
|
||||
monthly_quota_usd: Optional[float] = None
|
||||
quota_reset_day: int = Field(default=30, ge=1, le=365) # 重置周期(天数)
|
||||
quota_last_reset_at: Optional[str] = None # 当前周期开始时间
|
||||
quota_expires_at: Optional[str] = None
|
||||
rpm_limit: Optional[int] = Field(default=None, ge=0)
|
||||
provider_priority: int = Field(default=100, ge=0, le=200)
|
||||
|
||||
|
||||
@router.put("/providers/{provider_id}/billing")
|
||||
async def update_provider_billing(
|
||||
provider_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = AdminProviderBillingAdapter(provider_id=provider_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/providers/{provider_id}/stats")
|
||||
async def get_provider_stats(
|
||||
provider_id: str,
|
||||
request: Request,
|
||||
hours: int = 24,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = AdminProviderStatsAdapter(provider_id=provider_id, hours=hours)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/providers/{provider_id}/quota")
|
||||
async def reset_provider_quota(
|
||||
provider_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Reset provider quota usage to zero"""
|
||||
adapter = AdminProviderResetQuotaAdapter(provider_id=provider_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/strategies")
|
||||
async def list_available_strategies(request: Request, db: Session = Depends(get_db)):
|
||||
adapter = AdminListStrategiesAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
class AdminProviderBillingAdapter(AdminApiAdapter):
|
||||
def __init__(self, provider_id: str):
|
||||
self.provider_id = provider_id
|
||||
|
||||
async def handle(self, context):
|
||||
db = context.db
|
||||
payload = context.ensure_json_body()
|
||||
try:
|
||||
config = ProviderBillingUpdate.model_validate(payload)
|
||||
except ValidationError as e:
|
||||
errors = e.errors()
|
||||
if errors:
|
||||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||||
raise InvalidRequestException("请求数据验证失败")
|
||||
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
provider.billing_type = config.billing_type
|
||||
provider.monthly_quota_usd = config.monthly_quota_usd
|
||||
provider.quota_reset_day = config.quota_reset_day
|
||||
provider.rpm_limit = config.rpm_limit
|
||||
provider.provider_priority = config.provider_priority
|
||||
|
||||
from dateutil import parser
|
||||
from sqlalchemy import func
|
||||
|
||||
from src.models.database import Usage
|
||||
|
||||
if config.quota_last_reset_at:
|
||||
new_reset_at = parser.parse(config.quota_last_reset_at)
|
||||
provider.quota_last_reset_at = new_reset_at
|
||||
|
||||
# 自动同步该周期内的历史使用量
|
||||
period_usage = (
|
||||
db.query(func.coalesce(func.sum(Usage.total_cost_usd), 0))
|
||||
.filter(
|
||||
Usage.provider_id == self.provider_id,
|
||||
Usage.created_at >= new_reset_at,
|
||||
)
|
||||
.scalar()
|
||||
)
|
||||
provider.monthly_used_usd = float(period_usage or 0)
|
||||
logger.info(f"Synced usage for provider {provider.name}: ${period_usage:.4f} since {new_reset_at}")
|
||||
|
||||
if config.quota_expires_at:
|
||||
provider.quota_expires_at = parser.parse(config.quota_expires_at)
|
||||
|
||||
db.commit()
|
||||
db.refresh(provider)
|
||||
|
||||
logger.info(f"Updated billing config for provider {provider.name}")
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
"message": "Provider billing config updated successfully",
|
||||
"provider": {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"billing_type": provider.billing_type.value,
|
||||
"provider_priority": provider.provider_priority,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class AdminProviderStatsAdapter(AdminApiAdapter):
|
||||
def __init__(self, provider_id: str, hours: int):
|
||||
self.provider_id = provider_id
|
||||
self.hours = hours
|
||||
|
||||
async def handle(self, context):
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
since = datetime.now() - timedelta(hours=self.hours)
|
||||
stats = (
|
||||
db.query(ProviderUsageTracking)
|
||||
.filter(
|
||||
ProviderUsageTracking.provider_id == self.provider_id,
|
||||
ProviderUsageTracking.window_start >= since,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
total_requests = sum(s.total_requests for s in stats)
|
||||
total_success = sum(s.successful_requests for s in stats)
|
||||
total_failures = sum(s.failed_requests for s in stats)
|
||||
avg_response_time = sum(s.avg_response_time_ms for s in stats) / len(stats) if stats else 0
|
||||
total_cost = sum(s.total_cost_usd for s in stats)
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
"provider_id": self.provider_id,
|
||||
"provider_name": provider.name,
|
||||
"period_hours": self.hours,
|
||||
"billing_info": {
|
||||
"billing_type": provider.billing_type.value,
|
||||
"monthly_quota_usd": provider.monthly_quota_usd,
|
||||
"monthly_used_usd": provider.monthly_used_usd,
|
||||
"quota_remaining_usd": (
|
||||
provider.monthly_quota_usd - provider.monthly_used_usd
|
||||
if provider.monthly_quota_usd is not None
|
||||
else None
|
||||
),
|
||||
"quota_expires_at": (
|
||||
provider.quota_expires_at.isoformat() if provider.quota_expires_at else None
|
||||
),
|
||||
},
|
||||
"rpm_info": {
|
||||
"rpm_limit": provider.rpm_limit,
|
||||
"rpm_used": provider.rpm_used,
|
||||
"rpm_reset_at": (
|
||||
provider.rpm_reset_at.isoformat() if provider.rpm_reset_at else None
|
||||
),
|
||||
},
|
||||
"usage_stats": {
|
||||
"total_requests": total_requests,
|
||||
"successful_requests": total_success,
|
||||
"failed_requests": total_failures,
|
||||
"success_rate": total_success / total_requests if total_requests > 0 else 0,
|
||||
"avg_response_time_ms": round(avg_response_time, 2),
|
||||
"total_cost_usd": round(total_cost, 4),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class AdminProviderResetQuotaAdapter(AdminApiAdapter):
|
||||
def __init__(self, provider_id: str):
|
||||
self.provider_id = provider_id
|
||||
|
||||
async def handle(self, context):
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
if provider.billing_type != ProviderBillingType.MONTHLY_QUOTA:
|
||||
raise HTTPException(status_code=400, detail="Only monthly quota providers can be reset")
|
||||
|
||||
old_used = provider.monthly_used_usd
|
||||
provider.monthly_used_usd = 0.0
|
||||
provider.rpm_used = 0
|
||||
provider.rpm_reset_at = None
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Manually reset quota for provider {provider.name}")
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
"message": "Provider quota reset successfully",
|
||||
"provider_name": provider.name,
|
||||
"previous_used": old_used,
|
||||
"current_used": 0.0,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class AdminListStrategiesAdapter(AdminApiAdapter):
|
||||
async def handle(self, context):
|
||||
from src.plugins.manager import get_plugin_manager
|
||||
|
||||
plugin_manager = get_plugin_manager()
|
||||
lb_plugins = plugin_manager.plugins.get("load_balancer", {})
|
||||
|
||||
strategies = []
|
||||
for name, plugin in lb_plugins.items():
|
||||
try:
|
||||
strategies.append(
|
||||
{
|
||||
"name": getattr(plugin, "name", name),
|
||||
"priority": getattr(plugin, "priority", 0),
|
||||
"version": (
|
||||
getattr(plugin.metadata, "version", "1.0.0")
|
||||
if hasattr(plugin, "metadata")
|
||||
else "1.0.0"
|
||||
),
|
||||
"description": (
|
||||
getattr(plugin.metadata, "description", "")
|
||||
if hasattr(plugin, "metadata")
|
||||
else ""
|
||||
),
|
||||
"author": (
|
||||
getattr(plugin.metadata, "author", "Unknown")
|
||||
if hasattr(plugin, "metadata")
|
||||
else "Unknown"
|
||||
),
|
||||
}
|
||||
)
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.error(f"Error accessing plugin {name}: {exc}")
|
||||
continue
|
||||
|
||||
strategies.sort(key=lambda x: x["priority"], reverse=True)
|
||||
return JSONResponse({"strategies": strategies, "total": len(strategies)})
|
||||
20
src/api/admin/providers/__init__.py
Normal file
20
src/api/admin/providers/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Provider admin routes export."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .models import router as models_router
|
||||
from .routes import router as routes_router
|
||||
from .summary import router as summary_router
|
||||
|
||||
router = APIRouter(prefix="/api/admin/providers", tags=["Admin - Providers"])
|
||||
|
||||
# Provider CRUD
|
||||
router.include_router(routes_router)
|
||||
|
||||
# Provider summary & health monitor
|
||||
router.include_router(summary_router)
|
||||
|
||||
# Provider models management
|
||||
router.include_router(models_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
443
src/api/admin/providers/models.py
Normal file
443
src/api/admin/providers/models.py
Normal file
@@ -0,0 +1,443 @@
|
||||
"""
|
||||
Provider 模型管理 API
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.exceptions import InvalidRequestException, NotFoundException
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.api import (
|
||||
ModelCreate,
|
||||
ModelResponse,
|
||||
ModelUpdate,
|
||||
)
|
||||
from src.models.pydantic_models import (
|
||||
BatchAssignModelsToProviderRequest,
|
||||
BatchAssignModelsToProviderResponse,
|
||||
)
|
||||
from src.models.database import (
|
||||
GlobalModel,
|
||||
Model,
|
||||
ModelMapping,
|
||||
Provider,
|
||||
)
|
||||
from src.models.pydantic_models import (
|
||||
ProviderAvailableSourceModel,
|
||||
ProviderAvailableSourceModelsResponse,
|
||||
)
|
||||
from src.services.model.service import ModelService
|
||||
|
||||
router = APIRouter(tags=["Model Management"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("/{provider_id}/models", response_model=List[ModelResponse])
|
||||
async def list_provider_models(
|
||||
provider_id: str,
|
||||
request: Request,
|
||||
is_active: Optional[bool] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Session = Depends(get_db),
|
||||
) -> List[ModelResponse]:
|
||||
"""获取提供商的所有模型(管理员)"""
|
||||
adapter = AdminListProviderModelsAdapter(
|
||||
provider_id=provider_id,
|
||||
is_active=is_active,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/{provider_id}/models", response_model=ModelResponse)
|
||||
async def create_provider_model(
|
||||
provider_id: str,
|
||||
model_data: ModelCreate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ModelResponse:
|
||||
"""创建模型(管理员)"""
|
||||
adapter = AdminCreateProviderModelAdapter(provider_id=provider_id, model_data=model_data)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/{provider_id}/models/{model_id}", response_model=ModelResponse)
|
||||
async def get_provider_model(
|
||||
provider_id: str,
|
||||
model_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ModelResponse:
|
||||
"""获取模型详情(管理员)"""
|
||||
adapter = AdminGetProviderModelAdapter(provider_id=provider_id, model_id=model_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch("/{provider_id}/models/{model_id}", response_model=ModelResponse)
|
||||
async def update_provider_model(
|
||||
provider_id: str,
|
||||
model_id: str,
|
||||
model_data: ModelUpdate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ModelResponse:
|
||||
"""更新模型(管理员)"""
|
||||
adapter = AdminUpdateProviderModelAdapter(
|
||||
provider_id=provider_id,
|
||||
model_id=model_id,
|
||||
model_data=model_data,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/{provider_id}/models/{model_id}")
|
||||
async def delete_provider_model(
|
||||
provider_id: str,
|
||||
model_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除模型(管理员)"""
|
||||
adapter = AdminDeleteProviderModelAdapter(provider_id=provider_id, model_id=model_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/{provider_id}/models/batch", response_model=List[ModelResponse])
|
||||
async def batch_create_provider_models(
|
||||
provider_id: str,
|
||||
models_data: List[ModelCreate],
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> List[ModelResponse]:
|
||||
"""批量创建模型(管理员)"""
|
||||
adapter = AdminBatchCreateModelsAdapter(provider_id=provider_id, models_data=models_data)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{provider_id}/available-source-models",
|
||||
response_model=ProviderAvailableSourceModelsResponse,
|
||||
)
|
||||
async def get_provider_available_source_models(
|
||||
provider_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取该 Provider 支持的所有统一模型名(source_model)
|
||||
|
||||
包括:
|
||||
1. 通过 ModelMapping 映射的模型
|
||||
2. 直连模型(Model.provider_model_name 直接作为统一模型名)
|
||||
"""
|
||||
adapter = AdminGetProviderAvailableSourceModelsAdapter(provider_id=provider_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{provider_id}/assign-global-models",
|
||||
response_model=BatchAssignModelsToProviderResponse,
|
||||
)
|
||||
async def batch_assign_global_models_to_provider(
|
||||
provider_id: str,
|
||||
payload: BatchAssignModelsToProviderRequest,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> BatchAssignModelsToProviderResponse:
|
||||
"""批量为 Provider 关联 GlobalModels(自动继承价格和能力配置)"""
|
||||
adapter = AdminBatchAssignModelsToProviderAdapter(
|
||||
provider_id=provider_id, payload=payload
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# -------- Adapters --------
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminListProviderModelsAdapter(AdminApiAdapter):
|
||||
provider_id: str
|
||||
is_active: Optional[bool]
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException("Provider not found", "provider")
|
||||
|
||||
models = ModelService.get_models_by_provider(
|
||||
db, self.provider_id, self.skip, self.limit, self.is_active
|
||||
)
|
||||
return [ModelService.convert_to_response(model) for model in models]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminCreateProviderModelAdapter(AdminApiAdapter):
|
||||
provider_id: str
|
||||
model_data: ModelCreate
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException("Provider not found", "provider")
|
||||
|
||||
try:
|
||||
model = ModelService.create_model(db, self.provider_id, self.model_data)
|
||||
logger.info(f"Model created: {model.provider_model_name} for provider {provider.name} by {context.user.username}")
|
||||
return ModelService.convert_to_response(model)
|
||||
except Exception as exc:
|
||||
raise InvalidRequestException(str(exc))
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminGetProviderModelAdapter(AdminApiAdapter):
|
||||
provider_id: str
|
||||
model_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
model = (
|
||||
db.query(Model)
|
||||
.filter(Model.id == self.model_id, Model.provider_id == self.provider_id)
|
||||
.first()
|
||||
)
|
||||
if not model:
|
||||
raise NotFoundException("Model not found", "model")
|
||||
|
||||
return ModelService.convert_to_response(model)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminUpdateProviderModelAdapter(AdminApiAdapter):
|
||||
provider_id: str
|
||||
model_id: str
|
||||
model_data: ModelUpdate
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
model = (
|
||||
db.query(Model)
|
||||
.filter(Model.id == self.model_id, Model.provider_id == self.provider_id)
|
||||
.first()
|
||||
)
|
||||
if not model:
|
||||
raise NotFoundException("Model not found", "model")
|
||||
|
||||
try:
|
||||
updated_model = ModelService.update_model(db, self.model_id, self.model_data)
|
||||
logger.info(f"Model updated: {updated_model.provider_model_name} by {context.user.username}")
|
||||
return ModelService.convert_to_response(updated_model)
|
||||
except Exception as exc:
|
||||
raise InvalidRequestException(str(exc))
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminDeleteProviderModelAdapter(AdminApiAdapter):
|
||||
provider_id: str
|
||||
model_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
model = (
|
||||
db.query(Model)
|
||||
.filter(Model.id == self.model_id, Model.provider_id == self.provider_id)
|
||||
.first()
|
||||
)
|
||||
if not model:
|
||||
raise NotFoundException("Model not found", "model")
|
||||
|
||||
model_name = model.provider_model_name
|
||||
try:
|
||||
ModelService.delete_model(db, self.model_id)
|
||||
logger.info(f"Model deleted: {model_name} by {context.user.username}")
|
||||
return {"message": f"Model '{model_name}' deleted successfully"}
|
||||
except Exception as exc:
|
||||
raise InvalidRequestException(str(exc))
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminBatchCreateModelsAdapter(AdminApiAdapter):
|
||||
provider_id: str
|
||||
models_data: List[ModelCreate]
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException("Provider not found", "provider")
|
||||
|
||||
try:
|
||||
models = ModelService.batch_create_models(db, self.provider_id, self.models_data)
|
||||
logger.info(f"Batch created {len(models)} models for provider {provider.name} by {context.user.username}")
|
||||
return [ModelService.convert_to_response(model) for model in models]
|
||||
except Exception as exc:
|
||||
raise InvalidRequestException(str(exc))
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminGetProviderAvailableSourceModelsAdapter(AdminApiAdapter):
|
||||
provider_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
"""
|
||||
返回 Provider 支持的所有 GlobalModel
|
||||
|
||||
方案 A 逻辑:
|
||||
1. 查询该 Provider 的所有 Model
|
||||
2. 通过 Model.global_model_id 获取 GlobalModel
|
||||
3. 查询所有指向该 GlobalModel 的别名(ModelMapping.alias)
|
||||
"""
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException("Provider not found", "provider")
|
||||
|
||||
# 1. 查询该 Provider 的所有活跃 Model(预加载 GlobalModel)
|
||||
models = (
|
||||
db.query(Model)
|
||||
.options(joinedload(Model.global_model))
|
||||
.filter(Model.provider_id == self.provider_id, Model.is_active == True)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 2. 构建以 GlobalModel 为主键的字典
|
||||
global_models_dict: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
for model in models:
|
||||
global_model = model.global_model
|
||||
if not global_model or not global_model.is_active:
|
||||
continue
|
||||
|
||||
global_model_name = global_model.name
|
||||
|
||||
# 如果该 GlobalModel 还未处理,初始化
|
||||
if global_model_name not in global_models_dict:
|
||||
# 查询指向该 GlobalModel 的所有别名/映射
|
||||
alias_rows = (
|
||||
db.query(ModelMapping.source_model)
|
||||
.filter(
|
||||
ModelMapping.target_global_model_id == global_model.id,
|
||||
ModelMapping.is_active == True,
|
||||
or_(
|
||||
ModelMapping.provider_id == self.provider_id,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
alias_list = [alias[0] for alias in alias_rows]
|
||||
|
||||
global_models_dict[global_model_name] = {
|
||||
"global_model_name": global_model_name,
|
||||
"display_name": global_model.display_name,
|
||||
"provider_model_name": model.provider_model_name,
|
||||
"has_alias": len(alias_list) > 0,
|
||||
"aliases": alias_list,
|
||||
"model_id": model.id,
|
||||
"price": {
|
||||
"input_price_per_1m": model.get_effective_input_price(),
|
||||
"output_price_per_1m": model.get_effective_output_price(),
|
||||
"cache_creation_price_per_1m": model.get_effective_cache_creation_price(),
|
||||
"cache_read_price_per_1m": model.get_effective_cache_read_price(),
|
||||
"price_per_request": model.get_effective_price_per_request(),
|
||||
},
|
||||
"capabilities": {
|
||||
"supports_vision": bool(model.supports_vision),
|
||||
"supports_function_calling": bool(model.supports_function_calling),
|
||||
"supports_streaming": bool(model.supports_streaming),
|
||||
},
|
||||
"is_active": bool(model.is_active),
|
||||
}
|
||||
|
||||
models_list = [
|
||||
ProviderAvailableSourceModel(**global_models_dict[name])
|
||||
for name in sorted(global_models_dict.keys())
|
||||
]
|
||||
|
||||
return ProviderAvailableSourceModelsResponse(models=models_list, total=len(models_list))
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminBatchAssignModelsToProviderAdapter(AdminApiAdapter):
|
||||
"""批量为 Provider 关联 GlobalModels"""
|
||||
|
||||
provider_id: str
|
||||
payload: BatchAssignModelsToProviderRequest
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException("Provider not found", "provider")
|
||||
|
||||
success = []
|
||||
errors = []
|
||||
|
||||
for global_model_id in self.payload.global_model_ids:
|
||||
try:
|
||||
global_model = (
|
||||
db.query(GlobalModel).filter(GlobalModel.id == global_model_id).first()
|
||||
)
|
||||
if not global_model:
|
||||
errors.append(
|
||||
{"global_model_id": global_model_id, "error": "GlobalModel not found"}
|
||||
)
|
||||
continue
|
||||
|
||||
# 检查是否已存在关联
|
||||
existing = (
|
||||
db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == self.provider_id,
|
||||
Model.global_model_id == global_model_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
errors.append(
|
||||
{
|
||||
"global_model_id": global_model_id,
|
||||
"global_model_name": global_model.name,
|
||||
"error": "Already associated",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# 创建新的 Model 记录,继承 GlobalModel 的配置
|
||||
new_model = Model(
|
||||
provider_id=self.provider_id,
|
||||
global_model_id=global_model_id,
|
||||
provider_model_name=global_model.name,
|
||||
is_active=True,
|
||||
)
|
||||
db.add(new_model)
|
||||
db.flush()
|
||||
|
||||
success.append(
|
||||
{
|
||||
"global_model_id": global_model_id,
|
||||
"global_model_name": global_model.name,
|
||||
"model_id": new_model.id,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
errors.append({"global_model_id": global_model_id, "error": str(e)})
|
||||
|
||||
db.commit()
|
||||
logger.info(
|
||||
f"Batch assigned {len(success)} GlobalModels to provider {provider.name} by {context.user.username}"
|
||||
)
|
||||
|
||||
return BatchAssignModelsToProviderResponse(success=success, errors=errors)
|
||||
249
src/api/admin/providers/routes.py
Normal file
249
src/api/admin/providers/routes.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""管理员 Provider 管理路由。"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.enums import ProviderBillingType
|
||||
from src.core.exceptions import InvalidRequestException, NotFoundException
|
||||
from src.database import get_db
|
||||
from src.models.admin_requests import CreateProviderRequest, UpdateProviderRequest
|
||||
from src.models.database import Provider
|
||||
|
||||
router = APIRouter(tags=["Provider CRUD"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_providers(
|
||||
request: Request,
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=500),
|
||||
is_active: Optional[bool] = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = AdminListProvidersAdapter(skip=skip, limit=limit, is_active=is_active)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def create_provider(request: Request, db: Session = Depends(get_db)):
|
||||
adapter = AdminCreateProviderAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("/{provider_id}")
|
||||
async def update_provider(provider_id: str, request: Request, db: Session = Depends(get_db)):
|
||||
adapter = AdminUpdateProviderAdapter(provider_id=provider_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/{provider_id}")
|
||||
async def delete_provider(provider_id: str, request: Request, db: Session = Depends(get_db)):
|
||||
adapter = AdminDeleteProviderAdapter(provider_id=provider_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
class AdminListProvidersAdapter(AdminApiAdapter):
|
||||
def __init__(self, skip: int, limit: int, is_active: Optional[bool]):
|
||||
self.skip = skip
|
||||
self.limit = limit
|
||||
self.is_active = is_active
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
query = db.query(Provider)
|
||||
if self.is_active is not None:
|
||||
query = query.filter(Provider.is_active == self.is_active)
|
||||
providers = query.offset(self.skip).limit(self.limit).all()
|
||||
|
||||
data = []
|
||||
for provider in providers:
|
||||
api_format = getattr(provider, "api_format", None)
|
||||
base_url = getattr(provider, "base_url", None)
|
||||
api_key = getattr(provider, "api_key", None)
|
||||
priority = getattr(provider, "priority", provider.provider_priority)
|
||||
|
||||
data.append(
|
||||
{
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
"api_format": api_format.value if api_format else None,
|
||||
"base_url": base_url,
|
||||
"api_key": "***" if api_key else None,
|
||||
"priority": priority,
|
||||
"is_active": provider.is_active,
|
||||
"created_at": provider.created_at.isoformat(),
|
||||
"updated_at": provider.updated_at.isoformat() if provider.updated_at else None,
|
||||
}
|
||||
)
|
||||
context.add_audit_metadata(
|
||||
action="list_providers",
|
||||
filter_is_active=self.is_active,
|
||||
limit=self.limit,
|
||||
skip=self.skip,
|
||||
result_count=len(data),
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class AdminCreateProviderAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
payload = context.ensure_json_body()
|
||||
|
||||
try:
|
||||
# 使用 Pydantic 模型进行验证(自动进行 SQL 注入、XSS、SSRF 检测)
|
||||
validated_data = CreateProviderRequest.model_validate(payload)
|
||||
except ValidationError as exc:
|
||||
# 将 Pydantic 验证错误转换为友好的错误信息
|
||||
errors = []
|
||||
for error in exc.errors():
|
||||
field = " -> ".join(str(x) for x in error["loc"])
|
||||
errors.append(f"{field}: {error['msg']}")
|
||||
raise InvalidRequestException("输入验证失败: " + "; ".join(errors))
|
||||
|
||||
try:
|
||||
# 检查名称是否已存在
|
||||
existing = db.query(Provider).filter(Provider.name == validated_data.name).first()
|
||||
if existing:
|
||||
raise InvalidRequestException(f"提供商名称 '{validated_data.name}' 已存在")
|
||||
|
||||
# 将验证后的数据转换为枚举类型
|
||||
billing_type = (
|
||||
ProviderBillingType(validated_data.billing_type)
|
||||
if validated_data.billing_type
|
||||
else ProviderBillingType.PAY_AS_YOU_GO
|
||||
)
|
||||
|
||||
# 创建 Provider 对象
|
||||
provider = Provider(
|
||||
name=validated_data.name,
|
||||
display_name=validated_data.display_name,
|
||||
description=validated_data.description,
|
||||
website=validated_data.website,
|
||||
billing_type=billing_type,
|
||||
monthly_quota_usd=validated_data.monthly_quota_usd,
|
||||
quota_reset_day=validated_data.quota_reset_day,
|
||||
quota_last_reset_at=validated_data.quota_last_reset_at,
|
||||
quota_expires_at=validated_data.quota_expires_at,
|
||||
rpm_limit=validated_data.rpm_limit,
|
||||
provider_priority=validated_data.provider_priority,
|
||||
is_active=validated_data.is_active,
|
||||
rate_limit=validated_data.rate_limit,
|
||||
concurrent_limit=validated_data.concurrent_limit,
|
||||
config=validated_data.config,
|
||||
)
|
||||
|
||||
db.add(provider)
|
||||
db.commit()
|
||||
db.refresh(provider)
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="create_provider",
|
||||
provider_id=provider.id,
|
||||
provider_name=provider.name,
|
||||
billing_type=provider.billing_type.value if provider.billing_type else None,
|
||||
is_active=provider.is_active,
|
||||
provider_priority=provider.provider_priority,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
"message": "提供商创建成功",
|
||||
}
|
||||
except InvalidRequestException:
|
||||
db.rollback()
|
||||
raise
|
||||
except Exception:
|
||||
db.rollback()
|
||||
raise
|
||||
|
||||
|
||||
class AdminUpdateProviderAdapter(AdminApiAdapter):
|
||||
def __init__(self, provider_id: str):
|
||||
self.provider_id = provider_id
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
payload = context.ensure_json_body()
|
||||
|
||||
# 查找 Provider
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException("提供商不存在", "provider")
|
||||
|
||||
try:
|
||||
# 使用 Pydantic 模型进行验证(自动进行 SQL 注入、XSS、SSRF 检测)
|
||||
validated_data = UpdateProviderRequest.model_validate(payload)
|
||||
except ValidationError as exc:
|
||||
# 将 Pydantic 验证错误转换为友好的错误信息
|
||||
errors = []
|
||||
for error in exc.errors():
|
||||
field = " -> ".join(str(x) for x in error["loc"])
|
||||
errors.append(f"{field}: {error['msg']}")
|
||||
raise InvalidRequestException("输入验证失败: " + "; ".join(errors))
|
||||
|
||||
try:
|
||||
# 更新字段(只更新非 None 的字段)
|
||||
update_data = validated_data.model_dump(exclude_unset=True)
|
||||
|
||||
for field, value in update_data.items():
|
||||
if field == "billing_type" and value is not None:
|
||||
# billing_type 需要转换为枚举
|
||||
setattr(provider, field, ProviderBillingType(value))
|
||||
else:
|
||||
setattr(provider, field, value)
|
||||
|
||||
provider.updated_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
db.refresh(provider)
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="update_provider",
|
||||
provider_id=provider.id,
|
||||
changed_fields=list(update_data.keys()),
|
||||
is_active=provider.is_active,
|
||||
provider_priority=provider.provider_priority,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"is_active": provider.is_active,
|
||||
"message": "提供商更新成功",
|
||||
}
|
||||
except InvalidRequestException:
|
||||
db.rollback()
|
||||
raise
|
||||
except Exception:
|
||||
db.rollback()
|
||||
raise
|
||||
|
||||
|
||||
class AdminDeleteProviderAdapter(AdminApiAdapter):
|
||||
def __init__(self, provider_id: str):
|
||||
self.provider_id = provider_id
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException("提供商不存在", "provider")
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="delete_provider",
|
||||
provider_id=provider.id,
|
||||
provider_name=provider.name,
|
||||
)
|
||||
db.delete(provider)
|
||||
db.commit()
|
||||
return {"message": "提供商已删除"}
|
||||
348
src/api/admin/providers/summary.py
Normal file
348
src/api/admin/providers/summary.py
Normal file
@@ -0,0 +1,348 @@
|
||||
"""
|
||||
Provider 摘要与健康监控 API
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Dict, List
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from sqlalchemy import case, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.enums import ProviderBillingType
|
||||
from src.core.exceptions import NotFoundException
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import (
|
||||
Model,
|
||||
Provider,
|
||||
ProviderAPIKey,
|
||||
ProviderEndpoint,
|
||||
RequestCandidate,
|
||||
)
|
||||
from src.models.endpoint_models import (
|
||||
EndpointHealthEvent,
|
||||
EndpointHealthMonitor,
|
||||
ProviderEndpointHealthMonitorResponse,
|
||||
ProviderUpdateRequest,
|
||||
ProviderWithEndpointsSummary,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["Provider Summary"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("/summary", response_model=List[ProviderWithEndpointsSummary])
|
||||
async def get_providers_summary(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> List[ProviderWithEndpointsSummary]:
|
||||
"""获取所有 Providers 的摘要信息(包含 Endpoints 和 Keys 统计)"""
|
||||
adapter = AdminProviderSummaryAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/{provider_id}/summary", response_model=ProviderWithEndpointsSummary)
|
||||
async def get_provider_summary(
|
||||
provider_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ProviderWithEndpointsSummary:
|
||||
"""获取单个 Provider 的摘要信息(包含 Endpoints 和 Keys 统计)"""
|
||||
provider = db.query(Provider).filter(Provider.id == provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException(f"Provider {provider_id} not found")
|
||||
|
||||
return _build_provider_summary(db, provider)
|
||||
|
||||
|
||||
@router.get("/{provider_id}/health-monitor", response_model=ProviderEndpointHealthMonitorResponse)
|
||||
async def get_provider_health_monitor(
|
||||
provider_id: str,
|
||||
request: Request,
|
||||
lookback_hours: int = Query(6, ge=1, le=72, description="回溯的小时数"),
|
||||
per_endpoint_limit: int = Query(48, ge=10, le=200, description="每个端点的事件数量"),
|
||||
db: Session = Depends(get_db),
|
||||
) -> ProviderEndpointHealthMonitorResponse:
|
||||
"""获取 Provider 下所有端点的健康监控时间线"""
|
||||
|
||||
adapter = AdminProviderHealthMonitorAdapter(
|
||||
provider_id=provider_id,
|
||||
lookback_hours=lookback_hours,
|
||||
per_endpoint_limit=per_endpoint_limit,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch("/{provider_id}", response_model=ProviderWithEndpointsSummary)
|
||||
async def update_provider_settings(
|
||||
provider_id: str,
|
||||
update_data: ProviderUpdateRequest,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> ProviderWithEndpointsSummary:
|
||||
"""更新 Provider 基础配置(display_name, description, priority, weight 等)"""
|
||||
|
||||
adapter = AdminUpdateProviderSettingsAdapter(provider_id=provider_id, update_data=update_data)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
def _build_provider_summary(db: Session, provider: Provider) -> ProviderWithEndpointsSummary:
|
||||
endpoints = db.query(ProviderEndpoint).filter(ProviderEndpoint.provider_id == provider.id).all()
|
||||
|
||||
total_endpoints = len(endpoints)
|
||||
active_endpoints = sum(1 for e in endpoints if e.is_active)
|
||||
endpoint_ids = [e.id for e in endpoints]
|
||||
|
||||
# Key 统计(合并为单个查询)
|
||||
total_keys = 0
|
||||
active_keys = 0
|
||||
if endpoint_ids:
|
||||
key_stats = db.query(
|
||||
func.count(ProviderAPIKey.id).label("total"),
|
||||
func.sum(case((ProviderAPIKey.is_active == True, 1), else_=0)).label("active"),
|
||||
).filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids)).first()
|
||||
total_keys = key_stats.total or 0
|
||||
active_keys = int(key_stats.active or 0)
|
||||
|
||||
# Model 统计(合并为单个查询)
|
||||
model_stats = db.query(
|
||||
func.count(Model.id).label("total"),
|
||||
func.sum(case((Model.is_active == True, 1), else_=0)).label("active"),
|
||||
).filter(Model.provider_id == provider.id).first()
|
||||
total_models = model_stats.total or 0
|
||||
active_models = int(model_stats.active or 0)
|
||||
|
||||
api_formats = [e.api_format for e in endpoints]
|
||||
|
||||
# 优化: 一次性加载所有 endpoint 的 keys,避免 N+1 查询
|
||||
all_keys = []
|
||||
if endpoint_ids:
|
||||
all_keys = (
|
||||
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids)).all()
|
||||
)
|
||||
|
||||
# 按 endpoint_id 分组 keys
|
||||
keys_by_endpoint: dict[str, list[ProviderAPIKey]] = {}
|
||||
for key in all_keys:
|
||||
if key.endpoint_id not in keys_by_endpoint:
|
||||
keys_by_endpoint[key.endpoint_id] = []
|
||||
keys_by_endpoint[key.endpoint_id].append(key)
|
||||
|
||||
endpoint_health_map: dict[str, float] = {}
|
||||
for endpoint in endpoints:
|
||||
keys = keys_by_endpoint.get(endpoint.id, [])
|
||||
if keys:
|
||||
health_scores = [k.health_score for k in keys if k.health_score is not None]
|
||||
avg_health = sum(health_scores) / len(health_scores) if health_scores else 1.0
|
||||
endpoint_health_map[endpoint.id] = avg_health
|
||||
else:
|
||||
endpoint_health_map[endpoint.id] = 1.0
|
||||
|
||||
all_health_scores = list(endpoint_health_map.values())
|
||||
avg_health_score = sum(all_health_scores) / len(all_health_scores) if all_health_scores else 1.0
|
||||
unhealthy_endpoints = sum(1 for score in all_health_scores if score < 0.5)
|
||||
|
||||
# 计算每个端点的活跃密钥数量
|
||||
active_keys_by_endpoint: dict[str, int] = {}
|
||||
for endpoint_id, keys in keys_by_endpoint.items():
|
||||
active_keys_by_endpoint[endpoint_id] = sum(1 for k in keys if k.is_active)
|
||||
|
||||
endpoint_health_details = [
|
||||
{
|
||||
"api_format": e.api_format,
|
||||
"health_score": endpoint_health_map.get(e.id, 1.0),
|
||||
"is_active": e.is_active,
|
||||
"active_keys": active_keys_by_endpoint.get(e.id, 0),
|
||||
}
|
||||
for e in endpoints
|
||||
]
|
||||
|
||||
return ProviderWithEndpointsSummary(
|
||||
id=provider.id,
|
||||
name=provider.name,
|
||||
display_name=provider.display_name,
|
||||
description=provider.description,
|
||||
website=provider.website,
|
||||
provider_priority=provider.provider_priority,
|
||||
is_active=provider.is_active,
|
||||
billing_type=provider.billing_type.value if provider.billing_type else None,
|
||||
monthly_quota_usd=provider.monthly_quota_usd,
|
||||
monthly_used_usd=provider.monthly_used_usd,
|
||||
quota_reset_day=provider.quota_reset_day,
|
||||
quota_last_reset_at=provider.quota_last_reset_at,
|
||||
quota_expires_at=provider.quota_expires_at,
|
||||
rpm_limit=provider.rpm_limit,
|
||||
rpm_used=provider.rpm_used,
|
||||
rpm_reset_at=provider.rpm_reset_at,
|
||||
total_endpoints=total_endpoints,
|
||||
active_endpoints=active_endpoints,
|
||||
total_keys=total_keys,
|
||||
active_keys=active_keys,
|
||||
total_models=total_models,
|
||||
active_models=active_models,
|
||||
avg_health_score=avg_health_score,
|
||||
unhealthy_endpoints=unhealthy_endpoints,
|
||||
api_formats=api_formats,
|
||||
endpoint_health_details=endpoint_health_details,
|
||||
created_at=provider.created_at,
|
||||
updated_at=provider.updated_at,
|
||||
)
|
||||
|
||||
|
||||
# -------- Adapters --------
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminProviderHealthMonitorAdapter(AdminApiAdapter):
|
||||
provider_id: str
|
||||
lookback_hours: int
|
||||
per_endpoint_limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException(f"Provider {self.provider_id} 不存在")
|
||||
|
||||
endpoints = (
|
||||
db.query(ProviderEndpoint)
|
||||
.filter(ProviderEndpoint.provider_id == self.provider_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
since = now - timedelta(hours=self.lookback_hours)
|
||||
|
||||
endpoint_ids = [endpoint.id for endpoint in endpoints]
|
||||
if not endpoint_ids:
|
||||
response = ProviderEndpointHealthMonitorResponse(
|
||||
provider_id=provider.id,
|
||||
provider_name=provider.display_name or provider.name,
|
||||
generated_at=now,
|
||||
endpoints=[],
|
||||
)
|
||||
context.add_audit_metadata(
|
||||
action="provider_health_monitor",
|
||||
provider_id=self.provider_id,
|
||||
endpoint_count=0,
|
||||
lookback_hours=self.lookback_hours,
|
||||
)
|
||||
return response
|
||||
|
||||
limit_rows = max(200, self.per_endpoint_limit * max(1, len(endpoint_ids)) * 2)
|
||||
attempts_query = (
|
||||
db.query(RequestCandidate)
|
||||
.filter(
|
||||
RequestCandidate.endpoint_id.in_(endpoint_ids),
|
||||
RequestCandidate.created_at >= since,
|
||||
)
|
||||
.order_by(RequestCandidate.created_at.desc())
|
||||
)
|
||||
attempts = attempts_query.limit(limit_rows).all()
|
||||
|
||||
buffered_attempts: Dict[str, List[RequestCandidate]] = {eid: [] for eid in endpoint_ids}
|
||||
counters: Dict[str, int] = {eid: 0 for eid in endpoint_ids}
|
||||
|
||||
for attempt in attempts:
|
||||
if not attempt.endpoint_id or attempt.endpoint_id not in buffered_attempts:
|
||||
continue
|
||||
if counters[attempt.endpoint_id] >= self.per_endpoint_limit:
|
||||
continue
|
||||
buffered_attempts[attempt.endpoint_id].append(attempt)
|
||||
counters[attempt.endpoint_id] += 1
|
||||
|
||||
endpoint_monitors: List[EndpointHealthMonitor] = []
|
||||
for endpoint in endpoints:
|
||||
attempt_list = list(reversed(buffered_attempts.get(endpoint.id, [])))
|
||||
events: List[EndpointHealthEvent] = []
|
||||
for attempt in attempt_list:
|
||||
event_timestamp = attempt.finished_at or attempt.started_at or attempt.created_at
|
||||
events.append(
|
||||
EndpointHealthEvent(
|
||||
timestamp=event_timestamp,
|
||||
status=attempt.status,
|
||||
status_code=attempt.status_code,
|
||||
latency_ms=attempt.latency_ms,
|
||||
error_type=attempt.error_type,
|
||||
error_message=attempt.error_message,
|
||||
)
|
||||
)
|
||||
|
||||
success_count = sum(1 for event in events if event.status == "success")
|
||||
failed_count = sum(1 for event in events if event.status == "failed")
|
||||
skipped_count = sum(1 for event in events if event.status == "skipped")
|
||||
total_attempts = len(events)
|
||||
success_rate = success_count / total_attempts if total_attempts else 1.0
|
||||
last_event_at = events[-1].timestamp if events else None
|
||||
|
||||
endpoint_monitors.append(
|
||||
EndpointHealthMonitor(
|
||||
endpoint_id=endpoint.id,
|
||||
api_format=endpoint.api_format,
|
||||
is_active=endpoint.is_active,
|
||||
total_attempts=total_attempts,
|
||||
success_count=success_count,
|
||||
failed_count=failed_count,
|
||||
skipped_count=skipped_count,
|
||||
success_rate=success_rate,
|
||||
last_event_at=last_event_at,
|
||||
events=events,
|
||||
)
|
||||
)
|
||||
|
||||
response = ProviderEndpointHealthMonitorResponse(
|
||||
provider_id=provider.id,
|
||||
provider_name=provider.display_name or provider.name,
|
||||
generated_at=now,
|
||||
endpoints=endpoint_monitors,
|
||||
)
|
||||
context.add_audit_metadata(
|
||||
action="provider_health_monitor",
|
||||
provider_id=self.provider_id,
|
||||
endpoint_count=len(endpoint_monitors),
|
||||
lookback_hours=self.lookback_hours,
|
||||
per_endpoint_limit=self.per_endpoint_limit,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
class AdminProviderSummaryAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
providers = (
|
||||
db.query(Provider)
|
||||
.order_by(Provider.provider_priority.asc(), Provider.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
return [_build_provider_summary(db, provider) for provider in providers]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminUpdateProviderSettingsAdapter(AdminApiAdapter):
|
||||
provider_id: str
|
||||
update_data: ProviderUpdateRequest
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
if not provider:
|
||||
raise NotFoundException("Provider not found", "provider")
|
||||
|
||||
update_dict = self.update_data.model_dump(exclude_unset=True)
|
||||
if "billing_type" in update_dict and update_dict["billing_type"] is not None:
|
||||
update_dict["billing_type"] = ProviderBillingType(update_dict["billing_type"])
|
||||
|
||||
for key, value in update_dict.items():
|
||||
setattr(provider, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(provider)
|
||||
|
||||
admin_name = context.user.username if context.user else "admin"
|
||||
logger.info(f"Provider {provider.name} updated by {admin_name}: {update_dict}")
|
||||
|
||||
return _build_provider_summary(db, provider)
|
||||
14
src/api/admin/security/__init__.py
Normal file
14
src/api/admin/security/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
安全管理 API
|
||||
|
||||
提供 IP 黑白名单管理等安全功能
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .ip_management import router as ip_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(ip_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
202
src/api/admin/security/ip_management.py
Normal file
202
src/api/admin/security/ip_management.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""
|
||||
IP 安全管理接口
|
||||
|
||||
提供 IP 黑白名单管理和速率限制统计
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.adapter import ApiMode
|
||||
from src.api.base.authenticated_adapter import AuthenticatedApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.exceptions import InvalidRequestException, translate_pydantic_error
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.services.rate_limit.ip_limiter import IPRateLimiter
|
||||
|
||||
router = APIRouter(prefix="/api/admin/security/ip", tags=["IP Security"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
# ========== Pydantic 模型 ==========
|
||||
|
||||
|
||||
class AddIPToBlacklistRequest(BaseModel):
|
||||
"""添加 IP 到黑名单请求"""
|
||||
|
||||
ip_address: str = Field(..., description="IP 地址")
|
||||
reason: str = Field(..., min_length=1, max_length=200, description="加入黑名单的原因")
|
||||
ttl: Optional[int] = Field(None, gt=0, description="过期时间(秒),None 表示永久")
|
||||
|
||||
|
||||
class RemoveIPFromBlacklistRequest(BaseModel):
|
||||
"""从黑名单移除 IP 请求"""
|
||||
|
||||
ip_address: str = Field(..., description="IP 地址")
|
||||
|
||||
|
||||
class AddIPToWhitelistRequest(BaseModel):
|
||||
"""添加 IP 到白名单请求"""
|
||||
|
||||
ip_address: str = Field(..., description="IP 地址或 CIDR 格式(如 192.168.1.0/24)")
|
||||
|
||||
|
||||
class RemoveIPFromWhitelistRequest(BaseModel):
|
||||
"""从白名单移除 IP 请求"""
|
||||
|
||||
ip_address: str = Field(..., description="IP 地址")
|
||||
|
||||
|
||||
# ========== API 端点 ==========
|
||||
|
||||
|
||||
@router.post("/blacklist")
|
||||
async def add_to_blacklist(request: Request, db: Session = Depends(get_db)):
|
||||
"""Add IP to blacklist"""
|
||||
adapter = AddToBlacklistAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.ADMIN)
|
||||
|
||||
|
||||
@router.delete("/blacklist/{ip_address}")
|
||||
async def remove_from_blacklist(ip_address: str, request: Request, db: Session = Depends(get_db)):
|
||||
"""Remove IP from blacklist"""
|
||||
adapter = RemoveFromBlacklistAdapter(ip_address=ip_address)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.ADMIN)
|
||||
|
||||
|
||||
@router.get("/blacklist/stats")
|
||||
async def get_blacklist_stats(request: Request, db: Session = Depends(get_db)):
|
||||
"""Get blacklist statistics"""
|
||||
adapter = GetBlacklistStatsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.ADMIN)
|
||||
|
||||
|
||||
@router.post("/whitelist")
|
||||
async def add_to_whitelist(request: Request, db: Session = Depends(get_db)):
|
||||
"""Add IP to whitelist"""
|
||||
adapter = AddToWhitelistAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.ADMIN)
|
||||
|
||||
|
||||
@router.delete("/whitelist/{ip_address}")
|
||||
async def remove_from_whitelist(ip_address: str, request: Request, db: Session = Depends(get_db)):
|
||||
"""Remove IP from whitelist"""
|
||||
adapter = RemoveFromWhitelistAdapter(ip_address=ip_address)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.ADMIN)
|
||||
|
||||
|
||||
@router.get("/whitelist")
|
||||
async def get_whitelist(request: Request, db: Session = Depends(get_db)):
|
||||
"""Get whitelist"""
|
||||
adapter = GetWhitelistAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.ADMIN)
|
||||
|
||||
|
||||
# ========== 适配器实现 ==========
|
||||
|
||||
|
||||
class AddToBlacklistAdapter(AuthenticatedApiAdapter):
|
||||
"""添加 IP 到黑名单适配器"""
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
payload = context.ensure_json_body()
|
||||
try:
|
||||
req = AddIPToBlacklistRequest.model_validate(payload)
|
||||
except ValidationError as e:
|
||||
errors = e.errors()
|
||||
if errors:
|
||||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||||
raise InvalidRequestException("请求数据验证失败")
|
||||
|
||||
success = await IPRateLimiter.add_to_blacklist(req.ip_address, req.reason, req.ttl)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="添加 IP 到黑名单失败(Redis 不可用)",
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"IP {req.ip_address} 已加入黑名单",
|
||||
"reason": req.reason,
|
||||
"ttl": req.ttl or "永久",
|
||||
}
|
||||
|
||||
|
||||
class RemoveFromBlacklistAdapter(AuthenticatedApiAdapter):
|
||||
"""从黑名单移除 IP 适配器"""
|
||||
|
||||
def __init__(self, ip_address: str):
|
||||
self.ip_address = ip_address
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
success = await IPRateLimiter.remove_from_blacklist(self.ip_address)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=f"IP {self.ip_address} 不在黑名单中"
|
||||
)
|
||||
|
||||
return {"success": True, "message": f"IP {self.ip_address} 已从黑名单移除"}
|
||||
|
||||
|
||||
class GetBlacklistStatsAdapter(AuthenticatedApiAdapter):
|
||||
"""获取黑名单统计适配器"""
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
stats = await IPRateLimiter.get_blacklist_stats()
|
||||
return stats
|
||||
|
||||
|
||||
class AddToWhitelistAdapter(AuthenticatedApiAdapter):
|
||||
"""添加 IP 到白名单适配器"""
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
payload = context.ensure_json_body()
|
||||
try:
|
||||
req = AddIPToWhitelistRequest.model_validate(payload)
|
||||
except ValidationError as e:
|
||||
errors = e.errors()
|
||||
if errors:
|
||||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||||
raise InvalidRequestException("请求数据验证失败")
|
||||
|
||||
success = await IPRateLimiter.add_to_whitelist(req.ip_address)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"添加 IP 到白名单失败(无效的 IP 格式或 Redis 不可用)",
|
||||
)
|
||||
|
||||
return {"success": True, "message": f"IP {req.ip_address} 已加入白名单"}
|
||||
|
||||
|
||||
class RemoveFromWhitelistAdapter(AuthenticatedApiAdapter):
|
||||
"""从白名单移除 IP 适配器"""
|
||||
|
||||
def __init__(self, ip_address: str):
|
||||
self.ip_address = ip_address
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
success = await IPRateLimiter.remove_from_whitelist(self.ip_address)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=f"IP {self.ip_address} 不在白名单中"
|
||||
)
|
||||
|
||||
return {"success": True, "message": f"IP {self.ip_address} 已从白名单移除"}
|
||||
|
||||
|
||||
class GetWhitelistAdapter(AuthenticatedApiAdapter):
|
||||
"""获取白名单适配器"""
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
whitelist = await IPRateLimiter.get_whitelist()
|
||||
return {"whitelist": list(whitelist), "total": len(whitelist)}
|
||||
312
src/api/admin/system.py
Normal file
312
src/api/admin/system.py
Normal file
@@ -0,0 +1,312 @@
|
||||
"""系统设置API端点。"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.exceptions import InvalidRequestException, NotFoundException, translate_pydantic_error
|
||||
from src.database import get_db
|
||||
from src.models.api import SystemSettingsRequest, SystemSettingsResponse
|
||||
from src.models.database import ApiKey, Provider, Usage, User
|
||||
from src.services.system.config import SystemConfigService
|
||||
|
||||
router = APIRouter(prefix="/api/admin/system", tags=["Admin - System"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
@router.get("/settings")
|
||||
async def get_system_settings(request: Request, db: Session = Depends(get_db)):
|
||||
"""获取系统设置(管理员)"""
|
||||
|
||||
adapter = AdminGetSystemSettingsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("/settings")
|
||||
async def update_system_settings(http_request: Request, db: Session = Depends(get_db)):
|
||||
"""更新系统设置(管理员)"""
|
||||
|
||||
adapter = AdminUpdateSystemSettingsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=http_request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/configs")
|
||||
async def get_all_system_configs(request: Request, db: Session = Depends(get_db)):
|
||||
"""获取所有系统配置(管理员)"""
|
||||
|
||||
adapter = AdminGetAllConfigsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/configs/{key}")
|
||||
async def get_system_config(key: str, request: Request, db: Session = Depends(get_db)):
|
||||
"""获取特定系统配置(管理员)"""
|
||||
|
||||
adapter = AdminGetSystemConfigAdapter(key=key)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("/configs/{key}")
|
||||
async def set_system_config(
|
||||
key: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""设置系统配置(管理员)"""
|
||||
|
||||
adapter = AdminSetSystemConfigAdapter(key=key)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/configs/{key}")
|
||||
async def delete_system_config(key: str, request: Request, db: Session = Depends(get_db)):
|
||||
"""删除系统配置(管理员)"""
|
||||
|
||||
adapter = AdminDeleteSystemConfigAdapter(key=key)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_system_stats(request: Request, db: Session = Depends(get_db)):
|
||||
adapter = AdminSystemStatsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/cleanup")
|
||||
async def trigger_cleanup(request: Request, db: Session = Depends(get_db)):
|
||||
"""Manually trigger usage record cleanup task"""
|
||||
adapter = AdminTriggerCleanupAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/api-formats")
|
||||
async def get_api_formats(request: Request, db: Session = Depends(get_db)):
|
||||
"""获取所有可用的API格式列表"""
|
||||
adapter = AdminGetApiFormatsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# -------- 系统设置适配器 --------
|
||||
|
||||
|
||||
class AdminGetSystemSettingsAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
default_provider = SystemConfigService.get_default_provider(db)
|
||||
default_model = SystemConfigService.get_config(db, "default_model")
|
||||
enable_usage_tracking = (
|
||||
SystemConfigService.get_config(db, "enable_usage_tracking", "true") == "true"
|
||||
)
|
||||
|
||||
return SystemSettingsResponse(
|
||||
default_provider=default_provider,
|
||||
default_model=default_model,
|
||||
enable_usage_tracking=enable_usage_tracking,
|
||||
)
|
||||
|
||||
|
||||
class AdminUpdateSystemSettingsAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
payload = context.ensure_json_body()
|
||||
try:
|
||||
settings_request = SystemSettingsRequest.model_validate(payload)
|
||||
except ValidationError as e:
|
||||
errors = e.errors()
|
||||
if errors:
|
||||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||||
raise InvalidRequestException("请求数据验证失败")
|
||||
|
||||
if settings_request.default_provider is not None:
|
||||
provider = (
|
||||
db.query(Provider)
|
||||
.filter(
|
||||
Provider.name == settings_request.default_provider,
|
||||
Provider.is_active.is_(True),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not provider and settings_request.default_provider != "":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"提供商 '{settings_request.default_provider}' 不存在或未启用",
|
||||
)
|
||||
|
||||
if settings_request.default_provider:
|
||||
SystemConfigService.set_default_provider(db, settings_request.default_provider)
|
||||
else:
|
||||
SystemConfigService.delete_config(db, "default_provider")
|
||||
|
||||
if settings_request.default_model is not None:
|
||||
if settings_request.default_model:
|
||||
SystemConfigService.set_config(db, "default_model", settings_request.default_model)
|
||||
else:
|
||||
SystemConfigService.delete_config(db, "default_model")
|
||||
|
||||
if settings_request.enable_usage_tracking is not None:
|
||||
SystemConfigService.set_config(
|
||||
db,
|
||||
"enable_usage_tracking",
|
||||
str(settings_request.enable_usage_tracking).lower(),
|
||||
)
|
||||
|
||||
return {"message": "系统设置更新成功"}
|
||||
|
||||
|
||||
class AdminGetAllConfigsAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
return SystemConfigService.get_all_configs(context.db)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminGetSystemConfigAdapter(AdminApiAdapter):
|
||||
key: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
value = SystemConfigService.get_config(context.db, self.key)
|
||||
if value is None:
|
||||
raise NotFoundException(f"配置项 '{self.key}' 不存在")
|
||||
return {"key": self.key, "value": value}
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminSetSystemConfigAdapter(AdminApiAdapter):
|
||||
key: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
payload = context.ensure_json_body()
|
||||
config = SystemConfigService.set_config(
|
||||
context.db,
|
||||
self.key,
|
||||
payload.get("value"),
|
||||
payload.get("description"),
|
||||
)
|
||||
|
||||
return {
|
||||
"key": config.key,
|
||||
"value": config.value,
|
||||
"description": config.description,
|
||||
"updated_at": config.updated_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminDeleteSystemConfigAdapter(AdminApiAdapter):
|
||||
key: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
deleted = SystemConfigService.delete_config(context.db, self.key)
|
||||
if not deleted:
|
||||
raise NotFoundException(f"配置项 '{self.key}' 不存在")
|
||||
return {"message": f"配置项 '{self.key}' 已删除"}
|
||||
|
||||
|
||||
class AdminSystemStatsAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
total_users = db.query(User).count()
|
||||
active_users = db.query(User).filter(User.is_active.is_(True)).count()
|
||||
total_providers = db.query(Provider).count()
|
||||
active_providers = db.query(Provider).filter(Provider.is_active.is_(True)).count()
|
||||
total_api_keys = db.query(ApiKey).count()
|
||||
total_requests = db.query(Usage).count()
|
||||
|
||||
return {
|
||||
"users": {"total": total_users, "active": active_users},
|
||||
"providers": {"total": total_providers, "active": active_providers},
|
||||
"api_keys": total_api_keys,
|
||||
"requests": total_requests,
|
||||
}
|
||||
|
||||
|
||||
class AdminTriggerCleanupAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
"""手动触发清理任务"""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import func
|
||||
|
||||
from src.services.system.cleanup_scheduler import get_cleanup_scheduler
|
||||
|
||||
db = context.db
|
||||
|
||||
# 获取清理前的统计信息
|
||||
total_before = db.query(Usage).count()
|
||||
with_body_before = (
|
||||
db.query(Usage)
|
||||
.filter((Usage.request_body.isnot(None)) | (Usage.response_body.isnot(None)))
|
||||
.count()
|
||||
)
|
||||
with_headers_before = (
|
||||
db.query(Usage)
|
||||
.filter((Usage.request_headers.isnot(None)) | (Usage.response_headers.isnot(None)))
|
||||
.count()
|
||||
)
|
||||
|
||||
# 触发清理
|
||||
cleanup_scheduler = get_cleanup_scheduler()
|
||||
await cleanup_scheduler._perform_cleanup()
|
||||
|
||||
# 获取清理后的统计信息
|
||||
total_after = db.query(Usage).count()
|
||||
with_body_after = (
|
||||
db.query(Usage)
|
||||
.filter((Usage.request_body.isnot(None)) | (Usage.response_body.isnot(None)))
|
||||
.count()
|
||||
)
|
||||
with_headers_after = (
|
||||
db.query(Usage)
|
||||
.filter((Usage.request_headers.isnot(None)) | (Usage.response_headers.isnot(None)))
|
||||
.count()
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "清理任务执行完成",
|
||||
"stats": {
|
||||
"total_records": {
|
||||
"before": total_before,
|
||||
"after": total_after,
|
||||
"deleted": total_before - total_after,
|
||||
},
|
||||
"body_fields": {
|
||||
"before": with_body_before,
|
||||
"after": with_body_after,
|
||||
"cleaned": with_body_before - with_body_after,
|
||||
},
|
||||
"header_fields": {
|
||||
"before": with_headers_before,
|
||||
"after": with_headers_after,
|
||||
"cleaned": with_headers_before - with_headers_after,
|
||||
},
|
||||
},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
|
||||
class AdminGetApiFormatsAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
"""获取所有可用的API格式"""
|
||||
from src.core.api_format_metadata import API_FORMAT_DEFINITIONS
|
||||
from src.core.enums import APIFormat
|
||||
|
||||
_ = context # 参数保留以符合接口规范
|
||||
|
||||
formats = []
|
||||
for api_format in APIFormat:
|
||||
definition = API_FORMAT_DEFINITIONS.get(api_format)
|
||||
formats.append(
|
||||
{
|
||||
"value": api_format.value,
|
||||
"label": api_format.value,
|
||||
"default_path": definition.default_path if definition else "/",
|
||||
"aliases": list(definition.aliases) if definition else [],
|
||||
}
|
||||
)
|
||||
|
||||
return {"formats": formats}
|
||||
5
src/api/admin/usage/__init__.py
Normal file
5
src/api/admin/usage/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Usage admin routes export."""
|
||||
|
||||
from .routes import router
|
||||
|
||||
__all__ = ["router"]
|
||||
818
src/api/admin/usage/routes.py
Normal file
818
src/api/admin/usage/routes.py
Normal file
@@ -0,0 +1,818 @@
|
||||
"""管理员使用情况统计路由。"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.database import get_db
|
||||
from src.models.database import (
|
||||
ApiKey,
|
||||
Provider,
|
||||
ProviderAPIKey,
|
||||
ProviderEndpoint,
|
||||
RequestCandidate,
|
||||
Usage,
|
||||
User,
|
||||
)
|
||||
from src.services.usage.service import UsageService
|
||||
|
||||
router = APIRouter(prefix="/api/admin/usage", tags=["Admin - Usage"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
# ==================== RESTful Routes ====================
|
||||
|
||||
|
||||
@router.get("/aggregation/stats")
|
||||
async def get_usage_aggregation(
|
||||
request: Request,
|
||||
group_by: str = Query(..., description="Aggregation dimension: model, user, provider, or api_format"),
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get usage aggregation by specified dimension.
|
||||
|
||||
- group_by=model: Aggregate by model
|
||||
- group_by=user: Aggregate by user
|
||||
- group_by=provider: Aggregate by provider
|
||||
- group_by=api_format: Aggregate by API format
|
||||
"""
|
||||
if group_by == "model":
|
||||
adapter = AdminUsageByModelAdapter(start_date=start_date, end_date=end_date, limit=limit)
|
||||
elif group_by == "user":
|
||||
adapter = AdminUsageByUserAdapter(start_date=start_date, end_date=end_date, limit=limit)
|
||||
elif group_by == "provider":
|
||||
adapter = AdminUsageByProviderAdapter(start_date=start_date, end_date=end_date, limit=limit)
|
||||
elif group_by == "api_format":
|
||||
adapter = AdminUsageByApiFormatAdapter(start_date=start_date, end_date=end_date, limit=limit)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid group_by value: {group_by}. Must be one of: model, user, provider, api_format"
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_usage_stats(
|
||||
request: Request,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = AdminUsageStatsAdapter(start_date=start_date, end_date=end_date)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/records")
|
||||
async def get_usage_records(
|
||||
request: Request,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
user_id: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
status: Optional[str] = None, # stream, standard, error
|
||||
limit: int = Query(100, ge=1, le=500),
|
||||
offset: int = Query(0, ge=0),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = AdminUsageRecordsAdapter(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
model=model,
|
||||
provider=provider,
|
||||
status=status,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/active")
|
||||
async def get_active_requests(
|
||||
request: Request,
|
||||
ids: Optional[str] = Query(None, description="逗号分隔的请求 ID 列表,用于查询特定请求的状态"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取活跃请求的状态(轻量级接口,用于前端轮询)
|
||||
|
||||
- 如果提供 ids 参数,只返回这些 ID 对应请求的最新状态
|
||||
- 如果不提供 ids,返回所有 pending/streaming 状态的请求
|
||||
"""
|
||||
adapter = AdminActiveRequestsAdapter(ids=ids)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# NOTE: This route must be defined AFTER all other routes to avoid matching
|
||||
# routes like /stats, /records, /active, etc.
|
||||
@router.get("/{usage_id}")
|
||||
async def get_usage_detail(
|
||||
usage_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get detailed information of a specific usage record.
|
||||
|
||||
Includes request/response headers and body.
|
||||
"""
|
||||
adapter = AdminUsageDetailAdapter(usage_id=usage_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
class AdminUsageStatsAdapter(AdminApiAdapter):
|
||||
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime]):
|
||||
self.start_date = start_date
|
||||
self.end_date = end_date
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
query = db.query(Usage)
|
||||
if self.start_date:
|
||||
query = query.filter(Usage.created_at >= self.start_date)
|
||||
if self.end_date:
|
||||
query = query.filter(Usage.created_at <= self.end_date)
|
||||
|
||||
total_stats = query.with_entities(
|
||||
func.count(Usage.id).label("total_requests"),
|
||||
func.sum(Usage.total_tokens).label("total_tokens"),
|
||||
func.sum(Usage.total_cost_usd).label("total_cost"),
|
||||
func.sum(Usage.actual_total_cost_usd).label("total_actual_cost"),
|
||||
func.avg(Usage.response_time_ms).label("avg_response_time_ms"),
|
||||
).first()
|
||||
|
||||
# 缓存统计
|
||||
cache_stats = query.with_entities(
|
||||
func.sum(Usage.cache_creation_input_tokens).label("cache_creation_tokens"),
|
||||
func.sum(Usage.cache_read_input_tokens).label("cache_read_tokens"),
|
||||
func.sum(Usage.cache_creation_cost_usd).label("cache_creation_cost"),
|
||||
func.sum(Usage.cache_read_cost_usd).label("cache_read_cost"),
|
||||
).first()
|
||||
|
||||
# 错误统计
|
||||
error_count = query.filter(
|
||||
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
|
||||
).count()
|
||||
|
||||
activity_heatmap = UsageService.get_daily_activity(
|
||||
db=db,
|
||||
window_days=365,
|
||||
include_actual_cost=True,
|
||||
)
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="usage_stats",
|
||||
start_date=self.start_date.isoformat() if self.start_date else None,
|
||||
end_date=self.end_date.isoformat() if self.end_date else None,
|
||||
)
|
||||
|
||||
total_requests = total_stats.total_requests if total_stats else 0
|
||||
avg_response_time_ms = float(total_stats.avg_response_time_ms or 0) if total_stats else 0
|
||||
avg_response_time = avg_response_time_ms / 1000.0
|
||||
|
||||
return {
|
||||
"total_requests": total_requests,
|
||||
"total_tokens": int(total_stats.total_tokens or 0),
|
||||
"total_cost": float(total_stats.total_cost or 0),
|
||||
"total_actual_cost": float(total_stats.total_actual_cost or 0),
|
||||
"avg_response_time": round(avg_response_time, 2),
|
||||
"error_count": error_count,
|
||||
"error_rate": (
|
||||
round((error_count / total_requests) * 100, 2) if total_requests > 0 else 0
|
||||
),
|
||||
"cache_stats": {
|
||||
"cache_creation_tokens": (
|
||||
int(cache_stats.cache_creation_tokens or 0) if cache_stats else 0
|
||||
),
|
||||
"cache_read_tokens": int(cache_stats.cache_read_tokens or 0) if cache_stats else 0,
|
||||
"cache_creation_cost": (
|
||||
float(cache_stats.cache_creation_cost or 0) if cache_stats else 0
|
||||
),
|
||||
"cache_read_cost": float(cache_stats.cache_read_cost or 0) if cache_stats else 0,
|
||||
},
|
||||
"activity_heatmap": activity_heatmap,
|
||||
}
|
||||
|
||||
|
||||
class AdminUsageByModelAdapter(AdminApiAdapter):
|
||||
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int):
|
||||
self.start_date = start_date
|
||||
self.end_date = end_date
|
||||
self.limit = limit
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
query = db.query(
|
||||
Usage.model,
|
||||
func.count(Usage.id).label("request_count"),
|
||||
func.sum(Usage.total_tokens).label("total_tokens"),
|
||||
func.sum(Usage.total_cost_usd).label("total_cost"),
|
||||
func.sum(Usage.actual_total_cost_usd).label("actual_cost"),
|
||||
)
|
||||
# 过滤掉 pending/streaming 状态的请求(尚未完成的请求不应计入统计)
|
||||
query = query.filter(Usage.status.notin_(["pending", "streaming"]))
|
||||
# 过滤掉 unknown/pending provider(请求未到达任何提供商)
|
||||
query = query.filter(Usage.provider.notin_(["unknown", "pending"]))
|
||||
|
||||
if self.start_date:
|
||||
query = query.filter(Usage.created_at >= self.start_date)
|
||||
if self.end_date:
|
||||
query = query.filter(Usage.created_at <= self.end_date)
|
||||
|
||||
query = query.group_by(Usage.model).order_by(func.count(Usage.id).desc()).limit(self.limit)
|
||||
stats = query.all()
|
||||
context.add_audit_metadata(
|
||||
action="usage_by_model",
|
||||
start_date=self.start_date.isoformat() if self.start_date else None,
|
||||
end_date=self.end_date.isoformat() if self.end_date else None,
|
||||
limit=self.limit,
|
||||
result_count=len(stats),
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"model": model,
|
||||
"request_count": count,
|
||||
"total_tokens": int(tokens or 0),
|
||||
"total_cost": float(cost or 0),
|
||||
"actual_cost": float(actual_cost or 0),
|
||||
}
|
||||
for model, count, tokens, cost, actual_cost in stats
|
||||
]
|
||||
|
||||
|
||||
class AdminUsageByUserAdapter(AdminApiAdapter):
|
||||
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int):
|
||||
self.start_date = start_date
|
||||
self.end_date = end_date
|
||||
self.limit = limit
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
query = (
|
||||
db.query(
|
||||
User.id,
|
||||
User.email,
|
||||
User.username,
|
||||
func.count(Usage.id).label("request_count"),
|
||||
func.sum(Usage.total_tokens).label("total_tokens"),
|
||||
func.sum(Usage.total_cost_usd).label("total_cost"),
|
||||
)
|
||||
.join(Usage, Usage.user_id == User.id)
|
||||
.group_by(User.id, User.email, User.username)
|
||||
)
|
||||
|
||||
if self.start_date:
|
||||
query = query.filter(Usage.created_at >= self.start_date)
|
||||
if self.end_date:
|
||||
query = query.filter(Usage.created_at <= self.end_date)
|
||||
|
||||
query = query.order_by(func.count(Usage.id).desc()).limit(self.limit)
|
||||
stats = query.all()
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="usage_by_user",
|
||||
start_date=self.start_date.isoformat() if self.start_date else None,
|
||||
end_date=self.end_date.isoformat() if self.end_date else None,
|
||||
limit=self.limit,
|
||||
result_count=len(stats),
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"user_id": user_id,
|
||||
"email": email,
|
||||
"username": username,
|
||||
"request_count": count,
|
||||
"total_tokens": int(tokens or 0),
|
||||
"total_cost": float(cost or 0),
|
||||
}
|
||||
for user_id, email, username, count, tokens, cost in stats
|
||||
]
|
||||
|
||||
|
||||
class AdminUsageByProviderAdapter(AdminApiAdapter):
|
||||
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int):
|
||||
self.start_date = start_date
|
||||
self.end_date = end_date
|
||||
self.limit = limit
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
|
||||
# 从 request_candidates 表统计每个 Provider 的尝试次数和成功率
|
||||
# 这样可以正确统计 Fallback 场景(一个请求可能尝试多个 Provider)
|
||||
from sqlalchemy import case, Integer
|
||||
|
||||
attempt_query = db.query(
|
||||
RequestCandidate.provider_id,
|
||||
func.count(RequestCandidate.id).label("attempt_count"),
|
||||
func.sum(
|
||||
case((RequestCandidate.status == "success", 1), else_=0)
|
||||
).label("success_count"),
|
||||
func.sum(
|
||||
case((RequestCandidate.status == "failed", 1), else_=0)
|
||||
).label("failed_count"),
|
||||
func.avg(RequestCandidate.latency_ms).label("avg_latency_ms"),
|
||||
).filter(
|
||||
RequestCandidate.provider_id.isnot(None),
|
||||
# 只统计实际执行的尝试(排除 available/skipped 状态)
|
||||
RequestCandidate.status.in_(["success", "failed"]),
|
||||
)
|
||||
|
||||
if self.start_date:
|
||||
attempt_query = attempt_query.filter(RequestCandidate.created_at >= self.start_date)
|
||||
if self.end_date:
|
||||
attempt_query = attempt_query.filter(RequestCandidate.created_at <= self.end_date)
|
||||
|
||||
attempt_stats = (
|
||||
attempt_query.group_by(RequestCandidate.provider_id)
|
||||
.order_by(func.count(RequestCandidate.id).desc())
|
||||
.limit(self.limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 从 Usage 表获取 token 和费用统计(基于成功的请求)
|
||||
usage_query = db.query(
|
||||
Usage.provider_id,
|
||||
func.count(Usage.id).label("request_count"),
|
||||
func.sum(Usage.total_tokens).label("total_tokens"),
|
||||
func.sum(Usage.total_cost_usd).label("total_cost"),
|
||||
func.sum(Usage.actual_total_cost_usd).label("actual_cost"),
|
||||
func.avg(Usage.response_time_ms).label("avg_response_time_ms"),
|
||||
).filter(
|
||||
Usage.provider_id.isnot(None),
|
||||
# 过滤掉 pending/streaming 状态的请求
|
||||
Usage.status.notin_(["pending", "streaming"]),
|
||||
)
|
||||
|
||||
if self.start_date:
|
||||
usage_query = usage_query.filter(Usage.created_at >= self.start_date)
|
||||
if self.end_date:
|
||||
usage_query = usage_query.filter(Usage.created_at <= self.end_date)
|
||||
|
||||
usage_stats = usage_query.group_by(Usage.provider_id).all()
|
||||
usage_map = {str(u.provider_id): u for u in usage_stats}
|
||||
|
||||
# 获取所有相关的 Provider ID
|
||||
provider_ids = set()
|
||||
for stat in attempt_stats:
|
||||
if stat.provider_id:
|
||||
provider_ids.add(stat.provider_id)
|
||||
for stat in usage_stats:
|
||||
if stat.provider_id:
|
||||
provider_ids.add(stat.provider_id)
|
||||
|
||||
# 获取 Provider 名称映射
|
||||
provider_map = {}
|
||||
if provider_ids:
|
||||
providers_data = (
|
||||
db.query(Provider.id, Provider.name).filter(Provider.id.in_(provider_ids)).all()
|
||||
)
|
||||
provider_map = {str(p.id): p.name for p in providers_data}
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="usage_by_provider",
|
||||
start_date=self.start_date.isoformat() if self.start_date else None,
|
||||
end_date=self.end_date.isoformat() if self.end_date else None,
|
||||
limit=self.limit,
|
||||
result_count=len(attempt_stats),
|
||||
)
|
||||
|
||||
result = []
|
||||
for stat in attempt_stats:
|
||||
provider_id_str = str(stat.provider_id) if stat.provider_id else None
|
||||
attempt_count = stat.attempt_count or 0
|
||||
success_count = int(stat.success_count or 0)
|
||||
failed_count = int(stat.failed_count or 0)
|
||||
success_rate = (success_count / attempt_count * 100) if attempt_count > 0 else 0
|
||||
|
||||
# 从 usage_map 获取 token 和费用信息
|
||||
usage_stat = usage_map.get(provider_id_str)
|
||||
|
||||
result.append({
|
||||
"provider_id": provider_id_str,
|
||||
"provider": provider_map.get(provider_id_str, "Unknown"),
|
||||
"request_count": attempt_count, # 尝试次数
|
||||
"total_tokens": int(usage_stat.total_tokens or 0) if usage_stat else 0,
|
||||
"total_cost": float(usage_stat.total_cost or 0) if usage_stat else 0,
|
||||
"actual_cost": float(usage_stat.actual_cost or 0) if usage_stat else 0,
|
||||
"avg_response_time_ms": float(stat.avg_latency_ms or 0),
|
||||
"success_rate": round(success_rate, 2),
|
||||
"error_count": failed_count,
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class AdminUsageByApiFormatAdapter(AdminApiAdapter):
|
||||
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int):
|
||||
self.start_date = start_date
|
||||
self.end_date = end_date
|
||||
self.limit = limit
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
query = db.query(
|
||||
Usage.api_format,
|
||||
func.count(Usage.id).label("request_count"),
|
||||
func.sum(Usage.total_tokens).label("total_tokens"),
|
||||
func.sum(Usage.total_cost_usd).label("total_cost"),
|
||||
func.sum(Usage.actual_total_cost_usd).label("actual_cost"),
|
||||
func.avg(Usage.response_time_ms).label("avg_response_time_ms"),
|
||||
)
|
||||
# 过滤掉 pending/streaming 状态的请求
|
||||
query = query.filter(Usage.status.notin_(["pending", "streaming"]))
|
||||
# 过滤掉 unknown/pending provider
|
||||
query = query.filter(Usage.provider.notin_(["unknown", "pending"]))
|
||||
# 只统计有 api_format 的记录
|
||||
query = query.filter(Usage.api_format.isnot(None))
|
||||
|
||||
if self.start_date:
|
||||
query = query.filter(Usage.created_at >= self.start_date)
|
||||
if self.end_date:
|
||||
query = query.filter(Usage.created_at <= self.end_date)
|
||||
|
||||
query = (
|
||||
query.group_by(Usage.api_format)
|
||||
.order_by(func.count(Usage.id).desc())
|
||||
.limit(self.limit)
|
||||
)
|
||||
stats = query.all()
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="usage_by_api_format",
|
||||
start_date=self.start_date.isoformat() if self.start_date else None,
|
||||
end_date=self.end_date.isoformat() if self.end_date else None,
|
||||
limit=self.limit,
|
||||
result_count=len(stats),
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"api_format": api_format or "unknown",
|
||||
"request_count": count,
|
||||
"total_tokens": int(tokens or 0),
|
||||
"total_cost": float(cost or 0),
|
||||
"actual_cost": float(actual_cost or 0),
|
||||
"avg_response_time_ms": float(avg_response_time or 0),
|
||||
}
|
||||
for api_format, count, tokens, cost, actual_cost, avg_response_time in stats
|
||||
]
|
||||
|
||||
|
||||
class AdminUsageRecordsAdapter(AdminApiAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
start_date: Optional[datetime],
|
||||
end_date: Optional[datetime],
|
||||
user_id: Optional[str],
|
||||
username: Optional[str],
|
||||
model: Optional[str],
|
||||
provider: Optional[str],
|
||||
status: Optional[str],
|
||||
limit: int,
|
||||
offset: int,
|
||||
):
|
||||
self.start_date = start_date
|
||||
self.end_date = end_date
|
||||
self.user_id = user_id
|
||||
self.username = username
|
||||
self.model = model
|
||||
self.provider = provider
|
||||
self.status = status
|
||||
self.limit = limit
|
||||
self.offset = offset
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
query = (
|
||||
db.query(Usage, User, ProviderEndpoint, ProviderAPIKey)
|
||||
.outerjoin(User, Usage.user_id == User.id)
|
||||
.outerjoin(ProviderEndpoint, Usage.provider_endpoint_id == ProviderEndpoint.id)
|
||||
.outerjoin(ProviderAPIKey, Usage.provider_api_key_id == ProviderAPIKey.id)
|
||||
)
|
||||
if self.user_id:
|
||||
query = query.filter(Usage.user_id == self.user_id)
|
||||
if self.username:
|
||||
# 支持用户名模糊搜索
|
||||
query = query.filter(User.username.ilike(f"%{self.username}%"))
|
||||
if self.model:
|
||||
# 支持模型名模糊搜索
|
||||
query = query.filter(Usage.model.ilike(f"%{self.model}%"))
|
||||
if self.provider:
|
||||
# 支持提供商名称搜索(通过 Provider 表)
|
||||
query = query.join(Provider, Usage.provider_id == Provider.id, isouter=True)
|
||||
query = query.filter(Provider.name.ilike(f"%{self.provider}%"))
|
||||
if self.status:
|
||||
# 状态筛选
|
||||
# 旧的筛选值(基于 is_stream 和 status_code):stream, standard, error
|
||||
# 新的筛选值(基于 status 字段):pending, streaming, completed, failed, active
|
||||
if self.status == "stream":
|
||||
query = query.filter(Usage.is_stream == True) # noqa: E712
|
||||
elif self.status == "standard":
|
||||
query = query.filter(Usage.is_stream == False) # noqa: E712
|
||||
elif self.status == "error":
|
||||
query = query.filter(
|
||||
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
|
||||
)
|
||||
elif self.status in ("pending", "streaming", "completed", "failed"):
|
||||
# 新的状态筛选:直接按 status 字段过滤
|
||||
query = query.filter(Usage.status == self.status)
|
||||
elif self.status == "active":
|
||||
# 活跃请求:pending 或 streaming 状态
|
||||
query = query.filter(Usage.status.in_(["pending", "streaming"]))
|
||||
if self.start_date:
|
||||
query = query.filter(Usage.created_at >= self.start_date)
|
||||
if self.end_date:
|
||||
query = query.filter(Usage.created_at <= self.end_date)
|
||||
|
||||
total = query.count()
|
||||
records = (
|
||||
query.order_by(Usage.created_at.desc()).offset(self.offset).limit(self.limit).all()
|
||||
)
|
||||
|
||||
request_ids = [usage.request_id for usage, _, _, _ in records if usage.request_id]
|
||||
fallback_map = {}
|
||||
if request_ids:
|
||||
# 只统计实际执行的候选(success 或 failed),不包括 skipped/pending/available
|
||||
executed_counts = (
|
||||
db.query(RequestCandidate.request_id, func.count(RequestCandidate.id))
|
||||
.filter(
|
||||
RequestCandidate.request_id.in_(request_ids),
|
||||
RequestCandidate.status.in_(["success", "failed"]),
|
||||
)
|
||||
.group_by(RequestCandidate.request_id)
|
||||
.all()
|
||||
)
|
||||
# 如果实际执行的候选数 > 1,说明发生了 Provider 切换
|
||||
fallback_map = {req_id: count > 1 for req_id, count in executed_counts}
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="usage_records",
|
||||
start_date=self.start_date.isoformat() if self.start_date else None,
|
||||
end_date=self.end_date.isoformat() if self.end_date else None,
|
||||
user_id=self.user_id,
|
||||
username=self.username,
|
||||
model=self.model,
|
||||
provider=self.provider,
|
||||
status=self.status,
|
||||
limit=self.limit,
|
||||
offset=self.offset,
|
||||
total=total,
|
||||
)
|
||||
|
||||
# 构建 provider_id -> Provider 名称的映射,避免 N+1 查询
|
||||
provider_ids = [usage.provider_id for usage, _, _, _ in records if usage.provider_id]
|
||||
provider_map = {}
|
||||
if provider_ids:
|
||||
providers_data = (
|
||||
db.query(Provider.id, Provider.name).filter(Provider.id.in_(provider_ids)).all()
|
||||
)
|
||||
provider_map = {str(p.id): p.name for p in providers_data}
|
||||
|
||||
data = []
|
||||
for usage, user, endpoint, api_key in records:
|
||||
actual_cost = (
|
||||
float(usage.actual_total_cost_usd)
|
||||
if usage.actual_total_cost_usd is not None
|
||||
else 0.0
|
||||
)
|
||||
rate_multiplier = (
|
||||
float(usage.rate_multiplier) if usage.rate_multiplier is not None else 1.0
|
||||
)
|
||||
|
||||
# 提供商名称优先级:关联的 Provider 表 > usage.provider 字段
|
||||
provider_name = usage.provider
|
||||
if usage.provider_id and str(usage.provider_id) in provider_map:
|
||||
provider_name = provider_map[str(usage.provider_id)]
|
||||
|
||||
data.append(
|
||||
{
|
||||
"id": usage.id,
|
||||
"user_id": user.id if user else None,
|
||||
"user_email": user.email if user else "已删除用户",
|
||||
"username": user.username if user else "已删除用户",
|
||||
"provider": provider_name,
|
||||
"model": usage.model,
|
||||
"target_model": usage.target_model, # 映射后的目标模型名
|
||||
"input_tokens": usage.input_tokens,
|
||||
"output_tokens": usage.output_tokens,
|
||||
"cache_creation_input_tokens": usage.cache_creation_input_tokens,
|
||||
"cache_read_input_tokens": usage.cache_read_input_tokens,
|
||||
"total_tokens": usage.total_tokens,
|
||||
"cost": float(usage.total_cost_usd),
|
||||
"actual_cost": actual_cost,
|
||||
"rate_multiplier": rate_multiplier,
|
||||
"response_time_ms": usage.response_time_ms,
|
||||
"created_at": usage.created_at.isoformat(),
|
||||
"is_stream": usage.is_stream,
|
||||
"input_price_per_1m": usage.input_price_per_1m,
|
||||
"output_price_per_1m": usage.output_price_per_1m,
|
||||
"cache_creation_price_per_1m": usage.cache_creation_price_per_1m,
|
||||
"cache_read_price_per_1m": usage.cache_read_price_per_1m,
|
||||
"status_code": usage.status_code,
|
||||
"error_message": usage.error_message,
|
||||
"status": usage.status, # 请求状态: pending, streaming, completed, failed
|
||||
"has_fallback": fallback_map.get(usage.request_id, False),
|
||||
"api_format": usage.api_format
|
||||
or (endpoint.api_format if endpoint and endpoint.api_format else None),
|
||||
"api_key_name": api_key.name if api_key else None,
|
||||
"request_metadata": usage.request_metadata, # Provider 响应元数据
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"records": data,
|
||||
"total": total,
|
||||
"limit": self.limit,
|
||||
"offset": self.offset,
|
||||
}
|
||||
|
||||
|
||||
class AdminActiveRequestsAdapter(AdminApiAdapter):
|
||||
"""轻量级活跃请求状态查询适配器"""
|
||||
|
||||
def __init__(self, ids: Optional[str]):
|
||||
self.ids = ids
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
|
||||
if self.ids:
|
||||
# 查询指定 ID 的请求状态
|
||||
id_list = [id.strip() for id in self.ids.split(",") if id.strip()]
|
||||
if not id_list:
|
||||
return {"requests": []}
|
||||
|
||||
records = (
|
||||
db.query(Usage.id, Usage.status, Usage.input_tokens, Usage.output_tokens, Usage.total_cost_usd, Usage.response_time_ms)
|
||||
.filter(Usage.id.in_(id_list))
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
# 查询所有活跃请求(pending 或 streaming)
|
||||
records = (
|
||||
db.query(Usage.id, Usage.status, Usage.input_tokens, Usage.output_tokens, Usage.total_cost_usd, Usage.response_time_ms)
|
||||
.filter(Usage.status.in_(["pending", "streaming"]))
|
||||
.order_by(Usage.created_at.desc())
|
||||
.limit(50)
|
||||
.all()
|
||||
)
|
||||
|
||||
return {
|
||||
"requests": [
|
||||
{
|
||||
"id": r.id,
|
||||
"status": r.status,
|
||||
"input_tokens": r.input_tokens,
|
||||
"output_tokens": r.output_tokens,
|
||||
"cost": float(r.total_cost_usd) if r.total_cost_usd else 0,
|
||||
"response_time_ms": r.response_time_ms,
|
||||
}
|
||||
for r in records
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminUsageDetailAdapter(AdminApiAdapter):
|
||||
"""Get detailed usage record with request/response body"""
|
||||
|
||||
usage_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
usage_record = db.query(Usage).filter(Usage.id == self.usage_id).first()
|
||||
if not usage_record:
|
||||
raise HTTPException(status_code=404, detail="Usage record not found")
|
||||
|
||||
user = db.query(User).filter(User.id == usage_record.user_id).first()
|
||||
api_key = db.query(ApiKey).filter(ApiKey.id == usage_record.api_key_id).first()
|
||||
|
||||
# 获取阶梯计费信息
|
||||
tiered_pricing_info = await self._get_tiered_pricing_info(db, usage_record)
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="usage_detail",
|
||||
usage_id=self.usage_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": usage_record.id,
|
||||
"request_id": usage_record.request_id,
|
||||
"user": {
|
||||
"id": user.id if user else None,
|
||||
"username": user.username if user else "Unknown",
|
||||
"email": user.email if user else None,
|
||||
},
|
||||
"api_key": {
|
||||
"id": api_key.id if api_key else None,
|
||||
"name": api_key.name if api_key else None,
|
||||
"display": api_key.get_display_key() if api_key else None,
|
||||
},
|
||||
"provider": usage_record.provider,
|
||||
"api_format": usage_record.api_format,
|
||||
"model": usage_record.model,
|
||||
"target_model": usage_record.target_model,
|
||||
"tokens": {
|
||||
"input": usage_record.input_tokens,
|
||||
"output": usage_record.output_tokens,
|
||||
"total": usage_record.total_tokens,
|
||||
},
|
||||
"cost": {
|
||||
"input": usage_record.input_cost_usd,
|
||||
"output": usage_record.output_cost_usd,
|
||||
"total": usage_record.total_cost_usd,
|
||||
},
|
||||
"cache_creation_input_tokens": usage_record.cache_creation_input_tokens,
|
||||
"cache_read_input_tokens": usage_record.cache_read_input_tokens,
|
||||
"cache_creation_cost": getattr(usage_record, "cache_creation_cost_usd", 0.0),
|
||||
"cache_read_cost": getattr(usage_record, "cache_read_cost_usd", 0.0),
|
||||
"request_cost": getattr(usage_record, "request_cost_usd", 0.0),
|
||||
"input_price_per_1m": usage_record.input_price_per_1m,
|
||||
"output_price_per_1m": usage_record.output_price_per_1m,
|
||||
"cache_creation_price_per_1m": usage_record.cache_creation_price_per_1m,
|
||||
"cache_read_price_per_1m": usage_record.cache_read_price_per_1m,
|
||||
"price_per_request": usage_record.price_per_request,
|
||||
"request_type": usage_record.request_type,
|
||||
"is_stream": usage_record.is_stream,
|
||||
"status_code": usage_record.status_code,
|
||||
"error_message": usage_record.error_message,
|
||||
"response_time_ms": usage_record.response_time_ms,
|
||||
"created_at": usage_record.created_at.isoformat() if usage_record.created_at else None,
|
||||
"request_headers": usage_record.request_headers,
|
||||
"request_body": usage_record.get_request_body(),
|
||||
"provider_request_headers": usage_record.provider_request_headers,
|
||||
"response_headers": usage_record.response_headers,
|
||||
"response_body": usage_record.get_response_body(),
|
||||
"metadata": usage_record.request_metadata,
|
||||
"tiered_pricing": tiered_pricing_info,
|
||||
}
|
||||
|
||||
async def _get_tiered_pricing_info(self, db, usage_record) -> dict | None:
|
||||
"""获取阶梯计费信息"""
|
||||
from src.services.model.cost import ModelCostService
|
||||
|
||||
# 计算总输入上下文(用于阶梯判定):输入 + 缓存创建 + 缓存读取
|
||||
input_tokens = usage_record.input_tokens or 0
|
||||
cache_creation_tokens = usage_record.cache_creation_input_tokens or 0
|
||||
cache_read_tokens = usage_record.cache_read_input_tokens or 0
|
||||
total_input_context = input_tokens + cache_creation_tokens + cache_read_tokens
|
||||
|
||||
# 尝试获取模型的阶梯配置(带来源信息)
|
||||
cost_service = ModelCostService(db)
|
||||
pricing_result = await cost_service.get_tiered_pricing_with_source_async(
|
||||
usage_record.provider, usage_record.model
|
||||
)
|
||||
|
||||
if not pricing_result:
|
||||
return None
|
||||
|
||||
tiered_pricing = pricing_result.get("pricing")
|
||||
pricing_source = pricing_result.get("source") # 'provider' 或 'global'
|
||||
|
||||
if not tiered_pricing or not tiered_pricing.get("tiers"):
|
||||
return None
|
||||
|
||||
tiers = tiered_pricing.get("tiers", [])
|
||||
if not tiers:
|
||||
return None
|
||||
|
||||
# 找到命中的阶梯
|
||||
tier_index = None
|
||||
matched_tier = None
|
||||
for i, tier in enumerate(tiers):
|
||||
up_to = tier.get("up_to")
|
||||
if up_to is None or total_input_context <= up_to:
|
||||
tier_index = i
|
||||
matched_tier = tier
|
||||
break
|
||||
|
||||
# 如果都没匹配,使用最后一个阶梯
|
||||
if tier_index is None and tiers:
|
||||
tier_index = len(tiers) - 1
|
||||
matched_tier = tiers[-1]
|
||||
|
||||
return {
|
||||
"total_input_context": total_input_context,
|
||||
"tier_index": tier_index,
|
||||
"tier_count": len(tiers),
|
||||
"current_tier": matched_tier,
|
||||
"tiers": tiers,
|
||||
"source": pricing_source, # 定价来源: 'provider' 或 'global'
|
||||
}
|
||||
5
src/api/admin/users/__init__.py
Normal file
5
src/api/admin/users/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""User admin routes export."""
|
||||
|
||||
from .routes import router
|
||||
|
||||
__all__ = ["router"]
|
||||
488
src/api/admin/users/routes.py
Normal file
488
src/api/admin/users/routes.py
Normal file
@@ -0,0 +1,488 @@
|
||||
"""用户管理 API 端点。"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.exceptions import InvalidRequestException, NotFoundException, translate_pydantic_error
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.admin_requests import UpdateUserRequest
|
||||
from src.models.api import CreateApiKeyRequest, CreateUserRequest
|
||||
from src.models.database import ApiKey, User, UserRole
|
||||
from src.services.user.apikey import ApiKeyService
|
||||
from src.services.user.service import UserService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/admin/users", tags=["Admin - Users"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
# 管理员端点
|
||||
@router.post("")
|
||||
async def create_user_endpoint(request: Request, db: Session = Depends(get_db)):
|
||||
adapter = AdminCreateUserAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_users(
|
||||
request: Request,
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
role: Optional[str] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = AdminListUsersAdapter(skip=skip, limit=limit, role=role, is_active=is_active)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/{user_id}")
|
||||
async def get_user(user_id: str, request: Request, db: Session = Depends(get_db)): # UUID
|
||||
adapter = AdminGetUserAdapter(user_id=user_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.put("/{user_id}")
|
||||
async def update_user(
|
||||
user_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = AdminUpdateUserAdapter(user_id=user_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/{user_id}")
|
||||
async def delete_user(user_id: str, request: Request, db: Session = Depends(get_db)): # UUID
|
||||
adapter = AdminDeleteUserAdapter(user_id=user_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.patch("/{user_id}/quota")
|
||||
async def reset_user_quota(user_id: str, request: Request, db: Session = Depends(get_db)):
|
||||
"""Reset user quota (set used_usd to 0)"""
|
||||
adapter = AdminResetUserQuotaAdapter(user_id=user_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/{user_id}/api-keys")
|
||||
async def get_user_api_keys(
|
||||
user_id: str,
|
||||
request: Request,
|
||||
is_active: Optional[bool] = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取用户的所有API Keys(不包括独立Keys)"""
|
||||
adapter = AdminGetUserKeysAdapter(user_id=user_id, is_active=is_active)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/{user_id}/api-keys")
|
||||
async def create_user_api_key(
|
||||
user_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""为用户创建API Key"""
|
||||
adapter = AdminCreateUserKeyAdapter(user_id=user_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/{user_id}/api-keys/{key_id}")
|
||||
async def delete_user_api_key(
|
||||
user_id: str,
|
||||
key_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除用户的API Key"""
|
||||
adapter = AdminDeleteUserKeyAdapter(user_id=user_id, key_id=key_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# ============== 管理员适配器实现 ==============
|
||||
|
||||
|
||||
class AdminCreateUserAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
payload = context.ensure_json_body()
|
||||
try:
|
||||
request = CreateUserRequest.model_validate(payload)
|
||||
except ValidationError as e:
|
||||
errors = e.errors()
|
||||
if errors:
|
||||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||||
raise InvalidRequestException("请求数据验证失败")
|
||||
try:
|
||||
role = (
|
||||
request.role if hasattr(request.role, "value") else UserRole[request.role.upper()]
|
||||
)
|
||||
except (KeyError, AttributeError):
|
||||
raise InvalidRequestException("角色参数不合法")
|
||||
|
||||
try:
|
||||
user = UserService.create_user(
|
||||
db=db,
|
||||
email=request.email,
|
||||
username=request.username,
|
||||
password=request.password,
|
||||
role=role,
|
||||
quota_usd=request.quota_usd,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise InvalidRequestException(str(exc))
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="create_user",
|
||||
target_user_id=user.id,
|
||||
target_email=user.email,
|
||||
target_username=user.username,
|
||||
target_role=user.role.value,
|
||||
quota_usd=user.quota_usd,
|
||||
is_active=user.is_active,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"username": user.username,
|
||||
"role": user.role.value,
|
||||
"allowed_providers": user.allowed_providers,
|
||||
"allowed_endpoints": user.allowed_endpoints,
|
||||
"allowed_models": user.allowed_models,
|
||||
"quota_usd": user.quota_usd,
|
||||
"used_usd": user.used_usd,
|
||||
"total_usd": getattr(user, "total_usd", 0),
|
||||
"is_active": user.is_active,
|
||||
"created_at": user.created_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
class AdminListUsersAdapter(AdminApiAdapter):
|
||||
def __init__(self, skip: int, limit: int, role: Optional[str], is_active: Optional[bool]):
|
||||
self.skip = skip
|
||||
self.limit = limit
|
||||
self.role = role
|
||||
self.is_active = is_active
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
role_enum = UserRole[self.role.upper()] if self.role else None
|
||||
users = UserService.list_users(db, self.skip, self.limit, role_enum, self.is_active)
|
||||
return [
|
||||
{
|
||||
"id": u.id,
|
||||
"email": u.email,
|
||||
"username": u.username,
|
||||
"role": u.role.value,
|
||||
"quota_usd": u.quota_usd,
|
||||
"used_usd": u.used_usd,
|
||||
"total_usd": getattr(u, "total_usd", 0),
|
||||
"is_active": u.is_active,
|
||||
"created_at": u.created_at.isoformat(),
|
||||
}
|
||||
for u in users
|
||||
]
|
||||
|
||||
|
||||
class AdminGetUserAdapter(AdminApiAdapter):
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
user = UserService.get_user(db, self.user_id)
|
||||
if not user:
|
||||
raise NotFoundException("用户不存在", "user")
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="get_user_detail",
|
||||
target_user_id=user.id,
|
||||
target_role=user.role.value,
|
||||
include_history=bool(user.last_login_at),
|
||||
)
|
||||
|
||||
return {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"username": user.username,
|
||||
"role": user.role.value,
|
||||
"allowed_providers": user.allowed_providers,
|
||||
"allowed_endpoints": user.allowed_endpoints,
|
||||
"allowed_models": user.allowed_models,
|
||||
"quota_usd": user.quota_usd,
|
||||
"used_usd": user.used_usd,
|
||||
"total_usd": getattr(user, "total_usd", 0),
|
||||
"is_active": user.is_active,
|
||||
"created_at": user.created_at.isoformat(),
|
||||
"updated_at": user.updated_at.isoformat() if user.updated_at else None,
|
||||
"last_login_at": user.last_login_at.isoformat() if user.last_login_at else None,
|
||||
}
|
||||
|
||||
|
||||
class AdminUpdateUserAdapter(AdminApiAdapter):
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
existing_user = UserService.get_user(db, self.user_id)
|
||||
if not existing_user:
|
||||
raise NotFoundException("用户不存在", "user")
|
||||
|
||||
payload = context.ensure_json_body()
|
||||
try:
|
||||
request = UpdateUserRequest.model_validate(payload)
|
||||
except ValidationError as e:
|
||||
errors = e.errors()
|
||||
if errors:
|
||||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||||
raise InvalidRequestException("请求数据验证失败")
|
||||
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
if "role" in update_data and update_data["role"]:
|
||||
if hasattr(update_data["role"], "value"):
|
||||
update_data["role"] = update_data["role"]
|
||||
else:
|
||||
update_data["role"] = UserRole[update_data["role"].upper()]
|
||||
|
||||
user = UserService.update_user(db, self.user_id, **update_data)
|
||||
if not user:
|
||||
raise NotFoundException("用户不存在", "user")
|
||||
|
||||
changed_fields = list(update_data.keys())
|
||||
context.add_audit_metadata(
|
||||
action="update_user",
|
||||
target_user_id=user.id,
|
||||
updated_fields=changed_fields,
|
||||
role_before=existing_user.role.value if existing_user.role else None,
|
||||
role_after=user.role.value,
|
||||
quota_usd=user.quota_usd,
|
||||
is_active=user.is_active,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"username": user.username,
|
||||
"role": user.role.value,
|
||||
"allowed_providers": user.allowed_providers,
|
||||
"allowed_endpoints": user.allowed_endpoints,
|
||||
"allowed_models": user.allowed_models,
|
||||
"quota_usd": user.quota_usd,
|
||||
"used_usd": user.used_usd,
|
||||
"total_usd": getattr(user, "total_usd", 0),
|
||||
"is_active": user.is_active,
|
||||
"created_at": user.created_at.isoformat(),
|
||||
"updated_at": user.updated_at.isoformat() if user.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
class AdminDeleteUserAdapter(AdminApiAdapter):
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
user = UserService.get_user(db, self.user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
|
||||
|
||||
if user.role == UserRole.ADMIN:
|
||||
admin_count = db.query(User).filter(User.role == UserRole.ADMIN).count()
|
||||
if admin_count <= 1:
|
||||
raise InvalidRequestException("不能删除最后一个管理员账户")
|
||||
|
||||
success = UserService.delete_user(db, self.user_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="delete_user",
|
||||
target_user_id=user.id,
|
||||
target_email=user.email,
|
||||
target_role=user.role.value,
|
||||
)
|
||||
|
||||
return {"message": "用户删除成功"}
|
||||
|
||||
|
||||
class AdminResetUserQuotaAdapter(AdminApiAdapter):
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
user = UserService.get_user(db, self.user_id)
|
||||
if not user:
|
||||
raise NotFoundException("用户不存在", "user")
|
||||
|
||||
user.used_usd = 0.0
|
||||
user.total_usd = getattr(user, "total_usd", 0)
|
||||
user.updated_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="reset_user_quota",
|
||||
target_user_id=user.id,
|
||||
quota_usd=user.quota_usd,
|
||||
used_usd=user.used_usd,
|
||||
total_usd=user.total_usd,
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "配额已重置",
|
||||
"user_id": user.id,
|
||||
"quota_usd": user.quota_usd,
|
||||
"used_usd": user.used_usd,
|
||||
"total_usd": user.total_usd,
|
||||
}
|
||||
|
||||
|
||||
class AdminGetUserKeysAdapter(AdminApiAdapter):
|
||||
"""获取用户的API Keys"""
|
||||
|
||||
def __init__(self, user_id: str, is_active: Optional[bool]):
|
||||
self.user_id = user_id
|
||||
self.is_active = is_active
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
|
||||
# 验证用户存在
|
||||
user = db.query(User).filter(User.id == self.user_id).first()
|
||||
if not user:
|
||||
raise NotFoundException("用户不存在", "user")
|
||||
|
||||
# 获取用户的Keys(不包括独立Keys)
|
||||
api_keys = ApiKeyService.list_user_api_keys(
|
||||
db=db, user_id=self.user_id, is_active=self.is_active
|
||||
)
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="list_user_api_keys",
|
||||
target_user_id=self.user_id,
|
||||
total=len(api_keys),
|
||||
)
|
||||
|
||||
return {
|
||||
"api_keys": [
|
||||
{
|
||||
"id": key.id,
|
||||
"name": key.name,
|
||||
"key_display": key.get_display_key(),
|
||||
"is_active": key.is_active,
|
||||
"total_requests": key.total_requests,
|
||||
"total_cost_usd": float(key.total_cost_usd or 0),
|
||||
"rate_limit": key.rate_limit,
|
||||
"expires_at": key.expires_at.isoformat() if key.expires_at else None,
|
||||
"last_used_at": key.last_used_at.isoformat() if key.last_used_at else None,
|
||||
"created_at": key.created_at.isoformat(),
|
||||
}
|
||||
for key in api_keys
|
||||
],
|
||||
"total": len(api_keys),
|
||||
"user_email": user.email,
|
||||
"username": user.username,
|
||||
}
|
||||
|
||||
|
||||
class AdminCreateUserKeyAdapter(AdminApiAdapter):
|
||||
"""为用户创建API Key"""
|
||||
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
payload = context.ensure_json_body()
|
||||
try:
|
||||
key_data = CreateApiKeyRequest.model_validate(payload)
|
||||
except ValidationError as e:
|
||||
errors = e.errors()
|
||||
if errors:
|
||||
raise InvalidRequestException(translate_pydantic_error(errors[0]))
|
||||
raise InvalidRequestException("请求数据验证失败")
|
||||
|
||||
# 验证用户存在
|
||||
user = db.query(User).filter(User.id == self.user_id).first()
|
||||
if not user:
|
||||
raise NotFoundException("用户不存在", "user")
|
||||
|
||||
# 为用户创建Key(不是独立Key)
|
||||
api_key, plain_key = ApiKeyService.create_api_key(
|
||||
db=db,
|
||||
user_id=self.user_id,
|
||||
name=key_data.name,
|
||||
allowed_providers=key_data.allowed_providers,
|
||||
allowed_models=key_data.allowed_models,
|
||||
rate_limit=key_data.rate_limit or 100,
|
||||
expire_days=key_data.expire_days,
|
||||
initial_balance_usd=None, # 普通Key不设置余额限制
|
||||
is_standalone=False, # 不是独立Key
|
||||
)
|
||||
|
||||
logger.info(f"管理员为用户创建API Key: 用户 {user.email}, Key ID {api_key.id}")
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="create_user_api_key",
|
||||
target_user_id=self.user_id,
|
||||
key_id=api_key.id,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": api_key.id,
|
||||
"key": plain_key, # 只在创建时返回
|
||||
"name": api_key.name,
|
||||
"key_display": api_key.get_display_key(),
|
||||
"rate_limit": api_key.rate_limit,
|
||||
"expires_at": api_key.expires_at.isoformat() if api_key.expires_at else None,
|
||||
"created_at": api_key.created_at.isoformat(),
|
||||
"message": "API Key创建成功,请妥善保存完整密钥",
|
||||
}
|
||||
|
||||
|
||||
class AdminDeleteUserKeyAdapter(AdminApiAdapter):
|
||||
"""删除用户的API Key"""
|
||||
|
||||
def __init__(self, user_id: str, key_id: str):
|
||||
self.user_id = user_id
|
||||
self.key_id = key_id
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
|
||||
# 验证Key存在且属于该用户
|
||||
api_key = (
|
||||
db.query(ApiKey)
|
||||
.filter(
|
||||
ApiKey.id == self.key_id,
|
||||
ApiKey.user_id == self.user_id,
|
||||
ApiKey.is_standalone == False, # 只能删除普通Key
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
raise NotFoundException("API Key不存在或不属于该用户", "api_key")
|
||||
|
||||
db.delete(api_key)
|
||||
db.commit()
|
||||
|
||||
logger.info(f"管理员删除用户API Key: 用户ID {self.user_id}, Key ID {self.key_id}")
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="delete_user_api_key",
|
||||
target_user_id=self.user_id,
|
||||
key_id=self.key_id,
|
||||
)
|
||||
|
||||
return {"message": "API Key已删除"}
|
||||
Reference in New Issue
Block a user