mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-02 15:52:26 +08:00
refactor(backend): update model catalog and provider APIs after mappings removal
This commit is contained in:
@@ -1,38 +1,26 @@
|
||||
"""
|
||||
统一模型目录 Admin API
|
||||
|
||||
阶段一:基于 ModelMapping 和 Model 的聚合视图
|
||||
基于 GlobalModel 的聚合视图
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Set
|
||||
from typing import Dict, List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy import func, or_
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import GlobalModel, Model, ModelMapping, Provider
|
||||
from src.models.database import GlobalModel, Model
|
||||
from src.models.pydantic_models import (
|
||||
BatchAssignError,
|
||||
BatchAssignModelMappingRequest,
|
||||
BatchAssignModelMappingResponse,
|
||||
BatchAssignProviderResult,
|
||||
DeleteModelMappingResponse,
|
||||
ModelCapabilities,
|
||||
ModelCatalogItem,
|
||||
ModelCatalogProviderDetail,
|
||||
ModelCatalogResponse,
|
||||
ModelPriceRange,
|
||||
OrphanedModel,
|
||||
UpdateModelMappingRequest,
|
||||
UpdateModelMappingResponse,
|
||||
)
|
||||
from src.services.cache.invalidation import get_cache_invalidation_service
|
||||
from src.services.model.service import ModelService
|
||||
|
||||
router = APIRouter(prefix="/catalog", tags=["Admin - Model Catalog"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
@@ -47,24 +35,13 @@ async def get_model_catalog(
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/batch-assign", response_model=BatchAssignModelMappingResponse)
|
||||
async def batch_assign_model_mappings(
|
||||
request: Request,
|
||||
payload: BatchAssignModelMappingRequest,
|
||||
db: Session = Depends(get_db),
|
||||
) -> BatchAssignModelMappingResponse:
|
||||
adapter = AdminBatchAssignModelMappingsAdapter(payload=payload)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminGetModelCatalogAdapter(AdminApiAdapter):
|
||||
"""管理员查询统一模型目录
|
||||
|
||||
新架构说明:
|
||||
架构说明:
|
||||
1. 以 GlobalModel 为中心聚合数据
|
||||
2. ModelMapping 表提供别名信息(provider_id=NULL 表示全局)
|
||||
3. Model 表提供关联提供商和价格
|
||||
2. Model 表提供关联提供商和价格
|
||||
"""
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
@@ -75,29 +52,7 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
|
||||
db.query(GlobalModel).filter(GlobalModel.is_active == True).all()
|
||||
)
|
||||
|
||||
# 2. 获取所有活跃的别名(含全局和 Provider 特定)
|
||||
aliases_rows: List[ModelMapping] = (
|
||||
db.query(ModelMapping)
|
||||
.options(joinedload(ModelMapping.target_global_model))
|
||||
.filter(
|
||||
ModelMapping.is_active == True,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 按 GlobalModel ID 组织别名
|
||||
aliases_by_global_model: Dict[str, List[str]] = {}
|
||||
for alias_row in aliases_rows:
|
||||
if not alias_row.target_global_model_id:
|
||||
continue
|
||||
gm_id = alias_row.target_global_model_id
|
||||
if gm_id not in aliases_by_global_model:
|
||||
aliases_by_global_model[gm_id] = []
|
||||
if alias_row.source_model not in aliases_by_global_model[gm_id]:
|
||||
aliases_by_global_model[gm_id].append(alias_row.source_model)
|
||||
|
||||
# 3. 获取所有活跃的 Model 实现(包含 global_model 以便计算有效价格)
|
||||
# 2. 获取所有活跃的 Model 实现(包含 global_model 以便计算有效价格)
|
||||
models: List[Model] = (
|
||||
db.query(Model)
|
||||
.options(joinedload(Model.provider), joinedload(Model.global_model))
|
||||
@@ -111,7 +66,7 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
|
||||
if model.global_model_id:
|
||||
models_by_global_model.setdefault(model.global_model_id, []).append(model)
|
||||
|
||||
# 4. 为每个 GlobalModel 构建 catalog item
|
||||
# 3. 为每个 GlobalModel 构建 catalog item
|
||||
catalog_items: List[ModelCatalogItem] = []
|
||||
|
||||
for gm in global_models:
|
||||
@@ -168,7 +123,6 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
|
||||
supports_function_calling=model.get_effective_supports_function_calling(),
|
||||
supports_streaming=model.get_effective_supports_streaming(),
|
||||
is_active=bool(model.is_active),
|
||||
mapping_id=None, # 新架构中不再有 mapping_id
|
||||
)
|
||||
)
|
||||
|
||||
@@ -187,7 +141,6 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
|
||||
global_model_name=gm.name,
|
||||
display_name=gm.display_name,
|
||||
description=gm.description,
|
||||
aliases=aliases_by_global_model.get(gm_id, []),
|
||||
providers=provider_entries,
|
||||
price_range=price_range,
|
||||
total_providers=len(provider_entries),
|
||||
@@ -195,238 +148,7 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
|
||||
)
|
||||
)
|
||||
|
||||
# 5. 查找孤立的别名(别名指向的 GlobalModel 不存在或不活跃)
|
||||
orphaned_rows = (
|
||||
db.query(ModelMapping.source_model, GlobalModel.name, func.count(ModelMapping.id))
|
||||
.outerjoin(GlobalModel, ModelMapping.target_global_model_id == GlobalModel.id)
|
||||
.filter(
|
||||
ModelMapping.is_active == True,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
or_(GlobalModel.id == None, GlobalModel.is_active == False),
|
||||
)
|
||||
.group_by(ModelMapping.source_model, GlobalModel.name)
|
||||
.all()
|
||||
)
|
||||
orphaned_models = [
|
||||
OrphanedModel(alias=row[0], global_model_name=row[1], mapping_count=row[2])
|
||||
for row in orphaned_rows
|
||||
if row[0]
|
||||
]
|
||||
|
||||
return ModelCatalogResponse(
|
||||
models=catalog_items,
|
||||
total=len(catalog_items),
|
||||
orphaned_models=orphaned_models,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminBatchAssignModelMappingsAdapter(AdminApiAdapter):
|
||||
payload: BatchAssignModelMappingRequest
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db: Session = context.db
|
||||
created: List[BatchAssignProviderResult] = []
|
||||
errors: List[BatchAssignError] = []
|
||||
|
||||
for provider_config in self.payload.providers:
|
||||
provider_id = provider_config.provider_id
|
||||
try:
|
||||
provider: Provider = db.query(Provider).filter(Provider.id == provider_id).first()
|
||||
if not provider:
|
||||
errors.append(
|
||||
BatchAssignError(provider_id=provider_id, error="Provider 不存在")
|
||||
)
|
||||
continue
|
||||
|
||||
model_id: Optional[str] = None
|
||||
created_model = False
|
||||
|
||||
if provider_config.create_model:
|
||||
model_data = provider_config.model_data
|
||||
if not model_data:
|
||||
errors.append(
|
||||
BatchAssignError(provider_id=provider_id, error="缺少 model_data 配置")
|
||||
)
|
||||
continue
|
||||
|
||||
existing_model = ModelService.get_model_by_name(
|
||||
db, provider_id, model_data.provider_model_name
|
||||
)
|
||||
if existing_model:
|
||||
model_id = existing_model.id
|
||||
logger.info("模型 %s 已存在于 Provider %s,复用现有模型",
|
||||
model_data.provider_model_name,
|
||||
provider.name,
|
||||
)
|
||||
else:
|
||||
model = ModelService.create_model(db, provider_id, model_data)
|
||||
model_id = model.id
|
||||
created_model = True
|
||||
else:
|
||||
model_id = provider_config.model_id
|
||||
if not model_id:
|
||||
errors.append(
|
||||
BatchAssignError(provider_id=provider_id, error="缺少 model_id")
|
||||
)
|
||||
continue
|
||||
model = (
|
||||
db.query(Model)
|
||||
.filter(Model.id == model_id, Model.provider_id == provider_id)
|
||||
.first()
|
||||
)
|
||||
if not model:
|
||||
errors.append(
|
||||
BatchAssignError(
|
||||
provider_id=provider_id, error="模型不存在或不属于当前 Provider")
|
||||
)
|
||||
continue
|
||||
|
||||
# 批量分配功能需要适配 GlobalModel 架构
|
||||
# 参见 docs/optimization-backlog.md 中的待办项
|
||||
errors.append(
|
||||
BatchAssignError(
|
||||
provider_id=provider_id,
|
||||
error="批量分配功能暂时不可用,需要适配新的 GlobalModel 架构",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
except Exception as exc:
|
||||
db.rollback()
|
||||
logger.error("批量添加模型映射失败(需要适配新架构)")
|
||||
errors.append(BatchAssignError(provider_id=provider_id, error=str(exc)))
|
||||
|
||||
return BatchAssignModelMappingResponse(
|
||||
success=len(created) > 0,
|
||||
created_mappings=created,
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/mappings/{mapping_id}", response_model=UpdateModelMappingResponse)
|
||||
async def update_model_mapping(
|
||||
request: Request,
|
||||
mapping_id: str,
|
||||
payload: UpdateModelMappingRequest,
|
||||
db: Session = Depends(get_db),
|
||||
) -> UpdateModelMappingResponse:
|
||||
"""更新模型映射"""
|
||||
adapter = AdminUpdateModelMappingAdapter(mapping_id=mapping_id, payload=payload)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/mappings/{mapping_id}", response_model=DeleteModelMappingResponse)
|
||||
async def delete_model_mapping(
|
||||
request: Request,
|
||||
mapping_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
) -> DeleteModelMappingResponse:
|
||||
"""删除模型映射"""
|
||||
adapter = AdminDeleteModelMappingAdapter(mapping_id=mapping_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminUpdateModelMappingAdapter(AdminApiAdapter):
|
||||
"""更新模型映射"""
|
||||
|
||||
mapping_id: str
|
||||
payload: UpdateModelMappingRequest
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db: Session = context.db
|
||||
|
||||
mapping: Optional[ModelMapping] = (
|
||||
db.query(ModelMapping).filter(ModelMapping.id == self.mapping_id).first()
|
||||
)
|
||||
|
||||
if not mapping:
|
||||
raise HTTPException(status_code=404, detail="映射不存在")
|
||||
|
||||
update_data = self.payload.model_dump(exclude_unset=True)
|
||||
|
||||
if "provider_id" in update_data:
|
||||
new_provider_id = update_data["provider_id"]
|
||||
if new_provider_id:
|
||||
provider = db.query(Provider).filter(Provider.id == new_provider_id).first()
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider 不存在")
|
||||
mapping.provider_id = new_provider_id
|
||||
|
||||
if "target_global_model_id" in update_data:
|
||||
target_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.id == update_data["target_global_model_id"],
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not target_model:
|
||||
raise HTTPException(status_code=404, detail="目标 GlobalModel 不存在或未激活")
|
||||
mapping.target_global_model_id = update_data["target_global_model_id"]
|
||||
|
||||
if "source_model" in update_data:
|
||||
new_source = update_data["source_model"].strip()
|
||||
if not new_source:
|
||||
raise HTTPException(status_code=400, detail="source_model 不能为空")
|
||||
mapping.source_model = new_source
|
||||
|
||||
if "is_active" in update_data:
|
||||
mapping.is_active = update_data["is_active"]
|
||||
|
||||
duplicate = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == mapping.source_model,
|
||||
ModelMapping.provider_id == mapping.provider_id,
|
||||
ModelMapping.id != mapping.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if duplicate:
|
||||
raise HTTPException(status_code=400, detail="映射已存在")
|
||||
|
||||
db.commit()
|
||||
db.refresh(mapping)
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_mapping_changed(mapping.source_model, mapping.provider_id)
|
||||
|
||||
return UpdateModelMappingResponse(
|
||||
success=True,
|
||||
mapping_id=mapping.id,
|
||||
message="映射更新成功",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminDeleteModelMappingAdapter(AdminApiAdapter):
|
||||
"""删除模型映射"""
|
||||
|
||||
mapping_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db: Session = context.db
|
||||
|
||||
mapping: Optional[ModelMapping] = (
|
||||
db.query(ModelMapping).filter(ModelMapping.id == self.mapping_id).first()
|
||||
)
|
||||
|
||||
if not mapping:
|
||||
raise HTTPException(status_code=404, detail="映射不存在")
|
||||
|
||||
source_model = mapping.source_model
|
||||
provider_id = mapping.provider_id
|
||||
|
||||
db.delete(mapping)
|
||||
db.commit()
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_mapping_changed(source_model, provider_id)
|
||||
|
||||
return DeleteModelMappingResponse(
|
||||
success=True,
|
||||
message=f"映射 {self.mapping_id} 已删除",
|
||||
)
|
||||
|
||||
@@ -123,7 +123,7 @@ class AdminListGlobalModelsAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from sqlalchemy import func
|
||||
|
||||
from src.models.database import Model, ModelMapping
|
||||
from src.models.database import Model
|
||||
|
||||
models = GlobalModelService.list_global_models(
|
||||
db=context.db,
|
||||
@@ -144,17 +144,8 @@ class AdminListGlobalModelsAdapter(AdminApiAdapter):
|
||||
or 0
|
||||
)
|
||||
|
||||
# 统计别名数量
|
||||
alias_count = (
|
||||
context.db.query(func.count(ModelMapping.id))
|
||||
.filter(ModelMapping.target_global_model_id == gm.id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
response = GlobalModelResponse.model_validate(gm)
|
||||
response.provider_count = provider_count
|
||||
response.alias_count = alias_count
|
||||
# usage_count 直接从 GlobalModel 表读取,已在 model_validate 中自动映射
|
||||
model_responses.append(response)
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
@@ -26,7 +25,6 @@ from src.models.pydantic_models import (
|
||||
from src.models.database import (
|
||||
GlobalModel,
|
||||
Model,
|
||||
ModelMapping,
|
||||
Provider,
|
||||
)
|
||||
from src.models.pydantic_models import (
|
||||
@@ -136,8 +134,7 @@ async def get_provider_available_source_models(
|
||||
获取该 Provider 支持的所有统一模型名(source_model)
|
||||
|
||||
包括:
|
||||
1. 通过 ModelMapping 映射的模型
|
||||
2. 直连模型(Model.provider_model_name 直接作为统一模型名)
|
||||
1. 直连模型(Model.provider_model_name 直接作为统一模型名)
|
||||
"""
|
||||
adapter = AdminGetProviderAvailableSourceModelsAdapter(provider_id=provider_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
@@ -294,10 +291,9 @@ class AdminGetProviderAvailableSourceModelsAdapter(AdminApiAdapter):
|
||||
"""
|
||||
返回 Provider 支持的所有 GlobalModel
|
||||
|
||||
方案 A 逻辑:
|
||||
逻辑:
|
||||
1. 查询该 Provider 的所有 Model
|
||||
2. 通过 Model.global_model_id 获取 GlobalModel
|
||||
3. 查询所有指向该 GlobalModel 的别名(ModelMapping.alias)
|
||||
"""
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
@@ -324,27 +320,10 @@ class AdminGetProviderAvailableSourceModelsAdapter(AdminApiAdapter):
|
||||
|
||||
# 如果该 GlobalModel 还未处理,初始化
|
||||
if global_model_name not in global_models_dict:
|
||||
# 查询指向该 GlobalModel 的所有别名/映射
|
||||
alias_rows = (
|
||||
db.query(ModelMapping.source_model)
|
||||
.filter(
|
||||
ModelMapping.target_global_model_id == global_model.id,
|
||||
ModelMapping.is_active == True,
|
||||
or_(
|
||||
ModelMapping.provider_id == self.provider_id,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
alias_list = [alias[0] for alias in alias_rows]
|
||||
|
||||
global_models_dict[global_model_name] = {
|
||||
"global_model_name": global_model_name,
|
||||
"display_name": global_model.display_name,
|
||||
"provider_model_name": model.provider_model_name,
|
||||
"has_alias": len(alias_list) > 0,
|
||||
"aliases": alias_list,
|
||||
"model_id": model.id,
|
||||
"price": {
|
||||
"input_price_per_1m": model.get_effective_input_price(),
|
||||
|
||||
@@ -20,14 +20,12 @@ from src.models.api import (
|
||||
ProviderStatsResponse,
|
||||
PublicGlobalModelListResponse,
|
||||
PublicGlobalModelResponse,
|
||||
PublicModelMappingResponse,
|
||||
PublicModelResponse,
|
||||
PublicProviderResponse,
|
||||
)
|
||||
from src.models.database import (
|
||||
GlobalModel,
|
||||
Model,
|
||||
ModelMapping,
|
||||
Provider,
|
||||
ProviderEndpoint,
|
||||
RequestCandidate,
|
||||
@@ -72,24 +70,6 @@ async def get_public_models(
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
|
||||
|
||||
|
||||
@router.get("/model-mappings", response_model=List[PublicModelMappingResponse])
|
||||
async def get_public_model_mappings(
|
||||
request: Request,
|
||||
provider_id: Optional[str] = Query(None, description="提供商ID过滤"),
|
||||
alias: Optional[str] = Query(None, description="别名过滤(原source_model)"),
|
||||
skip: int = Query(0, description="跳过记录数"),
|
||||
limit: int = Query(100, description="返回记录数限制"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = PublicModelMappingsAdapter(
|
||||
provider_id=provider_id,
|
||||
alias=alias,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
|
||||
|
||||
|
||||
@router.get("/stats", response_model=ProviderStatsResponse)
|
||||
async def get_public_stats(request: Request, db: Session = Depends(get_db)):
|
||||
adapter = PublicStatsAdapter()
|
||||
@@ -176,13 +156,6 @@ class PublicProvidersAdapter(PublicApiAdapter):
|
||||
.filter(and_(Model.provider_id == provider.id, Model.is_active.is_(True)))
|
||||
.count()
|
||||
)
|
||||
mappings_count = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
and_(ModelMapping.provider_id == provider.id, ModelMapping.is_active.is_(True))
|
||||
)
|
||||
.count()
|
||||
)
|
||||
endpoints_count = len(provider.endpoints) if provider.endpoints else 0
|
||||
active_endpoints_count = (
|
||||
sum(1 for ep in provider.endpoints if ep.is_active) if provider.endpoints else 0
|
||||
@@ -196,7 +169,6 @@ class PublicProvidersAdapter(PublicApiAdapter):
|
||||
provider_priority=provider.provider_priority,
|
||||
models_count=models_count,
|
||||
active_models_count=active_models_count,
|
||||
mappings_count=mappings_count,
|
||||
endpoints_count=endpoints_count,
|
||||
active_endpoints_count=active_endpoints_count,
|
||||
)
|
||||
@@ -256,78 +228,6 @@ class PublicModelsAdapter(PublicApiAdapter):
|
||||
return response
|
||||
|
||||
|
||||
@dataclass
|
||||
class PublicModelMappingsAdapter(PublicApiAdapter):
|
||||
provider_id: Optional[str]
|
||||
alias: Optional[str] # 原 source_model,改为 alias
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
logger.debug("公共API请求模型映射列表")
|
||||
|
||||
query = (
|
||||
db.query(ModelMapping, GlobalModel, Provider)
|
||||
.join(GlobalModel, ModelMapping.target_global_model_id == GlobalModel.id)
|
||||
.outerjoin(Provider, ModelMapping.provider_id == Provider.id)
|
||||
.filter(
|
||||
and_(
|
||||
ModelMapping.is_active.is_(True),
|
||||
GlobalModel.is_active.is_(True),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if self.provider_id is not None:
|
||||
provider_global_model_ids = (
|
||||
db.query(Model.global_model_id)
|
||||
.join(Provider, Model.provider_id == Provider.id)
|
||||
.filter(
|
||||
Provider.id == self.provider_id,
|
||||
Model.is_active.is_(True),
|
||||
Provider.is_active.is_(True),
|
||||
Model.global_model_id.isnot(None),
|
||||
)
|
||||
.distinct()
|
||||
)
|
||||
query = query.filter(
|
||||
or_(
|
||||
ModelMapping.provider_id == self.provider_id,
|
||||
and_(
|
||||
ModelMapping.provider_id.is_(None),
|
||||
ModelMapping.target_global_model_id.in_(provider_global_model_ids),
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
query = query.filter(ModelMapping.provider_id.is_(None))
|
||||
|
||||
if self.alias is not None:
|
||||
query = query.filter(ModelMapping.source_model.ilike(f"%{self.alias}%"))
|
||||
|
||||
results = query.offset(self.skip).limit(self.limit).all()
|
||||
response = []
|
||||
for mapping, global_model, provider in results:
|
||||
scope = "provider" if mapping.provider_id else "global"
|
||||
mapping_data = PublicModelMappingResponse(
|
||||
id=mapping.id,
|
||||
source_model=mapping.source_model,
|
||||
target_global_model_id=mapping.target_global_model_id,
|
||||
target_global_model_name=global_model.name if global_model else None,
|
||||
target_global_model_display_name=(
|
||||
global_model.display_name if global_model else None
|
||||
),
|
||||
provider_id=mapping.provider_id,
|
||||
scope=scope,
|
||||
is_active=mapping.is_active,
|
||||
)
|
||||
response.append(mapping_data.model_dump())
|
||||
|
||||
logger.debug(f"返回 {len(response)} 个模型映射")
|
||||
return response
|
||||
|
||||
|
||||
class PublicStatsAdapter(PublicApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
@@ -339,9 +239,6 @@ class PublicStatsAdapter(PublicApiAdapter):
|
||||
.filter(and_(Model.is_active.is_(True), Provider.is_active.is_(True)))
|
||||
.count()
|
||||
)
|
||||
from ...models.database import ModelMapping
|
||||
|
||||
active_mappings = db.query(ModelMapping).filter(ModelMapping.is_active.is_(True)).count()
|
||||
formats = (
|
||||
db.query(Provider.api_format).filter(Provider.is_active.is_(True)).distinct().all()
|
||||
)
|
||||
@@ -351,7 +248,6 @@ class PublicStatsAdapter(PublicApiAdapter):
|
||||
active_providers=active_providers,
|
||||
total_models=active_models,
|
||||
active_models=active_models,
|
||||
total_mappings=active_mappings,
|
||||
supported_formats=supported_formats,
|
||||
)
|
||||
logger.debug("返回系统统计信息")
|
||||
|
||||
@@ -7,6 +7,7 @@ import secrets
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum as PyEnum
|
||||
from typing import Optional
|
||||
|
||||
import bcrypt
|
||||
from sqlalchemy import (
|
||||
@@ -491,9 +492,6 @@ class Provider(Base):
|
||||
|
||||
# 关系
|
||||
models = relationship("Model", back_populates="provider", cascade="all, delete-orphan")
|
||||
model_mappings = relationship(
|
||||
"ModelMapping", back_populates="provider", cascade="all, delete-orphan"
|
||||
)
|
||||
endpoints = relationship(
|
||||
"ProviderEndpoint", back_populates="provider", cascade="all, delete-orphan"
|
||||
)
|
||||
@@ -656,7 +654,11 @@ class Model(Base):
|
||||
global_model_id = Column(String(36), ForeignKey("global_models.id"), nullable=False, index=True)
|
||||
|
||||
# Provider 映射配置
|
||||
provider_model_name = Column(String(200), nullable=False) # Provider 侧的模型名称
|
||||
provider_model_name = Column(String(200), nullable=False) # Provider 侧的主模型名称
|
||||
# 模型名称别名列表(带优先级),用于同一模型在 Provider 侧有多个名称变体的场景
|
||||
# 格式: [{"name": "Claude-Sonnet-4.5", "priority": 1}, {"name": "Claude-Sonnet-4-5", "priority": 2}]
|
||||
# 为空时只使用 provider_model_name
|
||||
provider_model_aliases = Column(JSON, nullable=True, default=None)
|
||||
|
||||
# 按次计费配置(每次请求的固定费用,美元)- 可为空,为空时使用 GlobalModel 的默认值
|
||||
price_per_request = Column(Float, nullable=True) # 每次请求固定费用
|
||||
@@ -786,60 +788,83 @@ class Model(Base):
|
||||
def get_effective_supports_image_generation(self) -> bool:
|
||||
return self._get_effective_capability("supports_image_generation", False)
|
||||
|
||||
def select_provider_model_name(self, affinity_key: Optional[str] = None) -> str:
|
||||
"""按优先级选择要使用的 Provider 模型名称
|
||||
|
||||
class ModelMapping(Base):
|
||||
"""模型映射表 - 统一处理别名与降级策略
|
||||
如果配置了 provider_model_aliases,按优先级选择(数字越小越优先);
|
||||
相同优先级的别名通过哈希分散实现负载均衡(与 Key 调度策略一致);
|
||||
否则返回 provider_model_name。
|
||||
|
||||
设计原则:
|
||||
- source_model 接收用户请求的原始模型名/别名
|
||||
- target_global_model_id 指向真实的 GlobalModel
|
||||
- provider_id 为空表示全局别名,非空表示 Provider 特定映射/降级
|
||||
- 一个 (source_model, provider_id) 组合唯一
|
||||
Args:
|
||||
affinity_key: 用于哈希分散的亲和键(如用户 API Key 哈希),确保同一用户稳定选择同一别名
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
映射类型 (mapping_type):
|
||||
- alias: 别名模式,按目标模型计费(只是名称简写)
|
||||
- mapping: 映射模式,按源模型计费(模型降级/替代)
|
||||
"""
|
||||
if not self.provider_model_aliases:
|
||||
return self.provider_model_name
|
||||
|
||||
__tablename__ = "model_mappings"
|
||||
raw_aliases = self.provider_model_aliases
|
||||
if not isinstance(raw_aliases, list) or len(raw_aliases) == 0:
|
||||
return self.provider_model_name
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True)
|
||||
aliases: list[dict] = []
|
||||
for raw in raw_aliases:
|
||||
if not isinstance(raw, dict):
|
||||
continue
|
||||
name = raw.get("name")
|
||||
if not isinstance(name, str) or not name.strip():
|
||||
continue
|
||||
|
||||
# 源模型名称(可能是别名或真实 GlobalModel.name)
|
||||
source_model = Column(String(200), nullable=False, index=True)
|
||||
raw_priority = raw.get("priority", 1)
|
||||
try:
|
||||
priority = int(raw_priority)
|
||||
except Exception:
|
||||
priority = 1
|
||||
if priority < 1:
|
||||
priority = 1
|
||||
|
||||
# 目标 GlobalModel
|
||||
target_global_model_id = Column(
|
||||
String(36), ForeignKey("global_models.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
aliases.append({"name": name.strip(), "priority": priority})
|
||||
|
||||
# Provider 关联:NULL 代表全局别名
|
||||
provider_id = Column(String(36), ForeignKey("providers.id"), nullable=True, index=True)
|
||||
if not aliases:
|
||||
return self.provider_model_name
|
||||
|
||||
# 映射类型:alias=按目标模型计费,mapping=按源模型计费
|
||||
mapping_type = Column(String(20), nullable=False, default="alias", index=True)
|
||||
# 按优先级排序(数字越小越优先)
|
||||
sorted_aliases = sorted(aliases, key=lambda x: x["priority"])
|
||||
|
||||
# 状态
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
# 获取最高优先级(最小数字)
|
||||
highest_priority = sorted_aliases[0]["priority"]
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
nullable=False,
|
||||
)
|
||||
# 获取所有最高优先级的别名
|
||||
top_priority_aliases = [
|
||||
alias for alias in sorted_aliases
|
||||
if alias["priority"] == highest_priority
|
||||
]
|
||||
|
||||
# 关系
|
||||
target_global_model = relationship("GlobalModel", foreign_keys=[target_global_model_id])
|
||||
provider = relationship("Provider", back_populates="model_mappings")
|
||||
# 如果有多个相同优先级的别名,通过哈希分散选择
|
||||
if len(top_priority_aliases) > 1 and affinity_key:
|
||||
# 为每个别名计算哈希得分,选择得分最小的
|
||||
def hash_score(alias: dict) -> int:
|
||||
combined = f"{affinity_key}:{alias['name']}"
|
||||
return int(hashlib.md5(combined.encode()).hexdigest(), 16)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("source_model", "provider_id", name="uq_model_mapping_source_provider"),
|
||||
)
|
||||
selected = min(top_priority_aliases, key=hash_score)
|
||||
elif len(top_priority_aliases) > 1:
|
||||
# 没有 affinity_key 时,使用确定性选择(按名称排序后取第一个)
|
||||
# 避免随机选择导致同一请求重试时选择不同的模型名称
|
||||
selected = min(top_priority_aliases, key=lambda x: x["name"])
|
||||
else:
|
||||
selected = top_priority_aliases[0]
|
||||
|
||||
return selected["name"]
|
||||
|
||||
def get_all_provider_model_names(self) -> list[str]:
|
||||
"""获取所有可用的 Provider 模型名称(主名称 + 别名)"""
|
||||
names = [self.provider_model_name]
|
||||
if self.provider_model_aliases:
|
||||
for alias in self.provider_model_aliases:
|
||||
if isinstance(alias, dict) and alias.get("name"):
|
||||
names.append(alias["name"])
|
||||
return names
|
||||
|
||||
|
||||
class ProviderAPIKey(Base):
|
||||
|
||||
@@ -7,8 +7,6 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from .api import ModelCreate
|
||||
|
||||
|
||||
# ========== 阶梯计费相关模型 ==========
|
||||
|
||||
@@ -131,24 +129,14 @@ class ModelCatalogProviderDetail(BaseModel):
|
||||
supports_function_calling: Optional[bool] = None
|
||||
supports_streaming: Optional[bool] = None
|
||||
is_active: bool
|
||||
mapping_id: Optional[str]
|
||||
|
||||
|
||||
class OrphanedModel(BaseModel):
|
||||
"""孤立的统一模型(Mapping 存在但 GlobalModel 缺失)"""
|
||||
|
||||
alias: str # 别名
|
||||
global_model_name: Optional[str] # 关联的 GlobalModel 名称(如果有)
|
||||
mapping_count: int
|
||||
|
||||
|
||||
class ModelCatalogItem(BaseModel):
|
||||
"""统一模型目录条目(方案 A:基于 GlobalModel)"""
|
||||
"""统一模型目录条目(基于 GlobalModel)"""
|
||||
|
||||
global_model_name: str # GlobalModel.name
|
||||
display_name: str # GlobalModel.display_name
|
||||
description: Optional[str] # GlobalModel.description
|
||||
aliases: List[str] # 所有指向该 GlobalModel 的别名列表
|
||||
providers: List[ModelCatalogProviderDetail] # 支持该模型的 Provider 列表
|
||||
price_range: ModelPriceRange # 价格区间(从所有 Provider 的 Model 中聚合)
|
||||
total_providers: int
|
||||
@@ -160,7 +148,6 @@ class ModelCatalogResponse(BaseModel):
|
||||
|
||||
models: List[ModelCatalogItem]
|
||||
total: int
|
||||
orphaned_models: List[OrphanedModel]
|
||||
|
||||
|
||||
class ProviderModelPriceInfo(BaseModel):
|
||||
@@ -174,13 +161,11 @@ class ProviderModelPriceInfo(BaseModel):
|
||||
|
||||
|
||||
class ProviderAvailableSourceModel(BaseModel):
|
||||
"""Provider 支持的统一模型条目(方案 A)"""
|
||||
"""Provider 支持的统一模型条目"""
|
||||
|
||||
global_model_name: str # GlobalModel.name
|
||||
display_name: str # GlobalModel.display_name
|
||||
provider_model_name: str # Model.provider_model_name (Provider 侧的模型名)
|
||||
has_alias: bool # 是否有别名指向该 GlobalModel
|
||||
aliases: List[str] # 别名列表
|
||||
model_id: Optional[str] # Model.id
|
||||
price: ProviderModelPriceInfo
|
||||
capabilities: ModelCapabilities
|
||||
@@ -194,50 +179,7 @@ class ProviderAvailableSourceModelsResponse(BaseModel):
|
||||
total: int
|
||||
|
||||
|
||||
class BatchAssignProviderConfig(BaseModel):
|
||||
"""批量添加映射的 Provider 配置"""
|
||||
|
||||
provider_id: str
|
||||
create_model: bool = Field(False, description="是否需要创建新的 Model")
|
||||
model_data: Optional[ModelCreate] = Field(
|
||||
None, description="create_model=true 时需要提供的模型配置", alias="model_config"
|
||||
)
|
||||
model_id: Optional[str] = Field(None, description="create_model=false 时需要提供的现有模型 ID")
|
||||
|
||||
|
||||
class BatchAssignModelMappingRequest(BaseModel):
|
||||
"""批量添加模型映射请求(方案 A:暂不支持,需要重构)"""
|
||||
|
||||
global_model_id: str # 要分配的 GlobalModel ID
|
||||
providers: List[BatchAssignProviderConfig]
|
||||
|
||||
|
||||
class BatchAssignProviderResult(BaseModel):
|
||||
"""批量映射结果条目"""
|
||||
|
||||
provider_id: str
|
||||
mapping_id: Optional[str]
|
||||
created_model: bool
|
||||
model_id: Optional[str]
|
||||
updated: bool = False
|
||||
|
||||
|
||||
class BatchAssignError(BaseModel):
|
||||
"""批量映射错误信息"""
|
||||
|
||||
provider_id: str
|
||||
error: str
|
||||
|
||||
|
||||
class BatchAssignModelMappingResponse(BaseModel):
|
||||
"""批量映射响应"""
|
||||
|
||||
success: bool
|
||||
created_mappings: List[BatchAssignProviderResult]
|
||||
errors: List[BatchAssignError]
|
||||
|
||||
|
||||
# ========== 阶段二:GlobalModel 相关模型 ==========
|
||||
# ========== GlobalModel 相关模型 ==========
|
||||
|
||||
|
||||
class GlobalModelCreate(BaseModel):
|
||||
@@ -328,7 +270,6 @@ class GlobalModelResponse(BaseModel):
|
||||
)
|
||||
# 统计数据(可选)
|
||||
provider_count: Optional[int] = Field(default=0, description="支持的 Provider 数量")
|
||||
alias_count: Optional[int] = Field(default=0, description="别名数量")
|
||||
usage_count: Optional[int] = Field(default=0, description="调用次数")
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime]
|
||||
@@ -355,7 +296,7 @@ class GlobalModelListResponse(BaseModel):
|
||||
class BatchAssignToProvidersRequest(BaseModel):
|
||||
"""批量为 Provider 添加 GlobalModel 实现"""
|
||||
|
||||
provider_ids: List[str] = Field(..., min_items=1, description="Provider ID 列表")
|
||||
provider_ids: List[str] = Field(..., min_length=1, description="Provider ID 列表")
|
||||
create_models: bool = Field(default=False, description="是否自动创建 Model 记录")
|
||||
|
||||
|
||||
@@ -379,43 +320,11 @@ class BatchAssignModelsToProviderResponse(BaseModel):
|
||||
errors: List[dict]
|
||||
|
||||
|
||||
class UpdateModelMappingRequest(BaseModel):
|
||||
"""更新模型映射请求"""
|
||||
|
||||
source_model: Optional[str] = Field(
|
||||
None, min_length=1, max_length=200, description="源模型名或别名"
|
||||
)
|
||||
target_global_model_id: Optional[str] = Field(None, description="目标 GlobalModel ID")
|
||||
provider_id: Optional[str] = Field(None, description="Provider ID(为空时为全局别名)")
|
||||
is_active: Optional[bool] = Field(None, description="是否启用")
|
||||
|
||||
|
||||
class UpdateModelMappingResponse(BaseModel):
|
||||
"""更新模型映射响应"""
|
||||
|
||||
success: bool
|
||||
mapping_id: str
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
class DeleteModelMappingResponse(BaseModel):
|
||||
"""删除模型映射响应"""
|
||||
|
||||
success: bool
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BatchAssignError",
|
||||
"BatchAssignModelMappingRequest",
|
||||
"BatchAssignModelMappingResponse",
|
||||
"BatchAssignModelsToProviderRequest",
|
||||
"BatchAssignModelsToProviderResponse",
|
||||
"BatchAssignProviderConfig",
|
||||
"BatchAssignProviderResult",
|
||||
"BatchAssignToProvidersRequest",
|
||||
"BatchAssignToProvidersResponse",
|
||||
"DeleteModelMappingResponse",
|
||||
"GlobalModelCreate",
|
||||
"GlobalModelListResponse",
|
||||
"GlobalModelResponse",
|
||||
@@ -426,10 +335,7 @@ __all__ = [
|
||||
"ModelCatalogProviderDetail",
|
||||
"ModelCatalogResponse",
|
||||
"ModelPriceRange",
|
||||
"OrphanedModel",
|
||||
"ProviderAvailableSourceModel",
|
||||
"ProviderAvailableSourceModelsResponse",
|
||||
"ProviderModelPriceInfo",
|
||||
"UpdateModelMappingRequest",
|
||||
"UpdateModelMappingResponse",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user