diff --git a/src/api/handlers/base/chat_handler_base.py b/src/api/handlers/base/chat_handler_base.py index 3884ff2..97b5166 100644 --- a/src/api/handlers/base/chat_handler_base.py +++ b/src/api/handlers/base/chat_handler_base.py @@ -691,64 +691,70 @@ class ChatHandlerBase(BaseMessageHandler, ABC): f" [{self.request_id}] 请求体stream字段: {provider_payload.get('stream', 'N/A')}" ) - # 创建 HTTP 客户端(支持代理配置) - # endpoint.timeout 作为整体请求超时 + # 获取复用的 HTTP 客户端(支持代理配置) + # 注意:使用 get_proxy_client 复用连接池,不再每次创建新客户端 from src.clients.http_client import HTTPClientPool request_timeout = float(endpoint.timeout or 300) - http_client = HTTPClientPool.create_client_with_proxy( + http_client = await HTTPClientPool.get_proxy_client( proxy_config=endpoint.proxy, + ) + + # 注意:不使用 async with,因为复用的客户端不应该被关闭 + # 超时通过 timeout 参数控制 + resp = await http_client.post( + url, + json=provider_payload, + headers=provider_hdrs, timeout=httpx.Timeout(request_timeout), ) - async with http_client: - resp = await http_client.post(url, json=provider_payload, headers=provider_hdrs) - status_code = resp.status_code - response_headers = dict(resp.headers) + status_code = resp.status_code + response_headers = dict(resp.headers) - if resp.status_code == 401: - raise ProviderAuthException(f"提供商认证失败: {provider.name}") - elif resp.status_code == 429: - raise ProviderRateLimitException( - f"提供商速率限制: {provider.name}", - provider_name=str(provider.name), - response_headers=response_headers, + if resp.status_code == 401: + raise ProviderAuthException(f"提供商认证失败: {provider.name}") + elif resp.status_code == 429: + raise ProviderRateLimitException( + f"提供商速率限制: {provider.name}", + provider_name=str(provider.name), + response_headers=response_headers, + ) + elif resp.status_code >= 500: + # 记录响应体以便调试 + error_body = "" + try: + error_body = resp.text[:1000] + logger.error( + f" [{self.request_id}] 上游返回5xx错误: status={resp.status_code}, body={error_body[:500]}" ) - elif resp.status_code >= 500: - # 记录响应体以便调试 - error_body = "" - try: - error_body = resp.text[:1000] - logger.error( - f" [{self.request_id}] 上游返回5xx错误: status={resp.status_code}, body={error_body[:500]}" - ) - except Exception: - pass - raise ProviderNotAvailableException( - f"提供商服务不可用: {provider.name}", - provider_name=str(provider.name), - upstream_status=resp.status_code, - upstream_response=error_body, - ) - elif resp.status_code != 200: - # 记录非200响应以便调试 - error_body = "" - try: - error_body = resp.text[:1000] - logger.warning( - f" [{self.request_id}] 上游返回非200: status={resp.status_code}, body={error_body[:500]}" - ) - except Exception: - pass - raise ProviderNotAvailableException( - f"提供商返回错误: {provider.name}, 状态: {resp.status_code}", - provider_name=str(provider.name), - upstream_status=resp.status_code, - upstream_response=error_body, + except Exception: + pass + raise ProviderNotAvailableException( + f"提供商服务不可用: {provider.name}", + provider_name=str(provider.name), + upstream_status=resp.status_code, + upstream_response=error_body, + ) + elif resp.status_code != 200: + # 记录非200响应以便调试 + error_body = "" + try: + error_body = resp.text[:1000] + logger.warning( + f" [{self.request_id}] 上游返回非200: status={resp.status_code}, body={error_body[:500]}" ) + except Exception: + pass + raise ProviderNotAvailableException( + f"提供商返回错误: {provider.name}, 状态: {resp.status_code}", + provider_name=str(provider.name), + upstream_status=resp.status_code, + upstream_response=error_body, + ) - response_json = resp.json() - return response_json if isinstance(response_json, dict) else {} + response_json = resp.json() + return response_json if isinstance(response_json, dict) else {} try: # 解析能力需求 diff --git a/src/api/handlers/base/cli_handler_base.py b/src/api/handlers/base/cli_handler_base.py index 49e3e50..7bb9898 100644 --- a/src/api/handlers/base/cli_handler_base.py +++ b/src/api/handlers/base/cli_handler_base.py @@ -1534,72 +1534,78 @@ class CliMessageHandlerBase(BaseMessageHandler): f"原始模型={model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}" ) - # 创建 HTTP 客户端(支持代理配置) - # endpoint.timeout 作为整体请求超时 + # 获取复用的 HTTP 客户端(支持代理配置) + # 注意:使用 get_proxy_client 复用连接池,不再每次创建新客户端 from src.clients.http_client import HTTPClientPool request_timeout = float(endpoint.timeout or 300) - http_client = HTTPClientPool.create_client_with_proxy( + http_client = await HTTPClientPool.get_proxy_client( proxy_config=endpoint.proxy, + ) + + # 注意:不使用 async with,因为复用的客户端不应该被关闭 + # 超时通过 timeout 参数控制 + resp = await http_client.post( + url, + json=provider_payload, + headers=provider_headers, timeout=httpx.Timeout(request_timeout), ) - async with http_client: - resp = await http_client.post(url, json=provider_payload, headers=provider_headers) - status_code = resp.status_code - response_headers = dict(resp.headers) + status_code = resp.status_code + response_headers = dict(resp.headers) - if resp.status_code == 401: - raise ProviderAuthException(f"提供商认证失败: {provider.name}") - elif resp.status_code == 429: - raise ProviderRateLimitException( - f"提供商速率限制: {provider.name}", - provider_name=str(provider.name), - response_headers=response_headers, - retry_after=int(resp.headers.get("retry-after", 0)) or None, - ) - elif resp.status_code >= 500: - error_text = resp.text - raise ProviderNotAvailableException( - f"提供商服务不可用: {provider.name}, 状态: {resp.status_code}", - provider_name=str(provider.name), - upstream_status=resp.status_code, - upstream_response=error_text, - ) - elif 300 <= resp.status_code < 400: - redirect_url = resp.headers.get("location", "unknown") - raise ProviderNotAvailableException( - f"提供商配置错误: {provider.name}, 返回重定向 {resp.status_code} -> {redirect_url}" - ) - elif resp.status_code != 200: - error_text = resp.text - raise ProviderNotAvailableException( - f"提供商返回错误: {provider.name}, 状态: {resp.status_code}", - provider_name=str(provider.name), - upstream_status=resp.status_code, - upstream_response=error_text, - ) + if resp.status_code == 401: + raise ProviderAuthException(f"提供商认证失败: {provider.name}") + elif resp.status_code == 429: + raise ProviderRateLimitException( + f"提供商速率限制: {provider.name}", + provider_name=str(provider.name), + response_headers=response_headers, + retry_after=int(resp.headers.get("retry-after", 0)) or None, + ) + elif resp.status_code >= 500: + error_text = resp.text + raise ProviderNotAvailableException( + f"提供商服务不可用: {provider.name}, 状态: {resp.status_code}", + provider_name=str(provider.name), + upstream_status=resp.status_code, + upstream_response=error_text, + ) + elif 300 <= resp.status_code < 400: + redirect_url = resp.headers.get("location", "unknown") + raise ProviderNotAvailableException( + f"提供商配置错误: {provider.name}, 返回重定向 {resp.status_code} -> {redirect_url}" + ) + elif resp.status_code != 200: + error_text = resp.text + raise ProviderNotAvailableException( + f"提供商返回错误: {provider.name}, 状态: {resp.status_code}", + provider_name=str(provider.name), + upstream_status=resp.status_code, + upstream_response=error_text, + ) - # 安全解析 JSON 响应,处理可能的编码错误 - try: - response_json = resp.json() - except (UnicodeDecodeError, json.JSONDecodeError) as e: - # 记录原始响应信息用于调试 - content_type = resp.headers.get("content-type", "unknown") - content_encoding = resp.headers.get("content-encoding", "none") - logger.error( - f"[{self.request_id}] 无法解析响应 JSON: {e}, " - f"Content-Type: {content_type}, Content-Encoding: {content_encoding}, " - f"响应长度: {len(resp.content)} bytes" - ) - raise ProviderNotAvailableException( - f"提供商返回无效响应: {provider.name}, 无法解析 JSON: {str(e)[:100]}" - ) + # 安全解析 JSON 响应,处理可能的编码错误 + try: + response_json = resp.json() + except (UnicodeDecodeError, json.JSONDecodeError) as e: + # 记录原始响应信息用于调试 + content_type = resp.headers.get("content-type", "unknown") + content_encoding = resp.headers.get("content-encoding", "none") + logger.error( + f"[{self.request_id}] 无法解析响应 JSON: {e}, " + f"Content-Type: {content_type}, Content-Encoding: {content_encoding}, " + f"响应长度: {len(resp.content)} bytes" + ) + raise ProviderNotAvailableException( + f"提供商返回无效响应: {provider.name}, 无法解析 JSON: {str(e)[:100]}" + ) - # 提取 Provider 响应元数据(子类可覆盖) - response_metadata_result = self._extract_response_metadata(response_json) + # 提取 Provider 响应元数据(子类可覆盖) + response_metadata_result = self._extract_response_metadata(response_json) - return response_json if isinstance(response_json, dict) else {} + return response_json if isinstance(response_json, dict) else {} try: # 解析能力需求 diff --git a/src/clients/http_client.py b/src/clients/http_client.py index bf33e4f..4ba8cf3 100644 --- a/src/clients/http_client.py +++ b/src/clients/http_client.py @@ -1,10 +1,18 @@ """ 全局HTTP客户端池管理 避免每次请求都创建新的AsyncClient,提高性能 + +性能优化说明: +1. 默认客户端:无代理场景,全局复用单一客户端 +2. 代理客户端缓存:相同代理配置复用同一客户端,避免重复创建 +3. 连接池复用:Keep-alive 连接减少 TCP 握手开销 """ +import asyncio +import hashlib +import time from contextlib import asynccontextmanager -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple from urllib.parse import quote, urlparse import httpx @@ -12,6 +20,32 @@ import httpx from src.config import config from src.core.logger import logger +# 模块级锁,避免类属性延迟初始化的竞态条件 +_proxy_clients_lock = asyncio.Lock() +_default_client_lock = asyncio.Lock() + + +def _compute_proxy_cache_key(proxy_config: Optional[Dict[str, Any]]) -> str: + """ + 计算代理配置的缓存键 + + Args: + proxy_config: 代理配置字典 + + Returns: + 缓存键字符串,无代理时返回 "__no_proxy__" + """ + if not proxy_config: + return "__no_proxy__" + + # 构建代理 URL 作为缓存键的基础 + proxy_url = build_proxy_url(proxy_config) + if not proxy_url: + return "__no_proxy__" + + # 使用 MD5 哈希来避免过长的键名 + return f"proxy:{hashlib.md5(proxy_url.encode()).hexdigest()[:16]}" + def build_proxy_url(proxy_config: Dict[str, Any]) -> Optional[str]: """ @@ -61,11 +95,20 @@ class HTTPClientPool: 全局HTTP客户端池单例 管理可重用的httpx.AsyncClient实例,避免频繁创建/销毁连接 + + 性能优化: + 1. 默认客户端:无代理场景复用 + 2. 代理客户端缓存:相同代理配置复用同一客户端 + 3. LRU 淘汰:代理客户端超过上限时淘汰最久未使用的 """ _instance: Optional["HTTPClientPool"] = None _default_client: Optional[httpx.AsyncClient] = None _clients: Dict[str, httpx.AsyncClient] = {} + # 代理客户端缓存:{cache_key: (client, last_used_time)} + _proxy_clients: Dict[str, Tuple[httpx.AsyncClient, float]] = {} + # 代理客户端缓存上限(避免内存泄漏) + _max_proxy_clients: int = 50 def __new__(cls): if cls._instance is None: @@ -73,12 +116,50 @@ class HTTPClientPool: return cls._instance @classmethod - def get_default_client(cls) -> httpx.AsyncClient: + async def get_default_client_async(cls) -> httpx.AsyncClient: """ - 获取默认的HTTP客户端 + 获取默认的HTTP客户端(异步线程安全版本) 用于大多数HTTP请求,具有合理的默认配置 """ + if cls._default_client is not None: + return cls._default_client + + async with _default_client_lock: + # 双重检查,避免重复创建 + if cls._default_client is None: + cls._default_client = httpx.AsyncClient( + http2=False, # 暂时禁用HTTP/2以提高兼容性 + verify=True, # 启用SSL验证 + timeout=httpx.Timeout( + connect=config.http_connect_timeout, + read=config.http_read_timeout, + write=config.http_write_timeout, + pool=config.http_pool_timeout, + ), + limits=httpx.Limits( + max_connections=config.http_max_connections, + max_keepalive_connections=config.http_keepalive_connections, + keepalive_expiry=config.http_keepalive_expiry, + ), + follow_redirects=True, # 跟随重定向 + ) + logger.info( + f"全局HTTP客户端池已初始化: " + f"max_connections={config.http_max_connections}, " + f"keepalive={config.http_keepalive_connections}, " + f"keepalive_expiry={config.http_keepalive_expiry}s" + ) + return cls._default_client + + @classmethod + def get_default_client(cls) -> httpx.AsyncClient: + """ + 获取默认的HTTP客户端(同步版本,向后兼容) + + ⚠️ 注意:此方法在高并发首次调用时可能存在竞态条件, + 推荐使用 get_default_client_async() 异步版本。 + """ if cls._default_client is None: cls._default_client = httpx.AsyncClient( http2=False, # 暂时禁用HTTP/2以提高兼容性 @@ -135,6 +216,101 @@ class HTTPClientPool: return cls._clients[name] + @classmethod + def _get_proxy_clients_lock(cls) -> asyncio.Lock: + """获取代理客户端缓存锁(模块级单例,避免竞态条件)""" + return _proxy_clients_lock + + @classmethod + async def _evict_lru_proxy_client(cls) -> None: + """淘汰最久未使用的代理客户端""" + if len(cls._proxy_clients) < cls._max_proxy_clients: + return + + # 找到最久未使用的客户端 + oldest_key = min(cls._proxy_clients.keys(), key=lambda k: cls._proxy_clients[k][1]) + old_client, _ = cls._proxy_clients.pop(oldest_key) + + # 异步关闭旧客户端 + try: + await old_client.aclose() + logger.debug(f"淘汰代理客户端: {oldest_key}") + except Exception as e: + logger.warning(f"关闭代理客户端失败: {e}") + + @classmethod + async def get_proxy_client( + cls, + proxy_config: Optional[Dict[str, Any]] = None, + ) -> httpx.AsyncClient: + """ + 获取代理客户端(带缓存复用) + + 相同代理配置会复用同一个客户端,大幅减少连接建立开销。 + 注意:返回的客户端使用默认超时配置,如需自定义超时请在请求时传递 timeout 参数。 + + Args: + proxy_config: 代理配置字典,包含 url, username, password + + Returns: + 可复用的 httpx.AsyncClient 实例 + """ + cache_key = _compute_proxy_cache_key(proxy_config) + + # 无代理时返回默认客户端 + if cache_key == "__no_proxy__": + return await cls.get_default_client_async() + + lock = cls._get_proxy_clients_lock() + async with lock: + # 检查缓存 + if cache_key in cls._proxy_clients: + client, _ = cls._proxy_clients[cache_key] + # 健康检查:如果客户端已关闭,移除并重新创建 + if client.is_closed: + del cls._proxy_clients[cache_key] + logger.debug(f"代理客户端已关闭,将重新创建: {cache_key}") + else: + # 更新最后使用时间 + cls._proxy_clients[cache_key] = (client, time.time()) + return client + + # 淘汰旧客户端(如果超过上限) + await cls._evict_lru_proxy_client() + + # 创建新客户端(使用默认超时,请求时可覆盖) + client_config: Dict[str, Any] = { + "http2": False, + "verify": True, + "follow_redirects": True, + "limits": httpx.Limits( + max_connections=config.http_max_connections, + max_keepalive_connections=config.http_keepalive_connections, + keepalive_expiry=config.http_keepalive_expiry, + ), + "timeout": httpx.Timeout( + connect=config.http_connect_timeout, + read=config.http_read_timeout, + write=config.http_write_timeout, + pool=config.http_pool_timeout, + ), + } + + # 添加代理配置 + proxy_url = build_proxy_url(proxy_config) if proxy_config else None + if proxy_url: + client_config["proxy"] = proxy_url + + client = httpx.AsyncClient(**client_config) + cls._proxy_clients[cache_key] = (client, time.time()) + + logger.debug( + f"创建代理客户端(缓存): {proxy_config.get('url', 'unknown') if proxy_config else 'none'}, " + f"缓存数量: {len(cls._proxy_clients)}" + ) + + return client + @classmethod async def close_all(cls): """关闭所有HTTP客户端""" @@ -148,6 +324,16 @@ class HTTPClientPool: logger.debug(f"命名HTTP客户端已关闭: {name}") cls._clients.clear() + + # 关闭代理客户端缓存 + for cache_key, (client, _) in cls._proxy_clients.items(): + try: + await client.aclose() + logger.debug(f"代理客户端已关闭: {cache_key}") + except Exception as e: + logger.warning(f"关闭代理客户端失败: {e}") + + cls._proxy_clients.clear() logger.info("所有HTTP客户端已关闭") @classmethod @@ -190,13 +376,15 @@ class HTTPClientPool: """ 创建带代理配置的HTTP客户端 + ⚠️ 性能警告:此方法每次都创建新客户端,推荐使用 get_proxy_client() 复用连接。 + Args: proxy_config: 代理配置字典,包含 url, username, password timeout: 超时配置 **kwargs: 其他 httpx.AsyncClient 配置参数 Returns: - 配置好的 httpx.AsyncClient 实例 + 配置好的 httpx.AsyncClient 实例(调用者需要负责关闭) """ client_config: Dict[str, Any] = { "http2": False, @@ -218,11 +406,21 @@ class HTTPClientPool: proxy_url = build_proxy_url(proxy_config) if proxy_config else None if proxy_url: client_config["proxy"] = proxy_url - logger.debug(f"创建带代理的HTTP客户端: {proxy_config.get('url', 'unknown')}") + logger.debug(f"创建带代理的HTTP客户端(一次性): {proxy_config.get('url', 'unknown')}") client_config.update(kwargs) return httpx.AsyncClient(**client_config) + @classmethod + def get_pool_stats(cls) -> Dict[str, Any]: + """获取连接池统计信息""" + return { + "default_client_active": cls._default_client is not None, + "named_clients_count": len(cls._clients), + "proxy_clients_count": len(cls._proxy_clients), + "max_proxy_clients": cls._max_proxy_clients, + } + # 便捷访问函数 def get_http_client() -> httpx.AsyncClient: diff --git a/src/services/rate_limit/concurrency_manager.py b/src/services/rate_limit/concurrency_manager.py index b1af9b1..2f33a83 100644 --- a/src/services/rate_limit/concurrency_manager.py +++ b/src/services/rate_limit/concurrency_manager.py @@ -85,6 +85,8 @@ class ConcurrencyManager: """ 获取当前并发数 + 性能优化:使用 MGET 批量获取,减少 Redis 往返次数 + Args: endpoint_id: Endpoint ID(可选) key_id: ProviderAPIKey ID(可选) @@ -104,15 +106,21 @@ class ConcurrencyManager: key_count = 0 try: + # 使用 MGET 批量获取,减少 Redis 往返(2 次 GET -> 1 次 MGET) + keys_to_fetch = [] if endpoint_id: - endpoint_key = self._get_endpoint_key(endpoint_id) - result = await self._redis.get(endpoint_key) - endpoint_count = int(result) if result else 0 - + keys_to_fetch.append(self._get_endpoint_key(endpoint_id)) if key_id: - key_key = self._get_key_key(key_id) - result = await self._redis.get(key_key) - key_count = int(result) if result else 0 + keys_to_fetch.append(self._get_key_key(key_id)) + + if keys_to_fetch: + results = await self._redis.mget(keys_to_fetch) + idx = 0 + if endpoint_id: + endpoint_count = int(results[idx]) if results[idx] else 0 + idx += 1 + if key_id: + key_count = int(results[idx]) if results[idx] else 0 except Exception as e: logger.error(f"获取并发数失败: {e}")