refactor(backend): optimize cache system and model/provider services

This commit is contained in:
fawney19
2025-12-15 14:30:21 +08:00
parent 56fb6bf36c
commit 7068aa9130
12 changed files with 170 additions and 517 deletions

View File

@@ -36,7 +36,7 @@ from src.core.logger import logger
from sqlalchemy.orm import Session, selectinload
from src.core.enums import APIFormat
from src.core.exceptions import ProviderNotAvailableException
from src.core.exceptions import ModelNotSupportedException, ProviderNotAvailableException
from src.models.database import (
ApiKey,
Model,
@@ -227,19 +227,6 @@ class CacheAwareScheduler:
if provider_offset == 0:
# 没有找到任何候选,提供友好的错误提示
error_msg = f"模型 '{model_name}' 不可用"
# 查找相似模型
from src.services.model.mapping_resolver import get_model_mapping_resolver
resolver = get_model_mapping_resolver()
similar_models = resolver.find_similar_models(db, model_name, limit=3)
if similar_models:
suggestions = [
f"{name} (相似度: {score:.0%})" for name, score in similar_models
]
error_msg += f"\n\n您可能想使用以下模型:\n - " + "\n - ".join(suggestions)
raise ProviderNotAvailableException(error_msg)
break
@@ -579,13 +566,20 @@ class CacheAwareScheduler:
target_format = normalize_api_format(api_format)
# 0. 解析 model_name 到 GlobalModel用于缓存亲和性的规范化标识
from src.services.model.mapping_resolver import get_model_mapping_resolver
mapping_resolver = get_model_mapping_resolver()
global_model = await mapping_resolver.get_global_model_by_request(db, model_name, None)
# 0. 解析 model_name 到 GlobalModel直接查找,用户必须使用标准名称
from src.models.database import GlobalModel
global_model = (
db.query(GlobalModel)
.filter(GlobalModel.name == model_name, GlobalModel.is_active == True)
.first()
)
if not global_model:
logger.warning(f"GlobalModel not found: {model_name}")
raise ModelNotSupportedException(model=model_name)
# 使用 GlobalModel.id 作为缓存亲和性的模型标识,确保别名和规范名都能命中同一个缓存
global_model_id: str = str(global_model.id) if global_model else model_name
global_model_id: str = str(global_model.id)
# 获取合并后的访问限制ApiKey + User
restrictions = self._get_effective_restrictions(user_api_key)
@@ -685,7 +679,6 @@ class CacheAwareScheduler:
db.query(Provider)
.options(
selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys),
selectinload(Provider.model_mappings),
# 同时加载 models 和 global_model 关系,以便 get_effective_* 方法能正确继承默认值
selectinload(Provider.models).selectinload(Model.global_model),
)
@@ -715,63 +708,33 @@ class CacheAwareScheduler:
- 模型支持的能力是全局的,与具体的 Key 无关
- 如果模型不支持某能力,整个 Provider 的所有 Key 都应该被跳过
映射回退逻辑:
- 如果存在模型映射mapping先尝试映射后的模型
- 如果映射后的模型因能力不满足而失败,回退尝试原模型
- 其他失败原因如模型不存在、Provider 未实现等)不触发回退
Args:
db: 数据库会话
provider: Provider 对象
model_name: 模型名称
model_name: 模型名称(必须是 GlobalModel.name
is_stream: 是否是流式请求,如果为 True 则同时检查流式支持
capability_requirements: 能力需求(可选),用于检查模型是否支持所需能力
Returns:
(is_supported, skip_reason, supported_capabilities) - 是否支持、跳过原因、模型支持的能力列表
"""
from src.services.model.mapping_resolver import get_model_mapping_resolver
from src.models.database import GlobalModel
mapping_resolver = get_model_mapping_resolver()
# 获取映射后的模型,同时检查是否发生了映射
global_model, is_mapped = await mapping_resolver.get_global_model_with_mapping_info(
db, model_name, str(provider.id)
# 直接通过 GlobalModel.name 查找
global_model = (
db.query(GlobalModel)
.filter(GlobalModel.name == model_name, GlobalModel.is_active == True)
.first()
)
if not global_model:
return False, "模型不存在", None
# 尝试检查映射后的模型
# 检查模型支持
is_supported, skip_reason, caps = await self._check_model_support_for_global_model(
db, provider, global_model, model_name, is_stream, capability_requirements
)
# 如果映射后的模型因能力不满足而失败,且存在映射,则回退尝试原模型
if not is_supported and is_mapped and skip_reason and "不支持能力" in skip_reason:
logger.debug(
f"Provider {provider.name} 映射模型 {global_model.name} 能力不满足,"
f"回退尝试原模型 {model_name}"
)
# 获取原模型(不应用映射)
original_global_model = await mapping_resolver.get_global_model_direct(db, model_name)
if original_global_model and original_global_model.id != global_model.id:
# 尝试原模型
is_supported, skip_reason, caps = await self._check_model_support_for_global_model(
db,
provider,
original_global_model,
model_name,
is_stream,
capability_requirements,
)
if is_supported:
logger.debug(
f"Provider {provider.name} 原模型 {original_global_model.name} 支持所需能力"
)
return is_supported, skip_reason, caps
async def _check_model_support_for_global_model(

View File

@@ -6,8 +6,7 @@
2. RedisCache: Redis 缓存(分布式)
使用场景:
- ModelMappingResolver: 模型映射与别名解析缓存
- ModelMapper: 模型映射缓存
- ModelCacheService: 模型解析缓存
- 其他需要缓存的服务
"""

View File

@@ -3,18 +3,14 @@
统一管理各种缓存的失效逻辑,支持:
1. GlobalModel 变更时失效相关缓存
2. ModelMapping 变更时失效别名/降级缓存
3. Model 变更时失效模型映射缓存
4. 支持同步和异步缓存后端
2. Model 变更时失效模型映射缓存
3. 支持同步和异步缓存后端
"""
import asyncio
from typing import Optional
from src.core.logger import logger
from src.core.logger import logger
class CacheInvalidationService:
"""
@@ -25,14 +21,8 @@ class CacheInvalidationService:
def __init__(self):
"""初始化缓存失效服务"""
self._mapping_resolver = None
self._model_mappers = [] # 可能有多个 ModelMapperMiddleware 实例
def set_mapping_resolver(self, mapping_resolver):
"""设置模型映射解析器实例"""
self._mapping_resolver = mapping_resolver
logger.debug(f"[CacheInvalidation] 模型映射解析器已注册 (实例: {id(mapping_resolver)})")
def register_model_mapper(self, model_mapper):
"""注册 ModelMapper 实例"""
if model_mapper not in self._model_mappers:
@@ -48,37 +38,12 @@ class CacheInvalidationService:
"""
logger.info(f"[CacheInvalidation] GlobalModel 变更: {model_name}")
# 异步失效模型解析器中的缓存
if self._mapping_resolver:
asyncio.create_task(self._mapping_resolver.invalidate_global_model_cache())
# 失效所有 ModelMapper 中与此模型相关的缓存
for mapper in self._model_mappers:
# 清空所有缓存(因为不知道哪些 provider 使用了这个模型)
mapper.clear_cache()
logger.debug(f"[CacheInvalidation] 已清空 ModelMapper 缓存")
def on_model_mapping_changed(self, source_model: str, provider_id: Optional[str] = None):
"""
ModelMapping 变更时的缓存失效
Args:
source_model: 变更的源模型名
provider_id: 相关 ProviderNone 表示全局)
"""
logger.info(f"[CacheInvalidation] ModelMapping 变更: {source_model} (provider={provider_id})")
if self._mapping_resolver:
asyncio.create_task(
self._mapping_resolver.invalidate_mapping_cache(source_model, provider_id)
)
for mapper in self._model_mappers:
if provider_id:
mapper.refresh_cache(provider_id)
else:
mapper.clear_cache()
def on_model_changed(self, provider_id: str, global_model_id: str):
"""
Model 变更时的缓存失效
@@ -98,9 +63,6 @@ class CacheInvalidationService:
"""清空所有缓存"""
logger.info("[CacheInvalidation] 清空所有缓存")
if self._mapping_resolver:
asyncio.create_task(self._mapping_resolver.clear_cache())
for mapper in self._model_mappers:
mapper.clear_cache()

