mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-02 15:52:26 +08:00
refactor: 重构流式处理模块,提取 StreamContext/Processor/Telemetry
- 将 chat_handler_base.py 中的流式处理逻辑拆分为三个独立模块: - StreamContext: 类型安全的流式上下文数据类,替代原有的 ctx dict - StreamProcessor: SSE 解析、预读、嵌套错误检测 - StreamTelemetryRecorder: 统计记录(Usage/Audit/Candidate) - 将硬编码配置外置到 settings.py,支持环境变量覆盖: - HTTP 超时配置(connect/write/pool) - 流式处理配置(预读行数、统计延迟) - 并发控制配置(槽位 TTL、缓存预留比例)
This commit is contained in:
@@ -12,10 +12,13 @@ Chat Handler Base - Chat API 格式的通用基类
|
||||
- apply_mapped_model(): 模型映射
|
||||
- get_model_for_url(): URL 模型名
|
||||
- _extract_usage(): 使用量提取
|
||||
|
||||
重构说明:
|
||||
- StreamContext: 类型安全的流式上下文,替代原有的 ctx dict
|
||||
- StreamProcessor: 流式响应处理(SSE 解析、预读、错误检测)
|
||||
- StreamTelemetryRecorder: 统计记录(Usage、Audit、Candidate)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, Callable, Dict, Optional
|
||||
|
||||
@@ -24,13 +27,14 @@ from fastapi import BackgroundTasks, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.handlers.base.base_handler import (
|
||||
BaseMessageHandler,
|
||||
MessageTelemetry,
|
||||
)
|
||||
from src.api.handlers.base.base_handler import BaseMessageHandler
|
||||
from src.api.handlers.base.parsers import get_parser_for_format
|
||||
from src.api.handlers.base.request_builder import PassthroughRequestBuilder
|
||||
from src.api.handlers.base.response_parser import ResponseParser
|
||||
from src.api.handlers.base.stream_context import StreamContext
|
||||
from src.api.handlers.base.stream_processor import StreamProcessor
|
||||
from src.api.handlers.base.stream_telemetry import StreamTelemetryRecorder
|
||||
from src.config.settings import config
|
||||
from src.core.exceptions import (
|
||||
EmbeddedErrorException,
|
||||
ProviderAuthException,
|
||||
@@ -39,7 +43,6 @@ from src.core.exceptions import (
|
||||
ProviderTimeoutException,
|
||||
)
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import (
|
||||
ApiKey,
|
||||
Provider,
|
||||
@@ -48,7 +51,6 @@ from src.models.database import (
|
||||
User,
|
||||
)
|
||||
from src.services.provider.transport import build_provider_url
|
||||
from src.utils.sse_parser import SSEEventParser
|
||||
|
||||
|
||||
|
||||
@@ -285,30 +287,15 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
model = getattr(converted_request, "model", original_request_body.get("model", "unknown"))
|
||||
api_format = self.allowed_api_formats[0]
|
||||
|
||||
# 用于跟踪的上下文
|
||||
ctx = {
|
||||
"model": model,
|
||||
"api_format": api_format,
|
||||
"provider_name": None,
|
||||
"provider_id": None,
|
||||
"endpoint_id": None,
|
||||
"key_id": None,
|
||||
"attempt_id": None,
|
||||
"provider_api_format": None, # Provider 的响应格式(用于选择解析器)
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"cached_tokens": 0,
|
||||
"cache_creation_tokens": 0,
|
||||
"collected_text": "",
|
||||
"status_code": 200,
|
||||
"response_headers": {},
|
||||
"provider_request_headers": {},
|
||||
"provider_request_body": None,
|
||||
"data_count": 0,
|
||||
"chunk_count": 0,
|
||||
"has_completion": False,
|
||||
"parsed_chunks": [], # 收集解析后的 chunks
|
||||
}
|
||||
# 创建类型安全的流式上下文
|
||||
ctx = StreamContext(model=model, api_format=api_format)
|
||||
|
||||
# 创建流处理器
|
||||
stream_processor = StreamProcessor(
|
||||
request_id=self.request_id,
|
||||
default_parser=self.parser,
|
||||
on_streaming_start=self._update_usage_to_streaming,
|
||||
)
|
||||
|
||||
# 定义请求函数
|
||||
async def stream_request_func(
|
||||
@@ -318,6 +305,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
return await self._execute_stream_request(
|
||||
ctx,
|
||||
stream_processor,
|
||||
provider,
|
||||
endpoint,
|
||||
key,
|
||||
@@ -350,23 +338,39 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
is_stream=True,
|
||||
capability_requirements=capability_requirements or None,
|
||||
)
|
||||
ctx["attempt_id"] = attempt_id
|
||||
ctx["provider_name"] = provider_name
|
||||
ctx["provider_id"] = provider_id
|
||||
ctx["endpoint_id"] = endpoint_id
|
||||
ctx["key_id"] = key_id
|
||||
|
||||
# 更新上下文
|
||||
ctx.attempt_id = attempt_id
|
||||
ctx.provider_name = provider_name
|
||||
ctx.provider_id = provider_id
|
||||
ctx.endpoint_id = endpoint_id
|
||||
ctx.key_id = key_id
|
||||
|
||||
# 创建遥测记录器
|
||||
telemetry_recorder = StreamTelemetryRecorder(
|
||||
request_id=self.request_id,
|
||||
user_id=str(self.user.id),
|
||||
api_key_id=str(self.api_key.id),
|
||||
client_ip=self.client_ip,
|
||||
format_id=self.FORMAT_ID,
|
||||
)
|
||||
|
||||
# 创建后台任务记录统计
|
||||
background_tasks = BackgroundTasks()
|
||||
background_tasks.add_task(
|
||||
self._record_stream_stats,
|
||||
telemetry_recorder.record_stream_stats,
|
||||
ctx,
|
||||
original_headers,
|
||||
original_request_body,
|
||||
self.elapsed_ms(),
|
||||
)
|
||||
|
||||
# 创建监控流
|
||||
monitored_stream = self._create_monitored_stream(ctx, stream_generator, http_request)
|
||||
monitored_stream = stream_processor.create_monitored_stream(
|
||||
ctx,
|
||||
stream_generator,
|
||||
http_request.is_disconnected,
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
monitored_stream,
|
||||
@@ -381,7 +385,8 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
|
||||
async def _execute_stream_request(
|
||||
self,
|
||||
ctx: Dict,
|
||||
ctx: StreamContext,
|
||||
stream_processor: StreamProcessor,
|
||||
provider: Provider,
|
||||
endpoint: ProviderEndpoint,
|
||||
key: ProviderAPIKey,
|
||||
@@ -390,37 +395,32 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
query_params: Optional[Dict[str, str]] = None,
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""执行流式请求并返回流生成器"""
|
||||
# 重置上下文状态(重试时清除之前的数据,避免累积)
|
||||
ctx["parsed_chunks"] = []
|
||||
ctx["chunk_count"] = 0
|
||||
ctx["data_count"] = 0
|
||||
ctx["has_completion"] = False
|
||||
ctx["collected_text"] = ""
|
||||
ctx["input_tokens"] = 0
|
||||
ctx["output_tokens"] = 0
|
||||
ctx["cached_tokens"] = 0
|
||||
ctx["cache_creation_tokens"] = 0
|
||||
# 重置上下文状态(重试时清除之前的数据)
|
||||
ctx.reset_for_retry()
|
||||
|
||||
ctx["provider_name"] = str(provider.name)
|
||||
ctx["provider_id"] = str(provider.id)
|
||||
ctx["endpoint_id"] = str(endpoint.id)
|
||||
ctx["key_id"] = str(key.id)
|
||||
ctx["provider_api_format"] = str(endpoint.api_format) if endpoint.api_format else ""
|
||||
# 更新 Provider 信息
|
||||
ctx.update_provider_info(
|
||||
provider_name=str(provider.name),
|
||||
provider_id=str(provider.id),
|
||||
endpoint_id=str(endpoint.id),
|
||||
key_id=str(key.id),
|
||||
provider_api_format=str(endpoint.api_format) if endpoint.api_format else None,
|
||||
)
|
||||
|
||||
# 获取模型映射
|
||||
mapped_model = await self._get_mapped_model(
|
||||
source_model=ctx["model"],
|
||||
source_model=ctx.model,
|
||||
provider_id=str(provider.id),
|
||||
)
|
||||
|
||||
# 应用模型映射到请求体
|
||||
if mapped_model:
|
||||
ctx["mapped_model"] = mapped_model # 保存映射后的模型名,用于 Usage 记录
|
||||
ctx.mapped_model = mapped_model
|
||||
request_body = self.apply_mapped_model(original_request_body, mapped_model)
|
||||
else:
|
||||
request_body = dict(original_request_body)
|
||||
|
||||
# 准备发送给 Provider 的请求体(子类可覆盖以移除不需要的字段)
|
||||
# 准备发送给 Provider 的请求体
|
||||
request_body = self.prepare_provider_request_body(request_body)
|
||||
|
||||
# 构建请求
|
||||
@@ -432,11 +432,11 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
is_stream=True,
|
||||
)
|
||||
|
||||
ctx["provider_request_headers"] = provider_headers
|
||||
ctx["provider_request_body"] = provider_payload
|
||||
ctx.provider_request_headers = provider_headers
|
||||
ctx.provider_request_body = provider_payload
|
||||
|
||||
# 获取 URL 模型名(兜底使用 ctx 中的 model,确保 Gemini 等格式能正确构建 URL)
|
||||
url_model = self.get_model_for_url(request_body, mapped_model) or ctx["model"]
|
||||
# 获取 URL 模型名
|
||||
url_model = self.get_model_for_url(request_body, mapped_model) or ctx.model
|
||||
|
||||
url = build_provider_url(
|
||||
endpoint,
|
||||
@@ -445,15 +445,17 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
is_stream=True,
|
||||
)
|
||||
|
||||
logger.debug(f" [{self.request_id}] 发送流式请求: Provider={provider.name}, "
|
||||
f"模型={ctx['model']} -> {mapped_model or '无映射'}")
|
||||
logger.debug(
|
||||
f" [{self.request_id}] 发送流式请求: Provider={provider.name}, "
|
||||
f"模型={ctx.model} -> {mapped_model or '无映射'}"
|
||||
)
|
||||
|
||||
# 发送请求
|
||||
# 发送请求(使用配置中的超时设置)
|
||||
timeout_config = httpx.Timeout(
|
||||
connect=10.0,
|
||||
connect=config.http_connect_timeout,
|
||||
read=float(endpoint.timeout),
|
||||
write=60.0, # 写入超时增加到60秒,支持大请求体(如包含图片的长对话)
|
||||
pool=10.0,
|
||||
write=config.http_write_timeout,
|
||||
pool=config.http_pool_timeout,
|
||||
)
|
||||
|
||||
http_client = httpx.AsyncClient(timeout=timeout_config, follow_redirects=True)
|
||||
@@ -463,17 +465,21 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
)
|
||||
stream_response = await response_ctx.__aenter__()
|
||||
|
||||
ctx["status_code"] = stream_response.status_code
|
||||
ctx["response_headers"] = dict(stream_response.headers)
|
||||
ctx.status_code = stream_response.status_code
|
||||
ctx.response_headers = dict(stream_response.headers)
|
||||
|
||||
stream_response.raise_for_status()
|
||||
|
||||
# 创建行迭代器(只创建一次,后续会继续使用)
|
||||
# 创建行迭代器
|
||||
line_iterator = stream_response.aiter_lines()
|
||||
|
||||
# 预读第一个数据块,检测嵌套错误(HTTP 200 但响应体包含错误)
|
||||
prefetched_lines = await self._prefetch_and_check_embedded_error(
|
||||
line_iterator, provider, endpoint, ctx
|
||||
# 预读检测嵌套错误
|
||||
prefetched_lines = await stream_processor.prefetch_and_check_error(
|
||||
line_iterator,
|
||||
provider,
|
||||
endpoint,
|
||||
ctx,
|
||||
max_prefetch_lines=config.stream_prefetch_lines,
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
@@ -483,7 +489,6 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
raise
|
||||
|
||||
except EmbeddedErrorException:
|
||||
# 嵌套错误需要触发重试,关闭连接后重新抛出
|
||||
try:
|
||||
await response_ctx.__aexit__(None, None, None)
|
||||
except Exception:
|
||||
@@ -495,8 +500,8 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
await http_client.aclose()
|
||||
raise
|
||||
|
||||
# 创建流生成器(带预读数据,使用同一个迭代器)
|
||||
return self._create_response_stream_with_prefetch(
|
||||
# 创建流生成器
|
||||
return stream_processor.create_response_stream(
|
||||
ctx,
|
||||
line_iterator,
|
||||
response_ctx,
|
||||
@@ -504,518 +509,9 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
prefetched_lines,
|
||||
)
|
||||
|
||||
async def _create_response_stream(
|
||||
self,
|
||||
ctx: Dict,
|
||||
stream_response: httpx.Response,
|
||||
response_ctx: Any,
|
||||
http_client: httpx.AsyncClient,
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""创建响应流生成器"""
|
||||
try:
|
||||
sse_parser = SSEEventParser()
|
||||
streaming_status_updated = False
|
||||
|
||||
async for line in stream_response.aiter_lines():
|
||||
# 在第一次输出数据前更新状态为 streaming
|
||||
if not streaming_status_updated:
|
||||
self._update_usage_to_streaming()
|
||||
streaming_status_updated = True
|
||||
|
||||
normalized_line = line.rstrip("\r")
|
||||
events = sse_parser.feed_line(normalized_line)
|
||||
|
||||
if normalized_line == "":
|
||||
for event in events:
|
||||
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
||||
yield b"\n"
|
||||
continue
|
||||
|
||||
ctx["chunk_count"] += 1
|
||||
|
||||
yield (line + "\n").encode("utf-8")
|
||||
|
||||
for event in events:
|
||||
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
||||
|
||||
# 处理剩余事件
|
||||
for event in sse_parser.flush():
|
||||
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
||||
|
||||
except GeneratorExit:
|
||||
raise
|
||||
finally:
|
||||
try:
|
||||
await response_ctx.__aexit__(None, None, None)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await http_client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _prefetch_and_check_embedded_error(
|
||||
self,
|
||||
line_iterator: Any,
|
||||
provider: Provider,
|
||||
endpoint: ProviderEndpoint,
|
||||
ctx: Dict,
|
||||
) -> list:
|
||||
"""
|
||||
预读流的前几行,检测嵌套错误
|
||||
|
||||
某些 Provider(如 Gemini)可能返回 HTTP 200,但在响应体中包含错误信息。
|
||||
这种情况需要在流开始输出之前检测,以便触发重试逻辑。
|
||||
|
||||
Args:
|
||||
line_iterator: 行迭代器(aiter_lines() 返回的迭代器)
|
||||
provider: Provider 对象
|
||||
endpoint: Endpoint 对象
|
||||
ctx: 上下文字典
|
||||
|
||||
Returns:
|
||||
预读的行列表(需要在后续流中先输出)
|
||||
|
||||
Raises:
|
||||
EmbeddedErrorException: 如果检测到嵌套错误
|
||||
"""
|
||||
prefetched_lines: list = []
|
||||
max_prefetch_lines = 5 # 最多预读5行来检测错误
|
||||
|
||||
try:
|
||||
# 获取对应格式的解析器
|
||||
provider_parser = self._get_provider_parser(ctx)
|
||||
|
||||
line_count = 0
|
||||
async for line in line_iterator:
|
||||
prefetched_lines.append(line)
|
||||
line_count += 1
|
||||
|
||||
# 解析数据
|
||||
normalized_line = line.rstrip("\r")
|
||||
if not normalized_line or normalized_line.startswith(":"):
|
||||
# 空行或注释行,继续预读
|
||||
if line_count >= max_prefetch_lines:
|
||||
break
|
||||
continue
|
||||
|
||||
# 尝试解析 SSE 数据
|
||||
data_str = normalized_line
|
||||
if normalized_line.startswith("data: "):
|
||||
data_str = normalized_line[6:]
|
||||
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
# 不是有效 JSON,可能是部分数据,继续
|
||||
if line_count >= max_prefetch_lines:
|
||||
break
|
||||
continue
|
||||
|
||||
# 使用解析器检查是否为错误响应
|
||||
if isinstance(data, dict) and provider_parser.is_error_response(data):
|
||||
# 提取错误信息
|
||||
parsed = provider_parser.parse_response(data, 200)
|
||||
logger.warning(f" [{self.request_id}] 检测到嵌套错误: "
|
||||
f"Provider={provider.name}, "
|
||||
f"error_type={parsed.error_type}, "
|
||||
f"message={parsed.error_message}")
|
||||
raise EmbeddedErrorException(
|
||||
provider_name=str(provider.name),
|
||||
error_code=(
|
||||
int(parsed.error_type)
|
||||
if parsed.error_type and parsed.error_type.isdigit()
|
||||
else None
|
||||
),
|
||||
error_message=parsed.error_message,
|
||||
error_status=parsed.error_type,
|
||||
)
|
||||
|
||||
# 预读到有效数据,没有错误,停止预读
|
||||
break
|
||||
|
||||
except EmbeddedErrorException:
|
||||
# 重新抛出嵌套错误
|
||||
raise
|
||||
except Exception as e:
|
||||
# 其他异常(如网络错误)在预读阶段发生,记录日志但不中断
|
||||
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
|
||||
|
||||
return prefetched_lines
|
||||
|
||||
async def _create_response_stream_with_prefetch(
|
||||
self,
|
||||
ctx: Dict,
|
||||
line_iterator: Any,
|
||||
response_ctx: Any,
|
||||
http_client: httpx.AsyncClient,
|
||||
prefetched_lines: list,
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""创建响应流生成器(带预读数据)"""
|
||||
try:
|
||||
sse_parser = SSEEventParser()
|
||||
|
||||
# 在第一次输出数据前更新状态为 streaming
|
||||
if prefetched_lines:
|
||||
self._update_usage_to_streaming()
|
||||
|
||||
# 先输出预读的数据
|
||||
for line in prefetched_lines:
|
||||
normalized_line = line.rstrip("\r")
|
||||
events = sse_parser.feed_line(normalized_line)
|
||||
|
||||
if normalized_line == "":
|
||||
for event in events:
|
||||
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
||||
yield b"\n"
|
||||
continue
|
||||
|
||||
ctx["chunk_count"] += 1
|
||||
yield (line + "\n").encode("utf-8")
|
||||
|
||||
for event in events:
|
||||
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
||||
|
||||
# 继续输出剩余的流数据(使用同一个迭代器)
|
||||
async for line in line_iterator:
|
||||
normalized_line = line.rstrip("\r")
|
||||
events = sse_parser.feed_line(normalized_line)
|
||||
|
||||
if normalized_line == "":
|
||||
for event in events:
|
||||
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
||||
yield b"\n"
|
||||
continue
|
||||
|
||||
ctx["chunk_count"] += 1
|
||||
yield (line + "\n").encode("utf-8")
|
||||
|
||||
for event in events:
|
||||
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
||||
|
||||
# 处理剩余事件
|
||||
for event in sse_parser.flush():
|
||||
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
||||
|
||||
except GeneratorExit:
|
||||
raise
|
||||
finally:
|
||||
try:
|
||||
await response_ctx.__aexit__(None, None, None)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await http_client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _get_provider_parser(self, ctx: Dict) -> ResponseParser:
|
||||
"""
|
||||
获取 Provider 格式的解析器
|
||||
|
||||
根据 Provider 的 API 格式选择正确的解析器,
|
||||
而不是根据请求格式选择。
|
||||
"""
|
||||
provider_format = ctx.get("provider_api_format")
|
||||
if provider_format:
|
||||
try:
|
||||
return get_parser_for_format(provider_format)
|
||||
except KeyError:
|
||||
pass
|
||||
# 回退到默认解析器
|
||||
return self.parser
|
||||
|
||||
def _handle_sse_event(
|
||||
self,
|
||||
ctx: Dict,
|
||||
event_name: Optional[str],
|
||||
data_str: str,
|
||||
) -> None:
|
||||
"""处理 SSE 事件"""
|
||||
if not data_str:
|
||||
return
|
||||
|
||||
if data_str == "[DONE]":
|
||||
ctx["has_completion"] = True
|
||||
return
|
||||
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
return
|
||||
|
||||
ctx["data_count"] += 1
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return
|
||||
|
||||
# 收集原始 chunk 数据
|
||||
ctx["parsed_chunks"].append(data)
|
||||
|
||||
# 根据 Provider 格式选择解析器
|
||||
provider_parser = self._get_provider_parser(ctx)
|
||||
|
||||
# 使用解析器提取 usage
|
||||
usage = provider_parser.extract_usage_from_response(data)
|
||||
if usage:
|
||||
ctx["input_tokens"] = usage.get("input_tokens", ctx["input_tokens"])
|
||||
ctx["output_tokens"] = usage.get("output_tokens", ctx["output_tokens"])
|
||||
ctx["cached_tokens"] = usage.get("cache_read_tokens", ctx["cached_tokens"])
|
||||
ctx["cache_creation_tokens"] = usage.get(
|
||||
"cache_creation_tokens", ctx["cache_creation_tokens"]
|
||||
)
|
||||
|
||||
# 提取文本
|
||||
text = provider_parser.extract_text_content(data)
|
||||
if text:
|
||||
ctx["collected_text"] += text
|
||||
|
||||
# 检查完成
|
||||
event_type = event_name or data.get("type", "")
|
||||
if event_type in ("response.completed", "message_stop"):
|
||||
ctx["has_completion"] = True
|
||||
|
||||
async def _create_monitored_stream(
|
||||
self,
|
||||
ctx: Dict,
|
||||
stream_generator: AsyncGenerator[bytes, None],
|
||||
http_request: Request,
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""创建带监控的流生成器"""
|
||||
try:
|
||||
async for chunk in stream_generator:
|
||||
if await http_request.is_disconnected():
|
||||
logger.warning(f"ID:{self.request_id} | Client disconnected")
|
||||
# 客户端断开时设置 499 状态码(Client Closed Request)
|
||||
# 注意:Provider 可能已经成功返回数据,但客户端未完整接收
|
||||
ctx["status_code"] = 499
|
||||
break
|
||||
yield chunk
|
||||
except asyncio.CancelledError:
|
||||
ctx["status_code"] = 499
|
||||
raise
|
||||
except Exception:
|
||||
ctx["status_code"] = 500
|
||||
raise
|
||||
|
||||
async def _record_stream_stats(
|
||||
self,
|
||||
ctx: Dict,
|
||||
original_headers: Dict[str, str],
|
||||
original_request_body: Dict[str, Any],
|
||||
) -> None:
|
||||
"""记录流式统计信息"""
|
||||
response_time_ms = self.elapsed_ms()
|
||||
bg_db = None
|
||||
|
||||
try:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
if not ctx["provider_name"]:
|
||||
# 即使没有 provider_name,也要尝试更新状态为 failed
|
||||
await self._update_usage_status_on_error(
|
||||
response_time_ms=response_time_ms,
|
||||
error_message="Provider name not available",
|
||||
)
|
||||
return
|
||||
|
||||
db_gen = get_db()
|
||||
bg_db = next(db_gen)
|
||||
|
||||
try:
|
||||
from src.models.database import ApiKey as ApiKeyModel
|
||||
|
||||
user = bg_db.query(User).filter(User.id == self.user.id).first()
|
||||
api_key_obj = (
|
||||
bg_db.query(ApiKeyModel).filter(ApiKeyModel.id == self.api_key.id).first()
|
||||
)
|
||||
|
||||
if not user or not api_key_obj:
|
||||
logger.warning(f"[{self.request_id}] User or ApiKey not found, updating status directly")
|
||||
await self._update_usage_status_directly(
|
||||
bg_db,
|
||||
status="completed" if ctx["status_code"] == 200 else "failed",
|
||||
response_time_ms=response_time_ms,
|
||||
status_code=ctx["status_code"],
|
||||
)
|
||||
return
|
||||
|
||||
bg_telemetry = MessageTelemetry(
|
||||
bg_db, user, api_key_obj, self.request_id, self.client_ip
|
||||
)
|
||||
|
||||
actual_request_body = ctx["provider_request_body"] or original_request_body
|
||||
|
||||
# 构建响应体(与 CLI 模式一致)
|
||||
response_body = {
|
||||
"chunks": ctx["parsed_chunks"],
|
||||
"metadata": {
|
||||
"stream": True,
|
||||
"total_chunks": len(ctx["parsed_chunks"]),
|
||||
"data_count": ctx["data_count"],
|
||||
"has_completion": ctx["has_completion"],
|
||||
"response_time_ms": response_time_ms,
|
||||
},
|
||||
}
|
||||
|
||||
# 根据状态码决定记录成功还是失败
|
||||
# 499 = 客户端断开连接,503 = 服务不可用(如流中断)
|
||||
status_code: int = ctx.get("status_code") or 200
|
||||
if status_code >= 400:
|
||||
# 记录失败的 Usage,但使用已收到的预估 token 信息(来自 message_start)
|
||||
# 这样即使请求中断,也能记录预估成本
|
||||
await bg_telemetry.record_failure(
|
||||
provider=ctx.get("provider_name") or "unknown",
|
||||
model=ctx["model"],
|
||||
response_time_ms=response_time_ms,
|
||||
status_code=status_code,
|
||||
error_message=ctx.get("error_message") or f"HTTP {status_code}",
|
||||
request_headers=original_headers,
|
||||
request_body=actual_request_body,
|
||||
is_stream=True,
|
||||
api_format=ctx["api_format"],
|
||||
provider_request_headers=ctx["provider_request_headers"],
|
||||
# 预估 token 信息(来自 message_start 事件)
|
||||
input_tokens=ctx.get("input_tokens", 0),
|
||||
output_tokens=ctx.get("output_tokens", 0),
|
||||
cache_creation_tokens=ctx.get("cache_creation_tokens", 0),
|
||||
cache_read_tokens=ctx.get("cached_tokens", 0),
|
||||
response_body=response_body,
|
||||
# 模型映射信息
|
||||
target_model=ctx.get("mapped_model"),
|
||||
)
|
||||
logger.debug(f"{self.FORMAT_ID} 流式响应中断")
|
||||
# 简洁的请求失败摘要(包含预估 token 信息)
|
||||
logger.info(f"[FAIL] {self.request_id[:8]} | {ctx['model']} | {ctx.get('provider_name', 'unknown')} | {response_time_ms}ms | "
|
||||
f"{status_code} | in:{ctx.get('input_tokens', 0)} out:{ctx.get('output_tokens', 0)} cache:{ctx.get('cached_tokens', 0)}")
|
||||
else:
|
||||
await bg_telemetry.record_success(
|
||||
provider=ctx["provider_name"],
|
||||
model=ctx["model"],
|
||||
input_tokens=ctx["input_tokens"],
|
||||
output_tokens=ctx["output_tokens"],
|
||||
response_time_ms=response_time_ms,
|
||||
status_code=status_code,
|
||||
request_headers=original_headers,
|
||||
request_body=actual_request_body,
|
||||
response_headers=ctx["response_headers"],
|
||||
response_body=response_body,
|
||||
cache_creation_tokens=ctx["cache_creation_tokens"],
|
||||
cache_read_tokens=ctx["cached_tokens"],
|
||||
is_stream=True,
|
||||
provider_request_headers=ctx["provider_request_headers"],
|
||||
api_format=ctx["api_format"],
|
||||
provider_id=ctx["provider_id"],
|
||||
provider_endpoint_id=ctx["endpoint_id"],
|
||||
provider_api_key_id=ctx["key_id"],
|
||||
# 模型映射信息
|
||||
target_model=ctx.get("mapped_model"),
|
||||
)
|
||||
logger.debug(f"{self.FORMAT_ID} 流式响应完成")
|
||||
# 简洁的请求完成摘要
|
||||
logger.info(f"[OK] {self.request_id[:8]} | {ctx['model']} | {ctx.get('provider_name', 'unknown')} | {response_time_ms}ms | "
|
||||
f"in:{ctx.get('input_tokens', 0) or 0} out:{ctx.get('output_tokens', 0) or 0}")
|
||||
|
||||
# 更新候选记录的最终状态和延迟时间
|
||||
# 注意:RequestExecutor 会在流开始时过早地标记成功(只记录了连接建立的时间)
|
||||
# 这里用流传输完成后的实际时间覆盖
|
||||
if ctx.get("attempt_id"):
|
||||
from src.services.request.candidate import RequestCandidateService
|
||||
|
||||
# 根据状态码决定是成功还是失败(复用上面已定义的 status_code)
|
||||
# 499 = 客户端断开连接,应标记为失败
|
||||
# 503 = 服务不可用(如流中断),应标记为失败
|
||||
if status_code and status_code >= 400:
|
||||
RequestCandidateService.mark_candidate_failed(
|
||||
db=bg_db,
|
||||
candidate_id=ctx["attempt_id"],
|
||||
error_type="client_disconnected" if status_code == 499 else "stream_error",
|
||||
error_message=ctx.get("error_message") or f"HTTP {status_code}",
|
||||
status_code=status_code,
|
||||
latency_ms=response_time_ms,
|
||||
extra_data={
|
||||
"stream_completed": False,
|
||||
"data_count": ctx.get("data_count", 0),
|
||||
},
|
||||
)
|
||||
else:
|
||||
RequestCandidateService.mark_candidate_success(
|
||||
db=bg_db,
|
||||
candidate_id=ctx["attempt_id"],
|
||||
status_code=status_code,
|
||||
latency_ms=response_time_ms,
|
||||
extra_data={
|
||||
"stream_completed": True,
|
||||
"data_count": ctx.get("data_count", 0),
|
||||
},
|
||||
)
|
||||
|
||||
finally:
|
||||
if bg_db:
|
||||
bg_db.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("记录流式统计信息时出错")
|
||||
# 确保即使出错也要更新状态,避免 pending 状态卡住
|
||||
await self._update_usage_status_on_error(
|
||||
response_time_ms=response_time_ms,
|
||||
error_message=f"记录统计信息失败: {str(e)[:200]}",
|
||||
)
|
||||
|
||||
# _update_usage_to_streaming 方法已移至基类 BaseMessageHandler
|
||||
|
||||
async def _update_usage_status_on_error(
|
||||
self,
|
||||
response_time_ms: int,
|
||||
error_message: str,
|
||||
) -> None:
|
||||
"""在记录失败时更新 Usage 状态,避免卡在 pending"""
|
||||
try:
|
||||
db_gen = get_db()
|
||||
error_db = next(db_gen)
|
||||
try:
|
||||
await self._update_usage_status_directly(
|
||||
error_db,
|
||||
status="failed",
|
||||
response_time_ms=response_time_ms,
|
||||
status_code=500,
|
||||
error_message=error_message,
|
||||
)
|
||||
finally:
|
||||
error_db.close()
|
||||
except Exception as inner_e:
|
||||
logger.error(f"[{self.request_id}] 更新 Usage 状态失败: {inner_e}")
|
||||
|
||||
async def _update_usage_status_directly(
|
||||
self,
|
||||
db: Session,
|
||||
status: str,
|
||||
response_time_ms: int,
|
||||
status_code: int = 200,
|
||||
error_message: Optional[str] = None,
|
||||
) -> None:
|
||||
"""直接更新 Usage 表的状态字段"""
|
||||
try:
|
||||
from src.models.database import Usage
|
||||
|
||||
usage = db.query(Usage).filter(Usage.request_id == self.request_id).first()
|
||||
if usage:
|
||||
setattr(usage, "status", status)
|
||||
setattr(usage, "status_code", status_code)
|
||||
setattr(usage, "response_time_ms", response_time_ms)
|
||||
if error_message:
|
||||
setattr(usage, "error_message", error_message)
|
||||
db.commit()
|
||||
logger.debug(f"[{self.request_id}] Usage 状态已更新: {status}")
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.request_id}] 直接更新 Usage 状态失败: {e}")
|
||||
|
||||
async def _record_stream_failure(
|
||||
self,
|
||||
ctx: Dict,
|
||||
ctx: StreamContext,
|
||||
error: Exception,
|
||||
original_headers: Dict[str, str],
|
||||
original_request_body: Dict[str, Any],
|
||||
@@ -1031,21 +527,20 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
elif isinstance(error, ProviderTimeoutException):
|
||||
status_code = 504
|
||||
|
||||
actual_request_body = ctx.get("provider_request_body") or original_request_body
|
||||
actual_request_body = ctx.provider_request_body or original_request_body
|
||||
|
||||
await self.telemetry.record_failure(
|
||||
provider=ctx.get("provider_name") or "unknown",
|
||||
model=ctx["model"],
|
||||
provider=ctx.provider_name or "unknown",
|
||||
model=ctx.model,
|
||||
response_time_ms=response_time_ms,
|
||||
status_code=status_code,
|
||||
error_message=str(error),
|
||||
request_headers=original_headers,
|
||||
request_body=actual_request_body,
|
||||
is_stream=True,
|
||||
api_format=ctx["api_format"],
|
||||
provider_request_headers=ctx.get("provider_request_headers") or {},
|
||||
# 模型映射信息
|
||||
target_model=ctx.get("mapped_model"),
|
||||
api_format=ctx.api_format,
|
||||
provider_request_headers=ctx.provider_request_headers,
|
||||
target_model=ctx.mapped_model,
|
||||
)
|
||||
|
||||
# ==================== 非流式处理 ====================
|
||||
|
||||
154
src/api/handlers/base/stream_context.py
Normal file
154
src/api/handlers/base/stream_context.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""
|
||||
流式处理上下文 - 类型安全的数据类替代 dict
|
||||
|
||||
提供流式请求处理过程中的状态跟踪,包括:
|
||||
- Provider/Endpoint/Key 信息
|
||||
- Token 统计
|
||||
- 响应状态
|
||||
- 请求/响应数据
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamContext:
|
||||
"""
|
||||
流式处理上下文
|
||||
|
||||
用于在流式请求处理过程中跟踪状态,替代原有的 ctx dict。
|
||||
所有字段都有类型注解,提供更好的 IDE 支持和运行时类型安全。
|
||||
"""
|
||||
|
||||
# 请求基本信息
|
||||
model: str
|
||||
api_format: str
|
||||
|
||||
# Provider 信息(在请求执行时填充)
|
||||
provider_name: Optional[str] = None
|
||||
provider_id: Optional[str] = None
|
||||
endpoint_id: Optional[str] = None
|
||||
key_id: Optional[str] = None
|
||||
attempt_id: Optional[str] = None
|
||||
provider_api_format: Optional[str] = None # Provider 的响应格式
|
||||
|
||||
# 模型映射
|
||||
mapped_model: Optional[str] = None
|
||||
|
||||
# Token 统计
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cached_tokens: int = 0
|
||||
cache_creation_tokens: int = 0
|
||||
|
||||
# 响应内容
|
||||
collected_text: str = ""
|
||||
|
||||
# 响应状态
|
||||
status_code: int = 200
|
||||
error_message: Optional[str] = None
|
||||
has_completion: bool = False
|
||||
|
||||
# 请求/响应数据
|
||||
response_headers: Dict[str, str] = field(default_factory=dict)
|
||||
provider_request_headers: Dict[str, str] = field(default_factory=dict)
|
||||
provider_request_body: Optional[Dict[str, Any]] = None
|
||||
|
||||
# 流式处理统计
|
||||
data_count: int = 0
|
||||
chunk_count: int = 0
|
||||
parsed_chunks: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
def reset_for_retry(self) -> None:
|
||||
"""
|
||||
重试时重置状态
|
||||
|
||||
在故障转移重试时调用,清除之前的数据避免累积。
|
||||
保留 model 和 api_format,重置其他所有状态。
|
||||
"""
|
||||
self.parsed_chunks = []
|
||||
self.chunk_count = 0
|
||||
self.data_count = 0
|
||||
self.has_completion = False
|
||||
self.collected_text = ""
|
||||
self.input_tokens = 0
|
||||
self.output_tokens = 0
|
||||
self.cached_tokens = 0
|
||||
self.cache_creation_tokens = 0
|
||||
self.error_message = None
|
||||
self.status_code = 200
|
||||
self.response_headers = {}
|
||||
self.provider_request_headers = {}
|
||||
self.provider_request_body = None
|
||||
|
||||
def update_provider_info(
|
||||
self,
|
||||
provider_name: str,
|
||||
provider_id: str,
|
||||
endpoint_id: str,
|
||||
key_id: str,
|
||||
provider_api_format: Optional[str] = None,
|
||||
) -> None:
|
||||
"""更新 Provider 信息"""
|
||||
self.provider_name = provider_name
|
||||
self.provider_id = provider_id
|
||||
self.endpoint_id = endpoint_id
|
||||
self.key_id = key_id
|
||||
self.provider_api_format = provider_api_format
|
||||
|
||||
def update_usage(
|
||||
self,
|
||||
input_tokens: Optional[int] = None,
|
||||
output_tokens: Optional[int] = None,
|
||||
cached_tokens: Optional[int] = None,
|
||||
cache_creation_tokens: Optional[int] = None,
|
||||
) -> None:
|
||||
"""更新 Token 使用统计"""
|
||||
if input_tokens is not None:
|
||||
self.input_tokens = input_tokens
|
||||
if output_tokens is not None:
|
||||
self.output_tokens = output_tokens
|
||||
if cached_tokens is not None:
|
||||
self.cached_tokens = cached_tokens
|
||||
if cache_creation_tokens is not None:
|
||||
self.cache_creation_tokens = cache_creation_tokens
|
||||
|
||||
def mark_failed(self, status_code: int, error_message: str) -> None:
|
||||
"""标记请求失败"""
|
||||
self.status_code = status_code
|
||||
self.error_message = error_message
|
||||
|
||||
def is_success(self) -> bool:
|
||||
"""检查请求是否成功"""
|
||||
return self.status_code < 400
|
||||
|
||||
def build_response_body(self, response_time_ms: int) -> Dict[str, Any]:
|
||||
"""
|
||||
构建响应体元数据
|
||||
|
||||
用于记录到 Usage 表的 response_body 字段。
|
||||
"""
|
||||
return {
|
||||
"chunks": self.parsed_chunks,
|
||||
"metadata": {
|
||||
"stream": True,
|
||||
"total_chunks": len(self.parsed_chunks),
|
||||
"data_count": self.data_count,
|
||||
"has_completion": self.has_completion,
|
||||
"response_time_ms": response_time_ms,
|
||||
},
|
||||
}
|
||||
|
||||
def get_log_summary(self, request_id: str, response_time_ms: int) -> str:
|
||||
"""
|
||||
获取日志摘要
|
||||
|
||||
用于请求完成/失败时的日志输出。
|
||||
"""
|
||||
status = "OK" if self.is_success() else "FAIL"
|
||||
return (
|
||||
f"[{status}] {request_id[:8]} | {self.model} | "
|
||||
f"{self.provider_name or 'unknown'} | {response_time_ms}ms | "
|
||||
f"in:{self.input_tokens} out:{self.output_tokens}"
|
||||
)
|
||||
349
src/api/handlers/base/stream_processor.py
Normal file
349
src/api/handlers/base/stream_processor.py
Normal file
@@ -0,0 +1,349 @@
|
||||
"""
|
||||
流式处理器 - 从 ChatHandlerBase 提取的流式响应处理逻辑
|
||||
|
||||
职责:
|
||||
1. SSE 事件解析和处理
|
||||
2. 响应流生成
|
||||
3. 预读和嵌套错误检测
|
||||
4. 客户端断开检测
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, AsyncGenerator, Callable, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from src.api.handlers.base.parsers import get_parser_for_format
|
||||
from src.api.handlers.base.response_parser import ResponseParser
|
||||
from src.api.handlers.base.stream_context import StreamContext
|
||||
from src.core.exceptions import EmbeddedErrorException
|
||||
from src.core.logger import logger
|
||||
from src.models.database import Provider, ProviderEndpoint
|
||||
from src.utils.sse_parser import SSEEventParser
|
||||
|
||||
|
||||
class StreamProcessor:
|
||||
"""
|
||||
流式响应处理器
|
||||
|
||||
负责处理 SSE 流的解析、错误检测和响应生成。
|
||||
从 ChatHandlerBase 中提取,使其职责更加单一。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
default_parser: ResponseParser,
|
||||
on_streaming_start: Optional[Callable[[], None]] = None,
|
||||
):
|
||||
"""
|
||||
初始化流处理器
|
||||
|
||||
Args:
|
||||
request_id: 请求 ID(用于日志)
|
||||
default_parser: 默认响应解析器
|
||||
on_streaming_start: 流开始时的回调(用于更新状态)
|
||||
"""
|
||||
self.request_id = request_id
|
||||
self.default_parser = default_parser
|
||||
self.on_streaming_start = on_streaming_start
|
||||
|
||||
def get_parser_for_provider(self, ctx: StreamContext) -> ResponseParser:
|
||||
"""
|
||||
获取 Provider 格式的解析器
|
||||
|
||||
根据 Provider 的 API 格式选择正确的解析器。
|
||||
"""
|
||||
if ctx.provider_api_format:
|
||||
try:
|
||||
return get_parser_for_format(ctx.provider_api_format)
|
||||
except KeyError:
|
||||
pass
|
||||
return self.default_parser
|
||||
|
||||
def handle_sse_event(
|
||||
self,
|
||||
ctx: StreamContext,
|
||||
event_name: Optional[str],
|
||||
data_str: str,
|
||||
) -> None:
|
||||
"""
|
||||
处理单个 SSE 事件
|
||||
|
||||
解析事件数据,提取 usage 信息和文本内容。
|
||||
|
||||
Args:
|
||||
ctx: 流式上下文
|
||||
event_name: 事件名称
|
||||
data_str: 事件数据字符串
|
||||
"""
|
||||
if not data_str:
|
||||
return
|
||||
|
||||
if data_str == "[DONE]":
|
||||
ctx.has_completion = True
|
||||
return
|
||||
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
return
|
||||
|
||||
ctx.data_count += 1
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return
|
||||
|
||||
# 收集原始 chunk 数据
|
||||
ctx.parsed_chunks.append(data)
|
||||
|
||||
# 根据 Provider 格式选择解析器
|
||||
parser = self.get_parser_for_provider(ctx)
|
||||
|
||||
# 使用解析器提取 usage
|
||||
usage = parser.extract_usage_from_response(data)
|
||||
if usage:
|
||||
ctx.update_usage(
|
||||
input_tokens=usage.get("input_tokens"),
|
||||
output_tokens=usage.get("output_tokens"),
|
||||
cached_tokens=usage.get("cache_read_tokens"),
|
||||
cache_creation_tokens=usage.get("cache_creation_tokens"),
|
||||
)
|
||||
|
||||
# 提取文本
|
||||
text = parser.extract_text_content(data)
|
||||
if text:
|
||||
ctx.collected_text += text
|
||||
|
||||
# 检查完成
|
||||
event_type = event_name or data.get("type", "")
|
||||
if event_type in ("response.completed", "message_stop"):
|
||||
ctx.has_completion = True
|
||||
|
||||
async def prefetch_and_check_error(
|
||||
self,
|
||||
line_iterator: Any,
|
||||
provider: Provider,
|
||||
endpoint: ProviderEndpoint,
|
||||
ctx: StreamContext,
|
||||
max_prefetch_lines: int = 5,
|
||||
) -> list:
|
||||
"""
|
||||
预读流的前几行,检测嵌套错误
|
||||
|
||||
某些 Provider(如 Gemini)可能返回 HTTP 200,但在响应体中包含错误信息。
|
||||
这种情况需要在流开始输出之前检测,以便触发重试逻辑。
|
||||
|
||||
Args:
|
||||
line_iterator: 行迭代器
|
||||
provider: Provider 对象
|
||||
endpoint: Endpoint 对象
|
||||
ctx: 流式上下文
|
||||
max_prefetch_lines: 最多预读行数
|
||||
|
||||
Returns:
|
||||
预读的行列表
|
||||
|
||||
Raises:
|
||||
EmbeddedErrorException: 如果检测到嵌套错误
|
||||
"""
|
||||
prefetched_lines: list = []
|
||||
parser = self.get_parser_for_provider(ctx)
|
||||
|
||||
try:
|
||||
line_count = 0
|
||||
async for line in line_iterator:
|
||||
prefetched_lines.append(line)
|
||||
line_count += 1
|
||||
|
||||
normalized_line = line.rstrip("\r")
|
||||
if not normalized_line or normalized_line.startswith(":"):
|
||||
if line_count >= max_prefetch_lines:
|
||||
break
|
||||
continue
|
||||
|
||||
# 尝试解析 SSE 数据
|
||||
data_str = normalized_line
|
||||
if normalized_line.startswith("data: "):
|
||||
data_str = normalized_line[6:]
|
||||
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
if line_count >= max_prefetch_lines:
|
||||
break
|
||||
continue
|
||||
|
||||
# 使用解析器检查是否为错误响应
|
||||
if isinstance(data, dict) and parser.is_error_response(data):
|
||||
parsed = parser.parse_response(data, 200)
|
||||
logger.warning(
|
||||
f" [{self.request_id}] 检测到嵌套错误: "
|
||||
f"Provider={provider.name}, "
|
||||
f"error_type={parsed.error_type}, "
|
||||
f"message={parsed.error_message}"
|
||||
)
|
||||
raise EmbeddedErrorException(
|
||||
provider_name=str(provider.name),
|
||||
error_code=(
|
||||
int(parsed.error_type)
|
||||
if parsed.error_type and parsed.error_type.isdigit()
|
||||
else None
|
||||
),
|
||||
error_message=parsed.error_message,
|
||||
error_status=parsed.error_type,
|
||||
)
|
||||
|
||||
# 预读到有效数据,没有错误,停止预读
|
||||
break
|
||||
|
||||
except EmbeddedErrorException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
|
||||
|
||||
return prefetched_lines
|
||||
|
||||
async def create_response_stream(
|
||||
self,
|
||||
ctx: StreamContext,
|
||||
line_iterator: Any,
|
||||
response_ctx: Any,
|
||||
http_client: httpx.AsyncClient,
|
||||
prefetched_lines: Optional[list] = None,
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""
|
||||
创建响应流生成器
|
||||
|
||||
统一的流生成器,支持带预读数据和不带预读数据两种情况。
|
||||
|
||||
Args:
|
||||
ctx: 流式上下文
|
||||
line_iterator: 行迭代器
|
||||
response_ctx: HTTP 响应上下文管理器
|
||||
http_client: HTTP 客户端
|
||||
prefetched_lines: 预读的行列表(可选)
|
||||
|
||||
Yields:
|
||||
编码后的响应数据块
|
||||
"""
|
||||
try:
|
||||
sse_parser = SSEEventParser()
|
||||
streaming_started = False
|
||||
|
||||
# 处理预读数据
|
||||
if prefetched_lines:
|
||||
if not streaming_started and self.on_streaming_start:
|
||||
self.on_streaming_start()
|
||||
streaming_started = True
|
||||
|
||||
for line in prefetched_lines:
|
||||
for chunk in self._process_line(ctx, sse_parser, line):
|
||||
yield chunk
|
||||
|
||||
# 处理剩余的流数据
|
||||
async for line in line_iterator:
|
||||
if not streaming_started and self.on_streaming_start:
|
||||
self.on_streaming_start()
|
||||
streaming_started = True
|
||||
|
||||
for chunk in self._process_line(ctx, sse_parser, line):
|
||||
yield chunk
|
||||
|
||||
# 处理剩余事件
|
||||
for event in sse_parser.flush():
|
||||
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
||||
|
||||
except GeneratorExit:
|
||||
raise
|
||||
finally:
|
||||
await self._cleanup(response_ctx, http_client)
|
||||
|
||||
def _process_line(
|
||||
self,
|
||||
ctx: StreamContext,
|
||||
sse_parser: SSEEventParser,
|
||||
line: str,
|
||||
) -> list[bytes]:
|
||||
"""
|
||||
处理单行数据
|
||||
|
||||
Args:
|
||||
ctx: 流式上下文
|
||||
sse_parser: SSE 解析器
|
||||
line: 原始行数据
|
||||
|
||||
Returns:
|
||||
要发送的数据块列表
|
||||
"""
|
||||
result: list[bytes] = []
|
||||
normalized_line = line.rstrip("\r")
|
||||
events = sse_parser.feed_line(normalized_line)
|
||||
|
||||
if normalized_line == "":
|
||||
for event in events:
|
||||
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
||||
result.append(b"\n")
|
||||
else:
|
||||
ctx.chunk_count += 1
|
||||
result.append((line + "\n").encode("utf-8"))
|
||||
|
||||
for event in events:
|
||||
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
||||
|
||||
return result
|
||||
|
||||
async def create_monitored_stream(
|
||||
self,
|
||||
ctx: StreamContext,
|
||||
stream_generator: AsyncGenerator[bytes, None],
|
||||
is_disconnected: Callable[[], Any],
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""
|
||||
创建带监控的流生成器
|
||||
|
||||
检测客户端断开连接并更新状态码。
|
||||
|
||||
Args:
|
||||
ctx: 流式上下文
|
||||
stream_generator: 原始流生成器
|
||||
is_disconnected: 检查客户端是否断开的函数
|
||||
|
||||
Yields:
|
||||
响应数据块
|
||||
"""
|
||||
try:
|
||||
async for chunk in stream_generator:
|
||||
if await is_disconnected():
|
||||
logger.warning(f"ID:{self.request_id} | Client disconnected")
|
||||
ctx.status_code = 499 # Client Closed Request
|
||||
ctx.error_message = "client_disconnected"
|
||||
break
|
||||
yield chunk
|
||||
except asyncio.CancelledError:
|
||||
ctx.status_code = 499
|
||||
ctx.error_message = "client_disconnected"
|
||||
raise
|
||||
except Exception as e:
|
||||
ctx.status_code = 500
|
||||
ctx.error_message = str(e)
|
||||
raise
|
||||
|
||||
async def _cleanup(
|
||||
self,
|
||||
response_ctx: Any,
|
||||
http_client: httpx.AsyncClient,
|
||||
) -> None:
|
||||
"""清理资源"""
|
||||
try:
|
||||
await response_ctx.__aexit__(None, None, None)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await http_client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
293
src/api/handlers/base/stream_telemetry.py
Normal file
293
src/api/handlers/base/stream_telemetry.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""
|
||||
流式遥测记录器 - 从 ChatHandlerBase 提取的统计记录逻辑
|
||||
|
||||
职责:
|
||||
1. 记录流式请求的成功/失败统计
|
||||
2. 更新 Usage 状态
|
||||
3. 更新候选记录状态
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.handlers.base.base_handler import MessageTelemetry
|
||||
from src.api.handlers.base.stream_context import StreamContext
|
||||
from src.config.settings import config
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import ApiKey, User
|
||||
|
||||
|
||||
class StreamTelemetryRecorder:
|
||||
"""
|
||||
流式遥测记录器
|
||||
|
||||
负责在流式请求完成后记录统计信息。
|
||||
从 ChatHandlerBase 中提取的 _record_stream_stats 逻辑。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
user_id: str,
|
||||
api_key_id: str,
|
||||
client_ip: str,
|
||||
format_id: str,
|
||||
):
|
||||
"""
|
||||
初始化遥测记录器
|
||||
|
||||
Args:
|
||||
request_id: 请求 ID
|
||||
user_id: 用户 ID
|
||||
api_key_id: API Key ID
|
||||
client_ip: 客户端 IP
|
||||
format_id: API 格式标识
|
||||
"""
|
||||
self.request_id = request_id
|
||||
self.user_id = user_id
|
||||
self.api_key_id = api_key_id
|
||||
self.client_ip = client_ip
|
||||
self.format_id = format_id
|
||||
|
||||
async def record_stream_stats(
|
||||
self,
|
||||
ctx: StreamContext,
|
||||
original_headers: Dict[str, str],
|
||||
original_request_body: Dict[str, Any],
|
||||
response_time_ms: int,
|
||||
) -> None:
|
||||
"""
|
||||
记录流式统计信息
|
||||
|
||||
Args:
|
||||
ctx: 流式上下文
|
||||
original_headers: 原始请求头
|
||||
original_request_body: 原始请求体
|
||||
response_time_ms: 响应时间(毫秒)
|
||||
"""
|
||||
bg_db = None
|
||||
|
||||
try:
|
||||
await asyncio.sleep(config.stream_stats_delay) # 等待流完全关闭
|
||||
|
||||
if not ctx.provider_name:
|
||||
await self._update_usage_status_on_error(
|
||||
response_time_ms=response_time_ms,
|
||||
error_message="Provider name not available",
|
||||
)
|
||||
return
|
||||
|
||||
db_gen = get_db()
|
||||
bg_db = next(db_gen)
|
||||
|
||||
try:
|
||||
user = bg_db.query(User).filter(User.id == self.user_id).first()
|
||||
api_key_obj = bg_db.query(ApiKey).filter(ApiKey.id == self.api_key_id).first()
|
||||
|
||||
if not user or not api_key_obj:
|
||||
logger.warning(
|
||||
f"[{self.request_id}] User or ApiKey not found, updating status directly"
|
||||
)
|
||||
await self._update_usage_status_directly(
|
||||
bg_db,
|
||||
status="completed" if ctx.is_success() else "failed",
|
||||
response_time_ms=response_time_ms,
|
||||
status_code=ctx.status_code,
|
||||
)
|
||||
return
|
||||
|
||||
bg_telemetry = MessageTelemetry(
|
||||
bg_db, user, api_key_obj, self.request_id, self.client_ip
|
||||
)
|
||||
|
||||
actual_request_body = ctx.provider_request_body or original_request_body
|
||||
response_body = ctx.build_response_body(response_time_ms)
|
||||
|
||||
if ctx.is_success():
|
||||
await self._record_success(
|
||||
bg_telemetry,
|
||||
ctx,
|
||||
original_headers,
|
||||
actual_request_body,
|
||||
response_body,
|
||||
response_time_ms,
|
||||
)
|
||||
else:
|
||||
await self._record_failure(
|
||||
bg_telemetry,
|
||||
ctx,
|
||||
original_headers,
|
||||
actual_request_body,
|
||||
response_body,
|
||||
response_time_ms,
|
||||
)
|
||||
|
||||
# 更新候选记录状态
|
||||
await self._update_candidate_status(bg_db, ctx, response_time_ms)
|
||||
|
||||
finally:
|
||||
if bg_db:
|
||||
bg_db.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("记录流式统计信息时出错")
|
||||
await self._update_usage_status_on_error(
|
||||
response_time_ms=response_time_ms,
|
||||
error_message=f"记录统计信息失败: {str(e)[:200]}",
|
||||
)
|
||||
|
||||
async def _record_success(
|
||||
self,
|
||||
telemetry: MessageTelemetry,
|
||||
ctx: StreamContext,
|
||||
original_headers: Dict[str, str],
|
||||
actual_request_body: Dict[str, Any],
|
||||
response_body: Dict[str, Any],
|
||||
response_time_ms: int,
|
||||
) -> None:
|
||||
"""记录成功的请求"""
|
||||
await telemetry.record_success(
|
||||
provider=ctx.provider_name or "unknown",
|
||||
model=ctx.model,
|
||||
input_tokens=ctx.input_tokens,
|
||||
output_tokens=ctx.output_tokens,
|
||||
response_time_ms=response_time_ms,
|
||||
status_code=ctx.status_code,
|
||||
request_headers=original_headers,
|
||||
request_body=actual_request_body,
|
||||
response_headers=ctx.response_headers,
|
||||
response_body=response_body,
|
||||
cache_creation_tokens=ctx.cache_creation_tokens,
|
||||
cache_read_tokens=ctx.cached_tokens,
|
||||
is_stream=True,
|
||||
provider_request_headers=ctx.provider_request_headers,
|
||||
api_format=ctx.api_format,
|
||||
provider_id=ctx.provider_id,
|
||||
provider_endpoint_id=ctx.endpoint_id,
|
||||
provider_api_key_id=ctx.key_id,
|
||||
target_model=ctx.mapped_model,
|
||||
)
|
||||
|
||||
logger.debug(f"{self.format_id} 流式响应完成")
|
||||
logger.info(ctx.get_log_summary(self.request_id, response_time_ms))
|
||||
|
||||
async def _record_failure(
|
||||
self,
|
||||
telemetry: MessageTelemetry,
|
||||
ctx: StreamContext,
|
||||
original_headers: Dict[str, str],
|
||||
actual_request_body: Dict[str, Any],
|
||||
response_body: Dict[str, Any],
|
||||
response_time_ms: int,
|
||||
) -> None:
|
||||
"""记录失败的请求"""
|
||||
await telemetry.record_failure(
|
||||
provider=ctx.provider_name or "unknown",
|
||||
model=ctx.model,
|
||||
response_time_ms=response_time_ms,
|
||||
status_code=ctx.status_code,
|
||||
error_message=ctx.error_message or f"HTTP {ctx.status_code}",
|
||||
request_headers=original_headers,
|
||||
request_body=actual_request_body,
|
||||
is_stream=True,
|
||||
api_format=ctx.api_format,
|
||||
provider_request_headers=ctx.provider_request_headers,
|
||||
input_tokens=ctx.input_tokens,
|
||||
output_tokens=ctx.output_tokens,
|
||||
cache_creation_tokens=ctx.cache_creation_tokens,
|
||||
cache_read_tokens=ctx.cached_tokens,
|
||||
response_body=response_body,
|
||||
target_model=ctx.mapped_model,
|
||||
)
|
||||
|
||||
logger.debug(f"{self.format_id} 流式响应中断")
|
||||
log_summary = ctx.get_log_summary(self.request_id, response_time_ms)
|
||||
# 对于失败日志,添加缓存信息
|
||||
logger.info(f"{log_summary} cache:{ctx.cached_tokens}")
|
||||
|
||||
async def _update_candidate_status(
|
||||
self,
|
||||
db: Session,
|
||||
ctx: StreamContext,
|
||||
response_time_ms: int,
|
||||
) -> None:
|
||||
"""更新候选记录状态"""
|
||||
if not ctx.attempt_id:
|
||||
return
|
||||
|
||||
from src.services.request.candidate import RequestCandidateService
|
||||
|
||||
if ctx.is_success():
|
||||
RequestCandidateService.mark_candidate_success(
|
||||
db=db,
|
||||
candidate_id=ctx.attempt_id,
|
||||
status_code=ctx.status_code,
|
||||
latency_ms=response_time_ms,
|
||||
extra_data={
|
||||
"stream_completed": True,
|
||||
"data_count": ctx.data_count,
|
||||
},
|
||||
)
|
||||
else:
|
||||
error_type = "client_disconnected" if ctx.status_code == 499 else "stream_error"
|
||||
RequestCandidateService.mark_candidate_failed(
|
||||
db=db,
|
||||
candidate_id=ctx.attempt_id,
|
||||
error_type=error_type,
|
||||
error_message=ctx.error_message or f"HTTP {ctx.status_code}",
|
||||
status_code=ctx.status_code,
|
||||
latency_ms=response_time_ms,
|
||||
extra_data={
|
||||
"stream_completed": False,
|
||||
"data_count": ctx.data_count,
|
||||
},
|
||||
)
|
||||
|
||||
async def _update_usage_status_on_error(
|
||||
self,
|
||||
response_time_ms: int,
|
||||
error_message: str,
|
||||
) -> None:
|
||||
"""在记录失败时更新 Usage 状态"""
|
||||
try:
|
||||
db_gen = get_db()
|
||||
error_db = next(db_gen)
|
||||
try:
|
||||
await self._update_usage_status_directly(
|
||||
error_db,
|
||||
status="failed",
|
||||
response_time_ms=response_time_ms,
|
||||
status_code=500,
|
||||
error_message=error_message,
|
||||
)
|
||||
finally:
|
||||
error_db.close()
|
||||
except Exception as inner_e:
|
||||
logger.error(f"[{self.request_id}] 更新 Usage 状态失败: {inner_e}")
|
||||
|
||||
async def _update_usage_status_directly(
|
||||
self,
|
||||
db: Session,
|
||||
status: str,
|
||||
response_time_ms: int,
|
||||
status_code: int = 200,
|
||||
error_message: Optional[str] = None,
|
||||
) -> None:
|
||||
"""直接更新 Usage 表的状态字段"""
|
||||
try:
|
||||
from src.models.database import Usage
|
||||
|
||||
usage = db.query(Usage).filter(Usage.request_id == self.request_id).first()
|
||||
if usage:
|
||||
setattr(usage, "status", status)
|
||||
setattr(usage, "status_code", status_code)
|
||||
setattr(usage, "response_time_ms", response_time_ms)
|
||||
if error_message:
|
||||
setattr(usage, "error_message", error_message)
|
||||
db.commit()
|
||||
logger.debug(f"[{self.request_id}] Usage 状态已更新: {status}")
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.request_id}] 直接更新 Usage 状态失败: {e}")
|
||||
@@ -120,6 +120,23 @@ class Config:
|
||||
self.db_pool_recycle = int(os.getenv("DB_POOL_RECYCLE", "3600"))
|
||||
self.db_pool_warn_threshold = int(os.getenv("DB_POOL_WARN_THRESHOLD", "70"))
|
||||
|
||||
# 并发控制配置
|
||||
# CONCURRENCY_SLOT_TTL: 并发槽位 TTL(秒),防止死锁
|
||||
# CACHE_RESERVATION_RATIO: 缓存用户预留比例(默认 30%)
|
||||
self.concurrency_slot_ttl = int(os.getenv("CONCURRENCY_SLOT_TTL", "600"))
|
||||
self.cache_reservation_ratio = float(os.getenv("CACHE_RESERVATION_RATIO", "0.3"))
|
||||
|
||||
# HTTP 请求超时配置(秒)
|
||||
self.http_connect_timeout = float(os.getenv("HTTP_CONNECT_TIMEOUT", "10.0"))
|
||||
self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0"))
|
||||
self.http_pool_timeout = float(os.getenv("HTTP_POOL_TIMEOUT", "10.0"))
|
||||
|
||||
# 流式处理配置
|
||||
# STREAM_PREFETCH_LINES: 预读行数,用于检测嵌套错误
|
||||
# STREAM_STATS_DELAY: 统计记录延迟(秒),等待流完全关闭
|
||||
self.stream_prefetch_lines = int(os.getenv("STREAM_PREFETCH_LINES", "5"))
|
||||
self.stream_stats_delay = float(os.getenv("STREAM_STATS_DELAY", "0.1"))
|
||||
|
||||
# 验证连接池配置
|
||||
self._validate_pool_config()
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ import asyncio
|
||||
import math
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import timedelta
|
||||
from datetime import timedelta # noqa: F401 - kept for potential future use
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
@@ -185,8 +185,8 @@ class ConcurrencyManager:
|
||||
key_id: str,
|
||||
key_max_concurrent: Optional[int],
|
||||
is_cached_user: bool = False, # 新增:是否是缓存用户
|
||||
cache_reservation_ratio: float = 0.3, # 新增:缓存预留比例
|
||||
ttl_seconds: int = 600, # 10分钟 TTL,防止死锁
|
||||
cache_reservation_ratio: Optional[float] = None, # 缓存预留比例,None 时从配置读取
|
||||
ttl_seconds: Optional[int] = None, # TTL 秒数,None 时从配置读取
|
||||
) -> bool:
|
||||
"""
|
||||
尝试获取并发槽位(支持缓存用户优先级)
|
||||
@@ -197,8 +197,8 @@ class ConcurrencyManager:
|
||||
key_id: ProviderAPIKey ID
|
||||
key_max_concurrent: Key 最大并发数(None 表示不限制)
|
||||
is_cached_user: 是否是缓存用户(缓存用户可使用全部槽位)
|
||||
cache_reservation_ratio: 缓存预留比例(默认30%,只对新用户生效)
|
||||
ttl_seconds: TTL 秒数,防止异常情况下的死锁
|
||||
cache_reservation_ratio: 缓存预留比例,None 时从配置读取
|
||||
ttl_seconds: TTL 秒数,None 时从配置读取
|
||||
|
||||
Returns:
|
||||
是否成功获取(True/False)
|
||||
@@ -209,6 +209,14 @@ class ConcurrencyManager:
|
||||
- 缓存用户最多使用: 10个槽位(全部)
|
||||
- 预留的3个槽位专门给缓存用户,保证他们的请求优先
|
||||
"""
|
||||
# 从配置读取默认值
|
||||
from src.config.settings import config
|
||||
|
||||
if cache_reservation_ratio is None:
|
||||
cache_reservation_ratio = config.cache_reservation_ratio
|
||||
if ttl_seconds is None:
|
||||
ttl_seconds = config.concurrency_slot_ttl
|
||||
|
||||
if self._redis is None:
|
||||
async with self._memory_lock:
|
||||
endpoint_count = self._memory_endpoint_counts.get(endpoint_id, 0)
|
||||
@@ -426,7 +434,7 @@ class ConcurrencyManager:
|
||||
key_id: str,
|
||||
key_max_concurrent: Optional[int],
|
||||
is_cached_user: bool = False, # 新增:是否是缓存用户
|
||||
cache_reservation_ratio: float = 0.3, # 新增:缓存预留比例
|
||||
cache_reservation_ratio: Optional[float] = None, # 缓存预留比例,None 时从配置读取
|
||||
):
|
||||
"""
|
||||
并发控制上下文管理器(支持缓存用户优先级)
|
||||
@@ -441,6 +449,12 @@ class ConcurrencyManager:
|
||||
|
||||
如果获取失败,会抛出 ConcurrencyLimitError 异常
|
||||
"""
|
||||
# 从配置读取默认值
|
||||
from src.config.settings import config
|
||||
|
||||
if cache_reservation_ratio is None:
|
||||
cache_reservation_ratio = config.cache_reservation_ratio
|
||||
|
||||
# 尝试获取槽位(传递缓存用户参数)
|
||||
acquired = await self.acquire_slot(
|
||||
endpoint_id,
|
||||
|
||||
Reference in New Issue
Block a user