mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-02 15:52:26 +08:00
refactor(cache): optimize cache service architecture and provider transport
This commit is contained in:
@@ -120,6 +120,33 @@ class CacheService:
|
||||
logger.warning(f"缓存检查失败: {key} - {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def incr(key: str, ttl_seconds: Optional[int] = None) -> int:
|
||||
"""
|
||||
递增缓存值
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
ttl_seconds: 可选,如果提供则刷新 TTL
|
||||
|
||||
Returns:
|
||||
递增后的值,如果失败返回 0
|
||||
"""
|
||||
try:
|
||||
redis = await get_redis_client(require_redis=False)
|
||||
if not redis:
|
||||
return 0
|
||||
|
||||
result = await redis.incr(key)
|
||||
# 如果提供了 TTL,刷新过期时间
|
||||
if ttl_seconds is not None:
|
||||
await redis.expire(key, ttl_seconds)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"缓存递增失败: {key} - {e}")
|
||||
return 0
|
||||
|
||||
|
||||
# 缓存键前缀
|
||||
class CacheKeys:
|
||||
|
||||
157
src/services/cache/model_cache.py
vendored
157
src/services/cache/model_cache.py
vendored
@@ -2,11 +2,9 @@
|
||||
Model 映射缓存服务 - 减少模型查询
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy.exc import OperationalError, ProgrammingError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.config.constants import CacheTTL
|
||||
@@ -106,6 +104,7 @@ class ModelCacheService:
|
||||
Model 对象或 None
|
||||
"""
|
||||
cache_key = f"model:provider_global:{provider_id}:{global_model_id}"
|
||||
hit_count_key = f"model:provider_global:hits:{provider_id}:{global_model_id}"
|
||||
|
||||
# 1. 尝试从缓存获取
|
||||
cached_data = await CacheService.get(cache_key)
|
||||
@@ -113,6 +112,8 @@ class ModelCacheService:
|
||||
logger.debug(
|
||||
f"Model 缓存命中(provider+global): {provider_id[:8]}...+{global_model_id[:8]}..."
|
||||
)
|
||||
# 递增命中计数,同时刷新 TTL
|
||||
await CacheService.incr(hit_count_key, ttl_seconds=ModelCacheService.CACHE_TTL)
|
||||
return ModelCacheService._dict_to_model(cached_data)
|
||||
|
||||
# 2. 缓存未命中,查询数据库
|
||||
@@ -130,6 +131,8 @@ class ModelCacheService:
|
||||
if model:
|
||||
model_dict = ModelCacheService._model_to_dict(model)
|
||||
await CacheService.set(cache_key, model_dict, ttl_seconds=ModelCacheService.CACHE_TTL)
|
||||
# 重置命中计数(新缓存从1开始)
|
||||
await CacheService.set(hit_count_key, 1, ttl_seconds=ModelCacheService.CACHE_TTL)
|
||||
logger.debug(
|
||||
f"Model 已缓存(provider+global): {provider_id[:8]}...+{global_model_id[:8]}..."
|
||||
)
|
||||
@@ -189,9 +192,10 @@ class ModelCacheService:
|
||||
# 清除 model:id 缓存
|
||||
await CacheService.delete(f"model:id:{model_id}")
|
||||
|
||||
# 清除 provider_global 缓存(如果提供了必要参数)
|
||||
# 清除 provider_global 缓存及其命中计数(如果提供了必要参数)
|
||||
if provider_id and global_model_id:
|
||||
await CacheService.delete(f"model:provider_global:{provider_id}:{global_model_id}")
|
||||
await CacheService.delete(f"model:provider_global:hits:{provider_id}:{global_model_id}")
|
||||
logger.debug(
|
||||
f"Model 缓存已清除: {model_id}, provider_global:{provider_id[:8]}...:{global_model_id[:8]}..."
|
||||
)
|
||||
@@ -230,16 +234,20 @@ class ModelCacheService:
|
||||
db: Session, model_name: str
|
||||
) -> Optional[GlobalModel]:
|
||||
"""
|
||||
通过名称或映射解析 GlobalModel(带缓存,支持映射匹配)
|
||||
通过名称解析 GlobalModel(带缓存)
|
||||
|
||||
查找顺序:
|
||||
1. 检查缓存
|
||||
2. 通过映射匹配(查询 Model 表的 provider_model_name 和 provider_model_aliases)
|
||||
2. 通过 provider_model_name 匹配(查询 Model 表)
|
||||
3. 直接匹配 GlobalModel.name(兜底)
|
||||
|
||||
注意:此方法不使用 provider_model_aliases 进行全局解析。
|
||||
provider_model_aliases 是 Provider 级别的别名配置,只在特定 Provider 上下文中生效,
|
||||
由 resolve_provider_model() 处理。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
model_name: 模型名称(可以是 GlobalModel.name 或映射名称)
|
||||
model_name: 模型名称(可以是 GlobalModel.name 或 provider_model_name)
|
||||
|
||||
Returns:
|
||||
GlobalModel 对象或 None
|
||||
@@ -273,116 +281,53 @@ class ModelCacheService:
|
||||
logger.debug(f"GlobalModel 缓存命中(映射解析): {normalized_name}")
|
||||
return ModelCacheService._dict_to_global_model(cached_data)
|
||||
|
||||
# 2. 优先通过 provider_model_name 和映射名称匹配(Provider 配置优先级最高)
|
||||
from sqlalchemy import or_
|
||||
|
||||
# 2. 通过 provider_model_name 匹配(不考虑 provider_model_aliases)
|
||||
# 重要:provider_model_aliases 是 Provider 级别的别名配置,只在特定 Provider 上下文中生效
|
||||
# 全局解析不应该受到某个 Provider 别名配置的影响
|
||||
# 例如:Provider A 把 "haiku" 映射到 "sonnet",不应该影响 Provider B 的 "haiku" 解析
|
||||
from src.models.database import Provider
|
||||
|
||||
# 构建精确的映射匹配条件
|
||||
# 注意:provider_model_aliases 是 JSONB 数组,需要使用 PostgreSQL 的 JSONB 操作符
|
||||
# 对于 SQLite,会在 Python 层面进行过滤
|
||||
try:
|
||||
# 尝试使用 PostgreSQL 的 JSONB 查询(更高效)
|
||||
# 使用 json.dumps 确保正确转义特殊字符,避免 SQL 注入
|
||||
jsonb_pattern = json.dumps([{"name": normalized_name}])
|
||||
models_with_global = (
|
||||
db.query(Model, GlobalModel)
|
||||
.join(Provider, Model.provider_id == Provider.id)
|
||||
.join(GlobalModel, Model.global_model_id == GlobalModel.id)
|
||||
.filter(
|
||||
Provider.is_active == True,
|
||||
Model.is_active == True,
|
||||
GlobalModel.is_active == True,
|
||||
or_(
|
||||
Model.provider_model_name == normalized_name,
|
||||
# PostgreSQL JSONB 查询:检查数组中是否有包含 {"name": "xxx"} 的元素
|
||||
Model.provider_model_aliases.op("@>")(jsonb_pattern),
|
||||
),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
except (OperationalError, ProgrammingError) as e:
|
||||
# JSONB 操作符不支持(如 SQLite),回退到加载匹配 provider_model_name 的 Model
|
||||
# 并在 Python 层过滤 aliases
|
||||
logger.debug(
|
||||
f"JSONB 查询失败,回退到 Python 过滤: {e}",
|
||||
)
|
||||
# 优化:先用 provider_model_name 缩小范围,再加载其他可能匹配的记录
|
||||
models_with_global = (
|
||||
db.query(Model, GlobalModel)
|
||||
.join(Provider, Model.provider_id == Provider.id)
|
||||
.join(GlobalModel, Model.global_model_id == GlobalModel.id)
|
||||
.filter(
|
||||
Provider.is_active == True,
|
||||
Model.is_active == True,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.all()
|
||||
models_with_global = (
|
||||
db.query(Model, GlobalModel)
|
||||
.join(Provider, Model.provider_id == Provider.id)
|
||||
.join(GlobalModel, Model.global_model_id == GlobalModel.id)
|
||||
.filter(
|
||||
Provider.is_active == True,
|
||||
Model.is_active == True,
|
||||
GlobalModel.is_active == True,
|
||||
Model.provider_model_name == normalized_name,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 用于存储匹配结果:{(model_id, global_model_id): (GlobalModel, match_type, priority)}
|
||||
# 使用字典去重,同一个 Model 只保留优先级最高的匹配
|
||||
matched_models_dict = {}
|
||||
|
||||
# 遍历查询结果进行匹配
|
||||
# 收集匹配的 GlobalModel(只通过 provider_model_name 匹配)
|
||||
matched_global_models: List[GlobalModel] = []
|
||||
seen_global_model_ids: set[str] = set()
|
||||
for model, gm in models_with_global:
|
||||
key = (model.id, gm.id)
|
||||
|
||||
# 检查 provider_model_aliases 是否匹配(优先级更高)
|
||||
if model.provider_model_aliases:
|
||||
for alias_entry in model.provider_model_aliases:
|
||||
if isinstance(alias_entry, dict):
|
||||
alias_name = alias_entry.get("name", "").strip()
|
||||
if alias_name == normalized_name:
|
||||
# alias 优先级为 0(最高),覆盖任何已存在的匹配
|
||||
matched_models_dict[key] = (gm, "alias", 0)
|
||||
logger.debug(
|
||||
f"模型名称 '{normalized_name}' 通过映射名称匹配到 "
|
||||
f"GlobalModel: {gm.name} (Model: {model.id[:8]}...)"
|
||||
)
|
||||
break
|
||||
|
||||
# 如果还没有匹配(或只有 provider_model_name 匹配),检查 provider_model_name
|
||||
if key not in matched_models_dict or matched_models_dict[key][1] != "alias":
|
||||
if model.provider_model_name == normalized_name:
|
||||
# provider_model_name 优先级为 1(兜底),只在没有 alias 匹配时使用
|
||||
if key not in matched_models_dict:
|
||||
matched_models_dict[key] = (gm, "provider_model_name", 1)
|
||||
logger.debug(
|
||||
f"模型名称 '{normalized_name}' 通过 provider_model_name 匹配到 "
|
||||
f"GlobalModel: {gm.name} (Model: {model.id[:8]}...)"
|
||||
)
|
||||
|
||||
# 如果通过 provider_model_name/alias 找到了,直接返回
|
||||
if matched_models_dict:
|
||||
# 转换为列表并排序:按 priority(alias=0 优先)、然后按 GlobalModel.name
|
||||
matched_global_models = [
|
||||
(gm, match_type) for gm, match_type, priority in matched_models_dict.values()
|
||||
]
|
||||
matched_global_models.sort(
|
||||
key=lambda item: (
|
||||
0 if item[1] == "alias" else 1, # alias 优先
|
||||
item[0].name # 同优先级按名称排序(确定性)
|
||||
if gm.id not in seen_global_model_ids:
|
||||
seen_global_model_ids.add(gm.id)
|
||||
matched_global_models.append(gm)
|
||||
logger.debug(
|
||||
f"模型名称 '{normalized_name}' 通过 provider_model_name 匹配到 "
|
||||
f"GlobalModel: {gm.name} (Model: {model.id[:8]}...)"
|
||||
)
|
||||
)
|
||||
|
||||
# 记录解析方式
|
||||
resolution_method = matched_global_models[0][1]
|
||||
# 如果通过 provider_model_name 找到了,返回
|
||||
if matched_global_models:
|
||||
resolution_method = "provider_model_name"
|
||||
|
||||
if len(matched_global_models) > 1:
|
||||
# 检测到冲突
|
||||
unique_models = {gm.id: gm for gm, _ in matched_global_models}
|
||||
if len(unique_models) > 1:
|
||||
model_names = [gm.name for gm in unique_models.values()]
|
||||
logger.warning(
|
||||
f"模型映射冲突: 名称 '{normalized_name}' 匹配到多个不同的 GlobalModel: "
|
||||
f"{', '.join(model_names)},使用第一个匹配结果"
|
||||
)
|
||||
# 记录冲突指标
|
||||
model_mapping_conflict_total.inc()
|
||||
# 检测到冲突(多个不同的 GlobalModel 有相同的 provider_model_name)
|
||||
model_names = [gm.name for gm in matched_global_models if gm.name]
|
||||
logger.warning(
|
||||
f"模型映射冲突: 名称 '{normalized_name}' 匹配到多个不同的 GlobalModel: "
|
||||
f"{', '.join(model_names)},使用第一个匹配结果"
|
||||
)
|
||||
# 记录冲突指标
|
||||
model_mapping_conflict_total.inc()
|
||||
|
||||
# 返回第一个匹配的 GlobalModel
|
||||
result_global_model: GlobalModel = matched_global_models[0][0]
|
||||
result_global_model = matched_global_models[0]
|
||||
global_model_dict = ModelCacheService._global_model_to_dict(result_global_model)
|
||||
await CacheService.set(
|
||||
cache_key, global_model_dict, ttl_seconds=ModelCacheService.CACHE_TTL
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
- 根据 API 格式或端点配置生成请求 URL
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from src.core.api_format_metadata import get_auth_config, get_default_path, resolve_api_format
|
||||
@@ -14,11 +14,14 @@ from src.core.crypto import crypto_service
|
||||
from src.core.enums import APIFormat
|
||||
from src.core.logger import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.models.database import ProviderAPIKey, ProviderEndpoint
|
||||
|
||||
|
||||
|
||||
def build_provider_headers(
|
||||
endpoint,
|
||||
key,
|
||||
endpoint: "ProviderEndpoint",
|
||||
key: "ProviderAPIKey",
|
||||
original_headers: Optional[Dict[str, str]] = None,
|
||||
*,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
@@ -28,7 +31,8 @@ def build_provider_headers(
|
||||
"""
|
||||
headers: Dict[str, str] = {}
|
||||
|
||||
decrypted_key = crypto_service.decrypt(key.api_key)
|
||||
# api_key 在数据库中是 NOT NULL,类型标注为 Optional 是 SQLAlchemy 限制
|
||||
decrypted_key = crypto_service.decrypt(key.api_key) # type: ignore[arg-type]
|
||||
|
||||
# 根据 API 格式自动选择认证头
|
||||
api_format = getattr(endpoint, "api_format", None)
|
||||
@@ -68,8 +72,32 @@ def build_provider_headers(
|
||||
return headers
|
||||
|
||||
|
||||
def _normalize_base_url(base_url: str, path: str) -> str:
|
||||
"""
|
||||
规范化 base_url,去除末尾的斜杠和可能与 path 重复的版本前缀。
|
||||
|
||||
只有当 path 以版本前缀开头时,才从 base_url 中移除该前缀,
|
||||
避免拼接出 /v1/v1/messages 这样的重复路径。
|
||||
|
||||
兼容用户填写的各种格式:
|
||||
- https://api.example.com
|
||||
- https://api.example.com/
|
||||
- https://api.example.com/v1
|
||||
- https://api.example.com/v1/
|
||||
"""
|
||||
base = base_url.rstrip("/")
|
||||
# 只在 path 以版本前缀开头时才去除 base_url 中的该前缀
|
||||
# 例如:base="/v1", path="/v1/messages" -> 去除 /v1
|
||||
# 例如:base="/v1", path="/chat/completions" -> 不去除(用户可能期望保留)
|
||||
for suffix in ("/v1beta", "/v1", "/v2", "/v3"):
|
||||
if base.endswith(suffix) and path.startswith(suffix):
|
||||
base = base[: -len(suffix)]
|
||||
break
|
||||
return base
|
||||
|
||||
|
||||
def build_provider_url(
|
||||
endpoint,
|
||||
endpoint: "ProviderEndpoint",
|
||||
*,
|
||||
query_params: Optional[Dict[str, Any]] = None,
|
||||
path_params: Optional[Dict[str, Any]] = None,
|
||||
@@ -88,8 +116,6 @@ def build_provider_url(
|
||||
path_params: 路径模板参数 (如 {model})
|
||||
is_stream: 是否为流式请求,用于 Gemini API 选择正确的操作方法
|
||||
"""
|
||||
base = endpoint.base_url.rstrip("/")
|
||||
|
||||
# 准备路径参数,添加 Gemini API 所需的 action 参数
|
||||
effective_path_params = dict(path_params) if path_params else {}
|
||||
|
||||
@@ -123,6 +149,9 @@ def build_provider_url(
|
||||
if not path.startswith("/"):
|
||||
path = f"/{path}"
|
||||
|
||||
# 先确定 path,再根据 path 规范化 base_url
|
||||
# base_url 在数据库中是 NOT NULL,类型标注为 Optional 是 SQLAlchemy 限制
|
||||
base = _normalize_base_url(endpoint.base_url, path) # type: ignore[arg-type]
|
||||
url = f"{base}{path}"
|
||||
|
||||
# 添加查询参数
|
||||
@@ -134,7 +163,7 @@ def build_provider_url(
|
||||
return url
|
||||
|
||||
|
||||
def _resolve_default_path(api_format) -> str:
|
||||
def _resolve_default_path(api_format: Optional[str]) -> str:
|
||||
"""
|
||||
根据 API 格式返回默认路径
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user