refactor: optimize database session lifecycle and middleware architecture

- Improve database pool capacity logging with detailed configuration parameters
- Optimize database session dependency injection with middleware-managed lifecycle
- Simplify plugin middleware by delegating session creation to FastAPI dependencies
- Fix import path in auth routes (relative to absolute)
- Add safety checks for database session management across middleware exception handlers
- Ensure session cleanup only when not managed by middleware (avoid premature cleanup)
This commit is contained in:
fawney19
2025-12-18 00:35:46 +08:00
parent 9d5c84f9d3
commit b579420690
7 changed files with 113 additions and 113 deletions

View File

@@ -11,7 +11,6 @@
import asyncio
import math
import os
from contextlib import asynccontextmanager
from datetime import timedelta # noqa: F401 - kept for potential future use
from typing import Optional, Tuple
@@ -40,6 +39,7 @@ class ConcurrencyManager:
self._memory_lock: asyncio.Lock = asyncio.Lock()
self._memory_endpoint_counts: dict[str, int] = {}
self._memory_key_counts: dict[str, int] = {}
self._owns_redis: bool = False
self._memory_initialized = True
async def initialize(self) -> None:
@@ -47,41 +47,29 @@ class ConcurrencyManager:
if self._redis is not None:
return
# 优先使用 REDIS_URL如果没有则根据密码构建 URL
redis_url = os.getenv("REDIS_URL")
if not redis_url:
# 本地开发模式:从 REDIS_PASSWORD 构建 URL
redis_password = os.getenv("REDIS_PASSWORD")
if redis_password:
redis_url = f"redis://:{redis_password}@localhost:6379/0"
else:
redis_url = "redis://localhost:6379/0"
try:
self._redis = await aioredis.from_url(
redis_url,
encoding="utf-8",
decode_responses=True,
socket_timeout=5.0,
socket_connect_timeout=5.0,
)
# 测试连接
await self._redis.ping()
# 脱敏显示(隐藏密码)
safe_url = redis_url.split("@")[-1] if "@" in redis_url else redis_url
logger.info(f"[OK] Redis 连接成功: {safe_url}")
# 复用全局 Redis 客户端(带熔断/降级),避免重复创建连接池
from src.clients.redis_client import get_redis_client
self._redis = await get_redis_client(require_redis=False)
self._owns_redis = False
if self._redis:
logger.info("[OK] ConcurrencyManager 已复用全局 Redis 客户端")
else:
logger.warning("[WARN] Redis 不可用,并发控制降级为内存模式(仅在单实例环境下安全)")
except Exception as e:
logger.error(f"[ERROR] Redis 连接失败: {e}")
logger.warning("[WARN] 并发控制将被禁用(仅在单实例环境下安全)")
logger.error(f"[ERROR] 获取全局 Redis 客户端失败: {e}")
logger.warning("[WARN] 并发控制将降级为内存模式(仅在单实例环境下安全)")
self._redis = None
self._owns_redis = False
async def close(self) -> None:
"""关闭 Redis 连接"""
if self._redis:
if self._redis and self._owns_redis:
await self._redis.close()
self._redis = None
logger.info("Redis 连接已关闭")
logger.info("ConcurrencyManager Redis 连接已关闭")
self._redis = None
self._owns_redis = False
def _get_endpoint_key(self, endpoint_id: str) -> str:
"""获取 Endpoint 并发计数的 Redis Key"""

View File

@@ -3,7 +3,7 @@ RPM (Requests Per Minute) 限流服务
"""
import time
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
from typing import Dict, Tuple
from sqlalchemy.orm import Session
@@ -72,11 +72,7 @@ class RPMLimiter:
# 获取当前分钟窗口
now = datetime.now(timezone.utc)
window_start = now.replace(second=0, microsecond=0)
window_end = (
window_start.replace(minute=window_start.minute + 1)
if window_start.minute < 59
else window_start.replace(hour=window_start.hour + 1, minute=0)
)
window_end = window_start + timedelta(minutes=1)
# 查找或创建追踪记录
tracking = (