Initial commit

This commit is contained in:
fawney19
2025-12-10 20:52:44 +08:00
commit f784106826
485 changed files with 110993 additions and 0 deletions

3
src/utils/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .timeout import AsyncTimeoutError, run_with_timeout, with_timeout
__all__ = ["with_timeout", "run_with_timeout", "AsyncTimeoutError"]

203
src/utils/auth_utils.py Normal file
View File

@@ -0,0 +1,203 @@
"""
认证工具函数
提供统一的用户认证和授权功能
"""
from typing import Optional
from fastapi import Depends, Header, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlalchemy.orm import Session
from src.services.auth.service import AuthService
from ..core.exceptions import ForbiddenException
from src.core.logger import logger
from ..database import get_db
from ..models.database import User, UserRole
security = HTTPBearer()
def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security), db: Session = Depends(get_db)
) -> User:
"""
获取当前登录用户
统一的认证依赖函数
Args:
credentials: Bearer token 凭据
db: 数据库会话
Returns:
User: 当前用户对象
Raises:
HTTPException: 认证失败时抛出
"""
token = credentials.credentials
try:
# 验证Token格式和签名
try:
payload = AuthService.verify_token(token)
except HTTPException as token_error:
# 保持原始的HTTP状态码如401 Unauthorized不要转换为403
logger.error(f"Token验证失败: {token_error.status_code}: {token_error.detail}, Token前10位: {token[:10]}...")
raise # 重新抛出原始异常,保持状态码
except Exception as token_error:
logger.error(f"Token验证失败: {token_error}, Token前10位: {token[:10]}...")
raise ForbiddenException("无效的Token")
user_id = payload.get("user_id")
token_email = payload.get("email")
token_created_at = payload.get("created_at")
if not user_id:
logger.error(f"Token缺少user_id字段: payload={payload}")
raise ForbiddenException("无效的认证凭据")
if not token_email:
logger.error(f"Token缺少email字段: payload={payload}")
raise ForbiddenException("无效的认证凭据")
# 仅在DEBUG模式下记录详细信息
logger.debug(f"尝试获取用户: user_id={user_id}, token前10位: {token[:10]}...")
# 确保user_id是字符串格式UUID
if not isinstance(user_id, str):
logger.error(f"Token中user_id格式错误: {type(user_id)} - {user_id}")
raise ForbiddenException("认证信息格式错误,请重新登录")
# 使用新的数据库会话获取用户,避免会话状态问题
try:
from src.services.user.service import UserService
user = UserService.get_user(db, user_id)
except Exception as db_error:
logger.error(f"数据库查询失败: user_id={user_id}, error={db_error}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="数据库查询失败,请稍后重试",
)
if not user:
logger.error(f"用户不存在: user_id={user_id}")
raise ForbiddenException("用户不存在或已禁用")
if not user.is_active:
logger.error(f"用户已禁用: user_id={user_id}")
raise ForbiddenException("用户不存在或已禁用")
# 验证邮箱是否匹配防止用户ID重用导致的身份混淆
if user.email != token_email:
logger.error(f"Token邮箱不匹配: Token中的邮箱={token_email}, 数据库中的邮箱={user.email}")
raise ForbiddenException("身份验证失败")
# 验证用户创建时间是否匹配防止ID重用
if token_created_at and user.created_at:
try:
from datetime import datetime
token_created = datetime.fromisoformat(token_created_at.replace("Z", "+00:00"))
# 允许1秒的时间差异考虑到时间精度问题
time_diff = abs((user.created_at - token_created).total_seconds())
if time_diff > 1:
logger.error(f"Token创建时间不匹配: Token时间={token_created_at}, 用户创建时间={user.created_at}")
raise ForbiddenException("身份验证失败")
except ValueError as e:
logger.warning(f"Token时间格式解析失败: {e}")
logger.debug(f"成功获取用户: user_id={user_id}, email={user.email}")
return user
except HTTPException:
raise
except Exception as e:
logger.error(f"认证失败,未预期的错误: {e}")
# 返回500而不是401避免触发前端的退出逻辑
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="认证服务暂时不可用"
)
def get_current_user_from_header(
authorization: Optional[str] = Header(None), db: Session = Depends(get_db)
) -> User:
"""
从Header中获取当前用户兼容性函数
Args:
authorization: Authorization header
db: 数据库会话
Returns:
User: 当前用户对象
Raises:
HTTPException: 认证失败时抛出
"""
if not authorization or not authorization.startswith("Bearer "):
raise ForbiddenException("未提供认证令牌")
token = authorization.replace("Bearer ", "")
try:
payload = AuthService.verify_token(token)
user_id = payload.get("user_id")
if not user_id:
raise ForbiddenException("无效的认证凭据")
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise ForbiddenException("用户不存在")
if not user.is_active:
raise ForbiddenException("用户已被禁用")
return user
except HTTPException:
# 保持原始的HTTPException (包括401)
raise
except Exception as e:
logger.error(f"认证失败: {e}")
raise ForbiddenException("认证失败")
def require_admin(current_user: User = Depends(get_current_user)) -> User:
"""
要求管理员权限
Args:
current_user: 当前用户
Returns:
User: 管理员用户对象
Raises:
HTTPException: 非管理员时抛出403错误
"""
if current_user.role != UserRole.ADMIN:
raise ForbiddenException("需要管理员权限")
return current_user
def require_role(required_role: UserRole):
"""
要求特定角色权限的装饰器工厂
Args:
required_role: 需要的用户角色
Returns:
依赖函数
"""
def check_role(current_user: User = Depends(get_current_user)) -> User:
if current_user.role != required_role:
raise ForbiddenException(f"需要{required_role.value}权限")
return current_user
return check_role

