""" 模型成本服务 负责统一的价格解析、缓存以及成本计算逻辑。 支持固定价格、按次计费和阶梯计费三种模式。 计费策略: - 不同 API format 可以有不同的计费逻辑 - 通过 PricingStrategy 抽象,支持自定义总输入上下文计算、缓存 TTL 差异化等 """ from dataclasses import dataclass from typing import Dict, Optional, Tuple, Union from sqlalchemy.orm import Session from src.core.logger import logger from src.models.database import GlobalModel, Model, ModelMapping, Provider ProviderRef = Union[str, Provider, None] @dataclass class TieredPriceResult: """阶梯计费价格查询结果""" input_price_per_1m: float output_price_per_1m: float cache_creation_price_per_1m: Optional[float] = None cache_read_price_per_1m: Optional[float] = None tier_index: int = 0 # 命中的阶梯索引 @dataclass class CostBreakdown: """成本明细""" input_cost: float output_cost: float cache_creation_cost: float cache_read_cost: float cache_cost: float request_cost: float total_cost: float class ModelCostService: """集中负责模型价格与成本计算,避免在 mapper/usage 中重复实现。""" _price_cache: Dict[str, Dict[str, float]] = {} _cache_price_cache: Dict[str, Dict[str, float]] = {} _tiered_pricing_cache: Dict[str, Optional[dict]] = {} def __init__(self, db: Session): self.db = db # ------------------------------------------------------------------ # 阶梯计费相关方法 # ------------------------------------------------------------------ @staticmethod def get_tier_for_tokens( tiered_pricing: dict, total_input_tokens: int ) -> Optional[dict]: """ 根据总输入 token 数确定价格阶梯。 Args: tiered_pricing: 阶梯计费配置 {"tiers": [...]} total_input_tokens: 总输入 token 数(input_tokens + cache_read_tokens) Returns: 匹配的阶梯配置,如果未找到返回 None """ if not tiered_pricing or "tiers" not in tiered_pricing: return None tiers = tiered_pricing.get("tiers", []) if not tiers: return None for tier in tiers: up_to = tier.get("up_to") if up_to is None or total_input_tokens <= up_to: return tier # 如果所有阶梯都有上限且都超过了,返回最后一个阶梯 return tiers[-1] if tiers else None @staticmethod def get_cache_read_price_for_ttl( tier: dict, cache_ttl_minutes: Optional[int] = None ) -> Optional[float]: """ 根据缓存 TTL 获取缓存读取价格。 Args: tier: 当前阶梯配置 cache_ttl_minutes: 缓存时长(分钟),如果为 None 使用默认价格 Returns: 缓存读取价格 """ # 首先检查是否有 TTL 差异化定价 ttl_pricing = tier.get("cache_ttl_pricing") if ttl_pricing and cache_ttl_minutes is not None: # 找到匹配或最接近的 TTL 价格 matched_price = None for ttl_config in ttl_pricing: ttl_limit = ttl_config.get("ttl_minutes", 0) if cache_ttl_minutes <= ttl_limit: matched_price = ttl_config.get("cache_read_price_per_1m") break if matched_price is not None: return matched_price # 如果超过所有配置的 TTL,使用最后一个 if ttl_pricing: return ttl_pricing[-1].get("cache_read_price_per_1m") # 使用默认的缓存读取价格 return tier.get("cache_read_price_per_1m") async def get_tiered_pricing_async( self, provider: ProviderRef, model: str ) -> Optional[dict]: """ 异步获取模型的阶梯计费配置。 Args: provider: Provider 对象或提供商名称 model: 模型名称 Returns: 阶梯计费配置,如果未配置返回 None """ result = await self.get_tiered_pricing_with_source_async(provider, model) return result.get("pricing") if result else None async def get_tiered_pricing_with_source_async( self, provider: ProviderRef, model: str ) -> Optional[dict]: """ 异步获取模型的阶梯计费配置及来源信息。 Args: provider: Provider 对象或提供商名称 model: 模型名称 Returns: 包含 pricing 和 source 的字典: - pricing: 阶梯计费配置 - source: 'provider' 或 'global' """ provider_name = self._provider_name(provider) cache_key = f"{provider_name}:{model}:tiered_with_source" if cache_key in self._tiered_pricing_cache: return self._tiered_pricing_cache[cache_key] provider_obj = self._resolve_provider(provider) result = None if provider_obj: from src.services.model.mapping_resolver import resolve_model_to_global_name global_model_name = await resolve_model_to_global_name( self.db, model, provider_obj.id ) global_model = ( self.db.query(GlobalModel) .filter( GlobalModel.name == global_model_name, GlobalModel.is_active == True, ) .first() ) if global_model: model_obj = ( self.db.query(Model) .filter( Model.provider_id == provider_obj.id, Model.global_model_id == global_model.id, Model.is_active == True, ) .first() ) if model_obj: # 判断定价来源 if model_obj.tiered_pricing is not None: result = { "pricing": model_obj.tiered_pricing, "source": "provider" } elif global_model.default_tiered_pricing is not None: result = { "pricing": global_model.default_tiered_pricing, "source": "global" } self._tiered_pricing_cache[cache_key] = result return result def get_tiered_pricing(self, provider: ProviderRef, model: str) -> Optional[dict]: """同步获取模型的阶梯计费配置。""" import asyncio try: loop = asyncio.get_event_loop() except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) return loop.run_until_complete(self.get_tiered_pricing_async(provider, model)) # ------------------------------------------------------------------ # 公共方法 # ------------------------------------------------------------------ async def get_model_price_async(self, provider: ProviderRef, model: str) -> Tuple[float, float]: """ 异步版本: 返回给定 provider/model 的 (input_price, output_price)。 注意:如果模型配置了阶梯计费,此方法返回第一个阶梯的价格作为默认值。 实际计费时应使用 compute_cost_with_tiered_pricing 方法。 计费逻辑(基于 mapping_type): 1. 查找 ModelMapping(如果存在) 2. 如果 mapping_type='alias':使用目标 GlobalModel 的价格 3. 如果 mapping_type='mapping':尝试使用 source_model 对应的 GlobalModel 价格 - 如果 source_model 对应的 GlobalModel 存在且有 Model 实现,使用那个价格 - 否则回退到目标 GlobalModel 的价格 4. 如果没有找到任何 ModelMapping,尝试直接匹配 GlobalModel.name Args: provider: Provider 对象或提供商名称 model: 用户请求的模型名(可能是 GlobalModel.name 或别名) Returns: (input_price, output_price) 元组 """ provider_name = self._provider_name(provider) cache_key = f"{provider_name}:{model}" if cache_key in self._price_cache: prices = self._price_cache[cache_key] return prices["input"], prices["output"] provider_obj = self._resolve_provider(provider) input_price = None output_price = None if provider_obj: # 步骤 1: 查找 ModelMapping 以确定 mapping_type from src.models.database import ModelMapping mapping = None # 先查 Provider 特定映射 mapping = ( self.db.query(ModelMapping) .filter( ModelMapping.source_model == model, ModelMapping.provider_id == provider_obj.id, ModelMapping.is_active == True, ) .first() ) # 再查全局映射 if not mapping: mapping = ( self.db.query(ModelMapping) .filter( ModelMapping.source_model == model, ModelMapping.provider_id.is_(None), ModelMapping.is_active == True, ) .first() ) if mapping: # 有映射,根据 mapping_type 决定计费模型 if mapping.mapping_type == "mapping": # mapping 模式:尝试使用 source_model 对应的 GlobalModel 价格 source_global_model = ( self.db.query(GlobalModel) .filter( GlobalModel.name == model, GlobalModel.is_active == True, ) .first() ) if source_global_model: source_model_obj = ( self.db.query(Model) .filter( Model.provider_id == provider_obj.id, Model.global_model_id == source_global_model.id, Model.is_active == True, ) .first() ) if source_model_obj: # 检查是否有阶梯计费 tiered = source_model_obj.get_effective_tiered_pricing() if tiered and tiered.get("tiers"): first_tier = tiered["tiers"][0] input_price = first_tier.get("input_price_per_1m", 0) output_price = first_tier.get("output_price_per_1m", 0) else: input_price = source_model_obj.get_effective_input_price() output_price = source_model_obj.get_effective_output_price() logger.debug(f"[mapping模式] 使用源模型价格: {model} " f"(输入: ${input_price}/M, 输出: ${output_price}/M)") # alias 模式或 mapping 模式未找到源模型价格:使用目标 GlobalModel 价格 if input_price is None: target_global_model = ( self.db.query(GlobalModel) .filter( GlobalModel.id == mapping.target_global_model_id, GlobalModel.is_active == True, ) .first() ) if target_global_model: target_model_obj = ( self.db.query(Model) .filter( Model.provider_id == provider_obj.id, Model.global_model_id == target_global_model.id, Model.is_active == True, ) .first() ) if target_model_obj: # 检查是否有阶梯计费 tiered = target_model_obj.get_effective_tiered_pricing() if tiered and tiered.get("tiers"): first_tier = tiered["tiers"][0] input_price = first_tier.get("input_price_per_1m", 0) output_price = first_tier.get("output_price_per_1m", 0) else: input_price = target_model_obj.get_effective_input_price() output_price = target_model_obj.get_effective_output_price() mode_label = ( "alias模式" if mapping.mapping_type == "alias" else "mapping模式(回退)" ) logger.debug(f"[{mode_label}] 使用目标模型价格: {model} -> {target_global_model.name} " f"(输入: ${input_price}/M, 输出: ${output_price}/M)") else: # 没有映射,尝试直接匹配 GlobalModel.name global_model = ( self.db.query(GlobalModel) .filter( GlobalModel.name == model, GlobalModel.is_active == True, ) .first() ) if global_model: model_obj = ( self.db.query(Model) .filter( Model.provider_id == provider_obj.id, Model.global_model_id == global_model.id, Model.is_active == True, ) .first() ) if model_obj: # 检查是否有阶梯计费 tiered = model_obj.get_effective_tiered_pricing() if tiered and tiered.get("tiers"): first_tier = tiered["tiers"][0] input_price = first_tier.get("input_price_per_1m", 0) output_price = first_tier.get("output_price_per_1m", 0) else: input_price = model_obj.get_effective_input_price() output_price = model_obj.get_effective_output_price() logger.debug(f"找到模型价格配置: {provider_name}/{model} " f"(输入: ${input_price}/M, 输出: ${output_price}/M)") # 如果没有找到价格配置,使用 0.0 并记录警告 if input_price is None: input_price = 0.0 if output_price is None: output_price = 0.0 # 检查是否有按次计费配置(按次计费模型的 token 价格可以为 0) if input_price == 0.0 and output_price == 0.0: # 异步检查按次计费价格 price_per_request = await self.get_request_price_async(provider, model) if price_per_request is None or price_per_request == 0.0: logger.warning(f"未找到模型价格配置: {provider_name}/{model},请在 GlobalModel 中配置价格") self._price_cache[cache_key] = {"input": input_price, "output": output_price} return input_price, output_price def get_model_price(self, provider: ProviderRef, model: str) -> Tuple[float, float]: """ 返回给定 provider/model 的 (input_price, output_price)。 新架构逻辑: 1. 使用 ModelMappingResolver 解析别名(如果是) 2. 解析为 GlobalModel.name 3. 查找该 Provider 的 Model 实现 4. 获取价格配置 Args: provider: Provider 对象或提供商名称 model: 用户请求的模型名(可能是 GlobalModel.name 或别名) Returns: (input_price, output_price) 元组 """ import asyncio # 在同步上下文中调用异步方法 try: loop = asyncio.get_event_loop() except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) return loop.run_until_complete(self.get_model_price_async(provider, model)) async def get_cache_prices_async( self, provider: ProviderRef, model: str, input_price: float ) -> Tuple[Optional[float], Optional[float]]: """ 异步版本: 返回缓存创建/读取价格(每 1M tokens)。 新架构逻辑: 1. 使用 ModelMappingResolver 解析别名(如果是) 2. 解析为 GlobalModel.name 3. 查找该 Provider 的 Model 实现 4. 获取缓存价格配置 Args: provider: Provider 对象或提供商名称 model: 用户请求的模型名(可能是 GlobalModel.name 或别名) input_price: 基础输入价格(用于 Claude 模型的默认估算) Returns: (cache_creation_price, cache_read_price) 元组 """ provider_name = self._provider_name(provider) cache_key = f"{provider_name}:{model}" if cache_key in self._cache_price_cache: prices = self._cache_price_cache[cache_key] return prices["creation"], prices["read"] provider_obj = self._resolve_provider(provider) cache_creation_price = None cache_read_price = None if provider_obj: # 步骤 1: 检查是否是别名 from src.services.model.mapping_resolver import resolve_model_to_global_name global_model_name = await resolve_model_to_global_name(self.db, model, provider_obj.id) # 步骤 2: 查找 GlobalModel global_model = ( self.db.query(GlobalModel) .filter( GlobalModel.name == global_model_name, GlobalModel.is_active == True, ) .first() ) # 步骤 3: 查找该 Provider 的 Model 实现 if global_model: model_obj = ( self.db.query(Model) .filter( Model.provider_id == provider_obj.id, Model.global_model_id == global_model.id, Model.is_active == True, ) .first() ) if model_obj: # 检查是否有阶梯计费配置 tiered = model_obj.get_effective_tiered_pricing() if tiered and tiered.get("tiers"): # 使用第一个阶梯的缓存价格作为默认值 first_tier = tiered["tiers"][0] cache_creation_price = first_tier.get("cache_creation_price_per_1m") cache_read_price = first_tier.get("cache_read_price_per_1m") else: # 使用 get_effective_* 方法,会自动回退到 GlobalModel 的默认值 cache_creation_price = model_obj.get_effective_cache_creation_price() cache_read_price = model_obj.get_effective_cache_read_price() # 默认缓存价格估算(如果没有配置)- 基于输入价格计算 if cache_creation_price is None or cache_read_price is None: if cache_creation_price is None: cache_creation_price = input_price * 1.25 if cache_read_price is None: cache_read_price = input_price * 0.1 self._cache_price_cache[cache_key] = { "creation": cache_creation_price, "read": cache_read_price, } return cache_creation_price, cache_read_price async def get_request_price_async(self, provider: ProviderRef, model: str) -> Optional[float]: """ 异步版本: 返回按次计费价格(每次请求的固定费用)。 新架构逻辑: 1. 使用 ModelMappingResolver 解析别名(如果是) 2. 解析为 GlobalModel.name 3. 查找该 Provider 的 Model 实现 4. 获取按次计费价格配置 Args: provider: Provider 对象或提供商名称 model: 用户请求的模型名(可能是 GlobalModel.name 或别名) Returns: 按次计费价格,如果没有配置则返回 None """ provider_obj = self._resolve_provider(provider) price_per_request = None if provider_obj: # 步骤 1: 检查是否是别名 from src.services.model.mapping_resolver import resolve_model_to_global_name global_model_name = await resolve_model_to_global_name(self.db, model, provider_obj.id) # 步骤 2: 查找 GlobalModel global_model = ( self.db.query(GlobalModel) .filter( GlobalModel.name == global_model_name, GlobalModel.is_active == True, ) .first() ) # 步骤 3: 查找该 Provider 的 Model 实现 if global_model: model_obj = ( self.db.query(Model) .filter( Model.provider_id == provider_obj.id, Model.global_model_id == global_model.id, Model.is_active == True, ) .first() ) if model_obj: # 使用 get_effective_* 方法,会自动回退到 GlobalModel 的默认值 price_per_request = model_obj.get_effective_price_per_request() return price_per_request def get_request_price(self, provider: ProviderRef, model: str) -> Optional[float]: """ 返回按次计费价格(每次请求的固定费用)。 Args: provider: Provider 对象或提供商名称 model: 用户请求的模型名(可能是 GlobalModel.name 或别名) Returns: 按次计费价格,如果没有配置则返回 None """ import asyncio # 在同步上下文中调用异步方法 try: loop = asyncio.get_event_loop() except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) return loop.run_until_complete(self.get_request_price_async(provider, model)) def get_cache_prices( self, provider: ProviderRef, model: str, input_price: float ) -> Tuple[Optional[float], Optional[float]]: """ 返回缓存创建/读取价格(每 1M tokens)。 新架构逻辑: 1. 使用 ModelMappingResolver 解析别名(如果是) 2. 解析为 GlobalModel.name 3. 查找该 Provider 的 Model 实现 4. 获取缓存价格配置 Args: provider: Provider 对象或提供商名称 model: 用户请求的模型名(可能是 GlobalModel.name 或别名) input_price: 基础输入价格(用于 Claude 模型的默认估算) Returns: (cache_creation_price, cache_read_price) 元组 """ import asyncio # 在同步上下文中调用异步方法 try: loop = asyncio.get_event_loop() except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) return loop.run_until_complete(self.get_cache_prices_async(provider, model, input_price)) def calculate_cost( self, provider: Provider, model: str, input_tokens: int, output_tokens: int, ) -> Dict[str, float]: """返回与旧 ModelMapper.calculate_cost 相同结构的费用信息。""" input_price, output_price = self.get_model_price(provider, model) input_cost, output_cost, _, _, _, _, total_cost = self.compute_cost( input_tokens=input_tokens, output_tokens=output_tokens, input_price_per_1m=input_price, output_price_per_1m=output_price, ) return { "input_cost": round(input_cost, 6), "output_cost": round(output_cost, 6), "total_cost": round(total_cost, 6), "input_price_per_1m": input_price, "output_price_per_1m": output_price, } @staticmethod def compute_cost( *, input_tokens: int, output_tokens: int, input_price_per_1m: float, output_price_per_1m: float, cache_creation_input_tokens: int = 0, cache_read_input_tokens: int = 0, cache_creation_price_per_1m: Optional[float] = None, cache_read_price_per_1m: Optional[float] = None, price_per_request: Optional[float] = None, ) -> Tuple[float, float, float, float, float, float, float]: """成本计算核心逻辑(固定价格模式),供 UsageService 等复用。 Returns: Tuple of (input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost, request_cost, total_cost) """ input_cost = (input_tokens / 1_000_000) * input_price_per_1m output_cost = (output_tokens / 1_000_000) * output_price_per_1m cache_creation_cost = 0.0 cache_read_cost = 0.0 if cache_creation_input_tokens > 0 and cache_creation_price_per_1m is not None: cache_creation_cost = ( cache_creation_input_tokens / 1_000_000 ) * cache_creation_price_per_1m if cache_read_input_tokens > 0 and cache_read_price_per_1m is not None: cache_read_cost = (cache_read_input_tokens / 1_000_000) * cache_read_price_per_1m cache_cost = cache_creation_cost + cache_read_cost # 按次计费成本 request_cost = price_per_request if price_per_request is not None else 0.0 total_cost = input_cost + output_cost + cache_cost + request_cost return ( input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost, request_cost, total_cost, ) @staticmethod def compute_cost_with_tiered_pricing( *, input_tokens: int, output_tokens: int, cache_creation_input_tokens: int = 0, cache_read_input_tokens: int = 0, tiered_pricing: Optional[dict] = None, cache_ttl_minutes: Optional[int] = None, price_per_request: Optional[float] = None, # 回退价格(当没有阶梯配置时使用) fallback_input_price_per_1m: float = 0.0, fallback_output_price_per_1m: float = 0.0, fallback_cache_creation_price_per_1m: Optional[float] = None, fallback_cache_read_price_per_1m: Optional[float] = None, ) -> Tuple[float, float, float, float, float, float, float, Optional[int]]: """ 支持阶梯计费的成本计算核心逻辑。 阶梯判定:使用 input_tokens + cache_read_input_tokens(总输入上下文) Args: input_tokens: 输入 token 数 output_tokens: 输出 token 数 cache_creation_input_tokens: 缓存创建 token 数 cache_read_input_tokens: 缓存读取 token 数 tiered_pricing: 阶梯计费配置 cache_ttl_minutes: 缓存时长(分钟),用于 TTL 差异化定价 price_per_request: 按次计费价格 fallback_*: 回退价格配置 Returns: Tuple of (input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost, request_cost, total_cost, tier_index) tier_index: 命中的阶梯索引(0-based),如果未使用阶梯计费则为 None """ # 计算总输入上下文(用于阶梯判定) total_input_context = input_tokens + cache_read_input_tokens tier_index = None input_price_per_1m = fallback_input_price_per_1m output_price_per_1m = fallback_output_price_per_1m cache_creation_price_per_1m = fallback_cache_creation_price_per_1m cache_read_price_per_1m = fallback_cache_read_price_per_1m # 如果有阶梯配置,查找匹配的阶梯 if tiered_pricing and tiered_pricing.get("tiers"): tier = ModelCostService.get_tier_for_tokens(tiered_pricing, total_input_context) if tier: # 找到阶梯索引 tier_index = tiered_pricing["tiers"].index(tier) input_price_per_1m = tier.get("input_price_per_1m", fallback_input_price_per_1m) output_price_per_1m = tier.get("output_price_per_1m", fallback_output_price_per_1m) cache_creation_price_per_1m = tier.get( "cache_creation_price_per_1m", fallback_cache_creation_price_per_1m ) # 获取缓存读取价格(考虑 TTL 差异化) cache_read_price_per_1m = ModelCostService.get_cache_read_price_for_ttl( tier, cache_ttl_minutes ) if cache_read_price_per_1m is None: cache_read_price_per_1m = fallback_cache_read_price_per_1m logger.debug( f"[阶梯计费] 总输入上下文: {total_input_context}, " f"命中阶梯: {tier_index + 1}, " f"输入价格: ${input_price_per_1m}/M, " f"输出价格: ${output_price_per_1m}/M" ) # 计算成本 input_cost = (input_tokens / 1_000_000) * input_price_per_1m output_cost = (output_tokens / 1_000_000) * output_price_per_1m cache_creation_cost = 0.0 cache_read_cost = 0.0 if cache_creation_input_tokens > 0 and cache_creation_price_per_1m is not None: cache_creation_cost = ( cache_creation_input_tokens / 1_000_000 ) * cache_creation_price_per_1m if cache_read_input_tokens > 0 and cache_read_price_per_1m is not None: cache_read_cost = (cache_read_input_tokens / 1_000_000) * cache_read_price_per_1m cache_cost = cache_creation_cost + cache_read_cost # 按次计费成本 request_cost = price_per_request if price_per_request is not None else 0.0 total_cost = input_cost + output_cost + cache_cost + request_cost return ( input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost, request_cost, total_cost, tier_index, ) @classmethod def clear_cache(cls): """清理价格相关缓存。""" cls._price_cache.clear() cls._cache_price_cache.clear() cls._tiered_pricing_cache.clear() # ------------------------------------------------------------------ # 内部辅助 # ------------------------------------------------------------------ def _provider_name(self, provider: ProviderRef) -> str: if isinstance(provider, Provider): return provider.name return provider or "unknown" def _resolve_provider(self, provider: ProviderRef) -> Optional[Provider]: if isinstance(provider, Provider): return provider if not provider or provider == "unknown": return None return self.db.query(Provider).filter(Provider.name == provider).first() # ------------------------------------------------------------------ # 基于策略模式的计费方法 # ------------------------------------------------------------------ async def compute_cost_with_strategy_async( self, provider: ProviderRef, model: str, input_tokens: int, output_tokens: int, cache_creation_input_tokens: int = 0, cache_read_input_tokens: int = 0, api_format: Optional[str] = None, cache_ttl_minutes: Optional[int] = None, ) -> Tuple[float, float, float, float, float, float, float, Optional[int]]: """ 使用计费策略计算成本(异步版本) 根据 api_format 选择对应的 Adapter 计费逻辑,支持阶梯计费和 TTL 差异化。 Args: provider: Provider 对象或提供商名称 model: 模型名称 input_tokens: 输入 token 数 output_tokens: 输出 token 数 cache_creation_input_tokens: 缓存创建 token 数 cache_read_input_tokens: 缓存读取 token 数 api_format: API 格式(用于选择计费策略) cache_ttl_minutes: 缓存时长(分钟),用于 TTL 差异化定价 Returns: Tuple of (input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost, request_cost, total_cost, tier_index) """ # 获取价格配置 input_price, output_price = await self.get_model_price_async(provider, model) cache_creation_price, cache_read_price = await self.get_cache_prices_async( provider, model, input_price ) request_price = await self.get_request_price_async(provider, model) tiered_pricing = await self.get_tiered_pricing_async(provider, model) # 获取对应 API 格式的 Adapter 实例来计算成本 # 优先检查 Chat Adapter,然后检查 CLI Adapter from src.api.handlers.base.chat_adapter_base import get_adapter_instance from src.api.handlers.base.cli_adapter_base import get_cli_adapter_instance adapter = None if api_format: adapter = get_adapter_instance(api_format) if adapter is None: adapter = get_cli_adapter_instance(api_format) if adapter: # 使用 Adapter 的计费方法 result = adapter.compute_cost( input_tokens=input_tokens, output_tokens=output_tokens, cache_creation_input_tokens=cache_creation_input_tokens, cache_read_input_tokens=cache_read_input_tokens, input_price_per_1m=input_price, output_price_per_1m=output_price, cache_creation_price_per_1m=cache_creation_price, cache_read_price_per_1m=cache_read_price, price_per_request=request_price, tiered_pricing=tiered_pricing, cache_ttl_minutes=cache_ttl_minutes, ) return ( result["input_cost"], result["output_cost"], result["cache_creation_cost"], result["cache_read_cost"], result["cache_cost"], result["request_cost"], result["total_cost"], result["tier_index"], ) else: # 回退到默认计算逻辑(无 Adapter 时使用静态方法) return self.compute_cost_with_tiered_pricing( input_tokens=input_tokens, output_tokens=output_tokens, cache_creation_input_tokens=cache_creation_input_tokens, cache_read_input_tokens=cache_read_input_tokens, tiered_pricing=tiered_pricing, cache_ttl_minutes=cache_ttl_minutes, price_per_request=request_price, fallback_input_price_per_1m=input_price, fallback_output_price_per_1m=output_price, fallback_cache_creation_price_per_1m=cache_creation_price, fallback_cache_read_price_per_1m=cache_read_price, ) def compute_cost_with_strategy( self, provider: ProviderRef, model: str, input_tokens: int, output_tokens: int, cache_creation_input_tokens: int = 0, cache_read_input_tokens: int = 0, api_format: Optional[str] = None, cache_ttl_minutes: Optional[int] = None, ) -> Tuple[float, float, float, float, float, float, float, Optional[int]]: """ 使用计费策略计算成本(同步版本) """ import asyncio try: loop = asyncio.get_event_loop() except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) return loop.run_until_complete( self.compute_cost_with_strategy_async( provider=provider, model=model, input_tokens=input_tokens, output_tokens=output_tokens, cache_creation_input_tokens=cache_creation_input_tokens, cache_read_input_tokens=cache_read_input_tokens, api_format=api_format, cache_ttl_minutes=cache_ttl_minutes, ) )