Files
Aether/src/middleware/plugin_middleware.py

404 lines
15 KiB
Python
Raw Normal View History

2025-12-10 20:52:44 +08:00
"""
统一的插件中间件 ASGI 实现
2025-12-10 20:52:44 +08:00
负责协调所有插件的调用
注意使用纯 ASGI middleware 而非 BaseHTTPMiddleware
以避免 Starlette 已知的流式响应兼容性问题
2025-12-10 20:52:44 +08:00
"""
import hashlib
2025-12-10 20:52:44 +08:00
import time
from typing import Optional
2025-12-10 20:52:44 +08:00
from starlette.requests import Request
from starlette.types import ASGIApp, Message, Receive, Scope, Send
2025-12-10 20:52:44 +08:00
from src.config import config
from src.core.logger import logger
from src.plugins.manager import get_plugin_manager
from src.plugins.rate_limit.base import RateLimitResult
class PluginMiddleware:
2025-12-10 20:52:44 +08:00
"""
统一的插件调用中间件 ASGI 实现
2025-12-10 20:52:44 +08:00
职责:
- 性能监控
- 限流控制 (可选)
- 数据库会话生命周期管理
2025-12-10 20:52:44 +08:00
注意: 认证由各路由通过 Depends() 显式声明不在中间件层处理
为什么使用纯 ASGI 而非 BaseHTTPMiddleware:
- BaseHTTPMiddleware 会缓冲整个响应体对流式响应不友好
- BaseHTTPMiddleware StreamingResponse 存在已知兼容性问题
- ASGI 可以直接透传流式响应无额外开销
2025-12-10 20:52:44 +08:00
"""
def __init__(self, app: ASGIApp) -> None:
self.app = app
2025-12-10 20:52:44 +08:00
self.plugin_manager = get_plugin_manager()
# 从配置读取速率限制值
self.llm_api_rate_limit = config.llm_api_rate_limit
self.public_api_rate_limit = config.public_api_rate_limit
# 完全跳过限流的路径(静态资源、文档等)
self.skip_rate_limit_paths = [
"/health",
"/docs",
"/redoc",
"/openapi.json",
"/favicon.ico",
"/static/",
"/assets/",
"/api/admin/", # 管理后台已有JWT认证不需要额外限流
"/api/auth/", # 认证端点(由路由层的 IPRateLimiter 处理)
"/api/users/", # 用户端点
"/api/monitoring/", # 监控端点
]
# LLM API 端点(需要特殊的速率限制策略)
self.llm_api_paths = [
"/v1/messages",
"/v1/chat/completions",
"/v1/responses",
"/v1/completions",
]
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""ASGI 入口点"""
if scope["type"] != "http":
# 非 HTTP 请求(如 WebSocket直接透传
await self.app(scope, receive, send)
return
# 构建 Request 对象以便复用现有逻辑
request = Request(scope, receive, send)
2025-12-10 20:52:44 +08:00
# 记录请求开始时间
start_time = time.time()
# 设置 request.state 属性
# 注意Starlette 的 Request 对象总是有 state 属性State 实例)
2025-12-10 20:52:44 +08:00
request.state.request_id = request.headers.get("x-request-id", "")
request.state.start_time = start_time
# 标记:若请求过程中通过 Depends(get_db) 创建了会话,则由本中间件统一管理其生命周期
request.state.db_managed_by_middleware = True
2025-12-10 20:52:44 +08:00
# 1. 限流检查(在调用下游之前)
rate_limit_result = await self._call_rate_limit_plugins(request)
if rate_limit_result and not rate_limit_result.allowed:
# 限流触发返回429
await self._send_rate_limit_response(send, rate_limit_result)
return
2025-12-10 20:52:44 +08:00
# 2. 预处理插件调用
await self._call_pre_request_plugins(request)
2025-12-10 20:52:44 +08:00
# 用于捕获响应状态码
response_status_code: int = 0
2025-12-10 20:52:44 +08:00
async def send_wrapper(message: Message) -> None:
nonlocal response_status_code
2025-12-10 20:52:44 +08:00
if message["type"] == "http.response.start":
response_status_code = message.get("status", 0)
2025-12-10 20:52:44 +08:00
await send(message)
2025-12-10 20:52:44 +08:00
# 3. 调用下游应用
exception_occurred: Optional[Exception] = None
try:
await self.app(scope, receive, send_wrapper)
except Exception as e:
exception_occurred = e
# 错误处理插件调用
await self._call_error_plugins(request, e, start_time)
raise
finally:
# 4. 数据库会话清理(无论成功与否)
await self._cleanup_db_session(request, exception_occurred)
2025-12-10 20:52:44 +08:00
# 5. 后处理插件调用(仅在成功时)
if not exception_occurred and response_status_code > 0:
await self._call_post_request_plugins(request, response_status_code, start_time)
2025-12-10 20:52:44 +08:00
async def _send_rate_limit_response(
self, send: Send, result: RateLimitResult
) -> None:
"""发送 429 限流响应"""
import json
body = json.dumps({
"type": "error",
"error": {
"type": "rate_limit_error",
"message": result.message or "Rate limit exceeded",
},
}).encode("utf-8")
headers = [(b"content-type", b"application/json")]
if result.headers:
for key, value in result.headers.items():
headers.append((key.lower().encode(), str(value).encode()))
await send({
"type": "http.response.start",
"status": 429,
"headers": headers,
})
await send({
"type": "http.response.body",
"body": body,
})
async def _cleanup_db_session(
self, request: Request, exception: Optional[Exception]
) -> None:
"""清理数据库会话
2025-12-10 20:52:44 +08:00
事务策略
- 如果 request.state.tx_committed_by_route True说明路由已自行提交中间件只负责 close
- 否则由中间件统一 commit/rollback
2025-12-10 20:52:44 +08:00
这避免了双重提交的问题同时保持向后兼容
"""
from sqlalchemy.orm import Session
2025-12-10 20:52:44 +08:00
db = getattr(request.state, "db", None)
if not isinstance(db, Session):
return
2025-12-10 20:52:44 +08:00
# 检查是否由路由层已经提交
tx_committed_by_route = getattr(request.state, "tx_committed_by_route", False)
try:
if exception is not None:
# 发生异常,回滚事务(无论谁负责提交)
try:
db.rollback()
except Exception as rollback_error:
logger.debug(f"回滚事务时出错(可忽略): {rollback_error}")
elif not tx_committed_by_route:
# 正常完成且路由未自行提交,由中间件提交事务
2025-12-10 20:52:44 +08:00
try:
db.commit()
except Exception as commit_error:
logger.error(f"关键事务提交失败: {commit_error}")
try:
db.rollback()
except Exception:
pass
# 如果 tx_committed_by_route 为 True跳过 commit路由已提交
2025-12-10 20:52:44 +08:00
finally:
# 关闭会话,归还连接到连接池
try:
db.close()
except Exception as close_error:
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
2025-12-10 20:52:44 +08:00
def _get_client_ip(self, request: Request) -> str:
"""
获取客户端 IP 地址支持代理头
注意此方法信任 X-Forwarded-For X-Real-IP
仅当服务部署在可信代理 NginxCloudFlare后面时才安全
如果服务直接暴露公网攻击者可伪造这些头绕过限流
2025-12-10 20:52:44 +08:00
"""
# 从配置获取可信代理层数(默认为 1即信任最近一层代理
trusted_proxy_count = getattr(config, "trusted_proxy_count", 1)
2025-12-10 20:52:44 +08:00
# 优先从代理头获取真实 IP
forwarded_for = request.headers.get("x-forwarded-for")
if forwarded_for:
# X-Forwarded-For 格式: "client, proxy1, proxy2"
# 从右往左数 trusted_proxy_count 个,取其左边的第一个
ips = [ip.strip() for ip in forwarded_for.split(",")]
if len(ips) > trusted_proxy_count:
return ips[-(trusted_proxy_count + 1)]
elif ips:
return ips[0]
2025-12-10 20:52:44 +08:00
real_ip = request.headers.get("x-real-ip")
if real_ip:
return real_ip.strip()
# 回退到直连 IP
if request.client:
return request.client.host
return "unknown"
def _is_llm_api_path(self, path: str) -> bool:
"""检查是否为 LLM API 端点"""
for llm_path in self.llm_api_paths:
if path.startswith(llm_path):
return True
return False
async def _get_rate_limit_key_and_config(
self, request: Request
2025-12-10 20:52:44 +08:00
) -> tuple[Optional[str], Optional[int]]:
"""
获取速率限制的key和配置
策略说明:
- /v1/messages, /v1/chat/completions LLM API: API Key 限流
- /api/public/* 端点: 使用服务器级别 IP 限制
- /api/admin/* 端点: 跳过 skip_rate_limit_paths 中跳过
- /api/auth/* 端点: 跳过由路由层的 IPRateLimiter 处理
Returns:
(key, rate_limit_value) - key用于标识限制对象rate_limit_value是限制值
"""
path = request.url.path
# LLM API 端点: 按 API Key 或 IP 限流
if self._is_llm_api_path(path):
# 尝试从请求头获取 API Key
auth_header = request.headers.get("authorization", "")
api_key = request.headers.get("x-api-key", "")
if auth_header.lower().startswith("bearer "):
2025-12-10 20:52:44 +08:00
api_key = auth_header[7:]
if api_key:
# 使用 API Key 的哈希作为限制 key避免日志泄露完整 key
key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:16]
key = f"llm_api_key:{key_hash}"
request.state.rate_limit_key_type = "api_key"
else:
# 无 API Key 时使用 IP 限制(更严格)
client_ip = self._get_client_ip(request)
key = f"llm_ip:{client_ip}"
request.state.rate_limit_key_type = "ip"
rate_limit = self.llm_api_rate_limit
request.state.rate_limit_value = rate_limit
return key, rate_limit
# /api/public/* 端点: 使用服务器级别 IP 地址作为限制 key
if path.startswith("/api/public/"):
client_ip = self._get_client_ip(request)
key = f"public_ip:{client_ip}"
rate_limit = self.public_api_rate_limit
request.state.rate_limit_key_type = "public_ip"
request.state.rate_limit_value = rate_limit
return key, rate_limit
# 其他端点不应用速率限制(或已在 skip_rate_limit_paths 中跳过)
return None, None
async def _call_rate_limit_plugins(self, request: Request) -> Optional[RateLimitResult]:
"""调用限流插件"""
# 跳过不需要限流的路径(支持前缀匹配)
for skip_path in self.skip_rate_limit_paths:
if request.url.path == skip_path or request.url.path.startswith(skip_path):
return None
# 获取限流插件
rate_limit_plugin = self.plugin_manager.get_plugin("rate_limit")
if not rate_limit_plugin or not rate_limit_plugin.enabled:
# 如果没有限流插件,允许通过
return None
# 获取速率限制的 key 和配置
key, rate_limit_value = await self._get_rate_limit_key_and_config(request)
2025-12-10 20:52:44 +08:00
if not key:
# 不需要限流的端点(如未分类路径),静默跳过
return None
try:
# 检查速率限制,传入数据库配置的限制值
result = await rate_limit_plugin.check_limit(
key=key,
endpoint=request.url.path,
method=request.method,
rate_limit=rate_limit_value, # 传入配置的限制值
2025-12-10 20:52:44 +08:00
)
# 类型检查确保返回的是RateLimitResult类型
if isinstance(result, RateLimitResult):
# 如果检查通过,实际消耗令牌
if result.allowed:
await rate_limit_plugin.consume(
key=key,
amount=1,
rate_limit=rate_limit_value,
)
else:
# 限流触发,记录日志
logger.warning(
"速率限制触发: {}",
getattr(request.state, "rate_limit_key_type", "unknown"),
)
2025-12-10 20:52:44 +08:00
return result
return None
except Exception as e:
logger.error(f"Rate limit error: {e}")
# 发生错误时允许请求通过
return None
async def _call_pre_request_plugins(self, request: Request) -> None:
"""调用请求前的插件(当前保留扩展点)"""
pass
async def _call_post_request_plugins(
self, request: Request, status_code: int, start_time: float
2025-12-10 20:52:44 +08:00
) -> None:
"""调用请求后的插件"""
duration = time.time() - start_time
# 监控插件 - 记录指标
monitor_plugin = self.plugin_manager.get_plugin("monitor")
if monitor_plugin and monitor_plugin.enabled:
try:
monitor_labels = {
"method": request.method,
"endpoint": request.url.path,
"status": str(status_code),
"status_class": f"{status_code // 100}xx",
2025-12-10 20:52:44 +08:00
}
# 记录请求计数
await monitor_plugin.increment(
"http_requests_total",
labels=monitor_labels,
)
# 记录请求时长
await monitor_plugin.timing(
"http_request_duration",
duration,
labels=monitor_labels,
)
except Exception as e:
logger.error(f"Monitor plugin failed: {e}")
async def _call_error_plugins(
self, request: Request, error: Exception, start_time: float
) -> None:
"""调用错误处理插件"""
from fastapi import HTTPException
2025-12-10 20:52:44 +08:00
duration = time.time() - start_time
# 通知插件 - 发送严重错误通知
if not isinstance(error, HTTPException) or error.status_code >= 500:
notification_plugin = self.plugin_manager.get_plugin("notification")
if notification_plugin and notification_plugin.enabled:
try:
await notification_plugin.send_error(
error=error,
context={
"endpoint": f"{request.method} {request.url.path}",
"request_id": getattr(request.state, "request_id", ""),
2025-12-10 20:52:44 +08:00
"duration": duration,
},
)
except Exception as e:
logger.error(f"Notification plugin failed: {e}")