Files
Aether/src/api/handlers/base/stream_context.py
fawney19 f3a69a6160 refactor(handler): implement defensive token update strategy and extract cache creation token utility
- Add extract_cache_creation_tokens utility to handle new/old cache creation token formats
- Implement defensive update strategy in StreamContext to prevent zero values overwriting valid data
- Simplify cache creation token parsing in Claude handler using new utility
- Add comprehensive test suite for cache creation token extraction
- Improve type hints in handler classes
2025-12-16 00:02:49 +08:00

181 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
流式处理上下文 - 类型安全的数据类替代 dict
提供流式请求处理过程中的状态跟踪,包括:
- Provider/Endpoint/Key 信息
- Token 统计
- 响应状态
- 请求/响应数据
"""
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
@dataclass
class StreamContext:
"""
流式处理上下文
用于在流式请求处理过程中跟踪状态,替代原有的 ctx dict。
所有字段都有类型注解,提供更好的 IDE 支持和运行时类型安全。
"""
# 请求基本信息
model: str
api_format: str
# Provider 信息(在请求执行时填充)
provider_name: Optional[str] = None
provider_id: Optional[str] = None
endpoint_id: Optional[str] = None
key_id: Optional[str] = None
attempt_id: Optional[str] = None
provider_api_format: Optional[str] = None # Provider 的响应格式
# 模型映射
mapped_model: Optional[str] = None
# Token 统计
input_tokens: int = 0
output_tokens: int = 0
cached_tokens: int = 0
cache_creation_tokens: int = 0
# 响应内容
collected_text: str = ""
# 响应状态
status_code: int = 200
error_message: Optional[str] = None
has_completion: bool = False
# 请求/响应数据
response_headers: Dict[str, str] = field(default_factory=dict)
provider_request_headers: Dict[str, str] = field(default_factory=dict)
provider_request_body: Optional[Dict[str, Any]] = None
# 流式处理统计
data_count: int = 0
chunk_count: int = 0
parsed_chunks: List[Dict[str, Any]] = field(default_factory=list)
def reset_for_retry(self) -> None:
"""
重试时重置状态
在故障转移重试时调用,清除之前的数据避免累积。
保留 model 和 api_format重置其他所有状态。
"""
self.parsed_chunks = []
self.chunk_count = 0
self.data_count = 0
self.has_completion = False
self.collected_text = ""
self.input_tokens = 0
self.output_tokens = 0
self.cached_tokens = 0
self.cache_creation_tokens = 0
self.error_message = None
self.status_code = 200
self.response_headers = {}
self.provider_request_headers = {}
self.provider_request_body = None
def update_provider_info(
self,
provider_name: str,
provider_id: str,
endpoint_id: str,
key_id: str,
provider_api_format: Optional[str] = None,
) -> None:
"""更新 Provider 信息"""
self.provider_name = provider_name
self.provider_id = provider_id
self.endpoint_id = endpoint_id
self.key_id = key_id
self.provider_api_format = provider_api_format
def update_usage(
self,
input_tokens: Optional[int] = None,
output_tokens: Optional[int] = None,
cached_tokens: Optional[int] = None,
cache_creation_tokens: Optional[int] = None,
) -> None:
"""
更新 Token 使用统计
采用防御性更新策略:只有当新值 > 0 或当前值为 0 时才更新,避免用 0 覆盖已有的正确值。
设计原理:
- 在流式响应中,某些事件可能不包含完整的 usage 信息(字段为 0 或不存在)
- 后续事件可能会提供完整的统计数据
- 通过这种策略,确保一旦获得非零值就保留它,不会被后续的 0 值覆盖
示例场景:
- message_start 事件input_tokens=100, output_tokens=0
- message_delta 事件input_tokens=0, output_tokens=50
- 最终结果input_tokens=100, output_tokens=50
注意事项:
- 此策略假设初始值为 0 是正确的默认状态
- 如果需要将已有值重置为 0请直接修改实例属性不使用此方法
Args:
input_tokens: 输入 tokens 数量
output_tokens: 输出 tokens 数量
cached_tokens: 缓存命中 tokens 数量
cache_creation_tokens: 缓存创建 tokens 数量
"""
if input_tokens is not None and (input_tokens > 0 or self.input_tokens == 0):
self.input_tokens = input_tokens
if output_tokens is not None and (output_tokens > 0 or self.output_tokens == 0):
self.output_tokens = output_tokens
if cached_tokens is not None and (cached_tokens > 0 or self.cached_tokens == 0):
self.cached_tokens = cached_tokens
if cache_creation_tokens is not None and (
cache_creation_tokens > 0 or self.cache_creation_tokens == 0
):
self.cache_creation_tokens = cache_creation_tokens
def mark_failed(self, status_code: int, error_message: str) -> None:
"""标记请求失败"""
self.status_code = status_code
self.error_message = error_message
def is_success(self) -> bool:
"""检查请求是否成功"""
return self.status_code < 400
def build_response_body(self, response_time_ms: int) -> Dict[str, Any]:
"""
构建响应体元数据
用于记录到 Usage 表的 response_body 字段。
"""
return {
"chunks": self.parsed_chunks,
"metadata": {
"stream": True,
"total_chunks": len(self.parsed_chunks),
"data_count": self.data_count,
"has_completion": self.has_completion,
"response_time_ms": response_time_ms,
},
}
def get_log_summary(self, request_id: str, response_time_ms: int) -> str:
"""
获取日志摘要
用于请求完成/失败时的日志输出。
"""
status = "OK" if self.is_success() else "FAIL"
return (
f"[{status}] {request_id[:8]} | {self.model} | "
f"{self.provider_name or 'unknown'} | {response_time_ms}ms | "
f"in:{self.input_tokens} out:{self.output_tokens}"
)