From 51b85915d29012e92b945b76f2c546a5a004cfaa Mon Sep 17 00:00:00 2001 From: fawney19 Date: Mon, 15 Dec 2025 18:13:28 +0800 Subject: [PATCH] feat(cache): implement model alias resolution with caching - Add resolve_global_model_by_name_or_alias() supporting direct match and alias lookup - Support both provider_model_name and provider_model_aliases matching - Implement caching for resolved models with TTL - Add conflict detection when alias maps to multiple GlobalModels - Record resolution metrics: method, cache hits, duration, conflicts - Fallback to Python-level filtering for non-PostgreSQL databases - Add cache invalidation methods for GlobalModel --- src/services/cache/model_cache.py | 224 ++++++++++++++++++++++++++++-- 1 file changed, 216 insertions(+), 8 deletions(-) diff --git a/src/services/cache/model_cache.py b/src/services/cache/model_cache.py index 1f9486b..e79172a 100644 --- a/src/services/cache/model_cache.py +++ b/src/services/cache/model_cache.py @@ -2,6 +2,8 @@ Model 映射缓存服务 - 减少模型查询 """ +import json +import time from typing import Optional from sqlalchemy.orm import Session @@ -9,6 +11,11 @@ from sqlalchemy.orm import Session from src.config.constants import CacheTTL from src.core.cache_service import CacheService from src.core.logger import logger +from src.core.metrics import ( + model_alias_conflict_total, + model_alias_resolution_duration_seconds, + model_alias_resolution_total, +) from src.models.database import GlobalModel, Model @@ -102,7 +109,9 @@ class ModelCacheService: # 1. 尝试从缓存获取 cached_data = await CacheService.get(cache_key) if cached_data: - logger.debug(f"Model 缓存命中(provider+global): {provider_id[:8]}...+{global_model_id[:8]}...") + logger.debug( + f"Model 缓存命中(provider+global): {provider_id[:8]}...+{global_model_id[:8]}..." + ) return ModelCacheService._dict_to_model(cached_data) # 2. 缓存未命中,查询数据库 @@ -120,7 +129,9 @@ class ModelCacheService: if model: model_dict = ModelCacheService._model_to_dict(model) await CacheService.set(cache_key, model_dict, ttl_seconds=ModelCacheService.CACHE_TTL) - logger.debug(f"Model 已缓存(provider+global): {provider_id[:8]}...+{global_model_id[:8]}...") + logger.debug( + f"Model 已缓存(provider+global): {provider_id[:8]}...+{global_model_id[:8]}..." + ) return model @@ -160,7 +171,7 @@ class ModelCacheService: @staticmethod async def invalidate_model_cache( model_id: str, provider_id: Optional[str] = None, global_model_id: Optional[str] = None - ): + ) -> None: """清除 Model 缓存 Args: @@ -174,18 +185,207 @@ class ModelCacheService: # 清除 provider_global 缓存(如果提供了必要参数) if provider_id and global_model_id: await CacheService.delete(f"model:provider_global:{provider_id}:{global_model_id}") - logger.debug(f"Model 缓存已清除: {model_id}, provider_global:{provider_id[:8]}...:{global_model_id[:8]}...") + logger.debug( + f"Model 缓存已清除: {model_id}, provider_global:{provider_id[:8]}...:{global_model_id[:8]}..." + ) else: logger.debug(f"Model 缓存已清除: {model_id}") @staticmethod - async def invalidate_global_model_cache(global_model_id: str, name: Optional[str] = None): + async def invalidate_global_model_cache(global_model_id: str, name: Optional[str] = None) -> None: """清除 GlobalModel 缓存""" await CacheService.delete(f"global_model:id:{global_model_id}") if name: await CacheService.delete(f"global_model:name:{name}") logger.debug(f"GlobalModel 缓存已清除: {global_model_id}") + @staticmethod + async def resolve_global_model_by_name_or_alias( + db: Session, model_name: str + ) -> Optional[GlobalModel]: + """ + 通过名称或别名解析 GlobalModel(带缓存,支持别名匹配) + + 查找顺序: + 1. 检查缓存 + 2. 直接匹配 GlobalModel.name + 3. 通过别名匹配(查询 Model 表的 provider_model_name 和 provider_model_aliases) + + Args: + db: 数据库会话 + model_name: 模型名称(可以是 GlobalModel.name 或别名) + + Returns: + GlobalModel 对象或 None + """ + start_time = time.time() + resolution_method = "not_found" + cache_hit = False + + normalized_name = model_name.strip() + if not normalized_name: + return None + + cache_key = f"global_model:resolve:{normalized_name}" + + try: + # 1. 尝试从缓存获取 + cached_data = await CacheService.get(cache_key) + if cached_data: + if cached_data == "NOT_FOUND": + # 缓存的负结果 + cache_hit = True + resolution_method = "not_found" + logger.debug(f"GlobalModel 缓存命中(别名解析-未找到): {normalized_name}") + return None + if isinstance(cached_data, dict) and "supported_capabilities" not in cached_data: + # 兼容旧缓存:字段不全时视为未命中,走 DB 刷新 + logger.debug(f"GlobalModel 缓存命中但 schema 过旧,刷新: {normalized_name}") + else: + cache_hit = True + resolution_method = "direct_match" # 缓存命中时无法区分原始解析方式 + logger.debug(f"GlobalModel 缓存命中(别名解析): {normalized_name}") + return ModelCacheService._dict_to_global_model(cached_data) + + # 2. 直接通过 GlobalModel.name 查找 + global_model = ( + db.query(GlobalModel) + .filter(GlobalModel.name == normalized_name, GlobalModel.is_active == True) + .first() + ) + + if global_model: + resolution_method = "direct_match" + # 缓存结果 + global_model_dict = ModelCacheService._global_model_to_dict(global_model) + await CacheService.set( + cache_key, global_model_dict, ttl_seconds=ModelCacheService.CACHE_TTL + ) + logger.debug(f"GlobalModel 已缓存(别名解析-直接匹配): {normalized_name}") + return global_model + + # 3. 通过别名匹配(优化:精确查询,避免加载所有 Model) + from sqlalchemy import or_ + + from src.models.database import Provider + + # 构建精确的别名匹配条件 + # 注意:provider_model_aliases 是 JSONB 数组,需要使用 PostgreSQL 的 JSONB 操作符 + # 对于 SQLite,会在 Python 层面进行过滤 + try: + # 尝试使用 PostgreSQL 的 JSONB 查询(更高效) + # 使用 json.dumps 确保正确转义特殊字符,避免 SQL 注入 + jsonb_pattern = json.dumps([{"name": normalized_name}]) + models_with_global = ( + db.query(Model, GlobalModel) + .join(Provider, Model.provider_id == Provider.id) + .join(GlobalModel, Model.global_model_id == GlobalModel.id) + .filter( + Provider.is_active == True, + Model.is_active == True, + GlobalModel.is_active == True, + or_( + Model.provider_model_name == normalized_name, + # PostgreSQL JSONB 查询:检查数组中是否有包含 {"name": "xxx"} 的元素 + Model.provider_model_aliases.op("@>")(jsonb_pattern), + ), + ) + .all() + ) + except Exception as e: + # 如果 JSONB 查询失败(如使用 SQLite),回退到加载所有活跃 Model 并在 Python 层过滤 + logger.debug( + f"JSONB 查询失败,回退到 Python 过滤: {e}", + ) + models_with_global = ( + db.query(Model, GlobalModel) + .join(Provider, Model.provider_id == Provider.id) + .join(GlobalModel, Model.global_model_id == GlobalModel.id) + .filter( + Provider.is_active == True, + Model.is_active == True, + GlobalModel.is_active == True, + ) + .all() + ) + + # 用于检测别名冲突 + matched_global_models = [] + + # 遍历查询结果进行匹配 + for model, gm in models_with_global: + # 检查 provider_model_name 是否匹配 + if model.provider_model_name == normalized_name: + matched_global_models.append((gm, "provider_model_name")) + logger.debug( + f"模型名称 '{normalized_name}' 通过 provider_model_name 匹配到 " + f"GlobalModel: {gm.name}" + ) + + # 检查 provider_model_aliases 是否匹配 + if model.provider_model_aliases: + for alias_entry in model.provider_model_aliases: + if isinstance(alias_entry, dict): + alias_name = alias_entry.get("name", "").strip() + if alias_name == normalized_name: + matched_global_models.append((gm, "alias")) + logger.debug( + f"模型名称 '{normalized_name}' 通过别名匹配到 " + f"GlobalModel: {gm.name}" + ) + break + + # 处理匹配结果 + if not matched_global_models: + resolution_method = "not_found" + # 未找到匹配,缓存负结果 + await CacheService.set( + cache_key, "NOT_FOUND", ttl_seconds=ModelCacheService.CACHE_TTL + ) + logger.debug(f"GlobalModel 未找到(别名解析): {normalized_name}") + return None + + # 优先使用 provider_model_name 的直接匹配,其次才是 aliases;同级别按名称排序保证确定性 + matched_global_models.sort( + key=lambda item: (0 if item[1] == "provider_model_name" else 1, item[0].name) + ) + + # 记录解析方式 + resolution_method = matched_global_models[0][1] + + if len(matched_global_models) > 1: + # 检测到别名冲突 + unique_models = {gm.id: gm for gm, _ in matched_global_models} + if len(unique_models) > 1: + model_names = [gm.name for gm in unique_models.values()] + logger.warning( + f"别名冲突: 模型名称 '{normalized_name}' 匹配到多个不同的 GlobalModel: " + f"{', '.join(model_names)},使用第一个匹配结果" + ) + # 记录别名冲突指标 + model_alias_conflict_total.inc() + + # 返回第一个匹配的 GlobalModel + result_global_model: GlobalModel = matched_global_models[0][0] + global_model_dict = ModelCacheService._global_model_to_dict(result_global_model) + await CacheService.set( + cache_key, global_model_dict, ttl_seconds=ModelCacheService.CACHE_TTL + ) + logger.debug( + f"GlobalModel 已缓存(别名解析): {normalized_name} -> {result_global_model.name}" + ) + return result_global_model + + finally: + # 记录监控指标 + duration = time.time() - start_time + model_alias_resolution_total.labels( + method=resolution_method, cache_hit=str(cache_hit).lower() + ).inc() + model_alias_resolution_duration_seconds.labels(method=resolution_method).observe( + duration + ) + @staticmethod def _model_to_dict(model: Model) -> dict: """将 Model 对象转换为字典""" @@ -243,6 +443,7 @@ class ModelCacheService: "default_supports_streaming": global_model.default_supports_streaming, "default_supports_extended_thinking": global_model.default_supports_extended_thinking, "default_supports_image_generation": global_model.default_supports_image_generation, + "supported_capabilities": global_model.supported_capabilities, "is_active": global_model.is_active, "description": global_model.description, } @@ -255,10 +456,17 @@ class ModelCacheService: name=global_model_dict["name"], display_name=global_model_dict.get("display_name"), default_supports_vision=global_model_dict.get("default_supports_vision", False), - default_supports_function_calling=global_model_dict.get("default_supports_function_calling", False), + default_supports_function_calling=global_model_dict.get( + "default_supports_function_calling", False + ), default_supports_streaming=global_model_dict.get("default_supports_streaming", True), - default_supports_extended_thinking=global_model_dict.get("default_supports_extended_thinking", False), - default_supports_image_generation=global_model_dict.get("default_supports_image_generation", False), + default_supports_extended_thinking=global_model_dict.get( + "default_supports_extended_thinking", False + ), + default_supports_image_generation=global_model_dict.get( + "default_supports_image_generation", False + ), + supported_capabilities=global_model_dict.get("supported_capabilities") or [], is_active=global_model_dict.get("is_active", True), description=global_model_dict.get("description"), )