Files
Aether/src/services/cache/model_cache.py

502 lines
21 KiB
Python
Raw Normal View History

2025-12-10 20:52:44 +08:00
"""
Model 映射缓存服务 - 减少模型查询
2025-12-10 20:52:44 +08:00
"""
import json
import time
2025-12-10 20:52:44 +08:00
from typing import Optional
from sqlalchemy.exc import OperationalError, ProgrammingError
2025-12-10 20:52:44 +08:00
from sqlalchemy.orm import Session
from src.config.constants import CacheTTL
from src.core.cache_service import CacheService
from src.core.logger import logger
from src.core.metrics import (
model_alias_conflict_total,
model_alias_resolution_duration_seconds,
model_alias_resolution_total,
)
from src.models.database import GlobalModel, Model
2025-12-10 20:52:44 +08:00
class ModelCacheService:
"""Model 映射缓存服务"""
# 缓存 TTL- 使用统一常量
CACHE_TTL = CacheTTL.MODEL
@staticmethod
async def get_model_by_id(db: Session, model_id: str) -> Optional[Model]:
"""
获取 Model带缓存
Args:
db: 数据库会话
model_id: Model ID
Returns:
Model 对象或 None
"""
cache_key = f"model:id:{model_id}"
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data:
logger.debug(f"Model 缓存命中: {model_id}")
return ModelCacheService._dict_to_model(cached_data)
# 2. 缓存未命中,查询数据库
model = db.query(Model).filter(Model.id == model_id).first()
# 3. 写入缓存
if model:
model_dict = ModelCacheService._model_to_dict(model)
await CacheService.set(cache_key, model_dict, ttl_seconds=ModelCacheService.CACHE_TTL)
logger.debug(f"Model 已缓存: {model_id}")
return model
@staticmethod
async def get_global_model_by_id(db: Session, global_model_id: str) -> Optional[GlobalModel]:
"""
获取 GlobalModel带缓存
Args:
db: 数据库会话
global_model_id: GlobalModel ID
Returns:
GlobalModel 对象或 None
"""
cache_key = f"global_model:id:{global_model_id}"
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data:
logger.debug(f"GlobalModel 缓存命中: {global_model_id}")
return ModelCacheService._dict_to_global_model(cached_data)
# 2. 缓存未命中,查询数据库
global_model = db.query(GlobalModel).filter(GlobalModel.id == global_model_id).first()
# 3. 写入缓存
if global_model:
global_model_dict = ModelCacheService._global_model_to_dict(global_model)
await CacheService.set(
cache_key, global_model_dict, ttl_seconds=ModelCacheService.CACHE_TTL
)
logger.debug(f"GlobalModel 已缓存: {global_model_id}")
return global_model
@staticmethod
async def get_model_by_provider_and_global_model(
db: Session, provider_id: str, global_model_id: str
) -> Optional[Model]:
"""
通过 Provider ID GlobalModel ID 获取 Model带缓存
Args:
db: 数据库会话
provider_id: Provider ID
global_model_id: GlobalModel ID
Returns:
Model 对象或 None
"""
cache_key = f"model:provider_global:{provider_id}:{global_model_id}"
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data:
logger.debug(
f"Model 缓存命中(provider+global): {provider_id[:8]}...+{global_model_id[:8]}..."
)
2025-12-10 20:52:44 +08:00
return ModelCacheService._dict_to_model(cached_data)
# 2. 缓存未命中,查询数据库
model = (
db.query(Model)
.filter(
Model.provider_id == provider_id,
Model.global_model_id == global_model_id,
Model.is_active == True,
)
.first()
)
# 3. 写入缓存
if model:
model_dict = ModelCacheService._model_to_dict(model)
await CacheService.set(cache_key, model_dict, ttl_seconds=ModelCacheService.CACHE_TTL)
logger.debug(
f"Model 已缓存(provider+global): {provider_id[:8]}...+{global_model_id[:8]}..."
)
2025-12-10 20:52:44 +08:00
return model
@staticmethod
async def get_global_model_by_name(db: Session, name: str) -> Optional[GlobalModel]:
"""
通过名称获取 GlobalModel带缓存
Args:
db: 数据库会话
name: GlobalModel 名称
Returns:
GlobalModel 对象或 None
"""
cache_key = f"global_model:name:{name}"
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data:
logger.debug(f"GlobalModel 缓存命中(名称): {name}")
return ModelCacheService._dict_to_global_model(cached_data)
# 2. 缓存未命中,查询数据库
global_model = db.query(GlobalModel).filter(GlobalModel.name == name).first()
# 3. 写入缓存
if global_model:
global_model_dict = ModelCacheService._global_model_to_dict(global_model)
await CacheService.set(
cache_key, global_model_dict, ttl_seconds=ModelCacheService.CACHE_TTL
)
logger.debug(f"GlobalModel 已缓存(名称): {name}")
return global_model
@staticmethod
async def invalidate_model_cache(
model_id: str,
provider_id: Optional[str] = None,
global_model_id: Optional[str] = None,
provider_model_name: Optional[str] = None,
provider_model_aliases: Optional[list] = None,
) -> None:
2025-12-10 20:52:44 +08:00
"""清除 Model 缓存
Args:
model_id: Model ID
provider_id: Provider ID用于清除 provider_global 缓存
global_model_id: GlobalModel ID用于清除 provider_global 缓存
provider_model_name: provider_model_name用于清除 resolve 缓存
provider_model_aliases: 别名列表用于清除 resolve 缓存
2025-12-10 20:52:44 +08:00
"""
# 清除 model:id 缓存
await CacheService.delete(f"model:id:{model_id}")
# 清除 provider_global 缓存(如果提供了必要参数)
if provider_id and global_model_id:
await CacheService.delete(f"model:provider_global:{provider_id}:{global_model_id}")
logger.debug(
f"Model 缓存已清除: {model_id}, provider_global:{provider_id[:8]}...:{global_model_id[:8]}..."
)
2025-12-10 20:52:44 +08:00
else:
logger.debug(f"Model 缓存已清除: {model_id}")
# 清除 resolve 缓存provider_model_name 和 aliases 可能都被用作解析 key
resolve_keys_to_clear = []
if provider_model_name:
resolve_keys_to_clear.append(provider_model_name)
if provider_model_aliases:
for alias_entry in provider_model_aliases:
if isinstance(alias_entry, dict):
alias_name = alias_entry.get("name", "").strip()
if alias_name:
resolve_keys_to_clear.append(alias_name)
for key in resolve_keys_to_clear:
await CacheService.delete(f"global_model:resolve:{key}")
if resolve_keys_to_clear:
logger.debug(f"Model resolve 缓存已清除: {resolve_keys_to_clear}")
2025-12-10 20:52:44 +08:00
@staticmethod
async def invalidate_global_model_cache(global_model_id: str, name: Optional[str] = None) -> None:
2025-12-10 20:52:44 +08:00
"""清除 GlobalModel 缓存"""
await CacheService.delete(f"global_model:id:{global_model_id}")
if name:
await CacheService.delete(f"global_model:name:{name}")
# 同时清除 resolve 缓存,因为 GlobalModel.name 也是一个 resolve key
await CacheService.delete(f"global_model:resolve:{name}")
2025-12-10 20:52:44 +08:00
logger.debug(f"GlobalModel 缓存已清除: {global_model_id}")
@staticmethod
async def resolve_global_model_by_name_or_alias(
db: Session, model_name: str
) -> Optional[GlobalModel]:
"""
通过名称或别名解析 GlobalModel带缓存支持别名匹配
查找顺序
1. 检查缓存
2. 直接匹配 GlobalModel.name
3. 通过别名匹配查询 Model 表的 provider_model_name provider_model_aliases
Args:
db: 数据库会话
model_name: 模型名称可以是 GlobalModel.name 或别名
Returns:
GlobalModel 对象或 None
"""
start_time = time.time()
resolution_method = "not_found"
cache_hit = False
normalized_name = model_name.strip()
if not normalized_name:
return None
cache_key = f"global_model:resolve:{normalized_name}"
try:
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data:
if cached_data == "NOT_FOUND":
# 缓存的负结果
cache_hit = True
resolution_method = "not_found"
logger.debug(f"GlobalModel 缓存命中(别名解析-未找到): {normalized_name}")
return None
if isinstance(cached_data, dict) and "supported_capabilities" not in cached_data:
# 兼容旧缓存:字段不全时视为未命中,走 DB 刷新
logger.debug(f"GlobalModel 缓存命中但 schema 过旧,刷新: {normalized_name}")
else:
cache_hit = True
resolution_method = "direct_match" # 缓存命中时无法区分原始解析方式
logger.debug(f"GlobalModel 缓存命中(别名解析): {normalized_name}")
return ModelCacheService._dict_to_global_model(cached_data)
# 2. 直接通过 GlobalModel.name 查找
global_model = (
db.query(GlobalModel)
.filter(GlobalModel.name == normalized_name, GlobalModel.is_active == True)
.first()
)
if global_model:
resolution_method = "direct_match"
# 缓存结果
global_model_dict = ModelCacheService._global_model_to_dict(global_model)
await CacheService.set(
cache_key, global_model_dict, ttl_seconds=ModelCacheService.CACHE_TTL
)
logger.debug(f"GlobalModel 已缓存(别名解析-直接匹配): {normalized_name}")
return global_model
# 3. 通过别名匹配(优化:精确查询,避免加载所有 Model
from sqlalchemy import or_
from src.models.database import Provider
# 构建精确的别名匹配条件
# 注意provider_model_aliases 是 JSONB 数组,需要使用 PostgreSQL 的 JSONB 操作符
# 对于 SQLite会在 Python 层面进行过滤
try:
# 尝试使用 PostgreSQL 的 JSONB 查询(更高效)
# 使用 json.dumps 确保正确转义特殊字符,避免 SQL 注入
jsonb_pattern = json.dumps([{"name": normalized_name}])
models_with_global = (
db.query(Model, GlobalModel)
.join(Provider, Model.provider_id == Provider.id)
.join(GlobalModel, Model.global_model_id == GlobalModel.id)
.filter(
Provider.is_active == True,
Model.is_active == True,
GlobalModel.is_active == True,
or_(
Model.provider_model_name == normalized_name,
# PostgreSQL JSONB 查询:检查数组中是否有包含 {"name": "xxx"} 的元素
Model.provider_model_aliases.op("@>")(jsonb_pattern),
),
)
.all()
)
except (OperationalError, ProgrammingError) as e:
# JSONB 操作符不支持(如 SQLite回退到加载匹配 provider_model_name 的 Model
# 并在 Python 层过滤 aliases
logger.debug(
f"JSONB 查询失败,回退到 Python 过滤: {e}",
)
# 优化:先用 provider_model_name 缩小范围,再加载其他可能匹配的记录
models_with_global = (
db.query(Model, GlobalModel)
.join(Provider, Model.provider_id == Provider.id)
.join(GlobalModel, Model.global_model_id == GlobalModel.id)
.filter(
Provider.is_active == True,
Model.is_active == True,
GlobalModel.is_active == True,
)
.all()
)
# 用于检测别名冲突
matched_global_models = []
# 遍历查询结果进行匹配
for model, gm in models_with_global:
# 检查 provider_model_name 是否匹配
if model.provider_model_name == normalized_name:
matched_global_models.append((gm, "provider_model_name"))
logger.debug(
f"模型名称 '{normalized_name}' 通过 provider_model_name 匹配到 "
f"GlobalModel: {gm.name}"
)
# 检查 provider_model_aliases 是否匹配
if model.provider_model_aliases:
for alias_entry in model.provider_model_aliases:
if isinstance(alias_entry, dict):
alias_name = alias_entry.get("name", "").strip()
if alias_name == normalized_name:
matched_global_models.append((gm, "alias"))
logger.debug(
f"模型名称 '{normalized_name}' 通过别名匹配到 "
f"GlobalModel: {gm.name}"
)
break
# 处理匹配结果
if not matched_global_models:
resolution_method = "not_found"
# 未找到匹配,缓存负结果
await CacheService.set(
cache_key, "NOT_FOUND", ttl_seconds=ModelCacheService.CACHE_TTL
)
logger.debug(f"GlobalModel 未找到(别名解析): {normalized_name}")
return None
# 优先使用 provider_model_name 的直接匹配,其次才是 aliases同级别按名称排序保证确定性
matched_global_models.sort(
key=lambda item: (0 if item[1] == "provider_model_name" else 1, item[0].name)
)
# 记录解析方式
resolution_method = matched_global_models[0][1]
if len(matched_global_models) > 1:
# 检测到别名冲突
unique_models = {gm.id: gm for gm, _ in matched_global_models}
if len(unique_models) > 1:
model_names = [gm.name for gm in unique_models.values()]
logger.warning(
f"别名冲突: 模型名称 '{normalized_name}' 匹配到多个不同的 GlobalModel: "
f"{', '.join(model_names)},使用第一个匹配结果"
)
# 记录别名冲突指标
model_alias_conflict_total.inc()
# 返回第一个匹配的 GlobalModel
result_global_model: GlobalModel = matched_global_models[0][0]
global_model_dict = ModelCacheService._global_model_to_dict(result_global_model)
await CacheService.set(
cache_key, global_model_dict, ttl_seconds=ModelCacheService.CACHE_TTL
)
logger.debug(
f"GlobalModel 已缓存(别名解析): {normalized_name} -> {result_global_model.name}"
)
return result_global_model
finally:
# 记录监控指标
duration = time.time() - start_time
model_alias_resolution_total.labels(
method=resolution_method, cache_hit=str(cache_hit).lower()
).inc()
model_alias_resolution_duration_seconds.labels(method=resolution_method).observe(
duration
)
2025-12-10 20:52:44 +08:00
@staticmethod
def _model_to_dict(model: Model) -> dict:
"""将 Model 对象转换为字典"""
return {
"id": model.id,
"provider_id": model.provider_id,
"global_model_id": model.global_model_id,
"provider_model_name": model.provider_model_name,
"provider_model_aliases": getattr(model, "provider_model_aliases", None),
2025-12-10 20:52:44 +08:00
"is_active": model.is_active,
"is_available": model.is_available if hasattr(model, "is_available") else True,
"price_per_request": (
float(model.price_per_request) if model.price_per_request else None
),
"tiered_pricing": model.tiered_pricing,
"supports_vision": model.supports_vision,
"supports_function_calling": model.supports_function_calling,
"supports_streaming": model.supports_streaming,
"supports_extended_thinking": model.supports_extended_thinking,
"supports_image_generation": getattr(model, "supports_image_generation", None),
2025-12-10 20:52:44 +08:00
"config": model.config,
}
@staticmethod
def _dict_to_model(model_dict: dict) -> Model:
"""从字典重建 Model 对象"""
model = Model(
id=model_dict["id"],
provider_id=model_dict["provider_id"],
global_model_id=model_dict["global_model_id"],
provider_model_name=model_dict["provider_model_name"],
provider_model_aliases=model_dict.get("provider_model_aliases"),
2025-12-10 20:52:44 +08:00
is_active=model_dict["is_active"],
is_available=model_dict.get("is_available", True),
price_per_request=model_dict.get("price_per_request"),
tiered_pricing=model_dict.get("tiered_pricing"),
supports_vision=model_dict.get("supports_vision"),
supports_function_calling=model_dict.get("supports_function_calling"),
supports_streaming=model_dict.get("supports_streaming"),
supports_extended_thinking=model_dict.get("supports_extended_thinking"),
supports_image_generation=model_dict.get("supports_image_generation"),
2025-12-10 20:52:44 +08:00
config=model_dict.get("config"),
)
return model
@staticmethod
def _global_model_to_dict(global_model: GlobalModel) -> dict:
"""将 GlobalModel 对象转换为字典"""
return {
"id": global_model.id,
"name": global_model.name,
"display_name": global_model.display_name,
"default_supports_vision": global_model.default_supports_vision,
"default_supports_function_calling": global_model.default_supports_function_calling,
"default_supports_streaming": global_model.default_supports_streaming,
"default_supports_extended_thinking": global_model.default_supports_extended_thinking,
"default_supports_image_generation": global_model.default_supports_image_generation,
"supported_capabilities": global_model.supported_capabilities,
2025-12-10 20:52:44 +08:00
"is_active": global_model.is_active,
"description": global_model.description,
}
@staticmethod
def _dict_to_global_model(global_model_dict: dict) -> GlobalModel:
"""从字典重建 GlobalModel 对象"""
global_model = GlobalModel(
id=global_model_dict["id"],
name=global_model_dict["name"],
display_name=global_model_dict.get("display_name"),
default_supports_vision=global_model_dict.get("default_supports_vision", False),
default_supports_function_calling=global_model_dict.get(
"default_supports_function_calling", False
),
default_supports_streaming=global_model_dict.get("default_supports_streaming", True),
default_supports_extended_thinking=global_model_dict.get(
"default_supports_extended_thinking", False
),
default_supports_image_generation=global_model_dict.get(
"default_supports_image_generation", False
),
supported_capabilities=global_model_dict.get("supported_capabilities") or [],
2025-12-10 20:52:44 +08:00
is_active=global_model_dict.get("is_active", True),
description=global_model_dict.get("description"),
)
return global_model