mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-02 15:52:26 +08:00
refactor(backend): remove model mappings module
This commit is contained in:
@@ -6,11 +6,9 @@ from fastapi import APIRouter
|
|||||||
|
|
||||||
from .catalog import router as catalog_router
|
from .catalog import router as catalog_router
|
||||||
from .global_models import router as global_models_router
|
from .global_models import router as global_models_router
|
||||||
from .mappings import router as mappings_router
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/admin/models", tags=["Admin - Model Management"])
|
router = APIRouter(prefix="/api/admin/models", tags=["Admin - Model Management"])
|
||||||
|
|
||||||
# 挂载子路由
|
# 挂载子路由
|
||||||
router.include_router(catalog_router)
|
router.include_router(catalog_router)
|
||||||
router.include_router(global_models_router)
|
router.include_router(global_models_router)
|
||||||
router.include_router(mappings_router)
|
|
||||||
|
|||||||
@@ -1,303 +0,0 @@
|
|||||||
"""模型映射管理 API
|
|
||||||
|
|
||||||
提供模型映射的 CRUD 操作。
|
|
||||||
|
|
||||||
模型映射(Mapping)用于将源模型映射到目标模型,例如:
|
|
||||||
- 请求 gpt-5.1 → Provider A 映射到 gpt-4
|
|
||||||
- 用于处理 Provider 不支持请求模型的情况
|
|
||||||
|
|
||||||
映射必须关联到特定的 Provider。
|
|
||||||
"""
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
||||||
from sqlalchemy.orm import Session, joinedload
|
|
||||||
|
|
||||||
from src.core.logger import logger
|
|
||||||
from src.database import get_db
|
|
||||||
from src.models.api import (
|
|
||||||
ModelMappingCreate,
|
|
||||||
ModelMappingResponse,
|
|
||||||
ModelMappingUpdate,
|
|
||||||
)
|
|
||||||
from src.models.database import GlobalModel, ModelMapping, Provider, User
|
|
||||||
from src.services.cache.invalidation import get_cache_invalidation_service
|
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/mappings", tags=["Model Mappings"])
|
|
||||||
|
|
||||||
|
|
||||||
def _serialize_mapping(mapping: ModelMapping) -> ModelMappingResponse:
|
|
||||||
target = mapping.target_global_model
|
|
||||||
provider = mapping.provider
|
|
||||||
scope = "provider" if mapping.provider_id else "global"
|
|
||||||
return ModelMappingResponse(
|
|
||||||
id=mapping.id,
|
|
||||||
source_model=mapping.source_model,
|
|
||||||
target_global_model_id=mapping.target_global_model_id,
|
|
||||||
target_global_model_name=target.name if target else None,
|
|
||||||
target_global_model_display_name=target.display_name if target else None,
|
|
||||||
provider_id=mapping.provider_id,
|
|
||||||
provider_name=provider.name if provider else None,
|
|
||||||
scope=scope,
|
|
||||||
mapping_type=mapping.mapping_type,
|
|
||||||
is_active=mapping.is_active,
|
|
||||||
created_at=mapping.created_at,
|
|
||||||
updated_at=mapping.updated_at,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=List[ModelMappingResponse])
|
|
||||||
async def list_mappings(
|
|
||||||
provider_id: Optional[str] = Query(None, description="按 Provider 筛选"),
|
|
||||||
source_model: Optional[str] = Query(None, description="按源模型名筛选"),
|
|
||||||
target_global_model_id: Optional[str] = Query(None, description="按目标模型筛选"),
|
|
||||||
scope: Optional[str] = Query(None, description="global 或 provider"),
|
|
||||||
mapping_type: Optional[str] = Query(None, description="映射类型: alias 或 mapping"),
|
|
||||||
is_active: Optional[bool] = Query(None, description="按状态筛选"),
|
|
||||||
skip: int = Query(0, ge=0, description="跳过记录数"),
|
|
||||||
limit: int = Query(100, ge=1, le=1000, description="返回记录数"),
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""获取模型映射列表"""
|
|
||||||
query = db.query(ModelMapping).options(
|
|
||||||
joinedload(ModelMapping.target_global_model),
|
|
||||||
joinedload(ModelMapping.provider),
|
|
||||||
)
|
|
||||||
|
|
||||||
if provider_id is not None:
|
|
||||||
query = query.filter(ModelMapping.provider_id == provider_id)
|
|
||||||
if scope == "global":
|
|
||||||
query = query.filter(ModelMapping.provider_id.is_(None))
|
|
||||||
elif scope == "provider":
|
|
||||||
query = query.filter(ModelMapping.provider_id.isnot(None))
|
|
||||||
if mapping_type is not None:
|
|
||||||
query = query.filter(ModelMapping.mapping_type == mapping_type)
|
|
||||||
if source_model:
|
|
||||||
query = query.filter(ModelMapping.source_model.ilike(f"%{source_model}%"))
|
|
||||||
if target_global_model_id is not None:
|
|
||||||
query = query.filter(ModelMapping.target_global_model_id == target_global_model_id)
|
|
||||||
if is_active is not None:
|
|
||||||
query = query.filter(ModelMapping.is_active == is_active)
|
|
||||||
|
|
||||||
mappings = query.offset(skip).limit(limit).all()
|
|
||||||
return [_serialize_mapping(mapping) for mapping in mappings]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{mapping_id}", response_model=ModelMappingResponse)
|
|
||||||
async def get_mapping(
|
|
||||||
mapping_id: str,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""获取单个模型映射"""
|
|
||||||
mapping = (
|
|
||||||
db.query(ModelMapping)
|
|
||||||
.options(
|
|
||||||
joinedload(ModelMapping.target_global_model),
|
|
||||||
joinedload(ModelMapping.provider),
|
|
||||||
)
|
|
||||||
.filter(ModelMapping.id == mapping_id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not mapping:
|
|
||||||
raise HTTPException(status_code=404, detail=f"映射 {mapping_id} 不存在")
|
|
||||||
|
|
||||||
return _serialize_mapping(mapping)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=ModelMappingResponse, status_code=201)
|
|
||||||
async def create_mapping(
|
|
||||||
data: ModelMappingCreate,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""创建模型映射"""
|
|
||||||
source_model = data.source_model.strip()
|
|
||||||
if not source_model:
|
|
||||||
raise HTTPException(status_code=400, detail="source_model 不能为空")
|
|
||||||
|
|
||||||
# 验证 mapping_type
|
|
||||||
if data.mapping_type not in ("alias", "mapping"):
|
|
||||||
raise HTTPException(status_code=400, detail="mapping_type 必须是 'alias' 或 'mapping'")
|
|
||||||
|
|
||||||
# 验证目标 GlobalModel 存在
|
|
||||||
target_model = (
|
|
||||||
db.query(GlobalModel)
|
|
||||||
.filter(GlobalModel.id == data.target_global_model_id, GlobalModel.is_active == True)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if not target_model:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404, detail=f"目标模型 {data.target_global_model_id} 不存在或未激活"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 验证 Provider 存在
|
|
||||||
provider = None
|
|
||||||
provider_id = data.provider_id
|
|
||||||
if provider_id:
|
|
||||||
provider = db.query(Provider).filter(Provider.id == provider_id).first()
|
|
||||||
if not provider:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Provider {provider_id} 不存在")
|
|
||||||
|
|
||||||
# 检查映射是否已存在(全局或同一 Provider 下不可重复)
|
|
||||||
existing = (
|
|
||||||
db.query(ModelMapping)
|
|
||||||
.filter(
|
|
||||||
ModelMapping.source_model == source_model,
|
|
||||||
ModelMapping.provider_id == provider_id,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if existing:
|
|
||||||
raise HTTPException(status_code=400, detail="映射已存在")
|
|
||||||
|
|
||||||
# 创建映射
|
|
||||||
mapping = ModelMapping(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
source_model=source_model,
|
|
||||||
target_global_model_id=data.target_global_model_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
mapping_type=data.mapping_type,
|
|
||||||
is_active=data.is_active,
|
|
||||||
created_at=datetime.now(timezone.utc),
|
|
||||||
updated_at=datetime.now(timezone.utc),
|
|
||||||
)
|
|
||||||
|
|
||||||
db.add(mapping)
|
|
||||||
db.commit()
|
|
||||||
mapping = (
|
|
||||||
db.query(ModelMapping)
|
|
||||||
.options(
|
|
||||||
joinedload(ModelMapping.target_global_model),
|
|
||||||
joinedload(ModelMapping.provider),
|
|
||||||
)
|
|
||||||
.filter(ModelMapping.id == mapping.id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"创建模型映射: {source_model} -> {target_model.name} "
|
|
||||||
f"(Provider: {provider.name if provider else 'global'}, ID: {mapping.id})")
|
|
||||||
|
|
||||||
cache_service = get_cache_invalidation_service()
|
|
||||||
cache_service.on_model_mapping_changed(source_model, provider_id)
|
|
||||||
|
|
||||||
return _serialize_mapping(mapping)
|
|
||||||
|
|
||||||
|
|
||||||
@router.patch("/{mapping_id}", response_model=ModelMappingResponse)
|
|
||||||
async def update_mapping(
|
|
||||||
mapping_id: str,
|
|
||||||
data: ModelMappingUpdate,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""更新模型映射"""
|
|
||||||
mapping = db.query(ModelMapping).filter(ModelMapping.id == mapping_id).first()
|
|
||||||
if not mapping:
|
|
||||||
raise HTTPException(status_code=404, detail=f"映射 {mapping_id} 不存在")
|
|
||||||
|
|
||||||
update_data = data.model_dump(exclude_unset=True)
|
|
||||||
|
|
||||||
# 更新 Provider
|
|
||||||
if "provider_id" in update_data:
|
|
||||||
new_provider_id = update_data["provider_id"]
|
|
||||||
if new_provider_id:
|
|
||||||
provider = db.query(Provider).filter(Provider.id == new_provider_id).first()
|
|
||||||
if not provider:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Provider {new_provider_id} 不存在")
|
|
||||||
mapping.provider_id = new_provider_id
|
|
||||||
|
|
||||||
# 更新目标模型
|
|
||||||
if "target_global_model_id" in update_data:
|
|
||||||
target_model = (
|
|
||||||
db.query(GlobalModel)
|
|
||||||
.filter(
|
|
||||||
GlobalModel.id == update_data["target_global_model_id"],
|
|
||||||
GlobalModel.is_active == True,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if not target_model:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail=f"目标模型 {update_data['target_global_model_id']} 不存在或未激活",
|
|
||||||
)
|
|
||||||
mapping.target_global_model_id = update_data["target_global_model_id"]
|
|
||||||
|
|
||||||
# 更新源模型名
|
|
||||||
if "source_model" in update_data:
|
|
||||||
new_source = update_data["source_model"].strip()
|
|
||||||
if not new_source:
|
|
||||||
raise HTTPException(status_code=400, detail="source_model 不能为空")
|
|
||||||
mapping.source_model = new_source
|
|
||||||
|
|
||||||
# 检查唯一约束
|
|
||||||
duplicate = (
|
|
||||||
db.query(ModelMapping)
|
|
||||||
.filter(
|
|
||||||
ModelMapping.source_model == mapping.source_model,
|
|
||||||
ModelMapping.provider_id == mapping.provider_id,
|
|
||||||
ModelMapping.id != mapping_id,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if duplicate:
|
|
||||||
raise HTTPException(status_code=400, detail="映射已存在")
|
|
||||||
|
|
||||||
# 更新映射类型
|
|
||||||
if "mapping_type" in update_data:
|
|
||||||
if update_data["mapping_type"] not in ("alias", "mapping"):
|
|
||||||
raise HTTPException(status_code=400, detail="mapping_type 必须是 'alias' 或 'mapping'")
|
|
||||||
mapping.mapping_type = update_data["mapping_type"]
|
|
||||||
|
|
||||||
# 更新状态
|
|
||||||
if "is_active" in update_data:
|
|
||||||
mapping.is_active = update_data["is_active"]
|
|
||||||
|
|
||||||
mapping.updated_at = datetime.now(timezone.utc)
|
|
||||||
db.commit()
|
|
||||||
db.refresh(mapping)
|
|
||||||
|
|
||||||
logger.info(f"更新模型映射 (ID: {mapping.id})")
|
|
||||||
|
|
||||||
mapping = (
|
|
||||||
db.query(ModelMapping)
|
|
||||||
.options(
|
|
||||||
joinedload(ModelMapping.target_global_model),
|
|
||||||
joinedload(ModelMapping.provider),
|
|
||||||
)
|
|
||||||
.filter(ModelMapping.id == mapping.id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
cache_service = get_cache_invalidation_service()
|
|
||||||
cache_service.on_model_mapping_changed(mapping.source_model, mapping.provider_id)
|
|
||||||
|
|
||||||
return _serialize_mapping(mapping)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{mapping_id}", status_code=204)
|
|
||||||
async def delete_mapping(
|
|
||||||
mapping_id: str,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""删除模型映射"""
|
|
||||||
mapping = db.query(ModelMapping).filter(ModelMapping.id == mapping_id).first()
|
|
||||||
|
|
||||||
if not mapping:
|
|
||||||
raise HTTPException(status_code=404, detail=f"映射 {mapping_id} 不存在")
|
|
||||||
|
|
||||||
source_model = mapping.source_model
|
|
||||||
provider_id = mapping.provider_id
|
|
||||||
|
|
||||||
logger.info(f"删除模型映射: {source_model} -> {mapping.target_global_model_id} (ID: {mapping.id})")
|
|
||||||
|
|
||||||
db.delete(mapping)
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
cache_service = get_cache_invalidation_service()
|
|
||||||
cache_service.on_model_mapping_changed(source_model, provider_id)
|
|
||||||
|
|
||||||
return None
|
|
||||||
@@ -1,432 +0,0 @@
|
|||||||
"""
|
|
||||||
模型映射解析服务
|
|
||||||
|
|
||||||
负责统一的模型别名/降级解析,按优先级顺序:
|
|
||||||
1. 映射(mapping):Provider 特定 → 全局
|
|
||||||
2. 别名(alias):Provider 特定 → 全局
|
|
||||||
3. 直接匹配 GlobalModel.name
|
|
||||||
|
|
||||||
支持特性:
|
|
||||||
- 带缓存(本地或 Redis),减少数据库访问
|
|
||||||
- 提供模糊匹配能力,用于提示相似模型
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
from src.core.logger import logger
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from src.config.constants import CacheSize, CacheTTL
|
|
||||||
from src.core.logger import logger
|
|
||||||
from src.models.database import GlobalModel, ModelMapping
|
|
||||||
from src.services.cache.backend import BaseCacheBackend, get_cache_backend
|
|
||||||
|
|
||||||
|
|
||||||
class ModelMappingResolver:
|
|
||||||
"""统一的 ModelMapping 解析服务(可跨进程共享缓存)。"""
|
|
||||||
|
|
||||||
def __init__(self, cache_ttl: int = CacheTTL.MODEL_MAPPING, cache_backend_type: str = "auto"):
|
|
||||||
self._cache_ttl = cache_ttl
|
|
||||||
self._cache_backend_type = cache_backend_type
|
|
||||||
self._mapping_cache: Optional[BaseCacheBackend] = None
|
|
||||||
self._global_model_cache: Optional[BaseCacheBackend] = None
|
|
||||||
self._initialized = False
|
|
||||||
self._stats = {
|
|
||||||
"mapping_hits": 0,
|
|
||||||
"mapping_misses": 0,
|
|
||||||
"global_hits": 0,
|
|
||||||
"global_misses": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _ensure_initialized(self):
|
|
||||||
if self._initialized:
|
|
||||||
return
|
|
||||||
|
|
||||||
self._mapping_cache = await get_cache_backend(
|
|
||||||
name="model_mapping_resolver:mapping",
|
|
||||||
backend_type=self._cache_backend_type,
|
|
||||||
max_size=CacheSize.MODEL_MAPPING,
|
|
||||||
ttl=self._cache_ttl,
|
|
||||||
)
|
|
||||||
self._global_model_cache = await get_cache_backend(
|
|
||||||
name="model_mapping_resolver:global",
|
|
||||||
backend_type=self._cache_backend_type,
|
|
||||||
max_size=CacheSize.MODEL_MAPPING,
|
|
||||||
ttl=self._cache_ttl,
|
|
||||||
)
|
|
||||||
self._initialized = True
|
|
||||||
logger.debug(f"[ModelMappingResolver] 缓存后端已初始化: {self._mapping_cache.get_stats()['backend']}")
|
|
||||||
|
|
||||||
def _cache_key(self, source_model: str, provider_id: Optional[str]) -> str:
|
|
||||||
return f"{provider_id or 'global'}:{source_model}"
|
|
||||||
|
|
||||||
async def _lookup_target_global_model_id(
|
|
||||||
self,
|
|
||||||
db: Session,
|
|
||||||
source_model: str,
|
|
||||||
provider_id: Optional[str] = None,
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
按优先级查找目标 GlobalModel ID:
|
|
||||||
1. 映射(mapping_type='mapping'):Provider 特定 → 全局
|
|
||||||
2. 别名(mapping_type='alias'):Provider 特定 → 全局
|
|
||||||
3. 直接匹配 GlobalModel.name
|
|
||||||
"""
|
|
||||||
await self._ensure_initialized()
|
|
||||||
cache_key = self._cache_key(source_model, provider_id)
|
|
||||||
cached = await self._mapping_cache.get(cache_key)
|
|
||||||
if cached is not None:
|
|
||||||
self._stats["mapping_hits"] += 1
|
|
||||||
return cached or None
|
|
||||||
|
|
||||||
self._stats["mapping_misses"] += 1
|
|
||||||
|
|
||||||
target_id: Optional[str] = None
|
|
||||||
|
|
||||||
# 优先级 1:查找映射(mapping_type='mapping')
|
|
||||||
# 1.1 Provider 特定映射
|
|
||||||
if provider_id:
|
|
||||||
mapping = (
|
|
||||||
db.query(ModelMapping)
|
|
||||||
.filter(
|
|
||||||
ModelMapping.source_model == source_model,
|
|
||||||
ModelMapping.provider_id == provider_id,
|
|
||||||
ModelMapping.mapping_type == "mapping",
|
|
||||||
ModelMapping.is_active == True,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if mapping:
|
|
||||||
target_id = mapping.target_global_model_id
|
|
||||||
logger.debug(f"[MappingResolver] 命中 Provider 映射: {source_model} -> {target_id[:8]}...")
|
|
||||||
|
|
||||||
# 1.2 全局映射
|
|
||||||
if not target_id:
|
|
||||||
mapping = (
|
|
||||||
db.query(ModelMapping)
|
|
||||||
.filter(
|
|
||||||
ModelMapping.source_model == source_model,
|
|
||||||
ModelMapping.provider_id.is_(None),
|
|
||||||
ModelMapping.mapping_type == "mapping",
|
|
||||||
ModelMapping.is_active == True,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if mapping:
|
|
||||||
target_id = mapping.target_global_model_id
|
|
||||||
logger.debug(f"[MappingResolver] 命中全局映射: {source_model} -> {target_id[:8]}...")
|
|
||||||
|
|
||||||
# 优先级 2:查找别名(mapping_type='alias')
|
|
||||||
# 2.1 Provider 特定别名
|
|
||||||
if not target_id and provider_id:
|
|
||||||
alias = (
|
|
||||||
db.query(ModelMapping)
|
|
||||||
.filter(
|
|
||||||
ModelMapping.source_model == source_model,
|
|
||||||
ModelMapping.provider_id == provider_id,
|
|
||||||
ModelMapping.mapping_type == "alias",
|
|
||||||
ModelMapping.is_active == True,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if alias:
|
|
||||||
target_id = alias.target_global_model_id
|
|
||||||
logger.debug(f"[MappingResolver] 命中 Provider 别名: {source_model} -> {target_id[:8]}...")
|
|
||||||
|
|
||||||
# 2.2 全局别名
|
|
||||||
if not target_id:
|
|
||||||
alias = (
|
|
||||||
db.query(ModelMapping)
|
|
||||||
.filter(
|
|
||||||
ModelMapping.source_model == source_model,
|
|
||||||
ModelMapping.provider_id.is_(None),
|
|
||||||
ModelMapping.mapping_type == "alias",
|
|
||||||
ModelMapping.is_active == True,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if alias:
|
|
||||||
target_id = alias.target_global_model_id
|
|
||||||
logger.debug(f"[MappingResolver] 命中全局别名: {source_model} -> {target_id[:8]}...")
|
|
||||||
|
|
||||||
# 优先级 3:直接匹配 GlobalModel.name
|
|
||||||
if not target_id:
|
|
||||||
global_model = (
|
|
||||||
db.query(GlobalModel)
|
|
||||||
.filter(
|
|
||||||
GlobalModel.name == source_model,
|
|
||||||
GlobalModel.is_active == True,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if global_model:
|
|
||||||
target_id = global_model.id
|
|
||||||
logger.debug(f"[MappingResolver] 直接匹配 GlobalModel: {source_model}")
|
|
||||||
|
|
||||||
cached_value = target_id if target_id is not None else ""
|
|
||||||
await self._mapping_cache.set(cache_key, cached_value, self._cache_ttl)
|
|
||||||
return target_id
|
|
||||||
|
|
||||||
async def resolve_to_global_model_name(
|
|
||||||
self,
|
|
||||||
db: Session,
|
|
||||||
source_model: str,
|
|
||||||
provider_id: Optional[str] = None,
|
|
||||||
) -> str:
|
|
||||||
"""解析模型名/别名为 GlobalModel.name。未找到时返回原始输入。"""
|
|
||||||
target_id = await self._lookup_target_global_model_id(db, source_model, provider_id)
|
|
||||||
if not target_id:
|
|
||||||
return source_model
|
|
||||||
|
|
||||||
await self._ensure_initialized()
|
|
||||||
cached_name = await self._global_model_cache.get(target_id)
|
|
||||||
if cached_name:
|
|
||||||
self._stats["global_hits"] += 1
|
|
||||||
return cached_name
|
|
||||||
|
|
||||||
self._stats["global_misses"] += 1
|
|
||||||
global_model = (
|
|
||||||
db.query(GlobalModel)
|
|
||||||
.filter(GlobalModel.id == target_id, GlobalModel.is_active == True)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if global_model:
|
|
||||||
await self._global_model_cache.set(target_id, global_model.name, self._cache_ttl)
|
|
||||||
return global_model.name
|
|
||||||
|
|
||||||
return source_model
|
|
||||||
|
|
||||||
async def get_global_model_by_request(
|
|
||||||
self,
|
|
||||||
db: Session,
|
|
||||||
source_model: str,
|
|
||||||
provider_id: Optional[str] = None,
|
|
||||||
) -> Optional[GlobalModel]:
|
|
||||||
"""解析并返回 GlobalModel 对象(绑定当前 Session)。"""
|
|
||||||
target_id = await self._lookup_target_global_model_id(db, source_model, provider_id)
|
|
||||||
if not target_id:
|
|
||||||
return None
|
|
||||||
|
|
||||||
global_model = (
|
|
||||||
db.query(GlobalModel)
|
|
||||||
.filter(GlobalModel.id == target_id, GlobalModel.is_active == True)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
return global_model
|
|
||||||
|
|
||||||
async def get_global_model_with_mapping_info(
|
|
||||||
self,
|
|
||||||
db: Session,
|
|
||||||
source_model: str,
|
|
||||||
provider_id: Optional[str] = None,
|
|
||||||
) -> Tuple[Optional[GlobalModel], bool]:
|
|
||||||
"""
|
|
||||||
解析并返回 GlobalModel 对象,同时返回是否发生了映射。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: 数据库会话
|
|
||||||
source_model: 用户请求的模型名
|
|
||||||
provider_id: Provider ID(可选)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(global_model, is_mapped) - GlobalModel 对象和是否发生了映射
|
|
||||||
is_mapped=True 表示 source_model 通过 mapping 规则映射到了不同的模型
|
|
||||||
is_mapped=False 表示 source_model 直接匹配或通过 alias 匹配
|
|
||||||
"""
|
|
||||||
await self._ensure_initialized()
|
|
||||||
|
|
||||||
# 先检查是否存在 mapping 类型的映射规则
|
|
||||||
has_mapping = False
|
|
||||||
|
|
||||||
# 检查 Provider 特定映射
|
|
||||||
if provider_id:
|
|
||||||
mapping = (
|
|
||||||
db.query(ModelMapping)
|
|
||||||
.filter(
|
|
||||||
ModelMapping.source_model == source_model,
|
|
||||||
ModelMapping.provider_id == provider_id,
|
|
||||||
ModelMapping.mapping_type == "mapping",
|
|
||||||
ModelMapping.is_active == True,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if mapping:
|
|
||||||
has_mapping = True
|
|
||||||
|
|
||||||
# 检查全局映射
|
|
||||||
if not has_mapping:
|
|
||||||
mapping = (
|
|
||||||
db.query(ModelMapping)
|
|
||||||
.filter(
|
|
||||||
ModelMapping.source_model == source_model,
|
|
||||||
ModelMapping.provider_id.is_(None),
|
|
||||||
ModelMapping.mapping_type == "mapping",
|
|
||||||
ModelMapping.is_active == True,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if mapping:
|
|
||||||
has_mapping = True
|
|
||||||
|
|
||||||
# 获取 GlobalModel
|
|
||||||
global_model = await self.get_global_model_by_request(db, source_model, provider_id)
|
|
||||||
|
|
||||||
return global_model, has_mapping
|
|
||||||
|
|
||||||
async def get_global_model_direct(
|
|
||||||
self,
|
|
||||||
db: Session,
|
|
||||||
source_model: str,
|
|
||||||
) -> Optional[GlobalModel]:
|
|
||||||
"""
|
|
||||||
直接通过模型名获取 GlobalModel,不应用任何映射规则。
|
|
||||||
仅查找 alias 和直接匹配。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: 数据库会话
|
|
||||||
source_model: 模型名
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
GlobalModel 对象或 None
|
|
||||||
"""
|
|
||||||
# 优先级 1:查找别名(alias)
|
|
||||||
# 全局别名
|
|
||||||
alias = (
|
|
||||||
db.query(ModelMapping)
|
|
||||||
.filter(
|
|
||||||
ModelMapping.source_model == source_model,
|
|
||||||
ModelMapping.provider_id.is_(None),
|
|
||||||
ModelMapping.mapping_type == "alias",
|
|
||||||
ModelMapping.is_active == True,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if alias:
|
|
||||||
global_model = (
|
|
||||||
db.query(GlobalModel)
|
|
||||||
.filter(GlobalModel.id == alias.target_global_model_id, GlobalModel.is_active == True)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if global_model:
|
|
||||||
return global_model
|
|
||||||
|
|
||||||
# 优先级 2:直接匹配 GlobalModel.name
|
|
||||||
global_model = (
|
|
||||||
db.query(GlobalModel)
|
|
||||||
.filter(
|
|
||||||
GlobalModel.name == source_model,
|
|
||||||
GlobalModel.is_active == True,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
return global_model
|
|
||||||
|
|
||||||
def find_similar_models(
|
|
||||||
self,
|
|
||||||
db: Session,
|
|
||||||
invalid_model: str,
|
|
||||||
limit: int = 3,
|
|
||||||
threshold: float = 0.4,
|
|
||||||
) -> List[Tuple[str, float]]:
|
|
||||||
"""用于提示相似的 GlobalModel.name。"""
|
|
||||||
from difflib import SequenceMatcher
|
|
||||||
|
|
||||||
all_models = db.query(GlobalModel.name).filter(GlobalModel.is_active == True).all()
|
|
||||||
similarities: List[Tuple[str, float]] = []
|
|
||||||
invalid_lower = invalid_model.lower()
|
|
||||||
|
|
||||||
for model in all_models:
|
|
||||||
model_name = model.name
|
|
||||||
ratio = SequenceMatcher(None, invalid_lower, model_name.lower()).ratio()
|
|
||||||
if invalid_lower in model_name.lower() or model_name.lower() in invalid_lower:
|
|
||||||
ratio += 0.2
|
|
||||||
if ratio >= threshold:
|
|
||||||
similarities.append((model_name, ratio))
|
|
||||||
|
|
||||||
similarities.sort(key=lambda item: item[1], reverse=True)
|
|
||||||
return similarities[:limit]
|
|
||||||
|
|
||||||
async def invalidate_mapping_cache(self, source_model: str, provider_id: Optional[str] = None):
|
|
||||||
await self._ensure_initialized()
|
|
||||||
keys = [self._cache_key(source_model, provider_id)]
|
|
||||||
if provider_id:
|
|
||||||
keys.append(self._cache_key(source_model, None))
|
|
||||||
for key in keys:
|
|
||||||
await self._mapping_cache.delete(key)
|
|
||||||
|
|
||||||
async def invalidate_global_model_cache(self, global_model_id: Optional[str] = None):
|
|
||||||
await self._ensure_initialized()
|
|
||||||
if global_model_id:
|
|
||||||
await self._global_model_cache.delete(global_model_id)
|
|
||||||
else:
|
|
||||||
await self._global_model_cache.clear()
|
|
||||||
|
|
||||||
async def clear_cache(self):
|
|
||||||
await self._ensure_initialized()
|
|
||||||
await self._mapping_cache.clear()
|
|
||||||
await self._global_model_cache.clear()
|
|
||||||
|
|
||||||
def get_stats(self) -> dict:
|
|
||||||
total_mapping = self._stats["mapping_hits"] + self._stats["mapping_misses"]
|
|
||||||
total_global = self._stats["global_hits"] + self._stats["global_misses"]
|
|
||||||
stats = {
|
|
||||||
"mapping_hit_rate": (
|
|
||||||
self._stats["mapping_hits"] / total_mapping if total_mapping else 0.0
|
|
||||||
),
|
|
||||||
"global_hit_rate": self._stats["global_hits"] / total_global if total_global else 0.0,
|
|
||||||
"stats": self._stats,
|
|
||||||
}
|
|
||||||
if self._initialized:
|
|
||||||
stats["mapping_cache_backend"] = self._mapping_cache.get_stats()
|
|
||||||
stats["global_cache_backend"] = self._global_model_cache.get_stats()
|
|
||||||
return stats
|
|
||||||
|
|
||||||
|
|
||||||
_model_mapping_resolver: Optional[ModelMappingResolver] = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_mapping_resolver(
|
|
||||||
cache_ttl: int = 300, cache_backend_type: Optional[str] = None
|
|
||||||
) -> ModelMappingResolver:
|
|
||||||
global _model_mapping_resolver
|
|
||||||
|
|
||||||
if _model_mapping_resolver is None:
|
|
||||||
if cache_backend_type is None:
|
|
||||||
cache_backend_type = os.getenv("ALIAS_CACHE_BACKEND", "auto")
|
|
||||||
_model_mapping_resolver = ModelMappingResolver(
|
|
||||||
cache_ttl=cache_ttl,
|
|
||||||
cache_backend_type=cache_backend_type,
|
|
||||||
)
|
|
||||||
logger.debug(f"[ModelMappingResolver] 初始化(cache_ttl={cache_ttl}s, backend={cache_backend_type})")
|
|
||||||
|
|
||||||
# 注册到缓存失效服务
|
|
||||||
try:
|
|
||||||
from src.services.cache.invalidation import get_cache_invalidation_service
|
|
||||||
|
|
||||||
cache_service = get_cache_invalidation_service()
|
|
||||||
cache_service.set_mapping_resolver(_model_mapping_resolver)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(f"[ModelMappingResolver] 注册缓存失效服务失败: {exc}")
|
|
||||||
|
|
||||||
return _model_mapping_resolver
|
|
||||||
|
|
||||||
|
|
||||||
async def resolve_model_to_global_name(
|
|
||||||
db: Session,
|
|
||||||
source_model: str,
|
|
||||||
provider_id: Optional[str] = None,
|
|
||||||
) -> str:
|
|
||||||
resolver = get_model_mapping_resolver()
|
|
||||||
return await resolver.resolve_to_global_model_name(db, source_model, provider_id)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_global_model_by_request(
|
|
||||||
db: Session,
|
|
||||||
source_model: str,
|
|
||||||
provider_id: Optional[str] = None,
|
|
||||||
) -> Optional[GlobalModel]:
|
|
||||||
resolver = get_model_mapping_resolver()
|
|
||||||
return await resolver.get_global_model_by_request(db, source_model, provider_id)
|
|
||||||
Reference in New Issue
Block a user