Files
Aether/src/database/database.py

383 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
数据库连接和初始化
"""
import time
from typing import AsyncGenerator, Generator, Optional
from sqlalchemy import create_engine, event
from sqlalchemy.engine import Engine
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import Pool, QueuePool
from ..config import config
from src.core.logger import logger
from ..models.database import Base, SystemConfig, User, UserRole
# 延迟初始化的数据库引擎和会话工厂
_engine: Optional[Engine] = None
_SessionLocal: Optional[sessionmaker] = None
_async_engine: Optional[AsyncEngine] = None
_AsyncSessionLocal: Optional[async_sessionmaker] = None
# 连接池监控
_last_pool_warning: float = 0.0
POOL_WARNING_INTERVAL = 60 # 每60秒最多警告一次
def _setup_pool_monitoring(engine: Engine):
"""设置连接池监控事件"""
@event.listens_for(engine, "connect")
def receive_connect(dbapi_conn, connection_record):
"""连接创建时的监控"""
pass
@event.listens_for(engine, "checkout")
def receive_checkout(dbapi_conn, connection_record, connection_proxy):
"""从连接池检出连接时的监控"""
global _last_pool_warning
pool = engine.pool
# 获取连接池状态
checked_out = pool.checkedout()
pool_size = pool.size()
overflow = pool.overflow()
max_capacity = config.db_pool_size + config.db_max_overflow
# 计算使用率
usage_rate = (checked_out / max_capacity) * 100 if max_capacity > 0 else 0
# 如果使用率超过阈值,发出警告
if usage_rate >= config.db_pool_warn_threshold:
current_time = time.time()
# 避免频繁警告
if current_time - _last_pool_warning > POOL_WARNING_INTERVAL:
_last_pool_warning = current_time
logger.warning(
f"数据库连接池使用率过高: checked_out={checked_out}, "
f"pool_size={pool_size}, overflow={overflow}, "
f"max_capacity={max_capacity}, usage_rate={usage_rate:.1f}%, "
f"threshold={config.db_pool_warn_threshold}%"
)
def get_pool_status() -> dict:
"""获取连接池状态"""
engine = _ensure_engine()
pool = engine.pool
return {
"checked_out": pool.checkedout(),
"pool_size": pool.size(),
"overflow": pool.overflow(),
"max_capacity": config.db_pool_size + config.db_max_overflow,
"pool_timeout": config.db_pool_timeout,
}
def log_pool_status():
"""记录连接池状态到日志(用于监控)"""
try:
status = get_pool_status()
usage_rate = (
(status["checked_out"] / status["max_capacity"] * 100)
if status["max_capacity"] > 0
else 0
)
logger.info(
f"数据库连接池状态: checked_out={status['checked_out']}, "
f"pool_size={status['pool_size']}, overflow={status['overflow']}, "
f"max_capacity={status['max_capacity']}, usage_rate={usage_rate:.1f}%"
)
except Exception as e:
logger.error(f"获取连接池状态失败: {e}")
def _ensure_engine() -> Engine:
"""
确保数据库引擎已创建(延迟加载)
这允许测试和 CLI 工具在导入模块时不会立即连接数据库
"""
global _engine, _SessionLocal
if _engine is not None:
return _engine
# 获取数据库配置
DATABASE_URL = config.database_url
# 验证数据库类型(生产环境要求 PostgreSQL但允许测试环境使用其他数据库
is_production = config.environment == "production"
if is_production and not DATABASE_URL.startswith("postgresql://"):
raise ValueError("生产环境只支持 PostgreSQL 数据库,请配置正确的 DATABASE_URL")
# 创建引擎
_engine = create_engine(
DATABASE_URL,
poolclass=QueuePool, # 使用队列连接池
pool_size=config.db_pool_size, # 连接池大小
max_overflow=config.db_max_overflow, # 最大溢出连接数
pool_timeout=config.db_pool_timeout, # 连接超时(秒)
pool_recycle=config.db_pool_recycle, # 连接回收时间(秒)
pool_pre_ping=True, # 检查连接活性
echo=False, # 关闭SQL日志输出太冗长
)
# 设置连接池监控
_setup_pool_monitoring(_engine)
# 创建会话工厂
_SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=_engine)
_log_pool_capacity()
logger.debug(f"数据库引擎已初始化: {DATABASE_URL.split('@')[-1] if '@' in DATABASE_URL else 'local'}")
return _engine
def _log_pool_capacity():
theoretical = config.db_pool_size + config.db_max_overflow
workers = max(1, config.worker_processes)
total_estimated = theoretical * workers
logger.info("数据库连接池配置")
if total_estimated > config.db_pool_warn_threshold:
logger.warning("数据库连接需求可能超过阈值,请调小池大小或减少 worker 数")
def _ensure_async_engine() -> AsyncEngine:
"""
确保异步数据库引擎已创建(延迟加载)
这允许异步路由使用非阻塞的数据库访问
"""
global _async_engine, _AsyncSessionLocal
if _async_engine is not None:
return _async_engine
# 获取数据库配置并转换为异步URL
DATABASE_URL = config.database_url
# 转换同步URL为异步URLpostgresql:// -> postgresql+asyncpg://
if DATABASE_URL.startswith("postgresql://"):
ASYNC_DATABASE_URL = DATABASE_URL.replace("postgresql://", "postgresql+asyncpg://", 1)
elif DATABASE_URL.startswith("sqlite:///"):
ASYNC_DATABASE_URL = DATABASE_URL.replace("sqlite:///", "sqlite+aiosqlite:///", 1)
else:
raise ValueError(f"不支持的数据库类型: {DATABASE_URL}")
# 验证数据库类型(生产环境要求 PostgreSQL
is_production = config.environment == "production"
if is_production and not ASYNC_DATABASE_URL.startswith("postgresql+asyncpg://"):
raise ValueError("生产环境只支持 PostgreSQL 数据库,请配置正确的 DATABASE_URL")
# 创建异步引擎
_async_engine = create_async_engine(
ASYNC_DATABASE_URL,
poolclass=QueuePool, # 使用队列连接池
pool_size=config.db_pool_size,
max_overflow=config.db_max_overflow,
pool_timeout=config.db_pool_timeout,
pool_recycle=config.db_pool_recycle,
pool_pre_ping=True,
echo=False,
)
# 创建异步会话工厂
_AsyncSessionLocal = async_sessionmaker(
_async_engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
logger.debug(f"异步数据库引擎已初始化: {ASYNC_DATABASE_URL.split('@')[-1] if '@' in ASYNC_DATABASE_URL else 'local'}")
return _async_engine
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
"""获取异步数据库会话"""
# 确保异步引擎已初始化
_ensure_async_engine()
async with _AsyncSessionLocal() as session:
try:
yield session
finally:
await session.close()
def get_db() -> Generator[Session, None, None]:
"""获取数据库会话
注意:事务管理由业务逻辑层显式控制(手动调用 commit/rollback
这里只负责会话的创建和关闭,不自动提交
"""
# 确保引擎已初始化
_ensure_engine()
db = _SessionLocal()
try:
yield db
# 不再自动 commit由业务代码显式管理事务
except Exception:
try:
db.rollback() # 失败时回滚未提交的事务
except Exception as rollback_error:
# 记录回滚错误(可能是 commit 正在进行中)
logger.debug(f"回滚事务时出错(可忽略): {rollback_error}")
raise
finally:
try:
db.close() # 确保连接返回池
except Exception as close_error:
# 记录关闭错误(如 IllegalStateChangeError
# 连接池会处理连接的回收
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
def create_session() -> Session:
"""
创建一个新的数据库会话
注意:调用者必须负责关闭会话
推荐在 with 语句中使用或手动调用 session.close()
示例:
db = create_session()
try:
# 使用 db
finally:
db.close()
"""
_ensure_engine()
return _SessionLocal()
def get_db_url() -> str:
"""返回当前配置的数据库连接字符串(供脚本/测试使用)。"""
return config.database_url
def init_db():
"""初始化数据库"""
logger.info("初始化数据库...")
# 确保引擎已创建
engine = _ensure_engine()
# 创建所有表
Base.metadata.create_all(bind=engine)
# 数据库表已通过SQLAlchemy自动创建
db = _SessionLocal()
try:
# 创建管理员账户(如果环境变量中配置了)
init_admin_user(db)
# 添加默认模型配置
init_default_models(db)
# 添加系统配置
init_system_configs(db)
db.commit()
logger.info("数据库初始化完成")
except Exception as e:
logger.error(f"数据库初始化失败: {e}")
db.rollback()
raise
finally:
db.close()
def init_admin_user(db: Session):
"""从环境变量创建管理员账户"""
# 检查是否使用默认凭据
if config.admin_email == "admin@localhost" and config.admin_password == "admin123":
logger.warning("使用默认管理员账户配置,建议修改为安全的凭据")
# 检查是否已存在管理员
existing_admin = (
db.query(User)
.filter((User.email == config.admin_email) | (User.username == config.admin_username))
.first()
)
if existing_admin:
logger.info(f"管理员账户已存在: {existing_admin.email}")
return
try:
# 创建管理员账户
admin = User(
email=config.admin_email,
username=config.admin_username,
role=UserRole.ADMIN,
quota_usd=1000.0,
is_active=True,
)
admin.set_password(config.admin_password)
db.add(admin)
db.commit() # 刷新以获取ID但不提交
logger.info(f"创建管理员账户成功: {admin.email} ({admin.username})")
except Exception as e:
logger.error(f"创建管理员账户失败: {e}")
raise
def init_default_models(db: Session):
"""初始化默认模型配置"""
# 注意:作为中转代理服务,不再预设模型配置
# 模型配置应该通过 GlobalModel 和 Model 表动态管理
# 这个函数保留用于未来可能的默认模型初始化
pass
def init_system_configs(db: Session):
"""初始化系统配置"""
configs = [
{"key": "default_user_quota_usd", "value": 10.0, "description": "新用户默认美元配额"},
{"key": "rate_limit_per_minute", "value": 60, "description": "每分钟请求限制"},
{"key": "enable_registration", "value": False, "description": "是否开放用户注册"},
{"key": "require_email_verification", "value": False, "description": "是否需要邮箱验证"},
{"key": "api_key_expire_days", "value": 365, "description": "API密钥过期天数"},
]
for config_data in configs:
existing = db.query(SystemConfig).filter_by(key=config_data["key"]).first()
if not existing:
config = SystemConfig(**config_data)
db.add(config)
logger.info(f"添加系统配置: {config_data['key']}")
def reset_db():
"""重置数据库(仅用于开发)"""
logger.warning("重置数据库...")
# 确保引擎已创建
engine = _ensure_engine()
Base.metadata.drop_all(bind=engine)
init_db()