mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-07 18:22:28 +08:00
refactor(backend): update model catalog and provider APIs after mappings removal
This commit is contained in:
@@ -1,38 +1,26 @@
|
||||
"""
|
||||
统一模型目录 Admin API
|
||||
|
||||
阶段一:基于 ModelMapping 和 Model 的聚合视图
|
||||
基于 GlobalModel 的聚合视图
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Set
|
||||
from typing import Dict, List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy import func, or_
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import GlobalModel, Model, ModelMapping, Provider
|
||||
from src.models.database import GlobalModel, Model
|
||||
from src.models.pydantic_models import (
|
||||
BatchAssignError,
|
||||
BatchAssignModelMappingRequest,
|
||||
BatchAssignModelMappingResponse,
|
||||
BatchAssignProviderResult,
|
||||
DeleteModelMappingResponse,
|
||||
ModelCapabilities,
|
||||
ModelCatalogItem,
|
||||
ModelCatalogProviderDetail,
|
||||
ModelCatalogResponse,
|
||||
ModelPriceRange,
|
||||
OrphanedModel,
|
||||
UpdateModelMappingRequest,
|
||||
UpdateModelMappingResponse,
|
||||
)
|
||||
from src.services.cache.invalidation import get_cache_invalidation_service
|
||||
from src.services.model.service import ModelService
|
||||
|
||||
router = APIRouter(prefix="/catalog", tags=["Admin - Model Catalog"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
@@ -47,24 +35,13 @@ async def get_model_catalog(
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/batch-assign", response_model=BatchAssignModelMappingResponse)
|
||||
async def batch_assign_model_mappings(
|
||||
request: Request,
|
||||
payload: BatchAssignModelMappingRequest,
|
||||
db: Session = Depends(get_db),
|
||||
) -> BatchAssignModelMappingResponse:
|
||||
adapter = AdminBatchAssignModelMappingsAdapter(payload=payload)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminGetModelCatalogAdapter(AdminApiAdapter):
|
||||
"""管理员查询统一模型目录
|
||||
|
||||
新架构说明:
|
||||
架构说明:
|
||||
1. 以 GlobalModel 为中心聚合数据
|
||||
2. ModelMapping 表提供别名信息(provider_id=NULL 表示全局)
|
||||
3. Model 表提供关联提供商和价格
|
||||
2. Model 表提供关联提供商和价格
|
||||
"""
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
@@ -75,29 +52,7 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
|
||||
db.query(GlobalModel).filter(GlobalModel.is_active == True).all()
|
||||
)
|
||||
|
||||
# 2. 获取所有活跃的别名(含全局和 Provider 特定)
|
||||
aliases_rows: List[ModelMapping] = (
|
||||
db.query(ModelMapping)
|
||||
.options(joinedload(ModelMapping.target_global_model))
|
||||
.filter(
|
||||
ModelMapping.is_active == True,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 按 GlobalModel ID 组织别名
|
||||
aliases_by_global_model: Dict[str, List[str]] = {}
|
||||
for alias_row in aliases_rows:
|
||||
if not alias_row.target_global_model_id:
|
||||
continue
|
||||
gm_id = alias_row.target_global_model_id
|
||||
if gm_id not in aliases_by_global_model:
|
||||
aliases_by_global_model[gm_id] = []
|
||||
if alias_row.source_model not in aliases_by_global_model[gm_id]:
|
||||
aliases_by_global_model[gm_id].append(alias_row.source_model)
|
||||
|
||||
# 3. 获取所有活跃的 Model 实现(包含 global_model 以便计算有效价格)
|
||||
# 2. 获取所有活跃的 Model 实现(包含 global_model 以便计算有效价格)
|
||||
models: List[Model] = (
|
||||
db.query(Model)
|
||||
.options(joinedload(Model.provider), joinedload(Model.global_model))
|
||||
@@ -111,7 +66,7 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
|
||||
if model.global_model_id:
|
||||
models_by_global_model.setdefault(model.global_model_id, []).append(model)
|
||||
|
||||
# 4. 为每个 GlobalModel 构建 catalog item
|
||||
# 3. 为每个 GlobalModel 构建 catalog item
|
||||
catalog_items: List[ModelCatalogItem] = []
|
||||
|
||||
for gm in global_models:
|
||||
@@ -168,7 +123,6 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
|
||||
supports_function_calling=model.get_effective_supports_function_calling(),
|
||||
supports_streaming=model.get_effective_supports_streaming(),
|
||||
is_active=bool(model.is_active),
|
||||
mapping_id=None, # 新架构中不再有 mapping_id
|
||||
)
|
||||
)
|
||||
|
||||
@@ -187,7 +141,6 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
|
||||
global_model_name=gm.name,
|
||||
display_name=gm.display_name,
|
||||
description=gm.description,
|
||||
aliases=aliases_by_global_model.get(gm_id, []),
|
||||
providers=provider_entries,
|
||||
price_range=price_range,
|
||||
total_providers=len(provider_entries),
|
||||
@@ -195,238 +148,7 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
|
||||
)
|
||||
)
|
||||
|
||||
# 5. 查找孤立的别名(别名指向的 GlobalModel 不存在或不活跃)
|
||||
orphaned_rows = (
|
||||
db.query(ModelMapping.source_model, GlobalModel.name, func.count(ModelMapping.id))
|
||||
.outerjoin(GlobalModel, ModelMapping.target_global_model_id == GlobalModel.id)
|
||||
.filter(
|
||||
ModelMapping.is_active == True,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
or_(GlobalModel.id == None, GlobalModel.is_active == False),
|
||||
)
|
||||
.group_by(ModelMapping.source_model, GlobalModel.name)
|
||||
.all()
|
||||
)
|
||||
orphaned_models = [
|
||||
OrphanedModel(alias=row[0], global_model_name=row[1], mapping_count=row[2])
|
||||
for row in orphaned_rows
|
||||
if row[0]
|
||||
]
|
||||
|
||||
return ModelCatalogResponse(
|
||||
models=catalog_items,
|
||||
total=len(catalog_items),
|
||||
orphaned_models=orphaned_models,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminBatchAssignModelMappingsAdapter(AdminApiAdapter):
|
||||
payload: BatchAssignModelMappingRequest
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db: Session = context.db
|
||||
created: List[BatchAssignProviderResult] = []
|
||||
errors: List[BatchAssignError] = []
|
||||
|
||||
for provider_config in self.payload.providers:
|
||||
provider_id = provider_config.provider_id
|
||||
try:
|
||||
provider: Provider = db.query(Provider).filter(Provider.id == provider_id).first()
|
||||
if not provider:
|
||||
errors.append(
|
||||
BatchAssignError(provider_id=provider_id, error="Provider 不存在")
|
||||
)
|
||||
continue
|
||||
|
||||
model_id: Optional[str] = None
|
||||
created_model = False
|
||||
|
||||
if provider_config.create_model:
|
||||
model_data = provider_config.model_data
|
||||
if not model_data:
|
||||
errors.append(
|
||||
BatchAssignError(provider_id=provider_id, error="缺少 model_data 配置")
|
||||
)
|
||||
continue
|
||||
|
||||
existing_model = ModelService.get_model_by_name(
|
||||
db, provider_id, model_data.provider_model_name
|
||||
)
|
||||
if existing_model:
|
||||
model_id = existing_model.id
|
||||
logger.info("模型 %s 已存在于 Provider %s,复用现有模型",
|
||||
model_data.provider_model_name,
|
||||
provider.name,
|
||||
)
|
||||
else:
|
||||
model = ModelService.create_model(db, provider_id, model_data)
|
||||
model_id = model.id
|
||||
created_model = True
|
||||
else:
|
||||
model_id = provider_config.model_id
|
||||
if not model_id:
|
||||
errors.append(
|
||||
BatchAssignError(provider_id=provider_id, error="缺少 model_id")
|
||||
)
|
||||
continue
|
||||
model = (
|
||||
db.query(Model)
|
||||
.filter(Model.id == model_id, Model.provider_id == provider_id)
|
||||
.first()
|
||||
)
|
||||
if not model:
|
||||
errors.append(
|
||||
BatchAssignError(
|
||||
provider_id=provider_id, error="模型不存在或不属于当前 Provider")
|
||||
)
|
||||
continue
|
||||
|
||||
# 批量分配功能需要适配 GlobalModel 架构
|
||||
# 参见 docs/optimization-backlog.md 中的待办项
|
||||
errors.append(
|
||||
BatchAssignError(
|
||||
provider_id=provider_id,
|
||||
error="批量分配功能暂时不可用,需要适配新的 GlobalModel 架构",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
except Exception as exc:
|
||||
db.rollback()
|
||||
logger.error("批量添加模型映射失败(需要适配新架构)")
|
||||
errors.append(BatchAssignError(provider_id=provider_id, error=str(exc)))
|
||||
|
||||
return BatchAssignModelMappingResponse(
|
||||
success=len(created) > 0,
|
||||
created_mappings=created,
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/mappings/{mapping_id}", response_model=UpdateModelMappingResponse)
|
||||
async def update_model_mapping(
|
||||
request: Request,
|
||||
mapping_id: str,
|
||||
payload: UpdateModelMappingRequest,
|
||||
db: Session = Depends(get_db),
|
||||
) -> UpdateModelMappingResponse:
|
||||
"""更新模型映射"""
|
||||
adapter = AdminUpdateModelMappingAdapter(mapping_id=mapping_id, payload=payload)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/mappings/{mapping_id}", response_model=DeleteModelMappingResponse)
|
||||
async def delete_model_mapping(
|
||||
request: Request,
|
||||
mapping_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
) -> DeleteModelMappingResponse:
|
||||
"""删除模型映射"""
|
||||
adapter = AdminDeleteModelMappingAdapter(mapping_id=mapping_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminUpdateModelMappingAdapter(AdminApiAdapter):
|
||||
"""更新模型映射"""
|
||||
|
||||
mapping_id: str
|
||||
payload: UpdateModelMappingRequest
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db: Session = context.db
|
||||
|
||||
mapping: Optional[ModelMapping] = (
|
||||
db.query(ModelMapping).filter(ModelMapping.id == self.mapping_id).first()
|
||||
)
|
||||
|
||||
if not mapping:
|
||||
raise HTTPException(status_code=404, detail="映射不存在")
|
||||
|
||||
update_data = self.payload.model_dump(exclude_unset=True)
|
||||
|
||||
if "provider_id" in update_data:
|
||||
new_provider_id = update_data["provider_id"]
|
||||
if new_provider_id:
|
||||
provider = db.query(Provider).filter(Provider.id == new_provider_id).first()
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider 不存在")
|
||||
mapping.provider_id = new_provider_id
|
||||
|
||||
if "target_global_model_id" in update_data:
|
||||
target_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.id == update_data["target_global_model_id"],
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not target_model:
|
||||
raise HTTPException(status_code=404, detail="目标 GlobalModel 不存在或未激活")
|
||||
mapping.target_global_model_id = update_data["target_global_model_id"]
|
||||
|
||||
if "source_model" in update_data:
|
||||
new_source = update_data["source_model"].strip()
|
||||
if not new_source:
|
||||
raise HTTPException(status_code=400, detail="source_model 不能为空")
|
||||
mapping.source_model = new_source
|
||||
|
||||
if "is_active" in update_data:
|
||||
mapping.is_active = update_data["is_active"]
|
||||
|
||||
duplicate = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == mapping.source_model,
|
||||
ModelMapping.provider_id == mapping.provider_id,
|
||||
ModelMapping.id != mapping.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if duplicate:
|
||||
raise HTTPException(status_code=400, detail="映射已存在")
|
||||
|
||||
db.commit()
|
||||
db.refresh(mapping)
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_mapping_changed(mapping.source_model, mapping.provider_id)
|
||||
|
||||
return UpdateModelMappingResponse(
|
||||
success=True,
|
||||
mapping_id=mapping.id,
|
||||
message="映射更新成功",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminDeleteModelMappingAdapter(AdminApiAdapter):
|
||||
"""删除模型映射"""
|
||||
|
||||
mapping_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db: Session = context.db
|
||||
|
||||
mapping: Optional[ModelMapping] = (
|
||||
db.query(ModelMapping).filter(ModelMapping.id == self.mapping_id).first()
|
||||
)
|
||||
|
||||
if not mapping:
|
||||
raise HTTPException(status_code=404, detail="映射不存在")
|
||||
|
||||
source_model = mapping.source_model
|
||||
provider_id = mapping.provider_id
|
||||
|
||||
db.delete(mapping)
|
||||
db.commit()
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_mapping_changed(source_model, provider_id)
|
||||
|
||||
return DeleteModelMappingResponse(
|
||||
success=True,
|
||||
message=f"映射 {self.mapping_id} 已删除",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user