refactor(cli-handler): improve stream handling and response processing

- Refactor CLI handler base for better stream context management
- Optimize request/response handling for Claude, OpenAI, and Gemini CLI adapters
- Enhance telemetry tracking across CLI handlers
This commit is contained in:
fawney19
2025-12-16 02:39:20 +08:00
parent ad1c8c394c
commit a3df41d63d
4 changed files with 296 additions and 244 deletions

View File

@@ -11,17 +11,15 @@ CLI Message Handler 通用基类
""" """
import asyncio import asyncio
import codecs
import json import json
import time import time
from abc import abstractmethod
from dataclasses import dataclass, field
from typing import ( from typing import (
Any, Any,
AsyncGenerator, AsyncGenerator,
Callable, Callable,
Dict, Dict,
Optional, Optional,
Tuple,
) )
import httpx import httpx
@@ -35,6 +33,8 @@ from src.api.handlers.base.base_handler import (
) )
from src.api.handlers.base.parsers import get_parser_for_format from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.request_builder import PassthroughRequestBuilder from src.api.handlers.base.request_builder import PassthroughRequestBuilder
from src.api.handlers.base.stream_context import StreamContext
from src.api.handlers.base.utils import build_sse_headers
# 直接从具体模块导入,避免循环依赖 # 直接从具体模块导入,避免循环依赖
from src.api.handlers.base.response_parser import ( from src.api.handlers.base.response_parser import (
@@ -61,63 +61,6 @@ from src.services.provider.transport import build_provider_url
from src.utils.sse_parser import SSEEventParser from src.utils.sse_parser import SSEEventParser
@dataclass
class StreamContext:
"""流式请求的上下文信息"""
# 请求信息
model: str = "unknown" # 用户请求的原始模型名
mapped_model: Optional[str] = None # 映射后的目标模型名(如果发生了映射)
api_format: str = ""
request_id: str = ""
# 用户信息(提前提取避免 Session detached
user_id: int = 0
api_key_id: int = 0
# 统计信息
input_tokens: int = 0
output_tokens: int = 0
cached_tokens: int = 0 # cache_read_input_tokens
cache_creation_tokens: int = 0 # cache_creation_input_tokens
collected_text: str = ""
response_id: Optional[str] = None
final_usage: Optional[Dict[str, Any]] = None
final_response: Optional[Dict[str, Any]] = None
parsed_chunks: list = field(default_factory=list)
# 流状态
start_time: float = field(default_factory=time.time)
chunk_count: int = 0
data_count: int = 0
has_completion: bool = False
# 响应信息
status_code: int = 200
response_headers: Dict[str, str] = field(default_factory=dict)
# 请求信息(发送给 Provider 的)
provider_request_headers: Dict[str, str] = field(default_factory=dict)
provider_request_body: Optional[Dict[str, Any]] = None # 实际发送的请求体
# Provider 信息
provider_name: Optional[str] = None
provider_id: Optional[str] = None # Provider ID用于记录真实成本
endpoint_id: Optional[str] = None
key_id: Optional[str] = None
attempt_id: Optional[str] = None
attempt_synced: bool = False
error_message: Optional[str] = None
# 格式转换信息
provider_api_format: str = "" # Provider 的 API 格式(用于响应转换)
client_api_format: str = "" # 客户端请求的 API 格式
# Provider 响应元数据(存储 provider 返回的额外信息,如 Gemini 的 modelVersion
response_metadata: Dict[str, Any] = field(default_factory=dict)
class CliMessageHandlerBase(BaseMessageHandler): class CliMessageHandlerBase(BaseMessageHandler):
""" """
CLI 格式消息处理器基类 CLI 格式消息处理器基类
@@ -409,6 +352,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
return StreamingResponse( return StreamingResponse(
monitored_stream, monitored_stream,
media_type="text/event-stream", media_type="text/event-stream",
headers=build_sse_headers(),
background=background_tasks, background=background_tasks,
) )
@@ -433,7 +377,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
ctx.chunk_count = 0 ctx.chunk_count = 0
ctx.data_count = 0 ctx.data_count = 0
ctx.has_completion = False ctx.has_completion = False
ctx.collected_text = "" ctx._collected_text_parts = [] # 重置文本收集
ctx.input_tokens = 0 ctx.input_tokens = 0
ctx.output_tokens = 0 ctx.output_tokens = 0
ctx.cached_tokens = 0 ctx.cached_tokens = 0
@@ -521,12 +465,12 @@ class CliMessageHandlerBase(BaseMessageHandler):
stream_response.raise_for_status() stream_response.raise_for_status()
# 创建行迭代器(只创建一次,后续会继续使用 # 使用字节流迭代器(避免 aiter_lines 的性能问题
line_iterator = stream_response.aiter_lines() byte_iterator = stream_response.aiter_raw()
# 预读第一个数据块检测嵌套错误HTTP 200 但响应体包含错误) # 预读第一个数据块检测嵌套错误HTTP 200 但响应体包含错误)
prefetched_lines = await self._prefetch_and_check_embedded_error( prefetched_chunks = await self._prefetch_and_check_embedded_error(
line_iterator, provider, endpoint, ctx byte_iterator, provider, endpoint, ctx
) )
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
@@ -551,10 +495,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 创建流生成器(带预读数据,使用同一个迭代器) # 创建流生成器(带预读数据,使用同一个迭代器)
return self._create_response_stream_with_prefetch( return self._create_response_stream_with_prefetch(
ctx, ctx,
line_iterator, byte_iterator,
response_ctx, response_ctx,
http_client, http_client,
prefetched_lines, prefetched_chunks,
) )
async def _create_response_stream( async def _create_response_stream(
@@ -564,58 +508,75 @@ class CliMessageHandlerBase(BaseMessageHandler):
response_ctx: Any, response_ctx: Any,
http_client: httpx.AsyncClient, http_client: httpx.AsyncClient,
) -> AsyncGenerator[bytes, None]: ) -> AsyncGenerator[bytes, None]:
"""创建响应流生成器""" """创建响应流生成器(使用字节流)"""
try: try:
sse_parser = SSEEventParser() sse_parser = SSEEventParser()
last_data_time = time.time() last_data_time = time.time()
streaming_status_updated = False streaming_status_updated = False
buffer = b""
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
# 检查是否需要格式转换 # 检查是否需要格式转换
needs_conversion = self._needs_format_conversion(ctx) needs_conversion = self._needs_format_conversion(ctx)
async for line in stream_response.aiter_lines(): async for chunk in stream_response.aiter_raw():
# 在第一次输出数据前更新状态为 streaming # 在第一次输出数据前更新状态为 streaming
if not streaming_status_updated: if not streaming_status_updated:
self._update_usage_to_streaming(ctx.request_id) self._update_usage_to_streaming(ctx.request_id)
streaming_status_updated = True streaming_status_updated = True
normalized_line = line.rstrip("\r") buffer += chunk
events = sse_parser.feed_line(normalized_line) # 处理缓冲区中的完整行
while b"\n" in buffer:
if normalized_line == "": line_bytes, buffer = buffer.split(b"\n", 1)
for event in events: try:
self._handle_sse_event( # 使用增量解码器,可以正确处理跨 chunk 的多字节字符
ctx, line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
event.get("event"), except Exception as e:
event.get("data") or "", logger.warning(
f"[{self.request_id}] UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
) )
yield b"\n" continue
continue
ctx.chunk_count += 1 normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
# 空流检测:超过阈值且无数据,发送错误事件并结束 if normalized_line == "":
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0: for event in events:
elapsed = time.time() - last_data_time self._handle_sse_event(
if elapsed > self.DATA_TIMEOUT: ctx,
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据") event.get("event"),
error_event = { event.get("data") or "",
"type": "error", )
"error": { yield b"\n"
"type": "empty_stream_timeout", continue
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
return # 结束生成器
# 格式转换或直接透传 ctx.chunk_count += 1
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events) # 空流检测:超过阈值且无数据,发送错误事件并结束
if converted_line: if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
yield (converted_line + "\n").encode("utf-8") elapsed = time.time() - last_data_time
else: if elapsed > self.DATA_TIMEOUT:
yield (line + "\n").encode("utf-8") logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
error_event = {
"type": "error",
"error": {
"type": "empty_stream_timeout",
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
return # 结束生成器
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
yield (converted_line + "\n").encode("utf-8")
else:
yield (line + "\n").encode("utf-8")
for event in events: for event in events:
self._handle_sse_event( self._handle_sse_event(
@@ -689,7 +650,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
async def _prefetch_and_check_embedded_error( async def _prefetch_and_check_embedded_error(
self, self,
line_iterator: Any, byte_iterator: Any,
provider: Provider, provider: Provider,
endpoint: ProviderEndpoint, endpoint: ProviderEndpoint,
ctx: StreamContext, ctx: StreamContext,
@@ -703,20 +664,25 @@ class CliMessageHandlerBase(BaseMessageHandler):
同时检测 HTML 响应(通常是 base_url 配置错误导致返回网页)。 同时检测 HTML 响应(通常是 base_url 配置错误导致返回网页)。
Args: Args:
line_iterator: 行迭代器aiter_lines() 返回的迭代器 byte_iterator: 字节流迭代器
provider: Provider 对象 provider: Provider 对象
endpoint: Endpoint 对象 endpoint: Endpoint 对象
ctx: 流上下文 ctx: 流上下文
Returns: Returns:
预读的列表(需要在后续流中先输出) 预读的字节块列表(需要在后续流中先输出)
Raises: Raises:
EmbeddedErrorException: 如果检测到嵌套错误 EmbeddedErrorException: 如果检测到嵌套错误
ProviderNotAvailableException: 如果检测到 HTML 响应(配置错误) ProviderNotAvailableException: 如果检测到 HTML 响应(配置错误)
""" """
prefetched_lines: list = [] prefetched_chunks: list = []
max_prefetch_lines = 5 # 最多预读5行来检测错误 max_prefetch_lines = 5 # 最多预读5行来检测错误
buffer = b""
line_count = 0
should_stop = False
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
try: try:
# 获取对应格式的解析器 # 获取对应格式的解析器
@@ -729,69 +695,86 @@ class CliMessageHandlerBase(BaseMessageHandler):
else: else:
provider_parser = self.parser provider_parser = self.parser
line_count = 0 async for chunk in byte_iterator:
async for line in line_iterator: prefetched_chunks.append(chunk)
prefetched_lines.append(line) buffer += chunk
line_count += 1
# 解析数据 # 尝试按行解析缓冲区
normalized_line = line.rstrip("\r") while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
except Exception as e:
logger.warning(
f"[{self.request_id}] 预读时 UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
# 检测 HTML 响应base_url 配置错误的常见症状) line_count += 1
lower_line = normalized_line.lower() normalized_line = line.rstrip("\r")
if lower_line.startswith("<!doctype") or lower_line.startswith("<html"):
logger.error(
f" [{self.request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
f"base_url={endpoint.base_url}"
)
raise ProviderNotAvailableException(
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,请检查 endpoint 的 base_url 配置是否正确"
)
if not normalized_line or normalized_line.startswith(":"): # 检测 HTML 响应base_url 配置错误的常见症状)
# 空行或注释行,继续预读 lower_line = normalized_line.lower()
if line_count >= max_prefetch_lines: if lower_line.startswith("<!doctype") or lower_line.startswith("<html"):
logger.error(
f" [{self.request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
f"base_url={endpoint.base_url}"
)
raise ProviderNotAvailableException(
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,请检查 endpoint 的 base_url 配置是否正确"
)
if not normalized_line or normalized_line.startswith(":"):
# 空行或注释行,继续预读
if line_count >= max_prefetch_lines:
break
continue
# 尝试解析 SSE 数据
data_str = normalized_line
if normalized_line.startswith("data: "):
data_str = normalized_line[6:]
if data_str == "[DONE]":
should_stop = True
break break
continue
# 尝试解析 SSE 数据 try:
data_str = normalized_line data = json.loads(data_str)
if normalized_line.startswith("data: "): except json.JSONDecodeError:
data_str = normalized_line[6:] # 不是有效 JSON可能是部分数据继续
if line_count >= max_prefetch_lines:
break
continue
if data_str == "[DONE]": # 使用解析器检查是否为错误响应
if isinstance(data, dict) and provider_parser.is_error_response(data):
# 提取错误信息
parsed = provider_parser.parse_response(data, 200)
logger.warning(f" [{self.request_id}] 检测到嵌套错误: "
f"Provider={provider.name}, "
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}")
raise EmbeddedErrorException(
provider_name=str(provider.name),
error_code=(
int(parsed.error_type)
if parsed.error_type and parsed.error_type.isdigit()
else None
),
error_message=parsed.error_message,
error_status=parsed.error_type,
)
# 预读到有效数据,没有错误,停止预读
should_stop = True
break break
try: if should_stop or line_count >= max_prefetch_lines:
data = json.loads(data_str) break
except json.JSONDecodeError:
# 不是有效 JSON可能是部分数据继续
if line_count >= max_prefetch_lines:
break
continue
# 使用解析器检查是否为错误响应
if isinstance(data, dict) and provider_parser.is_error_response(data):
# 提取错误信息
parsed = provider_parser.parse_response(data, 200)
logger.warning(f" [{self.request_id}] 检测到嵌套错误: "
f"Provider={provider.name}, "
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}")
raise EmbeddedErrorException(
provider_name=str(provider.name),
error_code=(
int(parsed.error_type)
if parsed.error_type and parsed.error_type.isdigit()
else None
),
error_message=parsed.error_message,
error_status=parsed.error_type,
)
# 预读到有效数据,没有错误,停止预读
break
except EmbeddedErrorException: except EmbeddedErrorException:
# 重新抛出嵌套错误 # 重新抛出嵌套错误
@@ -800,112 +783,168 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 其他异常(如网络错误)在预读阶段发生,记录日志但不中断 # 其他异常(如网络错误)在预读阶段发生,记录日志但不中断
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}") logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
return prefetched_lines return prefetched_chunks
async def _create_response_stream_with_prefetch( async def _create_response_stream_with_prefetch(
self, self,
ctx: StreamContext, ctx: StreamContext,
line_iterator: Any, byte_iterator: Any,
response_ctx: Any, response_ctx: Any,
http_client: httpx.AsyncClient, http_client: httpx.AsyncClient,
prefetched_lines: list, prefetched_chunks: list,
) -> AsyncGenerator[bytes, None]: ) -> AsyncGenerator[bytes, None]:
"""创建响应流生成器(带预读数据)""" """创建响应流生成器(带预读数据,使用字节流"""
try: try:
sse_parser = SSEEventParser() sse_parser = SSEEventParser()
last_data_time = time.time() last_data_time = time.time()
buffer = b""
first_yield = True # 标记是否是第一次 yield
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
# 检查是否需要格式转换 # 检查是否需要格式转换
needs_conversion = self._needs_format_conversion(ctx) needs_conversion = self._needs_format_conversion(ctx)
# 在第一次输出数据前更新状态为 streaming # 在第一次输出数据前更新状态为 streaming
if prefetched_lines: if prefetched_chunks:
self._update_usage_to_streaming(ctx.request_id) self._update_usage_to_streaming(ctx.request_id)
# 先处理预读的数据 # 先处理预读的字节块
for line in prefetched_lines: for chunk in prefetched_chunks:
normalized_line = line.rstrip("\r") buffer += chunk
events = sse_parser.feed_line(normalized_line) # 处理缓冲区中的完整行
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
except Exception as e:
logger.warning(
f"[{self.request_id}] UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
# 记录首字时间 (第一次 yield)
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield b"\n"
continue
ctx.chunk_count += 1
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
# 记录首字时间 (第一次 yield)
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield (converted_line + "\n").encode("utf-8")
else:
# 记录首字时间 (第一次 yield)
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield (line + "\n").encode("utf-8")
if normalized_line == "":
for event in events: for event in events:
self._handle_sse_event( self._handle_sse_event(
ctx, ctx,
event.get("event"), event.get("event"),
event.get("data") or "", event.get("data") or "",
) )
yield b"\n"
continue
ctx.chunk_count += 1 if ctx.data_count > 0:
last_data_time = time.time()
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
yield (converted_line + "\n").encode("utf-8")
else:
yield (line + "\n").encode("utf-8")
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
if ctx.data_count > 0:
last_data_time = time.time()
# 继续处理剩余的流数据(使用同一个迭代器) # 继续处理剩余的流数据(使用同一个迭代器)
async for line in line_iterator: async for chunk in byte_iterator:
normalized_line = line.rstrip("\r") buffer += chunk
events = sse_parser.feed_line(normalized_line) # 处理缓冲区中的完整行
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
except Exception as e:
logger.warning(
f"[{self.request_id}] UTF-8 解码失败: {e}, "
f"bytes={line_bytes[:50]!r}"
)
continue
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
# 记录首字时间 (第一次 yield) - 如果预读数据为空
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield b"\n"
continue
ctx.chunk_count += 1
# 空流检测:超过阈值且无数据,发送错误事件并结束
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
elapsed = time.time() - last_data_time
if elapsed > self.DATA_TIMEOUT:
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
error_event = {
"type": "error",
"error": {
"type": "empty_stream_timeout",
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
return
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
# 记录首字时间 (第一次 yield) - 如果预读数据为空
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield (converted_line + "\n").encode("utf-8")
else:
# 记录首字时间 (第一次 yield) - 如果预读数据为空
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
yield (line + "\n").encode("utf-8")
if normalized_line == "":
for event in events: for event in events:
self._handle_sse_event( self._handle_sse_event(
ctx, ctx,
event.get("event"), event.get("event"),
event.get("data") or "", event.get("data") or "",
) )
yield b"\n"
continue
ctx.chunk_count += 1 if ctx.data_count > 0:
last_data_time = time.time()
# 空流检测:超过阈值且无数据,发送错误事件并结束
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
elapsed = time.time() - last_data_time
if elapsed > self.DATA_TIMEOUT:
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
error_event = {
"type": "error",
"error": {
"type": "empty_stream_timeout",
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
return
# 格式转换或直接透传
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
yield (converted_line + "\n").encode("utf-8")
else:
yield (line + "\n").encode("utf-8")
for event in events:
self._handle_sse_event(
ctx,
event.get("event"),
event.get("data") or "",
)
if ctx.data_count > 0:
last_data_time = time.time()
# 处理剩余事件 # 处理剩余事件
flushed_events = sse_parser.flush() flushed_events = sse_parser.flush()
@@ -1034,7 +1073,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 提取文本内容 # 提取文本内容
text = self.parser.extract_text_content(data) text = self.parser.extract_text_content(data)
if text: if text:
ctx.collected_text += text ctx.append_text(text)
# 检查完成事件 # 检查完成事件
if event_type in ("response.completed", "message_stop"): if event_type in ("response.completed", "message_stop"):
@@ -1086,9 +1125,11 @@ class CliMessageHandlerBase(BaseMessageHandler):
) -> None: ) -> None:
"""在流完成后记录统计信息""" """在流完成后记录统计信息"""
try: try:
await asyncio.sleep(0.1) # 使用 self.start_time 作为时间基准,与首字时间保持一致
# 注意:不要把统计延迟算进响应时间里
response_time_ms = int((time.time() - self.start_time) * 1000)
response_time_ms = int((time.time() - ctx.start_time) * 1000) await asyncio.sleep(0.1)
if not ctx.provider_name: if not ctx.provider_name:
logger.warning(f"[{ctx.request_id}] 流式请求失败,未选中提供商") logger.warning(f"[{ctx.request_id}] 流式请求失败,未选中提供商")
@@ -1168,6 +1209,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
input_tokens=actual_input_tokens, input_tokens=actual_input_tokens,
output_tokens=ctx.output_tokens, output_tokens=ctx.output_tokens,
response_time_ms=response_time_ms, response_time_ms=response_time_ms,
first_byte_time_ms=ctx.first_byte_time_ms, # 传递首字时间
status_code=ctx.status_code, status_code=ctx.status_code,
request_headers=original_headers, request_headers=original_headers,
request_body=actual_request_body, request_body=actual_request_body,
@@ -1188,9 +1230,18 @@ class CliMessageHandlerBase(BaseMessageHandler):
response_metadata=ctx.response_metadata if ctx.response_metadata else None, response_metadata=ctx.response_metadata if ctx.response_metadata else None,
) )
logger.debug(f"{self.FORMAT_ID} 流式响应完成") logger.debug(f"{self.FORMAT_ID} 流式响应完成")
# 简洁的请求完成摘要 # 简洁的请求完成摘要(两行格式)
logger.info(f"[OK] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name} | {response_time_ms}ms | " line1 = (
f"in:{ctx.input_tokens or 0} out:{ctx.output_tokens or 0}") f"[OK] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name}"
)
if ctx.first_byte_time_ms:
line1 += f" | TTFB: {ctx.first_byte_time_ms}ms"
line2 = (
f" Total: {response_time_ms}ms | "
f"in:{ctx.input_tokens or 0} out:{ctx.output_tokens or 0}"
)
logger.info(f"{line1}\n{line2}")
# 更新候选记录的最终状态和延迟时间 # 更新候选记录的最终状态和延迟时间
# 注意RequestExecutor 会在流开始时过早地标记成功(只记录了连接建立的时间) # 注意RequestExecutor 会在流开始时过早地标记成功(只记录了连接建立的时间)
@@ -1242,7 +1293,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
original_request_body: Dict[str, Any], original_request_body: Dict[str, Any],
) -> None: ) -> None:
"""记录流式请求失败""" """记录流式请求失败"""
response_time_ms = int((time.time() - ctx.start_time) * 1000) # 使用 self.start_time 作为时间基准,与首字时间保持一致
response_time_ms = int((time.time() - self.start_time) * 1000)
status_code = 503 status_code = 503
if isinstance(error, ProviderAuthException): if isinstance(error, ProviderAuthException):

View File

@@ -111,7 +111,7 @@ class ClaudeCliMessageHandler(CliMessageHandlerBase):
if delta.get("type") == "text_delta": if delta.get("type") == "text_delta":
text = delta.get("text", "") text = delta.get("text", "")
if text: if text:
ctx.collected_text += text ctx.append_text(text)
# 处理消息增量(包含最终 usage # 处理消息增量(包含最终 usage
elif event_type == "message_delta": elif event_type == "message_delta":

View File

@@ -160,7 +160,7 @@ class GeminiCliMessageHandler(CliMessageHandlerBase):
parts = content.get("parts", []) parts = content.get("parts", [])
for part in parts: for part in parts:
if "text" in part: if "text" in part:
ctx.collected_text += part["text"] ctx.append_text(part["text"])
# 检查结束原因 # 检查结束原因
finish_reason = candidate.get("finishReason") finish_reason = candidate.get("finishReason")

View File

@@ -94,9 +94,9 @@ class OpenAICliMessageHandler(CliMessageHandlerBase):
if event_type in ["response.output_text.delta", "response.outtext.delta"]: if event_type in ["response.output_text.delta", "response.outtext.delta"]:
delta = data.get("delta") delta = data.get("delta")
if isinstance(delta, str): if isinstance(delta, str):
ctx.collected_text += delta ctx.append_text(delta)
elif isinstance(delta, dict) and "text" in delta: elif isinstance(delta, dict) and "text" in delta:
ctx.collected_text += delta["text"] ctx.append_text(delta["text"])
# 处理完成事件 # 处理完成事件
elif event_type == "response.completed": elif event_type == "response.completed":
@@ -124,7 +124,7 @@ class OpenAICliMessageHandler(CliMessageHandlerBase):
if content_item.get("type") == "output_text": if content_item.get("type") == "output_text":
text = content_item.get("text", "") text = content_item.get("text", "")
if text: if text:
ctx.collected_text += text ctx.append_text(text)
# 备用:从顶层 usage 提取 # 备用:从顶层 usage 提取
usage_obj = data.get("usage") usage_obj = data.get("usage")