diff --git a/src/api/auth/routes.py b/src/api/auth/routes.py index 2c32559..208780e 100644 --- a/src/api/auth/routes.py +++ b/src/api/auth/routes.py @@ -211,7 +211,7 @@ class AuthRefreshAdapter(AuthPublicAdapter): class AuthRegisterAdapter(AuthPublicAdapter): async def handle(self, context): # type: ignore[override] - from ..models.database import SystemConfig + from src.models.database import SystemConfig db = context.db payload = context.ensure_json_body() diff --git a/src/clients/redis_client.py b/src/clients/redis_client.py index 96ea0e5..1337679 100644 --- a/src/clients/redis_client.py +++ b/src/clients/redis_client.py @@ -267,6 +267,9 @@ async def get_redis_client(require_redis: bool = False) -> Optional[aioredis.Red if _redis_manager is None: _redis_manager = RedisClientManager() + # 如果尚未连接(例如启动时降级、或 close() 后),尝试重新初始化。 + # initialize() 内部包含熔断器逻辑,避免频繁重试导致抖动。 + if _redis_manager.get_client() is None: await _redis_manager.initialize(require_redis=require_redis) return _redis_manager.get_client() diff --git a/src/core/batch_committer.py b/src/core/batch_committer.py index 44ed06a..87169dd 100644 --- a/src/core/batch_committer.py +++ b/src/core/batch_committer.py @@ -46,6 +46,11 @@ class BatchCommitter: def mark_dirty(self, session: Session): """标记 Session 有待提交的更改""" + # 请求级事务由中间件统一 commit/rollback;避免后台任务在请求中途误提交。 + if session is None: + return + if session.info.get("managed_by_middleware"): + return self._pending_sessions.add(session) async def _batch_commit_loop(self): diff --git a/src/database/database.py b/src/database/database.py index 6a07e9d..e89f2ad 100644 --- a/src/database/database.py +++ b/src/database/database.py @@ -5,6 +5,7 @@ import time from typing import AsyncGenerator, Generator, Optional +from starlette.requests import Request from sqlalchemy import create_engine, event from sqlalchemy.engine import Engine from sqlalchemy.ext.asyncio import ( @@ -150,9 +151,22 @@ def _log_pool_capacity(): theoretical = config.db_pool_size + config.db_max_overflow workers = max(1, config.worker_processes) total_estimated = theoretical * workers - logger.info("数据库连接池配置") - if total_estimated > config.db_pool_warn_threshold: - logger.warning("数据库连接需求可能超过阈值,请调小池大小或减少 worker 数") + safe_limit = config.pg_max_connections - config.pg_reserved_connections + logger.info( + "数据库连接池配置: pool_size=%s, max_overflow=%s, workers=%s, total_estimated=%s, safe_limit=%s", + config.db_pool_size, + config.db_max_overflow, + workers, + total_estimated, + safe_limit, + ) + if total_estimated > safe_limit: + logger.warning( + "数据库连接池总需求可能超过 PostgreSQL 限制: %s > %s (pg_max_connections - reserved)," + "建议调整 DB_POOL_SIZE/DB_MAX_OVERFLOW 或减少 worker 数", + total_estimated, + safe_limit, + ) def _ensure_async_engine() -> AsyncEngine: @@ -185,7 +199,7 @@ def _ensure_async_engine() -> AsyncEngine: # 创建异步引擎 _async_engine = create_async_engine( ASYNC_DATABASE_URL, - poolclass=QueuePool, # 使用队列连接池 + # AsyncEngine 不能使用 QueuePool;默认使用 AsyncAdaptedQueuePool pool_size=config.db_pool_size, max_overflow=config.db_max_overflow, pool_timeout=config.db_pool_timeout, @@ -220,16 +234,39 @@ async def get_async_db() -> AsyncGenerator[AsyncSession, None]: await session.close() -def get_db() -> Generator[Session, None, None]: +def get_db(request: Request = None) -> Generator[Session, None, None]: # type: ignore[assignment] """获取数据库会话 注意:事务管理由业务逻辑层显式控制(手动调用 commit/rollback) 这里只负责会话的创建和关闭,不自动提交 + + 在 FastAPI 请求上下文中通过 Depends(get_db) 调用时,会自动注入 Request 对象, + 支持中间件管理的 session 复用;在非请求上下文中直接调用 get_db() 时, + request 为 None,退化为独立 session 模式。 """ + # FastAPI 请求上下文:优先复用中间件绑定的 request.state.db + if request is not None: + existing_db = getattr(getattr(request, "state", None), "db", None) + if isinstance(existing_db, Session): + yield existing_db + return + # 确保引擎已初始化 _ensure_engine() db = _SessionLocal() + + # 如果中间件声明会统一管理会话生命周期,则把 session 绑定到 request.state, + # 并由中间件负责 commit/rollback/close(这里不关闭,避免流式响应提前释放会话)。 + managed_by_middleware = bool( + request is not None + and hasattr(request, "state") + and getattr(request.state, "db_managed_by_middleware", False) + ) + if managed_by_middleware: + request.state.db = db + db.info["managed_by_middleware"] = True + try: yield db # 不再自动 commit,由业务代码显式管理事务 @@ -241,12 +278,13 @@ def get_db() -> Generator[Session, None, None]: logger.debug(f"回滚事务时出错(可忽略): {rollback_error}") raise finally: - try: - db.close() # 确保连接返回池 - except Exception as close_error: - # 记录关闭错误(如 IllegalStateChangeError) - # 连接池会处理连接的回收 - logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}") + if not managed_by_middleware: + try: + db.close() # 确保连接返回池 + except Exception as close_error: + # 记录关闭错误(如 IllegalStateChangeError) + # 连接池会处理连接的回收 + logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}") def create_session() -> Session: @@ -336,7 +374,7 @@ def init_admin_user(db: Session): admin.set_password(config.admin_password) db.add(admin) - db.commit() # 刷新以获取ID,但不提交 + db.flush() # 分配ID,但不提交事务(由外层 init_db 统一 commit) logger.info(f"创建管理员账户成功: {admin.email} ({admin.username})") except Exception as e: diff --git a/src/middleware/plugin_middleware.py b/src/middleware/plugin_middleware.py index 030c5e0..a262c19 100644 --- a/src/middleware/plugin_middleware.py +++ b/src/middleware/plugin_middleware.py @@ -14,7 +14,6 @@ from starlette.responses import Response as StarletteResponse from src.config import config from src.core.logger import logger -from src.database import get_db from src.plugins.manager import get_plugin_manager from src.plugins.rate_limit.base import RateLimitResult @@ -71,26 +70,13 @@ class PluginMiddleware(BaseHTTPMiddleware): start_time = time.time() request.state.request_id = request.headers.get("x-request-id", "") request.state.start_time = start_time + # 标记:若请求过程中通过 Depends(get_db) 创建了会话,则由本中间件统一管理其生命周期 + request.state.db_managed_by_middleware = True - # 从 request.app 获取 FastAPI 应用实例(而不是从 __init__ 的 app 参数) - # 这样才能访问到真正的 FastAPI 实例和其 dependency_overrides - db_func = get_db - if hasattr(request, "app") and hasattr(request.app, "dependency_overrides"): - if get_db in request.app.dependency_overrides: - db_func = request.app.dependency_overrides[get_db] - logger.debug("Using overridden get_db from app.dependency_overrides") - - # 创建数据库会话供需要的插件或后续处理使用 - db_gen = db_func() - db = None response = None exception_to_raise = None try: - # 获取数据库会话 - db = next(db_gen) - request.state.db = db - # 1. 限流插件调用(可选功能) rate_limit_result = await self._call_rate_limit_plugins(request) if rate_limit_result and not rate_limit_result.allowed: @@ -111,10 +97,17 @@ class PluginMiddleware(BaseHTTPMiddleware): # 3. 提交关键数据库事务(在返回响应前) # 这确保了 Usage 记录、配额扣减等关键数据在响应返回前持久化 try: - db.commit() + db = getattr(request.state, "db", None) + if isinstance(db, Session): + db.commit() except Exception as commit_error: logger.error(f"关键事务提交失败: {commit_error}") - db.rollback() + try: + if isinstance(db, Session): + db.rollback() + except Exception: + pass + await self._call_error_plugins(request, commit_error, start_time) # 返回 500 错误,因为数据可能不一致 response = JSONResponse( status_code=500, @@ -139,14 +132,18 @@ class PluginMiddleware(BaseHTTPMiddleware): except RuntimeError as e: if str(e) == "No response returned.": - if db: - db.rollback() + db = getattr(request.state, "db", None) + if isinstance(db, Session): + try: + db.rollback() + except Exception: + pass logger.error("Downstream handler completed without returning a response") await self._call_error_plugins(request, e, start_time) - if db: + if isinstance(db, Session): try: db.commit() except Exception: @@ -167,14 +164,18 @@ class PluginMiddleware(BaseHTTPMiddleware): except Exception as e: # 回滚数据库事务 - if db: - db.rollback() + db = getattr(request.state, "db", None) + if isinstance(db, Session): + try: + db.rollback() + except Exception: + pass # 错误处理插件调用 await self._call_error_plugins(request, e, start_time) # 尝试提交错误日志 - if db: + if isinstance(db, Session): try: db.commit() except: @@ -183,38 +184,13 @@ class PluginMiddleware(BaseHTTPMiddleware): exception_to_raise = e finally: - # 确保数据库会话被正确关闭 - # 注意:需要安全地处理各种状态,避免 IllegalStateChangeError - if db is not None: + db = getattr(request.state, "db", None) + if isinstance(db, Session): try: - # 检查会话是否可以安全地进行回滚 - # 只有当没有进行中的事务操作时才尝试回滚 - if db.is_active and not db.get_transaction().is_active: - # 事务不在活跃状态,可以安全回滚 - pass - elif db.is_active: - # 事务在活跃状态,尝试回滚 - try: - db.rollback() - except Exception as rollback_error: - # 回滚失败(可能是 commit 正在进行中),忽略错误 - logger.debug(f"Rollback skipped: {rollback_error}") - except Exception: - # 检查状态时出错,忽略 - pass - - # 通过触发生成器的 finally 块来关闭会话(标准模式) - # 这会调用 get_db() 的 finally 块,执行 db.close() - try: - next(db_gen, None) - except StopIteration: - # 正常情况:生成器已耗尽 - pass - except Exception as cleanup_error: - # 忽略 IllegalStateChangeError 等清理错误 - # 这些错误通常是由于事务状态不一致导致的,不影响业务逻辑 - if "IllegalStateChangeError" not in str(type(cleanup_error).__name__): - logger.warning(f"Database cleanup warning: {cleanup_error}") + db.close() + except Exception as close_error: + # 连接池会处理连接的回收,这里的异常不应影响响应 + logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}") # 在 finally 块之后处理异常和响应 if exception_to_raise: @@ -250,7 +226,7 @@ class PluginMiddleware(BaseHTTPMiddleware): return False async def _get_rate_limit_key_and_config( - self, request: Request, db: Session + self, request: Request ) -> tuple[Optional[str], Optional[int]]: """ 获取速率限制的key和配置 @@ -318,14 +294,8 @@ class PluginMiddleware(BaseHTTPMiddleware): # 如果没有限流插件,允许通过 return None - # 获取数据库会话 - db = getattr(request.state, "db", None) - if not db: - logger.warning("速率限制检查:无法获取数据库会话") - return None - - # 获取速率限制的key和配置(从数据库) - key, rate_limit_value = await self._get_rate_limit_key_and_config(request, db) + # 获取速率限制的 key 和配置 + key, rate_limit_value = await self._get_rate_limit_key_and_config(request) if not key: # 不需要限流的端点(如未分类路径),静默跳过 return None @@ -336,7 +306,7 @@ class PluginMiddleware(BaseHTTPMiddleware): key=key, endpoint=request.url.path, method=request.method, - rate_limit=rate_limit_value, # 传入数据库配置的限制值 + rate_limit=rate_limit_value, # 传入配置的限制值 ) # 类型检查:确保返回的是RateLimitResult类型 if isinstance(result, RateLimitResult): diff --git a/src/services/rate_limit/concurrency_manager.py b/src/services/rate_limit/concurrency_manager.py index a968155..b1af9b1 100644 --- a/src/services/rate_limit/concurrency_manager.py +++ b/src/services/rate_limit/concurrency_manager.py @@ -11,7 +11,6 @@ import asyncio import math -import os from contextlib import asynccontextmanager from datetime import timedelta # noqa: F401 - kept for potential future use from typing import Optional, Tuple @@ -40,6 +39,7 @@ class ConcurrencyManager: self._memory_lock: asyncio.Lock = asyncio.Lock() self._memory_endpoint_counts: dict[str, int] = {} self._memory_key_counts: dict[str, int] = {} + self._owns_redis: bool = False self._memory_initialized = True async def initialize(self) -> None: @@ -47,41 +47,29 @@ class ConcurrencyManager: if self._redis is not None: return - # 优先使用 REDIS_URL,如果没有则根据密码构建 URL - redis_url = os.getenv("REDIS_URL") - - if not redis_url: - # 本地开发模式:从 REDIS_PASSWORD 构建 URL - redis_password = os.getenv("REDIS_PASSWORD") - if redis_password: - redis_url = f"redis://:{redis_password}@localhost:6379/0" - else: - redis_url = "redis://localhost:6379/0" - try: - self._redis = await aioredis.from_url( - redis_url, - encoding="utf-8", - decode_responses=True, - socket_timeout=5.0, - socket_connect_timeout=5.0, - ) - # 测试连接 - await self._redis.ping() - # 脱敏显示(隐藏密码) - safe_url = redis_url.split("@")[-1] if "@" in redis_url else redis_url - logger.info(f"[OK] Redis 连接成功: {safe_url}") + # 复用全局 Redis 客户端(带熔断/降级),避免重复创建连接池 + from src.clients.redis_client import get_redis_client + + self._redis = await get_redis_client(require_redis=False) + self._owns_redis = False + if self._redis: + logger.info("[OK] ConcurrencyManager 已复用全局 Redis 客户端") + else: + logger.warning("[WARN] Redis 不可用,并发控制降级为内存模式(仅在单实例环境下安全)") except Exception as e: - logger.error(f"[ERROR] Redis 连接失败: {e}") - logger.warning("[WARN] 并发控制将被禁用(仅在单实例环境下安全)") + logger.error(f"[ERROR] 获取全局 Redis 客户端失败: {e}") + logger.warning("[WARN] 并发控制将降级为内存模式(仅在单实例环境下安全)") self._redis = None + self._owns_redis = False async def close(self) -> None: """关闭 Redis 连接""" - if self._redis: + if self._redis and self._owns_redis: await self._redis.close() - self._redis = None - logger.info("Redis 连接已关闭") + logger.info("ConcurrencyManager Redis 连接已关闭") + self._redis = None + self._owns_redis = False def _get_endpoint_key(self, endpoint_id: str) -> str: """获取 Endpoint 并发计数的 Redis Key""" diff --git a/src/services/rate_limit/rpm_limiter.py b/src/services/rate_limit/rpm_limiter.py index 095fbcc..70479a6 100644 --- a/src/services/rate_limit/rpm_limiter.py +++ b/src/services/rate_limit/rpm_limiter.py @@ -3,7 +3,7 @@ RPM (Requests Per Minute) 限流服务 """ import time -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from typing import Dict, Tuple from sqlalchemy.orm import Session @@ -72,11 +72,7 @@ class RPMLimiter: # 获取当前分钟窗口 now = datetime.now(timezone.utc) window_start = now.replace(second=0, microsecond=0) - window_end = ( - window_start.replace(minute=window_start.minute + 1) - if window_start.minute < 59 - else window_start.replace(hour=window_start.hour + 1, minute=0) - ) + window_end = window_start + timedelta(minutes=1) # 查找或创建追踪记录 tracking = (