mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-04 00:32:26 +08:00
Compare commits
12 Commits
5f0c1fb347
...
v0.1.14
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f22a073fd9 | ||
|
|
5c7ad089d2 | ||
|
|
97425ac68f | ||
|
|
912f6643e2 | ||
|
|
6c0373fda6 | ||
|
|
070121717d | ||
|
|
85fafeacb8 | ||
|
|
daf8b870f0 | ||
|
|
880fb61c66 | ||
|
|
7e792dabfc | ||
|
|
cd06169b2f | ||
|
|
50ffd47546 |
@@ -70,6 +70,8 @@ RUN printf '%s\n' \
|
|||||||
' proxy_cache off;' \
|
' proxy_cache off;' \
|
||||||
' proxy_request_buffering off;' \
|
' proxy_request_buffering off;' \
|
||||||
' chunked_transfer_encoding on;' \
|
' chunked_transfer_encoding on;' \
|
||||||
|
' gzip off;' \
|
||||||
|
' add_header X-Accel-Buffering no;' \
|
||||||
' proxy_connect_timeout 600s;' \
|
' proxy_connect_timeout 600s;' \
|
||||||
' proxy_send_timeout 600s;' \
|
' proxy_send_timeout 600s;' \
|
||||||
' proxy_read_timeout 600s;' \
|
' proxy_read_timeout 600s;' \
|
||||||
|
|||||||
@@ -74,6 +74,8 @@ RUN printf '%s\n' \
|
|||||||
' proxy_cache off;' \
|
' proxy_cache off;' \
|
||||||
' proxy_request_buffering off;' \
|
' proxy_request_buffering off;' \
|
||||||
' chunked_transfer_encoding on;' \
|
' chunked_transfer_encoding on;' \
|
||||||
|
' gzip off;' \
|
||||||
|
' add_header X-Accel-Buffering no;' \
|
||||||
' proxy_connect_timeout 600s;' \
|
' proxy_connect_timeout 600s;' \
|
||||||
' proxy_send_timeout 600s;' \
|
' proxy_send_timeout 600s;' \
|
||||||
' proxy_read_timeout 600s;' \
|
' proxy_read_timeout 600s;' \
|
||||||
|
|||||||
13
deploy.sh
13
deploy.sh
@@ -21,9 +21,9 @@ HASH_FILE=".deps-hash"
|
|||||||
CODE_HASH_FILE=".code-hash"
|
CODE_HASH_FILE=".code-hash"
|
||||||
MIGRATION_HASH_FILE=".migration-hash"
|
MIGRATION_HASH_FILE=".migration-hash"
|
||||||
|
|
||||||
# 计算依赖文件的哈希值
|
# 计算依赖文件的哈希值(包含 Dockerfile.base.local)
|
||||||
calc_deps_hash() {
|
calc_deps_hash() {
|
||||||
cat pyproject.toml frontend/package.json frontend/package-lock.json 2>/dev/null | md5sum | cut -d' ' -f1
|
cat pyproject.toml frontend/package.json frontend/package-lock.json Dockerfile.base.local 2>/dev/null | md5sum | cut -d' ' -f1
|
||||||
}
|
}
|
||||||
|
|
||||||
# 计算代码文件的哈希值
|
# 计算代码文件的哈希值
|
||||||
@@ -162,25 +162,32 @@ git pull
|
|||||||
|
|
||||||
# 标记是否需要重启
|
# 标记是否需要重启
|
||||||
NEED_RESTART=false
|
NEED_RESTART=false
|
||||||
|
BASE_REBUILT=false
|
||||||
|
|
||||||
# 检查基础镜像是否存在,或依赖是否变化
|
# 检查基础镜像是否存在,或依赖是否变化
|
||||||
if ! docker image inspect aether-base:latest >/dev/null 2>&1; then
|
if ! docker image inspect aether-base:latest >/dev/null 2>&1; then
|
||||||
echo ">>> Base image not found, building..."
|
echo ">>> Base image not found, building..."
|
||||||
build_base
|
build_base
|
||||||
|
BASE_REBUILT=true
|
||||||
NEED_RESTART=true
|
NEED_RESTART=true
|
||||||
elif check_deps_changed; then
|
elif check_deps_changed; then
|
||||||
echo ">>> Dependencies changed, rebuilding base image..."
|
echo ">>> Dependencies changed, rebuilding base image..."
|
||||||
build_base
|
build_base
|
||||||
|
BASE_REBUILT=true
|
||||||
NEED_RESTART=true
|
NEED_RESTART=true
|
||||||
else
|
else
|
||||||
echo ">>> Dependencies unchanged."
|
echo ">>> Dependencies unchanged."
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# 检查代码是否变化
|
# 检查代码是否变化,或者 base 重建了(app 依赖 base)
|
||||||
if ! docker image inspect aether-app:latest >/dev/null 2>&1; then
|
if ! docker image inspect aether-app:latest >/dev/null 2>&1; then
|
||||||
echo ">>> App image not found, building..."
|
echo ">>> App image not found, building..."
|
||||||
build_app
|
build_app
|
||||||
NEED_RESTART=true
|
NEED_RESTART=true
|
||||||
|
elif [ "$BASE_REBUILT" = true ]; then
|
||||||
|
echo ">>> Base image rebuilt, rebuilding app image..."
|
||||||
|
build_app
|
||||||
|
NEED_RESTART=true
|
||||||
elif check_code_changed; then
|
elif check_code_changed; then
|
||||||
echo ">>> Code changed, rebuilding app image..."
|
echo ">>> Code changed, rebuilding app image..."
|
||||||
build_app
|
build_app
|
||||||
|
|||||||
@@ -464,6 +464,78 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</CardSection>
|
</CardSection>
|
||||||
|
|
||||||
|
<!-- 流式输出配置 -->
|
||||||
|
<CardSection
|
||||||
|
title="流式输出"
|
||||||
|
description="配置流式响应的输出效果"
|
||||||
|
>
|
||||||
|
<div class="grid grid-cols-1 md:grid-cols-2 gap-6">
|
||||||
|
<div class="md:col-span-2">
|
||||||
|
<div class="flex items-center space-x-2">
|
||||||
|
<Checkbox
|
||||||
|
id="stream-smoothing-enabled"
|
||||||
|
v-model:checked="systemConfig.stream_smoothing_enabled"
|
||||||
|
/>
|
||||||
|
<div>
|
||||||
|
<Label
|
||||||
|
for="stream-smoothing-enabled"
|
||||||
|
class="cursor-pointer"
|
||||||
|
>
|
||||||
|
启用平滑输出
|
||||||
|
</Label>
|
||||||
|
<p class="text-xs text-muted-foreground">
|
||||||
|
将上游返回的大块内容拆分成小块,模拟打字效果
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<Label
|
||||||
|
for="stream-smoothing-chunk-size"
|
||||||
|
class="block text-sm font-medium"
|
||||||
|
>
|
||||||
|
每块字符数
|
||||||
|
</Label>
|
||||||
|
<Input
|
||||||
|
id="stream-smoothing-chunk-size"
|
||||||
|
v-model.number="systemConfig.stream_smoothing_chunk_size"
|
||||||
|
type="number"
|
||||||
|
min="1"
|
||||||
|
max="100"
|
||||||
|
placeholder="20"
|
||||||
|
class="mt-1"
|
||||||
|
:disabled="!systemConfig.stream_smoothing_enabled"
|
||||||
|
/>
|
||||||
|
<p class="mt-1 text-xs text-muted-foreground">
|
||||||
|
每次输出的字符数量(1-100)
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<Label
|
||||||
|
for="stream-smoothing-delay-ms"
|
||||||
|
class="block text-sm font-medium"
|
||||||
|
>
|
||||||
|
输出间隔 (毫秒)
|
||||||
|
</Label>
|
||||||
|
<Input
|
||||||
|
id="stream-smoothing-delay-ms"
|
||||||
|
v-model.number="systemConfig.stream_smoothing_delay_ms"
|
||||||
|
type="number"
|
||||||
|
min="1"
|
||||||
|
max="100"
|
||||||
|
placeholder="8"
|
||||||
|
class="mt-1"
|
||||||
|
:disabled="!systemConfig.stream_smoothing_enabled"
|
||||||
|
/>
|
||||||
|
<p class="mt-1 text-xs text-muted-foreground">
|
||||||
|
每块之间的延迟毫秒数(1-100)
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</CardSection>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- 导入配置对话框 -->
|
<!-- 导入配置对话框 -->
|
||||||
@@ -812,6 +884,10 @@ interface SystemConfig {
|
|||||||
log_retention_days: number
|
log_retention_days: number
|
||||||
cleanup_batch_size: number
|
cleanup_batch_size: number
|
||||||
audit_log_retention_days: number
|
audit_log_retention_days: number
|
||||||
|
// 流式输出
|
||||||
|
stream_smoothing_enabled: boolean
|
||||||
|
stream_smoothing_chunk_size: number
|
||||||
|
stream_smoothing_delay_ms: number
|
||||||
}
|
}
|
||||||
|
|
||||||
const loading = ref(false)
|
const loading = ref(false)
|
||||||
@@ -861,6 +937,10 @@ const systemConfig = ref<SystemConfig>({
|
|||||||
log_retention_days: 365,
|
log_retention_days: 365,
|
||||||
cleanup_batch_size: 1000,
|
cleanup_batch_size: 1000,
|
||||||
audit_log_retention_days: 30,
|
audit_log_retention_days: 30,
|
||||||
|
// 流式输出
|
||||||
|
stream_smoothing_enabled: false,
|
||||||
|
stream_smoothing_chunk_size: 20,
|
||||||
|
stream_smoothing_delay_ms: 8,
|
||||||
})
|
})
|
||||||
|
|
||||||
// 计算属性:KB 和 字节 之间的转换
|
// 计算属性:KB 和 字节 之间的转换
|
||||||
@@ -917,6 +997,10 @@ async function loadSystemConfig() {
|
|||||||
'log_retention_days',
|
'log_retention_days',
|
||||||
'cleanup_batch_size',
|
'cleanup_batch_size',
|
||||||
'audit_log_retention_days',
|
'audit_log_retention_days',
|
||||||
|
// 流式输出
|
||||||
|
'stream_smoothing_enabled',
|
||||||
|
'stream_smoothing_chunk_size',
|
||||||
|
'stream_smoothing_delay_ms',
|
||||||
]
|
]
|
||||||
|
|
||||||
for (const key of configs) {
|
for (const key of configs) {
|
||||||
@@ -1024,6 +1108,22 @@ async function saveSystemConfig() {
|
|||||||
value: systemConfig.value.audit_log_retention_days,
|
value: systemConfig.value.audit_log_retention_days,
|
||||||
description: '审计日志保留天数'
|
description: '审计日志保留天数'
|
||||||
},
|
},
|
||||||
|
// 流式输出
|
||||||
|
{
|
||||||
|
key: 'stream_smoothing_enabled',
|
||||||
|
value: systemConfig.value.stream_smoothing_enabled,
|
||||||
|
description: '是否启用流式平滑输出'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
key: 'stream_smoothing_chunk_size',
|
||||||
|
value: systemConfig.value.stream_smoothing_chunk_size,
|
||||||
|
description: '流式平滑输出每个小块的字符数'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
key: 'stream_smoothing_delay_ms',
|
||||||
|
value: systemConfig.value.stream_smoothing_delay_ms,
|
||||||
|
description: '流式平滑输出每个小块之间的延迟毫秒数'
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
const promises = configItems.map(item =>
|
const promises = configItems.map(item =>
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from fastapi import APIRouter, Depends, Request
|
|||||||
from sqlalchemy.orm import Session, joinedload
|
from sqlalchemy.orm import Session, joinedload
|
||||||
|
|
||||||
from src.api.base.admin_adapter import AdminApiAdapter
|
from src.api.base.admin_adapter import AdminApiAdapter
|
||||||
|
from src.api.base.models_service import invalidate_models_list_cache
|
||||||
from src.api.base.pipeline import ApiRequestPipeline
|
from src.api.base.pipeline import ApiRequestPipeline
|
||||||
from src.core.exceptions import InvalidRequestException, NotFoundException
|
from src.core.exceptions import InvalidRequestException, NotFoundException
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
@@ -419,4 +420,8 @@ class AdminBatchAssignModelsToProviderAdapter(AdminApiAdapter):
|
|||||||
f"Batch assigned {len(success)} GlobalModels to provider {provider.name} by {context.user.username}"
|
f"Batch assigned {len(success)} GlobalModels to provider {provider.name} by {context.user.username}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 清除 /v1/models 列表缓存
|
||||||
|
if success:
|
||||||
|
await invalidate_models_list_cache()
|
||||||
|
|
||||||
return BatchAssignModelsToProviderResponse(success=success, errors=errors)
|
return BatchAssignModelsToProviderResponse(success=success, errors=errors)
|
||||||
|
|||||||
@@ -55,6 +55,23 @@ async def _set_cached_models(api_formats: list[str], models: list["ModelInfo"])
|
|||||||
logger.warning(f"[ModelsService] 缓存写入失败: {e}")
|
logger.warning(f"[ModelsService] 缓存写入失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def invalidate_models_list_cache() -> None:
|
||||||
|
"""
|
||||||
|
清除所有 /v1/models 列表缓存
|
||||||
|
|
||||||
|
在模型创建、更新、删除时调用,确保模型列表实时更新
|
||||||
|
"""
|
||||||
|
# 清除所有格式的缓存
|
||||||
|
all_formats = ["CLAUDE", "OPENAI", "GEMINI"]
|
||||||
|
for fmt in all_formats:
|
||||||
|
cache_key = f"{_CACHE_KEY_PREFIX}:{fmt}"
|
||||||
|
try:
|
||||||
|
await CacheService.delete(cache_key)
|
||||||
|
logger.debug(f"[ModelsService] 已清除缓存: {cache_key}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[ModelsService] 清除缓存失败 {cache_key}: {e}")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelInfo:
|
class ModelInfo:
|
||||||
"""统一的模型信息结构"""
|
"""统一的模型信息结构"""
|
||||||
|
|||||||
@@ -1114,8 +1114,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
async for chunk in stream_generator:
|
async for chunk in stream_generator:
|
||||||
yield chunk
|
yield chunk
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
ctx.status_code = 499
|
# 如果响应已完成,不标记为失败
|
||||||
ctx.error_message = "Client disconnected"
|
if not ctx.has_completion:
|
||||||
|
ctx.status_code = 499
|
||||||
|
ctx.error_message = "Client disconnected"
|
||||||
raise
|
raise
|
||||||
except httpx.TimeoutException as e:
|
except httpx.TimeoutException as e:
|
||||||
ctx.status_code = 504
|
ctx.status_code = 504
|
||||||
|
|||||||
274
src/api/handlers/base/content_extractors.py
Normal file
274
src/api/handlers/base/content_extractors.py
Normal file
@@ -0,0 +1,274 @@
|
|||||||
|
"""
|
||||||
|
流式内容提取器 - 策略模式实现
|
||||||
|
|
||||||
|
为不同 API 格式(OpenAI、Claude、Gemini)提供内容提取和 chunk 构造的抽象。
|
||||||
|
StreamSmoother 使用这些提取器来处理不同格式的 SSE 事件。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class ContentExtractor(ABC):
|
||||||
|
"""
|
||||||
|
流式内容提取器抽象基类
|
||||||
|
|
||||||
|
定义从 SSE 事件中提取文本内容和构造新 chunk 的接口。
|
||||||
|
每种 API 格式(OpenAI、Claude、Gemini)需要实现自己的提取器。
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def extract_content(self, data: dict) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
从 SSE 数据中提取可拆分的文本内容
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 解析后的 JSON 数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
提取的文本内容,如果无法提取则返回 None
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_chunk(
|
||||||
|
self,
|
||||||
|
original_data: dict,
|
||||||
|
new_content: str,
|
||||||
|
event_type: str = "",
|
||||||
|
is_first: bool = False,
|
||||||
|
) -> bytes:
|
||||||
|
"""
|
||||||
|
使用新内容构造 SSE chunk
|
||||||
|
|
||||||
|
Args:
|
||||||
|
original_data: 原始 JSON 数据
|
||||||
|
new_content: 新的文本内容
|
||||||
|
event_type: SSE 事件类型(某些格式需要)
|
||||||
|
is_first: 是否是第一个 chunk(用于保留 role 等字段)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
编码后的 SSE 字节数据
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIContentExtractor(ContentExtractor):
|
||||||
|
"""
|
||||||
|
OpenAI 格式内容提取器
|
||||||
|
|
||||||
|
处理 OpenAI Chat Completions API 的流式响应格式:
|
||||||
|
- 数据结构: choices[0].delta.content
|
||||||
|
- 只在 delta 仅包含 role/content 时允许拆分,避免破坏 tool_calls 等结构
|
||||||
|
"""
|
||||||
|
|
||||||
|
def extract_content(self, data: dict) -> Optional[str]:
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
choices = data.get("choices")
|
||||||
|
if not isinstance(choices, list) or len(choices) != 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
first_choice = choices[0]
|
||||||
|
if not isinstance(first_choice, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
delta = first_choice.get("delta")
|
||||||
|
if not isinstance(delta, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
content = delta.get("content")
|
||||||
|
if not isinstance(content, str):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 只有 delta 仅包含 role/content 时才允许拆分
|
||||||
|
# 避免破坏 tool_calls、function_call 等复杂结构
|
||||||
|
allowed_keys = {"role", "content"}
|
||||||
|
if not all(key in allowed_keys for key in delta.keys()):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
def create_chunk(
|
||||||
|
self,
|
||||||
|
original_data: dict,
|
||||||
|
new_content: str,
|
||||||
|
event_type: str = "",
|
||||||
|
is_first: bool = False,
|
||||||
|
) -> bytes:
|
||||||
|
new_data = original_data.copy()
|
||||||
|
|
||||||
|
if "choices" in new_data and new_data["choices"]:
|
||||||
|
new_choices = []
|
||||||
|
for choice in new_data["choices"]:
|
||||||
|
new_choice = choice.copy()
|
||||||
|
if "delta" in new_choice:
|
||||||
|
new_delta = {}
|
||||||
|
# 只有第一个 chunk 保留 role
|
||||||
|
if is_first and "role" in new_choice["delta"]:
|
||||||
|
new_delta["role"] = new_choice["delta"]["role"]
|
||||||
|
new_delta["content"] = new_content
|
||||||
|
new_choice["delta"] = new_delta
|
||||||
|
new_choices.append(new_choice)
|
||||||
|
new_data["choices"] = new_choices
|
||||||
|
|
||||||
|
return f"data: {json.dumps(new_data, ensure_ascii=False)}\n\n".encode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
class ClaudeContentExtractor(ContentExtractor):
|
||||||
|
"""
|
||||||
|
Claude 格式内容提取器
|
||||||
|
|
||||||
|
处理 Claude Messages API 的流式响应格式:
|
||||||
|
- 事件类型: content_block_delta
|
||||||
|
- 数据结构: delta.type=text_delta, delta.text
|
||||||
|
"""
|
||||||
|
|
||||||
|
def extract_content(self, data: dict) -> Optional[str]:
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 检查事件类型
|
||||||
|
if data.get("type") != "content_block_delta":
|
||||||
|
return None
|
||||||
|
|
||||||
|
delta = data.get("delta", {})
|
||||||
|
if not isinstance(delta, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 检查 delta 类型
|
||||||
|
if delta.get("type") != "text_delta":
|
||||||
|
return None
|
||||||
|
|
||||||
|
text = delta.get("text")
|
||||||
|
if not isinstance(text, str):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
def create_chunk(
|
||||||
|
self,
|
||||||
|
original_data: dict,
|
||||||
|
new_content: str,
|
||||||
|
event_type: str = "",
|
||||||
|
is_first: bool = False,
|
||||||
|
) -> bytes:
|
||||||
|
new_data = original_data.copy()
|
||||||
|
|
||||||
|
if "delta" in new_data:
|
||||||
|
new_delta = new_data["delta"].copy()
|
||||||
|
new_delta["text"] = new_content
|
||||||
|
new_data["delta"] = new_delta
|
||||||
|
|
||||||
|
# Claude 格式需要 event: 前缀
|
||||||
|
event_name = event_type or "content_block_delta"
|
||||||
|
return f"event: {event_name}\ndata: {json.dumps(new_data, ensure_ascii=False)}\n\n".encode(
|
||||||
|
"utf-8"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiContentExtractor(ContentExtractor):
|
||||||
|
"""
|
||||||
|
Gemini 格式内容提取器
|
||||||
|
|
||||||
|
处理 Gemini API 的流式响应格式:
|
||||||
|
- 数据结构: candidates[0].content.parts[0].text
|
||||||
|
- 只有纯文本块才拆分
|
||||||
|
"""
|
||||||
|
|
||||||
|
def extract_content(self, data: dict) -> Optional[str]:
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
candidates = data.get("candidates")
|
||||||
|
if not isinstance(candidates, list) or len(candidates) != 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
first_candidate = candidates[0]
|
||||||
|
if not isinstance(first_candidate, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
content = first_candidate.get("content", {})
|
||||||
|
if not isinstance(content, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
parts = content.get("parts", [])
|
||||||
|
if not isinstance(parts, list) or len(parts) != 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
first_part = parts[0]
|
||||||
|
if not isinstance(first_part, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
text = first_part.get("text")
|
||||||
|
# 只有纯文本块(只有 text 字段)才拆分
|
||||||
|
if not isinstance(text, str) or len(first_part) != 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
def create_chunk(
|
||||||
|
self,
|
||||||
|
original_data: dict,
|
||||||
|
new_content: str,
|
||||||
|
event_type: str = "",
|
||||||
|
is_first: bool = False,
|
||||||
|
) -> bytes:
|
||||||
|
new_data = copy.deepcopy(original_data)
|
||||||
|
|
||||||
|
if "candidates" in new_data and new_data["candidates"]:
|
||||||
|
first_candidate = new_data["candidates"][0]
|
||||||
|
if "content" in first_candidate:
|
||||||
|
content = first_candidate["content"]
|
||||||
|
if "parts" in content and content["parts"]:
|
||||||
|
content["parts"][0]["text"] = new_content
|
||||||
|
|
||||||
|
return f"data: {json.dumps(new_data, ensure_ascii=False)}\n\n".encode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
# 提取器注册表
|
||||||
|
_EXTRACTORS: dict[str, type[ContentExtractor]] = {
|
||||||
|
"openai": OpenAIContentExtractor,
|
||||||
|
"claude": ClaudeContentExtractor,
|
||||||
|
"gemini": GeminiContentExtractor,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_extractor(format_name: str) -> Optional[ContentExtractor]:
|
||||||
|
"""
|
||||||
|
根据格式名获取对应的内容提取器实例
|
||||||
|
|
||||||
|
Args:
|
||||||
|
format_name: 格式名称(openai, claude, gemini)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
对应的提取器实例,如果格式不支持则返回 None
|
||||||
|
"""
|
||||||
|
extractor_class = _EXTRACTORS.get(format_name.lower())
|
||||||
|
if extractor_class:
|
||||||
|
return extractor_class()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def register_extractor(format_name: str, extractor_class: type[ContentExtractor]) -> None:
|
||||||
|
"""
|
||||||
|
注册新的内容提取器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
format_name: 格式名称
|
||||||
|
extractor_class: 提取器类
|
||||||
|
"""
|
||||||
|
_EXTRACTORS[format_name.lower()] = extractor_class
|
||||||
|
|
||||||
|
|
||||||
|
def get_extractor_formats() -> list[str]:
|
||||||
|
"""
|
||||||
|
获取所有已注册的格式名称列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
格式名称列表
|
||||||
|
"""
|
||||||
|
return list(_EXTRACTORS.keys())
|
||||||
@@ -6,16 +6,22 @@
|
|||||||
2. 响应流生成
|
2. 响应流生成
|
||||||
3. 预读和嵌套错误检测
|
3. 预读和嵌套错误检测
|
||||||
4. 客户端断开检测
|
4. 客户端断开检测
|
||||||
|
5. 流式平滑输出
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import codecs
|
import codecs
|
||||||
import json
|
import json
|
||||||
import time
|
from dataclasses import dataclass
|
||||||
from typing import Any, AsyncGenerator, Callable, Optional
|
from typing import Any, AsyncGenerator, Callable, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
from src.api.handlers.base.content_extractors import (
|
||||||
|
ContentExtractor,
|
||||||
|
get_extractor,
|
||||||
|
get_extractor_formats,
|
||||||
|
)
|
||||||
from src.api.handlers.base.parsers import get_parser_for_format
|
from src.api.handlers.base.parsers import get_parser_for_format
|
||||||
from src.api.handlers.base.response_parser import ResponseParser
|
from src.api.handlers.base.response_parser import ResponseParser
|
||||||
from src.api.handlers.base.stream_context import StreamContext
|
from src.api.handlers.base.stream_context import StreamContext
|
||||||
@@ -25,11 +31,20 @@ from src.models.database import Provider, ProviderEndpoint
|
|||||||
from src.utils.sse_parser import SSEEventParser
|
from src.utils.sse_parser import SSEEventParser
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StreamSmoothingConfig:
|
||||||
|
"""流式平滑输出配置"""
|
||||||
|
|
||||||
|
enabled: bool = False
|
||||||
|
chunk_size: int = 20
|
||||||
|
delay_ms: int = 8
|
||||||
|
|
||||||
|
|
||||||
class StreamProcessor:
|
class StreamProcessor:
|
||||||
"""
|
"""
|
||||||
流式响应处理器
|
流式响应处理器
|
||||||
|
|
||||||
负责处理 SSE 流的解析、错误检测和响应生成。
|
负责处理 SSE 流的解析、错误检测、响应生成和平滑输出。
|
||||||
从 ChatHandlerBase 中提取,使其职责更加单一。
|
从 ChatHandlerBase 中提取,使其职责更加单一。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -40,6 +55,7 @@ class StreamProcessor:
|
|||||||
on_streaming_start: Optional[Callable[[], None]] = None,
|
on_streaming_start: Optional[Callable[[], None]] = None,
|
||||||
*,
|
*,
|
||||||
collect_text: bool = False,
|
collect_text: bool = False,
|
||||||
|
smoothing_config: Optional[StreamSmoothingConfig] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
初始化流处理器
|
初始化流处理器
|
||||||
@@ -48,11 +64,17 @@ class StreamProcessor:
|
|||||||
request_id: 请求 ID(用于日志)
|
request_id: 请求 ID(用于日志)
|
||||||
default_parser: 默认响应解析器
|
default_parser: 默认响应解析器
|
||||||
on_streaming_start: 流开始时的回调(用于更新状态)
|
on_streaming_start: 流开始时的回调(用于更新状态)
|
||||||
|
collect_text: 是否收集文本内容
|
||||||
|
smoothing_config: 流式平滑输出配置
|
||||||
"""
|
"""
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self.default_parser = default_parser
|
self.default_parser = default_parser
|
||||||
self.on_streaming_start = on_streaming_start
|
self.on_streaming_start = on_streaming_start
|
||||||
self.collect_text = collect_text
|
self.collect_text = collect_text
|
||||||
|
self.smoothing_config = smoothing_config or StreamSmoothingConfig()
|
||||||
|
|
||||||
|
# 内容提取器缓存
|
||||||
|
self._extractors: dict[str, ContentExtractor] = {}
|
||||||
|
|
||||||
def get_parser_for_provider(self, ctx: StreamContext) -> ResponseParser:
|
def get_parser_for_provider(self, ctx: StreamContext) -> ResponseParser:
|
||||||
"""
|
"""
|
||||||
@@ -127,6 +149,13 @@ class StreamProcessor:
|
|||||||
if event_type in ("response.completed", "message_stop"):
|
if event_type in ("response.completed", "message_stop"):
|
||||||
ctx.has_completion = True
|
ctx.has_completion = True
|
||||||
|
|
||||||
|
# 检查 OpenAI 格式的 finish_reason
|
||||||
|
choices = data.get("choices", [])
|
||||||
|
if choices and isinstance(choices, list) and len(choices) > 0:
|
||||||
|
finish_reason = choices[0].get("finish_reason")
|
||||||
|
if finish_reason is not None:
|
||||||
|
ctx.has_completion = True
|
||||||
|
|
||||||
async def prefetch_and_check_error(
|
async def prefetch_and_check_error(
|
||||||
self,
|
self,
|
||||||
byte_iterator: Any,
|
byte_iterator: Any,
|
||||||
@@ -369,7 +398,7 @@ class StreamProcessor:
|
|||||||
sse_parser: SSE 解析器
|
sse_parser: SSE 解析器
|
||||||
line: 原始行数据
|
line: 原始行数据
|
||||||
"""
|
"""
|
||||||
# SSEEventParser 以“去掉换行符”的单行文本作为输入;这里统一剔除 CR/LF,
|
# SSEEventParser 以"去掉换行符"的单行文本作为输入;这里统一剔除 CR/LF,
|
||||||
# 避免把空行误判成 "\n" 并导致事件边界解析错误。
|
# 避免把空行误判成 "\n" 并导致事件边界解析错误。
|
||||||
normalized_line = line.rstrip("\r\n")
|
normalized_line = line.rstrip("\r\n")
|
||||||
events = sse_parser.feed_line(normalized_line)
|
events = sse_parser.feed_line(normalized_line)
|
||||||
@@ -400,32 +429,201 @@ class StreamProcessor:
|
|||||||
响应数据块
|
响应数据块
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 断连检查频率:每次 await 都会引入调度开销,过于频繁会让流式"发一段停一段"
|
# 使用后台任务检查断连,完全不阻塞流式传输
|
||||||
# 这里按时间间隔节流,兼顾及时停止上游读取与吞吐平滑性。
|
disconnected = False
|
||||||
next_disconnect_check_at = 0.0
|
|
||||||
disconnect_check_interval_s = 0.25
|
|
||||||
|
|
||||||
async for chunk in stream_generator:
|
async def check_disconnect_background() -> None:
|
||||||
now = time.monotonic()
|
nonlocal disconnected
|
||||||
if now >= next_disconnect_check_at:
|
while not disconnected and not ctx.has_completion:
|
||||||
next_disconnect_check_at = now + disconnect_check_interval_s
|
await asyncio.sleep(0.5)
|
||||||
if await is_disconnected():
|
if await is_disconnected():
|
||||||
logger.warning(f"ID:{self.request_id} | Client disconnected")
|
disconnected = True
|
||||||
ctx.status_code = 499 # Client Closed Request
|
|
||||||
ctx.error_message = "client_disconnected"
|
|
||||||
|
|
||||||
break
|
break
|
||||||
yield chunk
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
ctx.status_code = 499
|
|
||||||
ctx.error_message = "client_disconnected"
|
|
||||||
|
|
||||||
|
# 启动后台检查任务
|
||||||
|
check_task = asyncio.create_task(check_disconnect_background())
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for chunk in stream_generator:
|
||||||
|
if disconnected:
|
||||||
|
# 如果响应已完成,客户端断开不算失败
|
||||||
|
if ctx.has_completion:
|
||||||
|
logger.info(
|
||||||
|
f"ID:{self.request_id} | Client disconnected after completion"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"ID:{self.request_id} | Client disconnected")
|
||||||
|
ctx.status_code = 499
|
||||||
|
ctx.error_message = "client_disconnected"
|
||||||
|
break
|
||||||
|
yield chunk
|
||||||
|
finally:
|
||||||
|
check_task.cancel()
|
||||||
|
try:
|
||||||
|
await check_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# 如果响应已完成,不标记为失败
|
||||||
|
if not ctx.has_completion:
|
||||||
|
ctx.status_code = 499
|
||||||
|
ctx.error_message = "client_disconnected"
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
ctx.status_code = 500
|
ctx.status_code = 500
|
||||||
ctx.error_message = str(e)
|
ctx.error_message = str(e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def create_smoothed_stream(
|
||||||
|
self,
|
||||||
|
stream_generator: AsyncGenerator[bytes, None],
|
||||||
|
) -> AsyncGenerator[bytes, None]:
|
||||||
|
"""
|
||||||
|
创建平滑输出的流生成器
|
||||||
|
|
||||||
|
如果启用了平滑输出,将大 chunk 拆分成小块并添加微小延迟。
|
||||||
|
否则直接透传原始流。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_generator: 原始流生成器
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
平滑处理后的响应数据块
|
||||||
|
"""
|
||||||
|
if not self.smoothing_config.enabled:
|
||||||
|
# 未启用平滑输出,直接透传
|
||||||
|
async for chunk in stream_generator:
|
||||||
|
yield chunk
|
||||||
|
return
|
||||||
|
|
||||||
|
# 启用平滑输出
|
||||||
|
buffer = b""
|
||||||
|
is_first_content = True
|
||||||
|
|
||||||
|
async for chunk in stream_generator:
|
||||||
|
buffer += chunk
|
||||||
|
|
||||||
|
# 按双换行分割 SSE 事件(标准 SSE 格式)
|
||||||
|
while b"\n\n" in buffer:
|
||||||
|
event_block, buffer = buffer.split(b"\n\n", 1)
|
||||||
|
event_str = event_block.decode("utf-8", errors="replace")
|
||||||
|
|
||||||
|
# 解析事件块
|
||||||
|
lines = event_str.strip().split("\n")
|
||||||
|
data_str = None
|
||||||
|
event_type = ""
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
line = line.rstrip("\r")
|
||||||
|
if line.startswith("event: "):
|
||||||
|
event_type = line[7:].strip()
|
||||||
|
elif line.startswith("data: "):
|
||||||
|
data_str = line[6:]
|
||||||
|
|
||||||
|
# 没有 data 行,直接透传
|
||||||
|
if data_str is None:
|
||||||
|
yield event_block + b"\n\n"
|
||||||
|
continue
|
||||||
|
|
||||||
|
# [DONE] 直接透传
|
||||||
|
if data_str.strip() == "[DONE]":
|
||||||
|
yield event_block + b"\n\n"
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 尝试解析 JSON
|
||||||
|
try:
|
||||||
|
data = json.loads(data_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
yield event_block + b"\n\n"
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检测格式并提取内容
|
||||||
|
content, extractor = self._detect_format_and_extract(data)
|
||||||
|
|
||||||
|
# 只有内容长度大于 1 才需要平滑处理
|
||||||
|
if content and len(content) > 1 and extractor:
|
||||||
|
# 获取配置的延迟
|
||||||
|
delay_seconds = self._calculate_delay()
|
||||||
|
|
||||||
|
# 拆分内容
|
||||||
|
content_chunks = self._split_content(content)
|
||||||
|
|
||||||
|
for i, sub_content in enumerate(content_chunks):
|
||||||
|
is_first = is_first_content and i == 0
|
||||||
|
|
||||||
|
# 使用提取器创建新 chunk
|
||||||
|
sse_chunk = extractor.create_chunk(
|
||||||
|
data,
|
||||||
|
sub_content,
|
||||||
|
event_type=event_type,
|
||||||
|
is_first=is_first,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield sse_chunk
|
||||||
|
|
||||||
|
# 除了最后一个块,其他块之间加延迟
|
||||||
|
if i < len(content_chunks) - 1:
|
||||||
|
await asyncio.sleep(delay_seconds)
|
||||||
|
|
||||||
|
is_first_content = False
|
||||||
|
else:
|
||||||
|
# 不需要拆分,直接透传
|
||||||
|
yield event_block + b"\n\n"
|
||||||
|
if content:
|
||||||
|
is_first_content = False
|
||||||
|
|
||||||
|
# 处理剩余数据
|
||||||
|
if buffer:
|
||||||
|
yield buffer
|
||||||
|
|
||||||
|
def _get_extractor(self, format_name: str) -> Optional[ContentExtractor]:
|
||||||
|
"""获取或创建格式对应的提取器(带缓存)"""
|
||||||
|
if format_name not in self._extractors:
|
||||||
|
extractor = get_extractor(format_name)
|
||||||
|
if extractor:
|
||||||
|
self._extractors[format_name] = extractor
|
||||||
|
return self._extractors.get(format_name)
|
||||||
|
|
||||||
|
def _detect_format_and_extract(
|
||||||
|
self, data: dict
|
||||||
|
) -> tuple[Optional[str], Optional[ContentExtractor]]:
|
||||||
|
"""
|
||||||
|
检测数据格式并提取内容
|
||||||
|
|
||||||
|
依次尝试各格式的提取器,返回第一个成功提取内容的结果。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(content, extractor): 提取的内容和对应的提取器
|
||||||
|
"""
|
||||||
|
for format_name in get_extractor_formats():
|
||||||
|
extractor = self._get_extractor(format_name)
|
||||||
|
if extractor:
|
||||||
|
content = extractor.extract_content(data)
|
||||||
|
if content is not None:
|
||||||
|
return content, extractor
|
||||||
|
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def _calculate_delay(self) -> float:
|
||||||
|
"""获取配置的延迟(秒)"""
|
||||||
|
return self.smoothing_config.delay_ms / 1000.0
|
||||||
|
|
||||||
|
def _split_content(self, content: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
按块拆分文本
|
||||||
|
"""
|
||||||
|
chunk_size = self.smoothing_config.chunk_size
|
||||||
|
text_length = len(content)
|
||||||
|
|
||||||
|
if text_length <= chunk_size:
|
||||||
|
return [content]
|
||||||
|
|
||||||
|
# 按块拆分
|
||||||
|
chunks = []
|
||||||
|
for i in range(0, text_length, chunk_size):
|
||||||
|
chunks.append(content[i : i + chunk_size])
|
||||||
|
return chunks
|
||||||
|
|
||||||
async def _cleanup(
|
async def _cleanup(
|
||||||
self,
|
self,
|
||||||
response_ctx: Any,
|
response_ctx: Any,
|
||||||
@@ -440,3 +638,128 @@ class StreamProcessor:
|
|||||||
await http_client.aclose()
|
await http_client.aclose()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def create_smoothed_stream(
|
||||||
|
stream_generator: AsyncGenerator[bytes, None],
|
||||||
|
chunk_size: int = 20,
|
||||||
|
delay_ms: int = 8,
|
||||||
|
) -> AsyncGenerator[bytes, None]:
|
||||||
|
"""
|
||||||
|
独立的平滑流生成函数
|
||||||
|
|
||||||
|
供 CLI handler 等场景使用,无需创建完整的 StreamProcessor 实例。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream_generator: 原始流生成器
|
||||||
|
chunk_size: 每块字符数
|
||||||
|
delay_ms: 每块之间的延迟毫秒数
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
平滑处理后的响应数据块
|
||||||
|
"""
|
||||||
|
processor = _LightweightSmoother(chunk_size=chunk_size, delay_ms=delay_ms)
|
||||||
|
async for chunk in processor.smooth(stream_generator):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
class _LightweightSmoother:
|
||||||
|
"""
|
||||||
|
轻量级平滑处理器
|
||||||
|
|
||||||
|
只包含平滑输出所需的最小逻辑,不依赖 StreamProcessor 的其他功能。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, chunk_size: int = 20, delay_ms: int = 8) -> None:
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
self.delay_ms = delay_ms
|
||||||
|
self._extractors: dict[str, ContentExtractor] = {}
|
||||||
|
|
||||||
|
def _get_extractor(self, format_name: str) -> Optional[ContentExtractor]:
|
||||||
|
if format_name not in self._extractors:
|
||||||
|
extractor = get_extractor(format_name)
|
||||||
|
if extractor:
|
||||||
|
self._extractors[format_name] = extractor
|
||||||
|
return self._extractors.get(format_name)
|
||||||
|
|
||||||
|
def _detect_format_and_extract(
|
||||||
|
self, data: dict
|
||||||
|
) -> tuple[Optional[str], Optional[ContentExtractor]]:
|
||||||
|
for format_name in get_extractor_formats():
|
||||||
|
extractor = self._get_extractor(format_name)
|
||||||
|
if extractor:
|
||||||
|
content = extractor.extract_content(data)
|
||||||
|
if content is not None:
|
||||||
|
return content, extractor
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def _calculate_delay(self) -> float:
|
||||||
|
return self.delay_ms / 1000.0
|
||||||
|
|
||||||
|
def _split_content(self, content: str) -> list[str]:
|
||||||
|
text_length = len(content)
|
||||||
|
if text_length <= self.chunk_size:
|
||||||
|
return [content]
|
||||||
|
return [content[i : i + self.chunk_size] for i in range(0, text_length, self.chunk_size)]
|
||||||
|
|
||||||
|
async def smooth(
|
||||||
|
self, stream_generator: AsyncGenerator[bytes, None]
|
||||||
|
) -> AsyncGenerator[bytes, None]:
|
||||||
|
buffer = b""
|
||||||
|
is_first_content = True
|
||||||
|
|
||||||
|
async for chunk in stream_generator:
|
||||||
|
buffer += chunk
|
||||||
|
|
||||||
|
while b"\n\n" in buffer:
|
||||||
|
event_block, buffer = buffer.split(b"\n\n", 1)
|
||||||
|
event_str = event_block.decode("utf-8", errors="replace")
|
||||||
|
|
||||||
|
lines = event_str.strip().split("\n")
|
||||||
|
data_str = None
|
||||||
|
event_type = ""
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
line = line.rstrip("\r")
|
||||||
|
if line.startswith("event: "):
|
||||||
|
event_type = line[7:].strip()
|
||||||
|
elif line.startswith("data: "):
|
||||||
|
data_str = line[6:]
|
||||||
|
|
||||||
|
if data_str is None:
|
||||||
|
yield event_block + b"\n\n"
|
||||||
|
continue
|
||||||
|
|
||||||
|
if data_str.strip() == "[DONE]":
|
||||||
|
yield event_block + b"\n\n"
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(data_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
yield event_block + b"\n\n"
|
||||||
|
continue
|
||||||
|
|
||||||
|
content, extractor = self._detect_format_and_extract(data)
|
||||||
|
|
||||||
|
if content and len(content) > 1 and extractor:
|
||||||
|
delay_seconds = self._calculate_delay()
|
||||||
|
content_chunks = self._split_content(content)
|
||||||
|
|
||||||
|
for i, sub_content in enumerate(content_chunks):
|
||||||
|
is_first = is_first_content and i == 0
|
||||||
|
sse_chunk = extractor.create_chunk(
|
||||||
|
data, sub_content, event_type=event_type, is_first=is_first
|
||||||
|
)
|
||||||
|
yield sse_chunk
|
||||||
|
if i < len(content_chunks) - 1:
|
||||||
|
await asyncio.sleep(delay_seconds)
|
||||||
|
|
||||||
|
is_first_content = False
|
||||||
|
else:
|
||||||
|
yield event_block + b"\n\n"
|
||||||
|
if content:
|
||||||
|
is_first_content = False
|
||||||
|
|
||||||
|
if buffer:
|
||||||
|
yield buffer
|
||||||
|
|||||||
@@ -143,6 +143,7 @@ class Config:
|
|||||||
# STREAM_STATS_DELAY: 统计记录延迟(秒),等待流完全关闭
|
# STREAM_STATS_DELAY: 统计记录延迟(秒),等待流完全关闭
|
||||||
self.stream_prefetch_lines = int(os.getenv("STREAM_PREFETCH_LINES", "5"))
|
self.stream_prefetch_lines = int(os.getenv("STREAM_PREFETCH_LINES", "5"))
|
||||||
self.stream_stats_delay = float(os.getenv("STREAM_STATS_DELAY", "0.1"))
|
self.stream_stats_delay = float(os.getenv("STREAM_STATS_DELAY", "0.1"))
|
||||||
|
# 注:流式平滑输出配置已移至数据库系统设置(stream_smoothing_*)
|
||||||
|
|
||||||
# 验证连接池配置
|
# 验证连接池配置
|
||||||
self._validate_pool_config()
|
self._validate_pool_config()
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from src.core.exceptions import InvalidRequestException, NotFoundException
|
|||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
from src.models.api import ModelCreate, ModelResponse, ModelUpdate
|
from src.models.api import ModelCreate, ModelResponse, ModelUpdate
|
||||||
from src.models.database import Model, Provider
|
from src.models.database import Model, Provider
|
||||||
|
from src.api.base.models_service import invalidate_models_list_cache
|
||||||
from src.services.cache.invalidation import get_cache_invalidation_service
|
from src.services.cache.invalidation import get_cache_invalidation_service
|
||||||
from src.services.cache.model_cache import ModelCacheService
|
from src.services.cache.model_cache import ModelCacheService
|
||||||
|
|
||||||
@@ -75,6 +76,10 @@ class ModelService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"创建模型成功: provider={provider.name}, model={model.provider_model_name}, global_model_id={model.global_model_id}")
|
logger.info(f"创建模型成功: provider={provider.name}, model={model.provider_model_name}, global_model_id={model.global_model_id}")
|
||||||
|
|
||||||
|
# 清除 /v1/models 列表缓存
|
||||||
|
asyncio.create_task(invalidate_models_list_cache())
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
@@ -197,6 +202,9 @@ class ModelService:
|
|||||||
cache_service = get_cache_invalidation_service()
|
cache_service = get_cache_invalidation_service()
|
||||||
cache_service.on_model_changed(model.provider_id, model.global_model_id)
|
cache_service.on_model_changed(model.provider_id, model.global_model_id)
|
||||||
|
|
||||||
|
# 清除 /v1/models 列表缓存
|
||||||
|
asyncio.create_task(invalidate_models_list_cache())
|
||||||
|
|
||||||
logger.info(f"更新模型成功: id={model_id}, 最终 supports_vision: {model.supports_vision}, supports_function_calling: {model.supports_function_calling}, supports_extended_thinking: {model.supports_extended_thinking}")
|
logger.info(f"更新模型成功: id={model_id}, 最终 supports_vision: {model.supports_vision}, supports_function_calling: {model.supports_function_calling}, supports_extended_thinking: {model.supports_extended_thinking}")
|
||||||
return model
|
return model
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
@@ -261,6 +269,9 @@ class ModelService:
|
|||||||
cache_service = get_cache_invalidation_service()
|
cache_service = get_cache_invalidation_service()
|
||||||
cache_service.on_model_changed(cache_info["provider_id"], cache_info["global_model_id"])
|
cache_service.on_model_changed(cache_info["provider_id"], cache_info["global_model_id"])
|
||||||
|
|
||||||
|
# 清除 /v1/models 列表缓存
|
||||||
|
asyncio.create_task(invalidate_models_list_cache())
|
||||||
|
|
||||||
logger.info(f"删除模型成功: id={model_id}, provider_model_name={cache_info['provider_model_name']}, "
|
logger.info(f"删除模型成功: id={model_id}, provider_model_name={cache_info['provider_model_name']}, "
|
||||||
f"global_model_id={cache_info['global_model_id'][:8] if cache_info['global_model_id'] else 'None'}...")
|
f"global_model_id={cache_info['global_model_id'][:8] if cache_info['global_model_id'] else 'None'}...")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -295,6 +306,9 @@ class ModelService:
|
|||||||
cache_service = get_cache_invalidation_service()
|
cache_service = get_cache_invalidation_service()
|
||||||
cache_service.on_model_changed(model.provider_id, model.global_model_id)
|
cache_service.on_model_changed(model.provider_id, model.global_model_id)
|
||||||
|
|
||||||
|
# 清除 /v1/models 列表缓存
|
||||||
|
asyncio.create_task(invalidate_models_list_cache())
|
||||||
|
|
||||||
status = "可用" if is_available else "不可用"
|
status = "可用" if is_available else "不可用"
|
||||||
logger.info(f"更新模型可用状态: id={model_id}, status={status}")
|
logger.info(f"更新模型可用状态: id={model_id}, status={status}")
|
||||||
return model
|
return model
|
||||||
@@ -358,6 +372,9 @@ class ModelService:
|
|||||||
for model in created_models:
|
for model in created_models:
|
||||||
db.refresh(model)
|
db.refresh(model)
|
||||||
logger.info(f"批量创建 {len(created_models)} 个模型成功")
|
logger.info(f"批量创建 {len(created_models)} 个模型成功")
|
||||||
|
|
||||||
|
# 清除 /v1/models 列表缓存
|
||||||
|
asyncio.create_task(invalidate_models_list_cache())
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
logger.error(f"批量创建模型失败: {str(e)}")
|
logger.error(f"批量创建模型失败: {str(e)}")
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ from src.core.logger import logger
|
|||||||
from src.models.database import Provider, SystemConfig
|
from src.models.database import Provider, SystemConfig
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class LogLevel(str, Enum):
|
class LogLevel(str, Enum):
|
||||||
"""日志记录级别"""
|
"""日志记录级别"""
|
||||||
|
|
||||||
@@ -79,6 +78,19 @@ class SystemConfigService:
|
|||||||
"value": False,
|
"value": False,
|
||||||
"description": "是否自动删除过期的API Key(True=物理删除,False=仅禁用),仅管理员可配置",
|
"description": "是否自动删除过期的API Key(True=物理删除,False=仅禁用),仅管理员可配置",
|
||||||
},
|
},
|
||||||
|
# 流式平滑输出配置
|
||||||
|
"stream_smoothing_enabled": {
|
||||||
|
"value": False,
|
||||||
|
"description": "是否启用流式平滑输出,自动根据文本长度调整输出速度",
|
||||||
|
},
|
||||||
|
"stream_smoothing_chunk_size": {
|
||||||
|
"value": 20,
|
||||||
|
"description": "流式平滑输出每个小块的字符数",
|
||||||
|
},
|
||||||
|
"stream_smoothing_delay_ms": {
|
||||||
|
"value": 8,
|
||||||
|
"description": "流式平滑输出每个小块之间的延迟毫秒数",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -94,6 +106,35 @@ class SystemConfigService:
|
|||||||
|
|
||||||
return default
|
return default
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_configs(cls, db: Session, keys: List[str]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
批量获取系统配置值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
keys: 配置键列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
配置键值字典
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
# 一次查询获取所有配置
|
||||||
|
configs = db.query(SystemConfig).filter(SystemConfig.key.in_(keys)).all()
|
||||||
|
config_map = {c.key: c.value for c in configs}
|
||||||
|
|
||||||
|
# 填充结果,不存在的使用默认值
|
||||||
|
for key in keys:
|
||||||
|
if key in config_map:
|
||||||
|
result[key] = config_map[key]
|
||||||
|
elif key in cls.DEFAULT_CONFIGS:
|
||||||
|
result[key] = cls.DEFAULT_CONFIGS[key]["value"]
|
||||||
|
else:
|
||||||
|
result[key] = None
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def set_config(db: Session, key: str, value: Any, description: str = None) -> SystemConfig:
|
def set_config(db: Session, key: str, value: Any, description: str = None) -> SystemConfig:
|
||||||
"""设置系统配置值"""
|
"""设置系统配置值"""
|
||||||
@@ -111,6 +152,7 @@ class SystemConfigService:
|
|||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(config)
|
db.refresh(config)
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -153,8 +195,8 @@ class SystemConfigService:
|
|||||||
for config in configs
|
for config in configs
|
||||||
]
|
]
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def delete_config(db: Session, key: str) -> bool:
|
def delete_config(cls, db: Session, key: str) -> bool:
|
||||||
"""删除系统配置"""
|
"""删除系统配置"""
|
||||||
config = db.query(SystemConfig).filter(SystemConfig.key == key).first()
|
config = db.query(SystemConfig).filter(SystemConfig.key == key).first()
|
||||||
if config:
|
if config:
|
||||||
|
|||||||
Reference in New Issue
Block a user