Files
Aether/src/api/admin/models/global_models.py

278 lines
9.2 KiB
Python
Raw Normal View History

2025-12-10 20:52:44 +08:00
"""
GlobalModel Admin API
提供 GlobalModel CRUD 操作接口
"""
from dataclasses import dataclass
from typing import List, Optional
from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.logger import logger
from src.database import get_db
from src.models.pydantic_models import (
BatchAssignToProvidersRequest,
BatchAssignToProvidersResponse,
GlobalModelCreate,
GlobalModelListResponse,
GlobalModelResponse,
GlobalModelUpdate,
GlobalModelWithStats,
)
from src.services.model.global_model import GlobalModelService
router = APIRouter(prefix="/global", tags=["Admin - Global Models"])
pipeline = ApiRequestPipeline()
@router.get("", response_model=GlobalModelListResponse)
async def list_global_models(
request: Request,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000),
is_active: Optional[bool] = Query(None),
search: Optional[str] = Query(None),
db: Session = Depends(get_db),
) -> GlobalModelListResponse:
"""获取 GlobalModel 列表"""
adapter = AdminListGlobalModelsAdapter(
skip=skip,
limit=limit,
is_active=is_active,
search=search,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/{global_model_id}", response_model=GlobalModelWithStats)
async def get_global_model(
request: Request,
global_model_id: str,
db: Session = Depends(get_db),
) -> GlobalModelWithStats:
"""获取单个 GlobalModel 详情(含统计信息)"""
adapter = AdminGetGlobalModelAdapter(global_model_id=global_model_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("", response_model=GlobalModelResponse, status_code=201)
async def create_global_model(
request: Request,
payload: GlobalModelCreate,
db: Session = Depends(get_db),
) -> GlobalModelResponse:
"""创建 GlobalModel"""
adapter = AdminCreateGlobalModelAdapter(payload=payload)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.patch("/{global_model_id}", response_model=GlobalModelResponse)
async def update_global_model(
request: Request,
global_model_id: str,
payload: GlobalModelUpdate,
db: Session = Depends(get_db),
) -> GlobalModelResponse:
"""更新 GlobalModel"""
adapter = AdminUpdateGlobalModelAdapter(global_model_id=global_model_id, payload=payload)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/{global_model_id}", status_code=204)
async def delete_global_model(
request: Request,
global_model_id: str,
db: Session = Depends(get_db),
):
"""删除 GlobalModel级联删除所有关联的 Provider 模型实现)"""
adapter = AdminDeleteGlobalModelAdapter(global_model_id=global_model_id)
await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
return None
@router.post(
"/{global_model_id}/assign-to-providers", response_model=BatchAssignToProvidersResponse
)
async def batch_assign_to_providers(
request: Request,
global_model_id: str,
payload: BatchAssignToProvidersRequest,
db: Session = Depends(get_db),
) -> BatchAssignToProvidersResponse:
"""批量为多个 Provider 添加 GlobalModel 实现"""
adapter = AdminBatchAssignToProvidersAdapter(global_model_id=global_model_id, payload=payload)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# ========== Adapters ==========
@dataclass
class AdminListGlobalModelsAdapter(AdminApiAdapter):
"""列出 GlobalModel"""
skip: int
limit: int
is_active: Optional[bool]
search: Optional[str]
async def handle(self, context): # type: ignore[override]
from sqlalchemy import func
from src.models.database import Model
2025-12-10 20:52:44 +08:00
models = GlobalModelService.list_global_models(
db=context.db,
skip=self.skip,
limit=self.limit,
is_active=self.is_active,
search=self.search,
)
# 为每个 GlobalModel 添加统计数据
model_responses = []
for gm in models:
# 统计关联的 Model 数量(去重 Provider
provider_count = (
context.db.query(func.count(func.distinct(Model.provider_id)))
.filter(Model.global_model_id == gm.id)
.scalar()
or 0
)
response = GlobalModelResponse.model_validate(gm)
response.provider_count = provider_count
# usage_count 直接从 GlobalModel 表读取,已在 model_validate 中自动映射
model_responses.append(response)
return GlobalModelListResponse(
models=model_responses,
total=len(models),
)
@dataclass
class AdminGetGlobalModelAdapter(AdminApiAdapter):
"""获取单个 GlobalModel"""
global_model_id: str
async def handle(self, context): # type: ignore[override]
global_model = GlobalModelService.get_global_model(context.db, self.global_model_id)
stats = GlobalModelService.get_global_model_stats(context.db, self.global_model_id)
return GlobalModelWithStats(
**GlobalModelResponse.model_validate(global_model).model_dump(),
total_models=stats["total_models"],
total_providers=stats["total_providers"],
price_range=stats["price_range"],
)
@dataclass
class AdminCreateGlobalModelAdapter(AdminApiAdapter):
"""创建 GlobalModel"""
payload: GlobalModelCreate
async def handle(self, context): # type: ignore[override]
# 将 TieredPricingConfig 转换为 dict
tiered_pricing_dict = self.payload.default_tiered_pricing.model_dump()
global_model = GlobalModelService.create_global_model(
db=context.db,
name=self.payload.name,
display_name=self.payload.display_name,
is_active=self.payload.is_active,
# 按次计费配置
default_price_per_request=self.payload.default_price_per_request,
# 阶梯计费配置
default_tiered_pricing=tiered_pricing_dict,
# Key 能力配置
supported_capabilities=self.payload.supported_capabilities,
# 模型配置JSON
config=self.payload.config,
2025-12-10 20:52:44 +08:00
)
logger.info(f"GlobalModel 已创建: id={global_model.id} name={global_model.name}")
return GlobalModelResponse.model_validate(global_model)
@dataclass
class AdminUpdateGlobalModelAdapter(AdminApiAdapter):
"""更新 GlobalModel"""
global_model_id: str
payload: GlobalModelUpdate
async def handle(self, context): # type: ignore[override]
global_model = GlobalModelService.update_global_model(
db=context.db,
global_model_id=self.global_model_id,
update_data=self.payload,
)
logger.info(f"GlobalModel 已更新: id={global_model.id} name={global_model.name}")
# 失效相关缓存
from src.services.cache.invalidation import get_cache_invalidation_service
cache_service = get_cache_invalidation_service()
cache_service.on_global_model_changed(global_model.name)
return GlobalModelResponse.model_validate(global_model)
@dataclass
class AdminDeleteGlobalModelAdapter(AdminApiAdapter):
"""删除 GlobalModel级联删除所有关联的 Provider 模型实现)"""
global_model_id: str
async def handle(self, context): # type: ignore[override]
# 先获取 GlobalModel 信息(用于失效缓存)
from src.models.database import GlobalModel
global_model = (
context.db.query(GlobalModel).filter(GlobalModel.id == self.global_model_id).first()
)
model_name = global_model.name if global_model else None
GlobalModelService.delete_global_model(context.db, self.global_model_id)
logger.info(f"GlobalModel 已删除: id={self.global_model_id}")
# 失效相关缓存
if model_name:
from src.services.cache.invalidation import get_cache_invalidation_service
cache_service = get_cache_invalidation_service()
cache_service.on_global_model_changed(model_name)
return None
@dataclass
class AdminBatchAssignToProvidersAdapter(AdminApiAdapter):
"""批量为 Provider 添加 GlobalModel 实现"""
global_model_id: str
payload: BatchAssignToProvidersRequest
async def handle(self, context): # type: ignore[override]
result = GlobalModelService.batch_assign_to_providers(
db=context.db,
global_model_id=self.global_model_id,
provider_ids=self.payload.provider_ids,
create_models=self.payload.create_models,
)
logger.info(f"批量为 Provider 添加 GlobalModel: global_model_id={self.global_model_id} success={len(result['success'])} errors={len(result['errors'])}")
return BatchAssignToProvidersResponse(**result)