From 2395093394c476ae211ff53a86c7bc79c3d08b3f Mon Sep 17 00:00:00 2001 From: fawney19 Date: Tue, 6 Jan 2026 16:29:03 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E7=AE=80=E5=8C=96=20IP=20=E8=8E=B7?= =?UTF-8?q?=E5=8F=96=E9=80=BB=E8=BE=91=E5=B9=B6=E5=B0=86=E8=AF=B7=E6=B1=82?= =?UTF-8?q?=E4=BD=93=E8=B6=85=E6=97=B6=E9=85=8D=E7=BD=AE=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 移除 TRUSTED_PROXY_COUNT 配置,改为优先使用 X-Real-IP 头 - 添加 REQUEST_BODY_TIMEOUT 环境变量,默认 60 秒 - 统一 get_client_ip 逻辑,优先级:X-Real-IP > X-Forwarded-For > 直连 IP --- src/api/base/context.py | 3 +- src/api/base/pipeline.py | 13 ++++-- src/config/settings.py | 12 +++-- src/middleware/plugin_middleware.py | 27 +++++------ src/utils/request_utils.py | 69 ++++++++++------------------- 5 files changed, 49 insertions(+), 75 deletions(-) diff --git a/src/api/base/context.py b/src/api/base/context.py index a1da13c..9372159 100644 --- a/src/api/base/context.py +++ b/src/api/base/context.py @@ -11,6 +11,7 @@ from sqlalchemy.orm import Session from src.core.logger import logger from src.models.database import ApiKey, User +from src.utils.request_utils import get_client_ip @@ -86,7 +87,7 @@ class ApiRequestContext: setattr(request.state, "request_id", request_id) start_time = time.time() - client_ip = request.client.host if request.client else "unknown" + client_ip = get_client_ip(request) user_agent = request.headers.get("user-agent", "unknown") context = cls( diff --git a/src/api/base/pipeline.py b/src/api/base/pipeline.py index 1418040..0b94b2d 100644 --- a/src/api/base/pipeline.py +++ b/src/api/base/pipeline.py @@ -7,6 +7,7 @@ from typing import Any, Optional, Tuple from fastapi import HTTPException, Request from sqlalchemy.orm import Session +from src.config.settings import config from src.core.exceptions import QuotaExceededException from src.core.logger import logger from src.models.database import ApiKey, AuditEventType, User, UserRole @@ -64,13 +65,17 @@ class ApiRequestPipeline: try: import asyncio - # 添加30秒超时防止卡死 - raw_body = await asyncio.wait_for(http_request.body(), timeout=30.0) + # 添加超时防止卡死 + raw_body = await asyncio.wait_for( + http_request.body(), timeout=config.request_body_timeout + ) logger.debug(f"[Pipeline] Raw body读取完成 | size={len(raw_body) if raw_body is not None else 0} bytes") except asyncio.TimeoutError: - logger.error("读取请求体超时(30s),可能客户端未发送完整请求体") + timeout_sec = int(config.request_body_timeout) + logger.error(f"读取请求体超时({timeout_sec}s),可能客户端未发送完整请求体") raise HTTPException( - status_code=408, detail="Request timeout: body not received within 30 seconds" + status_code=408, + detail=f"Request timeout: body not received within {timeout_sec} seconds", ) else: logger.debug(f"[Pipeline] 非写请求跳过读取Body | method={http_request.method}") diff --git a/src/config/settings.py b/src/config/settings.py index e66d017..0a98c06 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -106,13 +106,6 @@ class Config: self.llm_api_rate_limit = int(os.getenv("LLM_API_RATE_LIMIT", "100")) self.public_api_rate_limit = int(os.getenv("PUBLIC_API_RATE_LIMIT", "60")) - # 可信代理配置 - # TRUSTED_PROXY_COUNT: 信任的代理层数(默认 1,即信任最近一层代理) - # 设置为 0 表示不信任任何代理头,直接使用连接 IP - # 当服务部署在 Nginx/CloudFlare 等反向代理后面时,设置为对应的代理层数 - # 如果服务直接暴露公网,应设置为 0 以防止 IP 伪造 - self.trusted_proxy_count = int(os.getenv("TRUSTED_PROXY_COUNT", "1")) - # 异常处理配置 # 设置为 True 时,ProxyException 会传播到路由层以便记录 provider_request_headers # 设置为 False 时,使用全局异常处理器统一处理 @@ -161,6 +154,11 @@ class Config: self.stream_stats_delay = float(os.getenv("STREAM_STATS_DELAY", "0.1")) self.stream_first_byte_timeout = self._parse_ttfb_timeout() + # 请求体读取超时(秒) + # REQUEST_BODY_TIMEOUT: 等待客户端发送完整请求体的超时时间 + # 默认 60 秒,防止客户端发送不完整请求导致连接卡死 + self.request_body_timeout = float(os.getenv("REQUEST_BODY_TIMEOUT", "60.0")) + # 内部请求 User-Agent 配置(用于查询上游模型列表等) # 可通过环境变量覆盖默认值,模拟对应 CLI 客户端 self.internal_user_agent_claude_cli = os.getenv( diff --git a/src/middleware/plugin_middleware.py b/src/middleware/plugin_middleware.py index 6ca59c5..ab8eaa0 100644 --- a/src/middleware/plugin_middleware.py +++ b/src/middleware/plugin_middleware.py @@ -203,28 +203,21 @@ class PluginMiddleware: """ 获取客户端 IP 地址,支持代理头 - 注意:此方法信任 X-Forwarded-For 和 X-Real-IP 头, - 仅当服务部署在可信代理(如 Nginx、CloudFlare)后面时才安全。 - 如果服务直接暴露公网,攻击者可伪造这些头绕过限流。 + 优先级:X-Real-IP > X-Forwarded-For > 直连 IP + X-Real-IP 由最外层 Nginx 设置,最可靠 """ - # 从配置获取可信代理层数(默认为 1,即信任最近一层代理) - trusted_proxy_count = getattr(config, "trusted_proxy_count", 1) - - # 优先从代理头获取真实 IP - forwarded_for = request.headers.get("x-forwarded-for") - if forwarded_for: - # X-Forwarded-For 格式: "client, proxy1, proxy2" - # 从右往左数 trusted_proxy_count 个,取其左边的第一个 - ips = [ip.strip() for ip in forwarded_for.split(",")] - if len(ips) > trusted_proxy_count: - return ips[-(trusted_proxy_count + 1)] - elif ips: - return ips[0] - + # 优先检查 X-Real-IP(由最外层 Nginx 设置,最可靠) real_ip = request.headers.get("x-real-ip") if real_ip: return real_ip.strip() + # 检查 X-Forwarded-For,取第一个 IP(原始客户端) + forwarded_for = request.headers.get("x-forwarded-for") + if forwarded_for: + ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()] + if ips: + return ips[0] + # 回退到直连 IP if request.client: return request.client.host diff --git a/src/utils/request_utils.py b/src/utils/request_utils.py index 4aaeebe..a9f3ebd 100644 --- a/src/utils/request_utils.py +++ b/src/utils/request_utils.py @@ -7,22 +7,20 @@ from typing import Optional from fastapi import Request -from src.config import config - def get_client_ip(request: Request) -> str: """ 获取客户端真实IP地址 按优先级检查: - 1. X-Forwarded-For 头(支持代理链,根据可信代理数量提取) - 2. X-Real-IP 头(Nginx 代理) + 1. X-Real-IP 头(最可靠,由最外层可信 Nginx 直接设置) + 2. X-Forwarded-For 头的第一个 IP(原始客户端) 3. 直接客户端IP 安全说明: - - 此函数根据 TRUSTED_PROXY_COUNT 配置来决定信任的代理层数 - - 当 TRUSTED_PROXY_COUNT=0 时,不信任任何代理头,直接使用连接 IP - - 当服务直接暴露公网时,应设置 TRUSTED_PROXY_COUNT=0 以防止 IP 伪造 + - X-Real-IP 优先级最高,因为它通常由最外层 Nginx 设置为 $remote_addr, + Nginx 会直接覆盖这个头,不会传递客户端伪造的值 + - 只要最外层 Nginx 配置了 proxy_set_header X-Real-IP $remote_addr; 即可正确获取真实 IP Args: request: FastAPI Request 对象 @@ -30,30 +28,19 @@ def get_client_ip(request: Request) -> str: Returns: str: 客户端IP地址,如果无法获取则返回 "unknown" """ - trusted_proxy_count = config.trusted_proxy_count - - # 如果不信任任何代理,直接返回连接 IP - if trusted_proxy_count == 0: - if request.client and request.client.host: - return request.client.host - return "unknown" - - # 优先检查 X-Forwarded-For 头(可能包含代理链) - forwarded_for = request.headers.get("X-Forwarded-For") - if forwarded_for: - # X-Forwarded-For 格式: "client, proxy1, proxy2" - # 从右往左数 trusted_proxy_count 个,取其左边的第一个 - ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()] - if len(ips) > trusted_proxy_count: - return ips[-(trusted_proxy_count + 1)] - elif ips: - return ips[0] - - # 检查 X-Real-IP 头(通常由 Nginx 设置) + # 优先检查 X-Real-IP 头(由最外层 Nginx 设置,最可靠) real_ip = request.headers.get("X-Real-IP") if real_ip: return real_ip.strip() + # 检查 X-Forwarded-For 头,取第一个 IP(原始客户端) + forwarded_for = request.headers.get("X-Forwarded-For") + if forwarded_for: + # X-Forwarded-For 格式: "client, proxy1, proxy2" + ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()] + if ips: + return ips[0] + # 回退到直接客户端IP if request.client and request.client.host: return request.client.host @@ -109,36 +96,26 @@ def get_request_metadata(request: Request) -> dict: } -def extract_ip_from_headers(headers: dict, trusted_proxy_count: Optional[int] = None) -> str: +def extract_ip_from_headers(headers: dict) -> str: """ 从HTTP头字典中提取IP地址(用于中间件等场景) Args: headers: HTTP头字典 - trusted_proxy_count: 可信代理层数,None 时使用配置值 Returns: str: 客户端IP地址 """ - if trusted_proxy_count is None: - trusted_proxy_count = config.trusted_proxy_count - - # 如果不信任任何代理,返回 unknown(调用方需要用其他方式获取连接 IP) - if trusted_proxy_count == 0: - return "unknown" - - # 检查 X-Forwarded-For - forwarded_for = headers.get("x-forwarded-for", "") - if forwarded_for: - ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()] - if len(ips) > trusted_proxy_count: - return ips[-(trusted_proxy_count + 1)] - elif ips: - return ips[0] - - # 检查 X-Real-IP + # 优先检查 X-Real-IP(由最外层 Nginx 设置,最可靠) real_ip = headers.get("x-real-ip", "") if real_ip: return real_ip.strip() + # 检查 X-Forwarded-For,取第一个 IP + forwarded_for = headers.get("x-forwarded-for", "") + if forwarded_for: + ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()] + if ips: + return ips[0] + return "unknown"