mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-07 10:12:27 +08:00
Initial commit
This commit is contained in:
20
src/database/__init__.py
Normal file
20
src/database/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
数据库模块
|
||||
"""
|
||||
|
||||
from ..models.database import ApiKey, Base, Usage, User, UserQuota
|
||||
from .database import create_session, get_async_db, get_db, get_db_url, init_db, log_pool_status
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"User",
|
||||
"ApiKey",
|
||||
"Usage",
|
||||
"UserQuota",
|
||||
"get_db",
|
||||
"get_async_db",
|
||||
"init_db",
|
||||
"create_session",
|
||||
"get_db_url",
|
||||
"log_pool_status",
|
||||
]
|
||||
41
src/database/async_utils.py
Normal file
41
src/database/async_utils.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
异步数据库工具
|
||||
提供在异步上下文中安全使用同步数据库操作的工具
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
async def run_in_executor(func: Callable[..., T], *args, **kwargs) -> T:
|
||||
"""
|
||||
在线程池中运行同步函数,避免阻塞事件循环
|
||||
|
||||
用法:
|
||||
result = await run_in_executor(some_sync_function, arg1, arg2)
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, lambda: func(*args, **kwargs))
|
||||
|
||||
|
||||
def async_wrap_sync_db(func: Callable[..., T]) -> Callable[..., Any]:
|
||||
"""
|
||||
装饰器:包装同步数据库函数为异步函数
|
||||
|
||||
用法:
|
||||
@async_wrap_sync_db
|
||||
def get_user(db: Session, user_id: int):
|
||||
return db.query(User).filter(User.id == user_id).first()
|
||||
|
||||
# 现在可以在异步上下文中调用
|
||||
user = await get_user(db, 123)
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
return await run_in_executor(func, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
382
src/database/database.py
Normal file
382
src/database/database.py
Normal file
@@ -0,0 +1,382 @@
|
||||
"""
|
||||
数据库连接和初始化
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import AsyncGenerator, Generator, Optional
|
||||
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.pool import Pool, QueuePool
|
||||
|
||||
from ..config import config
|
||||
from src.core.logger import logger
|
||||
from ..models.database import Base, SystemConfig, User, UserRole
|
||||
|
||||
|
||||
# 延迟初始化的数据库引擎和会话工厂
|
||||
_engine: Optional[Engine] = None
|
||||
_SessionLocal: Optional[sessionmaker] = None
|
||||
_async_engine: Optional[AsyncEngine] = None
|
||||
_AsyncSessionLocal: Optional[async_sessionmaker] = None
|
||||
|
||||
# 连接池监控
|
||||
_last_pool_warning: float = 0.0
|
||||
POOL_WARNING_INTERVAL = 60 # 每60秒最多警告一次
|
||||
|
||||
|
||||
def _setup_pool_monitoring(engine: Engine):
|
||||
"""设置连接池监控事件"""
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def receive_connect(dbapi_conn, connection_record):
|
||||
"""连接创建时的监控"""
|
||||
pass
|
||||
|
||||
@event.listens_for(engine, "checkout")
|
||||
def receive_checkout(dbapi_conn, connection_record, connection_proxy):
|
||||
"""从连接池检出连接时的监控"""
|
||||
global _last_pool_warning
|
||||
|
||||
pool = engine.pool
|
||||
# 获取连接池状态
|
||||
checked_out = pool.checkedout()
|
||||
pool_size = pool.size()
|
||||
overflow = pool.overflow()
|
||||
max_capacity = config.db_pool_size + config.db_max_overflow
|
||||
|
||||
# 计算使用率
|
||||
usage_rate = (checked_out / max_capacity) * 100 if max_capacity > 0 else 0
|
||||
|
||||
# 如果使用率超过阈值,发出警告
|
||||
if usage_rate >= config.db_pool_warn_threshold:
|
||||
current_time = time.time()
|
||||
# 避免频繁警告
|
||||
if current_time - _last_pool_warning > POOL_WARNING_INTERVAL:
|
||||
_last_pool_warning = current_time
|
||||
logger.warning(
|
||||
f"数据库连接池使用率过高: checked_out={checked_out}, "
|
||||
f"pool_size={pool_size}, overflow={overflow}, "
|
||||
f"max_capacity={max_capacity}, usage_rate={usage_rate:.1f}%, "
|
||||
f"threshold={config.db_pool_warn_threshold}%"
|
||||
)
|
||||
|
||||
|
||||
def get_pool_status() -> dict:
|
||||
"""获取连接池状态"""
|
||||
engine = _ensure_engine()
|
||||
pool = engine.pool
|
||||
|
||||
return {
|
||||
"checked_out": pool.checkedout(),
|
||||
"pool_size": pool.size(),
|
||||
"overflow": pool.overflow(),
|
||||
"max_capacity": config.db_pool_size + config.db_max_overflow,
|
||||
"pool_timeout": config.db_pool_timeout,
|
||||
}
|
||||
|
||||
|
||||
def log_pool_status():
|
||||
"""记录连接池状态到日志(用于监控)"""
|
||||
try:
|
||||
status = get_pool_status()
|
||||
usage_rate = (
|
||||
(status["checked_out"] / status["max_capacity"] * 100)
|
||||
if status["max_capacity"] > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"数据库连接池状态: checked_out={status['checked_out']}, "
|
||||
f"pool_size={status['pool_size']}, overflow={status['overflow']}, "
|
||||
f"max_capacity={status['max_capacity']}, usage_rate={usage_rate:.1f}%"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取连接池状态失败: {e}")
|
||||
|
||||
|
||||
def _ensure_engine() -> Engine:
|
||||
"""
|
||||
确保数据库引擎已创建(延迟加载)
|
||||
|
||||
这允许测试和 CLI 工具在导入模块时不会立即连接数据库
|
||||
"""
|
||||
global _engine, _SessionLocal
|
||||
|
||||
if _engine is not None:
|
||||
return _engine
|
||||
|
||||
# 获取数据库配置
|
||||
DATABASE_URL = config.database_url
|
||||
|
||||
# 验证数据库类型(生产环境要求 PostgreSQL,但允许测试环境使用其他数据库)
|
||||
is_production = config.environment == "production"
|
||||
if is_production and not DATABASE_URL.startswith("postgresql://"):
|
||||
raise ValueError("生产环境只支持 PostgreSQL 数据库,请配置正确的 DATABASE_URL")
|
||||
|
||||
# 创建引擎
|
||||
_engine = create_engine(
|
||||
DATABASE_URL,
|
||||
poolclass=QueuePool, # 使用队列连接池
|
||||
pool_size=config.db_pool_size, # 连接池大小
|
||||
max_overflow=config.db_max_overflow, # 最大溢出连接数
|
||||
pool_timeout=config.db_pool_timeout, # 连接超时(秒)
|
||||
pool_recycle=config.db_pool_recycle, # 连接回收时间(秒)
|
||||
pool_pre_ping=True, # 检查连接活性
|
||||
echo=False, # 关闭SQL日志输出(太冗长)
|
||||
)
|
||||
|
||||
# 设置连接池监控
|
||||
_setup_pool_monitoring(_engine)
|
||||
|
||||
# 创建会话工厂
|
||||
_SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=_engine)
|
||||
|
||||
_log_pool_capacity()
|
||||
|
||||
logger.debug(f"数据库引擎已初始化: {DATABASE_URL.split('@')[-1] if '@' in DATABASE_URL else 'local'}")
|
||||
|
||||
return _engine
|
||||
|
||||
|
||||
def _log_pool_capacity():
|
||||
theoretical = config.db_pool_size + config.db_max_overflow
|
||||
workers = max(1, config.worker_processes)
|
||||
total_estimated = theoretical * workers
|
||||
logger.info("数据库连接池配置")
|
||||
if total_estimated > config.db_pool_warn_threshold:
|
||||
logger.warning("数据库连接需求可能超过阈值,请调小池大小或减少 worker 数")
|
||||
|
||||
|
||||
def _ensure_async_engine() -> AsyncEngine:
|
||||
"""
|
||||
确保异步数据库引擎已创建(延迟加载)
|
||||
|
||||
这允许异步路由使用非阻塞的数据库访问
|
||||
"""
|
||||
global _async_engine, _AsyncSessionLocal
|
||||
|
||||
if _async_engine is not None:
|
||||
return _async_engine
|
||||
|
||||
# 获取数据库配置并转换为异步URL
|
||||
DATABASE_URL = config.database_url
|
||||
|
||||
# 转换同步URL为异步URL(postgresql:// -> postgresql+asyncpg://)
|
||||
if DATABASE_URL.startswith("postgresql://"):
|
||||
ASYNC_DATABASE_URL = DATABASE_URL.replace("postgresql://", "postgresql+asyncpg://", 1)
|
||||
elif DATABASE_URL.startswith("sqlite:///"):
|
||||
ASYNC_DATABASE_URL = DATABASE_URL.replace("sqlite:///", "sqlite+aiosqlite:///", 1)
|
||||
else:
|
||||
raise ValueError(f"不支持的数据库类型: {DATABASE_URL}")
|
||||
|
||||
# 验证数据库类型(生产环境要求 PostgreSQL)
|
||||
is_production = config.environment == "production"
|
||||
if is_production and not ASYNC_DATABASE_URL.startswith("postgresql+asyncpg://"):
|
||||
raise ValueError("生产环境只支持 PostgreSQL 数据库,请配置正确的 DATABASE_URL")
|
||||
|
||||
# 创建异步引擎
|
||||
_async_engine = create_async_engine(
|
||||
ASYNC_DATABASE_URL,
|
||||
poolclass=QueuePool, # 使用队列连接池
|
||||
pool_size=config.db_pool_size,
|
||||
max_overflow=config.db_max_overflow,
|
||||
pool_timeout=config.db_pool_timeout,
|
||||
pool_recycle=config.db_pool_recycle,
|
||||
pool_pre_ping=True,
|
||||
echo=False,
|
||||
)
|
||||
|
||||
# 创建异步会话工厂
|
||||
_AsyncSessionLocal = async_sessionmaker(
|
||||
_async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
logger.debug(f"异步数据库引擎已初始化: {ASYNC_DATABASE_URL.split('@')[-1] if '@' in ASYNC_DATABASE_URL else 'local'}")
|
||||
|
||||
return _async_engine
|
||||
|
||||
|
||||
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""获取异步数据库会话"""
|
||||
# 确保异步引擎已初始化
|
||||
_ensure_async_engine()
|
||||
|
||||
async with _AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
"""获取数据库会话
|
||||
|
||||
注意:事务管理由业务逻辑层显式控制(手动调用 commit/rollback)
|
||||
这里只负责会话的创建和关闭,不自动提交
|
||||
"""
|
||||
# 确保引擎已初始化
|
||||
_ensure_engine()
|
||||
|
||||
db = _SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
# 不再自动 commit,由业务代码显式管理事务
|
||||
except Exception:
|
||||
try:
|
||||
db.rollback() # 失败时回滚未提交的事务
|
||||
except Exception as rollback_error:
|
||||
# 记录回滚错误(可能是 commit 正在进行中)
|
||||
logger.debug(f"回滚事务时出错(可忽略): {rollback_error}")
|
||||
raise
|
||||
finally:
|
||||
try:
|
||||
db.close() # 确保连接返回池
|
||||
except Exception as close_error:
|
||||
# 记录关闭错误(如 IllegalStateChangeError)
|
||||
# 连接池会处理连接的回收
|
||||
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
|
||||
|
||||
|
||||
def create_session() -> Session:
|
||||
"""
|
||||
创建一个新的数据库会话
|
||||
|
||||
注意:调用者必须负责关闭会话
|
||||
推荐在 with 语句中使用或手动调用 session.close()
|
||||
|
||||
示例:
|
||||
db = create_session()
|
||||
try:
|
||||
# 使用 db
|
||||
finally:
|
||||
db.close()
|
||||
"""
|
||||
_ensure_engine()
|
||||
return _SessionLocal()
|
||||
|
||||
|
||||
def get_db_url() -> str:
|
||||
"""返回当前配置的数据库连接字符串(供脚本/测试使用)。"""
|
||||
return config.database_url
|
||||
|
||||
|
||||
def init_db():
|
||||
"""初始化数据库"""
|
||||
logger.info("初始化数据库...")
|
||||
|
||||
# 确保引擎已创建
|
||||
engine = _ensure_engine()
|
||||
|
||||
# 创建所有表
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
# 数据库表已通过SQLAlchemy自动创建
|
||||
|
||||
db = _SessionLocal()
|
||||
try:
|
||||
# 创建管理员账户(如果环境变量中配置了)
|
||||
init_admin_user(db)
|
||||
|
||||
# 添加默认模型配置
|
||||
init_default_models(db)
|
||||
|
||||
# 添加系统配置
|
||||
init_system_configs(db)
|
||||
|
||||
db.commit()
|
||||
logger.info("数据库初始化完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"数据库初始化失败: {e}")
|
||||
db.rollback()
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def init_admin_user(db: Session):
|
||||
"""从环境变量创建管理员账户"""
|
||||
# 检查是否使用默认凭据
|
||||
if config.admin_email == "admin@localhost" and config.admin_password == "admin123":
|
||||
logger.warning("使用默认管理员账户配置,建议修改为安全的凭据")
|
||||
|
||||
# 检查是否已存在管理员
|
||||
existing_admin = (
|
||||
db.query(User)
|
||||
.filter((User.email == config.admin_email) | (User.username == config.admin_username))
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_admin:
|
||||
logger.info(f"管理员账户已存在: {existing_admin.email}")
|
||||
return
|
||||
|
||||
try:
|
||||
# 创建管理员账户
|
||||
admin = User(
|
||||
email=config.admin_email,
|
||||
username=config.admin_username,
|
||||
role=UserRole.ADMIN,
|
||||
quota_usd=1000.0,
|
||||
is_active=True,
|
||||
)
|
||||
admin.set_password(config.admin_password)
|
||||
|
||||
db.add(admin)
|
||||
db.commit() # 刷新以获取ID,但不提交
|
||||
|
||||
logger.info(f"创建管理员账户成功: {admin.email} ({admin.username})")
|
||||
except Exception as e:
|
||||
logger.error(f"创建管理员账户失败: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def init_default_models(db: Session):
|
||||
"""初始化默认模型配置"""
|
||||
|
||||
# 注意:作为中转代理服务,不再预设模型配置
|
||||
# 模型配置应该通过 Model 和 ModelMapping 表动态管理
|
||||
# 这个函数保留用于未来可能的默认模型初始化
|
||||
pass
|
||||
|
||||
|
||||
def init_system_configs(db: Session):
|
||||
"""初始化系统配置"""
|
||||
|
||||
configs = [
|
||||
{"key": "default_user_quota_usd", "value": 10.0, "description": "新用户默认美元配额"},
|
||||
{"key": "rate_limit_per_minute", "value": 60, "description": "每分钟请求限制"},
|
||||
{"key": "enable_registration", "value": False, "description": "是否开放用户注册"},
|
||||
{"key": "require_email_verification", "value": False, "description": "是否需要邮箱验证"},
|
||||
{"key": "api_key_expire_days", "value": 365, "description": "API密钥过期天数"},
|
||||
]
|
||||
|
||||
for config_data in configs:
|
||||
existing = db.query(SystemConfig).filter_by(key=config_data["key"]).first()
|
||||
if not existing:
|
||||
config = SystemConfig(**config_data)
|
||||
db.add(config)
|
||||
logger.info(f"添加系统配置: {config_data['key']}")
|
||||
|
||||
|
||||
def reset_db():
|
||||
"""重置数据库(仅用于开发)"""
|
||||
logger.warning("重置数据库...")
|
||||
|
||||
# 确保引擎已创建
|
||||
engine = _ensure_engine()
|
||||
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
init_db()
|
||||
Reference in New Issue
Block a user