refactor(backend): update model catalog and provider APIs after mappings removal

This commit is contained in:
fawney19
2025-12-15 14:30:10 +08:00
parent 728f9bb126
commit 56fb6bf36c
6 changed files with 85 additions and 566 deletions

View File

@@ -7,6 +7,7 @@ import secrets
import uuid
from datetime import datetime, timezone
from enum import Enum as PyEnum
from typing import Optional
import bcrypt
from sqlalchemy import (
@@ -491,9 +492,6 @@ class Provider(Base):
# 关系
models = relationship("Model", back_populates="provider", cascade="all, delete-orphan")
model_mappings = relationship(
"ModelMapping", back_populates="provider", cascade="all, delete-orphan"
)
endpoints = relationship(
"ProviderEndpoint", back_populates="provider", cascade="all, delete-orphan"
)
@@ -656,7 +654,11 @@ class Model(Base):
global_model_id = Column(String(36), ForeignKey("global_models.id"), nullable=False, index=True)
# Provider 映射配置
provider_model_name = Column(String(200), nullable=False) # Provider 侧的模型名称
provider_model_name = Column(String(200), nullable=False) # Provider 侧的模型名称
# 模型名称别名列表(带优先级),用于同一模型在 Provider 侧有多个名称变体的场景
# 格式: [{"name": "Claude-Sonnet-4.5", "priority": 1}, {"name": "Claude-Sonnet-4-5", "priority": 2}]
# 为空时只使用 provider_model_name
provider_model_aliases = Column(JSON, nullable=True, default=None)
# 按次计费配置(每次请求的固定费用,美元)- 可为空,为空时使用 GlobalModel 的默认值
price_per_request = Column(Float, nullable=True) # 每次请求固定费用
@@ -786,60 +788,83 @@ class Model(Base):
def get_effective_supports_image_generation(self) -> bool:
return self._get_effective_capability("supports_image_generation", False)
def select_provider_model_name(self, affinity_key: Optional[str] = None) -> str:
"""按优先级选择要使用的 Provider 模型名称
class ModelMapping(Base):
"""模型映射表 - 统一处理别名与降级策略
如果配置了 provider_model_aliases按优先级选择数字越小越优先
相同优先级的别名通过哈希分散实现负载均衡(与 Key 调度策略一致);
否则返回 provider_model_name。
设计原则:
- source_model 接收用户请求的原始模型名/别名
- target_global_model_id 指向真实的 GlobalModel
- provider_id 为空表示全局别名,非空表示 Provider 特定映射/降级
- 一个 (source_model, provider_id) 组合唯一
Args:
affinity_key: 用于哈希分散的亲和键(如用户 API Key 哈希),确保同一用户稳定选择同一别名
"""
import hashlib
映射类型 (mapping_type):
- alias: 别名模式,按目标模型计费(只是名称简写)
- mapping: 映射模式,按源模型计费(模型降级/替代)
"""
if not self.provider_model_aliases:
return self.provider_model_name
__tablename__ = "model_mappings"
raw_aliases = self.provider_model_aliases
if not isinstance(raw_aliases, list) or len(raw_aliases) == 0:
return self.provider_model_name
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True)
aliases: list[dict] = []
for raw in raw_aliases:
if not isinstance(raw, dict):
continue
name = raw.get("name")
if not isinstance(name, str) or not name.strip():
continue
# 源模型名称(可能是别名或真实 GlobalModel.name
source_model = Column(String(200), nullable=False, index=True)
raw_priority = raw.get("priority", 1)
try:
priority = int(raw_priority)
except Exception:
priority = 1
if priority < 1:
priority = 1
# 目标 GlobalModel
target_global_model_id = Column(
String(36), ForeignKey("global_models.id", ondelete="CASCADE"), nullable=False, index=True
)
aliases.append({"name": name.strip(), "priority": priority})
# Provider 关联NULL 代表全局别名
provider_id = Column(String(36), ForeignKey("providers.id"), nullable=True, index=True)
if not aliases:
return self.provider_model_name
# 映射类型alias=按目标模型计费mapping=按源模型计费
mapping_type = Column(String(20), nullable=False, default="alias", index=True)
# 按优先级排序(数字越小越优先)
sorted_aliases = sorted(aliases, key=lambda x: x["priority"])
# 状态
is_active = Column(Boolean, default=True, nullable=False)
# 获取最高优先级(最小数字)
highest_priority = sorted_aliases[0]["priority"]
# 时间戳
created_at = Column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False
)
updated_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
nullable=False,
)
# 获取所有最高优先级的别名
top_priority_aliases = [
alias for alias in sorted_aliases
if alias["priority"] == highest_priority
]
# 关系
target_global_model = relationship("GlobalModel", foreign_keys=[target_global_model_id])
provider = relationship("Provider", back_populates="model_mappings")
# 如果有多个相同优先级的别名,通过哈希分散选择
if len(top_priority_aliases) > 1 and affinity_key:
# 为每个别名计算哈希得分,选择得分最小的
def hash_score(alias: dict) -> int:
combined = f"{affinity_key}:{alias['name']}"
return int(hashlib.md5(combined.encode()).hexdigest(), 16)
__table_args__ = (
UniqueConstraint("source_model", "provider_id", name="uq_model_mapping_source_provider"),
)
selected = min(top_priority_aliases, key=hash_score)
elif len(top_priority_aliases) > 1:
# 没有 affinity_key 时,使用确定性选择(按名称排序后取第一个)
# 避免随机选择导致同一请求重试时选择不同的模型名称
selected = min(top_priority_aliases, key=lambda x: x["name"])
else:
selected = top_priority_aliases[0]
return selected["name"]
def get_all_provider_model_names(self) -> list[str]:
"""获取所有可用的 Provider 模型名称(主名称 + 别名)"""
names = [self.provider_model_name]
if self.provider_model_aliases:
for alias in self.provider_model_aliases:
if isinstance(alias, dict) and alias.get("name"):
names.append(alias["name"])
return names
class ProviderAPIKey(Base):

