mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-07 10:12:27 +08:00
refactor: consolidate stream smoothing into StreamProcessor with intelligent timing
- Move StreamSmoother functionality directly into StreamProcessor for better integration - Create ContentExtractor strategy pattern for format-agnostic content extraction - Implement intelligent dynamic delay calculation based on text length - Support three text length tiers: short (char-by-char), medium (chunked), long (chunked) - Remove manual chunk_size and delay_ms configuration - now auto-calculated - Simplify admin UI to single toggle switch with auto timing adjustment - Extract format detection logic to reusable content_extractors module - Improve code maintainability with cleaner architecture
This commit is contained in:
@@ -6,20 +6,26 @@
|
||||
2. 响应流生成
|
||||
3. 预读和嵌套错误检测
|
||||
4. 客户端断开检测
|
||||
5. 流式平滑输出
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import codecs
|
||||
import json
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncGenerator, Callable, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from src.api.handlers.base.content_extractors import (
|
||||
ContentExtractor,
|
||||
get_extractor,
|
||||
get_extractor_formats,
|
||||
)
|
||||
from src.api.handlers.base.parsers import get_parser_for_format
|
||||
from src.api.handlers.base.response_parser import ResponseParser
|
||||
from src.api.handlers.base.stream_context import StreamContext
|
||||
from src.api.handlers.base.stream_smoother import StreamSmoother
|
||||
from src.core.exceptions import EmbeddedErrorException
|
||||
from src.core.logger import logger
|
||||
from src.models.database import Provider, ProviderEndpoint
|
||||
@@ -31,18 +37,23 @@ class StreamSmoothingConfig:
|
||||
"""流式平滑输出配置"""
|
||||
|
||||
enabled: bool = False
|
||||
chunk_size: int = 5
|
||||
delay_ms: int = 15
|
||||
|
||||
|
||||
class StreamProcessor:
|
||||
"""
|
||||
流式响应处理器
|
||||
|
||||
负责处理 SSE 流的解析、错误检测和响应生成。
|
||||
负责处理 SSE 流的解析、错误检测、响应生成和平滑输出。
|
||||
从 ChatHandlerBase 中提取,使其职责更加单一。
|
||||
"""
|
||||
|
||||
# 平滑输出参数
|
||||
CHUNK_SIZE = 5 # 长文本每块字符数
|
||||
MIN_DELAY_MS = 15 # 长文本延迟(毫秒)
|
||||
MAX_DELAY_MS = 24 # 短文本延迟(毫秒)
|
||||
SHORT_TEXT_THRESHOLD = 10 # 短文本阈值(逐字符输出)
|
||||
LONG_TEXT_THRESHOLD = 50 # 长文本阈值(按块输出)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
@@ -68,6 +79,9 @@ class StreamProcessor:
|
||||
self.collect_text = collect_text
|
||||
self.smoothing_config = smoothing_config or StreamSmoothingConfig()
|
||||
|
||||
# 内容提取器缓存
|
||||
self._extractors: dict[str, ContentExtractor] = {}
|
||||
|
||||
def get_parser_for_provider(self, ctx: StreamContext) -> ResponseParser:
|
||||
"""
|
||||
获取 Provider 格式的解析器
|
||||
@@ -390,7 +404,7 @@ class StreamProcessor:
|
||||
sse_parser: SSE 解析器
|
||||
line: 原始行数据
|
||||
"""
|
||||
# SSEEventParser 以“去掉换行符”的单行文本作为输入;这里统一剔除 CR/LF,
|
||||
# SSEEventParser 以"去掉换行符"的单行文本作为输入;这里统一剔除 CR/LF,
|
||||
# 避免把空行误判成 "\n" 并导致事件边界解析错误。
|
||||
normalized_line = line.rstrip("\r\n")
|
||||
events = sse_parser.feed_line(normalized_line)
|
||||
@@ -489,12 +503,153 @@ class StreamProcessor:
|
||||
return
|
||||
|
||||
# 启用平滑输出
|
||||
smoother = StreamSmoother(
|
||||
chunk_size=self.smoothing_config.chunk_size,
|
||||
delay_ms=self.smoothing_config.delay_ms,
|
||||
buffer = b""
|
||||
is_first_content = True
|
||||
|
||||
async for chunk in stream_generator:
|
||||
buffer += chunk
|
||||
|
||||
# 按双换行分割 SSE 事件(标准 SSE 格式)
|
||||
while b"\n\n" in buffer:
|
||||
event_block, buffer = buffer.split(b"\n\n", 1)
|
||||
event_str = event_block.decode("utf-8", errors="replace")
|
||||
|
||||
# 解析事件块
|
||||
lines = event_str.strip().split("\n")
|
||||
data_str = None
|
||||
event_type = ""
|
||||
|
||||
for line in lines:
|
||||
line = line.rstrip("\r")
|
||||
if line.startswith("event: "):
|
||||
event_type = line[7:].strip()
|
||||
elif line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
|
||||
# 没有 data 行,直接透传
|
||||
if data_str is None:
|
||||
yield event_block + b"\n\n"
|
||||
continue
|
||||
|
||||
# [DONE] 直接透传
|
||||
if data_str.strip() == "[DONE]":
|
||||
yield event_block + b"\n\n"
|
||||
continue
|
||||
|
||||
# 尝试解析 JSON
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
yield event_block + b"\n\n"
|
||||
continue
|
||||
|
||||
# 检测格式并提取内容
|
||||
content, extractor = self._detect_format_and_extract(data)
|
||||
|
||||
# 只有内容长度大于 1 才需要平滑处理
|
||||
if content and len(content) > 1 and extractor:
|
||||
# 计算动态延迟
|
||||
delay_seconds = self._calculate_delay(len(content))
|
||||
|
||||
# 智能拆分
|
||||
content_chunks = self._split_content(content)
|
||||
|
||||
for i, sub_content in enumerate(content_chunks):
|
||||
is_first = is_first_content and i == 0
|
||||
|
||||
# 使用提取器创建新 chunk
|
||||
sse_chunk = extractor.create_chunk(
|
||||
data,
|
||||
sub_content,
|
||||
event_type=event_type,
|
||||
is_first=is_first,
|
||||
)
|
||||
|
||||
yield sse_chunk
|
||||
|
||||
# 除了最后一个块,其他块之间加延迟
|
||||
if i < len(content_chunks) - 1:
|
||||
await asyncio.sleep(delay_seconds)
|
||||
|
||||
is_first_content = False
|
||||
else:
|
||||
# 不需要拆分,直接透传
|
||||
yield event_block + b"\n\n"
|
||||
if content:
|
||||
is_first_content = False
|
||||
|
||||
# 处理剩余数据
|
||||
if buffer:
|
||||
yield buffer
|
||||
|
||||
def _get_extractor(self, format_name: str) -> Optional[ContentExtractor]:
|
||||
"""获取或创建格式对应的提取器(带缓存)"""
|
||||
if format_name not in self._extractors:
|
||||
extractor = get_extractor(format_name)
|
||||
if extractor:
|
||||
self._extractors[format_name] = extractor
|
||||
return self._extractors.get(format_name)
|
||||
|
||||
def _detect_format_and_extract(
|
||||
self, data: dict
|
||||
) -> tuple[Optional[str], Optional[ContentExtractor]]:
|
||||
"""
|
||||
检测数据格式并提取内容
|
||||
|
||||
依次尝试各格式的提取器,返回第一个成功提取内容的结果。
|
||||
|
||||
Returns:
|
||||
(content, extractor): 提取的内容和对应的提取器
|
||||
"""
|
||||
for format_name in get_extractor_formats():
|
||||
extractor = self._get_extractor(format_name)
|
||||
if extractor:
|
||||
content = extractor.extract_content(data)
|
||||
if content is not None:
|
||||
return content, extractor
|
||||
|
||||
return None, None
|
||||
|
||||
def _calculate_delay(self, text_length: int) -> float:
|
||||
"""
|
||||
根据文本长度计算动态延迟(秒)
|
||||
|
||||
短文本使用较大延迟(打字感更强),长文本使用较小延迟(避免卡顿)。
|
||||
中间长度使用对数插值平滑过渡。
|
||||
"""
|
||||
if text_length <= self.SHORT_TEXT_THRESHOLD:
|
||||
return self.MAX_DELAY_MS / 1000.0
|
||||
if text_length >= self.LONG_TEXT_THRESHOLD:
|
||||
return self.MIN_DELAY_MS / 1000.0
|
||||
|
||||
# 对数插值:平滑过渡
|
||||
ratio = math.log(text_length / self.SHORT_TEXT_THRESHOLD) / math.log(
|
||||
self.LONG_TEXT_THRESHOLD / self.SHORT_TEXT_THRESHOLD
|
||||
)
|
||||
async for chunk in smoother.smooth_stream(stream_generator):
|
||||
yield chunk
|
||||
delay_ms = self.MAX_DELAY_MS - ratio * (self.MAX_DELAY_MS - self.MIN_DELAY_MS)
|
||||
return delay_ms / 1000.0
|
||||
|
||||
def _split_content(self, content: str) -> list[str]:
|
||||
"""
|
||||
根据文本长度智能拆分
|
||||
|
||||
短文本:逐字符拆分(打字效果更真实)
|
||||
长文本:按 CHUNK_SIZE 拆分(避免过多延迟)
|
||||
"""
|
||||
text_length = len(content)
|
||||
|
||||
if text_length <= self.CHUNK_SIZE:
|
||||
return [content]
|
||||
|
||||
# 长文本按块拆分
|
||||
if text_length >= self.LONG_TEXT_THRESHOLD:
|
||||
chunks = []
|
||||
for i in range(0, text_length, self.CHUNK_SIZE):
|
||||
chunks.append(content[i : i + self.CHUNK_SIZE])
|
||||
return chunks
|
||||
|
||||
# 短/中文本逐字符拆分
|
||||
return list(content)
|
||||
|
||||
async def _cleanup(
|
||||
self,
|
||||
@@ -510,3 +665,137 @@ class StreamProcessor:
|
||||
await http_client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def create_smoothed_stream(
|
||||
stream_generator: AsyncGenerator[bytes, None],
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""
|
||||
独立的平滑流生成函数
|
||||
|
||||
供 CLI handler 等场景使用,无需创建完整的 StreamProcessor 实例。
|
||||
|
||||
Args:
|
||||
stream_generator: 原始流生成器
|
||||
|
||||
Yields:
|
||||
平滑处理后的响应数据块
|
||||
"""
|
||||
processor = _LightweightSmoother()
|
||||
async for chunk in processor.smooth(stream_generator):
|
||||
yield chunk
|
||||
|
||||
|
||||
class _LightweightSmoother:
|
||||
"""
|
||||
轻量级平滑处理器
|
||||
|
||||
只包含平滑输出所需的最小逻辑,不依赖 StreamProcessor 的其他功能。
|
||||
"""
|
||||
|
||||
CHUNK_SIZE = 5
|
||||
MIN_DELAY_MS = 15
|
||||
MAX_DELAY_MS = 24
|
||||
SHORT_TEXT_THRESHOLD = 10
|
||||
LONG_TEXT_THRESHOLD = 50
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._extractors: dict[str, ContentExtractor] = {}
|
||||
|
||||
def _get_extractor(self, format_name: str) -> Optional[ContentExtractor]:
|
||||
if format_name not in self._extractors:
|
||||
extractor = get_extractor(format_name)
|
||||
if extractor:
|
||||
self._extractors[format_name] = extractor
|
||||
return self._extractors.get(format_name)
|
||||
|
||||
def _detect_format_and_extract(
|
||||
self, data: dict
|
||||
) -> tuple[Optional[str], Optional[ContentExtractor]]:
|
||||
for format_name in get_extractor_formats():
|
||||
extractor = self._get_extractor(format_name)
|
||||
if extractor:
|
||||
content = extractor.extract_content(data)
|
||||
if content is not None:
|
||||
return content, extractor
|
||||
return None, None
|
||||
|
||||
def _calculate_delay(self, text_length: int) -> float:
|
||||
if text_length <= self.SHORT_TEXT_THRESHOLD:
|
||||
return self.MAX_DELAY_MS / 1000.0
|
||||
if text_length >= self.LONG_TEXT_THRESHOLD:
|
||||
return self.MIN_DELAY_MS / 1000.0
|
||||
ratio = math.log(text_length / self.SHORT_TEXT_THRESHOLD) / math.log(
|
||||
self.LONG_TEXT_THRESHOLD / self.SHORT_TEXT_THRESHOLD
|
||||
)
|
||||
return (self.MAX_DELAY_MS - ratio * (self.MAX_DELAY_MS - self.MIN_DELAY_MS)) / 1000.0
|
||||
|
||||
def _split_content(self, content: str) -> list[str]:
|
||||
text_length = len(content)
|
||||
if text_length <= self.CHUNK_SIZE:
|
||||
return [content]
|
||||
if text_length >= self.LONG_TEXT_THRESHOLD:
|
||||
return [content[i : i + self.CHUNK_SIZE] for i in range(0, text_length, self.CHUNK_SIZE)]
|
||||
return list(content)
|
||||
|
||||
async def smooth(
|
||||
self, stream_generator: AsyncGenerator[bytes, None]
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
buffer = b""
|
||||
is_first_content = True
|
||||
|
||||
async for chunk in stream_generator:
|
||||
buffer += chunk
|
||||
|
||||
while b"\n\n" in buffer:
|
||||
event_block, buffer = buffer.split(b"\n\n", 1)
|
||||
event_str = event_block.decode("utf-8", errors="replace")
|
||||
|
||||
lines = event_str.strip().split("\n")
|
||||
data_str = None
|
||||
event_type = ""
|
||||
|
||||
for line in lines:
|
||||
line = line.rstrip("\r")
|
||||
if line.startswith("event: "):
|
||||
event_type = line[7:].strip()
|
||||
elif line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
|
||||
if data_str is None:
|
||||
yield event_block + b"\n\n"
|
||||
continue
|
||||
|
||||
if data_str.strip() == "[DONE]":
|
||||
yield event_block + b"\n\n"
|
||||
continue
|
||||
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
yield event_block + b"\n\n"
|
||||
continue
|
||||
|
||||
content, extractor = self._detect_format_and_extract(data)
|
||||
|
||||
if content and len(content) > 1 and extractor:
|
||||
delay_seconds = self._calculate_delay(len(content))
|
||||
content_chunks = self._split_content(content)
|
||||
|
||||
for i, sub_content in enumerate(content_chunks):
|
||||
is_first = is_first_content and i == 0
|
||||
sse_chunk = extractor.create_chunk(
|
||||
data, sub_content, event_type=event_type, is_first=is_first
|
||||
)
|
||||
yield sse_chunk
|
||||
if i < len(content_chunks) - 1:
|
||||
await asyncio.sleep(delay_seconds)
|
||||
|
||||
is_first_content = False
|
||||
else:
|
||||
yield event_block + b"\n\n"
|
||||
if content:
|
||||
is_first_content = False
|
||||
|
||||
if buffer:
|
||||
yield buffer
|
||||
|
||||
Reference in New Issue
Block a user