refactor: optimize provider query and stats aggregation logic

This commit is contained in:
fawney19
2025-12-17 16:41:10 +08:00
parent 50abb55c94
commit 1dac4cb156
21 changed files with 1753 additions and 592 deletions

View File

@@ -35,6 +35,7 @@ class CleanupScheduler:
def __init__(self):
self.running = False
self._interval_tasks = []
self._stats_aggregation_lock = asyncio.Lock()
async def start(self):
"""启动调度器"""
@@ -56,6 +57,14 @@ class CleanupScheduler:
job_id="stats_aggregation",
name="统计数据聚合",
)
# 统计聚合补偿任务 - 每 30 分钟检查缺失并回填
scheduler.add_interval_job(
self._scheduled_stats_aggregation,
minutes=30,
job_id="stats_aggregation_backfill",
name="统计数据聚合补偿",
backfill=True,
)
# 清理任务 - 凌晨 3 点执行
scheduler.add_cron_job(
@@ -115,9 +124,9 @@ class CleanupScheduler:
# ========== 任务函数APScheduler 直接调用异步函数) ==========
async def _scheduled_stats_aggregation(self):
async def _scheduled_stats_aggregation(self, backfill: bool = False):
"""统计聚合任务(定时调用)"""
await self._perform_stats_aggregation()
await self._perform_stats_aggregation(backfill=backfill)
async def _scheduled_cleanup(self):
"""清理任务(定时调用)"""
@@ -144,136 +153,157 @@ class CleanupScheduler:
Args:
backfill: 是否回填历史数据(启动时检查缺失的日期)
"""
db = create_session()
try:
# 检查是否启用统计聚合
if not SystemConfigService.get_config(db, "enable_stats_aggregation", True):
logger.info("统计聚合已禁用,跳过聚合任务")
return
if self._stats_aggregation_lock.locked():
logger.info("统计聚合任务正在运行,跳过本次触发")
return
logger.info("开始执行统计数据聚合...")
from src.models.database import StatsDaily, User as DBUser
from src.services.system.scheduler import APP_TIMEZONE
from zoneinfo import ZoneInfo
# 使用业务时区计算日期,确保与定时任务触发时间一致
# 定时任务在 Asia/Shanghai 凌晨 1 点触发,此时应聚合 Asia/Shanghai 的"昨天"
app_tz = ZoneInfo(APP_TIMEZONE)
now_local = datetime.now(app_tz)
today_local = now_local.replace(hour=0, minute=0, second=0, microsecond=0)
if backfill:
# 启动时检查并回填缺失的日期
from src.models.database import StatsSummary
summary = db.query(StatsSummary).first()
if not summary:
# 首次运行,回填所有历史数据
logger.info("检测到首次运行,开始回填历史统计数据...")
days_to_backfill = SystemConfigService.get_config(
db, "stats_backfill_days", 365
)
count = StatsAggregatorService.backfill_historical_data(
db, days=days_to_backfill
)
logger.info(f"历史数据回填完成,共 {count}")
async with self._stats_aggregation_lock:
db = create_session()
try:
# 检查是否启用统计聚合
if not SystemConfigService.get_config(db, "enable_stats_aggregation", True):
logger.info("统计聚合已禁用,跳过聚合任务")
return
# 非首次运行,检查最近是否有缺失的日期需要回填
latest_stat = (
db.query(StatsDaily)
.order_by(StatsDaily.date.desc())
.first()
)
logger.info("开始执行统计数据聚合...")
if latest_stat:
latest_date_utc = latest_stat.date
if latest_date_utc.tzinfo is None:
latest_date_utc = latest_date_utc.replace(tzinfo=timezone.utc)
else:
latest_date_utc = latest_date_utc.astimezone(timezone.utc)
from src.models.database import StatsDaily, User as DBUser
from src.services.system.scheduler import APP_TIMEZONE
from zoneinfo import ZoneInfo
# 使用业务日期计算缺失区间(避免用 UTC 年月日导致日期偏移,且对 DST 更安全)
latest_business_date = latest_date_utc.astimezone(app_tz).date()
yesterday_business_date = (today_local.date() - timedelta(days=1))
missing_start_date = latest_business_date + timedelta(days=1)
# 使用业务时区计算日期,确保与定时任务触发时间一致
# 定时任务在 Asia/Shanghai 凌晨 1 点触发,此时应聚合 Asia/Shanghai 的"昨天"
app_tz = ZoneInfo(APP_TIMEZONE)
now_local = datetime.now(app_tz)
today_local = now_local.replace(hour=0, minute=0, second=0, microsecond=0)
if missing_start_date <= yesterday_business_date:
missing_days = (yesterday_business_date - missing_start_date).days + 1
logger.info(
f"检测到缺失 {missing_days} 天的统计数据 "
f"({missing_start_date} ~ {yesterday_business_date}),开始回填..."
if backfill:
# 启动时检查并回填缺失的日期
from src.models.database import StatsSummary
summary = db.query(StatsSummary).first()
if not summary:
# 首次运行,回填所有历史数据
logger.info("检测到首次运行,开始回填历史统计数据...")
days_to_backfill = SystemConfigService.get_config(
db, "stats_backfill_days", 365
)
count = StatsAggregatorService.backfill_historical_data(
db, days=days_to_backfill
)
logger.info(f"历史数据回填完成,共 {count}")
return
current_date = missing_start_date
users = db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
# 非首次运行,检查最近是否有缺失的日期需要回填
latest_stat = db.query(StatsDaily).order_by(StatsDaily.date.desc()).first()
while current_date <= yesterday_business_date:
try:
current_date_local = datetime.combine(
current_date, datetime.min.time(), tzinfo=app_tz
if latest_stat:
latest_date_utc = latest_stat.date
if latest_date_utc.tzinfo is None:
latest_date_utc = latest_date_utc.replace(tzinfo=timezone.utc)
else:
latest_date_utc = latest_date_utc.astimezone(timezone.utc)
# 使用业务日期计算缺失区间(避免用 UTC 年月日导致日期偏移,且对 DST 更安全)
latest_business_date = latest_date_utc.astimezone(app_tz).date()
yesterday_business_date = today_local.date() - timedelta(days=1)
missing_start_date = latest_business_date + timedelta(days=1)
if missing_start_date <= yesterday_business_date:
missing_days = (
yesterday_business_date - missing_start_date
).days + 1
# 限制最大回填天数,防止停机很久后一次性回填太多
max_backfill_days: int = SystemConfigService.get_config(
db, "max_stats_backfill_days", 30
) or 30
if missing_days > max_backfill_days:
logger.warning(
f"缺失 {missing_days} 天数据超过最大回填限制 "
f"{max_backfill_days} 天,只回填最近 {max_backfill_days}"
)
StatsAggregatorService.aggregate_daily_stats(db, current_date_local)
# 聚合用户数据
for (user_id,) in users:
try:
StatsAggregatorService.aggregate_user_daily_stats(
db, user_id, current_date_local
)
except Exception as e:
logger.warning(
f"回填用户 {user_id} 日期 {current_date} 失败: {e}"
)
try:
db.rollback()
except Exception:
pass
except Exception as e:
logger.warning(f"回填日期 {current_date} 失败: {e}")
missing_start_date = yesterday_business_date - timedelta(
days=max_backfill_days - 1
)
missing_days = max_backfill_days
logger.info(
f"检测到缺失 {missing_days} 天的统计数据 "
f"({missing_start_date} ~ {yesterday_business_date}),开始回填..."
)
current_date = missing_start_date
users = (
db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
)
while current_date <= yesterday_business_date:
try:
db.rollback()
except Exception:
pass
current_date_local = datetime.combine(
current_date, datetime.min.time(), tzinfo=app_tz
)
StatsAggregatorService.aggregate_daily_stats(
db, current_date_local
)
for (user_id,) in users:
try:
StatsAggregatorService.aggregate_user_daily_stats(
db, user_id, current_date_local
)
except Exception as e:
logger.warning(
f"回填用户 {user_id} 日期 {current_date} 失败: {e}"
)
try:
db.rollback()
except Exception:
pass
except Exception as e:
logger.warning(f"回填日期 {current_date} 失败: {e}")
try:
db.rollback()
except Exception:
pass
current_date += timedelta(days=1)
current_date += timedelta(days=1)
# 更新全局汇总
StatsAggregatorService.update_summary(db)
logger.info(f"缺失数据回填完成,共 {missing_days}")
else:
logger.info("统计数据已是最新,无需回填")
return
StatsAggregatorService.update_summary(db)
logger.info(f"缺失数据回填完成,共 {missing_days}")
else:
logger.info("统计数据已是最新,无需回填")
return
# 定时任务:聚合昨天的数据
# 注意aggregate_daily_stats 期望业务时区的日期,不是 UTC
yesterday_local = today_local - timedelta(days=1)
# 定时任务:聚合昨天的数据
yesterday_local = today_local - timedelta(days=1)
StatsAggregatorService.aggregate_daily_stats(db, yesterday_local)
StatsAggregatorService.aggregate_daily_stats(db, yesterday_local)
# 聚合所有用户的昨日数据
users = db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
for (user_id,) in users:
try:
StatsAggregatorService.aggregate_user_daily_stats(db, user_id, yesterday_local)
except Exception as e:
logger.warning(f"聚合用户 {user_id} 统计数据失败: {e}")
# 回滚当前用户的失败操作,继续处理其他用户
users = db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
for (user_id,) in users:
try:
db.rollback()
except Exception:
pass
StatsAggregatorService.aggregate_user_daily_stats(
db, user_id, yesterday_local
)
except Exception as e:
logger.warning(f"聚合用户 {user_id} 统计数据失败: {e}")
try:
db.rollback()
except Exception:
pass
# 更新全局汇总
StatsAggregatorService.update_summary(db)
StatsAggregatorService.update_summary(db)
logger.info("统计数据聚合完成")
logger.info("统计数据聚合完成")
except Exception as e:
logger.exception(f"统计聚合任务执行失败: {e}")
db.rollback()
finally:
db.close()
except Exception as e:
logger.exception(f"统计聚合任务执行失败: {e}")
try:
db.rollback()
except Exception:
pass
finally:
db.close()
async def _perform_pending_cleanup(self):
"""执行 pending 状态清理"""

View File

@@ -56,65 +56,44 @@ class StatsAggregatorService:
"""统计数据聚合服务"""
@staticmethod
def aggregate_daily_stats(db: Session, date: datetime) -> StatsDaily:
"""聚合指定日期的统计数据
Args:
db: 数据库会话
date: 要聚合的业务日期(使用 APP_TIMEZONE 时区)
Returns:
StatsDaily 记录
"""
# 将业务日期转换为 UTC 时间范围
def compute_daily_stats(db: Session, date: datetime) -> dict:
"""计算指定业务日期的统计数据(不写入数据库)"""
day_start, day_end = _get_business_day_range(date)
# stats_daily.date 存储的是业务日期对应的 UTC 开始时间
# 检查是否已存在该日期的记录
existing = db.query(StatsDaily).filter(StatsDaily.date == day_start).first()
if existing:
stats = existing
else:
stats = StatsDaily(id=str(uuid.uuid4()), date=day_start)
# 基础请求统计
base_query = db.query(Usage).filter(
and_(Usage.created_at >= day_start, Usage.created_at < day_end)
)
total_requests = base_query.count()
# 如果没有请求,直接返回空记录
if total_requests == 0:
stats.total_requests = 0
stats.success_requests = 0
stats.error_requests = 0
stats.input_tokens = 0
stats.output_tokens = 0
stats.cache_creation_tokens = 0
stats.cache_read_tokens = 0
stats.total_cost = 0.0
stats.actual_total_cost = 0.0
stats.input_cost = 0.0
stats.output_cost = 0.0
stats.cache_creation_cost = 0.0
stats.cache_read_cost = 0.0
stats.avg_response_time_ms = 0.0
stats.fallback_count = 0
return {
"day_start": day_start,
"total_requests": 0,
"success_requests": 0,
"error_requests": 0,
"input_tokens": 0,
"output_tokens": 0,
"cache_creation_tokens": 0,
"cache_read_tokens": 0,
"total_cost": 0.0,
"actual_total_cost": 0.0,
"input_cost": 0.0,
"output_cost": 0.0,
"cache_creation_cost": 0.0,
"cache_read_cost": 0.0,
"avg_response_time_ms": 0.0,
"fallback_count": 0,
"unique_models": 0,
"unique_providers": 0,
}
if not existing:
db.add(stats)
db.commit()
return stats
# 错误请求数
error_requests = (
base_query.filter(
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
).count()
)
# Token 和成本聚合
aggregated = (
db.query(
func.sum(Usage.input_tokens).label("input_tokens"),
@@ -157,7 +136,6 @@ class StatsAggregatorService:
or 0
)
# 使用维度统计
unique_models = (
db.query(func.count(func.distinct(Usage.model)))
.filter(and_(Usage.created_at >= day_start, Usage.created_at < day_end))
@@ -171,31 +149,74 @@ class StatsAggregatorService:
or 0
)
return {
"day_start": day_start,
"total_requests": total_requests,
"success_requests": total_requests - error_requests,
"error_requests": error_requests,
"input_tokens": int(aggregated.input_tokens or 0) if aggregated else 0,
"output_tokens": int(aggregated.output_tokens or 0) if aggregated else 0,
"cache_creation_tokens": int(aggregated.cache_creation_tokens or 0) if aggregated else 0,
"cache_read_tokens": int(aggregated.cache_read_tokens or 0) if aggregated else 0,
"total_cost": float(aggregated.total_cost or 0) if aggregated else 0.0,
"actual_total_cost": float(aggregated.actual_total_cost or 0) if aggregated else 0.0,
"input_cost": float(aggregated.input_cost or 0) if aggregated else 0.0,
"output_cost": float(aggregated.output_cost or 0) if aggregated else 0.0,
"cache_creation_cost": float(aggregated.cache_creation_cost or 0) if aggregated else 0.0,
"cache_read_cost": float(aggregated.cache_read_cost or 0) if aggregated else 0.0,
"avg_response_time_ms": float(aggregated.avg_response_time or 0) if aggregated else 0.0,
"fallback_count": fallback_count,
"unique_models": unique_models,
"unique_providers": unique_providers,
}
@staticmethod
def aggregate_daily_stats(db: Session, date: datetime) -> StatsDaily:
"""聚合指定日期的统计数据
Args:
db: 数据库会话
date: 要聚合的业务日期(使用 APP_TIMEZONE 时区)
Returns:
StatsDaily 记录
"""
computed = StatsAggregatorService.compute_daily_stats(db, date)
day_start = computed["day_start"]
# stats_daily.date 存储的是业务日期对应的 UTC 开始时间
# 检查是否已存在该日期的记录
existing = db.query(StatsDaily).filter(StatsDaily.date == day_start).first()
if existing:
stats = existing
else:
stats = StatsDaily(id=str(uuid.uuid4()), date=day_start)
# 更新统计记录
stats.total_requests = total_requests
stats.success_requests = total_requests - error_requests
stats.error_requests = error_requests
stats.input_tokens = int(aggregated.input_tokens or 0)
stats.output_tokens = int(aggregated.output_tokens or 0)
stats.cache_creation_tokens = int(aggregated.cache_creation_tokens or 0)
stats.cache_read_tokens = int(aggregated.cache_read_tokens or 0)
stats.total_cost = float(aggregated.total_cost or 0)
stats.actual_total_cost = float(aggregated.actual_total_cost or 0)
stats.input_cost = float(aggregated.input_cost or 0)
stats.output_cost = float(aggregated.output_cost or 0)
stats.cache_creation_cost = float(aggregated.cache_creation_cost or 0)
stats.cache_read_cost = float(aggregated.cache_read_cost or 0)
stats.avg_response_time_ms = float(aggregated.avg_response_time or 0)
stats.fallback_count = fallback_count
stats.unique_models = unique_models
stats.unique_providers = unique_providers
stats.total_requests = computed["total_requests"]
stats.success_requests = computed["success_requests"]
stats.error_requests = computed["error_requests"]
stats.input_tokens = computed["input_tokens"]
stats.output_tokens = computed["output_tokens"]
stats.cache_creation_tokens = computed["cache_creation_tokens"]
stats.cache_read_tokens = computed["cache_read_tokens"]
stats.total_cost = computed["total_cost"]
stats.actual_total_cost = computed["actual_total_cost"]
stats.input_cost = computed["input_cost"]
stats.output_cost = computed["output_cost"]
stats.cache_creation_cost = computed["cache_creation_cost"]
stats.cache_read_cost = computed["cache_read_cost"]
stats.avg_response_time_ms = computed["avg_response_time_ms"]
stats.fallback_count = computed["fallback_count"]
stats.unique_models = computed["unique_models"]
stats.unique_providers = computed["unique_providers"]
if not existing:
db.add(stats)
db.commit()
# 日志使用业务日期(输入参数),而不是 UTC 日期
logger.info(f"[StatsAggregator] 聚合日期 {date.date()} 完成: {total_requests} 请求")
logger.info(f"[StatsAggregator] 聚合日期 {date.date()} 完成: {computed['total_requests']} 请求")
return stats
@staticmethod