mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-08 18:52:28 +08:00
305 lines
9.2 KiB
Python
305 lines
9.2 KiB
Python
|
|
"""
|
|||
|
|
数据库事务管理工具
|
|||
|
|
提供事务装饰器和事务上下文管理器
|
|||
|
|
支持同步和异步函数
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import functools
|
|||
|
|
import inspect
|
|||
|
|
from contextlib import contextmanager
|
|||
|
|
from typing import Any, Callable, Generator, Optional
|
|||
|
|
|
|||
|
|
from sqlalchemy.exc import DatabaseError, IntegrityError
|
|||
|
|
from sqlalchemy.orm import Session
|
|||
|
|
|
|||
|
|
from src.core.logger import logger
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TransactionError(Exception):
|
|||
|
|
"""事务处理异常"""
|
|||
|
|
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _find_db_session(args, kwargs) -> Optional[Session]:
|
|||
|
|
"""从参数中查找数据库会话"""
|
|||
|
|
# 从位置参数中查找Session
|
|||
|
|
for arg in args:
|
|||
|
|
if isinstance(arg, Session):
|
|||
|
|
return arg
|
|||
|
|
|
|||
|
|
# 从关键字参数中查找Session
|
|||
|
|
for value in kwargs.values():
|
|||
|
|
if isinstance(value, Session):
|
|||
|
|
return value
|
|||
|
|
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def transactional(commit: bool = True, rollback_on_error: bool = True):
|
|||
|
|
"""
|
|||
|
|
事务装饰器,支持同步和异步函数
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
commit: 是否在成功时自动提交,默认True
|
|||
|
|
rollback_on_error: 是否在错误时自动回滚,默认True
|
|||
|
|
|
|||
|
|
Usage:
|
|||
|
|
@transactional()
|
|||
|
|
def create_user_with_api_key(db: Session, ...):
|
|||
|
|
# 同步方法会在事务中执行
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
@transactional()
|
|||
|
|
async def create_user_async(db: Session, ...):
|
|||
|
|
# 异步方法也会在事务中执行
|
|||
|
|
pass
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def decorator(func: Callable) -> Callable:
|
|||
|
|
# 检查是否是异步函数
|
|||
|
|
if inspect.iscoroutinefunction(func):
|
|||
|
|
@functools.wraps(func)
|
|||
|
|
async def async_wrapper(*args, **kwargs) -> Any:
|
|||
|
|
db_session = _find_db_session(args, kwargs)
|
|||
|
|
|
|||
|
|
if not db_session:
|
|||
|
|
raise TransactionError(
|
|||
|
|
f"No SQLAlchemy Session found in arguments for {func.__name__}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 检查是否已经在事务中
|
|||
|
|
if db_session.in_transaction():
|
|||
|
|
return await func(*args, **kwargs)
|
|||
|
|
|
|||
|
|
transaction_id = f"{func.__module__}.{func.__name__}"
|
|||
|
|
logger.debug(f"开始异步事务: {transaction_id}")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
result = await func(*args, **kwargs)
|
|||
|
|
|
|||
|
|
if commit:
|
|||
|
|
db_session.commit()
|
|||
|
|
logger.debug(f"异步事务提交成功: {transaction_id}")
|
|||
|
|
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
if rollback_on_error:
|
|||
|
|
try:
|
|||
|
|
db_session.rollback()
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
logger.error(
|
|||
|
|
f"异步事务回滚: {transaction_id} - {type(e).__name__}: {str(e)}"
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
logger.error(
|
|||
|
|
f"异步事务异常(未回滚): {transaction_id} - {type(e).__name__}: {str(e)}"
|
|||
|
|
)
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
return async_wrapper
|
|||
|
|
else:
|
|||
|
|
@functools.wraps(func)
|
|||
|
|
def sync_wrapper(*args, **kwargs) -> Any:
|
|||
|
|
db_session = _find_db_session(args, kwargs)
|
|||
|
|
|
|||
|
|
if not db_session:
|
|||
|
|
raise TransactionError(
|
|||
|
|
f"No SQLAlchemy Session found in arguments for {func.__name__}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 检查是否已经在事务中
|
|||
|
|
if db_session.in_transaction():
|
|||
|
|
return func(*args, **kwargs)
|
|||
|
|
|
|||
|
|
transaction_id = f"{func.__module__}.{func.__name__}"
|
|||
|
|
logger.debug(f"开始事务: {transaction_id}")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
result = func(*args, **kwargs)
|
|||
|
|
|
|||
|
|
if commit:
|
|||
|
|
db_session.commit()
|
|||
|
|
logger.debug(f"事务提交成功: {transaction_id}")
|
|||
|
|
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
if rollback_on_error:
|
|||
|
|
try:
|
|||
|
|
db_session.rollback()
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
logger.error(
|
|||
|
|
f"事务回滚: {transaction_id} - {type(e).__name__}: {str(e)}"
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
logger.error(
|
|||
|
|
f"事务异常(未回滚): {transaction_id} - {type(e).__name__}: {str(e)}"
|
|||
|
|
)
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
return sync_wrapper
|
|||
|
|
|
|||
|
|
return decorator
|
|||
|
|
|
|||
|
|
|
|||
|
|
@contextmanager
|
|||
|
|
def transaction_scope(
|
|||
|
|
db: Session,
|
|||
|
|
commit_on_success: bool = True,
|
|||
|
|
rollback_on_error: bool = True,
|
|||
|
|
operation_name: Optional[str] = None,
|
|||
|
|
) -> Generator[Session, None, None]:
|
|||
|
|
"""
|
|||
|
|
事务上下文管理器
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
db: 数据库会话
|
|||
|
|
commit_on_success: 成功时是否自动提交
|
|||
|
|
rollback_on_error: 失败时是否自动回滚
|
|||
|
|
operation_name: 操作名称,用于日志
|
|||
|
|
|
|||
|
|
Usage:
|
|||
|
|
with transaction_scope(db, operation_name="create_user") as tx:
|
|||
|
|
user = User(...)
|
|||
|
|
tx.add(user)
|
|||
|
|
# 自动提交或回滚
|
|||
|
|
"""
|
|||
|
|
operation_name = operation_name or "database_operation"
|
|||
|
|
|
|||
|
|
# 检查是否已经在事务中
|
|||
|
|
if db.in_transaction():
|
|||
|
|
# 已经在事务中,直接返回session
|
|||
|
|
logger.debug(f"使用现有事务: {operation_name}")
|
|||
|
|
yield db
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
logger.debug(f"开始事务范围: {operation_name}")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
yield db
|
|||
|
|
|
|||
|
|
if commit_on_success:
|
|||
|
|
db.commit()
|
|||
|
|
logger.debug(f"事务范围提交成功: {operation_name}")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
if rollback_on_error:
|
|||
|
|
db.rollback()
|
|||
|
|
logger.error(f"事务范围回滚: {operation_name} - {type(e).__name__}: {str(e)}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
|
|||
|
|
def retry_on_database_error(max_retries: int = 3, delay: float = 0.1):
|
|||
|
|
"""
|
|||
|
|
数据库错误重试装饰器
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
max_retries: 最大重试次数
|
|||
|
|
delay: 重试延迟(秒)
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def decorator(func: Callable) -> Callable:
|
|||
|
|
@functools.wraps(func)
|
|||
|
|
def wrapper(*args, **kwargs) -> Any:
|
|||
|
|
import random
|
|||
|
|
import time
|
|||
|
|
|
|||
|
|
for attempt in range(max_retries):
|
|||
|
|
try:
|
|||
|
|
return func(*args, **kwargs)
|
|||
|
|
|
|||
|
|
except (DatabaseError, IntegrityError) as e:
|
|||
|
|
if attempt < max_retries - 1:
|
|||
|
|
# 随机化延迟,避免多个请求同时重试
|
|||
|
|
actual_delay = delay * (2**attempt) + random.uniform(0, 0.1)
|
|||
|
|
logger.warning(f"数据库操作失败,{actual_delay:.2f}秒后重试 (尝试 {attempt + 1}/{max_retries}): {str(e)}")
|
|||
|
|
time.sleep(actual_delay)
|
|||
|
|
continue
|
|||
|
|
else:
|
|||
|
|
logger.error(
|
|||
|
|
f"数据库操作最终失败,已达最大重试次数({max_retries}): {func.__name__} - {str(e)}"
|
|||
|
|
)
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
return wrapper
|
|||
|
|
|
|||
|
|
return decorator
|
|||
|
|
|
|||
|
|
|
|||
|
|
class BatchOperation:
|
|||
|
|
"""
|
|||
|
|
批量操作管理器
|
|||
|
|
用于处理大量数据插入/更新操作
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, db: Session, batch_size: int = 100):
|
|||
|
|
self.db = db
|
|||
|
|
self.batch_size = batch_size
|
|||
|
|
self.operations = []
|
|||
|
|
self.operation_count = 0
|
|||
|
|
|
|||
|
|
def add(self, obj):
|
|||
|
|
"""添加对象到批处理"""
|
|||
|
|
self.operations.append(("add", obj))
|
|||
|
|
self.operation_count += 1
|
|||
|
|
|
|||
|
|
if self.operation_count >= self.batch_size:
|
|||
|
|
self.flush()
|
|||
|
|
|
|||
|
|
def update(self, obj):
|
|||
|
|
"""添加更新操作到批处理"""
|
|||
|
|
self.operations.append(("merge", obj))
|
|||
|
|
self.operation_count += 1
|
|||
|
|
|
|||
|
|
if self.operation_count >= self.batch_size:
|
|||
|
|
self.flush()
|
|||
|
|
|
|||
|
|
def flush(self):
|
|||
|
|
"""执行当前批次的所有操作"""
|
|||
|
|
if not self.operations:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
logger.debug(f"执行批量操作: {len(self.operations)} 项")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
for operation, obj in self.operations:
|
|||
|
|
if operation == "add":
|
|||
|
|
self.db.add(obj)
|
|||
|
|
elif operation == "merge":
|
|||
|
|
self.db.merge(obj)
|
|||
|
|
|
|||
|
|
self.db.flush() # 只flush,不提交
|
|||
|
|
logger.debug(f"批量操作flush完成: {len(self.operations)} 项")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"批量操作失败({len(self.operations)}项): {type(e).__name__}: {str(e)}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
finally:
|
|||
|
|
# 清空操作列表
|
|||
|
|
self.operations.clear()
|
|||
|
|
self.operation_count = 0
|
|||
|
|
|
|||
|
|
def commit(self):
|
|||
|
|
"""提交所有操作"""
|
|||
|
|
self.flush() # 确保所有操作都已flush
|
|||
|
|
self.db.commit()
|
|||
|
|
logger.debug("批量操作提交完成")
|
|||
|
|
|
|||
|
|
def __enter__(self):
|
|||
|
|
return self
|
|||
|
|
|
|||
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|||
|
|
if exc_type is None:
|
|||
|
|
# 正常退出,提交事务
|
|||
|
|
self.commit()
|
|||
|
|
else:
|
|||
|
|
# 异常退出,回滚事务
|
|||
|
|
self.db.rollback()
|
|||
|
|
logger.error(f"批量操作异常退出,已回滚: {str(exc_val)}")
|