View File

@@ -1,5 +1,5 @@
"""
Model 映射缓存服务 - 减少模型映射和别名查询
Model 映射缓存服务 - 减少模型查询
"""
from typing import Optional
@@ -9,8 +9,7 @@ from sqlalchemy.orm import Session
from src.config.constants import CacheTTL
from src.core.cache_service import CacheService
from src.core.logger import logger
from src.models.database import GlobalModel, Model, ModelMapping
from src.models.database import GlobalModel, Model
class ModelCacheService:
@@ -158,56 +157,6 @@ class ModelCacheService:
return global_model
@staticmethod
async def resolve_alias(
db: Session, source_model: str, provider_id: Optional[str] = None
) -> Optional[str]:
"""
解析模型别名(带缓存)
Args:
db: 数据库会话
source_model: 源模型名称或别名
provider_id: Provider ID可选用于 Provider 特定别名)
Returns:
目标 GlobalModel ID 或 None
"""
# 构造缓存键
if provider_id:
cache_key = f"alias:provider:{provider_id}:{source_model}"
else:
cache_key = f"alias:global:{source_model}"
# 1. 尝试从缓存获取
cached_result = await CacheService.get(cache_key)
if cached_result:
logger.debug(f"别名缓存命中: {source_model} (provider: {provider_id or 'global'})")
return cached_result
# 2. 缓存未命中,查询数据库
query = db.query(ModelMapping).filter(ModelMapping.source_model == source_model)
if provider_id:
# Provider 特定别名优先
query = query.filter(ModelMapping.provider_id == provider_id)
else:
# 全局别名
query = query.filter(ModelMapping.provider_id.is_(None))
mapping = query.first()
# 3. 写入缓存
target_global_model_id = mapping.target_global_model_id if mapping else None
await CacheService.set(
cache_key, target_global_model_id, ttl_seconds=ModelCacheService.CACHE_TTL
)
if mapping:
logger.debug(f"别名已缓存: {source_model}{target_global_model_id}")
return target_global_model_id
@staticmethod
async def invalidate_model_cache(
model_id: str, provider_id: Optional[str] = None, global_model_id: Optional[str] = None
@@ -237,17 +186,6 @@ class ModelCacheService:
await CacheService.delete(f"global_model:name:{name}")
logger.debug(f"GlobalModel 缓存已清除: {global_model_id}")
@staticmethod
async def invalidate_alias_cache(source_model: str, provider_id: Optional[str] = None):
"""清除别名缓存"""
if provider_id:
cache_key = f"alias:provider:{provider_id}:{source_model}"
else:
cache_key = f"alias:global:{source_model}"
await CacheService.delete(cache_key)
logger.debug(f"别名缓存已清除: {source_model}")
@staticmethod
def _model_to_dict(model: Model) -> dict:
"""将 Model 对象转换为字典"""
@@ -256,6 +194,7 @@ class ModelCacheService:
"provider_id": model.provider_id,
"global_model_id": model.global_model_id,
"provider_model_name": model.provider_model_name,
"provider_model_aliases": getattr(model, "provider_model_aliases", None),
"is_active": model.is_active,
"is_available": model.is_available if hasattr(model, "is_available") else True,
"price_per_request": (
@@ -266,6 +205,7 @@ class ModelCacheService:
"supports_function_calling": model.supports_function_calling,
"supports_streaming": model.supports_streaming,
"supports_extended_thinking": model.supports_extended_thinking,
"supports_image_generation": getattr(model, "supports_image_generation", None),
"config": model.config,
}
@@ -277,6 +217,7 @@ class ModelCacheService:
provider_id=model_dict["provider_id"],
global_model_id=model_dict["global_model_id"],
provider_model_name=model_dict["provider_model_name"],
provider_model_aliases=model_dict.get("provider_model_aliases"),
is_active=model_dict["is_active"],
is_available=model_dict.get("is_available", True),
price_per_request=model_dict.get("price_per_request"),
@@ -285,6 +226,7 @@ class ModelCacheService:
supports_function_calling=model_dict.get("supports_function_calling"),
supports_streaming=model_dict.get("supports_streaming"),
supports_extended_thinking=model_dict.get("supports_extended_thinking"),
supports_image_generation=model_dict.get("supports_image_generation"),
config=model_dict.get("config"),
)
return model
@@ -296,12 +238,11 @@ class ModelCacheService:
"id": global_model.id,
"name": global_model.name,
"display_name": global_model.display_name,
"family": global_model.family,
"group_id": global_model.group_id,
"supports_vision": global_model.supports_vision,
"supports_thinking": global_model.supports_thinking,
"context_window": global_model.context_window,
"max_output_tokens": global_model.max_output_tokens,
"default_supports_vision": global_model.default_supports_vision,
"default_supports_function_calling": global_model.default_supports_function_calling,
"default_supports_streaming": global_model.default_supports_streaming,
"default_supports_extended_thinking": global_model.default_supports_extended_thinking,
"default_supports_image_generation": global_model.default_supports_image_generation,
"is_active": global_model.is_active,
"description": global_model.description,
}
@@ -313,12 +254,11 @@ class ModelCacheService:
id=global_model_dict["id"],
name=global_model_dict["name"],
display_name=global_model_dict.get("display_name"),
family=global_model_dict.get("family"),
group_id=global_model_dict.get("group_id"),
supports_vision=global_model_dict.get("supports_vision", False),
supports_thinking=global_model_dict.get("supports_thinking", False),
context_window=global_model_dict.get("context_window"),
max_output_tokens=global_model_dict.get("max_output_tokens"),
default_supports_vision=global_model_dict.get("default_supports_vision", False),
default_supports_function_calling=global_model_dict.get("default_supports_function_calling", False),
default_supports_streaming=global_model_dict.get("default_supports_streaming", True),
default_supports_extended_thinking=global_model_dict.get("default_supports_extended_thinking", False),
default_supports_image_generation=global_model_dict.get("default_supports_image_generation", False),
is_active=global_model_dict.get("is_active", True),
description=global_model_dict.get("description"),
)

