mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-10 03:32:26 +08:00
refactor: 抽取统一计费模块,支持配置驱动的多厂商计费
- 新增 src/services/billing/ 模块,包含计费计算器、模板和使用量映射 - 将 ChatAdapterBase 和 CliAdapterBase 中的计费逻辑重构为调用 billing 模块 - 为每个 adapter 添加 BILLING_TEMPLATE 类属性,指定计费模板 - 支持 Claude/OpenAI/Gemini 三种计费模板,支持阶梯计费和缓存 TTL 定价 - 新增 tests/services/billing/ 单元测试
This commit is contained in:
@@ -40,6 +40,7 @@ from src.core.exceptions import (
|
||||
UpstreamClientException,
|
||||
)
|
||||
from src.core.logger import logger
|
||||
from src.services.billing import calculate_request_cost as _calculate_request_cost
|
||||
from src.services.request.result import RequestResult
|
||||
from src.services.usage.recorder import UsageRecorder
|
||||
|
||||
@@ -63,6 +64,9 @@ class ChatAdapterBase(ApiAdapter):
|
||||
name: str = "chat.base"
|
||||
mode = ApiMode.STANDARD
|
||||
|
||||
# 计费模板配置(子类可覆盖,如 "claude", "openai", "gemini")
|
||||
BILLING_TEMPLATE: str = "claude"
|
||||
|
||||
# 子类可以配置的特殊方法(用于check_endpoint)
|
||||
@classmethod
|
||||
def build_endpoint_url(cls, base_url: str) -> str:
|
||||
@@ -486,40 +490,6 @@ class ChatAdapterBase(ApiAdapter):
|
||||
"""
|
||||
return input_tokens + cache_read_input_tokens
|
||||
|
||||
def get_cache_read_price_for_ttl(
|
||||
self,
|
||||
tier: dict,
|
||||
cache_ttl_minutes: Optional[int] = None,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
根据缓存 TTL 获取缓存读取价格
|
||||
|
||||
默认实现:检查 cache_ttl_pricing 配置,按 TTL 选择价格
|
||||
子类可覆盖此方法实现不同的 TTL 定价逻辑
|
||||
|
||||
Args:
|
||||
tier: 当前阶梯配置
|
||||
cache_ttl_minutes: 缓存时长(分钟)
|
||||
|
||||
Returns:
|
||||
缓存读取价格(每 1M tokens)
|
||||
"""
|
||||
ttl_pricing = tier.get("cache_ttl_pricing")
|
||||
if ttl_pricing and cache_ttl_minutes is not None:
|
||||
matched_price = None
|
||||
for ttl_config in ttl_pricing:
|
||||
ttl_limit = ttl_config.get("ttl_minutes", 0)
|
||||
if cache_ttl_minutes <= ttl_limit:
|
||||
matched_price = ttl_config.get("cache_read_price_per_1m")
|
||||
break
|
||||
if matched_price is not None:
|
||||
return matched_price
|
||||
# 超过所有配置的 TTL,使用最后一个
|
||||
if ttl_pricing:
|
||||
return ttl_pricing[-1].get("cache_read_price_per_1m")
|
||||
|
||||
return tier.get("cache_read_price_per_1m")
|
||||
|
||||
def compute_cost(
|
||||
self,
|
||||
input_tokens: int,
|
||||
@@ -537,8 +507,9 @@ class ChatAdapterBase(ApiAdapter):
|
||||
"""
|
||||
计算请求成本
|
||||
|
||||
默认实现:支持固定价格和阶梯计费
|
||||
子类可覆盖此方法实现完全不同的计费逻辑
|
||||
使用 billing 模块的配置驱动计费。
|
||||
子类可通过设置 BILLING_TEMPLATE 类属性来指定计费模板,
|
||||
或覆盖此方法实现完全自定义的计费逻辑。
|
||||
|
||||
Args:
|
||||
input_tokens: 输入 token 数
|
||||
@@ -566,88 +537,26 @@ class ChatAdapterBase(ApiAdapter):
|
||||
"tier_index": Optional[int], # 命中的阶梯索引
|
||||
}
|
||||
"""
|
||||
tier_index = None
|
||||
effective_input_price = input_price_per_1m
|
||||
effective_output_price = output_price_per_1m
|
||||
effective_cache_creation_price = cache_creation_price_per_1m
|
||||
effective_cache_read_price = cache_read_price_per_1m
|
||||
# 计算总输入上下文(使用子类可覆盖的方法)
|
||||
total_input_context = self.compute_total_input_context(
|
||||
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
|
||||
)
|
||||
|
||||
# 检查阶梯计费
|
||||
if tiered_pricing and tiered_pricing.get("tiers"):
|
||||
total_input_context = self.compute_total_input_context(
|
||||
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
|
||||
)
|
||||
tier = self._get_tier_for_tokens(tiered_pricing, total_input_context)
|
||||
|
||||
if tier:
|
||||
tier_index = tiered_pricing["tiers"].index(tier)
|
||||
effective_input_price = tier.get("input_price_per_1m", input_price_per_1m)
|
||||
effective_output_price = tier.get("output_price_per_1m", output_price_per_1m)
|
||||
effective_cache_creation_price = tier.get(
|
||||
"cache_creation_price_per_1m", cache_creation_price_per_1m
|
||||
)
|
||||
effective_cache_read_price = self.get_cache_read_price_for_ttl(
|
||||
tier, cache_ttl_minutes
|
||||
)
|
||||
if effective_cache_read_price is None:
|
||||
effective_cache_read_price = cache_read_price_per_1m
|
||||
|
||||
# 计算各项成本
|
||||
input_cost = (input_tokens / 1_000_000) * effective_input_price
|
||||
output_cost = (output_tokens / 1_000_000) * effective_output_price
|
||||
|
||||
cache_creation_cost = 0.0
|
||||
cache_read_cost = 0.0
|
||||
if cache_creation_input_tokens > 0 and effective_cache_creation_price is not None:
|
||||
cache_creation_cost = (
|
||||
cache_creation_input_tokens / 1_000_000
|
||||
) * effective_cache_creation_price
|
||||
if cache_read_input_tokens > 0 and effective_cache_read_price is not None:
|
||||
cache_read_cost = (
|
||||
cache_read_input_tokens / 1_000_000
|
||||
) * effective_cache_read_price
|
||||
|
||||
cache_cost = cache_creation_cost + cache_read_cost
|
||||
request_cost = price_per_request if price_per_request else 0.0
|
||||
total_cost = input_cost + output_cost + cache_cost + request_cost
|
||||
|
||||
return {
|
||||
"input_cost": input_cost,
|
||||
"output_cost": output_cost,
|
||||
"cache_creation_cost": cache_creation_cost,
|
||||
"cache_read_cost": cache_read_cost,
|
||||
"cache_cost": cache_cost,
|
||||
"request_cost": request_cost,
|
||||
"total_cost": total_cost,
|
||||
"tier_index": tier_index,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_tier_for_tokens(tiered_pricing: dict, total_input_tokens: int) -> Optional[dict]:
|
||||
"""
|
||||
根据总输入 token 数确定价格阶梯
|
||||
|
||||
Args:
|
||||
tiered_pricing: 阶梯计费配置 {"tiers": [...]}
|
||||
total_input_tokens: 总输入 token 数
|
||||
|
||||
Returns:
|
||||
匹配的阶梯配置
|
||||
"""
|
||||
if not tiered_pricing or "tiers" not in tiered_pricing:
|
||||
return None
|
||||
|
||||
tiers = tiered_pricing.get("tiers", [])
|
||||
if not tiers:
|
||||
return None
|
||||
|
||||
for tier in tiers:
|
||||
up_to = tier.get("up_to")
|
||||
if up_to is None or total_input_tokens <= up_to:
|
||||
return tier
|
||||
|
||||
# 如果所有阶梯都有上限且都超过了,返回最后一个阶梯
|
||||
return tiers[-1] if tiers else None
|
||||
return _calculate_request_cost(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
input_price_per_1m=input_price_per_1m,
|
||||
output_price_per_1m=output_price_per_1m,
|
||||
cache_creation_price_per_1m=cache_creation_price_per_1m,
|
||||
cache_read_price_per_1m=cache_read_price_per_1m,
|
||||
price_per_request=price_per_request,
|
||||
tiered_pricing=tiered_pricing,
|
||||
cache_ttl_minutes=cache_ttl_minutes,
|
||||
total_input_context=total_input_context,
|
||||
billing_template=self.BILLING_TEMPLATE,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# 模型列表查询 - 子类应覆盖此方法
|
||||
|
||||
@@ -38,6 +38,7 @@ from src.core.exceptions import (
|
||||
UpstreamClientException,
|
||||
)
|
||||
from src.core.logger import logger
|
||||
from src.services.billing import calculate_request_cost as _calculate_request_cost
|
||||
from src.services.request.result import RequestResult
|
||||
from src.services.usage.recorder import UsageRecorder
|
||||
|
||||
@@ -61,6 +62,9 @@ class CliAdapterBase(ApiAdapter):
|
||||
name: str = "cli.base"
|
||||
mode = ApiMode.PROXY
|
||||
|
||||
# 计费模板配置(子类可覆盖,如 "claude", "openai", "gemini")
|
||||
BILLING_TEMPLATE: str = "claude"
|
||||
|
||||
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
||||
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
|
||||
|
||||
@@ -438,40 +442,6 @@ class CliAdapterBase(ApiAdapter):
|
||||
"""
|
||||
return input_tokens + cache_read_input_tokens
|
||||
|
||||
def get_cache_read_price_for_ttl(
|
||||
self,
|
||||
tier: dict,
|
||||
cache_ttl_minutes: Optional[int] = None,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
根据缓存 TTL 获取缓存读取价格
|
||||
|
||||
默认实现:检查 cache_ttl_pricing 配置,按 TTL 选择价格
|
||||
子类可覆盖此方法实现不同的 TTL 定价逻辑
|
||||
|
||||
Args:
|
||||
tier: 当前阶梯配置
|
||||
cache_ttl_minutes: 缓存时长(分钟)
|
||||
|
||||
Returns:
|
||||
缓存读取价格(每 1M tokens)
|
||||
"""
|
||||
ttl_pricing = tier.get("cache_ttl_pricing")
|
||||
if ttl_pricing and cache_ttl_minutes is not None:
|
||||
matched_price = None
|
||||
for ttl_config in ttl_pricing:
|
||||
ttl_limit = ttl_config.get("ttl_minutes", 0)
|
||||
if cache_ttl_minutes <= ttl_limit:
|
||||
matched_price = ttl_config.get("cache_read_price_per_1m")
|
||||
break
|
||||
if matched_price is not None:
|
||||
return matched_price
|
||||
# 超过所有配置的 TTL,使用最后一个
|
||||
if ttl_pricing:
|
||||
return ttl_pricing[-1].get("cache_read_price_per_1m")
|
||||
|
||||
return tier.get("cache_read_price_per_1m")
|
||||
|
||||
def compute_cost(
|
||||
self,
|
||||
input_tokens: int,
|
||||
@@ -489,8 +459,9 @@ class CliAdapterBase(ApiAdapter):
|
||||
"""
|
||||
计算请求成本
|
||||
|
||||
默认实现:支持固定价格和阶梯计费
|
||||
子类可覆盖此方法实现完全不同的计费逻辑
|
||||
使用 billing 模块的配置驱动计费。
|
||||
子类可通过设置 BILLING_TEMPLATE 类属性来指定计费模板,
|
||||
或覆盖此方法实现完全自定义的计费逻辑。
|
||||
|
||||
Args:
|
||||
input_tokens: 输入 token 数
|
||||
@@ -508,78 +479,26 @@ class CliAdapterBase(ApiAdapter):
|
||||
Returns:
|
||||
包含各项成本的字典
|
||||
"""
|
||||
tier_index = None
|
||||
effective_input_price = input_price_per_1m
|
||||
effective_output_price = output_price_per_1m
|
||||
effective_cache_creation_price = cache_creation_price_per_1m
|
||||
effective_cache_read_price = cache_read_price_per_1m
|
||||
# 计算总输入上下文(使用子类可覆盖的方法)
|
||||
total_input_context = self.compute_total_input_context(
|
||||
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
|
||||
)
|
||||
|
||||
# 检查阶梯计费
|
||||
if tiered_pricing and tiered_pricing.get("tiers"):
|
||||
total_input_context = self.compute_total_input_context(
|
||||
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
|
||||
)
|
||||
tier = self._get_tier_for_tokens(tiered_pricing, total_input_context)
|
||||
|
||||
if tier:
|
||||
tier_index = tiered_pricing["tiers"].index(tier)
|
||||
effective_input_price = tier.get("input_price_per_1m", input_price_per_1m)
|
||||
effective_output_price = tier.get("output_price_per_1m", output_price_per_1m)
|
||||
effective_cache_creation_price = tier.get(
|
||||
"cache_creation_price_per_1m", cache_creation_price_per_1m
|
||||
)
|
||||
effective_cache_read_price = self.get_cache_read_price_for_ttl(
|
||||
tier, cache_ttl_minutes
|
||||
)
|
||||
if effective_cache_read_price is None:
|
||||
effective_cache_read_price = cache_read_price_per_1m
|
||||
|
||||
# 计算各项成本
|
||||
input_cost = (input_tokens / 1_000_000) * effective_input_price
|
||||
output_cost = (output_tokens / 1_000_000) * effective_output_price
|
||||
|
||||
cache_creation_cost = 0.0
|
||||
cache_read_cost = 0.0
|
||||
if cache_creation_input_tokens > 0 and effective_cache_creation_price is not None:
|
||||
cache_creation_cost = (
|
||||
cache_creation_input_tokens / 1_000_000
|
||||
) * effective_cache_creation_price
|
||||
if cache_read_input_tokens > 0 and effective_cache_read_price is not None:
|
||||
cache_read_cost = (
|
||||
cache_read_input_tokens / 1_000_000
|
||||
) * effective_cache_read_price
|
||||
|
||||
cache_cost = cache_creation_cost + cache_read_cost
|
||||
request_cost = price_per_request if price_per_request else 0.0
|
||||
total_cost = input_cost + output_cost + cache_cost + request_cost
|
||||
|
||||
return {
|
||||
"input_cost": input_cost,
|
||||
"output_cost": output_cost,
|
||||
"cache_creation_cost": cache_creation_cost,
|
||||
"cache_read_cost": cache_read_cost,
|
||||
"cache_cost": cache_cost,
|
||||
"request_cost": request_cost,
|
||||
"total_cost": total_cost,
|
||||
"tier_index": tier_index,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_tier_for_tokens(tiered_pricing: dict, total_input_tokens: int) -> Optional[dict]:
|
||||
"""根据总输入 token 数确定价格阶梯"""
|
||||
if not tiered_pricing or "tiers" not in tiered_pricing:
|
||||
return None
|
||||
|
||||
tiers = tiered_pricing.get("tiers", [])
|
||||
if not tiers:
|
||||
return None
|
||||
|
||||
for tier in tiers:
|
||||
up_to = tier.get("up_to")
|
||||
if up_to is None or total_input_tokens <= up_to:
|
||||
return tier
|
||||
|
||||
return tiers[-1] if tiers else None
|
||||
return _calculate_request_cost(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
input_price_per_1m=input_price_per_1m,
|
||||
output_price_per_1m=output_price_per_1m,
|
||||
cache_creation_price_per_1m=cache_creation_price_per_1m,
|
||||
cache_read_price_per_1m=cache_read_price_per_1m,
|
||||
price_per_request=price_per_request,
|
||||
tiered_pricing=tiered_pricing,
|
||||
cache_ttl_minutes=cache_ttl_minutes,
|
||||
total_input_context=total_input_context,
|
||||
billing_template=self.BILLING_TEMPLATE,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# 模型列表查询 - 子类应覆盖此方法
|
||||
|
||||
@@ -63,6 +63,7 @@ class ClaudeChatAdapter(ChatAdapterBase):
|
||||
"""
|
||||
|
||||
FORMAT_ID = "CLAUDE"
|
||||
BILLING_TEMPLATE = "claude" # 使用 Claude 计费模板
|
||||
name = "claude.chat"
|
||||
|
||||
@property
|
||||
|
||||
@@ -24,6 +24,7 @@ class ClaudeCliAdapter(CliAdapterBase):
|
||||
"""
|
||||
|
||||
FORMAT_ID = "CLAUDE_CLI"
|
||||
BILLING_TEMPLATE = "claude" # 使用 Claude 计费模板
|
||||
name = "claude.cli"
|
||||
|
||||
@property
|
||||
|
||||
@@ -27,6 +27,7 @@ class GeminiChatAdapter(ChatAdapterBase):
|
||||
"""
|
||||
|
||||
FORMAT_ID = "GEMINI"
|
||||
BILLING_TEMPLATE = "gemini" # 使用 Gemini 计费模板
|
||||
name = "gemini.chat"
|
||||
|
||||
@property
|
||||
|
||||
@@ -24,6 +24,7 @@ class GeminiCliAdapter(CliAdapterBase):
|
||||
"""
|
||||
|
||||
FORMAT_ID = "GEMINI_CLI"
|
||||
BILLING_TEMPLATE = "gemini" # 使用 Gemini 计费模板
|
||||
name = "gemini.cli"
|
||||
|
||||
@property
|
||||
|
||||
@@ -26,6 +26,7 @@ class OpenAIChatAdapter(ChatAdapterBase):
|
||||
"""
|
||||
|
||||
FORMAT_ID = "OPENAI"
|
||||
BILLING_TEMPLATE = "openai" # 使用 OpenAI 计费模板
|
||||
name = "openai.chat"
|
||||
|
||||
@property
|
||||
|
||||
@@ -24,6 +24,7 @@ class OpenAICliAdapter(CliAdapterBase):
|
||||
"""
|
||||
|
||||
FORMAT_ID = "OPENAI_CLI"
|
||||
BILLING_TEMPLATE = "openai" # 使用 OpenAI 计费模板
|
||||
name = "openai.cli"
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user