refactor: make stream smoothing parameters configurable and add models cache invalidation

- Move stream smoothing parameters (chunk_size, delay_ms) to database config
- Remove hardcoded stream smoothing constants from StreamProcessor
- Simplify dynamic delay calculation by using config values directly
- Add invalidate_models_list_cache() function to clear /v1/models endpoint cache
- Call cache invalidation on model create, update, delete, and bulk operations
- Update admin UI to allow runtime configuration of smoothing parameters
- Improve model listing freshness when models are modified
This commit is contained in:
fawney19
2025-12-19 11:03:46 +08:00
parent 912f6643e2
commit 97425ac68f
8 changed files with 150 additions and 90 deletions

View File

@@ -470,6 +470,8 @@
title="流式输出" title="流式输出"
description="配置流式响应的输出效果" description="配置流式响应的输出效果"
> >
<div class="grid grid-cols-1 md:grid-cols-2 gap-6">
<div class="md:col-span-2">
<div class="flex items-center space-x-2"> <div class="flex items-center space-x-2">
<Checkbox <Checkbox
id="stream-smoothing-enabled" id="stream-smoothing-enabled"
@@ -483,7 +485,53 @@
启用平滑输出 启用平滑输出
</Label> </Label>
<p class="text-xs text-muted-foreground"> <p class="text-xs text-muted-foreground">
自动根据文本长度调整输出速度短文本逐字符输出打字感更强长文本按块输出避免卡顿 将上游返回的大块内容拆分成小块模拟打字效果
</p>
</div>
</div>
</div>
<div>
<Label
for="stream-smoothing-chunk-size"
class="block text-sm font-medium"
>
每块字符数
</Label>
<Input
id="stream-smoothing-chunk-size"
v-model.number="systemConfig.stream_smoothing_chunk_size"
type="number"
min="1"
max="100"
placeholder="20"
class="mt-1"
:disabled="!systemConfig.stream_smoothing_enabled"
/>
<p class="mt-1 text-xs text-muted-foreground">
每次输出的字符数量1-100
</p>
</div>
<div>
<Label
for="stream-smoothing-delay-ms"
class="block text-sm font-medium"
>
输出间隔 (毫秒)
</Label>
<Input
id="stream-smoothing-delay-ms"
v-model.number="systemConfig.stream_smoothing_delay_ms"
type="number"
min="1"
max="100"
placeholder="8"
class="mt-1"
:disabled="!systemConfig.stream_smoothing_enabled"
/>
<p class="mt-1 text-xs text-muted-foreground">
每块之间的延迟毫秒数1-100
</p> </p>
</div> </div>
</div> </div>
@@ -838,6 +886,8 @@ interface SystemConfig {
audit_log_retention_days: number audit_log_retention_days: number
// 流式输出 // 流式输出
stream_smoothing_enabled: boolean stream_smoothing_enabled: boolean
stream_smoothing_chunk_size: number
stream_smoothing_delay_ms: number
} }
const loading = ref(false) const loading = ref(false)
@@ -889,6 +939,8 @@ const systemConfig = ref<SystemConfig>({
audit_log_retention_days: 30, audit_log_retention_days: 30,
// 流式输出 // 流式输出
stream_smoothing_enabled: false, stream_smoothing_enabled: false,
stream_smoothing_chunk_size: 20,
stream_smoothing_delay_ms: 8,
}) })
// 计算属性KB 和 字节 之间的转换 // 计算属性KB 和 字节 之间的转换
@@ -947,6 +999,8 @@ async function loadSystemConfig() {
'audit_log_retention_days', 'audit_log_retention_days',
// 流式输出 // 流式输出
'stream_smoothing_enabled', 'stream_smoothing_enabled',
'stream_smoothing_chunk_size',
'stream_smoothing_delay_ms',
] ]
for (const key of configs) { for (const key of configs) {
@@ -1060,6 +1114,16 @@ async function saveSystemConfig() {
value: systemConfig.value.stream_smoothing_enabled, value: systemConfig.value.stream_smoothing_enabled,
description: '是否启用流式平滑输出' description: '是否启用流式平滑输出'
}, },
{
key: 'stream_smoothing_chunk_size',
value: systemConfig.value.stream_smoothing_chunk_size,
description: '流式平滑输出每个小块的字符数'
},
{
key: 'stream_smoothing_delay_ms',
value: systemConfig.value.stream_smoothing_delay_ms,
description: '流式平滑输出每个小块之间的延迟毫秒数'
},
] ]
const promises = configItems.map(item => const promises = configItems.map(item =>

View File

@@ -9,6 +9,7 @@ from fastapi import APIRouter, Depends, Request
from sqlalchemy.orm import Session, joinedload from sqlalchemy.orm import Session, joinedload
from src.api.base.admin_adapter import AdminApiAdapter from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.models_service import invalidate_models_list_cache
from src.api.base.pipeline import ApiRequestPipeline from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import InvalidRequestException, NotFoundException from src.core.exceptions import InvalidRequestException, NotFoundException
from src.core.logger import logger from src.core.logger import logger
@@ -419,4 +420,8 @@ class AdminBatchAssignModelsToProviderAdapter(AdminApiAdapter):
f"Batch assigned {len(success)} GlobalModels to provider {provider.name} by {context.user.username}" f"Batch assigned {len(success)} GlobalModels to provider {provider.name} by {context.user.username}"
) )
# 清除 /v1/models 列表缓存
if success:
await invalidate_models_list_cache()
return BatchAssignModelsToProviderResponse(success=success, errors=errors) return BatchAssignModelsToProviderResponse(success=success, errors=errors)

