mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-08 10:42:29 +08:00
refactor: optimize database session lifecycle and middleware architecture
- Improve database pool capacity logging with detailed configuration parameters - Optimize database session dependency injection with middleware-managed lifecycle - Simplify plugin middleware by delegating session creation to FastAPI dependencies - Fix import path in auth routes (relative to absolute) - Add safety checks for database session management across middleware exception handlers - Ensure session cleanup only when not managed by middleware (avoid premature cleanup)
This commit is contained in:
@@ -14,7 +14,6 @@ from starlette.responses import Response as StarletteResponse
|
||||
|
||||
from src.config import config
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.plugins.manager import get_plugin_manager
|
||||
from src.plugins.rate_limit.base import RateLimitResult
|
||||
|
||||
@@ -71,26 +70,13 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
start_time = time.time()
|
||||
request.state.request_id = request.headers.get("x-request-id", "")
|
||||
request.state.start_time = start_time
|
||||
# 标记:若请求过程中通过 Depends(get_db) 创建了会话,则由本中间件统一管理其生命周期
|
||||
request.state.db_managed_by_middleware = True
|
||||
|
||||
# 从 request.app 获取 FastAPI 应用实例(而不是从 __init__ 的 app 参数)
|
||||
# 这样才能访问到真正的 FastAPI 实例和其 dependency_overrides
|
||||
db_func = get_db
|
||||
if hasattr(request, "app") and hasattr(request.app, "dependency_overrides"):
|
||||
if get_db in request.app.dependency_overrides:
|
||||
db_func = request.app.dependency_overrides[get_db]
|
||||
logger.debug("Using overridden get_db from app.dependency_overrides")
|
||||
|
||||
# 创建数据库会话供需要的插件或后续处理使用
|
||||
db_gen = db_func()
|
||||
db = None
|
||||
response = None
|
||||
exception_to_raise = None
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
db = next(db_gen)
|
||||
request.state.db = db
|
||||
|
||||
# 1. 限流插件调用(可选功能)
|
||||
rate_limit_result = await self._call_rate_limit_plugins(request)
|
||||
if rate_limit_result and not rate_limit_result.allowed:
|
||||
@@ -111,10 +97,17 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
# 3. 提交关键数据库事务(在返回响应前)
|
||||
# 这确保了 Usage 记录、配额扣减等关键数据在响应返回前持久化
|
||||
try:
|
||||
db.commit()
|
||||
db = getattr(request.state, "db", None)
|
||||
if isinstance(db, Session):
|
||||
db.commit()
|
||||
except Exception as commit_error:
|
||||
logger.error(f"关键事务提交失败: {commit_error}")
|
||||
db.rollback()
|
||||
try:
|
||||
if isinstance(db, Session):
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
await self._call_error_plugins(request, commit_error, start_time)
|
||||
# 返回 500 错误,因为数据可能不一致
|
||||
response = JSONResponse(
|
||||
status_code=500,
|
||||
@@ -139,14 +132,18 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
except RuntimeError as e:
|
||||
if str(e) == "No response returned.":
|
||||
if db:
|
||||
db.rollback()
|
||||
db = getattr(request.state, "db", None)
|
||||
if isinstance(db, Session):
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.error("Downstream handler completed without returning a response")
|
||||
|
||||
await self._call_error_plugins(request, e, start_time)
|
||||
|
||||
if db:
|
||||
if isinstance(db, Session):
|
||||
try:
|
||||
db.commit()
|
||||
except Exception:
|
||||
@@ -167,14 +164,18 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
except Exception as e:
|
||||
# 回滚数据库事务
|
||||
if db:
|
||||
db.rollback()
|
||||
db = getattr(request.state, "db", None)
|
||||
if isinstance(db, Session):
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 错误处理插件调用
|
||||
await self._call_error_plugins(request, e, start_time)
|
||||
|
||||
# 尝试提交错误日志
|
||||
if db:
|
||||
if isinstance(db, Session):
|
||||
try:
|
||||
db.commit()
|
||||
except:
|
||||
@@ -183,38 +184,13 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
exception_to_raise = e
|
||||
|
||||
finally:
|
||||
# 确保数据库会话被正确关闭
|
||||
# 注意:需要安全地处理各种状态,避免 IllegalStateChangeError
|
||||
if db is not None:
|
||||
db = getattr(request.state, "db", None)
|
||||
if isinstance(db, Session):
|
||||
try:
|
||||
# 检查会话是否可以安全地进行回滚
|
||||
# 只有当没有进行中的事务操作时才尝试回滚
|
||||
if db.is_active and not db.get_transaction().is_active:
|
||||
# 事务不在活跃状态,可以安全回滚
|
||||
pass
|
||||
elif db.is_active:
|
||||
# 事务在活跃状态,尝试回滚
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception as rollback_error:
|
||||
# 回滚失败(可能是 commit 正在进行中),忽略错误
|
||||
logger.debug(f"Rollback skipped: {rollback_error}")
|
||||
except Exception:
|
||||
# 检查状态时出错,忽略
|
||||
pass
|
||||
|
||||
# 通过触发生成器的 finally 块来关闭会话(标准模式)
|
||||
# 这会调用 get_db() 的 finally 块,执行 db.close()
|
||||
try:
|
||||
next(db_gen, None)
|
||||
except StopIteration:
|
||||
# 正常情况:生成器已耗尽
|
||||
pass
|
||||
except Exception as cleanup_error:
|
||||
# 忽略 IllegalStateChangeError 等清理错误
|
||||
# 这些错误通常是由于事务状态不一致导致的,不影响业务逻辑
|
||||
if "IllegalStateChangeError" not in str(type(cleanup_error).__name__):
|
||||
logger.warning(f"Database cleanup warning: {cleanup_error}")
|
||||
db.close()
|
||||
except Exception as close_error:
|
||||
# 连接池会处理连接的回收,这里的异常不应影响响应
|
||||
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
|
||||
|
||||
# 在 finally 块之后处理异常和响应
|
||||
if exception_to_raise:
|
||||
@@ -250,7 +226,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
return False
|
||||
|
||||
async def _get_rate_limit_key_and_config(
|
||||
self, request: Request, db: Session
|
||||
self, request: Request
|
||||
) -> tuple[Optional[str], Optional[int]]:
|
||||
"""
|
||||
获取速率限制的key和配置
|
||||
@@ -318,14 +294,8 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
# 如果没有限流插件,允许通过
|
||||
return None
|
||||
|
||||
# 获取数据库会话
|
||||
db = getattr(request.state, "db", None)
|
||||
if not db:
|
||||
logger.warning("速率限制检查:无法获取数据库会话")
|
||||
return None
|
||||
|
||||
# 获取速率限制的key和配置(从数据库)
|
||||
key, rate_limit_value = await self._get_rate_limit_key_and_config(request, db)
|
||||
# 获取速率限制的 key 和配置
|
||||
key, rate_limit_value = await self._get_rate_limit_key_and_config(request)
|
||||
if not key:
|
||||
# 不需要限流的端点(如未分类路径),静默跳过
|
||||
return None
|
||||
@@ -336,7 +306,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
key=key,
|
||||
endpoint=request.url.path,
|
||||
method=request.method,
|
||||
rate_limit=rate_limit_value, # 传入数据库配置的限制值
|
||||
rate_limit=rate_limit_value, # 传入配置的限制值
|
||||
)
|
||||
# 类型检查:确保返回的是RateLimitResult类型
|
||||
if isinstance(result, RateLimitResult):
|
||||
|
||||
Reference in New Issue
Block a user