feat: add stream smoothing feature for improved user experience

- Implement StreamSmoother class to split large content chunks into smaller pieces with delay
- Support OpenAI, Claude, and Gemini API response formats for smooth playback
- Add stream smoothing configuration to system settings (enable, chunk size, delay)
- Create streamlined API for stream smoothing with StreamSmoothingConfig dataclass
- Add admin UI controls for configuring stream smoothing parameters
- Use batch configuration loading to minimize database queries
- Enable typing effect simulation for better user experience in streaming responses
This commit is contained in:
fawney19
2025-12-19 03:15:19 +08:00
parent daf8b870f0
commit 85fafeacb8
6 changed files with 466 additions and 5 deletions

View File

@@ -464,6 +464,78 @@
</div> </div>
</div> </div>
</CardSection> </CardSection>
<!-- 流式输出配置 -->
<CardSection
title="流式输出"
description="配置流式响应的输出效果"
>
<div class="grid grid-cols-1 md:grid-cols-2 gap-6">
<div class="md:col-span-2">
<div class="flex items-center space-x-2">
<Checkbox
id="stream-smoothing-enabled"
v-model:checked="systemConfig.stream_smoothing_enabled"
/>
<div>
<Label
for="stream-smoothing-enabled"
class="cursor-pointer"
>
启用平滑输出
</Label>
<p class="text-xs text-muted-foreground">
将上游返回的大块内容拆分成小块模拟打字效果
</p>
</div>
</div>
</div>
<div>
<Label
for="stream-smoothing-chunk-size"
class="block text-sm font-medium"
>
每块字符数
</Label>
<Input
id="stream-smoothing-chunk-size"
v-model.number="systemConfig.stream_smoothing_chunk_size"
type="number"
min="1"
max="100"
placeholder="5"
class="mt-1"
:disabled="!systemConfig.stream_smoothing_enabled"
/>
<p class="mt-1 text-xs text-muted-foreground">
每次输出的字符数量1-100
</p>
</div>
<div>
<Label
for="stream-smoothing-delay-ms"
class="block text-sm font-medium"
>
输出间隔 (毫秒)
</Label>
<Input
id="stream-smoothing-delay-ms"
v-model.number="systemConfig.stream_smoothing_delay_ms"
type="number"
min="1"
max="500"
placeholder="15"
class="mt-1"
:disabled="!systemConfig.stream_smoothing_enabled"
/>
<p class="mt-1 text-xs text-muted-foreground">
每块之间的延迟毫秒数1-500
</p>
</div>
</div>
</CardSection>
</div> </div>
<!-- 导入配置对话框 --> <!-- 导入配置对话框 -->
@@ -812,6 +884,10 @@ interface SystemConfig {
log_retention_days: number log_retention_days: number
cleanup_batch_size: number cleanup_batch_size: number
audit_log_retention_days: number audit_log_retention_days: number
// 流式输出
stream_smoothing_enabled: boolean
stream_smoothing_chunk_size: number
stream_smoothing_delay_ms: number
} }
const loading = ref(false) const loading = ref(false)
@@ -861,6 +937,10 @@ const systemConfig = ref<SystemConfig>({
log_retention_days: 365, log_retention_days: 365,
cleanup_batch_size: 1000, cleanup_batch_size: 1000,
audit_log_retention_days: 30, audit_log_retention_days: 30,
// 流式输出
stream_smoothing_enabled: false,
stream_smoothing_chunk_size: 5,
stream_smoothing_delay_ms: 15,
}) })
// 计算属性KB 和 字节 之间的转换 // 计算属性KB 和 字节 之间的转换
@@ -917,6 +997,10 @@ async function loadSystemConfig() {
'log_retention_days', 'log_retention_days',
'cleanup_batch_size', 'cleanup_batch_size',
'audit_log_retention_days', 'audit_log_retention_days',
// 流式输出
'stream_smoothing_enabled',
'stream_smoothing_chunk_size',
'stream_smoothing_delay_ms',
] ]
for (const key of configs) { for (const key of configs) {
@@ -1024,6 +1108,22 @@ async function saveSystemConfig() {
value: systemConfig.value.audit_log_retention_days, value: systemConfig.value.audit_log_retention_days,
description: '审计日志保留天数' description: '审计日志保留天数'
}, },
// 流式输出
{
key: 'stream_smoothing_enabled',
value: systemConfig.value.stream_smoothing_enabled,
description: '是否启用流式平滑输出'
},
{
key: 'stream_smoothing_chunk_size',
value: systemConfig.value.stream_smoothing_chunk_size,
description: '流式平滑输出每个小块的字符数'
},
{
key: 'stream_smoothing_delay_ms',
value: systemConfig.value.stream_smoothing_delay_ms,
description: '流式平滑输出每个小块之间的延迟毫秒数'
},
] ]
const promises = configItems.map(item => const promises = configItems.map(item =>

View File

@@ -32,7 +32,7 @@ from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.request_builder import PassthroughRequestBuilder from src.api.handlers.base.request_builder import PassthroughRequestBuilder
from src.api.handlers.base.response_parser import ResponseParser from src.api.handlers.base.response_parser import ResponseParser
from src.api.handlers.base.stream_context import StreamContext from src.api.handlers.base.stream_context import StreamContext
from src.api.handlers.base.stream_processor import StreamProcessor from src.api.handlers.base.stream_processor import StreamProcessor, StreamSmoothingConfig
from src.api.handlers.base.stream_telemetry import StreamTelemetryRecorder from src.api.handlers.base.stream_telemetry import StreamTelemetryRecorder
from src.api.handlers.base.utils import build_sse_headers from src.api.handlers.base.utils import build_sse_headers
from src.config.settings import config from src.config.settings import config
@@ -52,6 +52,7 @@ from src.models.database import (
User, User,
) )
from src.services.provider.transport import build_provider_url from src.services.provider.transport import build_provider_url
from src.services.system.config import SystemConfigService
@@ -297,11 +298,23 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
def update_streaming_status() -> None: def update_streaming_status() -> None:
self._update_usage_to_streaming_with_ctx(ctx) self._update_usage_to_streaming_with_ctx(ctx)
# 从数据库批量读取流式平滑输出配置(单次查询)
smoothing_cfg = SystemConfigService.get_configs(
self.db,
["stream_smoothing_enabled", "stream_smoothing_chunk_size", "stream_smoothing_delay_ms"],
)
smoothing_config = StreamSmoothingConfig(
enabled=bool(smoothing_cfg.get("stream_smoothing_enabled", False)),
chunk_size=int(smoothing_cfg.get("stream_smoothing_chunk_size") or 5),
delay_ms=int(smoothing_cfg.get("stream_smoothing_delay_ms") or 15),
)
# 创建流处理器 # 创建流处理器
stream_processor = StreamProcessor( stream_processor = StreamProcessor(
request_id=self.request_id, request_id=self.request_id,
default_parser=self.parser, default_parser=self.parser,
on_streaming_start=update_streaming_status, on_streaming_start=update_streaming_status,
smoothing_config=smoothing_config,
) )
# 定义请求函数 # 定义请求函数
@@ -379,8 +392,11 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
http_request.is_disconnected, http_request.is_disconnected,
) )
# 创建平滑输出流(如果启用)
smoothed_stream = stream_processor.create_smoothed_stream(monitored_stream)
return StreamingResponse( return StreamingResponse(
monitored_stream, smoothed_stream,
media_type="text/event-stream", media_type="text/event-stream",
headers=build_sse_headers(), headers=build_sse_headers(),
background=background_tasks, background=background_tasks,

View File

@@ -11,6 +11,7 @@
import asyncio import asyncio
import codecs import codecs
import json import json
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Callable, Optional from typing import Any, AsyncGenerator, Callable, Optional
import httpx import httpx
@@ -18,12 +19,22 @@ import httpx
from src.api.handlers.base.parsers import get_parser_for_format from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.response_parser import ResponseParser from src.api.handlers.base.response_parser import ResponseParser
from src.api.handlers.base.stream_context import StreamContext from src.api.handlers.base.stream_context import StreamContext
from src.api.handlers.base.stream_smoother import StreamSmoother
from src.core.exceptions import EmbeddedErrorException from src.core.exceptions import EmbeddedErrorException
from src.core.logger import logger from src.core.logger import logger
from src.models.database import Provider, ProviderEndpoint from src.models.database import Provider, ProviderEndpoint
from src.utils.sse_parser import SSEEventParser from src.utils.sse_parser import SSEEventParser
@dataclass
class StreamSmoothingConfig:
"""流式平滑输出配置"""
enabled: bool = False
chunk_size: int = 5
delay_ms: int = 15
class StreamProcessor: class StreamProcessor:
""" """
流式响应处理器 流式响应处理器
@@ -39,6 +50,7 @@ class StreamProcessor:
on_streaming_start: Optional[Callable[[], None]] = None, on_streaming_start: Optional[Callable[[], None]] = None,
*, *,
collect_text: bool = False, collect_text: bool = False,
smoothing_config: Optional[StreamSmoothingConfig] = None,
): ):
""" """
初始化流处理器 初始化流处理器
@@ -47,11 +59,14 @@ class StreamProcessor:
request_id: 请求 ID用于日志 request_id: 请求 ID用于日志
default_parser: 默认响应解析器 default_parser: 默认响应解析器
on_streaming_start: 流开始时的回调(用于更新状态) on_streaming_start: 流开始时的回调(用于更新状态)
collect_text: 是否收集文本内容
smoothing_config: 流式平滑输出配置
""" """
self.request_id = request_id self.request_id = request_id
self.default_parser = default_parser self.default_parser = default_parser
self.on_streaming_start = on_streaming_start self.on_streaming_start = on_streaming_start
self.collect_text = collect_text self.collect_text = collect_text
self.smoothing_config = smoothing_config or StreamSmoothingConfig()
def get_parser_for_provider(self, ctx: StreamContext) -> ResponseParser: def get_parser_for_provider(self, ctx: StreamContext) -> ResponseParser:
""" """
@@ -451,6 +466,36 @@ class StreamProcessor:
ctx.error_message = str(e) ctx.error_message = str(e)
raise raise
async def create_smoothed_stream(
self,
stream_generator: AsyncGenerator[bytes, None],
) -> AsyncGenerator[bytes, None]:
"""
创建平滑输出的流生成器
如果启用了平滑输出,将大 chunk 拆分成小块并添加微小延迟。
否则直接透传原始流。
Args:
stream_generator: 原始流生成器
Yields:
平滑处理后的响应数据块
"""
if not self.smoothing_config.enabled:
# 未启用平滑输出,直接透传
async for chunk in stream_generator:
yield chunk
return
# 启用平滑输出
smoother = StreamSmoother(
chunk_size=self.smoothing_config.chunk_size,
delay_ms=self.smoothing_config.delay_ms,
)
async for chunk in smoother.smooth_stream(stream_generator):
yield chunk
async def _cleanup( async def _cleanup(
self, self,
response_ctx: Any, response_ctx: Any,

View File

@@ -0,0 +1,257 @@
"""
流式平滑输出处理器
将上游返回的大 chunk 拆分成小块,模拟更流畅的打字效果。
支持 OpenAI、Claude、Gemini 格式的 SSE 事件。
"""
import asyncio
import copy
import json
from typing import AsyncGenerator, Optional, Tuple
class StreamSmoother:
"""
流式平滑输出处理器
将 SSE 事件中的大段 content 拆分成小块输出,
每块之间加入微小延迟,模拟打字效果。
"""
def __init__(
self,
chunk_size: int = 5,
delay_ms: int = 15,
):
"""
初始化平滑处理器
Args:
chunk_size: 每个小块的字符数
delay_ms: 每个小块之间的延迟毫秒数
"""
self.chunk_size = chunk_size
self.delay_ms = delay_ms
self.delay_seconds = self.delay_ms / 1000.0
def _split_content(self, content: str) -> list[str]:
"""
将内容按字符数拆分
对于中文等多字节字符,按字符(而非字节)拆分。
"""
if len(content) <= self.chunk_size:
return [content]
chunks = []
for i in range(0, len(content), self.chunk_size):
chunks.append(content[i : i + self.chunk_size])
return chunks
def _extract_content(self, data: dict) -> Tuple[Optional[str], str]:
"""
从 SSE 数据中提取可拆分的 content
Returns:
(content, format): content 为提取的文本format 为检测到的格式
format: "openai" | "claude" | "gemini" | "unknown"
"""
if not isinstance(data, dict):
return None, "unknown"
# OpenAI 格式: choices[0].delta.content
# 只在 delta 仅包含 role/content 时允许拆分,避免破坏 tool_calls 等结构
choices = data.get("choices")
if isinstance(choices, list) and len(choices) == 1:
first_choice = choices[0]
if isinstance(first_choice, dict):
delta = first_choice.get("delta")
if isinstance(delta, dict):
content = delta.get("content")
if isinstance(content, str):
allowed_keys = {"role", "content"}
if all(key in allowed_keys for key in delta.keys()):
return content, "openai"
# Claude 格式: type=content_block_delta, delta.type=text_delta, delta.text
if data.get("type") == "content_block_delta":
delta = data.get("delta", {})
if isinstance(delta, dict) and delta.get("type") == "text_delta":
text = delta.get("text")
if isinstance(text, str):
return text, "claude"
# Gemini 格式: candidates[0].content.parts[0].text
candidates = data.get("candidates")
if isinstance(candidates, list) and len(candidates) == 1:
first_candidate = candidates[0]
if isinstance(first_candidate, dict):
content = first_candidate.get("content", {})
if isinstance(content, dict):
parts = content.get("parts", [])
if isinstance(parts, list) and len(parts) == 1:
first_part = parts[0]
if isinstance(first_part, dict):
text = first_part.get("text")
# 只有纯文本块才拆分
if isinstance(text, str) and len(first_part) == 1:
return text, "gemini"
return None, "unknown"
def _create_openai_chunk(
self,
original_data: dict,
new_content: str,
is_first: bool = False,
) -> bytes:
"""创建 OpenAI 格式的 SSE chunk"""
new_data = original_data.copy()
if "choices" in new_data and new_data["choices"]:
new_choices = []
for choice in new_data["choices"]:
new_choice = choice.copy()
if "delta" in new_choice:
new_delta = {}
# 只有第一个 chunk 保留 role
if is_first and "role" in new_choice["delta"]:
new_delta["role"] = new_choice["delta"]["role"]
new_delta["content"] = new_content
new_choice["delta"] = new_delta
new_choices.append(new_choice)
new_data["choices"] = new_choices
return f"data: {json.dumps(new_data, ensure_ascii=False)}\n\n".encode("utf-8")
def _create_claude_chunk(
self,
original_data: dict,
new_content: str,
event_type: str,
) -> bytes:
"""创建 Claude 格式的 SSE chunk"""
new_data = original_data.copy()
if "delta" in new_data:
new_delta = new_data["delta"].copy()
new_delta["text"] = new_content
new_data["delta"] = new_delta
# Claude 格式需要 event: 前缀
return f"event: {event_type}\ndata: {json.dumps(new_data, ensure_ascii=False)}\n\n".encode(
"utf-8"
)
def _create_gemini_chunk(
self,
original_data: dict,
new_content: str,
) -> bytes:
"""创建 Gemini 格式的 SSE chunk"""
new_data = copy.deepcopy(original_data)
if "candidates" in new_data and new_data["candidates"]:
first_candidate = new_data["candidates"][0]
if "content" in first_candidate:
content = first_candidate["content"]
if "parts" in content and content["parts"]:
content["parts"][0]["text"] = new_content
return f"data: {json.dumps(new_data, ensure_ascii=False)}\n\n".encode("utf-8")
async def smooth_stream(
self,
byte_iterator: AsyncGenerator[bytes, None],
) -> AsyncGenerator[bytes, None]:
"""
对字节流进行平滑处理
解析 SSE 事件,拆分大 content添加延迟后输出。
Args:
byte_iterator: 原始字节流
Yields:
平滑处理后的字节块
"""
buffer = b""
is_first_content = True
async for chunk in byte_iterator:
buffer += chunk
# 按双换行分割 SSE 事件(标准 SSE 格式)
while b"\n\n" in buffer:
event_block, buffer = buffer.split(b"\n\n", 1)
event_str = event_block.decode("utf-8", errors="replace")
# 解析事件块
lines = event_str.strip().split("\n")
data_str = None
event_type = ""
for line in lines:
line = line.rstrip("\r")
if line.startswith("event: "):
event_type = line[7:].strip()
elif line.startswith("data: "):
data_str = line[6:]
# 没有 data 行,直接透传
if data_str is None:
yield event_block + b"\n\n"
continue
# [DONE] 直接透传
if data_str.strip() == "[DONE]":
yield event_block + b"\n\n"
continue
# 尝试解析 JSON
try:
data = json.loads(data_str)
except json.JSONDecodeError:
yield event_block + b"\n\n"
continue
# 提取 content 和格式
content, fmt = self._extract_content(data)
if content and len(content) > self.chunk_size:
# 需要拆分
content_chunks = self._split_content(content)
for i, sub_content in enumerate(content_chunks):
is_first = is_first_content and i == 0
if fmt == "openai":
sse_chunk = self._create_openai_chunk(data, sub_content, is_first)
elif fmt == "claude":
sse_chunk = self._create_claude_chunk(
data, sub_content, event_type or "content_block_delta"
)
elif fmt == "gemini":
sse_chunk = self._create_gemini_chunk(data, sub_content)
else:
# 未知格式,透传原始事件
yield event_block + b"\n\n"
break
yield sse_chunk
# 除了最后一个块,其他块之间加延迟
if i < len(content_chunks) - 1:
await asyncio.sleep(self.delay_seconds)
else:
is_first_content = False
else:
# 不需要拆分,直接透传
yield event_block + b"\n\n"
if content:
is_first_content = False
# 处理剩余数据
if buffer:
yield buffer

View File

@@ -143,6 +143,7 @@ class Config:
# STREAM_STATS_DELAY: 统计记录延迟(秒),等待流完全关闭 # STREAM_STATS_DELAY: 统计记录延迟(秒),等待流完全关闭
self.stream_prefetch_lines = int(os.getenv("STREAM_PREFETCH_LINES", "5")) self.stream_prefetch_lines = int(os.getenv("STREAM_PREFETCH_LINES", "5"))
self.stream_stats_delay = float(os.getenv("STREAM_STATS_DELAY", "0.1")) self.stream_stats_delay = float(os.getenv("STREAM_STATS_DELAY", "0.1"))
# 注流式平滑输出配置已移至数据库系统设置stream_smoothing_*
# 验证连接池配置 # 验证连接池配置
self._validate_pool_config() self._validate_pool_config()

View File

@@ -12,7 +12,6 @@ from src.core.logger import logger
from src.models.database import Provider, SystemConfig from src.models.database import Provider, SystemConfig
class LogLevel(str, Enum): class LogLevel(str, Enum):
"""日志记录级别""" """日志记录级别"""
@@ -79,6 +78,19 @@ class SystemConfigService:
"value": False, "value": False,
"description": "是否自动删除过期的API KeyTrue=物理删除False=仅禁用),仅管理员可配置", "description": "是否自动删除过期的API KeyTrue=物理删除False=仅禁用),仅管理员可配置",
}, },
# 流式平滑输出配置
"stream_smoothing_enabled": {
"value": False,
"description": "是否启用流式平滑输出,将大 chunk 拆分成小块模拟打字效果",
},
"stream_smoothing_chunk_size": {
"value": 5,
"description": "流式平滑输出每个小块的字符数",
},
"stream_smoothing_delay_ms": {
"value": 15,
"description": "流式平滑输出每个小块之间的延迟毫秒数",
},
} }
@classmethod @classmethod
@@ -94,6 +106,35 @@ class SystemConfigService:
return default return default
@classmethod
def get_configs(cls, db: Session, keys: List[str]) -> Dict[str, Any]:
"""
批量获取系统配置值
Args:
db: 数据库会话
keys: 配置键列表
Returns:
配置键值字典
"""
result = {}
# 一次查询获取所有配置
configs = db.query(SystemConfig).filter(SystemConfig.key.in_(keys)).all()
config_map = {c.key: c.value for c in configs}
# 填充结果,不存在的使用默认值
for key in keys:
if key in config_map:
result[key] = config_map[key]
elif key in cls.DEFAULT_CONFIGS:
result[key] = cls.DEFAULT_CONFIGS[key]["value"]
else:
result[key] = None
return result
@staticmethod @staticmethod
def set_config(db: Session, key: str, value: Any, description: str = None) -> SystemConfig: def set_config(db: Session, key: str, value: Any, description: str = None) -> SystemConfig:
"""设置系统配置值""" """设置系统配置值"""
@@ -111,6 +152,7 @@ class SystemConfigService:
db.commit() db.commit()
db.refresh(config) db.refresh(config)
return config return config
@staticmethod @staticmethod
@@ -153,8 +195,8 @@ class SystemConfigService:
for config in configs for config in configs
] ]
@staticmethod @classmethod
def delete_config(db: Session, key: str) -> bool: def delete_config(cls, db: Session, key: str) -> bool:
"""删除系统配置""" """删除系统配置"""
config = db.query(SystemConfig).filter(SystemConfig.key == key).first() config = db.query(SystemConfig).filter(SystemConfig.key == key).first()
if config: if config: