From 718f56ba75ebc6ca036d95ef80138ca13b65d239 Mon Sep 17 00:00:00 2001 From: fawney19 Date: Mon, 15 Dec 2025 23:12:34 +0800 Subject: [PATCH] refactor(cache): optimize cache service architecture and provider transport --- src/core/cache_service.py | 27 +++++ src/services/cache/model_cache.py | 157 ++++++++++------------------- src/services/provider/transport.py | 45 +++++++-- 3 files changed, 115 insertions(+), 114 deletions(-) diff --git a/src/core/cache_service.py b/src/core/cache_service.py index 7668e32..20fbf08 100644 --- a/src/core/cache_service.py +++ b/src/core/cache_service.py @@ -120,6 +120,33 @@ class CacheService: logger.warning(f"缓存检查失败: {key} - {e}") return False + @staticmethod + async def incr(key: str, ttl_seconds: Optional[int] = None) -> int: + """ + 递增缓存值 + + Args: + key: 缓存键 + ttl_seconds: 可选,如果提供则刷新 TTL + + Returns: + 递增后的值,如果失败返回 0 + """ + try: + redis = await get_redis_client(require_redis=False) + if not redis: + return 0 + + result = await redis.incr(key) + # 如果提供了 TTL,刷新过期时间 + if ttl_seconds is not None: + await redis.expire(key, ttl_seconds) + return result + + except Exception as e: + logger.warning(f"缓存递增失败: {key} - {e}") + return 0 + # 缓存键前缀 class CacheKeys: diff --git a/src/services/cache/model_cache.py b/src/services/cache/model_cache.py index 7a25afc..da11cd7 100644 --- a/src/services/cache/model_cache.py +++ b/src/services/cache/model_cache.py @@ -2,11 +2,9 @@ Model 映射缓存服务 - 减少模型查询 """ -import json import time -from typing import Optional +from typing import List, Optional -from sqlalchemy.exc import OperationalError, ProgrammingError from sqlalchemy.orm import Session from src.config.constants import CacheTTL @@ -106,6 +104,7 @@ class ModelCacheService: Model 对象或 None """ cache_key = f"model:provider_global:{provider_id}:{global_model_id}" + hit_count_key = f"model:provider_global:hits:{provider_id}:{global_model_id}" # 1. 尝试从缓存获取 cached_data = await CacheService.get(cache_key) @@ -113,6 +112,8 @@ class ModelCacheService: logger.debug( f"Model 缓存命中(provider+global): {provider_id[:8]}...+{global_model_id[:8]}..." ) + # 递增命中计数,同时刷新 TTL + await CacheService.incr(hit_count_key, ttl_seconds=ModelCacheService.CACHE_TTL) return ModelCacheService._dict_to_model(cached_data) # 2. 缓存未命中,查询数据库 @@ -130,6 +131,8 @@ class ModelCacheService: if model: model_dict = ModelCacheService._model_to_dict(model) await CacheService.set(cache_key, model_dict, ttl_seconds=ModelCacheService.CACHE_TTL) + # 重置命中计数(新缓存从1开始) + await CacheService.set(hit_count_key, 1, ttl_seconds=ModelCacheService.CACHE_TTL) logger.debug( f"Model 已缓存(provider+global): {provider_id[:8]}...+{global_model_id[:8]}..." ) @@ -189,9 +192,10 @@ class ModelCacheService: # 清除 model:id 缓存 await CacheService.delete(f"model:id:{model_id}") - # 清除 provider_global 缓存(如果提供了必要参数) + # 清除 provider_global 缓存及其命中计数(如果提供了必要参数) if provider_id and global_model_id: await CacheService.delete(f"model:provider_global:{provider_id}:{global_model_id}") + await CacheService.delete(f"model:provider_global:hits:{provider_id}:{global_model_id}") logger.debug( f"Model 缓存已清除: {model_id}, provider_global:{provider_id[:8]}...:{global_model_id[:8]}..." ) @@ -230,16 +234,20 @@ class ModelCacheService: db: Session, model_name: str ) -> Optional[GlobalModel]: """ - 通过名称或映射解析 GlobalModel(带缓存,支持映射匹配) + 通过名称解析 GlobalModel(带缓存) 查找顺序: 1. 检查缓存 - 2. 通过映射匹配(查询 Model 表的 provider_model_name 和 provider_model_aliases) + 2. 通过 provider_model_name 匹配(查询 Model 表) 3. 直接匹配 GlobalModel.name(兜底) + 注意:此方法不使用 provider_model_aliases 进行全局解析。 + provider_model_aliases 是 Provider 级别的别名配置,只在特定 Provider 上下文中生效, + 由 resolve_provider_model() 处理。 + Args: db: 数据库会话 - model_name: 模型名称(可以是 GlobalModel.name 或映射名称) + model_name: 模型名称(可以是 GlobalModel.name 或 provider_model_name) Returns: GlobalModel 对象或 None @@ -273,116 +281,53 @@ class ModelCacheService: logger.debug(f"GlobalModel 缓存命中(映射解析): {normalized_name}") return ModelCacheService._dict_to_global_model(cached_data) - # 2. 优先通过 provider_model_name 和映射名称匹配(Provider 配置优先级最高) - from sqlalchemy import or_ - + # 2. 通过 provider_model_name 匹配(不考虑 provider_model_aliases) + # 重要:provider_model_aliases 是 Provider 级别的别名配置,只在特定 Provider 上下文中生效 + # 全局解析不应该受到某个 Provider 别名配置的影响 + # 例如:Provider A 把 "haiku" 映射到 "sonnet",不应该影响 Provider B 的 "haiku" 解析 from src.models.database import Provider - # 构建精确的映射匹配条件 - # 注意:provider_model_aliases 是 JSONB 数组,需要使用 PostgreSQL 的 JSONB 操作符 - # 对于 SQLite,会在 Python 层面进行过滤 - try: - # 尝试使用 PostgreSQL 的 JSONB 查询(更高效) - # 使用 json.dumps 确保正确转义特殊字符,避免 SQL 注入 - jsonb_pattern = json.dumps([{"name": normalized_name}]) - models_with_global = ( - db.query(Model, GlobalModel) - .join(Provider, Model.provider_id == Provider.id) - .join(GlobalModel, Model.global_model_id == GlobalModel.id) - .filter( - Provider.is_active == True, - Model.is_active == True, - GlobalModel.is_active == True, - or_( - Model.provider_model_name == normalized_name, - # PostgreSQL JSONB 查询:检查数组中是否有包含 {"name": "xxx"} 的元素 - Model.provider_model_aliases.op("@>")(jsonb_pattern), - ), - ) - .all() - ) - except (OperationalError, ProgrammingError) as e: - # JSONB 操作符不支持(如 SQLite),回退到加载匹配 provider_model_name 的 Model - # 并在 Python 层过滤 aliases - logger.debug( - f"JSONB 查询失败,回退到 Python 过滤: {e}", - ) - # 优化:先用 provider_model_name 缩小范围,再加载其他可能匹配的记录 - models_with_global = ( - db.query(Model, GlobalModel) - .join(Provider, Model.provider_id == Provider.id) - .join(GlobalModel, Model.global_model_id == GlobalModel.id) - .filter( - Provider.is_active == True, - Model.is_active == True, - GlobalModel.is_active == True, - ) - .all() + models_with_global = ( + db.query(Model, GlobalModel) + .join(Provider, Model.provider_id == Provider.id) + .join(GlobalModel, Model.global_model_id == GlobalModel.id) + .filter( + Provider.is_active == True, + Model.is_active == True, + GlobalModel.is_active == True, + Model.provider_model_name == normalized_name, ) + .all() + ) - # 用于存储匹配结果:{(model_id, global_model_id): (GlobalModel, match_type, priority)} - # 使用字典去重,同一个 Model 只保留优先级最高的匹配 - matched_models_dict = {} - - # 遍历查询结果进行匹配 + # 收集匹配的 GlobalModel(只通过 provider_model_name 匹配) + matched_global_models: List[GlobalModel] = [] + seen_global_model_ids: set[str] = set() for model, gm in models_with_global: - key = (model.id, gm.id) - - # 检查 provider_model_aliases 是否匹配(优先级更高) - if model.provider_model_aliases: - for alias_entry in model.provider_model_aliases: - if isinstance(alias_entry, dict): - alias_name = alias_entry.get("name", "").strip() - if alias_name == normalized_name: - # alias 优先级为 0(最高),覆盖任何已存在的匹配 - matched_models_dict[key] = (gm, "alias", 0) - logger.debug( - f"模型名称 '{normalized_name}' 通过映射名称匹配到 " - f"GlobalModel: {gm.name} (Model: {model.id[:8]}...)" - ) - break - - # 如果还没有匹配(或只有 provider_model_name 匹配),检查 provider_model_name - if key not in matched_models_dict or matched_models_dict[key][1] != "alias": - if model.provider_model_name == normalized_name: - # provider_model_name 优先级为 1(兜底),只在没有 alias 匹配时使用 - if key not in matched_models_dict: - matched_models_dict[key] = (gm, "provider_model_name", 1) - logger.debug( - f"模型名称 '{normalized_name}' 通过 provider_model_name 匹配到 " - f"GlobalModel: {gm.name} (Model: {model.id[:8]}...)" - ) - - # 如果通过 provider_model_name/alias 找到了,直接返回 - if matched_models_dict: - # 转换为列表并排序:按 priority(alias=0 优先)、然后按 GlobalModel.name - matched_global_models = [ - (gm, match_type) for gm, match_type, priority in matched_models_dict.values() - ] - matched_global_models.sort( - key=lambda item: ( - 0 if item[1] == "alias" else 1, # alias 优先 - item[0].name # 同优先级按名称排序(确定性) + if gm.id not in seen_global_model_ids: + seen_global_model_ids.add(gm.id) + matched_global_models.append(gm) + logger.debug( + f"模型名称 '{normalized_name}' 通过 provider_model_name 匹配到 " + f"GlobalModel: {gm.name} (Model: {model.id[:8]}...)" ) - ) - # 记录解析方式 - resolution_method = matched_global_models[0][1] + # 如果通过 provider_model_name 找到了,返回 + if matched_global_models: + resolution_method = "provider_model_name" if len(matched_global_models) > 1: - # 检测到冲突 - unique_models = {gm.id: gm for gm, _ in matched_global_models} - if len(unique_models) > 1: - model_names = [gm.name for gm in unique_models.values()] - logger.warning( - f"模型映射冲突: 名称 '{normalized_name}' 匹配到多个不同的 GlobalModel: " - f"{', '.join(model_names)},使用第一个匹配结果" - ) - # 记录冲突指标 - model_mapping_conflict_total.inc() + # 检测到冲突(多个不同的 GlobalModel 有相同的 provider_model_name) + model_names = [gm.name for gm in matched_global_models if gm.name] + logger.warning( + f"模型映射冲突: 名称 '{normalized_name}' 匹配到多个不同的 GlobalModel: " + f"{', '.join(model_names)},使用第一个匹配结果" + ) + # 记录冲突指标 + model_mapping_conflict_total.inc() # 返回第一个匹配的 GlobalModel - result_global_model: GlobalModel = matched_global_models[0][0] + result_global_model = matched_global_models[0] global_model_dict = ModelCacheService._global_model_to_dict(result_global_model) await CacheService.set( cache_key, global_model_dict, ttl_seconds=ModelCacheService.CACHE_TTL diff --git a/src/services/provider/transport.py b/src/services/provider/transport.py index b46f3c6..0572110 100644 --- a/src/services/provider/transport.py +++ b/src/services/provider/transport.py @@ -6,7 +6,7 @@ - 根据 API 格式或端点配置生成请求 URL """ -from typing import Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional from urllib.parse import urlencode from src.core.api_format_metadata import get_auth_config, get_default_path, resolve_api_format @@ -14,11 +14,14 @@ from src.core.crypto import crypto_service from src.core.enums import APIFormat from src.core.logger import logger +if TYPE_CHECKING: + from src.models.database import ProviderAPIKey, ProviderEndpoint + def build_provider_headers( - endpoint, - key, + endpoint: "ProviderEndpoint", + key: "ProviderAPIKey", original_headers: Optional[Dict[str, str]] = None, *, extra_headers: Optional[Dict[str, str]] = None, @@ -28,7 +31,8 @@ def build_provider_headers( """ headers: Dict[str, str] = {} - decrypted_key = crypto_service.decrypt(key.api_key) + # api_key 在数据库中是 NOT NULL,类型标注为 Optional 是 SQLAlchemy 限制 + decrypted_key = crypto_service.decrypt(key.api_key) # type: ignore[arg-type] # 根据 API 格式自动选择认证头 api_format = getattr(endpoint, "api_format", None) @@ -68,8 +72,32 @@ def build_provider_headers( return headers +def _normalize_base_url(base_url: str, path: str) -> str: + """ + 规范化 base_url,去除末尾的斜杠和可能与 path 重复的版本前缀。 + + 只有当 path 以版本前缀开头时,才从 base_url 中移除该前缀, + 避免拼接出 /v1/v1/messages 这样的重复路径。 + + 兼容用户填写的各种格式: + - https://api.example.com + - https://api.example.com/ + - https://api.example.com/v1 + - https://api.example.com/v1/ + """ + base = base_url.rstrip("/") + # 只在 path 以版本前缀开头时才去除 base_url 中的该前缀 + # 例如:base="/v1", path="/v1/messages" -> 去除 /v1 + # 例如:base="/v1", path="/chat/completions" -> 不去除(用户可能期望保留) + for suffix in ("/v1beta", "/v1", "/v2", "/v3"): + if base.endswith(suffix) and path.startswith(suffix): + base = base[: -len(suffix)] + break + return base + + def build_provider_url( - endpoint, + endpoint: "ProviderEndpoint", *, query_params: Optional[Dict[str, Any]] = None, path_params: Optional[Dict[str, Any]] = None, @@ -88,8 +116,6 @@ def build_provider_url( path_params: 路径模板参数 (如 {model}) is_stream: 是否为流式请求,用于 Gemini API 选择正确的操作方法 """ - base = endpoint.base_url.rstrip("/") - # 准备路径参数,添加 Gemini API 所需的 action 参数 effective_path_params = dict(path_params) if path_params else {} @@ -123,6 +149,9 @@ def build_provider_url( if not path.startswith("/"): path = f"/{path}" + # 先确定 path,再根据 path 规范化 base_url + # base_url 在数据库中是 NOT NULL,类型标注为 Optional 是 SQLAlchemy 限制 + base = _normalize_base_url(endpoint.base_url, path) # type: ignore[arg-type] url = f"{base}{path}" # 添加查询参数 @@ -134,7 +163,7 @@ def build_provider_url( return url -def _resolve_default_path(api_format) -> str: +def _resolve_default_path(api_format: Optional[str]) -> str: """ 根据 API 格式返回默认路径 """