Files
Aether/src/api/handlers/base/stream_context.py
fawney19 ad1c8c394c refactor(handler): optimize stream processing and telemetry pipeline
- Enhance stream context for better token and latency tracking
- Refactor stream processor for improved performance metrics
- Improve telemetry integration with first_byte_time_ms support
- Add comprehensive stream context unit tests
2025-12-16 02:39:03 +08:00

240 lines
8.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
流式处理上下文 - 类型安全的数据类替代 dict
提供流式请求处理过程中的状态跟踪,包括:
- Provider/Endpoint/Key 信息
- Token 统计
- 响应状态
- 请求/响应数据
"""
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
@dataclass
class StreamContext:
"""
流式处理上下文
用于在流式请求处理过程中跟踪状态,替代原有的 ctx dict。
所有字段都有类型注解,提供更好的 IDE 支持和运行时类型安全。
"""
# 请求基本信息
model: str
api_format: str
# 请求标识信息CLI handler 需要)
request_id: str = ""
user_id: int = 0
api_key_id: int = 0
# Provider 信息(在请求执行时填充)
provider_name: Optional[str] = None
provider_id: Optional[str] = None
endpoint_id: Optional[str] = None
key_id: Optional[str] = None
attempt_id: Optional[str] = None
attempt_synced: bool = False
provider_api_format: Optional[str] = None # Provider 的响应格式
# 模型映射
mapped_model: Optional[str] = None
# Token 统计
input_tokens: int = 0
output_tokens: int = 0
cached_tokens: int = 0
cache_creation_tokens: int = 0
# 响应内容
_collected_text_parts: List[str] = field(default_factory=list, repr=False)
response_id: Optional[str] = None
final_usage: Optional[Dict[str, Any]] = None
final_response: Optional[Dict[str, Any]] = None
# 时间指标
first_byte_time_ms: Optional[int] = None # 首字时间 (TTFB - Time To First Byte)
start_time: float = field(default_factory=time.time)
# 响应状态
status_code: int = 200
error_message: Optional[str] = None
has_completion: bool = False
# 请求/响应数据
response_headers: Dict[str, str] = field(default_factory=dict)
provider_request_headers: Dict[str, str] = field(default_factory=dict)
provider_request_body: Optional[Dict[str, Any]] = None
# 格式转换信息CLI handler 需要)
client_api_format: str = ""
# Provider 响应元数据CLI handler 需要)
response_metadata: Dict[str, Any] = field(default_factory=dict)
# 流式处理统计
data_count: int = 0
chunk_count: int = 0
parsed_chunks: List[Dict[str, Any]] = field(default_factory=list)
def reset_for_retry(self) -> None:
"""
重试时重置状态
在故障转移重试时调用,清除之前的数据避免累积。
保留 model 和 api_format重置其他所有状态。
"""
self.parsed_chunks = []
self.chunk_count = 0
self.data_count = 0
self.has_completion = False
self._collected_text_parts = []
self.input_tokens = 0
self.output_tokens = 0
self.cached_tokens = 0
self.cache_creation_tokens = 0
self.error_message = None
self.status_code = 200
self.first_byte_time_ms = None
self.response_headers = {}
self.provider_request_headers = {}
self.provider_request_body = None
self.response_id = None
self.final_usage = None
self.final_response = None
@property
def collected_text(self) -> str:
"""已收集的文本内容(按需拼接,避免在流式过程中频繁做字符串拷贝)"""
return "".join(self._collected_text_parts)
def append_text(self, text: str) -> None:
"""追加文本内容(仅在需要收集文本时调用)"""
if text:
self._collected_text_parts.append(text)
def update_provider_info(
self,
provider_name: str,
provider_id: str,
endpoint_id: str,
key_id: str,
provider_api_format: Optional[str] = None,
) -> None:
"""更新 Provider 信息"""
self.provider_name = provider_name
self.provider_id = provider_id
self.endpoint_id = endpoint_id
self.key_id = key_id
self.provider_api_format = provider_api_format
def update_usage(
self,
input_tokens: Optional[int] = None,
output_tokens: Optional[int] = None,
cached_tokens: Optional[int] = None,
cache_creation_tokens: Optional[int] = None,
) -> None:
"""
更新 Token 使用统计
采用防御性更新策略:只有当新值 > 0 或当前值为 0 时才更新,避免用 0 覆盖已有的正确值。
设计原理:
- 在流式响应中,某些事件可能不包含完整的 usage 信息(字段为 0 或不存在)
- 后续事件可能会提供完整的统计数据
- 通过这种策略,确保一旦获得非零值就保留它,不会被后续的 0 值覆盖
示例场景:
- message_start 事件input_tokens=100, output_tokens=0
- message_delta 事件input_tokens=0, output_tokens=50
- 最终结果input_tokens=100, output_tokens=50
注意事项:
- 此策略假设初始值为 0 是正确的默认状态
- 如果需要将已有值重置为 0请直接修改实例属性不使用此方法
Args:
input_tokens: 输入 tokens 数量
output_tokens: 输出 tokens 数量
cached_tokens: 缓存命中 tokens 数量
cache_creation_tokens: 缓存创建 tokens 数量
"""
if input_tokens is not None and (input_tokens > 0 or self.input_tokens == 0):
self.input_tokens = input_tokens
if output_tokens is not None and (output_tokens > 0 or self.output_tokens == 0):
self.output_tokens = output_tokens
if cached_tokens is not None and (cached_tokens > 0 or self.cached_tokens == 0):
self.cached_tokens = cached_tokens
if cache_creation_tokens is not None and (
cache_creation_tokens > 0 or self.cache_creation_tokens == 0
):
self.cache_creation_tokens = cache_creation_tokens
def mark_failed(self, status_code: int, error_message: str) -> None:
"""标记请求失败"""
self.status_code = status_code
self.error_message = error_message
def record_first_byte_time(self, start_time: float) -> None:
"""
记录首字时间 (TTFB - Time To First Byte)
应在第一次向客户端发送数据时调用。
如果已记录过,则不会覆盖(避免重试时重复记录)。
Args:
start_time: 请求开始时间 (time.time())
"""
if self.first_byte_time_ms is None:
self.first_byte_time_ms = int((time.time() - start_time) * 1000)
def is_success(self) -> bool:
"""检查请求是否成功"""
return self.status_code < 400
def build_response_body(self, response_time_ms: int) -> Dict[str, Any]:
"""
构建响应体元数据
用于记录到 Usage 表的 response_body 字段。
"""
return {
"chunks": self.parsed_chunks,
"metadata": {
"stream": True,
"total_chunks": len(self.parsed_chunks),
"data_count": self.data_count,
"has_completion": self.has_completion,
"response_time_ms": response_time_ms,
},
}
def get_log_summary(self, request_id: str, response_time_ms: int) -> str:
"""
获取日志摘要
用于请求完成/失败时的日志输出。
包含首字时间 (TTFB) 和总响应时间,分两行显示。
"""
status = "OK" if self.is_success() else "FAIL"
# 第一行:基本信息 + 首字时间
line1 = (
f"[{status}] {request_id[:8]} | {self.model} | "
f"{self.provider_name or 'unknown'}"
)
if self.first_byte_time_ms is not None:
line1 += f" | TTFB: {self.first_byte_time_ms}ms"
# 第二行:总响应时间 + tokens
line2 = (
f" Total: {response_time_ms}ms | "
f"in:{self.input_tokens} out:{self.output_tokens}"
)
return f"{line1}\n{line2}"