Files
Aether/src/plugins/token/base.py

171 lines
6.0 KiB
Python
Raw Normal View History

2025-12-10 20:52:44 +08:00
"""
Token计数插件基类
定义Token计数的接口
"""
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from src.plugins.common import BasePlugin
@dataclass
class TokenUsage:
"""令牌使用情况"""
input_tokens: int = 0
output_tokens: int = 0
total_tokens: int = 0
cache_read_tokens: int = 0 # Claude缓存读取
cache_write_tokens: int = 0 # Claude缓存写入
reasoning_tokens: int = 0 # OpenAI o1推理令牌
def __add__(self, other: "TokenUsage") -> "TokenUsage":
"""令牌使用相加"""
return TokenUsage(
input_tokens=self.input_tokens + other.input_tokens,
output_tokens=self.output_tokens + other.output_tokens,
total_tokens=self.total_tokens + other.total_tokens,
cache_read_tokens=self.cache_read_tokens + other.cache_read_tokens,
cache_write_tokens=self.cache_write_tokens + other.cache_write_tokens,
reasoning_tokens=self.reasoning_tokens + other.reasoning_tokens,
)
def to_dict(self) -> Dict[str, int]:
"""转换为字典"""
return {
"input_tokens": self.input_tokens,
"output_tokens": self.output_tokens,
"total_tokens": self.total_tokens,
"cache_read_tokens": self.cache_read_tokens,
"cache_write_tokens": self.cache_write_tokens,
"reasoning_tokens": self.reasoning_tokens,
}
class TokenCounterPlugin(BasePlugin):
"""
Token计数插件基类
支持不同模型的Token计数
"""
def __init__(self, name: str = "token_counter", config: Dict[str, Any] = None):
# 调用父类初始化设置metadata
super().__init__(
name=name, config=config, description="Token Counter Plugin", version="1.0.0"
)
self.supported_models = self.config.get("supported_models", [])
self.default_model = self.config.get("default_model")
@abstractmethod
def supports_model(self, model: str) -> bool:
"""检查是否支持指定模型"""
pass
@abstractmethod
async def count_tokens(self, text: str, model: Optional[str] = None) -> int:
"""计算文本的Token数量"""
pass
@abstractmethod
async def count_messages(
self, messages: List[Dict[str, Any]], model: Optional[str] = None
) -> int:
"""计算消息列表的Token数量"""
pass
async def count_request(self, request: Dict[str, Any], model: Optional[str] = None) -> int:
"""计算请求的Token数量"""
model = model or request.get("model") or self.default_model
messages = request.get("messages", [])
return await self.count_messages(messages, model)
async def count_response(
self, response: Dict[str, Any], model: Optional[str] = None
) -> TokenUsage:
"""从响应中提取Token使用情况"""
usage = response.get("usage", {})
# OpenAI格式
if "prompt_tokens" in usage:
return TokenUsage(
input_tokens=usage.get("prompt_tokens", 0),
output_tokens=usage.get("completion_tokens", 0),
total_tokens=usage.get("total_tokens", 0),
reasoning_tokens=usage.get("completion_tokens_details", {}).get(
"reasoning_tokens", 0
),
)
# Claude格式
elif "input_tokens" in usage:
return TokenUsage(
input_tokens=usage.get("input_tokens", 0),
output_tokens=usage.get("output_tokens", 0),
total_tokens=usage.get("input_tokens", 0) + usage.get("output_tokens", 0),
cache_read_tokens=usage.get("cache_read_input_tokens", 0),
cache_write_tokens=usage.get("cache_creation_input_tokens", 0),
)
return TokenUsage()
async def estimate_cost(
self, usage: TokenUsage, model: str, provider: Optional[str] = None
) -> Dict[str, float]:
"""估算使用成本"""
# 默认价格表每1M tokens的价格
pricing = self.config.get("pricing", {})
# 获取模型价格
model_pricing = pricing.get(model, {})
if not model_pricing:
# 尝试使用前缀匹配
for model_prefix, price_info in pricing.items():
if model.startswith(model_prefix):
model_pricing = price_info
break
if not model_pricing:
return {"error": "No pricing information available"}
# 计算成本
input_cost = (usage.input_tokens / 1_000_000) * model_pricing.get("input", 0)
output_cost = (usage.output_tokens / 1_000_000) * model_pricing.get("output", 0)
# 缓存成本Claude特有
cache_read_cost = (usage.cache_read_tokens / 1_000_000) * model_pricing.get("cache_read", 0)
cache_write_cost = (usage.cache_write_tokens / 1_000_000) * model_pricing.get(
"cache_write", 0
)
# 推理成本OpenAI o1特有
reasoning_cost = (usage.reasoning_tokens / 1_000_000) * model_pricing.get("reasoning", 0)
total_cost = input_cost + output_cost + cache_read_cost + cache_write_cost + reasoning_cost
return {
"input_cost": round(input_cost, 6),
"output_cost": round(output_cost, 6),
"cache_read_cost": round(cache_read_cost, 6),
"cache_write_cost": round(cache_write_cost, 6),
"reasoning_cost": round(reasoning_cost, 6),
"total_cost": round(total_cost, 6),
"currency": "USD",
}
@abstractmethod
async def get_model_info(self, model: str) -> Dict[str, Any]:
"""获取模型信息"""
pass
async def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
return {
"type": self.name,
"enabled": self.enabled,
"supported_models": self.supported_models,
"default_model": self.default_model,
}