refactor(backend): optimize cache system and model/provider services

This commit is contained in:
fawney19
2025-12-15 14:30:21 +08:00
parent 56fb6bf36c
commit 7068aa9130
12 changed files with 170 additions and 517 deletions

View File

@@ -1,19 +1,15 @@
"""
模型服务模块
包含模型管理、模型映射、成本计算等功能。
包含模型管理、成本计算等功能。
"""
from src.services.model.cost import ModelCostService
from src.services.model.global_model import GlobalModelService
from src.services.model.mapper import ModelMapperMiddleware
from src.services.model.mapping_resolver import ModelMappingResolver
from src.services.model.service import ModelService
__all__ = [
"ModelService",
"GlobalModelService",
"ModelMapperMiddleware",
"ModelMappingResolver",
"ModelCostService",
]

View File

@@ -14,7 +14,7 @@ from typing import Dict, Optional, Tuple, Union
from sqlalchemy.orm import Session
from src.core.logger import logger
from src.models.database import GlobalModel, Model, ModelMapping, Provider
from src.models.database import GlobalModel, Model, Provider
ProviderRef = Union[str, Provider, None]
@@ -161,16 +161,11 @@ class ModelCostService:
result = None
if provider_obj:
from src.services.model.mapping_resolver import resolve_model_to_global_name
global_model_name = await resolve_model_to_global_name(
self.db, model, provider_obj.id
)
# 直接通过 GlobalModel.name 查找
global_model = (
self.db.query(GlobalModel)
.filter(
GlobalModel.name == global_model_name,
GlobalModel.name == model,
GlobalModel.is_active == True,
)
.first()
@@ -226,17 +221,14 @@ class ModelCostService:
注意:如果模型配置了阶梯计费,此方法返回第一个阶梯的价格作为默认值。
实际计费时应使用 compute_cost_with_tiered_pricing 方法。
计费逻辑(基于 mapping_type:
1. 查找 ModelMapping如果存在
2. 如果 mapping_type='alias':使用目标 GlobalModel 的价格
3. 如果 mapping_type='mapping':尝试使用 source_model 对应的 GlobalModel 价格
- 如果 source_model 对应的 GlobalModel 存在且有 Model 实现,使用那个价格
- 否则回退到目标 GlobalModel 的价格
4. 如果没有找到任何 ModelMapping尝试直接匹配 GlobalModel.name
计费逻辑:
1. 直接通过 GlobalModel.name 匹配
2. 查找该 Provider 的 Model 实现
3. 获取价格配置
Args:
provider: Provider 对象或提供商名称
model: 用户请求的模型名(可能是 GlobalModel.name 或别名
model: 用户请求的模型名(必须是 GlobalModel.name
Returns:
(input_price, output_price) 元组
@@ -253,136 +245,37 @@ class ModelCostService:
output_price = None
if provider_obj:
# 步骤 1: 查找 ModelMapping 以确定 mapping_type
from src.models.database import ModelMapping
mapping = None
# 先查 Provider 特定映射
mapping = (
self.db.query(ModelMapping)
# 直接通过 GlobalModel.name 查找(用户必须使用标准模型名称)
global_model = (
self.db.query(GlobalModel)
.filter(
ModelMapping.source_model == model,
ModelMapping.provider_id == provider_obj.id,
ModelMapping.is_active == True,
GlobalModel.name == model,
GlobalModel.is_active == True,
)
.first()
)
# 再查全局映射
if not mapping:
mapping = (
self.db.query(ModelMapping)
if global_model:
model_obj = (
self.db.query(Model)
.filter(
ModelMapping.source_model == model,
ModelMapping.provider_id.is_(None),
ModelMapping.is_active == True,
Model.provider_id == provider_obj.id,
Model.global_model_id == global_model.id,
Model.is_active == True,
)
.first()
)
if mapping:
# 有映射,根据 mapping_type 决定计费模型
if mapping.mapping_type == "mapping":
# mapping 模式:尝试使用 source_model 对应的 GlobalModel 价格
source_global_model = (
self.db.query(GlobalModel)
.filter(
GlobalModel.name == model,
GlobalModel.is_active == True,
)
.first()
)
if source_global_model:
source_model_obj = (
self.db.query(Model)
.filter(
Model.provider_id == provider_obj.id,
Model.global_model_id == source_global_model.id,
Model.is_active == True,
)
.first()
)
if source_model_obj:
# 检查是否有阶梯计费
tiered = source_model_obj.get_effective_tiered_pricing()
if tiered and tiered.get("tiers"):
first_tier = tiered["tiers"][0]
input_price = first_tier.get("input_price_per_1m", 0)
output_price = first_tier.get("output_price_per_1m", 0)
else:
input_price = source_model_obj.get_effective_input_price()
output_price = source_model_obj.get_effective_output_price()
logger.debug(f"[mapping模式] 使用源模型价格: {model} "
f"(输入: ${input_price}/M, 输出: ${output_price}/M)")
# alias 模式或 mapping 模式未找到源模型价格:使用目标 GlobalModel 价格
if input_price is None:
target_global_model = (
self.db.query(GlobalModel)
.filter(
GlobalModel.id == mapping.target_global_model_id,
GlobalModel.is_active == True,
)
.first()
)
if target_global_model:
target_model_obj = (
self.db.query(Model)
.filter(
Model.provider_id == provider_obj.id,
Model.global_model_id == target_global_model.id,
Model.is_active == True,
)
.first()
)
if target_model_obj:
# 检查是否有阶梯计费
tiered = target_model_obj.get_effective_tiered_pricing()
if tiered and tiered.get("tiers"):
first_tier = tiered["tiers"][0]
input_price = first_tier.get("input_price_per_1m", 0)
output_price = first_tier.get("output_price_per_1m", 0)
else:
input_price = target_model_obj.get_effective_input_price()
output_price = target_model_obj.get_effective_output_price()
mode_label = (
"alias模式"
if mapping.mapping_type == "alias"
else "mapping模式(回退)"
)
logger.debug(f"[{mode_label}] 使用目标模型价格: {model} -> {target_global_model.name} "
f"(输入: ${input_price}/M, 输出: ${output_price}/M)")
else:
# 没有映射,尝试直接匹配 GlobalModel.name
global_model = (
self.db.query(GlobalModel)
.filter(
GlobalModel.name == model,
GlobalModel.is_active == True,
)
.first()
)
if global_model:
model_obj = (
self.db.query(Model)
.filter(
Model.provider_id == provider_obj.id,
Model.global_model_id == global_model.id,
Model.is_active == True,
)
.first()
)
if model_obj:
# 检查是否有阶梯计费
tiered = model_obj.get_effective_tiered_pricing()
if tiered and tiered.get("tiers"):
first_tier = tiered["tiers"][0]
input_price = first_tier.get("input_price_per_1m", 0)
output_price = first_tier.get("output_price_per_1m", 0)
else:
input_price = model_obj.get_effective_input_price()
output_price = model_obj.get_effective_output_price()
logger.debug(f"找到模型价格配置: {provider_name}/{model} "
f"(输入: ${input_price}/M, 输出: ${output_price}/M)")
if model_obj:
# 检查是否有阶梯计费
tiered = model_obj.get_effective_tiered_pricing()
if tiered and tiered.get("tiers"):
first_tier = tiered["tiers"][0]
input_price = first_tier.get("input_price_per_1m", 0)
output_price = first_tier.get("output_price_per_1m", 0)
else:
input_price = model_obj.get_effective_input_price()
output_price = model_obj.get_effective_output_price()
logger.debug(f"找到模型价格配置: {provider_name}/{model} "
f"(输入: ${input_price}/M, 输出: ${output_price}/M)")
# 如果没有找到价格配置,使用 0.0 并记录警告
if input_price is None:
@@ -404,15 +297,14 @@ class ModelCostService:
"""
返回给定 provider/model 的 (input_price, output_price)。
新架构逻辑:
1. 使用 ModelMappingResolver 解析别名(如果是)
2. 解析为 GlobalModel.name
3. 查找该 Provider 的 Model 实现
4. 获取价格配置
逻辑:
1. 直接通过 GlobalModel.name 匹配
2. 查找该 Provider 的 Model 实现
3. 获取价格配置
Args:
provider: Provider 对象或提供商名称
model: 用户请求的模型名(可能是 GlobalModel.name 或别名
model: 用户请求的模型名(必须是 GlobalModel.name
Returns:
(input_price, output_price) 元组
@@ -434,15 +326,9 @@ class ModelCostService:
"""
异步版本: 返回缓存创建/读取价格(每 1M tokens
新架构逻辑:
1. 使用 ModelMappingResolver 解析别名(如果是)
2. 解析为 GlobalModel.name
3. 查找该 Provider 的 Model 实现
4. 获取缓存价格配置
Args:
provider: Provider 对象或提供商名称
model: 用户请求的模型名(可能是 GlobalModel.name 或别名
model: 用户请求的模型名(必须是 GlobalModel.name
input_price: 基础输入价格(用于 Claude 模型的默认估算)
Returns:
@@ -460,22 +346,17 @@ class ModelCostService:
cache_read_price = None
if provider_obj:
# 步骤 1: 检查是否是别名
from src.services.model.mapping_resolver import resolve_model_to_global_name
global_model_name = await resolve_model_to_global_name(self.db, model, provider_obj.id)
# 步骤 2: 查找 GlobalModel
# 直接通过 GlobalModel.name 查找(用户必须使用标准模型名称)
global_model = (
self.db.query(GlobalModel)
.filter(
GlobalModel.name == global_model_name,
GlobalModel.name == model,
GlobalModel.is_active == True,
)
.first()
)
# 步骤 3: 查找该 Provider 的 Model 实现
# 查找该 Provider 的 Model 实现
if global_model:
model_obj = (
self.db.query(Model)
@@ -517,15 +398,9 @@ class ModelCostService:
"""
异步版本: 返回按次计费价格(每次请求的固定费用)。
新架构逻辑:
1. 使用 ModelMappingResolver 解析别名(如果是)
2. 解析为 GlobalModel.name
3. 查找该 Provider 的 Model 实现
4. 获取按次计费价格配置
Args:
provider: Provider 对象或提供商名称
model: 用户请求的模型名(可能是 GlobalModel.name 或别名
model: 用户请求的模型名(必须是 GlobalModel.name
Returns:
按次计费价格,如果没有配置则返回 None
@@ -534,22 +409,17 @@ class ModelCostService:
price_per_request = None
if provider_obj:
# 步骤 1: 检查是否是别名
from src.services.model.mapping_resolver import resolve_model_to_global_name
global_model_name = await resolve_model_to_global_name(self.db, model, provider_obj.id)
# 步骤 2: 查找 GlobalModel
# 直接通过 GlobalModel.name 查找(用户必须使用标准模型名称)
global_model = (
self.db.query(GlobalModel)
.filter(
GlobalModel.name == global_model_name,
GlobalModel.name == model,
GlobalModel.is_active == True,
)
.first()
)
# 步骤 3: 查找该 Provider 的 Model 实现
# 查找该 Provider 的 Model 实现
if global_model:
model_obj = (
self.db.query(Model)
@@ -595,15 +465,14 @@ class ModelCostService:
"""
返回缓存创建/读取价格(每 1M tokens
新架构逻辑:
1. 使用 ModelMappingResolver 解析别名(如果是)
2. 解析为 GlobalModel.name
3. 查找该 Provider 的 Model 实现
4. 获取缓存价格配置
逻辑:
1. 直接通过 GlobalModel.name 匹配
2. 查找该 Provider 的 Model 实现
3. 获取缓存价格配置
Args:
provider: Provider 对象或提供商名称
model: 用户请求的模型名(可能是 GlobalModel.name 或别名
model: 用户请求的模型名(必须是 GlobalModel.name
input_price: 基础输入价格(用于 Claude 模型的默认估算)
Returns:

View File

@@ -11,7 +11,7 @@ from sqlalchemy.orm import Session, joinedload
from src.core.exceptions import InvalidRequestException, NotFoundException
from src.core.logger import logger
from src.models.database import GlobalModel, Model, ModelMapping
from src.models.database import GlobalModel, Model
from src.models.pydantic_models import GlobalModelUpdate

View File

@@ -5,18 +5,13 @@
from typing import Dict, List, Optional
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, joinedload
from src.core.cache_utils import SyncLRUCache
from src.core.logger import logger
from src.models.claude import ClaudeMessagesRequest
from src.models.database import GlobalModel, Model, ModelMapping, Provider, ProviderEndpoint
from src.models.database import GlobalModel, Model, Provider, ProviderEndpoint
from src.services.cache.model_cache import ModelCacheService
from src.services.model.mapping_resolver import (
get_model_mapping_resolver,
resolve_model_to_global_name,
)
class ModelMapperMiddleware:
@@ -71,10 +66,10 @@ class ModelMapperMiddleware:
if mapping:
# 应用映射
original_model = request.model
request.model = mapping.model.provider_model_name
request.model = mapping.model.select_provider_model_name()
logger.debug(f"Applied model mapping for provider {provider.name}: "
f"{original_model} -> {mapping.model.provider_model_name}")
f"{original_model} -> {request.model}")
else:
# 没有找到映射,使用原始模型名
logger.debug(f"No model mapping found for {source_model} with provider {provider.name}, "
@@ -84,17 +79,16 @@ class ModelMapperMiddleware:
async def get_mapping(
self, source_model: str, provider_id: str
) -> Optional[ModelMapping]: # UUID
) -> Optional[object]:
"""
获取模型映射
化后逻辑:
1. 使用统一的 ModelMappingResolver 解析别名(带缓存)
2. 通过 GlobalModel 找该 Provider 的 Model 实现
3. 使用独立的映射缓存
化后逻辑:
1. 通过 GlobalModel.name 直接查找
2. 找到 GlobalModel 后,查找该 Provider 的 Model 实现
Args:
source_model: 用户请求的模型名或别名
source_model: 用户请求的模型名(必须是 GlobalModel.name
provider_id: 提供商ID (UUID)
Returns:
@@ -107,62 +101,59 @@ class ModelMapperMiddleware:
mapping = None
# 步骤 1 & 2: 通过统一的模型映射解析服务
mapping_resolver = get_model_mapping_resolver()
global_model = await mapping_resolver.get_global_model_by_request(
self.db, source_model, provider_id
# 步骤 1: 直接通过名称查找 GlobalModel
global_model = (
self.db.query(GlobalModel)
.filter(GlobalModel.name == source_model, GlobalModel.is_active == True)
.first()
)
if not global_model:
logger.debug(f"GlobalModel not found: {source_model} (provider={provider_id[:8]}...)")
logger.debug(f"GlobalModel not found: {source_model}")
self._cache[cache_key] = None
return None
# 步骤 3: 查找该 Provider 是否有实现这个 GlobalModel 的 Model使用缓存
# 步骤 2: 查找该 Provider 是否有实现这个 GlobalModel 的 Model使用缓存
model = await ModelCacheService.get_model_by_provider_and_global_model(
self.db, provider_id, global_model.id
)
if model:
# 只有当模型名发生变化时才返回映射
if model.provider_model_name != source_model:
mapping = type(
"obj",
(object,),
{
"source_model": source_model,
"model": model,
"is_active": True,
"provider_id": provider_id,
},
)()
# 创建映射对象
mapping = type(
"obj",
(object,),
{
"source_model": source_model,
"model": model,
"is_active": True,
"provider_id": provider_id,
},
)()
logger.debug(f"Found model mapping: {source_model} -> {model.provider_model_name} "
f"(provider={provider_id[:8]}...)")
else:
logger.debug(f"Model found but no name change: {source_model} (provider={provider_id[:8]}...)")
logger.debug(f"Found model mapping: {source_model} -> {model.provider_model_name} "
f"(provider={provider_id[:8]}...)")
# 缓存结果
self._cache[cache_key] = mapping
return mapping
def get_all_mappings(self, provider_id: str) -> List[ModelMapping]: # UUID
def get_all_mappings(self, provider_id: str) -> List[object]:
"""
获取提供商的所有可用模型(通过 GlobalModel)
方案 A: 返回该 Provider 所有可用的 GlobalModel
Args:
provider_id: 提供商ID (UUID)
Returns:
模型映射列表(模拟的 ModelMapping 对象列表)
模型映射列表
"""
# 查询该 Provider 的所有活跃 Model
# 查询该 Provider 的所有活跃 Model(使用 joinedload 避免 N+1
models = (
self.db.query(Model)
.join(GlobalModel)
.options(joinedload(Model.global_model))
.filter(
Model.provider_id == provider_id,
Model.is_active == True,
@@ -171,7 +162,7 @@ class ModelMapperMiddleware:
.all()
)
# 构造兼容的 ModelMapping 对象列表
# 构造兼容的映射对象列表
mappings = []
for model in models:
mapping = type(
@@ -188,7 +179,7 @@ class ModelMapperMiddleware:
return mappings
def get_supported_models(self, provider_id: str) -> List[str]: # UUID
def get_supported_models(self, provider_id: str) -> List[str]:
"""
获取提供商支持的所有源模型名
@@ -223,15 +214,6 @@ class ModelMapperMiddleware:
if not mapping.is_active:
return False, f"Model mapping for {request.model} is disabled"
# 不限制max_tokens作为中转服务不应该限制用户的请求
# if request.max_tokens and request.max_tokens > mapping.max_output_tokens:
# return False, (
# f"Requested max_tokens {request.max_tokens} exceeds limit "
# f"{mapping.max_output_tokens} for model {request.model}"
# )
# 可以添加更多验证逻辑,比如检查输入长度等
return True, None
def clear_cache(self):
@@ -239,7 +221,7 @@ class ModelMapperMiddleware:
self._cache.clear()
logger.debug("Model mapping cache cleared")
def refresh_cache(self, provider_id: Optional[str] = None): # UUID
def refresh_cache(self, provider_id: Optional[str] = None):
"""
刷新缓存
@@ -285,16 +267,10 @@ class ModelRoutingMiddleware:
"""
根据模型名选择提供商
逻辑:
1. 如果指定了提供商,使用指定的提供商
2. 如果没指定,使用默认提供商
3. 选定提供商后会检查该提供商的模型映射在apply_mapping中处理
4. 如果指定了allowed_api_formats只选择符合格式的提供商
Args:
model_name: 请求的模型名
preferred_provider: 首选提供商名称
allowed_api_formats: 允许的API格式列表(如 ['CLAUDE', 'CLAUDE_CLI']
allowed_api_formats: 允许的API格式列表
request_id: 请求ID用于日志关联
Returns:
@@ -313,14 +289,12 @@ class ModelRoutingMiddleware:
if provider:
# 检查API格式 - 从 endpoints 中检查
if allowed_api_formats:
# 检查是否有符合要求的活跃端点
has_matching_endpoint = any(
ep.is_active and ep.api_format and ep.api_format in allowed_api_formats
for ep in provider.endpoints
)
if not has_matching_endpoint:
logger.warning(f"Specified provider {provider.name} has no active endpoints with allowed API formats ({allowed_api_formats})")
# 不返回该提供商,继续查找
else:
logger.debug(f" └─ {request_prefix}使用指定提供商: {provider.name} | 模型:{model_name}")
return provider
@@ -330,10 +304,9 @@ class ModelRoutingMiddleware:
else:
logger.warning(f"Specified provider {preferred_provider} not found or inactive")
# 2. 查找优先级最高的活动提供商provider_priority 最小)
# 2. 查找优先级最高的活动提供商
query = self.db.query(Provider).filter(Provider.is_active == True)
# 如果指定了API格式过滤添加过滤条件 - 检查是否有符合要求的 endpoint
if allowed_api_formats:
query = (
query.join(ProviderEndpoint)
@@ -344,32 +317,27 @@ class ModelRoutingMiddleware:
.distinct()
)
# 按 provider_priority 排序,优先级最高(数字最小)的在前
best_provider = query.order_by(Provider.provider_priority.asc(), Provider.id.asc()).first()
if best_provider:
logger.debug(f" └─ {request_prefix}使用优先级最高提供商: {best_provider.name} (priority:{best_provider.provider_priority}) | 模型:{model_name}")
return best_provider
# 3. 没有任何活动提供商
if allowed_api_formats:
logger.error(f"No active providers found with allowed API formats {allowed_api_formats}. Please configure at least one provider.")
logger.error(f"No active providers found with allowed API formats {allowed_api_formats}.")
else:
logger.error("No active providers found. Please configure at least one provider.")
logger.error("No active providers found.")
return None
def get_available_models(self) -> Dict[str, List[str]]:
"""
获取所有可用的模型及其提供商
方案 A: 基于 GlobalModel 查询
Returns:
字典,键为 GlobalModel.name值为支持该模型的提供商名列表
"""
result = {}
# 查询所有活跃的 GlobalModel 及其 Provider
models = (
self.db.query(GlobalModel.name, Provider.name)
.join(Model, GlobalModel.id == Model.global_model_id)
@@ -392,28 +360,23 @@ class ModelRoutingMiddleware:
"""
获取某个模型最便宜的提供商
方案 A: 通过 GlobalModel 查找
Args:
model_name: GlobalModel 名称或别名
model_name: GlobalModel 名称
Returns:
最便宜的提供商
"""
# 步骤 1: 解析模型名
global_model_name = await resolve_model_to_global_name(self.db, model_name)
# 步骤 2: 查找 GlobalModel
# 直接查找 GlobalModel
global_model = (
self.db.query(GlobalModel)
.filter(GlobalModel.name == global_model_name, GlobalModel.is_active == True)
.filter(GlobalModel.name == model_name, GlobalModel.is_active == True)
.first()
)
if not global_model:
return None
# 步骤 3: 查询所有支持该模型的 Provider 及其价格
# 查询所有支持该模型的 Provider 及其价格
models_with_providers = (
self.db.query(Provider, Model)
.join(Model, Provider.id == Model.provider_id)
@@ -428,15 +391,16 @@ class ModelRoutingMiddleware:
if not models_with_providers:
return None
# 按总价格排序(输入+输出价格)
# 按总价格排序
cheapest = min(
models_with_providers, key=lambda x: x[1].input_price_per_1m + x[1].output_price_per_1m
models_with_providers,
key=lambda x: x[1].get_effective_input_price() + x[1].get_effective_output_price()
)
provider = cheapest[0]
model = cheapest[1]
logger.debug(f"Selected cheapest provider {provider.name} for model {model_name} "
f"(input: ${model.input_price_per_1m}/M, output: ${model.output_price_per_1m}/M)")
f"(input: ${model.get_effective_input_price()}/M, output: ${model.get_effective_output_price()}/M)")
return provider

View File

@@ -50,6 +50,7 @@ class ModelService:
provider_id=provider_id,
global_model_id=model_data.global_model_id,
provider_model_name=model_data.provider_model_name,
provider_model_aliases=model_data.provider_model_aliases,
price_per_request=model_data.price_per_request,
tiered_pricing=model_data.tiered_pricing,
supports_vision=model_data.supports_vision,
@@ -191,7 +192,6 @@ class ModelService:
新架构删除逻辑:
- Model 只是 Provider 对 GlobalModel 的实现,删除不影响 GlobalModel
- 检查是否是该 GlobalModel 的最后一个实现(如果是,警告但允许删除)
- 不检查 ModelMapping映射是 GlobalModel 之间的关系,别名也统一存储在此表中)
"""
model = db.query(Model).filter(Model.id == model_id).first()
if not model:
@@ -326,6 +326,7 @@ class ModelService:
provider_id=model.provider_id,
global_model_id=model.global_model_id,
provider_model_name=model.provider_model_name,
provider_model_aliases=model.provider_model_aliases,
# 原始配置值(可能为空)
price_per_request=model.price_per_request,
tiered_pricing=model.tiered_pricing,