mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 08:12:26 +08:00
refactor: optimize middleware with pure ASGI implementation and enhance security measures
- Replace BaseHTTPMiddleware with pure ASGI implementation in plugin middleware for better streaming response handling - Add trusted proxy count configuration for client IP extraction in reverse proxy environments - Implement audit log cleanup scheduler with configurable retention period - Replace plaintext token logging with SHA256 hash fingerprints for security - Fix database session lifecycle management in middleware - Improve request tracing and error logging throughout the system - Add comprehensive tests for pipeline architecture
This commit is contained in:
@@ -1,16 +1,17 @@
|
||||
"""
|
||||
统一的插件中间件
|
||||
统一的插件中间件(纯 ASGI 实现)
|
||||
负责协调所有插件的调用
|
||||
|
||||
注意:使用纯 ASGI middleware 而非 BaseHTTPMiddleware,
|
||||
以避免 Starlette 已知的流式响应兼容性问题。
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
from typing import Any, Awaitable, Callable, Optional
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response as StarletteResponse
|
||||
from starlette.requests import Request
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
from src.config import config
|
||||
from src.core.logger import logger
|
||||
@@ -18,20 +19,25 @@ from src.plugins.manager import get_plugin_manager
|
||||
from src.plugins.rate_limit.base import RateLimitResult
|
||||
|
||||
|
||||
|
||||
class PluginMiddleware(BaseHTTPMiddleware):
|
||||
class PluginMiddleware:
|
||||
"""
|
||||
统一的插件调用中间件
|
||||
统一的插件调用中间件(纯 ASGI 实现)
|
||||
|
||||
职责:
|
||||
- 性能监控
|
||||
- 限流控制 (可选)
|
||||
- 数据库会话生命周期管理
|
||||
|
||||
注意: 认证由各路由通过 Depends() 显式声明,不在中间件层处理
|
||||
|
||||
为什么使用纯 ASGI 而非 BaseHTTPMiddleware:
|
||||
- BaseHTTPMiddleware 会缓冲整个响应体,对流式响应不友好
|
||||
- BaseHTTPMiddleware 与 StreamingResponse 存在已知兼容性问题
|
||||
- 纯 ASGI 可以直接透传流式响应,无额外开销
|
||||
"""
|
||||
|
||||
def __init__(self, app: Any) -> None:
|
||||
super().__init__(app)
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
self.plugin_manager = get_plugin_manager()
|
||||
|
||||
# 从配置读取速率限制值
|
||||
@@ -61,152 +67,159 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
"/v1/completions",
|
||||
]
|
||||
|
||||
async def dispatch(
|
||||
self, request: Request, call_next: Callable[[Request], Awaitable[StarletteResponse]]
|
||||
) -> StarletteResponse:
|
||||
"""处理请求并调用相应插件"""
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
"""ASGI 入口点"""
|
||||
if scope["type"] != "http":
|
||||
# 非 HTTP 请求(如 WebSocket)直接透传
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
# 构建 Request 对象以便复用现有逻辑
|
||||
request = Request(scope, receive, send)
|
||||
|
||||
# 记录请求开始时间
|
||||
start_time = time.time()
|
||||
|
||||
# 设置 request.state 属性
|
||||
# 注意:Starlette 的 Request 对象总是有 state 属性(State 实例)
|
||||
request.state.request_id = request.headers.get("x-request-id", "")
|
||||
request.state.start_time = start_time
|
||||
# 标记:若请求过程中通过 Depends(get_db) 创建了会话,则由本中间件统一管理其生命周期
|
||||
request.state.db_managed_by_middleware = True
|
||||
|
||||
response = None
|
||||
exception_to_raise = None
|
||||
# 1. 限流检查(在调用下游之前)
|
||||
rate_limit_result = await self._call_rate_limit_plugins(request)
|
||||
if rate_limit_result and not rate_limit_result.allowed:
|
||||
# 限流触发,返回429
|
||||
await self._send_rate_limit_response(send, rate_limit_result)
|
||||
return
|
||||
|
||||
# 2. 预处理插件调用
|
||||
await self._call_pre_request_plugins(request)
|
||||
|
||||
# 用于捕获响应状态码
|
||||
response_status_code: int = 0
|
||||
|
||||
async def send_wrapper(message: Message) -> None:
|
||||
nonlocal response_status_code
|
||||
|
||||
if message["type"] == "http.response.start":
|
||||
response_status_code = message.get("status", 0)
|
||||
|
||||
await send(message)
|
||||
|
||||
# 3. 调用下游应用
|
||||
exception_occurred: Optional[Exception] = None
|
||||
try:
|
||||
# 1. 限流插件调用(可选功能)
|
||||
rate_limit_result = await self._call_rate_limit_plugins(request)
|
||||
if rate_limit_result and not rate_limit_result.allowed:
|
||||
# 限流触发,返回429
|
||||
headers = rate_limit_result.headers or {}
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=rate_limit_result.message or "Rate limit exceeded",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# 2. 预处理插件调用
|
||||
await self._call_pre_request_plugins(request)
|
||||
|
||||
# 处理请求
|
||||
response = await call_next(request)
|
||||
|
||||
# 3. 提交关键数据库事务(在返回响应前)
|
||||
# 这确保了 Usage 记录、配额扣减等关键数据在响应返回前持久化
|
||||
try:
|
||||
db = getattr(request.state, "db", None)
|
||||
if isinstance(db, Session):
|
||||
db.commit()
|
||||
except Exception as commit_error:
|
||||
logger.error(f"关键事务提交失败: {commit_error}")
|
||||
try:
|
||||
if isinstance(db, Session):
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
await self._call_error_plugins(request, commit_error, start_time)
|
||||
# 返回 500 错误,因为数据可能不一致
|
||||
response = JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"type": "error",
|
||||
"error": {
|
||||
"type": "database_error",
|
||||
"message": "数据保存失败,请重试",
|
||||
},
|
||||
},
|
||||
)
|
||||
# 跳过后处理插件,直接返回错误响应
|
||||
return response
|
||||
|
||||
# 4. 后处理插件调用(监控等,非关键操作)
|
||||
# 这些操作失败不应影响用户响应
|
||||
await self._call_post_request_plugins(request, response, start_time)
|
||||
|
||||
# 注意:不在此处添加限流响应头,因为在BaseHTTPMiddleware中
|
||||
# 响应返回后修改headers会导致Content-Length不匹配错误
|
||||
# 限流响应头已在返回429错误时正确包含(见上面的HTTPException)
|
||||
|
||||
except RuntimeError as e:
|
||||
if str(e) == "No response returned.":
|
||||
db = getattr(request.state, "db", None)
|
||||
if isinstance(db, Session):
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.error("Downstream handler completed without returning a response")
|
||||
|
||||
await self._call_error_plugins(request, e, start_time)
|
||||
|
||||
if isinstance(db, Session):
|
||||
try:
|
||||
db.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
response = JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"type": "error",
|
||||
"error": {
|
||||
"type": "internal_error",
|
||||
"message": "Internal server error: downstream handler returned no response.",
|
||||
},
|
||||
},
|
||||
)
|
||||
else:
|
||||
exception_to_raise = e
|
||||
|
||||
await self.app(scope, receive, send_wrapper)
|
||||
except Exception as e:
|
||||
# 回滚数据库事务
|
||||
db = getattr(request.state, "db", None)
|
||||
if isinstance(db, Session):
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
exception_occurred = e
|
||||
# 错误处理插件调用
|
||||
await self._call_error_plugins(request, e, start_time)
|
||||
raise
|
||||
finally:
|
||||
# 4. 数据库会话清理(无论成功与否)
|
||||
await self._cleanup_db_session(request, exception_occurred)
|
||||
|
||||
# 尝试提交错误日志
|
||||
if isinstance(db, Session):
|
||||
# 5. 后处理插件调用(仅在成功时)
|
||||
if not exception_occurred and response_status_code > 0:
|
||||
await self._call_post_request_plugins(request, response_status_code, start_time)
|
||||
|
||||
async def _send_rate_limit_response(
|
||||
self, send: Send, result: RateLimitResult
|
||||
) -> None:
|
||||
"""发送 429 限流响应"""
|
||||
import json
|
||||
|
||||
body = json.dumps({
|
||||
"type": "error",
|
||||
"error": {
|
||||
"type": "rate_limit_error",
|
||||
"message": result.message or "Rate limit exceeded",
|
||||
},
|
||||
}).encode("utf-8")
|
||||
|
||||
headers = [(b"content-type", b"application/json")]
|
||||
if result.headers:
|
||||
for key, value in result.headers.items():
|
||||
headers.append((key.lower().encode(), str(value).encode()))
|
||||
|
||||
await send({
|
||||
"type": "http.response.start",
|
||||
"status": 429,
|
||||
"headers": headers,
|
||||
})
|
||||
await send({
|
||||
"type": "http.response.body",
|
||||
"body": body,
|
||||
})
|
||||
|
||||
async def _cleanup_db_session(
|
||||
self, request: Request, exception: Optional[Exception]
|
||||
) -> None:
|
||||
"""清理数据库会话
|
||||
|
||||
事务策略:
|
||||
- 如果 request.state.tx_committed_by_route 为 True,说明路由已自行提交,中间件只负责 close
|
||||
- 否则由中间件统一 commit/rollback
|
||||
|
||||
这避免了双重提交的问题,同时保持向后兼容。
|
||||
"""
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
db = getattr(request.state, "db", None)
|
||||
if not isinstance(db, Session):
|
||||
return
|
||||
|
||||
# 检查是否由路由层已经提交
|
||||
tx_committed_by_route = getattr(request.state, "tx_committed_by_route", False)
|
||||
|
||||
try:
|
||||
if exception is not None:
|
||||
# 发生异常,回滚事务(无论谁负责提交)
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception as rollback_error:
|
||||
logger.debug(f"回滚事务时出错(可忽略): {rollback_error}")
|
||||
elif not tx_committed_by_route:
|
||||
# 正常完成且路由未自行提交,由中间件提交事务
|
||||
try:
|
||||
db.commit()
|
||||
except:
|
||||
pass
|
||||
|
||||
exception_to_raise = e
|
||||
|
||||
except Exception as commit_error:
|
||||
logger.error(f"关键事务提交失败: {commit_error}")
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
# 如果 tx_committed_by_route 为 True,跳过 commit(路由已提交)
|
||||
finally:
|
||||
db = getattr(request.state, "db", None)
|
||||
if isinstance(db, Session):
|
||||
try:
|
||||
db.close()
|
||||
except Exception as close_error:
|
||||
# 连接池会处理连接的回收,这里的异常不应影响响应
|
||||
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
|
||||
|
||||
# 在 finally 块之后处理异常和响应
|
||||
if exception_to_raise:
|
||||
raise exception_to_raise
|
||||
|
||||
return response
|
||||
# 关闭会话,归还连接到连接池
|
||||
try:
|
||||
db.close()
|
||||
except Exception as close_error:
|
||||
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""
|
||||
获取客户端 IP 地址,支持代理头
|
||||
|
||||
注意:此方法信任 X-Forwarded-For 和 X-Real-IP 头,
|
||||
仅当服务部署在可信代理(如 Nginx、CloudFlare)后面时才安全。
|
||||
如果服务直接暴露公网,攻击者可伪造这些头绕过限流。
|
||||
"""
|
||||
# 从配置获取可信代理层数(默认为 1,即信任最近一层代理)
|
||||
trusted_proxy_count = getattr(config, "trusted_proxy_count", 1)
|
||||
|
||||
# 优先从代理头获取真实 IP
|
||||
forwarded_for = request.headers.get("x-forwarded-for")
|
||||
if forwarded_for:
|
||||
# X-Forwarded-For 可能包含多个 IP,取第一个
|
||||
return forwarded_for.split(",")[0].strip()
|
||||
# X-Forwarded-For 格式: "client, proxy1, proxy2"
|
||||
# 从右往左数 trusted_proxy_count 个,取其左边的第一个
|
||||
ips = [ip.strip() for ip in forwarded_for.split(",")]
|
||||
if len(ips) > trusted_proxy_count:
|
||||
return ips[-(trusted_proxy_count + 1)]
|
||||
elif ips:
|
||||
return ips[0]
|
||||
|
||||
real_ip = request.headers.get("x-real-ip")
|
||||
if real_ip:
|
||||
@@ -248,13 +261,11 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
api_key = request.headers.get("x-api-key", "")
|
||||
|
||||
if auth_header.startswith("Bearer "):
|
||||
if auth_header.lower().startswith("bearer "):
|
||||
api_key = auth_header[7:]
|
||||
|
||||
if api_key:
|
||||
# 使用 API Key 的哈希作为限制 key(避免日志泄露完整 key)
|
||||
import hashlib
|
||||
|
||||
key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:16]
|
||||
key = f"llm_api_key:{key_hash}"
|
||||
request.state.rate_limit_key_type = "api_key"
|
||||
@@ -319,7 +330,10 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
)
|
||||
else:
|
||||
# 限流触发,记录日志
|
||||
logger.warning(f"速率限制触发: {getattr(request.state, 'rate_limit_key_type', 'unknown')}")
|
||||
logger.warning(
|
||||
"速率限制触发: {}",
|
||||
getattr(request.state, "rate_limit_key_type", "unknown"),
|
||||
)
|
||||
return result
|
||||
return None
|
||||
except Exception as e:
|
||||
@@ -332,7 +346,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
pass
|
||||
|
||||
async def _call_post_request_plugins(
|
||||
self, request: Request, response: StarletteResponse, start_time: float
|
||||
self, request: Request, status_code: int, start_time: float
|
||||
) -> None:
|
||||
"""调用请求后的插件"""
|
||||
|
||||
@@ -345,8 +359,8 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
monitor_labels = {
|
||||
"method": request.method,
|
||||
"endpoint": request.url.path,
|
||||
"status": str(response.status_code),
|
||||
"status_class": f"{response.status_code // 100}xx",
|
||||
"status": str(status_code),
|
||||
"status_class": f"{status_code // 100}xx",
|
||||
}
|
||||
|
||||
# 记录请求计数
|
||||
@@ -368,6 +382,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
self, request: Request, error: Exception, start_time: float
|
||||
) -> None:
|
||||
"""调用错误处理插件"""
|
||||
from fastapi import HTTPException
|
||||
|
||||
duration = time.time() - start_time
|
||||
|
||||
@@ -380,7 +395,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
error=error,
|
||||
context={
|
||||
"endpoint": f"{request.method} {request.url.path}",
|
||||
"request_id": request.state.request_id,
|
||||
"request_id": getattr(request.state, "request_id", ""),
|
||||
"duration": duration,
|
||||
},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user