mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 00:02:28 +08:00
361 lines
13 KiB
Python
361 lines
13 KiB
Python
"""
|
||
主应用入口
|
||
采用模块化架构设计
|
||
"""
|
||
|
||
from contextlib import asynccontextmanager
|
||
|
||
import uvicorn
|
||
from fastapi import FastAPI, HTTPException
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
|
||
from src.api.admin import router as admin_router
|
||
from src.api.announcements import router as announcement_router
|
||
|
||
# API路由
|
||
from src.api.auth import router as auth_router
|
||
from src.api.dashboard import router as dashboard_router
|
||
from src.api.monitoring import router as monitoring_router
|
||
from src.api.public import router as public_router
|
||
from src.api.user_me import router as me_router
|
||
from src.clients.http_client import HTTPClientPool, close_http_clients
|
||
|
||
# 核心模块
|
||
from src.config import config
|
||
from src.core.exceptions import ExceptionHandlers, ProxyException
|
||
from src.core.logger import logger
|
||
from src.database import init_db
|
||
|
||
from src.middleware.plugin_middleware import PluginMiddleware
|
||
from src.plugins.manager import get_plugin_manager
|
||
|
||
|
||
|
||
async def initialize_providers():
|
||
"""从数据库初始化提供商(仅用于日志记录)"""
|
||
from sqlalchemy.orm import Session
|
||
|
||
from src.database.database import create_session
|
||
from src.models.database import Provider
|
||
|
||
try:
|
||
# 创建数据库会话
|
||
db: Session = create_session()
|
||
|
||
try:
|
||
# 从数据库加载所有活跃的提供商
|
||
providers = (
|
||
db.query(Provider)
|
||
.filter(Provider.is_active.is_(True))
|
||
.order_by(Provider.provider_priority.asc())
|
||
.all()
|
||
)
|
||
|
||
if not providers:
|
||
logger.warning("数据库中未找到活跃的提供商")
|
||
return
|
||
|
||
# 记录提供商信息
|
||
logger.info(f"从数据库加载了 {len(providers)} 个活跃提供商")
|
||
for provider in providers:
|
||
# 统计端点信息
|
||
endpoint_count = len(provider.endpoints) if provider.endpoints else 0
|
||
active_endpoints = (
|
||
sum(1 for ep in provider.endpoints if ep.is_active) if provider.endpoints else 0
|
||
)
|
||
|
||
logger.info(f"提供商: {provider.name} (端点: {active_endpoints}/{endpoint_count})")
|
||
|
||
finally:
|
||
db.close()
|
||
|
||
except Exception:
|
||
logger.exception("从数据库初始化提供商失败")
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
"""应用生命周期管理"""
|
||
# 禁用uvicorn的access日志(在子进程中执行)
|
||
import logging
|
||
|
||
logging.getLogger("uvicorn.access").setLevel(logging.CRITICAL)
|
||
logging.getLogger("uvicorn.access").disabled = True
|
||
|
||
# 启动时执行
|
||
logger.info("=" * 60)
|
||
from src import __version__
|
||
|
||
logger.info(f"AI Proxy v{__version__} - GlobalModel Architecture")
|
||
logger.info("=" * 60)
|
||
|
||
# 安全配置验证(生产环境会阻止启动)
|
||
security_errors = config.validate_security_config()
|
||
if security_errors:
|
||
for error in security_errors:
|
||
logger.error(f"[SECURITY] {error}")
|
||
if config.environment == "production":
|
||
raise RuntimeError(
|
||
"Security configuration errors detected. "
|
||
"Please fix the following issues before starting in production:\n"
|
||
+ "\n".join(f" - {e}" for e in security_errors)
|
||
)
|
||
|
||
# 记录启动警告(密码、连接池、JWT 等)
|
||
config.log_startup_warnings()
|
||
|
||
# 初始化数据库
|
||
logger.info("初始化数据库...")
|
||
init_db()
|
||
|
||
# 从数据库初始化提供商
|
||
await initialize_providers()
|
||
|
||
# 初始化全局HTTP客户端池
|
||
logger.info("初始化全局HTTP客户端池...")
|
||
HTTPClientPool.get_default_client() # 预创建默认客户端
|
||
|
||
# 初始化全局Redis客户端(可根据配置降级为内存模式)
|
||
logger.info("初始化全局Redis客户端...")
|
||
from src.clients.redis_client import get_redis_client
|
||
|
||
redis_client = None
|
||
try:
|
||
redis_client = await get_redis_client(require_redis=config.require_redis)
|
||
if redis_client:
|
||
logger.info("[OK] Redis客户端初始化成功,缓存亲和性功能已启用")
|
||
else:
|
||
logger.warning("[WARN] Redis未启用或连接失败,将使用内存缓存亲和性(仅适用于单实例/开发环境)")
|
||
except RuntimeError as e:
|
||
if config.require_redis:
|
||
logger.exception("[ERROR] Redis连接失败,应用启动中止")
|
||
raise
|
||
logger.warning(f"Redis连接失败,但配置允许降级,将继续使用内存模式: {e}")
|
||
redis_client = None
|
||
|
||
# 初始化并发管理器(内部会使用Redis)
|
||
logger.info("初始化并发管理器...")
|
||
from src.services.rate_limit.concurrency_manager import get_concurrency_manager
|
||
|
||
concurrency_manager = await get_concurrency_manager()
|
||
|
||
# 初始化批量提交器(提升数据库并发能力)
|
||
logger.info("初始化批量提交器...")
|
||
from src.core.batch_committer import init_batch_committer
|
||
|
||
await init_batch_committer()
|
||
logger.info("[OK] 批量提交器已启动,数据库写入性能优化已启用")
|
||
|
||
# 初始化插件系统
|
||
logger.info("初始化插件系统...")
|
||
plugin_manager = get_plugin_manager()
|
||
init_results = await plugin_manager.initialize_all()
|
||
successful = sum(1 for success in init_results.values() if success)
|
||
logger.info(f"插件初始化完成: {successful}/{len(init_results)} 个插件成功启动")
|
||
|
||
# 注册格式转换器
|
||
logger.info("注册格式转换器...")
|
||
from src.api.handlers.base.format_converter_registry import register_all_converters
|
||
|
||
register_all_converters()
|
||
|
||
logger.info(f"服务启动成功: http://{config.host}:{config.port}")
|
||
logger.info("=" * 60)
|
||
|
||
# 启动月卡额度重置调度器(仅一个 worker 执行)
|
||
logger.info("启动月卡额度重置调度器...")
|
||
from src.services.system.cleanup_scheduler import get_cleanup_scheduler
|
||
from src.services.usage.quota_scheduler import get_quota_scheduler
|
||
from src.utils.task_coordinator import StartupTaskCoordinator
|
||
|
||
quota_scheduler = get_quota_scheduler()
|
||
cleanup_scheduler = get_cleanup_scheduler()
|
||
task_coordinator = StartupTaskCoordinator(redis_client)
|
||
|
||
# 启动额度调度器
|
||
quota_scheduler_active = await task_coordinator.acquire("quota_scheduler")
|
||
if quota_scheduler_active:
|
||
await quota_scheduler.start()
|
||
else:
|
||
logger.info("检测到其他 worker 已运行额度调度器,本实例跳过")
|
||
quota_scheduler = None
|
||
|
||
# 启动清理调度器
|
||
cleanup_scheduler_active = await task_coordinator.acquire("cleanup_scheduler")
|
||
if cleanup_scheduler_active:
|
||
logger.info("启动使用记录清理调度器...")
|
||
await cleanup_scheduler.start()
|
||
else:
|
||
logger.info("检测到其他 worker 已运行清理调度器,本实例跳过")
|
||
cleanup_scheduler = None
|
||
|
||
# 启动统一的定时任务调度器
|
||
from src.services.system.scheduler import get_scheduler
|
||
|
||
task_scheduler = get_scheduler()
|
||
task_scheduler.start()
|
||
|
||
yield # 应用运行期间
|
||
|
||
# 关闭时执行
|
||
logger.info("正在关闭服务...")
|
||
|
||
# 停止批量提交器(确保所有待提交的数据都被保存)
|
||
logger.info("停止批量提交器...")
|
||
from src.core.batch_committer import shutdown_batch_committer
|
||
|
||
await shutdown_batch_committer()
|
||
logger.info("[OK] 批量提交器已停止,所有待提交数据已保存")
|
||
|
||
# 停止清理调度器
|
||
if cleanup_scheduler:
|
||
logger.info("停止使用记录清理调度器...")
|
||
await cleanup_scheduler.stop()
|
||
await task_coordinator.release("cleanup_scheduler")
|
||
|
||
# 停止月卡额度重置调度器,并释放分布式锁
|
||
logger.info("停止月卡额度重置调度器...")
|
||
if quota_scheduler:
|
||
await quota_scheduler.stop()
|
||
if task_coordinator:
|
||
await task_coordinator.release("quota_scheduler")
|
||
|
||
# 停止统一的定时任务调度器
|
||
logger.info("停止定时任务调度器...")
|
||
task_scheduler.stop()
|
||
|
||
# 关闭插件系统
|
||
logger.info("关闭插件系统...")
|
||
await plugin_manager.shutdown_all()
|
||
|
||
# 关闭并发管理器
|
||
logger.info("关闭并发管理器...")
|
||
if concurrency_manager:
|
||
await concurrency_manager.close()
|
||
|
||
# 关闭全局Redis客户端
|
||
logger.info("关闭全局Redis客户端...")
|
||
from src.clients.redis_client import close_redis_client
|
||
|
||
await close_redis_client()
|
||
|
||
# 关闭HTTP客户端池
|
||
logger.info("关闭HTTP客户端池...")
|
||
await close_http_clients()
|
||
|
||
logger.info("服务已关闭")
|
||
|
||
|
||
from src import __version__ as app_version
|
||
|
||
app = FastAPI(
|
||
title="AI Proxy with Modular Architecture",
|
||
version=app_version,
|
||
description="AI代理服务,采用模块化架构,支持插件化扩展",
|
||
lifespan=lifespan,
|
||
)
|
||
|
||
# 注册全局异常处理器
|
||
# 注意:异常处理器的注册顺序很重要,必须先注册更通用的异常类型,再注册具体的
|
||
# ProxyException 处理器的启用由配置控制:
|
||
# - propagate_provider_exceptions=True (默认): 不注册,让异常传播到路由层以记录 provider_request_headers
|
||
# - propagate_provider_exceptions=False: 注册全局处理器统一处理
|
||
if not config.propagate_provider_exceptions:
|
||
app.add_exception_handler(ProxyException, ExceptionHandlers.handle_proxy_exception)
|
||
app.add_exception_handler(Exception, ExceptionHandlers.handle_generic_exception)
|
||
app.add_exception_handler(HTTPException, ExceptionHandlers.handle_http_exception)
|
||
|
||
# 添加插件中间件(包含认证、审计、速率限制等功能)
|
||
app.add_middleware(PluginMiddleware)
|
||
|
||
# CORS配置 - 使用环境变量配置允许的域名
|
||
# 生产环境必须通过 CORS_ORIGINS 环境变量显式指定允许的域名
|
||
# 开发环境默认允许本地前端访问
|
||
if config.cors_origins:
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=config.cors_origins, # 使用配置的白名单
|
||
allow_credentials=config.cors_allow_credentials,
|
||
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"],
|
||
allow_headers=["*"],
|
||
expose_headers=["*"],
|
||
)
|
||
logger.info(f"CORS已启用,允许的源: {config.cors_origins}")
|
||
else:
|
||
# 没有配置CORS源,不允许跨域
|
||
logger.warning(
|
||
f"CORS未配置,不允许跨域请求。如需启用CORS,请设置 CORS_ORIGINS 环境变量(当前环境: {config.environment})"
|
||
)
|
||
|
||
# 注册路由
|
||
app.include_router(auth_router) # 认证相关
|
||
app.include_router(admin_router) # 管理员端点
|
||
app.include_router(me_router) # 用户个人端点
|
||
app.include_router(announcement_router) # 公告系统
|
||
app.include_router(dashboard_router) # 仪表盘端点
|
||
app.include_router(public_router) # 公开API端点(用户可查看提供商和模型)
|
||
app.include_router(monitoring_router) # 监控端点
|
||
|
||
|
||
|
||
def main():
|
||
# 初始化新日志系统
|
||
debug_mode = config.environment == "development"
|
||
# 日志系统已在导入时自动初始化
|
||
|
||
# Parse log level
|
||
log_level = config.log_level.split()[0].lower()
|
||
if log_level not in ["debug", "info", "warning", "error", "critical"]:
|
||
log_level = "info"
|
||
|
||
# 自定义uvicorn日志配置,完全禁用access日志
|
||
uvicorn_log_config = {
|
||
"version": 1,
|
||
"disable_existing_loggers": False,
|
||
"formatters": {
|
||
"default": {
|
||
"format": "%(levelprefix)s %(message)s",
|
||
"use_colors": True,
|
||
},
|
||
},
|
||
"handlers": {
|
||
"default": {
|
||
"formatter": "default",
|
||
"class": "logging.StreamHandler",
|
||
"stream": "ext://sys.stderr",
|
||
},
|
||
},
|
||
"loggers": {
|
||
"uvicorn": {"handlers": ["default"], "level": log_level.upper()},
|
||
"uvicorn.error": {"level": log_level.upper()},
|
||
"uvicorn.access": {"handlers": [], "level": "CRITICAL"}, # 禁用access日志
|
||
},
|
||
}
|
||
|
||
# Start server
|
||
# 根据环境设置热重载
|
||
uvicorn.run(
|
||
"src.main:app",
|
||
host=config.host,
|
||
port=config.port,
|
||
log_level=log_level,
|
||
reload=config.environment == "development", # 只在开发环境启用热重载
|
||
access_log=False, # 禁用 uvicorn 访问日志,使用自定义中间件
|
||
log_config=uvicorn_log_config, # 使用自定义日志配置
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 使用安全的方式清屏,避免命令注入风险
|
||
try:
|
||
import os
|
||
|
||
if os.name == "nt": # Windows
|
||
os.system("cls")
|
||
else: # Unix/Linux/MacOS
|
||
print("\033[2J\033[H", end="") # ANSI escape sequence
|
||
except:
|
||
pass # 清屏失败不影响程序运行
|
||
|
||
main()
|