View File

@@ -55,6 +55,23 @@ async def _set_cached_models(api_formats: list[str], models: list["ModelInfo"])
logger.warning(f"[ModelsService] 缓存写入失败: {e}") logger.warning(f"[ModelsService] 缓存写入失败: {e}")
async def invalidate_models_list_cache() -> None:
"""
清除所有 /v1/models 列表缓存
在模型创建、更新、删除时调用,确保模型列表实时更新
"""
# 清除所有格式的缓存
all_formats = ["CLAUDE", "OPENAI", "GEMINI"]
for fmt in all_formats:
cache_key = f"{_CACHE_KEY_PREFIX}:{fmt}"
try:
await CacheService.delete(cache_key)
logger.debug(f"[ModelsService] 已清除缓存: {cache_key}")
except Exception as e:
logger.warning(f"[ModelsService] 清除缓存失败 {cache_key}: {e}")
@dataclass @dataclass
class ModelInfo: class ModelInfo:
"""统一的模型信息结构""" """统一的模型信息结构"""

View File

@@ -32,7 +32,7 @@ from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.request_builder import PassthroughRequestBuilder from src.api.handlers.base.request_builder import PassthroughRequestBuilder
from src.api.handlers.base.response_parser import ResponseParser from src.api.handlers.base.response_parser import ResponseParser
from src.api.handlers.base.stream_context import StreamContext from src.api.handlers.base.stream_context import StreamContext
from src.api.handlers.base.stream_processor import StreamProcessor, StreamSmoothingConfig from src.api.handlers.base.stream_processor import StreamProcessor
from src.api.handlers.base.stream_telemetry import StreamTelemetryRecorder from src.api.handlers.base.stream_telemetry import StreamTelemetryRecorder
from src.api.handlers.base.utils import build_sse_headers from src.api.handlers.base.utils import build_sse_headers
from src.config.settings import config from src.config.settings import config
@@ -52,7 +52,6 @@ from src.models.database import (
User, User,
) )
from src.services.provider.transport import build_provider_url from src.services.provider.transport import build_provider_url
from src.services.system.config import SystemConfigService
@@ -298,18 +297,11 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
def update_streaming_status() -> None: def update_streaming_status() -> None:
self._update_usage_to_streaming_with_ctx(ctx) self._update_usage_to_streaming_with_ctx(ctx)
# 读取流式平滑输出开关
smoothing_enabled = bool(
SystemConfigService.get_config(self.db, "stream_smoothing_enabled", False)
)
smoothing_config = StreamSmoothingConfig(enabled=smoothing_enabled)
# 创建流处理器 # 创建流处理器
stream_processor = StreamProcessor( stream_processor = StreamProcessor(
request_id=self.request_id, request_id=self.request_id,
default_parser=self.parser, default_parser=self.parser,
on_streaming_start=update_streaming_status, on_streaming_start=update_streaming_status,
smoothing_config=smoothing_config,
) )
# 定义请求函数 # 定义请求函数
@@ -387,11 +379,8 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
http_request.is_disconnected, http_request.is_disconnected,
) )
# 创建平滑输出流(如果启用)
smoothed_stream = stream_processor.create_smoothed_stream(monitored_stream)
return StreamingResponse( return StreamingResponse(
smoothed_stream, monitored_stream,
media_type="text/event-stream", media_type="text/event-stream",
headers=build_sse_headers(), headers=build_sse_headers(),
background=background_tasks, background=background_tasks,

View File

@@ -34,9 +34,7 @@ from src.api.handlers.base.base_handler import (
from src.api.handlers.base.parsers import get_parser_for_format from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.request_builder import PassthroughRequestBuilder from src.api.handlers.base.request_builder import PassthroughRequestBuilder
from src.api.handlers.base.stream_context import StreamContext from src.api.handlers.base.stream_context import StreamContext
from src.api.handlers.base.stream_processor import create_smoothed_stream
from src.api.handlers.base.utils import build_sse_headers from src.api.handlers.base.utils import build_sse_headers
from src.services.system.config import SystemConfigService
# 直接从具体模块导入,避免循环依赖 # 直接从具体模块导入,避免循环依赖
from src.api.handlers.base.response_parser import ( from src.api.handlers.base.response_parser import (
@@ -354,17 +352,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 创建监控流 # 创建监控流
monitored_stream = self._create_monitored_stream(ctx, stream_generator) monitored_stream = self._create_monitored_stream(ctx, stream_generator)
# 创建平滑输出流(如果启用)
smoothing_enabled = bool(
SystemConfigService.get_config(self.db, "stream_smoothing_enabled", False)
)
if smoothing_enabled:
final_stream = create_smoothed_stream(monitored_stream)
else:
final_stream = monitored_stream
return StreamingResponse( return StreamingResponse(
final_stream, monitored_stream,
media_type="text/event-stream", media_type="text/event-stream",
headers=build_sse_headers(), headers=build_sse_headers(),
background=background_tasks, background=background_tasks,

View File

@@ -12,7 +12,6 @@
import asyncio import asyncio
import codecs import codecs
import json import json
import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, AsyncGenerator, Callable, Optional from typing import Any, AsyncGenerator, Callable, Optional
@@ -37,6 +36,8 @@ class StreamSmoothingConfig:
"""流式平滑输出配置""" """流式平滑输出配置"""
enabled: bool = False enabled: bool = False
chunk_size: int = 20
delay_ms: int = 8
class StreamProcessor: class StreamProcessor:
@@ -47,13 +48,6 @@ class StreamProcessor:
从 ChatHandlerBase 中提取,使其职责更加单一。 从 ChatHandlerBase 中提取,使其职责更加单一。
""" """
# 平滑输出参数
CHUNK_SIZE = 20 # 每块字符数
MIN_DELAY_MS = 8 # 长文本延迟(毫秒)
MAX_DELAY_MS = 15 # 短文本延迟(毫秒)
SHORT_TEXT_THRESHOLD = 20 # 短文本阈值
LONG_TEXT_THRESHOLD = 100 # 长文本阈值
def __init__( def __init__(
self, self,
request_id: str, request_id: str,
@@ -548,10 +542,10 @@ class StreamProcessor:
# 只有内容长度大于 1 才需要平滑处理 # 只有内容长度大于 1 才需要平滑处理
if content and len(content) > 1 and extractor: if content and len(content) > 1 and extractor:
# 计算动态延迟 # 获取配置的延迟
delay_seconds = self._calculate_delay(len(content)) delay_seconds = self._calculate_delay()
# 智能拆分 # 拆分内容
content_chunks = self._split_content(content) content_chunks = self._split_content(content)
for i, sub_content in enumerate(content_chunks): for i, sub_content in enumerate(content_chunks):
@@ -610,40 +604,24 @@ class StreamProcessor:
return None, None return None, None
def _calculate_delay(self, text_length: int) -> float: def _calculate_delay(self) -> float:
""" """获取配置的延迟(秒)"""
根据文本长度计算动态延迟(秒) return self.smoothing_config.delay_ms / 1000.0
短文本使用较大延迟(打字感更强),长文本使用较小延迟(避免卡顿)。
中间长度使用对数插值平滑过渡。
"""
if text_length <= self.SHORT_TEXT_THRESHOLD:
return self.MAX_DELAY_MS / 1000.0
if text_length >= self.LONG_TEXT_THRESHOLD:
return self.MIN_DELAY_MS / 1000.0
# 对数插值:平滑过渡
ratio = math.log(text_length / self.SHORT_TEXT_THRESHOLD) / math.log(
self.LONG_TEXT_THRESHOLD / self.SHORT_TEXT_THRESHOLD
)
delay_ms = self.MAX_DELAY_MS - ratio * (self.MAX_DELAY_MS - self.MIN_DELAY_MS)
return delay_ms / 1000.0
def _split_content(self, content: str) -> list[str]: def _split_content(self, content: str) -> list[str]:
""" """
按块拆分文本 按块拆分文本
统一使用 CHUNK_SIZE 拆分,通过动态延迟控制打字感。
""" """
chunk_size = self.smoothing_config.chunk_size
text_length = len(content) text_length = len(content)
if text_length <= self.CHUNK_SIZE: if text_length <= chunk_size:
return [content] return [content]
# 统一按块拆分 # 按块拆分
chunks = [] chunks = []
for i in range(0, text_length, self.CHUNK_SIZE): for i in range(0, text_length, chunk_size):
chunks.append(content[i : i + self.CHUNK_SIZE]) chunks.append(content[i : i + chunk_size])
return chunks return chunks
async def _cleanup( async def _cleanup(
@@ -664,6 +642,8 @@ class StreamProcessor:
async def create_smoothed_stream( async def create_smoothed_stream(
stream_generator: AsyncGenerator[bytes, None], stream_generator: AsyncGenerator[bytes, None],
chunk_size: int = 20,
delay_ms: int = 8,
) -> AsyncGenerator[bytes, None]: ) -> AsyncGenerator[bytes, None]:
""" """
独立的平滑流生成函数 独立的平滑流生成函数
@@ -672,11 +652,13 @@ async def create_smoothed_stream(
Args: Args:
stream_generator: 原始流生成器 stream_generator: 原始流生成器
chunk_size: 每块字符数
delay_ms: 每块之间的延迟毫秒数
Yields: Yields:
平滑处理后的响应数据块 平滑处理后的响应数据块
""" """
processor = _LightweightSmoother() processor = _LightweightSmoother(chunk_size=chunk_size, delay_ms=delay_ms)
async for chunk in processor.smooth(stream_generator): async for chunk in processor.smooth(stream_generator):
yield chunk yield chunk
@@ -688,13 +670,9 @@ class _LightweightSmoother:
只包含平滑输出所需的最小逻辑,不依赖 StreamProcessor 的其他功能。 只包含平滑输出所需的最小逻辑,不依赖 StreamProcessor 的其他功能。
""" """
CHUNK_SIZE = 20 def __init__(self, chunk_size: int = 20, delay_ms: int = 8) -> None:
MIN_DELAY_MS = 8 self.chunk_size = chunk_size
MAX_DELAY_MS = 15 self.delay_ms = delay_ms
SHORT_TEXT_THRESHOLD = 20
LONG_TEXT_THRESHOLD = 100
def __init__(self) -> None:
self._extractors: dict[str, ContentExtractor] = {} self._extractors: dict[str, ContentExtractor] = {}
def _get_extractor(self, format_name: str) -> Optional[ContentExtractor]: def _get_extractor(self, format_name: str) -> Optional[ContentExtractor]:
@@ -715,21 +693,14 @@ class _LightweightSmoother:
return content, extractor return content, extractor
return None, None return None, None
def _calculate_delay(self, text_length: int) -> float: def _calculate_delay(self) -> float:
if text_length <= self.SHORT_TEXT_THRESHOLD: return self.delay_ms / 1000.0
return self.MAX_DELAY_MS / 1000.0
if text_length >= self.LONG_TEXT_THRESHOLD:
return self.MIN_DELAY_MS / 1000.0
ratio = math.log(text_length / self.SHORT_TEXT_THRESHOLD) / math.log(
self.LONG_TEXT_THRESHOLD / self.SHORT_TEXT_THRESHOLD
)
return (self.MAX_DELAY_MS - ratio * (self.MAX_DELAY_MS - self.MIN_DELAY_MS)) / 1000.0
def _split_content(self, content: str) -> list[str]: def _split_content(self, content: str) -> list[str]:
text_length = len(content) text_length = len(content)
if text_length <= self.CHUNK_SIZE: if text_length <= self.chunk_size:
return [content] return [content]
return [content[i : i + self.CHUNK_SIZE] for i in range(0, text_length, self.CHUNK_SIZE)] return [content[i : i + self.chunk_size] for i in range(0, text_length, self.chunk_size)]
async def smooth( async def smooth(
self, stream_generator: AsyncGenerator[bytes, None] self, stream_generator: AsyncGenerator[bytes, None]
@@ -772,7 +743,7 @@ class _LightweightSmoother:
content, extractor = self._detect_format_and_extract(data) content, extractor = self._detect_format_and_extract(data)
if content and len(content) > 1 and extractor: if content and len(content) > 1 and extractor:
delay_seconds = self._calculate_delay(len(content)) delay_seconds = self._calculate_delay()
content_chunks = self._split_content(content) content_chunks = self._split_content(content)
for i, sub_content in enumerate(content_chunks): for i, sub_content in enumerate(content_chunks):

View File

@@ -13,6 +13,7 @@ from src.core.exceptions import InvalidRequestException, NotFoundException
from src.core.logger import logger from src.core.logger import logger
from src.models.api import ModelCreate, ModelResponse, ModelUpdate from src.models.api import ModelCreate, ModelResponse, ModelUpdate
from src.models.database import Model, Provider from src.models.database import Model, Provider
from src.api.base.models_service import invalidate_models_list_cache
from src.services.cache.invalidation import get_cache_invalidation_service from src.services.cache.invalidation import get_cache_invalidation_service
from src.services.cache.model_cache import ModelCacheService from src.services.cache.model_cache import ModelCacheService
@@ -75,6 +76,10 @@ class ModelService:
) )
logger.info(f"创建模型成功: provider={provider.name}, model={model.provider_model_name}, global_model_id={model.global_model_id}") logger.info(f"创建模型成功: provider={provider.name}, model={model.provider_model_name}, global_model_id={model.global_model_id}")
# 清除 /v1/models 列表缓存
asyncio.create_task(invalidate_models_list_cache())
return model return model
except IntegrityError as e: except IntegrityError as e:
@@ -197,6 +202,9 @@ class ModelService:
cache_service = get_cache_invalidation_service() cache_service = get_cache_invalidation_service()
cache_service.on_model_changed(model.provider_id, model.global_model_id) cache_service.on_model_changed(model.provider_id, model.global_model_id)
# 清除 /v1/models 列表缓存
asyncio.create_task(invalidate_models_list_cache())
logger.info(f"更新模型成功: id={model_id}, 最终 supports_vision: {model.supports_vision}, supports_function_calling: {model.supports_function_calling}, supports_extended_thinking: {model.supports_extended_thinking}") logger.info(f"更新模型成功: id={model_id}, 最终 supports_vision: {model.supports_vision}, supports_function_calling: {model.supports_function_calling}, supports_extended_thinking: {model.supports_extended_thinking}")
return model return model
except IntegrityError as e: except IntegrityError as e:
@@ -261,6 +269,9 @@ class ModelService:
cache_service = get_cache_invalidation_service() cache_service = get_cache_invalidation_service()
cache_service.on_model_changed(cache_info["provider_id"], cache_info["global_model_id"]) cache_service.on_model_changed(cache_info["provider_id"], cache_info["global_model_id"])
# 清除 /v1/models 列表缓存
asyncio.create_task(invalidate_models_list_cache())
logger.info(f"删除模型成功: id={model_id}, provider_model_name={cache_info['provider_model_name']}, " logger.info(f"删除模型成功: id={model_id}, provider_model_name={cache_info['provider_model_name']}, "
f"global_model_id={cache_info['global_model_id'][:8] if cache_info['global_model_id'] else 'None'}...") f"global_model_id={cache_info['global_model_id'][:8] if cache_info['global_model_id'] else 'None'}...")
except Exception as e: except Exception as e:
@@ -295,6 +306,9 @@ class ModelService:
cache_service = get_cache_invalidation_service() cache_service = get_cache_invalidation_service()
cache_service.on_model_changed(model.provider_id, model.global_model_id) cache_service.on_model_changed(model.provider_id, model.global_model_id)
# 清除 /v1/models 列表缓存
asyncio.create_task(invalidate_models_list_cache())
status = "可用" if is_available else "不可用" status = "可用" if is_available else "不可用"
logger.info(f"更新模型可用状态: id={model_id}, status={status}") logger.info(f"更新模型可用状态: id={model_id}, status={status}")
return model return model
@@ -358,6 +372,9 @@ class ModelService:
for model in created_models: for model in created_models:
db.refresh(model) db.refresh(model)
logger.info(f"批量创建 {len(created_models)} 个模型成功") logger.info(f"批量创建 {len(created_models)} 个模型成功")
# 清除 /v1/models 列表缓存
asyncio.create_task(invalidate_models_list_cache())
except IntegrityError as e: except IntegrityError as e:
db.rollback() db.rollback()
logger.error(f"批量创建模型失败: {str(e)}") logger.error(f"批量创建模型失败: {str(e)}")

View File

@@ -83,6 +83,14 @@ class SystemConfigService:
"value": False, "value": False,
"description": "是否启用流式平滑输出,自动根据文本长度调整输出速度", "description": "是否启用流式平滑输出,自动根据文本长度调整输出速度",
}, },
"stream_smoothing_chunk_size": {
"value": 20,
"description": "流式平滑输出每个小块的字符数",
},
"stream_smoothing_delay_ms": {
"value": 8,
"description": "流式平滑输出每个小块之间的延迟毫秒数",
},
} }
@classmethod @classmethod