From ad1c8c394cf12e3a4c7bd77cb7a1cb2f0fc0a282 Mon Sep 17 00:00:00 2001 From: fawney19 Date: Tue, 16 Dec 2025 02:39:03 +0800 Subject: [PATCH] refactor(handler): optimize stream processing and telemetry pipeline - Enhance stream context for better token and latency tracking - Refactor stream processor for improved performance metrics - Improve telemetry integration with first_byte_time_ms support - Add comprehensive stream context unit tests --- src/api/handlers/base/base_handler.py | 3 + src/api/handlers/base/chat_handler_base.py | 20 +- src/api/handlers/base/stream_context.py | 67 ++++- src/api/handlers/base/stream_processor.py | 259 ++++++++++++------ src/api/handlers/base/stream_telemetry.py | 10 +- src/api/handlers/base/utils.py | 38 ++- .../api/handlers/base/test_stream_context.py | 117 ++++++++ tests/api/handlers/base/test_utils.py | 22 +- 8 files changed, 428 insertions(+), 108 deletions(-) create mode 100644 tests/api/handlers/base/test_stream_context.py diff --git a/src/api/handlers/base/base_handler.py b/src/api/handlers/base/base_handler.py index fc5651a..81aca7c 100644 --- a/src/api/handlers/base/base_handler.py +++ b/src/api/handlers/base/base_handler.py @@ -100,6 +100,8 @@ class MessageTelemetry: cache_read_tokens: int = 0, is_stream: bool = False, provider_request_headers: Optional[Dict[str, Any]] = None, + # 时间指标 + first_byte_time_ms: Optional[int] = None, # 首字时间/TTFB # Provider 侧追踪信息(用于记录真实成本) provider_id: Optional[str] = None, provider_endpoint_id: Optional[str] = None, @@ -133,6 +135,7 @@ class MessageTelemetry: api_format=api_format, is_stream=is_stream, response_time_ms=response_time_ms, + first_byte_time_ms=first_byte_time_ms, # 传递首字时间 status_code=status_code, request_headers=request_headers, request_body=request_body, diff --git a/src/api/handlers/base/chat_handler_base.py b/src/api/handlers/base/chat_handler_base.py index fa4d019..1ebe285 100644 --- a/src/api/handlers/base/chat_handler_base.py +++ b/src/api/handlers/base/chat_handler_base.py @@ -34,6 +34,7 @@ from src.api.handlers.base.response_parser import ResponseParser from src.api.handlers.base.stream_context import StreamContext from src.api.handlers.base.stream_processor import StreamProcessor from src.api.handlers.base.stream_telemetry import StreamTelemetryRecorder +from src.api.handlers.base.utils import build_sse_headers from src.config.settings import config from src.core.exceptions import ( EmbeddedErrorException, @@ -365,7 +366,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC): ctx, original_headers, original_request_body, - self.elapsed_ms(), + self.start_time, # 传入开始时间,让 telemetry 在流结束后计算响应时间 ) # 创建监控流 @@ -378,6 +379,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC): return StreamingResponse( monitored_stream, media_type="text/event-stream", + headers=build_sse_headers(), background=background_tasks, ) @@ -473,12 +475,13 @@ class ChatHandlerBase(BaseMessageHandler, ABC): stream_response.raise_for_status() - # 创建行迭代器 - line_iterator = stream_response.aiter_lines() + # 使用字节流迭代器(避免 aiter_lines 的性能问题) + # aiter_raw() 返回原始数据块,无缓冲,实现真正的流式传输 + byte_iterator = stream_response.aiter_raw() # 预读检测嵌套错误 - prefetched_lines = await stream_processor.prefetch_and_check_error( - line_iterator, + prefetched_chunks = await stream_processor.prefetch_and_check_error( + byte_iterator, provider, endpoint, ctx, @@ -503,13 +506,14 @@ class ChatHandlerBase(BaseMessageHandler, ABC): await http_client.aclose() raise - # 创建流生成器 + # 创建流生成器(传入字节流迭代器) return stream_processor.create_response_stream( ctx, - line_iterator, + byte_iterator, response_ctx, http_client, - prefetched_lines, + prefetched_chunks, + start_time=self.start_time, ) async def _record_stream_failure( diff --git a/src/api/handlers/base/stream_context.py b/src/api/handlers/base/stream_context.py index 9bff0ec..342ac89 100644 --- a/src/api/handlers/base/stream_context.py +++ b/src/api/handlers/base/stream_context.py @@ -8,6 +8,7 @@ - 请求/响应数据 """ +import time from dataclasses import dataclass, field from typing import Any, Dict, List, Optional @@ -25,12 +26,18 @@ class StreamContext: model: str api_format: str + # 请求标识信息(CLI handler 需要) + request_id: str = "" + user_id: int = 0 + api_key_id: int = 0 + # Provider 信息(在请求执行时填充) provider_name: Optional[str] = None provider_id: Optional[str] = None endpoint_id: Optional[str] = None key_id: Optional[str] = None attempt_id: Optional[str] = None + attempt_synced: bool = False provider_api_format: Optional[str] = None # Provider 的响应格式 # 模型映射 @@ -43,7 +50,14 @@ class StreamContext: cache_creation_tokens: int = 0 # 响应内容 - collected_text: str = "" + _collected_text_parts: List[str] = field(default_factory=list, repr=False) + response_id: Optional[str] = None + final_usage: Optional[Dict[str, Any]] = None + final_response: Optional[Dict[str, Any]] = None + + # 时间指标 + first_byte_time_ms: Optional[int] = None # 首字时间 (TTFB - Time To First Byte) + start_time: float = field(default_factory=time.time) # 响应状态 status_code: int = 200 @@ -55,6 +69,12 @@ class StreamContext: provider_request_headers: Dict[str, str] = field(default_factory=dict) provider_request_body: Optional[Dict[str, Any]] = None + # 格式转换信息(CLI handler 需要) + client_api_format: str = "" + + # Provider 响应元数据(CLI handler 需要) + response_metadata: Dict[str, Any] = field(default_factory=dict) + # 流式处理统计 data_count: int = 0 chunk_count: int = 0 @@ -71,16 +91,30 @@ class StreamContext: self.chunk_count = 0 self.data_count = 0 self.has_completion = False - self.collected_text = "" + self._collected_text_parts = [] self.input_tokens = 0 self.output_tokens = 0 self.cached_tokens = 0 self.cache_creation_tokens = 0 self.error_message = None self.status_code = 200 + self.first_byte_time_ms = None self.response_headers = {} self.provider_request_headers = {} self.provider_request_body = None + self.response_id = None + self.final_usage = None + self.final_response = None + + @property + def collected_text(self) -> str: + """已收集的文本内容(按需拼接,避免在流式过程中频繁做字符串拷贝)""" + return "".join(self._collected_text_parts) + + def append_text(self, text: str) -> None: + """追加文本内容(仅在需要收集文本时调用)""" + if text: + self._collected_text_parts.append(text) def update_provider_info( self, @@ -145,6 +179,19 @@ class StreamContext: self.status_code = status_code self.error_message = error_message + def record_first_byte_time(self, start_time: float) -> None: + """ + 记录首字时间 (TTFB - Time To First Byte) + + 应在第一次向客户端发送数据时调用。 + 如果已记录过,则不会覆盖(避免重试时重复记录)。 + + Args: + start_time: 请求开始时间 (time.time()) + """ + if self.first_byte_time_ms is None: + self.first_byte_time_ms = int((time.time() - start_time) * 1000) + def is_success(self) -> bool: """检查请求是否成功""" return self.status_code < 400 @@ -171,10 +218,22 @@ class StreamContext: 获取日志摘要 用于请求完成/失败时的日志输出。 + 包含首字时间 (TTFB) 和总响应时间,分两行显示。 """ status = "OK" if self.is_success() else "FAIL" - return ( + + # 第一行:基本信息 + 首字时间 + line1 = ( f"[{status}] {request_id[:8]} | {self.model} | " - f"{self.provider_name or 'unknown'} | {response_time_ms}ms | " + f"{self.provider_name or 'unknown'}" + ) + if self.first_byte_time_ms is not None: + line1 += f" | TTFB: {self.first_byte_time_ms}ms" + + # 第二行:总响应时间 + tokens + line2 = ( + f" Total: {response_time_ms}ms | " f"in:{self.input_tokens} out:{self.output_tokens}" ) + + return f"{line1}\n{line2}" diff --git a/src/api/handlers/base/stream_processor.py b/src/api/handlers/base/stream_processor.py index 5ff85ee..a5bcb22 100644 --- a/src/api/handlers/base/stream_processor.py +++ b/src/api/handlers/base/stream_processor.py @@ -9,7 +9,9 @@ """ import asyncio +import codecs import json +import time from typing import Any, AsyncGenerator, Callable, Optional import httpx @@ -36,6 +38,8 @@ class StreamProcessor: request_id: str, default_parser: ResponseParser, on_streaming_start: Optional[Callable[[], None]] = None, + *, + collect_text: bool = False, ): """ 初始化流处理器 @@ -48,6 +52,7 @@ class StreamProcessor: self.request_id = request_id self.default_parser = default_parser self.on_streaming_start = on_streaming_start + self.collect_text = collect_text def get_parser_for_provider(self, ctx: StreamContext) -> ResponseParser: """ @@ -112,9 +117,10 @@ class StreamProcessor: ) # 提取文本 - text = parser.extract_text_content(data) - if text: - ctx.collected_text += text + if self.collect_text: + text = parser.extract_text_content(data) + if text: + ctx.append_text(text) # 检查完成 event_type = event_name or data.get("type", "") @@ -123,7 +129,7 @@ class StreamProcessor: async def prefetch_and_check_error( self, - line_iterator: Any, + byte_iterator: Any, provider: Provider, endpoint: ProviderEndpoint, ctx: StreamContext, @@ -136,97 +142,126 @@ class StreamProcessor: 这种情况需要在流开始输出之前检测,以便触发重试逻辑。 Args: - line_iterator: 行迭代器 + byte_iterator: 字节流迭代器 provider: Provider 对象 endpoint: Endpoint 对象 ctx: 流式上下文 max_prefetch_lines: 最多预读行数 Returns: - 预读的行列表 + 预读的字节块列表 Raises: EmbeddedErrorException: 如果检测到嵌套错误 """ - prefetched_lines: list = [] + prefetched_chunks: list = [] parser = self.get_parser_for_provider(ctx) + buffer = b"" + line_count = 0 + should_stop = False + # 使用增量解码器处理跨 chunk 的 UTF-8 字符 + decoder = codecs.getincrementaldecoder("utf-8")(errors="replace") try: - line_count = 0 - async for line in line_iterator: - prefetched_lines.append(line) - line_count += 1 + async for chunk in byte_iterator: + prefetched_chunks.append(chunk) + buffer += chunk - normalized_line = line.rstrip("\r") - if not normalized_line or normalized_line.startswith(":"): - if line_count >= max_prefetch_lines: + # 尝试按行解析缓冲区 + while b"\n" in buffer: + line_bytes, buffer = buffer.split(b"\n", 1) + try: + # 使用增量解码器,可以正确处理跨 chunk 的多字节字符 + line = decoder.decode(line_bytes + b"\n", False).rstrip("\r\n") + except Exception as e: + logger.warning( + f"[{self.request_id}] 预读时 UTF-8 解码失败: {e}, " + f"bytes={line_bytes[:50]!r}" + ) + continue + + line_count += 1 + + # 跳过空行和注释行 + if not line or line.startswith(":"): + if line_count >= max_prefetch_lines: + should_stop = True + break + continue + + # 尝试解析 SSE 数据 + data_str = line + if line.startswith("data: "): + data_str = line[6:] + + if data_str == "[DONE]": + should_stop = True break - continue - # 尝试解析 SSE 数据 - data_str = normalized_line - if normalized_line.startswith("data: "): - data_str = normalized_line[6:] + try: + data = json.loads(data_str) + except json.JSONDecodeError: + if line_count >= max_prefetch_lines: + should_stop = True + break + continue - if data_str == "[DONE]": + # 使用解析器检查是否为错误响应 + if isinstance(data, dict) and parser.is_error_response(data): + parsed = parser.parse_response(data, 200) + logger.warning( + f" [{self.request_id}] 检测到嵌套错误: " + f"Provider={provider.name}, " + f"error_type={parsed.error_type}, " + f"message={parsed.error_message}" + ) + raise EmbeddedErrorException( + provider_name=str(provider.name), + error_code=( + int(parsed.error_type) + if parsed.error_type and parsed.error_type.isdigit() + else None + ), + error_message=parsed.error_message, + error_status=parsed.error_type, + ) + + # 预读到有效数据,没有错误,停止预读 + should_stop = True break - try: - data = json.loads(data_str) - except json.JSONDecodeError: - if line_count >= max_prefetch_lines: - break - continue - - # 使用解析器检查是否为错误响应 - if isinstance(data, dict) and parser.is_error_response(data): - parsed = parser.parse_response(data, 200) - logger.warning( - f" [{self.request_id}] 检测到嵌套错误: " - f"Provider={provider.name}, " - f"error_type={parsed.error_type}, " - f"message={parsed.error_message}" - ) - raise EmbeddedErrorException( - provider_name=str(provider.name), - error_code=( - int(parsed.error_type) - if parsed.error_type and parsed.error_type.isdigit() - else None - ), - error_message=parsed.error_message, - error_status=parsed.error_type, - ) - - # 预读到有效数据,没有错误,停止预读 - break + if should_stop or line_count >= max_prefetch_lines: + break except EmbeddedErrorException: raise except Exception as e: logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}") - return prefetched_lines + return prefetched_chunks async def create_response_stream( self, ctx: StreamContext, - line_iterator: Any, + byte_iterator: Any, response_ctx: Any, http_client: httpx.AsyncClient, - prefetched_lines: Optional[list] = None, + prefetched_chunks: Optional[list] = None, + *, + start_time: Optional[float] = None, ) -> AsyncGenerator[bytes, None]: """ 创建响应流生成器 - 统一的流生成器,支持带预读数据和不带预读数据两种情况。 + 从字节流中解析 SSE 数据并转发,支持预读数据。 Args: ctx: 流式上下文 - line_iterator: 行迭代器 + byte_iterator: 字节流迭代器 response_ctx: HTTP 响应上下文管理器 http_client: HTTP 客户端 - prefetched_lines: 预读的行列表(可选) + prefetched_chunks: 预读的字节块列表(可选) + start_time: 请求开始时间,用于计算 TTFB(可选) Yields: 编码后的响应数据块 @@ -234,25 +269,82 @@ class StreamProcessor: try: sse_parser = SSEEventParser() streaming_started = False + buffer = b"" + # 使用增量解码器处理跨 chunk 的 UTF-8 字符 + decoder = codecs.getincrementaldecoder("utf-8")(errors="replace") # 处理预读数据 - if prefetched_lines: + if prefetched_chunks: if not streaming_started and self.on_streaming_start: self.on_streaming_start() streaming_started = True - for line in prefetched_lines: - for chunk in self._process_line(ctx, sse_parser, line): - yield chunk + for chunk in prefetched_chunks: + # 记录首字时间 (TTFB) - 在 yield 之前记录 + if start_time is not None: + ctx.record_first_byte_time(start_time) + start_time = None # 只记录一次 + + # 把原始数据转发给客户端 + yield chunk + + buffer += chunk + # 处理缓冲区中的完整行 + while b"\n" in buffer: + line_bytes, buffer = buffer.split(b"\n", 1) + try: + # 使用增量解码器,可以正确处理跨 chunk 的多字节字符 + line = decoder.decode(line_bytes + b"\n", False) + self._process_line(ctx, sse_parser, line) + except Exception as e: + # 解码失败,记录警告但继续处理 + logger.warning( + f"[{self.request_id}] UTF-8 解码失败: {e}, " + f"bytes={line_bytes[:50]!r}" + ) + continue # 处理剩余的流数据 - async for line in line_iterator: + async for chunk in byte_iterator: if not streaming_started and self.on_streaming_start: self.on_streaming_start() streaming_started = True - for chunk in self._process_line(ctx, sse_parser, line): - yield chunk + # 记录首字时间 (TTFB) - 在 yield 之前记录(如果预读数据为空) + if start_time is not None: + ctx.record_first_byte_time(start_time) + start_time = None # 只记录一次 + + # 原始数据透传 + yield chunk + + buffer += chunk + # 处理缓冲区中的完整行 + while b"\n" in buffer: + line_bytes, buffer = buffer.split(b"\n", 1) + try: + # 使用增量解码器,可以正确处理跨 chunk 的多字节字符 + line = decoder.decode(line_bytes + b"\n", False) + self._process_line(ctx, sse_parser, line) + except Exception as e: + # 解码失败,记录警告但继续处理 + logger.warning( + f"[{self.request_id}] UTF-8 解码失败: {e}, " + f"bytes={line_bytes[:50]!r}" + ) + continue + + # 处理剩余的缓冲区数据(如果有未完成的行) + if buffer: + try: + # 使用 final=True 处理最后的不完整字符 + line = decoder.decode(buffer, True) + self._process_line(ctx, sse_parser, line) + except Exception as e: + logger.warning( + f"[{self.request_id}] 处理剩余缓冲区失败: {e}, " + f"bytes={buffer[:50]!r}" + ) # 处理剩余事件 for event in sse_parser.flush(): @@ -268,7 +360,7 @@ class StreamProcessor: ctx: StreamContext, sse_parser: SSEEventParser, line: str, - ) -> list[bytes]: + ) -> None: """ 处理单行数据 @@ -276,26 +368,17 @@ class StreamProcessor: ctx: 流式上下文 sse_parser: SSE 解析器 line: 原始行数据 - - Returns: - 要发送的数据块列表 """ - result: list[bytes] = [] - normalized_line = line.rstrip("\r") + # SSEEventParser 以“去掉换行符”的单行文本作为输入;这里统一剔除 CR/LF, + # 避免把空行误判成 "\n" 并导致事件边界解析错误。 + normalized_line = line.rstrip("\r\n") events = sse_parser.feed_line(normalized_line) - if normalized_line == "": - for event in events: - self.handle_sse_event(ctx, event.get("event"), event.get("data") or "") - result.append(b"\n") - else: + if normalized_line != "": ctx.chunk_count += 1 - result.append((line + "\n").encode("utf-8")) - for event in events: - self.handle_sse_event(ctx, event.get("event"), event.get("data") or "") - - return result + for event in events: + self.handle_sse_event(ctx, event.get("event"), event.get("data") or "") async def create_monitored_stream( self, @@ -317,16 +400,26 @@ class StreamProcessor: 响应数据块 """ try: + # 断连检查频率:每次 await 都会引入调度开销,过于频繁会让流式"发一段停一段" + # 这里按时间间隔节流,兼顾及时停止上游读取与吞吐平滑性。 + next_disconnect_check_at = 0.0 + disconnect_check_interval_s = 0.25 + async for chunk in stream_generator: - if await is_disconnected(): - logger.warning(f"ID:{self.request_id} | Client disconnected") - ctx.status_code = 499 # Client Closed Request - ctx.error_message = "client_disconnected" - break + now = time.monotonic() + if now >= next_disconnect_check_at: + next_disconnect_check_at = now + disconnect_check_interval_s + if await is_disconnected(): + logger.warning(f"ID:{self.request_id} | Client disconnected") + ctx.status_code = 499 # Client Closed Request + ctx.error_message = "client_disconnected" + + break yield chunk except asyncio.CancelledError: ctx.status_code = 499 ctx.error_message = "client_disconnected" + raise except Exception as e: ctx.status_code = 500 diff --git a/src/api/handlers/base/stream_telemetry.py b/src/api/handlers/base/stream_telemetry.py index 2d3d7ba..8d4f03f 100644 --- a/src/api/handlers/base/stream_telemetry.py +++ b/src/api/handlers/base/stream_telemetry.py @@ -8,6 +8,7 @@ """ import asyncio +import time from typing import Any, Dict, Optional from sqlalchemy.orm import Session @@ -57,7 +58,7 @@ class StreamTelemetryRecorder: ctx: StreamContext, original_headers: Dict[str, str], original_request_body: Dict[str, Any], - response_time_ms: int, + start_time: float, ) -> None: """ 记录流式统计信息 @@ -66,11 +67,15 @@ class StreamTelemetryRecorder: ctx: 流式上下文 original_headers: 原始请求头 original_request_body: 原始请求体 - response_time_ms: 响应时间(毫秒) + start_time: 请求开始时间 (time.time()) """ bg_db = None try: + # 在流结束后计算响应时间,与首字时间使用相同的时间基准 + # 注意:不要把统计延迟(stream_stats_delay)算进响应时间里 + response_time_ms = int((time.time() - start_time) * 1000) + await asyncio.sleep(config.stream_stats_delay) # 等待流完全关闭 if not ctx.provider_name: @@ -155,6 +160,7 @@ class StreamTelemetryRecorder: input_tokens=ctx.input_tokens, output_tokens=ctx.output_tokens, response_time_ms=response_time_ms, + first_byte_time_ms=ctx.first_byte_time_ms, # 传递首字时间 status_code=ctx.status_code, request_headers=original_headers, request_body=actual_request_body, diff --git a/src/api/handlers/base/utils.py b/src/api/handlers/base/utils.py index f4e2069..92b50fc 100644 --- a/src/api/handlers/base/utils.py +++ b/src/api/handlers/base/utils.py @@ -2,7 +2,7 @@ Handler 基础工具函数 """ -from typing import Any, Dict +from typing import Any, Dict, Optional def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int: @@ -22,10 +22,34 @@ def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int: Returns: 缓存创建 tokens 总数 """ - # 优先使用新格式 - cache_5m = usage.get("claude_cache_creation_5_m_tokens", 0) - cache_1h = usage.get("claude_cache_creation_1_h_tokens", 0) - total = int(cache_5m) + int(cache_1h) + # 检查新格式字段是否存在(而非值是否为 0) + # 如果字段存在,即使值为 0 也是合法的,不应 fallback 到旧格式 + has_new_format = ( + "claude_cache_creation_5_m_tokens" in usage + or "claude_cache_creation_1_h_tokens" in usage + ) - # 如果新格式不存在(total == 0),回退到旧格式 - return total if total > 0 else int(usage.get("cache_creation_input_tokens", 0)) + if has_new_format: + cache_5m = usage.get("claude_cache_creation_5_m_tokens", 0) + cache_1h = usage.get("claude_cache_creation_1_h_tokens", 0) + return int(cache_5m) + int(cache_1h) + + # 回退到旧格式 + return int(usage.get("cache_creation_input_tokens", 0)) + + +def build_sse_headers(extra_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]: + """ + 构建 SSE(text/event-stream)推荐响应头,用于减少代理缓冲带来的卡顿/成段输出。 + + 说明: + - Cache-Control: no-transform 可避免部分代理对流做压缩/改写导致缓冲 + - X-Accel-Buffering: no 可显式提示 Nginx 关闭缓冲(即使全局已关闭也无害) + """ + headers: Dict[str, str] = { + "Cache-Control": "no-cache, no-transform", + "X-Accel-Buffering": "no", + } + if extra_headers: + headers.update(extra_headers) + return headers diff --git a/tests/api/handlers/base/test_stream_context.py b/tests/api/handlers/base/test_stream_context.py new file mode 100644 index 0000000..fb0bb0c --- /dev/null +++ b/tests/api/handlers/base/test_stream_context.py @@ -0,0 +1,117 @@ +from src.api.handlers.base import stream_context +from src.api.handlers.base.stream_context import StreamContext + + +def test_collected_text_append_and_property() -> None: + ctx = StreamContext(model="test-model", api_format="OPENAI") + assert ctx.collected_text == "" + + ctx.append_text("hello") + ctx.append_text(" ") + ctx.append_text("world") + assert ctx.collected_text == "hello world" + + +def test_reset_for_retry_clears_state() -> None: + ctx = StreamContext(model="test-model", api_format="OPENAI") + ctx.append_text("x") + ctx.update_usage(input_tokens=10, output_tokens=5) + ctx.parsed_chunks.append({"type": "chunk"}) + ctx.chunk_count = 3 + ctx.data_count = 2 + ctx.has_completion = True + ctx.status_code = 418 + ctx.error_message = "boom" + + ctx.reset_for_retry() + + assert ctx.collected_text == "" + assert ctx.input_tokens == 0 + assert ctx.output_tokens == 0 + assert ctx.parsed_chunks == [] + assert ctx.chunk_count == 0 + assert ctx.data_count == 0 + assert ctx.has_completion is False + assert ctx.status_code == 200 + assert ctx.error_message is None + + +def test_record_first_byte_time(monkeypatch) -> None: + """测试记录首字时间""" + ctx = StreamContext(model="claude-3", api_format="claude_messages") + start_time = 100.0 + monkeypatch.setattr(stream_context.time, "time", lambda: 100.0123) # 12.3ms + + # 记录首字时间 + ctx.record_first_byte_time(start_time) + + # 验证首字时间已记录 + assert ctx.first_byte_time_ms == 12 + + +def test_record_first_byte_time_idempotent(monkeypatch) -> None: + """测试首字时间只记录一次""" + ctx = StreamContext(model="claude-3", api_format="claude_messages") + start_time = 100.0 + + # 第一次记录 + monkeypatch.setattr(stream_context.time, "time", lambda: 100.010) + ctx.record_first_byte_time(start_time) + first_value = ctx.first_byte_time_ms + + # 第二次记录(应该被忽略) + monkeypatch.setattr(stream_context.time, "time", lambda: 100.020) + ctx.record_first_byte_time(start_time) + second_value = ctx.first_byte_time_ms + + # 验证值没有改变 + assert first_value == second_value + + +def test_reset_for_retry_clears_first_byte_time(monkeypatch) -> None: + """测试重试时清除首字时间""" + ctx = StreamContext(model="claude-3", api_format="claude_messages") + start_time = 100.0 + + # 记录首字时间 + monkeypatch.setattr(stream_context.time, "time", lambda: 100.010) + ctx.record_first_byte_time(start_time) + assert ctx.first_byte_time_ms is not None + + # 重置 + ctx.reset_for_retry() + + # 验证首字时间已清除 + assert ctx.first_byte_time_ms is None + + +def test_get_log_summary_with_first_byte_time() -> None: + """测试日志摘要包含首字时间""" + ctx = StreamContext(model="claude-3", api_format="claude_messages") + ctx.provider_name = "anthropic" + ctx.input_tokens = 100 + ctx.output_tokens = 50 + ctx.first_byte_time_ms = 123 + + summary = ctx.get_log_summary("request-id-123", 456) + + # 验证包含首字时间和总时间(大写格式) + assert "TTFB: 123ms" in summary + assert "Total: 456ms" in summary + assert "in:100 out:50" in summary + + +def test_get_log_summary_without_first_byte_time() -> None: + """测试日志摘要在没有首字时间时的格式""" + ctx = StreamContext(model="claude-3", api_format="claude_messages") + ctx.provider_name = "anthropic" + ctx.input_tokens = 100 + ctx.output_tokens = 50 + # first_byte_time_ms 保持为 None + + summary = ctx.get_log_summary("request-id-123", 456) + + # 验证不包含首字时间标记,但有总时间(使用大写 TTFB 和 Total) + assert "TTFB:" not in summary + assert "Total: 456ms" in summary + assert "in:100 out:50" in summary diff --git a/tests/api/handlers/base/test_utils.py b/tests/api/handlers/base/test_utils.py index 87aa01c..7ba39eb 100644 --- a/tests/api/handlers/base/test_utils.py +++ b/tests/api/handlers/base/test_utils.py @@ -2,7 +2,7 @@ import pytest -from src.api.handlers.base.utils import extract_cache_creation_tokens +from src.api.handlers.base.utils import build_sse_headers, extract_cache_creation_tokens class TestExtractCacheCreationTokens: @@ -69,14 +69,16 @@ class TestExtractCacheCreationTokens: } assert extract_cache_creation_tokens(usage) == 123 - def test_new_format_zero_fallback_to_old(self) -> None: - """测试新格式为 0 时回退到旧格式""" + def test_new_format_zero_should_not_fallback(self) -> None: + """测试新格式字段存在但为 0 时,不应 fallback 到旧格式""" usage = { "claude_cache_creation_5_m_tokens": 0, "claude_cache_creation_1_h_tokens": 0, "cache_creation_input_tokens": 456, } - assert extract_cache_creation_tokens(usage) == 456 + # 新格式字段存在,即使值为 0 也应该使用新格式(返回 0) + # 而不是 fallback 到旧格式(返回 456) + assert extract_cache_creation_tokens(usage) == 0 def test_unrelated_fields_ignored(self) -> None: """测试忽略无关字段""" @@ -88,3 +90,15 @@ class TestExtractCacheCreationTokens: "claude_cache_creation_1_h_tokens": 75, } assert extract_cache_creation_tokens(usage) == 125 + + +class TestBuildSSEHeaders: + def test_default_headers(self) -> None: + headers = build_sse_headers() + assert headers["Cache-Control"] == "no-cache, no-transform" + assert headers["X-Accel-Buffering"] == "no" + + def test_merge_extra_headers(self) -> None: + headers = build_sse_headers({"X-Test": "1", "Cache-Control": "custom"}) + assert headers["X-Test"] == "1" + assert headers["Cache-Control"] == "custom"