mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-12 04:28:28 +08:00
Initial commit
This commit is contained in:
170
src/plugins/token/base.py
Normal file
170
src/plugins/token/base.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
Token计数插件基类
|
||||
定义Token计数的接口
|
||||
"""
|
||||
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from src.plugins.common import BasePlugin
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenUsage:
|
||||
"""令牌使用情况"""
|
||||
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
cache_read_tokens: int = 0 # Claude缓存读取
|
||||
cache_write_tokens: int = 0 # Claude缓存写入
|
||||
reasoning_tokens: int = 0 # OpenAI o1推理令牌
|
||||
|
||||
def __add__(self, other: "TokenUsage") -> "TokenUsage":
|
||||
"""令牌使用相加"""
|
||||
return TokenUsage(
|
||||
input_tokens=self.input_tokens + other.input_tokens,
|
||||
output_tokens=self.output_tokens + other.output_tokens,
|
||||
total_tokens=self.total_tokens + other.total_tokens,
|
||||
cache_read_tokens=self.cache_read_tokens + other.cache_read_tokens,
|
||||
cache_write_tokens=self.cache_write_tokens + other.cache_write_tokens,
|
||||
reasoning_tokens=self.reasoning_tokens + other.reasoning_tokens,
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, int]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"input_tokens": self.input_tokens,
|
||||
"output_tokens": self.output_tokens,
|
||||
"total_tokens": self.total_tokens,
|
||||
"cache_read_tokens": self.cache_read_tokens,
|
||||
"cache_write_tokens": self.cache_write_tokens,
|
||||
"reasoning_tokens": self.reasoning_tokens,
|
||||
}
|
||||
|
||||
|
||||
class TokenCounterPlugin(BasePlugin):
|
||||
"""
|
||||
Token计数插件基类
|
||||
支持不同模型的Token计数
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = "token_counter", config: Dict[str, Any] = None):
|
||||
# 调用父类初始化,设置metadata
|
||||
super().__init__(
|
||||
name=name, config=config, description="Token Counter Plugin", version="1.0.0"
|
||||
)
|
||||
|
||||
self.supported_models = self.config.get("supported_models", [])
|
||||
self.default_model = self.config.get("default_model")
|
||||
|
||||
@abstractmethod
|
||||
def supports_model(self, model: str) -> bool:
|
||||
"""检查是否支持指定模型"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def count_tokens(self, text: str, model: Optional[str] = None) -> int:
|
||||
"""计算文本的Token数量"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def count_messages(
|
||||
self, messages: List[Dict[str, Any]], model: Optional[str] = None
|
||||
) -> int:
|
||||
"""计算消息列表的Token数量"""
|
||||
pass
|
||||
|
||||
async def count_request(self, request: Dict[str, Any], model: Optional[str] = None) -> int:
|
||||
"""计算请求的Token数量"""
|
||||
model = model or request.get("model") or self.default_model
|
||||
messages = request.get("messages", [])
|
||||
return await self.count_messages(messages, model)
|
||||
|
||||
async def count_response(
|
||||
self, response: Dict[str, Any], model: Optional[str] = None
|
||||
) -> TokenUsage:
|
||||
"""从响应中提取Token使用情况"""
|
||||
usage = response.get("usage", {})
|
||||
|
||||
# OpenAI格式
|
||||
if "prompt_tokens" in usage:
|
||||
return TokenUsage(
|
||||
input_tokens=usage.get("prompt_tokens", 0),
|
||||
output_tokens=usage.get("completion_tokens", 0),
|
||||
total_tokens=usage.get("total_tokens", 0),
|
||||
reasoning_tokens=usage.get("completion_tokens_details", {}).get(
|
||||
"reasoning_tokens", 0
|
||||
),
|
||||
)
|
||||
|
||||
# Claude格式
|
||||
elif "input_tokens" in usage:
|
||||
return TokenUsage(
|
||||
input_tokens=usage.get("input_tokens", 0),
|
||||
output_tokens=usage.get("output_tokens", 0),
|
||||
total_tokens=usage.get("input_tokens", 0) + usage.get("output_tokens", 0),
|
||||
cache_read_tokens=usage.get("cache_read_input_tokens", 0),
|
||||
cache_write_tokens=usage.get("cache_creation_input_tokens", 0),
|
||||
)
|
||||
|
||||
return TokenUsage()
|
||||
|
||||
async def estimate_cost(
|
||||
self, usage: TokenUsage, model: str, provider: Optional[str] = None
|
||||
) -> Dict[str, float]:
|
||||
"""估算使用成本"""
|
||||
# 默认价格表(每1M tokens的价格)
|
||||
pricing = self.config.get("pricing", {})
|
||||
|
||||
# 获取模型价格
|
||||
model_pricing = pricing.get(model, {})
|
||||
if not model_pricing:
|
||||
# 尝试使用前缀匹配
|
||||
for model_prefix, price_info in pricing.items():
|
||||
if model.startswith(model_prefix):
|
||||
model_pricing = price_info
|
||||
break
|
||||
|
||||
if not model_pricing:
|
||||
return {"error": "No pricing information available"}
|
||||
|
||||
# 计算成本
|
||||
input_cost = (usage.input_tokens / 1_000_000) * model_pricing.get("input", 0)
|
||||
output_cost = (usage.output_tokens / 1_000_000) * model_pricing.get("output", 0)
|
||||
|
||||
# 缓存成本(Claude特有)
|
||||
cache_read_cost = (usage.cache_read_tokens / 1_000_000) * model_pricing.get("cache_read", 0)
|
||||
cache_write_cost = (usage.cache_write_tokens / 1_000_000) * model_pricing.get(
|
||||
"cache_write", 0
|
||||
)
|
||||
|
||||
# 推理成本(OpenAI o1特有)
|
||||
reasoning_cost = (usage.reasoning_tokens / 1_000_000) * model_pricing.get("reasoning", 0)
|
||||
|
||||
total_cost = input_cost + output_cost + cache_read_cost + cache_write_cost + reasoning_cost
|
||||
|
||||
return {
|
||||
"input_cost": round(input_cost, 6),
|
||||
"output_cost": round(output_cost, 6),
|
||||
"cache_read_cost": round(cache_read_cost, 6),
|
||||
"cache_write_cost": round(cache_write_cost, 6),
|
||||
"reasoning_cost": round(reasoning_cost, 6),
|
||||
"total_cost": round(total_cost, 6),
|
||||
"currency": "USD",
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
async def get_model_info(self, model: str) -> Dict[str, Any]:
|
||||
"""获取模型信息"""
|
||||
pass
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
return {
|
||||
"type": self.name,
|
||||
"enabled": self.enabled,
|
||||
"supported_models": self.supported_models,
|
||||
"default_model": self.default_model,
|
||||
}
|
||||
Reference in New Issue
Block a user