feat: 优化首字时间和 streaming 状态的记录时序

改进 streaming 状态更新机制:
- 统一在首次输出时记录 TTFB 并更新 streaming 状态
- 重构 CliMessageHandlerBase 中的状态更新逻辑,消除重复
- 确保 provider/key 信息在 streaming 状态更新时已可用

前端改进:
- 添加 first_byte_time_ms 字段支持
- 管理员接口支持返回 provider/api_key_name 字段
- 优化活跃请求轮询逻辑,更准确地判断是否需要刷新完整数据

数据库与 API:
- UsageService.get_active_requests_status 添加 include_admin_fields 参数
- 管理员接口调用时启用该参数以获取额外信息
This commit is contained in:
fawney19
2026-01-05 10:31:34 +08:00
parent 43f349d415
commit 4fa9a1303a
7 changed files with 97 additions and 72 deletions

View File

@@ -203,11 +203,12 @@ export const meApi = {
async getActiveRequests(ids?: string): Promise<{
requests: Array<{
id: string
status: string
status: 'pending' | 'streaming' | 'completed' | 'failed'
input_tokens: number
output_tokens: number
cost: number
response_time_ms: number | null
first_byte_time_ms: number | null
}>
}> {
const params = ids ? { ids } : {}

View File

@@ -193,6 +193,9 @@ export const usageApi = {
output_tokens: number
cost: number
response_time_ms: number | null
first_byte_time_ms: number | null
provider?: string | null
api_key_name?: string | null
}>
}> {
const params = ids?.length ? { ids: ids.join(',') } : {}

View File

@@ -259,27 +259,40 @@ async function pollActiveRequests() {
? await usageApi.getActiveRequests(activeRequestIds.value)
: await meApi.getActiveRequests(idsParam)
// 检查是否有状态变化
let hasChanges = false
let shouldRefresh = false
for (const update of requests) {
const record = currentRecords.value.find(r => r.id === update.id)
if (record && record.status !== update.status) {
hasChanges = true
// 如果状态变为 completed 或 failed需要刷新获取完整数据
if (update.status === 'completed' || update.status === 'failed') {
break
}
// 否则只更新状态和 token 信息
if (!record) {
// 后端返回了未知的活跃请求,触发刷新以获取完整数据
shouldRefresh = true
continue
}
// 状态变化completed/failed 需要刷新获取完整数据
if (record.status !== update.status) {
record.status = update.status
record.input_tokens = update.input_tokens
record.output_tokens = update.output_tokens
record.cost = update.cost
record.response_time_ms = update.response_time_ms ?? undefined
}
if (update.status === 'completed' || update.status === 'failed') {
shouldRefresh = true
}
// 进行中状态也需要持续更新provider/key/TTFB 可能在 streaming 后才落库)
record.input_tokens = update.input_tokens
record.output_tokens = update.output_tokens
record.cost = update.cost
record.response_time_ms = update.response_time_ms ?? undefined
record.first_byte_time_ms = update.first_byte_time_ms ?? undefined
// 管理员接口返回额外字段
if ('provider' in update && typeof update.provider === 'string') {
record.provider = update.provider
}
if ('api_key_name' in update) {
record.api_key_name = typeof update.api_key_name === 'string' ? update.api_key_name : undefined
}
}
// 如果有请求完成或失败,刷新整个列表获取完整数据
if (hasChanges && requests.some(r => r.status === 'completed' || r.status === 'failed')) {
if (shouldRefresh) {
await refreshData()
}
} catch (error) {

View File

@@ -690,7 +690,9 @@ class AdminActiveRequestsAdapter(AdminApiAdapter):
if not id_list:
return {"requests": []}
requests = UsageService.get_active_requests_status(db=db, ids=id_list)
requests = UsageService.get_active_requests_status(
db=db, ids=id_list, include_admin_fields=True
)
return {"requests": requests}

View File

@@ -536,8 +536,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
try:
sse_parser = SSEEventParser()
last_data_time = time.time()
streaming_status_updated = False
buffer = b""
output_state = {"first_yield": True, "streaming_updated": False}
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
@@ -545,11 +545,6 @@ class CliMessageHandlerBase(BaseMessageHandler):
needs_conversion = self._needs_format_conversion(ctx)
async for chunk in stream_response.aiter_bytes():
# 在第一次输出数据前更新状态为 streaming
if not streaming_status_updated:
self._update_usage_to_streaming_with_ctx(ctx)
streaming_status_updated = True
buffer += chunk
# 处理缓冲区中的完整行
while b"\n" in buffer:
@@ -574,6 +569,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
event.get("event"),
event.get("data") or "",
)
self._mark_first_output(ctx, output_state)
yield b"\n"
continue
@@ -591,6 +587,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
},
}
self._mark_first_output(ctx, output_state)
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
return # 结束生成器
@@ -598,8 +595,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
self._mark_first_output(ctx, output_state)
yield (converted_line + "\n").encode("utf-8")
else:
self._mark_first_output(ctx, output_state)
yield (line + "\n").encode("utf-8")
for event in events:
@@ -650,7 +649,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
except httpx.RemoteProtocolError as e:
except httpx.RemoteProtocolError:
if ctx.data_count > 0:
error_event = {
"type": "error",
@@ -846,19 +845,13 @@ class CliMessageHandlerBase(BaseMessageHandler):
sse_parser = SSEEventParser()
last_data_time = time.time()
buffer = b""
first_yield = True # 标记是否是第一次 yield
streaming_status_updated = False # 标记状态是否已更新
output_state = {"first_yield": True, "streaming_updated": False}
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
# 检查是否需要格式转换
needs_conversion = self._needs_format_conversion(ctx)
# 在第一次输出数据前更新状态为 streaming
if prefetched_chunks:
self._update_usage_to_streaming_with_ctx(ctx)
streaming_status_updated = True
# 先处理预读的字节块
for chunk in prefetched_chunks:
buffer += chunk
@@ -885,10 +878,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
event.get("event"),
event.get("data") or "",
)
# 记录首字时间 (第一次 yield)
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
self._mark_first_output(ctx, output_state)
yield b"\n"
continue
@@ -898,16 +888,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
# 记录首字时间 (第一次 yield)
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
self._mark_first_output(ctx, output_state)
yield (converted_line + "\n").encode("utf-8")
else:
# 记录首字时间 (第一次 yield)
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
self._mark_first_output(ctx, output_state)
yield (line + "\n").encode("utf-8")
for event in events:
@@ -922,11 +906,6 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 继续处理剩余的流数据(使用同一个迭代器)
async for chunk in byte_iterator:
# 如果预读数据为空,在收到第一个 chunk 时更新状态
if not streaming_status_updated:
self._update_usage_to_streaming_with_ctx(ctx)
streaming_status_updated = True
buffer += chunk
# 处理缓冲区中的完整行
while b"\n" in buffer:
@@ -951,10 +930,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
event.get("event"),
event.get("data") or "",
)
# 记录首字时间 (第一次 yield) - 如果预读数据为空
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
self._mark_first_output(ctx, output_state)
yield b"\n"
continue
@@ -972,6 +948,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
},
}
self._mark_first_output(ctx, output_state)
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
return
@@ -979,16 +956,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
# 记录首字时间 (第一次 yield) - 如果预读数据为空
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
self._mark_first_output(ctx, output_state)
yield (converted_line + "\n").encode("utf-8")
else:
# 记录首字时间 (第一次 yield) - 如果预读数据为空
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
self._mark_first_output(ctx, output_state)
yield (line + "\n").encode("utf-8")
for event in events:
@@ -1685,6 +1656,25 @@ class CliMessageHandlerBase(BaseMessageHandler):
return False
return ctx.provider_api_format.upper() != ctx.client_api_format.upper()
def _mark_first_output(self, ctx: StreamContext, state: Dict[str, bool]) -> None:
"""
标记首次输出:记录 TTFB 并更新 streaming 状态
在第一次 yield 数据前调用,确保:
1. 首字时间 (TTFB) 已记录到 ctx
2. Usage 状态已更新为 streaming包含 provider/key/TTFB 信息)
Args:
ctx: 流上下文
state: 包含 first_yield 和 streaming_updated 的状态字典
"""
if state["first_yield"]:
ctx.record_first_byte_time(self.start_time)
state["first_yield"] = False
if not state["streaming_updated"]:
self._update_usage_to_streaming_with_ctx(ctx)
state["streaming_updated"] = True
def _convert_sse_line(
self,
ctx: StreamContext,

View File

@@ -332,15 +332,15 @@ class StreamProcessor:
# 处理预读数据
if prefetched_chunks:
if not streaming_started and self.on_streaming_start:
self.on_streaming_start()
streaming_started = True
for chunk in prefetched_chunks:
# 记录首字时间 (TTFB) - 在 yield 之前记录
if start_time is not None:
ctx.record_first_byte_time(start_time)
start_time = None # 只记录一次
# 首次输出前触发 streaming 回调(确保 TTFB 已写入 ctx
if not streaming_started and self.on_streaming_start:
self.on_streaming_start()
streaming_started = True
# 把原始数据转发给客户端
yield chunk
@@ -363,14 +363,14 @@ class StreamProcessor:
# 处理剩余的流数据
async for chunk in byte_iterator:
if not streaming_started and self.on_streaming_start:
self.on_streaming_start()
streaming_started = True
# 记录首字时间 (TTFB) - 在 yield 之前记录(如果预读数据为空)
if start_time is not None:
ctx.record_first_byte_time(start_time)
start_time = None # 只记录一次
# 首次输出前触发 streaming 回调(确保 TTFB 已写入 ctx
if not streaming_started and self.on_streaming_start:
self.on_streaming_start()
streaming_started = True
# 原始数据透传
yield chunk

View File

@@ -1636,6 +1636,8 @@ class UsageService:
ids: Optional[List[str]] = None,
user_id: Optional[str] = None,
default_timeout_seconds: int = 300,
*,
include_admin_fields: bool = False,
) -> List[Dict[str, Any]]:
"""
获取活跃请求状态(用于前端轮询),并自动清理超时的 pending/streaming 请求
@@ -1672,6 +1674,15 @@ class UsageService:
ProviderEndpoint.timeout.label("endpoint_timeout"),
).outerjoin(ProviderEndpoint, Usage.provider_endpoint_id == ProviderEndpoint.id)
# 管理员轮询:可附带 provider 与上游 key 名称(注意:不要在普通用户接口暴露上游 key 信息)
if include_admin_fields:
from src.models.database import ProviderAPIKey
query = query.add_columns(
Usage.provider,
ProviderAPIKey.name.label("api_key_name"),
).outerjoin(ProviderAPIKey, Usage.provider_api_key_id == ProviderAPIKey.id)
if ids:
query = query.filter(Usage.id.in_(ids))
if user_id:
@@ -1708,8 +1719,9 @@ class UsageService:
)
db.commit()
return [
{
result: List[Dict[str, Any]] = []
for r in records:
item: Dict[str, Any] = {
"id": r.id,
"status": "failed" if r.id in timeout_ids else r.status,
"input_tokens": r.input_tokens,
@@ -1718,8 +1730,12 @@ class UsageService:
"response_time_ms": r.response_time_ms,
"first_byte_time_ms": r.first_byte_time_ms, # 首字时间 (TTFB)
}
for r in records
]
if include_admin_fields:
item["provider"] = r.provider
item["api_key_name"] = r.api_key_name
result.append(item)
return result
# ========== 缓存亲和性分析方法 ==========