mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-07 18:22:28 +08:00
305 lines
9.2 KiB
Python
305 lines
9.2 KiB
Python
"""
|
||
数据库事务管理工具
|
||
提供事务装饰器和事务上下文管理器
|
||
支持同步和异步函数
|
||
"""
|
||
|
||
import functools
|
||
import inspect
|
||
from contextlib import contextmanager
|
||
from typing import Any, Callable, Generator, Optional
|
||
|
||
from sqlalchemy.exc import DatabaseError, IntegrityError
|
||
from sqlalchemy.orm import Session
|
||
|
||
from src.core.logger import logger
|
||
|
||
|
||
|
||
class TransactionError(Exception):
|
||
"""事务处理异常"""
|
||
|
||
pass
|
||
|
||
|
||
def _find_db_session(args, kwargs) -> Optional[Session]:
|
||
"""从参数中查找数据库会话"""
|
||
# 从位置参数中查找Session
|
||
for arg in args:
|
||
if isinstance(arg, Session):
|
||
return arg
|
||
|
||
# 从关键字参数中查找Session
|
||
for value in kwargs.values():
|
||
if isinstance(value, Session):
|
||
return value
|
||
|
||
return None
|
||
|
||
|
||
def transactional(commit: bool = True, rollback_on_error: bool = True):
|
||
"""
|
||
事务装饰器,支持同步和异步函数
|
||
|
||
Args:
|
||
commit: 是否在成功时自动提交,默认True
|
||
rollback_on_error: 是否在错误时自动回滚,默认True
|
||
|
||
Usage:
|
||
@transactional()
|
||
def create_user_with_api_key(db: Session, ...):
|
||
# 同步方法会在事务中执行
|
||
pass
|
||
|
||
@transactional()
|
||
async def create_user_async(db: Session, ...):
|
||
# 异步方法也会在事务中执行
|
||
pass
|
||
"""
|
||
|
||
def decorator(func: Callable) -> Callable:
|
||
# 检查是否是异步函数
|
||
if inspect.iscoroutinefunction(func):
|
||
@functools.wraps(func)
|
||
async def async_wrapper(*args, **kwargs) -> Any:
|
||
db_session = _find_db_session(args, kwargs)
|
||
|
||
if not db_session:
|
||
raise TransactionError(
|
||
f"No SQLAlchemy Session found in arguments for {func.__name__}"
|
||
)
|
||
|
||
# 检查是否已经在事务中
|
||
if db_session.in_transaction():
|
||
return await func(*args, **kwargs)
|
||
|
||
transaction_id = f"{func.__module__}.{func.__name__}"
|
||
logger.debug(f"开始异步事务: {transaction_id}")
|
||
|
||
try:
|
||
result = await func(*args, **kwargs)
|
||
|
||
if commit:
|
||
db_session.commit()
|
||
logger.debug(f"异步事务提交成功: {transaction_id}")
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
if rollback_on_error:
|
||
try:
|
||
db_session.rollback()
|
||
except Exception:
|
||
pass
|
||
logger.error(
|
||
f"异步事务回滚: {transaction_id} - {type(e).__name__}: {str(e)}"
|
||
)
|
||
else:
|
||
logger.error(
|
||
f"异步事务异常(未回滚): {transaction_id} - {type(e).__name__}: {str(e)}"
|
||
)
|
||
raise
|
||
|
||
return async_wrapper
|
||
else:
|
||
@functools.wraps(func)
|
||
def sync_wrapper(*args, **kwargs) -> Any:
|
||
db_session = _find_db_session(args, kwargs)
|
||
|
||
if not db_session:
|
||
raise TransactionError(
|
||
f"No SQLAlchemy Session found in arguments for {func.__name__}"
|
||
)
|
||
|
||
# 检查是否已经在事务中
|
||
if db_session.in_transaction():
|
||
return func(*args, **kwargs)
|
||
|
||
transaction_id = f"{func.__module__}.{func.__name__}"
|
||
logger.debug(f"开始事务: {transaction_id}")
|
||
|
||
try:
|
||
result = func(*args, **kwargs)
|
||
|
||
if commit:
|
||
db_session.commit()
|
||
logger.debug(f"事务提交成功: {transaction_id}")
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
if rollback_on_error:
|
||
try:
|
||
db_session.rollback()
|
||
except Exception:
|
||
pass
|
||
logger.error(
|
||
f"事务回滚: {transaction_id} - {type(e).__name__}: {str(e)}"
|
||
)
|
||
else:
|
||
logger.error(
|
||
f"事务异常(未回滚): {transaction_id} - {type(e).__name__}: {str(e)}"
|
||
)
|
||
raise
|
||
|
||
return sync_wrapper
|
||
|
||
return decorator
|
||
|
||
|
||
@contextmanager
|
||
def transaction_scope(
|
||
db: Session,
|
||
commit_on_success: bool = True,
|
||
rollback_on_error: bool = True,
|
||
operation_name: Optional[str] = None,
|
||
) -> Generator[Session, None, None]:
|
||
"""
|
||
事务上下文管理器
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
commit_on_success: 成功时是否自动提交
|
||
rollback_on_error: 失败时是否自动回滚
|
||
operation_name: 操作名称,用于日志
|
||
|
||
Usage:
|
||
with transaction_scope(db, operation_name="create_user") as tx:
|
||
user = User(...)
|
||
tx.add(user)
|
||
# 自动提交或回滚
|
||
"""
|
||
operation_name = operation_name or "database_operation"
|
||
|
||
# 检查是否已经在事务中
|
||
if db.in_transaction():
|
||
# 已经在事务中,直接返回session
|
||
logger.debug(f"使用现有事务: {operation_name}")
|
||
yield db
|
||
return
|
||
|
||
logger.debug(f"开始事务范围: {operation_name}")
|
||
|
||
try:
|
||
yield db
|
||
|
||
if commit_on_success:
|
||
db.commit()
|
||
logger.debug(f"事务范围提交成功: {operation_name}")
|
||
|
||
except Exception as e:
|
||
if rollback_on_error:
|
||
db.rollback()
|
||
logger.error(f"事务范围回滚: {operation_name} - {type(e).__name__}: {str(e)}")
|
||
raise
|
||
|
||
|
||
def retry_on_database_error(max_retries: int = 3, delay: float = 0.1):
|
||
"""
|
||
数据库错误重试装饰器
|
||
|
||
Args:
|
||
max_retries: 最大重试次数
|
||
delay: 重试延迟(秒)
|
||
"""
|
||
|
||
def decorator(func: Callable) -> Callable:
|
||
@functools.wraps(func)
|
||
def wrapper(*args, **kwargs) -> Any:
|
||
import random
|
||
import time
|
||
|
||
for attempt in range(max_retries):
|
||
try:
|
||
return func(*args, **kwargs)
|
||
|
||
except (DatabaseError, IntegrityError) as e:
|
||
if attempt < max_retries - 1:
|
||
# 随机化延迟,避免多个请求同时重试
|
||
actual_delay = delay * (2**attempt) + random.uniform(0, 0.1)
|
||
logger.warning(f"数据库操作失败,{actual_delay:.2f}秒后重试 (尝试 {attempt + 1}/{max_retries}): {str(e)}")
|
||
time.sleep(actual_delay)
|
||
continue
|
||
else:
|
||
logger.error(
|
||
f"数据库操作最终失败,已达最大重试次数({max_retries}): {func.__name__} - {str(e)}"
|
||
)
|
||
raise
|
||
|
||
return wrapper
|
||
|
||
return decorator
|
||
|
||
|
||
class BatchOperation:
|
||
"""
|
||
批量操作管理器
|
||
用于处理大量数据插入/更新操作
|
||
"""
|
||
|
||
def __init__(self, db: Session, batch_size: int = 100):
|
||
self.db = db
|
||
self.batch_size = batch_size
|
||
self.operations = []
|
||
self.operation_count = 0
|
||
|
||
def add(self, obj):
|
||
"""添加对象到批处理"""
|
||
self.operations.append(("add", obj))
|
||
self.operation_count += 1
|
||
|
||
if self.operation_count >= self.batch_size:
|
||
self.flush()
|
||
|
||
def update(self, obj):
|
||
"""添加更新操作到批处理"""
|
||
self.operations.append(("merge", obj))
|
||
self.operation_count += 1
|
||
|
||
if self.operation_count >= self.batch_size:
|
||
self.flush()
|
||
|
||
def flush(self):
|
||
"""执行当前批次的所有操作"""
|
||
if not self.operations:
|
||
return
|
||
|
||
logger.debug(f"执行批量操作: {len(self.operations)} 项")
|
||
|
||
try:
|
||
for operation, obj in self.operations:
|
||
if operation == "add":
|
||
self.db.add(obj)
|
||
elif operation == "merge":
|
||
self.db.merge(obj)
|
||
|
||
self.db.flush() # 只flush,不提交
|
||
logger.debug(f"批量操作flush完成: {len(self.operations)} 项")
|
||
|
||
except Exception as e:
|
||
logger.error(f"批量操作失败({len(self.operations)}项): {type(e).__name__}: {str(e)}")
|
||
raise
|
||
|
||
finally:
|
||
# 清空操作列表
|
||
self.operations.clear()
|
||
self.operation_count = 0
|
||
|
||
def commit(self):
|
||
"""提交所有操作"""
|
||
self.flush() # 确保所有操作都已flush
|
||
self.db.commit()
|
||
logger.debug("批量操作提交完成")
|
||
|
||
def __enter__(self):
|
||
return self
|
||
|
||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||
if exc_type is None:
|
||
# 正常退出,提交事务
|
||
self.commit()
|
||
else:
|
||
# 异常退出,回滚事务
|
||
self.db.rollback()
|
||
logger.error(f"批量操作异常退出,已回滚: {str(exc_val)}")
|