refactor: 抽取统一计费模块,支持配置驱动的多厂商计费

- 新增 src/services/billing/ 模块,包含计费计算器、模板和使用量映射
- 将 ChatAdapterBase 和 CliAdapterBase 中的计费逻辑重构为调用 billing 模块
- 为每个 adapter 添加 BILLING_TEMPLATE 类属性,指定计费模板
- 支持 Claude/OpenAI/Gemini 三种计费模板,支持阶梯计费和缓存 TTL 定价
- 新增 tests/services/billing/ 单元测试
This commit is contained in:
fawney19
2026-01-05 16:48:59 +08:00
parent 465da6f818
commit 35e29d46bd
15 changed files with 1649 additions and 224 deletions

View File

@@ -40,6 +40,7 @@ from src.core.exceptions import (
UpstreamClientException,
)
from src.core.logger import logger
from src.services.billing import calculate_request_cost as _calculate_request_cost
from src.services.request.result import RequestResult
from src.services.usage.recorder import UsageRecorder
@@ -63,6 +64,9 @@ class ChatAdapterBase(ApiAdapter):
name: str = "chat.base"
mode = ApiMode.STANDARD
# 计费模板配置(子类可覆盖,如 "claude", "openai", "gemini"
BILLING_TEMPLATE: str = "claude"
# 子类可以配置的特殊方法用于check_endpoint
@classmethod
def build_endpoint_url(cls, base_url: str) -> str:
@@ -486,40 +490,6 @@ class ChatAdapterBase(ApiAdapter):
"""
return input_tokens + cache_read_input_tokens
def get_cache_read_price_for_ttl(
self,
tier: dict,
cache_ttl_minutes: Optional[int] = None,
) -> Optional[float]:
"""
根据缓存 TTL 获取缓存读取价格
默认实现:检查 cache_ttl_pricing 配置,按 TTL 选择价格
子类可覆盖此方法实现不同的 TTL 定价逻辑
Args:
tier: 当前阶梯配置
cache_ttl_minutes: 缓存时长(分钟)
Returns:
缓存读取价格(每 1M tokens
"""
ttl_pricing = tier.get("cache_ttl_pricing")
if ttl_pricing and cache_ttl_minutes is not None:
matched_price = None
for ttl_config in ttl_pricing:
ttl_limit = ttl_config.get("ttl_minutes", 0)
if cache_ttl_minutes <= ttl_limit:
matched_price = ttl_config.get("cache_read_price_per_1m")
break
if matched_price is not None:
return matched_price
# 超过所有配置的 TTL使用最后一个
if ttl_pricing:
return ttl_pricing[-1].get("cache_read_price_per_1m")
return tier.get("cache_read_price_per_1m")
def compute_cost(
self,
input_tokens: int,
@@ -537,8 +507,9 @@ class ChatAdapterBase(ApiAdapter):
"""
计算请求成本
默认实现:支持固定价格和阶梯计费
子类可覆盖此方法实现完全不同的计费逻辑
使用 billing 模块的配置驱动计费
子类可通过设置 BILLING_TEMPLATE 类属性来指定计费模板,
或覆盖此方法实现完全自定义的计费逻辑。
Args:
input_tokens: 输入 token 数
@@ -566,88 +537,26 @@ class ChatAdapterBase(ApiAdapter):
"tier_index": Optional[int], # 命中的阶梯索引
}
"""
tier_index = None
effective_input_price = input_price_per_1m
effective_output_price = output_price_per_1m
effective_cache_creation_price = cache_creation_price_per_1m
effective_cache_read_price = cache_read_price_per_1m
# 计算总输入上下文(使用子类可覆盖的方法)
total_input_context = self.compute_total_input_context(
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
)
# 检查阶梯计费
if tiered_pricing and tiered_pricing.get("tiers"):
total_input_context = self.compute_total_input_context(
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
)
tier = self._get_tier_for_tokens(tiered_pricing, total_input_context)
if tier:
tier_index = tiered_pricing["tiers"].index(tier)
effective_input_price = tier.get("input_price_per_1m", input_price_per_1m)
effective_output_price = tier.get("output_price_per_1m", output_price_per_1m)
effective_cache_creation_price = tier.get(
"cache_creation_price_per_1m", cache_creation_price_per_1m
)
effective_cache_read_price = self.get_cache_read_price_for_ttl(
tier, cache_ttl_minutes
)
if effective_cache_read_price is None:
effective_cache_read_price = cache_read_price_per_1m
# 计算各项成本
input_cost = (input_tokens / 1_000_000) * effective_input_price
output_cost = (output_tokens / 1_000_000) * effective_output_price
cache_creation_cost = 0.0
cache_read_cost = 0.0
if cache_creation_input_tokens > 0 and effective_cache_creation_price is not None:
cache_creation_cost = (
cache_creation_input_tokens / 1_000_000
) * effective_cache_creation_price
if cache_read_input_tokens > 0 and effective_cache_read_price is not None:
cache_read_cost = (
cache_read_input_tokens / 1_000_000
) * effective_cache_read_price
cache_cost = cache_creation_cost + cache_read_cost
request_cost = price_per_request if price_per_request else 0.0
total_cost = input_cost + output_cost + cache_cost + request_cost
return {
"input_cost": input_cost,
"output_cost": output_cost,
"cache_creation_cost": cache_creation_cost,
"cache_read_cost": cache_read_cost,
"cache_cost": cache_cost,
"request_cost": request_cost,
"total_cost": total_cost,
"tier_index": tier_index,
}
@staticmethod
def _get_tier_for_tokens(tiered_pricing: dict, total_input_tokens: int) -> Optional[dict]:
"""
根据总输入 token 数确定价格阶梯
Args:
tiered_pricing: 阶梯计费配置 {"tiers": [...]}
total_input_tokens: 总输入 token 数
Returns:
匹配的阶梯配置
"""
if not tiered_pricing or "tiers" not in tiered_pricing:
return None
tiers = tiered_pricing.get("tiers", [])
if not tiers:
return None
for tier in tiers:
up_to = tier.get("up_to")
if up_to is None or total_input_tokens <= up_to:
return tier
# 如果所有阶梯都有上限且都超过了,返回最后一个阶梯
return tiers[-1] if tiers else None
return _calculate_request_cost(
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
input_price_per_1m=input_price_per_1m,
output_price_per_1m=output_price_per_1m,
cache_creation_price_per_1m=cache_creation_price_per_1m,
cache_read_price_per_1m=cache_read_price_per_1m,
price_per_request=price_per_request,
tiered_pricing=tiered_pricing,
cache_ttl_minutes=cache_ttl_minutes,
total_input_context=total_input_context,
billing_template=self.BILLING_TEMPLATE,
)
# =========================================================================
# 模型列表查询 - 子类应覆盖此方法

View File

@@ -38,6 +38,7 @@ from src.core.exceptions import (
UpstreamClientException,
)
from src.core.logger import logger
from src.services.billing import calculate_request_cost as _calculate_request_cost
from src.services.request.result import RequestResult
from src.services.usage.recorder import UsageRecorder
@@ -61,6 +62,9 @@ class CliAdapterBase(ApiAdapter):
name: str = "cli.base"
mode = ApiMode.PROXY
# 计费模板配置(子类可覆盖,如 "claude", "openai", "gemini"
BILLING_TEMPLATE: str = "claude"
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
@@ -438,40 +442,6 @@ class CliAdapterBase(ApiAdapter):
"""
return input_tokens + cache_read_input_tokens
def get_cache_read_price_for_ttl(
self,
tier: dict,
cache_ttl_minutes: Optional[int] = None,
) -> Optional[float]:
"""
根据缓存 TTL 获取缓存读取价格
默认实现:检查 cache_ttl_pricing 配置,按 TTL 选择价格
子类可覆盖此方法实现不同的 TTL 定价逻辑
Args:
tier: 当前阶梯配置
cache_ttl_minutes: 缓存时长(分钟)
Returns:
缓存读取价格(每 1M tokens
"""
ttl_pricing = tier.get("cache_ttl_pricing")
if ttl_pricing and cache_ttl_minutes is not None:
matched_price = None
for ttl_config in ttl_pricing:
ttl_limit = ttl_config.get("ttl_minutes", 0)
if cache_ttl_minutes <= ttl_limit:
matched_price = ttl_config.get("cache_read_price_per_1m")
break
if matched_price is not None:
return matched_price
# 超过所有配置的 TTL使用最后一个
if ttl_pricing:
return ttl_pricing[-1].get("cache_read_price_per_1m")
return tier.get("cache_read_price_per_1m")
def compute_cost(
self,
input_tokens: int,
@@ -489,8 +459,9 @@ class CliAdapterBase(ApiAdapter):
"""
计算请求成本
默认实现:支持固定价格和阶梯计费
子类可覆盖此方法实现完全不同的计费逻辑
使用 billing 模块的配置驱动计费
子类可通过设置 BILLING_TEMPLATE 类属性来指定计费模板,
或覆盖此方法实现完全自定义的计费逻辑。
Args:
input_tokens: 输入 token 数
@@ -508,78 +479,26 @@ class CliAdapterBase(ApiAdapter):
Returns:
包含各项成本的字典
"""
tier_index = None
effective_input_price = input_price_per_1m
effective_output_price = output_price_per_1m
effective_cache_creation_price = cache_creation_price_per_1m
effective_cache_read_price = cache_read_price_per_1m
# 计算总输入上下文(使用子类可覆盖的方法)
total_input_context = self.compute_total_input_context(
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
)
# 检查阶梯计费
if tiered_pricing and tiered_pricing.get("tiers"):
total_input_context = self.compute_total_input_context(
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
)
tier = self._get_tier_for_tokens(tiered_pricing, total_input_context)
if tier:
tier_index = tiered_pricing["tiers"].index(tier)
effective_input_price = tier.get("input_price_per_1m", input_price_per_1m)
effective_output_price = tier.get("output_price_per_1m", output_price_per_1m)
effective_cache_creation_price = tier.get(
"cache_creation_price_per_1m", cache_creation_price_per_1m
)
effective_cache_read_price = self.get_cache_read_price_for_ttl(
tier, cache_ttl_minutes
)
if effective_cache_read_price is None:
effective_cache_read_price = cache_read_price_per_1m
# 计算各项成本
input_cost = (input_tokens / 1_000_000) * effective_input_price
output_cost = (output_tokens / 1_000_000) * effective_output_price
cache_creation_cost = 0.0
cache_read_cost = 0.0
if cache_creation_input_tokens > 0 and effective_cache_creation_price is not None:
cache_creation_cost = (
cache_creation_input_tokens / 1_000_000
) * effective_cache_creation_price
if cache_read_input_tokens > 0 and effective_cache_read_price is not None:
cache_read_cost = (
cache_read_input_tokens / 1_000_000
) * effective_cache_read_price
cache_cost = cache_creation_cost + cache_read_cost
request_cost = price_per_request if price_per_request else 0.0
total_cost = input_cost + output_cost + cache_cost + request_cost
return {
"input_cost": input_cost,
"output_cost": output_cost,
"cache_creation_cost": cache_creation_cost,
"cache_read_cost": cache_read_cost,
"cache_cost": cache_cost,
"request_cost": request_cost,
"total_cost": total_cost,
"tier_index": tier_index,
}
@staticmethod
def _get_tier_for_tokens(tiered_pricing: dict, total_input_tokens: int) -> Optional[dict]:
"""根据总输入 token 数确定价格阶梯"""
if not tiered_pricing or "tiers" not in tiered_pricing:
return None
tiers = tiered_pricing.get("tiers", [])
if not tiers:
return None
for tier in tiers:
up_to = tier.get("up_to")
if up_to is None or total_input_tokens <= up_to:
return tier
return tiers[-1] if tiers else None
return _calculate_request_cost(
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
input_price_per_1m=input_price_per_1m,
output_price_per_1m=output_price_per_1m,
cache_creation_price_per_1m=cache_creation_price_per_1m,
cache_read_price_per_1m=cache_read_price_per_1m,
price_per_request=price_per_request,
tiered_pricing=tiered_pricing,
cache_ttl_minutes=cache_ttl_minutes,
total_input_context=total_input_context,
billing_template=self.BILLING_TEMPLATE,
)
# =========================================================================
# 模型列表查询 - 子类应覆盖此方法

View File

@@ -63,6 +63,7 @@ class ClaudeChatAdapter(ChatAdapterBase):
"""
FORMAT_ID = "CLAUDE"
BILLING_TEMPLATE = "claude" # 使用 Claude 计费模板
name = "claude.chat"
@property

View File

@@ -24,6 +24,7 @@ class ClaudeCliAdapter(CliAdapterBase):
"""
FORMAT_ID = "CLAUDE_CLI"
BILLING_TEMPLATE = "claude" # 使用 Claude 计费模板
name = "claude.cli"
@property

View File

@@ -27,6 +27,7 @@ class GeminiChatAdapter(ChatAdapterBase):
"""
FORMAT_ID = "GEMINI"
BILLING_TEMPLATE = "gemini" # 使用 Gemini 计费模板
name = "gemini.chat"
@property

View File

@@ -24,6 +24,7 @@ class GeminiCliAdapter(CliAdapterBase):
"""
FORMAT_ID = "GEMINI_CLI"
BILLING_TEMPLATE = "gemini" # 使用 Gemini 计费模板
name = "gemini.cli"
@property

View File

@@ -26,6 +26,7 @@ class OpenAIChatAdapter(ChatAdapterBase):
"""
FORMAT_ID = "OPENAI"
BILLING_TEMPLATE = "openai" # 使用 OpenAI 计费模板
name = "openai.chat"
@property

View File

@@ -24,6 +24,7 @@ class OpenAICliAdapter(CliAdapterBase):
"""
FORMAT_ID = "OPENAI_CLI"
BILLING_TEMPLATE = "openai" # 使用 OpenAI 计费模板
name = "openai.cli"
@property

View File

@@ -0,0 +1,51 @@
"""
计费模块
提供配置驱动的计费计算,支持不同厂商的差异化计费模式:
- Claude: input + output + cache_creation + cache_read
- OpenAI: input + output + cache_read (无缓存创建费用)
- 豆包: input + output + cache_read + cache_storage (缓存按时计费)
- 按次计费: per_request
使用方式:
from src.services.billing import BillingCalculator, UsageMapper, StandardizedUsage
# 1. 将原始 usage 映射为标准格式
usage = UsageMapper.map(raw_usage, api_format="OPENAI")
# 2. 使用计费计算器计算费用
calculator = BillingCalculator(template="openai")
result = calculator.calculate(usage, prices)
# 3. 获取费用明细
print(result.total_cost)
print(result.costs) # {"input": 0.01, "output": 0.02, ...}
"""
from src.services.billing.calculator import BillingCalculator, calculate_request_cost
from src.services.billing.models import (
BillingDimension,
BillingUnit,
CostBreakdown,
StandardizedUsage,
)
from src.services.billing.templates import BILLING_TEMPLATE_REGISTRY, BillingTemplates
from src.services.billing.usage_mapper import UsageMapper, map_usage, map_usage_from_response
__all__ = [
# 数据模型
"BillingDimension",
"BillingUnit",
"CostBreakdown",
"StandardizedUsage",
# 模板
"BillingTemplates",
"BILLING_TEMPLATE_REGISTRY",
# 计算器
"BillingCalculator",
"calculate_request_cost",
# 映射器
"UsageMapper",
"map_usage",
"map_usage_from_response",
]

View File

@@ -0,0 +1,339 @@
"""
计费计算器
配置驱动的计费计算,支持:
- 固定价格计费
- 阶梯计费
- 多种计费模板
- 自定义计费维度
"""
from typing import Any, Dict, List, Optional, Tuple
from src.services.billing.models import (
BillingDimension,
BillingUnit,
CostBreakdown,
StandardizedUsage,
)
from src.services.billing.templates import (
BILLING_TEMPLATE_REGISTRY,
BillingTemplates,
get_template,
)
class BillingCalculator:
"""
配置驱动的计费计算器
支持多种计费模式:
- 使用预定义模板claude, openai, doubao 等)
- 自定义计费维度
- 阶梯计费
示例:
# 使用模板
calculator = BillingCalculator(template="openai")
# 自定义维度
calculator = BillingCalculator(dimensions=[
BillingDimension(name="input", usage_field="input_tokens", price_field="input_price_per_1m"),
BillingDimension(name="output", usage_field="output_tokens", price_field="output_price_per_1m"),
])
# 计算费用
usage = StandardizedUsage(input_tokens=1000, output_tokens=500)
prices = {"input_price_per_1m": 3.0, "output_price_per_1m": 15.0}
result = calculator.calculate(usage, prices)
"""
def __init__(
self,
dimensions: Optional[List[BillingDimension]] = None,
template: Optional[str] = None,
):
"""
初始化计费计算器
Args:
dimensions: 自定义计费维度列表(优先级高于模板)
template: 使用预定义模板名称 ("claude", "openai", "doubao", "per_request" 等)
"""
if dimensions:
self.dimensions = dimensions
elif template:
self.dimensions = get_template(template)
else:
# 默认使用 Claude 模板(向后兼容)
self.dimensions = BillingTemplates.CLAUDE_STANDARD
self.template_name = template
def calculate(
self,
usage: StandardizedUsage,
prices: Dict[str, float],
tiered_pricing: Optional[Dict[str, Any]] = None,
cache_ttl_minutes: Optional[int] = None,
total_input_context: Optional[int] = None,
) -> CostBreakdown:
"""
计算费用
Args:
usage: 标准化的 usage 数据
prices: 价格配置 {"input_price_per_1m": 3.0, "output_price_per_1m": 15.0, ...}
tiered_pricing: 阶梯计费配置(可选)
cache_ttl_minutes: 缓存 TTL 分钟数(用于 TTL 差异化定价)
total_input_context: 总输入上下文(用于阶梯判定,可选)
如果提供,将使用该值进行阶梯判定;否则使用默认计算逻辑
Returns:
费用明细 (CostBreakdown)
"""
result = CostBreakdown()
# 处理阶梯计费
effective_prices = prices.copy()
if tiered_pricing and tiered_pricing.get("tiers"):
tier, tier_index = self._get_tier(usage, tiered_pricing, total_input_context)
if tier:
result.tier_index = tier_index
# 阶梯价格覆盖默认价格
for key, value in tier.items():
if key not in ("up_to", "cache_ttl_pricing") and value is not None:
effective_prices[key] = value
# 处理 TTL 差异化定价
if cache_ttl_minutes is not None:
ttl_price = self._get_cache_read_price_for_ttl(tier, cache_ttl_minutes)
if ttl_price is not None:
effective_prices["cache_read_price_per_1m"] = ttl_price
# 记录使用的价格
result.effective_prices = effective_prices.copy()
# 计算各维度费用
total = 0.0
for dim in self.dimensions:
usage_value = usage.get(dim.usage_field, 0)
price = effective_prices.get(dim.price_field, dim.default_price)
if usage_value and price:
cost = dim.calculate(usage_value, price)
result.costs[dim.name] = cost
total += cost
result.total_cost = total
return result
def _get_tier(
self,
usage: StandardizedUsage,
tiered_pricing: Dict[str, Any],
total_input_context: Optional[int] = None,
) -> Tuple[Optional[Dict[str, Any]], Optional[int]]:
"""
确定价格阶梯
Args:
usage: usage 数据
tiered_pricing: 阶梯配置 {"tiers": [...]}
total_input_context: 预计算的总输入上下文(可选)
Returns:
(匹配的阶梯配置, 阶梯索引)
"""
tiers = tiered_pricing.get("tiers", [])
if not tiers:
return None, None
# 使用传入的 total_input_context或者默认计算
if total_input_context is None:
total_input_context = self._compute_total_input_context(usage)
for i, tier in enumerate(tiers):
up_to = tier.get("up_to")
# up_to 为 None 表示无上限(最后一个阶梯)
if up_to is None or total_input_context <= up_to:
return tier, i
# 如果所有阶梯都有上限且都超过了,返回最后一个阶梯
return tiers[-1], len(tiers) - 1
def _compute_total_input_context(self, usage: StandardizedUsage) -> int:
"""
计算总输入上下文(用于阶梯计费判定)
默认: input_tokens + cache_read_tokens
Args:
usage: usage 数据
Returns:
总输入 token 数
"""
return usage.input_tokens + usage.cache_read_tokens
def _get_cache_read_price_for_ttl(
self,
tier: Dict[str, Any],
cache_ttl_minutes: int,
) -> Optional[float]:
"""
根据缓存 TTL 获取缓存读取价格
某些厂商(如 Claude对不同 TTL 的缓存有不同定价。
Args:
tier: 当前阶梯配置
cache_ttl_minutes: 缓存时长(分钟)
Returns:
缓存读取价格,如果没有 TTL 差异化配置返回 None
"""
ttl_pricing = tier.get("cache_ttl_pricing")
if not ttl_pricing:
return None
# 找到匹配或最接近的 TTL 价格
for ttl_config in ttl_pricing:
ttl_limit = ttl_config.get("ttl_minutes", 0)
if cache_ttl_minutes <= ttl_limit:
price = ttl_config.get("cache_read_price_per_1m")
return float(price) if price is not None else None
# 超过所有配置的 TTL使用最后一个
if ttl_pricing:
price = ttl_pricing[-1].get("cache_read_price_per_1m")
return float(price) if price is not None else None
return None
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "BillingCalculator":
"""
从配置创建计费计算器
Config 格式:
{
"template": "claude", # 或 "openai", "doubao", "per_request"
# 或者自定义维度:
"dimensions": [
{"name": "input", "usage_field": "input_tokens", "price_field": "input_price_per_1m"},
...
]
}
Args:
config: 配置字典
Returns:
BillingCalculator 实例
"""
if "dimensions" in config:
dimensions = [BillingDimension.from_dict(d) for d in config["dimensions"]]
return cls(dimensions=dimensions)
return cls(template=config.get("template", "claude"))
def get_dimension_names(self) -> List[str]:
"""获取所有计费维度名称"""
return [dim.name for dim in self.dimensions]
def get_required_price_fields(self) -> List[str]:
"""获取所需的价格字段名称"""
return [dim.price_field for dim in self.dimensions]
def get_required_usage_fields(self) -> List[str]:
"""获取所需的 usage 字段名称"""
return [dim.usage_field for dim in self.dimensions]
def calculate_request_cost(
input_tokens: int,
output_tokens: int,
cache_creation_input_tokens: int,
cache_read_input_tokens: int,
input_price_per_1m: float,
output_price_per_1m: float,
cache_creation_price_per_1m: Optional[float],
cache_read_price_per_1m: Optional[float],
price_per_request: Optional[float],
tiered_pricing: Optional[Dict[str, Any]] = None,
cache_ttl_minutes: Optional[int] = None,
total_input_context: Optional[int] = None,
billing_template: str = "claude",
) -> Dict[str, Any]:
"""
计算请求成本的便捷函数
封装了 BillingCalculator 的调用逻辑,返回兼容旧格式的字典。
Args:
input_tokens: 输入 token 数
output_tokens: 输出 token 数
cache_creation_input_tokens: 缓存创建 token 数
cache_read_input_tokens: 缓存读取 token 数
input_price_per_1m: 输入价格(每 1M tokens
output_price_per_1m: 输出价格(每 1M tokens
cache_creation_price_per_1m: 缓存创建价格(每 1M tokens
cache_read_price_per_1m: 缓存读取价格(每 1M tokens
price_per_request: 按次计费价格
tiered_pricing: 阶梯计费配置
cache_ttl_minutes: 缓存时长(分钟)
total_input_context: 总输入上下文(用于阶梯判定)
billing_template: 计费模板名称
Returns:
包含各项成本的字典:
{
"input_cost": float,
"output_cost": float,
"cache_creation_cost": float,
"cache_read_cost": float,
"cache_cost": float,
"request_cost": float,
"total_cost": float,
"tier_index": Optional[int],
}
"""
# 构建标准化 usage
usage = StandardizedUsage(
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_tokens=cache_creation_input_tokens,
cache_read_tokens=cache_read_input_tokens,
request_count=1,
)
# 构建价格配置
prices: Dict[str, float] = {
"input_price_per_1m": input_price_per_1m,
"output_price_per_1m": output_price_per_1m,
}
if cache_creation_price_per_1m is not None:
prices["cache_creation_price_per_1m"] = cache_creation_price_per_1m
if cache_read_price_per_1m is not None:
prices["cache_read_price_per_1m"] = cache_read_price_per_1m
if price_per_request is not None:
prices["price_per_request"] = price_per_request
# 使用 BillingCalculator 计算
calculator = BillingCalculator(template=billing_template)
result = calculator.calculate(
usage, prices, tiered_pricing, cache_ttl_minutes, total_input_context
)
# 返回兼容旧格式的字典
return {
"input_cost": result.input_cost,
"output_cost": result.output_cost,
"cache_creation_cost": result.cache_creation_cost,
"cache_read_cost": result.cache_read_cost,
"cache_cost": result.cache_cost,
"request_cost": result.request_cost,
"total_cost": result.total_cost,
"tier_index": result.tier_index,
}

View File

@@ -0,0 +1,281 @@
"""
计费模块数据模型
定义计费相关的核心数据结构:
- BillingUnit: 计费单位枚举
- BillingDimension: 计费维度定义
- StandardizedUsage: 标准化的 usage 数据
- CostBreakdown: 计费明细结果
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, Optional
class BillingUnit(str, Enum):
"""计费单位"""
PER_1M_TOKENS = "per_1m_tokens" # 每百万 token
PER_1M_TOKENS_HOUR = "per_1m_tokens_hour" # 每百万 token 每小时(豆包缓存存储)
PER_REQUEST = "per_request" # 每次请求
FIXED = "fixed" # 固定费用
@dataclass
class BillingDimension:
"""
计费维度定义
每个维度描述一种计费方式,例如:
- 输入 token 计费
- 输出 token 计费
- 缓存读取计费
- 按次计费
"""
name: str # 维度名称,如 "input", "output", "cache_read"
usage_field: str # 从 usage 中取值的字段名
price_field: str # 价格配置中的字段名
unit: BillingUnit = BillingUnit.PER_1M_TOKENS # 计费单位
default_price: float = 0.0 # 默认价格(当价格配置中没有时使用)
def calculate(self, usage_value: float, price: float) -> float:
"""
计算该维度的费用
Args:
usage_value: 使用量数值
price: 单价
Returns:
计算后的费用
"""
if usage_value <= 0 or price <= 0:
return 0.0
if self.unit == BillingUnit.PER_1M_TOKENS:
return (usage_value / 1_000_000) * price
elif self.unit == BillingUnit.PER_1M_TOKENS_HOUR:
# 缓存存储按 token 数 * 小时数计费
return (usage_value / 1_000_000) * price
elif self.unit == BillingUnit.PER_REQUEST:
return usage_value * price
elif self.unit == BillingUnit.FIXED:
return price
return 0.0
def to_dict(self) -> Dict[str, Any]:
"""转换为字典(用于序列化)"""
return {
"name": self.name,
"usage_field": self.usage_field,
"price_field": self.price_field,
"unit": self.unit.value,
"default_price": self.default_price,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "BillingDimension":
"""从字典创建实例"""
return cls(
name=data["name"],
usage_field=data["usage_field"],
price_field=data["price_field"],
unit=BillingUnit(data.get("unit", "per_1m_tokens")),
default_price=data.get("default_price", 0.0),
)
@dataclass
class StandardizedUsage:
"""
标准化的 Usage 数据
将不同 API 格式的 usage 统一为标准格式,便于计费计算。
"""
# 基础 token 计数
input_tokens: int = 0
output_tokens: int = 0
# 缓存相关
cache_creation_tokens: int = 0 # Claude: 缓存创建
cache_read_tokens: int = 0 # Claude/OpenAI/豆包: 缓存读取/命中
# 特殊 token 类型
reasoning_tokens: int = 0 # o1/豆包: 推理 token通常包含在 output 中,单独记录用于分析)
# 时间相关(用于按时计费)
cache_storage_token_hours: float = 0.0 # 豆包: 缓存存储 token*小时
# 请求计数(用于按次计费)
request_count: int = 1
# 扩展字段(未来可能需要的额外维度)
extra: Dict[str, Any] = field(default_factory=dict)
def get(self, field_name: str, default: Any = 0) -> Any:
"""
通用字段获取
支持获取标准字段和扩展字段。
Args:
field_name: 字段名
default: 默认值
Returns:
字段值
"""
if hasattr(self, field_name):
value = getattr(self, field_name)
# 对于 extra 字段,不直接返回
if field_name != "extra":
return value
return self.extra.get(field_name, default)
def set(self, field_name: str, value: Any) -> None:
"""
通用字段设置
Args:
field_name: 字段名
value: 字段值
"""
if hasattr(self, field_name) and field_name != "extra":
setattr(self, field_name, value)
else:
self.extra[field_name] = value
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
result: Dict[str, Any] = {
"input_tokens": self.input_tokens,
"output_tokens": self.output_tokens,
"cache_creation_tokens": self.cache_creation_tokens,
"cache_read_tokens": self.cache_read_tokens,
"reasoning_tokens": self.reasoning_tokens,
"cache_storage_token_hours": self.cache_storage_token_hours,
"request_count": self.request_count,
}
if self.extra:
result["extra"] = self.extra
return result
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "StandardizedUsage":
"""从字典创建实例"""
extra = data.pop("extra", {}) if "extra" in data else {}
# 只取已知字段
known_fields = {
"input_tokens",
"output_tokens",
"cache_creation_tokens",
"cache_read_tokens",
"reasoning_tokens",
"cache_storage_token_hours",
"request_count",
}
filtered = {k: v for k, v in data.items() if k in known_fields}
return cls(**filtered, extra=extra)
@dataclass
class CostBreakdown:
"""
计费明细结果
包含各维度的费用和总费用。
"""
# 各维度费用 {"input": 0.01, "output": 0.02, "cache_read": 0.001, ...}
costs: Dict[str, float] = field(default_factory=dict)
# 总费用
total_cost: float = 0.0
# 命中的阶梯索引(如果使用阶梯计费)
tier_index: Optional[int] = None
# 货币单位
currency: str = "USD"
# 使用的价格(用于记录和审计)
effective_prices: Dict[str, float] = field(default_factory=dict)
# =========================================================================
# 兼容旧接口的属性(便于渐进式迁移)
# =========================================================================
@property
def input_cost(self) -> float:
"""输入费用"""
return self.costs.get("input", 0.0)
@property
def output_cost(self) -> float:
"""输出费用"""
return self.costs.get("output", 0.0)
@property
def cache_creation_cost(self) -> float:
"""缓存创建费用"""
return self.costs.get("cache_creation", 0.0)
@property
def cache_read_cost(self) -> float:
"""缓存读取费用"""
return self.costs.get("cache_read", 0.0)
@property
def cache_cost(self) -> float:
"""总缓存费用(创建 + 读取)"""
return self.cache_creation_cost + self.cache_read_cost
@property
def request_cost(self) -> float:
"""按次计费费用"""
return self.costs.get("request", 0.0)
@property
def cache_storage_cost(self) -> float:
"""缓存存储费用(豆包等)"""
return self.costs.get("cache_storage", 0.0)
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"costs": self.costs,
"total_cost": self.total_cost,
"tier_index": self.tier_index,
"currency": self.currency,
"effective_prices": self.effective_prices,
# 兼容字段
"input_cost": self.input_cost,
"output_cost": self.output_cost,
"cache_creation_cost": self.cache_creation_cost,
"cache_read_cost": self.cache_read_cost,
"cache_cost": self.cache_cost,
"request_cost": self.request_cost,
}
def to_legacy_tuple(self) -> tuple:
"""
转换为旧接口的元组格式
Returns:
(input_cost, output_cost, cache_creation_cost, cache_read_cost,
cache_cost, request_cost, total_cost, tier_index)
"""
return (
self.input_cost,
self.output_cost,
self.cache_creation_cost,
self.cache_read_cost,
self.cache_cost,
self.request_cost,
self.total_cost,
self.tier_index,
)

View File

@@ -0,0 +1,213 @@
"""
预定义计费模板
提供常见厂商的计费配置模板,避免重复配置:
- CLAUDE_STANDARD: Claude/Anthropic 标准计费
- OPENAI_STANDARD: OpenAI 标准计费
- DOUBAO_STANDARD: 豆包计费(含缓存存储)
- GEMINI_STANDARD: Gemini 标准计费
- PER_REQUEST: 按次计费
"""
from typing import Dict, List, Optional
from src.services.billing.models import BillingDimension, BillingUnit
class BillingTemplates:
"""预定义的计费模板"""
# =========================================================================
# Claude/Anthropic 标准计费
# - 输入 token
# - 输出 token
# - 缓存创建(创建时收费,约 1.25x 输入价格)
# - 缓存读取(约 0.1x 输入价格)
# =========================================================================
CLAUDE_STANDARD: List[BillingDimension] = [
BillingDimension(
name="input",
usage_field="input_tokens",
price_field="input_price_per_1m",
),
BillingDimension(
name="output",
usage_field="output_tokens",
price_field="output_price_per_1m",
),
BillingDimension(
name="cache_creation",
usage_field="cache_creation_tokens",
price_field="cache_creation_price_per_1m",
),
BillingDimension(
name="cache_read",
usage_field="cache_read_tokens",
price_field="cache_read_price_per_1m",
),
]
# =========================================================================
# OpenAI 标准计费
# - 输入 token
# - 输出 token
# - 缓存读取(部分模型支持,无缓存创建费用)
# =========================================================================
OPENAI_STANDARD: List[BillingDimension] = [
BillingDimension(
name="input",
usage_field="input_tokens",
price_field="input_price_per_1m",
),
BillingDimension(
name="output",
usage_field="output_tokens",
price_field="output_price_per_1m",
),
BillingDimension(
name="cache_read",
usage_field="cache_read_tokens",
price_field="cache_read_price_per_1m",
),
]
# =========================================================================
# 豆包计费
# - 推理输入 (input_tokens)
# - 推理输出 (output_tokens)
# - 缓存命中 (cache_read_tokens) - 类似 Claude 的缓存读取
# - 缓存存储 (cache_storage_token_hours) - 按 token 数 * 存储时长计费
#
# 注意:豆包的缓存创建是免费的,但存储需要按时付费
# =========================================================================
DOUBAO_STANDARD: List[BillingDimension] = [
BillingDimension(
name="input",
usage_field="input_tokens",
price_field="input_price_per_1m",
),
BillingDimension(
name="output",
usage_field="output_tokens",
price_field="output_price_per_1m",
),
BillingDimension(
name="cache_read",
usage_field="cache_read_tokens",
price_field="cache_read_price_per_1m",
),
BillingDimension(
name="cache_storage",
usage_field="cache_storage_token_hours",
price_field="cache_storage_price_per_1m_hour",
unit=BillingUnit.PER_1M_TOKENS_HOUR,
),
]
# =========================================================================
# Gemini 标准计费
# - 输入 token
# - 输出 token
# - 缓存读取
# =========================================================================
GEMINI_STANDARD: List[BillingDimension] = [
BillingDimension(
name="input",
usage_field="input_tokens",
price_field="input_price_per_1m",
),
BillingDimension(
name="output",
usage_field="output_tokens",
price_field="output_price_per_1m",
),
BillingDimension(
name="cache_read",
usage_field="cache_read_tokens",
price_field="cache_read_price_per_1m",
),
]
# =========================================================================
# 按次计费
# - 适用于某些图片生成模型、特殊 API 等
# - 仅按请求次数计费,不按 token 计费
# =========================================================================
PER_REQUEST: List[BillingDimension] = [
BillingDimension(
name="request",
usage_field="request_count",
price_field="price_per_request",
unit=BillingUnit.PER_REQUEST,
),
]
# =========================================================================
# 混合计费(按次 + 按 token
# - 某些模型既有固定费用又有 token 费用
# =========================================================================
HYBRID_STANDARD: List[BillingDimension] = [
BillingDimension(
name="input",
usage_field="input_tokens",
price_field="input_price_per_1m",
),
BillingDimension(
name="output",
usage_field="output_tokens",
price_field="output_price_per_1m",
),
BillingDimension(
name="request",
usage_field="request_count",
price_field="price_per_request",
unit=BillingUnit.PER_REQUEST,
),
]
# =========================================================================
# 模板注册表
# =========================================================================
BILLING_TEMPLATE_REGISTRY: Dict[str, List[BillingDimension]] = {
# 按厂商名称
"claude": BillingTemplates.CLAUDE_STANDARD,
"anthropic": BillingTemplates.CLAUDE_STANDARD,
"openai": BillingTemplates.OPENAI_STANDARD,
"doubao": BillingTemplates.DOUBAO_STANDARD,
"bytedance": BillingTemplates.DOUBAO_STANDARD,
"gemini": BillingTemplates.GEMINI_STANDARD,
"google": BillingTemplates.GEMINI_STANDARD,
# 按计费模式
"per_request": BillingTemplates.PER_REQUEST,
"hybrid": BillingTemplates.HYBRID_STANDARD,
# 默认
"default": BillingTemplates.CLAUDE_STANDARD,
}
def get_template(name: Optional[str]) -> List[BillingDimension]:
"""
获取计费模板
Args:
name: 模板名称(不区分大小写)
Returns:
计费维度列表
"""
if not name:
return BILLING_TEMPLATE_REGISTRY["default"]
template = BILLING_TEMPLATE_REGISTRY.get(name.lower())
if template is None:
available = ", ".join(sorted(BILLING_TEMPLATE_REGISTRY.keys()))
raise ValueError(f"Unknown billing template: {name!r}. Available: {available}")
return template
def list_templates() -> List[str]:
"""列出所有可用的模板名称"""
return list(BILLING_TEMPLATE_REGISTRY.keys())

View File

@@ -0,0 +1,267 @@
"""
Usage 字段映射器
将不同 API 格式的原始 usage 数据映射为标准化格式。
支持的格式:
- OPENAI / OPENAI_CLI: OpenAI Chat Completions API
- CLAUDE / CLAUDE_CLI: Anthropic Messages API
- GEMINI / GEMINI_CLI: Google Gemini API
"""
from typing import Any, Dict, Optional
from src.services.billing.models import StandardizedUsage
class UsageMapper:
"""
Usage 字段映射器
将不同 API 格式的 usage 统一映射为 StandardizedUsage。
示例:
# OpenAI 格式
raw_usage = {
"prompt_tokens": 100,
"completion_tokens": 50,
"prompt_tokens_details": {"cached_tokens": 20},
"completion_tokens_details": {"reasoning_tokens": 10}
}
usage = UsageMapper.map(raw_usage, "OPENAI")
# Claude 格式
raw_usage = {
"input_tokens": 100,
"output_tokens": 50,
"cache_creation_input_tokens": 30,
"cache_read_input_tokens": 20
}
usage = UsageMapper.map(raw_usage, "CLAUDE")
"""
# =========================================================================
# 字段映射配置
# 格式: "source_path" -> "target_field"
# source_path 支持点号分隔的嵌套路径
# =========================================================================
# OpenAI 格式字段映射
OPENAI_MAPPING: Dict[str, str] = {
"prompt_tokens": "input_tokens",
"completion_tokens": "output_tokens",
"prompt_tokens_details.cached_tokens": "cache_read_tokens",
"completion_tokens_details.reasoning_tokens": "reasoning_tokens",
}
# Claude 格式字段映射
CLAUDE_MAPPING: Dict[str, str] = {
"input_tokens": "input_tokens",
"output_tokens": "output_tokens",
"cache_creation_input_tokens": "cache_creation_tokens",
"cache_read_input_tokens": "cache_read_tokens",
}
# Gemini 格式字段映射
GEMINI_MAPPING: Dict[str, str] = {
"promptTokenCount": "input_tokens",
"candidatesTokenCount": "output_tokens",
"cachedContentTokenCount": "cache_read_tokens",
# Gemini 的 usageMetadata 格式
"usageMetadata.promptTokenCount": "input_tokens",
"usageMetadata.candidatesTokenCount": "output_tokens",
"usageMetadata.cachedContentTokenCount": "cache_read_tokens",
}
# 格式名称到映射的对应关系
FORMAT_MAPPINGS: Dict[str, Dict[str, str]] = {
"OPENAI": OPENAI_MAPPING,
"OPENAI_CLI": OPENAI_MAPPING,
"CLAUDE": CLAUDE_MAPPING,
"CLAUDE_CLI": CLAUDE_MAPPING,
"GEMINI": GEMINI_MAPPING,
"GEMINI_CLI": GEMINI_MAPPING,
}
@classmethod
def map(
cls,
raw_usage: Dict[str, Any],
api_format: str,
extra_mapping: Optional[Dict[str, str]] = None,
) -> StandardizedUsage:
"""
将原始 usage 映射为标准化格式
Args:
raw_usage: 原始 usage 字典
api_format: API 格式 ("OPENAI", "CLAUDE", "GEMINI" 等)
extra_mapping: 额外的字段映射(用于自定义扩展)
Returns:
标准化的 usage 对象
"""
if not raw_usage:
return StandardizedUsage()
# 获取对应格式的字段映射
mapping = cls._get_mapping(api_format)
# 合并额外映射
if extra_mapping:
mapping = {**mapping, **extra_mapping}
result = StandardizedUsage()
# 执行映射
for source_path, target_field in mapping.items():
value = cls._get_nested_value(raw_usage, source_path)
if value is not None:
result.set(target_field, value)
return result
@classmethod
def map_from_response(
cls,
response: Dict[str, Any],
api_format: str,
) -> StandardizedUsage:
"""
从完整响应中提取并映射 usage
不同 API 格式的 usage 位置可能不同:
- OpenAI: response["usage"]
- Claude: response["usage"] 或 message_delta 中
- Gemini: response["usageMetadata"]
Args:
response: 完整的 API 响应
api_format: API 格式
Returns:
标准化的 usage 对象
"""
format_upper = api_format.upper() if api_format else ""
# 提取 usage 部分
usage_data: Dict[str, Any] = {}
if format_upper.startswith("GEMINI"):
# Gemini: usageMetadata
usage_data = response.get("usageMetadata", {})
if not usage_data:
# 尝试从 candidates 中获取
candidates = response.get("candidates", [])
if candidates:
usage_data = candidates[0].get("usageMetadata", {})
else:
# OpenAI/Claude: usage
usage_data = response.get("usage", {})
return cls.map(usage_data, api_format)
@classmethod
def _get_mapping(cls, api_format: str) -> Dict[str, str]:
"""获取对应格式的字段映射"""
if not api_format:
return cls.CLAUDE_MAPPING
format_upper = api_format.upper()
# 精确匹配
if format_upper in cls.FORMAT_MAPPINGS:
return cls.FORMAT_MAPPINGS[format_upper]
# 前缀匹配
for key, mapping in cls.FORMAT_MAPPINGS.items():
if format_upper.startswith(key.split("_")[0]):
return mapping
# 默认使用 Claude 映射
return cls.CLAUDE_MAPPING
@classmethod
def _get_nested_value(cls, data: Dict[str, Any], path: str) -> Any:
"""
获取嵌套字段值
支持点号分隔的路径,如 "prompt_tokens_details.cached_tokens"
Args:
data: 数据字典
path: 字段路径
Returns:
字段值,不存在则返回 None
"""
if not data or not path:
return None
keys = path.split(".")
value: Any = data
for key in keys:
if isinstance(value, dict):
value = value.get(key)
if value is None:
return None
else:
return None
return value
@classmethod
def register_format(cls, format_name: str, mapping: Dict[str, str]) -> None:
"""
注册新的格式映射
Args:
format_name: 格式名称(会自动转为大写)
mapping: 字段映射
"""
cls.FORMAT_MAPPINGS[format_name.upper()] = mapping
@classmethod
def get_supported_formats(cls) -> list:
"""获取所有支持的格式"""
return list(cls.FORMAT_MAPPINGS.keys())
# =========================================================================
# 便捷函数
# =========================================================================
def map_usage(
raw_usage: Dict[str, Any],
api_format: str,
) -> StandardizedUsage:
"""
便捷函数:将原始 usage 映射为标准化格式
Args:
raw_usage: 原始 usage 字典
api_format: API 格式
Returns:
StandardizedUsage 对象
"""
return UsageMapper.map(raw_usage, api_format)
def map_usage_from_response(
response: Dict[str, Any],
api_format: str,
) -> StandardizedUsage:
"""
便捷函数:从响应中提取并映射 usage
Args:
response: API 响应
api_format: API 格式
Returns:
StandardizedUsage 对象
"""
return UsageMapper.map_from_response(response, api_format)

View File

View File

@@ -0,0 +1,440 @@
"""
Billing 模块测试
测试计费模块的核心功能:
- BillingCalculator 计费计算
- 计费模板
- 阶梯计费
- calculate_request_cost 便捷函数
"""
import pytest
from src.services.billing import (
BillingCalculator,
BillingDimension,
BillingTemplates,
BillingUnit,
CostBreakdown,
StandardizedUsage,
calculate_request_cost,
)
from src.services.billing.templates import get_template, list_templates
class TestBillingDimension:
"""测试计费维度"""
def test_calculate_per_1m_tokens(self) -> None:
"""测试 per_1m_tokens 计费"""
dim = BillingDimension(
name="input",
usage_field="input_tokens",
price_field="input_price_per_1m",
)
# 1000 tokens * $3 / 1M = $0.003
cost = dim.calculate(1000, 3.0)
assert abs(cost - 0.003) < 0.0001
def test_calculate_per_request(self) -> None:
"""测试按次计费"""
dim = BillingDimension(
name="request",
usage_field="request_count",
price_field="price_per_request",
unit=BillingUnit.PER_REQUEST,
)
# 按次计费cost = request_count * price
cost = dim.calculate(1, 0.05)
assert cost == 0.05
# 多次请求应按次数计费
cost = dim.calculate(3, 0.05)
assert abs(cost - 0.15) < 0.0001
def test_calculate_zero_usage(self) -> None:
"""测试零用量"""
dim = BillingDimension(
name="input",
usage_field="input_tokens",
price_field="input_price_per_1m",
)
cost = dim.calculate(0, 3.0)
assert cost == 0.0
def test_calculate_zero_price(self) -> None:
"""测试零价格"""
dim = BillingDimension(
name="input",
usage_field="input_tokens",
price_field="input_price_per_1m",
)
cost = dim.calculate(1000, 0.0)
assert cost == 0.0
def test_to_dict_and_from_dict(self) -> None:
"""测试序列化和反序列化"""
dim = BillingDimension(
name="cache_read",
usage_field="cache_read_tokens",
price_field="cache_read_price_per_1m",
unit=BillingUnit.PER_1M_TOKENS,
default_price=0.3,
)
d = dim.to_dict()
restored = BillingDimension.from_dict(d)
assert restored.name == dim.name
assert restored.usage_field == dim.usage_field
assert restored.price_field == dim.price_field
assert restored.unit == dim.unit
assert restored.default_price == dim.default_price
class TestStandardizedUsage:
"""测试标准化 Usage"""
def test_basic_usage(self) -> None:
"""测试基础 usage"""
usage = StandardizedUsage(
input_tokens=1000,
output_tokens=500,
)
assert usage.input_tokens == 1000
assert usage.output_tokens == 500
assert usage.cache_creation_tokens == 0
assert usage.cache_read_tokens == 0
def test_get_field(self) -> None:
"""测试字段获取"""
usage = StandardizedUsage(
input_tokens=1000,
output_tokens=500,
)
assert usage.get("input_tokens") == 1000
assert usage.get("nonexistent", 0) == 0
def test_extra_fields(self) -> None:
"""测试扩展字段"""
usage = StandardizedUsage(
input_tokens=1000,
output_tokens=500,
extra={"custom_field": 123},
)
assert usage.get("custom_field") == 123
def test_to_dict(self) -> None:
"""测试转换为字典"""
usage = StandardizedUsage(
input_tokens=1000,
output_tokens=500,
cache_creation_tokens=100,
)
d = usage.to_dict()
assert d["input_tokens"] == 1000
assert d["output_tokens"] == 500
assert d["cache_creation_tokens"] == 100
class TestCostBreakdown:
"""测试费用明细"""
def test_basic_breakdown(self) -> None:
"""测试基础费用明细"""
breakdown = CostBreakdown(
costs={"input": 0.003, "output": 0.0075},
total_cost=0.0105,
)
assert breakdown.input_cost == 0.003
assert breakdown.output_cost == 0.0075
assert breakdown.total_cost == 0.0105
def test_cache_cost_calculation(self) -> None:
"""测试缓存费用汇总"""
breakdown = CostBreakdown(
costs={
"input": 0.003,
"output": 0.0075,
"cache_creation": 0.001,
"cache_read": 0.0005,
},
total_cost=0.012,
)
# cache_cost = cache_creation + cache_read
assert abs(breakdown.cache_cost - 0.0015) < 0.0001
def test_to_dict(self) -> None:
"""测试转换为字典"""
breakdown = CostBreakdown(
costs={"input": 0.003, "output": 0.0075},
total_cost=0.0105,
tier_index=1,
)
d = breakdown.to_dict()
assert d["total_cost"] == 0.0105
assert d["tier_index"] == 1
assert d["input_cost"] == 0.003
class TestBillingTemplates:
"""测试计费模板"""
def test_claude_template(self) -> None:
"""测试 Claude 模板"""
template = BillingTemplates.CLAUDE_STANDARD
dim_names = [d.name for d in template]
assert "input" in dim_names
assert "output" in dim_names
assert "cache_creation" in dim_names
assert "cache_read" in dim_names
def test_openai_template(self) -> None:
"""测试 OpenAI 模板"""
template = BillingTemplates.OPENAI_STANDARD
dim_names = [d.name for d in template]
assert "input" in dim_names
assert "output" in dim_names
assert "cache_read" in dim_names
# OpenAI 没有缓存创建费用
assert "cache_creation" not in dim_names
def test_gemini_template(self) -> None:
"""测试 Gemini 模板"""
template = BillingTemplates.GEMINI_STANDARD
dim_names = [d.name for d in template]
assert "input" in dim_names
assert "output" in dim_names
assert "cache_read" in dim_names
def test_per_request_template(self) -> None:
"""测试按次计费模板"""
template = BillingTemplates.PER_REQUEST
assert len(template) == 1
assert template[0].name == "request"
assert template[0].unit == BillingUnit.PER_REQUEST
def test_get_template(self) -> None:
"""测试获取模板"""
template = get_template("claude")
assert template == BillingTemplates.CLAUDE_STANDARD
template = get_template("openai")
assert template == BillingTemplates.OPENAI_STANDARD
# 不区分大小写
template = get_template("CLAUDE")
assert template == BillingTemplates.CLAUDE_STANDARD
with pytest.raises(ValueError, match="Unknown billing template"):
get_template("unknown_template")
def test_list_templates(self) -> None:
"""测试列出模板"""
templates = list_templates()
assert "claude" in templates
assert "openai" in templates
assert "gemini" in templates
assert "per_request" in templates
class TestBillingCalculator:
"""测试计费计算器"""
def test_basic_calculation(self) -> None:
"""测试基础计费计算"""
calculator = BillingCalculator(template="claude")
usage = StandardizedUsage(input_tokens=1000, output_tokens=500)
prices = {"input_price_per_1m": 3.0, "output_price_per_1m": 15.0}
result = calculator.calculate(usage, prices)
# 1000 * 3 / 1M = 0.003
assert abs(result.input_cost - 0.003) < 0.0001
# 500 * 15 / 1M = 0.0075
assert abs(result.output_cost - 0.0075) < 0.0001
# Total = 0.0105
assert abs(result.total_cost - 0.0105) < 0.0001
def test_calculation_with_cache(self) -> None:
"""测试带缓存的计费计算"""
calculator = BillingCalculator(template="claude")
usage = StandardizedUsage(
input_tokens=1000,
output_tokens=500,
cache_creation_tokens=200,
cache_read_tokens=300,
)
prices = {
"input_price_per_1m": 3.0,
"output_price_per_1m": 15.0,
"cache_creation_price_per_1m": 3.75,
"cache_read_price_per_1m": 0.3,
}
result = calculator.calculate(usage, prices)
# cache_creation: 200 * 3.75 / 1M = 0.00075
assert abs(result.cache_creation_cost - 0.00075) < 0.0001
# cache_read: 300 * 0.3 / 1M = 0.00009
assert abs(result.cache_read_cost - 0.00009) < 0.0001
def test_tiered_pricing(self) -> None:
"""测试阶梯计费"""
calculator = BillingCalculator(template="claude")
usage = StandardizedUsage(input_tokens=250000, output_tokens=10000)
# 大于 200k 进入第二阶梯
tiered_pricing = {
"tiers": [
{"up_to": 200000, "input_price_per_1m": 3.0, "output_price_per_1m": 15.0},
{"up_to": None, "input_price_per_1m": 1.5, "output_price_per_1m": 7.5},
]
}
prices = {"input_price_per_1m": 3.0, "output_price_per_1m": 15.0}
result = calculator.calculate(usage, prices, tiered_pricing)
# 应该使用第二阶梯价格
assert result.tier_index == 1
# 250000 * 1.5 / 1M = 0.375
assert abs(result.input_cost - 0.375) < 0.0001
def test_openai_no_cache_creation(self) -> None:
"""测试 OpenAI 模板没有缓存创建费用"""
calculator = BillingCalculator(template="openai")
usage = StandardizedUsage(
input_tokens=1000,
output_tokens=500,
cache_creation_tokens=200, # 这个不应该计费
cache_read_tokens=300,
)
prices = {
"input_price_per_1m": 3.0,
"output_price_per_1m": 15.0,
"cache_creation_price_per_1m": 3.75,
"cache_read_price_per_1m": 0.3,
}
result = calculator.calculate(usage, prices)
# OpenAI 模板不包含 cache_creation 维度
assert result.cache_creation_cost == 0.0
# 但 cache_read 应该计费
assert result.cache_read_cost > 0
def test_from_config(self) -> None:
"""测试从配置创建计算器"""
config = {"template": "openai"}
calculator = BillingCalculator.from_config(config)
assert calculator.template_name == "openai"
class TestCalculateRequestCost:
"""测试便捷函数"""
def test_basic_usage(self) -> None:
"""测试基础用法"""
result = calculate_request_cost(
input_tokens=1000,
output_tokens=500,
cache_creation_input_tokens=0,
cache_read_input_tokens=0,
input_price_per_1m=3.0,
output_price_per_1m=15.0,
cache_creation_price_per_1m=None,
cache_read_price_per_1m=None,
price_per_request=None,
billing_template="claude",
)
assert "input_cost" in result
assert "output_cost" in result
assert "total_cost" in result
assert abs(result["input_cost"] - 0.003) < 0.0001
assert abs(result["output_cost"] - 0.0075) < 0.0001
def test_with_cache(self) -> None:
"""测试带缓存"""
result = calculate_request_cost(
input_tokens=1000,
output_tokens=500,
cache_creation_input_tokens=200,
cache_read_input_tokens=300,
input_price_per_1m=3.0,
output_price_per_1m=15.0,
cache_creation_price_per_1m=3.75,
cache_read_price_per_1m=0.3,
price_per_request=None,
billing_template="claude",
)
assert result["cache_creation_cost"] > 0
assert result["cache_read_cost"] > 0
assert result["cache_cost"] == result["cache_creation_cost"] + result["cache_read_cost"]
def test_different_templates(self) -> None:
"""测试不同模板"""
prices = {
"input_tokens": 1000,
"output_tokens": 500,
"cache_creation_input_tokens": 200,
"cache_read_input_tokens": 300,
"input_price_per_1m": 3.0,
"output_price_per_1m": 15.0,
"cache_creation_price_per_1m": 3.75,
"cache_read_price_per_1m": 0.3,
"price_per_request": None,
}
# Claude 模板有 cache_creation
result_claude = calculate_request_cost(**prices, billing_template="claude")
assert result_claude["cache_creation_cost"] > 0
# OpenAI 模板没有 cache_creation
result_openai = calculate_request_cost(**prices, billing_template="openai")
assert result_openai["cache_creation_cost"] == 0
def test_tiered_pricing_with_total_context(self) -> None:
"""测试使用自定义 total_input_context 的阶梯计费"""
tiered_pricing = {
"tiers": [
{"up_to": 200000, "input_price_per_1m": 3.0, "output_price_per_1m": 15.0},
{"up_to": None, "input_price_per_1m": 1.5, "output_price_per_1m": 7.5},
]
}
# 传入预计算的 total_input_context
result = calculate_request_cost(
input_tokens=1000,
output_tokens=500,
cache_creation_input_tokens=0,
cache_read_input_tokens=0,
input_price_per_1m=3.0,
output_price_per_1m=15.0,
cache_creation_price_per_1m=None,
cache_read_price_per_1m=None,
price_per_request=None,
tiered_pricing=tiered_pricing,
total_input_context=250000, # 预计算的值,超过 200k
billing_template="claude",
)
# 应该使用第二阶梯价格
assert result["tier_index"] == 1