mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-07 10:12:27 +08:00
refactor(backend): optimize cache system and model/provider services
This commit is contained in:
@@ -14,7 +14,7 @@ from typing import Dict, Optional, Tuple, Union
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.models.database import GlobalModel, Model, ModelMapping, Provider
|
||||
from src.models.database import GlobalModel, Model, Provider
|
||||
|
||||
|
||||
ProviderRef = Union[str, Provider, None]
|
||||
@@ -161,16 +161,11 @@ class ModelCostService:
|
||||
result = None
|
||||
|
||||
if provider_obj:
|
||||
from src.services.model.mapping_resolver import resolve_model_to_global_name
|
||||
|
||||
global_model_name = await resolve_model_to_global_name(
|
||||
self.db, model, provider_obj.id
|
||||
)
|
||||
|
||||
# 直接通过 GlobalModel.name 查找
|
||||
global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.name == global_model_name,
|
||||
GlobalModel.name == model,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
@@ -226,17 +221,14 @@ class ModelCostService:
|
||||
注意:如果模型配置了阶梯计费,此方法返回第一个阶梯的价格作为默认值。
|
||||
实际计费时应使用 compute_cost_with_tiered_pricing 方法。
|
||||
|
||||
计费逻辑(基于 mapping_type):
|
||||
1. 查找 ModelMapping(如果存在)
|
||||
2. 如果 mapping_type='alias':使用目标 GlobalModel 的价格
|
||||
3. 如果 mapping_type='mapping':尝试使用 source_model 对应的 GlobalModel 价格
|
||||
- 如果 source_model 对应的 GlobalModel 存在且有 Model 实现,使用那个价格
|
||||
- 否则回退到目标 GlobalModel 的价格
|
||||
4. 如果没有找到任何 ModelMapping,尝试直接匹配 GlobalModel.name
|
||||
计费逻辑:
|
||||
1. 直接通过 GlobalModel.name 匹配
|
||||
2. 查找该 Provider 的 Model 实现
|
||||
3. 获取价格配置
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
|
||||
model: 用户请求的模型名(必须是 GlobalModel.name)
|
||||
|
||||
Returns:
|
||||
(input_price, output_price) 元组
|
||||
@@ -253,136 +245,37 @@ class ModelCostService:
|
||||
output_price = None
|
||||
|
||||
if provider_obj:
|
||||
# 步骤 1: 查找 ModelMapping 以确定 mapping_type
|
||||
from src.models.database import ModelMapping
|
||||
|
||||
mapping = None
|
||||
# 先查 Provider 特定映射
|
||||
mapping = (
|
||||
self.db.query(ModelMapping)
|
||||
# 直接通过 GlobalModel.name 查找(用户必须使用标准模型名称)
|
||||
global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
ModelMapping.source_model == model,
|
||||
ModelMapping.provider_id == provider_obj.id,
|
||||
ModelMapping.is_active == True,
|
||||
GlobalModel.name == model,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
# 再查全局映射
|
||||
if not mapping:
|
||||
mapping = (
|
||||
self.db.query(ModelMapping)
|
||||
if global_model:
|
||||
model_obj = (
|
||||
self.db.query(Model)
|
||||
.filter(
|
||||
ModelMapping.source_model == model,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
ModelMapping.is_active == True,
|
||||
Model.provider_id == provider_obj.id,
|
||||
Model.global_model_id == global_model.id,
|
||||
Model.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if mapping:
|
||||
# 有映射,根据 mapping_type 决定计费模型
|
||||
if mapping.mapping_type == "mapping":
|
||||
# mapping 模式:尝试使用 source_model 对应的 GlobalModel 价格
|
||||
source_global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.name == model,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if source_global_model:
|
||||
source_model_obj = (
|
||||
self.db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == provider_obj.id,
|
||||
Model.global_model_id == source_global_model.id,
|
||||
Model.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if source_model_obj:
|
||||
# 检查是否有阶梯计费
|
||||
tiered = source_model_obj.get_effective_tiered_pricing()
|
||||
if tiered and tiered.get("tiers"):
|
||||
first_tier = tiered["tiers"][0]
|
||||
input_price = first_tier.get("input_price_per_1m", 0)
|
||||
output_price = first_tier.get("output_price_per_1m", 0)
|
||||
else:
|
||||
input_price = source_model_obj.get_effective_input_price()
|
||||
output_price = source_model_obj.get_effective_output_price()
|
||||
logger.debug(f"[mapping模式] 使用源模型价格: {model} "
|
||||
f"(输入: ${input_price}/M, 输出: ${output_price}/M)")
|
||||
|
||||
# alias 模式或 mapping 模式未找到源模型价格:使用目标 GlobalModel 价格
|
||||
if input_price is None:
|
||||
target_global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.id == mapping.target_global_model_id,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if target_global_model:
|
||||
target_model_obj = (
|
||||
self.db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == provider_obj.id,
|
||||
Model.global_model_id == target_global_model.id,
|
||||
Model.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if target_model_obj:
|
||||
# 检查是否有阶梯计费
|
||||
tiered = target_model_obj.get_effective_tiered_pricing()
|
||||
if tiered and tiered.get("tiers"):
|
||||
first_tier = tiered["tiers"][0]
|
||||
input_price = first_tier.get("input_price_per_1m", 0)
|
||||
output_price = first_tier.get("output_price_per_1m", 0)
|
||||
else:
|
||||
input_price = target_model_obj.get_effective_input_price()
|
||||
output_price = target_model_obj.get_effective_output_price()
|
||||
mode_label = (
|
||||
"alias模式"
|
||||
if mapping.mapping_type == "alias"
|
||||
else "mapping模式(回退)"
|
||||
)
|
||||
logger.debug(f"[{mode_label}] 使用目标模型价格: {model} -> {target_global_model.name} "
|
||||
f"(输入: ${input_price}/M, 输出: ${output_price}/M)")
|
||||
else:
|
||||
# 没有映射,尝试直接匹配 GlobalModel.name
|
||||
global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.name == model,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if global_model:
|
||||
model_obj = (
|
||||
self.db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == provider_obj.id,
|
||||
Model.global_model_id == global_model.id,
|
||||
Model.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if model_obj:
|
||||
# 检查是否有阶梯计费
|
||||
tiered = model_obj.get_effective_tiered_pricing()
|
||||
if tiered and tiered.get("tiers"):
|
||||
first_tier = tiered["tiers"][0]
|
||||
input_price = first_tier.get("input_price_per_1m", 0)
|
||||
output_price = first_tier.get("output_price_per_1m", 0)
|
||||
else:
|
||||
input_price = model_obj.get_effective_input_price()
|
||||
output_price = model_obj.get_effective_output_price()
|
||||
logger.debug(f"找到模型价格配置: {provider_name}/{model} "
|
||||
f"(输入: ${input_price}/M, 输出: ${output_price}/M)")
|
||||
if model_obj:
|
||||
# 检查是否有阶梯计费
|
||||
tiered = model_obj.get_effective_tiered_pricing()
|
||||
if tiered and tiered.get("tiers"):
|
||||
first_tier = tiered["tiers"][0]
|
||||
input_price = first_tier.get("input_price_per_1m", 0)
|
||||
output_price = first_tier.get("output_price_per_1m", 0)
|
||||
else:
|
||||
input_price = model_obj.get_effective_input_price()
|
||||
output_price = model_obj.get_effective_output_price()
|
||||
logger.debug(f"找到模型价格配置: {provider_name}/{model} "
|
||||
f"(输入: ${input_price}/M, 输出: ${output_price}/M)")
|
||||
|
||||
# 如果没有找到价格配置,使用 0.0 并记录警告
|
||||
if input_price is None:
|
||||
@@ -404,15 +297,14 @@ class ModelCostService:
|
||||
"""
|
||||
返回给定 provider/model 的 (input_price, output_price)。
|
||||
|
||||
新架构逻辑:
|
||||
1. 使用 ModelMappingResolver 解析别名(如果是)
|
||||
2. 解析为 GlobalModel.name
|
||||
3. 查找该 Provider 的 Model 实现
|
||||
4. 获取价格配置
|
||||
逻辑:
|
||||
1. 直接通过 GlobalModel.name 匹配
|
||||
2. 查找该 Provider 的 Model 实现
|
||||
3. 获取价格配置
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
|
||||
model: 用户请求的模型名(必须是 GlobalModel.name)
|
||||
|
||||
Returns:
|
||||
(input_price, output_price) 元组
|
||||
@@ -434,15 +326,9 @@ class ModelCostService:
|
||||
"""
|
||||
异步版本: 返回缓存创建/读取价格(每 1M tokens)。
|
||||
|
||||
新架构逻辑:
|
||||
1. 使用 ModelMappingResolver 解析别名(如果是)
|
||||
2. 解析为 GlobalModel.name
|
||||
3. 查找该 Provider 的 Model 实现
|
||||
4. 获取缓存价格配置
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
|
||||
model: 用户请求的模型名(必须是 GlobalModel.name)
|
||||
input_price: 基础输入价格(用于 Claude 模型的默认估算)
|
||||
|
||||
Returns:
|
||||
@@ -460,22 +346,17 @@ class ModelCostService:
|
||||
cache_read_price = None
|
||||
|
||||
if provider_obj:
|
||||
# 步骤 1: 检查是否是别名
|
||||
from src.services.model.mapping_resolver import resolve_model_to_global_name
|
||||
|
||||
global_model_name = await resolve_model_to_global_name(self.db, model, provider_obj.id)
|
||||
|
||||
# 步骤 2: 查找 GlobalModel
|
||||
# 直接通过 GlobalModel.name 查找(用户必须使用标准模型名称)
|
||||
global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.name == global_model_name,
|
||||
GlobalModel.name == model,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# 步骤 3: 查找该 Provider 的 Model 实现
|
||||
# 查找该 Provider 的 Model 实现
|
||||
if global_model:
|
||||
model_obj = (
|
||||
self.db.query(Model)
|
||||
@@ -517,15 +398,9 @@ class ModelCostService:
|
||||
"""
|
||||
异步版本: 返回按次计费价格(每次请求的固定费用)。
|
||||
|
||||
新架构逻辑:
|
||||
1. 使用 ModelMappingResolver 解析别名(如果是)
|
||||
2. 解析为 GlobalModel.name
|
||||
3. 查找该 Provider 的 Model 实现
|
||||
4. 获取按次计费价格配置
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
|
||||
model: 用户请求的模型名(必须是 GlobalModel.name)
|
||||
|
||||
Returns:
|
||||
按次计费价格,如果没有配置则返回 None
|
||||
@@ -534,22 +409,17 @@ class ModelCostService:
|
||||
price_per_request = None
|
||||
|
||||
if provider_obj:
|
||||
# 步骤 1: 检查是否是别名
|
||||
from src.services.model.mapping_resolver import resolve_model_to_global_name
|
||||
|
||||
global_model_name = await resolve_model_to_global_name(self.db, model, provider_obj.id)
|
||||
|
||||
# 步骤 2: 查找 GlobalModel
|
||||
# 直接通过 GlobalModel.name 查找(用户必须使用标准模型名称)
|
||||
global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.name == global_model_name,
|
||||
GlobalModel.name == model,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# 步骤 3: 查找该 Provider 的 Model 实现
|
||||
# 查找该 Provider 的 Model 实现
|
||||
if global_model:
|
||||
model_obj = (
|
||||
self.db.query(Model)
|
||||
@@ -595,15 +465,14 @@ class ModelCostService:
|
||||
"""
|
||||
返回缓存创建/读取价格(每 1M tokens)。
|
||||
|
||||
新架构逻辑:
|
||||
1. 使用 ModelMappingResolver 解析别名(如果是)
|
||||
2. 解析为 GlobalModel.name
|
||||
3. 查找该 Provider 的 Model 实现
|
||||
4. 获取缓存价格配置
|
||||
逻辑:
|
||||
1. 直接通过 GlobalModel.name 匹配
|
||||
2. 查找该 Provider 的 Model 实现
|
||||
3. 获取缓存价格配置
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
|
||||
model: 用户请求的模型名(必须是 GlobalModel.name)
|
||||
input_price: 基础输入价格(用于 Claude 模型的默认估算)
|
||||
|
||||
Returns:
|
||||
|
||||
Reference in New Issue
Block a user