refactor: consolidate stream smoothing into StreamProcessor with intelligent timing

- Move StreamSmoother functionality directly into StreamProcessor for better integration
- Create ContentExtractor strategy pattern for format-agnostic content extraction
- Implement intelligent dynamic delay calculation based on text length
- Support three text length tiers: short (char-by-char), medium (chunked), long (chunked)
- Remove manual chunk_size and delay_ms configuration - now auto-calculated
- Simplify admin UI to single toggle switch with auto timing adjustment
- Extract format detection logic to reusable content_extractors module
- Improve code maintainability with cleaner architecture
This commit is contained in:
fawney19
2025-12-19 09:46:22 +08:00
parent 85fafeacb8
commit 070121717d
7 changed files with 600 additions and 360 deletions

View File

@@ -6,20 +6,26 @@
2. 响应流生成
3. 预读和嵌套错误检测
4. 客户端断开检测
5. 流式平滑输出
"""
import asyncio
import codecs
import json
import math
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Callable, Optional
import httpx
from src.api.handlers.base.content_extractors import (
ContentExtractor,
get_extractor,
get_extractor_formats,
)
from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.response_parser import ResponseParser
from src.api.handlers.base.stream_context import StreamContext
from src.api.handlers.base.stream_smoother import StreamSmoother
from src.core.exceptions import EmbeddedErrorException
from src.core.logger import logger
from src.models.database import Provider, ProviderEndpoint
@@ -31,18 +37,23 @@ class StreamSmoothingConfig:
"""流式平滑输出配置"""
enabled: bool = False
chunk_size: int = 5
delay_ms: int = 15
class StreamProcessor:
"""
流式响应处理器
负责处理 SSE 流的解析、错误检测响应生成。
负责处理 SSE 流的解析、错误检测响应生成和平滑输出
从 ChatHandlerBase 中提取,使其职责更加单一。
"""
# 平滑输出参数
CHUNK_SIZE = 5 # 长文本每块字符数
MIN_DELAY_MS = 15 # 长文本延迟(毫秒)
MAX_DELAY_MS = 24 # 短文本延迟(毫秒)
SHORT_TEXT_THRESHOLD = 10 # 短文本阈值(逐字符输出)
LONG_TEXT_THRESHOLD = 50 # 长文本阈值(按块输出)
def __init__(
self,
request_id: str,
@@ -68,6 +79,9 @@ class StreamProcessor:
self.collect_text = collect_text
self.smoothing_config = smoothing_config or StreamSmoothingConfig()
# 内容提取器缓存
self._extractors: dict[str, ContentExtractor] = {}
def get_parser_for_provider(self, ctx: StreamContext) -> ResponseParser:
"""
获取 Provider 格式的解析器
@@ -390,7 +404,7 @@ class StreamProcessor:
sse_parser: SSE 解析器
line: 原始行数据
"""
# SSEEventParser 以去掉换行符的单行文本作为输入;这里统一剔除 CR/LF
# SSEEventParser 以"去掉换行符"的单行文本作为输入;这里统一剔除 CR/LF
# 避免把空行误判成 "\n" 并导致事件边界解析错误。
normalized_line = line.rstrip("\r\n")
events = sse_parser.feed_line(normalized_line)
@@ -489,12 +503,153 @@ class StreamProcessor:
return
# 启用平滑输出
smoother = StreamSmoother(
chunk_size=self.smoothing_config.chunk_size,
delay_ms=self.smoothing_config.delay_ms,
buffer = b""
is_first_content = True
async for chunk in stream_generator:
buffer += chunk
# 按双换行分割 SSE 事件(标准 SSE 格式)
while b"\n\n" in buffer:
event_block, buffer = buffer.split(b"\n\n", 1)
event_str = event_block.decode("utf-8", errors="replace")
# 解析事件块
lines = event_str.strip().split("\n")
data_str = None
event_type = ""
for line in lines:
line = line.rstrip("\r")
if line.startswith("event: "):
event_type = line[7:].strip()
elif line.startswith("data: "):
data_str = line[6:]
# 没有 data 行,直接透传
if data_str is None:
yield event_block + b"\n\n"
continue
# [DONE] 直接透传
if data_str.strip() == "[DONE]":
yield event_block + b"\n\n"
continue
# 尝试解析 JSON
try:
data = json.loads(data_str)
except json.JSONDecodeError:
yield event_block + b"\n\n"
continue
# 检测格式并提取内容
content, extractor = self._detect_format_and_extract(data)
# 只有内容长度大于 1 才需要平滑处理
if content and len(content) > 1 and extractor:
# 计算动态延迟
delay_seconds = self._calculate_delay(len(content))
# 智能拆分
content_chunks = self._split_content(content)
for i, sub_content in enumerate(content_chunks):
is_first = is_first_content and i == 0
# 使用提取器创建新 chunk
sse_chunk = extractor.create_chunk(
data,
sub_content,
event_type=event_type,
is_first=is_first,
)
yield sse_chunk
# 除了最后一个块,其他块之间加延迟
if i < len(content_chunks) - 1:
await asyncio.sleep(delay_seconds)
is_first_content = False
else:
# 不需要拆分,直接透传
yield event_block + b"\n\n"
if content:
is_first_content = False
# 处理剩余数据
if buffer:
yield buffer
def _get_extractor(self, format_name: str) -> Optional[ContentExtractor]:
"""获取或创建格式对应的提取器(带缓存)"""
if format_name not in self._extractors:
extractor = get_extractor(format_name)
if extractor:
self._extractors[format_name] = extractor
return self._extractors.get(format_name)
def _detect_format_and_extract(
self, data: dict
) -> tuple[Optional[str], Optional[ContentExtractor]]:
"""
检测数据格式并提取内容
依次尝试各格式的提取器,返回第一个成功提取内容的结果。
Returns:
(content, extractor): 提取的内容和对应的提取器
"""
for format_name in get_extractor_formats():
extractor = self._get_extractor(format_name)
if extractor:
content = extractor.extract_content(data)
if content is not None:
return content, extractor
return None, None
def _calculate_delay(self, text_length: int) -> float:
"""
根据文本长度计算动态延迟(秒)
短文本使用较大延迟(打字感更强),长文本使用较小延迟(避免卡顿)。
中间长度使用对数插值平滑过渡。
"""
if text_length <= self.SHORT_TEXT_THRESHOLD:
return self.MAX_DELAY_MS / 1000.0
if text_length >= self.LONG_TEXT_THRESHOLD:
return self.MIN_DELAY_MS / 1000.0
# 对数插值:平滑过渡
ratio = math.log(text_length / self.SHORT_TEXT_THRESHOLD) / math.log(
self.LONG_TEXT_THRESHOLD / self.SHORT_TEXT_THRESHOLD
)
async for chunk in smoother.smooth_stream(stream_generator):
yield chunk
delay_ms = self.MAX_DELAY_MS - ratio * (self.MAX_DELAY_MS - self.MIN_DELAY_MS)
return delay_ms / 1000.0
def _split_content(self, content: str) -> list[str]:
"""
根据文本长度智能拆分
短文本:逐字符拆分(打字效果更真实)
长文本:按 CHUNK_SIZE 拆分(避免过多延迟)
"""
text_length = len(content)
if text_length <= self.CHUNK_SIZE:
return [content]
# 长文本按块拆分
if text_length >= self.LONG_TEXT_THRESHOLD:
chunks = []
for i in range(0, text_length, self.CHUNK_SIZE):
chunks.append(content[i : i + self.CHUNK_SIZE])
return chunks
# 短/中文本逐字符拆分
return list(content)
async def _cleanup(
self,
@@ -510,3 +665,137 @@ class StreamProcessor:
await http_client.aclose()
except Exception:
pass
async def create_smoothed_stream(
stream_generator: AsyncGenerator[bytes, None],
) -> AsyncGenerator[bytes, None]:
"""
独立的平滑流生成函数
供 CLI handler 等场景使用,无需创建完整的 StreamProcessor 实例。
Args:
stream_generator: 原始流生成器
Yields:
平滑处理后的响应数据块
"""
processor = _LightweightSmoother()
async for chunk in processor.smooth(stream_generator):
yield chunk
class _LightweightSmoother:
"""
轻量级平滑处理器
只包含平滑输出所需的最小逻辑,不依赖 StreamProcessor 的其他功能。
"""
CHUNK_SIZE = 5
MIN_DELAY_MS = 15
MAX_DELAY_MS = 24
SHORT_TEXT_THRESHOLD = 10
LONG_TEXT_THRESHOLD = 50
def __init__(self) -> None:
self._extractors: dict[str, ContentExtractor] = {}
def _get_extractor(self, format_name: str) -> Optional[ContentExtractor]:
if format_name not in self._extractors:
extractor = get_extractor(format_name)
if extractor:
self._extractors[format_name] = extractor
return self._extractors.get(format_name)
def _detect_format_and_extract(
self, data: dict
) -> tuple[Optional[str], Optional[ContentExtractor]]:
for format_name in get_extractor_formats():
extractor = self._get_extractor(format_name)
if extractor:
content = extractor.extract_content(data)
if content is not None:
return content, extractor
return None, None
def _calculate_delay(self, text_length: int) -> float:
if text_length <= self.SHORT_TEXT_THRESHOLD:
return self.MAX_DELAY_MS / 1000.0
if text_length >= self.LONG_TEXT_THRESHOLD:
return self.MIN_DELAY_MS / 1000.0
ratio = math.log(text_length / self.SHORT_TEXT_THRESHOLD) / math.log(
self.LONG_TEXT_THRESHOLD / self.SHORT_TEXT_THRESHOLD
)
return (self.MAX_DELAY_MS - ratio * (self.MAX_DELAY_MS - self.MIN_DELAY_MS)) / 1000.0
def _split_content(self, content: str) -> list[str]:
text_length = len(content)
if text_length <= self.CHUNK_SIZE:
return [content]
if text_length >= self.LONG_TEXT_THRESHOLD:
return [content[i : i + self.CHUNK_SIZE] for i in range(0, text_length, self.CHUNK_SIZE)]
return list(content)
async def smooth(
self, stream_generator: AsyncGenerator[bytes, None]
) -> AsyncGenerator[bytes, None]:
buffer = b""
is_first_content = True
async for chunk in stream_generator:
buffer += chunk
while b"\n\n" in buffer:
event_block, buffer = buffer.split(b"\n\n", 1)
event_str = event_block.decode("utf-8", errors="replace")
lines = event_str.strip().split("\n")
data_str = None
event_type = ""
for line in lines:
line = line.rstrip("\r")
if line.startswith("event: "):
event_type = line[7:].strip()
elif line.startswith("data: "):
data_str = line[6:]
if data_str is None:
yield event_block + b"\n\n"
continue
if data_str.strip() == "[DONE]":
yield event_block + b"\n\n"
continue
try:
data = json.loads(data_str)
except json.JSONDecodeError:
yield event_block + b"\n\n"
continue
content, extractor = self._detect_format_and_extract(data)
if content and len(content) > 1 and extractor:
delay_seconds = self._calculate_delay(len(content))
content_chunks = self._split_content(content)
for i, sub_content in enumerate(content_chunks):
is_first = is_first_content and i == 0
sse_chunk = extractor.create_chunk(
data, sub_content, event_type=event_type, is_first=is_first
)
yield sse_chunk
if i < len(content_chunks) - 1:
await asyncio.sleep(delay_seconds)
is_first_content = False
else:
yield event_block + b"\n\n"
if content:
is_first_content = False
if buffer:
yield buffer