mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-05 09:12:27 +08:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c42ebdd0ee | ||
|
|
f1e3c2ab11 | ||
|
|
4e2ba0e57f | ||
|
|
a3df41d63d | ||
|
|
ad1c8c394c | ||
|
|
9b496abb73 | ||
|
|
f3a69a6160 |
@@ -0,0 +1,28 @@
|
||||
"""add first_byte_time_ms to usage table
|
||||
|
||||
Revision ID: 180e63a9c83a
|
||||
Revises: e9b3d63f0cbf
|
||||
Create Date: 2025-12-15 17:07:44.631032+00:00
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '180e63a9c83a'
|
||||
down_revision = 'e9b3d63f0cbf'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""应用迁移:升级到新版本"""
|
||||
# 添加首字时间字段到 usage 表
|
||||
op.add_column('usage', sa.Column('first_byte_time_ms', sa.Integer(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""回滚迁移:降级到旧版本"""
|
||||
# 删除首字时间字段
|
||||
op.drop_column('usage', 'first_byte_time_ms')
|
||||
@@ -479,10 +479,25 @@ const groupedTimeline = computed<NodeGroup[]>(() => {
|
||||
return groups
|
||||
})
|
||||
|
||||
// 计算链路总耗时(从第一个节点开始到最后一个节点结束)
|
||||
// 计算链路总耗时(使用成功候选的 latency_ms 字段)
|
||||
// 优先使用 latency_ms,因为它与 Usage.response_time_ms 使用相同的时间基准
|
||||
// 避免 finished_at - started_at 带来的额外延迟(数据库操作时间)
|
||||
const totalTraceLatency = computed(() => {
|
||||
if (!timeline.value || timeline.value.length === 0) return 0
|
||||
|
||||
// 查找成功的候选,使用其 latency_ms
|
||||
const successCandidate = timeline.value.find(c => c.status === 'success')
|
||||
if (successCandidate?.latency_ms != null) {
|
||||
return successCandidate.latency_ms
|
||||
}
|
||||
|
||||
// 如果没有成功的候选,查找失败但有 latency_ms 的候选
|
||||
const failedWithLatency = timeline.value.find(c => c.status === 'failed' && c.latency_ms != null)
|
||||
if (failedWithLatency?.latency_ms != null) {
|
||||
return failedWithLatency.latency_ms
|
||||
}
|
||||
|
||||
// 回退:使用 finished_at - started_at 计算
|
||||
let earliestStart: number | null = null
|
||||
let latestEnd: number | null = null
|
||||
|
||||
|
||||
@@ -177,8 +177,9 @@
|
||||
费用
|
||||
</TableHead>
|
||||
<TableHead class="h-12 font-semibold w-[70px] text-right">
|
||||
<div class="inline-block max-w-[2rem] leading-tight">
|
||||
响应时间
|
||||
<div class="flex flex-col items-end text-xs gap-0.5">
|
||||
<span>首字</span>
|
||||
<span class="text-muted-foreground font-normal">总耗时</span>
|
||||
</div>
|
||||
</TableHead>
|
||||
</TableRow>
|
||||
@@ -356,15 +357,28 @@
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell class="text-right py-4 w-[70px]">
|
||||
<span
|
||||
<div
|
||||
v-if="record.status === 'pending' || record.status === 'streaming'"
|
||||
class="text-primary tabular-nums"
|
||||
class="flex flex-col items-end text-xs gap-0.5"
|
||||
>
|
||||
{{ getElapsedTime(record) }}
|
||||
</span>
|
||||
<span v-else-if="record.response_time_ms">
|
||||
{{ (record.response_time_ms / 1000).toFixed(2) }}s
|
||||
</span>
|
||||
<span class="text-primary tabular-nums">
|
||||
{{ getElapsedTime(record) }}
|
||||
</span>
|
||||
</div>
|
||||
<div
|
||||
v-else-if="record.response_time_ms != null"
|
||||
class="flex flex-col items-end text-xs gap-0.5"
|
||||
>
|
||||
<span
|
||||
v-if="record.first_byte_time_ms != null"
|
||||
class="tabular-nums"
|
||||
>{{ (record.first_byte_time_ms / 1000).toFixed(2) }}s</span>
|
||||
<span
|
||||
v-else
|
||||
class="text-muted-foreground"
|
||||
>-</span>
|
||||
<span class="text-muted-foreground tabular-nums">{{ (record.response_time_ms / 1000).toFixed(2) }}s</span>
|
||||
</div>
|
||||
<span
|
||||
v-else
|
||||
class="text-muted-foreground"
|
||||
|
||||
@@ -78,6 +78,7 @@ export interface UsageRecord {
|
||||
cost: number
|
||||
actual_cost?: number
|
||||
response_time_ms?: number
|
||||
first_byte_time_ms?: number // 首字时间 (TTFB)
|
||||
is_stream: boolean
|
||||
status_code?: number
|
||||
error_message?: string
|
||||
|
||||
@@ -628,6 +628,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
|
||||
"actual_cost": actual_cost,
|
||||
"rate_multiplier": rate_multiplier,
|
||||
"response_time_ms": usage.response_time_ms,
|
||||
"first_byte_time_ms": usage.first_byte_time_ms, # 首字时间 (TTFB)
|
||||
"created_at": usage.created_at.isoformat(),
|
||||
"is_stream": usage.is_stream,
|
||||
"input_price_per_1m": usage.input_price_per_1m,
|
||||
@@ -738,6 +739,7 @@ class AdminUsageDetailAdapter(AdminApiAdapter):
|
||||
"status_code": usage_record.status_code,
|
||||
"error_message": usage_record.error_message,
|
||||
"response_time_ms": usage_record.response_time_ms,
|
||||
"first_byte_time_ms": usage_record.first_byte_time_ms, # 首字时间 (TTFB)
|
||||
"created_at": usage_record.created_at.isoformat() if usage_record.created_at else None,
|
||||
"request_headers": usage_record.request_headers,
|
||||
"request_body": usage_record.get_request_body(),
|
||||
|
||||
@@ -100,6 +100,8 @@ class MessageTelemetry:
|
||||
cache_read_tokens: int = 0,
|
||||
is_stream: bool = False,
|
||||
provider_request_headers: Optional[Dict[str, Any]] = None,
|
||||
# 时间指标
|
||||
first_byte_time_ms: Optional[int] = None, # 首字时间/TTFB
|
||||
# Provider 侧追踪信息(用于记录真实成本)
|
||||
provider_id: Optional[str] = None,
|
||||
provider_endpoint_id: Optional[str] = None,
|
||||
@@ -133,6 +135,7 @@ class MessageTelemetry:
|
||||
api_format=api_format,
|
||||
is_stream=is_stream,
|
||||
response_time_ms=response_time_ms,
|
||||
first_byte_time_ms=first_byte_time_ms, # 传递首字时间
|
||||
status_code=status_code,
|
||||
request_headers=request_headers,
|
||||
request_body=request_body,
|
||||
|
||||
@@ -34,6 +34,7 @@ from src.api.handlers.base.response_parser import ResponseParser
|
||||
from src.api.handlers.base.stream_context import StreamContext
|
||||
from src.api.handlers.base.stream_processor import StreamProcessor
|
||||
from src.api.handlers.base.stream_telemetry import StreamTelemetryRecorder
|
||||
from src.api.handlers.base.utils import build_sse_headers
|
||||
from src.config.settings import config
|
||||
from src.core.exceptions import (
|
||||
EmbeddedErrorException,
|
||||
@@ -365,7 +366,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
ctx,
|
||||
original_headers,
|
||||
original_request_body,
|
||||
self.elapsed_ms(),
|
||||
self.start_time, # 传入开始时间,让 telemetry 在流结束后计算响应时间
|
||||
)
|
||||
|
||||
# 创建监控流
|
||||
@@ -378,6 +379,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
return StreamingResponse(
|
||||
monitored_stream,
|
||||
media_type="text/event-stream",
|
||||
headers=build_sse_headers(),
|
||||
background=background_tasks,
|
||||
)
|
||||
|
||||
@@ -473,12 +475,13 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
|
||||
stream_response.raise_for_status()
|
||||
|
||||
# 创建行迭代器
|
||||
line_iterator = stream_response.aiter_lines()
|
||||
# 使用字节流迭代器(避免 aiter_lines 的性能问题)
|
||||
# aiter_raw() 返回原始数据块,无缓冲,实现真正的流式传输
|
||||
byte_iterator = stream_response.aiter_raw()
|
||||
|
||||
# 预读检测嵌套错误
|
||||
prefetched_lines = await stream_processor.prefetch_and_check_error(
|
||||
line_iterator,
|
||||
prefetched_chunks = await stream_processor.prefetch_and_check_error(
|
||||
byte_iterator,
|
||||
provider,
|
||||
endpoint,
|
||||
ctx,
|
||||
@@ -503,13 +506,14 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
await http_client.aclose()
|
||||
raise
|
||||
|
||||
# 创建流生成器
|
||||
# 创建流生成器(传入字节流迭代器)
|
||||
return stream_processor.create_response_stream(
|
||||
ctx,
|
||||
line_iterator,
|
||||
byte_iterator,
|
||||
response_ctx,
|
||||
http_client,
|
||||
prefetched_lines,
|
||||
prefetched_chunks,
|
||||
start_time=self.start_time,
|
||||
)
|
||||
|
||||
async def _record_stream_failure(
|
||||
|
||||
@@ -11,17 +11,15 @@ CLI Message Handler 通用基类
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import codecs
|
||||
import json
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Dict,
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
import httpx
|
||||
@@ -35,6 +33,8 @@ from src.api.handlers.base.base_handler import (
|
||||
)
|
||||
from src.api.handlers.base.parsers import get_parser_for_format
|
||||
from src.api.handlers.base.request_builder import PassthroughRequestBuilder
|
||||
from src.api.handlers.base.stream_context import StreamContext
|
||||
from src.api.handlers.base.utils import build_sse_headers
|
||||
|
||||
# 直接从具体模块导入,避免循环依赖
|
||||
from src.api.handlers.base.response_parser import (
|
||||
@@ -61,63 +61,6 @@ from src.services.provider.transport import build_provider_url
|
||||
from src.utils.sse_parser import SSEEventParser
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamContext:
|
||||
"""流式请求的上下文信息"""
|
||||
|
||||
# 请求信息
|
||||
model: str = "unknown" # 用户请求的原始模型名
|
||||
mapped_model: Optional[str] = None # 映射后的目标模型名(如果发生了映射)
|
||||
api_format: str = ""
|
||||
request_id: str = ""
|
||||
|
||||
# 用户信息(提前提取避免 Session detached)
|
||||
user_id: int = 0
|
||||
api_key_id: int = 0
|
||||
|
||||
# 统计信息
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cached_tokens: int = 0 # cache_read_input_tokens
|
||||
cache_creation_tokens: int = 0 # cache_creation_input_tokens
|
||||
collected_text: str = ""
|
||||
response_id: Optional[str] = None
|
||||
final_usage: Optional[Dict[str, Any]] = None
|
||||
final_response: Optional[Dict[str, Any]] = None
|
||||
parsed_chunks: list = field(default_factory=list)
|
||||
|
||||
# 流状态
|
||||
start_time: float = field(default_factory=time.time)
|
||||
chunk_count: int = 0
|
||||
data_count: int = 0
|
||||
has_completion: bool = False
|
||||
|
||||
# 响应信息
|
||||
status_code: int = 200
|
||||
response_headers: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
# 请求信息(发送给 Provider 的)
|
||||
provider_request_headers: Dict[str, str] = field(default_factory=dict)
|
||||
provider_request_body: Optional[Dict[str, Any]] = None # 实际发送的请求体
|
||||
|
||||
# Provider 信息
|
||||
provider_name: Optional[str] = None
|
||||
provider_id: Optional[str] = None # Provider ID(用于记录真实成本)
|
||||
endpoint_id: Optional[str] = None
|
||||
key_id: Optional[str] = None
|
||||
attempt_id: Optional[str] = None
|
||||
attempt_synced: bool = False
|
||||
error_message: Optional[str] = None
|
||||
|
||||
# 格式转换信息
|
||||
provider_api_format: str = "" # Provider 的 API 格式(用于响应转换)
|
||||
client_api_format: str = "" # 客户端请求的 API 格式
|
||||
|
||||
# Provider 响应元数据(存储 provider 返回的额外信息,如 Gemini 的 modelVersion)
|
||||
response_metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class CliMessageHandlerBase(BaseMessageHandler):
|
||||
"""
|
||||
CLI 格式消息处理器基类
|
||||
@@ -409,6 +352,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
return StreamingResponse(
|
||||
monitored_stream,
|
||||
media_type="text/event-stream",
|
||||
headers=build_sse_headers(),
|
||||
background=background_tasks,
|
||||
)
|
||||
|
||||
@@ -433,7 +377,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
ctx.chunk_count = 0
|
||||
ctx.data_count = 0
|
||||
ctx.has_completion = False
|
||||
ctx.collected_text = ""
|
||||
ctx._collected_text_parts = [] # 重置文本收集
|
||||
ctx.input_tokens = 0
|
||||
ctx.output_tokens = 0
|
||||
ctx.cached_tokens = 0
|
||||
@@ -521,12 +465,12 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
|
||||
stream_response.raise_for_status()
|
||||
|
||||
# 创建行迭代器(只创建一次,后续会继续使用)
|
||||
line_iterator = stream_response.aiter_lines()
|
||||
# 使用字节流迭代器(避免 aiter_lines 的性能问题)
|
||||
byte_iterator = stream_response.aiter_raw()
|
||||
|
||||
# 预读第一个数据块,检测嵌套错误(HTTP 200 但响应体包含错误)
|
||||
prefetched_lines = await self._prefetch_and_check_embedded_error(
|
||||
line_iterator, provider, endpoint, ctx
|
||||
prefetched_chunks = await self._prefetch_and_check_embedded_error(
|
||||
byte_iterator, provider, endpoint, ctx
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
@@ -551,10 +495,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
# 创建流生成器(带预读数据,使用同一个迭代器)
|
||||
return self._create_response_stream_with_prefetch(
|
||||
ctx,
|
||||
line_iterator,
|
||||
byte_iterator,
|
||||
response_ctx,
|
||||
http_client,
|
||||
prefetched_lines,
|
||||
prefetched_chunks,
|
||||
)
|
||||
|
||||
async def _create_response_stream(
|
||||
@@ -564,58 +508,75 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
response_ctx: Any,
|
||||
http_client: httpx.AsyncClient,
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""创建响应流生成器"""
|
||||
"""创建响应流生成器(使用字节流)"""
|
||||
try:
|
||||
sse_parser = SSEEventParser()
|
||||
last_data_time = time.time()
|
||||
streaming_status_updated = False
|
||||
buffer = b""
|
||||
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
|
||||
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
||||
|
||||
# 检查是否需要格式转换
|
||||
needs_conversion = self._needs_format_conversion(ctx)
|
||||
|
||||
async for line in stream_response.aiter_lines():
|
||||
async for chunk in stream_response.aiter_raw():
|
||||
# 在第一次输出数据前更新状态为 streaming
|
||||
if not streaming_status_updated:
|
||||
self._update_usage_to_streaming(ctx.request_id)
|
||||
streaming_status_updated = True
|
||||
|
||||
normalized_line = line.rstrip("\r")
|
||||
events = sse_parser.feed_line(normalized_line)
|
||||
|
||||
if normalized_line == "":
|
||||
for event in events:
|
||||
self._handle_sse_event(
|
||||
ctx,
|
||||
event.get("event"),
|
||||
event.get("data") or "",
|
||||
buffer += chunk
|
||||
# 处理缓冲区中的完整行
|
||||
while b"\n" in buffer:
|
||||
line_bytes, buffer = buffer.split(b"\n", 1)
|
||||
try:
|
||||
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
|
||||
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[{self.request_id}] UTF-8 解码失败: {e}, "
|
||||
f"bytes={line_bytes[:50]!r}"
|
||||
)
|
||||
yield b"\n"
|
||||
continue
|
||||
continue
|
||||
|
||||
ctx.chunk_count += 1
|
||||
normalized_line = line.rstrip("\r")
|
||||
events = sse_parser.feed_line(normalized_line)
|
||||
|
||||
# 空流检测:超过阈值且无数据,发送错误事件并结束
|
||||
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
|
||||
elapsed = time.time() - last_data_time
|
||||
if elapsed > self.DATA_TIMEOUT:
|
||||
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
|
||||
error_event = {
|
||||
"type": "error",
|
||||
"error": {
|
||||
"type": "empty_stream_timeout",
|
||||
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
|
||||
},
|
||||
}
|
||||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
||||
return # 结束生成器
|
||||
if normalized_line == "":
|
||||
for event in events:
|
||||
self._handle_sse_event(
|
||||
ctx,
|
||||
event.get("event"),
|
||||
event.get("data") or "",
|
||||
)
|
||||
yield b"\n"
|
||||
continue
|
||||
|
||||
# 格式转换或直接透传
|
||||
if needs_conversion:
|
||||
converted_line = self._convert_sse_line(ctx, line, events)
|
||||
if converted_line:
|
||||
yield (converted_line + "\n").encode("utf-8")
|
||||
else:
|
||||
yield (line + "\n").encode("utf-8")
|
||||
ctx.chunk_count += 1
|
||||
|
||||
# 空流检测:超过阈值且无数据,发送错误事件并结束
|
||||
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
|
||||
elapsed = time.time() - last_data_time
|
||||
if elapsed > self.DATA_TIMEOUT:
|
||||
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
|
||||
error_event = {
|
||||
"type": "error",
|
||||
"error": {
|
||||
"type": "empty_stream_timeout",
|
||||
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
|
||||
},
|
||||
}
|
||||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
||||
return # 结束生成器
|
||||
|
||||
# 格式转换或直接透传
|
||||
if needs_conversion:
|
||||
converted_line = self._convert_sse_line(ctx, line, events)
|
||||
if converted_line:
|
||||
yield (converted_line + "\n").encode("utf-8")
|
||||
else:
|
||||
yield (line + "\n").encode("utf-8")
|
||||
|
||||
for event in events:
|
||||
self._handle_sse_event(
|
||||
@@ -689,7 +650,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
|
||||
async def _prefetch_and_check_embedded_error(
|
||||
self,
|
||||
line_iterator: Any,
|
||||
byte_iterator: Any,
|
||||
provider: Provider,
|
||||
endpoint: ProviderEndpoint,
|
||||
ctx: StreamContext,
|
||||
@@ -703,20 +664,25 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
同时检测 HTML 响应(通常是 base_url 配置错误导致返回网页)。
|
||||
|
||||
Args:
|
||||
line_iterator: 行迭代器(aiter_lines() 返回的迭代器)
|
||||
byte_iterator: 字节流迭代器
|
||||
provider: Provider 对象
|
||||
endpoint: Endpoint 对象
|
||||
ctx: 流上下文
|
||||
|
||||
Returns:
|
||||
预读的行列表(需要在后续流中先输出)
|
||||
预读的字节块列表(需要在后续流中先输出)
|
||||
|
||||
Raises:
|
||||
EmbeddedErrorException: 如果检测到嵌套错误
|
||||
ProviderNotAvailableException: 如果检测到 HTML 响应(配置错误)
|
||||
"""
|
||||
prefetched_lines: list = []
|
||||
prefetched_chunks: list = []
|
||||
max_prefetch_lines = 5 # 最多预读5行来检测错误
|
||||
buffer = b""
|
||||
line_count = 0
|
||||
should_stop = False
|
||||
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
|
||||
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
||||
|
||||
try:
|
||||
# 获取对应格式的解析器
|
||||
@@ -729,69 +695,86 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
else:
|
||||
provider_parser = self.parser
|
||||
|
||||
line_count = 0
|
||||
async for line in line_iterator:
|
||||
prefetched_lines.append(line)
|
||||
line_count += 1
|
||||
async for chunk in byte_iterator:
|
||||
prefetched_chunks.append(chunk)
|
||||
buffer += chunk
|
||||
|
||||
# 解析数据
|
||||
normalized_line = line.rstrip("\r")
|
||||
# 尝试按行解析缓冲区
|
||||
while b"\n" in buffer:
|
||||
line_bytes, buffer = buffer.split(b"\n", 1)
|
||||
try:
|
||||
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
|
||||
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[{self.request_id}] 预读时 UTF-8 解码失败: {e}, "
|
||||
f"bytes={line_bytes[:50]!r}"
|
||||
)
|
||||
continue
|
||||
|
||||
# 检测 HTML 响应(base_url 配置错误的常见症状)
|
||||
lower_line = normalized_line.lower()
|
||||
if lower_line.startswith("<!doctype") or lower_line.startswith("<html"):
|
||||
logger.error(
|
||||
f" [{self.request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
|
||||
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
|
||||
f"base_url={endpoint.base_url}"
|
||||
)
|
||||
raise ProviderNotAvailableException(
|
||||
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,请检查 endpoint 的 base_url 配置是否正确"
|
||||
)
|
||||
line_count += 1
|
||||
normalized_line = line.rstrip("\r")
|
||||
|
||||
if not normalized_line or normalized_line.startswith(":"):
|
||||
# 空行或注释行,继续预读
|
||||
if line_count >= max_prefetch_lines:
|
||||
# 检测 HTML 响应(base_url 配置错误的常见症状)
|
||||
lower_line = normalized_line.lower()
|
||||
if lower_line.startswith("<!doctype") or lower_line.startswith("<html"):
|
||||
logger.error(
|
||||
f" [{self.request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
|
||||
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
|
||||
f"base_url={endpoint.base_url}"
|
||||
)
|
||||
raise ProviderNotAvailableException(
|
||||
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,请检查 endpoint 的 base_url 配置是否正确"
|
||||
)
|
||||
|
||||
if not normalized_line or normalized_line.startswith(":"):
|
||||
# 空行或注释行,继续预读
|
||||
if line_count >= max_prefetch_lines:
|
||||
break
|
||||
continue
|
||||
|
||||
# 尝试解析 SSE 数据
|
||||
data_str = normalized_line
|
||||
if normalized_line.startswith("data: "):
|
||||
data_str = normalized_line[6:]
|
||||
|
||||
if data_str == "[DONE]":
|
||||
should_stop = True
|
||||
break
|
||||
continue
|
||||
|
||||
# 尝试解析 SSE 数据
|
||||
data_str = normalized_line
|
||||
if normalized_line.startswith("data: "):
|
||||
data_str = normalized_line[6:]
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
# 不是有效 JSON,可能是部分数据,继续
|
||||
if line_count >= max_prefetch_lines:
|
||||
break
|
||||
continue
|
||||
|
||||
if data_str == "[DONE]":
|
||||
# 使用解析器检查是否为错误响应
|
||||
if isinstance(data, dict) and provider_parser.is_error_response(data):
|
||||
# 提取错误信息
|
||||
parsed = provider_parser.parse_response(data, 200)
|
||||
logger.warning(f" [{self.request_id}] 检测到嵌套错误: "
|
||||
f"Provider={provider.name}, "
|
||||
f"error_type={parsed.error_type}, "
|
||||
f"message={parsed.error_message}")
|
||||
raise EmbeddedErrorException(
|
||||
provider_name=str(provider.name),
|
||||
error_code=(
|
||||
int(parsed.error_type)
|
||||
if parsed.error_type and parsed.error_type.isdigit()
|
||||
else None
|
||||
),
|
||||
error_message=parsed.error_message,
|
||||
error_status=parsed.error_type,
|
||||
)
|
||||
|
||||
# 预读到有效数据,没有错误,停止预读
|
||||
should_stop = True
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
# 不是有效 JSON,可能是部分数据,继续
|
||||
if line_count >= max_prefetch_lines:
|
||||
break
|
||||
continue
|
||||
|
||||
# 使用解析器检查是否为错误响应
|
||||
if isinstance(data, dict) and provider_parser.is_error_response(data):
|
||||
# 提取错误信息
|
||||
parsed = provider_parser.parse_response(data, 200)
|
||||
logger.warning(f" [{self.request_id}] 检测到嵌套错误: "
|
||||
f"Provider={provider.name}, "
|
||||
f"error_type={parsed.error_type}, "
|
||||
f"message={parsed.error_message}")
|
||||
raise EmbeddedErrorException(
|
||||
provider_name=str(provider.name),
|
||||
error_code=(
|
||||
int(parsed.error_type)
|
||||
if parsed.error_type and parsed.error_type.isdigit()
|
||||
else None
|
||||
),
|
||||
error_message=parsed.error_message,
|
||||
error_status=parsed.error_type,
|
||||
)
|
||||
|
||||
# 预读到有效数据,没有错误,停止预读
|
||||
break
|
||||
if should_stop or line_count >= max_prefetch_lines:
|
||||
break
|
||||
|
||||
except EmbeddedErrorException:
|
||||
# 重新抛出嵌套错误
|
||||
@@ -800,112 +783,168 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
# 其他异常(如网络错误)在预读阶段发生,记录日志但不中断
|
||||
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
|
||||
|
||||
return prefetched_lines
|
||||
return prefetched_chunks
|
||||
|
||||
async def _create_response_stream_with_prefetch(
|
||||
self,
|
||||
ctx: StreamContext,
|
||||
line_iterator: Any,
|
||||
byte_iterator: Any,
|
||||
response_ctx: Any,
|
||||
http_client: httpx.AsyncClient,
|
||||
prefetched_lines: list,
|
||||
prefetched_chunks: list,
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""创建响应流生成器(带预读数据)"""
|
||||
"""创建响应流生成器(带预读数据,使用字节流)"""
|
||||
try:
|
||||
sse_parser = SSEEventParser()
|
||||
last_data_time = time.time()
|
||||
buffer = b""
|
||||
first_yield = True # 标记是否是第一次 yield
|
||||
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
|
||||
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
||||
|
||||
# 检查是否需要格式转换
|
||||
needs_conversion = self._needs_format_conversion(ctx)
|
||||
|
||||
# 在第一次输出数据前更新状态为 streaming
|
||||
if prefetched_lines:
|
||||
if prefetched_chunks:
|
||||
self._update_usage_to_streaming(ctx.request_id)
|
||||
|
||||
# 先处理预读的数据
|
||||
for line in prefetched_lines:
|
||||
normalized_line = line.rstrip("\r")
|
||||
events = sse_parser.feed_line(normalized_line)
|
||||
# 先处理预读的字节块
|
||||
for chunk in prefetched_chunks:
|
||||
buffer += chunk
|
||||
# 处理缓冲区中的完整行
|
||||
while b"\n" in buffer:
|
||||
line_bytes, buffer = buffer.split(b"\n", 1)
|
||||
try:
|
||||
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
|
||||
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[{self.request_id}] UTF-8 解码失败: {e}, "
|
||||
f"bytes={line_bytes[:50]!r}"
|
||||
)
|
||||
continue
|
||||
|
||||
normalized_line = line.rstrip("\r")
|
||||
events = sse_parser.feed_line(normalized_line)
|
||||
|
||||
if normalized_line == "":
|
||||
for event in events:
|
||||
self._handle_sse_event(
|
||||
ctx,
|
||||
event.get("event"),
|
||||
event.get("data") or "",
|
||||
)
|
||||
# 记录首字时间 (第一次 yield)
|
||||
if first_yield:
|
||||
ctx.record_first_byte_time(self.start_time)
|
||||
first_yield = False
|
||||
yield b"\n"
|
||||
continue
|
||||
|
||||
ctx.chunk_count += 1
|
||||
|
||||
# 格式转换或直接透传
|
||||
if needs_conversion:
|
||||
converted_line = self._convert_sse_line(ctx, line, events)
|
||||
if converted_line:
|
||||
# 记录首字时间 (第一次 yield)
|
||||
if first_yield:
|
||||
ctx.record_first_byte_time(self.start_time)
|
||||
first_yield = False
|
||||
yield (converted_line + "\n").encode("utf-8")
|
||||
else:
|
||||
# 记录首字时间 (第一次 yield)
|
||||
if first_yield:
|
||||
ctx.record_first_byte_time(self.start_time)
|
||||
first_yield = False
|
||||
yield (line + "\n").encode("utf-8")
|
||||
|
||||
if normalized_line == "":
|
||||
for event in events:
|
||||
self._handle_sse_event(
|
||||
ctx,
|
||||
event.get("event"),
|
||||
event.get("data") or "",
|
||||
)
|
||||
yield b"\n"
|
||||
continue
|
||||
|
||||
ctx.chunk_count += 1
|
||||
|
||||
# 格式转换或直接透传
|
||||
if needs_conversion:
|
||||
converted_line = self._convert_sse_line(ctx, line, events)
|
||||
if converted_line:
|
||||
yield (converted_line + "\n").encode("utf-8")
|
||||
else:
|
||||
yield (line + "\n").encode("utf-8")
|
||||
|
||||
for event in events:
|
||||
self._handle_sse_event(
|
||||
ctx,
|
||||
event.get("event"),
|
||||
event.get("data") or "",
|
||||
)
|
||||
|
||||
if ctx.data_count > 0:
|
||||
last_data_time = time.time()
|
||||
if ctx.data_count > 0:
|
||||
last_data_time = time.time()
|
||||
|
||||
# 继续处理剩余的流数据(使用同一个迭代器)
|
||||
async for line in line_iterator:
|
||||
normalized_line = line.rstrip("\r")
|
||||
events = sse_parser.feed_line(normalized_line)
|
||||
async for chunk in byte_iterator:
|
||||
buffer += chunk
|
||||
# 处理缓冲区中的完整行
|
||||
while b"\n" in buffer:
|
||||
line_bytes, buffer = buffer.split(b"\n", 1)
|
||||
try:
|
||||
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
|
||||
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[{self.request_id}] UTF-8 解码失败: {e}, "
|
||||
f"bytes={line_bytes[:50]!r}"
|
||||
)
|
||||
continue
|
||||
|
||||
normalized_line = line.rstrip("\r")
|
||||
events = sse_parser.feed_line(normalized_line)
|
||||
|
||||
if normalized_line == "":
|
||||
for event in events:
|
||||
self._handle_sse_event(
|
||||
ctx,
|
||||
event.get("event"),
|
||||
event.get("data") or "",
|
||||
)
|
||||
# 记录首字时间 (第一次 yield) - 如果预读数据为空
|
||||
if first_yield:
|
||||
ctx.record_first_byte_time(self.start_time)
|
||||
first_yield = False
|
||||
yield b"\n"
|
||||
continue
|
||||
|
||||
ctx.chunk_count += 1
|
||||
|
||||
# 空流检测:超过阈值且无数据,发送错误事件并结束
|
||||
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
|
||||
elapsed = time.time() - last_data_time
|
||||
if elapsed > self.DATA_TIMEOUT:
|
||||
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
|
||||
error_event = {
|
||||
"type": "error",
|
||||
"error": {
|
||||
"type": "empty_stream_timeout",
|
||||
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
|
||||
},
|
||||
}
|
||||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
||||
return
|
||||
|
||||
# 格式转换或直接透传
|
||||
if needs_conversion:
|
||||
converted_line = self._convert_sse_line(ctx, line, events)
|
||||
if converted_line:
|
||||
# 记录首字时间 (第一次 yield) - 如果预读数据为空
|
||||
if first_yield:
|
||||
ctx.record_first_byte_time(self.start_time)
|
||||
first_yield = False
|
||||
yield (converted_line + "\n").encode("utf-8")
|
||||
else:
|
||||
# 记录首字时间 (第一次 yield) - 如果预读数据为空
|
||||
if first_yield:
|
||||
ctx.record_first_byte_time(self.start_time)
|
||||
first_yield = False
|
||||
yield (line + "\n").encode("utf-8")
|
||||
|
||||
if normalized_line == "":
|
||||
for event in events:
|
||||
self._handle_sse_event(
|
||||
ctx,
|
||||
event.get("event"),
|
||||
event.get("data") or "",
|
||||
)
|
||||
yield b"\n"
|
||||
continue
|
||||
|
||||
ctx.chunk_count += 1
|
||||
|
||||
# 空流检测:超过阈值且无数据,发送错误事件并结束
|
||||
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
|
||||
elapsed = time.time() - last_data_time
|
||||
if elapsed > self.DATA_TIMEOUT:
|
||||
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
|
||||
error_event = {
|
||||
"type": "error",
|
||||
"error": {
|
||||
"type": "empty_stream_timeout",
|
||||
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
|
||||
},
|
||||
}
|
||||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
||||
return
|
||||
|
||||
# 格式转换或直接透传
|
||||
if needs_conversion:
|
||||
converted_line = self._convert_sse_line(ctx, line, events)
|
||||
if converted_line:
|
||||
yield (converted_line + "\n").encode("utf-8")
|
||||
else:
|
||||
yield (line + "\n").encode("utf-8")
|
||||
|
||||
for event in events:
|
||||
self._handle_sse_event(
|
||||
ctx,
|
||||
event.get("event"),
|
||||
event.get("data") or "",
|
||||
)
|
||||
|
||||
if ctx.data_count > 0:
|
||||
last_data_time = time.time()
|
||||
if ctx.data_count > 0:
|
||||
last_data_time = time.time()
|
||||
|
||||
# 处理剩余事件
|
||||
flushed_events = sse_parser.flush()
|
||||
@@ -1034,7 +1073,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
# 提取文本内容
|
||||
text = self.parser.extract_text_content(data)
|
||||
if text:
|
||||
ctx.collected_text += text
|
||||
ctx.append_text(text)
|
||||
|
||||
# 检查完成事件
|
||||
if event_type in ("response.completed", "message_stop"):
|
||||
@@ -1086,9 +1125,11 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
) -> None:
|
||||
"""在流完成后记录统计信息"""
|
||||
try:
|
||||
await asyncio.sleep(0.1)
|
||||
# 使用 self.start_time 作为时间基准,与首字时间保持一致
|
||||
# 注意:不要把统计延迟算进响应时间里
|
||||
response_time_ms = int((time.time() - self.start_time) * 1000)
|
||||
|
||||
response_time_ms = int((time.time() - ctx.start_time) * 1000)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
if not ctx.provider_name:
|
||||
logger.warning(f"[{ctx.request_id}] 流式请求失败,未选中提供商")
|
||||
@@ -1168,6 +1209,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
input_tokens=actual_input_tokens,
|
||||
output_tokens=ctx.output_tokens,
|
||||
response_time_ms=response_time_ms,
|
||||
first_byte_time_ms=ctx.first_byte_time_ms, # 传递首字时间
|
||||
status_code=ctx.status_code,
|
||||
request_headers=original_headers,
|
||||
request_body=actual_request_body,
|
||||
@@ -1188,9 +1230,18 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
response_metadata=ctx.response_metadata if ctx.response_metadata else None,
|
||||
)
|
||||
logger.debug(f"{self.FORMAT_ID} 流式响应完成")
|
||||
# 简洁的请求完成摘要
|
||||
logger.info(f"[OK] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name} | {response_time_ms}ms | "
|
||||
f"in:{ctx.input_tokens or 0} out:{ctx.output_tokens or 0}")
|
||||
# 简洁的请求完成摘要(两行格式)
|
||||
line1 = (
|
||||
f"[OK] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name}"
|
||||
)
|
||||
if ctx.first_byte_time_ms:
|
||||
line1 += f" | TTFB: {ctx.first_byte_time_ms}ms"
|
||||
|
||||
line2 = (
|
||||
f" Total: {response_time_ms}ms | "
|
||||
f"in:{ctx.input_tokens or 0} out:{ctx.output_tokens or 0}"
|
||||
)
|
||||
logger.info(f"{line1}\n{line2}")
|
||||
|
||||
# 更新候选记录的最终状态和延迟时间
|
||||
# 注意:RequestExecutor 会在流开始时过早地标记成功(只记录了连接建立的时间)
|
||||
@@ -1242,7 +1293,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
original_request_body: Dict[str, Any],
|
||||
) -> None:
|
||||
"""记录流式请求失败"""
|
||||
response_time_ms = int((time.time() - ctx.start_time) * 1000)
|
||||
# 使用 self.start_time 作为时间基准,与首字时间保持一致
|
||||
response_time_ms = int((time.time() - self.start_time) * 1000)
|
||||
|
||||
status_code = 503
|
||||
if isinstance(error, ProviderAuthException):
|
||||
|
||||
@@ -13,6 +13,7 @@ from src.api.handlers.base.response_parser import (
|
||||
ResponseParser,
|
||||
StreamStats,
|
||||
)
|
||||
from src.api.handlers.base.utils import extract_cache_creation_tokens
|
||||
|
||||
|
||||
def _check_nested_error(response: Dict[str, Any]) -> Tuple[bool, Optional[Dict[str, Any]]]:
|
||||
@@ -252,7 +253,7 @@ class ClaudeResponseParser(ResponseParser):
|
||||
usage = response.get("usage", {})
|
||||
result.input_tokens = usage.get("input_tokens", 0)
|
||||
result.output_tokens = usage.get("output_tokens", 0)
|
||||
result.cache_creation_tokens = usage.get("cache_creation_input_tokens", 0)
|
||||
result.cache_creation_tokens = extract_cache_creation_tokens(usage)
|
||||
result.cache_read_tokens = usage.get("cache_read_input_tokens", 0)
|
||||
|
||||
# 检查错误(支持嵌套错误格式)
|
||||
@@ -265,11 +266,16 @@ class ClaudeResponseParser(ResponseParser):
|
||||
return result
|
||||
|
||||
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
|
||||
# 对于 message_start 事件,usage 在 message.usage 路径下
|
||||
# 对于其他响应,usage 在顶层
|
||||
usage = response.get("usage", {})
|
||||
if not usage and "message" in response:
|
||||
usage = response.get("message", {}).get("usage", {})
|
||||
|
||||
return {
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
|
||||
"cache_creation_tokens": extract_cache_creation_tokens(usage),
|
||||
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
- 请求/响应数据
|
||||
"""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@@ -25,12 +26,18 @@ class StreamContext:
|
||||
model: str
|
||||
api_format: str
|
||||
|
||||
# 请求标识信息(CLI handler 需要)
|
||||
request_id: str = ""
|
||||
user_id: int = 0
|
||||
api_key_id: int = 0
|
||||
|
||||
# Provider 信息(在请求执行时填充)
|
||||
provider_name: Optional[str] = None
|
||||
provider_id: Optional[str] = None
|
||||
endpoint_id: Optional[str] = None
|
||||
key_id: Optional[str] = None
|
||||
attempt_id: Optional[str] = None
|
||||
attempt_synced: bool = False
|
||||
provider_api_format: Optional[str] = None # Provider 的响应格式
|
||||
|
||||
# 模型映射
|
||||
@@ -43,7 +50,14 @@ class StreamContext:
|
||||
cache_creation_tokens: int = 0
|
||||
|
||||
# 响应内容
|
||||
collected_text: str = ""
|
||||
_collected_text_parts: List[str] = field(default_factory=list, repr=False)
|
||||
response_id: Optional[str] = None
|
||||
final_usage: Optional[Dict[str, Any]] = None
|
||||
final_response: Optional[Dict[str, Any]] = None
|
||||
|
||||
# 时间指标
|
||||
first_byte_time_ms: Optional[int] = None # 首字时间 (TTFB - Time To First Byte)
|
||||
start_time: float = field(default_factory=time.time)
|
||||
|
||||
# 响应状态
|
||||
status_code: int = 200
|
||||
@@ -55,6 +69,12 @@ class StreamContext:
|
||||
provider_request_headers: Dict[str, str] = field(default_factory=dict)
|
||||
provider_request_body: Optional[Dict[str, Any]] = None
|
||||
|
||||
# 格式转换信息(CLI handler 需要)
|
||||
client_api_format: str = ""
|
||||
|
||||
# Provider 响应元数据(CLI handler 需要)
|
||||
response_metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# 流式处理统计
|
||||
data_count: int = 0
|
||||
chunk_count: int = 0
|
||||
@@ -71,16 +91,30 @@ class StreamContext:
|
||||
self.chunk_count = 0
|
||||
self.data_count = 0
|
||||
self.has_completion = False
|
||||
self.collected_text = ""
|
||||
self._collected_text_parts = []
|
||||
self.input_tokens = 0
|
||||
self.output_tokens = 0
|
||||
self.cached_tokens = 0
|
||||
self.cache_creation_tokens = 0
|
||||
self.error_message = None
|
||||
self.status_code = 200
|
||||
self.first_byte_time_ms = None
|
||||
self.response_headers = {}
|
||||
self.provider_request_headers = {}
|
||||
self.provider_request_body = None
|
||||
self.response_id = None
|
||||
self.final_usage = None
|
||||
self.final_response = None
|
||||
|
||||
@property
|
||||
def collected_text(self) -> str:
|
||||
"""已收集的文本内容(按需拼接,避免在流式过程中频繁做字符串拷贝)"""
|
||||
return "".join(self._collected_text_parts)
|
||||
|
||||
def append_text(self, text: str) -> None:
|
||||
"""追加文本内容(仅在需要收集文本时调用)"""
|
||||
if text:
|
||||
self._collected_text_parts.append(text)
|
||||
|
||||
def update_provider_info(
|
||||
self,
|
||||
@@ -104,14 +138,40 @@ class StreamContext:
|
||||
cached_tokens: Optional[int] = None,
|
||||
cache_creation_tokens: Optional[int] = None,
|
||||
) -> None:
|
||||
"""更新 Token 使用统计"""
|
||||
if input_tokens is not None:
|
||||
"""
|
||||
更新 Token 使用统计
|
||||
|
||||
采用防御性更新策略:只有当新值 > 0 或当前值为 0 时才更新,避免用 0 覆盖已有的正确值。
|
||||
|
||||
设计原理:
|
||||
- 在流式响应中,某些事件可能不包含完整的 usage 信息(字段为 0 或不存在)
|
||||
- 后续事件可能会提供完整的统计数据
|
||||
- 通过这种策略,确保一旦获得非零值就保留它,不会被后续的 0 值覆盖
|
||||
|
||||
示例场景:
|
||||
- message_start 事件:input_tokens=100, output_tokens=0
|
||||
- message_delta 事件:input_tokens=0, output_tokens=50
|
||||
- 最终结果:input_tokens=100, output_tokens=50
|
||||
|
||||
注意事项:
|
||||
- 此策略假设初始值为 0 是正确的默认状态
|
||||
- 如果需要将已有值重置为 0,请直接修改实例属性(不使用此方法)
|
||||
|
||||
Args:
|
||||
input_tokens: 输入 tokens 数量
|
||||
output_tokens: 输出 tokens 数量
|
||||
cached_tokens: 缓存命中 tokens 数量
|
||||
cache_creation_tokens: 缓存创建 tokens 数量
|
||||
"""
|
||||
if input_tokens is not None and (input_tokens > 0 or self.input_tokens == 0):
|
||||
self.input_tokens = input_tokens
|
||||
if output_tokens is not None:
|
||||
if output_tokens is not None and (output_tokens > 0 or self.output_tokens == 0):
|
||||
self.output_tokens = output_tokens
|
||||
if cached_tokens is not None:
|
||||
if cached_tokens is not None and (cached_tokens > 0 or self.cached_tokens == 0):
|
||||
self.cached_tokens = cached_tokens
|
||||
if cache_creation_tokens is not None:
|
||||
if cache_creation_tokens is not None and (
|
||||
cache_creation_tokens > 0 or self.cache_creation_tokens == 0
|
||||
):
|
||||
self.cache_creation_tokens = cache_creation_tokens
|
||||
|
||||
def mark_failed(self, status_code: int, error_message: str) -> None:
|
||||
@@ -119,6 +179,19 @@ class StreamContext:
|
||||
self.status_code = status_code
|
||||
self.error_message = error_message
|
||||
|
||||
def record_first_byte_time(self, start_time: float) -> None:
|
||||
"""
|
||||
记录首字时间 (TTFB - Time To First Byte)
|
||||
|
||||
应在第一次向客户端发送数据时调用。
|
||||
如果已记录过,则不会覆盖(避免重试时重复记录)。
|
||||
|
||||
Args:
|
||||
start_time: 请求开始时间 (time.time())
|
||||
"""
|
||||
if self.first_byte_time_ms is None:
|
||||
self.first_byte_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
def is_success(self) -> bool:
|
||||
"""检查请求是否成功"""
|
||||
return self.status_code < 400
|
||||
@@ -145,10 +218,22 @@ class StreamContext:
|
||||
获取日志摘要
|
||||
|
||||
用于请求完成/失败时的日志输出。
|
||||
包含首字时间 (TTFB) 和总响应时间,分两行显示。
|
||||
"""
|
||||
status = "OK" if self.is_success() else "FAIL"
|
||||
return (
|
||||
|
||||
# 第一行:基本信息 + 首字时间
|
||||
line1 = (
|
||||
f"[{status}] {request_id[:8]} | {self.model} | "
|
||||
f"{self.provider_name or 'unknown'} | {response_time_ms}ms | "
|
||||
f"{self.provider_name or 'unknown'}"
|
||||
)
|
||||
if self.first_byte_time_ms is not None:
|
||||
line1 += f" | TTFB: {self.first_byte_time_ms}ms"
|
||||
|
||||
# 第二行:总响应时间 + tokens
|
||||
line2 = (
|
||||
f" Total: {response_time_ms}ms | "
|
||||
f"in:{self.input_tokens} out:{self.output_tokens}"
|
||||
)
|
||||
|
||||
return f"{line1}\n{line2}"
|
||||
|
||||
@@ -9,7 +9,9 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import codecs
|
||||
import json
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Callable, Optional
|
||||
|
||||
import httpx
|
||||
@@ -36,6 +38,8 @@ class StreamProcessor:
|
||||
request_id: str,
|
||||
default_parser: ResponseParser,
|
||||
on_streaming_start: Optional[Callable[[], None]] = None,
|
||||
*,
|
||||
collect_text: bool = False,
|
||||
):
|
||||
"""
|
||||
初始化流处理器
|
||||
@@ -48,6 +52,7 @@ class StreamProcessor:
|
||||
self.request_id = request_id
|
||||
self.default_parser = default_parser
|
||||
self.on_streaming_start = on_streaming_start
|
||||
self.collect_text = collect_text
|
||||
|
||||
def get_parser_for_provider(self, ctx: StreamContext) -> ResponseParser:
|
||||
"""
|
||||
@@ -112,9 +117,10 @@ class StreamProcessor:
|
||||
)
|
||||
|
||||
# 提取文本
|
||||
text = parser.extract_text_content(data)
|
||||
if text:
|
||||
ctx.collected_text += text
|
||||
if self.collect_text:
|
||||
text = parser.extract_text_content(data)
|
||||
if text:
|
||||
ctx.append_text(text)
|
||||
|
||||
# 检查完成
|
||||
event_type = event_name or data.get("type", "")
|
||||
@@ -123,7 +129,7 @@ class StreamProcessor:
|
||||
|
||||
async def prefetch_and_check_error(
|
||||
self,
|
||||
line_iterator: Any,
|
||||
byte_iterator: Any,
|
||||
provider: Provider,
|
||||
endpoint: ProviderEndpoint,
|
||||
ctx: StreamContext,
|
||||
@@ -136,97 +142,126 @@ class StreamProcessor:
|
||||
这种情况需要在流开始输出之前检测,以便触发重试逻辑。
|
||||
|
||||
Args:
|
||||
line_iterator: 行迭代器
|
||||
byte_iterator: 字节流迭代器
|
||||
provider: Provider 对象
|
||||
endpoint: Endpoint 对象
|
||||
ctx: 流式上下文
|
||||
max_prefetch_lines: 最多预读行数
|
||||
|
||||
Returns:
|
||||
预读的行列表
|
||||
预读的字节块列表
|
||||
|
||||
Raises:
|
||||
EmbeddedErrorException: 如果检测到嵌套错误
|
||||
"""
|
||||
prefetched_lines: list = []
|
||||
prefetched_chunks: list = []
|
||||
parser = self.get_parser_for_provider(ctx)
|
||||
buffer = b""
|
||||
line_count = 0
|
||||
should_stop = False
|
||||
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
|
||||
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
||||
|
||||
try:
|
||||
line_count = 0
|
||||
async for line in line_iterator:
|
||||
prefetched_lines.append(line)
|
||||
line_count += 1
|
||||
async for chunk in byte_iterator:
|
||||
prefetched_chunks.append(chunk)
|
||||
buffer += chunk
|
||||
|
||||
normalized_line = line.rstrip("\r")
|
||||
if not normalized_line or normalized_line.startswith(":"):
|
||||
if line_count >= max_prefetch_lines:
|
||||
# 尝试按行解析缓冲区
|
||||
while b"\n" in buffer:
|
||||
line_bytes, buffer = buffer.split(b"\n", 1)
|
||||
try:
|
||||
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
|
||||
line = decoder.decode(line_bytes + b"\n", False).rstrip("\r\n")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[{self.request_id}] 预读时 UTF-8 解码失败: {e}, "
|
||||
f"bytes={line_bytes[:50]!r}"
|
||||
)
|
||||
continue
|
||||
|
||||
line_count += 1
|
||||
|
||||
# 跳过空行和注释行
|
||||
if not line or line.startswith(":"):
|
||||
if line_count >= max_prefetch_lines:
|
||||
should_stop = True
|
||||
break
|
||||
continue
|
||||
|
||||
# 尝试解析 SSE 数据
|
||||
data_str = line
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
|
||||
if data_str == "[DONE]":
|
||||
should_stop = True
|
||||
break
|
||||
continue
|
||||
|
||||
# 尝试解析 SSE 数据
|
||||
data_str = normalized_line
|
||||
if normalized_line.startswith("data: "):
|
||||
data_str = normalized_line[6:]
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
if line_count >= max_prefetch_lines:
|
||||
should_stop = True
|
||||
break
|
||||
continue
|
||||
|
||||
if data_str == "[DONE]":
|
||||
# 使用解析器检查是否为错误响应
|
||||
if isinstance(data, dict) and parser.is_error_response(data):
|
||||
parsed = parser.parse_response(data, 200)
|
||||
logger.warning(
|
||||
f" [{self.request_id}] 检测到嵌套错误: "
|
||||
f"Provider={provider.name}, "
|
||||
f"error_type={parsed.error_type}, "
|
||||
f"message={parsed.error_message}"
|
||||
)
|
||||
raise EmbeddedErrorException(
|
||||
provider_name=str(provider.name),
|
||||
error_code=(
|
||||
int(parsed.error_type)
|
||||
if parsed.error_type and parsed.error_type.isdigit()
|
||||
else None
|
||||
),
|
||||
error_message=parsed.error_message,
|
||||
error_status=parsed.error_type,
|
||||
)
|
||||
|
||||
# 预读到有效数据,没有错误,停止预读
|
||||
should_stop = True
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
if line_count >= max_prefetch_lines:
|
||||
break
|
||||
continue
|
||||
|
||||
# 使用解析器检查是否为错误响应
|
||||
if isinstance(data, dict) and parser.is_error_response(data):
|
||||
parsed = parser.parse_response(data, 200)
|
||||
logger.warning(
|
||||
f" [{self.request_id}] 检测到嵌套错误: "
|
||||
f"Provider={provider.name}, "
|
||||
f"error_type={parsed.error_type}, "
|
||||
f"message={parsed.error_message}"
|
||||
)
|
||||
raise EmbeddedErrorException(
|
||||
provider_name=str(provider.name),
|
||||
error_code=(
|
||||
int(parsed.error_type)
|
||||
if parsed.error_type and parsed.error_type.isdigit()
|
||||
else None
|
||||
),
|
||||
error_message=parsed.error_message,
|
||||
error_status=parsed.error_type,
|
||||
)
|
||||
|
||||
# 预读到有效数据,没有错误,停止预读
|
||||
break
|
||||
if should_stop or line_count >= max_prefetch_lines:
|
||||
break
|
||||
|
||||
except EmbeddedErrorException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
|
||||
|
||||
return prefetched_lines
|
||||
return prefetched_chunks
|
||||
|
||||
async def create_response_stream(
|
||||
self,
|
||||
ctx: StreamContext,
|
||||
line_iterator: Any,
|
||||
byte_iterator: Any,
|
||||
response_ctx: Any,
|
||||
http_client: httpx.AsyncClient,
|
||||
prefetched_lines: Optional[list] = None,
|
||||
prefetched_chunks: Optional[list] = None,
|
||||
*,
|
||||
start_time: Optional[float] = None,
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""
|
||||
创建响应流生成器
|
||||
|
||||
统一的流生成器,支持带预读数据和不带预读数据两种情况。
|
||||
从字节流中解析 SSE 数据并转发,支持预读数据。
|
||||
|
||||
Args:
|
||||
ctx: 流式上下文
|
||||
line_iterator: 行迭代器
|
||||
byte_iterator: 字节流迭代器
|
||||
response_ctx: HTTP 响应上下文管理器
|
||||
http_client: HTTP 客户端
|
||||
prefetched_lines: 预读的行列表(可选)
|
||||
prefetched_chunks: 预读的字节块列表(可选)
|
||||
start_time: 请求开始时间,用于计算 TTFB(可选)
|
||||
|
||||
Yields:
|
||||
编码后的响应数据块
|
||||
@@ -234,25 +269,82 @@ class StreamProcessor:
|
||||
try:
|
||||
sse_parser = SSEEventParser()
|
||||
streaming_started = False
|
||||
buffer = b""
|
||||
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
|
||||
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
||||
|
||||
# 处理预读数据
|
||||
if prefetched_lines:
|
||||
if prefetched_chunks:
|
||||
if not streaming_started and self.on_streaming_start:
|
||||
self.on_streaming_start()
|
||||
streaming_started = True
|
||||
|
||||
for line in prefetched_lines:
|
||||
for chunk in self._process_line(ctx, sse_parser, line):
|
||||
yield chunk
|
||||
for chunk in prefetched_chunks:
|
||||
# 记录首字时间 (TTFB) - 在 yield 之前记录
|
||||
if start_time is not None:
|
||||
ctx.record_first_byte_time(start_time)
|
||||
start_time = None # 只记录一次
|
||||
|
||||
# 把原始数据转发给客户端
|
||||
yield chunk
|
||||
|
||||
buffer += chunk
|
||||
# 处理缓冲区中的完整行
|
||||
while b"\n" in buffer:
|
||||
line_bytes, buffer = buffer.split(b"\n", 1)
|
||||
try:
|
||||
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
|
||||
line = decoder.decode(line_bytes + b"\n", False)
|
||||
self._process_line(ctx, sse_parser, line)
|
||||
except Exception as e:
|
||||
# 解码失败,记录警告但继续处理
|
||||
logger.warning(
|
||||
f"[{self.request_id}] UTF-8 解码失败: {e}, "
|
||||
f"bytes={line_bytes[:50]!r}"
|
||||
)
|
||||
continue
|
||||
|
||||
# 处理剩余的流数据
|
||||
async for line in line_iterator:
|
||||
async for chunk in byte_iterator:
|
||||
if not streaming_started and self.on_streaming_start:
|
||||
self.on_streaming_start()
|
||||
streaming_started = True
|
||||
|
||||
for chunk in self._process_line(ctx, sse_parser, line):
|
||||
yield chunk
|
||||
# 记录首字时间 (TTFB) - 在 yield 之前记录(如果预读数据为空)
|
||||
if start_time is not None:
|
||||
ctx.record_first_byte_time(start_time)
|
||||
start_time = None # 只记录一次
|
||||
|
||||
# 原始数据透传
|
||||
yield chunk
|
||||
|
||||
buffer += chunk
|
||||
# 处理缓冲区中的完整行
|
||||
while b"\n" in buffer:
|
||||
line_bytes, buffer = buffer.split(b"\n", 1)
|
||||
try:
|
||||
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
|
||||
line = decoder.decode(line_bytes + b"\n", False)
|
||||
self._process_line(ctx, sse_parser, line)
|
||||
except Exception as e:
|
||||
# 解码失败,记录警告但继续处理
|
||||
logger.warning(
|
||||
f"[{self.request_id}] UTF-8 解码失败: {e}, "
|
||||
f"bytes={line_bytes[:50]!r}"
|
||||
)
|
||||
continue
|
||||
|
||||
# 处理剩余的缓冲区数据(如果有未完成的行)
|
||||
if buffer:
|
||||
try:
|
||||
# 使用 final=True 处理最后的不完整字符
|
||||
line = decoder.decode(buffer, True)
|
||||
self._process_line(ctx, sse_parser, line)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[{self.request_id}] 处理剩余缓冲区失败: {e}, "
|
||||
f"bytes={buffer[:50]!r}"
|
||||
)
|
||||
|
||||
# 处理剩余事件
|
||||
for event in sse_parser.flush():
|
||||
@@ -268,7 +360,7 @@ class StreamProcessor:
|
||||
ctx: StreamContext,
|
||||
sse_parser: SSEEventParser,
|
||||
line: str,
|
||||
) -> list[bytes]:
|
||||
) -> None:
|
||||
"""
|
||||
处理单行数据
|
||||
|
||||
@@ -276,26 +368,17 @@ class StreamProcessor:
|
||||
ctx: 流式上下文
|
||||
sse_parser: SSE 解析器
|
||||
line: 原始行数据
|
||||
|
||||
Returns:
|
||||
要发送的数据块列表
|
||||
"""
|
||||
result: list[bytes] = []
|
||||
normalized_line = line.rstrip("\r")
|
||||
# SSEEventParser 以“去掉换行符”的单行文本作为输入;这里统一剔除 CR/LF,
|
||||
# 避免把空行误判成 "\n" 并导致事件边界解析错误。
|
||||
normalized_line = line.rstrip("\r\n")
|
||||
events = sse_parser.feed_line(normalized_line)
|
||||
|
||||
if normalized_line == "":
|
||||
for event in events:
|
||||
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
||||
result.append(b"\n")
|
||||
else:
|
||||
if normalized_line != "":
|
||||
ctx.chunk_count += 1
|
||||
result.append((line + "\n").encode("utf-8"))
|
||||
|
||||
for event in events:
|
||||
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
||||
|
||||
return result
|
||||
for event in events:
|
||||
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
||||
|
||||
async def create_monitored_stream(
|
||||
self,
|
||||
@@ -317,16 +400,26 @@ class StreamProcessor:
|
||||
响应数据块
|
||||
"""
|
||||
try:
|
||||
# 断连检查频率:每次 await 都会引入调度开销,过于频繁会让流式"发一段停一段"
|
||||
# 这里按时间间隔节流,兼顾及时停止上游读取与吞吐平滑性。
|
||||
next_disconnect_check_at = 0.0
|
||||
disconnect_check_interval_s = 0.25
|
||||
|
||||
async for chunk in stream_generator:
|
||||
if await is_disconnected():
|
||||
logger.warning(f"ID:{self.request_id} | Client disconnected")
|
||||
ctx.status_code = 499 # Client Closed Request
|
||||
ctx.error_message = "client_disconnected"
|
||||
break
|
||||
now = time.monotonic()
|
||||
if now >= next_disconnect_check_at:
|
||||
next_disconnect_check_at = now + disconnect_check_interval_s
|
||||
if await is_disconnected():
|
||||
logger.warning(f"ID:{self.request_id} | Client disconnected")
|
||||
ctx.status_code = 499 # Client Closed Request
|
||||
ctx.error_message = "client_disconnected"
|
||||
|
||||
break
|
||||
yield chunk
|
||||
except asyncio.CancelledError:
|
||||
ctx.status_code = 499
|
||||
ctx.error_message = "client_disconnected"
|
||||
|
||||
raise
|
||||
except Exception as e:
|
||||
ctx.status_code = 500
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -57,7 +58,7 @@ class StreamTelemetryRecorder:
|
||||
ctx: StreamContext,
|
||||
original_headers: Dict[str, str],
|
||||
original_request_body: Dict[str, Any],
|
||||
response_time_ms: int,
|
||||
start_time: float,
|
||||
) -> None:
|
||||
"""
|
||||
记录流式统计信息
|
||||
@@ -66,11 +67,15 @@ class StreamTelemetryRecorder:
|
||||
ctx: 流式上下文
|
||||
original_headers: 原始请求头
|
||||
original_request_body: 原始请求体
|
||||
response_time_ms: 响应时间(毫秒)
|
||||
start_time: 请求开始时间 (time.time())
|
||||
"""
|
||||
bg_db = None
|
||||
|
||||
try:
|
||||
# 在流结束后计算响应时间,与首字时间使用相同的时间基准
|
||||
# 注意:不要把统计延迟(stream_stats_delay)算进响应时间里
|
||||
response_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
await asyncio.sleep(config.stream_stats_delay) # 等待流完全关闭
|
||||
|
||||
if not ctx.provider_name:
|
||||
@@ -155,6 +160,7 @@ class StreamTelemetryRecorder:
|
||||
input_tokens=ctx.input_tokens,
|
||||
output_tokens=ctx.output_tokens,
|
||||
response_time_ms=response_time_ms,
|
||||
first_byte_time_ms=ctx.first_byte_time_ms, # 传递首字时间
|
||||
status_code=ctx.status_code,
|
||||
request_headers=original_headers,
|
||||
request_body=actual_request_body,
|
||||
|
||||
55
src/api/handlers/base/utils.py
Normal file
55
src/api/handlers/base/utils.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
Handler 基础工具函数
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int:
|
||||
"""
|
||||
提取缓存创建 tokens(兼容新旧格式)
|
||||
|
||||
Claude API 在不同版本中使用了不同的字段名来表示缓存创建 tokens:
|
||||
- 新格式(2024年后):使用 claude_cache_creation_5_m_tokens 和
|
||||
claude_cache_creation_1_h_tokens 分别表示 5 分钟和 1 小时缓存
|
||||
- 旧格式:使用 cache_creation_input_tokens 表示总的缓存创建 tokens
|
||||
|
||||
此函数自动检测并适配两种格式,优先使用新格式。
|
||||
|
||||
Args:
|
||||
usage: API 响应中的 usage 字典
|
||||
|
||||
Returns:
|
||||
缓存创建 tokens 总数
|
||||
"""
|
||||
# 检查新格式字段是否存在(而非值是否为 0)
|
||||
# 如果字段存在,即使值为 0 也是合法的,不应 fallback 到旧格式
|
||||
has_new_format = (
|
||||
"claude_cache_creation_5_m_tokens" in usage
|
||||
or "claude_cache_creation_1_h_tokens" in usage
|
||||
)
|
||||
|
||||
if has_new_format:
|
||||
cache_5m = usage.get("claude_cache_creation_5_m_tokens", 0)
|
||||
cache_1h = usage.get("claude_cache_creation_1_h_tokens", 0)
|
||||
return int(cache_5m) + int(cache_1h)
|
||||
|
||||
# 回退到旧格式
|
||||
return int(usage.get("cache_creation_input_tokens", 0))
|
||||
|
||||
|
||||
def build_sse_headers(extra_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
|
||||
"""
|
||||
构建 SSE(text/event-stream)推荐响应头,用于减少代理缓冲带来的卡顿/成段输出。
|
||||
|
||||
说明:
|
||||
- Cache-Control: no-transform 可避免部分代理对流做压缩/改写导致缓冲
|
||||
- X-Accel-Buffering: no 可显式提示 Nginx 关闭缓冲(即使全局已关闭也无害)
|
||||
"""
|
||||
headers: Dict[str, str] = {
|
||||
"Cache-Control": "no-cache, no-transform",
|
||||
"X-Accel-Buffering": "no",
|
||||
}
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
return headers
|
||||
@@ -8,6 +8,7 @@ Claude Chat Handler - 基于通用 Chat Handler 基类的简化实现
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
||||
from src.api.handlers.base.utils import extract_cache_creation_tokens
|
||||
|
||||
|
||||
class ClaudeChatHandler(ChatHandlerBase):
|
||||
@@ -63,7 +64,7 @@ class ClaudeChatHandler(ChatHandlerBase):
|
||||
result["model"] = mapped_model
|
||||
return result
|
||||
|
||||
async def _convert_request(self, request):
|
||||
async def _convert_request(self, request: Any) -> Any:
|
||||
"""
|
||||
将请求转换为 Claude 格式
|
||||
|
||||
@@ -109,30 +110,18 @@ class ClaudeChatHandler(ChatHandlerBase):
|
||||
Claude 格式使用:
|
||||
- input_tokens / output_tokens
|
||||
- cache_creation_input_tokens / cache_read_input_tokens
|
||||
- 新格式:claude_cache_creation_5_m_tokens / claude_cache_creation_1_h_tokens
|
||||
"""
|
||||
usage = response.get("usage", {})
|
||||
|
||||
input_tokens = usage.get("input_tokens", 0)
|
||||
output_tokens = usage.get("output_tokens", 0)
|
||||
cache_creation_input_tokens = usage.get("cache_creation_input_tokens", 0)
|
||||
cache_read_input_tokens = usage.get("cache_read_input_tokens", 0)
|
||||
|
||||
# 处理新的 cache_creation 格式
|
||||
if "cache_creation" in usage:
|
||||
cache_creation_data = usage.get("cache_creation", {})
|
||||
if not cache_creation_input_tokens:
|
||||
cache_creation_input_tokens = cache_creation_data.get(
|
||||
"ephemeral_5m_input_tokens", 0
|
||||
) + cache_creation_data.get("ephemeral_1h_input_tokens", 0)
|
||||
|
||||
return {
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"cache_creation_input_tokens": cache_creation_input_tokens,
|
||||
"cache_read_input_tokens": cache_read_input_tokens,
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"cache_creation_input_tokens": extract_cache_creation_tokens(usage),
|
||||
"cache_read_input_tokens": usage.get("cache_read_input_tokens", 0),
|
||||
}
|
||||
|
||||
def _normalize_response(self, response: Dict) -> Dict:
|
||||
def _normalize_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
规范化 Claude 响应
|
||||
|
||||
@@ -143,8 +132,9 @@ class ClaudeChatHandler(ChatHandlerBase):
|
||||
规范化后的响应
|
||||
"""
|
||||
if self.response_normalizer and self.response_normalizer.should_normalize(response):
|
||||
return self.response_normalizer.normalize_claude_response(
|
||||
result: Dict[str, Any] = self.response_normalizer.normalize_claude_response(
|
||||
response_data=response,
|
||||
request_id=self.request_id,
|
||||
)
|
||||
return result
|
||||
return response
|
||||
|
||||
@@ -9,6 +9,8 @@ from __future__ import annotations
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from src.api.handlers.base.utils import extract_cache_creation_tokens
|
||||
|
||||
|
||||
class ClaudeStreamParser:
|
||||
"""
|
||||
@@ -193,7 +195,7 @@ class ClaudeStreamParser:
|
||||
return {
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
|
||||
"cache_creation_tokens": extract_cache_creation_tokens(usage),
|
||||
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
|
||||
}
|
||||
|
||||
@@ -204,7 +206,7 @@ class ClaudeStreamParser:
|
||||
return {
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
|
||||
"cache_creation_tokens": extract_cache_creation_tokens(usage),
|
||||
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from src.api.handlers.base.cli_handler_base import (
|
||||
CliMessageHandlerBase,
|
||||
StreamContext,
|
||||
)
|
||||
from src.api.handlers.base.utils import extract_cache_creation_tokens
|
||||
|
||||
|
||||
class ClaudeCliMessageHandler(CliMessageHandlerBase):
|
||||
@@ -95,11 +96,12 @@ class ClaudeCliMessageHandler(CliMessageHandlerBase):
|
||||
usage = message.get("usage", {})
|
||||
if usage:
|
||||
ctx.input_tokens = usage.get("input_tokens", 0)
|
||||
# Claude 的缓存 tokens 使用不同的字段名
|
||||
|
||||
cache_read = usage.get("cache_read_input_tokens", 0)
|
||||
if cache_read:
|
||||
ctx.cached_tokens = cache_read
|
||||
cache_creation = usage.get("cache_creation_input_tokens", 0)
|
||||
|
||||
cache_creation = extract_cache_creation_tokens(usage)
|
||||
if cache_creation:
|
||||
ctx.cache_creation_tokens = cache_creation
|
||||
|
||||
@@ -109,7 +111,7 @@ class ClaudeCliMessageHandler(CliMessageHandlerBase):
|
||||
if delta.get("type") == "text_delta":
|
||||
text = delta.get("text", "")
|
||||
if text:
|
||||
ctx.collected_text += text
|
||||
ctx.append_text(text)
|
||||
|
||||
# 处理消息增量(包含最终 usage)
|
||||
elif event_type == "message_delta":
|
||||
@@ -119,11 +121,15 @@ class ClaudeCliMessageHandler(CliMessageHandlerBase):
|
||||
ctx.input_tokens = usage["input_tokens"]
|
||||
if "output_tokens" in usage:
|
||||
ctx.output_tokens = usage["output_tokens"]
|
||||
# 更新缓存 tokens
|
||||
|
||||
# 更新缓存读取 tokens
|
||||
if "cache_read_input_tokens" in usage:
|
||||
ctx.cached_tokens = usage["cache_read_input_tokens"]
|
||||
if "cache_creation_input_tokens" in usage:
|
||||
ctx.cache_creation_tokens = usage["cache_creation_input_tokens"]
|
||||
|
||||
# 更新缓存创建 tokens
|
||||
cache_creation = extract_cache_creation_tokens(usage)
|
||||
if cache_creation > 0:
|
||||
ctx.cache_creation_tokens = cache_creation
|
||||
|
||||
# 检查是否结束
|
||||
delta = data.get("delta", {})
|
||||
|
||||
@@ -160,7 +160,7 @@ class GeminiCliMessageHandler(CliMessageHandlerBase):
|
||||
parts = content.get("parts", [])
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
ctx.collected_text += part["text"]
|
||||
ctx.append_text(part["text"])
|
||||
|
||||
# 检查结束原因
|
||||
finish_reason = candidate.get("finishReason")
|
||||
|
||||
@@ -94,9 +94,9 @@ class OpenAICliMessageHandler(CliMessageHandlerBase):
|
||||
if event_type in ["response.output_text.delta", "response.outtext.delta"]:
|
||||
delta = data.get("delta")
|
||||
if isinstance(delta, str):
|
||||
ctx.collected_text += delta
|
||||
ctx.append_text(delta)
|
||||
elif isinstance(delta, dict) and "text" in delta:
|
||||
ctx.collected_text += delta["text"]
|
||||
ctx.append_text(delta["text"])
|
||||
|
||||
# 处理完成事件
|
||||
elif event_type == "response.completed":
|
||||
@@ -124,7 +124,7 @@ class OpenAICliMessageHandler(CliMessageHandlerBase):
|
||||
if content_item.get("type") == "output_text":
|
||||
text = content_item.get("text", "")
|
||||
if text:
|
||||
ctx.collected_text += text
|
||||
ctx.append_text(text)
|
||||
|
||||
# 备用:从顶层 usage 提取
|
||||
usage_obj = data.get("usage")
|
||||
|
||||
@@ -307,7 +307,8 @@ class Usage(Base):
|
||||
is_stream = Column(Boolean, default=False) # 是否为流式请求
|
||||
status_code = Column(Integer)
|
||||
error_message = Column(Text, nullable=True)
|
||||
response_time_ms = Column(Integer) # 响应时间(毫秒)
|
||||
response_time_ms = Column(Integer) # 总响应时间(毫秒)
|
||||
first_byte_time_ms = Column(Integer, nullable=True) # 首字时间/TTFB(毫秒)
|
||||
|
||||
# 请求状态追踪
|
||||
# pending: 请求开始处理中
|
||||
|
||||
@@ -157,6 +157,7 @@ class UsageService:
|
||||
api_format: Optional[str] = None,
|
||||
is_stream: bool = False,
|
||||
response_time_ms: Optional[int] = None,
|
||||
first_byte_time_ms: Optional[int] = None, # 首字时间 (TTFB)
|
||||
status_code: int = 200,
|
||||
error_message: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
@@ -368,6 +369,7 @@ class UsageService:
|
||||
status_code=status_code,
|
||||
error_message=error_message,
|
||||
response_time_ms=response_time_ms,
|
||||
first_byte_time_ms=first_byte_time_ms, # 首字时间 (TTFB)
|
||||
status=status, # 请求状态追踪
|
||||
request_metadata=metadata,
|
||||
request_headers=processed_request_headers,
|
||||
@@ -419,6 +421,7 @@ class UsageService:
|
||||
api_format: Optional[str] = None,
|
||||
is_stream: bool = False,
|
||||
response_time_ms: Optional[int] = None,
|
||||
first_byte_time_ms: Optional[int] = None, # 首字时间 (TTFB)
|
||||
status_code: int = 200,
|
||||
error_message: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
@@ -629,6 +632,7 @@ class UsageService:
|
||||
status_code=status_code,
|
||||
error_message=error_message,
|
||||
response_time_ms=response_time_ms,
|
||||
first_byte_time_ms=first_byte_time_ms, # 首字时间 (TTFB)
|
||||
status=status, # 请求状态追踪
|
||||
request_metadata=metadata,
|
||||
request_headers=processed_request_headers,
|
||||
@@ -649,6 +653,7 @@ class UsageService:
|
||||
existing_usage.status_code = status_code
|
||||
existing_usage.error_message = error_message
|
||||
existing_usage.response_time_ms = response_time_ms
|
||||
existing_usage.first_byte_time_ms = first_byte_time_ms # 更新首字时间
|
||||
# 更新请求头和请求体(如果有新值)
|
||||
if processed_request_headers is not None:
|
||||
existing_usage.request_headers = processed_request_headers
|
||||
@@ -1315,11 +1320,11 @@ class UsageService:
|
||||
default_timeout_seconds: int = 300,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取活跃请求状态(用于前端轮询),并自动清理超时的 pending 请求
|
||||
获取活跃请求状态(用于前端轮询),并自动清理超时的 pending/streaming 请求
|
||||
|
||||
与 get_active_requests 不同,此方法:
|
||||
1. 返回轻量级的状态字典而非完整 Usage 对象
|
||||
2. 自动检测并清理超时的 pending 请求
|
||||
2. 自动检测并清理超时的 pending/streaming 请求
|
||||
3. 支持按 ID 列表查询特定请求
|
||||
|
||||
Args:
|
||||
@@ -1343,6 +1348,7 @@ class UsageService:
|
||||
Usage.output_tokens,
|
||||
Usage.total_cost_usd,
|
||||
Usage.response_time_ms,
|
||||
Usage.first_byte_time_ms, # 首字时间 (TTFB)
|
||||
Usage.created_at,
|
||||
Usage.provider_endpoint_id,
|
||||
ProviderEndpoint.timeout.label("endpoint_timeout"),
|
||||
@@ -1361,10 +1367,10 @@ class UsageService:
|
||||
|
||||
records = query.all()
|
||||
|
||||
# 检查超时的 pending 请求
|
||||
# 检查超时的 pending/streaming 请求
|
||||
timeout_ids = []
|
||||
for r in records:
|
||||
if r.status == "pending" and r.created_at:
|
||||
if r.status in ("pending", "streaming") and r.created_at:
|
||||
# 使用端点配置的超时时间,若无则使用默认值
|
||||
timeout_seconds = r.endpoint_timeout or default_timeout_seconds
|
||||
|
||||
@@ -1392,6 +1398,7 @@ class UsageService:
|
||||
"output_tokens": r.output_tokens,
|
||||
"cost": float(r.total_cost_usd) if r.total_cost_usd else 0,
|
||||
"response_time_ms": r.response_time_ms,
|
||||
"first_byte_time_ms": r.first_byte_time_ms, # 首字时间 (TTFB)
|
||||
}
|
||||
for r in records
|
||||
]
|
||||
|
||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""测试模块"""
|
||||
117
tests/api/handlers/base/test_stream_context.py
Normal file
117
tests/api/handlers/base/test_stream_context.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from src.api.handlers.base import stream_context
|
||||
from src.api.handlers.base.stream_context import StreamContext
|
||||
|
||||
|
||||
def test_collected_text_append_and_property() -> None:
|
||||
ctx = StreamContext(model="test-model", api_format="OPENAI")
|
||||
assert ctx.collected_text == ""
|
||||
|
||||
ctx.append_text("hello")
|
||||
ctx.append_text(" ")
|
||||
ctx.append_text("world")
|
||||
assert ctx.collected_text == "hello world"
|
||||
|
||||
|
||||
def test_reset_for_retry_clears_state() -> None:
|
||||
ctx = StreamContext(model="test-model", api_format="OPENAI")
|
||||
ctx.append_text("x")
|
||||
ctx.update_usage(input_tokens=10, output_tokens=5)
|
||||
ctx.parsed_chunks.append({"type": "chunk"})
|
||||
ctx.chunk_count = 3
|
||||
ctx.data_count = 2
|
||||
ctx.has_completion = True
|
||||
ctx.status_code = 418
|
||||
ctx.error_message = "boom"
|
||||
|
||||
ctx.reset_for_retry()
|
||||
|
||||
assert ctx.collected_text == ""
|
||||
assert ctx.input_tokens == 0
|
||||
assert ctx.output_tokens == 0
|
||||
assert ctx.parsed_chunks == []
|
||||
assert ctx.chunk_count == 0
|
||||
assert ctx.data_count == 0
|
||||
assert ctx.has_completion is False
|
||||
assert ctx.status_code == 200
|
||||
assert ctx.error_message is None
|
||||
|
||||
|
||||
def test_record_first_byte_time(monkeypatch) -> None:
|
||||
"""测试记录首字时间"""
|
||||
ctx = StreamContext(model="claude-3", api_format="claude_messages")
|
||||
start_time = 100.0
|
||||
monkeypatch.setattr(stream_context.time, "time", lambda: 100.0123) # 12.3ms
|
||||
|
||||
# 记录首字时间
|
||||
ctx.record_first_byte_time(start_time)
|
||||
|
||||
# 验证首字时间已记录
|
||||
assert ctx.first_byte_time_ms == 12
|
||||
|
||||
|
||||
def test_record_first_byte_time_idempotent(monkeypatch) -> None:
|
||||
"""测试首字时间只记录一次"""
|
||||
ctx = StreamContext(model="claude-3", api_format="claude_messages")
|
||||
start_time = 100.0
|
||||
|
||||
# 第一次记录
|
||||
monkeypatch.setattr(stream_context.time, "time", lambda: 100.010)
|
||||
ctx.record_first_byte_time(start_time)
|
||||
first_value = ctx.first_byte_time_ms
|
||||
|
||||
# 第二次记录(应该被忽略)
|
||||
monkeypatch.setattr(stream_context.time, "time", lambda: 100.020)
|
||||
ctx.record_first_byte_time(start_time)
|
||||
second_value = ctx.first_byte_time_ms
|
||||
|
||||
# 验证值没有改变
|
||||
assert first_value == second_value
|
||||
|
||||
|
||||
def test_reset_for_retry_clears_first_byte_time(monkeypatch) -> None:
|
||||
"""测试重试时清除首字时间"""
|
||||
ctx = StreamContext(model="claude-3", api_format="claude_messages")
|
||||
start_time = 100.0
|
||||
|
||||
# 记录首字时间
|
||||
monkeypatch.setattr(stream_context.time, "time", lambda: 100.010)
|
||||
ctx.record_first_byte_time(start_time)
|
||||
assert ctx.first_byte_time_ms is not None
|
||||
|
||||
# 重置
|
||||
ctx.reset_for_retry()
|
||||
|
||||
# 验证首字时间已清除
|
||||
assert ctx.first_byte_time_ms is None
|
||||
|
||||
|
||||
def test_get_log_summary_with_first_byte_time() -> None:
|
||||
"""测试日志摘要包含首字时间"""
|
||||
ctx = StreamContext(model="claude-3", api_format="claude_messages")
|
||||
ctx.provider_name = "anthropic"
|
||||
ctx.input_tokens = 100
|
||||
ctx.output_tokens = 50
|
||||
ctx.first_byte_time_ms = 123
|
||||
|
||||
summary = ctx.get_log_summary("request-id-123", 456)
|
||||
|
||||
# 验证包含首字时间和总时间(大写格式)
|
||||
assert "TTFB: 123ms" in summary
|
||||
assert "Total: 456ms" in summary
|
||||
assert "in:100 out:50" in summary
|
||||
|
||||
|
||||
def test_get_log_summary_without_first_byte_time() -> None:
|
||||
"""测试日志摘要在没有首字时间时的格式"""
|
||||
ctx = StreamContext(model="claude-3", api_format="claude_messages")
|
||||
ctx.provider_name = "anthropic"
|
||||
ctx.input_tokens = 100
|
||||
ctx.output_tokens = 50
|
||||
# first_byte_time_ms 保持为 None
|
||||
|
||||
summary = ctx.get_log_summary("request-id-123", 456)
|
||||
|
||||
# 验证不包含首字时间标记,但有总时间(使用大写 TTFB 和 Total)
|
||||
assert "TTFB:" not in summary
|
||||
assert "Total: 456ms" in summary
|
||||
assert "in:100 out:50" in summary
|
||||
32
tests/api/handlers/base/test_stream_processor.py
Normal file
32
tests/api/handlers/base/test_stream_processor.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.api.handlers.base.response_parser import ParsedChunk, ParsedResponse, ResponseParser, StreamStats
|
||||
from src.api.handlers.base.stream_context import StreamContext
|
||||
from src.api.handlers.base.stream_processor import StreamProcessor
|
||||
from src.utils.sse_parser import SSEEventParser
|
||||
|
||||
|
||||
class DummyParser(ResponseParser):
|
||||
def parse_sse_line(self, line: str, stats: StreamStats) -> Optional[ParsedChunk]:
|
||||
return None
|
||||
|
||||
def parse_response(self, response: Dict[str, Any], status_code: int) -> ParsedResponse:
|
||||
return ParsedResponse(raw_response=response, status_code=status_code)
|
||||
|
||||
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
|
||||
return {}
|
||||
|
||||
def extract_text_content(self, response: Dict[str, Any]) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def test_process_line_strips_newlines_and_finalizes_event() -> None:
|
||||
ctx = StreamContext(model="test-model", api_format="OPENAI")
|
||||
processor = StreamProcessor(request_id="test-request", default_parser=DummyParser())
|
||||
sse_parser = SSEEventParser()
|
||||
|
||||
processor._process_line(ctx, sse_parser, 'data: {"type":"response.completed"}\n')
|
||||
processor._process_line(ctx, sse_parser, "\n")
|
||||
|
||||
assert ctx.has_completion is True
|
||||
|
||||
104
tests/api/handlers/base/test_utils.py
Normal file
104
tests/api/handlers/base/test_utils.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""测试 handler 基础工具函数"""
|
||||
|
||||
import pytest
|
||||
|
||||
from src.api.handlers.base.utils import build_sse_headers, extract_cache_creation_tokens
|
||||
|
||||
|
||||
class TestExtractCacheCreationTokens:
|
||||
"""测试 extract_cache_creation_tokens 函数"""
|
||||
|
||||
def test_new_format_only(self) -> None:
|
||||
"""测试只有新格式字段"""
|
||||
usage = {
|
||||
"claude_cache_creation_5_m_tokens": 100,
|
||||
"claude_cache_creation_1_h_tokens": 200,
|
||||
}
|
||||
assert extract_cache_creation_tokens(usage) == 300
|
||||
|
||||
def test_new_format_5m_only(self) -> None:
|
||||
"""测试只有 5 分钟缓存"""
|
||||
usage = {
|
||||
"claude_cache_creation_5_m_tokens": 150,
|
||||
"claude_cache_creation_1_h_tokens": 0,
|
||||
}
|
||||
assert extract_cache_creation_tokens(usage) == 150
|
||||
|
||||
def test_new_format_1h_only(self) -> None:
|
||||
"""测试只有 1 小时缓存"""
|
||||
usage = {
|
||||
"claude_cache_creation_5_m_tokens": 0,
|
||||
"claude_cache_creation_1_h_tokens": 250,
|
||||
}
|
||||
assert extract_cache_creation_tokens(usage) == 250
|
||||
|
||||
def test_old_format_only(self) -> None:
|
||||
"""测试只有旧格式字段"""
|
||||
usage = {
|
||||
"cache_creation_input_tokens": 500,
|
||||
}
|
||||
assert extract_cache_creation_tokens(usage) == 500
|
||||
|
||||
def test_both_formats_prefers_new(self) -> None:
|
||||
"""测试同时存在时优先使用新格式"""
|
||||
usage = {
|
||||
"claude_cache_creation_5_m_tokens": 100,
|
||||
"claude_cache_creation_1_h_tokens": 200,
|
||||
"cache_creation_input_tokens": 999, # 应该被忽略
|
||||
}
|
||||
assert extract_cache_creation_tokens(usage) == 300
|
||||
|
||||
def test_empty_usage(self) -> None:
|
||||
"""测试空字典"""
|
||||
usage = {}
|
||||
assert extract_cache_creation_tokens(usage) == 0
|
||||
|
||||
def test_all_zeros(self) -> None:
|
||||
"""测试所有字段都为 0"""
|
||||
usage = {
|
||||
"claude_cache_creation_5_m_tokens": 0,
|
||||
"claude_cache_creation_1_h_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
}
|
||||
assert extract_cache_creation_tokens(usage) == 0
|
||||
|
||||
def test_partial_new_format_with_old_format_fallback(self) -> None:
|
||||
"""测试新格式字段不存在时回退到旧格式"""
|
||||
usage = {
|
||||
"cache_creation_input_tokens": 123,
|
||||
}
|
||||
assert extract_cache_creation_tokens(usage) == 123
|
||||
|
||||
def test_new_format_zero_should_not_fallback(self) -> None:
|
||||
"""测试新格式字段存在但为 0 时,不应 fallback 到旧格式"""
|
||||
usage = {
|
||||
"claude_cache_creation_5_m_tokens": 0,
|
||||
"claude_cache_creation_1_h_tokens": 0,
|
||||
"cache_creation_input_tokens": 456,
|
||||
}
|
||||
# 新格式字段存在,即使值为 0 也应该使用新格式(返回 0)
|
||||
# 而不是 fallback 到旧格式(返回 456)
|
||||
assert extract_cache_creation_tokens(usage) == 0
|
||||
|
||||
def test_unrelated_fields_ignored(self) -> None:
|
||||
"""测试忽略无关字段"""
|
||||
usage = {
|
||||
"input_tokens": 1000,
|
||||
"output_tokens": 2000,
|
||||
"cache_read_input_tokens": 300,
|
||||
"claude_cache_creation_5_m_tokens": 50,
|
||||
"claude_cache_creation_1_h_tokens": 75,
|
||||
}
|
||||
assert extract_cache_creation_tokens(usage) == 125
|
||||
|
||||
|
||||
class TestBuildSSEHeaders:
|
||||
def test_default_headers(self) -> None:
|
||||
headers = build_sse_headers()
|
||||
assert headers["Cache-Control"] == "no-cache, no-transform"
|
||||
assert headers["X-Accel-Buffering"] == "no"
|
||||
|
||||
def test_merge_extra_headers(self) -> None:
|
||||
headers = build_sse_headers({"X-Test": "1", "Cache-Control": "custom"})
|
||||
assert headers["X-Test"] == "1"
|
||||
assert headers["Cache-Control"] == "custom"
|
||||
Reference in New Issue
Block a user