mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-12 04:28:28 +08:00
171 lines
6.0 KiB
Python
171 lines
6.0 KiB
Python
|
|
"""
|
|||
|
|
Token计数插件基类
|
|||
|
|
定义Token计数的接口
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from abc import abstractmethod
|
|||
|
|
from dataclasses import dataclass
|
|||
|
|
from typing import Any, Dict, List, Optional, Union
|
|||
|
|
|
|||
|
|
from src.plugins.common import BasePlugin
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class TokenUsage:
|
|||
|
|
"""令牌使用情况"""
|
|||
|
|
|
|||
|
|
input_tokens: int = 0
|
|||
|
|
output_tokens: int = 0
|
|||
|
|
total_tokens: int = 0
|
|||
|
|
cache_read_tokens: int = 0 # Claude缓存读取
|
|||
|
|
cache_write_tokens: int = 0 # Claude缓存写入
|
|||
|
|
reasoning_tokens: int = 0 # OpenAI o1推理令牌
|
|||
|
|
|
|||
|
|
def __add__(self, other: "TokenUsage") -> "TokenUsage":
|
|||
|
|
"""令牌使用相加"""
|
|||
|
|
return TokenUsage(
|
|||
|
|
input_tokens=self.input_tokens + other.input_tokens,
|
|||
|
|
output_tokens=self.output_tokens + other.output_tokens,
|
|||
|
|
total_tokens=self.total_tokens + other.total_tokens,
|
|||
|
|
cache_read_tokens=self.cache_read_tokens + other.cache_read_tokens,
|
|||
|
|
cache_write_tokens=self.cache_write_tokens + other.cache_write_tokens,
|
|||
|
|
reasoning_tokens=self.reasoning_tokens + other.reasoning_tokens,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def to_dict(self) -> Dict[str, int]:
|
|||
|
|
"""转换为字典"""
|
|||
|
|
return {
|
|||
|
|
"input_tokens": self.input_tokens,
|
|||
|
|
"output_tokens": self.output_tokens,
|
|||
|
|
"total_tokens": self.total_tokens,
|
|||
|
|
"cache_read_tokens": self.cache_read_tokens,
|
|||
|
|
"cache_write_tokens": self.cache_write_tokens,
|
|||
|
|
"reasoning_tokens": self.reasoning_tokens,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TokenCounterPlugin(BasePlugin):
|
|||
|
|
"""
|
|||
|
|
Token计数插件基类
|
|||
|
|
支持不同模型的Token计数
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, name: str = "token_counter", config: Dict[str, Any] = None):
|
|||
|
|
# 调用父类初始化,设置metadata
|
|||
|
|
super().__init__(
|
|||
|
|
name=name, config=config, description="Token Counter Plugin", version="1.0.0"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
self.supported_models = self.config.get("supported_models", [])
|
|||
|
|
self.default_model = self.config.get("default_model")
|
|||
|
|
|
|||
|
|
@abstractmethod
|
|||
|
|
def supports_model(self, model: str) -> bool:
|
|||
|
|
"""检查是否支持指定模型"""
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
@abstractmethod
|
|||
|
|
async def count_tokens(self, text: str, model: Optional[str] = None) -> int:
|
|||
|
|
"""计算文本的Token数量"""
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
@abstractmethod
|
|||
|
|
async def count_messages(
|
|||
|
|
self, messages: List[Dict[str, Any]], model: Optional[str] = None
|
|||
|
|
) -> int:
|
|||
|
|
"""计算消息列表的Token数量"""
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
async def count_request(self, request: Dict[str, Any], model: Optional[str] = None) -> int:
|
|||
|
|
"""计算请求的Token数量"""
|
|||
|
|
model = model or request.get("model") or self.default_model
|
|||
|
|
messages = request.get("messages", [])
|
|||
|
|
return await self.count_messages(messages, model)
|
|||
|
|
|
|||
|
|
async def count_response(
|
|||
|
|
self, response: Dict[str, Any], model: Optional[str] = None
|
|||
|
|
) -> TokenUsage:
|
|||
|
|
"""从响应中提取Token使用情况"""
|
|||
|
|
usage = response.get("usage", {})
|
|||
|
|
|
|||
|
|
# OpenAI格式
|
|||
|
|
if "prompt_tokens" in usage:
|
|||
|
|
return TokenUsage(
|
|||
|
|
input_tokens=usage.get("prompt_tokens", 0),
|
|||
|
|
output_tokens=usage.get("completion_tokens", 0),
|
|||
|
|
total_tokens=usage.get("total_tokens", 0),
|
|||
|
|
reasoning_tokens=usage.get("completion_tokens_details", {}).get(
|
|||
|
|
"reasoning_tokens", 0
|
|||
|
|
),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Claude格式
|
|||
|
|
elif "input_tokens" in usage:
|
|||
|
|
return TokenUsage(
|
|||
|
|
input_tokens=usage.get("input_tokens", 0),
|
|||
|
|
output_tokens=usage.get("output_tokens", 0),
|
|||
|
|
total_tokens=usage.get("input_tokens", 0) + usage.get("output_tokens", 0),
|
|||
|
|
cache_read_tokens=usage.get("cache_read_input_tokens", 0),
|
|||
|
|
cache_write_tokens=usage.get("cache_creation_input_tokens", 0),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return TokenUsage()
|
|||
|
|
|
|||
|
|
async def estimate_cost(
|
|||
|
|
self, usage: TokenUsage, model: str, provider: Optional[str] = None
|
|||
|
|
) -> Dict[str, float]:
|
|||
|
|
"""估算使用成本"""
|
|||
|
|
# 默认价格表(每1M tokens的价格)
|
|||
|
|
pricing = self.config.get("pricing", {})
|
|||
|
|
|
|||
|
|
# 获取模型价格
|
|||
|
|
model_pricing = pricing.get(model, {})
|
|||
|
|
if not model_pricing:
|
|||
|
|
# 尝试使用前缀匹配
|
|||
|
|
for model_prefix, price_info in pricing.items():
|
|||
|
|
if model.startswith(model_prefix):
|
|||
|
|
model_pricing = price_info
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
if not model_pricing:
|
|||
|
|
return {"error": "No pricing information available"}
|
|||
|
|
|
|||
|
|
# 计算成本
|
|||
|
|
input_cost = (usage.input_tokens / 1_000_000) * model_pricing.get("input", 0)
|
|||
|
|
output_cost = (usage.output_tokens / 1_000_000) * model_pricing.get("output", 0)
|
|||
|
|
|
|||
|
|
# 缓存成本(Claude特有)
|
|||
|
|
cache_read_cost = (usage.cache_read_tokens / 1_000_000) * model_pricing.get("cache_read", 0)
|
|||
|
|
cache_write_cost = (usage.cache_write_tokens / 1_000_000) * model_pricing.get(
|
|||
|
|
"cache_write", 0
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 推理成本(OpenAI o1特有)
|
|||
|
|
reasoning_cost = (usage.reasoning_tokens / 1_000_000) * model_pricing.get("reasoning", 0)
|
|||
|
|
|
|||
|
|
total_cost = input_cost + output_cost + cache_read_cost + cache_write_cost + reasoning_cost
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"input_cost": round(input_cost, 6),
|
|||
|
|
"output_cost": round(output_cost, 6),
|
|||
|
|
"cache_read_cost": round(cache_read_cost, 6),
|
|||
|
|
"cache_write_cost": round(cache_write_cost, 6),
|
|||
|
|
"reasoning_cost": round(reasoning_cost, 6),
|
|||
|
|
"total_cost": round(total_cost, 6),
|
|||
|
|
"currency": "USD",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
@abstractmethod
|
|||
|
|
async def get_model_info(self, model: str) -> Dict[str, Any]:
|
|||
|
|
"""获取模型信息"""
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
async def get_stats(self) -> Dict[str, Any]:
|
|||
|
|
"""获取统计信息"""
|
|||
|
|
return {
|
|||
|
|
"type": self.name,
|
|||
|
|
"enabled": self.enabled,
|
|||
|
|
"supported_models": self.supported_models,
|
|||
|
|
"default_model": self.default_model,
|
|||
|
|
}
|