refactor: 重构流式处理模块,提取 StreamContext/Processor/Telemetry

- 将 chat_handler_base.py 中的流式处理逻辑拆分为三个独立模块:
  - StreamContext: 类型安全的流式上下文数据类,替代原有的 ctx dict
  - StreamProcessor: SSE 解析、预读、嵌套错误检测
  - StreamTelemetryRecorder: 统计记录(Usage/Audit/Candidate)
- 将硬编码配置外置到 settings.py,支持环境变量覆盖:
  - HTTP 超时配置(connect/write/pool)
  - 流式处理配置(预读行数、统计延迟)
  - 并发控制配置(槽位 TTL、缓存预留比例)
This commit is contained in:
fawney19
2025-12-12 15:42:45 +08:00
parent 39defce71c
commit 53bf74429e
6 changed files with 922 additions and 600 deletions

View File

@@ -12,10 +12,13 @@ Chat Handler Base - Chat API 格式的通用基类
- apply_mapped_model(): 模型映射
- get_model_for_url(): URL 模型名
- _extract_usage(): 使用量提取
重构说明:
- StreamContext: 类型安全的流式上下文,替代原有的 ctx dict
- StreamProcessor: 流式响应处理SSE 解析、预读、错误检测)
- StreamTelemetryRecorder: 统计记录Usage、Audit、Candidate
"""
import asyncio
import json
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Callable, Dict, Optional
@@ -24,13 +27,14 @@ from fastapi import BackgroundTasks, Request
from fastapi.responses import JSONResponse, StreamingResponse
from sqlalchemy.orm import Session
from src.api.handlers.base.base_handler import (
BaseMessageHandler,
MessageTelemetry,
)
from src.api.handlers.base.base_handler import BaseMessageHandler
from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.request_builder import PassthroughRequestBuilder
from src.api.handlers.base.response_parser import ResponseParser
from src.api.handlers.base.stream_context import StreamContext
from src.api.handlers.base.stream_processor import StreamProcessor
from src.api.handlers.base.stream_telemetry import StreamTelemetryRecorder
from src.config.settings import config
from src.core.exceptions import (
EmbeddedErrorException,
ProviderAuthException,
@@ -39,7 +43,6 @@ from src.core.exceptions import (
ProviderTimeoutException,
)
from src.core.logger import logger
from src.database import get_db
from src.models.database import (
ApiKey,
Provider,
@@ -48,7 +51,6 @@ from src.models.database import (
User,
)
from src.services.provider.transport import build_provider_url
from src.utils.sse_parser import SSEEventParser
@@ -285,30 +287,15 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
model = getattr(converted_request, "model", original_request_body.get("model", "unknown"))
api_format = self.allowed_api_formats[0]
# 用于跟踪的上下文
ctx = {
"model": model,
"api_format": api_format,
"provider_name": None,
"provider_id": None,
"endpoint_id": None,
"key_id": None,
"attempt_id": None,
"provider_api_format": None, # Provider 的响应格式(用于选择解析器)
"input_tokens": 0,
"output_tokens": 0,
"cached_tokens": 0,
"cache_creation_tokens": 0,
"collected_text": "",
"status_code": 200,
"response_headers": {},
"provider_request_headers": {},
"provider_request_body": None,
"data_count": 0,
"chunk_count": 0,
"has_completion": False,
"parsed_chunks": [], # 收集解析后的 chunks
}
# 创建类型安全的流式上下文
ctx = StreamContext(model=model, api_format=api_format)
# 创建流处理器
stream_processor = StreamProcessor(
request_id=self.request_id,
default_parser=self.parser,
on_streaming_start=self._update_usage_to_streaming,
)
# 定义请求函数
async def stream_request_func(
@@ -318,6 +305,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
) -> AsyncGenerator[bytes, None]:
return await self._execute_stream_request(
ctx,
stream_processor,
provider,
endpoint,
key,
@@ -350,23 +338,39 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
is_stream=True,
capability_requirements=capability_requirements or None,
)
ctx["attempt_id"] = attempt_id
ctx["provider_name"] = provider_name
ctx["provider_id"] = provider_id
ctx["endpoint_id"] = endpoint_id
ctx["key_id"] = key_id
# 更新上下文
ctx.attempt_id = attempt_id
ctx.provider_name = provider_name
ctx.provider_id = provider_id
ctx.endpoint_id = endpoint_id
ctx.key_id = key_id
# 创建遥测记录器
telemetry_recorder = StreamTelemetryRecorder(
request_id=self.request_id,
user_id=str(self.user.id),
api_key_id=str(self.api_key.id),
client_ip=self.client_ip,
format_id=self.FORMAT_ID,
)
# 创建后台任务记录统计
background_tasks = BackgroundTasks()
background_tasks.add_task(
self._record_stream_stats,
telemetry_recorder.record_stream_stats,
ctx,
original_headers,
original_request_body,
self.elapsed_ms(),
)
# 创建监控流
monitored_stream = self._create_monitored_stream(ctx, stream_generator, http_request)
monitored_stream = stream_processor.create_monitored_stream(
ctx,
stream_generator,
http_request.is_disconnected,
)
return StreamingResponse(
monitored_stream,
@@ -381,7 +385,8 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
async def _execute_stream_request(
self,
ctx: Dict,
ctx: StreamContext,
stream_processor: StreamProcessor,
provider: Provider,
endpoint: ProviderEndpoint,
key: ProviderAPIKey,
@@ -390,37 +395,32 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
query_params: Optional[Dict[str, str]] = None,
) -> AsyncGenerator[bytes, None]:
"""执行流式请求并返回流生成器"""
# 重置上下文状态(重试时清除之前的数据,避免累积
ctx["parsed_chunks"] = []
ctx["chunk_count"] = 0
ctx["data_count"] = 0
ctx["has_completion"] = False
ctx["collected_text"] = ""
ctx["input_tokens"] = 0
ctx["output_tokens"] = 0
ctx["cached_tokens"] = 0
ctx["cache_creation_tokens"] = 0
# 重置上下文状态(重试时清除之前的数据)
ctx.reset_for_retry()
ctx["provider_name"] = str(provider.name)
ctx["provider_id"] = str(provider.id)
ctx["endpoint_id"] = str(endpoint.id)
ctx["key_id"] = str(key.id)
ctx["provider_api_format"] = str(endpoint.api_format) if endpoint.api_format else ""
# 更新 Provider 信息
ctx.update_provider_info(
provider_name=str(provider.name),
provider_id=str(provider.id),
endpoint_id=str(endpoint.id),
key_id=str(key.id),
provider_api_format=str(endpoint.api_format) if endpoint.api_format else None,
)
# 获取模型映射
mapped_model = await self._get_mapped_model(
source_model=ctx["model"],
source_model=ctx.model,
provider_id=str(provider.id),
)
# 应用模型映射到请求体
if mapped_model:
ctx["mapped_model"] = mapped_model # 保存映射后的模型名,用于 Usage 记录
ctx.mapped_model = mapped_model
request_body = self.apply_mapped_model(original_request_body, mapped_model)
else:
request_body = dict(original_request_body)
# 准备发送给 Provider 的请求体(子类可覆盖以移除不需要的字段)
# 准备发送给 Provider 的请求体
request_body = self.prepare_provider_request_body(request_body)
# 构建请求
@@ -432,11 +432,11 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
is_stream=True,
)
ctx["provider_request_headers"] = provider_headers
ctx["provider_request_body"] = provider_payload
ctx.provider_request_headers = provider_headers
ctx.provider_request_body = provider_payload
# 获取 URL 模型名(兜底使用 ctx 中的 model确保 Gemini 等格式能正确构建 URL
url_model = self.get_model_for_url(request_body, mapped_model) or ctx["model"]
# 获取 URL 模型名
url_model = self.get_model_for_url(request_body, mapped_model) or ctx.model
url = build_provider_url(
endpoint,
@@ -445,15 +445,17 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
is_stream=True,
)
logger.debug(f" [{self.request_id}] 发送流式请求: Provider={provider.name}, "
f"模型={ctx['model']} -> {mapped_model or '无映射'}")
logger.debug(
f" [{self.request_id}] 发送流式请求: Provider={provider.name}, "
f"模型={ctx.model} -> {mapped_model or '无映射'}"
)
# 发送请求
# 发送请求(使用配置中的超时设置)
timeout_config = httpx.Timeout(
connect=10.0,
connect=config.http_connect_timeout,
read=float(endpoint.timeout),
write=60.0, # 写入超时增加到60秒支持大请求体如包含图片的长对话
pool=10.0,
write=config.http_write_timeout,
pool=config.http_pool_timeout,
)
http_client = httpx.AsyncClient(timeout=timeout_config, follow_redirects=True)
@@ -463,17 +465,21 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
)
stream_response = await response_ctx.__aenter__()
ctx["status_code"] = stream_response.status_code
ctx["response_headers"] = dict(stream_response.headers)
ctx.status_code = stream_response.status_code
ctx.response_headers = dict(stream_response.headers)
stream_response.raise_for_status()
# 创建行迭代器(只创建一次,后续会继续使用)
# 创建行迭代器
line_iterator = stream_response.aiter_lines()
# 预读第一个数据块检测嵌套错误HTTP 200 但响应体包含错误)
prefetched_lines = await self._prefetch_and_check_embedded_error(
line_iterator, provider, endpoint, ctx
# 预读检测嵌套错误
prefetched_lines = await stream_processor.prefetch_and_check_error(
line_iterator,
provider,
endpoint,
ctx,
max_prefetch_lines=config.stream_prefetch_lines,
)
except httpx.HTTPStatusError as e:
@@ -483,7 +489,6 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
raise
except EmbeddedErrorException:
# 嵌套错误需要触发重试,关闭连接后重新抛出
try:
await response_ctx.__aexit__(None, None, None)
except Exception:
@@ -495,8 +500,8 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
await http_client.aclose()
raise
# 创建流生成器(带预读数据,使用同一个迭代器)
return self._create_response_stream_with_prefetch(
# 创建流生成器
return stream_processor.create_response_stream(
ctx,
line_iterator,
response_ctx,
@@ -504,518 +509,9 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
prefetched_lines,
)
async def _create_response_stream(
self,
ctx: Dict,
stream_response: httpx.Response,
response_ctx: Any,
http_client: httpx.AsyncClient,
) -> AsyncGenerator[bytes, None]:
"""创建响应流生成器"""
try:
sse_parser = SSEEventParser()
streaming_status_updated = False
async for line in stream_response.aiter_lines():
# 在第一次输出数据前更新状态为 streaming
if not streaming_status_updated:
self._update_usage_to_streaming()
streaming_status_updated = True
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
yield b"\n"
continue
ctx["chunk_count"] += 1
yield (line + "\n").encode("utf-8")
for event in events:
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
# 处理剩余事件
for event in sse_parser.flush():
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
except GeneratorExit:
raise
finally:
try:
await response_ctx.__aexit__(None, None, None)
except Exception:
pass
try:
await http_client.aclose()
except Exception:
pass
async def _prefetch_and_check_embedded_error(
self,
line_iterator: Any,
provider: Provider,
endpoint: ProviderEndpoint,
ctx: Dict,
) -> list:
"""
预读流的前几行,检测嵌套错误
某些 Provider如 Gemini可能返回 HTTP 200但在响应体中包含错误信息。
这种情况需要在流开始输出之前检测,以便触发重试逻辑。
Args:
line_iterator: 行迭代器aiter_lines() 返回的迭代器)
provider: Provider 对象
endpoint: Endpoint 对象
ctx: 上下文字典
Returns:
预读的行列表(需要在后续流中先输出)
Raises:
EmbeddedErrorException: 如果检测到嵌套错误
"""
prefetched_lines: list = []
max_prefetch_lines = 5 # 最多预读5行来检测错误
try:
# 获取对应格式的解析器
provider_parser = self._get_provider_parser(ctx)
line_count = 0
async for line in line_iterator:
prefetched_lines.append(line)
line_count += 1
# 解析数据
normalized_line = line.rstrip("\r")
if not normalized_line or normalized_line.startswith(":"):
# 空行或注释行,继续预读
if line_count >= max_prefetch_lines:
break
continue
# 尝试解析 SSE 数据
data_str = normalized_line
if normalized_line.startswith("data: "):
data_str = normalized_line[6:]
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
except json.JSONDecodeError:
# 不是有效 JSON可能是部分数据继续
if line_count >= max_prefetch_lines:
break
continue
# 使用解析器检查是否为错误响应
if isinstance(data, dict) and provider_parser.is_error_response(data):
# 提取错误信息
parsed = provider_parser.parse_response(data, 200)
logger.warning(f" [{self.request_id}] 检测到嵌套错误: "
f"Provider={provider.name}, "
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}")
raise EmbeddedErrorException(
provider_name=str(provider.name),
error_code=(
int(parsed.error_type)
if parsed.error_type and parsed.error_type.isdigit()
else None
),
error_message=parsed.error_message,
error_status=parsed.error_type,
)
# 预读到有效数据,没有错误,停止预读
break
except EmbeddedErrorException:
# 重新抛出嵌套错误
raise
except Exception as e:
# 其他异常(如网络错误)在预读阶段发生,记录日志但不中断
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
return prefetched_lines
async def _create_response_stream_with_prefetch(
self,
ctx: Dict,
line_iterator: Any,
response_ctx: Any,
http_client: httpx.AsyncClient,
prefetched_lines: list,
) -> AsyncGenerator[bytes, None]:
"""创建响应流生成器(带预读数据)"""
try:
sse_parser = SSEEventParser()
# 在第一次输出数据前更新状态为 streaming
if prefetched_lines:
self._update_usage_to_streaming()
# 先输出预读的数据
for line in prefetched_lines:
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
yield b"\n"
continue
ctx["chunk_count"] += 1
yield (line + "\n").encode("utf-8")
for event in events:
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
# 继续输出剩余的流数据(使用同一个迭代器)
async for line in line_iterator:
normalized_line = line.rstrip("\r")
events = sse_parser.feed_line(normalized_line)
if normalized_line == "":
for event in events:
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
yield b"\n"
continue
ctx["chunk_count"] += 1
yield (line + "\n").encode("utf-8")
for event in events:
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
# 处理剩余事件
for event in sse_parser.flush():
self._handle_sse_event(ctx, event.get("event"), event.get("data") or "")
except GeneratorExit:
raise
finally:
try:
await response_ctx.__aexit__(None, None, None)
except Exception:
pass
try:
await http_client.aclose()
except Exception:
pass
def _get_provider_parser(self, ctx: Dict) -> ResponseParser:
"""
获取 Provider 格式的解析器
根据 Provider 的 API 格式选择正确的解析器,
而不是根据请求格式选择。
"""
provider_format = ctx.get("provider_api_format")
if provider_format:
try:
return get_parser_for_format(provider_format)
except KeyError:
pass
# 回退到默认解析器
return self.parser
def _handle_sse_event(
self,
ctx: Dict,
event_name: Optional[str],
data_str: str,
) -> None:
"""处理 SSE 事件"""
if not data_str:
return
if data_str == "[DONE]":
ctx["has_completion"] = True
return
try:
data = json.loads(data_str)
except json.JSONDecodeError:
return
ctx["data_count"] += 1
if not isinstance(data, dict):
return
# 收集原始 chunk 数据
ctx["parsed_chunks"].append(data)
# 根据 Provider 格式选择解析器
provider_parser = self._get_provider_parser(ctx)
# 使用解析器提取 usage
usage = provider_parser.extract_usage_from_response(data)
if usage:
ctx["input_tokens"] = usage.get("input_tokens", ctx["input_tokens"])
ctx["output_tokens"] = usage.get("output_tokens", ctx["output_tokens"])
ctx["cached_tokens"] = usage.get("cache_read_tokens", ctx["cached_tokens"])
ctx["cache_creation_tokens"] = usage.get(
"cache_creation_tokens", ctx["cache_creation_tokens"]
)
# 提取文本
text = provider_parser.extract_text_content(data)
if text:
ctx["collected_text"] += text
# 检查完成
event_type = event_name or data.get("type", "")
if event_type in ("response.completed", "message_stop"):
ctx["has_completion"] = True
async def _create_monitored_stream(
self,
ctx: Dict,
stream_generator: AsyncGenerator[bytes, None],
http_request: Request,
) -> AsyncGenerator[bytes, None]:
"""创建带监控的流生成器"""
try:
async for chunk in stream_generator:
if await http_request.is_disconnected():
logger.warning(f"ID:{self.request_id} | Client disconnected")
# 客户端断开时设置 499 状态码Client Closed Request
# 注意Provider 可能已经成功返回数据,但客户端未完整接收
ctx["status_code"] = 499
break
yield chunk
except asyncio.CancelledError:
ctx["status_code"] = 499
raise
except Exception:
ctx["status_code"] = 500
raise
async def _record_stream_stats(
self,
ctx: Dict,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
) -> None:
"""记录流式统计信息"""
response_time_ms = self.elapsed_ms()
bg_db = None
try:
await asyncio.sleep(0.1)
if not ctx["provider_name"]:
# 即使没有 provider_name也要尝试更新状态为 failed
await self._update_usage_status_on_error(
response_time_ms=response_time_ms,
error_message="Provider name not available",
)
return
db_gen = get_db()
bg_db = next(db_gen)
try:
from src.models.database import ApiKey as ApiKeyModel
user = bg_db.query(User).filter(User.id == self.user.id).first()
api_key_obj = (
bg_db.query(ApiKeyModel).filter(ApiKeyModel.id == self.api_key.id).first()
)
if not user or not api_key_obj:
logger.warning(f"[{self.request_id}] User or ApiKey not found, updating status directly")
await self._update_usage_status_directly(
bg_db,
status="completed" if ctx["status_code"] == 200 else "failed",
response_time_ms=response_time_ms,
status_code=ctx["status_code"],
)
return
bg_telemetry = MessageTelemetry(
bg_db, user, api_key_obj, self.request_id, self.client_ip
)
actual_request_body = ctx["provider_request_body"] or original_request_body
# 构建响应体(与 CLI 模式一致)
response_body = {
"chunks": ctx["parsed_chunks"],
"metadata": {
"stream": True,
"total_chunks": len(ctx["parsed_chunks"]),
"data_count": ctx["data_count"],
"has_completion": ctx["has_completion"],
"response_time_ms": response_time_ms,
},
}
# 根据状态码决定记录成功还是失败
# 499 = 客户端断开连接503 = 服务不可用(如流中断)
status_code: int = ctx.get("status_code") or 200
if status_code >= 400:
# 记录失败的 Usage但使用已收到的预估 token 信息(来自 message_start
# 这样即使请求中断,也能记录预估成本
await bg_telemetry.record_failure(
provider=ctx.get("provider_name") or "unknown",
model=ctx["model"],
response_time_ms=response_time_ms,
status_code=status_code,
error_message=ctx.get("error_message") or f"HTTP {status_code}",
request_headers=original_headers,
request_body=actual_request_body,
is_stream=True,
api_format=ctx["api_format"],
provider_request_headers=ctx["provider_request_headers"],
# 预估 token 信息(来自 message_start 事件)
input_tokens=ctx.get("input_tokens", 0),
output_tokens=ctx.get("output_tokens", 0),
cache_creation_tokens=ctx.get("cache_creation_tokens", 0),
cache_read_tokens=ctx.get("cached_tokens", 0),
response_body=response_body,
# 模型映射信息
target_model=ctx.get("mapped_model"),
)
logger.debug(f"{self.FORMAT_ID} 流式响应中断")
# 简洁的请求失败摘要(包含预估 token 信息)
logger.info(f"[FAIL] {self.request_id[:8]} | {ctx['model']} | {ctx.get('provider_name', 'unknown')} | {response_time_ms}ms | "
f"{status_code} | in:{ctx.get('input_tokens', 0)} out:{ctx.get('output_tokens', 0)} cache:{ctx.get('cached_tokens', 0)}")
else:
await bg_telemetry.record_success(
provider=ctx["provider_name"],
model=ctx["model"],
input_tokens=ctx["input_tokens"],
output_tokens=ctx["output_tokens"],
response_time_ms=response_time_ms,
status_code=status_code,
request_headers=original_headers,
request_body=actual_request_body,
response_headers=ctx["response_headers"],
response_body=response_body,
cache_creation_tokens=ctx["cache_creation_tokens"],
cache_read_tokens=ctx["cached_tokens"],
is_stream=True,
provider_request_headers=ctx["provider_request_headers"],
api_format=ctx["api_format"],
provider_id=ctx["provider_id"],
provider_endpoint_id=ctx["endpoint_id"],
provider_api_key_id=ctx["key_id"],
# 模型映射信息
target_model=ctx.get("mapped_model"),
)
logger.debug(f"{self.FORMAT_ID} 流式响应完成")
# 简洁的请求完成摘要
logger.info(f"[OK] {self.request_id[:8]} | {ctx['model']} | {ctx.get('provider_name', 'unknown')} | {response_time_ms}ms | "
f"in:{ctx.get('input_tokens', 0) or 0} out:{ctx.get('output_tokens', 0) or 0}")
# 更新候选记录的最终状态和延迟时间
# 注意RequestExecutor 会在流开始时过早地标记成功(只记录了连接建立的时间)
# 这里用流传输完成后的实际时间覆盖
if ctx.get("attempt_id"):
from src.services.request.candidate import RequestCandidateService
# 根据状态码决定是成功还是失败(复用上面已定义的 status_code
# 499 = 客户端断开连接,应标记为失败
# 503 = 服务不可用(如流中断),应标记为失败
if status_code and status_code >= 400:
RequestCandidateService.mark_candidate_failed(
db=bg_db,
candidate_id=ctx["attempt_id"],
error_type="client_disconnected" if status_code == 499 else "stream_error",
error_message=ctx.get("error_message") or f"HTTP {status_code}",
status_code=status_code,
latency_ms=response_time_ms,
extra_data={
"stream_completed": False,
"data_count": ctx.get("data_count", 0),
},
)
else:
RequestCandidateService.mark_candidate_success(
db=bg_db,
candidate_id=ctx["attempt_id"],
status_code=status_code,
latency_ms=response_time_ms,
extra_data={
"stream_completed": True,
"data_count": ctx.get("data_count", 0),
},
)
finally:
if bg_db:
bg_db.close()
except Exception as e:
logger.exception("记录流式统计信息时出错")
# 确保即使出错也要更新状态,避免 pending 状态卡住
await self._update_usage_status_on_error(
response_time_ms=response_time_ms,
error_message=f"记录统计信息失败: {str(e)[:200]}",
)
# _update_usage_to_streaming 方法已移至基类 BaseMessageHandler
async def _update_usage_status_on_error(
self,
response_time_ms: int,
error_message: str,
) -> None:
"""在记录失败时更新 Usage 状态,避免卡在 pending"""
try:
db_gen = get_db()
error_db = next(db_gen)
try:
await self._update_usage_status_directly(
error_db,
status="failed",
response_time_ms=response_time_ms,
status_code=500,
error_message=error_message,
)
finally:
error_db.close()
except Exception as inner_e:
logger.error(f"[{self.request_id}] 更新 Usage 状态失败: {inner_e}")
async def _update_usage_status_directly(
self,
db: Session,
status: str,
response_time_ms: int,
status_code: int = 200,
error_message: Optional[str] = None,
) -> None:
"""直接更新 Usage 表的状态字段"""
try:
from src.models.database import Usage
usage = db.query(Usage).filter(Usage.request_id == self.request_id).first()
if usage:
setattr(usage, "status", status)
setattr(usage, "status_code", status_code)
setattr(usage, "response_time_ms", response_time_ms)
if error_message:
setattr(usage, "error_message", error_message)
db.commit()
logger.debug(f"[{self.request_id}] Usage 状态已更新: {status}")
except Exception as e:
logger.error(f"[{self.request_id}] 直接更新 Usage 状态失败: {e}")
async def _record_stream_failure(
self,
ctx: Dict,
ctx: StreamContext,
error: Exception,
original_headers: Dict[str, str],
original_request_body: Dict[str, Any],
@@ -1031,21 +527,20 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
elif isinstance(error, ProviderTimeoutException):
status_code = 504
actual_request_body = ctx.get("provider_request_body") or original_request_body
actual_request_body = ctx.provider_request_body or original_request_body
await self.telemetry.record_failure(
provider=ctx.get("provider_name") or "unknown",
model=ctx["model"],
provider=ctx.provider_name or "unknown",
model=ctx.model,
response_time_ms=response_time_ms,
status_code=status_code,
error_message=str(error),
request_headers=original_headers,
request_body=actual_request_body,
is_stream=True,
api_format=ctx["api_format"],
provider_request_headers=ctx.get("provider_request_headers") or {},
# 模型映射信息
target_model=ctx.get("mapped_model"),
api_format=ctx.api_format,
provider_request_headers=ctx.provider_request_headers,
target_model=ctx.mapped_model,
)
# ==================== 非流式处理 ====================