refactor: optimize database session lifecycle and middleware architecture

- Improve database pool capacity logging with detailed configuration parameters
- Optimize database session dependency injection with middleware-managed lifecycle
- Simplify plugin middleware by delegating session creation to FastAPI dependencies
- Fix import path in auth routes (relative to absolute)
- Add safety checks for database session management across middleware exception handlers
- Ensure session cleanup only when not managed by middleware (avoid premature cleanup)
This commit is contained in:
fawney19
2025-12-18 00:35:46 +08:00
parent 9d5c84f9d3
commit b579420690
7 changed files with 113 additions and 113 deletions

View File

@@ -211,7 +211,7 @@ class AuthRefreshAdapter(AuthPublicAdapter):
class AuthRegisterAdapter(AuthPublicAdapter): class AuthRegisterAdapter(AuthPublicAdapter):
async def handle(self, context): # type: ignore[override] async def handle(self, context): # type: ignore[override]
from ..models.database import SystemConfig from src.models.database import SystemConfig
db = context.db db = context.db
payload = context.ensure_json_body() payload = context.ensure_json_body()

View File

@@ -267,6 +267,9 @@ async def get_redis_client(require_redis: bool = False) -> Optional[aioredis.Red
if _redis_manager is None: if _redis_manager is None:
_redis_manager = RedisClientManager() _redis_manager = RedisClientManager()
# 如果尚未连接(例如启动时降级、或 close() 后),尝试重新初始化。
# initialize() 内部包含熔断器逻辑,避免频繁重试导致抖动。
if _redis_manager.get_client() is None:
await _redis_manager.initialize(require_redis=require_redis) await _redis_manager.initialize(require_redis=require_redis)
return _redis_manager.get_client() return _redis_manager.get_client()

View File

@@ -46,6 +46,11 @@ class BatchCommitter:
def mark_dirty(self, session: Session): def mark_dirty(self, session: Session):
"""标记 Session 有待提交的更改""" """标记 Session 有待提交的更改"""
# 请求级事务由中间件统一 commit/rollback避免后台任务在请求中途误提交。
if session is None:
return
if session.info.get("managed_by_middleware"):
return
self._pending_sessions.add(session) self._pending_sessions.add(session)
async def _batch_commit_loop(self): async def _batch_commit_loop(self):

View File

@@ -5,6 +5,7 @@
import time import time
from typing import AsyncGenerator, Generator, Optional from typing import AsyncGenerator, Generator, Optional
from starlette.requests import Request
from sqlalchemy import create_engine, event from sqlalchemy import create_engine, event
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
from sqlalchemy.ext.asyncio import ( from sqlalchemy.ext.asyncio import (
@@ -150,9 +151,22 @@ def _log_pool_capacity():
theoretical = config.db_pool_size + config.db_max_overflow theoretical = config.db_pool_size + config.db_max_overflow
workers = max(1, config.worker_processes) workers = max(1, config.worker_processes)
total_estimated = theoretical * workers total_estimated = theoretical * workers
logger.info("数据库连接池配置") safe_limit = config.pg_max_connections - config.pg_reserved_connections
if total_estimated > config.db_pool_warn_threshold: logger.info(
logger.warning("数据库连接需求可能超过阈值,请调小池大小或减少 worker 数") "数据库连接池配置: pool_size=%s, max_overflow=%s, workers=%s, total_estimated=%s, safe_limit=%s",
config.db_pool_size,
config.db_max_overflow,
workers,
total_estimated,
safe_limit,
)
if total_estimated > safe_limit:
logger.warning(
"数据库连接池总需求可能超过 PostgreSQL 限制: %s > %s (pg_max_connections - reserved)"
"建议调整 DB_POOL_SIZE/DB_MAX_OVERFLOW 或减少 worker 数",
total_estimated,
safe_limit,
)
def _ensure_async_engine() -> AsyncEngine: def _ensure_async_engine() -> AsyncEngine:
@@ -185,7 +199,7 @@ def _ensure_async_engine() -> AsyncEngine:
# 创建异步引擎 # 创建异步引擎
_async_engine = create_async_engine( _async_engine = create_async_engine(
ASYNC_DATABASE_URL, ASYNC_DATABASE_URL,
poolclass=QueuePool, # 使用队列连接池 # AsyncEngine 不能使用 QueuePool默认使用 AsyncAdaptedQueuePool
pool_size=config.db_pool_size, pool_size=config.db_pool_size,
max_overflow=config.db_max_overflow, max_overflow=config.db_max_overflow,
pool_timeout=config.db_pool_timeout, pool_timeout=config.db_pool_timeout,
@@ -220,16 +234,39 @@ async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
await session.close() await session.close()
def get_db() -> Generator[Session, None, None]: def get_db(request: Request = None) -> Generator[Session, None, None]: # type: ignore[assignment]
"""获取数据库会话 """获取数据库会话
注意:事务管理由业务逻辑层显式控制(手动调用 commit/rollback 注意:事务管理由业务逻辑层显式控制(手动调用 commit/rollback
这里只负责会话的创建和关闭,不自动提交 这里只负责会话的创建和关闭,不自动提交
在 FastAPI 请求上下文中通过 Depends(get_db) 调用时,会自动注入 Request 对象,
支持中间件管理的 session 复用;在非请求上下文中直接调用 get_db() 时,
request 为 None退化为独立 session 模式。
""" """
# FastAPI 请求上下文:优先复用中间件绑定的 request.state.db
if request is not None:
existing_db = getattr(getattr(request, "state", None), "db", None)
if isinstance(existing_db, Session):
yield existing_db
return
# 确保引擎已初始化 # 确保引擎已初始化
_ensure_engine() _ensure_engine()
db = _SessionLocal() db = _SessionLocal()
# 如果中间件声明会统一管理会话生命周期,则把 session 绑定到 request.state
# 并由中间件负责 commit/rollback/close这里不关闭避免流式响应提前释放会话
managed_by_middleware = bool(
request is not None
and hasattr(request, "state")
and getattr(request.state, "db_managed_by_middleware", False)
)
if managed_by_middleware:
request.state.db = db
db.info["managed_by_middleware"] = True
try: try:
yield db yield db
# 不再自动 commit由业务代码显式管理事务 # 不再自动 commit由业务代码显式管理事务
@@ -241,6 +278,7 @@ def get_db() -> Generator[Session, None, None]:
logger.debug(f"回滚事务时出错(可忽略): {rollback_error}") logger.debug(f"回滚事务时出错(可忽略): {rollback_error}")
raise raise
finally: finally:
if not managed_by_middleware:
try: try:
db.close() # 确保连接返回池 db.close() # 确保连接返回池
except Exception as close_error: except Exception as close_error:
@@ -336,7 +374,7 @@ def init_admin_user(db: Session):
admin.set_password(config.admin_password) admin.set_password(config.admin_password)
db.add(admin) db.add(admin)
db.commit() # 刷新以获取ID但不提交 db.flush() # 分配ID但不提交事务(由外层 init_db 统一 commit
logger.info(f"创建管理员账户成功: {admin.email} ({admin.username})") logger.info(f"创建管理员账户成功: {admin.email} ({admin.username})")
except Exception as e: except Exception as e:

View File

@@ -14,7 +14,6 @@ from starlette.responses import Response as StarletteResponse
from src.config import config from src.config import config
from src.core.logger import logger from src.core.logger import logger
from src.database import get_db
from src.plugins.manager import get_plugin_manager from src.plugins.manager import get_plugin_manager
from src.plugins.rate_limit.base import RateLimitResult from src.plugins.rate_limit.base import RateLimitResult
@@ -71,26 +70,13 @@ class PluginMiddleware(BaseHTTPMiddleware):
start_time = time.time() start_time = time.time()
request.state.request_id = request.headers.get("x-request-id", "") request.state.request_id = request.headers.get("x-request-id", "")
request.state.start_time = start_time request.state.start_time = start_time
# 标记:若请求过程中通过 Depends(get_db) 创建了会话,则由本中间件统一管理其生命周期
request.state.db_managed_by_middleware = True
# 从 request.app 获取 FastAPI 应用实例(而不是从 __init__ 的 app 参数)
# 这样才能访问到真正的 FastAPI 实例和其 dependency_overrides
db_func = get_db
if hasattr(request, "app") and hasattr(request.app, "dependency_overrides"):
if get_db in request.app.dependency_overrides:
db_func = request.app.dependency_overrides[get_db]
logger.debug("Using overridden get_db from app.dependency_overrides")
# 创建数据库会话供需要的插件或后续处理使用
db_gen = db_func()
db = None
response = None response = None
exception_to_raise = None exception_to_raise = None
try: try:
# 获取数据库会话
db = next(db_gen)
request.state.db = db
# 1. 限流插件调用(可选功能) # 1. 限流插件调用(可选功能)
rate_limit_result = await self._call_rate_limit_plugins(request) rate_limit_result = await self._call_rate_limit_plugins(request)
if rate_limit_result and not rate_limit_result.allowed: if rate_limit_result and not rate_limit_result.allowed:
@@ -111,10 +97,17 @@ class PluginMiddleware(BaseHTTPMiddleware):
# 3. 提交关键数据库事务(在返回响应前) # 3. 提交关键数据库事务(在返回响应前)
# 这确保了 Usage 记录、配额扣减等关键数据在响应返回前持久化 # 这确保了 Usage 记录、配额扣减等关键数据在响应返回前持久化
try: try:
db = getattr(request.state, "db", None)
if isinstance(db, Session):
db.commit() db.commit()
except Exception as commit_error: except Exception as commit_error:
logger.error(f"关键事务提交失败: {commit_error}") logger.error(f"关键事务提交失败: {commit_error}")
try:
if isinstance(db, Session):
db.rollback() db.rollback()
except Exception:
pass
await self._call_error_plugins(request, commit_error, start_time)
# 返回 500 错误,因为数据可能不一致 # 返回 500 错误,因为数据可能不一致
response = JSONResponse( response = JSONResponse(
status_code=500, status_code=500,
@@ -139,14 +132,18 @@ class PluginMiddleware(BaseHTTPMiddleware):
except RuntimeError as e: except RuntimeError as e:
if str(e) == "No response returned.": if str(e) == "No response returned.":
if db: db = getattr(request.state, "db", None)
if isinstance(db, Session):
try:
db.rollback() db.rollback()
except Exception:
pass
logger.error("Downstream handler completed without returning a response") logger.error("Downstream handler completed without returning a response")
await self._call_error_plugins(request, e, start_time) await self._call_error_plugins(request, e, start_time)
if db: if isinstance(db, Session):
try: try:
db.commit() db.commit()
except Exception: except Exception:
@@ -167,14 +164,18 @@ class PluginMiddleware(BaseHTTPMiddleware):
except Exception as e: except Exception as e:
# 回滚数据库事务 # 回滚数据库事务
if db: db = getattr(request.state, "db", None)
if isinstance(db, Session):
try:
db.rollback() db.rollback()
except Exception:
pass
# 错误处理插件调用 # 错误处理插件调用
await self._call_error_plugins(request, e, start_time) await self._call_error_plugins(request, e, start_time)
# 尝试提交错误日志 # 尝试提交错误日志
if db: if isinstance(db, Session):
try: try:
db.commit() db.commit()
except: except:
@@ -183,38 +184,13 @@ class PluginMiddleware(BaseHTTPMiddleware):
exception_to_raise = e exception_to_raise = e
finally: finally:
# 确保数据库会话被正确关闭 db = getattr(request.state, "db", None)
# 注意:需要安全地处理各种状态,避免 IllegalStateChangeError if isinstance(db, Session):
if db is not None:
try: try:
# 检查会话是否可以安全地进行回滚 db.close()
# 只有当没有进行中的事务操作时才尝试回滚 except Exception as close_error:
if db.is_active and not db.get_transaction().is_active: # 连接池会处理连接的回收,这里的异常不应影响响应
# 事务不在活跃状态,可以安全回滚 logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
pass
elif db.is_active:
# 事务在活跃状态,尝试回滚
try:
db.rollback()
except Exception as rollback_error:
# 回滚失败(可能是 commit 正在进行中),忽略错误
logger.debug(f"Rollback skipped: {rollback_error}")
except Exception:
# 检查状态时出错,忽略
pass
# 通过触发生成器的 finally 块来关闭会话(标准模式)
# 这会调用 get_db() 的 finally 块,执行 db.close()
try:
next(db_gen, None)
except StopIteration:
# 正常情况:生成器已耗尽
pass
except Exception as cleanup_error:
# 忽略 IllegalStateChangeError 等清理错误
# 这些错误通常是由于事务状态不一致导致的,不影响业务逻辑
if "IllegalStateChangeError" not in str(type(cleanup_error).__name__):
logger.warning(f"Database cleanup warning: {cleanup_error}")
# 在 finally 块之后处理异常和响应 # 在 finally 块之后处理异常和响应
if exception_to_raise: if exception_to_raise:
@@ -250,7 +226,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
return False return False
async def _get_rate_limit_key_and_config( async def _get_rate_limit_key_and_config(
self, request: Request, db: Session self, request: Request
) -> tuple[Optional[str], Optional[int]]: ) -> tuple[Optional[str], Optional[int]]:
""" """
获取速率限制的key和配置 获取速率限制的key和配置
@@ -318,14 +294,8 @@ class PluginMiddleware(BaseHTTPMiddleware):
# 如果没有限流插件,允许通过 # 如果没有限流插件,允许通过
return None return None
# 获取数据库会话 # 获取速率限制的 key 和配置
db = getattr(request.state, "db", None) key, rate_limit_value = await self._get_rate_limit_key_and_config(request)
if not db:
logger.warning("速率限制检查:无法获取数据库会话")
return None
# 获取速率限制的key和配置从数据库
key, rate_limit_value = await self._get_rate_limit_key_and_config(request, db)
if not key: if not key:
# 不需要限流的端点(如未分类路径),静默跳过 # 不需要限流的端点(如未分类路径),静默跳过
return None return None
@@ -336,7 +306,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
key=key, key=key,
endpoint=request.url.path, endpoint=request.url.path,
method=request.method, method=request.method,
rate_limit=rate_limit_value, # 传入数据库配置的限制值 rate_limit=rate_limit_value, # 传入配置的限制值
) )
# 类型检查确保返回的是RateLimitResult类型 # 类型检查确保返回的是RateLimitResult类型
if isinstance(result, RateLimitResult): if isinstance(result, RateLimitResult):

View File

@@ -11,7 +11,6 @@
import asyncio import asyncio
import math import math
import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import timedelta # noqa: F401 - kept for potential future use from datetime import timedelta # noqa: F401 - kept for potential future use
from typing import Optional, Tuple from typing import Optional, Tuple
@@ -40,6 +39,7 @@ class ConcurrencyManager:
self._memory_lock: asyncio.Lock = asyncio.Lock() self._memory_lock: asyncio.Lock = asyncio.Lock()
self._memory_endpoint_counts: dict[str, int] = {} self._memory_endpoint_counts: dict[str, int] = {}
self._memory_key_counts: dict[str, int] = {} self._memory_key_counts: dict[str, int] = {}
self._owns_redis: bool = False
self._memory_initialized = True self._memory_initialized = True
async def initialize(self) -> None: async def initialize(self) -> None:
@@ -47,41 +47,29 @@ class ConcurrencyManager:
if self._redis is not None: if self._redis is not None:
return return
# 优先使用 REDIS_URL如果没有则根据密码构建 URL
redis_url = os.getenv("REDIS_URL")
if not redis_url:
# 本地开发模式:从 REDIS_PASSWORD 构建 URL
redis_password = os.getenv("REDIS_PASSWORD")
if redis_password:
redis_url = f"redis://:{redis_password}@localhost:6379/0"
else:
redis_url = "redis://localhost:6379/0"
try: try:
self._redis = await aioredis.from_url( # 复用全局 Redis 客户端(带熔断/降级),避免重复创建连接池
redis_url, from src.clients.redis_client import get_redis_client
encoding="utf-8",
decode_responses=True, self._redis = await get_redis_client(require_redis=False)
socket_timeout=5.0, self._owns_redis = False
socket_connect_timeout=5.0, if self._redis:
) logger.info("[OK] ConcurrencyManager 已复用全局 Redis 客户端")
# 测试连接 else:
await self._redis.ping() logger.warning("[WARN] Redis 不可用,并发控制降级为内存模式(仅在单实例环境下安全)")
# 脱敏显示(隐藏密码)
safe_url = redis_url.split("@")[-1] if "@" in redis_url else redis_url
logger.info(f"[OK] Redis 连接成功: {safe_url}")
except Exception as e: except Exception as e:
logger.error(f"[ERROR] Redis 连接失败: {e}") logger.error(f"[ERROR] 获取全局 Redis 客户端失败: {e}")
logger.warning("[WARN] 并发控制将被禁用(仅在单实例环境下安全)") logger.warning("[WARN] 并发控制将降级为内存模式(仅在单实例环境下安全)")
self._redis = None self._redis = None
self._owns_redis = False
async def close(self) -> None: async def close(self) -> None:
"""关闭 Redis 连接""" """关闭 Redis 连接"""
if self._redis: if self._redis and self._owns_redis:
await self._redis.close() await self._redis.close()
logger.info("ConcurrencyManager Redis 连接已关闭")
self._redis = None self._redis = None
logger.info("Redis 连接已关闭") self._owns_redis = False
def _get_endpoint_key(self, endpoint_id: str) -> str: def _get_endpoint_key(self, endpoint_id: str) -> str:
"""获取 Endpoint 并发计数的 Redis Key""" """获取 Endpoint 并发计数的 Redis Key"""

View File

@@ -3,7 +3,7 @@ RPM (Requests Per Minute) 限流服务
""" """
import time import time
from datetime import datetime, timezone from datetime import datetime, timedelta, timezone
from typing import Dict, Tuple from typing import Dict, Tuple
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -72,11 +72,7 @@ class RPMLimiter:
# 获取当前分钟窗口 # 获取当前分钟窗口
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
window_start = now.replace(second=0, microsecond=0) window_start = now.replace(second=0, microsecond=0)
window_end = ( window_end = window_start + timedelta(minutes=1)
window_start.replace(minute=window_start.minute + 1)
if window_start.minute < 59
else window_start.replace(hour=window_start.hour + 1, minute=0)
)
# 查找或创建追踪记录 # 查找或创建追踪记录
tracking = ( tracking = (