Initial commit

This commit is contained in:
fawney19
2025-12-10 20:52:44 +08:00
commit f784106826
485 changed files with 110993 additions and 0 deletions

View File

@@ -0,0 +1,23 @@
"""
系统服务模块
包含系统配置、审计日志、公告等功能。
"""
from src.services.system.announcement import AnnouncementService
from src.services.system.audit import AuditService
from src.services.system.cleanup_scheduler import CleanupScheduler
from src.services.system.config import SystemConfigService
from src.services.system.scheduler import APP_TIMEZONE, TaskScheduler, get_scheduler
from src.services.system.sync_stats import SyncStatsService
__all__ = [
"SystemConfigService",
"AuditService",
"AnnouncementService",
"CleanupScheduler",
"SyncStatsService",
"TaskScheduler",
"get_scheduler",
"APP_TIMEZONE",
]

View File

@@ -0,0 +1,241 @@
"""
公告系统服务
"""
from datetime import datetime, timezone
from typing import List, Optional
from sqlalchemy import and_, or_
from sqlalchemy.orm import Session
from src.core.exceptions import ForbiddenException, NotFoundException
from src.core.logger import logger
from src.models.database import Announcement, AnnouncementRead, User, UserRole
class AnnouncementService:
"""公告系统服务"""
@staticmethod
def create_announcement(
db: Session,
author_id: str, # UUID
title: str,
content: str,
type: str = "info",
priority: int = 0,
is_pinned: bool = False,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
) -> Announcement:
"""创建公告"""
# 验证作者是否为管理员
author = db.query(User).filter(User.id == author_id).first()
if not author or author.role != UserRole.ADMIN:
raise ForbiddenException("Only administrators can create announcements")
# 验证类型
if type not in ["info", "warning", "maintenance", "important"]:
raise ValueError("Invalid announcement type")
announcement = Announcement(
title=title,
content=content,
type=type,
priority=priority,
author_id=author_id,
is_pinned=is_pinned,
start_time=start_time,
end_time=end_time,
is_active=True,
)
db.add(announcement)
db.commit()
db.refresh(announcement)
logger.info(f"Created announcement: {announcement.id} - {title}")
return announcement
@staticmethod
def get_announcements(
db: Session,
user_id: Optional[str] = None, # UUID
active_only: bool = True,
include_read_status: bool = False,
limit: int = 50,
offset: int = 0,
) -> dict:
"""获取公告列表"""
query = db.query(Announcement)
# 筛选条件
if active_only:
now = datetime.now(timezone.utc)
query = query.filter(
Announcement.is_active == True,
or_(Announcement.start_time == None, Announcement.start_time <= now),
or_(Announcement.end_time == None, Announcement.end_time >= now),
)
# 排序:置顶优先,然后按优先级和创建时间
query = query.order_by(
Announcement.is_pinned.desc(),
Announcement.priority.desc(),
Announcement.created_at.desc(),
)
# 分页
total = query.count()
announcements = query.offset(offset).limit(limit).all()
# 获取已读状态
read_announcement_ids = set()
unread_count = 0
if user_id and include_read_status:
read_records = (
db.query(AnnouncementRead.announcement_id)
.filter(AnnouncementRead.user_id == user_id)
.all()
)
read_announcement_ids = {r[0] for r in read_records}
unread_count = total - len(read_announcement_ids)
# 构建返回数据
items = []
for announcement in announcements:
item = {
"id": announcement.id,
"title": announcement.title,
"content": announcement.content,
"type": announcement.type,
"priority": announcement.priority,
"is_pinned": announcement.is_pinned,
"is_active": announcement.is_active,
"author": {"id": announcement.author.id, "username": announcement.author.username},
"start_time": announcement.start_time,
"end_time": announcement.end_time,
"created_at": announcement.created_at,
"updated_at": announcement.updated_at,
}
if include_read_status and user_id:
item["is_read"] = announcement.id in read_announcement_ids
items.append(item)
result = {"items": items, "total": total}
if include_read_status and user_id:
result["unread_count"] = unread_count
return result
@staticmethod
def get_announcement(db: Session, announcement_id: str) -> Announcement: # UUID
"""获取单个公告"""
announcement = db.query(Announcement).filter(Announcement.id == announcement_id).first()
if not announcement:
raise NotFoundException("Announcement not found")
return announcement
@staticmethod
def update_announcement(
db: Session,
announcement_id: str, # UUID
user_id: str, # UUID
title: Optional[str] = None,
content: Optional[str] = None,
type: Optional[str] = None,
priority: Optional[int] = None,
is_active: Optional[bool] = None,
is_pinned: Optional[bool] = None,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
) -> Announcement:
"""更新公告"""
# 验证用户是否为管理员
user = db.query(User).filter(User.id == user_id).first()
if not user or user.role != UserRole.ADMIN:
raise ForbiddenException("Only administrators can update announcements")
announcement = AnnouncementService.get_announcement(db, announcement_id)
# 更新提供的字段
if title is not None:
announcement.title = title
if content is not None:
announcement.content = content
if type is not None:
if type not in ["info", "warning", "maintenance", "important"]:
raise ValueError("Invalid announcement type")
announcement.type = type
if priority is not None:
announcement.priority = priority
if is_active is not None:
announcement.is_active = is_active
if is_pinned is not None:
announcement.is_pinned = is_pinned
if start_time is not None:
announcement.start_time = start_time
if end_time is not None:
announcement.end_time = end_time
db.commit()
db.refresh(announcement)
logger.info(f"Updated announcement: {announcement_id}")
return announcement
@staticmethod
def delete_announcement(db: Session, announcement_id: str, user_id: str) -> None: # UUID
"""删除公告"""
# 验证用户是否为管理员
user = db.query(User).filter(User.id == user_id).first()
if not user or user.role != UserRole.ADMIN:
raise ForbiddenException("Only administrators can delete announcements")
announcement = AnnouncementService.get_announcement(db, announcement_id)
db.delete(announcement)
db.commit()
logger.info(f"Deleted announcement: {announcement_id}")
@staticmethod
def mark_as_read(db: Session, announcement_id: str, user_id: str) -> None: # UUID
"""标记公告为已读"""
# 检查公告是否存在
announcement = AnnouncementService.get_announcement(db, announcement_id)
# 检查是否已经标记为已读
existing = (
db.query(AnnouncementRead)
.filter(
AnnouncementRead.user_id == user_id,
AnnouncementRead.announcement_id == announcement_id,
)
.first()
)
if not existing:
read_record = AnnouncementRead(user_id=user_id, announcement_id=announcement_id)
db.add(read_record)
db.commit()
logger.info(f"User {user_id} marked announcement {announcement_id} as read")
@staticmethod
def get_active_announcements(db: Session, user_id: Optional[str] = None) -> dict: # UUID
"""获取当前有效的公告(首页展示用)"""
return AnnouncementService.get_announcements(
db=db,
user_id=user_id,
active_only=True,
include_read_status=True if user_id else False,
limit=10,
)

View File

