Initial commit

This commit is contained in:
fawney19
2025-12-10 20:52:44 +08:00
commit f784106826
485 changed files with 110993 additions and 0 deletions

20
src/database/__init__.py Normal file
View File

@@ -0,0 +1,20 @@
"""
数据库模块
"""
from ..models.database import ApiKey, Base, Usage, User, UserQuota
from .database import create_session, get_async_db, get_db, get_db_url, init_db, log_pool_status
__all__ = [
"Base",
"User",
"ApiKey",
"Usage",
"UserQuota",
"get_db",
"get_async_db",
"init_db",
"create_session",
"get_db_url",
"log_pool_status",
]

View File

@@ -0,0 +1,41 @@
"""
异步数据库工具
提供在异步上下文中安全使用同步数据库操作的工具
"""
import asyncio
from functools import wraps
from typing import Any, Callable, TypeVar
T = TypeVar("T")
async def run_in_executor(func: Callable[..., T], *args, **kwargs) -> T:
"""
在线程池中运行同步函数,避免阻塞事件循环
用法:
result = await run_in_executor(some_sync_function, arg1, arg2)
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, lambda: func(*args, **kwargs))
def async_wrap_sync_db(func: Callable[..., T]) -> Callable[..., Any]:
"""
装饰器:包装同步数据库函数为异步函数
用法:
@async_wrap_sync_db
def get_user(db: Session, user_id: int):
return db.query(User).filter(User.id == user_id).first()
# 现在可以在异步上下文中调用
user = await get_user(db, 123)
"""
@wraps(func)
async def wrapper(*args, **kwargs):
return await run_in_executor(func, *args, **kwargs)
return wrapper

382
src/database/database.py Normal file
View File

@@ -0,0 +1,382 @@
"""
数据库连接和初始化
"""
import time
from typing import AsyncGenerator, Generator, Optional
from sqlalchemy import create_engine, event
from sqlalchemy.engine import Engine
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import Pool, QueuePool
from ..config import config
from src.core.logger import logger
from ..models.database import Base, SystemConfig, User, UserRole
# 延迟初始化的数据库引擎和会话工厂
_engine: Optional[Engine] = None
_SessionLocal: Optional[sessionmaker] = None
_async_engine: Optional[AsyncEngine] = None
_AsyncSessionLocal: Optional[async_sessionmaker] = None
# 连接池监控
_last_pool_warning: float = 0.0
POOL_WARNING_INTERVAL = 60 # 每60秒最多警告一次
def _setup_pool_monitoring(engine: Engine):
"""设置连接池监控事件"""
@event.listens_for(engine, "connect")
def receive_connect(dbapi_conn, connection_record):
"""连接创建时的监控"""
pass
@event.listens_for(engine, "checkout")
def receive_checkout(dbapi_conn, connection_record, connection_proxy):
"""从连接池检出连接时的监控"""
global _last_pool_warning
pool = engine.pool
# 获取连接池状态
checked_out = pool.checkedout()
pool_size = pool.size()
overflow = pool.overflow()
max_capacity = config.db_pool_size + config.db_max_overflow
# 计算使用率
usage_rate = (checked_out / max_capacity) * 100 if max_capacity > 0 else 0
# 如果使用率超过阈值,发出警告
if usage_rate >= config.db_pool_warn_threshold:
current_time = time.time()
# 避免频繁警告
if current_time - _last_pool_warning > POOL_WARNING_INTERVAL:
_last_pool_warning = current_time
logger.warning(
f"数据库连接池使用率过高: checked_out={checked_out}, "
f"pool_size={pool_size}, overflow={overflow}, "
f"max_capacity={max_capacity}, usage_rate={usage_rate:.1f}%, "
f"threshold={config.db_pool_warn_threshold}%"
)
def get_pool_status() -> dict:
"""获取连接池状态"""
engine = _ensure_engine()
pool = engine.pool
return {
"checked_out": pool.checkedout(),
"pool_size": pool.size(),
"overflow": pool.overflow(),
"max_capacity": config.db_pool_size + config.db_max_overflow,
"pool_timeout": config.db_pool_timeout,
}
def log_pool_status():
"""记录连接池状态到日志(用于监控)"""
try:
status = get_pool_status()
usage_rate = (
(status["checked_out"] / status["max_capacity"] * 100)
if status["max_capacity"] > 0
else 0
)
logger.info(
f"数据库连接池状态: checked_out={status['checked_out']}, "
f"pool_size={status['pool_size']}, overflow={status['overflow']}, "
f"max_capacity={status['max_capacity']}, usage_rate={usage_rate:.1f}%"
)
except Exception as e:
logger.error(f"获取连接池状态失败: {e}")
def _ensure_engine() -> Engine:
"""
确保数据库引擎已创建(延迟加载)
这允许测试和 CLI 工具在导入模块时不会立即连接数据库
"""
global _engine, _SessionLocal
if _engine is not None:
return _engine
# 获取数据库配置
DATABASE_URL = config.database_url
# 验证数据库类型(生产环境要求 PostgreSQL但允许测试环境使用其他数据库
is_production = config.environment == "production"
if is_production and not DATABASE_URL.startswith("postgresql://"):
raise ValueError("生产环境只支持 PostgreSQL 数据库,请配置正确的 DATABASE_URL")
# 创建引擎
_engine = create_engine(
DATABASE_URL,
poolclass=QueuePool, # 使用队列连接池
pool_size=config.db_pool_size, # 连接池大小
max_overflow=config.db_max_overflow, # 最大溢出连接数
pool_timeout=config.db_pool_timeout, # 连接超时(秒)
pool_recycle=config.db_pool_recycle, # 连接回收时间(秒)
pool_pre_ping=True, # 检查连接活性
echo=False, # 关闭SQL日志输出太冗长
)
# 设置连接池监控
_setup_pool_monitoring(_engine)
# 创建会话工厂
_SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=_engine)
_log_pool_capacity()
logger.debug(f"数据库引擎已初始化: {DATABASE_URL.split('@')[-1] if '@' in DATABASE_URL else 'local'}")
return _engine
def _log_pool_capacity():
theoretical = config.db_pool_size + config.db_max_overflow
workers = max(1, config.worker_processes)
total_estimated = theoretical * workers
logger.info("数据库连接池配置")
if total_estimated > config.db_pool_warn_threshold:
logger.warning("数据库连接需求可能超过阈值,请调小池大小或减少 worker 数")
def _ensure_async_engine() -> AsyncEngine:
"""
确保异步数据库引擎已创建(延迟加载)
这允许异步路由使用非阻塞的数据库访问
"""
global _async_engine, _AsyncSessionLocal
if _async_engine is not None:
return _async_engine
# 获取数据库配置并转换为异步URL
DATABASE_URL = config.database_url
# 转换同步URL为异步URLpostgresql:// -> postgresql+asyncpg://
if DATABASE_URL.startswith("postgresql://"):
ASYNC_DATABASE_URL = DATABASE_URL.replace("postgresql://", "postgresql+asyncpg://", 1)
elif DATABASE_URL.startswith("sqlite:///"):
ASYNC_DATABASE_URL = DATABASE_URL.replace("sqlite:///", "sqlite+aiosqlite:///", 1)
else:
raise ValueError(f"不支持的数据库类型: {DATABASE_URL}")
# 验证数据库类型(生产环境要求 PostgreSQL
is_production = config.environment == "production"
if is_production and not ASYNC_DATABASE_URL.startswith("postgresql+asyncpg://"):
raise ValueError("生产环境只支持 PostgreSQL 数据库,请配置正确的 DATABASE_URL")
# 创建异步引擎
_async_engine = create_async_engine(
ASYNC_DATABASE_URL,
poolclass=QueuePool, # 使用队列连接池
pool_size=config.db_pool_size,
max_overflow=config.db_max_overflow,
pool_timeout=config.db_pool_timeout,
pool_recycle=config.db_pool_recycle,
pool_pre_ping=True,
echo=False,
)
# 创建异步会话工厂
_AsyncSessionLocal = async_sessionmaker(
_async_engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
logger.debug(f"异步数据库引擎已初始化: {ASYNC_DATABASE_URL.split('@')[-1] if '@' in ASYNC_DATABASE_URL else 'local'}")
return _async_engine
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
"""获取异步数据库会话"""
# 确保异步引擎已初始化
_ensure_async_engine()
async with _AsyncSessionLocal() as session:
try:
yield session
finally:
await session.close()
def get_db() -> Generator[Session, None, None]:
"""获取数据库会话
注意:事务管理由业务逻辑层显式控制(手动调用 commit/rollback
这里只负责会话的创建和关闭,不自动提交
"""
# 确保引擎已初始化
_ensure_engine()
db = _SessionLocal()
try:
yield db
# 不再自动 commit由业务代码显式管理事务
except Exception:
try:
db.rollback() # 失败时回滚未提交的事务
except Exception as rollback_error:
# 记录回滚错误(可能是 commit 正在进行中)
logger.debug(f"回滚事务时出错(可忽略): {rollback_error}")
raise
finally:
try:
db.close() # 确保连接返回池
except Exception as close_error:
# 记录关闭错误(如 IllegalStateChangeError
# 连接池会处理连接的回收
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
def create_session() -> Session:
"""
创建一个新的数据库会话
注意:调用者必须负责关闭会话
推荐在 with 语句中使用或手动调用 session.close()
示例:
db = create_session()
try:
# 使用 db
finally:
db.close()
"""
_ensure_engine()
return _SessionLocal()
def get_db_url() -> str:
"""返回当前配置的数据库连接字符串(供脚本/测试使用)。"""
return config.database_url
def init_db():
"""初始化数据库"""
logger.info("初始化数据库...")
# 确保引擎已创建
engine = _ensure_engine()
# 创建所有表
Base.metadata.create_all(bind=engine)
# 数据库表已通过SQLAlchemy自动创建
db = _SessionLocal()
try:
# 创建管理员账户(如果环境变量中配置了)
init_admin_user(db)
# 添加默认模型配置
init_default_models(db)
# 添加系统配置
init_system_configs(db)
db.commit()
logger.info("数据库初始化完成")
except Exception as e:
logger.error(f"数据库初始化失败: {e}")
db.rollback()
raise
finally:
db.close()
def init_admin_user(db: Session):
"""从环境变量创建管理员账户"""
# 检查是否使用默认凭据
if config.admin_email == "admin@localhost" and config.admin_password == "admin123":
logger.warning("使用默认管理员账户配置,建议修改为安全的凭据")
# 检查是否已存在管理员
existing_admin = (
db.query(User)
.filter((User.email == config.admin_email) | (User.username == config.admin_username))
.first()
)
if existing_admin:
logger.info(f"管理员账户已存在: {existing_admin.email}")
return
try:
# 创建管理员账户
admin = User(
email=config.admin_email,
username=config.admin_username,
role=UserRole.ADMIN,
quota_usd=1000.0,
is_active=True,
)
admin.set_password(config.admin_password)
db.add(admin)
db.commit() # 刷新以获取ID但不提交
logger.info(f"创建管理员账户成功: {admin.email} ({admin.username})")
except Exception as e:
logger.error(f"创建管理员账户失败: {e}")
raise
def init_default_models(db: Session):
"""初始化默认模型配置"""
# 注意:作为中转代理服务,不再预设模型配置
# 模型配置应该通过 Model 和 ModelMapping 表动态管理
# 这个函数保留用于未来可能的默认模型初始化
pass
def init_system_configs(db: Session):
"""初始化系统配置"""
configs = [
{"key": "default_user_quota_usd", "value": 10.0, "description": "新用户默认美元配额"},
{"key": "rate_limit_per_minute", "value": 60, "description": "每分钟请求限制"},
{"key": "enable_registration", "value": False, "description": "是否开放用户注册"},
{"key": "require_email_verification", "value": False, "description": "是否需要邮箱验证"},
{"key": "api_key_expire_days", "value": 365, "description": "API密钥过期天数"},
]
for config_data in configs:
existing = db.query(SystemConfig).filter_by(key=config_data["key"]).first()
if not existing:
config = SystemConfig(**config_data)
db.add(config)
logger.info(f"添加系统配置: {config_data['key']}")
def reset_db():
"""重置数据库(仅用于开发)"""
logger.warning("重置数据库...")
# 确保引擎已创建
engine = _ensure_engine()
Base.metadata.drop_all(bind=engine)
init_db()