mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 00:02:28 +08:00
- Replace BaseHTTPMiddleware with pure ASGI implementation in plugin middleware for better streaming response handling - Add trusted proxy count configuration for client IP extraction in reverse proxy environments - Implement audit log cleanup scheduler with configurable retention period - Replace plaintext token logging with SHA256 hash fingerprints for security - Fix database session lifecycle management in middleware - Improve request tracing and error logging throughout the system - Add comprehensive tests for pipeline architecture
404 lines
15 KiB
Python
404 lines
15 KiB
Python
"""
|
||
统一的插件中间件(纯 ASGI 实现)
|
||
负责协调所有插件的调用
|
||
|
||
注意:使用纯 ASGI middleware 而非 BaseHTTPMiddleware,
|
||
以避免 Starlette 已知的流式响应兼容性问题。
|
||
"""
|
||
|
||
import hashlib
|
||
import time
|
||
from typing import Optional
|
||
|
||
from starlette.requests import Request
|
||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||
|
||
from src.config import config
|
||
from src.core.logger import logger
|
||
from src.plugins.manager import get_plugin_manager
|
||
from src.plugins.rate_limit.base import RateLimitResult
|
||
|
||
|
||
class PluginMiddleware:
|
||
"""
|
||
统一的插件调用中间件(纯 ASGI 实现)
|
||
|
||
职责:
|
||
- 性能监控
|
||
- 限流控制 (可选)
|
||
- 数据库会话生命周期管理
|
||
|
||
注意: 认证由各路由通过 Depends() 显式声明,不在中间件层处理
|
||
|
||
为什么使用纯 ASGI 而非 BaseHTTPMiddleware:
|
||
- BaseHTTPMiddleware 会缓冲整个响应体,对流式响应不友好
|
||
- BaseHTTPMiddleware 与 StreamingResponse 存在已知兼容性问题
|
||
- 纯 ASGI 可以直接透传流式响应,无额外开销
|
||
"""
|
||
|
||
def __init__(self, app: ASGIApp) -> None:
|
||
self.app = app
|
||
self.plugin_manager = get_plugin_manager()
|
||
|
||
# 从配置读取速率限制值
|
||
self.llm_api_rate_limit = config.llm_api_rate_limit
|
||
self.public_api_rate_limit = config.public_api_rate_limit
|
||
|
||
# 完全跳过限流的路径(静态资源、文档等)
|
||
self.skip_rate_limit_paths = [
|
||
"/health",
|
||
"/docs",
|
||
"/redoc",
|
||
"/openapi.json",
|
||
"/favicon.ico",
|
||
"/static/",
|
||
"/assets/",
|
||
"/api/admin/", # 管理后台已有JWT认证,不需要额外限流
|
||
"/api/auth/", # 认证端点(由路由层的 IPRateLimiter 处理)
|
||
"/api/users/", # 用户端点
|
||
"/api/monitoring/", # 监控端点
|
||
]
|
||
|
||
# LLM API 端点(需要特殊的速率限制策略)
|
||
self.llm_api_paths = [
|
||
"/v1/messages",
|
||
"/v1/chat/completions",
|
||
"/v1/responses",
|
||
"/v1/completions",
|
||
]
|
||
|
||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||
"""ASGI 入口点"""
|
||
if scope["type"] != "http":
|
||
# 非 HTTP 请求(如 WebSocket)直接透传
|
||
await self.app(scope, receive, send)
|
||
return
|
||
|
||
# 构建 Request 对象以便复用现有逻辑
|
||
request = Request(scope, receive, send)
|
||
|
||
# 记录请求开始时间
|
||
start_time = time.time()
|
||
|
||
# 设置 request.state 属性
|
||
# 注意:Starlette 的 Request 对象总是有 state 属性(State 实例)
|
||
request.state.request_id = request.headers.get("x-request-id", "")
|
||
request.state.start_time = start_time
|
||
# 标记:若请求过程中通过 Depends(get_db) 创建了会话,则由本中间件统一管理其生命周期
|
||
request.state.db_managed_by_middleware = True
|
||
|
||
# 1. 限流检查(在调用下游之前)
|
||
rate_limit_result = await self._call_rate_limit_plugins(request)
|
||
if rate_limit_result and not rate_limit_result.allowed:
|
||
# 限流触发,返回429
|
||
await self._send_rate_limit_response(send, rate_limit_result)
|
||
return
|
||
|
||
# 2. 预处理插件调用
|
||
await self._call_pre_request_plugins(request)
|
||
|
||
# 用于捕获响应状态码
|
||
response_status_code: int = 0
|
||
|
||
async def send_wrapper(message: Message) -> None:
|
||
nonlocal response_status_code
|
||
|
||
if message["type"] == "http.response.start":
|
||
response_status_code = message.get("status", 0)
|
||
|
||
await send(message)
|
||
|
||
# 3. 调用下游应用
|
||
exception_occurred: Optional[Exception] = None
|
||
try:
|
||
await self.app(scope, receive, send_wrapper)
|
||
except Exception as e:
|
||
exception_occurred = e
|
||
# 错误处理插件调用
|
||
await self._call_error_plugins(request, e, start_time)
|
||
raise
|
||
finally:
|
||
# 4. 数据库会话清理(无论成功与否)
|
||
await self._cleanup_db_session(request, exception_occurred)
|
||
|
||
# 5. 后处理插件调用(仅在成功时)
|
||
if not exception_occurred and response_status_code > 0:
|
||
await self._call_post_request_plugins(request, response_status_code, start_time)
|
||
|
||
async def _send_rate_limit_response(
|
||
self, send: Send, result: RateLimitResult
|
||
) -> None:
|
||
"""发送 429 限流响应"""
|
||
import json
|
||
|
||
body = json.dumps({
|
||
"type": "error",
|
||
"error": {
|
||
"type": "rate_limit_error",
|
||
"message": result.message or "Rate limit exceeded",
|
||
},
|
||
}).encode("utf-8")
|
||
|
||
headers = [(b"content-type", b"application/json")]
|
||
if result.headers:
|
||
for key, value in result.headers.items():
|
||
headers.append((key.lower().encode(), str(value).encode()))
|
||
|
||
await send({
|
||
"type": "http.response.start",
|
||
"status": 429,
|
||
"headers": headers,
|
||
})
|
||
await send({
|
||
"type": "http.response.body",
|
||
"body": body,
|
||
})
|
||
|
||
async def _cleanup_db_session(
|
||
self, request: Request, exception: Optional[Exception]
|
||
) -> None:
|
||
"""清理数据库会话
|
||
|
||
事务策略:
|
||
- 如果 request.state.tx_committed_by_route 为 True,说明路由已自行提交,中间件只负责 close
|
||
- 否则由中间件统一 commit/rollback
|
||
|
||
这避免了双重提交的问题,同时保持向后兼容。
|
||
"""
|
||
from sqlalchemy.orm import Session
|
||
|
||
db = getattr(request.state, "db", None)
|
||
if not isinstance(db, Session):
|
||
return
|
||
|
||
# 检查是否由路由层已经提交
|
||
tx_committed_by_route = getattr(request.state, "tx_committed_by_route", False)
|
||
|
||
try:
|
||
if exception is not None:
|
||
# 发生异常,回滚事务(无论谁负责提交)
|
||
try:
|
||
db.rollback()
|
||
except Exception as rollback_error:
|
||
logger.debug(f"回滚事务时出错(可忽略): {rollback_error}")
|
||
elif not tx_committed_by_route:
|
||
# 正常完成且路由未自行提交,由中间件提交事务
|
||
try:
|
||
db.commit()
|
||
except Exception as commit_error:
|
||
logger.error(f"关键事务提交失败: {commit_error}")
|
||
try:
|
||
db.rollback()
|
||
except Exception:
|
||
pass
|
||
# 如果 tx_committed_by_route 为 True,跳过 commit(路由已提交)
|
||
finally:
|
||
# 关闭会话,归还连接到连接池
|
||
try:
|
||
db.close()
|
||
except Exception as close_error:
|
||
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
|
||
|
||
def _get_client_ip(self, request: Request) -> str:
|
||
"""
|
||
获取客户端 IP 地址,支持代理头
|
||
|
||
注意:此方法信任 X-Forwarded-For 和 X-Real-IP 头,
|
||
仅当服务部署在可信代理(如 Nginx、CloudFlare)后面时才安全。
|
||
如果服务直接暴露公网,攻击者可伪造这些头绕过限流。
|
||
"""
|
||
# 从配置获取可信代理层数(默认为 1,即信任最近一层代理)
|
||
trusted_proxy_count = getattr(config, "trusted_proxy_count", 1)
|
||
|
||
# 优先从代理头获取真实 IP
|
||
forwarded_for = request.headers.get("x-forwarded-for")
|
||
if forwarded_for:
|
||
# X-Forwarded-For 格式: "client, proxy1, proxy2"
|
||
# 从右往左数 trusted_proxy_count 个,取其左边的第一个
|
||
ips = [ip.strip() for ip in forwarded_for.split(",")]
|
||
if len(ips) > trusted_proxy_count:
|
||
return ips[-(trusted_proxy_count + 1)]
|
||
elif ips:
|
||
return ips[0]
|
||
|
||
real_ip = request.headers.get("x-real-ip")
|
||
if real_ip:
|
||
return real_ip.strip()
|
||
|
||
# 回退到直连 IP
|
||
if request.client:
|
||
return request.client.host
|
||
|
||
return "unknown"
|
||
|
||
def _is_llm_api_path(self, path: str) -> bool:
|
||
"""检查是否为 LLM API 端点"""
|
||
for llm_path in self.llm_api_paths:
|
||
if path.startswith(llm_path):
|
||
return True
|
||
return False
|
||
|
||
async def _get_rate_limit_key_and_config(
|
||
self, request: Request
|
||
) -> tuple[Optional[str], Optional[int]]:
|
||
"""
|
||
获取速率限制的key和配置
|
||
|
||
策略说明:
|
||
- /v1/messages, /v1/chat/completions 等 LLM API: 按 API Key 限流
|
||
- /api/public/* 端点: 使用服务器级别 IP 限制
|
||
- /api/admin/* 端点: 跳过(在 skip_rate_limit_paths 中跳过)
|
||
- /api/auth/* 端点: 跳过(由路由层的 IPRateLimiter 处理)
|
||
|
||
Returns:
|
||
(key, rate_limit_value) - key用于标识限制对象,rate_limit_value是限制值
|
||
"""
|
||
path = request.url.path
|
||
|
||
# LLM API 端点: 按 API Key 或 IP 限流
|
||
if self._is_llm_api_path(path):
|
||
# 尝试从请求头获取 API Key
|
||
auth_header = request.headers.get("authorization", "")
|
||
api_key = request.headers.get("x-api-key", "")
|
||
|
||
if auth_header.lower().startswith("bearer "):
|
||
api_key = auth_header[7:]
|
||
|
||
if api_key:
|
||
# 使用 API Key 的哈希作为限制 key(避免日志泄露完整 key)
|
||
key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:16]
|
||
key = f"llm_api_key:{key_hash}"
|
||
request.state.rate_limit_key_type = "api_key"
|
||
else:
|
||
# 无 API Key 时使用 IP 限制(更严格)
|
||
client_ip = self._get_client_ip(request)
|
||
key = f"llm_ip:{client_ip}"
|
||
request.state.rate_limit_key_type = "ip"
|
||
|
||
rate_limit = self.llm_api_rate_limit
|
||
request.state.rate_limit_value = rate_limit
|
||
return key, rate_limit
|
||
|
||
# /api/public/* 端点: 使用服务器级别 IP 地址作为限制 key
|
||
if path.startswith("/api/public/"):
|
||
client_ip = self._get_client_ip(request)
|
||
key = f"public_ip:{client_ip}"
|
||
rate_limit = self.public_api_rate_limit
|
||
request.state.rate_limit_key_type = "public_ip"
|
||
request.state.rate_limit_value = rate_limit
|
||
return key, rate_limit
|
||
|
||
# 其他端点不应用速率限制(或已在 skip_rate_limit_paths 中跳过)
|
||
return None, None
|
||
|
||
async def _call_rate_limit_plugins(self, request: Request) -> Optional[RateLimitResult]:
|
||
"""调用限流插件"""
|
||
|
||
# 跳过不需要限流的路径(支持前缀匹配)
|
||
for skip_path in self.skip_rate_limit_paths:
|
||
if request.url.path == skip_path or request.url.path.startswith(skip_path):
|
||
return None
|
||
|
||
# 获取限流插件
|
||
rate_limit_plugin = self.plugin_manager.get_plugin("rate_limit")
|
||
if not rate_limit_plugin or not rate_limit_plugin.enabled:
|
||
# 如果没有限流插件,允许通过
|
||
return None
|
||
|
||
# 获取速率限制的 key 和配置
|
||
key, rate_limit_value = await self._get_rate_limit_key_and_config(request)
|
||
if not key:
|
||
# 不需要限流的端点(如未分类路径),静默跳过
|
||
return None
|
||
|
||
try:
|
||
# 检查速率限制,传入数据库配置的限制值
|
||
result = await rate_limit_plugin.check_limit(
|
||
key=key,
|
||
endpoint=request.url.path,
|
||
method=request.method,
|
||
rate_limit=rate_limit_value, # 传入配置的限制值
|
||
)
|
||
# 类型检查:确保返回的是RateLimitResult类型
|
||
if isinstance(result, RateLimitResult):
|
||
# 如果检查通过,实际消耗令牌
|
||
if result.allowed:
|
||
await rate_limit_plugin.consume(
|
||
key=key,
|
||
amount=1,
|
||
rate_limit=rate_limit_value,
|
||
)
|
||
else:
|
||
# 限流触发,记录日志
|
||
logger.warning(
|
||
"速率限制触发: {}",
|
||
getattr(request.state, "rate_limit_key_type", "unknown"),
|
||
)
|
||
return result
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Rate limit error: {e}")
|
||
# 发生错误时允许请求通过
|
||
return None
|
||
|
||
async def _call_pre_request_plugins(self, request: Request) -> None:
|
||
"""调用请求前的插件(当前保留扩展点)"""
|
||
pass
|
||
|
||
async def _call_post_request_plugins(
|
||
self, request: Request, status_code: int, start_time: float
|
||
) -> None:
|
||
"""调用请求后的插件"""
|
||
|
||
duration = time.time() - start_time
|
||
|
||
# 监控插件 - 记录指标
|
||
monitor_plugin = self.plugin_manager.get_plugin("monitor")
|
||
if monitor_plugin and monitor_plugin.enabled:
|
||
try:
|
||
monitor_labels = {
|
||
"method": request.method,
|
||
"endpoint": request.url.path,
|
||
"status": str(status_code),
|
||
"status_class": f"{status_code // 100}xx",
|
||
}
|
||
|
||
# 记录请求计数
|
||
await monitor_plugin.increment(
|
||
"http_requests_total",
|
||
labels=monitor_labels,
|
||
)
|
||
|
||
# 记录请求时长
|
||
await monitor_plugin.timing(
|
||
"http_request_duration",
|
||
duration,
|
||
labels=monitor_labels,
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Monitor plugin failed: {e}")
|
||
|
||
async def _call_error_plugins(
|
||
self, request: Request, error: Exception, start_time: float
|
||
) -> None:
|
||
"""调用错误处理插件"""
|
||
from fastapi import HTTPException
|
||
|
||
duration = time.time() - start_time
|
||
|
||
# 通知插件 - 发送严重错误通知
|
||
if not isinstance(error, HTTPException) or error.status_code >= 500:
|
||
notification_plugin = self.plugin_manager.get_plugin("notification")
|
||
if notification_plugin and notification_plugin.enabled:
|
||
try:
|
||
await notification_plugin.send_error(
|
||
error=error,
|
||
context={
|
||
"endpoint": f"{request.method} {request.url.path}",
|
||
"request_id": getattr(request.state, "request_id", ""),
|
||
"duration": duration,
|
||
},
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Notification plugin failed: {e}")
|