@@ -0,0 +1,459 @@
"""
审计日志服务
记录所有重要操作和安全事件
"""
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session
from src.core.logger import logger
from src.database import get_db
from src.models.database import AuditEventType, AuditLog
from src.utils.transaction_manager import transactional
# 审计模型已移至 src/models/database.py
class AuditService:
"""审计服务"""
@staticmethod
@transactional(commit=False) # 不自动提交,让调用方决定
def log_event(
db: Session,
event_type: AuditEventType,
description: str,
user_id: Optional[str] = None, # UUID
api_key_id: Optional[str] = None, # UUID
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
request_id: Optional[str] = None,
status_code: Optional[int] = None,
error_message: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> AuditLog:
"""
记录审计事件
Args:
db: 数据库会话
event_type: 事件类型
description: 事件描述
user_id: 用户ID
api_key_id: API密钥ID
ip_address: IP地址
user_agent: 用户代理
request_id: 请求ID
status_code: 状态码
error_message: 错误消息
metadata: 额外元数据
Returns:
审计日志记录
"""
try:
audit_log = AuditLog(
event_type=event_type.value,
description=description,
user_id=user_id,
api_key_id=api_key_id,
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
status_code=status_code,
error_message=error_message,
event_metadata=metadata,
)
db.add(audit_log)
db.commit() # 立即提交事务,释放数据库锁
db.refresh(audit_log)
# 同时记录到系统日志
log_message = (
f"AUDIT [{event_type.value}] - {description} | "
f"user_id={user_id}, ip={ip_address}"
)
if event_type in [
AuditEventType.UNAUTHORIZED_ACCESS,
AuditEventType.SUSPICIOUS_ACTIVITY,
]:
logger.warning(log_message)
elif event_type in [AuditEventType.LOGIN_FAILED, AuditEventType.REQUEST_FAILED]:
logger.info(log_message)
else:
logger.debug(log_message)
return audit_log
except Exception as e:
logger.error(f"Failed to log audit event: {e}")
db.rollback()
return None
@staticmethod
def log_login_attempt(
db: Session,
email: str,
success: bool,
ip_address: str,
user_agent: str,
user_id: Optional[str] = None, # UUID
error_reason: Optional[str] = None,
):
"""
记录登录尝试
Args:
db: 数据库会话
email: 登录邮箱
success: 是否成功
ip_address: IP地址
user_agent: 用户代理
user_id: 用户ID成功时
error_reason: 失败原因
"""
event_type = AuditEventType.LOGIN_SUCCESS if success else AuditEventType.LOGIN_FAILED
description = f"Login attempt for {email}"
if not success and error_reason:
description += f": {error_reason}"
AuditService.log_event(
db=db,
event_type=event_type,
description=description,
user_id=user_id,
ip_address=ip_address,
user_agent=user_agent,
metadata={"email": email},
)
@staticmethod
def log_api_request(
db: Session,
user_id: str, # UUID
api_key_id: str, # UUID
request_id: str,
model: str,
provider: str,
success: bool,
ip_address: str,
status_code: int,
error_message: Optional[str] = None,
input_tokens: Optional[int] = None,
output_tokens: Optional[int] = None,
cost_usd: Optional[float] = None,
):
"""
记录API请求
Args:
db: 数据库会话
user_id: 用户ID
api_key_id: API密钥ID
request_id: 请求ID
model: 模型名称
provider: 提供商名称
success: 是否成功
ip_address: IP地址
status_code: 状态码
error_message: 错误消息
input_tokens: 输入tokens
output_tokens: 输出tokens
cost_usd: 成本(美元)
"""
event_type = AuditEventType.REQUEST_SUCCESS if success else AuditEventType.REQUEST_FAILED
description = f"API request to {provider}/{model}"
metadata = {"model": model, "provider": provider}
if input_tokens:
metadata["input_tokens"] = input_tokens
if output_tokens:
metadata["output_tokens"] = output_tokens
if cost_usd:
metadata["cost_usd"] = cost_usd
AuditService.log_event(
db=db,
event_type=event_type,
description=description,
user_id=user_id,
api_key_id=api_key_id,
request_id=request_id,
ip_address=ip_address,
status_code=status_code,
error_message=error_message,
metadata=metadata,
)
@staticmethod
def log_security_event(
db: Session,
event_type: AuditEventType,
description: str,
ip_address: str,
user_id: Optional[str] = None, # UUID
severity: str = "medium",
details: Optional[Dict[str, Any]] = None,
):
"""
记录安全事件
Args:
db: 数据库会话
event_type: 事件类型
description: 事件描述
ip_address: IP地址
user_id: 用户ID
severity: 严重程度 (low, medium, high, critical)
details: 详细信息
"""
event_metadata = {"severity": severity}
if details:
event_metadata.update(details)
AuditService.log_event(
db=db,
event_type=event_type,
description=description,
user_id=user_id,
ip_address=ip_address,
metadata=event_metadata,
)
# 对于高严重性事件,简化日志输出
if severity in ["high", "critical"]:
logger.error(f"安全告警 [{severity.upper()}]: {description}")
@staticmethod
def get_user_audit_logs(
db: Session,
user_id: str, # UUID
event_types: Optional[List[AuditEventType]] = None,
limit: int = 100,
) -> List[AuditLog]:
"""
获取用户的审计日志
Args:
db: 数据库会话
user_id: 用户ID
event_types: 事件类型过滤
limit: 返回数量限制
Returns:
审计日志列表
"""
query = db.query(AuditLog).filter(AuditLog.user_id == user_id)
if event_types:
event_type_values = [et.value for et in event_types]
query = query.filter(AuditLog.event_type.in_(event_type_values))
return query.order_by(AuditLog.created_at.desc()).limit(limit).all()
@staticmethod
def get_suspicious_activities(db: Session, hours: int = 24, limit: int = 100) -> List[AuditLog]:
"""
获取可疑活动
Args:
db: 数据库会话
hours: 时间范围(小时)
limit: 返回数量限制
Returns:
可疑活动列表
"""
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours)
suspicious_types = [
AuditEventType.SUSPICIOUS_ACTIVITY.value,
AuditEventType.UNAUTHORIZED_ACCESS.value,
AuditEventType.LOGIN_FAILED.value,
AuditEventType.REQUEST_RATE_LIMITED.value,
]
return (
db.query(AuditLog)
.filter(AuditLog.event_type.in_(suspicious_types), AuditLog.created_at >= cutoff_time)
.order_by(AuditLog.created_at.desc())
.limit(limit)
.all()
)
@staticmethod
def analyze_user_behavior(db: Session, user_id: str, days: int = 30) -> Dict[str, Any]: # UUID
"""
分析用户行为
Args:
db: 数据库会话
user_id: 用户ID
days: 分析天数
Returns:
行为分析结果
"""
from sqlalchemy import func
cutoff_time = datetime.now(timezone.utc) - timedelta(days=days)
# 统计各种事件类型
event_counts = (
db.query(AuditLog.event_type, func.count(AuditLog.id).label("count"))
.filter(AuditLog.user_id == user_id, AuditLog.created_at >= cutoff_time)
.group_by(AuditLog.event_type)
.all()
)
# 统计失败请求
failed_requests = (
db.query(func.count(AuditLog.id))
.filter(
AuditLog.user_id == user_id,
AuditLog.event_type == AuditEventType.REQUEST_FAILED.value,
AuditLog.created_at >= cutoff_time,
)
.scalar()
)
# 统计成功请求
success_requests = (
db.query(func.count(AuditLog.id))
.filter(
AuditLog.user_id == user_id,
AuditLog.event_type == AuditEventType.REQUEST_SUCCESS.value,
AuditLog.created_at >= cutoff_time,
)
.scalar()
)
# 获取最近的可疑活动
recent_suspicious = (
db.query(AuditLog)
.filter(
AuditLog.user_id == user_id,
AuditLog.event_type.in_(
[
AuditEventType.SUSPICIOUS_ACTIVITY.value,
AuditEventType.UNAUTHORIZED_ACCESS.value,
]
),
AuditLog.created_at >= cutoff_time,
)
.count()
)
return {
"user_id": user_id,
"period_days": days,
"event_counts": {event: count for event, count in event_counts},
"failed_requests": failed_requests or 0,
"success_requests": success_requests or 0,
"success_rate": (
success_requests / (success_requests + failed_requests)
if (success_requests + failed_requests) > 0
else 0
),
"suspicious_activities": recent_suspicious,
"analysis_time": datetime.now(timezone.utc).isoformat(),
}
@staticmethod
def log_event_auto(
event_type: AuditEventType,
description: str,
user_id: Optional[str] = None,
api_key_id: Optional[str] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
request_id: Optional[str] = None,
status_code: Optional[int] = None,
error_message: Optional[str] = None,
event_metadata: Optional[Dict[str, Any]] = None,
db: Optional[Session] = None,
) -> Optional[AuditLog]:
"""
自动管理数据库会话的审计日志记录方法
适用于中间件等无法直接获取数据库会话的场景
Args:
event_type: 事件类型
description: 事件描述
user_id: 用户ID
api_key_id: API密钥ID
ip_address: IP地址
user_agent: 用户代理
request_id: 请求ID
status_code: 状态码
error_message: 错误消息
event_metadata: 额外元数据
db: 数据库会话(可选,如不提供则自动创建)
Returns:
审计日志记录
"""
# 如果提供了数据库会话,使用它(不自动提交)
if db is not None:
try:
audit_log = AuditService.log_event(
db=db,
event_type=event_type,
description=description,
user_id=user_id,
api_key_id=api_key_id,
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
status_code=status_code,
error_message=error_message,
metadata=event_metadata,
)
# 注意:不在这里提交,让调用方决定何时提交
return audit_log
except Exception as e:
logger.error(f"Failed to log audit event: {e}")
return None
# 如果没有提供会话,自动创建并管理
db_session = None
try:
db_session = next(get_db())
audit_log = AuditService.log_event(
db=db_session,
event_type=event_type,
description=description,
user_id=user_id,
api_key_id=api_key_id,
ip_address=ip_address,
user_agent=user_agent,
request_id=request_id,
status_code=status_code,
error_message=error_message,
metadata=event_metadata,
)
db_session.commit()
return audit_log
except Exception as e:
logger.error(f"Failed to log audit event with auto session: {e}")
if db_session is not None:
db_session.rollback()
return None
finally:
if db_session is not None:
db_session.close()
# 全局审计服务实例
audit_service = AuditService()

