mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-08 02:32:27 +08:00
refactor: 简化 IP 获取逻辑并将请求体超时配置化
- 移除 TRUSTED_PROXY_COUNT 配置,改为优先使用 X-Real-IP 头 - 添加 REQUEST_BODY_TIMEOUT 环境变量,默认 60 秒 - 统一 get_client_ip 逻辑,优先级:X-Real-IP > X-Forwarded-For > 直连 IP
This commit is contained in:
@@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.models.database import ApiKey, User
|
||||
from src.utils.request_utils import get_client_ip
|
||||
|
||||
|
||||
|
||||
@@ -86,7 +87,7 @@ class ApiRequestContext:
|
||||
setattr(request.state, "request_id", request_id)
|
||||
|
||||
start_time = time.time()
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
client_ip = get_client_ip(request)
|
||||
user_agent = request.headers.get("user-agent", "unknown")
|
||||
|
||||
context = cls(
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Any, Optional, Tuple
|
||||
from fastapi import HTTPException, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.config.settings import config
|
||||
from src.core.exceptions import QuotaExceededException
|
||||
from src.core.logger import logger
|
||||
from src.models.database import ApiKey, AuditEventType, User, UserRole
|
||||
@@ -64,13 +65,17 @@ class ApiRequestPipeline:
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
# 添加30秒超时防止卡死
|
||||
raw_body = await asyncio.wait_for(http_request.body(), timeout=30.0)
|
||||
# 添加超时防止卡死
|
||||
raw_body = await asyncio.wait_for(
|
||||
http_request.body(), timeout=config.request_body_timeout
|
||||
)
|
||||
logger.debug(f"[Pipeline] Raw body读取完成 | size={len(raw_body) if raw_body is not None else 0} bytes")
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("读取请求体超时(30s),可能客户端未发送完整请求体")
|
||||
timeout_sec = int(config.request_body_timeout)
|
||||
logger.error(f"读取请求体超时({timeout_sec}s),可能客户端未发送完整请求体")
|
||||
raise HTTPException(
|
||||
status_code=408, detail="Request timeout: body not received within 30 seconds"
|
||||
status_code=408,
|
||||
detail=f"Request timeout: body not received within {timeout_sec} seconds",
|
||||
)
|
||||
else:
|
||||
logger.debug(f"[Pipeline] 非写请求跳过读取Body | method={http_request.method}")
|
||||
|
||||
@@ -106,13 +106,6 @@ class Config:
|
||||
self.llm_api_rate_limit = int(os.getenv("LLM_API_RATE_LIMIT", "100"))
|
||||
self.public_api_rate_limit = int(os.getenv("PUBLIC_API_RATE_LIMIT", "60"))
|
||||
|
||||
# 可信代理配置
|
||||
# TRUSTED_PROXY_COUNT: 信任的代理层数(默认 1,即信任最近一层代理)
|
||||
# 设置为 0 表示不信任任何代理头,直接使用连接 IP
|
||||
# 当服务部署在 Nginx/CloudFlare 等反向代理后面时,设置为对应的代理层数
|
||||
# 如果服务直接暴露公网,应设置为 0 以防止 IP 伪造
|
||||
self.trusted_proxy_count = int(os.getenv("TRUSTED_PROXY_COUNT", "1"))
|
||||
|
||||
# 异常处理配置
|
||||
# 设置为 True 时,ProxyException 会传播到路由层以便记录 provider_request_headers
|
||||
# 设置为 False 时,使用全局异常处理器统一处理
|
||||
@@ -161,6 +154,11 @@ class Config:
|
||||
self.stream_stats_delay = float(os.getenv("STREAM_STATS_DELAY", "0.1"))
|
||||
self.stream_first_byte_timeout = self._parse_ttfb_timeout()
|
||||
|
||||
# 请求体读取超时(秒)
|
||||
# REQUEST_BODY_TIMEOUT: 等待客户端发送完整请求体的超时时间
|
||||
# 默认 60 秒,防止客户端发送不完整请求导致连接卡死
|
||||
self.request_body_timeout = float(os.getenv("REQUEST_BODY_TIMEOUT", "60.0"))
|
||||
|
||||
# 内部请求 User-Agent 配置(用于查询上游模型列表等)
|
||||
# 可通过环境变量覆盖默认值,模拟对应 CLI 客户端
|
||||
self.internal_user_agent_claude_cli = os.getenv(
|
||||
|
||||
@@ -203,28 +203,21 @@ class PluginMiddleware:
|
||||
"""
|
||||
获取客户端 IP 地址,支持代理头
|
||||
|
||||
注意:此方法信任 X-Forwarded-For 和 X-Real-IP 头,
|
||||
仅当服务部署在可信代理(如 Nginx、CloudFlare)后面时才安全。
|
||||
如果服务直接暴露公网,攻击者可伪造这些头绕过限流。
|
||||
优先级:X-Real-IP > X-Forwarded-For > 直连 IP
|
||||
X-Real-IP 由最外层 Nginx 设置,最可靠
|
||||
"""
|
||||
# 从配置获取可信代理层数(默认为 1,即信任最近一层代理)
|
||||
trusted_proxy_count = getattr(config, "trusted_proxy_count", 1)
|
||||
|
||||
# 优先从代理头获取真实 IP
|
||||
forwarded_for = request.headers.get("x-forwarded-for")
|
||||
if forwarded_for:
|
||||
# X-Forwarded-For 格式: "client, proxy1, proxy2"
|
||||
# 从右往左数 trusted_proxy_count 个,取其左边的第一个
|
||||
ips = [ip.strip() for ip in forwarded_for.split(",")]
|
||||
if len(ips) > trusted_proxy_count:
|
||||
return ips[-(trusted_proxy_count + 1)]
|
||||
elif ips:
|
||||
return ips[0]
|
||||
|
||||
# 优先检查 X-Real-IP(由最外层 Nginx 设置,最可靠)
|
||||
real_ip = request.headers.get("x-real-ip")
|
||||
if real_ip:
|
||||
return real_ip.strip()
|
||||
|
||||
# 检查 X-Forwarded-For,取第一个 IP(原始客户端)
|
||||
forwarded_for = request.headers.get("x-forwarded-for")
|
||||
if forwarded_for:
|
||||
ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()]
|
||||
if ips:
|
||||
return ips[0]
|
||||
|
||||
# 回退到直连 IP
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
@@ -7,22 +7,20 @@ from typing import Optional
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from src.config import config
|
||||
|
||||
|
||||
def get_client_ip(request: Request) -> str:
|
||||
"""
|
||||
获取客户端真实IP地址
|
||||
|
||||
按优先级检查:
|
||||
1. X-Forwarded-For 头(支持代理链,根据可信代理数量提取)
|
||||
2. X-Real-IP 头(Nginx 代理)
|
||||
1. X-Real-IP 头(最可靠,由最外层可信 Nginx 直接设置)
|
||||
2. X-Forwarded-For 头的第一个 IP(原始客户端)
|
||||
3. 直接客户端IP
|
||||
|
||||
安全说明:
|
||||
- 此函数根据 TRUSTED_PROXY_COUNT 配置来决定信任的代理层数
|
||||
- 当 TRUSTED_PROXY_COUNT=0 时,不信任任何代理头,直接使用连接 IP
|
||||
- 当服务直接暴露公网时,应设置 TRUSTED_PROXY_COUNT=0 以防止 IP 伪造
|
||||
- X-Real-IP 优先级最高,因为它通常由最外层 Nginx 设置为 $remote_addr,
|
||||
Nginx 会直接覆盖这个头,不会传递客户端伪造的值
|
||||
- 只要最外层 Nginx 配置了 proxy_set_header X-Real-IP $remote_addr; 即可正确获取真实 IP
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
@@ -30,30 +28,19 @@ def get_client_ip(request: Request) -> str:
|
||||
Returns:
|
||||
str: 客户端IP地址,如果无法获取则返回 "unknown"
|
||||
"""
|
||||
trusted_proxy_count = config.trusted_proxy_count
|
||||
|
||||
# 如果不信任任何代理,直接返回连接 IP
|
||||
if trusted_proxy_count == 0:
|
||||
if request.client and request.client.host:
|
||||
return request.client.host
|
||||
return "unknown"
|
||||
|
||||
# 优先检查 X-Forwarded-For 头(可能包含代理链)
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
# X-Forwarded-For 格式: "client, proxy1, proxy2"
|
||||
# 从右往左数 trusted_proxy_count 个,取其左边的第一个
|
||||
ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()]
|
||||
if len(ips) > trusted_proxy_count:
|
||||
return ips[-(trusted_proxy_count + 1)]
|
||||
elif ips:
|
||||
return ips[0]
|
||||
|
||||
# 检查 X-Real-IP 头(通常由 Nginx 设置)
|
||||
# 优先检查 X-Real-IP 头(由最外层 Nginx 设置,最可靠)
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
return real_ip.strip()
|
||||
|
||||
# 检查 X-Forwarded-For 头,取第一个 IP(原始客户端)
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
# X-Forwarded-For 格式: "client, proxy1, proxy2"
|
||||
ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()]
|
||||
if ips:
|
||||
return ips[0]
|
||||
|
||||
# 回退到直接客户端IP
|
||||
if request.client and request.client.host:
|
||||
return request.client.host
|
||||
@@ -109,36 +96,26 @@ def get_request_metadata(request: Request) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def extract_ip_from_headers(headers: dict, trusted_proxy_count: Optional[int] = None) -> str:
|
||||
def extract_ip_from_headers(headers: dict) -> str:
|
||||
"""
|
||||
从HTTP头字典中提取IP地址(用于中间件等场景)
|
||||
|
||||
Args:
|
||||
headers: HTTP头字典
|
||||
trusted_proxy_count: 可信代理层数,None 时使用配置值
|
||||
|
||||
Returns:
|
||||
str: 客户端IP地址
|
||||
"""
|
||||
if trusted_proxy_count is None:
|
||||
trusted_proxy_count = config.trusted_proxy_count
|
||||
|
||||
# 如果不信任任何代理,返回 unknown(调用方需要用其他方式获取连接 IP)
|
||||
if trusted_proxy_count == 0:
|
||||
return "unknown"
|
||||
|
||||
# 检查 X-Forwarded-For
|
||||
forwarded_for = headers.get("x-forwarded-for", "")
|
||||
if forwarded_for:
|
||||
ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()]
|
||||
if len(ips) > trusted_proxy_count:
|
||||
return ips[-(trusted_proxy_count + 1)]
|
||||
elif ips:
|
||||
return ips[0]
|
||||
|
||||
# 检查 X-Real-IP
|
||||
# 优先检查 X-Real-IP(由最外层 Nginx 设置,最可靠)
|
||||
real_ip = headers.get("x-real-ip", "")
|
||||
if real_ip:
|
||||
return real_ip.strip()
|
||||
|
||||
# 检查 X-Forwarded-For,取第一个 IP
|
||||
forwarded_for = headers.get("x-forwarded-for", "")
|
||||
if forwarded_for:
|
||||
ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()]
|
||||
if ips:
|
||||
return ips[0]
|
||||
|
||||
return "unknown"
|
||||
|
||||
Reference in New Issue
Block a user