Files
Aether/src/api/handlers/base/stream_processor.py
fawney19 50ffd47546 fix: handle client disconnection after stream completion gracefully
- Check has_completion flag before marking client disconnection as failure
- Allow graceful termination if response already completed when client disconnects
- Change logging level to info for post-completion disconnections
- Prevent false error reporting when client closes connection after receiving full response
2025-12-19 01:36:20 +08:00

449 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
流式处理器 - 从 ChatHandlerBase 提取的流式响应处理逻辑
职责:
1. SSE 事件解析和处理
2. 响应流生成
3. 预读和嵌套错误检测
4. 客户端断开检测
"""
import asyncio
import codecs
import json
import time
from typing import Any, AsyncGenerator, Callable, Optional
import httpx
from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.response_parser import ResponseParser
from src.api.handlers.base.stream_context import StreamContext
from src.core.exceptions import EmbeddedErrorException
from src.core.logger import logger
from src.models.database import Provider, ProviderEndpoint
from src.utils.sse_parser import SSEEventParser
class StreamProcessor:
"""
流式响应处理器
负责处理 SSE 流的解析、错误检测和响应生成。
从 ChatHandlerBase 中提取,使其职责更加单一。
"""
def __init__(
self,
request_id: str,
default_parser: ResponseParser,
on_streaming_start: Optional[Callable[[], None]] = None,
*,
collect_text: bool = False,
):
"""
初始化流处理器
Args:
request_id: 请求 ID用于日志
default_parser: 默认响应解析器
on_streaming_start: 流开始时的回调(用于更新状态)
"""
self.request_id = request_id
self.default_parser = default_parser
self.on_streaming_start = on_streaming_start
self.collect_text = collect_text
def get_parser_for_provider(self, ctx: StreamContext) -> ResponseParser:
"""
获取 Provider 格式的解析器
根据 Provider 的 API 格式选择正确的解析器。
"""
if ctx.provider_api_format:
try:
return get_parser_for_format(ctx.provider_api_format)
except KeyError:
pass
return self.default_parser
def handle_sse_event(
self,
ctx: StreamContext,
event_name: Optional[str],
data_str: str,
) -> None:
"""
处理单个 SSE 事件
解析事件数据,提取 usage 信息和文本内容。
Args:
ctx: 流式上下文
event_name: 事件名称
data_str: 事件数据字符串
"""
if not data_str:
return
if data_str == "[DONE]":
ctx.has_completion = True
return
try:
data = json.loads(data_str)
except json.JSONDecodeError:
return
ctx.data_count += 1
if not isinstance(data, dict):
return
# 收集原始 chunk 数据
ctx.parsed_chunks.append(data)
# 根据 Provider 格式选择解析器
parser = self.get_parser_for_provider(ctx)
# 使用解析器提取 usage
usage = parser.extract_usage_from_response(data)
if usage:
ctx.update_usage(
input_tokens=usage.get("input_tokens"),
output_tokens=usage.get("output_tokens"),
cached_tokens=usage.get("cache_read_tokens"),
cache_creation_tokens=usage.get("cache_creation_tokens"),
)
# 提取文本
if self.collect_text:
text = parser.extract_text_content(data)
if text:
ctx.append_text(text)
# 检查完成
event_type = event_name or data.get("type", "")
if event_type in ("response.completed", "message_stop"):
ctx.has_completion = True
async def prefetch_and_check_error(
self,
byte_iterator: Any,
provider: Provider,
endpoint: ProviderEndpoint,
ctx: StreamContext,
max_prefetch_lines: int = 5,
) -> list:
"""
预读流的前几行,检测嵌套错误
某些 Provider如 Gemini可能返回 HTTP 200但在响应体中包含错误信息。
这种情况需要在流开始输出之前检测,以便触发重试逻辑。
Args:
byte_iterator: 字节流迭代器
provider: Provider 对象
endpoint: Endpoint 对象
ctx: 流式上下文
max_prefetch_lines: 最多预读行数
Returns:
预读的字节块列表
Raises:
EmbeddedErrorException: 如果检测到嵌套错误
"""
prefetched_chunks: list = []
parser = self.get_parser_for_provider(ctx)
buffer = b""
line_count = 0
should_stop = False
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
try:
async for chunk in byte_iterator:
prefetched_chunks.append(chunk)
buffer += chunk
# 尝试按行解析缓冲区
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False).rstrip("\r\n")
except Exception as e:
logger.warning(
f"[{self.request_id}] 预读时 UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
line_count += 1
# 跳过空行和注释行
if not line or line.startswith(":"):
if line_count >= max_prefetch_lines:
should_stop = True
break
continue
# 尝试解析 SSE 数据
data_str = line
if line.startswith("data: "):
data_str = line[6:]
if data_str == "[DONE]":
should_stop = True
break
try:
data = json.loads(data_str)
except json.JSONDecodeError:
if line_count >= max_prefetch_lines:
should_stop = True
break
continue
# 使用解析器检查是否为错误响应
if isinstance(data, dict) and parser.is_error_response(data):
parsed = parser.parse_response(data, 200)
logger.warning(
f" [{self.request_id}] 检测到嵌套错误: "
f"Provider={provider.name}, "
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}"
)
raise EmbeddedErrorException(
provider_name=str(provider.name),
error_code=(
int(parsed.error_type)
if parsed.error_type and parsed.error_type.isdigit()
else None
),
error_message=parsed.error_message,
error_status=parsed.error_type,
)
# 预读到有效数据,没有错误,停止预读
should_stop = True
break
if should_stop or line_count >= max_prefetch_lines:
break
except EmbeddedErrorException:
raise
except Exception as e:
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
return prefetched_chunks
async def create_response_stream(
self,
ctx: StreamContext,
byte_iterator: Any,
response_ctx: Any,
http_client: httpx.AsyncClient,
prefetched_chunks: Optional[list] = None,
*,
start_time: Optional[float] = None,
) -> AsyncGenerator[bytes, None]:
"""
创建响应流生成器
从字节流中解析 SSE 数据并转发,支持预读数据。
Args:
ctx: 流式上下文
byte_iterator: 字节流迭代器
response_ctx: HTTP 响应上下文管理器
http_client: HTTP 客户端
prefetched_chunks: 预读的字节块列表(可选)
start_time: 请求开始时间,用于计算 TTFB可选
Yields:
编码后的响应数据块
"""
try:
sse_parser = SSEEventParser()
streaming_started = False
buffer = b""
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
# 处理预读数据
if prefetched_chunks:
if not streaming_started and self.on_streaming_start:
self.on_streaming_start()
streaming_started = True
for chunk in prefetched_chunks:
# 记录首字时间 (TTFB) - 在 yield 之前记录
if start_time is not None:
ctx.record_first_byte_time(start_time)
start_time = None # 只记录一次
# 把原始数据转发给客户端
yield chunk
buffer += chunk
# 处理缓冲区中的完整行
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False)
self._process_line(ctx, sse_parser, line)
except Exception as e:
# 解码失败,记录警告但继续处理
logger.warning(
f"[{self.request_id}] UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
# 处理剩余的流数据
async for chunk in byte_iterator:
if not streaming_started and self.on_streaming_start:
self.on_streaming_start()
streaming_started = True
# 记录首字时间 (TTFB) - 在 yield 之前记录(如果预读数据为空)
if start_time is not None:
ctx.record_first_byte_time(start_time)
start_time = None # 只记录一次
# 原始数据透传
yield chunk
buffer += chunk
# 处理缓冲区中的完整行
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False)
self._process_line(ctx, sse_parser, line)
except Exception as e:
# 解码失败,记录警告但继续处理
logger.warning(
f"[{self.request_id}] UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
# 处理剩余的缓冲区数据(如果有未完成的行)
if buffer:
try:
# 使用 final=True 处理最后的不完整字符
line = decoder.decode(buffer, True)
self._process_line(ctx, sse_parser, line)
except Exception as e:
logger.warning(
f"[{self.request_id}] 处理剩余缓冲区失败: {e}, "
f"bytes={buffer[:50]!r}"
)
# 处理剩余事件
for event in sse_parser.flush():
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
except GeneratorExit:
raise
finally:
await self._cleanup(response_ctx, http_client)
def _process_line(
self,
ctx: StreamContext,
sse_parser: SSEEventParser,
line: str,
) -> None:
"""
处理单行数据
Args:
ctx: 流式上下文
sse_parser: SSE 解析器
line: 原始行数据
"""
# SSEEventParser 以“去掉换行符”的单行文本作为输入;这里统一剔除 CR/LF
# 避免把空行误判成 "\n" 并导致事件边界解析错误。
normalized_line = line.rstrip("\r\n")
events = sse_parser.feed_line(normalized_line)
if normalized_line != "":
ctx.chunk_count += 1
for event in events:
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
async def create_monitored_stream(
self,
ctx: StreamContext,
stream_generator: AsyncGenerator[bytes, None],
is_disconnected: Callable[[], Any],
) -> AsyncGenerator[bytes, None]:
"""
创建带监控的流生成器
检测客户端断开连接并更新状态码。
Args:
ctx: 流式上下文
stream_generator: 原始流生成器
is_disconnected: 检查客户端是否断开的函数
Yields:
响应数据块
"""
try:
# 断连检查频率:每次 await 都会引入调度开销,过于频繁会让流式"发一段停一段"
# 这里按时间间隔节流,兼顾及时停止上游读取与吞吐平滑性。
next_disconnect_check_at = 0.0
disconnect_check_interval_s = 0.25
async for chunk in stream_generator:
now = time.monotonic()
if now >= next_disconnect_check_at:
next_disconnect_check_at = now + disconnect_check_interval_s
if await is_disconnected():
# 如果响应已完成(收到 finish_reason客户端断开不算失败
if ctx.has_completion:
logger.info(
f"ID:{self.request_id} | Client disconnected after completion"
)
else:
logger.warning(f"ID:{self.request_id} | Client disconnected")
ctx.status_code = 499 # Client Closed Request
ctx.error_message = "client_disconnected"
break
yield chunk
except asyncio.CancelledError:
# 如果响应已完成,不标记为失败
if not ctx.has_completion:
ctx.status_code = 499
ctx.error_message = "client_disconnected"
raise
except Exception as e:
ctx.status_code = 500
ctx.error_message = str(e)
raise
async def _cleanup(
self,
response_ctx: Any,
http_client: httpx.AsyncClient,
) -> None:
"""清理资源"""
try:
await response_ctx.__aexit__(None, None, None)
except Exception:
pass
try:
await http_client.aclose()
except Exception:
pass