View File

@@ -0,0 +1,597 @@
"""
使用记录清理定时任务
分级清理策略:
- detail_log_retention_days: 压缩 request_body 和 response_body 到压缩字段
- header_retention_days: 清空 request_headers 和 response_headers
- log_retention_days: 删除整条记录
统计聚合任务:
- 每天凌晨聚合前一天的统计数据
- 更新全局统计汇总
使用 APScheduler 进行任务调度,支持时区配置。
"""
import asyncio
from datetime import datetime, timedelta, timezone
from sqlalchemy import delete
from sqlalchemy.orm import Session
from src.core.logger import logger
from src.database import create_session
from src.models.database import Usage
from src.services.system.config import SystemConfigService
from src.services.system.scheduler import get_scheduler
from src.services.system.stats_aggregator import StatsAggregatorService
from src.services.user.apikey import ApiKeyService
from src.utils.compression import compress_json
class CleanupScheduler:
"""使用记录清理调度器"""
def __init__(self):
self.running = False
self._interval_tasks = []
async def start(self):
"""启动调度器"""
if self.running:
logger.warning("Cleanup scheduler already running")
return
self.running = True
logger.info("使用记录清理调度器已启动")
scheduler = get_scheduler()
# 注册定时任务(使用业务时区)
# 统计聚合任务 - 凌晨 1 点执行
scheduler.add_cron_job(
self._scheduled_stats_aggregation,
hour=1,
minute=0,
job_id="stats_aggregation",
name="统计数据聚合",
)
# 清理任务 - 凌晨 3 点执行
scheduler.add_cron_job(
self._scheduled_cleanup,
hour=3,
minute=0,
job_id="usage_cleanup",
name="使用记录清理",
)
# 连接池监控 - 每 5 分钟
scheduler.add_interval_job(
self._scheduled_monitor,
minutes=5,
job_id="pool_monitor",
name="连接池监控",
)
# Pending 状态清理 - 每 5 分钟
scheduler.add_interval_job(
self._scheduled_pending_cleanup,
minutes=5,
job_id="pending_cleanup",
name="Pending状态清理",
)
# 启动时执行一次初始化任务
asyncio.create_task(self._run_startup_tasks())
async def _run_startup_tasks(self):
"""启动时执行的初始化任务"""
# 延迟一点执行,确保系统完全启动
await asyncio.sleep(2)
try:
logger.info("启动时执行首次清理任务...")
await self._perform_cleanup()
except Exception as e:
logger.exception(f"启动时清理任务执行出错: {e}")
try:
logger.info("启动时检查统计数据...")
await self._perform_stats_aggregation(backfill=True)
except Exception as e:
logger.exception(f"启动时统计聚合任务出错: {e}")
async def stop(self):
"""停止调度器"""
if not self.running:
return
self.running = False
scheduler = get_scheduler()
scheduler.stop()
logger.info("使用记录清理调度器已停止")
# ========== 任务函数APScheduler 直接调用异步函数) ==========
async def _scheduled_stats_aggregation(self):
"""统计聚合任务(定时调用)"""
await self._perform_stats_aggregation()
async def _scheduled_cleanup(self):
"""清理任务(定时调用)"""
await self._perform_cleanup()
async def _scheduled_monitor(self):
"""监控任务(定时调用)"""
try:
from src.database import log_pool_status
log_pool_status()
except Exception as e:
logger.exception(f"连接池监控任务出错: {e}")
async def _scheduled_pending_cleanup(self):
"""Pending 清理任务(定时调用)"""
await self._perform_pending_cleanup()
# ========== 实际任务实现 ==========
async def _perform_stats_aggregation(self, backfill: bool = False):
"""执行统计聚合任务
Args:
backfill: 是否回填历史数据(首次启动时使用)
"""
db = create_session()
try:
# 检查是否启用统计聚合
if not SystemConfigService.get_config(db, "enable_stats_aggregation", True):
logger.info("统计聚合已禁用,跳过聚合任务")
return
logger.info("开始执行统计数据聚合...")
if backfill:
# 首次启动时回填历史数据
from src.models.database import StatsSummary
summary = db.query(StatsSummary).first()
if not summary:
logger.info("检测到首次运行,开始回填历史统计数据...")
days_to_backfill = SystemConfigService.get_config(
db, "stats_backfill_days", 365
)
count = StatsAggregatorService.backfill_historical_data(
db, days=days_to_backfill
)
logger.info(f"历史数据回填完成,共 {count}")
return
# 聚合昨天的数据
now = datetime.now(timezone.utc)
yesterday = (now - timedelta(days=1)).replace(
hour=0, minute=0, second=0, microsecond=0
)
StatsAggregatorService.aggregate_daily_stats(db, yesterday)
# 聚合所有用户的昨日数据
from src.models.database import User as DBUser
users = db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
for (user_id,) in users:
try:
StatsAggregatorService.aggregate_user_daily_stats(db, user_id, yesterday)
except Exception as e:
logger.warning(f"聚合用户 {user_id} 统计数据失败: {e}")
# 回滚当前用户的失败操作,继续处理其他用户
try:
db.rollback()
except Exception:
pass
# 更新全局汇总
StatsAggregatorService.update_summary(db)
logger.info("统计数据聚合完成")
except Exception as e:
logger.exception(f"统计聚合任务执行失败: {e}")
db.rollback()
finally:
db.close()
async def _perform_pending_cleanup(self):
"""执行 pending 状态清理"""
db = create_session()
try:
from src.services.usage.service import UsageService
# 获取配置的超时时间(默认 10 分钟)
timeout_minutes = SystemConfigService.get_config(
db, "pending_request_timeout_minutes", 10
)
# 执行清理
cleaned_count = UsageService.cleanup_stale_pending_requests(
db, timeout_minutes=timeout_minutes
)
if cleaned_count > 0:
logger.info(f"清理了 {cleaned_count} 条超时的 pending/streaming 请求")
except Exception as e:
logger.exception(f"清理 pending 请求失败: {e}")
db.rollback()
finally:
db.close()
async def _perform_cleanup(self):
"""执行清理任务"""
db = create_session()
try:
# 检查是否启用自动清理
if not SystemConfigService.get_config(db, "enable_auto_cleanup", True):
logger.info("自动清理已禁用,跳过清理任务")
return
logger.info("开始执行使用记录分级清理...")
# 获取配置参数
detail_retention = SystemConfigService.get_config(db, "detail_log_retention_days", 7)
compressed_retention = SystemConfigService.get_config(
db, "compressed_log_retention_days", 90
)
header_retention = SystemConfigService.get_config(db, "header_retention_days", 90)
log_retention = SystemConfigService.get_config(db, "log_retention_days", 365)
batch_size = SystemConfigService.get_config(db, "cleanup_batch_size", 1000)
now = datetime.now(timezone.utc)
# 1. 压缩详细日志 (body 字段 -> 压缩字段)
detail_cutoff = now - timedelta(days=detail_retention)
body_compressed = await self._cleanup_body_fields(db, detail_cutoff, batch_size)
# 2. 清理压缩字段90天后
compressed_cutoff = now - timedelta(days=compressed_retention)
compressed_cleaned = await self._cleanup_compressed_fields(
db, compressed_cutoff, batch_size
)
# 3. 清理请求头
header_cutoff = now - timedelta(days=header_retention)
header_cleaned = await self._cleanup_header_fields(db, header_cutoff, batch_size)
# 4. 删除过期记录
log_cutoff = now - timedelta(days=log_retention)
records_deleted = await self._delete_old_records(db, log_cutoff, batch_size)
# 5. 清理过期的API Keys
auto_delete = SystemConfigService.get_config(db, "auto_delete_expired_keys", False)
keys_cleaned = ApiKeyService.cleanup_expired_keys(db, auto_delete=auto_delete)
logger.info(
f"清理完成: 压缩 {body_compressed} 条, "
f"清理压缩字段 {compressed_cleaned} 条, "
f"清理header {header_cleaned} 条, "
f"删除记录 {records_deleted} 条, "
f"清理过期Keys {keys_cleaned}"
)
except Exception as e:
logger.exception(f"清理任务执行失败: {e}")
db.rollback()
finally:
db.close()
async def _cleanup_body_fields(
self, db: Session, cutoff_time: datetime, batch_size: int
) -> int:
"""压缩 request_body 和 response_body 字段到压缩字段
逐条处理,确保每条记录都正确更新
"""
from sqlalchemy import null, update
total_compressed = 0
no_progress_count = 0 # 连续无进展计数
processed_ids: set = set() # 记录已处理的 ID防止重复处理
while True:
batch_db = create_session()
try:
# 1. 查询需要压缩的记录
# 注意:排除已经是 NULL 或 JSON null 的记录
records = (
batch_db.query(Usage.id, Usage.request_body, Usage.response_body)
.filter(Usage.created_at < cutoff_time)
.filter((Usage.request_body.isnot(None)) | (Usage.response_body.isnot(None)))
.limit(batch_size)
.all()
)
if not records:
break
# 过滤掉实际值为 None 的记录JSON null 被解析为 Python None
valid_records = [
(rid, req, resp)
for rid, req, resp in records
if req is not None or resp is not None
]
if not valid_records:
# 所有记录都是 JSON null需要清理它们
logger.warning(
f"检测到 {len(records)} 条记录的 body 字段为 JSON null进行清理"
)
for record_id, _, _ in records:
batch_db.execute(
update(Usage)
.where(Usage.id == record_id)
.values(request_body=null(), response_body=null())
)
batch_db.commit()
continue
# 检测是否有重复的 ID说明更新未生效
current_ids = {r[0] for r in valid_records}
repeated_ids = current_ids & processed_ids
if repeated_ids:
logger.error(
f"检测到重复处理的记录 ID: {list(repeated_ids)[:5]}..."
"说明数据库更新未生效,终止循环"
)
break
batch_success = 0
# 2. 逐条更新(确保每条都正确处理)
for record_id, req_body, resp_body in valid_records:
try:
# 使用 null() 确保设置的是 SQL NULL 而不是 JSON null
result = batch_db.execute(
update(Usage)
.where(Usage.id == record_id)
.values(
request_body=null(),
response_body=null(),
request_body_compressed=compress_json(req_body)
if req_body
else None,
response_body_compressed=compress_json(resp_body)
if resp_body
else None,
)
)
if result.rowcount > 0:
batch_success += 1
processed_ids.add(record_id)
except Exception as e:
logger.warning(f"压缩记录 {record_id} 失败: {e}")
continue
batch_db.commit()
# 3. 检查是否有实际进展
if batch_success == 0:
no_progress_count += 1
if no_progress_count >= 3:
logger.error(
f"压缩 body 字段连续 {no_progress_count} 批无进展,"
"终止循环以避免死循环"
)
break
else:
no_progress_count = 0 # 重置计数
total_compressed += batch_success
logger.debug(
f"已压缩 {batch_success} 条记录的 body 字段,累计 {total_compressed}"
)
await asyncio.sleep(0.1)
except Exception as e:
logger.exception(f"压缩 body 字段失败: {e}")
try:
batch_db.rollback()
except Exception:
pass
break
finally:
batch_db.close()
return total_compressed
async def _cleanup_compressed_fields(
self, db: Session, cutoff_time: datetime, batch_size: int
) -> int:
"""清理压缩字段90天后删除压缩的body
每批使用短生命周期 session避免 ORM 缓存问题
"""
from sqlalchemy import null, update
total_cleaned = 0
while True:
batch_db = create_session()
try:
# 查询需要清理压缩字段的记录
records_to_clean = (
batch_db.query(Usage.id)
.filter(Usage.created_at < cutoff_time)
.filter(
(Usage.request_body_compressed.isnot(None))
| (Usage.response_body_compressed.isnot(None))
)
.limit(batch_size)
.all()
)
if not records_to_clean:
break
record_ids = [r.id for r in records_to_clean]
# 批量更新,使用 null() 确保设置 SQL NULL
result = batch_db.execute(
update(Usage)
.where(Usage.id.in_(record_ids))
.values(
request_body_compressed=null(),
response_body_compressed=null(),
)
)
rows_updated = result.rowcount
batch_db.commit()
if rows_updated == 0:
logger.warning("清理压缩字段: rowcount=0可能存在问题")
break
total_cleaned += rows_updated
logger.debug(f"已清理 {rows_updated} 条记录的压缩字段,累计 {total_cleaned}")
await asyncio.sleep(0.1)
except Exception as e:
logger.exception(f"清理压缩字段失败: {e}")
try:
batch_db.rollback()
except Exception:
pass
break
finally:
batch_db.close()
return total_cleaned
async def _cleanup_header_fields(
self, db: Session, cutoff_time: datetime, batch_size: int
) -> int:
"""清理 request_headers, response_headers 和 provider_request_headers 字段
每批使用短生命周期 session避免 ORM 缓存问题
"""
from sqlalchemy import null, update
total_cleaned = 0
while True:
batch_db = create_session()
try:
# 先查询需要清理的记录ID分批
records_to_clean = (
batch_db.query(Usage.id)
.filter(Usage.created_at < cutoff_time)
.filter(
(Usage.request_headers.isnot(None))
| (Usage.response_headers.isnot(None))
| (Usage.provider_request_headers.isnot(None))
)
.limit(batch_size)
.all()
)
if not records_to_clean:
break
record_ids = [r.id for r in records_to_clean]
# 批量更新,使用 null() 确保设置 SQL NULL
result = batch_db.execute(
update(Usage)
.where(Usage.id.in_(record_ids))
.values(
request_headers=null(),
response_headers=null(),
provider_request_headers=null(),
)
)
rows_updated = result.rowcount
batch_db.commit()
if rows_updated == 0:
logger.warning("清理 header 字段: rowcount=0可能存在问题")
break
total_cleaned += rows_updated
logger.debug(
f"已清理 {rows_updated} 条记录的 header 字段,累计 {total_cleaned}"
)
await asyncio.sleep(0.1)
except Exception as e:
logger.exception(f"清理 header 字段失败: {e}")
try:
batch_db.rollback()
except Exception:
pass
break
finally:
batch_db.close()
return total_cleaned
async def _delete_old_records(self, db: Session, cutoff_time: datetime, batch_size: int) -> int:
"""删除过期的完整记录"""
total_deleted = 0
while True:
try:
# 查询要删除的记录ID分批
records_to_delete = (
db.query(Usage.id)
.filter(Usage.created_at < cutoff_time)
.limit(batch_size)
.all()
)
if not records_to_delete:
break
record_ids = [r.id for r in records_to_delete]
# 执行删除
result = db.execute(
delete(Usage)
.where(Usage.id.in_(record_ids))
.execution_options(synchronize_session=False)
)
rows_deleted = result.rowcount
db.commit()
total_deleted += rows_deleted
logger.debug(f"已删除 {rows_deleted} 条过期记录,累计 {total_deleted}")
await asyncio.sleep(0.1)
except Exception as e:
logger.exception(f"删除过期记录失败: {e}")
try:
db.rollback()
except Exception:
pass
break
return total_deleted
# 全局单例
_cleanup_scheduler = None
def get_cleanup_scheduler() -> CleanupScheduler:
"""获取清理调度器单例"""
global _cleanup_scheduler
if _cleanup_scheduler is None:
_cleanup_scheduler = CleanupScheduler()
return _cleanup_scheduler

