diff --git a/src/api/handlers/base/parsers.py b/src/api/handlers/base/parsers.py index d268478..42fc3ee 100644 --- a/src/api/handlers/base/parsers.py +++ b/src/api/handlers/base/parsers.py @@ -13,6 +13,7 @@ from src.api.handlers.base.response_parser import ( ResponseParser, StreamStats, ) +from src.api.handlers.base.utils import extract_cache_creation_tokens def _check_nested_error(response: Dict[str, Any]) -> Tuple[bool, Optional[Dict[str, Any]]]: @@ -252,7 +253,7 @@ class ClaudeResponseParser(ResponseParser): usage = response.get("usage", {}) result.input_tokens = usage.get("input_tokens", 0) result.output_tokens = usage.get("output_tokens", 0) - result.cache_creation_tokens = usage.get("cache_creation_input_tokens", 0) + result.cache_creation_tokens = extract_cache_creation_tokens(usage) result.cache_read_tokens = usage.get("cache_read_input_tokens", 0) # 检查错误(支持嵌套错误格式) @@ -265,11 +266,16 @@ class ClaudeResponseParser(ResponseParser): return result def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]: + # 对于 message_start 事件,usage 在 message.usage 路径下 + # 对于其他响应,usage 在顶层 usage = response.get("usage", {}) + if not usage and "message" in response: + usage = response.get("message", {}).get("usage", {}) + return { "input_tokens": usage.get("input_tokens", 0), "output_tokens": usage.get("output_tokens", 0), - "cache_creation_tokens": usage.get("cache_creation_input_tokens", 0), + "cache_creation_tokens": extract_cache_creation_tokens(usage), "cache_read_tokens": usage.get("cache_read_input_tokens", 0), } diff --git a/src/api/handlers/base/stream_context.py b/src/api/handlers/base/stream_context.py index 711f6a1..9bff0ec 100644 --- a/src/api/handlers/base/stream_context.py +++ b/src/api/handlers/base/stream_context.py @@ -104,14 +104,40 @@ class StreamContext: cached_tokens: Optional[int] = None, cache_creation_tokens: Optional[int] = None, ) -> None: - """更新 Token 使用统计""" - if input_tokens is not None: + """ + 更新 Token 使用统计 + + 采用防御性更新策略:只有当新值 > 0 或当前值为 0 时才更新,避免用 0 覆盖已有的正确值。 + + 设计原理: + - 在流式响应中,某些事件可能不包含完整的 usage 信息(字段为 0 或不存在) + - 后续事件可能会提供完整的统计数据 + - 通过这种策略,确保一旦获得非零值就保留它,不会被后续的 0 值覆盖 + + 示例场景: + - message_start 事件:input_tokens=100, output_tokens=0 + - message_delta 事件:input_tokens=0, output_tokens=50 + - 最终结果:input_tokens=100, output_tokens=50 + + 注意事项: + - 此策略假设初始值为 0 是正确的默认状态 + - 如果需要将已有值重置为 0,请直接修改实例属性(不使用此方法) + + Args: + input_tokens: 输入 tokens 数量 + output_tokens: 输出 tokens 数量 + cached_tokens: 缓存命中 tokens 数量 + cache_creation_tokens: 缓存创建 tokens 数量 + """ + if input_tokens is not None and (input_tokens > 0 or self.input_tokens == 0): self.input_tokens = input_tokens - if output_tokens is not None: + if output_tokens is not None and (output_tokens > 0 or self.output_tokens == 0): self.output_tokens = output_tokens - if cached_tokens is not None: + if cached_tokens is not None and (cached_tokens > 0 or self.cached_tokens == 0): self.cached_tokens = cached_tokens - if cache_creation_tokens is not None: + if cache_creation_tokens is not None and ( + cache_creation_tokens > 0 or self.cache_creation_tokens == 0 + ): self.cache_creation_tokens = cache_creation_tokens def mark_failed(self, status_code: int, error_message: str) -> None: diff --git a/src/api/handlers/base/utils.py b/src/api/handlers/base/utils.py new file mode 100644 index 0000000..f4e2069 --- /dev/null +++ b/src/api/handlers/base/utils.py @@ -0,0 +1,31 @@ +""" +Handler 基础工具函数 +""" + +from typing import Any, Dict + + +def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int: + """ + 提取缓存创建 tokens(兼容新旧格式) + + Claude API 在不同版本中使用了不同的字段名来表示缓存创建 tokens: + - 新格式(2024年后):使用 claude_cache_creation_5_m_tokens 和 + claude_cache_creation_1_h_tokens 分别表示 5 分钟和 1 小时缓存 + - 旧格式:使用 cache_creation_input_tokens 表示总的缓存创建 tokens + + 此函数自动检测并适配两种格式,优先使用新格式。 + + Args: + usage: API 响应中的 usage 字典 + + Returns: + 缓存创建 tokens 总数 + """ + # 优先使用新格式 + cache_5m = usage.get("claude_cache_creation_5_m_tokens", 0) + cache_1h = usage.get("claude_cache_creation_1_h_tokens", 0) + total = int(cache_5m) + int(cache_1h) + + # 如果新格式不存在(total == 0),回退到旧格式 + return total if total > 0 else int(usage.get("cache_creation_input_tokens", 0)) diff --git a/src/api/handlers/claude/handler.py b/src/api/handlers/claude/handler.py index 5135c5d..6b45b01 100644 --- a/src/api/handlers/claude/handler.py +++ b/src/api/handlers/claude/handler.py @@ -8,6 +8,7 @@ Claude Chat Handler - 基于通用 Chat Handler 基类的简化实现 from typing import Any, Dict, Optional from src.api.handlers.base.chat_handler_base import ChatHandlerBase +from src.api.handlers.base.utils import extract_cache_creation_tokens class ClaudeChatHandler(ChatHandlerBase): @@ -63,7 +64,7 @@ class ClaudeChatHandler(ChatHandlerBase): result["model"] = mapped_model return result - async def _convert_request(self, request): + async def _convert_request(self, request: Any) -> Any: """ 将请求转换为 Claude 格式 @@ -109,30 +110,18 @@ class ClaudeChatHandler(ChatHandlerBase): Claude 格式使用: - input_tokens / output_tokens - cache_creation_input_tokens / cache_read_input_tokens + - 新格式:claude_cache_creation_5_m_tokens / claude_cache_creation_1_h_tokens """ usage = response.get("usage", {}) - input_tokens = usage.get("input_tokens", 0) - output_tokens = usage.get("output_tokens", 0) - cache_creation_input_tokens = usage.get("cache_creation_input_tokens", 0) - cache_read_input_tokens = usage.get("cache_read_input_tokens", 0) - - # 处理新的 cache_creation 格式 - if "cache_creation" in usage: - cache_creation_data = usage.get("cache_creation", {}) - if not cache_creation_input_tokens: - cache_creation_input_tokens = cache_creation_data.get( - "ephemeral_5m_input_tokens", 0 - ) + cache_creation_data.get("ephemeral_1h_input_tokens", 0) - return { - "input_tokens": input_tokens, - "output_tokens": output_tokens, - "cache_creation_input_tokens": cache_creation_input_tokens, - "cache_read_input_tokens": cache_read_input_tokens, + "input_tokens": usage.get("input_tokens", 0), + "output_tokens": usage.get("output_tokens", 0), + "cache_creation_input_tokens": extract_cache_creation_tokens(usage), + "cache_read_input_tokens": usage.get("cache_read_input_tokens", 0), } - def _normalize_response(self, response: Dict) -> Dict: + def _normalize_response(self, response: Dict[str, Any]) -> Dict[str, Any]: """ 规范化 Claude 响应 @@ -143,8 +132,9 @@ class ClaudeChatHandler(ChatHandlerBase): 规范化后的响应 """ if self.response_normalizer and self.response_normalizer.should_normalize(response): - return self.response_normalizer.normalize_claude_response( + result: Dict[str, Any] = self.response_normalizer.normalize_claude_response( response_data=response, request_id=self.request_id, ) + return result return response diff --git a/src/api/handlers/claude/stream_parser.py b/src/api/handlers/claude/stream_parser.py index 88ac955..cf51ba5 100644 --- a/src/api/handlers/claude/stream_parser.py +++ b/src/api/handlers/claude/stream_parser.py @@ -9,6 +9,8 @@ from __future__ import annotations import json from typing import Any, Dict, List, Optional +from src.api.handlers.base.utils import extract_cache_creation_tokens + class ClaudeStreamParser: """ @@ -193,7 +195,7 @@ class ClaudeStreamParser: return { "input_tokens": usage.get("input_tokens", 0), "output_tokens": usage.get("output_tokens", 0), - "cache_creation_tokens": usage.get("cache_creation_input_tokens", 0), + "cache_creation_tokens": extract_cache_creation_tokens(usage), "cache_read_tokens": usage.get("cache_read_input_tokens", 0), } @@ -204,7 +206,7 @@ class ClaudeStreamParser: return { "input_tokens": usage.get("input_tokens", 0), "output_tokens": usage.get("output_tokens", 0), - "cache_creation_tokens": usage.get("cache_creation_input_tokens", 0), + "cache_creation_tokens": extract_cache_creation_tokens(usage), "cache_read_tokens": usage.get("cache_read_input_tokens", 0), } diff --git a/src/api/handlers/claude_cli/handler.py b/src/api/handlers/claude_cli/handler.py index 136f829..bc5e822 100644 --- a/src/api/handlers/claude_cli/handler.py +++ b/src/api/handlers/claude_cli/handler.py @@ -11,6 +11,7 @@ from src.api.handlers.base.cli_handler_base import ( CliMessageHandlerBase, StreamContext, ) +from src.api.handlers.base.utils import extract_cache_creation_tokens class ClaudeCliMessageHandler(CliMessageHandlerBase): @@ -95,11 +96,12 @@ class ClaudeCliMessageHandler(CliMessageHandlerBase): usage = message.get("usage", {}) if usage: ctx.input_tokens = usage.get("input_tokens", 0) - # Claude 的缓存 tokens 使用不同的字段名 + cache_read = usage.get("cache_read_input_tokens", 0) if cache_read: ctx.cached_tokens = cache_read - cache_creation = usage.get("cache_creation_input_tokens", 0) + + cache_creation = extract_cache_creation_tokens(usage) if cache_creation: ctx.cache_creation_tokens = cache_creation @@ -119,11 +121,15 @@ class ClaudeCliMessageHandler(CliMessageHandlerBase): ctx.input_tokens = usage["input_tokens"] if "output_tokens" in usage: ctx.output_tokens = usage["output_tokens"] - # 更新缓存 tokens + + # 更新缓存读取 tokens if "cache_read_input_tokens" in usage: ctx.cached_tokens = usage["cache_read_input_tokens"] - if "cache_creation_input_tokens" in usage: - ctx.cache_creation_tokens = usage["cache_creation_input_tokens"] + + # 更新缓存创建 tokens + cache_creation = extract_cache_creation_tokens(usage) + if cache_creation > 0: + ctx.cache_creation_tokens = cache_creation # 检查是否结束 delta = data.get("delta", {}) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..412ecbc --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""测试模块""" diff --git a/tests/api/handlers/base/test_utils.py b/tests/api/handlers/base/test_utils.py new file mode 100644 index 0000000..87aa01c --- /dev/null +++ b/tests/api/handlers/base/test_utils.py @@ -0,0 +1,90 @@ +"""测试 handler 基础工具函数""" + +import pytest + +from src.api.handlers.base.utils import extract_cache_creation_tokens + + +class TestExtractCacheCreationTokens: + """测试 extract_cache_creation_tokens 函数""" + + def test_new_format_only(self) -> None: + """测试只有新格式字段""" + usage = { + "claude_cache_creation_5_m_tokens": 100, + "claude_cache_creation_1_h_tokens": 200, + } + assert extract_cache_creation_tokens(usage) == 300 + + def test_new_format_5m_only(self) -> None: + """测试只有 5 分钟缓存""" + usage = { + "claude_cache_creation_5_m_tokens": 150, + "claude_cache_creation_1_h_tokens": 0, + } + assert extract_cache_creation_tokens(usage) == 150 + + def test_new_format_1h_only(self) -> None: + """测试只有 1 小时缓存""" + usage = { + "claude_cache_creation_5_m_tokens": 0, + "claude_cache_creation_1_h_tokens": 250, + } + assert extract_cache_creation_tokens(usage) == 250 + + def test_old_format_only(self) -> None: + """测试只有旧格式字段""" + usage = { + "cache_creation_input_tokens": 500, + } + assert extract_cache_creation_tokens(usage) == 500 + + def test_both_formats_prefers_new(self) -> None: + """测试同时存在时优先使用新格式""" + usage = { + "claude_cache_creation_5_m_tokens": 100, + "claude_cache_creation_1_h_tokens": 200, + "cache_creation_input_tokens": 999, # 应该被忽略 + } + assert extract_cache_creation_tokens(usage) == 300 + + def test_empty_usage(self) -> None: + """测试空字典""" + usage = {} + assert extract_cache_creation_tokens(usage) == 0 + + def test_all_zeros(self) -> None: + """测试所有字段都为 0""" + usage = { + "claude_cache_creation_5_m_tokens": 0, + "claude_cache_creation_1_h_tokens": 0, + "cache_creation_input_tokens": 0, + } + assert extract_cache_creation_tokens(usage) == 0 + + def test_partial_new_format_with_old_format_fallback(self) -> None: + """测试新格式字段不存在时回退到旧格式""" + usage = { + "cache_creation_input_tokens": 123, + } + assert extract_cache_creation_tokens(usage) == 123 + + def test_new_format_zero_fallback_to_old(self) -> None: + """测试新格式为 0 时回退到旧格式""" + usage = { + "claude_cache_creation_5_m_tokens": 0, + "claude_cache_creation_1_h_tokens": 0, + "cache_creation_input_tokens": 456, + } + assert extract_cache_creation_tokens(usage) == 456 + + def test_unrelated_fields_ignored(self) -> None: + """测试忽略无关字段""" + usage = { + "input_tokens": 1000, + "output_tokens": 2000, + "cache_read_input_tokens": 300, + "claude_cache_creation_5_m_tokens": 50, + "claude_cache_creation_1_h_tokens": 75, + } + assert extract_cache_creation_tokens(usage) == 125