refactor(backend): update model catalog and provider APIs after mappings removal

This commit is contained in:
fawney19
2025-12-15 14:30:10 +08:00
parent 728f9bb126
commit 56fb6bf36c
6 changed files with 85 additions and 566 deletions

View File

@@ -1,38 +1,26 @@
""" """
统一模型目录 Admin API 统一模型目录 Admin API
阶段一:基于 ModelMapping 和 Model 的聚合视图 基于 GlobalModel 的聚合视图
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Set from typing import Dict, List
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, Request
from sqlalchemy import func, or_
from sqlalchemy.orm import Session, joinedload from sqlalchemy.orm import Session, joinedload
from src.api.base.admin_adapter import AdminApiAdapter from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline from src.api.base.pipeline import ApiRequestPipeline
from src.core.logger import logger
from src.database import get_db from src.database import get_db
from src.models.database import GlobalModel, Model, ModelMapping, Provider from src.models.database import GlobalModel, Model
from src.models.pydantic_models import ( from src.models.pydantic_models import (
BatchAssignError,
BatchAssignModelMappingRequest,
BatchAssignModelMappingResponse,
BatchAssignProviderResult,
DeleteModelMappingResponse,
ModelCapabilities, ModelCapabilities,
ModelCatalogItem, ModelCatalogItem,
ModelCatalogProviderDetail, ModelCatalogProviderDetail,
ModelCatalogResponse, ModelCatalogResponse,
ModelPriceRange, ModelPriceRange,
OrphanedModel,
UpdateModelMappingRequest,
UpdateModelMappingResponse,
) )
from src.services.cache.invalidation import get_cache_invalidation_service
from src.services.model.service import ModelService
router = APIRouter(prefix="/catalog", tags=["Admin - Model Catalog"]) router = APIRouter(prefix="/catalog", tags=["Admin - Model Catalog"])
pipeline = ApiRequestPipeline() pipeline = ApiRequestPipeline()
@@ -47,24 +35,13 @@ async def get_model_catalog(
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/batch-assign", response_model=BatchAssignModelMappingResponse)
async def batch_assign_model_mappings(
request: Request,
payload: BatchAssignModelMappingRequest,
db: Session = Depends(get_db),
) -> BatchAssignModelMappingResponse:
adapter = AdminBatchAssignModelMappingsAdapter(payload=payload)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@dataclass @dataclass
class AdminGetModelCatalogAdapter(AdminApiAdapter): class AdminGetModelCatalogAdapter(AdminApiAdapter):
"""管理员查询统一模型目录 """管理员查询统一模型目录
架构说明: 架构说明:
1. 以 GlobalModel 为中心聚合数据 1. 以 GlobalModel 为中心聚合数据
2. ModelMapping 表提供别名信息provider_id=NULL 表示全局) 2. Model 表提供关联提供商和价格
3. Model 表提供关联提供商和价格
""" """
async def handle(self, context): # type: ignore[override] async def handle(self, context): # type: ignore[override]
@@ -75,29 +52,7 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
db.query(GlobalModel).filter(GlobalModel.is_active == True).all() db.query(GlobalModel).filter(GlobalModel.is_active == True).all()
) )
# 2. 获取所有活跃的别名(含全局和 Provider 特定 # 2. 获取所有活跃的 Model 实现(包含 global_model 以便计算有效价格
aliases_rows: List[ModelMapping] = (
db.query(ModelMapping)
.options(joinedload(ModelMapping.target_global_model))
.filter(
ModelMapping.is_active == True,
ModelMapping.provider_id.is_(None),
)
.all()
)
# 按 GlobalModel ID 组织别名
aliases_by_global_model: Dict[str, List[str]] = {}
for alias_row in aliases_rows:
if not alias_row.target_global_model_id:
continue
gm_id = alias_row.target_global_model_id
if gm_id not in aliases_by_global_model:
aliases_by_global_model[gm_id] = []
if alias_row.source_model not in aliases_by_global_model[gm_id]:
aliases_by_global_model[gm_id].append(alias_row.source_model)
# 3. 获取所有活跃的 Model 实现(包含 global_model 以便计算有效价格)
models: List[Model] = ( models: List[Model] = (
db.query(Model) db.query(Model)
.options(joinedload(Model.provider), joinedload(Model.global_model)) .options(joinedload(Model.provider), joinedload(Model.global_model))
@@ -111,7 +66,7 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
if model.global_model_id: if model.global_model_id:
models_by_global_model.setdefault(model.global_model_id, []).append(model) models_by_global_model.setdefault(model.global_model_id, []).append(model)
# 4. 为每个 GlobalModel 构建 catalog item # 3. 为每个 GlobalModel 构建 catalog item
catalog_items: List[ModelCatalogItem] = [] catalog_items: List[ModelCatalogItem] = []
for gm in global_models: for gm in global_models:
@@ -168,7 +123,6 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
supports_function_calling=model.get_effective_supports_function_calling(), supports_function_calling=model.get_effective_supports_function_calling(),
supports_streaming=model.get_effective_supports_streaming(), supports_streaming=model.get_effective_supports_streaming(),
is_active=bool(model.is_active), is_active=bool(model.is_active),
mapping_id=None, # 新架构中不再有 mapping_id
) )
) )
@@ -187,7 +141,6 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
global_model_name=gm.name, global_model_name=gm.name,
display_name=gm.display_name, display_name=gm.display_name,
description=gm.description, description=gm.description,
aliases=aliases_by_global_model.get(gm_id, []),
providers=provider_entries, providers=provider_entries,
price_range=price_range, price_range=price_range,
total_providers=len(provider_entries), total_providers=len(provider_entries),
@@ -195,238 +148,7 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
) )
) )
# 5. 查找孤立的别名(别名指向的 GlobalModel 不存在或不活跃)
orphaned_rows = (
db.query(ModelMapping.source_model, GlobalModel.name, func.count(ModelMapping.id))
.outerjoin(GlobalModel, ModelMapping.target_global_model_id == GlobalModel.id)
.filter(
ModelMapping.is_active == True,
ModelMapping.provider_id.is_(None),
or_(GlobalModel.id == None, GlobalModel.is_active == False),
)
.group_by(ModelMapping.source_model, GlobalModel.name)
.all()
)
orphaned_models = [
OrphanedModel(alias=row[0], global_model_name=row[1], mapping_count=row[2])
for row in orphaned_rows
if row[0]
]
return ModelCatalogResponse( return ModelCatalogResponse(
models=catalog_items, models=catalog_items,
total=len(catalog_items), total=len(catalog_items),
orphaned_models=orphaned_models,
)
@dataclass
class AdminBatchAssignModelMappingsAdapter(AdminApiAdapter):
payload: BatchAssignModelMappingRequest
async def handle(self, context): # type: ignore[override]
db: Session = context.db
created: List[BatchAssignProviderResult] = []
errors: List[BatchAssignError] = []
for provider_config in self.payload.providers:
provider_id = provider_config.provider_id
try:
provider: Provider = db.query(Provider).filter(Provider.id == provider_id).first()
if not provider:
errors.append(
BatchAssignError(provider_id=provider_id, error="Provider 不存在")
)
continue
model_id: Optional[str] = None
created_model = False
if provider_config.create_model:
model_data = provider_config.model_data
if not model_data:
errors.append(
BatchAssignError(provider_id=provider_id, error="缺少 model_data 配置")
)
continue
existing_model = ModelService.get_model_by_name(
db, provider_id, model_data.provider_model_name
)
if existing_model:
model_id = existing_model.id
logger.info("模型 %s 已存在于 Provider %s,复用现有模型",
model_data.provider_model_name,
provider.name,
)
else:
model = ModelService.create_model(db, provider_id, model_data)
model_id = model.id
created_model = True
else:
model_id = provider_config.model_id
if not model_id:
errors.append(
BatchAssignError(provider_id=provider_id, error="缺少 model_id")
)
continue
model = (
db.query(Model)
.filter(Model.id == model_id, Model.provider_id == provider_id)
.first()
)
if not model:
errors.append(
BatchAssignError(
provider_id=provider_id, error="模型不存在或不属于当前 Provider")
)
continue
# 批量分配功能需要适配 GlobalModel 架构
# 参见 docs/optimization-backlog.md 中的待办项
errors.append(
BatchAssignError(
provider_id=provider_id,
error="批量分配功能暂时不可用,需要适配新的 GlobalModel 架构",
)
)
continue
except Exception as exc:
db.rollback()
logger.error("批量添加模型映射失败(需要适配新架构)")
errors.append(BatchAssignError(provider_id=provider_id, error=str(exc)))
return BatchAssignModelMappingResponse(
success=len(created) > 0,
created_mappings=created,
errors=errors,
)
@router.put("/mappings/{mapping_id}", response_model=UpdateModelMappingResponse)
async def update_model_mapping(
request: Request,
mapping_id: str,
payload: UpdateModelMappingRequest,
db: Session = Depends(get_db),
) -> UpdateModelMappingResponse:
"""更新模型映射"""
adapter = AdminUpdateModelMappingAdapter(mapping_id=mapping_id, payload=payload)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/mappings/{mapping_id}", response_model=DeleteModelMappingResponse)
async def delete_model_mapping(
request: Request,
mapping_id: str,
db: Session = Depends(get_db),
) -> DeleteModelMappingResponse:
"""删除模型映射"""
adapter = AdminDeleteModelMappingAdapter(mapping_id=mapping_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@dataclass
class AdminUpdateModelMappingAdapter(AdminApiAdapter):
"""更新模型映射"""
mapping_id: str
payload: UpdateModelMappingRequest
async def handle(self, context): # type: ignore[override]
db: Session = context.db
mapping: Optional[ModelMapping] = (
db.query(ModelMapping).filter(ModelMapping.id == self.mapping_id).first()
)
if not mapping:
raise HTTPException(status_code=404, detail="映射不存在")
update_data = self.payload.model_dump(exclude_unset=True)
if "provider_id" in update_data:
new_provider_id = update_data["provider_id"]
if new_provider_id:
provider = db.query(Provider).filter(Provider.id == new_provider_id).first()
if not provider:
raise HTTPException(status_code=404, detail="Provider 不存在")
mapping.provider_id = new_provider_id
if "target_global_model_id" in update_data:
target_model = (
db.query(GlobalModel)
.filter(
GlobalModel.id == update_data["target_global_model_id"],
GlobalModel.is_active == True,
)
.first()
)
if not target_model:
raise HTTPException(status_code=404, detail="目标 GlobalModel 不存在或未激活")
mapping.target_global_model_id = update_data["target_global_model_id"]
if "source_model" in update_data:
new_source = update_data["source_model"].strip()
if not new_source:
raise HTTPException(status_code=400, detail="source_model 不能为空")
mapping.source_model = new_source
if "is_active" in update_data:
mapping.is_active = update_data["is_active"]
duplicate = (
db.query(ModelMapping)
.filter(
ModelMapping.source_model == mapping.source_model,
ModelMapping.provider_id == mapping.provider_id,
ModelMapping.id != mapping.id,
)
.first()
)
if duplicate:
raise HTTPException(status_code=400, detail="映射已存在")
db.commit()
db.refresh(mapping)
cache_service = get_cache_invalidation_service()
cache_service.on_model_mapping_changed(mapping.source_model, mapping.provider_id)
return UpdateModelMappingResponse(
success=True,
mapping_id=mapping.id,
message="映射更新成功",
)
@dataclass
class AdminDeleteModelMappingAdapter(AdminApiAdapter):
"""删除模型映射"""
mapping_id: str
async def handle(self, context): # type: ignore[override]
db: Session = context.db
mapping: Optional[ModelMapping] = (
db.query(ModelMapping).filter(ModelMapping.id == self.mapping_id).first()
)
if not mapping:
raise HTTPException(status_code=404, detail="映射不存在")
source_model = mapping.source_model
provider_id = mapping.provider_id
db.delete(mapping)
db.commit()
cache_service = get_cache_invalidation_service()
cache_service.on_model_mapping_changed(source_model, provider_id)
return DeleteModelMappingResponse(
success=True,
message=f"映射 {self.mapping_id} 已删除",
) )

View File

@@ -123,7 +123,7 @@ class AdminListGlobalModelsAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override] async def handle(self, context): # type: ignore[override]
from sqlalchemy import func from sqlalchemy import func
from src.models.database import Model, ModelMapping from src.models.database import Model
models = GlobalModelService.list_global_models( models = GlobalModelService.list_global_models(
db=context.db, db=context.db,
@@ -144,17 +144,8 @@ class AdminListGlobalModelsAdapter(AdminApiAdapter):
or 0 or 0
) )
# 统计别名数量
alias_count = (
context.db.query(func.count(ModelMapping.id))
.filter(ModelMapping.target_global_model_id == gm.id)
.scalar()
or 0
)
response = GlobalModelResponse.model_validate(gm) response = GlobalModelResponse.model_validate(gm)
response.provider_count = provider_count response.provider_count = provider_count
response.alias_count = alias_count
# usage_count 直接从 GlobalModel 表读取,已在 model_validate 中自动映射 # usage_count 直接从 GlobalModel 表读取,已在 model_validate 中自动映射
model_responses.append(response) model_responses.append(response)

