mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-07 02:02:27 +08:00
Initial commit
This commit is contained in:
3
src/utils/__init__.py
Normal file
3
src/utils/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .timeout import AsyncTimeoutError, run_with_timeout, with_timeout
|
||||
|
||||
__all__ = ["with_timeout", "run_with_timeout", "AsyncTimeoutError"]
|
||||
203
src/utils/auth_utils.py
Normal file
203
src/utils/auth_utils.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""
|
||||
认证工具函数
|
||||
提供统一的用户认证和授权功能
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends, Header, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.services.auth.service import AuthService
|
||||
|
||||
from ..core.exceptions import ForbiddenException
|
||||
from src.core.logger import logger
|
||||
from ..database import get_db
|
||||
from ..models.database import User, UserRole
|
||||
|
||||
security = HTTPBearer()
|
||||
|
||||
|
||||
def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security), db: Session = Depends(get_db)
|
||||
) -> User:
|
||||
"""
|
||||
获取当前登录用户
|
||||
统一的认证依赖函数
|
||||
|
||||
Args:
|
||||
credentials: Bearer token 凭据
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
User: 当前用户对象
|
||||
|
||||
Raises:
|
||||
HTTPException: 认证失败时抛出
|
||||
"""
|
||||
token = credentials.credentials
|
||||
|
||||
try:
|
||||
# 验证Token格式和签名
|
||||
try:
|
||||
payload = AuthService.verify_token(token)
|
||||
except HTTPException as token_error:
|
||||
# 保持原始的HTTP状态码(如401 Unauthorized),不要转换为403
|
||||
logger.error(f"Token验证失败: {token_error.status_code}: {token_error.detail}, Token前10位: {token[:10]}...")
|
||||
raise # 重新抛出原始异常,保持状态码
|
||||
except Exception as token_error:
|
||||
logger.error(f"Token验证失败: {token_error}, Token前10位: {token[:10]}...")
|
||||
raise ForbiddenException("无效的Token")
|
||||
|
||||
user_id = payload.get("user_id")
|
||||
token_email = payload.get("email")
|
||||
token_created_at = payload.get("created_at")
|
||||
|
||||
if not user_id:
|
||||
logger.error(f"Token缺少user_id字段: payload={payload}")
|
||||
raise ForbiddenException("无效的认证凭据")
|
||||
|
||||
if not token_email:
|
||||
logger.error(f"Token缺少email字段: payload={payload}")
|
||||
raise ForbiddenException("无效的认证凭据")
|
||||
|
||||
# 仅在DEBUG模式下记录详细信息
|
||||
logger.debug(f"尝试获取用户: user_id={user_id}, token前10位: {token[:10]}...")
|
||||
|
||||
# 确保user_id是字符串格式(UUID)
|
||||
if not isinstance(user_id, str):
|
||||
logger.error(f"Token中user_id格式错误: {type(user_id)} - {user_id}")
|
||||
raise ForbiddenException("认证信息格式错误,请重新登录")
|
||||
|
||||
# 使用新的数据库会话获取用户,避免会话状态问题
|
||||
try:
|
||||
from src.services.user.service import UserService
|
||||
|
||||
user = UserService.get_user(db, user_id)
|
||||
except Exception as db_error:
|
||||
logger.error(f"数据库查询失败: user_id={user_id}, error={db_error}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="数据库查询失败,请稍后重试",
|
||||
)
|
||||
|
||||
if not user:
|
||||
logger.error(f"用户不存在: user_id={user_id}")
|
||||
raise ForbiddenException("用户不存在或已禁用")
|
||||
|
||||
if not user.is_active:
|
||||
logger.error(f"用户已禁用: user_id={user_id}")
|
||||
raise ForbiddenException("用户不存在或已禁用")
|
||||
|
||||
# 验证邮箱是否匹配(防止用户ID重用导致的身份混淆)
|
||||
if user.email != token_email:
|
||||
logger.error(f"Token邮箱不匹配: Token中的邮箱={token_email}, 数据库中的邮箱={user.email}")
|
||||
raise ForbiddenException("身份验证失败")
|
||||
|
||||
# 验证用户创建时间是否匹配(防止ID重用)
|
||||
if token_created_at and user.created_at:
|
||||
try:
|
||||
from datetime import datetime
|
||||
|
||||
token_created = datetime.fromisoformat(token_created_at.replace("Z", "+00:00"))
|
||||
# 允许1秒的时间差异(考虑到时间精度问题)
|
||||
time_diff = abs((user.created_at - token_created).total_seconds())
|
||||
if time_diff > 1:
|
||||
logger.error(f"Token创建时间不匹配: Token时间={token_created_at}, 用户创建时间={user.created_at}")
|
||||
raise ForbiddenException("身份验证失败")
|
||||
except ValueError as e:
|
||||
logger.warning(f"Token时间格式解析失败: {e}")
|
||||
|
||||
logger.debug(f"成功获取用户: user_id={user_id}, email={user.email}")
|
||||
return user
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"认证失败,未预期的错误: {e}")
|
||||
# 返回500而不是401,避免触发前端的退出逻辑
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="认证服务暂时不可用"
|
||||
)
|
||||
|
||||
|
||||
def get_current_user_from_header(
|
||||
authorization: Optional[str] = Header(None), db: Session = Depends(get_db)
|
||||
) -> User:
|
||||
"""
|
||||
从Header中获取当前用户(兼容性函数)
|
||||
|
||||
Args:
|
||||
authorization: Authorization header
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
User: 当前用户对象
|
||||
|
||||
Raises:
|
||||
HTTPException: 认证失败时抛出
|
||||
"""
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise ForbiddenException("未提供认证令牌")
|
||||
|
||||
token = authorization.replace("Bearer ", "")
|
||||
|
||||
try:
|
||||
payload = AuthService.verify_token(token)
|
||||
user_id = payload.get("user_id")
|
||||
|
||||
if not user_id:
|
||||
raise ForbiddenException("无效的认证凭据")
|
||||
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise ForbiddenException("用户不存在")
|
||||
|
||||
if not user.is_active:
|
||||
raise ForbiddenException("用户已被禁用")
|
||||
|
||||
return user
|
||||
except HTTPException:
|
||||
# 保持原始的HTTPException (包括401)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"认证失败: {e}")
|
||||
raise ForbiddenException("认证失败")
|
||||
|
||||
|
||||
def require_admin(current_user: User = Depends(get_current_user)) -> User:
|
||||
"""
|
||||
要求管理员权限
|
||||
|
||||
Args:
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
User: 管理员用户对象
|
||||
|
||||
Raises:
|
||||
HTTPException: 非管理员时抛出403错误
|
||||
"""
|
||||
if current_user.role != UserRole.ADMIN:
|
||||
raise ForbiddenException("需要管理员权限")
|
||||
return current_user
|
||||
|
||||
|
||||
def require_role(required_role: UserRole):
|
||||
"""
|
||||
要求特定角色权限的装饰器工厂
|
||||
|
||||
Args:
|
||||
required_role: 需要的用户角色
|
||||
|
||||
Returns:
|
||||
依赖函数
|
||||
"""
|
||||
|
||||
def check_role(current_user: User = Depends(get_current_user)) -> User:
|
||||
if current_user.role != required_role:
|
||||
raise ForbiddenException(f"需要{required_role.value}权限")
|
||||
return current_user
|
||||
|
||||
return check_role
|
||||
75
src/utils/cache_decorator.py
Normal file
75
src/utils/cache_decorator.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""缓存装饰器工具"""
|
||||
|
||||
import functools
|
||||
import json
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from src.core.logger import logger
|
||||
|
||||
from src.clients.redis_client import get_redis_client_sync
|
||||
|
||||
|
||||
def cache_result(key_prefix: str, ttl: int = 60, user_specific: bool = True) -> Callable:
|
||||
"""
|
||||
缓存函数结果的装饰器
|
||||
|
||||
Args:
|
||||
key_prefix: 缓存键前缀
|
||||
ttl: 缓存过期时间(秒)
|
||||
user_specific: 是否针对用户缓存(从 context.user.id 获取)
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs) -> Any:
|
||||
redis_client = get_redis_client_sync()
|
||||
|
||||
# 如果 Redis 不可用,直接执行原函数
|
||||
if redis_client is None:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
# 构建缓存键
|
||||
try:
|
||||
# 从 args 中获取 context(通常是第一个参数)
|
||||
context = args[0] if args else None
|
||||
|
||||
if user_specific and context and hasattr(context, "user") and context.user:
|
||||
cache_key = f"{key_prefix}:user:{context.user.id}"
|
||||
else:
|
||||
cache_key = f"{key_prefix}:global"
|
||||
|
||||
# 如果有额外的参数(如 days),添加到键中
|
||||
if hasattr(args[0], "__dict__"):
|
||||
# 如果是 dataclass 或对象,获取其属性
|
||||
for attr_name in ["days", "limit"]:
|
||||
if hasattr(args[0], attr_name):
|
||||
attr_value = getattr(args[0], attr_name)
|
||||
cache_key += f":{attr_name}:{attr_value}"
|
||||
|
||||
# 尝试从缓存获取
|
||||
cached = await redis_client.get(cache_key)
|
||||
if cached:
|
||||
logger.debug(f"缓存命中: {cache_key}")
|
||||
return json.loads(cached)
|
||||
|
||||
# 执行原函数
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
# 保存到缓存
|
||||
try:
|
||||
await redis_client.setex(
|
||||
cache_key, ttl, json.dumps(result, ensure_ascii=False, default=str)
|
||||
)
|
||||
logger.debug(f"缓存已保存: {cache_key}, TTL: {ttl}s")
|
||||
except Exception as e:
|
||||
logger.warning(f"保存缓存失败: {e}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"缓存处理出错: {e}, 直接执行原函数")
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
75
src/utils/compression.py
Normal file
75
src/utils/compression.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
数据压缩/解压工具
|
||||
|
||||
提供JSON数据的gzip压缩和解压功能
|
||||
"""
|
||||
|
||||
import gzip
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
def compress_json(data: Any) -> Optional[bytes]:
|
||||
"""
|
||||
将JSON数据压缩为gzip格式的字节
|
||||
|
||||
Args:
|
||||
data: 任意可JSON序列化的数据
|
||||
|
||||
Returns:
|
||||
gzip压缩后的字节,如果输入为None则返回None
|
||||
"""
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# 转换为JSON字符串
|
||||
json_str = json.dumps(data, ensure_ascii=False)
|
||||
# gzip压缩
|
||||
compressed = gzip.compress(json_str.encode("utf-8"), compresslevel=6)
|
||||
return compressed
|
||||
except Exception:
|
||||
# 如果压缩失败,返回None
|
||||
return None
|
||||
|
||||
|
||||
def decompress_json(compressed_data: Optional[bytes]) -> Optional[Any]:
|
||||
"""
|
||||
解压gzip格式的字节为JSON数据
|
||||
|
||||
Args:
|
||||
compressed_data: gzip压缩的字节数据
|
||||
|
||||
Returns:
|
||||
解压后的JSON数据,如果输入为None或解压失败则返回None
|
||||
"""
|
||||
if compressed_data is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# gzip解压
|
||||
json_str = gzip.decompress(compressed_data).decode("utf-8")
|
||||
# 解析JSON
|
||||
data = json.loads(json_str)
|
||||
return data
|
||||
except Exception:
|
||||
# 如果解压失败,返回None
|
||||
return None
|
||||
|
||||
|
||||
def get_body_size(data: Any) -> int:
|
||||
"""
|
||||
获取JSON数据序列化后的字节大小
|
||||
|
||||
Args:
|
||||
data: 任意可JSON序列化的数据
|
||||
|
||||
Returns:
|
||||
字节大小
|
||||
"""
|
||||
if data is None:
|
||||
return 0
|
||||
try:
|
||||
return len(json.dumps(data, ensure_ascii=False).encode("utf-8"))
|
||||
except Exception:
|
||||
return 0
|
||||
64
src/utils/database_helpers.py
Normal file
64
src/utils/database_helpers.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
数据库方言兼容性辅助函数
|
||||
"""
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.sql.elements import ClauseElement
|
||||
|
||||
|
||||
def date_trunc_portable(dialect_name: str, interval: str, column) -> ClauseElement:
|
||||
"""
|
||||
跨数据库的日期截断函数
|
||||
|
||||
Args:
|
||||
dialect_name: 数据库方言名称 ('postgresql', 'sqlite', 'mysql')
|
||||
interval: 时间间隔 ('day', 'week', 'month', 'year')
|
||||
column: 日期列
|
||||
|
||||
Returns:
|
||||
SQLAlchemy ClauseElement
|
||||
|
||||
Raises:
|
||||
NotImplementedError: 不支持的数据库方言
|
||||
|
||||
Examples:
|
||||
>>> # PostgreSQL
|
||||
>>> period_func = date_trunc_portable('postgresql', 'week', Usage.created_at)
|
||||
>>> # 等价于: func.date_trunc('week', Usage.created_at)
|
||||
|
||||
>>> # SQLite
|
||||
>>> period_func = date_trunc_portable('sqlite', 'month', Usage.created_at)
|
||||
>>> # 等价于: func.strftime("%Y-%m", Usage.created_at)
|
||||
"""
|
||||
if dialect_name == "postgresql":
|
||||
# PostgreSQL 使用 date_trunc 函数
|
||||
return func.date_trunc(interval, column)
|
||||
|
||||
elif dialect_name == "sqlite":
|
||||
# SQLite 使用 strftime 函数
|
||||
format_map = {
|
||||
"year": "%Y",
|
||||
"month": "%Y-%m",
|
||||
"week": "%Y-%W",
|
||||
"day": "%Y-%m-%d",
|
||||
}
|
||||
if interval not in format_map:
|
||||
raise ValueError(f"Unsupported interval for SQLite: {interval}")
|
||||
return func.strftime(format_map[interval], column)
|
||||
|
||||
elif dialect_name == "mysql":
|
||||
# MySQL 使用 date_format 函数
|
||||
format_map = {
|
||||
"year": "%Y",
|
||||
"month": "%Y-%m",
|
||||
"day": "%Y-%m-%d",
|
||||
}
|
||||
if interval not in format_map:
|
||||
raise ValueError(f"Unsupported interval for MySQL: {interval}")
|
||||
return func.date_format(column, format_map[interval])
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported database dialect: {dialect_name}. "
|
||||
f"Supported dialects: postgresql, sqlite, mysql"
|
||||
)
|
||||
114
src/utils/request_utils.py
Normal file
114
src/utils/request_utils.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""
|
||||
请求处理工具函数
|
||||
提供统一的HTTP请求信息提取功能
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
|
||||
def get_client_ip(request: Request) -> str:
|
||||
"""
|
||||
获取客户端真实IP地址
|
||||
|
||||
按优先级检查:
|
||||
1. X-Forwarded-For 头(支持代理链)
|
||||
2. X-Real-IP 头(Nginx 代理)
|
||||
3. 直接客户端IP
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
|
||||
Returns:
|
||||
str: 客户端IP地址,如果无法获取则返回 "unknown"
|
||||
"""
|
||||
# 优先检查 X-Forwarded-For 头(可能包含代理链)
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
# X-Forwarded-For 格式: "client, proxy1, proxy2",取第一个(真实客户端)
|
||||
client_ip = forwarded_for.split(",")[0].strip()
|
||||
if client_ip:
|
||||
return client_ip
|
||||
|
||||
# 检查 X-Real-IP 头(通常由 Nginx 设置)
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
return real_ip.strip()
|
||||
|
||||
# 回退到直接客户端IP
|
||||
if request.client and request.client.host:
|
||||
return request.client.host
|
||||
|
||||
return "unknown"
|
||||
|
||||
|
||||
def get_user_agent(request: Request) -> str:
|
||||
"""
|
||||
获取用户代理字符串
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
|
||||
Returns:
|
||||
str: User-Agent 字符串,如果无法获取则返回 "unknown"
|
||||
"""
|
||||
return request.headers.get("User-Agent", "unknown")
|
||||
|
||||
|
||||
def get_request_id(request: Request) -> Optional[str]:
|
||||
"""
|
||||
获取请求ID(如果存在)
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
|
||||
Returns:
|
||||
Optional[str]: 请求ID,如果不存在则返回 None
|
||||
"""
|
||||
return getattr(request.state, "request_id", None)
|
||||
|
||||
|
||||
def get_request_metadata(request: Request) -> dict:
|
||||
"""
|
||||
获取请求的完整元数据
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
|
||||
Returns:
|
||||
dict: 包含请求元数据的字典
|
||||
"""
|
||||
return {
|
||||
"client_ip": get_client_ip(request),
|
||||
"user_agent": get_user_agent(request),
|
||||
"request_id": get_request_id(request),
|
||||
"method": request.method,
|
||||
"path": request.url.path,
|
||||
"query_params": str(request.query_params) if request.query_params else None,
|
||||
"content_type": request.headers.get("Content-Type"),
|
||||
"content_length": request.headers.get("Content-Length"),
|
||||
}
|
||||
|
||||
|
||||
def extract_ip_from_headers(headers: dict) -> str:
|
||||
"""
|
||||
从HTTP头字典中提取IP地址(用于中间件等场景)
|
||||
|
||||
Args:
|
||||
headers: HTTP头字典
|
||||
|
||||
Returns:
|
||||
str: 客户端IP地址
|
||||
"""
|
||||
# 检查 X-Forwarded-For
|
||||
forwarded_for = headers.get("x-forwarded-for", "")
|
||||
if forwarded_for:
|
||||
return forwarded_for.split(",")[0].strip()
|
||||
|
||||
# 检查 X-Real-IP
|
||||
real_ip = headers.get("x-real-ip", "")
|
||||
if real_ip:
|
||||
return real_ip.strip()
|
||||
|
||||
return "unknown"
|
||||
109
src/utils/sse_parser.py
Normal file
109
src/utils/sse_parser.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
class SSEEventParser:
|
||||
"""轻量SSE解析器,按行接收输入并输出完整事件。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._reset_buffer()
|
||||
|
||||
def _reset_buffer(self) -> None:
|
||||
self._buffer: Dict[str, Optional[str] | List[str]] = {
|
||||
"event": None,
|
||||
"data": [],
|
||||
"id": None,
|
||||
"retry": None,
|
||||
}
|
||||
|
||||
def _finalize_event(self) -> Optional[Dict[str, Optional[str]]]:
|
||||
data_lines = self._buffer.get("data", [])
|
||||
if not data_lines:
|
||||
self._reset_buffer()
|
||||
return None
|
||||
|
||||
data_str = "\n".join(data_lines)
|
||||
event = {
|
||||
"event": self._buffer.get("event"),
|
||||
"data": data_str,
|
||||
"id": self._buffer.get("id"),
|
||||
"retry": self._buffer.get("retry"),
|
||||
}
|
||||
|
||||
self._reset_buffer()
|
||||
return event
|
||||
|
||||
def feed_line(self, line: Optional[str]) -> List[Dict[str, Optional[str]]]:
|
||||
"""处理单行SSE文本,返回所有完成的事件。"""
|
||||
|
||||
normalized_line = (line or "").rstrip("\r")
|
||||
events: List[Dict[str, Optional[str]]] = []
|
||||
|
||||
# 空行表示事件结束
|
||||
if normalized_line == "":
|
||||
event = self._finalize_event()
|
||||
if event:
|
||||
events.append(event)
|
||||
return events
|
||||
|
||||
# 注释行直接忽略
|
||||
if normalized_line.startswith(":") and not normalized_line.startswith("::"):
|
||||
return events
|
||||
|
||||
if normalized_line.startswith("event:"):
|
||||
_, rest = normalized_line.split(":", 1)
|
||||
value = rest.lstrip()
|
||||
|
||||
if " data:" in value:
|
||||
event_part, data_part = value.split("data:", 1)
|
||||
event_name = event_part.strip() or None
|
||||
data_value = data_part.lstrip()
|
||||
self._buffer["event"] = event_name
|
||||
if data_value:
|
||||
self._append_data_line(data_value)
|
||||
event = self._finalize_event()
|
||||
if event:
|
||||
events.append(event)
|
||||
else:
|
||||
event_name = value.strip() or None
|
||||
self._buffer["event"] = event_name
|
||||
return events
|
||||
|
||||
if normalized_line.startswith("data:"):
|
||||
# 如果已经有缓存的 data,先完成上一个事件
|
||||
# 这样可以处理没有空行分隔的连续 data 行
|
||||
existing_data = self._buffer.get("data", [])
|
||||
if existing_data and len(existing_data) > 0:
|
||||
event = self._finalize_event()
|
||||
if event:
|
||||
events.append(event)
|
||||
|
||||
_, rest = normalized_line.split(":", 1)
|
||||
self._append_data_line(rest[1:] if rest.startswith(" ") else rest)
|
||||
return events
|
||||
|
||||
if normalized_line.startswith("id:"):
|
||||
_, rest = normalized_line.split(":", 1)
|
||||
self._buffer["id"] = rest.strip() or None
|
||||
return events
|
||||
|
||||
if normalized_line.startswith("retry:"):
|
||||
_, rest = normalized_line.split(":", 1)
|
||||
self._buffer["retry"] = rest.strip() or None
|
||||
return events
|
||||
|
||||
# 未知行:视作数据追加(部分实现会缺少 data: 前缀)
|
||||
self._append_data_line(normalized_line)
|
||||
return events
|
||||
|
||||
def flush(self) -> List[Dict[str, Optional[str]]]:
|
||||
"""在流结束时调用,输出尚未完成的事件。"""
|
||||
|
||||
event = self._finalize_event()
|
||||
return [event] if event else []
|
||||
|
||||
def _append_data_line(self, value: str) -> None:
|
||||
data_lines = self._buffer.get("data")
|
||||
if isinstance(data_lines, list):
|
||||
data_lines.append(value)
|
||||
else:
|
||||
self._buffer["data"] = [value]
|
||||
97
src/utils/task_coordinator.py
Normal file
97
src/utils/task_coordinator.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""分布式任务协调器,确保仅有一个 worker 执行特定任务"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import pathlib
|
||||
import uuid
|
||||
from typing import Dict, Optional
|
||||
|
||||
from src.core.logger import logger
|
||||
|
||||
try:
|
||||
import fcntl # type: ignore
|
||||
except ImportError: # pragma: no cover - Windows 环境
|
||||
fcntl = None
|
||||
|
||||
|
||||
class StartupTaskCoordinator:
|
||||
"""利用 Redis 或文件锁,保证任务只在单个进程/实例中运行"""
|
||||
|
||||
def __init__(self, redis_client=None, lock_dir: Optional[str] = None):
|
||||
self.redis = redis_client
|
||||
self._tokens: Dict[str, str] = {}
|
||||
self._file_handles: Dict[str, object] = {}
|
||||
self._lock_dir = pathlib.Path(lock_dir or os.getenv("TASK_LOCK_DIR", "./.locks"))
|
||||
if not self._lock_dir.exists():
|
||||
self._lock_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _redis_key(self, name: str) -> str:
|
||||
return f"task_lock:{name}"
|
||||
|
||||
async def acquire(self, name: str, ttl: Optional[int] = None) -> bool:
|
||||
ttl = ttl or int(os.getenv("TASK_COORDINATOR_LOCK_TTL", "86400"))
|
||||
|
||||
if self.redis:
|
||||
token = str(uuid.uuid4())
|
||||
try:
|
||||
acquired = await self.redis.set(self._redis_key(name), token, nx=True, ex=ttl)
|
||||
if acquired:
|
||||
self._tokens[name] = token
|
||||
logger.info(f"任务 {name} 通过 Redis 锁独占执行")
|
||||
return True
|
||||
return False
|
||||
except Exception as exc: # pragma: no cover - Redis 异常回退
|
||||
logger.warning(f"Redis 锁获取失败,回退到文件锁: {exc}")
|
||||
|
||||
return await self._acquire_file_lock(name)
|
||||
|
||||
async def release(self, name: str):
|
||||
if self.redis and name in self._tokens:
|
||||
token = self._tokens.pop(name)
|
||||
script = """
|
||||
if redis.call('GET', KEYS[1]) == ARGV[1] then
|
||||
return redis.call('DEL', KEYS[1])
|
||||
end
|
||||
return 0
|
||||
"""
|
||||
try:
|
||||
await self.redis.eval(script, 1, self._redis_key(name), token)
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.warning(f"释放 Redis 锁失败: {exc}")
|
||||
|
||||
handle = self._file_handles.pop(name, None)
|
||||
if handle and fcntl:
|
||||
try:
|
||||
fcntl.flock(handle, fcntl.LOCK_UN)
|
||||
finally:
|
||||
handle.close()
|
||||
|
||||
async def _acquire_file_lock(self, name: str) -> bool:
|
||||
if fcntl is None:
|
||||
# 在不支持 fcntl 的平台上退化为单进程锁
|
||||
if name in self._file_handles:
|
||||
return False
|
||||
self._file_handles[name] = object()
|
||||
logger.warning("操作系统不支持文件锁,任务锁仅在当前进程生效")
|
||||
return True
|
||||
|
||||
lock_path = self._lock_dir / f"{name}.lock"
|
||||
handle = open(lock_path, "a+")
|
||||
try:
|
||||
fcntl.flock(handle, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||
self._file_handles[name] = handle
|
||||
logger.info(f"任务 {name} 使用文件锁独占执行")
|
||||
return True
|
||||
except BlockingIOError:
|
||||
handle.close()
|
||||
return False
|
||||
|
||||
|
||||
async def ensure_singleton_task(name: str, redis_client=None, ttl: Optional[int] = None):
|
||||
"""便捷协程,返回 (coordinator, acquired)"""
|
||||
|
||||
coordinator = StartupTaskCoordinator(redis_client)
|
||||
acquired = await coordinator.acquire(name, ttl=ttl)
|
||||
return coordinator, acquired
|
||||
141
src/utils/timeout.py
Normal file
141
src/utils/timeout.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
超时保护工具
|
||||
|
||||
为异步函数和操作提供超时保护
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
|
||||
from src.core.logger import logger
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class AsyncTimeoutError(TimeoutError):
|
||||
"""异步操作超时错误"""
|
||||
|
||||
def __init__(self, message: str, operation: str, timeout: float):
|
||||
super().__init__(message)
|
||||
self.operation = operation
|
||||
self.timeout = timeout
|
||||
|
||||
|
||||
def with_timeout(seconds: float, operation_name: Optional[str] = None):
|
||||
"""
|
||||
装饰器:为异步函数添加超时保护
|
||||
|
||||
Args:
|
||||
seconds: 超时时间(秒)
|
||||
operation_name: 操作名称(用于日志,默认使用函数名)
|
||||
|
||||
Usage:
|
||||
@with_timeout(30.0)
|
||||
async def my_async_function():
|
||||
...
|
||||
|
||||
@with_timeout(60.0, operation_name="API请求")
|
||||
async def api_call():
|
||||
...
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
op_name = operation_name or func.__name__
|
||||
try:
|
||||
return await asyncio.wait_for(func(*args, **kwargs), timeout=seconds)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"操作超时: {op_name} (timeout={seconds}s)")
|
||||
raise AsyncTimeoutError(
|
||||
f"{op_name} 操作超时({seconds}秒)",
|
||||
operation=op_name,
|
||||
timeout=seconds,
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
async def run_with_timeout(
|
||||
coro,
|
||||
timeout: float,
|
||||
operation_name: str = "operation",
|
||||
default: T = None,
|
||||
raise_on_timeout: bool = True,
|
||||
) -> T:
|
||||
"""
|
||||
为协程添加超时保护(函数式调用)
|
||||
|
||||
Args:
|
||||
coro: 协程对象
|
||||
timeout: 超时时间(秒)
|
||||
operation_name: 操作名称(用于日志)
|
||||
default: 超时时返回的默认值(仅在 raise_on_timeout=False 时有效)
|
||||
raise_on_timeout: 超时时是否抛出异常
|
||||
|
||||
Returns:
|
||||
协程的返回值,或超时时的默认值
|
||||
|
||||
Usage:
|
||||
result = await run_with_timeout(
|
||||
my_async_function(),
|
||||
timeout=30.0,
|
||||
operation_name="API请求"
|
||||
)
|
||||
"""
|
||||
try:
|
||||
return await asyncio.wait_for(coro, timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"操作超时: {operation_name} (timeout={timeout}s)")
|
||||
if raise_on_timeout:
|
||||
raise AsyncTimeoutError(
|
||||
f"{operation_name} 操作超时({timeout}秒)",
|
||||
operation=operation_name,
|
||||
timeout=timeout,
|
||||
)
|
||||
return default
|
||||
|
||||
|
||||
class TimeoutContext:
|
||||
"""
|
||||
超时上下文管理器
|
||||
|
||||
Usage:
|
||||
async with TimeoutContext(30.0, "数据库查询") as ctx:
|
||||
result = await db.query(...)
|
||||
# 如果超过30秒会抛出 AsyncTimeoutError
|
||||
"""
|
||||
|
||||
def __init__(self, timeout: float, operation_name: str = "operation"):
|
||||
self.timeout = timeout
|
||||
self.operation_name = operation_name
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
# asyncio.timeout 在 Python 3.11+ 可用
|
||||
# 这里使用更通用的方式
|
||||
pass
|
||||
|
||||
|
||||
async def with_timeout_context(timeout: float, operation_name: str = "operation"):
|
||||
"""
|
||||
超时上下文管理器(Python 3.11+ asyncio.timeout 的替代)
|
||||
|
||||
Usage:
|
||||
async with with_timeout_context(30.0, "API请求"):
|
||||
result = await api_call()
|
||||
"""
|
||||
try:
|
||||
# Python 3.11+ 使用内置的 asyncio.timeout
|
||||
return asyncio.timeout(timeout)
|
||||
except AttributeError:
|
||||
# Python 3.10 及以下版本的兼容实现
|
||||
# 注意:这个简单实现不支持嵌套取消
|
||||
pass
|
||||
304
src/utils/transaction_manager.py
Normal file
304
src/utils/transaction_manager.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
数据库事务管理工具
|
||||
提供事务装饰器和事务上下文管理器
|
||||
支持同步和异步函数
|
||||
"""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Generator, Optional
|
||||
|
||||
from sqlalchemy.exc import DatabaseError, IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.logger import logger
|
||||
|
||||
|
||||
|
||||
class TransactionError(Exception):
|
||||
"""事务处理异常"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def _find_db_session(args, kwargs) -> Optional[Session]:
|
||||
"""从参数中查找数据库会话"""
|
||||
# 从位置参数中查找Session
|
||||
for arg in args:
|
||||
if isinstance(arg, Session):
|
||||
return arg
|
||||
|
||||
# 从关键字参数中查找Session
|
||||
for value in kwargs.values():
|
||||
if isinstance(value, Session):
|
||||
return value
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def transactional(commit: bool = True, rollback_on_error: bool = True):
|
||||
"""
|
||||
事务装饰器,支持同步和异步函数
|
||||
|
||||
Args:
|
||||
commit: 是否在成功时自动提交,默认True
|
||||
rollback_on_error: 是否在错误时自动回滚,默认True
|
||||
|
||||
Usage:
|
||||
@transactional()
|
||||
def create_user_with_api_key(db: Session, ...):
|
||||
# 同步方法会在事务中执行
|
||||
pass
|
||||
|
||||
@transactional()
|
||||
async def create_user_async(db: Session, ...):
|
||||
# 异步方法也会在事务中执行
|
||||
pass
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
# 检查是否是异步函数
|
||||
if inspect.iscoroutinefunction(func):
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args, **kwargs) -> Any:
|
||||
db_session = _find_db_session(args, kwargs)
|
||||
|
||||
if not db_session:
|
||||
raise TransactionError(
|
||||
f"No SQLAlchemy Session found in arguments for {func.__name__}"
|
||||
)
|
||||
|
||||
# 检查是否已经在事务中
|
||||
if db_session.in_transaction():
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
transaction_id = f"{func.__module__}.{func.__name__}"
|
||||
logger.debug(f"开始异步事务: {transaction_id}")
|
||||
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
logger.debug(f"异步事务提交成功: {transaction_id}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
if rollback_on_error:
|
||||
try:
|
||||
db_session.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
logger.error(
|
||||
f"异步事务回滚: {transaction_id} - {type(e).__name__}: {str(e)}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"异步事务异常(未回滚): {transaction_id} - {type(e).__name__}: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
return async_wrapper
|
||||
else:
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args, **kwargs) -> Any:
|
||||
db_session = _find_db_session(args, kwargs)
|
||||
|
||||
if not db_session:
|
||||
raise TransactionError(
|
||||
f"No SQLAlchemy Session found in arguments for {func.__name__}"
|
||||
)
|
||||
|
||||
# 检查是否已经在事务中
|
||||
if db_session.in_transaction():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
transaction_id = f"{func.__module__}.{func.__name__}"
|
||||
logger.debug(f"开始事务: {transaction_id}")
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
logger.debug(f"事务提交成功: {transaction_id}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
if rollback_on_error:
|
||||
try:
|
||||
db_session.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
logger.error(
|
||||
f"事务回滚: {transaction_id} - {type(e).__name__}: {str(e)}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"事务异常(未回滚): {transaction_id} - {type(e).__name__}: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@contextmanager
|
||||
def transaction_scope(
|
||||
db: Session,
|
||||
commit_on_success: bool = True,
|
||||
rollback_on_error: bool = True,
|
||||
operation_name: Optional[str] = None,
|
||||
) -> Generator[Session, None, None]:
|
||||
"""
|
||||
事务上下文管理器
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
commit_on_success: 成功时是否自动提交
|
||||
rollback_on_error: 失败时是否自动回滚
|
||||
operation_name: 操作名称,用于日志
|
||||
|
||||
Usage:
|
||||
with transaction_scope(db, operation_name="create_user") as tx:
|
||||
user = User(...)
|
||||
tx.add(user)
|
||||
# 自动提交或回滚
|
||||
"""
|
||||
operation_name = operation_name or "database_operation"
|
||||
|
||||
# 检查是否已经在事务中
|
||||
if db.in_transaction():
|
||||
# 已经在事务中,直接返回session
|
||||
logger.debug(f"使用现有事务: {operation_name}")
|
||||
yield db
|
||||
return
|
||||
|
||||
logger.debug(f"开始事务范围: {operation_name}")
|
||||
|
||||
try:
|
||||
yield db
|
||||
|
||||
if commit_on_success:
|
||||
db.commit()
|
||||
logger.debug(f"事务范围提交成功: {operation_name}")
|
||||
|
||||
except Exception as e:
|
||||
if rollback_on_error:
|
||||
db.rollback()
|
||||
logger.error(f"事务范围回滚: {operation_name} - {type(e).__name__}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def retry_on_database_error(max_retries: int = 3, delay: float = 0.1):
|
||||
"""
|
||||
数据库错误重试装饰器
|
||||
|
||||
Args:
|
||||
max_retries: 最大重试次数
|
||||
delay: 重试延迟(秒)
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Any:
|
||||
import random
|
||||
import time
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
except (DatabaseError, IntegrityError) as e:
|
||||
if attempt < max_retries - 1:
|
||||
# 随机化延迟,避免多个请求同时重试
|
||||
actual_delay = delay * (2**attempt) + random.uniform(0, 0.1)
|
||||
logger.warning(f"数据库操作失败,{actual_delay:.2f}秒后重试 (尝试 {attempt + 1}/{max_retries}): {str(e)}")
|
||||
time.sleep(actual_delay)
|
||||
continue
|
||||
else:
|
||||
logger.error(
|
||||
f"数据库操作最终失败,已达最大重试次数({max_retries}): {func.__name__} - {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class BatchOperation:
|
||||
"""
|
||||
批量操作管理器
|
||||
用于处理大量数据插入/更新操作
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session, batch_size: int = 100):
|
||||
self.db = db
|
||||
self.batch_size = batch_size
|
||||
self.operations = []
|
||||
self.operation_count = 0
|
||||
|
||||
def add(self, obj):
|
||||
"""添加对象到批处理"""
|
||||
self.operations.append(("add", obj))
|
||||
self.operation_count += 1
|
||||
|
||||
if self.operation_count >= self.batch_size:
|
||||
self.flush()
|
||||
|
||||
def update(self, obj):
|
||||
"""添加更新操作到批处理"""
|
||||
self.operations.append(("merge", obj))
|
||||
self.operation_count += 1
|
||||
|
||||
if self.operation_count >= self.batch_size:
|
||||
self.flush()
|
||||
|
||||
def flush(self):
|
||||
"""执行当前批次的所有操作"""
|
||||
if not self.operations:
|
||||
return
|
||||
|
||||
logger.debug(f"执行批量操作: {len(self.operations)} 项")
|
||||
|
||||
try:
|
||||
for operation, obj in self.operations:
|
||||
if operation == "add":
|
||||
self.db.add(obj)
|
||||
elif operation == "merge":
|
||||
self.db.merge(obj)
|
||||
|
||||
self.db.flush() # 只flush,不提交
|
||||
logger.debug(f"批量操作flush完成: {len(self.operations)} 项")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量操作失败({len(self.operations)}项): {type(e).__name__}: {str(e)}")
|
||||
raise
|
||||
|
||||
finally:
|
||||
# 清空操作列表
|
||||
self.operations.clear()
|
||||
self.operation_count = 0
|
||||
|
||||
def commit(self):
|
||||
"""提交所有操作"""
|
||||
self.flush() # 确保所有操作都已flush
|
||||
self.db.commit()
|
||||
logger.debug("批量操作提交完成")
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type is None:
|
||||
# 正常退出,提交事务
|
||||
self.commit()
|
||||
else:
|
||||
# 异常退出,回滚事务
|
||||
self.db.rollback()
|
||||
logger.error(f"批量操作异常退出,已回滚: {str(exc_val)}")
|
||||
Reference in New Issue
Block a user