refactor(backend): optimize cache system and model/provider services

This commit is contained in:
fawney19
2025-12-15 14:30:21 +08:00
parent 56fb6bf36c
commit 7068aa9130
12 changed files with 170 additions and 517 deletions

View File

@@ -36,7 +36,7 @@ from src.core.logger import logger
from sqlalchemy.orm import Session, selectinload
from src.core.enums import APIFormat
from src.core.exceptions import ProviderNotAvailableException
from src.core.exceptions import ModelNotSupportedException, ProviderNotAvailableException
from src.models.database import (
ApiKey,
Model,
@@ -227,19 +227,6 @@ class CacheAwareScheduler:
if provider_offset == 0:
# 没有找到任何候选,提供友好的错误提示
error_msg = f"模型 '{model_name}' 不可用"
# 查找相似模型
from src.services.model.mapping_resolver import get_model_mapping_resolver
resolver = get_model_mapping_resolver()
similar_models = resolver.find_similar_models(db, model_name, limit=3)
if similar_models:
suggestions = [
f"{name} (相似度: {score:.0%})" for name, score in similar_models
]
error_msg += f"\n\n您可能想使用以下模型:\n - " + "\n - ".join(suggestions)
raise ProviderNotAvailableException(error_msg)
break
@@ -579,13 +566,20 @@ class CacheAwareScheduler:
target_format = normalize_api_format(api_format)
# 0. 解析 model_name 到 GlobalModel用于缓存亲和性的规范化标识
from src.services.model.mapping_resolver import get_model_mapping_resolver
mapping_resolver = get_model_mapping_resolver()
global_model = await mapping_resolver.get_global_model_by_request(db, model_name, None)
# 0. 解析 model_name 到 GlobalModel直接查找,用户必须使用标准名称
from src.models.database import GlobalModel
global_model = (
db.query(GlobalModel)
.filter(GlobalModel.name == model_name, GlobalModel.is_active == True)
.first()
)
if not global_model:
logger.warning(f"GlobalModel not found: {model_name}")
raise ModelNotSupportedException(model=model_name)
# 使用 GlobalModel.id 作为缓存亲和性的模型标识,确保别名和规范名都能命中同一个缓存
global_model_id: str = str(global_model.id) if global_model else model_name
global_model_id: str = str(global_model.id)
# 获取合并后的访问限制ApiKey + User
restrictions = self._get_effective_restrictions(user_api_key)
@@ -685,7 +679,6 @@ class CacheAwareScheduler:
db.query(Provider)
.options(
selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys),
selectinload(Provider.model_mappings),
# 同时加载 models 和 global_model 关系,以便 get_effective_* 方法能正确继承默认值
selectinload(Provider.models).selectinload(Model.global_model),
)
@@ -715,63 +708,33 @@ class CacheAwareScheduler:
- 模型支持的能力是全局的,与具体的 Key 无关
- 如果模型不支持某能力,整个 Provider 的所有 Key 都应该被跳过
映射回退逻辑:
- 如果存在模型映射mapping先尝试映射后的模型
- 如果映射后的模型因能力不满足而失败,回退尝试原模型
- 其他失败原因如模型不存在、Provider 未实现等)不触发回退
Args:
db: 数据库会话
provider: Provider 对象
model_name: 模型名称
model_name: 模型名称(必须是 GlobalModel.name
is_stream: 是否是流式请求,如果为 True 则同时检查流式支持
capability_requirements: 能力需求(可选),用于检查模型是否支持所需能力
Returns:
(is_supported, skip_reason, supported_capabilities) - 是否支持、跳过原因、模型支持的能力列表
"""
from src.services.model.mapping_resolver import get_model_mapping_resolver
from src.models.database import GlobalModel
mapping_resolver = get_model_mapping_resolver()
# 获取映射后的模型,同时检查是否发生了映射
global_model, is_mapped = await mapping_resolver.get_global_model_with_mapping_info(
db, model_name, str(provider.id)
# 直接通过 GlobalModel.name 查找
global_model = (
db.query(GlobalModel)
.filter(GlobalModel.name == model_name, GlobalModel.is_active == True)
.first()
)
if not global_model:
return False, "模型不存在", None
# 尝试检查映射后的模型
# 检查模型支持
is_supported, skip_reason, caps = await self._check_model_support_for_global_model(
db, provider, global_model, model_name, is_stream, capability_requirements
)
# 如果映射后的模型因能力不满足而失败,且存在映射,则回退尝试原模型
if not is_supported and is_mapped and skip_reason and "不支持能力" in skip_reason:
logger.debug(
f"Provider {provider.name} 映射模型 {global_model.name} 能力不满足,"
f"回退尝试原模型 {model_name}"
)
# 获取原模型(不应用映射)
original_global_model = await mapping_resolver.get_global_model_direct(db, model_name)
if original_global_model and original_global_model.id != global_model.id:
# 尝试原模型
is_supported, skip_reason, caps = await self._check_model_support_for_global_model(
db,
provider,
original_global_model,
model_name,
is_stream,
capability_requirements,
)
if is_supported:
logger.debug(
f"Provider {provider.name} 原模型 {original_global_model.name} 支持所需能力"
)
return is_supported, skip_reason, caps
async def _check_model_support_for_global_model(