mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 00:02:28 +08:00
refactor(cli-handler): improve stream handling and response processing
- Refactor CLI handler base for better stream context management - Optimize request/response handling for Claude, OpenAI, and Gemini CLI adapters - Enhance telemetry tracking across CLI handlers
This commit is contained in:
@@ -11,17 +11,15 @@ CLI Message Handler 通用基类
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import codecs
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from abc import abstractmethod
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncGenerator,
|
AsyncGenerator,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@@ -35,6 +33,8 @@ from src.api.handlers.base.base_handler import (
|
|||||||
)
|
)
|
||||||
from src.api.handlers.base.parsers import get_parser_for_format
|
from src.api.handlers.base.parsers import get_parser_for_format
|
||||||
from src.api.handlers.base.request_builder import PassthroughRequestBuilder
|
from src.api.handlers.base.request_builder import PassthroughRequestBuilder
|
||||||
|
from src.api.handlers.base.stream_context import StreamContext
|
||||||
|
from src.api.handlers.base.utils import build_sse_headers
|
||||||
|
|
||||||
# 直接从具体模块导入,避免循环依赖
|
# 直接从具体模块导入,避免循环依赖
|
||||||
from src.api.handlers.base.response_parser import (
|
from src.api.handlers.base.response_parser import (
|
||||||
@@ -61,63 +61,6 @@ from src.services.provider.transport import build_provider_url
|
|||||||
from src.utils.sse_parser import SSEEventParser
|
from src.utils.sse_parser import SSEEventParser
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class StreamContext:
|
|
||||||
"""流式请求的上下文信息"""
|
|
||||||
|
|
||||||
# 请求信息
|
|
||||||
model: str = "unknown" # 用户请求的原始模型名
|
|
||||||
mapped_model: Optional[str] = None # 映射后的目标模型名(如果发生了映射)
|
|
||||||
api_format: str = ""
|
|
||||||
request_id: str = ""
|
|
||||||
|
|
||||||
# 用户信息(提前提取避免 Session detached)
|
|
||||||
user_id: int = 0
|
|
||||||
api_key_id: int = 0
|
|
||||||
|
|
||||||
# 统计信息
|
|
||||||
input_tokens: int = 0
|
|
||||||
output_tokens: int = 0
|
|
||||||
cached_tokens: int = 0 # cache_read_input_tokens
|
|
||||||
cache_creation_tokens: int = 0 # cache_creation_input_tokens
|
|
||||||
collected_text: str = ""
|
|
||||||
response_id: Optional[str] = None
|
|
||||||
final_usage: Optional[Dict[str, Any]] = None
|
|
||||||
final_response: Optional[Dict[str, Any]] = None
|
|
||||||
parsed_chunks: list = field(default_factory=list)
|
|
||||||
|
|
||||||
# 流状态
|
|
||||||
start_time: float = field(default_factory=time.time)
|
|
||||||
chunk_count: int = 0
|
|
||||||
data_count: int = 0
|
|
||||||
has_completion: bool = False
|
|
||||||
|
|
||||||
# 响应信息
|
|
||||||
status_code: int = 200
|
|
||||||
response_headers: Dict[str, str] = field(default_factory=dict)
|
|
||||||
|
|
||||||
# 请求信息(发送给 Provider 的)
|
|
||||||
provider_request_headers: Dict[str, str] = field(default_factory=dict)
|
|
||||||
provider_request_body: Optional[Dict[str, Any]] = None # 实际发送的请求体
|
|
||||||
|
|
||||||
# Provider 信息
|
|
||||||
provider_name: Optional[str] = None
|
|
||||||
provider_id: Optional[str] = None # Provider ID(用于记录真实成本)
|
|
||||||
endpoint_id: Optional[str] = None
|
|
||||||
key_id: Optional[str] = None
|
|
||||||
attempt_id: Optional[str] = None
|
|
||||||
attempt_synced: bool = False
|
|
||||||
error_message: Optional[str] = None
|
|
||||||
|
|
||||||
# 格式转换信息
|
|
||||||
provider_api_format: str = "" # Provider 的 API 格式(用于响应转换)
|
|
||||||
client_api_format: str = "" # 客户端请求的 API 格式
|
|
||||||
|
|
||||||
# Provider 响应元数据(存储 provider 返回的额外信息,如 Gemini 的 modelVersion)
|
|
||||||
response_metadata: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
class CliMessageHandlerBase(BaseMessageHandler):
|
class CliMessageHandlerBase(BaseMessageHandler):
|
||||||
"""
|
"""
|
||||||
CLI 格式消息处理器基类
|
CLI 格式消息处理器基类
|
||||||
@@ -409,6 +352,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
monitored_stream,
|
monitored_stream,
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
|
headers=build_sse_headers(),
|
||||||
background=background_tasks,
|
background=background_tasks,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -433,7 +377,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
ctx.chunk_count = 0
|
ctx.chunk_count = 0
|
||||||
ctx.data_count = 0
|
ctx.data_count = 0
|
||||||
ctx.has_completion = False
|
ctx.has_completion = False
|
||||||
ctx.collected_text = ""
|
ctx._collected_text_parts = [] # 重置文本收集
|
||||||
ctx.input_tokens = 0
|
ctx.input_tokens = 0
|
||||||
ctx.output_tokens = 0
|
ctx.output_tokens = 0
|
||||||
ctx.cached_tokens = 0
|
ctx.cached_tokens = 0
|
||||||
@@ -521,12 +465,12 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
|
|
||||||
stream_response.raise_for_status()
|
stream_response.raise_for_status()
|
||||||
|
|
||||||
# 创建行迭代器(只创建一次,后续会继续使用)
|
# 使用字节流迭代器(避免 aiter_lines 的性能问题)
|
||||||
line_iterator = stream_response.aiter_lines()
|
byte_iterator = stream_response.aiter_raw()
|
||||||
|
|
||||||
# 预读第一个数据块,检测嵌套错误(HTTP 200 但响应体包含错误)
|
# 预读第一个数据块,检测嵌套错误(HTTP 200 但响应体包含错误)
|
||||||
prefetched_lines = await self._prefetch_and_check_embedded_error(
|
prefetched_chunks = await self._prefetch_and_check_embedded_error(
|
||||||
line_iterator, provider, endpoint, ctx
|
byte_iterator, provider, endpoint, ctx
|
||||||
)
|
)
|
||||||
|
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
@@ -551,10 +495,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
# 创建流生成器(带预读数据,使用同一个迭代器)
|
# 创建流生成器(带预读数据,使用同一个迭代器)
|
||||||
return self._create_response_stream_with_prefetch(
|
return self._create_response_stream_with_prefetch(
|
||||||
ctx,
|
ctx,
|
||||||
line_iterator,
|
byte_iterator,
|
||||||
response_ctx,
|
response_ctx,
|
||||||
http_client,
|
http_client,
|
||||||
prefetched_lines,
|
prefetched_chunks,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _create_response_stream(
|
async def _create_response_stream(
|
||||||
@@ -564,58 +508,75 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
response_ctx: Any,
|
response_ctx: Any,
|
||||||
http_client: httpx.AsyncClient,
|
http_client: httpx.AsyncClient,
|
||||||
) -> AsyncGenerator[bytes, None]:
|
) -> AsyncGenerator[bytes, None]:
|
||||||
"""创建响应流生成器"""
|
"""创建响应流生成器(使用字节流)"""
|
||||||
try:
|
try:
|
||||||
sse_parser = SSEEventParser()
|
sse_parser = SSEEventParser()
|
||||||
last_data_time = time.time()
|
last_data_time = time.time()
|
||||||
streaming_status_updated = False
|
streaming_status_updated = False
|
||||||
|
buffer = b""
|
||||||
|
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
|
||||||
|
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
||||||
|
|
||||||
# 检查是否需要格式转换
|
# 检查是否需要格式转换
|
||||||
needs_conversion = self._needs_format_conversion(ctx)
|
needs_conversion = self._needs_format_conversion(ctx)
|
||||||
|
|
||||||
async for line in stream_response.aiter_lines():
|
async for chunk in stream_response.aiter_raw():
|
||||||
# 在第一次输出数据前更新状态为 streaming
|
# 在第一次输出数据前更新状态为 streaming
|
||||||
if not streaming_status_updated:
|
if not streaming_status_updated:
|
||||||
self._update_usage_to_streaming(ctx.request_id)
|
self._update_usage_to_streaming(ctx.request_id)
|
||||||
streaming_status_updated = True
|
streaming_status_updated = True
|
||||||
|
|
||||||
normalized_line = line.rstrip("\r")
|
buffer += chunk
|
||||||
events = sse_parser.feed_line(normalized_line)
|
# 处理缓冲区中的完整行
|
||||||
|
while b"\n" in buffer:
|
||||||
if normalized_line == "":
|
line_bytes, buffer = buffer.split(b"\n", 1)
|
||||||
for event in events:
|
try:
|
||||||
self._handle_sse_event(
|
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
|
||||||
ctx,
|
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
|
||||||
event.get("event"),
|
except Exception as e:
|
||||||
event.get("data") or "",
|
logger.warning(
|
||||||
|
f"[{self.request_id}] UTF-8 解码失败: {e}, "
|
||||||
|
f"bytes={line_bytes[:50]!r}"
|
||||||
)
|
)
|
||||||
yield b"\n"
|
continue
|
||||||
continue
|
|
||||||
|
|
||||||
ctx.chunk_count += 1
|
normalized_line = line.rstrip("\r")
|
||||||
|
events = sse_parser.feed_line(normalized_line)
|
||||||
|
|
||||||
# 空流检测:超过阈值且无数据,发送错误事件并结束
|
if normalized_line == "":
|
||||||
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
|
for event in events:
|
||||||
elapsed = time.time() - last_data_time
|
self._handle_sse_event(
|
||||||
if elapsed > self.DATA_TIMEOUT:
|
ctx,
|
||||||
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
|
event.get("event"),
|
||||||
error_event = {
|
event.get("data") or "",
|
||||||
"type": "error",
|
)
|
||||||
"error": {
|
yield b"\n"
|
||||||
"type": "empty_stream_timeout",
|
continue
|
||||||
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
|
||||||
return # 结束生成器
|
|
||||||
|
|
||||||
# 格式转换或直接透传
|
ctx.chunk_count += 1
|
||||||
if needs_conversion:
|
|
||||||
converted_line = self._convert_sse_line(ctx, line, events)
|
# 空流检测:超过阈值且无数据,发送错误事件并结束
|
||||||
if converted_line:
|
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
|
||||||
yield (converted_line + "\n").encode("utf-8")
|
elapsed = time.time() - last_data_time
|
||||||
else:
|
if elapsed > self.DATA_TIMEOUT:
|
||||||
yield (line + "\n").encode("utf-8")
|
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
|
||||||
|
error_event = {
|
||||||
|
"type": "error",
|
||||||
|
"error": {
|
||||||
|
"type": "empty_stream_timeout",
|
||||||
|
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
||||||
|
return # 结束生成器
|
||||||
|
|
||||||
|
# 格式转换或直接透传
|
||||||
|
if needs_conversion:
|
||||||
|
converted_line = self._convert_sse_line(ctx, line, events)
|
||||||
|
if converted_line:
|
||||||
|
yield (converted_line + "\n").encode("utf-8")
|
||||||
|
else:
|
||||||
|
yield (line + "\n").encode("utf-8")
|
||||||
|
|
||||||
for event in events:
|
for event in events:
|
||||||
self._handle_sse_event(
|
self._handle_sse_event(
|
||||||
@@ -689,7 +650,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
|
|
||||||
async def _prefetch_and_check_embedded_error(
|
async def _prefetch_and_check_embedded_error(
|
||||||
self,
|
self,
|
||||||
line_iterator: Any,
|
byte_iterator: Any,
|
||||||
provider: Provider,
|
provider: Provider,
|
||||||
endpoint: ProviderEndpoint,
|
endpoint: ProviderEndpoint,
|
||||||
ctx: StreamContext,
|
ctx: StreamContext,
|
||||||
@@ -703,20 +664,25 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
同时检测 HTML 响应(通常是 base_url 配置错误导致返回网页)。
|
同时检测 HTML 响应(通常是 base_url 配置错误导致返回网页)。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
line_iterator: 行迭代器(aiter_lines() 返回的迭代器)
|
byte_iterator: 字节流迭代器
|
||||||
provider: Provider 对象
|
provider: Provider 对象
|
||||||
endpoint: Endpoint 对象
|
endpoint: Endpoint 对象
|
||||||
ctx: 流上下文
|
ctx: 流上下文
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
预读的行列表(需要在后续流中先输出)
|
预读的字节块列表(需要在后续流中先输出)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
EmbeddedErrorException: 如果检测到嵌套错误
|
EmbeddedErrorException: 如果检测到嵌套错误
|
||||||
ProviderNotAvailableException: 如果检测到 HTML 响应(配置错误)
|
ProviderNotAvailableException: 如果检测到 HTML 响应(配置错误)
|
||||||
"""
|
"""
|
||||||
prefetched_lines: list = []
|
prefetched_chunks: list = []
|
||||||
max_prefetch_lines = 5 # 最多预读5行来检测错误
|
max_prefetch_lines = 5 # 最多预读5行来检测错误
|
||||||
|
buffer = b""
|
||||||
|
line_count = 0
|
||||||
|
should_stop = False
|
||||||
|
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
|
||||||
|
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 获取对应格式的解析器
|
# 获取对应格式的解析器
|
||||||
@@ -729,69 +695,86 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
else:
|
else:
|
||||||
provider_parser = self.parser
|
provider_parser = self.parser
|
||||||
|
|
||||||
line_count = 0
|
async for chunk in byte_iterator:
|
||||||
async for line in line_iterator:
|
prefetched_chunks.append(chunk)
|
||||||
prefetched_lines.append(line)
|
buffer += chunk
|
||||||
line_count += 1
|
|
||||||
|
|
||||||
# 解析数据
|
# 尝试按行解析缓冲区
|
||||||
normalized_line = line.rstrip("\r")
|
while b"\n" in buffer:
|
||||||
|
line_bytes, buffer = buffer.split(b"\n", 1)
|
||||||
|
try:
|
||||||
|
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
|
||||||
|
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"[{self.request_id}] 预读时 UTF-8 解码失败: {e}, "
|
||||||
|
f"bytes={line_bytes[:50]!r}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
# 检测 HTML 响应(base_url 配置错误的常见症状)
|
line_count += 1
|
||||||
lower_line = normalized_line.lower()
|
normalized_line = line.rstrip("\r")
|
||||||
if lower_line.startswith("<!doctype") or lower_line.startswith("<html"):
|
|
||||||
logger.error(
|
|
||||||
f" [{self.request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
|
|
||||||
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
|
|
||||||
f"base_url={endpoint.base_url}"
|
|
||||||
)
|
|
||||||
raise ProviderNotAvailableException(
|
|
||||||
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,请检查 endpoint 的 base_url 配置是否正确"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not normalized_line or normalized_line.startswith(":"):
|
# 检测 HTML 响应(base_url 配置错误的常见症状)
|
||||||
# 空行或注释行,继续预读
|
lower_line = normalized_line.lower()
|
||||||
if line_count >= max_prefetch_lines:
|
if lower_line.startswith("<!doctype") or lower_line.startswith("<html"):
|
||||||
|
logger.error(
|
||||||
|
f" [{self.request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
|
||||||
|
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
|
||||||
|
f"base_url={endpoint.base_url}"
|
||||||
|
)
|
||||||
|
raise ProviderNotAvailableException(
|
||||||
|
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,请检查 endpoint 的 base_url 配置是否正确"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not normalized_line or normalized_line.startswith(":"):
|
||||||
|
# 空行或注释行,继续预读
|
||||||
|
if line_count >= max_prefetch_lines:
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 尝试解析 SSE 数据
|
||||||
|
data_str = normalized_line
|
||||||
|
if normalized_line.startswith("data: "):
|
||||||
|
data_str = normalized_line[6:]
|
||||||
|
|
||||||
|
if data_str == "[DONE]":
|
||||||
|
should_stop = True
|
||||||
break
|
break
|
||||||
continue
|
|
||||||
|
|
||||||
# 尝试解析 SSE 数据
|
try:
|
||||||
data_str = normalized_line
|
data = json.loads(data_str)
|
||||||
if normalized_line.startswith("data: "):
|
except json.JSONDecodeError:
|
||||||
data_str = normalized_line[6:]
|
# 不是有效 JSON,可能是部分数据,继续
|
||||||
|
if line_count >= max_prefetch_lines:
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
|
||||||
if data_str == "[DONE]":
|
# 使用解析器检查是否为错误响应
|
||||||
|
if isinstance(data, dict) and provider_parser.is_error_response(data):
|
||||||
|
# 提取错误信息
|
||||||
|
parsed = provider_parser.parse_response(data, 200)
|
||||||
|
logger.warning(f" [{self.request_id}] 检测到嵌套错误: "
|
||||||
|
f"Provider={provider.name}, "
|
||||||
|
f"error_type={parsed.error_type}, "
|
||||||
|
f"message={parsed.error_message}")
|
||||||
|
raise EmbeddedErrorException(
|
||||||
|
provider_name=str(provider.name),
|
||||||
|
error_code=(
|
||||||
|
int(parsed.error_type)
|
||||||
|
if parsed.error_type and parsed.error_type.isdigit()
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
error_message=parsed.error_message,
|
||||||
|
error_status=parsed.error_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 预读到有效数据,没有错误,停止预读
|
||||||
|
should_stop = True
|
||||||
break
|
break
|
||||||
|
|
||||||
try:
|
if should_stop or line_count >= max_prefetch_lines:
|
||||||
data = json.loads(data_str)
|
break
|
||||||
except json.JSONDecodeError:
|
|
||||||
# 不是有效 JSON,可能是部分数据,继续
|
|
||||||
if line_count >= max_prefetch_lines:
|
|
||||||
break
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 使用解析器检查是否为错误响应
|
|
||||||
if isinstance(data, dict) and provider_parser.is_error_response(data):
|
|
||||||
# 提取错误信息
|
|
||||||
parsed = provider_parser.parse_response(data, 200)
|
|
||||||
logger.warning(f" [{self.request_id}] 检测到嵌套错误: "
|
|
||||||
f"Provider={provider.name}, "
|
|
||||||
f"error_type={parsed.error_type}, "
|
|
||||||
f"message={parsed.error_message}")
|
|
||||||
raise EmbeddedErrorException(
|
|
||||||
provider_name=str(provider.name),
|
|
||||||
error_code=(
|
|
||||||
int(parsed.error_type)
|
|
||||||
if parsed.error_type and parsed.error_type.isdigit()
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
error_message=parsed.error_message,
|
|
||||||
error_status=parsed.error_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 预读到有效数据,没有错误,停止预读
|
|
||||||
break
|
|
||||||
|
|
||||||
except EmbeddedErrorException:
|
except EmbeddedErrorException:
|
||||||
# 重新抛出嵌套错误
|
# 重新抛出嵌套错误
|
||||||
@@ -800,112 +783,168 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
# 其他异常(如网络错误)在预读阶段发生,记录日志但不中断
|
# 其他异常(如网络错误)在预读阶段发生,记录日志但不中断
|
||||||
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
|
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
|
||||||
|
|
||||||
return prefetched_lines
|
return prefetched_chunks
|
||||||
|
|
||||||
async def _create_response_stream_with_prefetch(
|
async def _create_response_stream_with_prefetch(
|
||||||
self,
|
self,
|
||||||
ctx: StreamContext,
|
ctx: StreamContext,
|
||||||
line_iterator: Any,
|
byte_iterator: Any,
|
||||||
response_ctx: Any,
|
response_ctx: Any,
|
||||||
http_client: httpx.AsyncClient,
|
http_client: httpx.AsyncClient,
|
||||||
prefetched_lines: list,
|
prefetched_chunks: list,
|
||||||
) -> AsyncGenerator[bytes, None]:
|
) -> AsyncGenerator[bytes, None]:
|
||||||
"""创建响应流生成器(带预读数据)"""
|
"""创建响应流生成器(带预读数据,使用字节流)"""
|
||||||
try:
|
try:
|
||||||
sse_parser = SSEEventParser()
|
sse_parser = SSEEventParser()
|
||||||
last_data_time = time.time()
|
last_data_time = time.time()
|
||||||
|
buffer = b""
|
||||||
|
first_yield = True # 标记是否是第一次 yield
|
||||||
|
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
|
||||||
|
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
||||||
|
|
||||||
# 检查是否需要格式转换
|
# 检查是否需要格式转换
|
||||||
needs_conversion = self._needs_format_conversion(ctx)
|
needs_conversion = self._needs_format_conversion(ctx)
|
||||||
|
|
||||||
# 在第一次输出数据前更新状态为 streaming
|
# 在第一次输出数据前更新状态为 streaming
|
||||||
if prefetched_lines:
|
if prefetched_chunks:
|
||||||
self._update_usage_to_streaming(ctx.request_id)
|
self._update_usage_to_streaming(ctx.request_id)
|
||||||
|
|
||||||
# 先处理预读的数据
|
# 先处理预读的字节块
|
||||||
for line in prefetched_lines:
|
for chunk in prefetched_chunks:
|
||||||
normalized_line = line.rstrip("\r")
|
buffer += chunk
|
||||||
events = sse_parser.feed_line(normalized_line)
|
# 处理缓冲区中的完整行
|
||||||
|
while b"\n" in buffer:
|
||||||
|
line_bytes, buffer = buffer.split(b"\n", 1)
|
||||||
|
try:
|
||||||
|
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
|
||||||
|
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"[{self.request_id}] UTF-8 解码失败: {e}, "
|
||||||
|
f"bytes={line_bytes[:50]!r}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
normalized_line = line.rstrip("\r")
|
||||||
|
events = sse_parser.feed_line(normalized_line)
|
||||||
|
|
||||||
|
if normalized_line == "":
|
||||||
|
for event in events:
|
||||||
|
self._handle_sse_event(
|
||||||
|
ctx,
|
||||||
|
event.get("event"),
|
||||||
|
event.get("data") or "",
|
||||||
|
)
|
||||||
|
# 记录首字时间 (第一次 yield)
|
||||||
|
if first_yield:
|
||||||
|
ctx.record_first_byte_time(self.start_time)
|
||||||
|
first_yield = False
|
||||||
|
yield b"\n"
|
||||||
|
continue
|
||||||
|
|
||||||
|
ctx.chunk_count += 1
|
||||||
|
|
||||||
|
# 格式转换或直接透传
|
||||||
|
if needs_conversion:
|
||||||
|
converted_line = self._convert_sse_line(ctx, line, events)
|
||||||
|
if converted_line:
|
||||||
|
# 记录首字时间 (第一次 yield)
|
||||||
|
if first_yield:
|
||||||
|
ctx.record_first_byte_time(self.start_time)
|
||||||
|
first_yield = False
|
||||||
|
yield (converted_line + "\n").encode("utf-8")
|
||||||
|
else:
|
||||||
|
# 记录首字时间 (第一次 yield)
|
||||||
|
if first_yield:
|
||||||
|
ctx.record_first_byte_time(self.start_time)
|
||||||
|
first_yield = False
|
||||||
|
yield (line + "\n").encode("utf-8")
|
||||||
|
|
||||||
if normalized_line == "":
|
|
||||||
for event in events:
|
for event in events:
|
||||||
self._handle_sse_event(
|
self._handle_sse_event(
|
||||||
ctx,
|
ctx,
|
||||||
event.get("event"),
|
event.get("event"),
|
||||||
event.get("data") or "",
|
event.get("data") or "",
|
||||||
)
|
)
|
||||||
yield b"\n"
|
|
||||||
continue
|
|
||||||
|
|
||||||
ctx.chunk_count += 1
|
if ctx.data_count > 0:
|
||||||
|
last_data_time = time.time()
|
||||||
# 格式转换或直接透传
|
|
||||||
if needs_conversion:
|
|
||||||
converted_line = self._convert_sse_line(ctx, line, events)
|
|
||||||
if converted_line:
|
|
||||||
yield (converted_line + "\n").encode("utf-8")
|
|
||||||
else:
|
|
||||||
yield (line + "\n").encode("utf-8")
|
|
||||||
|
|
||||||
for event in events:
|
|
||||||
self._handle_sse_event(
|
|
||||||
ctx,
|
|
||||||
event.get("event"),
|
|
||||||
event.get("data") or "",
|
|
||||||
)
|
|
||||||
|
|
||||||
if ctx.data_count > 0:
|
|
||||||
last_data_time = time.time()
|
|
||||||
|
|
||||||
# 继续处理剩余的流数据(使用同一个迭代器)
|
# 继续处理剩余的流数据(使用同一个迭代器)
|
||||||
async for line in line_iterator:
|
async for chunk in byte_iterator:
|
||||||
normalized_line = line.rstrip("\r")
|
buffer += chunk
|
||||||
events = sse_parser.feed_line(normalized_line)
|
# 处理缓冲区中的完整行
|
||||||
|
while b"\n" in buffer:
|
||||||
|
line_bytes, buffer = buffer.split(b"\n", 1)
|
||||||
|
try:
|
||||||
|
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
|
||||||
|
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"[{self.request_id}] UTF-8 解码失败: {e}, "
|
||||||
|
f"bytes={line_bytes[:50]!r}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
normalized_line = line.rstrip("\r")
|
||||||
|
events = sse_parser.feed_line(normalized_line)
|
||||||
|
|
||||||
|
if normalized_line == "":
|
||||||
|
for event in events:
|
||||||
|
self._handle_sse_event(
|
||||||
|
ctx,
|
||||||
|
event.get("event"),
|
||||||
|
event.get("data") or "",
|
||||||
|
)
|
||||||
|
# 记录首字时间 (第一次 yield) - 如果预读数据为空
|
||||||
|
if first_yield:
|
||||||
|
ctx.record_first_byte_time(self.start_time)
|
||||||
|
first_yield = False
|
||||||
|
yield b"\n"
|
||||||
|
continue
|
||||||
|
|
||||||
|
ctx.chunk_count += 1
|
||||||
|
|
||||||
|
# 空流检测:超过阈值且无数据,发送错误事件并结束
|
||||||
|
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
|
||||||
|
elapsed = time.time() - last_data_time
|
||||||
|
if elapsed > self.DATA_TIMEOUT:
|
||||||
|
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
|
||||||
|
error_event = {
|
||||||
|
"type": "error",
|
||||||
|
"error": {
|
||||||
|
"type": "empty_stream_timeout",
|
||||||
|
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 格式转换或直接透传
|
||||||
|
if needs_conversion:
|
||||||
|
converted_line = self._convert_sse_line(ctx, line, events)
|
||||||
|
if converted_line:
|
||||||
|
# 记录首字时间 (第一次 yield) - 如果预读数据为空
|
||||||
|
if first_yield:
|
||||||
|
ctx.record_first_byte_time(self.start_time)
|
||||||
|
first_yield = False
|
||||||
|
yield (converted_line + "\n").encode("utf-8")
|
||||||
|
else:
|
||||||
|
# 记录首字时间 (第一次 yield) - 如果预读数据为空
|
||||||
|
if first_yield:
|
||||||
|
ctx.record_first_byte_time(self.start_time)
|
||||||
|
first_yield = False
|
||||||
|
yield (line + "\n").encode("utf-8")
|
||||||
|
|
||||||
if normalized_line == "":
|
|
||||||
for event in events:
|
for event in events:
|
||||||
self._handle_sse_event(
|
self._handle_sse_event(
|
||||||
ctx,
|
ctx,
|
||||||
event.get("event"),
|
event.get("event"),
|
||||||
event.get("data") or "",
|
event.get("data") or "",
|
||||||
)
|
)
|
||||||
yield b"\n"
|
|
||||||
continue
|
|
||||||
|
|
||||||
ctx.chunk_count += 1
|
if ctx.data_count > 0:
|
||||||
|
last_data_time = time.time()
|
||||||
# 空流检测:超过阈值且无数据,发送错误事件并结束
|
|
||||||
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
|
|
||||||
elapsed = time.time() - last_data_time
|
|
||||||
if elapsed > self.DATA_TIMEOUT:
|
|
||||||
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
|
|
||||||
error_event = {
|
|
||||||
"type": "error",
|
|
||||||
"error": {
|
|
||||||
"type": "empty_stream_timeout",
|
|
||||||
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 格式转换或直接透传
|
|
||||||
if needs_conversion:
|
|
||||||
converted_line = self._convert_sse_line(ctx, line, events)
|
|
||||||
if converted_line:
|
|
||||||
yield (converted_line + "\n").encode("utf-8")
|
|
||||||
else:
|
|
||||||
yield (line + "\n").encode("utf-8")
|
|
||||||
|
|
||||||
for event in events:
|
|
||||||
self._handle_sse_event(
|
|
||||||
ctx,
|
|
||||||
event.get("event"),
|
|
||||||
event.get("data") or "",
|
|
||||||
)
|
|
||||||
|
|
||||||
if ctx.data_count > 0:
|
|
||||||
last_data_time = time.time()
|
|
||||||
|
|
||||||
# 处理剩余事件
|
# 处理剩余事件
|
||||||
flushed_events = sse_parser.flush()
|
flushed_events = sse_parser.flush()
|
||||||
@@ -1034,7 +1073,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
# 提取文本内容
|
# 提取文本内容
|
||||||
text = self.parser.extract_text_content(data)
|
text = self.parser.extract_text_content(data)
|
||||||
if text:
|
if text:
|
||||||
ctx.collected_text += text
|
ctx.append_text(text)
|
||||||
|
|
||||||
# 检查完成事件
|
# 检查完成事件
|
||||||
if event_type in ("response.completed", "message_stop"):
|
if event_type in ("response.completed", "message_stop"):
|
||||||
@@ -1086,9 +1125,11 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""在流完成后记录统计信息"""
|
"""在流完成后记录统计信息"""
|
||||||
try:
|
try:
|
||||||
await asyncio.sleep(0.1)
|
# 使用 self.start_time 作为时间基准,与首字时间保持一致
|
||||||
|
# 注意:不要把统计延迟算进响应时间里
|
||||||
|
response_time_ms = int((time.time() - self.start_time) * 1000)
|
||||||
|
|
||||||
response_time_ms = int((time.time() - ctx.start_time) * 1000)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
if not ctx.provider_name:
|
if not ctx.provider_name:
|
||||||
logger.warning(f"[{ctx.request_id}] 流式请求失败,未选中提供商")
|
logger.warning(f"[{ctx.request_id}] 流式请求失败,未选中提供商")
|
||||||
@@ -1168,6 +1209,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
input_tokens=actual_input_tokens,
|
input_tokens=actual_input_tokens,
|
||||||
output_tokens=ctx.output_tokens,
|
output_tokens=ctx.output_tokens,
|
||||||
response_time_ms=response_time_ms,
|
response_time_ms=response_time_ms,
|
||||||
|
first_byte_time_ms=ctx.first_byte_time_ms, # 传递首字时间
|
||||||
status_code=ctx.status_code,
|
status_code=ctx.status_code,
|
||||||
request_headers=original_headers,
|
request_headers=original_headers,
|
||||||
request_body=actual_request_body,
|
request_body=actual_request_body,
|
||||||
@@ -1188,9 +1230,18 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
response_metadata=ctx.response_metadata if ctx.response_metadata else None,
|
response_metadata=ctx.response_metadata if ctx.response_metadata else None,
|
||||||
)
|
)
|
||||||
logger.debug(f"{self.FORMAT_ID} 流式响应完成")
|
logger.debug(f"{self.FORMAT_ID} 流式响应完成")
|
||||||
# 简洁的请求完成摘要
|
# 简洁的请求完成摘要(两行格式)
|
||||||
logger.info(f"[OK] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name} | {response_time_ms}ms | "
|
line1 = (
|
||||||
f"in:{ctx.input_tokens or 0} out:{ctx.output_tokens or 0}")
|
f"[OK] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name}"
|
||||||
|
)
|
||||||
|
if ctx.first_byte_time_ms:
|
||||||
|
line1 += f" | TTFB: {ctx.first_byte_time_ms}ms"
|
||||||
|
|
||||||
|
line2 = (
|
||||||
|
f" Total: {response_time_ms}ms | "
|
||||||
|
f"in:{ctx.input_tokens or 0} out:{ctx.output_tokens or 0}"
|
||||||
|
)
|
||||||
|
logger.info(f"{line1}\n{line2}")
|
||||||
|
|
||||||
# 更新候选记录的最终状态和延迟时间
|
# 更新候选记录的最终状态和延迟时间
|
||||||
# 注意:RequestExecutor 会在流开始时过早地标记成功(只记录了连接建立的时间)
|
# 注意:RequestExecutor 会在流开始时过早地标记成功(只记录了连接建立的时间)
|
||||||
@@ -1242,7 +1293,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
original_request_body: Dict[str, Any],
|
original_request_body: Dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""记录流式请求失败"""
|
"""记录流式请求失败"""
|
||||||
response_time_ms = int((time.time() - ctx.start_time) * 1000)
|
# 使用 self.start_time 作为时间基准,与首字时间保持一致
|
||||||
|
response_time_ms = int((time.time() - self.start_time) * 1000)
|
||||||
|
|
||||||
status_code = 503
|
status_code = 503
|
||||||
if isinstance(error, ProviderAuthException):
|
if isinstance(error, ProviderAuthException):
|
||||||
|
|||||||
@@ -111,7 +111,7 @@ class ClaudeCliMessageHandler(CliMessageHandlerBase):
|
|||||||
if delta.get("type") == "text_delta":
|
if delta.get("type") == "text_delta":
|
||||||
text = delta.get("text", "")
|
text = delta.get("text", "")
|
||||||
if text:
|
if text:
|
||||||
ctx.collected_text += text
|
ctx.append_text(text)
|
||||||
|
|
||||||
# 处理消息增量(包含最终 usage)
|
# 处理消息增量(包含最终 usage)
|
||||||
elif event_type == "message_delta":
|
elif event_type == "message_delta":
|
||||||
|
|||||||
@@ -160,7 +160,7 @@ class GeminiCliMessageHandler(CliMessageHandlerBase):
|
|||||||
parts = content.get("parts", [])
|
parts = content.get("parts", [])
|
||||||
for part in parts:
|
for part in parts:
|
||||||
if "text" in part:
|
if "text" in part:
|
||||||
ctx.collected_text += part["text"]
|
ctx.append_text(part["text"])
|
||||||
|
|
||||||
# 检查结束原因
|
# 检查结束原因
|
||||||
finish_reason = candidate.get("finishReason")
|
finish_reason = candidate.get("finishReason")
|
||||||
|
|||||||
@@ -94,9 +94,9 @@ class OpenAICliMessageHandler(CliMessageHandlerBase):
|
|||||||
if event_type in ["response.output_text.delta", "response.outtext.delta"]:
|
if event_type in ["response.output_text.delta", "response.outtext.delta"]:
|
||||||
delta = data.get("delta")
|
delta = data.get("delta")
|
||||||
if isinstance(delta, str):
|
if isinstance(delta, str):
|
||||||
ctx.collected_text += delta
|
ctx.append_text(delta)
|
||||||
elif isinstance(delta, dict) and "text" in delta:
|
elif isinstance(delta, dict) and "text" in delta:
|
||||||
ctx.collected_text += delta["text"]
|
ctx.append_text(delta["text"])
|
||||||
|
|
||||||
# 处理完成事件
|
# 处理完成事件
|
||||||
elif event_type == "response.completed":
|
elif event_type == "response.completed":
|
||||||
@@ -124,7 +124,7 @@ class OpenAICliMessageHandler(CliMessageHandlerBase):
|
|||||||
if content_item.get("type") == "output_text":
|
if content_item.get("type") == "output_text":
|
||||||
text = content_item.get("text", "")
|
text = content_item.get("text", "")
|
||||||
if text:
|
if text:
|
||||||
ctx.collected_text += text
|
ctx.append_text(text)
|
||||||
|
|
||||||
# 备用:从顶层 usage 提取
|
# 备用:从顶层 usage 提取
|
||||||
usage_obj = data.get("usage")
|
usage_obj = data.get("usage")
|
||||||
|
|||||||
Reference in New Issue
Block a user