From 728f9bb12682da9b9fb2aac4171d6f982f6e2650 Mon Sep 17 00:00:00 2001 From: fawney19 Date: Mon, 15 Dec 2025 14:30:00 +0800 Subject: [PATCH] refactor(backend): remove model mappings module --- src/api/admin/models/__init__.py | 2 - src/api/admin/models/mappings.py | 303 ----------------- src/services/model/mapping_resolver.py | 432 ------------------------- 3 files changed, 737 deletions(-) delete mode 100644 src/api/admin/models/mappings.py delete mode 100644 src/services/model/mapping_resolver.py diff --git a/src/api/admin/models/__init__.py b/src/api/admin/models/__init__.py index 43f5730..ec17618 100644 --- a/src/api/admin/models/__init__.py +++ b/src/api/admin/models/__init__.py @@ -6,11 +6,9 @@ from fastapi import APIRouter from .catalog import router as catalog_router from .global_models import router as global_models_router -from .mappings import router as mappings_router router = APIRouter(prefix="/api/admin/models", tags=["Admin - Model Management"]) # 挂载子路由 router.include_router(catalog_router) router.include_router(global_models_router) -router.include_router(mappings_router) diff --git a/src/api/admin/models/mappings.py b/src/api/admin/models/mappings.py deleted file mode 100644 index d1c5fa5..0000000 --- a/src/api/admin/models/mappings.py +++ /dev/null @@ -1,303 +0,0 @@ -"""模型映射管理 API - -提供模型映射的 CRUD 操作。 - -模型映射(Mapping)用于将源模型映射到目标模型,例如: -- 请求 gpt-5.1 → Provider A 映射到 gpt-4 -- 用于处理 Provider 不支持请求模型的情况 - -映射必须关联到特定的 Provider。 -""" - -import uuid -from datetime import datetime, timezone -from typing import List, Optional - -from fastapi import APIRouter, Depends, HTTPException, Query -from sqlalchemy.orm import Session, joinedload - -from src.core.logger import logger -from src.database import get_db -from src.models.api import ( - ModelMappingCreate, - ModelMappingResponse, - ModelMappingUpdate, -) -from src.models.database import GlobalModel, ModelMapping, Provider, User -from src.services.cache.invalidation import get_cache_invalidation_service - - -router = APIRouter(prefix="/mappings", tags=["Model Mappings"]) - - -def _serialize_mapping(mapping: ModelMapping) -> ModelMappingResponse: - target = mapping.target_global_model - provider = mapping.provider - scope = "provider" if mapping.provider_id else "global" - return ModelMappingResponse( - id=mapping.id, - source_model=mapping.source_model, - target_global_model_id=mapping.target_global_model_id, - target_global_model_name=target.name if target else None, - target_global_model_display_name=target.display_name if target else None, - provider_id=mapping.provider_id, - provider_name=provider.name if provider else None, - scope=scope, - mapping_type=mapping.mapping_type, - is_active=mapping.is_active, - created_at=mapping.created_at, - updated_at=mapping.updated_at, - ) - - -@router.get("", response_model=List[ModelMappingResponse]) -async def list_mappings( - provider_id: Optional[str] = Query(None, description="按 Provider 筛选"), - source_model: Optional[str] = Query(None, description="按源模型名筛选"), - target_global_model_id: Optional[str] = Query(None, description="按目标模型筛选"), - scope: Optional[str] = Query(None, description="global 或 provider"), - mapping_type: Optional[str] = Query(None, description="映射类型: alias 或 mapping"), - is_active: Optional[bool] = Query(None, description="按状态筛选"), - skip: int = Query(0, ge=0, description="跳过记录数"), - limit: int = Query(100, ge=1, le=1000, description="返回记录数"), - db: Session = Depends(get_db), -): - """获取模型映射列表""" - query = db.query(ModelMapping).options( - joinedload(ModelMapping.target_global_model), - joinedload(ModelMapping.provider), - ) - - if provider_id is not None: - query = query.filter(ModelMapping.provider_id == provider_id) - if scope == "global": - query = query.filter(ModelMapping.provider_id.is_(None)) - elif scope == "provider": - query = query.filter(ModelMapping.provider_id.isnot(None)) - if mapping_type is not None: - query = query.filter(ModelMapping.mapping_type == mapping_type) - if source_model: - query = query.filter(ModelMapping.source_model.ilike(f"%{source_model}%")) - if target_global_model_id is not None: - query = query.filter(ModelMapping.target_global_model_id == target_global_model_id) - if is_active is not None: - query = query.filter(ModelMapping.is_active == is_active) - - mappings = query.offset(skip).limit(limit).all() - return [_serialize_mapping(mapping) for mapping in mappings] - - -@router.get("/{mapping_id}", response_model=ModelMappingResponse) -async def get_mapping( - mapping_id: str, - db: Session = Depends(get_db), -): - """获取单个模型映射""" - mapping = ( - db.query(ModelMapping) - .options( - joinedload(ModelMapping.target_global_model), - joinedload(ModelMapping.provider), - ) - .filter(ModelMapping.id == mapping_id) - .first() - ) - - if not mapping: - raise HTTPException(status_code=404, detail=f"映射 {mapping_id} 不存在") - - return _serialize_mapping(mapping) - - -@router.post("", response_model=ModelMappingResponse, status_code=201) -async def create_mapping( - data: ModelMappingCreate, - db: Session = Depends(get_db), -): - """创建模型映射""" - source_model = data.source_model.strip() - if not source_model: - raise HTTPException(status_code=400, detail="source_model 不能为空") - - # 验证 mapping_type - if data.mapping_type not in ("alias", "mapping"): - raise HTTPException(status_code=400, detail="mapping_type 必须是 'alias' 或 'mapping'") - - # 验证目标 GlobalModel 存在 - target_model = ( - db.query(GlobalModel) - .filter(GlobalModel.id == data.target_global_model_id, GlobalModel.is_active == True) - .first() - ) - if not target_model: - raise HTTPException( - status_code=404, detail=f"目标模型 {data.target_global_model_id} 不存在或未激活" - ) - - # 验证 Provider 存在 - provider = None - provider_id = data.provider_id - if provider_id: - provider = db.query(Provider).filter(Provider.id == provider_id).first() - if not provider: - raise HTTPException(status_code=404, detail=f"Provider {provider_id} 不存在") - - # 检查映射是否已存在(全局或同一 Provider 下不可重复) - existing = ( - db.query(ModelMapping) - .filter( - ModelMapping.source_model == source_model, - ModelMapping.provider_id == provider_id, - ) - .first() - ) - if existing: - raise HTTPException(status_code=400, detail="映射已存在") - - # 创建映射 - mapping = ModelMapping( - id=str(uuid.uuid4()), - source_model=source_model, - target_global_model_id=data.target_global_model_id, - provider_id=provider_id, - mapping_type=data.mapping_type, - is_active=data.is_active, - created_at=datetime.now(timezone.utc), - updated_at=datetime.now(timezone.utc), - ) - - db.add(mapping) - db.commit() - mapping = ( - db.query(ModelMapping) - .options( - joinedload(ModelMapping.target_global_model), - joinedload(ModelMapping.provider), - ) - .filter(ModelMapping.id == mapping.id) - .first() - ) - - logger.info(f"创建模型映射: {source_model} -> {target_model.name} " - f"(Provider: {provider.name if provider else 'global'}, ID: {mapping.id})") - - cache_service = get_cache_invalidation_service() - cache_service.on_model_mapping_changed(source_model, provider_id) - - return _serialize_mapping(mapping) - - -@router.patch("/{mapping_id}", response_model=ModelMappingResponse) -async def update_mapping( - mapping_id: str, - data: ModelMappingUpdate, - db: Session = Depends(get_db), -): - """更新模型映射""" - mapping = db.query(ModelMapping).filter(ModelMapping.id == mapping_id).first() - if not mapping: - raise HTTPException(status_code=404, detail=f"映射 {mapping_id} 不存在") - - update_data = data.model_dump(exclude_unset=True) - - # 更新 Provider - if "provider_id" in update_data: - new_provider_id = update_data["provider_id"] - if new_provider_id: - provider = db.query(Provider).filter(Provider.id == new_provider_id).first() - if not provider: - raise HTTPException(status_code=404, detail=f"Provider {new_provider_id} 不存在") - mapping.provider_id = new_provider_id - - # 更新目标模型 - if "target_global_model_id" in update_data: - target_model = ( - db.query(GlobalModel) - .filter( - GlobalModel.id == update_data["target_global_model_id"], - GlobalModel.is_active == True, - ) - .first() - ) - if not target_model: - raise HTTPException( - status_code=404, - detail=f"目标模型 {update_data['target_global_model_id']} 不存在或未激活", - ) - mapping.target_global_model_id = update_data["target_global_model_id"] - - # 更新源模型名 - if "source_model" in update_data: - new_source = update_data["source_model"].strip() - if not new_source: - raise HTTPException(status_code=400, detail="source_model 不能为空") - mapping.source_model = new_source - - # 检查唯一约束 - duplicate = ( - db.query(ModelMapping) - .filter( - ModelMapping.source_model == mapping.source_model, - ModelMapping.provider_id == mapping.provider_id, - ModelMapping.id != mapping_id, - ) - .first() - ) - if duplicate: - raise HTTPException(status_code=400, detail="映射已存在") - - # 更新映射类型 - if "mapping_type" in update_data: - if update_data["mapping_type"] not in ("alias", "mapping"): - raise HTTPException(status_code=400, detail="mapping_type 必须是 'alias' 或 'mapping'") - mapping.mapping_type = update_data["mapping_type"] - - # 更新状态 - if "is_active" in update_data: - mapping.is_active = update_data["is_active"] - - mapping.updated_at = datetime.now(timezone.utc) - db.commit() - db.refresh(mapping) - - logger.info(f"更新模型映射 (ID: {mapping.id})") - - mapping = ( - db.query(ModelMapping) - .options( - joinedload(ModelMapping.target_global_model), - joinedload(ModelMapping.provider), - ) - .filter(ModelMapping.id == mapping.id) - .first() - ) - - cache_service = get_cache_invalidation_service() - cache_service.on_model_mapping_changed(mapping.source_model, mapping.provider_id) - - return _serialize_mapping(mapping) - - -@router.delete("/{mapping_id}", status_code=204) -async def delete_mapping( - mapping_id: str, - db: Session = Depends(get_db), -): - """删除模型映射""" - mapping = db.query(ModelMapping).filter(ModelMapping.id == mapping_id).first() - - if not mapping: - raise HTTPException(status_code=404, detail=f"映射 {mapping_id} 不存在") - - source_model = mapping.source_model - provider_id = mapping.provider_id - - logger.info(f"删除模型映射: {source_model} -> {mapping.target_global_model_id} (ID: {mapping.id})") - - db.delete(mapping) - db.commit() - - cache_service = get_cache_invalidation_service() - cache_service.on_model_mapping_changed(source_model, provider_id) - - return None diff --git a/src/services/model/mapping_resolver.py b/src/services/model/mapping_resolver.py deleted file mode 100644 index a7c6b11..0000000 --- a/src/services/model/mapping_resolver.py +++ /dev/null @@ -1,432 +0,0 @@ -""" -模型映射解析服务 - -负责统一的模型别名/降级解析,按优先级顺序: -1. 映射(mapping):Provider 特定 → 全局 -2. 别名(alias):Provider 特定 → 全局 -3. 直接匹配 GlobalModel.name - -支持特性: -- 带缓存(本地或 Redis),减少数据库访问 -- 提供模糊匹配能力,用于提示相似模型 -""" - -from __future__ import annotations - -import os -from typing import List, Optional, Tuple - -from src.core.logger import logger -from sqlalchemy.orm import Session - -from src.config.constants import CacheSize, CacheTTL -from src.core.logger import logger -from src.models.database import GlobalModel, ModelMapping -from src.services.cache.backend import BaseCacheBackend, get_cache_backend - - -class ModelMappingResolver: - """统一的 ModelMapping 解析服务(可跨进程共享缓存)。""" - - def __init__(self, cache_ttl: int = CacheTTL.MODEL_MAPPING, cache_backend_type: str = "auto"): - self._cache_ttl = cache_ttl - self._cache_backend_type = cache_backend_type - self._mapping_cache: Optional[BaseCacheBackend] = None - self._global_model_cache: Optional[BaseCacheBackend] = None - self._initialized = False - self._stats = { - "mapping_hits": 0, - "mapping_misses": 0, - "global_hits": 0, - "global_misses": 0, - } - - async def _ensure_initialized(self): - if self._initialized: - return - - self._mapping_cache = await get_cache_backend( - name="model_mapping_resolver:mapping", - backend_type=self._cache_backend_type, - max_size=CacheSize.MODEL_MAPPING, - ttl=self._cache_ttl, - ) - self._global_model_cache = await get_cache_backend( - name="model_mapping_resolver:global", - backend_type=self._cache_backend_type, - max_size=CacheSize.MODEL_MAPPING, - ttl=self._cache_ttl, - ) - self._initialized = True - logger.debug(f"[ModelMappingResolver] 缓存后端已初始化: {self._mapping_cache.get_stats()['backend']}") - - def _cache_key(self, source_model: str, provider_id: Optional[str]) -> str: - return f"{provider_id or 'global'}:{source_model}" - - async def _lookup_target_global_model_id( - self, - db: Session, - source_model: str, - provider_id: Optional[str] = None, - ) -> Optional[str]: - """ - 按优先级查找目标 GlobalModel ID: - 1. 映射(mapping_type='mapping'):Provider 特定 → 全局 - 2. 别名(mapping_type='alias'):Provider 特定 → 全局 - 3. 直接匹配 GlobalModel.name - """ - await self._ensure_initialized() - cache_key = self._cache_key(source_model, provider_id) - cached = await self._mapping_cache.get(cache_key) - if cached is not None: - self._stats["mapping_hits"] += 1 - return cached or None - - self._stats["mapping_misses"] += 1 - - target_id: Optional[str] = None - - # 优先级 1:查找映射(mapping_type='mapping') - # 1.1 Provider 特定映射 - if provider_id: - mapping = ( - db.query(ModelMapping) - .filter( - ModelMapping.source_model == source_model, - ModelMapping.provider_id == provider_id, - ModelMapping.mapping_type == "mapping", - ModelMapping.is_active == True, - ) - .first() - ) - if mapping: - target_id = mapping.target_global_model_id - logger.debug(f"[MappingResolver] 命中 Provider 映射: {source_model} -> {target_id[:8]}...") - - # 1.2 全局映射 - if not target_id: - mapping = ( - db.query(ModelMapping) - .filter( - ModelMapping.source_model == source_model, - ModelMapping.provider_id.is_(None), - ModelMapping.mapping_type == "mapping", - ModelMapping.is_active == True, - ) - .first() - ) - if mapping: - target_id = mapping.target_global_model_id - logger.debug(f"[MappingResolver] 命中全局映射: {source_model} -> {target_id[:8]}...") - - # 优先级 2:查找别名(mapping_type='alias') - # 2.1 Provider 特定别名 - if not target_id and provider_id: - alias = ( - db.query(ModelMapping) - .filter( - ModelMapping.source_model == source_model, - ModelMapping.provider_id == provider_id, - ModelMapping.mapping_type == "alias", - ModelMapping.is_active == True, - ) - .first() - ) - if alias: - target_id = alias.target_global_model_id - logger.debug(f"[MappingResolver] 命中 Provider 别名: {source_model} -> {target_id[:8]}...") - - # 2.2 全局别名 - if not target_id: - alias = ( - db.query(ModelMapping) - .filter( - ModelMapping.source_model == source_model, - ModelMapping.provider_id.is_(None), - ModelMapping.mapping_type == "alias", - ModelMapping.is_active == True, - ) - .first() - ) - if alias: - target_id = alias.target_global_model_id - logger.debug(f"[MappingResolver] 命中全局别名: {source_model} -> {target_id[:8]}...") - - # 优先级 3:直接匹配 GlobalModel.name - if not target_id: - global_model = ( - db.query(GlobalModel) - .filter( - GlobalModel.name == source_model, - GlobalModel.is_active == True, - ) - .first() - ) - if global_model: - target_id = global_model.id - logger.debug(f"[MappingResolver] 直接匹配 GlobalModel: {source_model}") - - cached_value = target_id if target_id is not None else "" - await self._mapping_cache.set(cache_key, cached_value, self._cache_ttl) - return target_id - - async def resolve_to_global_model_name( - self, - db: Session, - source_model: str, - provider_id: Optional[str] = None, - ) -> str: - """解析模型名/别名为 GlobalModel.name。未找到时返回原始输入。""" - target_id = await self._lookup_target_global_model_id(db, source_model, provider_id) - if not target_id: - return source_model - - await self._ensure_initialized() - cached_name = await self._global_model_cache.get(target_id) - if cached_name: - self._stats["global_hits"] += 1 - return cached_name - - self._stats["global_misses"] += 1 - global_model = ( - db.query(GlobalModel) - .filter(GlobalModel.id == target_id, GlobalModel.is_active == True) - .first() - ) - if global_model: - await self._global_model_cache.set(target_id, global_model.name, self._cache_ttl) - return global_model.name - - return source_model - - async def get_global_model_by_request( - self, - db: Session, - source_model: str, - provider_id: Optional[str] = None, - ) -> Optional[GlobalModel]: - """解析并返回 GlobalModel 对象(绑定当前 Session)。""" - target_id = await self._lookup_target_global_model_id(db, source_model, provider_id) - if not target_id: - return None - - global_model = ( - db.query(GlobalModel) - .filter(GlobalModel.id == target_id, GlobalModel.is_active == True) - .first() - ) - return global_model - - async def get_global_model_with_mapping_info( - self, - db: Session, - source_model: str, - provider_id: Optional[str] = None, - ) -> Tuple[Optional[GlobalModel], bool]: - """ - 解析并返回 GlobalModel 对象,同时返回是否发生了映射。 - - Args: - db: 数据库会话 - source_model: 用户请求的模型名 - provider_id: Provider ID(可选) - - Returns: - (global_model, is_mapped) - GlobalModel 对象和是否发生了映射 - is_mapped=True 表示 source_model 通过 mapping 规则映射到了不同的模型 - is_mapped=False 表示 source_model 直接匹配或通过 alias 匹配 - """ - await self._ensure_initialized() - - # 先检查是否存在 mapping 类型的映射规则 - has_mapping = False - - # 检查 Provider 特定映射 - if provider_id: - mapping = ( - db.query(ModelMapping) - .filter( - ModelMapping.source_model == source_model, - ModelMapping.provider_id == provider_id, - ModelMapping.mapping_type == "mapping", - ModelMapping.is_active == True, - ) - .first() - ) - if mapping: - has_mapping = True - - # 检查全局映射 - if not has_mapping: - mapping = ( - db.query(ModelMapping) - .filter( - ModelMapping.source_model == source_model, - ModelMapping.provider_id.is_(None), - ModelMapping.mapping_type == "mapping", - ModelMapping.is_active == True, - ) - .first() - ) - if mapping: - has_mapping = True - - # 获取 GlobalModel - global_model = await self.get_global_model_by_request(db, source_model, provider_id) - - return global_model, has_mapping - - async def get_global_model_direct( - self, - db: Session, - source_model: str, - ) -> Optional[GlobalModel]: - """ - 直接通过模型名获取 GlobalModel,不应用任何映射规则。 - 仅查找 alias 和直接匹配。 - - Args: - db: 数据库会话 - source_model: 模型名 - - Returns: - GlobalModel 对象或 None - """ - # 优先级 1:查找别名(alias) - # 全局别名 - alias = ( - db.query(ModelMapping) - .filter( - ModelMapping.source_model == source_model, - ModelMapping.provider_id.is_(None), - ModelMapping.mapping_type == "alias", - ModelMapping.is_active == True, - ) - .first() - ) - if alias: - global_model = ( - db.query(GlobalModel) - .filter(GlobalModel.id == alias.target_global_model_id, GlobalModel.is_active == True) - .first() - ) - if global_model: - return global_model - - # 优先级 2:直接匹配 GlobalModel.name - global_model = ( - db.query(GlobalModel) - .filter( - GlobalModel.name == source_model, - GlobalModel.is_active == True, - ) - .first() - ) - return global_model - - def find_similar_models( - self, - db: Session, - invalid_model: str, - limit: int = 3, - threshold: float = 0.4, - ) -> List[Tuple[str, float]]: - """用于提示相似的 GlobalModel.name。""" - from difflib import SequenceMatcher - - all_models = db.query(GlobalModel.name).filter(GlobalModel.is_active == True).all() - similarities: List[Tuple[str, float]] = [] - invalid_lower = invalid_model.lower() - - for model in all_models: - model_name = model.name - ratio = SequenceMatcher(None, invalid_lower, model_name.lower()).ratio() - if invalid_lower in model_name.lower() or model_name.lower() in invalid_lower: - ratio += 0.2 - if ratio >= threshold: - similarities.append((model_name, ratio)) - - similarities.sort(key=lambda item: item[1], reverse=True) - return similarities[:limit] - - async def invalidate_mapping_cache(self, source_model: str, provider_id: Optional[str] = None): - await self._ensure_initialized() - keys = [self._cache_key(source_model, provider_id)] - if provider_id: - keys.append(self._cache_key(source_model, None)) - for key in keys: - await self._mapping_cache.delete(key) - - async def invalidate_global_model_cache(self, global_model_id: Optional[str] = None): - await self._ensure_initialized() - if global_model_id: - await self._global_model_cache.delete(global_model_id) - else: - await self._global_model_cache.clear() - - async def clear_cache(self): - await self._ensure_initialized() - await self._mapping_cache.clear() - await self._global_model_cache.clear() - - def get_stats(self) -> dict: - total_mapping = self._stats["mapping_hits"] + self._stats["mapping_misses"] - total_global = self._stats["global_hits"] + self._stats["global_misses"] - stats = { - "mapping_hit_rate": ( - self._stats["mapping_hits"] / total_mapping if total_mapping else 0.0 - ), - "global_hit_rate": self._stats["global_hits"] / total_global if total_global else 0.0, - "stats": self._stats, - } - if self._initialized: - stats["mapping_cache_backend"] = self._mapping_cache.get_stats() - stats["global_cache_backend"] = self._global_model_cache.get_stats() - return stats - - -_model_mapping_resolver: Optional[ModelMappingResolver] = None - - -def get_model_mapping_resolver( - cache_ttl: int = 300, cache_backend_type: Optional[str] = None -) -> ModelMappingResolver: - global _model_mapping_resolver - - if _model_mapping_resolver is None: - if cache_backend_type is None: - cache_backend_type = os.getenv("ALIAS_CACHE_BACKEND", "auto") - _model_mapping_resolver = ModelMappingResolver( - cache_ttl=cache_ttl, - cache_backend_type=cache_backend_type, - ) - logger.debug(f"[ModelMappingResolver] 初始化(cache_ttl={cache_ttl}s, backend={cache_backend_type})") - - # 注册到缓存失效服务 - try: - from src.services.cache.invalidation import get_cache_invalidation_service - - cache_service = get_cache_invalidation_service() - cache_service.set_mapping_resolver(_model_mapping_resolver) - except Exception as exc: - logger.warning(f"[ModelMappingResolver] 注册缓存失效服务失败: {exc}") - - return _model_mapping_resolver - - -async def resolve_model_to_global_name( - db: Session, - source_model: str, - provider_id: Optional[str] = None, -) -> str: - resolver = get_model_mapping_resolver() - return await resolver.resolve_to_global_model_name(db, source_model, provider_id) - - -async def get_global_model_by_request( - db: Session, - source_model: str, - provider_id: Optional[str] = None, -) -> Optional[GlobalModel]: - resolver = get_model_mapping_resolver() - return await resolver.get_global_model_by_request(db, source_model, provider_id)