mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 00:02:28 +08:00
- Check has_completion flag before marking client disconnection as failure - Allow graceful termination if response already completed when client disconnects - Change logging level to info for post-completion disconnections - Prevent false error reporting when client closes connection after receiving full response
449 lines
16 KiB
Python
449 lines
16 KiB
Python
"""
|
||
流式处理器 - 从 ChatHandlerBase 提取的流式响应处理逻辑
|
||
|
||
职责:
|
||
1. SSE 事件解析和处理
|
||
2. 响应流生成
|
||
3. 预读和嵌套错误检测
|
||
4. 客户端断开检测
|
||
"""
|
||
|
||
import asyncio
|
||
import codecs
|
||
import json
|
||
import time
|
||
from typing import Any, AsyncGenerator, Callable, Optional
|
||
|
||
import httpx
|
||
|
||
from src.api.handlers.base.parsers import get_parser_for_format
|
||
from src.api.handlers.base.response_parser import ResponseParser
|
||
from src.api.handlers.base.stream_context import StreamContext
|
||
from src.core.exceptions import EmbeddedErrorException
|
||
from src.core.logger import logger
|
||
from src.models.database import Provider, ProviderEndpoint
|
||
from src.utils.sse_parser import SSEEventParser
|
||
|
||
|
||
class StreamProcessor:
|
||
"""
|
||
流式响应处理器
|
||
|
||
负责处理 SSE 流的解析、错误检测和响应生成。
|
||
从 ChatHandlerBase 中提取,使其职责更加单一。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
request_id: str,
|
||
default_parser: ResponseParser,
|
||
on_streaming_start: Optional[Callable[[], None]] = None,
|
||
*,
|
||
collect_text: bool = False,
|
||
):
|
||
"""
|
||
初始化流处理器
|
||
|
||
Args:
|
||
request_id: 请求 ID(用于日志)
|
||
default_parser: 默认响应解析器
|
||
on_streaming_start: 流开始时的回调(用于更新状态)
|
||
"""
|
||
self.request_id = request_id
|
||
self.default_parser = default_parser
|
||
self.on_streaming_start = on_streaming_start
|
||
self.collect_text = collect_text
|
||
|
||
def get_parser_for_provider(self, ctx: StreamContext) -> ResponseParser:
|
||
"""
|
||
获取 Provider 格式的解析器
|
||
|
||
根据 Provider 的 API 格式选择正确的解析器。
|
||
"""
|
||
if ctx.provider_api_format:
|
||
try:
|
||
return get_parser_for_format(ctx.provider_api_format)
|
||
except KeyError:
|
||
pass
|
||
return self.default_parser
|
||
|
||
def handle_sse_event(
|
||
self,
|
||
ctx: StreamContext,
|
||
event_name: Optional[str],
|
||
data_str: str,
|
||
) -> None:
|
||
"""
|
||
处理单个 SSE 事件
|
||
|
||
解析事件数据,提取 usage 信息和文本内容。
|
||
|
||
Args:
|
||
ctx: 流式上下文
|
||
event_name: 事件名称
|
||
data_str: 事件数据字符串
|
||
"""
|
||
if not data_str:
|
||
return
|
||
|
||
if data_str == "[DONE]":
|
||
ctx.has_completion = True
|
||
return
|
||
|
||
try:
|
||
data = json.loads(data_str)
|
||
except json.JSONDecodeError:
|
||
return
|
||
|
||
ctx.data_count += 1
|
||
|
||
if not isinstance(data, dict):
|
||
return
|
||
|
||
# 收集原始 chunk 数据
|
||
ctx.parsed_chunks.append(data)
|
||
|
||
# 根据 Provider 格式选择解析器
|
||
parser = self.get_parser_for_provider(ctx)
|
||
|
||
# 使用解析器提取 usage
|
||
usage = parser.extract_usage_from_response(data)
|
||
if usage:
|
||
ctx.update_usage(
|
||
input_tokens=usage.get("input_tokens"),
|
||
output_tokens=usage.get("output_tokens"),
|
||
cached_tokens=usage.get("cache_read_tokens"),
|
||
cache_creation_tokens=usage.get("cache_creation_tokens"),
|
||
)
|
||
|
||
# 提取文本
|
||
if self.collect_text:
|
||
text = parser.extract_text_content(data)
|
||
if text:
|
||
ctx.append_text(text)
|
||
|
||
# 检查完成
|
||
event_type = event_name or data.get("type", "")
|
||
if event_type in ("response.completed", "message_stop"):
|
||
ctx.has_completion = True
|
||
|
||
async def prefetch_and_check_error(
|
||
self,
|
||
byte_iterator: Any,
|
||
provider: Provider,
|
||
endpoint: ProviderEndpoint,
|
||
ctx: StreamContext,
|
||
max_prefetch_lines: int = 5,
|
||
) -> list:
|
||
"""
|
||
预读流的前几行,检测嵌套错误
|
||
|
||
某些 Provider(如 Gemini)可能返回 HTTP 200,但在响应体中包含错误信息。
|
||
这种情况需要在流开始输出之前检测,以便触发重试逻辑。
|
||
|
||
Args:
|
||
byte_iterator: 字节流迭代器
|
||
provider: Provider 对象
|
||
endpoint: Endpoint 对象
|
||
ctx: 流式上下文
|
||
max_prefetch_lines: 最多预读行数
|
||
|
||
Returns:
|
||
预读的字节块列表
|
||
|
||
Raises:
|
||
EmbeddedErrorException: 如果检测到嵌套错误
|
||
"""
|
||
prefetched_chunks: list = []
|
||
parser = self.get_parser_for_provider(ctx)
|
||
buffer = b""
|
||
line_count = 0
|
||
should_stop = False
|
||
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
|
||
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
||
|
||
try:
|
||
async for chunk in byte_iterator:
|
||
prefetched_chunks.append(chunk)
|
||
buffer += chunk
|
||
|
||
# 尝试按行解析缓冲区
|
||
while b"\n" in buffer:
|
||
line_bytes, buffer = buffer.split(b"\n", 1)
|
||
try:
|
||
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
|
||
line = decoder.decode(line_bytes + b"\n", False).rstrip("\r\n")
|
||
except Exception as e:
|
||
logger.warning(
|
||
f"[{self.request_id}] 预读时 UTF-8 解码失败: {e}, "
|
||
f"bytes={line_bytes[:50]!r}"
|
||
)
|
||
continue
|
||
|
||
line_count += 1
|
||
|
||
# 跳过空行和注释行
|
||
if not line or line.startswith(":"):
|
||
if line_count >= max_prefetch_lines:
|
||
should_stop = True
|
||
break
|
||
continue
|
||
|
||
# 尝试解析 SSE 数据
|
||
data_str = line
|
||
if line.startswith("data: "):
|
||
data_str = line[6:]
|
||
|
||
if data_str == "[DONE]":
|
||
should_stop = True
|
||
break
|
||
|
||
try:
|
||
data = json.loads(data_str)
|
||
except json.JSONDecodeError:
|
||
if line_count >= max_prefetch_lines:
|
||
should_stop = True
|
||
break
|
||
continue
|
||
|
||
# 使用解析器检查是否为错误响应
|
||
if isinstance(data, dict) and parser.is_error_response(data):
|
||
parsed = parser.parse_response(data, 200)
|
||
logger.warning(
|
||
f" [{self.request_id}] 检测到嵌套错误: "
|
||
f"Provider={provider.name}, "
|
||
f"error_type={parsed.error_type}, "
|
||
f"message={parsed.error_message}"
|
||
)
|
||
raise EmbeddedErrorException(
|
||
provider_name=str(provider.name),
|
||
error_code=(
|
||
int(parsed.error_type)
|
||
if parsed.error_type and parsed.error_type.isdigit()
|
||
else None
|
||
),
|
||
error_message=parsed.error_message,
|
||
error_status=parsed.error_type,
|
||
)
|
||
|
||
# 预读到有效数据,没有错误,停止预读
|
||
should_stop = True
|
||
break
|
||
|
||
if should_stop or line_count >= max_prefetch_lines:
|
||
break
|
||
|
||
except EmbeddedErrorException:
|
||
raise
|
||
except Exception as e:
|
||
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
|
||
|
||
return prefetched_chunks
|
||
|
||
async def create_response_stream(
|
||
self,
|
||
ctx: StreamContext,
|
||
byte_iterator: Any,
|
||
response_ctx: Any,
|
||
http_client: httpx.AsyncClient,
|
||
prefetched_chunks: Optional[list] = None,
|
||
*,
|
||
start_time: Optional[float] = None,
|
||
) -> AsyncGenerator[bytes, None]:
|
||
"""
|
||
创建响应流生成器
|
||
|
||
从字节流中解析 SSE 数据并转发,支持预读数据。
|
||
|
||
Args:
|
||
ctx: 流式上下文
|
||
byte_iterator: 字节流迭代器
|
||
response_ctx: HTTP 响应上下文管理器
|
||
http_client: HTTP 客户端
|
||
prefetched_chunks: 预读的字节块列表(可选)
|
||
start_time: 请求开始时间,用于计算 TTFB(可选)
|
||
|
||
Yields:
|
||
编码后的响应数据块
|
||
"""
|
||
try:
|
||
sse_parser = SSEEventParser()
|
||
streaming_started = False
|
||
buffer = b""
|
||
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
|
||
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
||
|
||
# 处理预读数据
|
||
if prefetched_chunks:
|
||
if not streaming_started and self.on_streaming_start:
|
||
self.on_streaming_start()
|
||
streaming_started = True
|
||
|
||
for chunk in prefetched_chunks:
|
||
# 记录首字时间 (TTFB) - 在 yield 之前记录
|
||
if start_time is not None:
|
||
ctx.record_first_byte_time(start_time)
|
||
start_time = None # 只记录一次
|
||
|
||
# 把原始数据转发给客户端
|
||
yield chunk
|
||
|
||
buffer += chunk
|
||
# 处理缓冲区中的完整行
|
||
while b"\n" in buffer:
|
||
line_bytes, buffer = buffer.split(b"\n", 1)
|
||
try:
|
||
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
|
||
line = decoder.decode(line_bytes + b"\n", False)
|
||
self._process_line(ctx, sse_parser, line)
|
||
except Exception as e:
|
||
# 解码失败,记录警告但继续处理
|
||
logger.warning(
|
||
f"[{self.request_id}] UTF-8 解码失败: {e}, "
|
||
f"bytes={line_bytes[:50]!r}"
|
||
)
|
||
continue
|
||
|
||
# 处理剩余的流数据
|
||
async for chunk in byte_iterator:
|
||
if not streaming_started and self.on_streaming_start:
|
||
self.on_streaming_start()
|
||
streaming_started = True
|
||
|
||
# 记录首字时间 (TTFB) - 在 yield 之前记录(如果预读数据为空)
|
||
if start_time is not None:
|
||
ctx.record_first_byte_time(start_time)
|
||
start_time = None # 只记录一次
|
||
|
||
# 原始数据透传
|
||
yield chunk
|
||
|
||
buffer += chunk
|
||
# 处理缓冲区中的完整行
|
||
while b"\n" in buffer:
|
||
line_bytes, buffer = buffer.split(b"\n", 1)
|
||
try:
|
||
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
|
||
line = decoder.decode(line_bytes + b"\n", False)
|
||
self._process_line(ctx, sse_parser, line)
|
||
except Exception as e:
|
||
# 解码失败,记录警告但继续处理
|
||
logger.warning(
|
||
f"[{self.request_id}] UTF-8 解码失败: {e}, "
|
||
f"bytes={line_bytes[:50]!r}"
|
||
)
|
||
continue
|
||
|
||
# 处理剩余的缓冲区数据(如果有未完成的行)
|
||
if buffer:
|
||
try:
|
||
# 使用 final=True 处理最后的不完整字符
|
||
line = decoder.decode(buffer, True)
|
||
self._process_line(ctx, sse_parser, line)
|
||
except Exception as e:
|
||
logger.warning(
|
||
f"[{self.request_id}] 处理剩余缓冲区失败: {e}, "
|
||
f"bytes={buffer[:50]!r}"
|
||
)
|
||
|
||
# 处理剩余事件
|
||
for event in sse_parser.flush():
|
||
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
||
|
||
except GeneratorExit:
|
||
raise
|
||
finally:
|
||
await self._cleanup(response_ctx, http_client)
|
||
|
||
def _process_line(
|
||
self,
|
||
ctx: StreamContext,
|
||
sse_parser: SSEEventParser,
|
||
line: str,
|
||
) -> None:
|
||
"""
|
||
处理单行数据
|
||
|
||
Args:
|
||
ctx: 流式上下文
|
||
sse_parser: SSE 解析器
|
||
line: 原始行数据
|
||
"""
|
||
# SSEEventParser 以“去掉换行符”的单行文本作为输入;这里统一剔除 CR/LF,
|
||
# 避免把空行误判成 "\n" 并导致事件边界解析错误。
|
||
normalized_line = line.rstrip("\r\n")
|
||
events = sse_parser.feed_line(normalized_line)
|
||
|
||
if normalized_line != "":
|
||
ctx.chunk_count += 1
|
||
|
||
for event in events:
|
||
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
||
|
||
async def create_monitored_stream(
|
||
self,
|
||
ctx: StreamContext,
|
||
stream_generator: AsyncGenerator[bytes, None],
|
||
is_disconnected: Callable[[], Any],
|
||
) -> AsyncGenerator[bytes, None]:
|
||
"""
|
||
创建带监控的流生成器
|
||
|
||
检测客户端断开连接并更新状态码。
|
||
|
||
Args:
|
||
ctx: 流式上下文
|
||
stream_generator: 原始流生成器
|
||
is_disconnected: 检查客户端是否断开的函数
|
||
|
||
Yields:
|
||
响应数据块
|
||
"""
|
||
try:
|
||
# 断连检查频率:每次 await 都会引入调度开销,过于频繁会让流式"发一段停一段"
|
||
# 这里按时间间隔节流,兼顾及时停止上游读取与吞吐平滑性。
|
||
next_disconnect_check_at = 0.0
|
||
disconnect_check_interval_s = 0.25
|
||
|
||
async for chunk in stream_generator:
|
||
now = time.monotonic()
|
||
if now >= next_disconnect_check_at:
|
||
next_disconnect_check_at = now + disconnect_check_interval_s
|
||
if await is_disconnected():
|
||
# 如果响应已完成(收到 finish_reason),客户端断开不算失败
|
||
if ctx.has_completion:
|
||
logger.info(
|
||
f"ID:{self.request_id} | Client disconnected after completion"
|
||
)
|
||
else:
|
||
logger.warning(f"ID:{self.request_id} | Client disconnected")
|
||
ctx.status_code = 499 # Client Closed Request
|
||
ctx.error_message = "client_disconnected"
|
||
break
|
||
yield chunk
|
||
except asyncio.CancelledError:
|
||
# 如果响应已完成,不标记为失败
|
||
if not ctx.has_completion:
|
||
ctx.status_code = 499
|
||
ctx.error_message = "client_disconnected"
|
||
raise
|
||
except Exception as e:
|
||
ctx.status_code = 500
|
||
ctx.error_message = str(e)
|
||
raise
|
||
|
||
async def _cleanup(
|
||
self,
|
||
response_ctx: Any,
|
||
http_client: httpx.AsyncClient,
|
||
) -> None:
|
||
"""清理资源"""
|
||
try:
|
||
await response_ctx.__aexit__(None, None, None)
|
||
except Exception:
|
||
pass
|
||
try:
|
||
await http_client.aclose()
|
||
except Exception:
|
||
pass
|