mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-09 19:22:26 +08:00
Initial commit
This commit is contained in:
946
src/services/model/cost.py
Normal file
946
src/services/model/cost.py
Normal file
@@ -0,0 +1,946 @@
|
||||
"""
|
||||
模型成本服务
|
||||
负责统一的价格解析、缓存以及成本计算逻辑。
|
||||
支持固定价格、按次计费和阶梯计费三种模式。
|
||||
|
||||
计费策略:
|
||||
- 不同 API format 可以有不同的计费逻辑
|
||||
- 通过 PricingStrategy 抽象,支持自定义总输入上下文计算、缓存 TTL 差异化等
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.models.database import GlobalModel, Model, ModelMapping, Provider
|
||||
|
||||
|
||||
ProviderRef = Union[str, Provider, None]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TieredPriceResult:
|
||||
"""阶梯计费价格查询结果"""
|
||||
input_price_per_1m: float
|
||||
output_price_per_1m: float
|
||||
cache_creation_price_per_1m: Optional[float] = None
|
||||
cache_read_price_per_1m: Optional[float] = None
|
||||
tier_index: int = 0 # 命中的阶梯索引
|
||||
|
||||
|
||||
@dataclass
|
||||
class CostBreakdown:
|
||||
"""成本明细"""
|
||||
input_cost: float
|
||||
output_cost: float
|
||||
cache_creation_cost: float
|
||||
cache_read_cost: float
|
||||
cache_cost: float
|
||||
request_cost: float
|
||||
total_cost: float
|
||||
|
||||
|
||||
class ModelCostService:
|
||||
"""集中负责模型价格与成本计算,避免在 mapper/usage 中重复实现。"""
|
||||
|
||||
_price_cache: Dict[str, Dict[str, float]] = {}
|
||||
_cache_price_cache: Dict[str, Dict[str, float]] = {}
|
||||
_tiered_pricing_cache: Dict[str, Optional[dict]] = {}
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 阶梯计费相关方法
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def get_tier_for_tokens(
|
||||
tiered_pricing: dict,
|
||||
total_input_tokens: int
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
根据总输入 token 数确定价格阶梯。
|
||||
|
||||
Args:
|
||||
tiered_pricing: 阶梯计费配置 {"tiers": [...]}
|
||||
total_input_tokens: 总输入 token 数(input_tokens + cache_read_tokens)
|
||||
|
||||
Returns:
|
||||
匹配的阶梯配置,如果未找到返回 None
|
||||
"""
|
||||
if not tiered_pricing or "tiers" not in tiered_pricing:
|
||||
return None
|
||||
|
||||
tiers = tiered_pricing.get("tiers", [])
|
||||
if not tiers:
|
||||
return None
|
||||
|
||||
for tier in tiers:
|
||||
up_to = tier.get("up_to")
|
||||
if up_to is None or total_input_tokens <= up_to:
|
||||
return tier
|
||||
|
||||
# 如果所有阶梯都有上限且都超过了,返回最后一个阶梯
|
||||
return tiers[-1] if tiers else None
|
||||
|
||||
@staticmethod
|
||||
def get_cache_read_price_for_ttl(
|
||||
tier: dict,
|
||||
cache_ttl_minutes: Optional[int] = None
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
根据缓存 TTL 获取缓存读取价格。
|
||||
|
||||
Args:
|
||||
tier: 当前阶梯配置
|
||||
cache_ttl_minutes: 缓存时长(分钟),如果为 None 使用默认价格
|
||||
|
||||
Returns:
|
||||
缓存读取价格
|
||||
"""
|
||||
# 首先检查是否有 TTL 差异化定价
|
||||
ttl_pricing = tier.get("cache_ttl_pricing")
|
||||
if ttl_pricing and cache_ttl_minutes is not None:
|
||||
# 找到匹配或最接近的 TTL 价格
|
||||
matched_price = None
|
||||
for ttl_config in ttl_pricing:
|
||||
ttl_limit = ttl_config.get("ttl_minutes", 0)
|
||||
if cache_ttl_minutes <= ttl_limit:
|
||||
matched_price = ttl_config.get("cache_read_price_per_1m")
|
||||
break
|
||||
if matched_price is not None:
|
||||
return matched_price
|
||||
# 如果超过所有配置的 TTL,使用最后一个
|
||||
if ttl_pricing:
|
||||
return ttl_pricing[-1].get("cache_read_price_per_1m")
|
||||
|
||||
# 使用默认的缓存读取价格
|
||||
return tier.get("cache_read_price_per_1m")
|
||||
|
||||
async def get_tiered_pricing_async(
|
||||
self, provider: ProviderRef, model: str
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
异步获取模型的阶梯计费配置。
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 模型名称
|
||||
|
||||
Returns:
|
||||
阶梯计费配置,如果未配置返回 None
|
||||
"""
|
||||
result = await self.get_tiered_pricing_with_source_async(provider, model)
|
||||
return result.get("pricing") if result else None
|
||||
|
||||
async def get_tiered_pricing_with_source_async(
|
||||
self, provider: ProviderRef, model: str
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
异步获取模型的阶梯计费配置及来源信息。
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 模型名称
|
||||
|
||||
Returns:
|
||||
包含 pricing 和 source 的字典:
|
||||
- pricing: 阶梯计费配置
|
||||
- source: 'provider' 或 'global'
|
||||
"""
|
||||
provider_name = self._provider_name(provider)
|
||||
cache_key = f"{provider_name}:{model}:tiered_with_source"
|
||||
|
||||
if cache_key in self._tiered_pricing_cache:
|
||||
return self._tiered_pricing_cache[cache_key]
|
||||
|
||||
provider_obj = self._resolve_provider(provider)
|
||||
result = None
|
||||
|
||||
if provider_obj:
|
||||
from src.services.model.mapping_resolver import resolve_model_to_global_name
|
||||
|
||||
global_model_name = await resolve_model_to_global_name(
|
||||
self.db, model, provider_obj.id
|
||||
)
|
||||
|
||||
global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.name == global_model_name,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if global_model:
|
||||
model_obj = (
|
||||
self.db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == provider_obj.id,
|
||||
Model.global_model_id == global_model.id,
|
||||
Model.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if model_obj:
|
||||
# 判断定价来源
|
||||
if model_obj.tiered_pricing is not None:
|
||||
result = {
|
||||
"pricing": model_obj.tiered_pricing,
|
||||
"source": "provider"
|
||||
}
|
||||
elif global_model.default_tiered_pricing is not None:
|
||||
result = {
|
||||
"pricing": global_model.default_tiered_pricing,
|
||||
"source": "global"
|
||||
}
|
||||
|
||||
self._tiered_pricing_cache[cache_key] = result
|
||||
return result
|
||||
|
||||
def get_tiered_pricing(self, provider: ProviderRef, model: str) -> Optional[dict]:
|
||||
"""同步获取模型的阶梯计费配置。"""
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
return loop.run_until_complete(self.get_tiered_pricing_async(provider, model))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 公共方法
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def get_model_price_async(self, provider: ProviderRef, model: str) -> Tuple[float, float]:
|
||||
"""
|
||||
异步版本: 返回给定 provider/model 的 (input_price, output_price)。
|
||||
|
||||
注意:如果模型配置了阶梯计费,此方法返回第一个阶梯的价格作为默认值。
|
||||
实际计费时应使用 compute_cost_with_tiered_pricing 方法。
|
||||
|
||||
计费逻辑(基于 mapping_type):
|
||||
1. 查找 ModelMapping(如果存在)
|
||||
2. 如果 mapping_type='alias':使用目标 GlobalModel 的价格
|
||||
3. 如果 mapping_type='mapping':尝试使用 source_model 对应的 GlobalModel 价格
|
||||
- 如果 source_model 对应的 GlobalModel 存在且有 Model 实现,使用那个价格
|
||||
- 否则回退到目标 GlobalModel 的价格
|
||||
4. 如果没有找到任何 ModelMapping,尝试直接匹配 GlobalModel.name
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
|
||||
|
||||
Returns:
|
||||
(input_price, output_price) 元组
|
||||
"""
|
||||
provider_name = self._provider_name(provider)
|
||||
cache_key = f"{provider_name}:{model}"
|
||||
|
||||
if cache_key in self._price_cache:
|
||||
prices = self._price_cache[cache_key]
|
||||
return prices["input"], prices["output"]
|
||||
|
||||
provider_obj = self._resolve_provider(provider)
|
||||
input_price = None
|
||||
output_price = None
|
||||
|
||||
if provider_obj:
|
||||
# 步骤 1: 查找 ModelMapping 以确定 mapping_type
|
||||
from src.models.database import ModelMapping
|
||||
|
||||
mapping = None
|
||||
# 先查 Provider 特定映射
|
||||
mapping = (
|
||||
self.db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == model,
|
||||
ModelMapping.provider_id == provider_obj.id,
|
||||
ModelMapping.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
# 再查全局映射
|
||||
if not mapping:
|
||||
mapping = (
|
||||
self.db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == model,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
ModelMapping.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if mapping:
|
||||
# 有映射,根据 mapping_type 决定计费模型
|
||||
if mapping.mapping_type == "mapping":
|
||||
# mapping 模式:尝试使用 source_model 对应的 GlobalModel 价格
|
||||
source_global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.name == model,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if source_global_model:
|
||||
source_model_obj = (
|
||||
self.db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == provider_obj.id,
|
||||
Model.global_model_id == source_global_model.id,
|
||||
Model.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if source_model_obj:
|
||||
# 检查是否有阶梯计费
|
||||
tiered = source_model_obj.get_effective_tiered_pricing()
|
||||
if tiered and tiered.get("tiers"):
|
||||
first_tier = tiered["tiers"][0]
|
||||
input_price = first_tier.get("input_price_per_1m", 0)
|
||||
output_price = first_tier.get("output_price_per_1m", 0)
|
||||
else:
|
||||
input_price = source_model_obj.get_effective_input_price()
|
||||
output_price = source_model_obj.get_effective_output_price()
|
||||
logger.debug(f"[mapping模式] 使用源模型价格: {model} "
|
||||
f"(输入: ${input_price}/M, 输出: ${output_price}/M)")
|
||||
|
||||
# alias 模式或 mapping 模式未找到源模型价格:使用目标 GlobalModel 价格
|
||||
if input_price is None:
|
||||
target_global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.id == mapping.target_global_model_id,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if target_global_model:
|
||||
target_model_obj = (
|
||||
self.db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == provider_obj.id,
|
||||
Model.global_model_id == target_global_model.id,
|
||||
Model.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if target_model_obj:
|
||||
# 检查是否有阶梯计费
|
||||
tiered = target_model_obj.get_effective_tiered_pricing()
|
||||
if tiered and tiered.get("tiers"):
|
||||
first_tier = tiered["tiers"][0]
|
||||
input_price = first_tier.get("input_price_per_1m", 0)
|
||||
output_price = first_tier.get("output_price_per_1m", 0)
|
||||
else:
|
||||
input_price = target_model_obj.get_effective_input_price()
|
||||
output_price = target_model_obj.get_effective_output_price()
|
||||
mode_label = (
|
||||
"alias模式"
|
||||
if mapping.mapping_type == "alias"
|
||||
else "mapping模式(回退)"
|
||||
)
|
||||
logger.debug(f"[{mode_label}] 使用目标模型价格: {model} -> {target_global_model.name} "
|
||||
f"(输入: ${input_price}/M, 输出: ${output_price}/M)")
|
||||
else:
|
||||
# 没有映射,尝试直接匹配 GlobalModel.name
|
||||
global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.name == model,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if global_model:
|
||||
model_obj = (
|
||||
self.db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == provider_obj.id,
|
||||
Model.global_model_id == global_model.id,
|
||||
Model.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if model_obj:
|
||||
# 检查是否有阶梯计费
|
||||
tiered = model_obj.get_effective_tiered_pricing()
|
||||
if tiered and tiered.get("tiers"):
|
||||
first_tier = tiered["tiers"][0]
|
||||
input_price = first_tier.get("input_price_per_1m", 0)
|
||||
output_price = first_tier.get("output_price_per_1m", 0)
|
||||
else:
|
||||
input_price = model_obj.get_effective_input_price()
|
||||
output_price = model_obj.get_effective_output_price()
|
||||
logger.debug(f"找到模型价格配置: {provider_name}/{model} "
|
||||
f"(输入: ${input_price}/M, 输出: ${output_price}/M)")
|
||||
|
||||
# 如果没有找到价格配置,使用 0.0 并记录警告
|
||||
if input_price is None:
|
||||
input_price = 0.0
|
||||
if output_price is None:
|
||||
output_price = 0.0
|
||||
|
||||
# 检查是否有按次计费配置(按次计费模型的 token 价格可以为 0)
|
||||
if input_price == 0.0 and output_price == 0.0:
|
||||
# 异步检查按次计费价格
|
||||
price_per_request = await self.get_request_price_async(provider, model)
|
||||
if price_per_request is None or price_per_request == 0.0:
|
||||
logger.warning(f"未找到模型价格配置: {provider_name}/{model},请在 GlobalModel 中配置价格")
|
||||
|
||||
self._price_cache[cache_key] = {"input": input_price, "output": output_price}
|
||||
return input_price, output_price
|
||||
|
||||
def get_model_price(self, provider: ProviderRef, model: str) -> Tuple[float, float]:
|
||||
"""
|
||||
返回给定 provider/model 的 (input_price, output_price)。
|
||||
|
||||
新架构逻辑:
|
||||
1. 使用 ModelMappingResolver 解析别名(如果是)
|
||||
2. 解析为 GlobalModel.name
|
||||
3. 查找该 Provider 的 Model 实现
|
||||
4. 获取价格配置
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
|
||||
|
||||
Returns:
|
||||
(input_price, output_price) 元组
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# 在同步上下文中调用异步方法
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
return loop.run_until_complete(self.get_model_price_async(provider, model))
|
||||
|
||||
async def get_cache_prices_async(
|
||||
self, provider: ProviderRef, model: str, input_price: float
|
||||
) -> Tuple[Optional[float], Optional[float]]:
|
||||
"""
|
||||
异步版本: 返回缓存创建/读取价格(每 1M tokens)。
|
||||
|
||||
新架构逻辑:
|
||||
1. 使用 ModelMappingResolver 解析别名(如果是)
|
||||
2. 解析为 GlobalModel.name
|
||||
3. 查找该 Provider 的 Model 实现
|
||||
4. 获取缓存价格配置
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
|
||||
input_price: 基础输入价格(用于 Claude 模型的默认估算)
|
||||
|
||||
Returns:
|
||||
(cache_creation_price, cache_read_price) 元组
|
||||
"""
|
||||
provider_name = self._provider_name(provider)
|
||||
cache_key = f"{provider_name}:{model}"
|
||||
|
||||
if cache_key in self._cache_price_cache:
|
||||
prices = self._cache_price_cache[cache_key]
|
||||
return prices["creation"], prices["read"]
|
||||
|
||||
provider_obj = self._resolve_provider(provider)
|
||||
cache_creation_price = None
|
||||
cache_read_price = None
|
||||
|
||||
if provider_obj:
|
||||
# 步骤 1: 检查是否是别名
|
||||
from src.services.model.mapping_resolver import resolve_model_to_global_name
|
||||
|
||||
global_model_name = await resolve_model_to_global_name(self.db, model, provider_obj.id)
|
||||
|
||||
# 步骤 2: 查找 GlobalModel
|
||||
global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.name == global_model_name,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# 步骤 3: 查找该 Provider 的 Model 实现
|
||||
if global_model:
|
||||
model_obj = (
|
||||
self.db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == provider_obj.id,
|
||||
Model.global_model_id == global_model.id,
|
||||
Model.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if model_obj:
|
||||
# 检查是否有阶梯计费配置
|
||||
tiered = model_obj.get_effective_tiered_pricing()
|
||||
if tiered and tiered.get("tiers"):
|
||||
# 使用第一个阶梯的缓存价格作为默认值
|
||||
first_tier = tiered["tiers"][0]
|
||||
cache_creation_price = first_tier.get("cache_creation_price_per_1m")
|
||||
cache_read_price = first_tier.get("cache_read_price_per_1m")
|
||||
else:
|
||||
# 使用 get_effective_* 方法,会自动回退到 GlobalModel 的默认值
|
||||
cache_creation_price = model_obj.get_effective_cache_creation_price()
|
||||
cache_read_price = model_obj.get_effective_cache_read_price()
|
||||
|
||||
# 默认缓存价格估算(如果没有配置)- 基于输入价格计算
|
||||
if cache_creation_price is None or cache_read_price is None:
|
||||
if cache_creation_price is None:
|
||||
cache_creation_price = input_price * 1.25
|
||||
if cache_read_price is None:
|
||||
cache_read_price = input_price * 0.1
|
||||
|
||||
self._cache_price_cache[cache_key] = {
|
||||
"creation": cache_creation_price,
|
||||
"read": cache_read_price,
|
||||
}
|
||||
return cache_creation_price, cache_read_price
|
||||
|
||||
async def get_request_price_async(self, provider: ProviderRef, model: str) -> Optional[float]:
|
||||
"""
|
||||
异步版本: 返回按次计费价格(每次请求的固定费用)。
|
||||
|
||||
新架构逻辑:
|
||||
1. 使用 ModelMappingResolver 解析别名(如果是)
|
||||
2. 解析为 GlobalModel.name
|
||||
3. 查找该 Provider 的 Model 实现
|
||||
4. 获取按次计费价格配置
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
|
||||
|
||||
Returns:
|
||||
按次计费价格,如果没有配置则返回 None
|
||||
"""
|
||||
provider_obj = self._resolve_provider(provider)
|
||||
price_per_request = None
|
||||
|
||||
if provider_obj:
|
||||
# 步骤 1: 检查是否是别名
|
||||
from src.services.model.mapping_resolver import resolve_model_to_global_name
|
||||
|
||||
global_model_name = await resolve_model_to_global_name(self.db, model, provider_obj.id)
|
||||
|
||||
# 步骤 2: 查找 GlobalModel
|
||||
global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.name == global_model_name,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# 步骤 3: 查找该 Provider 的 Model 实现
|
||||
if global_model:
|
||||
model_obj = (
|
||||
self.db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == provider_obj.id,
|
||||
Model.global_model_id == global_model.id,
|
||||
Model.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if model_obj:
|
||||
# 使用 get_effective_* 方法,会自动回退到 GlobalModel 的默认值
|
||||
price_per_request = model_obj.get_effective_price_per_request()
|
||||
|
||||
return price_per_request
|
||||
|
||||
def get_request_price(self, provider: ProviderRef, model: str) -> Optional[float]:
|
||||
"""
|
||||
返回按次计费价格(每次请求的固定费用)。
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
|
||||
|
||||
Returns:
|
||||
按次计费价格,如果没有配置则返回 None
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# 在同步上下文中调用异步方法
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
return loop.run_until_complete(self.get_request_price_async(provider, model))
|
||||
|
||||
def get_cache_prices(
|
||||
self, provider: ProviderRef, model: str, input_price: float
|
||||
) -> Tuple[Optional[float], Optional[float]]:
|
||||
"""
|
||||
返回缓存创建/读取价格(每 1M tokens)。
|
||||
|
||||
新架构逻辑:
|
||||
1. 使用 ModelMappingResolver 解析别名(如果是)
|
||||
2. 解析为 GlobalModel.name
|
||||
3. 查找该 Provider 的 Model 实现
|
||||
4. 获取缓存价格配置
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
|
||||
input_price: 基础输入价格(用于 Claude 模型的默认估算)
|
||||
|
||||
Returns:
|
||||
(cache_creation_price, cache_read_price) 元组
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# 在同步上下文中调用异步方法
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
return loop.run_until_complete(self.get_cache_prices_async(provider, model, input_price))
|
||||
|
||||
def calculate_cost(
|
||||
self,
|
||||
provider: Provider,
|
||||
model: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
) -> Dict[str, float]:
|
||||
"""返回与旧 ModelMapper.calculate_cost 相同结构的费用信息。"""
|
||||
input_price, output_price = self.get_model_price(provider, model)
|
||||
input_cost, output_cost, _, _, _, _, total_cost = self.compute_cost(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
input_price_per_1m=input_price,
|
||||
output_price_per_1m=output_price,
|
||||
)
|
||||
return {
|
||||
"input_cost": round(input_cost, 6),
|
||||
"output_cost": round(output_cost, 6),
|
||||
"total_cost": round(total_cost, 6),
|
||||
"input_price_per_1m": input_price,
|
||||
"output_price_per_1m": output_price,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def compute_cost(
|
||||
*,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
input_price_per_1m: float,
|
||||
output_price_per_1m: float,
|
||||
cache_creation_input_tokens: int = 0,
|
||||
cache_read_input_tokens: int = 0,
|
||||
cache_creation_price_per_1m: Optional[float] = None,
|
||||
cache_read_price_per_1m: Optional[float] = None,
|
||||
price_per_request: Optional[float] = None,
|
||||
) -> Tuple[float, float, float, float, float, float, float]:
|
||||
"""成本计算核心逻辑(固定价格模式),供 UsageService 等复用。
|
||||
|
||||
Returns:
|
||||
Tuple of (input_cost, output_cost, cache_creation_cost,
|
||||
cache_read_cost, cache_cost, request_cost, total_cost)
|
||||
"""
|
||||
input_cost = (input_tokens / 1_000_000) * input_price_per_1m
|
||||
output_cost = (output_tokens / 1_000_000) * output_price_per_1m
|
||||
|
||||
cache_creation_cost = 0.0
|
||||
cache_read_cost = 0.0
|
||||
if cache_creation_input_tokens > 0 and cache_creation_price_per_1m is not None:
|
||||
cache_creation_cost = (
|
||||
cache_creation_input_tokens / 1_000_000
|
||||
) * cache_creation_price_per_1m
|
||||
if cache_read_input_tokens > 0 and cache_read_price_per_1m is not None:
|
||||
cache_read_cost = (cache_read_input_tokens / 1_000_000) * cache_read_price_per_1m
|
||||
|
||||
cache_cost = cache_creation_cost + cache_read_cost
|
||||
|
||||
# 按次计费成本
|
||||
request_cost = price_per_request if price_per_request is not None else 0.0
|
||||
|
||||
total_cost = input_cost + output_cost + cache_cost + request_cost
|
||||
return (
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
cache_read_cost,
|
||||
cache_cost,
|
||||
request_cost,
|
||||
total_cost,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def compute_cost_with_tiered_pricing(
|
||||
*,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cache_creation_input_tokens: int = 0,
|
||||
cache_read_input_tokens: int = 0,
|
||||
tiered_pricing: Optional[dict] = None,
|
||||
cache_ttl_minutes: Optional[int] = None,
|
||||
price_per_request: Optional[float] = None,
|
||||
# 回退价格(当没有阶梯配置时使用)
|
||||
fallback_input_price_per_1m: float = 0.0,
|
||||
fallback_output_price_per_1m: float = 0.0,
|
||||
fallback_cache_creation_price_per_1m: Optional[float] = None,
|
||||
fallback_cache_read_price_per_1m: Optional[float] = None,
|
||||
) -> Tuple[float, float, float, float, float, float, float, Optional[int]]:
|
||||
"""
|
||||
支持阶梯计费的成本计算核心逻辑。
|
||||
|
||||
阶梯判定:使用 input_tokens + cache_read_input_tokens(总输入上下文)
|
||||
|
||||
Args:
|
||||
input_tokens: 输入 token 数
|
||||
output_tokens: 输出 token 数
|
||||
cache_creation_input_tokens: 缓存创建 token 数
|
||||
cache_read_input_tokens: 缓存读取 token 数
|
||||
tiered_pricing: 阶梯计费配置
|
||||
cache_ttl_minutes: 缓存时长(分钟),用于 TTL 差异化定价
|
||||
price_per_request: 按次计费价格
|
||||
fallback_*: 回退价格配置
|
||||
|
||||
Returns:
|
||||
Tuple of (input_cost, output_cost, cache_creation_cost,
|
||||
cache_read_cost, cache_cost, request_cost, total_cost, tier_index)
|
||||
tier_index: 命中的阶梯索引(0-based),如果未使用阶梯计费则为 None
|
||||
"""
|
||||
# 计算总输入上下文(用于阶梯判定)
|
||||
total_input_context = input_tokens + cache_read_input_tokens
|
||||
|
||||
tier_index = None
|
||||
input_price_per_1m = fallback_input_price_per_1m
|
||||
output_price_per_1m = fallback_output_price_per_1m
|
||||
cache_creation_price_per_1m = fallback_cache_creation_price_per_1m
|
||||
cache_read_price_per_1m = fallback_cache_read_price_per_1m
|
||||
|
||||
# 如果有阶梯配置,查找匹配的阶梯
|
||||
if tiered_pricing and tiered_pricing.get("tiers"):
|
||||
tier = ModelCostService.get_tier_for_tokens(tiered_pricing, total_input_context)
|
||||
if tier:
|
||||
# 找到阶梯索引
|
||||
tier_index = tiered_pricing["tiers"].index(tier)
|
||||
|
||||
input_price_per_1m = tier.get("input_price_per_1m", fallback_input_price_per_1m)
|
||||
output_price_per_1m = tier.get("output_price_per_1m", fallback_output_price_per_1m)
|
||||
cache_creation_price_per_1m = tier.get(
|
||||
"cache_creation_price_per_1m", fallback_cache_creation_price_per_1m
|
||||
)
|
||||
|
||||
# 获取缓存读取价格(考虑 TTL 差异化)
|
||||
cache_read_price_per_1m = ModelCostService.get_cache_read_price_for_ttl(
|
||||
tier, cache_ttl_minutes
|
||||
)
|
||||
if cache_read_price_per_1m is None:
|
||||
cache_read_price_per_1m = fallback_cache_read_price_per_1m
|
||||
|
||||
logger.debug(
|
||||
f"[阶梯计费] 总输入上下文: {total_input_context}, "
|
||||
f"命中阶梯: {tier_index + 1}, "
|
||||
f"输入价格: ${input_price_per_1m}/M, "
|
||||
f"输出价格: ${output_price_per_1m}/M"
|
||||
)
|
||||
|
||||
# 计算成本
|
||||
input_cost = (input_tokens / 1_000_000) * input_price_per_1m
|
||||
output_cost = (output_tokens / 1_000_000) * output_price_per_1m
|
||||
|
||||
cache_creation_cost = 0.0
|
||||
cache_read_cost = 0.0
|
||||
if cache_creation_input_tokens > 0 and cache_creation_price_per_1m is not None:
|
||||
cache_creation_cost = (
|
||||
cache_creation_input_tokens / 1_000_000
|
||||
) * cache_creation_price_per_1m
|
||||
if cache_read_input_tokens > 0 and cache_read_price_per_1m is not None:
|
||||
cache_read_cost = (cache_read_input_tokens / 1_000_000) * cache_read_price_per_1m
|
||||
|
||||
cache_cost = cache_creation_cost + cache_read_cost
|
||||
|
||||
# 按次计费成本
|
||||
request_cost = price_per_request if price_per_request is not None else 0.0
|
||||
|
||||
total_cost = input_cost + output_cost + cache_cost + request_cost
|
||||
|
||||
return (
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
cache_read_cost,
|
||||
cache_cost,
|
||||
request_cost,
|
||||
total_cost,
|
||||
tier_index,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def clear_cache(cls):
|
||||
"""清理价格相关缓存。"""
|
||||
cls._price_cache.clear()
|
||||
cls._cache_price_cache.clear()
|
||||
cls._tiered_pricing_cache.clear()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 内部辅助
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _provider_name(self, provider: ProviderRef) -> str:
|
||||
if isinstance(provider, Provider):
|
||||
return provider.name
|
||||
return provider or "unknown"
|
||||
|
||||
def _resolve_provider(self, provider: ProviderRef) -> Optional[Provider]:
|
||||
if isinstance(provider, Provider):
|
||||
return provider
|
||||
if not provider or provider == "unknown":
|
||||
return None
|
||||
return self.db.query(Provider).filter(Provider.name == provider).first()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 基于策略模式的计费方法
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def compute_cost_with_strategy_async(
|
||||
self,
|
||||
provider: ProviderRef,
|
||||
model: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cache_creation_input_tokens: int = 0,
|
||||
cache_read_input_tokens: int = 0,
|
||||
api_format: Optional[str] = None,
|
||||
cache_ttl_minutes: Optional[int] = None,
|
||||
) -> Tuple[float, float, float, float, float, float, float, Optional[int]]:
|
||||
"""
|
||||
使用计费策略计算成本(异步版本)
|
||||
|
||||
根据 api_format 选择对应的 Adapter 计费逻辑,支持阶梯计费和 TTL 差异化。
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 模型名称
|
||||
input_tokens: 输入 token 数
|
||||
output_tokens: 输出 token 数
|
||||
cache_creation_input_tokens: 缓存创建 token 数
|
||||
cache_read_input_tokens: 缓存读取 token 数
|
||||
api_format: API 格式(用于选择计费策略)
|
||||
cache_ttl_minutes: 缓存时长(分钟),用于 TTL 差异化定价
|
||||
|
||||
Returns:
|
||||
Tuple of (input_cost, output_cost, cache_creation_cost,
|
||||
cache_read_cost, cache_cost, request_cost, total_cost, tier_index)
|
||||
"""
|
||||
# 获取价格配置
|
||||
input_price, output_price = await self.get_model_price_async(provider, model)
|
||||
cache_creation_price, cache_read_price = await self.get_cache_prices_async(
|
||||
provider, model, input_price
|
||||
)
|
||||
request_price = await self.get_request_price_async(provider, model)
|
||||
tiered_pricing = await self.get_tiered_pricing_async(provider, model)
|
||||
|
||||
# 获取对应 API 格式的 Adapter 实例来计算成本
|
||||
# 优先检查 Chat Adapter,然后检查 CLI Adapter
|
||||
from src.api.handlers.base.chat_adapter_base import get_adapter_instance
|
||||
from src.api.handlers.base.cli_adapter_base import get_cli_adapter_instance
|
||||
|
||||
adapter = None
|
||||
if api_format:
|
||||
adapter = get_adapter_instance(api_format)
|
||||
if adapter is None:
|
||||
adapter = get_cli_adapter_instance(api_format)
|
||||
|
||||
if adapter:
|
||||
# 使用 Adapter 的计费方法
|
||||
result = adapter.compute_cost(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
input_price_per_1m=input_price,
|
||||
output_price_per_1m=output_price,
|
||||
cache_creation_price_per_1m=cache_creation_price,
|
||||
cache_read_price_per_1m=cache_read_price,
|
||||
price_per_request=request_price,
|
||||
tiered_pricing=tiered_pricing,
|
||||
cache_ttl_minutes=cache_ttl_minutes,
|
||||
)
|
||||
return (
|
||||
result["input_cost"],
|
||||
result["output_cost"],
|
||||
result["cache_creation_cost"],
|
||||
result["cache_read_cost"],
|
||||
result["cache_cost"],
|
||||
result["request_cost"],
|
||||
result["total_cost"],
|
||||
result["tier_index"],
|
||||
)
|
||||
else:
|
||||
# 回退到默认计算逻辑(无 Adapter 时使用静态方法)
|
||||
return self.compute_cost_with_tiered_pricing(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
tiered_pricing=tiered_pricing,
|
||||
cache_ttl_minutes=cache_ttl_minutes,
|
||||
price_per_request=request_price,
|
||||
fallback_input_price_per_1m=input_price,
|
||||
fallback_output_price_per_1m=output_price,
|
||||
fallback_cache_creation_price_per_1m=cache_creation_price,
|
||||
fallback_cache_read_price_per_1m=cache_read_price,
|
||||
)
|
||||
|
||||
def compute_cost_with_strategy(
|
||||
self,
|
||||
provider: ProviderRef,
|
||||
model: str,
|
||||
input_tokens: int,
|
||||
output_tokens: int,
|
||||
cache_creation_input_tokens: int = 0,
|
||||
cache_read_input_tokens: int = 0,
|
||||
api_format: Optional[str] = None,
|
||||
cache_ttl_minutes: Optional[int] = None,
|
||||
) -> Tuple[float, float, float, float, float, float, float, Optional[int]]:
|
||||
"""
|
||||
使用计费策略计算成本(同步版本)
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
return loop.run_until_complete(
|
||||
self.compute_cost_with_strategy_async(
|
||||
provider=provider,
|
||||
model=model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
api_format=api_format,
|
||||
cache_ttl_minutes=cache_ttl_minutes,
|
||||
)
|
||||
)
|
||||
Reference in New Issue
Block a user