mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-02 15:52:26 +08:00
feat(cache): implement model alias resolution with caching
- Add resolve_global_model_by_name_or_alias() supporting direct match and alias lookup - Support both provider_model_name and provider_model_aliases matching - Implement caching for resolved models with TTL - Add conflict detection when alias maps to multiple GlobalModels - Record resolution metrics: method, cache hits, duration, conflicts - Fallback to Python-level filtering for non-PostgreSQL databases - Add cache invalidation methods for GlobalModel
This commit is contained in:
224
src/services/cache/model_cache.py
vendored
224
src/services/cache/model_cache.py
vendored
@@ -2,6 +2,8 @@
|
||||
Model 映射缓存服务 - 减少模型查询
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -9,6 +11,11 @@ from sqlalchemy.orm import Session
|
||||
from src.config.constants import CacheTTL
|
||||
from src.core.cache_service import CacheService
|
||||
from src.core.logger import logger
|
||||
from src.core.metrics import (
|
||||
model_alias_conflict_total,
|
||||
model_alias_resolution_duration_seconds,
|
||||
model_alias_resolution_total,
|
||||
)
|
||||
from src.models.database import GlobalModel, Model
|
||||
|
||||
|
||||
@@ -102,7 +109,9 @@ class ModelCacheService:
|
||||
# 1. 尝试从缓存获取
|
||||
cached_data = await CacheService.get(cache_key)
|
||||
if cached_data:
|
||||
logger.debug(f"Model 缓存命中(provider+global): {provider_id[:8]}...+{global_model_id[:8]}...")
|
||||
logger.debug(
|
||||
f"Model 缓存命中(provider+global): {provider_id[:8]}...+{global_model_id[:8]}..."
|
||||
)
|
||||
return ModelCacheService._dict_to_model(cached_data)
|
||||
|
||||
# 2. 缓存未命中,查询数据库
|
||||
@@ -120,7 +129,9 @@ class ModelCacheService:
|
||||
if model:
|
||||
model_dict = ModelCacheService._model_to_dict(model)
|
||||
await CacheService.set(cache_key, model_dict, ttl_seconds=ModelCacheService.CACHE_TTL)
|
||||
logger.debug(f"Model 已缓存(provider+global): {provider_id[:8]}...+{global_model_id[:8]}...")
|
||||
logger.debug(
|
||||
f"Model 已缓存(provider+global): {provider_id[:8]}...+{global_model_id[:8]}..."
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
@@ -160,7 +171,7 @@ class ModelCacheService:
|
||||
@staticmethod
|
||||
async def invalidate_model_cache(
|
||||
model_id: str, provider_id: Optional[str] = None, global_model_id: Optional[str] = None
|
||||
):
|
||||
) -> None:
|
||||
"""清除 Model 缓存
|
||||
|
||||
Args:
|
||||
@@ -174,18 +185,207 @@ class ModelCacheService:
|
||||
# 清除 provider_global 缓存(如果提供了必要参数)
|
||||
if provider_id and global_model_id:
|
||||
await CacheService.delete(f"model:provider_global:{provider_id}:{global_model_id}")
|
||||
logger.debug(f"Model 缓存已清除: {model_id}, provider_global:{provider_id[:8]}...:{global_model_id[:8]}...")
|
||||
logger.debug(
|
||||
f"Model 缓存已清除: {model_id}, provider_global:{provider_id[:8]}...:{global_model_id[:8]}..."
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Model 缓存已清除: {model_id}")
|
||||
|
||||
@staticmethod
|
||||
async def invalidate_global_model_cache(global_model_id: str, name: Optional[str] = None):
|
||||
async def invalidate_global_model_cache(global_model_id: str, name: Optional[str] = None) -> None:
|
||||
"""清除 GlobalModel 缓存"""
|
||||
await CacheService.delete(f"global_model:id:{global_model_id}")
|
||||
if name:
|
||||
await CacheService.delete(f"global_model:name:{name}")
|
||||
logger.debug(f"GlobalModel 缓存已清除: {global_model_id}")
|
||||
|
||||
@staticmethod
|
||||
async def resolve_global_model_by_name_or_alias(
|
||||
db: Session, model_name: str
|
||||
) -> Optional[GlobalModel]:
|
||||
"""
|
||||
通过名称或别名解析 GlobalModel(带缓存,支持别名匹配)
|
||||
|
||||
查找顺序:
|
||||
1. 检查缓存
|
||||
2. 直接匹配 GlobalModel.name
|
||||
3. 通过别名匹配(查询 Model 表的 provider_model_name 和 provider_model_aliases)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
model_name: 模型名称(可以是 GlobalModel.name 或别名)
|
||||
|
||||
Returns:
|
||||
GlobalModel 对象或 None
|
||||
"""
|
||||
start_time = time.time()
|
||||
resolution_method = "not_found"
|
||||
cache_hit = False
|
||||
|
||||
normalized_name = model_name.strip()
|
||||
if not normalized_name:
|
||||
return None
|
||||
|
||||
cache_key = f"global_model:resolve:{normalized_name}"
|
||||
|
||||
try:
|
||||
# 1. 尝试从缓存获取
|
||||
cached_data = await CacheService.get(cache_key)
|
||||
if cached_data:
|
||||
if cached_data == "NOT_FOUND":
|
||||
# 缓存的负结果
|
||||
cache_hit = True
|
||||
resolution_method = "not_found"
|
||||
logger.debug(f"GlobalModel 缓存命中(别名解析-未找到): {normalized_name}")
|
||||
return None
|
||||
if isinstance(cached_data, dict) and "supported_capabilities" not in cached_data:
|
||||
# 兼容旧缓存:字段不全时视为未命中,走 DB 刷新
|
||||
logger.debug(f"GlobalModel 缓存命中但 schema 过旧,刷新: {normalized_name}")
|
||||
else:
|
||||
cache_hit = True
|
||||
resolution_method = "direct_match" # 缓存命中时无法区分原始解析方式
|
||||
logger.debug(f"GlobalModel 缓存命中(别名解析): {normalized_name}")
|
||||
return ModelCacheService._dict_to_global_model(cached_data)
|
||||
|
||||
# 2. 直接通过 GlobalModel.name 查找
|
||||
global_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(GlobalModel.name == normalized_name, GlobalModel.is_active == True)
|
||||
.first()
|
||||
)
|
||||
|
||||
if global_model:
|
||||
resolution_method = "direct_match"
|
||||
# 缓存结果
|
||||
global_model_dict = ModelCacheService._global_model_to_dict(global_model)
|
||||
await CacheService.set(
|
||||
cache_key, global_model_dict, ttl_seconds=ModelCacheService.CACHE_TTL
|
||||
)
|
||||
logger.debug(f"GlobalModel 已缓存(别名解析-直接匹配): {normalized_name}")
|
||||
return global_model
|
||||
|
||||
# 3. 通过别名匹配(优化:精确查询,避免加载所有 Model)
|
||||
from sqlalchemy import or_
|
||||
|
||||
from src.models.database import Provider
|
||||
|
||||
# 构建精确的别名匹配条件
|
||||
# 注意:provider_model_aliases 是 JSONB 数组,需要使用 PostgreSQL 的 JSONB 操作符
|
||||
# 对于 SQLite,会在 Python 层面进行过滤
|
||||
try:
|
||||
# 尝试使用 PostgreSQL 的 JSONB 查询(更高效)
|
||||
# 使用 json.dumps 确保正确转义特殊字符,避免 SQL 注入
|
||||
jsonb_pattern = json.dumps([{"name": normalized_name}])
|
||||
models_with_global = (
|
||||
db.query(Model, GlobalModel)
|
||||
.join(Provider, Model.provider_id == Provider.id)
|
||||
.join(GlobalModel, Model.global_model_id == GlobalModel.id)
|
||||
.filter(
|
||||
Provider.is_active == True,
|
||||
Model.is_active == True,
|
||||
GlobalModel.is_active == True,
|
||||
or_(
|
||||
Model.provider_model_name == normalized_name,
|
||||
# PostgreSQL JSONB 查询:检查数组中是否有包含 {"name": "xxx"} 的元素
|
||||
Model.provider_model_aliases.op("@>")(jsonb_pattern),
|
||||
),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
except Exception as e:
|
||||
# 如果 JSONB 查询失败(如使用 SQLite),回退到加载所有活跃 Model 并在 Python 层过滤
|
||||
logger.debug(
|
||||
f"JSONB 查询失败,回退到 Python 过滤: {e}",
|
||||
)
|
||||
models_with_global = (
|
||||
db.query(Model, GlobalModel)
|
||||
.join(Provider, Model.provider_id == Provider.id)
|
||||
.join(GlobalModel, Model.global_model_id == GlobalModel.id)
|
||||
.filter(
|
||||
Provider.is_active == True,
|
||||
Model.is_active == True,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 用于检测别名冲突
|
||||
matched_global_models = []
|
||||
|
||||
# 遍历查询结果进行匹配
|
||||
for model, gm in models_with_global:
|
||||
# 检查 provider_model_name 是否匹配
|
||||
if model.provider_model_name == normalized_name:
|
||||
matched_global_models.append((gm, "provider_model_name"))
|
||||
logger.debug(
|
||||
f"模型名称 '{normalized_name}' 通过 provider_model_name 匹配到 "
|
||||
f"GlobalModel: {gm.name}"
|
||||
)
|
||||
|
||||
# 检查 provider_model_aliases 是否匹配
|
||||
if model.provider_model_aliases:
|
||||
for alias_entry in model.provider_model_aliases:
|
||||
if isinstance(alias_entry, dict):
|
||||
alias_name = alias_entry.get("name", "").strip()
|
||||
if alias_name == normalized_name:
|
||||
matched_global_models.append((gm, "alias"))
|
||||
logger.debug(
|
||||
f"模型名称 '{normalized_name}' 通过别名匹配到 "
|
||||
f"GlobalModel: {gm.name}"
|
||||
)
|
||||
break
|
||||
|
||||
# 处理匹配结果
|
||||
if not matched_global_models:
|
||||
resolution_method = "not_found"
|
||||
# 未找到匹配,缓存负结果
|
||||
await CacheService.set(
|
||||
cache_key, "NOT_FOUND", ttl_seconds=ModelCacheService.CACHE_TTL
|
||||
)
|
||||
logger.debug(f"GlobalModel 未找到(别名解析): {normalized_name}")
|
||||
return None
|
||||
|
||||
# 优先使用 provider_model_name 的直接匹配,其次才是 aliases;同级别按名称排序保证确定性
|
||||
matched_global_models.sort(
|
||||
key=lambda item: (0 if item[1] == "provider_model_name" else 1, item[0].name)
|
||||
)
|
||||
|
||||
# 记录解析方式
|
||||
resolution_method = matched_global_models[0][1]
|
||||
|
||||
if len(matched_global_models) > 1:
|
||||
# 检测到别名冲突
|
||||
unique_models = {gm.id: gm for gm, _ in matched_global_models}
|
||||
if len(unique_models) > 1:
|
||||
model_names = [gm.name for gm in unique_models.values()]
|
||||
logger.warning(
|
||||
f"别名冲突: 模型名称 '{normalized_name}' 匹配到多个不同的 GlobalModel: "
|
||||
f"{', '.join(model_names)},使用第一个匹配结果"
|
||||
)
|
||||
# 记录别名冲突指标
|
||||
model_alias_conflict_total.inc()
|
||||
|
||||
# 返回第一个匹配的 GlobalModel
|
||||
result_global_model: GlobalModel = matched_global_models[0][0]
|
||||
global_model_dict = ModelCacheService._global_model_to_dict(result_global_model)
|
||||
await CacheService.set(
|
||||
cache_key, global_model_dict, ttl_seconds=ModelCacheService.CACHE_TTL
|
||||
)
|
||||
logger.debug(
|
||||
f"GlobalModel 已缓存(别名解析): {normalized_name} -> {result_global_model.name}"
|
||||
)
|
||||
return result_global_model
|
||||
|
||||
finally:
|
||||
# 记录监控指标
|
||||
duration = time.time() - start_time
|
||||
model_alias_resolution_total.labels(
|
||||
method=resolution_method, cache_hit=str(cache_hit).lower()
|
||||
).inc()
|
||||
model_alias_resolution_duration_seconds.labels(method=resolution_method).observe(
|
||||
duration
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _model_to_dict(model: Model) -> dict:
|
||||
"""将 Model 对象转换为字典"""
|
||||
@@ -243,6 +443,7 @@ class ModelCacheService:
|
||||
"default_supports_streaming": global_model.default_supports_streaming,
|
||||
"default_supports_extended_thinking": global_model.default_supports_extended_thinking,
|
||||
"default_supports_image_generation": global_model.default_supports_image_generation,
|
||||
"supported_capabilities": global_model.supported_capabilities,
|
||||
"is_active": global_model.is_active,
|
||||
"description": global_model.description,
|
||||
}
|
||||
@@ -255,10 +456,17 @@ class ModelCacheService:
|
||||
name=global_model_dict["name"],
|
||||
display_name=global_model_dict.get("display_name"),
|
||||
default_supports_vision=global_model_dict.get("default_supports_vision", False),
|
||||
default_supports_function_calling=global_model_dict.get("default_supports_function_calling", False),
|
||||
default_supports_function_calling=global_model_dict.get(
|
||||
"default_supports_function_calling", False
|
||||
),
|
||||
default_supports_streaming=global_model_dict.get("default_supports_streaming", True),
|
||||
default_supports_extended_thinking=global_model_dict.get("default_supports_extended_thinking", False),
|
||||
default_supports_image_generation=global_model_dict.get("default_supports_image_generation", False),
|
||||
default_supports_extended_thinking=global_model_dict.get(
|
||||
"default_supports_extended_thinking", False
|
||||
),
|
||||
default_supports_image_generation=global_model_dict.get(
|
||||
"default_supports_image_generation", False
|
||||
),
|
||||
supported_capabilities=global_model_dict.get("supported_capabilities") or [],
|
||||
is_active=global_model_dict.get("is_active", True),
|
||||
description=global_model_dict.get("description"),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user