View File

@@ -0,0 +1,257 @@
"""
系统配置服务
"""
import json
from enum import Enum
from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session
from src.core.logger import logger
from src.models.database import Provider, SystemConfig
class LogLevel(str, Enum):
"""日志记录级别"""
BASIC = "basic" # 仅记录基本信息tokens、成本等
HEADERS = "headers" # 记录基本信息+请求/响应头(敏感信息会脱敏)
FULL = "full" # 记录完整请求和响应包含body敏感信息会脱敏
class SystemConfigService:
"""系统配置服务类"""
# 默认配置
DEFAULT_CONFIGS = {
"request_log_level": {
"value": LogLevel.BASIC.value,
"description": "请求记录级别basic(基本信息), headers(含请求头), full(完整请求响应)",
},
"max_request_body_size": {
"value": 1048576, # 1MB
"description": "最大请求体记录大小字节超过此大小的请求体将被截断仅影响数据库记录不影响真实API请求",
},
"max_response_body_size": {
"value": 1048576, # 1MB
"description": "最大响应体记录大小字节超过此大小的响应体将被截断仅影响数据库记录不影响真实API响应",
},
"sensitive_headers": {
"value": ["authorization", "x-api-key", "api-key", "cookie", "set-cookie"],
"description": "敏感请求头列表,这些请求头会被脱敏处理",
},
# 分级清理策略
"detail_log_retention_days": {
"value": 7,
"description": "详细日志保留天数,超过此天数后压缩 request_body 和 response_body 到压缩字段",
},
"compressed_log_retention_days": {
"value": 90,
"description": "压缩日志保留天数,超过此天数后删除压缩的 body 字段保留headers和统计",
},
"header_retention_days": {
"value": 90,
"description": "请求头保留天数,超过此天数后清空 request_headers 和 response_headers 字段",
},
"log_retention_days": {
"value": 365,
"description": "完整日志保留天数,超过此天数后删除整条记录(保留核心统计)",
},
"enable_auto_cleanup": {
"value": True,
"description": "是否启用自动清理任务,每天凌晨执行分级清理",
},
"cleanup_batch_size": {
"value": 1000,
"description": "每批次清理的记录数,避免单次操作过大影响数据库性能",
},
"provider_priority_mode": {
"value": "provider",
"description": "优先级策略provider(提供商优先模式) 或 global_key(全局Key优先模式)",
},
"auto_delete_expired_keys": {
"value": False,
"description": "是否自动删除过期的API KeyTrue=物理删除False=仅禁用),仅管理员可配置",
},
}
@classmethod
def get_config(cls, db: Session, key: str, default: Any = None) -> Optional[Any]:
"""获取系统配置值"""
config = db.query(SystemConfig).filter(SystemConfig.key == key).first()
if config:
return config.value
# 如果配置不存在,检查默认值
if key in cls.DEFAULT_CONFIGS:
return cls.DEFAULT_CONFIGS[key]["value"]
return default
@staticmethod
def set_config(db: Session, key: str, value: Any, description: str = None) -> SystemConfig:
"""设置系统配置值"""
config = db.query(SystemConfig).filter(SystemConfig.key == key).first()
if config:
# 更新现有配置
config.value = value
if description:
config.description = description
else:
# 创建新配置
config = SystemConfig(key=key, value=value, description=description)
db.add(config)
db.commit()
db.refresh(config)
return config
@staticmethod
def get_default_provider(db: Session) -> Optional[str]:
"""
获取系统默认提供商
优先级1. 管理员设置的默认提供商 2. 数据库中第一个可用提供商
"""
# 首先尝试获取管理员设置的默认提供商
default_provider = SystemConfigService.get_config(db, "default_provider")
if default_provider:
return default_provider
# 如果没有设置fallback到数据库中第一个可用提供商
first_provider = db.query(Provider).filter(Provider.is_active == True).first()
if first_provider:
return first_provider.name
return None
@staticmethod
def set_default_provider(db: Session, provider_name: str) -> SystemConfig:
"""设置系统默认提供商"""
return SystemConfigService.set_config(
db, "default_provider", provider_name, "系统默认提供商,当用户未设置个人提供商时使用"
)
@staticmethod
def get_all_configs(db: Session) -> list:
"""获取所有系统配置"""
configs = db.query(SystemConfig).all()
return [
{
"key": config.key,
"value": config.value,
"description": config.description,
"updated_at": config.updated_at.isoformat(),
}
for config in configs
]
@staticmethod
def delete_config(db: Session, key: str) -> bool:
"""删除系统配置"""
config = db.query(SystemConfig).filter(SystemConfig.key == key).first()
if config:
db.delete(config)
db.commit()
return True
return False
@classmethod
def init_default_configs(cls, db: Session):
"""初始化默认配置"""
for key, default_config in cls.DEFAULT_CONFIGS.items():
if not db.query(SystemConfig).filter(SystemConfig.key == key).first():
config = SystemConfig(
key=key,
value=default_config["value"],
description=default_config["description"],
)
db.add(config)
db.commit()
logger.info("初始化默认系统配置完成")
@classmethod
def get_log_level(cls, db: Session) -> LogLevel:
"""获取日志记录级别"""
level = cls.get_config(db, "request_log_level", LogLevel.BASIC.value)
if isinstance(level, str):
return LogLevel(level)
return level
@classmethod
def should_log_headers(cls, db: Session) -> bool:
"""是否应该记录请求头"""
log_level = cls.get_log_level(db)
return log_level in [LogLevel.HEADERS, LogLevel.FULL]
@classmethod
def should_log_body(cls, db: Session) -> bool:
"""是否应该记录请求体和响应体"""
log_level = cls.get_log_level(db)
return log_level == LogLevel.FULL
@classmethod
def should_mask_sensitive_data(cls, db: Session) -> bool:
"""是否应该脱敏敏感数据(始终脱敏)"""
_ = db # 保持接口一致性
return True
@classmethod
def get_sensitive_headers(cls, db: Session) -> list:
"""获取敏感请求头列表"""
return cls.get_config(db, "sensitive_headers", [])
@classmethod
def mask_sensitive_headers(cls, db: Session, headers: Dict[str, Any]) -> Dict[str, Any]:
"""脱敏敏感请求头"""
if not cls.should_mask_sensitive_data(db):
return headers
sensitive_headers = cls.get_sensitive_headers(db)
masked_headers = {}
for key, value in headers.items():
if key.lower() in [h.lower() for h in sensitive_headers]:
# 保留前后各4个字符中间用星号替换
if len(str(value)) > 8:
masked_value = str(value)[:4] + "****" + str(value)[-4:]
else:
masked_value = "****"
masked_headers[key] = masked_value
else:
masked_headers[key] = value
return masked_headers
@classmethod
def truncate_body(cls, db: Session, body: Any, is_request: bool = True) -> Any:
"""截断过大的请求体或响应体"""
max_size_key = "max_request_body_size" if is_request else "max_response_body_size"
max_size = cls.get_config(db, max_size_key, 102400)
if not body:
return body
# 转换为字符串以计算大小
body_str = json.dumps(body) if isinstance(body, (dict, list)) else str(body)
if len(body_str) > max_size:
# 截断并添加提示
truncated_str = body_str[:max_size]
if isinstance(body, (dict, list)):
try:
# 尝试保持JSON格式
return {
"_truncated": True,
"_original_size": len(body_str),
"_content": truncated_str,
}
except:
pass
return truncated_str + f"\n... (truncated, original size: {len(body_str)} bytes)"
return body