View File

@@ -6,7 +6,7 @@
使用场景:
1. 多实例部署时,确保所有实例的缓存一致性
2. GlobalModel/ModelMapping 变更时,同步失效所有实例的缓存
2. GlobalModel/Model 变更时,同步失效所有实例的缓存
"""
import asyncio
@@ -29,7 +29,6 @@ class CacheSyncService:
# Redis 频道名称
CHANNEL_GLOBAL_MODEL = "cache:invalidate:global_model"
CHANNEL_MODEL_MAPPING = "cache:invalidate:model_mapping"
CHANNEL_MODEL = "cache:invalidate:model"
CHANNEL_CLEAR_ALL = "cache:invalidate:clear_all"
@@ -58,7 +57,6 @@ class CacheSyncService:
# 订阅所有缓存失效频道
await self._pubsub.subscribe(
self.CHANNEL_GLOBAL_MODEL,
self.CHANNEL_MODEL_MAPPING,
self.CHANNEL_MODEL,
self.CHANNEL_CLEAR_ALL,
)
@@ -68,7 +66,7 @@ class CacheSyncService:
self._running = True
logger.info("[CacheSync] 缓存同步服务已启动,订阅频道: "
f"{self.CHANNEL_GLOBAL_MODEL}, {self.CHANNEL_MODEL_MAPPING}, "
f"{self.CHANNEL_GLOBAL_MODEL}, "
f"{self.CHANNEL_MODEL}, {self.CHANNEL_CLEAR_ALL}")
except Exception as e:
logger.error(f"[CacheSync] 启动失败: {e}")
@@ -141,14 +139,6 @@ class CacheSyncService:
"""发布 GlobalModel 变更通知"""
await self._publish(self.CHANNEL_GLOBAL_MODEL, {"model_name": model_name})
async def publish_model_mapping_changed(
self, source_model: str, provider_id: Optional[str] = None
):
"""发布 ModelMapping 变更通知"""
await self._publish(
self.CHANNEL_MODEL_MAPPING, {"source_model": source_model, "provider_id": provider_id}
)
async def publish_model_changed(self, provider_id: str, global_model_id: str):
"""发布 Model 变更通知"""
await self._publish(