refactor(backend): remove model mappings module

This commit is contained in:
fawney19
2025-12-15 14:30:00 +08:00
parent 5319c06f0e
commit 728f9bb126
3 changed files with 0 additions and 737 deletions

View File

@@ -6,11 +6,9 @@ from fastapi import APIRouter
from .catalog import router as catalog_router from .catalog import router as catalog_router
from .global_models import router as global_models_router from .global_models import router as global_models_router
from .mappings import router as mappings_router
router = APIRouter(prefix="/api/admin/models", tags=["Admin - Model Management"]) router = APIRouter(prefix="/api/admin/models", tags=["Admin - Model Management"])
# 挂载子路由 # 挂载子路由
router.include_router(catalog_router) router.include_router(catalog_router)
router.include_router(global_models_router) router.include_router(global_models_router)
router.include_router(mappings_router)

View File

@@ -1,303 +0,0 @@
"""模型映射管理 API
提供模型映射的 CRUD 操作。
模型映射Mapping用于将源模型映射到目标模型例如
- 请求 gpt-5.1 → Provider A 映射到 gpt-4
- 用于处理 Provider 不支持请求模型的情况
映射必须关联到特定的 Provider。
"""
import uuid
from datetime import datetime, timezone
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session, joinedload
from src.core.logger import logger
from src.database import get_db
from src.models.api import (
ModelMappingCreate,
ModelMappingResponse,
ModelMappingUpdate,
)
from src.models.database import GlobalModel, ModelMapping, Provider, User
from src.services.cache.invalidation import get_cache_invalidation_service
router = APIRouter(prefix="/mappings", tags=["Model Mappings"])
def _serialize_mapping(mapping: ModelMapping) -> ModelMappingResponse:
target = mapping.target_global_model
provider = mapping.provider
scope = "provider" if mapping.provider_id else "global"
return ModelMappingResponse(
id=mapping.id,
source_model=mapping.source_model,
target_global_model_id=mapping.target_global_model_id,
target_global_model_name=target.name if target else None,
target_global_model_display_name=target.display_name if target else None,
provider_id=mapping.provider_id,
provider_name=provider.name if provider else None,
scope=scope,
mapping_type=mapping.mapping_type,
is_active=mapping.is_active,
created_at=mapping.created_at,
updated_at=mapping.updated_at,
)
@router.get("", response_model=List[ModelMappingResponse])
async def list_mappings(
provider_id: Optional[str] = Query(None, description="按 Provider 筛选"),
source_model: Optional[str] = Query(None, description="按源模型名筛选"),
target_global_model_id: Optional[str] = Query(None, description="按目标模型筛选"),
scope: Optional[str] = Query(None, description="global 或 provider"),
mapping_type: Optional[str] = Query(None, description="映射类型: alias 或 mapping"),
is_active: Optional[bool] = Query(None, description="按状态筛选"),
skip: int = Query(0, ge=0, description="跳过记录数"),
limit: int = Query(100, ge=1, le=1000, description="返回记录数"),
db: Session = Depends(get_db),
):
"""获取模型映射列表"""
query = db.query(ModelMapping).options(
joinedload(ModelMapping.target_global_model),
joinedload(ModelMapping.provider),
)
if provider_id is not None:
query = query.filter(ModelMapping.provider_id == provider_id)
if scope == "global":
query = query.filter(ModelMapping.provider_id.is_(None))
elif scope == "provider":
query = query.filter(ModelMapping.provider_id.isnot(None))
if mapping_type is not None:
query = query.filter(ModelMapping.mapping_type == mapping_type)
if source_model:
query = query.filter(ModelMapping.source_model.ilike(f"%{source_model}%"))
if target_global_model_id is not None:
query = query.filter(ModelMapping.target_global_model_id == target_global_model_id)
if is_active is not None:
query = query.filter(ModelMapping.is_active == is_active)
mappings = query.offset(skip).limit(limit).all()
return [_serialize_mapping(mapping) for mapping in mappings]
@router.get("/{mapping_id}", response_model=ModelMappingResponse)
async def get_mapping(
mapping_id: str,
db: Session = Depends(get_db),
):
"""获取单个模型映射"""
mapping = (
db.query(ModelMapping)
.options(
joinedload(ModelMapping.target_global_model),
joinedload(ModelMapping.provider),
)
.filter(ModelMapping.id == mapping_id)
.first()
)
if not mapping:
raise HTTPException(status_code=404, detail=f"映射 {mapping_id} 不存在")
return _serialize_mapping(mapping)
@router.post("", response_model=ModelMappingResponse, status_code=201)
async def create_mapping(
data: ModelMappingCreate,
db: Session = Depends(get_db),
):
"""创建模型映射"""
source_model = data.source_model.strip()
if not source_model:
raise HTTPException(status_code=400, detail="source_model 不能为空")
# 验证 mapping_type
if data.mapping_type not in ("alias", "mapping"):
raise HTTPException(status_code=400, detail="mapping_type 必须是 'alias''mapping'")
# 验证目标 GlobalModel 存在
target_model = (
db.query(GlobalModel)
.filter(GlobalModel.id == data.target_global_model_id, GlobalModel.is_active == True)
.first()
)
if not target_model:
raise HTTPException(
status_code=404, detail=f"目标模型 {data.target_global_model_id} 不存在或未激活"
)
# 验证 Provider 存在
provider = None
provider_id = data.provider_id
if provider_id:
provider = db.query(Provider).filter(Provider.id == provider_id).first()
if not provider:
raise HTTPException(status_code=404, detail=f"Provider {provider_id} 不存在")
# 检查映射是否已存在(全局或同一 Provider 下不可重复)
existing = (
db.query(ModelMapping)
.filter(
ModelMapping.source_model == source_model,
ModelMapping.provider_id == provider_id,
)
.first()
)
if existing:
raise HTTPException(status_code=400, detail="映射已存在")
# 创建映射
mapping = ModelMapping(
id=str(uuid.uuid4()),
source_model=source_model,
target_global_model_id=data.target_global_model_id,
provider_id=provider_id,
mapping_type=data.mapping_type,
is_active=data.is_active,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
db.add(mapping)
db.commit()
mapping = (
db.query(ModelMapping)
.options(
joinedload(ModelMapping.target_global_model),
joinedload(ModelMapping.provider),
)
.filter(ModelMapping.id == mapping.id)
.first()
)
logger.info(f"创建模型映射: {source_model} -> {target_model.name} "
f"(Provider: {provider.name if provider else 'global'}, ID: {mapping.id})")
cache_service = get_cache_invalidation_service()
cache_service.on_model_mapping_changed(source_model, provider_id)
return _serialize_mapping(mapping)
@router.patch("/{mapping_id}", response_model=ModelMappingResponse)
async def update_mapping(
mapping_id: str,
data: ModelMappingUpdate,
db: Session = Depends(get_db),
):
"""更新模型映射"""
mapping = db.query(ModelMapping).filter(ModelMapping.id == mapping_id).first()
if not mapping:
raise HTTPException(status_code=404, detail=f"映射 {mapping_id} 不存在")
update_data = data.model_dump(exclude_unset=True)
# 更新 Provider
if "provider_id" in update_data:
new_provider_id = update_data["provider_id"]
if new_provider_id:
provider = db.query(Provider).filter(Provider.id == new_provider_id).first()
if not provider:
raise HTTPException(status_code=404, detail=f"Provider {new_provider_id} 不存在")
mapping.provider_id = new_provider_id
# 更新目标模型
if "target_global_model_id" in update_data:
target_model = (
db.query(GlobalModel)
.filter(
GlobalModel.id == update_data["target_global_model_id"],
GlobalModel.is_active == True,
)
.first()
)
if not target_model:
raise HTTPException(
status_code=404,
detail=f"目标模型 {update_data['target_global_model_id']} 不存在或未激活",
)
mapping.target_global_model_id = update_data["target_global_model_id"]
# 更新源模型名
if "source_model" in update_data:
new_source = update_data["source_model"].strip()
if not new_source:
raise HTTPException(status_code=400, detail="source_model 不能为空")
mapping.source_model = new_source
# 检查唯一约束
duplicate = (
db.query(ModelMapping)
.filter(
ModelMapping.source_model == mapping.source_model,
ModelMapping.provider_id == mapping.provider_id,
ModelMapping.id != mapping_id,
)
.first()
)
if duplicate:
raise HTTPException(status_code=400, detail="映射已存在")
# 更新映射类型
if "mapping_type" in update_data:
if update_data["mapping_type"] not in ("alias", "mapping"):
raise HTTPException(status_code=400, detail="mapping_type 必须是 'alias''mapping'")
mapping.mapping_type = update_data["mapping_type"]
# 更新状态
if "is_active" in update_data:
mapping.is_active = update_data["is_active"]
mapping.updated_at = datetime.now(timezone.utc)
db.commit()
db.refresh(mapping)
logger.info(f"更新模型映射 (ID: {mapping.id})")
mapping = (
db.query(ModelMapping)
.options(
joinedload(ModelMapping.target_global_model),
joinedload(ModelMapping.provider),
)
.filter(ModelMapping.id == mapping.id)
.first()
)
cache_service = get_cache_invalidation_service()
cache_service.on_model_mapping_changed(mapping.source_model, mapping.provider_id)
return _serialize_mapping(mapping)
@router.delete("/{mapping_id}", status_code=204)
async def delete_mapping(
mapping_id: str,
db: Session = Depends(get_db),
):
"""删除模型映射"""
mapping = db.query(ModelMapping).filter(ModelMapping.id == mapping_id).first()
if not mapping:
raise HTTPException(status_code=404, detail=f"映射 {mapping_id} 不存在")
source_model = mapping.source_model
provider_id = mapping.provider_id
logger.info(f"删除模型映射: {source_model} -> {mapping.target_global_model_id} (ID: {mapping.id})")
db.delete(mapping)
db.commit()
cache_service = get_cache_invalidation_service()
cache_service.on_model_mapping_changed(source_model, provider_id)
return None

View File

@@ -1,432 +0,0 @@
"""
模型映射解析服务
负责统一的模型别名/降级解析,按优先级顺序:
1. 映射mappingProvider 特定 → 全局
2. 别名aliasProvider 特定 → 全局
3. 直接匹配 GlobalModel.name
支持特性:
- 带缓存(本地或 Redis减少数据库访问
- 提供模糊匹配能力,用于提示相似模型
"""
from __future__ import annotations
import os
from typing import List, Optional, Tuple
from src.core.logger import logger
from sqlalchemy.orm import Session
from src.config.constants import CacheSize, CacheTTL
from src.core.logger import logger
from src.models.database import GlobalModel, ModelMapping
from src.services.cache.backend import BaseCacheBackend, get_cache_backend
class ModelMappingResolver:
"""统一的 ModelMapping 解析服务(可跨进程共享缓存)。"""
def __init__(self, cache_ttl: int = CacheTTL.MODEL_MAPPING, cache_backend_type: str = "auto"):
self._cache_ttl = cache_ttl
self._cache_backend_type = cache_backend_type
self._mapping_cache: Optional[BaseCacheBackend] = None
self._global_model_cache: Optional[BaseCacheBackend] = None
self._initialized = False
self._stats = {
"mapping_hits": 0,
"mapping_misses": 0,
"global_hits": 0,
"global_misses": 0,
}
async def _ensure_initialized(self):
if self._initialized:
return
self._mapping_cache = await get_cache_backend(
name="model_mapping_resolver:mapping",
backend_type=self._cache_backend_type,
max_size=CacheSize.MODEL_MAPPING,
ttl=self._cache_ttl,
)
self._global_model_cache = await get_cache_backend(
name="model_mapping_resolver:global",
backend_type=self._cache_backend_type,
max_size=CacheSize.MODEL_MAPPING,
ttl=self._cache_ttl,
)
self._initialized = True
logger.debug(f"[ModelMappingResolver] 缓存后端已初始化: {self._mapping_cache.get_stats()['backend']}")
def _cache_key(self, source_model: str, provider_id: Optional[str]) -> str:
return f"{provider_id or 'global'}:{source_model}"
async def _lookup_target_global_model_id(
self,
db: Session,
source_model: str,
provider_id: Optional[str] = None,
) -> Optional[str]:
"""
按优先级查找目标 GlobalModel ID
1. 映射mapping_type='mapping'Provider 特定 → 全局
2. 别名mapping_type='alias'Provider 特定 → 全局
3. 直接匹配 GlobalModel.name
"""
await self._ensure_initialized()
cache_key = self._cache_key(source_model, provider_id)
cached = await self._mapping_cache.get(cache_key)
if cached is not None:
self._stats["mapping_hits"] += 1
return cached or None
self._stats["mapping_misses"] += 1
target_id: Optional[str] = None
# 优先级 1查找映射mapping_type='mapping'
# 1.1 Provider 特定映射
if provider_id:
mapping = (
db.query(ModelMapping)
.filter(
ModelMapping.source_model == source_model,
ModelMapping.provider_id == provider_id,
ModelMapping.mapping_type == "mapping",
ModelMapping.is_active == True,
)
.first()
)
if mapping:
target_id = mapping.target_global_model_id
logger.debug(f"[MappingResolver] 命中 Provider 映射: {source_model} -> {target_id[:8]}...")
# 1.2 全局映射
if not target_id:
mapping = (
db.query(ModelMapping)
.filter(
ModelMapping.source_model == source_model,
ModelMapping.provider_id.is_(None),
ModelMapping.mapping_type == "mapping",
ModelMapping.is_active == True,
)
.first()
)
if mapping:
target_id = mapping.target_global_model_id
logger.debug(f"[MappingResolver] 命中全局映射: {source_model} -> {target_id[:8]}...")
# 优先级 2查找别名mapping_type='alias'
# 2.1 Provider 特定别名
if not target_id and provider_id:
alias = (
db.query(ModelMapping)
.filter(
ModelMapping.source_model == source_model,
ModelMapping.provider_id == provider_id,
ModelMapping.mapping_type == "alias",
ModelMapping.is_active == True,
)
.first()
)
if alias:
target_id = alias.target_global_model_id
logger.debug(f"[MappingResolver] 命中 Provider 别名: {source_model} -> {target_id[:8]}...")
# 2.2 全局别名
if not target_id:
alias = (
db.query(ModelMapping)
.filter(
ModelMapping.source_model == source_model,
ModelMapping.provider_id.is_(None),
ModelMapping.mapping_type == "alias",
ModelMapping.is_active == True,
)
.first()
)
if alias:
target_id = alias.target_global_model_id
logger.debug(f"[MappingResolver] 命中全局别名: {source_model} -> {target_id[:8]}...")
# 优先级 3直接匹配 GlobalModel.name
if not target_id:
global_model = (
db.query(GlobalModel)
.filter(
GlobalModel.name == source_model,
GlobalModel.is_active == True,
)
.first()
)
if global_model:
target_id = global_model.id
logger.debug(f"[MappingResolver] 直接匹配 GlobalModel: {source_model}")
cached_value = target_id if target_id is not None else ""
await self._mapping_cache.set(cache_key, cached_value, self._cache_ttl)
return target_id
async def resolve_to_global_model_name(
self,
db: Session,
source_model: str,
provider_id: Optional[str] = None,
) -> str:
"""解析模型名/别名为 GlobalModel.name。未找到时返回原始输入。"""
target_id = await self._lookup_target_global_model_id(db, source_model, provider_id)
if not target_id:
return source_model
await self._ensure_initialized()
cached_name = await self._global_model_cache.get(target_id)
if cached_name:
self._stats["global_hits"] += 1
return cached_name
self._stats["global_misses"] += 1
global_model = (
db.query(GlobalModel)
.filter(GlobalModel.id == target_id, GlobalModel.is_active == True)
.first()
)
if global_model:
await self._global_model_cache.set(target_id, global_model.name, self._cache_ttl)
return global_model.name
return source_model
async def get_global_model_by_request(
self,
db: Session,
source_model: str,
provider_id: Optional[str] = None,
) -> Optional[GlobalModel]:
"""解析并返回 GlobalModel 对象(绑定当前 Session"""
target_id = await self._lookup_target_global_model_id(db, source_model, provider_id)
if not target_id:
return None
global_model = (
db.query(GlobalModel)
.filter(GlobalModel.id == target_id, GlobalModel.is_active == True)
.first()
)
return global_model
async def get_global_model_with_mapping_info(
self,
db: Session,
source_model: str,
provider_id: Optional[str] = None,
) -> Tuple[Optional[GlobalModel], bool]:
"""
解析并返回 GlobalModel 对象,同时返回是否发生了映射。
Args:
db: 数据库会话
source_model: 用户请求的模型名
provider_id: Provider ID可选
Returns:
(global_model, is_mapped) - GlobalModel 对象和是否发生了映射
is_mapped=True 表示 source_model 通过 mapping 规则映射到了不同的模型
is_mapped=False 表示 source_model 直接匹配或通过 alias 匹配
"""
await self._ensure_initialized()
# 先检查是否存在 mapping 类型的映射规则
has_mapping = False
# 检查 Provider 特定映射
if provider_id:
mapping = (
db.query(ModelMapping)
.filter(
ModelMapping.source_model == source_model,
ModelMapping.provider_id == provider_id,
ModelMapping.mapping_type == "mapping",
ModelMapping.is_active == True,
)
.first()
)
if mapping:
has_mapping = True
# 检查全局映射
if not has_mapping:
mapping = (
db.query(ModelMapping)
.filter(
ModelMapping.source_model == source_model,
ModelMapping.provider_id.is_(None),
ModelMapping.mapping_type == "mapping",
ModelMapping.is_active == True,
)
.first()
)
if mapping:
has_mapping = True
# 获取 GlobalModel
global_model = await self.get_global_model_by_request(db, source_model, provider_id)
return global_model, has_mapping
async def get_global_model_direct(
self,
db: Session,
source_model: str,
) -> Optional[GlobalModel]:
"""
直接通过模型名获取 GlobalModel不应用任何映射规则。
仅查找 alias 和直接匹配。
Args:
db: 数据库会话
source_model: 模型名
Returns:
GlobalModel 对象或 None
"""
# 优先级 1查找别名alias
# 全局别名
alias = (
db.query(ModelMapping)
.filter(
ModelMapping.source_model == source_model,
ModelMapping.provider_id.is_(None),
ModelMapping.mapping_type == "alias",
ModelMapping.is_active == True,
)
.first()
)
if alias:
global_model = (
db.query(GlobalModel)
.filter(GlobalModel.id == alias.target_global_model_id, GlobalModel.is_active == True)
.first()
)
if global_model:
return global_model
# 优先级 2直接匹配 GlobalModel.name
global_model = (
db.query(GlobalModel)
.filter(
GlobalModel.name == source_model,
GlobalModel.is_active == True,
)
.first()
)
return global_model
def find_similar_models(
self,
db: Session,
invalid_model: str,
limit: int = 3,
threshold: float = 0.4,
) -> List[Tuple[str, float]]:
"""用于提示相似的 GlobalModel.name。"""
from difflib import SequenceMatcher
all_models = db.query(GlobalModel.name).filter(GlobalModel.is_active == True).all()
similarities: List[Tuple[str, float]] = []
invalid_lower = invalid_model.lower()
for model in all_models:
model_name = model.name
ratio = SequenceMatcher(None, invalid_lower, model_name.lower()).ratio()
if invalid_lower in model_name.lower() or model_name.lower() in invalid_lower:
ratio += 0.2
if ratio >= threshold:
similarities.append((model_name, ratio))
similarities.sort(key=lambda item: item[1], reverse=True)
return similarities[:limit]
async def invalidate_mapping_cache(self, source_model: str, provider_id: Optional[str] = None):
await self._ensure_initialized()
keys = [self._cache_key(source_model, provider_id)]
if provider_id:
keys.append(self._cache_key(source_model, None))
for key in keys:
await self._mapping_cache.delete(key)
async def invalidate_global_model_cache(self, global_model_id: Optional[str] = None):
await self._ensure_initialized()
if global_model_id:
await self._global_model_cache.delete(global_model_id)
else:
await self._global_model_cache.clear()
async def clear_cache(self):
await self._ensure_initialized()
await self._mapping_cache.clear()
await self._global_model_cache.clear()
def get_stats(self) -> dict:
total_mapping = self._stats["mapping_hits"] + self._stats["mapping_misses"]
total_global = self._stats["global_hits"] + self._stats["global_misses"]
stats = {
"mapping_hit_rate": (
self._stats["mapping_hits"] / total_mapping if total_mapping else 0.0
),
"global_hit_rate": self._stats["global_hits"] / total_global if total_global else 0.0,
"stats": self._stats,
}
if self._initialized:
stats["mapping_cache_backend"] = self._mapping_cache.get_stats()
stats["global_cache_backend"] = self._global_model_cache.get_stats()
return stats
_model_mapping_resolver: Optional[ModelMappingResolver] = None
def get_model_mapping_resolver(
cache_ttl: int = 300, cache_backend_type: Optional[str] = None
) -> ModelMappingResolver:
global _model_mapping_resolver
if _model_mapping_resolver is None:
if cache_backend_type is None:
cache_backend_type = os.getenv("ALIAS_CACHE_BACKEND", "auto")
_model_mapping_resolver = ModelMappingResolver(
cache_ttl=cache_ttl,
cache_backend_type=cache_backend_type,
)
logger.debug(f"[ModelMappingResolver] 初始化cache_ttl={cache_ttl}s, backend={cache_backend_type})")
# 注册到缓存失效服务
try:
from src.services.cache.invalidation import get_cache_invalidation_service
cache_service = get_cache_invalidation_service()
cache_service.set_mapping_resolver(_model_mapping_resolver)
except Exception as exc:
logger.warning(f"[ModelMappingResolver] 注册缓存失效服务失败: {exc}")
return _model_mapping_resolver
async def resolve_model_to_global_name(
db: Session,
source_model: str,
provider_id: Optional[str] = None,
) -> str:
resolver = get_model_mapping_resolver()
return await resolver.resolve_to_global_model_name(db, source_model, provider_id)
async def get_global_model_by_request(
db: Session,
source_model: str,
provider_id: Optional[str] = None,
) -> Optional[GlobalModel]:
resolver = get_model_mapping_resolver()
return await resolver.get_global_model_by_request(db, source_model, provider_id)