mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-02 15:52:26 +08:00
refactor(handler): implement defensive token update strategy and extract cache creation token utility
- Add extract_cache_creation_tokens utility to handle new/old cache creation token formats - Implement defensive update strategy in StreamContext to prevent zero values overwriting valid data - Simplify cache creation token parsing in Claude handler using new utility - Add comprehensive test suite for cache creation token extraction - Improve type hints in handler classes
This commit is contained in:
@@ -13,6 +13,7 @@ from src.api.handlers.base.response_parser import (
|
|||||||
ResponseParser,
|
ResponseParser,
|
||||||
StreamStats,
|
StreamStats,
|
||||||
)
|
)
|
||||||
|
from src.api.handlers.base.utils import extract_cache_creation_tokens
|
||||||
|
|
||||||
|
|
||||||
def _check_nested_error(response: Dict[str, Any]) -> Tuple[bool, Optional[Dict[str, Any]]]:
|
def _check_nested_error(response: Dict[str, Any]) -> Tuple[bool, Optional[Dict[str, Any]]]:
|
||||||
@@ -252,7 +253,7 @@ class ClaudeResponseParser(ResponseParser):
|
|||||||
usage = response.get("usage", {})
|
usage = response.get("usage", {})
|
||||||
result.input_tokens = usage.get("input_tokens", 0)
|
result.input_tokens = usage.get("input_tokens", 0)
|
||||||
result.output_tokens = usage.get("output_tokens", 0)
|
result.output_tokens = usage.get("output_tokens", 0)
|
||||||
result.cache_creation_tokens = usage.get("cache_creation_input_tokens", 0)
|
result.cache_creation_tokens = extract_cache_creation_tokens(usage)
|
||||||
result.cache_read_tokens = usage.get("cache_read_input_tokens", 0)
|
result.cache_read_tokens = usage.get("cache_read_input_tokens", 0)
|
||||||
|
|
||||||
# 检查错误(支持嵌套错误格式)
|
# 检查错误(支持嵌套错误格式)
|
||||||
@@ -265,11 +266,16 @@ class ClaudeResponseParser(ResponseParser):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
|
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
|
||||||
|
# 对于 message_start 事件,usage 在 message.usage 路径下
|
||||||
|
# 对于其他响应,usage 在顶层
|
||||||
usage = response.get("usage", {})
|
usage = response.get("usage", {})
|
||||||
|
if not usage and "message" in response:
|
||||||
|
usage = response.get("message", {}).get("usage", {})
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_tokens": usage.get("input_tokens", 0),
|
"input_tokens": usage.get("input_tokens", 0),
|
||||||
"output_tokens": usage.get("output_tokens", 0),
|
"output_tokens": usage.get("output_tokens", 0),
|
||||||
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
|
"cache_creation_tokens": extract_cache_creation_tokens(usage),
|
||||||
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
|
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -104,14 +104,40 @@ class StreamContext:
|
|||||||
cached_tokens: Optional[int] = None,
|
cached_tokens: Optional[int] = None,
|
||||||
cache_creation_tokens: Optional[int] = None,
|
cache_creation_tokens: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""更新 Token 使用统计"""
|
"""
|
||||||
if input_tokens is not None:
|
更新 Token 使用统计
|
||||||
|
|
||||||
|
采用防御性更新策略:只有当新值 > 0 或当前值为 0 时才更新,避免用 0 覆盖已有的正确值。
|
||||||
|
|
||||||
|
设计原理:
|
||||||
|
- 在流式响应中,某些事件可能不包含完整的 usage 信息(字段为 0 或不存在)
|
||||||
|
- 后续事件可能会提供完整的统计数据
|
||||||
|
- 通过这种策略,确保一旦获得非零值就保留它,不会被后续的 0 值覆盖
|
||||||
|
|
||||||
|
示例场景:
|
||||||
|
- message_start 事件:input_tokens=100, output_tokens=0
|
||||||
|
- message_delta 事件:input_tokens=0, output_tokens=50
|
||||||
|
- 最终结果:input_tokens=100, output_tokens=50
|
||||||
|
|
||||||
|
注意事项:
|
||||||
|
- 此策略假设初始值为 0 是正确的默认状态
|
||||||
|
- 如果需要将已有值重置为 0,请直接修改实例属性(不使用此方法)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_tokens: 输入 tokens 数量
|
||||||
|
output_tokens: 输出 tokens 数量
|
||||||
|
cached_tokens: 缓存命中 tokens 数量
|
||||||
|
cache_creation_tokens: 缓存创建 tokens 数量
|
||||||
|
"""
|
||||||
|
if input_tokens is not None and (input_tokens > 0 or self.input_tokens == 0):
|
||||||
self.input_tokens = input_tokens
|
self.input_tokens = input_tokens
|
||||||
if output_tokens is not None:
|
if output_tokens is not None and (output_tokens > 0 or self.output_tokens == 0):
|
||||||
self.output_tokens = output_tokens
|
self.output_tokens = output_tokens
|
||||||
if cached_tokens is not None:
|
if cached_tokens is not None and (cached_tokens > 0 or self.cached_tokens == 0):
|
||||||
self.cached_tokens = cached_tokens
|
self.cached_tokens = cached_tokens
|
||||||
if cache_creation_tokens is not None:
|
if cache_creation_tokens is not None and (
|
||||||
|
cache_creation_tokens > 0 or self.cache_creation_tokens == 0
|
||||||
|
):
|
||||||
self.cache_creation_tokens = cache_creation_tokens
|
self.cache_creation_tokens = cache_creation_tokens
|
||||||
|
|
||||||
def mark_failed(self, status_code: int, error_message: str) -> None:
|
def mark_failed(self, status_code: int, error_message: str) -> None:
|
||||||
|
|||||||
31
src/api/handlers/base/utils.py
Normal file
31
src/api/handlers/base/utils.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
"""
|
||||||
|
Handler 基础工具函数
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
|
||||||
|
def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int:
|
||||||
|
"""
|
||||||
|
提取缓存创建 tokens(兼容新旧格式)
|
||||||
|
|
||||||
|
Claude API 在不同版本中使用了不同的字段名来表示缓存创建 tokens:
|
||||||
|
- 新格式(2024年后):使用 claude_cache_creation_5_m_tokens 和
|
||||||
|
claude_cache_creation_1_h_tokens 分别表示 5 分钟和 1 小时缓存
|
||||||
|
- 旧格式:使用 cache_creation_input_tokens 表示总的缓存创建 tokens
|
||||||
|
|
||||||
|
此函数自动检测并适配两种格式,优先使用新格式。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
usage: API 响应中的 usage 字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
缓存创建 tokens 总数
|
||||||
|
"""
|
||||||
|
# 优先使用新格式
|
||||||
|
cache_5m = usage.get("claude_cache_creation_5_m_tokens", 0)
|
||||||
|
cache_1h = usage.get("claude_cache_creation_1_h_tokens", 0)
|
||||||
|
total = int(cache_5m) + int(cache_1h)
|
||||||
|
|
||||||
|
# 如果新格式不存在(total == 0),回退到旧格式
|
||||||
|
return total if total > 0 else int(usage.get("cache_creation_input_tokens", 0))
|
||||||
@@ -8,6 +8,7 @@ Claude Chat Handler - 基于通用 Chat Handler 基类的简化实现
|
|||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
||||||
|
from src.api.handlers.base.utils import extract_cache_creation_tokens
|
||||||
|
|
||||||
|
|
||||||
class ClaudeChatHandler(ChatHandlerBase):
|
class ClaudeChatHandler(ChatHandlerBase):
|
||||||
@@ -63,7 +64,7 @@ class ClaudeChatHandler(ChatHandlerBase):
|
|||||||
result["model"] = mapped_model
|
result["model"] = mapped_model
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _convert_request(self, request):
|
async def _convert_request(self, request: Any) -> Any:
|
||||||
"""
|
"""
|
||||||
将请求转换为 Claude 格式
|
将请求转换为 Claude 格式
|
||||||
|
|
||||||
@@ -109,30 +110,18 @@ class ClaudeChatHandler(ChatHandlerBase):
|
|||||||
Claude 格式使用:
|
Claude 格式使用:
|
||||||
- input_tokens / output_tokens
|
- input_tokens / output_tokens
|
||||||
- cache_creation_input_tokens / cache_read_input_tokens
|
- cache_creation_input_tokens / cache_read_input_tokens
|
||||||
|
- 新格式:claude_cache_creation_5_m_tokens / claude_cache_creation_1_h_tokens
|
||||||
"""
|
"""
|
||||||
usage = response.get("usage", {})
|
usage = response.get("usage", {})
|
||||||
|
|
||||||
input_tokens = usage.get("input_tokens", 0)
|
|
||||||
output_tokens = usage.get("output_tokens", 0)
|
|
||||||
cache_creation_input_tokens = usage.get("cache_creation_input_tokens", 0)
|
|
||||||
cache_read_input_tokens = usage.get("cache_read_input_tokens", 0)
|
|
||||||
|
|
||||||
# 处理新的 cache_creation 格式
|
|
||||||
if "cache_creation" in usage:
|
|
||||||
cache_creation_data = usage.get("cache_creation", {})
|
|
||||||
if not cache_creation_input_tokens:
|
|
||||||
cache_creation_input_tokens = cache_creation_data.get(
|
|
||||||
"ephemeral_5m_input_tokens", 0
|
|
||||||
) + cache_creation_data.get("ephemeral_1h_input_tokens", 0)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_tokens": input_tokens,
|
"input_tokens": usage.get("input_tokens", 0),
|
||||||
"output_tokens": output_tokens,
|
"output_tokens": usage.get("output_tokens", 0),
|
||||||
"cache_creation_input_tokens": cache_creation_input_tokens,
|
"cache_creation_input_tokens": extract_cache_creation_tokens(usage),
|
||||||
"cache_read_input_tokens": cache_read_input_tokens,
|
"cache_read_input_tokens": usage.get("cache_read_input_tokens", 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
def _normalize_response(self, response: Dict) -> Dict:
|
def _normalize_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
规范化 Claude 响应
|
规范化 Claude 响应
|
||||||
|
|
||||||
@@ -143,8 +132,9 @@ class ClaudeChatHandler(ChatHandlerBase):
|
|||||||
规范化后的响应
|
规范化后的响应
|
||||||
"""
|
"""
|
||||||
if self.response_normalizer and self.response_normalizer.should_normalize(response):
|
if self.response_normalizer and self.response_normalizer.should_normalize(response):
|
||||||
return self.response_normalizer.normalize_claude_response(
|
result: Dict[str, Any] = self.response_normalizer.normalize_claude_response(
|
||||||
response_data=response,
|
response_data=response,
|
||||||
request_id=self.request_id,
|
request_id=self.request_id,
|
||||||
)
|
)
|
||||||
|
return result
|
||||||
return response
|
return response
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from src.api.handlers.base.utils import extract_cache_creation_tokens
|
||||||
|
|
||||||
|
|
||||||
class ClaudeStreamParser:
|
class ClaudeStreamParser:
|
||||||
"""
|
"""
|
||||||
@@ -193,7 +195,7 @@ class ClaudeStreamParser:
|
|||||||
return {
|
return {
|
||||||
"input_tokens": usage.get("input_tokens", 0),
|
"input_tokens": usage.get("input_tokens", 0),
|
||||||
"output_tokens": usage.get("output_tokens", 0),
|
"output_tokens": usage.get("output_tokens", 0),
|
||||||
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
|
"cache_creation_tokens": extract_cache_creation_tokens(usage),
|
||||||
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
|
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -204,7 +206,7 @@ class ClaudeStreamParser:
|
|||||||
return {
|
return {
|
||||||
"input_tokens": usage.get("input_tokens", 0),
|
"input_tokens": usage.get("input_tokens", 0),
|
||||||
"output_tokens": usage.get("output_tokens", 0),
|
"output_tokens": usage.get("output_tokens", 0),
|
||||||
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
|
"cache_creation_tokens": extract_cache_creation_tokens(usage),
|
||||||
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
|
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from src.api.handlers.base.cli_handler_base import (
|
|||||||
CliMessageHandlerBase,
|
CliMessageHandlerBase,
|
||||||
StreamContext,
|
StreamContext,
|
||||||
)
|
)
|
||||||
|
from src.api.handlers.base.utils import extract_cache_creation_tokens
|
||||||
|
|
||||||
|
|
||||||
class ClaudeCliMessageHandler(CliMessageHandlerBase):
|
class ClaudeCliMessageHandler(CliMessageHandlerBase):
|
||||||
@@ -95,11 +96,12 @@ class ClaudeCliMessageHandler(CliMessageHandlerBase):
|
|||||||
usage = message.get("usage", {})
|
usage = message.get("usage", {})
|
||||||
if usage:
|
if usage:
|
||||||
ctx.input_tokens = usage.get("input_tokens", 0)
|
ctx.input_tokens = usage.get("input_tokens", 0)
|
||||||
# Claude 的缓存 tokens 使用不同的字段名
|
|
||||||
cache_read = usage.get("cache_read_input_tokens", 0)
|
cache_read = usage.get("cache_read_input_tokens", 0)
|
||||||
if cache_read:
|
if cache_read:
|
||||||
ctx.cached_tokens = cache_read
|
ctx.cached_tokens = cache_read
|
||||||
cache_creation = usage.get("cache_creation_input_tokens", 0)
|
|
||||||
|
cache_creation = extract_cache_creation_tokens(usage)
|
||||||
if cache_creation:
|
if cache_creation:
|
||||||
ctx.cache_creation_tokens = cache_creation
|
ctx.cache_creation_tokens = cache_creation
|
||||||
|
|
||||||
@@ -119,11 +121,15 @@ class ClaudeCliMessageHandler(CliMessageHandlerBase):
|
|||||||
ctx.input_tokens = usage["input_tokens"]
|
ctx.input_tokens = usage["input_tokens"]
|
||||||
if "output_tokens" in usage:
|
if "output_tokens" in usage:
|
||||||
ctx.output_tokens = usage["output_tokens"]
|
ctx.output_tokens = usage["output_tokens"]
|
||||||
# 更新缓存 tokens
|
|
||||||
|
# 更新缓存读取 tokens
|
||||||
if "cache_read_input_tokens" in usage:
|
if "cache_read_input_tokens" in usage:
|
||||||
ctx.cached_tokens = usage["cache_read_input_tokens"]
|
ctx.cached_tokens = usage["cache_read_input_tokens"]
|
||||||
if "cache_creation_input_tokens" in usage:
|
|
||||||
ctx.cache_creation_tokens = usage["cache_creation_input_tokens"]
|
# 更新缓存创建 tokens
|
||||||
|
cache_creation = extract_cache_creation_tokens(usage)
|
||||||
|
if cache_creation > 0:
|
||||||
|
ctx.cache_creation_tokens = cache_creation
|
||||||
|
|
||||||
# 检查是否结束
|
# 检查是否结束
|
||||||
delta = data.get("delta", {})
|
delta = data.get("delta", {})
|
||||||
|
|||||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""测试模块"""
|
||||||
90
tests/api/handlers/base/test_utils.py
Normal file
90
tests/api/handlers/base/test_utils.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
"""测试 handler 基础工具函数"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.api.handlers.base.utils import extract_cache_creation_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractCacheCreationTokens:
|
||||||
|
"""测试 extract_cache_creation_tokens 函数"""
|
||||||
|
|
||||||
|
def test_new_format_only(self) -> None:
|
||||||
|
"""测试只有新格式字段"""
|
||||||
|
usage = {
|
||||||
|
"claude_cache_creation_5_m_tokens": 100,
|
||||||
|
"claude_cache_creation_1_h_tokens": 200,
|
||||||
|
}
|
||||||
|
assert extract_cache_creation_tokens(usage) == 300
|
||||||
|
|
||||||
|
def test_new_format_5m_only(self) -> None:
|
||||||
|
"""测试只有 5 分钟缓存"""
|
||||||
|
usage = {
|
||||||
|
"claude_cache_creation_5_m_tokens": 150,
|
||||||
|
"claude_cache_creation_1_h_tokens": 0,
|
||||||
|
}
|
||||||
|
assert extract_cache_creation_tokens(usage) == 150
|
||||||
|
|
||||||
|
def test_new_format_1h_only(self) -> None:
|
||||||
|
"""测试只有 1 小时缓存"""
|
||||||
|
usage = {
|
||||||
|
"claude_cache_creation_5_m_tokens": 0,
|
||||||
|
"claude_cache_creation_1_h_tokens": 250,
|
||||||
|
}
|
||||||
|
assert extract_cache_creation_tokens(usage) == 250
|
||||||
|
|
||||||
|
def test_old_format_only(self) -> None:
|
||||||
|
"""测试只有旧格式字段"""
|
||||||
|
usage = {
|
||||||
|
"cache_creation_input_tokens": 500,
|
||||||
|
}
|
||||||
|
assert extract_cache_creation_tokens(usage) == 500
|
||||||
|
|
||||||
|
def test_both_formats_prefers_new(self) -> None:
|
||||||
|
"""测试同时存在时优先使用新格式"""
|
||||||
|
usage = {
|
||||||
|
"claude_cache_creation_5_m_tokens": 100,
|
||||||
|
"claude_cache_creation_1_h_tokens": 200,
|
||||||
|
"cache_creation_input_tokens": 999, # 应该被忽略
|
||||||
|
}
|
||||||
|
assert extract_cache_creation_tokens(usage) == 300
|
||||||
|
|
||||||
|
def test_empty_usage(self) -> None:
|
||||||
|
"""测试空字典"""
|
||||||
|
usage = {}
|
||||||
|
assert extract_cache_creation_tokens(usage) == 0
|
||||||
|
|
||||||
|
def test_all_zeros(self) -> None:
|
||||||
|
"""测试所有字段都为 0"""
|
||||||
|
usage = {
|
||||||
|
"claude_cache_creation_5_m_tokens": 0,
|
||||||
|
"claude_cache_creation_1_h_tokens": 0,
|
||||||
|
"cache_creation_input_tokens": 0,
|
||||||
|
}
|
||||||
|
assert extract_cache_creation_tokens(usage) == 0
|
||||||
|
|
||||||
|
def test_partial_new_format_with_old_format_fallback(self) -> None:
|
||||||
|
"""测试新格式字段不存在时回退到旧格式"""
|
||||||
|
usage = {
|
||||||
|
"cache_creation_input_tokens": 123,
|
||||||
|
}
|
||||||
|
assert extract_cache_creation_tokens(usage) == 123
|
||||||
|
|
||||||
|
def test_new_format_zero_fallback_to_old(self) -> None:
|
||||||
|
"""测试新格式为 0 时回退到旧格式"""
|
||||||
|
usage = {
|
||||||
|
"claude_cache_creation_5_m_tokens": 0,
|
||||||
|
"claude_cache_creation_1_h_tokens": 0,
|
||||||
|
"cache_creation_input_tokens": 456,
|
||||||
|
}
|
||||||
|
assert extract_cache_creation_tokens(usage) == 456
|
||||||
|
|
||||||
|
def test_unrelated_fields_ignored(self) -> None:
|
||||||
|
"""测试忽略无关字段"""
|
||||||
|
usage = {
|
||||||
|
"input_tokens": 1000,
|
||||||
|
"output_tokens": 2000,
|
||||||
|
"cache_read_input_tokens": 300,
|
||||||
|
"claude_cache_creation_5_m_tokens": 50,
|
||||||
|
"claude_cache_creation_1_h_tokens": 75,
|
||||||
|
}
|
||||||
|
assert extract_cache_creation_tokens(usage) == 125
|
||||||
Reference in New Issue
Block a user