mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-02 15:52:26 +08:00
refactor(handler): implement defensive token update strategy and extract cache creation token utility
- Add extract_cache_creation_tokens utility to handle new/old cache creation token formats - Implement defensive update strategy in StreamContext to prevent zero values overwriting valid data - Simplify cache creation token parsing in Claude handler using new utility - Add comprehensive test suite for cache creation token extraction - Improve type hints in handler classes
This commit is contained in:
@@ -13,6 +13,7 @@ from src.api.handlers.base.response_parser import (
|
||||
ResponseParser,
|
||||
StreamStats,
|
||||
)
|
||||
from src.api.handlers.base.utils import extract_cache_creation_tokens
|
||||
|
||||
|
||||
def _check_nested_error(response: Dict[str, Any]) -> Tuple[bool, Optional[Dict[str, Any]]]:
|
||||
@@ -252,7 +253,7 @@ class ClaudeResponseParser(ResponseParser):
|
||||
usage = response.get("usage", {})
|
||||
result.input_tokens = usage.get("input_tokens", 0)
|
||||
result.output_tokens = usage.get("output_tokens", 0)
|
||||
result.cache_creation_tokens = usage.get("cache_creation_input_tokens", 0)
|
||||
result.cache_creation_tokens = extract_cache_creation_tokens(usage)
|
||||
result.cache_read_tokens = usage.get("cache_read_input_tokens", 0)
|
||||
|
||||
# 检查错误(支持嵌套错误格式)
|
||||
@@ -265,11 +266,16 @@ class ClaudeResponseParser(ResponseParser):
|
||||
return result
|
||||
|
||||
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
|
||||
# 对于 message_start 事件,usage 在 message.usage 路径下
|
||||
# 对于其他响应,usage 在顶层
|
||||
usage = response.get("usage", {})
|
||||
if not usage and "message" in response:
|
||||
usage = response.get("message", {}).get("usage", {})
|
||||
|
||||
return {
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
|
||||
"cache_creation_tokens": extract_cache_creation_tokens(usage),
|
||||
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
|
||||
}
|
||||
|
||||
|
||||
@@ -104,14 +104,40 @@ class StreamContext:
|
||||
cached_tokens: Optional[int] = None,
|
||||
cache_creation_tokens: Optional[int] = None,
|
||||
) -> None:
|
||||
"""更新 Token 使用统计"""
|
||||
if input_tokens is not None:
|
||||
"""
|
||||
更新 Token 使用统计
|
||||
|
||||
采用防御性更新策略:只有当新值 > 0 或当前值为 0 时才更新,避免用 0 覆盖已有的正确值。
|
||||
|
||||
设计原理:
|
||||
- 在流式响应中,某些事件可能不包含完整的 usage 信息(字段为 0 或不存在)
|
||||
- 后续事件可能会提供完整的统计数据
|
||||
- 通过这种策略,确保一旦获得非零值就保留它,不会被后续的 0 值覆盖
|
||||
|
||||
示例场景:
|
||||
- message_start 事件:input_tokens=100, output_tokens=0
|
||||
- message_delta 事件:input_tokens=0, output_tokens=50
|
||||
- 最终结果:input_tokens=100, output_tokens=50
|
||||
|
||||
注意事项:
|
||||
- 此策略假设初始值为 0 是正确的默认状态
|
||||
- 如果需要将已有值重置为 0,请直接修改实例属性(不使用此方法)
|
||||
|
||||
Args:
|
||||
input_tokens: 输入 tokens 数量
|
||||
output_tokens: 输出 tokens 数量
|
||||
cached_tokens: 缓存命中 tokens 数量
|
||||
cache_creation_tokens: 缓存创建 tokens 数量
|
||||
"""
|
||||
if input_tokens is not None and (input_tokens > 0 or self.input_tokens == 0):
|
||||
self.input_tokens = input_tokens
|
||||
if output_tokens is not None:
|
||||
if output_tokens is not None and (output_tokens > 0 or self.output_tokens == 0):
|
||||
self.output_tokens = output_tokens
|
||||
if cached_tokens is not None:
|
||||
if cached_tokens is not None and (cached_tokens > 0 or self.cached_tokens == 0):
|
||||
self.cached_tokens = cached_tokens
|
||||
if cache_creation_tokens is not None:
|
||||
if cache_creation_tokens is not None and (
|
||||
cache_creation_tokens > 0 or self.cache_creation_tokens == 0
|
||||
):
|
||||
self.cache_creation_tokens = cache_creation_tokens
|
||||
|
||||
def mark_failed(self, status_code: int, error_message: str) -> None:
|
||||
|
||||
31
src/api/handlers/base/utils.py
Normal file
31
src/api/handlers/base/utils.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Handler 基础工具函数
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int:
|
||||
"""
|
||||
提取缓存创建 tokens(兼容新旧格式)
|
||||
|
||||
Claude API 在不同版本中使用了不同的字段名来表示缓存创建 tokens:
|
||||
- 新格式(2024年后):使用 claude_cache_creation_5_m_tokens 和
|
||||
claude_cache_creation_1_h_tokens 分别表示 5 分钟和 1 小时缓存
|
||||
- 旧格式:使用 cache_creation_input_tokens 表示总的缓存创建 tokens
|
||||
|
||||
此函数自动检测并适配两种格式,优先使用新格式。
|
||||
|
||||
Args:
|
||||
usage: API 响应中的 usage 字典
|
||||
|
||||
Returns:
|
||||
缓存创建 tokens 总数
|
||||
"""
|
||||
# 优先使用新格式
|
||||
cache_5m = usage.get("claude_cache_creation_5_m_tokens", 0)
|
||||
cache_1h = usage.get("claude_cache_creation_1_h_tokens", 0)
|
||||
total = int(cache_5m) + int(cache_1h)
|
||||
|
||||
# 如果新格式不存在(total == 0),回退到旧格式
|
||||
return total if total > 0 else int(usage.get("cache_creation_input_tokens", 0))
|
||||
@@ -8,6 +8,7 @@ Claude Chat Handler - 基于通用 Chat Handler 基类的简化实现
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
||||
from src.api.handlers.base.utils import extract_cache_creation_tokens
|
||||
|
||||
|
||||
class ClaudeChatHandler(ChatHandlerBase):
|
||||
@@ -63,7 +64,7 @@ class ClaudeChatHandler(ChatHandlerBase):
|
||||
result["model"] = mapped_model
|
||||
return result
|
||||
|
||||
async def _convert_request(self, request):
|
||||
async def _convert_request(self, request: Any) -> Any:
|
||||
"""
|
||||
将请求转换为 Claude 格式
|
||||
|
||||
@@ -109,30 +110,18 @@ class ClaudeChatHandler(ChatHandlerBase):
|
||||
Claude 格式使用:
|
||||
- input_tokens / output_tokens
|
||||
- cache_creation_input_tokens / cache_read_input_tokens
|
||||
- 新格式:claude_cache_creation_5_m_tokens / claude_cache_creation_1_h_tokens
|
||||
"""
|
||||
usage = response.get("usage", {})
|
||||
|
||||
input_tokens = usage.get("input_tokens", 0)
|
||||
output_tokens = usage.get("output_tokens", 0)
|
||||
cache_creation_input_tokens = usage.get("cache_creation_input_tokens", 0)
|
||||
cache_read_input_tokens = usage.get("cache_read_input_tokens", 0)
|
||||
|
||||
# 处理新的 cache_creation 格式
|
||||
if "cache_creation" in usage:
|
||||
cache_creation_data = usage.get("cache_creation", {})
|
||||
if not cache_creation_input_tokens:
|
||||
cache_creation_input_tokens = cache_creation_data.get(
|
||||
"ephemeral_5m_input_tokens", 0
|
||||
) + cache_creation_data.get("ephemeral_1h_input_tokens", 0)
|
||||
|
||||
return {
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"cache_creation_input_tokens": cache_creation_input_tokens,
|
||||
"cache_read_input_tokens": cache_read_input_tokens,
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"cache_creation_input_tokens": extract_cache_creation_tokens(usage),
|
||||
"cache_read_input_tokens": usage.get("cache_read_input_tokens", 0),
|
||||
}
|
||||
|
||||
def _normalize_response(self, response: Dict) -> Dict:
|
||||
def _normalize_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
规范化 Claude 响应
|
||||
|
||||
@@ -143,8 +132,9 @@ class ClaudeChatHandler(ChatHandlerBase):
|
||||
规范化后的响应
|
||||
"""
|
||||
if self.response_normalizer and self.response_normalizer.should_normalize(response):
|
||||
return self.response_normalizer.normalize_claude_response(
|
||||
result: Dict[str, Any] = self.response_normalizer.normalize_claude_response(
|
||||
response_data=response,
|
||||
request_id=self.request_id,
|
||||
)
|
||||
return result
|
||||
return response
|
||||
|
||||
@@ -9,6 +9,8 @@ from __future__ import annotations
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from src.api.handlers.base.utils import extract_cache_creation_tokens
|
||||
|
||||
|
||||
class ClaudeStreamParser:
|
||||
"""
|
||||
@@ -193,7 +195,7 @@ class ClaudeStreamParser:
|
||||
return {
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
|
||||
"cache_creation_tokens": extract_cache_creation_tokens(usage),
|
||||
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
|
||||
}
|
||||
|
||||
@@ -204,7 +206,7 @@ class ClaudeStreamParser:
|
||||
return {
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
|
||||
"cache_creation_tokens": extract_cache_creation_tokens(usage),
|
||||
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from src.api.handlers.base.cli_handler_base import (
|
||||
CliMessageHandlerBase,
|
||||
StreamContext,
|
||||
)
|
||||
from src.api.handlers.base.utils import extract_cache_creation_tokens
|
||||
|
||||
|
||||
class ClaudeCliMessageHandler(CliMessageHandlerBase):
|
||||
@@ -95,11 +96,12 @@ class ClaudeCliMessageHandler(CliMessageHandlerBase):
|
||||
usage = message.get("usage", {})
|
||||
if usage:
|
||||
ctx.input_tokens = usage.get("input_tokens", 0)
|
||||
# Claude 的缓存 tokens 使用不同的字段名
|
||||
|
||||
cache_read = usage.get("cache_read_input_tokens", 0)
|
||||
if cache_read:
|
||||
ctx.cached_tokens = cache_read
|
||||
cache_creation = usage.get("cache_creation_input_tokens", 0)
|
||||
|
||||
cache_creation = extract_cache_creation_tokens(usage)
|
||||
if cache_creation:
|
||||
ctx.cache_creation_tokens = cache_creation
|
||||
|
||||
@@ -119,11 +121,15 @@ class ClaudeCliMessageHandler(CliMessageHandlerBase):
|
||||
ctx.input_tokens = usage["input_tokens"]
|
||||
if "output_tokens" in usage:
|
||||
ctx.output_tokens = usage["output_tokens"]
|
||||
# 更新缓存 tokens
|
||||
|
||||
# 更新缓存读取 tokens
|
||||
if "cache_read_input_tokens" in usage:
|
||||
ctx.cached_tokens = usage["cache_read_input_tokens"]
|
||||
if "cache_creation_input_tokens" in usage:
|
||||
ctx.cache_creation_tokens = usage["cache_creation_input_tokens"]
|
||||
|
||||
# 更新缓存创建 tokens
|
||||
cache_creation = extract_cache_creation_tokens(usage)
|
||||
if cache_creation > 0:
|
||||
ctx.cache_creation_tokens = cache_creation
|
||||
|
||||
# 检查是否结束
|
||||
delta = data.get("delta", {})
|
||||
|
||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""测试模块"""
|
||||
90
tests/api/handlers/base/test_utils.py
Normal file
90
tests/api/handlers/base/test_utils.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""测试 handler 基础工具函数"""
|
||||
|
||||
import pytest
|
||||
|
||||
from src.api.handlers.base.utils import extract_cache_creation_tokens
|
||||
|
||||
|
||||
class TestExtractCacheCreationTokens:
|
||||
"""测试 extract_cache_creation_tokens 函数"""
|
||||
|
||||
def test_new_format_only(self) -> None:
|
||||
"""测试只有新格式字段"""
|
||||
usage = {
|
||||
"claude_cache_creation_5_m_tokens": 100,
|
||||
"claude_cache_creation_1_h_tokens": 200,
|
||||
}
|
||||
assert extract_cache_creation_tokens(usage) == 300
|
||||
|
||||
def test_new_format_5m_only(self) -> None:
|
||||
"""测试只有 5 分钟缓存"""
|
||||
usage = {
|
||||
"claude_cache_creation_5_m_tokens": 150,
|
||||
"claude_cache_creation_1_h_tokens": 0,
|
||||
}
|
||||
assert extract_cache_creation_tokens(usage) == 150
|
||||
|
||||
def test_new_format_1h_only(self) -> None:
|
||||
"""测试只有 1 小时缓存"""
|
||||
usage = {
|
||||
"claude_cache_creation_5_m_tokens": 0,
|
||||
"claude_cache_creation_1_h_tokens": 250,
|
||||
}
|
||||
assert extract_cache_creation_tokens(usage) == 250
|
||||
|
||||
def test_old_format_only(self) -> None:
|
||||
"""测试只有旧格式字段"""
|
||||
usage = {
|
||||
"cache_creation_input_tokens": 500,
|
||||
}
|
||||
assert extract_cache_creation_tokens(usage) == 500
|
||||
|
||||
def test_both_formats_prefers_new(self) -> None:
|
||||
"""测试同时存在时优先使用新格式"""
|
||||
usage = {
|
||||
"claude_cache_creation_5_m_tokens": 100,
|
||||
"claude_cache_creation_1_h_tokens": 200,
|
||||
"cache_creation_input_tokens": 999, # 应该被忽略
|
||||
}
|
||||
assert extract_cache_creation_tokens(usage) == 300
|
||||
|
||||
def test_empty_usage(self) -> None:
|
||||
"""测试空字典"""
|
||||
usage = {}
|
||||
assert extract_cache_creation_tokens(usage) == 0
|
||||
|
||||
def test_all_zeros(self) -> None:
|
||||
"""测试所有字段都为 0"""
|
||||
usage = {
|
||||
"claude_cache_creation_5_m_tokens": 0,
|
||||
"claude_cache_creation_1_h_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
}
|
||||
assert extract_cache_creation_tokens(usage) == 0
|
||||
|
||||
def test_partial_new_format_with_old_format_fallback(self) -> None:
|
||||
"""测试新格式字段不存在时回退到旧格式"""
|
||||
usage = {
|
||||
"cache_creation_input_tokens": 123,
|
||||
}
|
||||
assert extract_cache_creation_tokens(usage) == 123
|
||||
|
||||
def test_new_format_zero_fallback_to_old(self) -> None:
|
||||
"""测试新格式为 0 时回退到旧格式"""
|
||||
usage = {
|
||||
"claude_cache_creation_5_m_tokens": 0,
|
||||
"claude_cache_creation_1_h_tokens": 0,
|
||||
"cache_creation_input_tokens": 456,
|
||||
}
|
||||
assert extract_cache_creation_tokens(usage) == 456
|
||||
|
||||
def test_unrelated_fields_ignored(self) -> None:
|
||||
"""测试忽略无关字段"""
|
||||
usage = {
|
||||
"input_tokens": 1000,
|
||||
"output_tokens": 2000,
|
||||
"cache_read_input_tokens": 300,
|
||||
"claude_cache_creation_5_m_tokens": 50,
|
||||
"claude_cache_creation_1_h_tokens": 75,
|
||||
}
|
||||
assert extract_cache_creation_tokens(usage) == 125
|
||||
Reference in New Issue
Block a user