From 56fb6bf36c69a2b0494d23f294c6591305099769 Mon Sep 17 00:00:00 2001 From: fawney19 Date: Mon, 15 Dec 2025 14:30:10 +0800 Subject: [PATCH] refactor(backend): update model catalog and provider APIs after mappings removal --- src/api/admin/models/catalog.py | 294 +------------------------- src/api/admin/models/global_models.py | 11 +- src/api/admin/providers/models.py | 25 +-- src/api/public/catalog.py | 104 --------- src/models/database.py | 115 ++++++---- src/models/pydantic_models.py | 102 +-------- 6 files changed, 85 insertions(+), 566 deletions(-) diff --git a/src/api/admin/models/catalog.py b/src/api/admin/models/catalog.py index e12ab1c..57e1705 100644 --- a/src/api/admin/models/catalog.py +++ b/src/api/admin/models/catalog.py @@ -1,38 +1,26 @@ """ 统一模型目录 Admin API -阶段一:基于 ModelMapping 和 Model 的聚合视图 +基于 GlobalModel 的聚合视图 """ from dataclasses import dataclass -from typing import Dict, List, Optional, Set +from typing import Dict, List -from fastapi import APIRouter, Depends, HTTPException, Request -from sqlalchemy import func, or_ +from fastapi import APIRouter, Depends, Request from sqlalchemy.orm import Session, joinedload from src.api.base.admin_adapter import AdminApiAdapter from src.api.base.pipeline import ApiRequestPipeline -from src.core.logger import logger from src.database import get_db -from src.models.database import GlobalModel, Model, ModelMapping, Provider +from src.models.database import GlobalModel, Model from src.models.pydantic_models import ( - BatchAssignError, - BatchAssignModelMappingRequest, - BatchAssignModelMappingResponse, - BatchAssignProviderResult, - DeleteModelMappingResponse, ModelCapabilities, ModelCatalogItem, ModelCatalogProviderDetail, ModelCatalogResponse, ModelPriceRange, - OrphanedModel, - UpdateModelMappingRequest, - UpdateModelMappingResponse, ) -from src.services.cache.invalidation import get_cache_invalidation_service -from src.services.model.service import ModelService router = APIRouter(prefix="/catalog", tags=["Admin - Model Catalog"]) pipeline = ApiRequestPipeline() @@ -47,24 +35,13 @@ async def get_model_catalog( return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) -@router.post("/batch-assign", response_model=BatchAssignModelMappingResponse) -async def batch_assign_model_mappings( - request: Request, - payload: BatchAssignModelMappingRequest, - db: Session = Depends(get_db), -) -> BatchAssignModelMappingResponse: - adapter = AdminBatchAssignModelMappingsAdapter(payload=payload) - return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) - - @dataclass class AdminGetModelCatalogAdapter(AdminApiAdapter): """管理员查询统一模型目录 - 新架构说明: + 架构说明: 1. 以 GlobalModel 为中心聚合数据 - 2. ModelMapping 表提供别名信息(provider_id=NULL 表示全局) - 3. Model 表提供关联提供商和价格 + 2. Model 表提供关联提供商和价格 """ async def handle(self, context): # type: ignore[override] @@ -75,29 +52,7 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter): db.query(GlobalModel).filter(GlobalModel.is_active == True).all() ) - # 2. 获取所有活跃的别名(含全局和 Provider 特定) - aliases_rows: List[ModelMapping] = ( - db.query(ModelMapping) - .options(joinedload(ModelMapping.target_global_model)) - .filter( - ModelMapping.is_active == True, - ModelMapping.provider_id.is_(None), - ) - .all() - ) - - # 按 GlobalModel ID 组织别名 - aliases_by_global_model: Dict[str, List[str]] = {} - for alias_row in aliases_rows: - if not alias_row.target_global_model_id: - continue - gm_id = alias_row.target_global_model_id - if gm_id not in aliases_by_global_model: - aliases_by_global_model[gm_id] = [] - if alias_row.source_model not in aliases_by_global_model[gm_id]: - aliases_by_global_model[gm_id].append(alias_row.source_model) - - # 3. 获取所有活跃的 Model 实现(包含 global_model 以便计算有效价格) + # 2. 获取所有活跃的 Model 实现(包含 global_model 以便计算有效价格) models: List[Model] = ( db.query(Model) .options(joinedload(Model.provider), joinedload(Model.global_model)) @@ -111,7 +66,7 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter): if model.global_model_id: models_by_global_model.setdefault(model.global_model_id, []).append(model) - # 4. 为每个 GlobalModel 构建 catalog item + # 3. 为每个 GlobalModel 构建 catalog item catalog_items: List[ModelCatalogItem] = [] for gm in global_models: @@ -168,7 +123,6 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter): supports_function_calling=model.get_effective_supports_function_calling(), supports_streaming=model.get_effective_supports_streaming(), is_active=bool(model.is_active), - mapping_id=None, # 新架构中不再有 mapping_id ) ) @@ -187,7 +141,6 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter): global_model_name=gm.name, display_name=gm.display_name, description=gm.description, - aliases=aliases_by_global_model.get(gm_id, []), providers=provider_entries, price_range=price_range, total_providers=len(provider_entries), @@ -195,238 +148,7 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter): ) ) - # 5. 查找孤立的别名(别名指向的 GlobalModel 不存在或不活跃) - orphaned_rows = ( - db.query(ModelMapping.source_model, GlobalModel.name, func.count(ModelMapping.id)) - .outerjoin(GlobalModel, ModelMapping.target_global_model_id == GlobalModel.id) - .filter( - ModelMapping.is_active == True, - ModelMapping.provider_id.is_(None), - or_(GlobalModel.id == None, GlobalModel.is_active == False), - ) - .group_by(ModelMapping.source_model, GlobalModel.name) - .all() - ) - orphaned_models = [ - OrphanedModel(alias=row[0], global_model_name=row[1], mapping_count=row[2]) - for row in orphaned_rows - if row[0] - ] - return ModelCatalogResponse( models=catalog_items, total=len(catalog_items), - orphaned_models=orphaned_models, - ) - - -@dataclass -class AdminBatchAssignModelMappingsAdapter(AdminApiAdapter): - payload: BatchAssignModelMappingRequest - - async def handle(self, context): # type: ignore[override] - db: Session = context.db - created: List[BatchAssignProviderResult] = [] - errors: List[BatchAssignError] = [] - - for provider_config in self.payload.providers: - provider_id = provider_config.provider_id - try: - provider: Provider = db.query(Provider).filter(Provider.id == provider_id).first() - if not provider: - errors.append( - BatchAssignError(provider_id=provider_id, error="Provider 不存在") - ) - continue - - model_id: Optional[str] = None - created_model = False - - if provider_config.create_model: - model_data = provider_config.model_data - if not model_data: - errors.append( - BatchAssignError(provider_id=provider_id, error="缺少 model_data 配置") - ) - continue - - existing_model = ModelService.get_model_by_name( - db, provider_id, model_data.provider_model_name - ) - if existing_model: - model_id = existing_model.id - logger.info("模型 %s 已存在于 Provider %s,复用现有模型", - model_data.provider_model_name, - provider.name, - ) - else: - model = ModelService.create_model(db, provider_id, model_data) - model_id = model.id - created_model = True - else: - model_id = provider_config.model_id - if not model_id: - errors.append( - BatchAssignError(provider_id=provider_id, error="缺少 model_id") - ) - continue - model = ( - db.query(Model) - .filter(Model.id == model_id, Model.provider_id == provider_id) - .first() - ) - if not model: - errors.append( - BatchAssignError( - provider_id=provider_id, error="模型不存在或不属于当前 Provider") - ) - continue - - # 批量分配功能需要适配 GlobalModel 架构 - # 参见 docs/optimization-backlog.md 中的待办项 - errors.append( - BatchAssignError( - provider_id=provider_id, - error="批量分配功能暂时不可用,需要适配新的 GlobalModel 架构", - ) - ) - continue - - except Exception as exc: - db.rollback() - logger.error("批量添加模型映射失败(需要适配新架构)") - errors.append(BatchAssignError(provider_id=provider_id, error=str(exc))) - - return BatchAssignModelMappingResponse( - success=len(created) > 0, - created_mappings=created, - errors=errors, - ) - - -@router.put("/mappings/{mapping_id}", response_model=UpdateModelMappingResponse) -async def update_model_mapping( - request: Request, - mapping_id: str, - payload: UpdateModelMappingRequest, - db: Session = Depends(get_db), -) -> UpdateModelMappingResponse: - """更新模型映射""" - adapter = AdminUpdateModelMappingAdapter(mapping_id=mapping_id, payload=payload) - return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) - - -@router.delete("/mappings/{mapping_id}", response_model=DeleteModelMappingResponse) -async def delete_model_mapping( - request: Request, - mapping_id: str, - db: Session = Depends(get_db), -) -> DeleteModelMappingResponse: - """删除模型映射""" - adapter = AdminDeleteModelMappingAdapter(mapping_id=mapping_id) - return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) - - -@dataclass -class AdminUpdateModelMappingAdapter(AdminApiAdapter): - """更新模型映射""" - - mapping_id: str - payload: UpdateModelMappingRequest - - async def handle(self, context): # type: ignore[override] - db: Session = context.db - - mapping: Optional[ModelMapping] = ( - db.query(ModelMapping).filter(ModelMapping.id == self.mapping_id).first() - ) - - if not mapping: - raise HTTPException(status_code=404, detail="映射不存在") - - update_data = self.payload.model_dump(exclude_unset=True) - - if "provider_id" in update_data: - new_provider_id = update_data["provider_id"] - if new_provider_id: - provider = db.query(Provider).filter(Provider.id == new_provider_id).first() - if not provider: - raise HTTPException(status_code=404, detail="Provider 不存在") - mapping.provider_id = new_provider_id - - if "target_global_model_id" in update_data: - target_model = ( - db.query(GlobalModel) - .filter( - GlobalModel.id == update_data["target_global_model_id"], - GlobalModel.is_active == True, - ) - .first() - ) - if not target_model: - raise HTTPException(status_code=404, detail="目标 GlobalModel 不存在或未激活") - mapping.target_global_model_id = update_data["target_global_model_id"] - - if "source_model" in update_data: - new_source = update_data["source_model"].strip() - if not new_source: - raise HTTPException(status_code=400, detail="source_model 不能为空") - mapping.source_model = new_source - - if "is_active" in update_data: - mapping.is_active = update_data["is_active"] - - duplicate = ( - db.query(ModelMapping) - .filter( - ModelMapping.source_model == mapping.source_model, - ModelMapping.provider_id == mapping.provider_id, - ModelMapping.id != mapping.id, - ) - .first() - ) - if duplicate: - raise HTTPException(status_code=400, detail="映射已存在") - - db.commit() - db.refresh(mapping) - - cache_service = get_cache_invalidation_service() - cache_service.on_model_mapping_changed(mapping.source_model, mapping.provider_id) - - return UpdateModelMappingResponse( - success=True, - mapping_id=mapping.id, - message="映射更新成功", - ) - - -@dataclass -class AdminDeleteModelMappingAdapter(AdminApiAdapter): - """删除模型映射""" - - mapping_id: str - - async def handle(self, context): # type: ignore[override] - db: Session = context.db - - mapping: Optional[ModelMapping] = ( - db.query(ModelMapping).filter(ModelMapping.id == self.mapping_id).first() - ) - - if not mapping: - raise HTTPException(status_code=404, detail="映射不存在") - - source_model = mapping.source_model - provider_id = mapping.provider_id - - db.delete(mapping) - db.commit() - - cache_service = get_cache_invalidation_service() - cache_service.on_model_mapping_changed(source_model, provider_id) - - return DeleteModelMappingResponse( - success=True, - message=f"映射 {self.mapping_id} 已删除", ) diff --git a/src/api/admin/models/global_models.py b/src/api/admin/models/global_models.py index 1292eee..766743c 100644 --- a/src/api/admin/models/global_models.py +++ b/src/api/admin/models/global_models.py @@ -123,7 +123,7 @@ class AdminListGlobalModelsAdapter(AdminApiAdapter): async def handle(self, context): # type: ignore[override] from sqlalchemy import func - from src.models.database import Model, ModelMapping + from src.models.database import Model models = GlobalModelService.list_global_models( db=context.db, @@ -144,17 +144,8 @@ class AdminListGlobalModelsAdapter(AdminApiAdapter): or 0 ) - # 统计别名数量 - alias_count = ( - context.db.query(func.count(ModelMapping.id)) - .filter(ModelMapping.target_global_model_id == gm.id) - .scalar() - or 0 - ) - response = GlobalModelResponse.model_validate(gm) response.provider_count = provider_count - response.alias_count = alias_count # usage_count 直接从 GlobalModel 表读取,已在 model_validate 中自动映射 model_responses.append(response) diff --git a/src/api/admin/providers/models.py b/src/api/admin/providers/models.py index 359f5d6..8b84ca6 100644 --- a/src/api/admin/providers/models.py +++ b/src/api/admin/providers/models.py @@ -6,7 +6,6 @@ from dataclasses import dataclass from typing import Any, Dict, List, Optional from fastapi import APIRouter, Depends, Request -from sqlalchemy import or_ from sqlalchemy.orm import Session, joinedload from src.api.base.admin_adapter import AdminApiAdapter @@ -26,7 +25,6 @@ from src.models.pydantic_models import ( from src.models.database import ( GlobalModel, Model, - ModelMapping, Provider, ) from src.models.pydantic_models import ( @@ -136,8 +134,7 @@ async def get_provider_available_source_models( 获取该 Provider 支持的所有统一模型名(source_model) 包括: - 1. 通过 ModelMapping 映射的模型 - 2. 直连模型(Model.provider_model_name 直接作为统一模型名) + 1. 直连模型(Model.provider_model_name 直接作为统一模型名) """ adapter = AdminGetProviderAvailableSourceModelsAdapter(provider_id=provider_id) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) @@ -294,10 +291,9 @@ class AdminGetProviderAvailableSourceModelsAdapter(AdminApiAdapter): """ 返回 Provider 支持的所有 GlobalModel - 方案 A 逻辑: + 逻辑: 1. 查询该 Provider 的所有 Model 2. 通过 Model.global_model_id 获取 GlobalModel - 3. 查询所有指向该 GlobalModel 的别名(ModelMapping.alias) """ db = context.db provider = db.query(Provider).filter(Provider.id == self.provider_id).first() @@ -324,27 +320,10 @@ class AdminGetProviderAvailableSourceModelsAdapter(AdminApiAdapter): # 如果该 GlobalModel 还未处理,初始化 if global_model_name not in global_models_dict: - # 查询指向该 GlobalModel 的所有别名/映射 - alias_rows = ( - db.query(ModelMapping.source_model) - .filter( - ModelMapping.target_global_model_id == global_model.id, - ModelMapping.is_active == True, - or_( - ModelMapping.provider_id == self.provider_id, - ModelMapping.provider_id.is_(None), - ), - ) - .all() - ) - alias_list = [alias[0] for alias in alias_rows] - global_models_dict[global_model_name] = { "global_model_name": global_model_name, "display_name": global_model.display_name, "provider_model_name": model.provider_model_name, - "has_alias": len(alias_list) > 0, - "aliases": alias_list, "model_id": model.id, "price": { "input_price_per_1m": model.get_effective_input_price(), diff --git a/src/api/public/catalog.py b/src/api/public/catalog.py index 290c465..bdc4381 100644 --- a/src/api/public/catalog.py +++ b/src/api/public/catalog.py @@ -20,14 +20,12 @@ from src.models.api import ( ProviderStatsResponse, PublicGlobalModelListResponse, PublicGlobalModelResponse, - PublicModelMappingResponse, PublicModelResponse, PublicProviderResponse, ) from src.models.database import ( GlobalModel, Model, - ModelMapping, Provider, ProviderEndpoint, RequestCandidate, @@ -72,24 +70,6 @@ async def get_public_models( return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC) -@router.get("/model-mappings", response_model=List[PublicModelMappingResponse]) -async def get_public_model_mappings( - request: Request, - provider_id: Optional[str] = Query(None, description="提供商ID过滤"), - alias: Optional[str] = Query(None, description="别名过滤(原source_model)"), - skip: int = Query(0, description="跳过记录数"), - limit: int = Query(100, description="返回记录数限制"), - db: Session = Depends(get_db), -): - adapter = PublicModelMappingsAdapter( - provider_id=provider_id, - alias=alias, - skip=skip, - limit=limit, - ) - return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC) - - @router.get("/stats", response_model=ProviderStatsResponse) async def get_public_stats(request: Request, db: Session = Depends(get_db)): adapter = PublicStatsAdapter() @@ -176,13 +156,6 @@ class PublicProvidersAdapter(PublicApiAdapter): .filter(and_(Model.provider_id == provider.id, Model.is_active.is_(True))) .count() ) - mappings_count = ( - db.query(ModelMapping) - .filter( - and_(ModelMapping.provider_id == provider.id, ModelMapping.is_active.is_(True)) - ) - .count() - ) endpoints_count = len(provider.endpoints) if provider.endpoints else 0 active_endpoints_count = ( sum(1 for ep in provider.endpoints if ep.is_active) if provider.endpoints else 0 @@ -196,7 +169,6 @@ class PublicProvidersAdapter(PublicApiAdapter): provider_priority=provider.provider_priority, models_count=models_count, active_models_count=active_models_count, - mappings_count=mappings_count, endpoints_count=endpoints_count, active_endpoints_count=active_endpoints_count, ) @@ -256,78 +228,6 @@ class PublicModelsAdapter(PublicApiAdapter): return response -@dataclass -class PublicModelMappingsAdapter(PublicApiAdapter): - provider_id: Optional[str] - alias: Optional[str] # 原 source_model,改为 alias - skip: int - limit: int - - async def handle(self, context): # type: ignore[override] - db = context.db - logger.debug("公共API请求模型映射列表") - - query = ( - db.query(ModelMapping, GlobalModel, Provider) - .join(GlobalModel, ModelMapping.target_global_model_id == GlobalModel.id) - .outerjoin(Provider, ModelMapping.provider_id == Provider.id) - .filter( - and_( - ModelMapping.is_active.is_(True), - GlobalModel.is_active.is_(True), - ) - ) - ) - - if self.provider_id is not None: - provider_global_model_ids = ( - db.query(Model.global_model_id) - .join(Provider, Model.provider_id == Provider.id) - .filter( - Provider.id == self.provider_id, - Model.is_active.is_(True), - Provider.is_active.is_(True), - Model.global_model_id.isnot(None), - ) - .distinct() - ) - query = query.filter( - or_( - ModelMapping.provider_id == self.provider_id, - and_( - ModelMapping.provider_id.is_(None), - ModelMapping.target_global_model_id.in_(provider_global_model_ids), - ), - ) - ) - else: - query = query.filter(ModelMapping.provider_id.is_(None)) - - if self.alias is not None: - query = query.filter(ModelMapping.source_model.ilike(f"%{self.alias}%")) - - results = query.offset(self.skip).limit(self.limit).all() - response = [] - for mapping, global_model, provider in results: - scope = "provider" if mapping.provider_id else "global" - mapping_data = PublicModelMappingResponse( - id=mapping.id, - source_model=mapping.source_model, - target_global_model_id=mapping.target_global_model_id, - target_global_model_name=global_model.name if global_model else None, - target_global_model_display_name=( - global_model.display_name if global_model else None - ), - provider_id=mapping.provider_id, - scope=scope, - is_active=mapping.is_active, - ) - response.append(mapping_data.model_dump()) - - logger.debug(f"返回 {len(response)} 个模型映射") - return response - - class PublicStatsAdapter(PublicApiAdapter): async def handle(self, context): # type: ignore[override] db = context.db @@ -339,9 +239,6 @@ class PublicStatsAdapter(PublicApiAdapter): .filter(and_(Model.is_active.is_(True), Provider.is_active.is_(True))) .count() ) - from ...models.database import ModelMapping - - active_mappings = db.query(ModelMapping).filter(ModelMapping.is_active.is_(True)).count() formats = ( db.query(Provider.api_format).filter(Provider.is_active.is_(True)).distinct().all() ) @@ -351,7 +248,6 @@ class PublicStatsAdapter(PublicApiAdapter): active_providers=active_providers, total_models=active_models, active_models=active_models, - total_mappings=active_mappings, supported_formats=supported_formats, ) logger.debug("返回系统统计信息") diff --git a/src/models/database.py b/src/models/database.py index bd97009..98b231d 100644 --- a/src/models/database.py +++ b/src/models/database.py @@ -7,6 +7,7 @@ import secrets import uuid from datetime import datetime, timezone from enum import Enum as PyEnum +from typing import Optional import bcrypt from sqlalchemy import ( @@ -491,9 +492,6 @@ class Provider(Base): # 关系 models = relationship("Model", back_populates="provider", cascade="all, delete-orphan") - model_mappings = relationship( - "ModelMapping", back_populates="provider", cascade="all, delete-orphan" - ) endpoints = relationship( "ProviderEndpoint", back_populates="provider", cascade="all, delete-orphan" ) @@ -656,7 +654,11 @@ class Model(Base): global_model_id = Column(String(36), ForeignKey("global_models.id"), nullable=False, index=True) # Provider 映射配置 - provider_model_name = Column(String(200), nullable=False) # Provider 侧的模型名称 + provider_model_name = Column(String(200), nullable=False) # Provider 侧的主模型名称 + # 模型名称别名列表(带优先级),用于同一模型在 Provider 侧有多个名称变体的场景 + # 格式: [{"name": "Claude-Sonnet-4.5", "priority": 1}, {"name": "Claude-Sonnet-4-5", "priority": 2}] + # 为空时只使用 provider_model_name + provider_model_aliases = Column(JSON, nullable=True, default=None) # 按次计费配置(每次请求的固定费用,美元)- 可为空,为空时使用 GlobalModel 的默认值 price_per_request = Column(Float, nullable=True) # 每次请求固定费用 @@ -786,60 +788,83 @@ class Model(Base): def get_effective_supports_image_generation(self) -> bool: return self._get_effective_capability("supports_image_generation", False) + def select_provider_model_name(self, affinity_key: Optional[str] = None) -> str: + """按优先级选择要使用的 Provider 模型名称 -class ModelMapping(Base): - """模型映射表 - 统一处理别名与降级策略 + 如果配置了 provider_model_aliases,按优先级选择(数字越小越优先); + 相同优先级的别名通过哈希分散实现负载均衡(与 Key 调度策略一致); + 否则返回 provider_model_name。 - 设计原则: - - source_model 接收用户请求的原始模型名/别名 - - target_global_model_id 指向真实的 GlobalModel - - provider_id 为空表示全局别名,非空表示 Provider 特定映射/降级 - - 一个 (source_model, provider_id) 组合唯一 + Args: + affinity_key: 用于哈希分散的亲和键(如用户 API Key 哈希),确保同一用户稳定选择同一别名 + """ + import hashlib - 映射类型 (mapping_type): - - alias: 别名模式,按目标模型计费(只是名称简写) - - mapping: 映射模式,按源模型计费(模型降级/替代) - """ + if not self.provider_model_aliases: + return self.provider_model_name - __tablename__ = "model_mappings" + raw_aliases = self.provider_model_aliases + if not isinstance(raw_aliases, list) or len(raw_aliases) == 0: + return self.provider_model_name - id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True) + aliases: list[dict] = [] + for raw in raw_aliases: + if not isinstance(raw, dict): + continue + name = raw.get("name") + if not isinstance(name, str) or not name.strip(): + continue - # 源模型名称(可能是别名或真实 GlobalModel.name) - source_model = Column(String(200), nullable=False, index=True) + raw_priority = raw.get("priority", 1) + try: + priority = int(raw_priority) + except Exception: + priority = 1 + if priority < 1: + priority = 1 - # 目标 GlobalModel - target_global_model_id = Column( - String(36), ForeignKey("global_models.id", ondelete="CASCADE"), nullable=False, index=True - ) + aliases.append({"name": name.strip(), "priority": priority}) - # Provider 关联:NULL 代表全局别名 - provider_id = Column(String(36), ForeignKey("providers.id"), nullable=True, index=True) + if not aliases: + return self.provider_model_name - # 映射类型:alias=按目标模型计费,mapping=按源模型计费 - mapping_type = Column(String(20), nullable=False, default="alias", index=True) + # 按优先级排序(数字越小越优先) + sorted_aliases = sorted(aliases, key=lambda x: x["priority"]) - # 状态 - is_active = Column(Boolean, default=True, nullable=False) + # 获取最高优先级(最小数字) + highest_priority = sorted_aliases[0]["priority"] - # 时间戳 - created_at = Column( - DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False - ) - updated_at = Column( - DateTime(timezone=True), - default=lambda: datetime.now(timezone.utc), - onupdate=lambda: datetime.now(timezone.utc), - nullable=False, - ) + # 获取所有最高优先级的别名 + top_priority_aliases = [ + alias for alias in sorted_aliases + if alias["priority"] == highest_priority + ] - # 关系 - target_global_model = relationship("GlobalModel", foreign_keys=[target_global_model_id]) - provider = relationship("Provider", back_populates="model_mappings") + # 如果有多个相同优先级的别名,通过哈希分散选择 + if len(top_priority_aliases) > 1 and affinity_key: + # 为每个别名计算哈希得分,选择得分最小的 + def hash_score(alias: dict) -> int: + combined = f"{affinity_key}:{alias['name']}" + return int(hashlib.md5(combined.encode()).hexdigest(), 16) - __table_args__ = ( - UniqueConstraint("source_model", "provider_id", name="uq_model_mapping_source_provider"), - ) + selected = min(top_priority_aliases, key=hash_score) + elif len(top_priority_aliases) > 1: + # 没有 affinity_key 时,使用确定性选择(按名称排序后取第一个) + # 避免随机选择导致同一请求重试时选择不同的模型名称 + selected = min(top_priority_aliases, key=lambda x: x["name"]) + else: + selected = top_priority_aliases[0] + + return selected["name"] + + def get_all_provider_model_names(self) -> list[str]: + """获取所有可用的 Provider 模型名称(主名称 + 别名)""" + names = [self.provider_model_name] + if self.provider_model_aliases: + for alias in self.provider_model_aliases: + if isinstance(alias, dict) and alias.get("name"): + names.append(alias["name"]) + return names class ProviderAPIKey(Base): diff --git a/src/models/pydantic_models.py b/src/models/pydantic_models.py index 22362d6..ca83513 100644 --- a/src/models/pydantic_models.py +++ b/src/models/pydantic_models.py @@ -7,8 +7,6 @@ from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field, model_validator -from .api import ModelCreate - # ========== 阶梯计费相关模型 ========== @@ -131,24 +129,14 @@ class ModelCatalogProviderDetail(BaseModel): supports_function_calling: Optional[bool] = None supports_streaming: Optional[bool] = None is_active: bool - mapping_id: Optional[str] - - -class OrphanedModel(BaseModel): - """孤立的统一模型(Mapping 存在但 GlobalModel 缺失)""" - - alias: str # 别名 - global_model_name: Optional[str] # 关联的 GlobalModel 名称(如果有) - mapping_count: int class ModelCatalogItem(BaseModel): - """统一模型目录条目(方案 A:基于 GlobalModel)""" + """统一模型目录条目(基于 GlobalModel)""" global_model_name: str # GlobalModel.name display_name: str # GlobalModel.display_name description: Optional[str] # GlobalModel.description - aliases: List[str] # 所有指向该 GlobalModel 的别名列表 providers: List[ModelCatalogProviderDetail] # 支持该模型的 Provider 列表 price_range: ModelPriceRange # 价格区间(从所有 Provider 的 Model 中聚合) total_providers: int @@ -160,7 +148,6 @@ class ModelCatalogResponse(BaseModel): models: List[ModelCatalogItem] total: int - orphaned_models: List[OrphanedModel] class ProviderModelPriceInfo(BaseModel): @@ -174,13 +161,11 @@ class ProviderModelPriceInfo(BaseModel): class ProviderAvailableSourceModel(BaseModel): - """Provider 支持的统一模型条目(方案 A)""" + """Provider 支持的统一模型条目""" global_model_name: str # GlobalModel.name display_name: str # GlobalModel.display_name provider_model_name: str # Model.provider_model_name (Provider 侧的模型名) - has_alias: bool # 是否有别名指向该 GlobalModel - aliases: List[str] # 别名列表 model_id: Optional[str] # Model.id price: ProviderModelPriceInfo capabilities: ModelCapabilities @@ -194,50 +179,7 @@ class ProviderAvailableSourceModelsResponse(BaseModel): total: int -class BatchAssignProviderConfig(BaseModel): - """批量添加映射的 Provider 配置""" - - provider_id: str - create_model: bool = Field(False, description="是否需要创建新的 Model") - model_data: Optional[ModelCreate] = Field( - None, description="create_model=true 时需要提供的模型配置", alias="model_config" - ) - model_id: Optional[str] = Field(None, description="create_model=false 时需要提供的现有模型 ID") - - -class BatchAssignModelMappingRequest(BaseModel): - """批量添加模型映射请求(方案 A:暂不支持,需要重构)""" - - global_model_id: str # 要分配的 GlobalModel ID - providers: List[BatchAssignProviderConfig] - - -class BatchAssignProviderResult(BaseModel): - """批量映射结果条目""" - - provider_id: str - mapping_id: Optional[str] - created_model: bool - model_id: Optional[str] - updated: bool = False - - -class BatchAssignError(BaseModel): - """批量映射错误信息""" - - provider_id: str - error: str - - -class BatchAssignModelMappingResponse(BaseModel): - """批量映射响应""" - - success: bool - created_mappings: List[BatchAssignProviderResult] - errors: List[BatchAssignError] - - -# ========== 阶段二:GlobalModel 相关模型 ========== +# ========== GlobalModel 相关模型 ========== class GlobalModelCreate(BaseModel): @@ -328,7 +270,6 @@ class GlobalModelResponse(BaseModel): ) # 统计数据(可选) provider_count: Optional[int] = Field(default=0, description="支持的 Provider 数量") - alias_count: Optional[int] = Field(default=0, description="别名数量") usage_count: Optional[int] = Field(default=0, description="调用次数") created_at: datetime updated_at: Optional[datetime] @@ -355,7 +296,7 @@ class GlobalModelListResponse(BaseModel): class BatchAssignToProvidersRequest(BaseModel): """批量为 Provider 添加 GlobalModel 实现""" - provider_ids: List[str] = Field(..., min_items=1, description="Provider ID 列表") + provider_ids: List[str] = Field(..., min_length=1, description="Provider ID 列表") create_models: bool = Field(default=False, description="是否自动创建 Model 记录") @@ -379,43 +320,11 @@ class BatchAssignModelsToProviderResponse(BaseModel): errors: List[dict] -class UpdateModelMappingRequest(BaseModel): - """更新模型映射请求""" - - source_model: Optional[str] = Field( - None, min_length=1, max_length=200, description="源模型名或别名" - ) - target_global_model_id: Optional[str] = Field(None, description="目标 GlobalModel ID") - provider_id: Optional[str] = Field(None, description="Provider ID(为空时为全局别名)") - is_active: Optional[bool] = Field(None, description="是否启用") - - -class UpdateModelMappingResponse(BaseModel): - """更新模型映射响应""" - - success: bool - mapping_id: str - message: Optional[str] = None - - -class DeleteModelMappingResponse(BaseModel): - """删除模型映射响应""" - - success: bool - message: Optional[str] = None - - __all__ = [ - "BatchAssignError", - "BatchAssignModelMappingRequest", - "BatchAssignModelMappingResponse", "BatchAssignModelsToProviderRequest", "BatchAssignModelsToProviderResponse", - "BatchAssignProviderConfig", - "BatchAssignProviderResult", "BatchAssignToProvidersRequest", "BatchAssignToProvidersResponse", - "DeleteModelMappingResponse", "GlobalModelCreate", "GlobalModelListResponse", "GlobalModelResponse", @@ -426,10 +335,7 @@ __all__ = [ "ModelCatalogProviderDetail", "ModelCatalogResponse", "ModelPriceRange", - "OrphanedModel", "ProviderAvailableSourceModel", "ProviderAvailableSourceModelsResponse", "ProviderModelPriceInfo", - "UpdateModelMappingRequest", - "UpdateModelMappingResponse", ]