View File

@@ -0,0 +1,187 @@
"""
统一定时任务调度器
使用 APScheduler 管理所有定时任务,支持时区配置。
所有定时任务使用应用时区APP_TIMEZONE配置执行时间
数据存储仍然使用 UTC。
"""
import os
from datetime import datetime
from typing import Optional
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger
from apscheduler.triggers.interval import IntervalTrigger
from src.core.logger import logger
# 应用时区配置,默认为 Asia/Shanghai
APP_TIMEZONE = os.getenv("APP_TIMEZONE", "Asia/Shanghai")
class TaskScheduler:
"""统一定时任务调度器"""
_instance: Optional["TaskScheduler"] = None
def __init__(self):
self.scheduler = AsyncIOScheduler(timezone=APP_TIMEZONE)
self._started = False
@classmethod
def get_instance(cls) -> "TaskScheduler":
"""获取调度器单例"""
if cls._instance is None:
cls._instance = TaskScheduler()
return cls._instance
def add_cron_job(
self,
func,
hour: int,
minute: int = 0,
job_id: str = None,
name: str = None,
**kwargs,
):
"""
添加 cron 定时任务
Args:
func: 要执行的函数
hour: 执行时间(小时),使用业务时区
minute: 执行时间(分钟)
job_id: 任务ID
name: 任务名称(用于日志)
**kwargs: 传递给任务函数的参数
"""
trigger = CronTrigger(hour=hour, minute=minute, timezone=APP_TIMEZONE)
job_id = job_id or func.__name__
display_name = name or job_id
self.scheduler.add_job(
func,
trigger,
id=job_id,
name=display_name,
replace_existing=True,
kwargs=kwargs,
)
logger.info(
f"已注册定时任务: {display_name}, "
f"执行时间: {hour:02d}:{minute:02d} ({APP_TIMEZONE})"
)
def add_interval_job(
self,
func,
seconds: int = None,
minutes: int = None,
hours: int = None,
job_id: str = None,
name: str = None,
**kwargs,
):
"""
添加间隔执行任务
Args:
func: 要执行的函数
seconds: 间隔秒数
minutes: 间隔分钟数
hours: 间隔小时数
job_id: 任务ID
name: 任务名称
**kwargs: 传递给任务函数的参数
"""
# 构建 trigger 参数,过滤掉 None 值
trigger_kwargs = {}
if seconds is not None:
trigger_kwargs["seconds"] = seconds
if minutes is not None:
trigger_kwargs["minutes"] = minutes
if hours is not None:
trigger_kwargs["hours"] = hours
trigger = IntervalTrigger(**trigger_kwargs)
job_id = job_id or func.__name__
display_name = name or job_id
# 计算间隔描述
interval_parts = []
if hours:
interval_parts.append(f"{hours}小时")
if minutes:
interval_parts.append(f"{minutes}分钟")
if seconds:
interval_parts.append(f"{seconds}")
interval_desc = "".join(interval_parts) or "未知间隔"
self.scheduler.add_job(
func,
trigger,
id=job_id,
name=display_name,
replace_existing=True,
kwargs=kwargs,
)
logger.info(f"已注册间隔任务: {display_name}, 执行间隔: {interval_desc}")
def start(self):
"""启动调度器"""
if self._started:
logger.warning("调度器已在运行中")
return
self.scheduler.start()
self._started = True
logger.info(f"定时任务调度器已启动,应用时区: {APP_TIMEZONE}")
# 打印下次执行时间
self._log_next_run_times()
def stop(self):
"""停止调度器"""
if not self._started:
return
self.scheduler.shutdown(wait=False)
self._started = False
logger.info("定时任务调度器已停止")
def _log_next_run_times(self):
"""记录所有任务的下次执行时间"""
jobs = self.scheduler.get_jobs()
if not jobs:
return
logger.info("已注册的定时任务:")
for job in jobs:
next_run = job.next_run_time
if next_run:
# 计算距离下次执行的时间
now = datetime.now(next_run.tzinfo)
delta = next_run - now
hours, remainder = divmod(int(delta.total_seconds()), 3600)
minutes = remainder // 60
logger.info(
f" - {job.name}: 下次执行 {next_run.strftime('%Y-%m-%d %H:%M')} "
f"({hours}小时{minutes}分钟后)"
)
@property
def is_running(self) -> bool:
"""调度器是否在运行"""
return self._started
# 便捷函数
def get_scheduler() -> TaskScheduler:
"""获取调度器单例"""
return TaskScheduler.get_instance()

