Files
Aether/src/services/model/cost.py

947 lines
37 KiB
Python
Raw Normal View History

2025-12-10 20:52:44 +08:00
"""
模型成本服务
负责统一的价格解析缓存以及成本计算逻辑
支持固定价格按次计费和阶梯计费三种模式
计费策略
- 不同 API format 可以有不同的计费逻辑
- 通过 PricingStrategy 抽象支持自定义总输入上下文计算缓存 TTL 差异化等
"""
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from sqlalchemy.orm import Session
from src.core.logger import logger
from src.models.database import GlobalModel, Model, ModelMapping, Provider
ProviderRef = Union[str, Provider, None]
@dataclass
class TieredPriceResult:
"""阶梯计费价格查询结果"""
input_price_per_1m: float
output_price_per_1m: float
cache_creation_price_per_1m: Optional[float] = None
cache_read_price_per_1m: Optional[float] = None
tier_index: int = 0 # 命中的阶梯索引
@dataclass
class CostBreakdown:
"""成本明细"""
input_cost: float
output_cost: float
cache_creation_cost: float
cache_read_cost: float
cache_cost: float
request_cost: float
total_cost: float
class ModelCostService:
"""集中负责模型价格与成本计算,避免在 mapper/usage 中重复实现。"""
_price_cache: Dict[str, Dict[str, float]] = {}
_cache_price_cache: Dict[str, Dict[str, float]] = {}
_tiered_pricing_cache: Dict[str, Optional[dict]] = {}
def __init__(self, db: Session):
self.db = db
# ------------------------------------------------------------------
# 阶梯计费相关方法
# ------------------------------------------------------------------
@staticmethod
def get_tier_for_tokens(
tiered_pricing: dict,
total_input_tokens: int
) -> Optional[dict]:
"""
根据总输入 token 数确定价格阶梯
Args:
tiered_pricing: 阶梯计费配置 {"tiers": [...]}
total_input_tokens: 总输入 token input_tokens + cache_read_tokens
Returns:
匹配的阶梯配置如果未找到返回 None
"""
if not tiered_pricing or "tiers" not in tiered_pricing:
return None
tiers = tiered_pricing.get("tiers", [])
if not tiers:
return None
for tier in tiers:
up_to = tier.get("up_to")
if up_to is None or total_input_tokens <= up_to:
return tier
# 如果所有阶梯都有上限且都超过了,返回最后一个阶梯
return tiers[-1] if tiers else None
@staticmethod
def get_cache_read_price_for_ttl(
tier: dict,
cache_ttl_minutes: Optional[int] = None
) -> Optional[float]:
"""
根据缓存 TTL 获取缓存读取价格
Args:
tier: 当前阶梯配置
cache_ttl_minutes: 缓存时长分钟如果为 None 使用默认价格
Returns:
缓存读取价格
"""
# 首先检查是否有 TTL 差异化定价
ttl_pricing = tier.get("cache_ttl_pricing")
if ttl_pricing and cache_ttl_minutes is not None:
# 找到匹配或最接近的 TTL 价格
matched_price = None
for ttl_config in ttl_pricing:
ttl_limit = ttl_config.get("ttl_minutes", 0)
if cache_ttl_minutes <= ttl_limit:
matched_price = ttl_config.get("cache_read_price_per_1m")
break
if matched_price is not None:
return matched_price
# 如果超过所有配置的 TTL使用最后一个
if ttl_pricing:
return ttl_pricing[-1].get("cache_read_price_per_1m")
# 使用默认的缓存读取价格
return tier.get("cache_read_price_per_1m")
async def get_tiered_pricing_async(
self, provider: ProviderRef, model: str
) -> Optional[dict]:
"""
异步获取模型的阶梯计费配置
Args:
provider: Provider 对象或提供商名称
model: 模型名称
Returns:
阶梯计费配置如果未配置返回 None
"""
result = await self.get_tiered_pricing_with_source_async(provider, model)
return result.get("pricing") if result else None
async def get_tiered_pricing_with_source_async(
self, provider: ProviderRef, model: str
) -> Optional[dict]:
"""
异步获取模型的阶梯计费配置及来源信息
Args:
provider: Provider 对象或提供商名称
model: 模型名称
Returns:
包含 pricing source 的字典:
- pricing: 阶梯计费配置
- source: 'provider' 'global'
"""
provider_name = self._provider_name(provider)
cache_key = f"{provider_name}:{model}:tiered_with_source"
if cache_key in self._tiered_pricing_cache:
return self._tiered_pricing_cache[cache_key]
provider_obj = self._resolve_provider(provider)
result = None
if provider_obj:
from src.services.model.mapping_resolver import resolve_model_to_global_name
global_model_name = await resolve_model_to_global_name(
self.db, model, provider_obj.id
)
global_model = (
self.db.query(GlobalModel)
.filter(
GlobalModel.name == global_model_name,
GlobalModel.is_active == True,
)
.first()
)
if global_model:
model_obj = (
self.db.query(Model)
.filter(
Model.provider_id == provider_obj.id,
Model.global_model_id == global_model.id,
Model.is_active == True,
)
.first()
)
if model_obj:
# 判断定价来源
if model_obj.tiered_pricing is not None:
result = {
"pricing": model_obj.tiered_pricing,
"source": "provider"
}
elif global_model.default_tiered_pricing is not None:
result = {
"pricing": global_model.default_tiered_pricing,
"source": "global"
}
self._tiered_pricing_cache[cache_key] = result
return result
def get_tiered_pricing(self, provider: ProviderRef, model: str) -> Optional[dict]:
"""同步获取模型的阶梯计费配置。"""
import asyncio
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(self.get_tiered_pricing_async(provider, model))
# ------------------------------------------------------------------
# 公共方法
# ------------------------------------------------------------------
async def get_model_price_async(self, provider: ProviderRef, model: str) -> Tuple[float, float]:
"""
异步版本: 返回给定 provider/model (input_price, output_price)
注意如果模型配置了阶梯计费此方法返回第一个阶梯的价格作为默认值
实际计费时应使用 compute_cost_with_tiered_pricing 方法
计费逻辑基于 mapping_type:
1. 查找 ModelMapping如果存在
2. 如果 mapping_type='alias'使用目标 GlobalModel 的价格
3. 如果 mapping_type='mapping'尝试使用 source_model 对应的 GlobalModel 价格
- 如果 source_model 对应的 GlobalModel 存在且有 Model 实现使用那个价格
- 否则回退到目标 GlobalModel 的价格
4. 如果没有找到任何 ModelMapping尝试直接匹配 GlobalModel.name
Args:
provider: Provider 对象或提供商名称
model: 用户请求的模型名可能是 GlobalModel.name 或别名
Returns:
(input_price, output_price) 元组
"""
provider_name = self._provider_name(provider)
cache_key = f"{provider_name}:{model}"
if cache_key in self._price_cache:
prices = self._price_cache[cache_key]
return prices["input"], prices["output"]
provider_obj = self._resolve_provider(provider)
input_price = None
output_price = None
if provider_obj:
# 步骤 1: 查找 ModelMapping 以确定 mapping_type
from src.models.database import ModelMapping
mapping = None
# 先查 Provider 特定映射
mapping = (
self.db.query(ModelMapping)
.filter(
ModelMapping.source_model == model,
ModelMapping.provider_id == provider_obj.id,
ModelMapping.is_active == True,
)
.first()
)
# 再查全局映射
if not mapping:
mapping = (
self.db.query(ModelMapping)
.filter(
ModelMapping.source_model == model,
ModelMapping.provider_id.is_(None),
ModelMapping.is_active == True,
)
.first()
)
if mapping:
# 有映射,根据 mapping_type 决定计费模型
if mapping.mapping_type == "mapping":
# mapping 模式:尝试使用 source_model 对应的 GlobalModel 价格
source_global_model = (
self.db.query(GlobalModel)
.filter(
GlobalModel.name == model,
GlobalModel.is_active == True,
)
.first()
)
if source_global_model:
source_model_obj = (
self.db.query(Model)
.filter(
Model.provider_id == provider_obj.id,
Model.global_model_id == source_global_model.id,
Model.is_active == True,
)
.first()
)
if source_model_obj:
# 检查是否有阶梯计费
tiered = source_model_obj.get_effective_tiered_pricing()
if tiered and tiered.get("tiers"):
first_tier = tiered["tiers"][0]
input_price = first_tier.get("input_price_per_1m", 0)
output_price = first_tier.get("output_price_per_1m", 0)
else:
input_price = source_model_obj.get_effective_input_price()
output_price = source_model_obj.get_effective_output_price()
logger.debug(f"[mapping模式] 使用源模型价格: {model} "
f"(输入: ${input_price}/M, 输出: ${output_price}/M)")
# alias 模式或 mapping 模式未找到源模型价格:使用目标 GlobalModel 价格
if input_price is None:
target_global_model = (
self.db.query(GlobalModel)
.filter(
GlobalModel.id == mapping.target_global_model_id,
GlobalModel.is_active == True,
)
.first()
)
if target_global_model:
target_model_obj = (
self.db.query(Model)
.filter(
Model.provider_id == provider_obj.id,
Model.global_model_id == target_global_model.id,
Model.is_active == True,
)
.first()
)
if target_model_obj:
# 检查是否有阶梯计费
tiered = target_model_obj.get_effective_tiered_pricing()
if tiered and tiered.get("tiers"):
first_tier = tiered["tiers"][0]
input_price = first_tier.get("input_price_per_1m", 0)
output_price = first_tier.get("output_price_per_1m", 0)
else:
input_price = target_model_obj.get_effective_input_price()
output_price = target_model_obj.get_effective_output_price()
mode_label = (
"alias模式"
if mapping.mapping_type == "alias"
else "mapping模式(回退)"
)
logger.debug(f"[{mode_label}] 使用目标模型价格: {model} -> {target_global_model.name} "
f"(输入: ${input_price}/M, 输出: ${output_price}/M)")
else:
# 没有映射,尝试直接匹配 GlobalModel.name
global_model = (
self.db.query(GlobalModel)
.filter(
GlobalModel.name == model,
GlobalModel.is_active == True,
)
.first()
)
if global_model:
model_obj = (
self.db.query(Model)
.filter(
Model.provider_id == provider_obj.id,
Model.global_model_id == global_model.id,
Model.is_active == True,
)
.first()
)
if model_obj:
# 检查是否有阶梯计费
tiered = model_obj.get_effective_tiered_pricing()
if tiered and tiered.get("tiers"):
first_tier = tiered["tiers"][0]
input_price = first_tier.get("input_price_per_1m", 0)
output_price = first_tier.get("output_price_per_1m", 0)
else:
input_price = model_obj.get_effective_input_price()
output_price = model_obj.get_effective_output_price()
logger.debug(f"找到模型价格配置: {provider_name}/{model} "
f"(输入: ${input_price}/M, 输出: ${output_price}/M)")
# 如果没有找到价格配置,使用 0.0 并记录警告
if input_price is None:
input_price = 0.0
if output_price is None:
output_price = 0.0
# 检查是否有按次计费配置(按次计费模型的 token 价格可以为 0
if input_price == 0.0 and output_price == 0.0:
# 异步检查按次计费价格
price_per_request = await self.get_request_price_async(provider, model)
if price_per_request is None or price_per_request == 0.0:
logger.warning(f"未找到模型价格配置: {provider_name}/{model},请在 GlobalModel 中配置价格")
self._price_cache[cache_key] = {"input": input_price, "output": output_price}
return input_price, output_price
def get_model_price(self, provider: ProviderRef, model: str) -> Tuple[float, float]:
"""
返回给定 provider/model (input_price, output_price)
新架构逻辑:
1. 使用 ModelMappingResolver 解析别名如果是
2. 解析为 GlobalModel.name
3. 查找该 Provider Model 实现
4. 获取价格配置
Args:
provider: Provider 对象或提供商名称
model: 用户请求的模型名可能是 GlobalModel.name 或别名
Returns:
(input_price, output_price) 元组
"""
import asyncio
# 在同步上下文中调用异步方法
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(self.get_model_price_async(provider, model))
async def get_cache_prices_async(
self, provider: ProviderRef, model: str, input_price: float
) -> Tuple[Optional[float], Optional[float]]:
"""
异步版本: 返回缓存创建/读取价格 1M tokens
新架构逻辑:
1. 使用 ModelMappingResolver 解析别名如果是
2. 解析为 GlobalModel.name
3. 查找该 Provider Model 实现
4. 获取缓存价格配置
Args:
provider: Provider 对象或提供商名称
model: 用户请求的模型名可能是 GlobalModel.name 或别名
input_price: 基础输入价格用于 Claude 模型的默认估算
Returns:
(cache_creation_price, cache_read_price) 元组
"""
provider_name = self._provider_name(provider)
cache_key = f"{provider_name}:{model}"
if cache_key in self._cache_price_cache:
prices = self._cache_price_cache[cache_key]
return prices["creation"], prices["read"]
provider_obj = self._resolve_provider(provider)
cache_creation_price = None
cache_read_price = None
if provider_obj:
# 步骤 1: 检查是否是别名
from src.services.model.mapping_resolver import resolve_model_to_global_name
global_model_name = await resolve_model_to_global_name(self.db, model, provider_obj.id)
# 步骤 2: 查找 GlobalModel
global_model = (
self.db.query(GlobalModel)
.filter(
GlobalModel.name == global_model_name,
GlobalModel.is_active == True,
)
.first()
)
# 步骤 3: 查找该 Provider 的 Model 实现
if global_model:
model_obj = (
self.db.query(Model)
.filter(
Model.provider_id == provider_obj.id,
Model.global_model_id == global_model.id,
Model.is_active == True,
)
.first()
)
if model_obj:
# 检查是否有阶梯计费配置
tiered = model_obj.get_effective_tiered_pricing()
if tiered and tiered.get("tiers"):
# 使用第一个阶梯的缓存价格作为默认值
first_tier = tiered["tiers"][0]
cache_creation_price = first_tier.get("cache_creation_price_per_1m")
cache_read_price = first_tier.get("cache_read_price_per_1m")
else:
# 使用 get_effective_* 方法,会自动回退到 GlobalModel 的默认值
cache_creation_price = model_obj.get_effective_cache_creation_price()
cache_read_price = model_obj.get_effective_cache_read_price()
# 默认缓存价格估算(如果没有配置)- 基于输入价格计算
if cache_creation_price is None or cache_read_price is None:
if cache_creation_price is None:
cache_creation_price = input_price * 1.25
if cache_read_price is None:
cache_read_price = input_price * 0.1
self._cache_price_cache[cache_key] = {
"creation": cache_creation_price,
"read": cache_read_price,
}
return cache_creation_price, cache_read_price
async def get_request_price_async(self, provider: ProviderRef, model: str) -> Optional[float]:
"""
异步版本: 返回按次计费价格每次请求的固定费用
新架构逻辑:
1. 使用 ModelMappingResolver 解析别名如果是
2. 解析为 GlobalModel.name
3. 查找该 Provider Model 实现
4. 获取按次计费价格配置
Args:
provider: Provider 对象或提供商名称
model: 用户请求的模型名可能是 GlobalModel.name 或别名
Returns:
按次计费价格如果没有配置则返回 None
"""
provider_obj = self._resolve_provider(provider)
price_per_request = None
if provider_obj:
# 步骤 1: 检查是否是别名
from src.services.model.mapping_resolver import resolve_model_to_global_name
global_model_name = await resolve_model_to_global_name(self.db, model, provider_obj.id)
# 步骤 2: 查找 GlobalModel
global_model = (
self.db.query(GlobalModel)
.filter(
GlobalModel.name == global_model_name,
GlobalModel.is_active == True,
)
.first()
)
# 步骤 3: 查找该 Provider 的 Model 实现
if global_model:
model_obj = (
self.db.query(Model)
.filter(
Model.provider_id == provider_obj.id,
Model.global_model_id == global_model.id,
Model.is_active == True,
)
.first()
)
if model_obj:
# 使用 get_effective_* 方法,会自动回退到 GlobalModel 的默认值
price_per_request = model_obj.get_effective_price_per_request()
return price_per_request
def get_request_price(self, provider: ProviderRef, model: str) -> Optional[float]:
"""
返回按次计费价格每次请求的固定费用
Args:
provider: Provider 对象或提供商名称
model: 用户请求的模型名可能是 GlobalModel.name 或别名
Returns:
按次计费价格如果没有配置则返回 None
"""
import asyncio
# 在同步上下文中调用异步方法
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(self.get_request_price_async(provider, model))
def get_cache_prices(
self, provider: ProviderRef, model: str, input_price: float
) -> Tuple[Optional[float], Optional[float]]:
"""
返回缓存创建/读取价格 1M tokens
新架构逻辑:
1. 使用 ModelMappingResolver 解析别名如果是
2. 解析为 GlobalModel.name
3. 查找该 Provider Model 实现
4. 获取缓存价格配置
Args:
provider: Provider 对象或提供商名称
model: 用户请求的模型名可能是 GlobalModel.name 或别名
input_price: 基础输入价格用于 Claude 模型的默认估算
Returns:
(cache_creation_price, cache_read_price) 元组
"""
import asyncio
# 在同步上下文中调用异步方法
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(self.get_cache_prices_async(provider, model, input_price))
def calculate_cost(
self,
provider: Provider,
model: str,
input_tokens: int,
output_tokens: int,
) -> Dict[str, float]:
"""返回与旧 ModelMapper.calculate_cost 相同结构的费用信息。"""
input_price, output_price = self.get_model_price(provider, model)
input_cost, output_cost, _, _, _, _, total_cost = self.compute_cost(
input_tokens=input_tokens,
output_tokens=output_tokens,
input_price_per_1m=input_price,
output_price_per_1m=output_price,
)
return {
"input_cost": round(input_cost, 6),
"output_cost": round(output_cost, 6),
"total_cost": round(total_cost, 6),
"input_price_per_1m": input_price,
"output_price_per_1m": output_price,
}
@staticmethod
def compute_cost(
*,
input_tokens: int,
output_tokens: int,
input_price_per_1m: float,
output_price_per_1m: float,
cache_creation_input_tokens: int = 0,
cache_read_input_tokens: int = 0,
cache_creation_price_per_1m: Optional[float] = None,
cache_read_price_per_1m: Optional[float] = None,
price_per_request: Optional[float] = None,
) -> Tuple[float, float, float, float, float, float, float]:
"""成本计算核心逻辑(固定价格模式),供 UsageService 等复用。
Returns:
Tuple of (input_cost, output_cost, cache_creation_cost,
cache_read_cost, cache_cost, request_cost, total_cost)
"""
input_cost = (input_tokens / 1_000_000) * input_price_per_1m
output_cost = (output_tokens / 1_000_000) * output_price_per_1m
cache_creation_cost = 0.0
cache_read_cost = 0.0
if cache_creation_input_tokens > 0 and cache_creation_price_per_1m is not None:
cache_creation_cost = (
cache_creation_input_tokens / 1_000_000
) * cache_creation_price_per_1m
if cache_read_input_tokens > 0 and cache_read_price_per_1m is not None:
cache_read_cost = (cache_read_input_tokens / 1_000_000) * cache_read_price_per_1m
cache_cost = cache_creation_cost + cache_read_cost
# 按次计费成本
request_cost = price_per_request if price_per_request is not None else 0.0
total_cost = input_cost + output_cost + cache_cost + request_cost
return (
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
cache_cost,
request_cost,
total_cost,
)
@staticmethod
def compute_cost_with_tiered_pricing(
*,
input_tokens: int,
output_tokens: int,
cache_creation_input_tokens: int = 0,
cache_read_input_tokens: int = 0,
tiered_pricing: Optional[dict] = None,
cache_ttl_minutes: Optional[int] = None,
price_per_request: Optional[float] = None,
# 回退价格(当没有阶梯配置时使用)
fallback_input_price_per_1m: float = 0.0,
fallback_output_price_per_1m: float = 0.0,
fallback_cache_creation_price_per_1m: Optional[float] = None,
fallback_cache_read_price_per_1m: Optional[float] = None,
) -> Tuple[float, float, float, float, float, float, float, Optional[int]]:
"""
支持阶梯计费的成本计算核心逻辑
阶梯判定使用 input_tokens + cache_read_input_tokens总输入上下文
Args:
input_tokens: 输入 token
output_tokens: 输出 token
cache_creation_input_tokens: 缓存创建 token
cache_read_input_tokens: 缓存读取 token
tiered_pricing: 阶梯计费配置
cache_ttl_minutes: 缓存时长分钟用于 TTL 差异化定价
price_per_request: 按次计费价格
fallback_*: 回退价格配置
Returns:
Tuple of (input_cost, output_cost, cache_creation_cost,
cache_read_cost, cache_cost, request_cost, total_cost, tier_index)
tier_index: 命中的阶梯索引0-based如果未使用阶梯计费则为 None
"""
# 计算总输入上下文(用于阶梯判定)
total_input_context = input_tokens + cache_read_input_tokens
tier_index = None
input_price_per_1m = fallback_input_price_per_1m
output_price_per_1m = fallback_output_price_per_1m
cache_creation_price_per_1m = fallback_cache_creation_price_per_1m
cache_read_price_per_1m = fallback_cache_read_price_per_1m
# 如果有阶梯配置,查找匹配的阶梯
if tiered_pricing and tiered_pricing.get("tiers"):
tier = ModelCostService.get_tier_for_tokens(tiered_pricing, total_input_context)
if tier:
# 找到阶梯索引
tier_index = tiered_pricing["tiers"].index(tier)
input_price_per_1m = tier.get("input_price_per_1m", fallback_input_price_per_1m)
output_price_per_1m = tier.get("output_price_per_1m", fallback_output_price_per_1m)
cache_creation_price_per_1m = tier.get(
"cache_creation_price_per_1m", fallback_cache_creation_price_per_1m
)
# 获取缓存读取价格(考虑 TTL 差异化)
cache_read_price_per_1m = ModelCostService.get_cache_read_price_for_ttl(
tier, cache_ttl_minutes
)
if cache_read_price_per_1m is None:
cache_read_price_per_1m = fallback_cache_read_price_per_1m
logger.debug(
f"[阶梯计费] 总输入上下文: {total_input_context}, "
f"命中阶梯: {tier_index + 1}, "
f"输入价格: ${input_price_per_1m}/M, "
f"输出价格: ${output_price_per_1m}/M"
)
# 计算成本
input_cost = (input_tokens / 1_000_000) * input_price_per_1m
output_cost = (output_tokens / 1_000_000) * output_price_per_1m
cache_creation_cost = 0.0
cache_read_cost = 0.0
if cache_creation_input_tokens > 0 and cache_creation_price_per_1m is not None:
cache_creation_cost = (
cache_creation_input_tokens / 1_000_000
) * cache_creation_price_per_1m
if cache_read_input_tokens > 0 and cache_read_price_per_1m is not None:
cache_read_cost = (cache_read_input_tokens / 1_000_000) * cache_read_price_per_1m
cache_cost = cache_creation_cost + cache_read_cost
# 按次计费成本
request_cost = price_per_request if price_per_request is not None else 0.0
total_cost = input_cost + output_cost + cache_cost + request_cost
return (
input_cost,
output_cost,
cache_creation_cost,
cache_read_cost,
cache_cost,
request_cost,
total_cost,
tier_index,
)
@classmethod
def clear_cache(cls):
"""清理价格相关缓存。"""
cls._price_cache.clear()
cls._cache_price_cache.clear()
cls._tiered_pricing_cache.clear()
# ------------------------------------------------------------------
# 内部辅助
# ------------------------------------------------------------------
def _provider_name(self, provider: ProviderRef) -> str:
if isinstance(provider, Provider):
return provider.name
return provider or "unknown"
def _resolve_provider(self, provider: ProviderRef) -> Optional[Provider]:
if isinstance(provider, Provider):
return provider
if not provider or provider == "unknown":
return None
return self.db.query(Provider).filter(Provider.name == provider).first()
# ------------------------------------------------------------------
# 基于策略模式的计费方法
# ------------------------------------------------------------------
async def compute_cost_with_strategy_async(
self,
provider: ProviderRef,
model: str,
input_tokens: int,
output_tokens: int,
cache_creation_input_tokens: int = 0,
cache_read_input_tokens: int = 0,
api_format: Optional[str] = None,
cache_ttl_minutes: Optional[int] = None,
) -> Tuple[float, float, float, float, float, float, float, Optional[int]]:
"""
使用计费策略计算成本异步版本
根据 api_format 选择对应的 Adapter 计费逻辑支持阶梯计费和 TTL 差异化
Args:
provider: Provider 对象或提供商名称
model: 模型名称
input_tokens: 输入 token
output_tokens: 输出 token
cache_creation_input_tokens: 缓存创建 token
cache_read_input_tokens: 缓存读取 token
api_format: API 格式用于选择计费策略
cache_ttl_minutes: 缓存时长分钟用于 TTL 差异化定价
Returns:
Tuple of (input_cost, output_cost, cache_creation_cost,
cache_read_cost, cache_cost, request_cost, total_cost, tier_index)
"""
# 获取价格配置
input_price, output_price = await self.get_model_price_async(provider, model)
cache_creation_price, cache_read_price = await self.get_cache_prices_async(
provider, model, input_price
)
request_price = await self.get_request_price_async(provider, model)
tiered_pricing = await self.get_tiered_pricing_async(provider, model)
# 获取对应 API 格式的 Adapter 实例来计算成本
# 优先检查 Chat Adapter然后检查 CLI Adapter
from src.api.handlers.base.chat_adapter_base import get_adapter_instance
from src.api.handlers.base.cli_adapter_base import get_cli_adapter_instance
adapter = None
if api_format:
adapter = get_adapter_instance(api_format)
if adapter is None:
adapter = get_cli_adapter_instance(api_format)
if adapter:
# 使用 Adapter 的计费方法
result = adapter.compute_cost(
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
input_price_per_1m=input_price,
output_price_per_1m=output_price,
cache_creation_price_per_1m=cache_creation_price,
cache_read_price_per_1m=cache_read_price,
price_per_request=request_price,
tiered_pricing=tiered_pricing,
cache_ttl_minutes=cache_ttl_minutes,
)
return (
result["input_cost"],
result["output_cost"],
result["cache_creation_cost"],
result["cache_read_cost"],
result["cache_cost"],
result["request_cost"],
result["total_cost"],
result["tier_index"],
)
else:
# 回退到默认计算逻辑(无 Adapter 时使用静态方法)
return self.compute_cost_with_tiered_pricing(
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
tiered_pricing=tiered_pricing,
cache_ttl_minutes=cache_ttl_minutes,
price_per_request=request_price,
fallback_input_price_per_1m=input_price,
fallback_output_price_per_1m=output_price,
fallback_cache_creation_price_per_1m=cache_creation_price,
fallback_cache_read_price_per_1m=cache_read_price,
)
def compute_cost_with_strategy(
self,
provider: ProviderRef,
model: str,
input_tokens: int,
output_tokens: int,
cache_creation_input_tokens: int = 0,
cache_read_input_tokens: int = 0,
api_format: Optional[str] = None,
cache_ttl_minutes: Optional[int] = None,
) -> Tuple[float, float, float, float, float, float, float, Optional[int]]:
"""
使用计费策略计算成本同步版本
"""
import asyncio
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(
self.compute_cost_with_strategy_async(
provider=provider,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
api_format=api_format,
cache_ttl_minutes=cache_ttl_minutes,
)
)