From 15a9b88fc88d27210f5dba47c5d4be2e27e63fd4 Mon Sep 17 00:00:00 2001 From: Hwwwww-dev <47653238+Hwwwww-dev@users.noreply.github.com> Date: Wed, 24 Dec 2025 01:31:45 +0800 Subject: [PATCH] feat: enhance extract_cache_creation_tokens function to support three formats[#41] (#42) - Updated the function to prioritize nested format, followed by flat new format, and finally old format for cache creation tokens. - Added fallback logic for cases where the preferred formats return zero. - Expanded unit tests to cover new format scenarios and ensure proper functionality across all formats. Co-authored-by: heweimin --- src/api/handlers/base/utils.py | 84 +++++++++++++++---- tests/api/handlers/base/test_utils.py | 112 ++++++++++++++++---------- 2 files changed, 140 insertions(+), 56 deletions(-) diff --git a/src/api/handlers/base/utils.py b/src/api/handlers/base/utils.py index 92b50fc..54172d1 100644 --- a/src/api/handlers/base/utils.py +++ b/src/api/handlers/base/utils.py @@ -4,17 +4,28 @@ Handler 基础工具函数 from typing import Any, Dict, Optional +from src.core.logger import logger + def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int: """ - 提取缓存创建 tokens(兼容新旧格式) + 提取缓存创建 tokens(兼容三种格式) - Claude API 在不同版本中使用了不同的字段名来表示缓存创建 tokens: - - 新格式(2024年后):使用 claude_cache_creation_5_m_tokens 和 - claude_cache_creation_1_h_tokens 分别表示 5 分钟和 1 小时缓存 - - 旧格式:使用 cache_creation_input_tokens 表示总的缓存创建 tokens + 根据 Anthropic API 文档,支持三种格式(按优先级): - 此函数自动检测并适配两种格式,优先使用新格式。 + 1. **嵌套格式(优先级最高)**: + usage.cache_creation.ephemeral_5m_input_tokens + usage.cache_creation.ephemeral_1h_input_tokens + + 2. **扁平新格式(优先级第二)**: + usage.claude_cache_creation_5_m_tokens + usage.claude_cache_creation_1_h_tokens + + 3. **旧格式(优先级第三)**: + usage.cache_creation_input_tokens + + 优先使用嵌套格式,如果嵌套格式字段存在但值为 0,则智能 fallback 到旧格式。 + 扁平格式和嵌套格式互斥,按顺序检查。 Args: usage: API 响应中的 usage 字典 @@ -22,20 +33,63 @@ def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int: Returns: 缓存创建 tokens 总数 """ - # 检查新格式字段是否存在(而非值是否为 0) - # 如果字段存在,即使值为 0 也是合法的,不应 fallback 到旧格式 - has_new_format = ( + # 1. 检查嵌套格式(最新格式) + cache_creation = usage.get("cache_creation") + if isinstance(cache_creation, dict): + cache_5m = int(cache_creation.get("ephemeral_5m_input_tokens", 0)) + cache_1h = int(cache_creation.get("ephemeral_1h_input_tokens", 0)) + total = cache_5m + cache_1h + + if total > 0: + logger.debug( + f"Using nested cache_creation: 5m={cache_5m}, 1h={cache_1h}, total={total}" + ) + return total + + # 嵌套格式存在但为 0,fallback 到旧格式 + old_format = int(usage.get("cache_creation_input_tokens", 0)) + if old_format > 0: + logger.debug( + f"Nested cache_creation is 0, using old format: {old_format}" + ) + return old_format + + # 都是 0,返回 0 + return 0 + + # 2. 检查扁平新格式 + has_flat_format = ( "claude_cache_creation_5_m_tokens" in usage or "claude_cache_creation_1_h_tokens" in usage ) - if has_new_format: - cache_5m = usage.get("claude_cache_creation_5_m_tokens", 0) - cache_1h = usage.get("claude_cache_creation_1_h_tokens", 0) - return int(cache_5m) + int(cache_1h) + if has_flat_format: + cache_5m = int(usage.get("claude_cache_creation_5_m_tokens", 0)) + cache_1h = int(usage.get("claude_cache_creation_1_h_tokens", 0)) + total = cache_5m + cache_1h - # 回退到旧格式 - return int(usage.get("cache_creation_input_tokens", 0)) + if total > 0: + logger.debug( + f"Using flat new format: 5m={cache_5m}, 1h={cache_1h}, total={total}" + ) + return total + + # 扁平格式存在但为 0,fallback 到旧格式 + old_format = int(usage.get("cache_creation_input_tokens", 0)) + if old_format > 0: + logger.debug( + f"Flat cache_creation is 0, using old format: {old_format}" + ) + return old_format + + # 都是 0,返回 0 + return 0 + + # 3. 回退到旧格式 + old_format = int(usage.get("cache_creation_input_tokens", 0)) + if old_format > 0: + logger.debug(f"Using old format: cache_creation_input_tokens={old_format}") + return old_format def build_sse_headers(extra_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]: diff --git a/tests/api/handlers/base/test_utils.py b/tests/api/handlers/base/test_utils.py index 7ba39eb..6670b38 100644 --- a/tests/api/handlers/base/test_utils.py +++ b/tests/api/handlers/base/test_utils.py @@ -8,86 +8,116 @@ from src.api.handlers.base.utils import build_sse_headers, extract_cache_creatio class TestExtractCacheCreationTokens: """测试 extract_cache_creation_tokens 函数""" - def test_new_format_only(self) -> None: - """测试只有新格式字段""" + # === 嵌套格式测试(优先级最高)=== + + def test_nested_cache_creation_format(self) -> None: + """测试嵌套格式正常情况""" + usage = { + "cache_creation": { + "ephemeral_5m_input_tokens": 456, + "ephemeral_1h_input_tokens": 100, + } + } + assert extract_cache_creation_tokens(usage) == 556 + + def test_nested_cache_creation_with_old_format_fallback(self) -> None: + """测试嵌套格式为 0 时回退到旧格式""" + usage = { + "cache_creation": { + "ephemeral_5m_input_tokens": 0, + "ephemeral_1h_input_tokens": 0, + }, + "cache_creation_input_tokens": 549, + } + assert extract_cache_creation_tokens(usage) == 549 + + def test_nested_has_priority_over_flat(self) -> None: + """测试嵌套格式优先于扁平格式""" + usage = { + "cache_creation": { + "ephemeral_5m_input_tokens": 100, + "ephemeral_1h_input_tokens": 200, + }, + "claude_cache_creation_5_m_tokens": 999, # 应该被忽略 + "claude_cache_creation_1_h_tokens": 888, # 应该被忽略 + "cache_creation_input_tokens": 777, # 应该被忽略 + } + assert extract_cache_creation_tokens(usage) == 300 + + # === 扁平格式测试(优先级第二)=== + + def test_flat_new_format_still_works(self) -> None: + """测试扁平新格式兼容性""" usage = { "claude_cache_creation_5_m_tokens": 100, "claude_cache_creation_1_h_tokens": 200, } assert extract_cache_creation_tokens(usage) == 300 - def test_new_format_5m_only(self) -> None: - """测试只有 5 分钟缓存""" + def test_flat_new_format_with_old_format_fallback(self) -> None: + """测试扁平格式为 0 时回退到旧格式""" + usage = { + "claude_cache_creation_5_m_tokens": 0, + "claude_cache_creation_1_h_tokens": 0, + "cache_creation_input_tokens": 549, + } + assert extract_cache_creation_tokens(usage) == 549 + + def test_flat_new_format_5m_only(self) -> None: + """测试只有 5 分钟扁平缓存""" usage = { "claude_cache_creation_5_m_tokens": 150, "claude_cache_creation_1_h_tokens": 0, } assert extract_cache_creation_tokens(usage) == 150 - def test_new_format_1h_only(self) -> None: - """测试只有 1 小时缓存""" + def test_flat_new_format_1h_only(self) -> None: + """测试只有 1 小时扁平缓存""" usage = { "claude_cache_creation_5_m_tokens": 0, "claude_cache_creation_1_h_tokens": 250, } assert extract_cache_creation_tokens(usage) == 250 + # === 旧格式测试(优先级第三)=== + def test_old_format_only(self) -> None: - """测试只有旧格式字段""" + """测试只有旧格式""" usage = { - "cache_creation_input_tokens": 500, + "cache_creation_input_tokens": 549, } - assert extract_cache_creation_tokens(usage) == 500 + assert extract_cache_creation_tokens(usage) == 549 - def test_both_formats_prefers_new(self) -> None: - """测试同时存在时优先使用新格式""" - usage = { - "claude_cache_creation_5_m_tokens": 100, - "claude_cache_creation_1_h_tokens": 200, - "cache_creation_input_tokens": 999, # 应该被忽略 - } - assert extract_cache_creation_tokens(usage) == 300 + # === 边界情况测试 === - def test_empty_usage(self) -> None: - """测试空字典""" + def test_no_cache_creation_tokens(self) -> None: + """测试没有任何缓存字段""" usage = {} assert extract_cache_creation_tokens(usage) == 0 - def test_all_zeros(self) -> None: - """测试所有字段都为 0""" + def test_all_formats_zero(self) -> None: + """测试所有格式都为 0""" usage = { + "cache_creation": { + "ephemeral_5m_input_tokens": 0, + "ephemeral_1h_input_tokens": 0, + }, "claude_cache_creation_5_m_tokens": 0, "claude_cache_creation_1_h_tokens": 0, "cache_creation_input_tokens": 0, } assert extract_cache_creation_tokens(usage) == 0 - def test_partial_new_format_with_old_format_fallback(self) -> None: - """测试新格式字段不存在时回退到旧格式""" - usage = { - "cache_creation_input_tokens": 123, - } - assert extract_cache_creation_tokens(usage) == 123 - - def test_new_format_zero_should_not_fallback(self) -> None: - """测试新格式字段存在但为 0 时,不应 fallback 到旧格式""" - usage = { - "claude_cache_creation_5_m_tokens": 0, - "claude_cache_creation_1_h_tokens": 0, - "cache_creation_input_tokens": 456, - } - # 新格式字段存在,即使值为 0 也应该使用新格式(返回 0) - # 而不是 fallback 到旧格式(返回 456) - assert extract_cache_creation_tokens(usage) == 0 - def test_unrelated_fields_ignored(self) -> None: """测试忽略无关字段""" usage = { "input_tokens": 1000, "output_tokens": 2000, "cache_read_input_tokens": 300, - "claude_cache_creation_5_m_tokens": 50, - "claude_cache_creation_1_h_tokens": 75, + "cache_creation": { + "ephemeral_5m_input_tokens": 50, + "ephemeral_1h_input_tokens": 75, + }, } assert extract_cache_creation_tokens(usage) == 125