View File

@@ -0,0 +1,436 @@
"""统计数据聚合服务
实现预聚合统计,避免每次请求都全表扫描。
"""
import uuid
from datetime import datetime, timedelta, timezone
from typing import Optional
from sqlalchemy import and_, func
from sqlalchemy.orm import Session
from src.core.logger import logger
from src.models.database import (
ApiKey,
RequestCandidate,
StatsDaily,
StatsSummary,
StatsUserDaily,
Usage,
)
from src.models.database import User as DBUser
class StatsAggregatorService:
"""统计数据聚合服务"""
@staticmethod
def aggregate_daily_stats(db: Session, date: datetime) -> StatsDaily:
"""聚合指定日期的统计数据
Args:
db: 数据库会话
date: 要聚合的日期 (会自动转为 UTC 当天开始)
Returns:
StatsDaily 记录
"""
# 确保日期是 UTC 当天开始
day_start = date.replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=timezone.utc)
day_end = day_start + timedelta(days=1)
# 检查是否已存在该日期的记录
existing = db.query(StatsDaily).filter(StatsDaily.date == day_start).first()
if existing:
stats = existing
else:
stats = StatsDaily(id=str(uuid.uuid4()), date=day_start)
# 基础请求统计
base_query = db.query(Usage).filter(
and_(Usage.created_at >= day_start, Usage.created_at < day_end)
)
total_requests = base_query.count()
# 如果没有请求,直接返回空记录
if total_requests == 0:
stats.total_requests = 0
stats.success_requests = 0
stats.error_requests = 0
stats.input_tokens = 0
stats.output_tokens = 0
stats.cache_creation_tokens = 0
stats.cache_read_tokens = 0
stats.total_cost = 0.0
stats.actual_total_cost = 0.0
stats.input_cost = 0.0
stats.output_cost = 0.0
stats.cache_creation_cost = 0.0
stats.cache_read_cost = 0.0
stats.avg_response_time_ms = 0.0
stats.fallback_count = 0
if not existing:
db.add(stats)
db.commit()
return stats
# 错误请求数
error_requests = (
base_query.filter(
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
).count()
)
# Token 和成本聚合
aggregated = (
db.query(
func.sum(Usage.input_tokens).label("input_tokens"),
func.sum(Usage.output_tokens).label("output_tokens"),
func.sum(Usage.cache_creation_input_tokens).label("cache_creation_tokens"),
func.sum(Usage.cache_read_input_tokens).label("cache_read_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost"),
func.sum(Usage.actual_total_cost_usd).label("actual_total_cost"),
func.sum(Usage.input_cost_usd).label("input_cost"),
func.sum(Usage.output_cost_usd).label("output_cost"),
func.sum(Usage.cache_creation_cost_usd).label("cache_creation_cost"),
func.sum(Usage.cache_read_cost_usd).label("cache_read_cost"),
func.avg(Usage.response_time_ms).label("avg_response_time"),
)
.filter(and_(Usage.created_at >= day_start, Usage.created_at < day_end))
.first()
)
# Fallback 统计 (执行候选数 > 1 的请求数)
fallback_subquery = (
db.query(
RequestCandidate.request_id,
func.count(RequestCandidate.id).label("executed_count"),
)
.filter(
and_(
RequestCandidate.created_at >= day_start,
RequestCandidate.created_at < day_end,
RequestCandidate.status.in_(["success", "failed"]),
)
)
.group_by(RequestCandidate.request_id)
.subquery()
)
fallback_count = (
db.query(func.count())
.select_from(fallback_subquery)
.filter(fallback_subquery.c.executed_count > 1)
.scalar()
or 0
)
# 使用维度统计
unique_models = (
db.query(func.count(func.distinct(Usage.model)))
.filter(and_(Usage.created_at >= day_start, Usage.created_at < day_end))
.scalar()
or 0
)
unique_providers = (
db.query(func.count(func.distinct(Usage.provider)))
.filter(and_(Usage.created_at >= day_start, Usage.created_at < day_end))
.scalar()
or 0
)
# 更新统计记录
stats.total_requests = total_requests
stats.success_requests = total_requests - error_requests
stats.error_requests = error_requests
stats.input_tokens = int(aggregated.input_tokens or 0)
stats.output_tokens = int(aggregated.output_tokens or 0)
stats.cache_creation_tokens = int(aggregated.cache_creation_tokens or 0)
stats.cache_read_tokens = int(aggregated.cache_read_tokens or 0)
stats.total_cost = float(aggregated.total_cost or 0)
stats.actual_total_cost = float(aggregated.actual_total_cost or 0)
stats.input_cost = float(aggregated.input_cost or 0)
stats.output_cost = float(aggregated.output_cost or 0)
stats.cache_creation_cost = float(aggregated.cache_creation_cost or 0)
stats.cache_read_cost = float(aggregated.cache_read_cost or 0)
stats.avg_response_time_ms = float(aggregated.avg_response_time or 0)
stats.fallback_count = fallback_count
stats.unique_models = unique_models
stats.unique_providers = unique_providers
if not existing:
db.add(stats)
db.commit()
logger.info(f"[StatsAggregator] 聚合日期 {day_start.date()} 完成: {total_requests} 请求")
return stats
@staticmethod
def aggregate_user_daily_stats(
db: Session, user_id: str, date: datetime
) -> StatsUserDaily:
"""聚合指定用户指定日期的统计数据"""
day_start = date.replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=timezone.utc)
day_end = day_start + timedelta(days=1)
existing = (
db.query(StatsUserDaily)
.filter(and_(StatsUserDaily.user_id == user_id, StatsUserDaily.date == day_start))
.first()
)
if existing:
stats = existing
else:
stats = StatsUserDaily(id=str(uuid.uuid4()), user_id=user_id, date=day_start)
# 用户请求统计
base_query = db.query(Usage).filter(
and_(
Usage.user_id == user_id,
Usage.created_at >= day_start,
Usage.created_at < day_end,
)
)
total_requests = base_query.count()
if total_requests == 0:
stats.total_requests = 0
stats.success_requests = 0
stats.error_requests = 0
stats.input_tokens = 0
stats.output_tokens = 0
stats.cache_creation_tokens = 0
stats.cache_read_tokens = 0
stats.total_cost = 0.0
if not existing:
db.add(stats)
db.commit()
return stats
error_requests = (
base_query.filter(
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
).count()
)
aggregated = (
db.query(
func.sum(Usage.input_tokens).label("input_tokens"),
func.sum(Usage.output_tokens).label("output_tokens"),
func.sum(Usage.cache_creation_input_tokens).label("cache_creation_tokens"),
func.sum(Usage.cache_read_input_tokens).label("cache_read_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost"),
)
.filter(
and_(
Usage.user_id == user_id,
Usage.created_at >= day_start,
Usage.created_at < day_end,
)
)
.first()
)
stats.total_requests = total_requests
stats.success_requests = total_requests - error_requests
stats.error_requests = error_requests
stats.input_tokens = int(aggregated.input_tokens or 0)
stats.output_tokens = int(aggregated.output_tokens or 0)
stats.cache_creation_tokens = int(aggregated.cache_creation_tokens or 0)
stats.cache_read_tokens = int(aggregated.cache_read_tokens or 0)
stats.total_cost = float(aggregated.total_cost or 0)
if not existing:
db.add(stats)
db.commit()
return stats
@staticmethod
def update_summary(db: Session) -> StatsSummary:
"""更新全局统计汇总
汇总截止到昨天的所有数据。
"""
now = datetime.now(timezone.utc)
today = now.replace(hour=0, minute=0, second=0, microsecond=0)
cutoff_date = today # 不含今天
# 获取或创建 summary 记录
summary = db.query(StatsSummary).first()
if not summary:
summary = StatsSummary(id=str(uuid.uuid4()), cutoff_date=cutoff_date)
# 从 stats_daily 聚合历史数据
daily_aggregated = (
db.query(
func.sum(StatsDaily.total_requests).label("total_requests"),
func.sum(StatsDaily.success_requests).label("success_requests"),
func.sum(StatsDaily.error_requests).label("error_requests"),
func.sum(StatsDaily.input_tokens).label("input_tokens"),
func.sum(StatsDaily.output_tokens).label("output_tokens"),
func.sum(StatsDaily.cache_creation_tokens).label("cache_creation_tokens"),
func.sum(StatsDaily.cache_read_tokens).label("cache_read_tokens"),
func.sum(StatsDaily.total_cost).label("total_cost"),
func.sum(StatsDaily.actual_total_cost).label("actual_total_cost"),
)
.filter(StatsDaily.date < cutoff_date)
.first()
)
# 用户/API Key 统计
total_users = db.query(func.count(DBUser.id)).scalar() or 0
active_users = (
db.query(func.count(DBUser.id)).filter(DBUser.is_active.is_(True)).scalar() or 0
)
total_api_keys = db.query(func.count(ApiKey.id)).scalar() or 0
active_api_keys = (
db.query(func.count(ApiKey.id)).filter(ApiKey.is_active.is_(True)).scalar() or 0
)
# 更新 summary
summary.cutoff_date = cutoff_date
summary.all_time_requests = int(daily_aggregated.total_requests or 0)
summary.all_time_success_requests = int(daily_aggregated.success_requests or 0)
summary.all_time_error_requests = int(daily_aggregated.error_requests or 0)
summary.all_time_input_tokens = int(daily_aggregated.input_tokens or 0)
summary.all_time_output_tokens = int(daily_aggregated.output_tokens or 0)
summary.all_time_cache_creation_tokens = int(daily_aggregated.cache_creation_tokens or 0)
summary.all_time_cache_read_tokens = int(daily_aggregated.cache_read_tokens or 0)
summary.all_time_cost = float(daily_aggregated.total_cost or 0)
summary.all_time_actual_cost = float(daily_aggregated.actual_total_cost or 0)
summary.total_users = total_users
summary.active_users = active_users
summary.total_api_keys = total_api_keys
summary.active_api_keys = active_api_keys
db.add(summary)
db.commit()
logger.info(f"[StatsAggregator] 更新全局汇总完成,截止日期: {cutoff_date.date()}")
return summary
@staticmethod
def get_today_realtime_stats(db: Session) -> dict:
"""获取今日实时统计(用于与预聚合数据合并)"""
now = datetime.now(timezone.utc)
today = now.replace(hour=0, minute=0, second=0, microsecond=0)
base_query = db.query(Usage).filter(Usage.created_at >= today)
total_requests = base_query.count()
if total_requests == 0:
return {
"total_requests": 0,
"success_requests": 0,
"error_requests": 0,
"input_tokens": 0,
"output_tokens": 0,
"cache_creation_tokens": 0,
"cache_read_tokens": 0,
"total_cost": 0.0,
"actual_total_cost": 0.0,
}
error_requests = (
base_query.filter(
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
).count()
)
aggregated = (
db.query(
func.sum(Usage.input_tokens).label("input_tokens"),
func.sum(Usage.output_tokens).label("output_tokens"),
func.sum(Usage.cache_creation_input_tokens).label("cache_creation_tokens"),
func.sum(Usage.cache_read_input_tokens).label("cache_read_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost"),
func.sum(Usage.actual_total_cost_usd).label("actual_total_cost"),
)
.filter(Usage.created_at >= today)
.first()
)
return {
"total_requests": total_requests,
"success_requests": total_requests - error_requests,
"error_requests": error_requests,
"input_tokens": int(aggregated.input_tokens or 0),
"output_tokens": int(aggregated.output_tokens or 0),
"cache_creation_tokens": int(aggregated.cache_creation_tokens or 0),
"cache_read_tokens": int(aggregated.cache_read_tokens or 0),
"total_cost": float(aggregated.total_cost or 0),
"actual_total_cost": float(aggregated.actual_total_cost or 0),
}
@staticmethod
def get_combined_stats(db: Session) -> dict:
"""获取合并后的统计数据(预聚合 + 今日实时)"""
summary = db.query(StatsSummary).first()
today_stats = StatsAggregatorService.get_today_realtime_stats(db)
if not summary:
# 如果没有预聚合数据,返回今日数据
return today_stats
return {
"total_requests": summary.all_time_requests + today_stats["total_requests"],
"success_requests": summary.all_time_success_requests
+ today_stats["success_requests"],
"error_requests": summary.all_time_error_requests + today_stats["error_requests"],
"input_tokens": summary.all_time_input_tokens + today_stats["input_tokens"],
"output_tokens": summary.all_time_output_tokens + today_stats["output_tokens"],
"cache_creation_tokens": summary.all_time_cache_creation_tokens
+ today_stats["cache_creation_tokens"],
"cache_read_tokens": summary.all_time_cache_read_tokens
+ today_stats["cache_read_tokens"],
"total_cost": summary.all_time_cost + today_stats["total_cost"],
"actual_total_cost": summary.all_time_actual_cost + today_stats["actual_total_cost"],
"total_users": summary.total_users,
"active_users": summary.active_users,
"total_api_keys": summary.total_api_keys,
"active_api_keys": summary.active_api_keys,
}
@staticmethod
def backfill_historical_data(db: Session, days: int = 365) -> int:
"""回填历史数据(首次部署时使用)
Args:
db: 数据库会话
days: 要回填的天数
Returns:
回填的天数
"""
now = datetime.now(timezone.utc)
today = now.replace(hour=0, minute=0, second=0, microsecond=0)
# 找到最早的 Usage 记录
earliest = db.query(func.min(Usage.created_at)).scalar()
if not earliest:
logger.info("[StatsAggregator] 没有历史数据需要回填")
return 0
# 计算需要回填的日期范围
earliest_date = earliest.replace(hour=0, minute=0, second=0, microsecond=0)
start_date = max(earliest_date, today - timedelta(days=days))
count = 0
current_date = start_date
while current_date < today:
StatsAggregatorService.aggregate_daily_stats(db, current_date)
count += 1
current_date += timedelta(days=1)
# 更新汇总
if count > 0:
StatsAggregatorService.update_summary(db)
logger.info(f"[StatsAggregator] 回填历史数据完成,共 {count}")
return count

