mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 16:22:27 +08:00
- Add stream first byte timeout (TTFB) detection to trigger failover when provider responds too slowly (configurable via STREAM_FIRST_BYTE_TIMEOUT) - Add rate limit fail-open/fail-close strategy configuration - Improve exception handling in stream prefetch with proper error classification - Refactor UsageService with shared _prepare_usage_record method - Add batch deletion for old usage records to avoid long transaction locks - Update CLI adapters to use proper User-Agent headers for each CLI client - Add composite indexes migration for usage table query optimization - Fix streaming status display in frontend to show TTFB during streaming - Remove sensitive JWT secret logging in auth service
794 lines
27 KiB
Python
794 lines
27 KiB
Python
"""
|
||
流式处理器 - 从 ChatHandlerBase 提取的流式响应处理逻辑
|
||
|
||
职责:
|
||
1. SSE 事件解析和处理
|
||
2. 响应流生成
|
||
3. 预读和嵌套错误检测
|
||
4. 客户端断开检测
|
||
5. 流式平滑输出
|
||
"""
|
||
|
||
import asyncio
|
||
import codecs
|
||
import json
|
||
from dataclasses import dataclass
|
||
from typing import Any, AsyncGenerator, Callable, Optional
|
||
|
||
import httpx
|
||
|
||
from src.api.handlers.base.content_extractors import (
|
||
ContentExtractor,
|
||
get_extractor,
|
||
get_extractor_formats,
|
||
)
|
||
from src.api.handlers.base.parsers import get_parser_for_format
|
||
from src.api.handlers.base.response_parser import ResponseParser
|
||
from src.api.handlers.base.stream_context import StreamContext
|
||
from src.config.settings import config
|
||
from src.core.exceptions import EmbeddedErrorException, ProviderTimeoutException
|
||
from src.core.logger import logger
|
||
from src.models.database import Provider, ProviderEndpoint
|
||
from src.utils.sse_parser import SSEEventParser
|
||
from src.utils.timeout import read_first_chunk_with_ttfb_timeout
|
||
|
||
|
||
@dataclass
|
||
class StreamSmoothingConfig:
|
||
"""流式平滑输出配置"""
|
||
|
||
enabled: bool = False
|
||
chunk_size: int = 20
|
||
delay_ms: int = 8
|
||
|
||
|
||
class StreamProcessor:
|
||
"""
|
||
流式响应处理器
|
||
|
||
负责处理 SSE 流的解析、错误检测、响应生成和平滑输出。
|
||
从 ChatHandlerBase 中提取,使其职责更加单一。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
request_id: str,
|
||
default_parser: ResponseParser,
|
||
on_streaming_start: Optional[Callable[[], None]] = None,
|
||
*,
|
||
collect_text: bool = False,
|
||
smoothing_config: Optional[StreamSmoothingConfig] = None,
|
||
):
|
||
"""
|
||
初始化流处理器
|
||
|
||
Args:
|
||
request_id: 请求 ID(用于日志)
|
||
default_parser: 默认响应解析器
|
||
on_streaming_start: 流开始时的回调(用于更新状态)
|
||
collect_text: 是否收集文本内容
|
||
smoothing_config: 流式平滑输出配置
|
||
"""
|
||
self.request_id = request_id
|
||
self.default_parser = default_parser
|
||
self.on_streaming_start = on_streaming_start
|
||
self.collect_text = collect_text
|
||
self.smoothing_config = smoothing_config or StreamSmoothingConfig()
|
||
|
||
# 内容提取器缓存
|
||
self._extractors: dict[str, ContentExtractor] = {}
|
||
|
||
def get_parser_for_provider(self, ctx: StreamContext) -> ResponseParser:
|
||
"""
|
||
获取 Provider 格式的解析器
|
||
|
||
根据 Provider 的 API 格式选择正确的解析器。
|
||
"""
|
||
if ctx.provider_api_format:
|
||
try:
|
||
return get_parser_for_format(ctx.provider_api_format)
|
||
except KeyError:
|
||
pass
|
||
return self.default_parser
|
||
|
||
def handle_sse_event(
|
||
self,
|
||
ctx: StreamContext,
|
||
event_name: Optional[str],
|
||
data_str: str,
|
||
) -> None:
|
||
"""
|
||
处理单个 SSE 事件
|
||
|
||
解析事件数据,提取 usage 信息和文本内容。
|
||
|
||
Args:
|
||
ctx: 流式上下文
|
||
event_name: 事件名称
|
||
data_str: 事件数据字符串
|
||
"""
|
||
if not data_str:
|
||
return
|
||
|
||
if data_str == "[DONE]":
|
||
ctx.has_completion = True
|
||
return
|
||
|
||
try:
|
||
data = json.loads(data_str)
|
||
except json.JSONDecodeError:
|
||
return
|
||
|
||
ctx.data_count += 1
|
||
|
||
if not isinstance(data, dict):
|
||
return
|
||
|
||
# 收集原始 chunk 数据
|
||
ctx.parsed_chunks.append(data)
|
||
|
||
# 根据 Provider 格式选择解析器
|
||
parser = self.get_parser_for_provider(ctx)
|
||
|
||
# 使用解析器提取 usage
|
||
usage = parser.extract_usage_from_response(data)
|
||
if usage:
|
||
ctx.update_usage(
|
||
input_tokens=usage.get("input_tokens"),
|
||
output_tokens=usage.get("output_tokens"),
|
||
cached_tokens=usage.get("cache_read_tokens"),
|
||
cache_creation_tokens=usage.get("cache_creation_tokens"),
|
||
)
|
||
|
||
# 提取文本
|
||
if self.collect_text:
|
||
text = parser.extract_text_content(data)
|
||
if text:
|
||
ctx.append_text(text)
|
||
|
||
# 检查完成
|
||
event_type = event_name or data.get("type", "")
|
||
if event_type in ("response.completed", "message_stop"):
|
||
ctx.has_completion = True
|
||
|
||
# 检查 OpenAI 格式的 finish_reason
|
||
choices = data.get("choices", [])
|
||
if choices and isinstance(choices, list) and len(choices) > 0:
|
||
finish_reason = choices[0].get("finish_reason")
|
||
if finish_reason is not None:
|
||
ctx.has_completion = True
|
||
|
||
async def prefetch_and_check_error(
|
||
self,
|
||
byte_iterator: Any,
|
||
provider: Provider,
|
||
endpoint: ProviderEndpoint,
|
||
ctx: StreamContext,
|
||
max_prefetch_lines: int = 5,
|
||
) -> list:
|
||
"""
|
||
预读流的前几行,检测嵌套错误
|
||
|
||
某些 Provider(如 Gemini)可能返回 HTTP 200,但在响应体中包含错误信息。
|
||
这种情况需要在流开始输出之前检测,以便触发重试逻辑。
|
||
|
||
首次读取时会应用 TTFB(首字节超时)检测,超时则触发故障转移。
|
||
|
||
Args:
|
||
byte_iterator: 字节流迭代器
|
||
provider: Provider 对象
|
||
endpoint: Endpoint 对象
|
||
ctx: 流式上下文
|
||
max_prefetch_lines: 最多预读行数
|
||
|
||
Returns:
|
||
预读的字节块列表
|
||
|
||
Raises:
|
||
EmbeddedErrorException: 如果检测到嵌套错误
|
||
ProviderTimeoutException: 如果首字节超时(TTFB timeout)
|
||
"""
|
||
prefetched_chunks: list = []
|
||
parser = self.get_parser_for_provider(ctx)
|
||
buffer = b""
|
||
line_count = 0
|
||
should_stop = False
|
||
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
|
||
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
||
|
||
try:
|
||
# 使用共享的 TTFB 超时函数读取首字节
|
||
ttfb_timeout = config.stream_first_byte_timeout
|
||
first_chunk, aiter = await read_first_chunk_with_ttfb_timeout(
|
||
byte_iterator,
|
||
timeout=ttfb_timeout,
|
||
request_id=self.request_id,
|
||
provider_name=str(provider.name),
|
||
)
|
||
prefetched_chunks.append(first_chunk)
|
||
buffer += first_chunk
|
||
|
||
# 继续读取剩余的预读数据
|
||
async for chunk in aiter:
|
||
prefetched_chunks.append(chunk)
|
||
buffer += chunk
|
||
|
||
# 尝试按行解析缓冲区
|
||
while b"\n" in buffer:
|
||
line_bytes, buffer = buffer.split(b"\n", 1)
|
||
try:
|
||
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
|
||
line = decoder.decode(line_bytes + b"\n", False).rstrip("\r\n")
|
||
except Exception as e:
|
||
logger.warning(
|
||
f"[{self.request_id}] 预读时 UTF-8 解码失败: {e}, "
|
||
f"bytes={line_bytes[:50]!r}"
|
||
)
|
||
continue
|
||
|
||
line_count += 1
|
||
|
||
# 跳过空行和注释行
|
||
if not line or line.startswith(":"):
|
||
if line_count >= max_prefetch_lines:
|
||
should_stop = True
|
||
break
|
||
continue
|
||
|
||
# 尝试解析 SSE 数据
|
||
data_str = line
|
||
if line.startswith("data: "):
|
||
data_str = line[6:]
|
||
|
||
if data_str == "[DONE]":
|
||
should_stop = True
|
||
break
|
||
|
||
try:
|
||
data = json.loads(data_str)
|
||
except json.JSONDecodeError:
|
||
if line_count >= max_prefetch_lines:
|
||
should_stop = True
|
||
break
|
||
continue
|
||
|
||
# 使用解析器检查是否为错误响应
|
||
if isinstance(data, dict) and parser.is_error_response(data):
|
||
parsed = parser.parse_response(data, 200)
|
||
logger.warning(
|
||
f" [{self.request_id}] 检测到嵌套错误: "
|
||
f"Provider={provider.name}, "
|
||
f"error_type={parsed.error_type}, "
|
||
f"message={parsed.error_message}"
|
||
)
|
||
raise EmbeddedErrorException(
|
||
provider_name=str(provider.name),
|
||
error_code=(
|
||
int(parsed.error_type)
|
||
if parsed.error_type and parsed.error_type.isdigit()
|
||
else None
|
||
),
|
||
error_message=parsed.error_message,
|
||
error_status=parsed.error_type,
|
||
)
|
||
|
||
# 预读到有效数据,没有错误,停止预读
|
||
should_stop = True
|
||
break
|
||
|
||
if should_stop or line_count >= max_prefetch_lines:
|
||
break
|
||
|
||
except (EmbeddedErrorException, ProviderTimeoutException):
|
||
# 重新抛出可重试的 Provider 异常,触发故障转移
|
||
raise
|
||
except (OSError, IOError) as e:
|
||
# 网络 I/O <20><><EFBFBD>常:记录警告,可能需要重试
|
||
logger.warning(
|
||
f" [{self.request_id}] 预读流时发生网络异常: {type(e).__name__}: {e}"
|
||
)
|
||
except Exception as e:
|
||
# 未预期的严重异常:记录错误并重新抛出,避免掩盖问题
|
||
logger.error(
|
||
f" [{self.request_id}] 预读流时发生严重异常: {type(e).__name__}: {e}",
|
||
exc_info=True
|
||
)
|
||
raise
|
||
|
||
return prefetched_chunks
|
||
|
||
async def create_response_stream(
|
||
self,
|
||
ctx: StreamContext,
|
||
byte_iterator: Any,
|
||
response_ctx: Any,
|
||
http_client: httpx.AsyncClient,
|
||
prefetched_chunks: Optional[list] = None,
|
||
*,
|
||
start_time: Optional[float] = None,
|
||
) -> AsyncGenerator[bytes, None]:
|
||
"""
|
||
创建响应流生成器
|
||
|
||
从字节流中解析 SSE 数据并转发,支持预读数据。
|
||
|
||
Args:
|
||
ctx: 流式上下文
|
||
byte_iterator: 字节流迭代器
|
||
response_ctx: HTTP 响应上下文管理器
|
||
http_client: HTTP 客户端
|
||
prefetched_chunks: 预读的字节块列表(可选)
|
||
start_time: 请求开始时间,用于计算 TTFB(可选)
|
||
|
||
Yields:
|
||
编码后的响应数据块
|
||
"""
|
||
try:
|
||
sse_parser = SSEEventParser()
|
||
streaming_started = False
|
||
buffer = b""
|
||
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
|
||
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
||
|
||
# 处理预读数据
|
||
if prefetched_chunks:
|
||
if not streaming_started and self.on_streaming_start:
|
||
self.on_streaming_start()
|
||
streaming_started = True
|
||
|
||
for chunk in prefetched_chunks:
|
||
# 记录首字时间 (TTFB) - 在 yield 之前记录
|
||
if start_time is not None:
|
||
ctx.record_first_byte_time(start_time)
|
||
start_time = None # 只记录一次
|
||
|
||
# 把原始数据转发给客户端
|
||
yield chunk
|
||
|
||
buffer += chunk
|
||
# 处理缓冲区中的完整行
|
||
while b"\n" in buffer:
|
||
line_bytes, buffer = buffer.split(b"\n", 1)
|
||
try:
|
||
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
|
||
line = decoder.decode(line_bytes + b"\n", False)
|
||
self._process_line(ctx, sse_parser, line)
|
||
except Exception as e:
|
||
# 解码失败,记录警告但继续处理
|
||
logger.warning(
|
||
f"[{self.request_id}] UTF-8 解码失败: {e}, "
|
||
f"bytes={line_bytes[:50]!r}"
|
||
)
|
||
continue
|
||
|
||
# 处理剩余的流数据
|
||
async for chunk in byte_iterator:
|
||
if not streaming_started and self.on_streaming_start:
|
||
self.on_streaming_start()
|
||
streaming_started = True
|
||
|
||
# 记录首字时间 (TTFB) - 在 yield 之前记录(如果预读数据为空)
|
||
if start_time is not None:
|
||
ctx.record_first_byte_time(start_time)
|
||
start_time = None # 只记录一次
|
||
|
||
# 原始数据透传
|
||
yield chunk
|
||
|
||
buffer += chunk
|
||
# 处理缓冲区中的完整行
|
||
while b"\n" in buffer:
|
||
line_bytes, buffer = buffer.split(b"\n", 1)
|
||
try:
|
||
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
|
||
line = decoder.decode(line_bytes + b"\n", False)
|
||
self._process_line(ctx, sse_parser, line)
|
||
except Exception as e:
|
||
# 解码失败,记录警告但继续处理
|
||
logger.warning(
|
||
f"[{self.request_id}] UTF-8 解码失败: {e}, "
|
||
f"bytes={line_bytes[:50]!r}"
|
||
)
|
||
continue
|
||
|
||
# 处理剩余的缓冲区数据(如果有未完成的行)
|
||
if buffer:
|
||
try:
|
||
# 使用 final=True 处理最后的不完整字符
|
||
line = decoder.decode(buffer, True)
|
||
self._process_line(ctx, sse_parser, line)
|
||
except Exception as e:
|
||
logger.warning(
|
||
f"[{self.request_id}] 处理剩余缓冲区失败: {e}, "
|
||
f"bytes={buffer[:50]!r}"
|
||
)
|
||
|
||
# 处理剩余事件
|
||
for event in sse_parser.flush():
|
||
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
||
|
||
except GeneratorExit:
|
||
raise
|
||
finally:
|
||
await self._cleanup(response_ctx, http_client)
|
||
|
||
def _process_line(
|
||
self,
|
||
ctx: StreamContext,
|
||
sse_parser: SSEEventParser,
|
||
line: str,
|
||
) -> None:
|
||
"""
|
||
处理单行数据
|
||
|
||
Args:
|
||
ctx: 流式上下文
|
||
sse_parser: SSE 解析器
|
||
line: 原始行数据
|
||
"""
|
||
# SSEEventParser 以"去掉换行符"的单行文本作为输入;这里统一剔除 CR/LF,
|
||
# 避免把空行误判成 "\n" 并导致事件边界解析错误。
|
||
normalized_line = line.rstrip("\r\n")
|
||
events = sse_parser.feed_line(normalized_line)
|
||
|
||
if normalized_line != "":
|
||
ctx.chunk_count += 1
|
||
|
||
for event in events:
|
||
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
||
|
||
async def create_monitored_stream(
|
||
self,
|
||
ctx: StreamContext,
|
||
stream_generator: AsyncGenerator[bytes, None],
|
||
is_disconnected: Callable[[], Any],
|
||
) -> AsyncGenerator[bytes, None]:
|
||
"""
|
||
创建带监控的流生成器
|
||
|
||
检测客户端断开连接并更新状态码。
|
||
|
||
Args:
|
||
ctx: 流式上下文
|
||
stream_generator: 原始流生成器
|
||
is_disconnected: 检查客户端是否断开的函数
|
||
|
||
Yields:
|
||
响应数据块
|
||
"""
|
||
try:
|
||
# 使用后台任务检查断连,完全不阻塞流式传输
|
||
disconnected = False
|
||
|
||
async def check_disconnect_background() -> None:
|
||
nonlocal disconnected
|
||
while not disconnected and not ctx.has_completion:
|
||
await asyncio.sleep(0.5)
|
||
if await is_disconnected():
|
||
disconnected = True
|
||
break
|
||
|
||
# 启动后台检查任务
|
||
check_task = asyncio.create_task(check_disconnect_background())
|
||
|
||
try:
|
||
async for chunk in stream_generator:
|
||
if disconnected:
|
||
# 如果响应已完成,客户端断开不算失败
|
||
if ctx.has_completion:
|
||
logger.info(
|
||
f"ID:{self.request_id} | Client disconnected after completion"
|
||
)
|
||
else:
|
||
logger.warning(f"ID:{self.request_id} | Client disconnected")
|
||
ctx.status_code = 499
|
||
ctx.error_message = "client_disconnected"
|
||
break
|
||
yield chunk
|
||
finally:
|
||
check_task.cancel()
|
||
try:
|
||
await check_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
except asyncio.CancelledError:
|
||
# 如果响应已完成,不标记为失败
|
||
if not ctx.has_completion:
|
||
ctx.status_code = 499
|
||
ctx.error_message = "client_disconnected"
|
||
raise
|
||
except Exception as e:
|
||
ctx.status_code = 500
|
||
ctx.error_message = str(e)
|
||
raise
|
||
|
||
async def create_smoothed_stream(
|
||
self,
|
||
stream_generator: AsyncGenerator[bytes, None],
|
||
) -> AsyncGenerator[bytes, None]:
|
||
"""
|
||
创建平滑输出的流生成器
|
||
|
||
如果启用了平滑输出,将大 chunk 拆分成小块并添加微小延迟。
|
||
否则直接透传原始流。
|
||
|
||
Args:
|
||
stream_generator: 原始流生成器
|
||
|
||
Yields:
|
||
平滑处理后的响应数据块
|
||
"""
|
||
if not self.smoothing_config.enabled:
|
||
# 未启用平滑输出,直接透传
|
||
async for chunk in stream_generator:
|
||
yield chunk
|
||
return
|
||
|
||
# 启用平滑输出
|
||
buffer = b""
|
||
is_first_content = True
|
||
|
||
async for chunk in stream_generator:
|
||
buffer += chunk
|
||
|
||
# 按双换行分割 SSE 事件(标准 SSE 格式)
|
||
while b"\n\n" in buffer:
|
||
event_block, buffer = buffer.split(b"\n\n", 1)
|
||
event_str = event_block.decode("utf-8", errors="replace")
|
||
|
||
# 解析事件块
|
||
lines = event_str.strip().split("\n")
|
||
data_str = None
|
||
event_type = ""
|
||
|
||
for line in lines:
|
||
line = line.rstrip("\r")
|
||
if line.startswith("event: "):
|
||
event_type = line[7:].strip()
|
||
elif line.startswith("data: "):
|
||
data_str = line[6:]
|
||
|
||
# 没有 data 行,直接透传
|
||
if data_str is None:
|
||
yield event_block + b"\n\n"
|
||
continue
|
||
|
||
# [DONE] 直接透传
|
||
if data_str.strip() == "[DONE]":
|
||
yield event_block + b"\n\n"
|
||
continue
|
||
|
||
# 尝试解析 JSON
|
||
try:
|
||
data = json.loads(data_str)
|
||
except json.JSONDecodeError:
|
||
yield event_block + b"\n\n"
|
||
continue
|
||
|
||
# 检测格式并提取内容
|
||
content, extractor = self._detect_format_and_extract(data)
|
||
|
||
# 只有内容长度大于 1 才需要平滑处理
|
||
if content and len(content) > 1 and extractor:
|
||
# 获取配置的延迟
|
||
delay_seconds = self._calculate_delay()
|
||
|
||
# 拆分内容
|
||
content_chunks = self._split_content(content)
|
||
|
||
for i, sub_content in enumerate(content_chunks):
|
||
is_first = is_first_content and i == 0
|
||
|
||
# 使用提取器创建新 chunk
|
||
sse_chunk = extractor.create_chunk(
|
||
data,
|
||
sub_content,
|
||
event_type=event_type,
|
||
is_first=is_first,
|
||
)
|
||
|
||
yield sse_chunk
|
||
|
||
# 除了最后一个块,其他块之间加延迟
|
||
if i < len(content_chunks) - 1:
|
||
await asyncio.sleep(delay_seconds)
|
||
|
||
is_first_content = False
|
||
else:
|
||
# 不需要拆分,直接透传
|
||
yield event_block + b"\n\n"
|
||
if content:
|
||
is_first_content = False
|
||
|
||
# 处理剩余数据
|
||
if buffer:
|
||
yield buffer
|
||
|
||
def _get_extractor(self, format_name: str) -> Optional[ContentExtractor]:
|
||
"""获取或创建格式对应的提取器(带缓存)"""
|
||
if format_name not in self._extractors:
|
||
extractor = get_extractor(format_name)
|
||
if extractor:
|
||
self._extractors[format_name] = extractor
|
||
return self._extractors.get(format_name)
|
||
|
||
def _detect_format_and_extract(
|
||
self, data: dict
|
||
) -> tuple[Optional[str], Optional[ContentExtractor]]:
|
||
"""
|
||
检测数据格式并提取内容
|
||
|
||
依次尝试各格式的提取器,返回第一个成功提取内容的结果。
|
||
|
||
Returns:
|
||
(content, extractor): 提取的内容和对应的提取器
|
||
"""
|
||
for format_name in get_extractor_formats():
|
||
extractor = self._get_extractor(format_name)
|
||
if extractor:
|
||
content = extractor.extract_content(data)
|
||
if content is not None:
|
||
return content, extractor
|
||
|
||
return None, None
|
||
|
||
def _calculate_delay(self) -> float:
|
||
"""获取配置的延迟(秒)"""
|
||
return self.smoothing_config.delay_ms / 1000.0
|
||
|
||
def _split_content(self, content: str) -> list[str]:
|
||
"""
|
||
按块拆分文本
|
||
"""
|
||
chunk_size = self.smoothing_config.chunk_size
|
||
text_length = len(content)
|
||
|
||
if text_length <= chunk_size:
|
||
return [content]
|
||
|
||
# 按块拆分
|
||
chunks = []
|
||
for i in range(0, text_length, chunk_size):
|
||
chunks.append(content[i : i + chunk_size])
|
||
return chunks
|
||
|
||
async def _cleanup(
|
||
self,
|
||
response_ctx: Any,
|
||
http_client: httpx.AsyncClient,
|
||
) -> None:
|
||
"""清理资源"""
|
||
try:
|
||
await response_ctx.__aexit__(None, None, None)
|
||
except Exception:
|
||
pass
|
||
try:
|
||
await http_client.aclose()
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
async def create_smoothed_stream(
|
||
stream_generator: AsyncGenerator[bytes, None],
|
||
chunk_size: int = 20,
|
||
delay_ms: int = 8,
|
||
) -> AsyncGenerator[bytes, None]:
|
||
"""
|
||
独立的平滑流生成函数
|
||
|
||
供 CLI handler 等场景使用,无需创建完整的 StreamProcessor 实例。
|
||
|
||
Args:
|
||
stream_generator: 原始流生成器
|
||
chunk_size: 每块字符数
|
||
delay_ms: 每块之间的延迟毫秒数
|
||
|
||
Yields:
|
||
平滑处理后的响应数据块
|
||
"""
|
||
processor = _LightweightSmoother(chunk_size=chunk_size, delay_ms=delay_ms)
|
||
async for chunk in processor.smooth(stream_generator):
|
||
yield chunk
|
||
|
||
|
||
class _LightweightSmoother:
|
||
"""
|
||
轻量级平滑处理器
|
||
|
||
只包含平滑输出所需的最小逻辑,不依赖 StreamProcessor 的其他功能。
|
||
"""
|
||
|
||
def __init__(self, chunk_size: int = 20, delay_ms: int = 8) -> None:
|
||
self.chunk_size = chunk_size
|
||
self.delay_ms = delay_ms
|
||
self._extractors: dict[str, ContentExtractor] = {}
|
||
|
||
def _get_extractor(self, format_name: str) -> Optional[ContentExtractor]:
|
||
if format_name not in self._extractors:
|
||
extractor = get_extractor(format_name)
|
||
if extractor:
|
||
self._extractors[format_name] = extractor
|
||
return self._extractors.get(format_name)
|
||
|
||
def _detect_format_and_extract(
|
||
self, data: dict
|
||
) -> tuple[Optional[str], Optional[ContentExtractor]]:
|
||
for format_name in get_extractor_formats():
|
||
extractor = self._get_extractor(format_name)
|
||
if extractor:
|
||
content = extractor.extract_content(data)
|
||
if content is not None:
|
||
return content, extractor
|
||
return None, None
|
||
|
||
def _calculate_delay(self) -> float:
|
||
return self.delay_ms / 1000.0
|
||
|
||
def _split_content(self, content: str) -> list[str]:
|
||
text_length = len(content)
|
||
if text_length <= self.chunk_size:
|
||
return [content]
|
||
return [content[i : i + self.chunk_size] for i in range(0, text_length, self.chunk_size)]
|
||
|
||
async def smooth(
|
||
self, stream_generator: AsyncGenerator[bytes, None]
|
||
) -> AsyncGenerator[bytes, None]:
|
||
buffer = b""
|
||
is_first_content = True
|
||
|
||
async for chunk in stream_generator:
|
||
buffer += chunk
|
||
|
||
while b"\n\n" in buffer:
|
||
event_block, buffer = buffer.split(b"\n\n", 1)
|
||
event_str = event_block.decode("utf-8", errors="replace")
|
||
|
||
lines = event_str.strip().split("\n")
|
||
data_str = None
|
||
event_type = ""
|
||
|
||
for line in lines:
|
||
line = line.rstrip("\r")
|
||
if line.startswith("event: "):
|
||
event_type = line[7:].strip()
|
||
elif line.startswith("data: "):
|
||
data_str = line[6:]
|
||
|
||
if data_str is None:
|
||
yield event_block + b"\n\n"
|
||
continue
|
||
|
||
if data_str.strip() == "[DONE]":
|
||
yield event_block + b"\n\n"
|
||
continue
|
||
|
||
try:
|
||
data = json.loads(data_str)
|
||
except json.JSONDecodeError:
|
||
yield event_block + b"\n\n"
|
||
continue
|
||
|
||
content, extractor = self._detect_format_and_extract(data)
|
||
|
||
if content and len(content) > 1 and extractor:
|
||
delay_seconds = self._calculate_delay()
|
||
content_chunks = self._split_content(content)
|
||
|
||
for i, sub_content in enumerate(content_chunks):
|
||
is_first = is_first_content and i == 0
|
||
sse_chunk = extractor.create_chunk(
|
||
data, sub_content, event_type=event_type, is_first=is_first
|
||
)
|
||
yield sse_chunk
|
||
if i < len(content_chunks) - 1:
|
||
await asyncio.sleep(delay_seconds)
|
||
|
||
is_first_content = False
|
||
else:
|
||
yield event_block + b"\n\n"
|
||
if content:
|
||
is_first_content = False
|
||
|
||
if buffer:
|
||
yield buffer
|