mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-09 19:22:26 +08:00
Initial commit
This commit is contained in:
304
src/utils/transaction_manager.py
Normal file
304
src/utils/transaction_manager.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
数据库事务管理工具
|
||||
提供事务装饰器和事务上下文管理器
|
||||
支持同步和异步函数
|
||||
"""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Generator, Optional
|
||||
|
||||
from sqlalchemy.exc import DatabaseError, IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.logger import logger
|
||||
|
||||
|
||||
|
||||
class TransactionError(Exception):
|
||||
"""事务处理异常"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def _find_db_session(args, kwargs) -> Optional[Session]:
|
||||
"""从参数中查找数据库会话"""
|
||||
# 从位置参数中查找Session
|
||||
for arg in args:
|
||||
if isinstance(arg, Session):
|
||||
return arg
|
||||
|
||||
# 从关键字参数中查找Session
|
||||
for value in kwargs.values():
|
||||
if isinstance(value, Session):
|
||||
return value
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def transactional(commit: bool = True, rollback_on_error: bool = True):
|
||||
"""
|
||||
事务装饰器,支持同步和异步函数
|
||||
|
||||
Args:
|
||||
commit: 是否在成功时自动提交,默认True
|
||||
rollback_on_error: 是否在错误时自动回滚,默认True
|
||||
|
||||
Usage:
|
||||
@transactional()
|
||||
def create_user_with_api_key(db: Session, ...):
|
||||
# 同步方法会在事务中执行
|
||||
pass
|
||||
|
||||
@transactional()
|
||||
async def create_user_async(db: Session, ...):
|
||||
# 异步方法也会在事务中执行
|
||||
pass
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
# 检查是否是异步函数
|
||||
if inspect.iscoroutinefunction(func):
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args, **kwargs) -> Any:
|
||||
db_session = _find_db_session(args, kwargs)
|
||||
|
||||
if not db_session:
|
||||
raise TransactionError(
|
||||
f"No SQLAlchemy Session found in arguments for {func.__name__}"
|
||||
)
|
||||
|
||||
# 检查是否已经在事务中
|
||||
if db_session.in_transaction():
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
transaction_id = f"{func.__module__}.{func.__name__}"
|
||||
logger.debug(f"开始异步事务: {transaction_id}")
|
||||
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
logger.debug(f"异步事务提交成功: {transaction_id}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
if rollback_on_error:
|
||||
try:
|
||||
db_session.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
logger.error(
|
||||
f"异步事务回滚: {transaction_id} - {type(e).__name__}: {str(e)}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"异步事务异常(未回滚): {transaction_id} - {type(e).__name__}: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
return async_wrapper
|
||||
else:
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args, **kwargs) -> Any:
|
||||
db_session = _find_db_session(args, kwargs)
|
||||
|
||||
if not db_session:
|
||||
raise TransactionError(
|
||||
f"No SQLAlchemy Session found in arguments for {func.__name__}"
|
||||
)
|
||||
|
||||
# 检查是否已经在事务中
|
||||
if db_session.in_transaction():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
transaction_id = f"{func.__module__}.{func.__name__}"
|
||||
logger.debug(f"开始事务: {transaction_id}")
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
logger.debug(f"事务提交成功: {transaction_id}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
if rollback_on_error:
|
||||
try:
|
||||
db_session.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
logger.error(
|
||||
f"事务回滚: {transaction_id} - {type(e).__name__}: {str(e)}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"事务异常(未回滚): {transaction_id} - {type(e).__name__}: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@contextmanager
|
||||
def transaction_scope(
|
||||
db: Session,
|
||||
commit_on_success: bool = True,
|
||||
rollback_on_error: bool = True,
|
||||
operation_name: Optional[str] = None,
|
||||
) -> Generator[Session, None, None]:
|
||||
"""
|
||||
事务上下文管理器
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
commit_on_success: 成功时是否自动提交
|
||||
rollback_on_error: 失败时是否自动回滚
|
||||
operation_name: 操作名称,用于日志
|
||||
|
||||
Usage:
|
||||
with transaction_scope(db, operation_name="create_user") as tx:
|
||||
user = User(...)
|
||||
tx.add(user)
|
||||
# 自动提交或回滚
|
||||
"""
|
||||
operation_name = operation_name or "database_operation"
|
||||
|
||||
# 检查是否已经在事务中
|
||||
if db.in_transaction():
|
||||
# 已经在事务中,直接返回session
|
||||
logger.debug(f"使用现有事务: {operation_name}")
|
||||
yield db
|
||||
return
|
||||
|
||||
logger.debug(f"开始事务范围: {operation_name}")
|
||||
|
||||
try:
|
||||
yield db
|
||||
|
||||
if commit_on_success:
|
||||
db.commit()
|
||||
logger.debug(f"事务范围提交成功: {operation_name}")
|
||||
|
||||
except Exception as e:
|
||||
if rollback_on_error:
|
||||
db.rollback()
|
||||
logger.error(f"事务范围回滚: {operation_name} - {type(e).__name__}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def retry_on_database_error(max_retries: int = 3, delay: float = 0.1):
|
||||
"""
|
||||
数据库错误重试装饰器
|
||||
|
||||
Args:
|
||||
max_retries: 最大重试次数
|
||||
delay: 重试延迟(秒)
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Any:
|
||||
import random
|
||||
import time
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
except (DatabaseError, IntegrityError) as e:
|
||||
if attempt < max_retries - 1:
|
||||
# 随机化延迟,避免多个请求同时重试
|
||||
actual_delay = delay * (2**attempt) + random.uniform(0, 0.1)
|
||||
logger.warning(f"数据库操作失败,{actual_delay:.2f}秒后重试 (尝试 {attempt + 1}/{max_retries}): {str(e)}")
|
||||
time.sleep(actual_delay)
|
||||
continue
|
||||
else:
|
||||
logger.error(
|
||||
f"数据库操作最终失败,已达最大重试次数({max_retries}): {func.__name__} - {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class BatchOperation:
|
||||
"""
|
||||
批量操作管理器
|
||||
用于处理大量数据插入/更新操作
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session, batch_size: int = 100):
|
||||
self.db = db
|
||||
self.batch_size = batch_size
|
||||
self.operations = []
|
||||
self.operation_count = 0
|
||||
|
||||
def add(self, obj):
|
||||
"""添加对象到批处理"""
|
||||
self.operations.append(("add", obj))
|
||||
self.operation_count += 1
|
||||
|
||||
if self.operation_count >= self.batch_size:
|
||||
self.flush()
|
||||
|
||||
def update(self, obj):
|
||||
"""添加更新操作到批处理"""
|
||||
self.operations.append(("merge", obj))
|
||||
self.operation_count += 1
|
||||
|
||||
if self.operation_count >= self.batch_size:
|
||||
self.flush()
|
||||
|
||||
def flush(self):
|
||||
"""执行当前批次的所有操作"""
|
||||
if not self.operations:
|
||||
return
|
||||
|
||||
logger.debug(f"执行批量操作: {len(self.operations)} 项")
|
||||
|
||||
try:
|
||||
for operation, obj in self.operations:
|
||||
if operation == "add":
|
||||
self.db.add(obj)
|
||||
elif operation == "merge":
|
||||
self.db.merge(obj)
|
||||
|
||||
self.db.flush() # 只flush,不提交
|
||||
logger.debug(f"批量操作flush完成: {len(self.operations)} 项")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量操作失败({len(self.operations)}项): {type(e).__name__}: {str(e)}")
|
||||
raise
|
||||
|
||||
finally:
|
||||
# 清空操作列表
|
||||
self.operations.clear()
|
||||
self.operation_count = 0
|
||||
|
||||
def commit(self):
|
||||
"""提交所有操作"""
|
||||
self.flush() # 确保所有操作都已flush
|
||||
self.db.commit()
|
||||
logger.debug("批量操作提交完成")
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type is None:
|
||||
# 正常退出,提交事务
|
||||
self.commit()
|
||||
else:
|
||||
# 异常退出,回滚事务
|
||||
self.db.rollback()
|
||||
logger.error(f"批量操作异常退出,已回滚: {str(exc_val)}")
|
||||
Reference in New Issue
Block a user