refactor(cache): optimize cache service architecture and provider transport

This commit is contained in:
fawney19
2025-12-15 23:12:34 +08:00
parent d87de10f62
commit 718f56ba75
3 changed files with 115 additions and 114 deletions

View File

@@ -2,11 +2,9 @@
Model 映射缓存服务 - 减少模型查询
"""
import json
import time
from typing import Optional
from typing import List, Optional
from sqlalchemy.exc import OperationalError, ProgrammingError
from sqlalchemy.orm import Session
from src.config.constants import CacheTTL
@@ -106,6 +104,7 @@ class ModelCacheService:
Model 对象或 None
"""
cache_key = f"model:provider_global:{provider_id}:{global_model_id}"
hit_count_key = f"model:provider_global:hits:{provider_id}:{global_model_id}"
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
@@ -113,6 +112,8 @@ class ModelCacheService:
logger.debug(
f"Model 缓存命中(provider+global): {provider_id[:8]}...+{global_model_id[:8]}..."
)
# 递增命中计数,同时刷新 TTL
await CacheService.incr(hit_count_key, ttl_seconds=ModelCacheService.CACHE_TTL)
return ModelCacheService._dict_to_model(cached_data)
# 2. 缓存未命中,查询数据库
@@ -130,6 +131,8 @@ class ModelCacheService:
if model:
model_dict = ModelCacheService._model_to_dict(model)
await CacheService.set(cache_key, model_dict, ttl_seconds=ModelCacheService.CACHE_TTL)
# 重置命中计数新缓存从1开始
await CacheService.set(hit_count_key, 1, ttl_seconds=ModelCacheService.CACHE_TTL)
logger.debug(
f"Model 已缓存(provider+global): {provider_id[:8]}...+{global_model_id[:8]}..."
)
@@ -189,9 +192,10 @@ class ModelCacheService:
# 清除 model:id 缓存
await CacheService.delete(f"model:id:{model_id}")
# 清除 provider_global 缓存(如果提供了必要参数)
# 清除 provider_global 缓存及其命中计数(如果提供了必要参数)
if provider_id and global_model_id:
await CacheService.delete(f"model:provider_global:{provider_id}:{global_model_id}")
await CacheService.delete(f"model:provider_global:hits:{provider_id}:{global_model_id}")
logger.debug(
f"Model 缓存已清除: {model_id}, provider_global:{provider_id[:8]}...:{global_model_id[:8]}..."
)
@@ -230,16 +234,20 @@ class ModelCacheService:
db: Session, model_name: str
) -> Optional[GlobalModel]:
"""
通过名称或映射解析 GlobalModel带缓存,支持映射匹配
通过名称解析 GlobalModel带缓存
查找顺序:
1. 检查缓存
2. 通过映射匹配(查询 Model 表的 provider_model_name 和 provider_model_aliases
2. 通过 provider_model_name 匹配(查询 Model 表
3. 直接匹配 GlobalModel.name兜底
注意:此方法不使用 provider_model_aliases 进行全局解析。
provider_model_aliases 是 Provider 级别的别名配置,只在特定 Provider 上下文中生效,
由 resolve_provider_model() 处理。
Args:
db: 数据库会话
model_name: 模型名称(可以是 GlobalModel.name 或映射名称
model_name: 模型名称(可以是 GlobalModel.name 或 provider_model_name
Returns:
GlobalModel 对象或 None
@@ -273,116 +281,53 @@ class ModelCacheService:
logger.debug(f"GlobalModel 缓存命中(映射解析): {normalized_name}")
return ModelCacheService._dict_to_global_model(cached_data)
# 2. 优先通过 provider_model_name 和映射名称匹配Provider 配置优先级最高
from sqlalchemy import or_
# 2. 通过 provider_model_name 匹配(不考虑 provider_model_aliases
# 重要provider_model_aliases 是 Provider 级别的别名配置,只在特定 Provider 上下文中生效
# 全局解析不应该受到某个 Provider 别名配置的影响
# 例如Provider A 把 "haiku" 映射到 "sonnet",不应该影响 Provider B 的 "haiku" 解析
from src.models.database import Provider
# 构建精确的映射匹配条件
# 注意provider_model_aliases 是 JSONB 数组,需要使用 PostgreSQL 的 JSONB 操作符
# 对于 SQLite会在 Python 层面进行过滤
try:
# 尝试使用 PostgreSQL 的 JSONB 查询(更高效)
# 使用 json.dumps 确保正确转义特殊字符,避免 SQL 注入
jsonb_pattern = json.dumps([{"name": normalized_name}])
models_with_global = (
db.query(Model, GlobalModel)
.join(Provider, Model.provider_id == Provider.id)
.join(GlobalModel, Model.global_model_id == GlobalModel.id)
.filter(
Provider.is_active == True,
Model.is_active == True,
GlobalModel.is_active == True,
or_(
Model.provider_model_name == normalized_name,
# PostgreSQL JSONB 查询:检查数组中是否有包含 {"name": "xxx"} 的元素
Model.provider_model_aliases.op("@>")(jsonb_pattern),
),
)
.all()
)
except (OperationalError, ProgrammingError) as e:
# JSONB 操作符不支持(如 SQLite回退到加载匹配 provider_model_name 的 Model
# 并在 Python 层过滤 aliases
logger.debug(
f"JSONB 查询失败,回退到 Python 过滤: {e}",
)
# 优化:先用 provider_model_name 缩小范围,再加载其他可能匹配的记录
models_with_global = (
db.query(Model, GlobalModel)
.join(Provider, Model.provider_id == Provider.id)
.join(GlobalModel, Model.global_model_id == GlobalModel.id)
.filter(
Provider.is_active == True,
Model.is_active == True,
GlobalModel.is_active == True,
)
.all()
models_with_global = (
db.query(Model, GlobalModel)
.join(Provider, Model.provider_id == Provider.id)
.join(GlobalModel, Model.global_model_id == GlobalModel.id)
.filter(
Provider.is_active == True,
Model.is_active == True,
GlobalModel.is_active == True,
Model.provider_model_name == normalized_name,
)
.all()
)
# 用于存储匹配结果:{(model_id, global_model_id): (GlobalModel, match_type, priority)}
# 使用字典去重,同一个 Model 只保留优先级最高的匹配
matched_models_dict = {}
# 遍历查询结果进行匹配
# 收集匹配的 GlobalModel(只通过 provider_model_name 匹配)
matched_global_models: List[GlobalModel] = []
seen_global_model_ids: set[str] = set()
for model, gm in models_with_global:
key = (model.id, gm.id)
# 检查 provider_model_aliases 是否匹配(优先级更高)
if model.provider_model_aliases:
for alias_entry in model.provider_model_aliases:
if isinstance(alias_entry, dict):
alias_name = alias_entry.get("name", "").strip()
if alias_name == normalized_name:
# alias 优先级为 0最高覆盖任何已存在的匹配
matched_models_dict[key] = (gm, "alias", 0)
logger.debug(
f"模型名称 '{normalized_name}' 通过映射名称匹配到 "
f"GlobalModel: {gm.name} (Model: {model.id[:8]}...)"
)
break
# 如果还没有匹配(或只有 provider_model_name 匹配),检查 provider_model_name
if key not in matched_models_dict or matched_models_dict[key][1] != "alias":
if model.provider_model_name == normalized_name:
# provider_model_name 优先级为 1兜底只在没有 alias 匹配时使用
if key not in matched_models_dict:
matched_models_dict[key] = (gm, "provider_model_name", 1)
logger.debug(
f"模型名称 '{normalized_name}' 通过 provider_model_name 匹配到 "
f"GlobalModel: {gm.name} (Model: {model.id[:8]}...)"
)
# 如果通过 provider_model_name/alias 找到了,直接返回
if matched_models_dict:
# 转换为列表并排序:按 priorityalias=0 优先)、然后按 GlobalModel.name
matched_global_models = [
(gm, match_type) for gm, match_type, priority in matched_models_dict.values()
]
matched_global_models.sort(
key=lambda item: (
0 if item[1] == "alias" else 1, # alias 优先
item[0].name # 同优先级按名称排序(确定性)
if gm.id not in seen_global_model_ids:
seen_global_model_ids.add(gm.id)
matched_global_models.append(gm)
logger.debug(
f"模型名称 '{normalized_name}' 通过 provider_model_name 匹配到 "
f"GlobalModel: {gm.name} (Model: {model.id[:8]}...)"
)
)
# 记录解析方式
resolution_method = matched_global_models[0][1]
# 如果通过 provider_model_name 找到了,返回
if matched_global_models:
resolution_method = "provider_model_name"
if len(matched_global_models) > 1:
# 检测到冲突
unique_models = {gm.id: gm for gm, _ in matched_global_models}
if len(unique_models) > 1:
model_names = [gm.name for gm in unique_models.values()]
logger.warning(
f"模型映射冲突: 名称 '{normalized_name}' 匹配到多个不同的 GlobalModel: "
f"{', '.join(model_names)},使用第一个匹配结果"
)
# 记录冲突指标
model_mapping_conflict_total.inc()
# 检测到冲突(多个不同的 GlobalModel 有相同的 provider_model_name
model_names = [gm.name for gm in matched_global_models if gm.name]
logger.warning(
f"模型映射冲突: 名称 '{normalized_name}' 匹配到多个不同的 GlobalModel: "
f"{', '.join(model_names)},使用第一个匹配结果"
)
# 记录冲突指标
model_mapping_conflict_total.inc()
# 返回第一个匹配的 GlobalModel
result_global_model: GlobalModel = matched_global_models[0][0]
result_global_model = matched_global_models[0]
global_model_dict = ModelCacheService._global_model_to_dict(result_global_model)
await CacheService.set(
cache_key, global_model_dict, ttl_seconds=ModelCacheService.CACHE_TTL