View File

@@ -6,7 +6,6 @@ from dataclasses import dataclass
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, Request from fastapi import APIRouter, Depends, Request
from sqlalchemy import or_
from sqlalchemy.orm import Session, joinedload from sqlalchemy.orm import Session, joinedload
from src.api.base.admin_adapter import AdminApiAdapter from src.api.base.admin_adapter import AdminApiAdapter
@@ -26,7 +25,6 @@ from src.models.pydantic_models import (
from src.models.database import ( from src.models.database import (
GlobalModel, GlobalModel,
Model, Model,
ModelMapping,
Provider, Provider,
) )
from src.models.pydantic_models import ( from src.models.pydantic_models import (
@@ -136,8 +134,7 @@ async def get_provider_available_source_models(
获取该 Provider 支持的所有统一模型名source_model 获取该 Provider 支持的所有统一模型名source_model
包括: 包括:
1. 通过 ModelMapping 映射的模型 1. 直连模型Model.provider_model_name 直接作为统一模型名)
2. 直连模型Model.provider_model_name 直接作为统一模型名)
""" """
adapter = AdminGetProviderAvailableSourceModelsAdapter(provider_id=provider_id) adapter = AdminGetProviderAvailableSourceModelsAdapter(provider_id=provider_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@@ -294,10 +291,9 @@ class AdminGetProviderAvailableSourceModelsAdapter(AdminApiAdapter):
""" """
返回 Provider 支持的所有 GlobalModel 返回 Provider 支持的所有 GlobalModel
方案 A 逻辑: 逻辑:
1. 查询该 Provider 的所有 Model 1. 查询该 Provider 的所有 Model
2. 通过 Model.global_model_id 获取 GlobalModel 2. 通过 Model.global_model_id 获取 GlobalModel
3. 查询所有指向该 GlobalModel 的别名ModelMapping.alias
""" """
db = context.db db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first() provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
@@ -324,27 +320,10 @@ class AdminGetProviderAvailableSourceModelsAdapter(AdminApiAdapter):
# 如果该 GlobalModel 还未处理,初始化 # 如果该 GlobalModel 还未处理,初始化
if global_model_name not in global_models_dict: if global_model_name not in global_models_dict:
# 查询指向该 GlobalModel 的所有别名/映射
alias_rows = (
db.query(ModelMapping.source_model)
.filter(
ModelMapping.target_global_model_id == global_model.id,
ModelMapping.is_active == True,
or_(
ModelMapping.provider_id == self.provider_id,
ModelMapping.provider_id.is_(None),
),
)
.all()
)
alias_list = [alias[0] for alias in alias_rows]
global_models_dict[global_model_name] = { global_models_dict[global_model_name] = {
"global_model_name": global_model_name, "global_model_name": global_model_name,
"display_name": global_model.display_name, "display_name": global_model.display_name,
"provider_model_name": model.provider_model_name, "provider_model_name": model.provider_model_name,
"has_alias": len(alias_list) > 0,
"aliases": alias_list,
"model_id": model.id, "model_id": model.id,
"price": { "price": {
"input_price_per_1m": model.get_effective_input_price(), "input_price_per_1m": model.get_effective_input_price(),

View File

@@ -20,14 +20,12 @@ from src.models.api import (
ProviderStatsResponse, ProviderStatsResponse,
PublicGlobalModelListResponse, PublicGlobalModelListResponse,
PublicGlobalModelResponse, PublicGlobalModelResponse,
PublicModelMappingResponse,
PublicModelResponse, PublicModelResponse,
PublicProviderResponse, PublicProviderResponse,
) )
from src.models.database import ( from src.models.database import (
GlobalModel, GlobalModel,
Model, Model,
ModelMapping,
Provider, Provider,
ProviderEndpoint, ProviderEndpoint,
RequestCandidate, RequestCandidate,
@@ -72,24 +70,6 @@ async def get_public_models(
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
@router.get("/model-mappings", response_model=List[PublicModelMappingResponse])
async def get_public_model_mappings(
request: Request,
provider_id: Optional[str] = Query(None, description="提供商ID过滤"),
alias: Optional[str] = Query(None, description="别名过滤原source_model"),
skip: int = Query(0, description="跳过记录数"),
limit: int = Query(100, description="返回记录数限制"),
db: Session = Depends(get_db),
):
adapter = PublicModelMappingsAdapter(
provider_id=provider_id,
alias=alias,
skip=skip,
limit=limit,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
@router.get("/stats", response_model=ProviderStatsResponse) @router.get("/stats", response_model=ProviderStatsResponse)
async def get_public_stats(request: Request, db: Session = Depends(get_db)): async def get_public_stats(request: Request, db: Session = Depends(get_db)):
adapter = PublicStatsAdapter() adapter = PublicStatsAdapter()
@@ -176,13 +156,6 @@ class PublicProvidersAdapter(PublicApiAdapter):
.filter(and_(Model.provider_id == provider.id, Model.is_active.is_(True))) .filter(and_(Model.provider_id == provider.id, Model.is_active.is_(True)))
.count() .count()
) )
mappings_count = (
db.query(ModelMapping)
.filter(
and_(ModelMapping.provider_id == provider.id, ModelMapping.is_active.is_(True))
)
.count()
)
endpoints_count = len(provider.endpoints) if provider.endpoints else 0 endpoints_count = len(provider.endpoints) if provider.endpoints else 0
active_endpoints_count = ( active_endpoints_count = (
sum(1 for ep in provider.endpoints if ep.is_active) if provider.endpoints else 0 sum(1 for ep in provider.endpoints if ep.is_active) if provider.endpoints else 0
@@ -196,7 +169,6 @@ class PublicProvidersAdapter(PublicApiAdapter):
provider_priority=provider.provider_priority, provider_priority=provider.provider_priority,
models_count=models_count, models_count=models_count,
active_models_count=active_models_count, active_models_count=active_models_count,
mappings_count=mappings_count,
endpoints_count=endpoints_count, endpoints_count=endpoints_count,
active_endpoints_count=active_endpoints_count, active_endpoints_count=active_endpoints_count,
) )
@@ -256,78 +228,6 @@ class PublicModelsAdapter(PublicApiAdapter):
return response return response
@dataclass
class PublicModelMappingsAdapter(PublicApiAdapter):
provider_id: Optional[str]
alias: Optional[str] # 原 source_model改为 alias
skip: int
limit: int
async def handle(self, context): # type: ignore[override]
db = context.db
logger.debug("公共API请求模型映射列表")
query = (
db.query(ModelMapping, GlobalModel, Provider)
.join(GlobalModel, ModelMapping.target_global_model_id == GlobalModel.id)
.outerjoin(Provider, ModelMapping.provider_id == Provider.id)
.filter(
and_(
ModelMapping.is_active.is_(True),
GlobalModel.is_active.is_(True),
)
)
)
if self.provider_id is not None:
provider_global_model_ids = (
db.query(Model.global_model_id)
.join(Provider, Model.provider_id == Provider.id)
.filter(
Provider.id == self.provider_id,
Model.is_active.is_(True),
Provider.is_active.is_(True),
Model.global_model_id.isnot(None),
)
.distinct()
)
query = query.filter(
or_(
ModelMapping.provider_id == self.provider_id,
and_(
ModelMapping.provider_id.is_(None),
ModelMapping.target_global_model_id.in_(provider_global_model_ids),
),
)
)
else:
query = query.filter(ModelMapping.provider_id.is_(None))
if self.alias is not None:
query = query.filter(ModelMapping.source_model.ilike(f"%{self.alias}%"))
results = query.offset(self.skip).limit(self.limit).all()
response = []
for mapping, global_model, provider in results:
scope = "provider" if mapping.provider_id else "global"
mapping_data = PublicModelMappingResponse(
id=mapping.id,
source_model=mapping.source_model,
target_global_model_id=mapping.target_global_model_id,
target_global_model_name=global_model.name if global_model else None,
target_global_model_display_name=(
global_model.display_name if global_model else None
),
provider_id=mapping.provider_id,
scope=scope,
is_active=mapping.is_active,
)
response.append(mapping_data.model_dump())
logger.debug(f"返回 {len(response)} 个模型映射")
return response
class PublicStatsAdapter(PublicApiAdapter): class PublicStatsAdapter(PublicApiAdapter):
async def handle(self, context): # type: ignore[override] async def handle(self, context): # type: ignore[override]
db = context.db db = context.db
@@ -339,9 +239,6 @@ class PublicStatsAdapter(PublicApiAdapter):
.filter(and_(Model.is_active.is_(True), Provider.is_active.is_(True))) .filter(and_(Model.is_active.is_(True), Provider.is_active.is_(True)))
.count() .count()
) )
from ...models.database import ModelMapping
active_mappings = db.query(ModelMapping).filter(ModelMapping.is_active.is_(True)).count()
formats = ( formats = (
db.query(Provider.api_format).filter(Provider.is_active.is_(True)).distinct().all() db.query(Provider.api_format).filter(Provider.is_active.is_(True)).distinct().all()
) )
@@ -351,7 +248,6 @@ class PublicStatsAdapter(PublicApiAdapter):
active_providers=active_providers, active_providers=active_providers,
total_models=active_models, total_models=active_models,
active_models=active_models, active_models=active_models,
total_mappings=active_mappings,
supported_formats=supported_formats, supported_formats=supported_formats,
) )
logger.debug("返回系统统计信息") logger.debug("返回系统统计信息")

View File

@@ -7,6 +7,7 @@ import secrets
import uuid import uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from enum import Enum as PyEnum from enum import Enum as PyEnum
from typing import Optional
import bcrypt import bcrypt
from sqlalchemy import ( from sqlalchemy import (
@@ -491,9 +492,6 @@ class Provider(Base):
# 关系 # 关系
models = relationship("Model", back_populates="provider", cascade="all, delete-orphan") models = relationship("Model", back_populates="provider", cascade="all, delete-orphan")
model_mappings = relationship(
"ModelMapping", back_populates="provider", cascade="all, delete-orphan"
)
endpoints = relationship( endpoints = relationship(
"ProviderEndpoint", back_populates="provider", cascade="all, delete-orphan" "ProviderEndpoint", back_populates="provider", cascade="all, delete-orphan"
) )
@@ -656,7 +654,11 @@ class Model(Base):
global_model_id = Column(String(36), ForeignKey("global_models.id"), nullable=False, index=True) global_model_id = Column(String(36), ForeignKey("global_models.id"), nullable=False, index=True)
# Provider 映射配置 # Provider 映射配置
provider_model_name = Column(String(200), nullable=False) # Provider 侧的模型名称 provider_model_name = Column(String(200), nullable=False) # Provider 侧的模型名称
# 模型名称别名列表(带优先级),用于同一模型在 Provider 侧有多个名称变体的场景
# 格式: [{"name": "Claude-Sonnet-4.5", "priority": 1}, {"name": "Claude-Sonnet-4-5", "priority": 2}]
# 为空时只使用 provider_model_name
provider_model_aliases = Column(JSON, nullable=True, default=None)
# 按次计费配置(每次请求的固定费用,美元)- 可为空,为空时使用 GlobalModel 的默认值 # 按次计费配置(每次请求的固定费用,美元)- 可为空,为空时使用 GlobalModel 的默认值
price_per_request = Column(Float, nullable=True) # 每次请求固定费用 price_per_request = Column(Float, nullable=True) # 每次请求固定费用
@@ -786,60 +788,83 @@ class Model(Base):
def get_effective_supports_image_generation(self) -> bool: def get_effective_supports_image_generation(self) -> bool:
return self._get_effective_capability("supports_image_generation", False) return self._get_effective_capability("supports_image_generation", False)
def select_provider_model_name(self, affinity_key: Optional[str] = None) -> str:
"""按优先级选择要使用的 Provider 模型名称
class ModelMapping(Base): 如果配置了 provider_model_aliases按优先级选择数字越小越优先
"""模型映射表 - 统一处理别名与降级策略 相同优先级的别名通过哈希分散实现负载均衡(与 Key 调度策略一致);
否则返回 provider_model_name。
设计原则: Args:
- source_model 接收用户请求的原始模型名/别名 affinity_key: 用于哈希分散的亲和键(如用户 API Key 哈希),确保同一用户稳定选择同一别名
- target_global_model_id 指向真实的 GlobalModel """
- provider_id 为空表示全局别名,非空表示 Provider 特定映射/降级 import hashlib
- 一个 (source_model, provider_id) 组合唯一
映射类型 (mapping_type): if not self.provider_model_aliases:
- alias: 别名模式,按目标模型计费(只是名称简写) return self.provider_model_name
- mapping: 映射模式,按源模型计费(模型降级/替代)
"""
__tablename__ = "model_mappings" raw_aliases = self.provider_model_aliases
if not isinstance(raw_aliases, list) or len(raw_aliases) == 0:
return self.provider_model_name
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True) aliases: list[dict] = []
for raw in raw_aliases:
if not isinstance(raw, dict):
continue
name = raw.get("name")
if not isinstance(name, str) or not name.strip():
continue
# 源模型名称(可能是别名或真实 GlobalModel.name raw_priority = raw.get("priority", 1)
source_model = Column(String(200), nullable=False, index=True) try:
priority = int(raw_priority)
except Exception:
priority = 1
if priority < 1:
priority = 1
# 目标 GlobalModel aliases.append({"name": name.strip(), "priority": priority})
target_global_model_id = Column(
String(36), ForeignKey("global_models.id", ondelete="CASCADE"), nullable=False, index=True
)
# Provider 关联NULL 代表全局别名 if not aliases:
provider_id = Column(String(36), ForeignKey("providers.id"), nullable=True, index=True) return self.provider_model_name
# 映射类型alias=按目标模型计费mapping=按源模型计费 # 按优先级排序(数字越小越优先)
mapping_type = Column(String(20), nullable=False, default="alias", index=True) sorted_aliases = sorted(aliases, key=lambda x: x["priority"])
# 状态 # 获取最高优先级(最小数字)
is_active = Column(Boolean, default=True, nullable=False) highest_priority = sorted_aliases[0]["priority"]
# 时间戳 # 获取所有最高优先级的别名
created_at = Column( top_priority_aliases = [
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False alias for alias in sorted_aliases
) if alias["priority"] == highest_priority
updated_at = Column( ]
DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
nullable=False,
)
# 关系 # 如果有多个相同优先级的别名,通过哈希分散选择
target_global_model = relationship("GlobalModel", foreign_keys=[target_global_model_id]) if len(top_priority_aliases) > 1 and affinity_key:
provider = relationship("Provider", back_populates="model_mappings") # 为每个别名计算哈希得分,选择得分最小的
def hash_score(alias: dict) -> int:
combined = f"{affinity_key}:{alias['name']}"
return int(hashlib.md5(combined.encode()).hexdigest(), 16)
__table_args__ = ( selected = min(top_priority_aliases, key=hash_score)
UniqueConstraint("source_model", "provider_id", name="uq_model_mapping_source_provider"), elif len(top_priority_aliases) > 1:
) # 没有 affinity_key 时,使用确定性选择(按名称排序后取第一个)
# 避免随机选择导致同一请求重试时选择不同的模型名称
selected = min(top_priority_aliases, key=lambda x: x["name"])
else:
selected = top_priority_aliases[0]
return selected["name"]
def get_all_provider_model_names(self) -> list[str]:
"""获取所有可用的 Provider 模型名称(主名称 + 别名)"""
names = [self.provider_model_name]
if self.provider_model_aliases:
for alias in self.provider_model_aliases:
if isinstance(alias, dict) and alias.get("name"):
names.append(alias["name"])
return names
class ProviderAPIKey(Base): class ProviderAPIKey(Base):

View File

@@ -7,8 +7,6 @@ from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
from .api import ModelCreate
# ========== 阶梯计费相关模型 ========== # ========== 阶梯计费相关模型 ==========
@@ -131,24 +129,14 @@ class ModelCatalogProviderDetail(BaseModel):
supports_function_calling: Optional[bool] = None supports_function_calling: Optional[bool] = None
supports_streaming: Optional[bool] = None supports_streaming: Optional[bool] = None
is_active: bool is_active: bool
mapping_id: Optional[str]
class OrphanedModel(BaseModel):
"""孤立的统一模型Mapping 存在但 GlobalModel 缺失)"""
alias: str # 别名
global_model_name: Optional[str] # 关联的 GlobalModel 名称(如果有)
mapping_count: int
class ModelCatalogItem(BaseModel): class ModelCatalogItem(BaseModel):
"""统一模型目录条目(方案 A基于 GlobalModel""" """统一模型目录条目(基于 GlobalModel"""
global_model_name: str # GlobalModel.name global_model_name: str # GlobalModel.name
display_name: str # GlobalModel.display_name display_name: str # GlobalModel.display_name
description: Optional[str] # GlobalModel.description description: Optional[str] # GlobalModel.description
aliases: List[str] # 所有指向该 GlobalModel 的别名列表
providers: List[ModelCatalogProviderDetail] # 支持该模型的 Provider 列表 providers: List[ModelCatalogProviderDetail] # 支持该模型的 Provider 列表
price_range: ModelPriceRange # 价格区间(从所有 Provider 的 Model 中聚合) price_range: ModelPriceRange # 价格区间(从所有 Provider 的 Model 中聚合)
total_providers: int total_providers: int
@@ -160,7 +148,6 @@ class ModelCatalogResponse(BaseModel):
models: List[ModelCatalogItem] models: List[ModelCatalogItem]
total: int total: int
orphaned_models: List[OrphanedModel]
class ProviderModelPriceInfo(BaseModel): class ProviderModelPriceInfo(BaseModel):
@@ -174,13 +161,11 @@ class ProviderModelPriceInfo(BaseModel):
class ProviderAvailableSourceModel(BaseModel): class ProviderAvailableSourceModel(BaseModel):
"""Provider 支持的统一模型条目(方案 A""" """Provider 支持的统一模型条目"""
global_model_name: str # GlobalModel.name global_model_name: str # GlobalModel.name
display_name: str # GlobalModel.display_name display_name: str # GlobalModel.display_name
provider_model_name: str # Model.provider_model_name (Provider 侧的模型名) provider_model_name: str # Model.provider_model_name (Provider 侧的模型名)
has_alias: bool # 是否有别名指向该 GlobalModel
aliases: List[str] # 别名列表
model_id: Optional[str] # Model.id model_id: Optional[str] # Model.id
price: ProviderModelPriceInfo price: ProviderModelPriceInfo
capabilities: ModelCapabilities capabilities: ModelCapabilities
@@ -194,50 +179,7 @@ class ProviderAvailableSourceModelsResponse(BaseModel):
total: int total: int
class BatchAssignProviderConfig(BaseModel): # ========== GlobalModel 相关模型 ==========
"""批量添加映射的 Provider 配置"""
provider_id: str
create_model: bool = Field(False, description="是否需要创建新的 Model")
model_data: Optional[ModelCreate] = Field(
None, description="create_model=true 时需要提供的模型配置", alias="model_config"
)
model_id: Optional[str] = Field(None, description="create_model=false 时需要提供的现有模型 ID")
class BatchAssignModelMappingRequest(BaseModel):
"""批量添加模型映射请求(方案 A暂不支持需要重构"""
global_model_id: str # 要分配的 GlobalModel ID
providers: List[BatchAssignProviderConfig]
class BatchAssignProviderResult(BaseModel):
"""批量映射结果条目"""
provider_id: str
mapping_id: Optional[str]
created_model: bool
model_id: Optional[str]
updated: bool = False
class BatchAssignError(BaseModel):
"""批量映射错误信息"""
provider_id: str
error: str
class BatchAssignModelMappingResponse(BaseModel):
"""批量映射响应"""
success: bool
created_mappings: List[BatchAssignProviderResult]
errors: List[BatchAssignError]
# ========== 阶段二GlobalModel 相关模型 ==========
class GlobalModelCreate(BaseModel): class GlobalModelCreate(BaseModel):
@@ -328,7 +270,6 @@ class GlobalModelResponse(BaseModel):
) )
# 统计数据(可选) # 统计数据(可选)
provider_count: Optional[int] = Field(default=0, description="支持的 Provider 数量") provider_count: Optional[int] = Field(default=0, description="支持的 Provider 数量")
alias_count: Optional[int] = Field(default=0, description="别名数量")
usage_count: Optional[int] = Field(default=0, description="调用次数") usage_count: Optional[int] = Field(default=0, description="调用次数")
created_at: datetime created_at: datetime
updated_at: Optional[datetime] updated_at: Optional[datetime]
@@ -355,7 +296,7 @@ class GlobalModelListResponse(BaseModel):
class BatchAssignToProvidersRequest(BaseModel): class BatchAssignToProvidersRequest(BaseModel):
"""批量为 Provider 添加 GlobalModel 实现""" """批量为 Provider 添加 GlobalModel 实现"""
provider_ids: List[str] = Field(..., min_items=1, description="Provider ID 列表") provider_ids: List[str] = Field(..., min_length=1, description="Provider ID 列表")
create_models: bool = Field(default=False, description="是否自动创建 Model 记录") create_models: bool = Field(default=False, description="是否自动创建 Model 记录")
@@ -379,43 +320,11 @@ class BatchAssignModelsToProviderResponse(BaseModel):
errors: List[dict] errors: List[dict]
class UpdateModelMappingRequest(BaseModel):
"""更新模型映射请求"""
source_model: Optional[str] = Field(
None, min_length=1, max_length=200, description="源模型名或别名"
)
target_global_model_id: Optional[str] = Field(None, description="目标 GlobalModel ID")
provider_id: Optional[str] = Field(None, description="Provider ID为空时为全局别名")
is_active: Optional[bool] = Field(None, description="是否启用")
class UpdateModelMappingResponse(BaseModel):
"""更新模型映射响应"""
success: bool
mapping_id: str
message: Optional[str] = None
class DeleteModelMappingResponse(BaseModel):
"""删除模型映射响应"""
success: bool
message: Optional[str] = None
__all__ = [ __all__ = [
"BatchAssignError",
"BatchAssignModelMappingRequest",
"BatchAssignModelMappingResponse",
"BatchAssignModelsToProviderRequest", "BatchAssignModelsToProviderRequest",
"BatchAssignModelsToProviderResponse", "BatchAssignModelsToProviderResponse",
"BatchAssignProviderConfig",
"BatchAssignProviderResult",
"BatchAssignToProvidersRequest", "BatchAssignToProvidersRequest",
"BatchAssignToProvidersResponse", "BatchAssignToProvidersResponse",
"DeleteModelMappingResponse",
"GlobalModelCreate", "GlobalModelCreate",
"GlobalModelListResponse", "GlobalModelListResponse",
"GlobalModelResponse", "GlobalModelResponse",
@@ -426,10 +335,7 @@ __all__ = [
"ModelCatalogProviderDetail", "ModelCatalogProviderDetail",
"ModelCatalogResponse", "ModelCatalogResponse",
"ModelPriceRange", "ModelPriceRange",
"OrphanedModel",
"ProviderAvailableSourceModel", "ProviderAvailableSourceModel",
"ProviderAvailableSourceModelsResponse", "ProviderAvailableSourceModelsResponse",
"ProviderModelPriceInfo", "ProviderModelPriceInfo",
"UpdateModelMappingRequest",
"UpdateModelMappingResponse",
] ]