View File

@@ -0,0 +1,75 @@
"""缓存装饰器工具"""
import functools
import json
from typing import Any, Callable, Optional
from src.core.logger import logger
from src.clients.redis_client import get_redis_client_sync
def cache_result(key_prefix: str, ttl: int = 60, user_specific: bool = True) -> Callable:
"""
缓存函数结果的装饰器
Args:
key_prefix: 缓存键前缀
ttl: 缓存过期时间(秒)
user_specific: 是否针对用户缓存(从 context.user.id 获取)
"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
async def wrapper(*args, **kwargs) -> Any:
redis_client = get_redis_client_sync()
# 如果 Redis 不可用,直接执行原函数
if redis_client is None:
return await func(*args, **kwargs)
# 构建缓存键
try:
# 从 args 中获取 context通常是第一个参数
context = args[0] if args else None
if user_specific and context and hasattr(context, "user") and context.user:
cache_key = f"{key_prefix}:user:{context.user.id}"
else:
cache_key = f"{key_prefix}:global"
# 如果有额外的参数(如 days添加到键中
if hasattr(args[0], "__dict__"):
# 如果是 dataclass 或对象,获取其属性
for attr_name in ["days", "limit"]:
if hasattr(args[0], attr_name):
attr_value = getattr(args[0], attr_name)
cache_key += f":{attr_name}:{attr_value}"
# 尝试从缓存获取
cached = await redis_client.get(cache_key)
if cached:
logger.debug(f"缓存命中: {cache_key}")
return json.loads(cached)
# 执行原函数
result = await func(*args, **kwargs)
# 保存到缓存
try:
await redis_client.setex(
cache_key, ttl, json.dumps(result, ensure_ascii=False, default=str)
)
logger.debug(f"缓存已保存: {cache_key}, TTL: {ttl}s")
except Exception as e:
logger.warning(f"保存缓存失败: {e}")
return result
except Exception as e:
logger.warning(f"缓存处理出错: {e}, 直接执行原函数")
return await func(*args, **kwargs)
return wrapper
return decorator

75
src/utils/compression.py Normal file
View File

@@ -0,0 +1,75 @@
"""
数据压缩/解压工具
提供JSON数据的gzip压缩和解压功能
"""
import gzip
import json
from typing import Any, Optional
def compress_json(data: Any) -> Optional[bytes]:
"""
将JSON数据压缩为gzip格式的字节
Args:
data: 任意可JSON序列化的数据
Returns:
gzip压缩后的字节如果输入为None则返回None
"""
if data is None:
return None
try:
# 转换为JSON字符串
json_str = json.dumps(data, ensure_ascii=False)
# gzip压缩
compressed = gzip.compress(json_str.encode("utf-8"), compresslevel=6)
return compressed
except Exception:
# 如果压缩失败返回None
return None
def decompress_json(compressed_data: Optional[bytes]) -> Optional[Any]:
"""
解压gzip格式的字节为JSON数据
Args:
compressed_data: gzip压缩的字节数据
Returns:
解压后的JSON数据如果输入为None或解压失败则返回None
"""
if compressed_data is None:
return None
try:
# gzip解压
json_str = gzip.decompress(compressed_data).decode("utf-8")
# 解析JSON
data = json.loads(json_str)
return data
except Exception:
# 如果解压失败返回None
return None
def get_body_size(data: Any) -> int:
"""
获取JSON数据序列化后的字节大小
Args:
data: 任意可JSON序列化的数据
Returns:
字节大小
"""
if data is None:
return 0
try:
return len(json.dumps(data, ensure_ascii=False).encode("utf-8"))
except Exception:
return 0

View File

@@ -0,0 +1,64 @@
"""
数据库方言兼容性辅助函数
"""
from sqlalchemy import func
from sqlalchemy.sql.elements import ClauseElement
def date_trunc_portable(dialect_name: str, interval: str, column) -> ClauseElement:
"""
跨数据库的日期截断函数
Args:
dialect_name: 数据库方言名称 ('postgresql', 'sqlite', 'mysql')
interval: 时间间隔 ('day', 'week', 'month', 'year')
column: 日期列
Returns:
SQLAlchemy ClauseElement
Raises:
NotImplementedError: 不支持的数据库方言
Examples:
>>> # PostgreSQL
>>> period_func = date_trunc_portable('postgresql', 'week', Usage.created_at)
>>> # 等价于: func.date_trunc('week', Usage.created_at)
>>> # SQLite
>>> period_func = date_trunc_portable('sqlite', 'month', Usage.created_at)
>>> # 等价于: func.strftime("%Y-%m", Usage.created_at)
"""
if dialect_name == "postgresql":
# PostgreSQL 使用 date_trunc 函数
return func.date_trunc(interval, column)
elif dialect_name == "sqlite":
# SQLite 使用 strftime 函数
format_map = {
"year": "%Y",
"month": "%Y-%m",
"week": "%Y-%W",
"day": "%Y-%m-%d",
}
if interval not in format_map:
raise ValueError(f"Unsupported interval for SQLite: {interval}")
return func.strftime(format_map[interval], column)
elif dialect_name == "mysql":
# MySQL 使用 date_format 函数
format_map = {
"year": "%Y",
"month": "%Y-%m",
"day": "%Y-%m-%d",
}
if interval not in format_map:
raise ValueError(f"Unsupported interval for MySQL: {interval}")
return func.date_format(column, format_map[interval])
else:
raise NotImplementedError(
f"Unsupported database dialect: {dialect_name}. "
f"Supported dialects: postgresql, sqlite, mysql"
)

114
src/utils/request_utils.py Normal file
View File

@@ -0,0 +1,114 @@
"""
请求处理工具函数
提供统一的HTTP请求信息提取功能
"""
from typing import Optional
from fastapi import Request
def get_client_ip(request: Request) -> str:
"""
获取客户端真实IP地址
按优先级检查:
1. X-Forwarded-For 头(支持代理链)
2. X-Real-IP 头Nginx 代理)
3. 直接客户端IP
Args:
request: FastAPI Request 对象
Returns:
str: 客户端IP地址如果无法获取则返回 "unknown"
"""
# 优先检查 X-Forwarded-For 头(可能包含代理链)
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
# X-Forwarded-For 格式: "client, proxy1, proxy2",取第一个(真实客户端)
client_ip = forwarded_for.split(",")[0].strip()
if client_ip:
return client_ip
# 检查 X-Real-IP 头(通常由 Nginx 设置)
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip.strip()
# 回退到直接客户端IP
if request.client and request.client.host:
return request.client.host
return "unknown"
def get_user_agent(request: Request) -> str:
"""
获取用户代理字符串
Args:
request: FastAPI Request 对象
Returns:
str: User-Agent 字符串,如果无法获取则返回 "unknown"
"""
return request.headers.get("User-Agent", "unknown")
def get_request_id(request: Request) -> Optional[str]:
"""
获取请求ID如果存在
Args:
request: FastAPI Request 对象
Returns:
Optional[str]: 请求ID如果不存在则返回 None
"""
return getattr(request.state, "request_id", None)
def get_request_metadata(request: Request) -> dict:
"""
获取请求的完整元数据
Args:
request: FastAPI Request 对象
Returns:
dict: 包含请求元数据的字典
"""
return {
"client_ip": get_client_ip(request),
"user_agent": get_user_agent(request),
"request_id": get_request_id(request),
"method": request.method,
"path": request.url.path,
"query_params": str(request.query_params) if request.query_params else None,
"content_type": request.headers.get("Content-Type"),
"content_length": request.headers.get("Content-Length"),
}
def extract_ip_from_headers(headers: dict) -> str:
"""
从HTTP头字典中提取IP地址用于中间件等场景
Args:
headers: HTTP头字典
Returns:
str: 客户端IP地址
"""
# 检查 X-Forwarded-For
forwarded_for = headers.get("x-forwarded-for", "")
if forwarded_for:
return forwarded_for.split(",")[0].strip()
# 检查 X-Real-IP
real_ip = headers.get("x-real-ip", "")
if real_ip:
return real_ip.strip()
return "unknown"

109
src/utils/sse_parser.py Normal file
View File

@@ -0,0 +1,109 @@
from typing import Dict, List, Optional
class SSEEventParser:
"""轻量SSE解析器按行接收输入并输出完整事件。"""
def __init__(self) -> None:
self._reset_buffer()
def _reset_buffer(self) -> None:
self._buffer: Dict[str, Optional[str] | List[str]] = {
"event": None,
"data": [],
"id": None,
"retry": None,
}
def _finalize_event(self) -> Optional[Dict[str, Optional[str]]]:
data_lines = self._buffer.get("data", [])
if not data_lines:
self._reset_buffer()
return None
data_str = "\n".join(data_lines)
event = {
"event": self._buffer.get("event"),
"data": data_str,
"id": self._buffer.get("id"),
"retry": self._buffer.get("retry"),
}
self._reset_buffer()
return event
def feed_line(self, line: Optional[str]) -> List[Dict[str, Optional[str]]]:
"""处理单行SSE文本返回所有完成的事件。"""
normalized_line = (line or "").rstrip("\r")
events: List[Dict[str, Optional[str]]] = []
# 空行表示事件结束
if normalized_line == "":
event = self._finalize_event()
if event:
events.append(event)
return events
# 注释行直接忽略
if normalized_line.startswith(":") and not normalized_line.startswith("::"):
return events
if normalized_line.startswith("event:"):
_, rest = normalized_line.split(":", 1)
value = rest.lstrip()
if " data:" in value:
event_part, data_part = value.split("data:", 1)
event_name = event_part.strip() or None
data_value = data_part.lstrip()
self._buffer["event"] = event_name
if data_value:
self._append_data_line(data_value)
event = self._finalize_event()
if event:
events.append(event)
else:
event_name = value.strip() or None
self._buffer["event"] = event_name
return events
if normalized_line.startswith("data:"):
# 如果已经有缓存的 data先完成上一个事件
# 这样可以处理没有空行分隔的连续 data 行
existing_data = self._buffer.get("data", [])
if existing_data and len(existing_data) > 0:
event = self._finalize_event()
if event:
events.append(event)
_, rest = normalized_line.split(":", 1)
self._append_data_line(rest[1:] if rest.startswith(" ") else rest)
return events
if normalized_line.startswith("id:"):
_, rest = normalized_line.split(":", 1)
self._buffer["id"] = rest.strip() or None
return events
if normalized_line.startswith("retry:"):
_, rest = normalized_line.split(":", 1)
self._buffer["retry"] = rest.strip() or None
return events
# 未知行:视作数据追加(部分实现会缺少 data: 前缀)
self._append_data_line(normalized_line)
return events
def flush(self) -> List[Dict[str, Optional[str]]]:
"""在流结束时调用,输出尚未完成的事件。"""
event = self._finalize_event()
return [event] if event else []
def _append_data_line(self, value: str) -> None:
data_lines = self._buffer.get("data")
if isinstance(data_lines, list):
data_lines.append(value)
else:
self._buffer["data"] = [value]

View File

@@ -0,0 +1,97 @@
"""分布式任务协调器,确保仅有一个 worker 执行特定任务"""
from __future__ import annotations
import asyncio
import os
import pathlib
import uuid
from typing import Dict, Optional
from src.core.logger import logger
try:
import fcntl # type: ignore
except ImportError: # pragma: no cover - Windows 环境
fcntl = None
class StartupTaskCoordinator:
"""利用 Redis 或文件锁,保证任务只在单个进程/实例中运行"""
def __init__(self, redis_client=None, lock_dir: Optional[str] = None):
self.redis = redis_client
self._tokens: Dict[str, str] = {}
self._file_handles: Dict[str, object] = {}
self._lock_dir = pathlib.Path(lock_dir or os.getenv("TASK_LOCK_DIR", "./.locks"))
if not self._lock_dir.exists():
self._lock_dir.mkdir(parents=True, exist_ok=True)
def _redis_key(self, name: str) -> str:
return f"task_lock:{name}"
async def acquire(self, name: str, ttl: Optional[int] = None) -> bool:
ttl = ttl or int(os.getenv("TASK_COORDINATOR_LOCK_TTL", "86400"))
if self.redis:
token = str(uuid.uuid4())
try:
acquired = await self.redis.set(self._redis_key(name), token, nx=True, ex=ttl)
if acquired:
self._tokens[name] = token
logger.info(f"任务 {name} 通过 Redis 锁独占执行")
return True
return False
except Exception as exc: # pragma: no cover - Redis 异常回退
logger.warning(f"Redis 锁获取失败,回退到文件锁: {exc}")
return await self._acquire_file_lock(name)
async def release(self, name: str):
if self.redis and name in self._tokens:
token = self._tokens.pop(name)
script = """
if redis.call('GET', KEYS[1]) == ARGV[1] then
return redis.call('DEL', KEYS[1])
end
return 0
"""
try:
await self.redis.eval(script, 1, self._redis_key(name), token)
except Exception as exc: # pragma: no cover
logger.warning(f"释放 Redis 锁失败: {exc}")
handle = self._file_handles.pop(name, None)
if handle and fcntl:
try:
fcntl.flock(handle, fcntl.LOCK_UN)
finally:
handle.close()
async def _acquire_file_lock(self, name: str) -> bool:
if fcntl is None:
# 在不支持 fcntl 的平台上退化为单进程锁
if name in self._file_handles:
return False
self._file_handles[name] = object()
logger.warning("操作系统不支持文件锁,任务锁仅在当前进程生效")
return True
lock_path = self._lock_dir / f"{name}.lock"
handle = open(lock_path, "a+")
try:
fcntl.flock(handle, fcntl.LOCK_EX | fcntl.LOCK_NB)
self._file_handles[name] = handle
logger.info(f"任务 {name} 使用文件锁独占执行")
return True
except BlockingIOError:
handle.close()
return False
async def ensure_singleton_task(name: str, redis_client=None, ttl: Optional[int] = None):
"""便捷协程,返回 (coordinator, acquired)"""
coordinator = StartupTaskCoordinator(redis_client)
acquired = await coordinator.acquire(name, ttl=ttl)
return coordinator, acquired

141
src/utils/timeout.py Normal file
View File

@@ -0,0 +1,141 @@
"""
超时保护工具
为异步函数和操作提供超时保护
"""
import asyncio
from functools import wraps
from typing import Any, Callable, Optional, TypeVar
from src.core.logger import logger
T = TypeVar("T")
class AsyncTimeoutError(TimeoutError):
"""异步操作超时错误"""
def __init__(self, message: str, operation: str, timeout: float):
super().__init__(message)
self.operation = operation
self.timeout = timeout
def with_timeout(seconds: float, operation_name: Optional[str] = None):
"""
装饰器:为异步函数添加超时保护
Args:
seconds: 超时时间(秒)
operation_name: 操作名称(用于日志,默认使用函数名)
Usage:
@with_timeout(30.0)
async def my_async_function():
...
@with_timeout(60.0, operation_name="API请求")
async def api_call():
...
"""
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
async def wrapper(*args, **kwargs):
op_name = operation_name or func.__name__
try:
return await asyncio.wait_for(func(*args, **kwargs), timeout=seconds)
except asyncio.TimeoutError:
logger.warning(f"操作超时: {op_name} (timeout={seconds}s)")
raise AsyncTimeoutError(
f"{op_name} 操作超时({seconds}秒)",
operation=op_name,
timeout=seconds,
)
return wrapper
return decorator
async def run_with_timeout(
coro,
timeout: float,
operation_name: str = "operation",
default: T = None,
raise_on_timeout: bool = True,
) -> T:
"""
为协程添加超时保护(函数式调用)
Args:
coro: 协程对象
timeout: 超时时间(秒)
operation_name: 操作名称(用于日志)
default: 超时时返回的默认值(仅在 raise_on_timeout=False 时有效)
raise_on_timeout: 超时时是否抛出异常
Returns:
协程的返回值,或超时时的默认值
Usage:
result = await run_with_timeout(
my_async_function(),
timeout=30.0,
operation_name="API请求"
)
"""
try:
return await asyncio.wait_for(coro, timeout=timeout)
except asyncio.TimeoutError:
logger.warning(f"操作超时: {operation_name} (timeout={timeout}s)")
if raise_on_timeout:
raise AsyncTimeoutError(
f"{operation_name} 操作超时({timeout}秒)",
operation=operation_name,
timeout=timeout,
)
return default
class TimeoutContext:
"""
超时上下文管理器
Usage:
async with TimeoutContext(30.0, "数据库查询") as ctx:
result = await db.query(...)
# 如果超过30秒会抛出 AsyncTimeoutError
"""
def __init__(self, timeout: float, operation_name: str = "operation"):
self.timeout = timeout
self.operation_name = operation_name
self._task: Optional[asyncio.Task] = None
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
# asyncio.timeout 在 Python 3.11+ 可用
# 这里使用更通用的方式
pass
async def with_timeout_context(timeout: float, operation_name: str = "operation"):
"""
超时上下文管理器Python 3.11+ asyncio.timeout 的替代)
Usage:
async with with_timeout_context(30.0, "API请求"):
result = await api_call()
"""
try:
# Python 3.11+ 使用内置的 asyncio.timeout
return asyncio.timeout(timeout)
except AttributeError:
# Python 3.10 及以下版本的兼容实现
# 注意:这个简单实现不支持嵌套取消
pass

View File

@@ -0,0 +1,304 @@
"""
数据库事务管理工具
提供事务装饰器和事务上下文管理器
支持同步和异步函数
"""
import functools
import inspect
from contextlib import contextmanager
from typing import Any, Callable, Generator, Optional
from sqlalchemy.exc import DatabaseError, IntegrityError
from sqlalchemy.orm import Session
from src.core.logger import logger
class TransactionError(Exception):
"""事务处理异常"""
pass
def _find_db_session(args, kwargs) -> Optional[Session]:
"""从参数中查找数据库会话"""
# 从位置参数中查找Session
for arg in args:
if isinstance(arg, Session):
return arg
# 从关键字参数中查找Session
for value in kwargs.values():
if isinstance(value, Session):
return value
return None
def transactional(commit: bool = True, rollback_on_error: bool = True):
"""
事务装饰器,支持同步和异步函数
Args:
commit: 是否在成功时自动提交默认True
rollback_on_error: 是否在错误时自动回滚默认True
Usage:
@transactional()
def create_user_with_api_key(db: Session, ...):
# 同步方法会在事务中执行
pass
@transactional()
async def create_user_async(db: Session, ...):
# 异步方法也会在事务中执行
pass
"""
def decorator(func: Callable) -> Callable:
# 检查是否是异步函数
if inspect.iscoroutinefunction(func):
@functools.wraps(func)
async def async_wrapper(*args, **kwargs) -> Any:
db_session = _find_db_session(args, kwargs)
if not db_session:
raise TransactionError(
f"No SQLAlchemy Session found in arguments for {func.__name__}"
)
# 检查是否已经在事务中
if db_session.in_transaction():
return await func(*args, **kwargs)
transaction_id = f"{func.__module__}.{func.__name__}"
logger.debug(f"开始异步事务: {transaction_id}")
try:
result = await func(*args, **kwargs)
if commit:
db_session.commit()
logger.debug(f"异步事务提交成功: {transaction_id}")
return result
except Exception as e:
if rollback_on_error:
try:
db_session.rollback()
except Exception:
pass
logger.error(
f"异步事务回滚: {transaction_id} - {type(e).__name__}: {str(e)}"
)
else:
logger.error(
f"异步事务异常(未回滚): {transaction_id} - {type(e).__name__}: {str(e)}"
)
raise
return async_wrapper
else:
@functools.wraps(func)
def sync_wrapper(*args, **kwargs) -> Any:
db_session = _find_db_session(args, kwargs)
if not db_session:
raise TransactionError(
f"No SQLAlchemy Session found in arguments for {func.__name__}"
)
# 检查是否已经在事务中
if db_session.in_transaction():
return func(*args, **kwargs)
transaction_id = f"{func.__module__}.{func.__name__}"
logger.debug(f"开始事务: {transaction_id}")
try:
result = func(*args, **kwargs)
if commit:
db_session.commit()
logger.debug(f"事务提交成功: {transaction_id}")
return result
except Exception as e:
if rollback_on_error:
try:
db_session.rollback()
except Exception:
pass
logger.error(
f"事务回滚: {transaction_id} - {type(e).__name__}: {str(e)}"
)
else:
logger.error(
f"事务异常(未回滚): {transaction_id} - {type(e).__name__}: {str(e)}"
)
raise
return sync_wrapper
return decorator
@contextmanager
def transaction_scope(
db: Session,
commit_on_success: bool = True,
rollback_on_error: bool = True,
operation_name: Optional[str] = None,
) -> Generator[Session, None, None]:
"""
事务上下文管理器
Args:
db: 数据库会话
commit_on_success: 成功时是否自动提交
rollback_on_error: 失败时是否自动回滚
operation_name: 操作名称,用于日志
Usage:
with transaction_scope(db, operation_name="create_user") as tx:
user = User(...)
tx.add(user)
# 自动提交或回滚
"""
operation_name = operation_name or "database_operation"
# 检查是否已经在事务中
if db.in_transaction():
# 已经在事务中直接返回session
logger.debug(f"使用现有事务: {operation_name}")
yield db
return
logger.debug(f"开始事务范围: {operation_name}")
try:
yield db
if commit_on_success:
db.commit()
logger.debug(f"事务范围提交成功: {operation_name}")
except Exception as e:
if rollback_on_error:
db.rollback()
logger.error(f"事务范围回滚: {operation_name} - {type(e).__name__}: {str(e)}")
raise
def retry_on_database_error(max_retries: int = 3, delay: float = 0.1):
"""
数据库错误重试装饰器
Args:
max_retries: 最大重试次数
delay: 重试延迟(秒)
"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs) -> Any:
import random
import time
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except (DatabaseError, IntegrityError) as e:
if attempt < max_retries - 1:
# 随机化延迟,避免多个请求同时重试
actual_delay = delay * (2**attempt) + random.uniform(0, 0.1)
logger.warning(f"数据库操作失败,{actual_delay:.2f}秒后重试 (尝试 {attempt + 1}/{max_retries}): {str(e)}")
time.sleep(actual_delay)
continue
else:
logger.error(
f"数据库操作最终失败,已达最大重试次数({max_retries}): {func.__name__} - {str(e)}"
)
raise
return wrapper
return decorator
class BatchOperation:
"""
批量操作管理器
用于处理大量数据插入/更新操作
"""
def __init__(self, db: Session, batch_size: int = 100):
self.db = db
self.batch_size = batch_size
self.operations = []
self.operation_count = 0
def add(self, obj):
"""添加对象到批处理"""
self.operations.append(("add", obj))
self.operation_count += 1
if self.operation_count >= self.batch_size:
self.flush()
def update(self, obj):
"""添加更新操作到批处理"""
self.operations.append(("merge", obj))
self.operation_count += 1
if self.operation_count >= self.batch_size:
self.flush()
def flush(self):
"""执行当前批次的所有操作"""
if not self.operations:
return
logger.debug(f"执行批量操作: {len(self.operations)}")
try:
for operation, obj in self.operations:
if operation == "add":
self.db.add(obj)
elif operation == "merge":
self.db.merge(obj)
self.db.flush() # 只flush不提交
logger.debug(f"批量操作flush完成: {len(self.operations)}")
except Exception as e:
logger.error(f"批量操作失败({len(self.operations)}项): {type(e).__name__}: {str(e)}")
raise
finally:
# 清空操作列表
self.operations.clear()
self.operation_count = 0
def commit(self):
"""提交所有操作"""
self.flush() # 确保所有操作都已flush
self.db.commit()
logger.debug("批量操作提交完成")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
# 正常退出,提交事务
self.commit()
else:
# 异常退出,回滚事务
self.db.rollback()
logger.error(f"批量操作异常退出,已回滚: {str(exc_val)}")