mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-02 15:52:26 +08:00
refactor(backend): remove model mappings module
This commit is contained in:
@@ -6,11 +6,9 @@ from fastapi import APIRouter
|
||||
|
||||
from .catalog import router as catalog_router
|
||||
from .global_models import router as global_models_router
|
||||
from .mappings import router as mappings_router
|
||||
|
||||
router = APIRouter(prefix="/api/admin/models", tags=["Admin - Model Management"])
|
||||
|
||||
# 挂载子路由
|
||||
router.include_router(catalog_router)
|
||||
router.include_router(global_models_router)
|
||||
router.include_router(mappings_router)
|
||||
|
||||
@@ -1,303 +0,0 @@
|
||||
"""模型映射管理 API
|
||||
|
||||
提供模型映射的 CRUD 操作。
|
||||
|
||||
模型映射(Mapping)用于将源模型映射到目标模型,例如:
|
||||
- 请求 gpt-5.1 → Provider A 映射到 gpt-4
|
||||
- 用于处理 Provider 不支持请求模型的情况
|
||||
|
||||
映射必须关联到特定的 Provider。
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.api import (
|
||||
ModelMappingCreate,
|
||||
ModelMappingResponse,
|
||||
ModelMappingUpdate,
|
||||
)
|
||||
from src.models.database import GlobalModel, ModelMapping, Provider, User
|
||||
from src.services.cache.invalidation import get_cache_invalidation_service
|
||||
|
||||
|
||||
router = APIRouter(prefix="/mappings", tags=["Model Mappings"])
|
||||
|
||||
|
||||
def _serialize_mapping(mapping: ModelMapping) -> ModelMappingResponse:
|
||||
target = mapping.target_global_model
|
||||
provider = mapping.provider
|
||||
scope = "provider" if mapping.provider_id else "global"
|
||||
return ModelMappingResponse(
|
||||
id=mapping.id,
|
||||
source_model=mapping.source_model,
|
||||
target_global_model_id=mapping.target_global_model_id,
|
||||
target_global_model_name=target.name if target else None,
|
||||
target_global_model_display_name=target.display_name if target else None,
|
||||
provider_id=mapping.provider_id,
|
||||
provider_name=provider.name if provider else None,
|
||||
scope=scope,
|
||||
mapping_type=mapping.mapping_type,
|
||||
is_active=mapping.is_active,
|
||||
created_at=mapping.created_at,
|
||||
updated_at=mapping.updated_at,
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=List[ModelMappingResponse])
|
||||
async def list_mappings(
|
||||
provider_id: Optional[str] = Query(None, description="按 Provider 筛选"),
|
||||
source_model: Optional[str] = Query(None, description="按源模型名筛选"),
|
||||
target_global_model_id: Optional[str] = Query(None, description="按目标模型筛选"),
|
||||
scope: Optional[str] = Query(None, description="global 或 provider"),
|
||||
mapping_type: Optional[str] = Query(None, description="映射类型: alias 或 mapping"),
|
||||
is_active: Optional[bool] = Query(None, description="按状态筛选"),
|
||||
skip: int = Query(0, ge=0, description="跳过记录数"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回记录数"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取模型映射列表"""
|
||||
query = db.query(ModelMapping).options(
|
||||
joinedload(ModelMapping.target_global_model),
|
||||
joinedload(ModelMapping.provider),
|
||||
)
|
||||
|
||||
if provider_id is not None:
|
||||
query = query.filter(ModelMapping.provider_id == provider_id)
|
||||
if scope == "global":
|
||||
query = query.filter(ModelMapping.provider_id.is_(None))
|
||||
elif scope == "provider":
|
||||
query = query.filter(ModelMapping.provider_id.isnot(None))
|
||||
if mapping_type is not None:
|
||||
query = query.filter(ModelMapping.mapping_type == mapping_type)
|
||||
if source_model:
|
||||
query = query.filter(ModelMapping.source_model.ilike(f"%{source_model}%"))
|
||||
if target_global_model_id is not None:
|
||||
query = query.filter(ModelMapping.target_global_model_id == target_global_model_id)
|
||||
if is_active is not None:
|
||||
query = query.filter(ModelMapping.is_active == is_active)
|
||||
|
||||
mappings = query.offset(skip).limit(limit).all()
|
||||
return [_serialize_mapping(mapping) for mapping in mappings]
|
||||
|
||||
|
||||
@router.get("/{mapping_id}", response_model=ModelMappingResponse)
|
||||
async def get_mapping(
|
||||
mapping_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取单个模型映射"""
|
||||
mapping = (
|
||||
db.query(ModelMapping)
|
||||
.options(
|
||||
joinedload(ModelMapping.target_global_model),
|
||||
joinedload(ModelMapping.provider),
|
||||
)
|
||||
.filter(ModelMapping.id == mapping_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not mapping:
|
||||
raise HTTPException(status_code=404, detail=f"映射 {mapping_id} 不存在")
|
||||
|
||||
return _serialize_mapping(mapping)
|
||||
|
||||
|
||||
@router.post("", response_model=ModelMappingResponse, status_code=201)
|
||||
async def create_mapping(
|
||||
data: ModelMappingCreate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""创建模型映射"""
|
||||
source_model = data.source_model.strip()
|
||||
if not source_model:
|
||||
raise HTTPException(status_code=400, detail="source_model 不能为空")
|
||||
|
||||
# 验证 mapping_type
|
||||
if data.mapping_type not in ("alias", "mapping"):
|
||||
raise HTTPException(status_code=400, detail="mapping_type 必须是 'alias' 或 'mapping'")
|
||||
|
||||
# 验证目标 GlobalModel 存在
|
||||
target_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(GlobalModel.id == data.target_global_model_id, GlobalModel.is_active == True)
|
||||
.first()
|
||||
)
|
||||
if not target_model:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"目标模型 {data.target_global_model_id} 不存在或未激活"
|
||||
)
|
||||
|
||||
# 验证 Provider 存在
|
||||
provider = None
|
||||
provider_id = data.provider_id
|
||||
if provider_id:
|
||||
provider = db.query(Provider).filter(Provider.id == provider_id).first()
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail=f"Provider {provider_id} 不存在")
|
||||
|
||||
# 检查映射是否已存在(全局或同一 Provider 下不可重复)
|
||||
existing = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == source_model,
|
||||
ModelMapping.provider_id == provider_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="映射已存在")
|
||||
|
||||
# 创建映射
|
||||
mapping = ModelMapping(
|
||||
id=str(uuid.uuid4()),
|
||||
source_model=source_model,
|
||||
target_global_model_id=data.target_global_model_id,
|
||||
provider_id=provider_id,
|
||||
mapping_type=data.mapping_type,
|
||||
is_active=data.is_active,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
db.add(mapping)
|
||||
db.commit()
|
||||
mapping = (
|
||||
db.query(ModelMapping)
|
||||
.options(
|
||||
joinedload(ModelMapping.target_global_model),
|
||||
joinedload(ModelMapping.provider),
|
||||
)
|
||||
.filter(ModelMapping.id == mapping.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
logger.info(f"创建模型映射: {source_model} -> {target_model.name} "
|
||||
f"(Provider: {provider.name if provider else 'global'}, ID: {mapping.id})")
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_mapping_changed(source_model, provider_id)
|
||||
|
||||
return _serialize_mapping(mapping)
|
||||
|
||||
|
||||
@router.patch("/{mapping_id}", response_model=ModelMappingResponse)
|
||||
async def update_mapping(
|
||||
mapping_id: str,
|
||||
data: ModelMappingUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""更新模型映射"""
|
||||
mapping = db.query(ModelMapping).filter(ModelMapping.id == mapping_id).first()
|
||||
if not mapping:
|
||||
raise HTTPException(status_code=404, detail=f"映射 {mapping_id} 不存在")
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
|
||||
# 更新 Provider
|
||||
if "provider_id" in update_data:
|
||||
new_provider_id = update_data["provider_id"]
|
||||
if new_provider_id:
|
||||
provider = db.query(Provider).filter(Provider.id == new_provider_id).first()
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail=f"Provider {new_provider_id} 不存在")
|
||||
mapping.provider_id = new_provider_id
|
||||
|
||||
# 更新目标模型
|
||||
if "target_global_model_id" in update_data:
|
||||
target_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.id == update_data["target_global_model_id"],
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not target_model:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"目标模型 {update_data['target_global_model_id']} 不存在或未激活",
|
||||
)
|
||||
mapping.target_global_model_id = update_data["target_global_model_id"]
|
||||
|
||||
# 更新源模型名
|
||||
if "source_model" in update_data:
|
||||
new_source = update_data["source_model"].strip()
|
||||
if not new_source:
|
||||
raise HTTPException(status_code=400, detail="source_model 不能为空")
|
||||
mapping.source_model = new_source
|
||||
|
||||
# 检查唯一约束
|
||||
duplicate = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == mapping.source_model,
|
||||
ModelMapping.provider_id == mapping.provider_id,
|
||||
ModelMapping.id != mapping_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if duplicate:
|
||||
raise HTTPException(status_code=400, detail="映射已存在")
|
||||
|
||||
# 更新映射类型
|
||||
if "mapping_type" in update_data:
|
||||
if update_data["mapping_type"] not in ("alias", "mapping"):
|
||||
raise HTTPException(status_code=400, detail="mapping_type 必须是 'alias' 或 'mapping'")
|
||||
mapping.mapping_type = update_data["mapping_type"]
|
||||
|
||||
# 更新状态
|
||||
if "is_active" in update_data:
|
||||
mapping.is_active = update_data["is_active"]
|
||||
|
||||
mapping.updated_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
db.refresh(mapping)
|
||||
|
||||
logger.info(f"更新模型映射 (ID: {mapping.id})")
|
||||
|
||||
mapping = (
|
||||
db.query(ModelMapping)
|
||||
.options(
|
||||
joinedload(ModelMapping.target_global_model),
|
||||
joinedload(ModelMapping.provider),
|
||||
)
|
||||
.filter(ModelMapping.id == mapping.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_mapping_changed(mapping.source_model, mapping.provider_id)
|
||||
|
||||
return _serialize_mapping(mapping)
|
||||
|
||||
|
||||
@router.delete("/{mapping_id}", status_code=204)
|
||||
async def delete_mapping(
|
||||
mapping_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除模型映射"""
|
||||
mapping = db.query(ModelMapping).filter(ModelMapping.id == mapping_id).first()
|
||||
|
||||
if not mapping:
|
||||
raise HTTPException(status_code=404, detail=f"映射 {mapping_id} 不存在")
|
||||
|
||||
source_model = mapping.source_model
|
||||
provider_id = mapping.provider_id
|
||||
|
||||
logger.info(f"删除模型映射: {source_model} -> {mapping.target_global_model_id} (ID: {mapping.id})")
|
||||
|
||||
db.delete(mapping)
|
||||
db.commit()
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_mapping_changed(source_model, provider_id)
|
||||
|
||||
return None
|
||||
@@ -1,432 +0,0 @@
|
||||
"""
|
||||
模型映射解析服务
|
||||
|
||||
负责统一的模型别名/降级解析,按优先级顺序:
|
||||
1. 映射(mapping):Provider 特定 → 全局
|
||||
2. 别名(alias):Provider 特定 → 全局
|
||||
3. 直接匹配 GlobalModel.name
|
||||
|
||||
支持特性:
|
||||
- 带缓存(本地或 Redis),减少数据库访问
|
||||
- 提供模糊匹配能力,用于提示相似模型
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from src.core.logger import logger
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.config.constants import CacheSize, CacheTTL
|
||||
from src.core.logger import logger
|
||||
from src.models.database import GlobalModel, ModelMapping
|
||||
from src.services.cache.backend import BaseCacheBackend, get_cache_backend
|
||||
|
||||
|
||||
class ModelMappingResolver:
|
||||
"""统一的 ModelMapping 解析服务(可跨进程共享缓存)。"""
|
||||
|
||||
def __init__(self, cache_ttl: int = CacheTTL.MODEL_MAPPING, cache_backend_type: str = "auto"):
|
||||
self._cache_ttl = cache_ttl
|
||||
self._cache_backend_type = cache_backend_type
|
||||
self._mapping_cache: Optional[BaseCacheBackend] = None
|
||||
self._global_model_cache: Optional[BaseCacheBackend] = None
|
||||
self._initialized = False
|
||||
self._stats = {
|
||||
"mapping_hits": 0,
|
||||
"mapping_misses": 0,
|
||||
"global_hits": 0,
|
||||
"global_misses": 0,
|
||||
}
|
||||
|
||||
async def _ensure_initialized(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._mapping_cache = await get_cache_backend(
|
||||
name="model_mapping_resolver:mapping",
|
||||
backend_type=self._cache_backend_type,
|
||||
max_size=CacheSize.MODEL_MAPPING,
|
||||
ttl=self._cache_ttl,
|
||||
)
|
||||
self._global_model_cache = await get_cache_backend(
|
||||
name="model_mapping_resolver:global",
|
||||
backend_type=self._cache_backend_type,
|
||||
max_size=CacheSize.MODEL_MAPPING,
|
||||
ttl=self._cache_ttl,
|
||||
)
|
||||
self._initialized = True
|
||||
logger.debug(f"[ModelMappingResolver] 缓存后端已初始化: {self._mapping_cache.get_stats()['backend']}")
|
||||
|
||||
def _cache_key(self, source_model: str, provider_id: Optional[str]) -> str:
|
||||
return f"{provider_id or 'global'}:{source_model}"
|
||||
|
||||
async def _lookup_target_global_model_id(
|
||||
self,
|
||||
db: Session,
|
||||
source_model: str,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
按优先级查找目标 GlobalModel ID:
|
||||
1. 映射(mapping_type='mapping'):Provider 特定 → 全局
|
||||
2. 别名(mapping_type='alias'):Provider 特定 → 全局
|
||||
3. 直接匹配 GlobalModel.name
|
||||
"""
|
||||
await self._ensure_initialized()
|
||||
cache_key = self._cache_key(source_model, provider_id)
|
||||
cached = await self._mapping_cache.get(cache_key)
|
||||
if cached is not None:
|
||||
self._stats["mapping_hits"] += 1
|
||||
return cached or None
|
||||
|
||||
self._stats["mapping_misses"] += 1
|
||||
|
||||
target_id: Optional[str] = None
|
||||
|
||||
# 优先级 1:查找映射(mapping_type='mapping')
|
||||
# 1.1 Provider 特定映射
|
||||
if provider_id:
|
||||
mapping = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == source_model,
|
||||
ModelMapping.provider_id == provider_id,
|
||||
ModelMapping.mapping_type == "mapping",
|
||||
ModelMapping.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if mapping:
|
||||
target_id = mapping.target_global_model_id
|
||||
logger.debug(f"[MappingResolver] 命中 Provider 映射: {source_model} -> {target_id[:8]}...")
|
||||
|
||||
# 1.2 全局映射
|
||||
if not target_id:
|
||||
mapping = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == source_model,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
ModelMapping.mapping_type == "mapping",
|
||||
ModelMapping.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if mapping:
|
||||
target_id = mapping.target_global_model_id
|
||||
logger.debug(f"[MappingResolver] 命中全局映射: {source_model} -> {target_id[:8]}...")
|
||||
|
||||
# 优先级 2:查找别名(mapping_type='alias')
|
||||
# 2.1 Provider 特定别名
|
||||
if not target_id and provider_id:
|
||||
alias = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == source_model,
|
||||
ModelMapping.provider_id == provider_id,
|
||||
ModelMapping.mapping_type == "alias",
|
||||
ModelMapping.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if alias:
|
||||
target_id = alias.target_global_model_id
|
||||
logger.debug(f"[MappingResolver] 命中 Provider 别名: {source_model} -> {target_id[:8]}...")
|
||||
|
||||
# 2.2 全局别名
|
||||
if not target_id:
|
||||
alias = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == source_model,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
ModelMapping.mapping_type == "alias",
|
||||
ModelMapping.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if alias:
|
||||
target_id = alias.target_global_model_id
|
||||
logger.debug(f"[MappingResolver] 命中全局别名: {source_model} -> {target_id[:8]}...")
|
||||
|
||||
# 优先级 3:直接匹配 GlobalModel.name
|
||||
if not target_id:
|
||||
global_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.name == source_model,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if global_model:
|
||||
target_id = global_model.id
|
||||
logger.debug(f"[MappingResolver] 直接匹配 GlobalModel: {source_model}")
|
||||
|
||||
cached_value = target_id if target_id is not None else ""
|
||||
await self._mapping_cache.set(cache_key, cached_value, self._cache_ttl)
|
||||
return target_id
|
||||
|
||||
async def resolve_to_global_model_name(
|
||||
self,
|
||||
db: Session,
|
||||
source_model: str,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""解析模型名/别名为 GlobalModel.name。未找到时返回原始输入。"""
|
||||
target_id = await self._lookup_target_global_model_id(db, source_model, provider_id)
|
||||
if not target_id:
|
||||
return source_model
|
||||
|
||||
await self._ensure_initialized()
|
||||
cached_name = await self._global_model_cache.get(target_id)
|
||||
if cached_name:
|
||||
self._stats["global_hits"] += 1
|
||||
return cached_name
|
||||
|
||||
self._stats["global_misses"] += 1
|
||||
global_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(GlobalModel.id == target_id, GlobalModel.is_active == True)
|
||||
.first()
|
||||
)
|
||||
if global_model:
|
||||
await self._global_model_cache.set(target_id, global_model.name, self._cache_ttl)
|
||||
return global_model.name
|
||||
|
||||
return source_model
|
||||
|
||||
async def get_global_model_by_request(
|
||||
self,
|
||||
db: Session,
|
||||
source_model: str,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> Optional[GlobalModel]:
|
||||
"""解析并返回 GlobalModel 对象(绑定当前 Session)。"""
|
||||
target_id = await self._lookup_target_global_model_id(db, source_model, provider_id)
|
||||
if not target_id:
|
||||
return None
|
||||
|
||||
global_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(GlobalModel.id == target_id, GlobalModel.is_active == True)
|
||||
.first()
|
||||
)
|
||||
return global_model
|
||||
|
||||
async def get_global_model_with_mapping_info(
|
||||
self,
|
||||
db: Session,
|
||||
source_model: str,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> Tuple[Optional[GlobalModel], bool]:
|
||||
"""
|
||||
解析并返回 GlobalModel 对象,同时返回是否发生了映射。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
source_model: 用户请求的模型名
|
||||
provider_id: Provider ID(可选)
|
||||
|
||||
Returns:
|
||||
(global_model, is_mapped) - GlobalModel 对象和是否发生了映射
|
||||
is_mapped=True 表示 source_model 通过 mapping 规则映射到了不同的模型
|
||||
is_mapped=False 表示 source_model 直接匹配或通过 alias 匹配
|
||||
"""
|
||||
await self._ensure_initialized()
|
||||
|
||||
# 先检查是否存在 mapping 类型的映射规则
|
||||
has_mapping = False
|
||||
|
||||
# 检查 Provider 特定映射
|
||||
if provider_id:
|
||||
mapping = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == source_model,
|
||||
ModelMapping.provider_id == provider_id,
|
||||
ModelMapping.mapping_type == "mapping",
|
||||
ModelMapping.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if mapping:
|
||||
has_mapping = True
|
||||
|
||||
# 检查全局映射
|
||||
if not has_mapping:
|
||||
mapping = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == source_model,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
ModelMapping.mapping_type == "mapping",
|
||||
ModelMapping.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if mapping:
|
||||
has_mapping = True
|
||||
|
||||
# 获取 GlobalModel
|
||||
global_model = await self.get_global_model_by_request(db, source_model, provider_id)
|
||||
|
||||
return global_model, has_mapping
|
||||
|
||||
async def get_global_model_direct(
|
||||
self,
|
||||
db: Session,
|
||||
source_model: str,
|
||||
) -> Optional[GlobalModel]:
|
||||
"""
|
||||
直接通过模型名获取 GlobalModel,不应用任何映射规则。
|
||||
仅查找 alias 和直接匹配。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
source_model: 模型名
|
||||
|
||||
Returns:
|
||||
GlobalModel 对象或 None
|
||||
"""
|
||||
# 优先级 1:查找别名(alias)
|
||||
# 全局别名
|
||||
alias = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == source_model,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
ModelMapping.mapping_type == "alias",
|
||||
ModelMapping.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if alias:
|
||||
global_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(GlobalModel.id == alias.target_global_model_id, GlobalModel.is_active == True)
|
||||
.first()
|
||||
)
|
||||
if global_model:
|
||||
return global_model
|
||||
|
||||
# 优先级 2:直接匹配 GlobalModel.name
|
||||
global_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.name == source_model,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return global_model
|
||||
|
||||
def find_similar_models(
|
||||
self,
|
||||
db: Session,
|
||||
invalid_model: str,
|
||||
limit: int = 3,
|
||||
threshold: float = 0.4,
|
||||
) -> List[Tuple[str, float]]:
|
||||
"""用于提示相似的 GlobalModel.name。"""
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
all_models = db.query(GlobalModel.name).filter(GlobalModel.is_active == True).all()
|
||||
similarities: List[Tuple[str, float]] = []
|
||||
invalid_lower = invalid_model.lower()
|
||||
|
||||
for model in all_models:
|
||||
model_name = model.name
|
||||
ratio = SequenceMatcher(None, invalid_lower, model_name.lower()).ratio()
|
||||
if invalid_lower in model_name.lower() or model_name.lower() in invalid_lower:
|
||||
ratio += 0.2
|
||||
if ratio >= threshold:
|
||||
similarities.append((model_name, ratio))
|
||||
|
||||
similarities.sort(key=lambda item: item[1], reverse=True)
|
||||
return similarities[:limit]
|
||||
|
||||
async def invalidate_mapping_cache(self, source_model: str, provider_id: Optional[str] = None):
|
||||
await self._ensure_initialized()
|
||||
keys = [self._cache_key(source_model, provider_id)]
|
||||
if provider_id:
|
||||
keys.append(self._cache_key(source_model, None))
|
||||
for key in keys:
|
||||
await self._mapping_cache.delete(key)
|
||||
|
||||
async def invalidate_global_model_cache(self, global_model_id: Optional[str] = None):
|
||||
await self._ensure_initialized()
|
||||
if global_model_id:
|
||||
await self._global_model_cache.delete(global_model_id)
|
||||
else:
|
||||
await self._global_model_cache.clear()
|
||||
|
||||
async def clear_cache(self):
|
||||
await self._ensure_initialized()
|
||||
await self._mapping_cache.clear()
|
||||
await self._global_model_cache.clear()
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
total_mapping = self._stats["mapping_hits"] + self._stats["mapping_misses"]
|
||||
total_global = self._stats["global_hits"] + self._stats["global_misses"]
|
||||
stats = {
|
||||
"mapping_hit_rate": (
|
||||
self._stats["mapping_hits"] / total_mapping if total_mapping else 0.0
|
||||
),
|
||||
"global_hit_rate": self._stats["global_hits"] / total_global if total_global else 0.0,
|
||||
"stats": self._stats,
|
||||
}
|
||||
if self._initialized:
|
||||
stats["mapping_cache_backend"] = self._mapping_cache.get_stats()
|
||||
stats["global_cache_backend"] = self._global_model_cache.get_stats()
|
||||
return stats
|
||||
|
||||
|
||||
_model_mapping_resolver: Optional[ModelMappingResolver] = None
|
||||
|
||||
|
||||
def get_model_mapping_resolver(
|
||||
cache_ttl: int = 300, cache_backend_type: Optional[str] = None
|
||||
) -> ModelMappingResolver:
|
||||
global _model_mapping_resolver
|
||||
|
||||
if _model_mapping_resolver is None:
|
||||
if cache_backend_type is None:
|
||||
cache_backend_type = os.getenv("ALIAS_CACHE_BACKEND", "auto")
|
||||
_model_mapping_resolver = ModelMappingResolver(
|
||||
cache_ttl=cache_ttl,
|
||||
cache_backend_type=cache_backend_type,
|
||||
)
|
||||
logger.debug(f"[ModelMappingResolver] 初始化(cache_ttl={cache_ttl}s, backend={cache_backend_type})")
|
||||
|
||||
# 注册到缓存失效服务
|
||||
try:
|
||||
from src.services.cache.invalidation import get_cache_invalidation_service
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.set_mapping_resolver(_model_mapping_resolver)
|
||||
except Exception as exc:
|
||||
logger.warning(f"[ModelMappingResolver] 注册缓存失效服务失败: {exc}")
|
||||
|
||||
return _model_mapping_resolver
|
||||
|
||||
|
||||
async def resolve_model_to_global_name(
|
||||
db: Session,
|
||||
source_model: str,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> str:
|
||||
resolver = get_model_mapping_resolver()
|
||||
return await resolver.resolve_to_global_model_name(db, source_model, provider_id)
|
||||
|
||||
|
||||
async def get_global_model_by_request(
|
||||
db: Session,
|
||||
source_model: str,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> Optional[GlobalModel]:
|
||||
resolver = get_model_mapping_resolver()
|
||||
return await resolver.get_global_model_by_request(db, source_model, provider_id)
|
||||
Reference in New Issue
Block a user