mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-02 15:52:26 +08:00
refactor(handler): optimize stream processing and telemetry pipeline
- Enhance stream context for better token and latency tracking - Refactor stream processor for improved performance metrics - Improve telemetry integration with first_byte_time_ms support - Add comprehensive stream context unit tests
This commit is contained in:
@@ -100,6 +100,8 @@ class MessageTelemetry:
|
|||||||
cache_read_tokens: int = 0,
|
cache_read_tokens: int = 0,
|
||||||
is_stream: bool = False,
|
is_stream: bool = False,
|
||||||
provider_request_headers: Optional[Dict[str, Any]] = None,
|
provider_request_headers: Optional[Dict[str, Any]] = None,
|
||||||
|
# 时间指标
|
||||||
|
first_byte_time_ms: Optional[int] = None, # 首字时间/TTFB
|
||||||
# Provider 侧追踪信息(用于记录真实成本)
|
# Provider 侧追踪信息(用于记录真实成本)
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
provider_endpoint_id: Optional[str] = None,
|
provider_endpoint_id: Optional[str] = None,
|
||||||
@@ -133,6 +135,7 @@ class MessageTelemetry:
|
|||||||
api_format=api_format,
|
api_format=api_format,
|
||||||
is_stream=is_stream,
|
is_stream=is_stream,
|
||||||
response_time_ms=response_time_ms,
|
response_time_ms=response_time_ms,
|
||||||
|
first_byte_time_ms=first_byte_time_ms, # 传递首字时间
|
||||||
status_code=status_code,
|
status_code=status_code,
|
||||||
request_headers=request_headers,
|
request_headers=request_headers,
|
||||||
request_body=request_body,
|
request_body=request_body,
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from src.api.handlers.base.response_parser import ResponseParser
|
|||||||
from src.api.handlers.base.stream_context import StreamContext
|
from src.api.handlers.base.stream_context import StreamContext
|
||||||
from src.api.handlers.base.stream_processor import StreamProcessor
|
from src.api.handlers.base.stream_processor import StreamProcessor
|
||||||
from src.api.handlers.base.stream_telemetry import StreamTelemetryRecorder
|
from src.api.handlers.base.stream_telemetry import StreamTelemetryRecorder
|
||||||
|
from src.api.handlers.base.utils import build_sse_headers
|
||||||
from src.config.settings import config
|
from src.config.settings import config
|
||||||
from src.core.exceptions import (
|
from src.core.exceptions import (
|
||||||
EmbeddedErrorException,
|
EmbeddedErrorException,
|
||||||
@@ -365,7 +366,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
ctx,
|
ctx,
|
||||||
original_headers,
|
original_headers,
|
||||||
original_request_body,
|
original_request_body,
|
||||||
self.elapsed_ms(),
|
self.start_time, # 传入开始时间,让 telemetry 在流结束后计算响应时间
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建监控流
|
# 创建监控流
|
||||||
@@ -378,6 +379,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
monitored_stream,
|
monitored_stream,
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
|
headers=build_sse_headers(),
|
||||||
background=background_tasks,
|
background=background_tasks,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -473,12 +475,13 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
|
|
||||||
stream_response.raise_for_status()
|
stream_response.raise_for_status()
|
||||||
|
|
||||||
# 创建行迭代器
|
# 使用字节流迭代器(避免 aiter_lines 的性能问题)
|
||||||
line_iterator = stream_response.aiter_lines()
|
# aiter_raw() 返回原始数据块,无缓冲,实现真正的流式传输
|
||||||
|
byte_iterator = stream_response.aiter_raw()
|
||||||
|
|
||||||
# 预读检测嵌套错误
|
# 预读检测嵌套错误
|
||||||
prefetched_lines = await stream_processor.prefetch_and_check_error(
|
prefetched_chunks = await stream_processor.prefetch_and_check_error(
|
||||||
line_iterator,
|
byte_iterator,
|
||||||
provider,
|
provider,
|
||||||
endpoint,
|
endpoint,
|
||||||
ctx,
|
ctx,
|
||||||
@@ -503,13 +506,14 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
await http_client.aclose()
|
await http_client.aclose()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# 创建流生成器
|
# 创建流生成器(传入字节流迭代器)
|
||||||
return stream_processor.create_response_stream(
|
return stream_processor.create_response_stream(
|
||||||
ctx,
|
ctx,
|
||||||
line_iterator,
|
byte_iterator,
|
||||||
response_ctx,
|
response_ctx,
|
||||||
http_client,
|
http_client,
|
||||||
prefetched_lines,
|
prefetched_chunks,
|
||||||
|
start_time=self.start_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _record_stream_failure(
|
async def _record_stream_failure(
|
||||||
|
|||||||
@@ -8,6 +8,7 @@
|
|||||||
- 请求/响应数据
|
- 请求/响应数据
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
@@ -25,12 +26,18 @@ class StreamContext:
|
|||||||
model: str
|
model: str
|
||||||
api_format: str
|
api_format: str
|
||||||
|
|
||||||
|
# 请求标识信息(CLI handler 需要)
|
||||||
|
request_id: str = ""
|
||||||
|
user_id: int = 0
|
||||||
|
api_key_id: int = 0
|
||||||
|
|
||||||
# Provider 信息(在请求执行时填充)
|
# Provider 信息(在请求执行时填充)
|
||||||
provider_name: Optional[str] = None
|
provider_name: Optional[str] = None
|
||||||
provider_id: Optional[str] = None
|
provider_id: Optional[str] = None
|
||||||
endpoint_id: Optional[str] = None
|
endpoint_id: Optional[str] = None
|
||||||
key_id: Optional[str] = None
|
key_id: Optional[str] = None
|
||||||
attempt_id: Optional[str] = None
|
attempt_id: Optional[str] = None
|
||||||
|
attempt_synced: bool = False
|
||||||
provider_api_format: Optional[str] = None # Provider 的响应格式
|
provider_api_format: Optional[str] = None # Provider 的响应格式
|
||||||
|
|
||||||
# 模型映射
|
# 模型映射
|
||||||
@@ -43,7 +50,14 @@ class StreamContext:
|
|||||||
cache_creation_tokens: int = 0
|
cache_creation_tokens: int = 0
|
||||||
|
|
||||||
# 响应内容
|
# 响应内容
|
||||||
collected_text: str = ""
|
_collected_text_parts: List[str] = field(default_factory=list, repr=False)
|
||||||
|
response_id: Optional[str] = None
|
||||||
|
final_usage: Optional[Dict[str, Any]] = None
|
||||||
|
final_response: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
# 时间指标
|
||||||
|
first_byte_time_ms: Optional[int] = None # 首字时间 (TTFB - Time To First Byte)
|
||||||
|
start_time: float = field(default_factory=time.time)
|
||||||
|
|
||||||
# 响应状态
|
# 响应状态
|
||||||
status_code: int = 200
|
status_code: int = 200
|
||||||
@@ -55,6 +69,12 @@ class StreamContext:
|
|||||||
provider_request_headers: Dict[str, str] = field(default_factory=dict)
|
provider_request_headers: Dict[str, str] = field(default_factory=dict)
|
||||||
provider_request_body: Optional[Dict[str, Any]] = None
|
provider_request_body: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
# 格式转换信息(CLI handler 需要)
|
||||||
|
client_api_format: str = ""
|
||||||
|
|
||||||
|
# Provider 响应元数据(CLI handler 需要)
|
||||||
|
response_metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
# 流式处理统计
|
# 流式处理统计
|
||||||
data_count: int = 0
|
data_count: int = 0
|
||||||
chunk_count: int = 0
|
chunk_count: int = 0
|
||||||
@@ -71,16 +91,30 @@ class StreamContext:
|
|||||||
self.chunk_count = 0
|
self.chunk_count = 0
|
||||||
self.data_count = 0
|
self.data_count = 0
|
||||||
self.has_completion = False
|
self.has_completion = False
|
||||||
self.collected_text = ""
|
self._collected_text_parts = []
|
||||||
self.input_tokens = 0
|
self.input_tokens = 0
|
||||||
self.output_tokens = 0
|
self.output_tokens = 0
|
||||||
self.cached_tokens = 0
|
self.cached_tokens = 0
|
||||||
self.cache_creation_tokens = 0
|
self.cache_creation_tokens = 0
|
||||||
self.error_message = None
|
self.error_message = None
|
||||||
self.status_code = 200
|
self.status_code = 200
|
||||||
|
self.first_byte_time_ms = None
|
||||||
self.response_headers = {}
|
self.response_headers = {}
|
||||||
self.provider_request_headers = {}
|
self.provider_request_headers = {}
|
||||||
self.provider_request_body = None
|
self.provider_request_body = None
|
||||||
|
self.response_id = None
|
||||||
|
self.final_usage = None
|
||||||
|
self.final_response = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def collected_text(self) -> str:
|
||||||
|
"""已收集的文本内容(按需拼接,避免在流式过程中频繁做字符串拷贝)"""
|
||||||
|
return "".join(self._collected_text_parts)
|
||||||
|
|
||||||
|
def append_text(self, text: str) -> None:
|
||||||
|
"""追加文本内容(仅在需要收集文本时调用)"""
|
||||||
|
if text:
|
||||||
|
self._collected_text_parts.append(text)
|
||||||
|
|
||||||
def update_provider_info(
|
def update_provider_info(
|
||||||
self,
|
self,
|
||||||
@@ -145,6 +179,19 @@ class StreamContext:
|
|||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
self.error_message = error_message
|
self.error_message = error_message
|
||||||
|
|
||||||
|
def record_first_byte_time(self, start_time: float) -> None:
|
||||||
|
"""
|
||||||
|
记录首字时间 (TTFB - Time To First Byte)
|
||||||
|
|
||||||
|
应在第一次向客户端发送数据时调用。
|
||||||
|
如果已记录过,则不会覆盖(避免重试时重复记录)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_time: 请求开始时间 (time.time())
|
||||||
|
"""
|
||||||
|
if self.first_byte_time_ms is None:
|
||||||
|
self.first_byte_time_ms = int((time.time() - start_time) * 1000)
|
||||||
|
|
||||||
def is_success(self) -> bool:
|
def is_success(self) -> bool:
|
||||||
"""检查请求是否成功"""
|
"""检查请求是否成功"""
|
||||||
return self.status_code < 400
|
return self.status_code < 400
|
||||||
@@ -171,10 +218,22 @@ class StreamContext:
|
|||||||
获取日志摘要
|
获取日志摘要
|
||||||
|
|
||||||
用于请求完成/失败时的日志输出。
|
用于请求完成/失败时的日志输出。
|
||||||
|
包含首字时间 (TTFB) 和总响应时间,分两行显示。
|
||||||
"""
|
"""
|
||||||
status = "OK" if self.is_success() else "FAIL"
|
status = "OK" if self.is_success() else "FAIL"
|
||||||
return (
|
|
||||||
|
# 第一行:基本信息 + 首字时间
|
||||||
|
line1 = (
|
||||||
f"[{status}] {request_id[:8]} | {self.model} | "
|
f"[{status}] {request_id[:8]} | {self.model} | "
|
||||||
f"{self.provider_name or 'unknown'} | {response_time_ms}ms | "
|
f"{self.provider_name or 'unknown'}"
|
||||||
|
)
|
||||||
|
if self.first_byte_time_ms is not None:
|
||||||
|
line1 += f" | TTFB: {self.first_byte_time_ms}ms"
|
||||||
|
|
||||||
|
# 第二行:总响应时间 + tokens
|
||||||
|
line2 = (
|
||||||
|
f" Total: {response_time_ms}ms | "
|
||||||
f"in:{self.input_tokens} out:{self.output_tokens}"
|
f"in:{self.input_tokens} out:{self.output_tokens}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return f"{line1}\n{line2}"
|
||||||
|
|||||||
@@ -9,7 +9,9 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import codecs
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
from typing import Any, AsyncGenerator, Callable, Optional
|
from typing import Any, AsyncGenerator, Callable, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@@ -36,6 +38,8 @@ class StreamProcessor:
|
|||||||
request_id: str,
|
request_id: str,
|
||||||
default_parser: ResponseParser,
|
default_parser: ResponseParser,
|
||||||
on_streaming_start: Optional[Callable[[], None]] = None,
|
on_streaming_start: Optional[Callable[[], None]] = None,
|
||||||
|
*,
|
||||||
|
collect_text: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
初始化流处理器
|
初始化流处理器
|
||||||
@@ -48,6 +52,7 @@ class StreamProcessor:
|
|||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self.default_parser = default_parser
|
self.default_parser = default_parser
|
||||||
self.on_streaming_start = on_streaming_start
|
self.on_streaming_start = on_streaming_start
|
||||||
|
self.collect_text = collect_text
|
||||||
|
|
||||||
def get_parser_for_provider(self, ctx: StreamContext) -> ResponseParser:
|
def get_parser_for_provider(self, ctx: StreamContext) -> ResponseParser:
|
||||||
"""
|
"""
|
||||||
@@ -112,9 +117,10 @@ class StreamProcessor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 提取文本
|
# 提取文本
|
||||||
|
if self.collect_text:
|
||||||
text = parser.extract_text_content(data)
|
text = parser.extract_text_content(data)
|
||||||
if text:
|
if text:
|
||||||
ctx.collected_text += text
|
ctx.append_text(text)
|
||||||
|
|
||||||
# 检查完成
|
# 检查完成
|
||||||
event_type = event_name or data.get("type", "")
|
event_type = event_name or data.get("type", "")
|
||||||
@@ -123,7 +129,7 @@ class StreamProcessor:
|
|||||||
|
|
||||||
async def prefetch_and_check_error(
|
async def prefetch_and_check_error(
|
||||||
self,
|
self,
|
||||||
line_iterator: Any,
|
byte_iterator: Any,
|
||||||
provider: Provider,
|
provider: Provider,
|
||||||
endpoint: ProviderEndpoint,
|
endpoint: ProviderEndpoint,
|
||||||
ctx: StreamContext,
|
ctx: StreamContext,
|
||||||
@@ -136,45 +142,67 @@ class StreamProcessor:
|
|||||||
这种情况需要在流开始输出之前检测,以便触发重试逻辑。
|
这种情况需要在流开始输出之前检测,以便触发重试逻辑。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
line_iterator: 行迭代器
|
byte_iterator: 字节流迭代器
|
||||||
provider: Provider 对象
|
provider: Provider 对象
|
||||||
endpoint: Endpoint 对象
|
endpoint: Endpoint 对象
|
||||||
ctx: 流式上下文
|
ctx: 流式上下文
|
||||||
max_prefetch_lines: 最多预读行数
|
max_prefetch_lines: 最多预读行数
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
预读的行列表
|
预读的字节块列表
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
EmbeddedErrorException: 如果检测到嵌套错误
|
EmbeddedErrorException: 如果检测到嵌套错误
|
||||||
"""
|
"""
|
||||||
prefetched_lines: list = []
|
prefetched_chunks: list = []
|
||||||
parser = self.get_parser_for_provider(ctx)
|
parser = self.get_parser_for_provider(ctx)
|
||||||
|
buffer = b""
|
||||||
|
line_count = 0
|
||||||
|
should_stop = False
|
||||||
|
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
|
||||||
|
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
line_count = 0
|
async for chunk in byte_iterator:
|
||||||
async for line in line_iterator:
|
prefetched_chunks.append(chunk)
|
||||||
prefetched_lines.append(line)
|
buffer += chunk
|
||||||
|
|
||||||
|
# 尝试按行解析缓冲区
|
||||||
|
while b"\n" in buffer:
|
||||||
|
line_bytes, buffer = buffer.split(b"\n", 1)
|
||||||
|
try:
|
||||||
|
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
|
||||||
|
line = decoder.decode(line_bytes + b"\n", False).rstrip("\r\n")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"[{self.request_id}] 预读时 UTF-8 解码失败: {e}, "
|
||||||
|
f"bytes={line_bytes[:50]!r}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
line_count += 1
|
line_count += 1
|
||||||
|
|
||||||
normalized_line = line.rstrip("\r")
|
# 跳过空行和注释行
|
||||||
if not normalized_line or normalized_line.startswith(":"):
|
if not line or line.startswith(":"):
|
||||||
if line_count >= max_prefetch_lines:
|
if line_count >= max_prefetch_lines:
|
||||||
|
should_stop = True
|
||||||
break
|
break
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 尝试解析 SSE 数据
|
# 尝试解析 SSE 数据
|
||||||
data_str = normalized_line
|
data_str = line
|
||||||
if normalized_line.startswith("data: "):
|
if line.startswith("data: "):
|
||||||
data_str = normalized_line[6:]
|
data_str = line[6:]
|
||||||
|
|
||||||
if data_str == "[DONE]":
|
if data_str == "[DONE]":
|
||||||
|
should_stop = True
|
||||||
break
|
break
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = json.loads(data_str)
|
data = json.loads(data_str)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
if line_count >= max_prefetch_lines:
|
if line_count >= max_prefetch_lines:
|
||||||
|
should_stop = True
|
||||||
break
|
break
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -199,6 +227,10 @@ class StreamProcessor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 预读到有效数据,没有错误,停止预读
|
# 预读到有效数据,没有错误,停止预读
|
||||||
|
should_stop = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if should_stop or line_count >= max_prefetch_lines:
|
||||||
break
|
break
|
||||||
|
|
||||||
except EmbeddedErrorException:
|
except EmbeddedErrorException:
|
||||||
@@ -206,27 +238,30 @@ class StreamProcessor:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
|
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
|
||||||
|
|
||||||
return prefetched_lines
|
return prefetched_chunks
|
||||||
|
|
||||||
async def create_response_stream(
|
async def create_response_stream(
|
||||||
self,
|
self,
|
||||||
ctx: StreamContext,
|
ctx: StreamContext,
|
||||||
line_iterator: Any,
|
byte_iterator: Any,
|
||||||
response_ctx: Any,
|
response_ctx: Any,
|
||||||
http_client: httpx.AsyncClient,
|
http_client: httpx.AsyncClient,
|
||||||
prefetched_lines: Optional[list] = None,
|
prefetched_chunks: Optional[list] = None,
|
||||||
|
*,
|
||||||
|
start_time: Optional[float] = None,
|
||||||
) -> AsyncGenerator[bytes, None]:
|
) -> AsyncGenerator[bytes, None]:
|
||||||
"""
|
"""
|
||||||
创建响应流生成器
|
创建响应流生成器
|
||||||
|
|
||||||
统一的流生成器,支持带预读数据和不带预读数据两种情况。
|
从字节流中解析 SSE 数据并转发,支持预读数据。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ctx: 流式上下文
|
ctx: 流式上下文
|
||||||
line_iterator: 行迭代器
|
byte_iterator: 字节流迭代器
|
||||||
response_ctx: HTTP 响应上下文管理器
|
response_ctx: HTTP 响应上下文管理器
|
||||||
http_client: HTTP 客户端
|
http_client: HTTP 客户端
|
||||||
prefetched_lines: 预读的行列表(可选)
|
prefetched_chunks: 预读的字节块列表(可选)
|
||||||
|
start_time: 请求开始时间,用于计算 TTFB(可选)
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
编码后的响应数据块
|
编码后的响应数据块
|
||||||
@@ -234,26 +269,83 @@ class StreamProcessor:
|
|||||||
try:
|
try:
|
||||||
sse_parser = SSEEventParser()
|
sse_parser = SSEEventParser()
|
||||||
streaming_started = False
|
streaming_started = False
|
||||||
|
buffer = b""
|
||||||
|
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
|
||||||
|
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
||||||
|
|
||||||
# 处理预读数据
|
# 处理预读数据
|
||||||
if prefetched_lines:
|
if prefetched_chunks:
|
||||||
if not streaming_started and self.on_streaming_start:
|
if not streaming_started and self.on_streaming_start:
|
||||||
self.on_streaming_start()
|
self.on_streaming_start()
|
||||||
streaming_started = True
|
streaming_started = True
|
||||||
|
|
||||||
for line in prefetched_lines:
|
for chunk in prefetched_chunks:
|
||||||
for chunk in self._process_line(ctx, sse_parser, line):
|
# 记录首字时间 (TTFB) - 在 yield 之前记录
|
||||||
|
if start_time is not None:
|
||||||
|
ctx.record_first_byte_time(start_time)
|
||||||
|
start_time = None # 只记录一次
|
||||||
|
|
||||||
|
# 把原始数据转发给客户端
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
buffer += chunk
|
||||||
|
# 处理缓冲区中的完整行
|
||||||
|
while b"\n" in buffer:
|
||||||
|
line_bytes, buffer = buffer.split(b"\n", 1)
|
||||||
|
try:
|
||||||
|
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
|
||||||
|
line = decoder.decode(line_bytes + b"\n", False)
|
||||||
|
self._process_line(ctx, sse_parser, line)
|
||||||
|
except Exception as e:
|
||||||
|
# 解码失败,记录警告但继续处理
|
||||||
|
logger.warning(
|
||||||
|
f"[{self.request_id}] UTF-8 解码失败: {e}, "
|
||||||
|
f"bytes={line_bytes[:50]!r}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
# 处理剩余的流数据
|
# 处理剩余的流数据
|
||||||
async for line in line_iterator:
|
async for chunk in byte_iterator:
|
||||||
if not streaming_started and self.on_streaming_start:
|
if not streaming_started and self.on_streaming_start:
|
||||||
self.on_streaming_start()
|
self.on_streaming_start()
|
||||||
streaming_started = True
|
streaming_started = True
|
||||||
|
|
||||||
for chunk in self._process_line(ctx, sse_parser, line):
|
# 记录首字时间 (TTFB) - 在 yield 之前记录(如果预读数据为空)
|
||||||
|
if start_time is not None:
|
||||||
|
ctx.record_first_byte_time(start_time)
|
||||||
|
start_time = None # 只记录一次
|
||||||
|
|
||||||
|
# 原始数据透传
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
buffer += chunk
|
||||||
|
# 处理缓冲区中的完整行
|
||||||
|
while b"\n" in buffer:
|
||||||
|
line_bytes, buffer = buffer.split(b"\n", 1)
|
||||||
|
try:
|
||||||
|
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
|
||||||
|
line = decoder.decode(line_bytes + b"\n", False)
|
||||||
|
self._process_line(ctx, sse_parser, line)
|
||||||
|
except Exception as e:
|
||||||
|
# 解码失败,记录警告但继续处理
|
||||||
|
logger.warning(
|
||||||
|
f"[{self.request_id}] UTF-8 解码失败: {e}, "
|
||||||
|
f"bytes={line_bytes[:50]!r}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 处理剩余的缓冲区数据(如果有未完成的行)
|
||||||
|
if buffer:
|
||||||
|
try:
|
||||||
|
# 使用 final=True 处理最后的不完整字符
|
||||||
|
line = decoder.decode(buffer, True)
|
||||||
|
self._process_line(ctx, sse_parser, line)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"[{self.request_id}] 处理剩余缓冲区失败: {e}, "
|
||||||
|
f"bytes={buffer[:50]!r}"
|
||||||
|
)
|
||||||
|
|
||||||
# 处理剩余事件
|
# 处理剩余事件
|
||||||
for event in sse_parser.flush():
|
for event in sse_parser.flush():
|
||||||
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
||||||
@@ -268,7 +360,7 @@ class StreamProcessor:
|
|||||||
ctx: StreamContext,
|
ctx: StreamContext,
|
||||||
sse_parser: SSEEventParser,
|
sse_parser: SSEEventParser,
|
||||||
line: str,
|
line: str,
|
||||||
) -> list[bytes]:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
处理单行数据
|
处理单行数据
|
||||||
|
|
||||||
@@ -276,27 +368,18 @@ class StreamProcessor:
|
|||||||
ctx: 流式上下文
|
ctx: 流式上下文
|
||||||
sse_parser: SSE 解析器
|
sse_parser: SSE 解析器
|
||||||
line: 原始行数据
|
line: 原始行数据
|
||||||
|
|
||||||
Returns:
|
|
||||||
要发送的数据块列表
|
|
||||||
"""
|
"""
|
||||||
result: list[bytes] = []
|
# SSEEventParser 以“去掉换行符”的单行文本作为输入;这里统一剔除 CR/LF,
|
||||||
normalized_line = line.rstrip("\r")
|
# 避免把空行误判成 "\n" 并导致事件边界解析错误。
|
||||||
|
normalized_line = line.rstrip("\r\n")
|
||||||
events = sse_parser.feed_line(normalized_line)
|
events = sse_parser.feed_line(normalized_line)
|
||||||
|
|
||||||
if normalized_line == "":
|
if normalized_line != "":
|
||||||
for event in events:
|
|
||||||
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
|
||||||
result.append(b"\n")
|
|
||||||
else:
|
|
||||||
ctx.chunk_count += 1
|
ctx.chunk_count += 1
|
||||||
result.append((line + "\n").encode("utf-8"))
|
|
||||||
|
|
||||||
for event in events:
|
for event in events:
|
||||||
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def create_monitored_stream(
|
async def create_monitored_stream(
|
||||||
self,
|
self,
|
||||||
ctx: StreamContext,
|
ctx: StreamContext,
|
||||||
@@ -317,16 +400,26 @@ class StreamProcessor:
|
|||||||
响应数据块
|
响应数据块
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# 断连检查频率:每次 await 都会引入调度开销,过于频繁会让流式"发一段停一段"
|
||||||
|
# 这里按时间间隔节流,兼顾及时停止上游读取与吞吐平滑性。
|
||||||
|
next_disconnect_check_at = 0.0
|
||||||
|
disconnect_check_interval_s = 0.25
|
||||||
|
|
||||||
async for chunk in stream_generator:
|
async for chunk in stream_generator:
|
||||||
|
now = time.monotonic()
|
||||||
|
if now >= next_disconnect_check_at:
|
||||||
|
next_disconnect_check_at = now + disconnect_check_interval_s
|
||||||
if await is_disconnected():
|
if await is_disconnected():
|
||||||
logger.warning(f"ID:{self.request_id} | Client disconnected")
|
logger.warning(f"ID:{self.request_id} | Client disconnected")
|
||||||
ctx.status_code = 499 # Client Closed Request
|
ctx.status_code = 499 # Client Closed Request
|
||||||
ctx.error_message = "client_disconnected"
|
ctx.error_message = "client_disconnected"
|
||||||
|
|
||||||
break
|
break
|
||||||
yield chunk
|
yield chunk
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
ctx.status_code = 499
|
ctx.status_code = 499
|
||||||
ctx.error_message = "client_disconnected"
|
ctx.error_message = "client_disconnected"
|
||||||
|
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
ctx.status_code = 500
|
ctx.status_code = 500
|
||||||
|
|||||||
@@ -8,6 +8,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import time
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@@ -57,7 +58,7 @@ class StreamTelemetryRecorder:
|
|||||||
ctx: StreamContext,
|
ctx: StreamContext,
|
||||||
original_headers: Dict[str, str],
|
original_headers: Dict[str, str],
|
||||||
original_request_body: Dict[str, Any],
|
original_request_body: Dict[str, Any],
|
||||||
response_time_ms: int,
|
start_time: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
记录流式统计信息
|
记录流式统计信息
|
||||||
@@ -66,11 +67,15 @@ class StreamTelemetryRecorder:
|
|||||||
ctx: 流式上下文
|
ctx: 流式上下文
|
||||||
original_headers: 原始请求头
|
original_headers: 原始请求头
|
||||||
original_request_body: 原始请求体
|
original_request_body: 原始请求体
|
||||||
response_time_ms: 响应时间(毫秒)
|
start_time: 请求开始时间 (time.time())
|
||||||
"""
|
"""
|
||||||
bg_db = None
|
bg_db = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# 在流结束后计算响应时间,与首字时间使用相同的时间基准
|
||||||
|
# 注意:不要把统计延迟(stream_stats_delay)算进响应时间里
|
||||||
|
response_time_ms = int((time.time() - start_time) * 1000)
|
||||||
|
|
||||||
await asyncio.sleep(config.stream_stats_delay) # 等待流完全关闭
|
await asyncio.sleep(config.stream_stats_delay) # 等待流完全关闭
|
||||||
|
|
||||||
if not ctx.provider_name:
|
if not ctx.provider_name:
|
||||||
@@ -155,6 +160,7 @@ class StreamTelemetryRecorder:
|
|||||||
input_tokens=ctx.input_tokens,
|
input_tokens=ctx.input_tokens,
|
||||||
output_tokens=ctx.output_tokens,
|
output_tokens=ctx.output_tokens,
|
||||||
response_time_ms=response_time_ms,
|
response_time_ms=response_time_ms,
|
||||||
|
first_byte_time_ms=ctx.first_byte_time_ms, # 传递首字时间
|
||||||
status_code=ctx.status_code,
|
status_code=ctx.status_code,
|
||||||
request_headers=original_headers,
|
request_headers=original_headers,
|
||||||
request_body=actual_request_body,
|
request_body=actual_request_body,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
Handler 基础工具函数
|
Handler 基础工具函数
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int:
|
def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int:
|
||||||
@@ -22,10 +22,34 @@ def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int:
|
|||||||
Returns:
|
Returns:
|
||||||
缓存创建 tokens 总数
|
缓存创建 tokens 总数
|
||||||
"""
|
"""
|
||||||
# 优先使用新格式
|
# 检查新格式字段是否存在(而非值是否为 0)
|
||||||
|
# 如果字段存在,即使值为 0 也是合法的,不应 fallback 到旧格式
|
||||||
|
has_new_format = (
|
||||||
|
"claude_cache_creation_5_m_tokens" in usage
|
||||||
|
or "claude_cache_creation_1_h_tokens" in usage
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_new_format:
|
||||||
cache_5m = usage.get("claude_cache_creation_5_m_tokens", 0)
|
cache_5m = usage.get("claude_cache_creation_5_m_tokens", 0)
|
||||||
cache_1h = usage.get("claude_cache_creation_1_h_tokens", 0)
|
cache_1h = usage.get("claude_cache_creation_1_h_tokens", 0)
|
||||||
total = int(cache_5m) + int(cache_1h)
|
return int(cache_5m) + int(cache_1h)
|
||||||
|
|
||||||
# 如果新格式不存在(total == 0),回退到旧格式
|
# 回退到旧格式
|
||||||
return total if total > 0 else int(usage.get("cache_creation_input_tokens", 0))
|
return int(usage.get("cache_creation_input_tokens", 0))
|
||||||
|
|
||||||
|
|
||||||
|
def build_sse_headers(extra_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
构建 SSE(text/event-stream)推荐响应头,用于减少代理缓冲带来的卡顿/成段输出。
|
||||||
|
|
||||||
|
说明:
|
||||||
|
- Cache-Control: no-transform 可避免部分代理对流做压缩/改写导致缓冲
|
||||||
|
- X-Accel-Buffering: no 可显式提示 Nginx 关闭缓冲(即使全局已关闭也无害)
|
||||||
|
"""
|
||||||
|
headers: Dict[str, str] = {
|
||||||
|
"Cache-Control": "no-cache, no-transform",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
}
|
||||||
|
if extra_headers:
|
||||||
|
headers.update(extra_headers)
|
||||||
|
return headers
|
||||||
|
|||||||
117
tests/api/handlers/base/test_stream_context.py
Normal file
117
tests/api/handlers/base/test_stream_context.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
from src.api.handlers.base import stream_context
|
||||||
|
from src.api.handlers.base.stream_context import StreamContext
|
||||||
|
|
||||||
|
|
||||||
|
def test_collected_text_append_and_property() -> None:
|
||||||
|
ctx = StreamContext(model="test-model", api_format="OPENAI")
|
||||||
|
assert ctx.collected_text == ""
|
||||||
|
|
||||||
|
ctx.append_text("hello")
|
||||||
|
ctx.append_text(" ")
|
||||||
|
ctx.append_text("world")
|
||||||
|
assert ctx.collected_text == "hello world"
|
||||||
|
|
||||||
|
|
||||||
|
def test_reset_for_retry_clears_state() -> None:
|
||||||
|
ctx = StreamContext(model="test-model", api_format="OPENAI")
|
||||||
|
ctx.append_text("x")
|
||||||
|
ctx.update_usage(input_tokens=10, output_tokens=5)
|
||||||
|
ctx.parsed_chunks.append({"type": "chunk"})
|
||||||
|
ctx.chunk_count = 3
|
||||||
|
ctx.data_count = 2
|
||||||
|
ctx.has_completion = True
|
||||||
|
ctx.status_code = 418
|
||||||
|
ctx.error_message = "boom"
|
||||||
|
|
||||||
|
ctx.reset_for_retry()
|
||||||
|
|
||||||
|
assert ctx.collected_text == ""
|
||||||
|
assert ctx.input_tokens == 0
|
||||||
|
assert ctx.output_tokens == 0
|
||||||
|
assert ctx.parsed_chunks == []
|
||||||
|
assert ctx.chunk_count == 0
|
||||||
|
assert ctx.data_count == 0
|
||||||
|
assert ctx.has_completion is False
|
||||||
|
assert ctx.status_code == 200
|
||||||
|
assert ctx.error_message is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_record_first_byte_time(monkeypatch) -> None:
|
||||||
|
"""测试记录首字时间"""
|
||||||
|
ctx = StreamContext(model="claude-3", api_format="claude_messages")
|
||||||
|
start_time = 100.0
|
||||||
|
monkeypatch.setattr(stream_context.time, "time", lambda: 100.0123) # 12.3ms
|
||||||
|
|
||||||
|
# 记录首字时间
|
||||||
|
ctx.record_first_byte_time(start_time)
|
||||||
|
|
||||||
|
# 验证首字时间已记录
|
||||||
|
assert ctx.first_byte_time_ms == 12
|
||||||
|
|
||||||
|
|
||||||
|
def test_record_first_byte_time_idempotent(monkeypatch) -> None:
|
||||||
|
"""测试首字时间只记录一次"""
|
||||||
|
ctx = StreamContext(model="claude-3", api_format="claude_messages")
|
||||||
|
start_time = 100.0
|
||||||
|
|
||||||
|
# 第一次记录
|
||||||
|
monkeypatch.setattr(stream_context.time, "time", lambda: 100.010)
|
||||||
|
ctx.record_first_byte_time(start_time)
|
||||||
|
first_value = ctx.first_byte_time_ms
|
||||||
|
|
||||||
|
# 第二次记录(应该被忽略)
|
||||||
|
monkeypatch.setattr(stream_context.time, "time", lambda: 100.020)
|
||||||
|
ctx.record_first_byte_time(start_time)
|
||||||
|
second_value = ctx.first_byte_time_ms
|
||||||
|
|
||||||
|
# 验证值没有改变
|
||||||
|
assert first_value == second_value
|
||||||
|
|
||||||
|
|
||||||
|
def test_reset_for_retry_clears_first_byte_time(monkeypatch) -> None:
|
||||||
|
"""测试重试时清除首字时间"""
|
||||||
|
ctx = StreamContext(model="claude-3", api_format="claude_messages")
|
||||||
|
start_time = 100.0
|
||||||
|
|
||||||
|
# 记录首字时间
|
||||||
|
monkeypatch.setattr(stream_context.time, "time", lambda: 100.010)
|
||||||
|
ctx.record_first_byte_time(start_time)
|
||||||
|
assert ctx.first_byte_time_ms is not None
|
||||||
|
|
||||||
|
# 重置
|
||||||
|
ctx.reset_for_retry()
|
||||||
|
|
||||||
|
# 验证首字时间已清除
|
||||||
|
assert ctx.first_byte_time_ms is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_log_summary_with_first_byte_time() -> None:
|
||||||
|
"""测试日志摘要包含首字时间"""
|
||||||
|
ctx = StreamContext(model="claude-3", api_format="claude_messages")
|
||||||
|
ctx.provider_name = "anthropic"
|
||||||
|
ctx.input_tokens = 100
|
||||||
|
ctx.output_tokens = 50
|
||||||
|
ctx.first_byte_time_ms = 123
|
||||||
|
|
||||||
|
summary = ctx.get_log_summary("request-id-123", 456)
|
||||||
|
|
||||||
|
# 验证包含首字时间和总时间(大写格式)
|
||||||
|
assert "TTFB: 123ms" in summary
|
||||||
|
assert "Total: 456ms" in summary
|
||||||
|
assert "in:100 out:50" in summary
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_log_summary_without_first_byte_time() -> None:
|
||||||
|
"""测试日志摘要在没有首字时间时的格式"""
|
||||||
|
ctx = StreamContext(model="claude-3", api_format="claude_messages")
|
||||||
|
ctx.provider_name = "anthropic"
|
||||||
|
ctx.input_tokens = 100
|
||||||
|
ctx.output_tokens = 50
|
||||||
|
# first_byte_time_ms 保持为 None
|
||||||
|
|
||||||
|
summary = ctx.get_log_summary("request-id-123", 456)
|
||||||
|
|
||||||
|
# 验证不包含首字时间标记,但有总时间(使用大写 TTFB 和 Total)
|
||||||
|
assert "TTFB:" not in summary
|
||||||
|
assert "Total: 456ms" in summary
|
||||||
|
assert "in:100 out:50" in summary
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from src.api.handlers.base.utils import extract_cache_creation_tokens
|
from src.api.handlers.base.utils import build_sse_headers, extract_cache_creation_tokens
|
||||||
|
|
||||||
|
|
||||||
class TestExtractCacheCreationTokens:
|
class TestExtractCacheCreationTokens:
|
||||||
@@ -69,14 +69,16 @@ class TestExtractCacheCreationTokens:
|
|||||||
}
|
}
|
||||||
assert extract_cache_creation_tokens(usage) == 123
|
assert extract_cache_creation_tokens(usage) == 123
|
||||||
|
|
||||||
def test_new_format_zero_fallback_to_old(self) -> None:
|
def test_new_format_zero_should_not_fallback(self) -> None:
|
||||||
"""测试新格式为 0 时回退到旧格式"""
|
"""测试新格式字段存在但为 0 时,不应 fallback 到旧格式"""
|
||||||
usage = {
|
usage = {
|
||||||
"claude_cache_creation_5_m_tokens": 0,
|
"claude_cache_creation_5_m_tokens": 0,
|
||||||
"claude_cache_creation_1_h_tokens": 0,
|
"claude_cache_creation_1_h_tokens": 0,
|
||||||
"cache_creation_input_tokens": 456,
|
"cache_creation_input_tokens": 456,
|
||||||
}
|
}
|
||||||
assert extract_cache_creation_tokens(usage) == 456
|
# 新格式字段存在,即使值为 0 也应该使用新格式(返回 0)
|
||||||
|
# 而不是 fallback 到旧格式(返回 456)
|
||||||
|
assert extract_cache_creation_tokens(usage) == 0
|
||||||
|
|
||||||
def test_unrelated_fields_ignored(self) -> None:
|
def test_unrelated_fields_ignored(self) -> None:
|
||||||
"""测试忽略无关字段"""
|
"""测试忽略无关字段"""
|
||||||
@@ -88,3 +90,15 @@ class TestExtractCacheCreationTokens:
|
|||||||
"claude_cache_creation_1_h_tokens": 75,
|
"claude_cache_creation_1_h_tokens": 75,
|
||||||
}
|
}
|
||||||
assert extract_cache_creation_tokens(usage) == 125
|
assert extract_cache_creation_tokens(usage) == 125
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildSSEHeaders:
|
||||||
|
def test_default_headers(self) -> None:
|
||||||
|
headers = build_sse_headers()
|
||||||
|
assert headers["Cache-Control"] == "no-cache, no-transform"
|
||||||
|
assert headers["X-Accel-Buffering"] == "no"
|
||||||
|
|
||||||
|
def test_merge_extra_headers(self) -> None:
|
||||||
|
headers = build_sse_headers({"X-Test": "1", "Cache-Control": "custom"})
|
||||||
|
assert headers["X-Test"] == "1"
|
||||||
|
assert headers["Cache-Control"] == "custom"
|
||||||
|
|||||||
Reference in New Issue
Block a user