View File

@@ -7,8 +7,6 @@ from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, model_validator
from .api import ModelCreate
# ========== 阶梯计费相关模型 ==========
@@ -131,24 +129,14 @@ class ModelCatalogProviderDetail(BaseModel):
supports_function_calling: Optional[bool] = None
supports_streaming: Optional[bool] = None
is_active: bool
mapping_id: Optional[str]
class OrphanedModel(BaseModel):
"""孤立的统一模型Mapping 存在但 GlobalModel 缺失)"""
alias: str # 别名
global_model_name: Optional[str] # 关联的 GlobalModel 名称(如果有)
mapping_count: int
class ModelCatalogItem(BaseModel):
"""统一模型目录条目(方案 A基于 GlobalModel"""
"""统一模型目录条目(基于 GlobalModel"""
global_model_name: str # GlobalModel.name
display_name: str # GlobalModel.display_name
description: Optional[str] # GlobalModel.description
aliases: List[str] # 所有指向该 GlobalModel 的别名列表
providers: List[ModelCatalogProviderDetail] # 支持该模型的 Provider 列表
price_range: ModelPriceRange # 价格区间(从所有 Provider 的 Model 中聚合)
total_providers: int
@@ -160,7 +148,6 @@ class ModelCatalogResponse(BaseModel):
models: List[ModelCatalogItem]
total: int
orphaned_models: List[OrphanedModel]
class ProviderModelPriceInfo(BaseModel):
@@ -174,13 +161,11 @@ class ProviderModelPriceInfo(BaseModel):
class ProviderAvailableSourceModel(BaseModel):
"""Provider 支持的统一模型条目(方案 A"""
"""Provider 支持的统一模型条目"""
global_model_name: str # GlobalModel.name
display_name: str # GlobalModel.display_name
provider_model_name: str # Model.provider_model_name (Provider 侧的模型名)
has_alias: bool # 是否有别名指向该 GlobalModel
aliases: List[str] # 别名列表
model_id: Optional[str] # Model.id
price: ProviderModelPriceInfo
capabilities: ModelCapabilities
@@ -194,50 +179,7 @@ class ProviderAvailableSourceModelsResponse(BaseModel):
total: int
class BatchAssignProviderConfig(BaseModel):
"""批量添加映射的 Provider 配置"""
provider_id: str
create_model: bool = Field(False, description="是否需要创建新的 Model")
model_data: Optional[ModelCreate] = Field(
None, description="create_model=true 时需要提供的模型配置", alias="model_config"
)
model_id: Optional[str] = Field(None, description="create_model=false 时需要提供的现有模型 ID")
class BatchAssignModelMappingRequest(BaseModel):
"""批量添加模型映射请求(方案 A暂不支持需要重构"""
global_model_id: str # 要分配的 GlobalModel ID
providers: List[BatchAssignProviderConfig]
class BatchAssignProviderResult(BaseModel):
"""批量映射结果条目"""
provider_id: str
mapping_id: Optional[str]
created_model: bool
model_id: Optional[str]
updated: bool = False
class BatchAssignError(BaseModel):
"""批量映射错误信息"""
provider_id: str
error: str
class BatchAssignModelMappingResponse(BaseModel):
"""批量映射响应"""
success: bool
created_mappings: List[BatchAssignProviderResult]
errors: List[BatchAssignError]
# ========== 阶段二GlobalModel 相关模型 ==========
# ========== GlobalModel 相关模型 ==========
class GlobalModelCreate(BaseModel):
@@ -328,7 +270,6 @@ class GlobalModelResponse(BaseModel):
)
# 统计数据(可选)
provider_count: Optional[int] = Field(default=0, description="支持的 Provider 数量")
alias_count: Optional[int] = Field(default=0, description="别名数量")
usage_count: Optional[int] = Field(default=0, description="调用次数")
created_at: datetime
updated_at: Optional[datetime]
@@ -355,7 +296,7 @@ class GlobalModelListResponse(BaseModel):
class BatchAssignToProvidersRequest(BaseModel):
"""批量为 Provider 添加 GlobalModel 实现"""
provider_ids: List[str] = Field(..., min_items=1, description="Provider ID 列表")
provider_ids: List[str] = Field(..., min_length=1, description="Provider ID 列表")
create_models: bool = Field(default=False, description="是否自动创建 Model 记录")
@@ -379,43 +320,11 @@ class BatchAssignModelsToProviderResponse(BaseModel):
errors: List[dict]
class UpdateModelMappingRequest(BaseModel):
"""更新模型映射请求"""
source_model: Optional[str] = Field(
None, min_length=1, max_length=200, description="源模型名或别名"
)
target_global_model_id: Optional[str] = Field(None, description="目标 GlobalModel ID")
provider_id: Optional[str] = Field(None, description="Provider ID为空时为全局别名")
is_active: Optional[bool] = Field(None, description="是否启用")
class UpdateModelMappingResponse(BaseModel):
"""更新模型映射响应"""
success: bool
mapping_id: str
message: Optional[str] = None
class DeleteModelMappingResponse(BaseModel):
"""删除模型映射响应"""
success: bool
message: Optional[str] = None
__all__ = [
"BatchAssignError",
"BatchAssignModelMappingRequest",
"BatchAssignModelMappingResponse",
"BatchAssignModelsToProviderRequest",
"BatchAssignModelsToProviderResponse",
"BatchAssignProviderConfig",
"BatchAssignProviderResult",
"BatchAssignToProvidersRequest",
"BatchAssignToProvidersResponse",
"DeleteModelMappingResponse",
"GlobalModelCreate",
"GlobalModelListResponse",
"GlobalModelResponse",
@@ -426,10 +335,7 @@ __all__ = [
"ModelCatalogProviderDetail",
"ModelCatalogResponse",
"ModelPriceRange",
"OrphanedModel",
"ProviderAvailableSourceModel",
"ProviderAvailableSourceModelsResponse",
"ProviderModelPriceInfo",
"UpdateModelMappingRequest",
"UpdateModelMappingResponse",
]