Files
Aether/src/plugins/token/base.py
2025-12-10 20:52:44 +08:00

171 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Token计数插件基类
定义Token计数的接口
"""
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from src.plugins.common import BasePlugin
@dataclass
class TokenUsage:
"""令牌使用情况"""
input_tokens: int = 0
output_tokens: int = 0
total_tokens: int = 0
cache_read_tokens: int = 0 # Claude缓存读取
cache_write_tokens: int = 0 # Claude缓存写入
reasoning_tokens: int = 0 # OpenAI o1推理令牌
def __add__(self, other: "TokenUsage") -> "TokenUsage":
"""令牌使用相加"""
return TokenUsage(
input_tokens=self.input_tokens + other.input_tokens,
output_tokens=self.output_tokens + other.output_tokens,
total_tokens=self.total_tokens + other.total_tokens,
cache_read_tokens=self.cache_read_tokens + other.cache_read_tokens,
cache_write_tokens=self.cache_write_tokens + other.cache_write_tokens,
reasoning_tokens=self.reasoning_tokens + other.reasoning_tokens,
)
def to_dict(self) -> Dict[str, int]:
"""转换为字典"""
return {
"input_tokens": self.input_tokens,
"output_tokens": self.output_tokens,
"total_tokens": self.total_tokens,
"cache_read_tokens": self.cache_read_tokens,
"cache_write_tokens": self.cache_write_tokens,
"reasoning_tokens": self.reasoning_tokens,
}
class TokenCounterPlugin(BasePlugin):
"""
Token计数插件基类
支持不同模型的Token计数
"""
def __init__(self, name: str = "token_counter", config: Dict[str, Any] = None):
# 调用父类初始化设置metadata
super().__init__(
name=name, config=config, description="Token Counter Plugin", version="1.0.0"
)
self.supported_models = self.config.get("supported_models", [])
self.default_model = self.config.get("default_model")
@abstractmethod
def supports_model(self, model: str) -> bool:
"""检查是否支持指定模型"""
pass
@abstractmethod
async def count_tokens(self, text: str, model: Optional[str] = None) -> int:
"""计算文本的Token数量"""
pass
@abstractmethod
async def count_messages(
self, messages: List[Dict[str, Any]], model: Optional[str] = None
) -> int:
"""计算消息列表的Token数量"""
pass
async def count_request(self, request: Dict[str, Any], model: Optional[str] = None) -> int:
"""计算请求的Token数量"""
model = model or request.get("model") or self.default_model
messages = request.get("messages", [])
return await self.count_messages(messages, model)
async def count_response(
self, response: Dict[str, Any], model: Optional[str] = None
) -> TokenUsage:
"""从响应中提取Token使用情况"""
usage = response.get("usage", {})
# OpenAI格式
if "prompt_tokens" in usage:
return TokenUsage(
input_tokens=usage.get("prompt_tokens", 0),
output_tokens=usage.get("completion_tokens", 0),
total_tokens=usage.get("total_tokens", 0),
reasoning_tokens=usage.get("completion_tokens_details", {}).get(
"reasoning_tokens", 0
),
)
# Claude格式
elif "input_tokens" in usage:
return TokenUsage(
input_tokens=usage.get("input_tokens", 0),
output_tokens=usage.get("output_tokens", 0),
total_tokens=usage.get("input_tokens", 0) + usage.get("output_tokens", 0),
cache_read_tokens=usage.get("cache_read_input_tokens", 0),
cache_write_tokens=usage.get("cache_creation_input_tokens", 0),
)
return TokenUsage()
async def estimate_cost(
self, usage: TokenUsage, model: str, provider: Optional[str] = None
) -> Dict[str, float]:
"""估算使用成本"""
# 默认价格表每1M tokens的价格
pricing = self.config.get("pricing", {})
# 获取模型价格
model_pricing = pricing.get(model, {})
if not model_pricing:
# 尝试使用前缀匹配
for model_prefix, price_info in pricing.items():
if model.startswith(model_prefix):
model_pricing = price_info
break
if not model_pricing:
return {"error": "No pricing information available"}
# 计算成本
input_cost = (usage.input_tokens / 1_000_000) * model_pricing.get("input", 0)
output_cost = (usage.output_tokens / 1_000_000) * model_pricing.get("output", 0)
# 缓存成本Claude特有
cache_read_cost = (usage.cache_read_tokens / 1_000_000) * model_pricing.get("cache_read", 0)
cache_write_cost = (usage.cache_write_tokens / 1_000_000) * model_pricing.get(
"cache_write", 0
)
# 推理成本OpenAI o1特有
reasoning_cost = (usage.reasoning_tokens / 1_000_000) * model_pricing.get("reasoning", 0)
total_cost = input_cost + output_cost + cache_read_cost + cache_write_cost + reasoning_cost
return {
"input_cost": round(input_cost, 6),
"output_cost": round(output_cost, 6),
"cache_read_cost": round(cache_read_cost, 6),
"cache_write_cost": round(cache_write_cost, 6),
"reasoning_cost": round(reasoning_cost, 6),
"total_cost": round(total_cost, 6),
"currency": "USD",
}
@abstractmethod
async def get_model_info(self, model: str) -> Dict[str, Any]:
"""获取模型信息"""
pass
async def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
return {
"type": self.name,
"enabled": self.enabled,
"supported_models": self.supported_models,
"default_model": self.default_model,
}