refactor: optimize middleware with pure ASGI implementation and enhance security measures

- Replace BaseHTTPMiddleware with pure ASGI implementation in plugin middleware for better streaming response handling
- Add trusted proxy count configuration for client IP extraction in reverse proxy environments
- Implement audit log cleanup scheduler with configurable retention period
- Replace plaintext token logging with SHA256 hash fingerprints for security
- Fix database session lifecycle management in middleware
- Improve request tracing and error logging throughout the system
- Add comprehensive tests for pipeline architecture
This commit is contained in:
fawney19
2025-12-18 19:07:20 +08:00
parent c7b971cfe7
commit 7b932d7afb
24 changed files with 497 additions and 219 deletions

View File

@@ -1,16 +1,17 @@
"""
统一的插件中间件
统一的插件中间件(纯 ASGI 实现)
负责协调所有插件的调用
注意:使用纯 ASGI middleware 而非 BaseHTTPMiddleware
以避免 Starlette 已知的流式响应兼容性问题。
"""
import hashlib
import time
from typing import Any, Awaitable, Callable, Optional
from typing import Optional
from fastapi import HTTPException, Request
from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response as StarletteResponse
from starlette.requests import Request
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from src.config import config
from src.core.logger import logger
@@ -18,20 +19,25 @@ from src.plugins.manager import get_plugin_manager
from src.plugins.rate_limit.base import RateLimitResult
class PluginMiddleware(BaseHTTPMiddleware):
class PluginMiddleware:
"""
统一的插件调用中间件
统一的插件调用中间件(纯 ASGI 实现)
职责:
- 性能监控
- 限流控制 (可选)
- 数据库会话生命周期管理
注意: 认证由各路由通过 Depends() 显式声明,不在中间件层处理
为什么使用纯 ASGI 而非 BaseHTTPMiddleware:
- BaseHTTPMiddleware 会缓冲整个响应体,对流式响应不友好
- BaseHTTPMiddleware 与 StreamingResponse 存在已知兼容性问题
- 纯 ASGI 可以直接透传流式响应,无额外开销
"""
def __init__(self, app: Any) -> None:
super().__init__(app)
def __init__(self, app: ASGIApp) -> None:
self.app = app
self.plugin_manager = get_plugin_manager()
# 从配置读取速率限制值
@@ -61,152 +67,159 @@ class PluginMiddleware(BaseHTTPMiddleware):
"/v1/completions",
]
async def dispatch(
self, request: Request, call_next: Callable[[Request], Awaitable[StarletteResponse]]
) -> StarletteResponse:
"""处理请求并调用相应插件"""
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""ASGI 入口点"""
if scope["type"] != "http":
# 非 HTTP 请求(如 WebSocket直接透传
await self.app(scope, receive, send)
return
# 构建 Request 对象以便复用现有逻辑
request = Request(scope, receive, send)
# 记录请求开始时间
start_time = time.time()
# 设置 request.state 属性
# 注意Starlette 的 Request 对象总是有 state 属性State 实例)
request.state.request_id = request.headers.get("x-request-id", "")
request.state.start_time = start_time
# 标记:若请求过程中通过 Depends(get_db) 创建了会话,则由本中间件统一管理其生命周期
request.state.db_managed_by_middleware = True
response = None
exception_to_raise = None
# 1. 限流检查(在调用下游之前)
rate_limit_result = await self._call_rate_limit_plugins(request)
if rate_limit_result and not rate_limit_result.allowed:
# 限流触发返回429
await self._send_rate_limit_response(send, rate_limit_result)
return
# 2. 预处理插件调用
await self._call_pre_request_plugins(request)
# 用于捕获响应状态码
response_status_code: int = 0
async def send_wrapper(message: Message) -> None:
nonlocal response_status_code
if message["type"] == "http.response.start":
response_status_code = message.get("status", 0)
await send(message)
# 3. 调用下游应用
exception_occurred: Optional[Exception] = None
try:
# 1. 限流插件调用(可选功能)
rate_limit_result = await self._call_rate_limit_plugins(request)
if rate_limit_result and not rate_limit_result.allowed:
# 限流触发返回429
headers = rate_limit_result.headers or {}
raise HTTPException(
status_code=429,
detail=rate_limit_result.message or "Rate limit exceeded",
headers=headers,
)
# 2. 预处理插件调用
await self._call_pre_request_plugins(request)
# 处理请求
response = await call_next(request)
# 3. 提交关键数据库事务(在返回响应前)
# 这确保了 Usage 记录、配额扣减等关键数据在响应返回前持久化
try:
db = getattr(request.state, "db", None)
if isinstance(db, Session):
db.commit()
except Exception as commit_error:
logger.error(f"关键事务提交失败: {commit_error}")
try:
if isinstance(db, Session):
db.rollback()
except Exception:
pass
await self._call_error_plugins(request, commit_error, start_time)
# 返回 500 错误,因为数据可能不一致
response = JSONResponse(
status_code=500,
content={
"type": "error",
"error": {
"type": "database_error",
"message": "数据保存失败,请重试",
},
},
)
# 跳过后处理插件,直接返回错误响应
return response
# 4. 后处理插件调用(监控等,非关键操作)
# 这些操作失败不应影响用户响应
await self._call_post_request_plugins(request, response, start_time)
# 注意不在此处添加限流响应头因为在BaseHTTPMiddleware中
# 响应返回后修改headers会导致Content-Length不匹配错误
# 限流响应头已在返回429错误时正确包含见上面的HTTPException
except RuntimeError as e:
if str(e) == "No response returned.":
db = getattr(request.state, "db", None)
if isinstance(db, Session):
try:
db.rollback()
except Exception:
pass
logger.error("Downstream handler completed without returning a response")
await self._call_error_plugins(request, e, start_time)
if isinstance(db, Session):
try:
db.commit()
except Exception:
pass
response = JSONResponse(
status_code=500,
content={
"type": "error",
"error": {
"type": "internal_error",
"message": "Internal server error: downstream handler returned no response.",
},
},
)
else:
exception_to_raise = e
await self.app(scope, receive, send_wrapper)
except Exception as e:
# 回滚数据库事务
db = getattr(request.state, "db", None)
if isinstance(db, Session):
try:
db.rollback()
except Exception:
pass
exception_occurred = e
# 错误处理插件调用
await self._call_error_plugins(request, e, start_time)
raise
finally:
# 4. 数据库会话清理(无论成功与否)
await self._cleanup_db_session(request, exception_occurred)
# 尝试提交错误日志
if isinstance(db, Session):
# 5. 后处理插件调用(仅在成功时)
if not exception_occurred and response_status_code > 0:
await self._call_post_request_plugins(request, response_status_code, start_time)
async def _send_rate_limit_response(
self, send: Send, result: RateLimitResult
) -> None:
"""发送 429 限流响应"""
import json
body = json.dumps({
"type": "error",
"error": {
"type": "rate_limit_error",
"message": result.message or "Rate limit exceeded",
},
}).encode("utf-8")
headers = [(b"content-type", b"application/json")]
if result.headers:
for key, value in result.headers.items():
headers.append((key.lower().encode(), str(value).encode()))
await send({
"type": "http.response.start",
"status": 429,
"headers": headers,
})
await send({
"type": "http.response.body",
"body": body,
})
async def _cleanup_db_session(
self, request: Request, exception: Optional[Exception]
) -> None:
"""清理数据库会话
事务策略:
- 如果 request.state.tx_committed_by_route 为 True说明路由已自行提交中间件只负责 close
- 否则由中间件统一 commit/rollback
这避免了双重提交的问题,同时保持向后兼容。
"""
from sqlalchemy.orm import Session
db = getattr(request.state, "db", None)
if not isinstance(db, Session):
return
# 检查是否由路由层已经提交
tx_committed_by_route = getattr(request.state, "tx_committed_by_route", False)
try:
if exception is not None:
# 发生异常,回滚事务(无论谁负责提交)
try:
db.rollback()
except Exception as rollback_error:
logger.debug(f"回滚事务时出错(可忽略): {rollback_error}")
elif not tx_committed_by_route:
# 正常完成且路由未自行提交,由中间件提交事务
try:
db.commit()
except:
pass
exception_to_raise = e
except Exception as commit_error:
logger.error(f"关键事务提交失败: {commit_error}")
try:
db.rollback()
except Exception:
pass
# 如果 tx_committed_by_route 为 True跳过 commit路由已提交
finally:
db = getattr(request.state, "db", None)
if isinstance(db, Session):
try:
db.close()
except Exception as close_error:
# 连接池会处理连接的回收,这里的异常不应影响响应
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
# 在 finally 块之后处理异常和响应
if exception_to_raise:
raise exception_to_raise
return response
# 关闭会话,归还连接到连接池
try:
db.close()
except Exception as close_error:
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
def _get_client_ip(self, request: Request) -> str:
"""
获取客户端 IP 地址,支持代理头
注意:此方法信任 X-Forwarded-For 和 X-Real-IP 头,
仅当服务部署在可信代理(如 Nginx、CloudFlare后面时才安全。
如果服务直接暴露公网,攻击者可伪造这些头绕过限流。
"""
# 从配置获取可信代理层数(默认为 1即信任最近一层代理
trusted_proxy_count = getattr(config, "trusted_proxy_count", 1)
# 优先从代理头获取真实 IP
forwarded_for = request.headers.get("x-forwarded-for")
if forwarded_for:
# X-Forwarded-For 可能包含多个 IP取第一个
return forwarded_for.split(",")[0].strip()
# X-Forwarded-For 格式: "client, proxy1, proxy2"
# 从右往左数 trusted_proxy_count 个,取其左边的第一个
ips = [ip.strip() for ip in forwarded_for.split(",")]
if len(ips) > trusted_proxy_count:
return ips[-(trusted_proxy_count + 1)]
elif ips:
return ips[0]
real_ip = request.headers.get("x-real-ip")
if real_ip:
@@ -248,13 +261,11 @@ class PluginMiddleware(BaseHTTPMiddleware):
auth_header = request.headers.get("authorization", "")
api_key = request.headers.get("x-api-key", "")
if auth_header.startswith("Bearer "):
if auth_header.lower().startswith("bearer "):
api_key = auth_header[7:]
if api_key:
# 使用 API Key 的哈希作为限制 key避免日志泄露完整 key
import hashlib
key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:16]
key = f"llm_api_key:{key_hash}"
request.state.rate_limit_key_type = "api_key"
@@ -319,7 +330,10 @@ class PluginMiddleware(BaseHTTPMiddleware):
)
else:
# 限流触发,记录日志
logger.warning(f"速率限制触发: {getattr(request.state, 'rate_limit_key_type', 'unknown')}")
logger.warning(
"速率限制触发: {}",
getattr(request.state, "rate_limit_key_type", "unknown"),
)
return result
return None
except Exception as e:
@@ -332,7 +346,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
pass
async def _call_post_request_plugins(
self, request: Request, response: StarletteResponse, start_time: float
self, request: Request, status_code: int, start_time: float
) -> None:
"""调用请求后的插件"""
@@ -345,8 +359,8 @@ class PluginMiddleware(BaseHTTPMiddleware):
monitor_labels = {
"method": request.method,
"endpoint": request.url.path,
"status": str(response.status_code),
"status_class": f"{response.status_code // 100}xx",
"status": str(status_code),
"status_class": f"{status_code // 100}xx",
}
# 记录请求计数
@@ -368,6 +382,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
self, request: Request, error: Exception, start_time: float
) -> None:
"""调用错误处理插件"""
from fastapi import HTTPException
duration = time.time() - start_time
@@ -380,7 +395,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
error=error,
context={
"endpoint": f"{request.method} {request.url.path}",
"request_id": request.state.request_id,
"request_id": getattr(request.state, "request_id", ""),
"duration": duration,
},
)