mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-12 04:28:28 +08:00
171 lines
6.0 KiB
Python
171 lines
6.0 KiB
Python
"""
|
||
Token计数插件基类
|
||
定义Token计数的接口
|
||
"""
|
||
|
||
from abc import abstractmethod
|
||
from dataclasses import dataclass
|
||
from typing import Any, Dict, List, Optional, Union
|
||
|
||
from src.plugins.common import BasePlugin
|
||
|
||
|
||
@dataclass
|
||
class TokenUsage:
|
||
"""令牌使用情况"""
|
||
|
||
input_tokens: int = 0
|
||
output_tokens: int = 0
|
||
total_tokens: int = 0
|
||
cache_read_tokens: int = 0 # Claude缓存读取
|
||
cache_write_tokens: int = 0 # Claude缓存写入
|
||
reasoning_tokens: int = 0 # OpenAI o1推理令牌
|
||
|
||
def __add__(self, other: "TokenUsage") -> "TokenUsage":
|
||
"""令牌使用相加"""
|
||
return TokenUsage(
|
||
input_tokens=self.input_tokens + other.input_tokens,
|
||
output_tokens=self.output_tokens + other.output_tokens,
|
||
total_tokens=self.total_tokens + other.total_tokens,
|
||
cache_read_tokens=self.cache_read_tokens + other.cache_read_tokens,
|
||
cache_write_tokens=self.cache_write_tokens + other.cache_write_tokens,
|
||
reasoning_tokens=self.reasoning_tokens + other.reasoning_tokens,
|
||
)
|
||
|
||
def to_dict(self) -> Dict[str, int]:
|
||
"""转换为字典"""
|
||
return {
|
||
"input_tokens": self.input_tokens,
|
||
"output_tokens": self.output_tokens,
|
||
"total_tokens": self.total_tokens,
|
||
"cache_read_tokens": self.cache_read_tokens,
|
||
"cache_write_tokens": self.cache_write_tokens,
|
||
"reasoning_tokens": self.reasoning_tokens,
|
||
}
|
||
|
||
|
||
class TokenCounterPlugin(BasePlugin):
|
||
"""
|
||
Token计数插件基类
|
||
支持不同模型的Token计数
|
||
"""
|
||
|
||
def __init__(self, name: str = "token_counter", config: Dict[str, Any] = None):
|
||
# 调用父类初始化,设置metadata
|
||
super().__init__(
|
||
name=name, config=config, description="Token Counter Plugin", version="1.0.0"
|
||
)
|
||
|
||
self.supported_models = self.config.get("supported_models", [])
|
||
self.default_model = self.config.get("default_model")
|
||
|
||
@abstractmethod
|
||
def supports_model(self, model: str) -> bool:
|
||
"""检查是否支持指定模型"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
async def count_tokens(self, text: str, model: Optional[str] = None) -> int:
|
||
"""计算文本的Token数量"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
async def count_messages(
|
||
self, messages: List[Dict[str, Any]], model: Optional[str] = None
|
||
) -> int:
|
||
"""计算消息列表的Token数量"""
|
||
pass
|
||
|
||
async def count_request(self, request: Dict[str, Any], model: Optional[str] = None) -> int:
|
||
"""计算请求的Token数量"""
|
||
model = model or request.get("model") or self.default_model
|
||
messages = request.get("messages", [])
|
||
return await self.count_messages(messages, model)
|
||
|
||
async def count_response(
|
||
self, response: Dict[str, Any], model: Optional[str] = None
|
||
) -> TokenUsage:
|
||
"""从响应中提取Token使用情况"""
|
||
usage = response.get("usage", {})
|
||
|
||
# OpenAI格式
|
||
if "prompt_tokens" in usage:
|
||
return TokenUsage(
|
||
input_tokens=usage.get("prompt_tokens", 0),
|
||
output_tokens=usage.get("completion_tokens", 0),
|
||
total_tokens=usage.get("total_tokens", 0),
|
||
reasoning_tokens=usage.get("completion_tokens_details", {}).get(
|
||
"reasoning_tokens", 0
|
||
),
|
||
)
|
||
|
||
# Claude格式
|
||
elif "input_tokens" in usage:
|
||
return TokenUsage(
|
||
input_tokens=usage.get("input_tokens", 0),
|
||
output_tokens=usage.get("output_tokens", 0),
|
||
total_tokens=usage.get("input_tokens", 0) + usage.get("output_tokens", 0),
|
||
cache_read_tokens=usage.get("cache_read_input_tokens", 0),
|
||
cache_write_tokens=usage.get("cache_creation_input_tokens", 0),
|
||
)
|
||
|
||
return TokenUsage()
|
||
|
||
async def estimate_cost(
|
||
self, usage: TokenUsage, model: str, provider: Optional[str] = None
|
||
) -> Dict[str, float]:
|
||
"""估算使用成本"""
|
||
# 默认价格表(每1M tokens的价格)
|
||
pricing = self.config.get("pricing", {})
|
||
|
||
# 获取模型价格
|
||
model_pricing = pricing.get(model, {})
|
||
if not model_pricing:
|
||
# 尝试使用前缀匹配
|
||
for model_prefix, price_info in pricing.items():
|
||
if model.startswith(model_prefix):
|
||
model_pricing = price_info
|
||
break
|
||
|
||
if not model_pricing:
|
||
return {"error": "No pricing information available"}
|
||
|
||
# 计算成本
|
||
input_cost = (usage.input_tokens / 1_000_000) * model_pricing.get("input", 0)
|
||
output_cost = (usage.output_tokens / 1_000_000) * model_pricing.get("output", 0)
|
||
|
||
# 缓存成本(Claude特有)
|
||
cache_read_cost = (usage.cache_read_tokens / 1_000_000) * model_pricing.get("cache_read", 0)
|
||
cache_write_cost = (usage.cache_write_tokens / 1_000_000) * model_pricing.get(
|
||
"cache_write", 0
|
||
)
|
||
|
||
# 推理成本(OpenAI o1特有)
|
||
reasoning_cost = (usage.reasoning_tokens / 1_000_000) * model_pricing.get("reasoning", 0)
|
||
|
||
total_cost = input_cost + output_cost + cache_read_cost + cache_write_cost + reasoning_cost
|
||
|
||
return {
|
||
"input_cost": round(input_cost, 6),
|
||
"output_cost": round(output_cost, 6),
|
||
"cache_read_cost": round(cache_read_cost, 6),
|
||
"cache_write_cost": round(cache_write_cost, 6),
|
||
"reasoning_cost": round(reasoning_cost, 6),
|
||
"total_cost": round(total_cost, 6),
|
||
"currency": "USD",
|
||
}
|
||
|
||
@abstractmethod
|
||
async def get_model_info(self, model: str) -> Dict[str, Any]:
|
||
"""获取模型信息"""
|
||
pass
|
||
|
||
async def get_stats(self) -> Dict[str, Any]:
|
||
"""获取统计信息"""
|
||
return {
|
||
"type": self.name,
|
||
"enabled": self.enabled,
|
||
"supported_models": self.supported_models,
|
||
"default_model": self.default_model,
|
||
}
|