View File

@@ -0,0 +1,142 @@
"""
API密钥统计同步服务
定期同步API密钥的统计数据确保与实际使用记录一致
"""
from typing import Optional
from sqlalchemy import func
from sqlalchemy.orm import Session
from src.core.logger import logger
from src.models.database import ApiKey, Usage
class SyncStatsService:
"""API密钥统计同步服务"""
# 分页批量大小
BATCH_SIZE = 100
@staticmethod
def sync_api_key_stats(db: Session, api_key_id: Optional[str] = None) -> dict: # UUID
"""
同步API密钥的统计数据
Args:
db: 数据库会话
api_key_id: 指定要同步的API密钥ID如果不指定则同步所有
Returns:
同步结果统计
"""
result = {"synced": 0, "updated": 0, "errors": 0}
try:
# 获取要同步的API密钥使用分页避免大数据量问题
if api_key_id:
api_keys = db.query(ApiKey).filter(ApiKey.id == api_key_id).all()
else:
# 分页处理,避免一次加载所有数据
offset = 0
api_keys = []
while True:
batch = db.query(ApiKey).offset(offset).limit(SyncStatsService.BATCH_SIZE).all()
if not batch:
break
api_keys.extend(batch)
offset += SyncStatsService.BATCH_SIZE
for api_key in api_keys:
try:
# 计算实际的使用统计
stats = (
db.query(
func.count(Usage.id).label("requests"),
func.sum(Usage.total_cost_usd).label("cost"),
)
.filter(Usage.api_key_id == api_key.id)
.first()
)
actual_requests = stats.requests or 0
actual_cost = float(stats.cost or 0)
# 获取最后使用时间
last_usage = (
db.query(Usage.created_at)
.filter(Usage.api_key_id == api_key.id)
.order_by(Usage.created_at.desc())
.first()
)
# 检查是否需要更新
needs_update = False
if api_key.total_requests != actual_requests:
logger.info(f"API密钥 {api_key.id} 请求数不一致: {api_key.total_requests} -> {actual_requests}")
api_key.total_requests = actual_requests
needs_update = True
if abs(api_key.total_cost_usd - actual_cost) > 0.0001:
logger.info(f"API密钥 {api_key.id} 费用不一致: {api_key.total_cost_usd} -> {actual_cost}")
api_key.total_cost_usd = actual_cost
needs_update = True
if last_usage and api_key.last_used_at != last_usage[0]:
api_key.last_used_at = last_usage[0]
needs_update = True
result["synced"] += 1
if needs_update:
result["updated"] += 1
logger.info(f"已更新API密钥 {api_key.id} 的统计数据")
except Exception as e:
logger.error(f"同步API密钥 {api_key.id} 统计时出错: {e}")
result["errors"] += 1
# 回滚当前失败的操作,继续处理其他密钥
try:
db.rollback()
except Exception:
pass
# 提交所有更改
db.commit()
logger.info(f"同步完成: 处理 {result['synced']} 个密钥, 更新 {result['updated']} 个, 错误 {result['errors']}")
except Exception as e:
logger.error(f"同步统计数据时出错: {e}")
db.rollback()
raise
return result
@staticmethod
def get_api_key_real_stats(db: Session, api_key_id: str) -> dict: # UUID
"""
获取API密钥的实际统计数据直接从使用记录计算
Args:
db: 数据库会话
api_key_id: API密钥ID
Returns:
实际的统计数据
"""
# 计算实际的使用统计
stats = (
db.query(
func.count(Usage.id).label("requests"),
func.sum(Usage.total_cost_usd).label("cost"),
func.max(Usage.created_at).label("last_used"),
)
.filter(Usage.api_key_id == api_key_id)
.first()
)
return {
"total_requests": stats.requests or 0,
"total_cost_usd": float(stats.cost or 0),
"last_used_at": stats.last_used,
}