Files
Aether/src/middleware/plugin_middleware.py
fawney19 7b932d7afb refactor: optimize middleware with pure ASGI implementation and enhance security measures
- Replace BaseHTTPMiddleware with pure ASGI implementation in plugin middleware for better streaming response handling
- Add trusted proxy count configuration for client IP extraction in reverse proxy environments
- Implement audit log cleanup scheduler with configurable retention period
- Replace plaintext token logging with SHA256 hash fingerprints for security
- Fix database session lifecycle management in middleware
- Improve request tracing and error logging throughout the system
- Add comprehensive tests for pipeline architecture
2025-12-18 19:07:20 +08:00

404 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
统一的插件中间件(纯 ASGI 实现)
负责协调所有插件的调用
注意:使用纯 ASGI middleware 而非 BaseHTTPMiddleware
以避免 Starlette 已知的流式响应兼容性问题。
"""
import hashlib
import time
from typing import Optional
from starlette.requests import Request
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from src.config import config
from src.core.logger import logger
from src.plugins.manager import get_plugin_manager
from src.plugins.rate_limit.base import RateLimitResult
class PluginMiddleware:
"""
统一的插件调用中间件(纯 ASGI 实现)
职责:
- 性能监控
- 限流控制 (可选)
- 数据库会话生命周期管理
注意: 认证由各路由通过 Depends() 显式声明,不在中间件层处理
为什么使用纯 ASGI 而非 BaseHTTPMiddleware:
- BaseHTTPMiddleware 会缓冲整个响应体,对流式响应不友好
- BaseHTTPMiddleware 与 StreamingResponse 存在已知兼容性问题
- 纯 ASGI 可以直接透传流式响应,无额外开销
"""
def __init__(self, app: ASGIApp) -> None:
self.app = app
self.plugin_manager = get_plugin_manager()
# 从配置读取速率限制值
self.llm_api_rate_limit = config.llm_api_rate_limit
self.public_api_rate_limit = config.public_api_rate_limit
# 完全跳过限流的路径(静态资源、文档等)
self.skip_rate_limit_paths = [
"/health",
"/docs",
"/redoc",
"/openapi.json",
"/favicon.ico",
"/static/",
"/assets/",
"/api/admin/", # 管理后台已有JWT认证不需要额外限流
"/api/auth/", # 认证端点(由路由层的 IPRateLimiter 处理)
"/api/users/", # 用户端点
"/api/monitoring/", # 监控端点
]
# LLM API 端点(需要特殊的速率限制策略)
self.llm_api_paths = [
"/v1/messages",
"/v1/chat/completions",
"/v1/responses",
"/v1/completions",
]
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""ASGI 入口点"""
if scope["type"] != "http":
# 非 HTTP 请求(如 WebSocket直接透传
await self.app(scope, receive, send)
return
# 构建 Request 对象以便复用现有逻辑
request = Request(scope, receive, send)
# 记录请求开始时间
start_time = time.time()
# 设置 request.state 属性
# 注意Starlette 的 Request 对象总是有 state 属性State 实例)
request.state.request_id = request.headers.get("x-request-id", "")
request.state.start_time = start_time
# 标记:若请求过程中通过 Depends(get_db) 创建了会话,则由本中间件统一管理其生命周期
request.state.db_managed_by_middleware = True
# 1. 限流检查(在调用下游之前)
rate_limit_result = await self._call_rate_limit_plugins(request)
if rate_limit_result and not rate_limit_result.allowed:
# 限流触发返回429
await self._send_rate_limit_response(send, rate_limit_result)
return
# 2. 预处理插件调用
await self._call_pre_request_plugins(request)
# 用于捕获响应状态码
response_status_code: int = 0
async def send_wrapper(message: Message) -> None:
nonlocal response_status_code
if message["type"] == "http.response.start":
response_status_code = message.get("status", 0)
await send(message)
# 3. 调用下游应用
exception_occurred: Optional[Exception] = None
try:
await self.app(scope, receive, send_wrapper)
except Exception as e:
exception_occurred = e
# 错误处理插件调用
await self._call_error_plugins(request, e, start_time)
raise
finally:
# 4. 数据库会话清理(无论成功与否)
await self._cleanup_db_session(request, exception_occurred)
# 5. 后处理插件调用(仅在成功时)
if not exception_occurred and response_status_code > 0:
await self._call_post_request_plugins(request, response_status_code, start_time)
async def _send_rate_limit_response(
self, send: Send, result: RateLimitResult
) -> None:
"""发送 429 限流响应"""
import json
body = json.dumps({
"type": "error",
"error": {
"type": "rate_limit_error",
"message": result.message or "Rate limit exceeded",
},
}).encode("utf-8")
headers = [(b"content-type", b"application/json")]
if result.headers:
for key, value in result.headers.items():
headers.append((key.lower().encode(), str(value).encode()))
await send({
"type": "http.response.start",
"status": 429,
"headers": headers,
})
await send({
"type": "http.response.body",
"body": body,
})
async def _cleanup_db_session(
self, request: Request, exception: Optional[Exception]
) -> None:
"""清理数据库会话
事务策略:
- 如果 request.state.tx_committed_by_route 为 True说明路由已自行提交中间件只负责 close
- 否则由中间件统一 commit/rollback
这避免了双重提交的问题,同时保持向后兼容。
"""
from sqlalchemy.orm import Session
db = getattr(request.state, "db", None)
if not isinstance(db, Session):
return
# 检查是否由路由层已经提交
tx_committed_by_route = getattr(request.state, "tx_committed_by_route", False)
try:
if exception is not None:
# 发生异常,回滚事务(无论谁负责提交)
try:
db.rollback()
except Exception as rollback_error:
logger.debug(f"回滚事务时出错(可忽略): {rollback_error}")
elif not tx_committed_by_route:
# 正常完成且路由未自行提交,由中间件提交事务
try:
db.commit()
except Exception as commit_error:
logger.error(f"关键事务提交失败: {commit_error}")
try:
db.rollback()
except Exception:
pass
# 如果 tx_committed_by_route 为 True跳过 commit路由已提交
finally:
# 关闭会话,归还连接到连接池
try:
db.close()
except Exception as close_error:
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
def _get_client_ip(self, request: Request) -> str:
"""
获取客户端 IP 地址,支持代理头
注意:此方法信任 X-Forwarded-For 和 X-Real-IP 头,
仅当服务部署在可信代理(如 Nginx、CloudFlare后面时才安全。
如果服务直接暴露公网,攻击者可伪造这些头绕过限流。
"""
# 从配置获取可信代理层数(默认为 1即信任最近一层代理
trusted_proxy_count = getattr(config, "trusted_proxy_count", 1)
# 优先从代理头获取真实 IP
forwarded_for = request.headers.get("x-forwarded-for")
if forwarded_for:
# X-Forwarded-For 格式: "client, proxy1, proxy2"
# 从右往左数 trusted_proxy_count 个,取其左边的第一个
ips = [ip.strip() for ip in forwarded_for.split(",")]
if len(ips) > trusted_proxy_count:
return ips[-(trusted_proxy_count + 1)]
elif ips:
return ips[0]
real_ip = request.headers.get("x-real-ip")
if real_ip:
return real_ip.strip()
# 回退到直连 IP
if request.client:
return request.client.host
return "unknown"
def _is_llm_api_path(self, path: str) -> bool:
"""检查是否为 LLM API 端点"""
for llm_path in self.llm_api_paths:
if path.startswith(llm_path):
return True
return False
async def _get_rate_limit_key_and_config(
self, request: Request
) -> tuple[Optional[str], Optional[int]]:
"""
获取速率限制的key和配置
策略说明:
- /v1/messages, /v1/chat/completions 等 LLM API: 按 API Key 限流
- /api/public/* 端点: 使用服务器级别 IP 限制
- /api/admin/* 端点: 跳过(在 skip_rate_limit_paths 中跳过)
- /api/auth/* 端点: 跳过(由路由层的 IPRateLimiter 处理)
Returns:
(key, rate_limit_value) - key用于标识限制对象rate_limit_value是限制值
"""
path = request.url.path
# LLM API 端点: 按 API Key 或 IP 限流
if self._is_llm_api_path(path):
# 尝试从请求头获取 API Key
auth_header = request.headers.get("authorization", "")
api_key = request.headers.get("x-api-key", "")
if auth_header.lower().startswith("bearer "):
api_key = auth_header[7:]
if api_key:
# 使用 API Key 的哈希作为限制 key避免日志泄露完整 key
key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:16]
key = f"llm_api_key:{key_hash}"
request.state.rate_limit_key_type = "api_key"
else:
# 无 API Key 时使用 IP 限制(更严格)
client_ip = self._get_client_ip(request)
key = f"llm_ip:{client_ip}"
request.state.rate_limit_key_type = "ip"
rate_limit = self.llm_api_rate_limit
request.state.rate_limit_value = rate_limit
return key, rate_limit
# /api/public/* 端点: 使用服务器级别 IP 地址作为限制 key
if path.startswith("/api/public/"):
client_ip = self._get_client_ip(request)
key = f"public_ip:{client_ip}"
rate_limit = self.public_api_rate_limit
request.state.rate_limit_key_type = "public_ip"
request.state.rate_limit_value = rate_limit
return key, rate_limit
# 其他端点不应用速率限制(或已在 skip_rate_limit_paths 中跳过)
return None, None
async def _call_rate_limit_plugins(self, request: Request) -> Optional[RateLimitResult]:
"""调用限流插件"""
# 跳过不需要限流的路径(支持前缀匹配)
for skip_path in self.skip_rate_limit_paths:
if request.url.path == skip_path or request.url.path.startswith(skip_path):
return None
# 获取限流插件
rate_limit_plugin = self.plugin_manager.get_plugin("rate_limit")
if not rate_limit_plugin or not rate_limit_plugin.enabled:
# 如果没有限流插件,允许通过
return None
# 获取速率限制的 key 和配置
key, rate_limit_value = await self._get_rate_limit_key_and_config(request)
if not key:
# 不需要限流的端点(如未分类路径),静默跳过
return None
try:
# 检查速率限制,传入数据库配置的限制值
result = await rate_limit_plugin.check_limit(
key=key,
endpoint=request.url.path,
method=request.method,
rate_limit=rate_limit_value, # 传入配置的限制值
)
# 类型检查确保返回的是RateLimitResult类型
if isinstance(result, RateLimitResult):
# 如果检查通过,实际消耗令牌
if result.allowed:
await rate_limit_plugin.consume(
key=key,
amount=1,
rate_limit=rate_limit_value,
)
else:
# 限流触发,记录日志
logger.warning(
"速率限制触发: {}",
getattr(request.state, "rate_limit_key_type", "unknown"),
)
return result
return None
except Exception as e:
logger.error(f"Rate limit error: {e}")
# 发生错误时允许请求通过
return None
async def _call_pre_request_plugins(self, request: Request) -> None:
"""调用请求前的插件(当前保留扩展点)"""
pass
async def _call_post_request_plugins(
self, request: Request, status_code: int, start_time: float
) -> None:
"""调用请求后的插件"""
duration = time.time() - start_time
# 监控插件 - 记录指标
monitor_plugin = self.plugin_manager.get_plugin("monitor")
if monitor_plugin and monitor_plugin.enabled:
try:
monitor_labels = {
"method": request.method,
"endpoint": request.url.path,
"status": str(status_code),
"status_class": f"{status_code // 100}xx",
}
# 记录请求计数
await monitor_plugin.increment(
"http_requests_total",
labels=monitor_labels,
)
# 记录请求时长
await monitor_plugin.timing(
"http_request_duration",
duration,
labels=monitor_labels,
)
except Exception as e:
logger.error(f"Monitor plugin failed: {e}")
async def _call_error_plugins(
self, request: Request, error: Exception, start_time: float
) -> None:
"""调用错误处理插件"""
from fastapi import HTTPException
duration = time.time() - start_time
# 通知插件 - 发送严重错误通知
if not isinstance(error, HTTPException) or error.status_code >= 500:
notification_plugin = self.plugin_manager.get_plugin("notification")
if notification_plugin and notification_plugin.enabled:
try:
await notification_plugin.send_error(
error=error,
context={
"endpoint": f"{request.method} {request.url.path}",
"request_id": getattr(request.state, "request_id", ""),
"duration": duration,
},
)
except Exception as e:
logger.error(f"Notification plugin failed: {e}")