From 6885cf1f6dad2c8772bab6b071d5af357420f0ee Mon Sep 17 00:00:00 2001 From: fawney19 Date: Wed, 7 Jan 2026 18:17:35 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E4=BD=BF=E7=94=A8=20asyncio.wait?= =?UTF-8?q?=5Ffor=20=E6=8E=A7=E5=88=B6=E8=AF=B7=E6=B1=82=E6=95=B4=E4=BD=93?= =?UTF-8?q?=E8=B6=85=E6=97=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 将 endpoint.timeout 从 httpx 的 read 超时改为 asyncio.wait_for 控制, 更精确地管理"建立连接 + 获取首字节"阶段的整体超时。 主要改动: - HTTP 超时配置改用全局 config 参数 - endpoint.timeout 作为 asyncio.wait_for 的整体超时 - 增加 asyncio.TimeoutError 处理和连接清理逻辑 - 增加防御性空值检查 --- src/api/handlers/base/chat_handler_base.py | 87 +++++++++--- src/api/handlers/base/cli_handler_base.py | 154 +++++++++++++++------ 2 files changed, 179 insertions(+), 62 deletions(-) diff --git a/src/api/handlers/base/chat_handler_base.py b/src/api/handlers/base/chat_handler_base.py index 46fd472..3884ff2 100644 --- a/src/api/handlers/base/chat_handler_base.py +++ b/src/api/handlers/base/chat_handler_base.py @@ -19,6 +19,7 @@ Chat Handler Base - Chat API 格式的通用基类 - StreamTelemetryRecorder: 统计记录(Usage、Audit、Candidate) """ +import asyncio from abc import ABC, abstractmethod from typing import Any, AsyncGenerator, Callable, Dict, Optional @@ -55,7 +56,6 @@ from src.models.database import ( from src.services.provider.transport import build_provider_url - class ChatHandlerBase(BaseMessageHandler, ABC): """ Chat Handler 基类 @@ -89,7 +89,9 @@ class ChatHandlerBase(BaseMessageHandler, ABC): user_agent: str, start_time: float, allowed_api_formats: Optional[list] = None, - adapter_detector: Optional[Callable[[Dict[str, str], Optional[Dict[str, Any]]], Dict[str, bool]]] = None, + adapter_detector: Optional[ + Callable[[Dict[str, str], Optional[Dict[str, Any]]], Dict[str, bool]] + ] = None, ): allowed = allowed_api_formats or [self.FORMAT_ID] super().__init__( @@ -459,14 +461,19 @@ class ChatHandlerBase(BaseMessageHandler, ABC): f"模型={ctx.model} -> {mapped_model or '无映射'}" ) - # 发送请求(使用配置中的超时设置) + # 配置 HTTP 超时 + # 注意:read timeout 用于检测连接断开,不是整体请求超时 + # 整体请求超时由 asyncio.wait_for 控制,使用 endpoint.timeout timeout_config = httpx.Timeout( connect=config.http_connect_timeout, - read=float(endpoint.timeout), + read=config.http_read_timeout, # 使用全局配置,用于检测连接断开 write=config.http_write_timeout, pool=config.http_pool_timeout, ) + # endpoint.timeout 作为整体请求超时(建立连接 + 获取首字节) + request_timeout = float(endpoint.timeout or 300) + # 创建 HTTP 客户端(支持代理配置) from src.clients.http_client import HTTPClientPool @@ -474,7 +481,15 @@ class ChatHandlerBase(BaseMessageHandler, ABC): proxy_config=endpoint.proxy, timeout=timeout_config, ) - try: + + # 用于存储内部函数的结果(必须在函数定义前声明,供 nonlocal 使用) + byte_iterator: Any = None + prefetched_chunks: Any = None + response_ctx: Any = None + + async def _connect_and_prefetch() -> None: + """建立连接并预读首字节(受整体超时控制)""" + nonlocal byte_iterator, prefetched_chunks, response_ctx response_ctx = http_client.stream( "POST", url, json=provider_payload, headers=provider_headers ) @@ -497,6 +512,28 @@ class ChatHandlerBase(BaseMessageHandler, ABC): max_prefetch_lines=config.stream_prefetch_lines, ) + try: + # 使用 asyncio.wait_for 包裹整个"建立连接 + 获取首字节"阶段 + # endpoint.timeout 控制整体超时,避免上游长时间无响应 + await asyncio.wait_for(_connect_and_prefetch(), timeout=request_timeout) + + except asyncio.TimeoutError: + # 整体请求超时(建立连接 + 获取首字节) + # 清理可能已建立的连接上下文 + if response_ctx is not None: + try: + await response_ctx.__aexit__(None, None, None) + except Exception: + pass + await http_client.aclose() + logger.warning( + f" [{self.request_id}] 请求超时: Provider={provider.name}, timeout={request_timeout}s" + ) + raise ProviderTimeoutException( + provider_name=str(provider.name), + timeout=int(request_timeout), + ) + except httpx.HTTPStatusError as e: error_text = await self._extract_error_text(e) logger.error(f"Provider 返回错误: {e.response.status_code}\n Response: {error_text}") @@ -507,7 +544,8 @@ class ChatHandlerBase(BaseMessageHandler, ABC): except EmbeddedErrorException: try: - await response_ctx.__aexit__(None, None, None) + if response_ctx is not None: + await response_ctx.__aexit__(None, None, None) except Exception: pass await http_client.aclose() @@ -517,6 +555,11 @@ class ChatHandlerBase(BaseMessageHandler, ABC): await http_client.aclose() raise + # 类型断言:成功执行后这些变量不会为 None + assert byte_iterator is not None + assert prefetched_chunks is not None + assert response_ctx is not None + # 创建流生成器(传入字节流迭代器) return stream_processor.create_response_stream( ctx, @@ -639,17 +682,23 @@ class ChatHandlerBase(BaseMessageHandler, ABC): is_stream=False, ) - logger.info(f" [{self.request_id}] 发送非流式请求: Provider={provider.name}, " - f"模型={model} -> {mapped_model or '无映射'}") + logger.info( + f" [{self.request_id}] 发送非流式请求: Provider={provider.name}, " + f"模型={model} -> {mapped_model or '无映射'}" + ) logger.debug(f" [{self.request_id}] 请求URL: {url}") - logger.debug(f" [{self.request_id}] 请求体stream字段: {provider_payload.get('stream', 'N/A')}") + logger.debug( + f" [{self.request_id}] 请求体stream字段: {provider_payload.get('stream', 'N/A')}" + ) # 创建 HTTP 客户端(支持代理配置) + # endpoint.timeout 作为整体请求超时 from src.clients.http_client import HTTPClientPool + request_timeout = float(endpoint.timeout or 300) http_client = HTTPClientPool.create_client_with_proxy( proxy_config=endpoint.proxy, - timeout=httpx.Timeout(float(endpoint.timeout)), + timeout=httpx.Timeout(request_timeout), ) async with http_client: resp = await http_client.post(url, json=provider_payload, headers=provider_hdrs) @@ -670,7 +719,9 @@ class ChatHandlerBase(BaseMessageHandler, ABC): error_body = "" try: error_body = resp.text[:1000] - logger.error(f" [{self.request_id}] 上游返回5xx错误: status={resp.status_code}, body={error_body[:500]}") + logger.error( + f" [{self.request_id}] 上游返回5xx错误: status={resp.status_code}, body={error_body[:500]}" + ) except Exception: pass raise ProviderNotAvailableException( @@ -684,7 +735,9 @@ class ChatHandlerBase(BaseMessageHandler, ABC): error_body = "" try: error_body = resp.text[:1000] - logger.warning(f" [{self.request_id}] 上游返回非200: status={resp.status_code}, body={error_body[:500]}") + logger.warning( + f" [{self.request_id}] 上游返回非200: status={resp.status_code}, body={error_body[:500]}" + ) except Exception: pass raise ProviderNotAvailableException( @@ -765,8 +818,10 @@ class ChatHandlerBase(BaseMessageHandler, ABC): logger.debug(f"{self.FORMAT_ID} 非流式响应完成") # 简洁的请求完成摘要 - logger.info(f"[OK] {self.request_id[:8]} | {model} | {provider_name or 'unknown'} | {response_time_ms}ms | " - f"in:{input_tokens or 0} out:{output_tokens or 0}") + logger.info( + f"[OK] {self.request_id[:8]} | {model} | {provider_name or 'unknown'} | {response_time_ms}ms | " + f"in:{input_tokens or 0} out:{output_tokens or 0}" + ) return JSONResponse(status_code=status_code, content=response_json) @@ -807,8 +862,6 @@ class ChatHandlerBase(BaseMessageHandler, ABC): error_bytes = await e.response.aread() return error_bytes.decode("utf-8", errors="replace") else: - return ( - e.response.text if hasattr(e.response, "_content") else "Unable to read" - ) + return e.response.text if hasattr(e.response, "_content") else "Unable to read" except Exception as decode_error: return f"Unable to read error: {decode_error}" diff --git a/src/api/handlers/base/cli_handler_base.py b/src/api/handlers/base/cli_handler_base.py index 16ff18a..49e3e50 100644 --- a/src/api/handlers/base/cli_handler_base.py +++ b/src/api/handlers/base/cli_handler_base.py @@ -33,19 +33,21 @@ from src.api.handlers.base.base_handler import ( ) from src.api.handlers.base.parsers import get_parser_for_format from src.api.handlers.base.request_builder import PassthroughRequestBuilder -from src.api.handlers.base.stream_context import StreamContext -from src.api.handlers.base.utils import ( - build_sse_headers, - check_html_response, - check_prefetched_response_error, -) -from src.core.error_utils import extract_error_message # 直接从具体模块导入,避免循环依赖 from src.api.handlers.base.response_parser import ( ResponseParser, StreamStats, ) +from src.api.handlers.base.stream_context import StreamContext +from src.api.handlers.base.utils import ( + build_sse_headers, + check_html_response, + check_prefetched_response_error, +) +from src.config.constants import StreamDefaults +from src.config.settings import config +from src.core.error_utils import extract_error_message from src.core.exceptions import ( EmbeddedErrorException, ProviderAuthException, @@ -62,8 +64,6 @@ from src.models.database import ( ProviderEndpoint, User, ) -from src.config.constants import StreamDefaults -from src.config.settings import config from src.services.provider.transport import build_provider_url from src.utils.sse_parser import SSEEventParser from src.utils.timeout import read_first_chunk_with_ttfb_timeout @@ -100,7 +100,9 @@ class CliMessageHandlerBase(BaseMessageHandler): user_agent: str, start_time: float, allowed_api_formats: Optional[list] = None, - adapter_detector: Optional[Callable[[Dict[str, str], Optional[Dict[str, Any]]], Dict[str, bool]]] = None, + adapter_detector: Optional[ + Callable[[Dict[str, str], Optional[Dict[str, Any]]], Dict[str, bool]] + ] = None, ): allowed = allowed_api_formats or [self.FORMAT_ID] super().__init__( @@ -158,7 +160,9 @@ class CliMessageHandlerBase(BaseMessageHandler): mapper = ModelMapperMiddleware(self.db) mapping = await mapper.get_mapping(source_model, provider_id) - logger.debug(f"[CLI] _get_mapped_model: source={source_model}, provider={provider_id[:8]}..., mapping={mapping}") + logger.debug( + f"[CLI] _get_mapped_model: source={source_model}, provider={provider_id[:8]}..., mapping={mapping}" + ) if mapping and mapping.model: # 使用 select_provider_model_name 支持模型映射功能 @@ -168,7 +172,9 @@ class CliMessageHandlerBase(BaseMessageHandler): mapped_name = mapping.model.select_provider_model_name( affinity_key, api_format=self.FORMAT_ID ) - logger.debug(f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)") + logger.debug( + f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)" + ) return mapped_name logger.debug(f"[CLI] 无模型映射,使用原始名称: {source_model}") @@ -459,18 +465,26 @@ class CliMessageHandlerBase(BaseMessageHandler): is_stream=True, # CLI handler 处理流式请求 ) - # 配置超时 + # 配置 HTTP 超时 + # 注意:read timeout 用于检测连接断开,不是整体请求超时 + # 整体请求超时由 _connect_and_prefetch 内部的 asyncio.wait_for 控制 timeout_config = httpx.Timeout( - connect=10.0, - read=float(endpoint.timeout), - write=60.0, # 写入超时增加到60秒,支持大请求体(如包含图片的长对话) - pool=10.0, + connect=config.http_connect_timeout, + read=config.http_read_timeout, # 使用全局配置,用于检测连接断开 + write=config.http_write_timeout, + pool=config.http_pool_timeout, ) - logger.debug(f" └─ [{self.request_id}] 发送流式请求: " - f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., " - f"Key=***{key.api_key[-4:]}, " - f"原始模型={ctx.model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}") + # endpoint.timeout 作为整体请求超时(建立连接 + 获取首字节) + request_timeout = float(endpoint.timeout or 300) + + logger.debug( + f" └─ [{self.request_id}] 发送流式请求: " + f"Provider={provider.name}, Endpoint={endpoint.id[:8] if endpoint.id else 'N/A'}..., " + f"Key=***{key.api_key[-4:] if key.api_key else 'N/A'}, " + f"原始模型={ctx.model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}, " + f"timeout={request_timeout}s" + ) # 创建 HTTP 客户端(支持代理配置) from src.clients.http_client import HTTPClientPool @@ -479,7 +493,15 @@ class CliMessageHandlerBase(BaseMessageHandler): proxy_config=endpoint.proxy, timeout=timeout_config, ) - try: + + # 用于存储内部函数的结果(必须在函数定义前声明,供 nonlocal 使用) + byte_iterator: Any = None + prefetched_chunks: Any = None + response_ctx: Any = None + + async def _connect_and_prefetch() -> None: + """建立连接并预读首字节(受整体超时控制)""" + nonlocal byte_iterator, prefetched_chunks, response_ctx response_ctx = http_client.stream( "POST", url, json=provider_payload, headers=provider_headers ) @@ -500,9 +522,33 @@ class CliMessageHandlerBase(BaseMessageHandler): byte_iterator, provider, endpoint, ctx ) + try: + # 使用 asyncio.wait_for 包裹整个"建立连接 + 获取首字节"阶段 + # endpoint.timeout 控制整体超时,避免上游长时间无响应 + await asyncio.wait_for(_connect_and_prefetch(), timeout=request_timeout) + + except asyncio.TimeoutError: + # 整体请求超时(建立连接 + 获取首字节) + # 清理可能已建立的连接上下文 + if response_ctx is not None: + try: + await response_ctx.__aexit__(None, None, None) + except Exception: + pass + await http_client.aclose() + logger.warning( + f" [{self.request_id}] 请求超时: Provider={provider.name}, timeout={request_timeout}s" + ) + raise ProviderTimeoutException( + provider_name=str(provider.name), + timeout=int(request_timeout), + ) + except httpx.HTTPStatusError as e: error_text = await self._extract_error_text(e) - logger.error(f"Provider 返回错误状态: {e.response.status_code}\n Response: {error_text}") + logger.error( + f"Provider 返回错误状态: {e.response.status_code}\n Response: {error_text}" + ) await http_client.aclose() # 将上游错误信息附加到异常,以便故障转移时能够返回给客户端 e.upstream_response = error_text # type: ignore[attr-defined] @@ -511,7 +557,8 @@ class CliMessageHandlerBase(BaseMessageHandler): except EmbeddedErrorException: # 嵌套错误需要触发重试,关闭连接后重新抛出 try: - await response_ctx.__aexit__(None, None, None) + if response_ctx is not None: + await response_ctx.__aexit__(None, None, None) except Exception: pass await http_client.aclose() @@ -521,6 +568,11 @@ class CliMessageHandlerBase(BaseMessageHandler): await http_client.aclose() raise + # 类型断言:成功执行后这些变量不会为 None + assert byte_iterator is not None + assert prefetched_chunks is not None + assert response_ctx is not None + # 创建流生成器(带预读数据,使用同一个迭代器) return self._create_response_stream_with_prefetch( ctx, @@ -593,7 +645,9 @@ class CliMessageHandlerBase(BaseMessageHandler): }, } self._mark_first_output(ctx, output_state) - yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8") + yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode( + "utf-8" + ) return # 结束生成器 # 格式转换或直接透传 @@ -801,10 +855,12 @@ class CliMessageHandlerBase(BaseMessageHandler): if isinstance(data, dict) and provider_parser.is_error_response(data): # 提取错误信息 parsed = provider_parser.parse_response(data, 200) - logger.warning(f" [{self.request_id}] 检测到嵌套错误: " + logger.warning( + f" [{self.request_id}] 检测到嵌套错误: " f"Provider={provider.name}, " f"error_type={parsed.error_type}, " - f"message={parsed.error_message}") + f"message={parsed.error_message}" + ) raise EmbeddedErrorException( provider_name=str(provider.name), error_code=( @@ -849,14 +905,12 @@ class CliMessageHandlerBase(BaseMessageHandler): raise except (OSError, IOError) as e: # 网络 I/O 异常:记录警告,可能需要重试 - logger.warning( - f" [{self.request_id}] 预读流时发生网络异常: {type(e).__name__}: {e}" - ) + logger.warning(f" [{self.request_id}] 预读流时发生网络异常: {type(e).__name__}: {e}") except Exception as e: # 未预期的严重异常:记录错误并重新抛出,避免掩盖问题 logger.error( f" [{self.request_id}] 预读流时发生严重异常: {type(e).__name__}: {e}", - exc_info=True + exc_info=True, ) raise @@ -979,7 +1033,9 @@ class CliMessageHandlerBase(BaseMessageHandler): }, } self._mark_first_output(ctx, output_state) - yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8") + yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode( + "utf-8" + ) return # 格式转换或直接透传 @@ -1255,8 +1311,10 @@ class CliMessageHandlerBase(BaseMessageHandler): ) logger.debug(f"{self.FORMAT_ID} 流式响应中断") # 简洁的请求失败摘要(包含预估 token 信息) - logger.info(f"[FAIL] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name} | {response_time_ms}ms | " - f"{ctx.status_code} | in:{actual_input_tokens} out:{ctx.output_tokens} cache:{ctx.cached_tokens}") + logger.info( + f"[FAIL] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name} | {response_time_ms}ms | " + f"{ctx.status_code} | in:{actual_input_tokens} out:{ctx.output_tokens} cache:{ctx.cached_tokens}" + ) else: # 在记录统计前,允许子类从 parsed_chunks 中提取额外的元数据 self._finalize_stream_metadata(ctx) @@ -1289,9 +1347,7 @@ class CliMessageHandlerBase(BaseMessageHandler): ) logger.debug(f"{self.FORMAT_ID} 流式响应完成") # 简洁的请求完成摘要(两行格式) - line1 = ( - f"[OK] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name}" - ) + line1 = f"[OK] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name}" if ctx.first_byte_time_ms: line1 += f" | TTFB: {ctx.first_byte_time_ms}ms" @@ -1314,7 +1370,9 @@ class CliMessageHandlerBase(BaseMessageHandler): RequestCandidateService.mark_candidate_failed( db=bg_db, candidate_id=ctx.attempt_id, - error_type="client_disconnected" if ctx.status_code == 499 else "stream_error", + error_type=( + "client_disconnected" if ctx.status_code == 499 else "stream_error" + ), error_message=ctx.error_message or f"HTTP {ctx.status_code}", status_code=ctx.status_code, latency_ms=response_time_ms, @@ -1469,17 +1527,21 @@ class CliMessageHandlerBase(BaseMessageHandler): is_stream=False, # 非流式请求 ) - logger.info(f" └─ [{self.request_id}] 发送非流式请求: " - f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., " - f"Key=***{key.api_key[-4:]}, " - f"原始模型={model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}") + logger.info( + f" └─ [{self.request_id}] 发送非流式请求: " + f"Provider={provider.name}, Endpoint={endpoint.id[:8] if endpoint.id else 'N/A'}..., " + f"Key=***{key.api_key[-4:] if key.api_key else 'N/A'}, " + f"原始模型={model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}" + ) # 创建 HTTP 客户端(支持代理配置) + # endpoint.timeout 作为整体请求超时 from src.clients.http_client import HTTPClientPool + request_timeout = float(endpoint.timeout or 300) http_client = HTTPClientPool.create_client_with_proxy( proxy_config=endpoint.proxy, - timeout=httpx.Timeout(float(endpoint.timeout)), + timeout=httpx.Timeout(request_timeout), ) async with http_client: resp = await http_client.post(url, json=provider_payload, headers=provider_headers) @@ -1525,9 +1587,11 @@ class CliMessageHandlerBase(BaseMessageHandler): # 记录原始响应信息用于调试 content_type = resp.headers.get("content-type", "unknown") content_encoding = resp.headers.get("content-encoding", "none") - logger.error(f"[{self.request_id}] 无法解析响应 JSON: {e}, " + logger.error( + f"[{self.request_id}] 无法解析响应 JSON: {e}, " f"Content-Type: {content_type}, Content-Encoding: {content_encoding}, " - f"响应长度: {len(resp.content)} bytes") + f"响应长度: {len(resp.content)} bytes" + ) raise ProviderNotAvailableException( f"提供商返回无效响应: {provider.name}, 无法解析 JSON: {str(e)[:100]}" )