refactor: 使用 asyncio.wait_for 控制请求整体超时

将 endpoint.timeout 从 httpx 的 read 超时改为 asyncio.wait_for 控制,
更精确地管理"建立连接 + 获取首字节"阶段的整体超时。

主要改动:
- HTTP 超时配置改用全局 config 参数
- endpoint.timeout 作为 asyncio.wait_for 的整体超时
- 增加 asyncio.TimeoutError 处理和连接清理逻辑
- 增加防御性空值检查
This commit is contained in:
fawney19
2026-01-07 18:17:35 +08:00
parent 00f6fafcfc
commit 6885cf1f6d
2 changed files with 179 additions and 62 deletions

View File

@@ -19,6 +19,7 @@ Chat Handler Base - Chat API 格式的通用基类
- StreamTelemetryRecorder: 统计记录Usage、Audit、Candidate
"""
import asyncio
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Callable, Dict, Optional
@@ -55,7 +56,6 @@ from src.models.database import (
from src.services.provider.transport import build_provider_url
class ChatHandlerBase(BaseMessageHandler, ABC):
"""
Chat Handler 基类
@@ -89,7 +89,9 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
user_agent: str,
start_time: float,
allowed_api_formats: Optional[list] = None,
adapter_detector: Optional[Callable[[Dict[str, str], Optional[Dict[str, Any]]], Dict[str, bool]]] = None,
adapter_detector: Optional[
Callable[[Dict[str, str], Optional[Dict[str, Any]]], Dict[str, bool]]
] = None,
):
allowed = allowed_api_formats or [self.FORMAT_ID]
super().__init__(
@@ -459,14 +461,19 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
f"模型={ctx.model} -> {mapped_model or '无映射'}"
)
# 发送请求(使用配置中的超时设置)
# 配置 HTTP 超时
# 注意read timeout 用于检测连接断开,不是整体请求超时
# 整体请求超时由 asyncio.wait_for 控制,使用 endpoint.timeout
timeout_config = httpx.Timeout(
connect=config.http_connect_timeout,
read=float(endpoint.timeout),
read=config.http_read_timeout, # 使用全局配置,用于检测连接断开
write=config.http_write_timeout,
pool=config.http_pool_timeout,
)
# endpoint.timeout 作为整体请求超时(建立连接 + 获取首字节)
request_timeout = float(endpoint.timeout or 300)
# 创建 HTTP 客户端(支持代理配置)
from src.clients.http_client import HTTPClientPool
@@ -474,7 +481,15 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
proxy_config=endpoint.proxy,
timeout=timeout_config,
)
try:
# 用于存储内部函数的结果(必须在函数定义前声明,供 nonlocal 使用)
byte_iterator: Any = None
prefetched_chunks: Any = None
response_ctx: Any = None
async def _connect_and_prefetch() -> None:
"""建立连接并预读首字节(受整体超时控制)"""
nonlocal byte_iterator, prefetched_chunks, response_ctx
response_ctx = http_client.stream(
"POST", url, json=provider_payload, headers=provider_headers
)
@@ -497,6 +512,28 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
max_prefetch_lines=config.stream_prefetch_lines,
)
try:
# 使用 asyncio.wait_for 包裹整个"建立连接 + 获取首字节"阶段
# endpoint.timeout 控制整体超时,避免上游长时间无响应
await asyncio.wait_for(_connect_and_prefetch(), timeout=request_timeout)
except asyncio.TimeoutError:
# 整体请求超时(建立连接 + 获取首字节)
# 清理可能已建立的连接上下文
if response_ctx is not None:
try:
await response_ctx.__aexit__(None, None, None)
except Exception:
pass
await http_client.aclose()
logger.warning(
f" [{self.request_id}] 请求超时: Provider={provider.name}, timeout={request_timeout}s"
)
raise ProviderTimeoutException(
provider_name=str(provider.name),
timeout=int(request_timeout),
)
except httpx.HTTPStatusError as e:
error_text = await self._extract_error_text(e)
logger.error(f"Provider 返回错误: {e.response.status_code}\n Response: {error_text}")
@@ -507,7 +544,8 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
except EmbeddedErrorException:
try:
await response_ctx.__aexit__(None, None, None)
if response_ctx is not None:
await response_ctx.__aexit__(None, None, None)
except Exception:
pass
await http_client.aclose()
@@ -517,6 +555,11 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
await http_client.aclose()
raise
# 类型断言:成功执行后这些变量不会为 None
assert byte_iterator is not None
assert prefetched_chunks is not None
assert response_ctx is not None
# 创建流生成器(传入字节流迭代器)
return stream_processor.create_response_stream(
ctx,
@@ -639,17 +682,23 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
is_stream=False,
)
logger.info(f" [{self.request_id}] 发送非流式请求: Provider={provider.name}, "
f"模型={model} -> {mapped_model or '无映射'}")
logger.info(
f" [{self.request_id}] 发送非流式请求: Provider={provider.name}, "
f"模型={model} -> {mapped_model or '无映射'}"
)
logger.debug(f" [{self.request_id}] 请求URL: {url}")
logger.debug(f" [{self.request_id}] 请求体stream字段: {provider_payload.get('stream', 'N/A')}")
logger.debug(
f" [{self.request_id}] 请求体stream字段: {provider_payload.get('stream', 'N/A')}"
)
# 创建 HTTP 客户端(支持代理配置)
# endpoint.timeout 作为整体请求超时
from src.clients.http_client import HTTPClientPool
request_timeout = float(endpoint.timeout or 300)
http_client = HTTPClientPool.create_client_with_proxy(
proxy_config=endpoint.proxy,
timeout=httpx.Timeout(float(endpoint.timeout)),
timeout=httpx.Timeout(request_timeout),
)
async with http_client:
resp = await http_client.post(url, json=provider_payload, headers=provider_hdrs)
@@ -670,7 +719,9 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
error_body = ""
try:
error_body = resp.text[:1000]
logger.error(f" [{self.request_id}] 上游返回5xx错误: status={resp.status_code}, body={error_body[:500]}")
logger.error(
f" [{self.request_id}] 上游返回5xx错误: status={resp.status_code}, body={error_body[:500]}"
)
except Exception:
pass
raise ProviderNotAvailableException(
@@ -684,7 +735,9 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
error_body = ""
try:
error_body = resp.text[:1000]
logger.warning(f" [{self.request_id}] 上游返回非200: status={resp.status_code}, body={error_body[:500]}")
logger.warning(
f" [{self.request_id}] 上游返回非200: status={resp.status_code}, body={error_body[:500]}"
)
except Exception:
pass
raise ProviderNotAvailableException(
@@ -765,8 +818,10 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
logger.debug(f"{self.FORMAT_ID} 非流式响应完成")
# 简洁的请求完成摘要
logger.info(f"[OK] {self.request_id[:8]} | {model} | {provider_name or 'unknown'} | {response_time_ms}ms | "
f"in:{input_tokens or 0} out:{output_tokens or 0}")
logger.info(
f"[OK] {self.request_id[:8]} | {model} | {provider_name or 'unknown'} | {response_time_ms}ms | "
f"in:{input_tokens or 0} out:{output_tokens or 0}"
)
return JSONResponse(status_code=status_code, content=response_json)
@@ -807,8 +862,6 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
error_bytes = await e.response.aread()
return error_bytes.decode("utf-8", errors="replace")
else:
return (
e.response.text if hasattr(e.response, "_content") else "Unable to read"
)
return e.response.text if hasattr(e.response, "_content") else "Unable to read"
except Exception as decode_error:
return f"Unable to read error: {decode_error}"

View File

@@ -33,19 +33,21 @@ from src.api.handlers.base.base_handler import (
)
from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.request_builder import PassthroughRequestBuilder
from src.api.handlers.base.stream_context import StreamContext
from src.api.handlers.base.utils import (
build_sse_headers,
check_html_response,
check_prefetched_response_error,
)
from src.core.error_utils import extract_error_message
# 直接从具体模块导入,避免循环依赖
from src.api.handlers.base.response_parser import (
ResponseParser,
StreamStats,
)
from src.api.handlers.base.stream_context import StreamContext
from src.api.handlers.base.utils import (
build_sse_headers,
check_html_response,
check_prefetched_response_error,
)
from src.config.constants import StreamDefaults
from src.config.settings import config
from src.core.error_utils import extract_error_message
from src.core.exceptions import (
EmbeddedErrorException,
ProviderAuthException,
@@ -62,8 +64,6 @@ from src.models.database import (
ProviderEndpoint,
User,
)
from src.config.constants import StreamDefaults
from src.config.settings import config
from src.services.provider.transport import build_provider_url
from src.utils.sse_parser import SSEEventParser
from src.utils.timeout import read_first_chunk_with_ttfb_timeout
@@ -100,7 +100,9 @@ class CliMessageHandlerBase(BaseMessageHandler):
user_agent: str,
start_time: float,
allowed_api_formats: Optional[list] = None,
adapter_detector: Optional[Callable[[Dict[str, str], Optional[Dict[str, Any]]], Dict[str, bool]]] = None,
adapter_detector: Optional[
Callable[[Dict[str, str], Optional[Dict[str, Any]]], Dict[str, bool]]
] = None,
):
allowed = allowed_api_formats or [self.FORMAT_ID]
super().__init__(
@@ -158,7 +160,9 @@ class CliMessageHandlerBase(BaseMessageHandler):
mapper = ModelMapperMiddleware(self.db)
mapping = await mapper.get_mapping(source_model, provider_id)
logger.debug(f"[CLI] _get_mapped_model: source={source_model}, provider={provider_id[:8]}..., mapping={mapping}")
logger.debug(
f"[CLI] _get_mapped_model: source={source_model}, provider={provider_id[:8]}..., mapping={mapping}"
)
if mapping and mapping.model:
# 使用 select_provider_model_name 支持模型映射功能
@@ -168,7 +172,9 @@ class CliMessageHandlerBase(BaseMessageHandler):
mapped_name = mapping.model.select_provider_model_name(
affinity_key, api_format=self.FORMAT_ID
)
logger.debug(f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)")
logger.debug(
f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)"
)
return mapped_name
logger.debug(f"[CLI] 无模型映射,使用原始名称: {source_model}")
@@ -459,18 +465,26 @@ class CliMessageHandlerBase(BaseMessageHandler):
is_stream=True, # CLI handler 处理流式请求
)
# 配置超时
# 配置 HTTP 超时
# 注意read timeout 用于检测连接断开,不是整体请求超时
# 整体请求超时由 _connect_and_prefetch 内部的 asyncio.wait_for 控制
timeout_config = httpx.Timeout(
connect=10.0,
read=float(endpoint.timeout),
write=60.0, # 写入超时增加到60秒支持大请求体如包含图片的长对话
pool=10.0,
connect=config.http_connect_timeout,
read=config.http_read_timeout, # 使用全局配置,用于检测连接断开
write=config.http_write_timeout,
pool=config.http_pool_timeout,
)
logger.debug(f" └─ [{self.request_id}] 发送流式请求: "
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
f"Key=***{key.api_key[-4:]}, "
f"原始模型={ctx.model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}")
# endpoint.timeout 作为整体请求超时(建立连接 + 获取首字节)
request_timeout = float(endpoint.timeout or 300)
logger.debug(
f" └─ [{self.request_id}] 发送流式请求: "
f"Provider={provider.name}, Endpoint={endpoint.id[:8] if endpoint.id else 'N/A'}..., "
f"Key=***{key.api_key[-4:] if key.api_key else 'N/A'}, "
f"原始模型={ctx.model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}, "
f"timeout={request_timeout}s"
)
# 创建 HTTP 客户端(支持代理配置)
from src.clients.http_client import HTTPClientPool
@@ -479,7 +493,15 @@ class CliMessageHandlerBase(BaseMessageHandler):
proxy_config=endpoint.proxy,
timeout=timeout_config,
)
try:
# 用于存储内部函数的结果(必须在函数定义前声明,供 nonlocal 使用)
byte_iterator: Any = None
prefetched_chunks: Any = None
response_ctx: Any = None
async def _connect_and_prefetch() -> None:
"""建立连接并预读首字节(受整体超时控制)"""
nonlocal byte_iterator, prefetched_chunks, response_ctx
response_ctx = http_client.stream(
"POST", url, json=provider_payload, headers=provider_headers
)
@@ -500,9 +522,33 @@ class CliMessageHandlerBase(BaseMessageHandler):
byte_iterator, provider, endpoint, ctx
)
try:
# 使用 asyncio.wait_for 包裹整个"建立连接 + 获取首字节"阶段
# endpoint.timeout 控制整体超时,避免上游长时间无响应
await asyncio.wait_for(_connect_and_prefetch(), timeout=request_timeout)
except asyncio.TimeoutError:
# 整体请求超时(建立连接 + 获取首字节)
# 清理可能已建立的连接上下文
if response_ctx is not None:
try:
await response_ctx.__aexit__(None, None, None)
except Exception:
pass
await http_client.aclose()
logger.warning(
f" [{self.request_id}] 请求超时: Provider={provider.name}, timeout={request_timeout}s"
)
raise ProviderTimeoutException(
provider_name=str(provider.name),
timeout=int(request_timeout),
)
except httpx.HTTPStatusError as e:
error_text = await self._extract_error_text(e)
logger.error(f"Provider 返回错误状态: {e.response.status_code}\n Response: {error_text}")
logger.error(
f"Provider 返回错误状态: {e.response.status_code}\n Response: {error_text}"
)
await http_client.aclose()
# 将上游错误信息附加到异常,以便故障转移时能够返回给客户端
e.upstream_response = error_text # type: ignore[attr-defined]
@@ -511,7 +557,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
except EmbeddedErrorException:
# 嵌套错误需要触发重试,关闭连接后重新抛出
try:
await response_ctx.__aexit__(None, None, None)
if response_ctx is not None:
await response_ctx.__aexit__(None, None, None)
except Exception:
pass
await http_client.aclose()
@@ -521,6 +568,11 @@ class CliMessageHandlerBase(BaseMessageHandler):
await http_client.aclose()
raise
# 类型断言:成功执行后这些变量不会为 None
assert byte_iterator is not None
assert prefetched_chunks is not None
assert response_ctx is not None
# 创建流生成器(带预读数据,使用同一个迭代器)
return self._create_response_stream_with_prefetch(
ctx,
@@ -593,7 +645,9 @@ class CliMessageHandlerBase(BaseMessageHandler):
},
}
self._mark_first_output(ctx, output_state)
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode(
"utf-8"
)
return # 结束生成器
# 格式转换或直接透传
@@ -801,10 +855,12 @@ class CliMessageHandlerBase(BaseMessageHandler):
if isinstance(data, dict) and provider_parser.is_error_response(data):
# 提取错误信息
parsed = provider_parser.parse_response(data, 200)
logger.warning(f" [{self.request_id}] 检测到嵌套错误: "
logger.warning(
f" [{self.request_id}] 检测到嵌套错误: "
f"Provider={provider.name}, "
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}")
f"message={parsed.error_message}"
)
raise EmbeddedErrorException(
provider_name=str(provider.name),
error_code=(
@@ -849,14 +905,12 @@ class CliMessageHandlerBase(BaseMessageHandler):
raise
except (OSError, IOError) as e:
# 网络 I/O 异常:记录警告,可能需要重试
logger.warning(
f" [{self.request_id}] 预读流时发生网络异常: {type(e).__name__}: {e}"
)
logger.warning(f" [{self.request_id}] 预读流时发生网络异常: {type(e).__name__}: {e}")
except Exception as e:
# 未预期的严重异常:记录错误并重新抛出,避免掩盖问题
logger.error(
f" [{self.request_id}] 预读流时发生严重异常: {type(e).__name__}: {e}",
exc_info=True
exc_info=True,
)
raise
@@ -979,7 +1033,9 @@ class CliMessageHandlerBase(BaseMessageHandler):
},
}
self._mark_first_output(ctx, output_state)
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode(
"utf-8"
)
return
# 格式转换或直接透传
@@ -1255,8 +1311,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
)
logger.debug(f"{self.FORMAT_ID} 流式响应中断")
# 简洁的请求失败摘要(包含预估 token 信息)
logger.info(f"[FAIL] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name} | {response_time_ms}ms | "
f"{ctx.status_code} | in:{actual_input_tokens} out:{ctx.output_tokens} cache:{ctx.cached_tokens}")
logger.info(
f"[FAIL] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name} | {response_time_ms}ms | "
f"{ctx.status_code} | in:{actual_input_tokens} out:{ctx.output_tokens} cache:{ctx.cached_tokens}"
)
else:
# 在记录统计前,允许子类从 parsed_chunks 中提取额外的元数据
self._finalize_stream_metadata(ctx)
@@ -1289,9 +1347,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
)
logger.debug(f"{self.FORMAT_ID} 流式响应完成")
# 简洁的请求完成摘要(两行格式)
line1 = (
f"[OK] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name}"
)
line1 = f"[OK] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name}"
if ctx.first_byte_time_ms:
line1 += f" | TTFB: {ctx.first_byte_time_ms}ms"
@@ -1314,7 +1370,9 @@ class CliMessageHandlerBase(BaseMessageHandler):
RequestCandidateService.mark_candidate_failed(
db=bg_db,
candidate_id=ctx.attempt_id,
error_type="client_disconnected" if ctx.status_code == 499 else "stream_error",
error_type=(
"client_disconnected" if ctx.status_code == 499 else "stream_error"
),
error_message=ctx.error_message or f"HTTP {ctx.status_code}",
status_code=ctx.status_code,
latency_ms=response_time_ms,
@@ -1469,17 +1527,21 @@ class CliMessageHandlerBase(BaseMessageHandler):
is_stream=False, # 非流式请求
)
logger.info(f" └─ [{self.request_id}] 发送非流式请求: "
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
f"Key=***{key.api_key[-4:]}, "
f"原始模型={model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}")
logger.info(
f" └─ [{self.request_id}] 发送非流式请求: "
f"Provider={provider.name}, Endpoint={endpoint.id[:8] if endpoint.id else 'N/A'}..., "
f"Key=***{key.api_key[-4:] if key.api_key else 'N/A'}, "
f"原始模型={model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}"
)
# 创建 HTTP 客户端(支持代理配置)
# endpoint.timeout 作为整体请求超时
from src.clients.http_client import HTTPClientPool
request_timeout = float(endpoint.timeout or 300)
http_client = HTTPClientPool.create_client_with_proxy(
proxy_config=endpoint.proxy,
timeout=httpx.Timeout(float(endpoint.timeout)),
timeout=httpx.Timeout(request_timeout),
)
async with http_client:
resp = await http_client.post(url, json=provider_payload, headers=provider_headers)
@@ -1525,9 +1587,11 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 记录原始响应信息用于调试
content_type = resp.headers.get("content-type", "unknown")
content_encoding = resp.headers.get("content-encoding", "none")
logger.error(f"[{self.request_id}] 无法解析响应 JSON: {e}, "
logger.error(
f"[{self.request_id}] 无法解析响应 JSON: {e}, "
f"Content-Type: {content_type}, Content-Encoding: {content_encoding}, "
f"响应长度: {len(resp.content)} bytes")
f"响应长度: {len(resp.content)} bytes"
)
raise ProviderNotAvailableException(
f"提供商返回无效响应: {provider.name}, 无法解析 JSON: {str(e)[:100]}"
)