mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-09 19:22:26 +08:00
Initial commit
This commit is contained in:
23
src/services/system/__init__.py
Normal file
23
src/services/system/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
系统服务模块
|
||||
|
||||
包含系统配置、审计日志、公告等功能。
|
||||
"""
|
||||
|
||||
from src.services.system.announcement import AnnouncementService
|
||||
from src.services.system.audit import AuditService
|
||||
from src.services.system.cleanup_scheduler import CleanupScheduler
|
||||
from src.services.system.config import SystemConfigService
|
||||
from src.services.system.scheduler import APP_TIMEZONE, TaskScheduler, get_scheduler
|
||||
from src.services.system.sync_stats import SyncStatsService
|
||||
|
||||
__all__ = [
|
||||
"SystemConfigService",
|
||||
"AuditService",
|
||||
"AnnouncementService",
|
||||
"CleanupScheduler",
|
||||
"SyncStatsService",
|
||||
"TaskScheduler",
|
||||
"get_scheduler",
|
||||
"APP_TIMEZONE",
|
||||
]
|
||||
241
src/services/system/announcement.py
Normal file
241
src/services/system/announcement.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""
|
||||
公告系统服务
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import and_, or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.exceptions import ForbiddenException, NotFoundException
|
||||
from src.core.logger import logger
|
||||
from src.models.database import Announcement, AnnouncementRead, User, UserRole
|
||||
|
||||
|
||||
|
||||
class AnnouncementService:
|
||||
"""公告系统服务"""
|
||||
|
||||
@staticmethod
|
||||
def create_announcement(
|
||||
db: Session,
|
||||
author_id: str, # UUID
|
||||
title: str,
|
||||
content: str,
|
||||
type: str = "info",
|
||||
priority: int = 0,
|
||||
is_pinned: bool = False,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
) -> Announcement:
|
||||
"""创建公告"""
|
||||
# 验证作者是否为管理员
|
||||
author = db.query(User).filter(User.id == author_id).first()
|
||||
if not author or author.role != UserRole.ADMIN:
|
||||
raise ForbiddenException("Only administrators can create announcements")
|
||||
|
||||
# 验证类型
|
||||
if type not in ["info", "warning", "maintenance", "important"]:
|
||||
raise ValueError("Invalid announcement type")
|
||||
|
||||
announcement = Announcement(
|
||||
title=title,
|
||||
content=content,
|
||||
type=type,
|
||||
priority=priority,
|
||||
author_id=author_id,
|
||||
is_pinned=is_pinned,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
db.add(announcement)
|
||||
db.commit()
|
||||
db.refresh(announcement)
|
||||
|
||||
logger.info(f"Created announcement: {announcement.id} - {title}")
|
||||
return announcement
|
||||
|
||||
@staticmethod
|
||||
def get_announcements(
|
||||
db: Session,
|
||||
user_id: Optional[str] = None, # UUID
|
||||
active_only: bool = True,
|
||||
include_read_status: bool = False,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> dict:
|
||||
"""获取公告列表"""
|
||||
query = db.query(Announcement)
|
||||
|
||||
# 筛选条件
|
||||
if active_only:
|
||||
now = datetime.now(timezone.utc)
|
||||
query = query.filter(
|
||||
Announcement.is_active == True,
|
||||
or_(Announcement.start_time == None, Announcement.start_time <= now),
|
||||
or_(Announcement.end_time == None, Announcement.end_time >= now),
|
||||
)
|
||||
|
||||
# 排序:置顶优先,然后按优先级和创建时间
|
||||
query = query.order_by(
|
||||
Announcement.is_pinned.desc(),
|
||||
Announcement.priority.desc(),
|
||||
Announcement.created_at.desc(),
|
||||
)
|
||||
|
||||
# 分页
|
||||
total = query.count()
|
||||
announcements = query.offset(offset).limit(limit).all()
|
||||
|
||||
# 获取已读状态
|
||||
read_announcement_ids = set()
|
||||
unread_count = 0
|
||||
|
||||
if user_id and include_read_status:
|
||||
read_records = (
|
||||
db.query(AnnouncementRead.announcement_id)
|
||||
.filter(AnnouncementRead.user_id == user_id)
|
||||
.all()
|
||||
)
|
||||
read_announcement_ids = {r[0] for r in read_records}
|
||||
unread_count = total - len(read_announcement_ids)
|
||||
|
||||
# 构建返回数据
|
||||
items = []
|
||||
for announcement in announcements:
|
||||
item = {
|
||||
"id": announcement.id,
|
||||
"title": announcement.title,
|
||||
"content": announcement.content,
|
||||
"type": announcement.type,
|
||||
"priority": announcement.priority,
|
||||
"is_pinned": announcement.is_pinned,
|
||||
"is_active": announcement.is_active,
|
||||
"author": {"id": announcement.author.id, "username": announcement.author.username},
|
||||
"start_time": announcement.start_time,
|
||||
"end_time": announcement.end_time,
|
||||
"created_at": announcement.created_at,
|
||||
"updated_at": announcement.updated_at,
|
||||
}
|
||||
|
||||
if include_read_status and user_id:
|
||||
item["is_read"] = announcement.id in read_announcement_ids
|
||||
|
||||
items.append(item)
|
||||
|
||||
result = {"items": items, "total": total}
|
||||
|
||||
if include_read_status and user_id:
|
||||
result["unread_count"] = unread_count
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_announcement(db: Session, announcement_id: str) -> Announcement: # UUID
|
||||
"""获取单个公告"""
|
||||
announcement = db.query(Announcement).filter(Announcement.id == announcement_id).first()
|
||||
|
||||
if not announcement:
|
||||
raise NotFoundException("Announcement not found")
|
||||
|
||||
return announcement
|
||||
|
||||
@staticmethod
|
||||
def update_announcement(
|
||||
db: Session,
|
||||
announcement_id: str, # UUID
|
||||
user_id: str, # UUID
|
||||
title: Optional[str] = None,
|
||||
content: Optional[str] = None,
|
||||
type: Optional[str] = None,
|
||||
priority: Optional[int] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
is_pinned: Optional[bool] = None,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
) -> Announcement:
|
||||
"""更新公告"""
|
||||
# 验证用户是否为管理员
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user or user.role != UserRole.ADMIN:
|
||||
raise ForbiddenException("Only administrators can update announcements")
|
||||
|
||||
announcement = AnnouncementService.get_announcement(db, announcement_id)
|
||||
|
||||
# 更新提供的字段
|
||||
if title is not None:
|
||||
announcement.title = title
|
||||
if content is not None:
|
||||
announcement.content = content
|
||||
if type is not None:
|
||||
if type not in ["info", "warning", "maintenance", "important"]:
|
||||
raise ValueError("Invalid announcement type")
|
||||
announcement.type = type
|
||||
if priority is not None:
|
||||
announcement.priority = priority
|
||||
if is_active is not None:
|
||||
announcement.is_active = is_active
|
||||
if is_pinned is not None:
|
||||
announcement.is_pinned = is_pinned
|
||||
if start_time is not None:
|
||||
announcement.start_time = start_time
|
||||
if end_time is not None:
|
||||
announcement.end_time = end_time
|
||||
|
||||
db.commit()
|
||||
db.refresh(announcement)
|
||||
|
||||
logger.info(f"Updated announcement: {announcement_id}")
|
||||
return announcement
|
||||
|
||||
@staticmethod
|
||||
def delete_announcement(db: Session, announcement_id: str, user_id: str) -> None: # UUID
|
||||
"""删除公告"""
|
||||
# 验证用户是否为管理员
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user or user.role != UserRole.ADMIN:
|
||||
raise ForbiddenException("Only administrators can delete announcements")
|
||||
|
||||
announcement = AnnouncementService.get_announcement(db, announcement_id)
|
||||
|
||||
db.delete(announcement)
|
||||
db.commit()
|
||||
|
||||
logger.info(f"Deleted announcement: {announcement_id}")
|
||||
|
||||
@staticmethod
|
||||
def mark_as_read(db: Session, announcement_id: str, user_id: str) -> None: # UUID
|
||||
"""标记公告为已读"""
|
||||
# 检查公告是否存在
|
||||
announcement = AnnouncementService.get_announcement(db, announcement_id)
|
||||
|
||||
# 检查是否已经标记为已读
|
||||
existing = (
|
||||
db.query(AnnouncementRead)
|
||||
.filter(
|
||||
AnnouncementRead.user_id == user_id,
|
||||
AnnouncementRead.announcement_id == announcement_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not existing:
|
||||
read_record = AnnouncementRead(user_id=user_id, announcement_id=announcement_id)
|
||||
db.add(read_record)
|
||||
db.commit()
|
||||
|
||||
logger.info(f"User {user_id} marked announcement {announcement_id} as read")
|
||||
|
||||
@staticmethod
|
||||
def get_active_announcements(db: Session, user_id: Optional[str] = None) -> dict: # UUID
|
||||
"""获取当前有效的公告(首页展示用)"""
|
||||
return AnnouncementService.get_announcements(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
active_only=True,
|
||||
include_read_status=True if user_id else False,
|
||||
limit=10,
|
||||
)
|
||||
459
src/services/system/audit.py
Normal file
459
src/services/system/audit.py
Normal file
@@ -0,0 +1,459 @@
|
||||
"""
|
||||
审计日志服务
|
||||
记录所有重要操作和安全事件
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import AuditEventType, AuditLog
|
||||
from src.utils.transaction_manager import transactional
|
||||
|
||||
|
||||
|
||||
# 审计模型已移至 src/models/database.py
|
||||
|
||||
|
||||
class AuditService:
|
||||
"""审计服务"""
|
||||
|
||||
@staticmethod
|
||||
@transactional(commit=False) # 不自动提交,让调用方决定
|
||||
def log_event(
|
||||
db: Session,
|
||||
event_type: AuditEventType,
|
||||
description: str,
|
||||
user_id: Optional[str] = None, # UUID
|
||||
api_key_id: Optional[str] = None, # UUID
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
request_id: Optional[str] = None,
|
||||
status_code: Optional[int] = None,
|
||||
error_message: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> AuditLog:
|
||||
"""
|
||||
记录审计事件
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
event_type: 事件类型
|
||||
description: 事件描述
|
||||
user_id: 用户ID
|
||||
api_key_id: API密钥ID
|
||||
ip_address: IP地址
|
||||
user_agent: 用户代理
|
||||
request_id: 请求ID
|
||||
status_code: 状态码
|
||||
error_message: 错误消息
|
||||
metadata: 额外元数据
|
||||
|
||||
Returns:
|
||||
审计日志记录
|
||||
"""
|
||||
try:
|
||||
audit_log = AuditLog(
|
||||
event_type=event_type.value,
|
||||
description=description,
|
||||
user_id=user_id,
|
||||
api_key_id=api_key_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
status_code=status_code,
|
||||
error_message=error_message,
|
||||
event_metadata=metadata,
|
||||
)
|
||||
|
||||
db.add(audit_log)
|
||||
db.commit() # 立即提交事务,释放数据库锁
|
||||
db.refresh(audit_log)
|
||||
|
||||
# 同时记录到系统日志
|
||||
log_message = (
|
||||
f"AUDIT [{event_type.value}] - {description} | "
|
||||
f"user_id={user_id}, ip={ip_address}"
|
||||
)
|
||||
|
||||
if event_type in [
|
||||
AuditEventType.UNAUTHORIZED_ACCESS,
|
||||
AuditEventType.SUSPICIOUS_ACTIVITY,
|
||||
]:
|
||||
logger.warning(log_message)
|
||||
elif event_type in [AuditEventType.LOGIN_FAILED, AuditEventType.REQUEST_FAILED]:
|
||||
logger.info(log_message)
|
||||
else:
|
||||
logger.debug(log_message)
|
||||
|
||||
return audit_log
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log audit event: {e}")
|
||||
db.rollback()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def log_login_attempt(
|
||||
db: Session,
|
||||
email: str,
|
||||
success: bool,
|
||||
ip_address: str,
|
||||
user_agent: str,
|
||||
user_id: Optional[str] = None, # UUID
|
||||
error_reason: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
记录登录尝试
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
email: 登录邮箱
|
||||
success: 是否成功
|
||||
ip_address: IP地址
|
||||
user_agent: 用户代理
|
||||
user_id: 用户ID(成功时)
|
||||
error_reason: 失败原因
|
||||
"""
|
||||
event_type = AuditEventType.LOGIN_SUCCESS if success else AuditEventType.LOGIN_FAILED
|
||||
description = f"Login attempt for {email}"
|
||||
if not success and error_reason:
|
||||
description += f": {error_reason}"
|
||||
|
||||
AuditService.log_event(
|
||||
db=db,
|
||||
event_type=event_type,
|
||||
description=description,
|
||||
user_id=user_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
metadata={"email": email},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def log_api_request(
|
||||
db: Session,
|
||||
user_id: str, # UUID
|
||||
api_key_id: str, # UUID
|
||||
request_id: str,
|
||||
model: str,
|
||||
provider: str,
|
||||
success: bool,
|
||||
ip_address: str,
|
||||
status_code: int,
|
||||
error_message: Optional[str] = None,
|
||||
input_tokens: Optional[int] = None,
|
||||
output_tokens: Optional[int] = None,
|
||||
cost_usd: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
记录API请求
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
user_id: 用户ID
|
||||
api_key_id: API密钥ID
|
||||
request_id: 请求ID
|
||||
model: 模型名称
|
||||
provider: 提供商名称
|
||||
success: 是否成功
|
||||
ip_address: IP地址
|
||||
status_code: 状态码
|
||||
error_message: 错误消息
|
||||
input_tokens: 输入tokens
|
||||
output_tokens: 输出tokens
|
||||
cost_usd: 成本(美元)
|
||||
"""
|
||||
event_type = AuditEventType.REQUEST_SUCCESS if success else AuditEventType.REQUEST_FAILED
|
||||
description = f"API request to {provider}/{model}"
|
||||
|
||||
metadata = {"model": model, "provider": provider}
|
||||
|
||||
if input_tokens:
|
||||
metadata["input_tokens"] = input_tokens
|
||||
if output_tokens:
|
||||
metadata["output_tokens"] = output_tokens
|
||||
if cost_usd:
|
||||
metadata["cost_usd"] = cost_usd
|
||||
|
||||
AuditService.log_event(
|
||||
db=db,
|
||||
event_type=event_type,
|
||||
description=description,
|
||||
user_id=user_id,
|
||||
api_key_id=api_key_id,
|
||||
request_id=request_id,
|
||||
ip_address=ip_address,
|
||||
status_code=status_code,
|
||||
error_message=error_message,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def log_security_event(
|
||||
db: Session,
|
||||
event_type: AuditEventType,
|
||||
description: str,
|
||||
ip_address: str,
|
||||
user_id: Optional[str] = None, # UUID
|
||||
severity: str = "medium",
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
记录安全事件
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
event_type: 事件类型
|
||||
description: 事件描述
|
||||
ip_address: IP地址
|
||||
user_id: 用户ID
|
||||
severity: 严重程度 (low, medium, high, critical)
|
||||
details: 详细信息
|
||||
"""
|
||||
event_metadata = {"severity": severity}
|
||||
if details:
|
||||
event_metadata.update(details)
|
||||
|
||||
AuditService.log_event(
|
||||
db=db,
|
||||
event_type=event_type,
|
||||
description=description,
|
||||
user_id=user_id,
|
||||
ip_address=ip_address,
|
||||
metadata=event_metadata,
|
||||
)
|
||||
|
||||
# 对于高严重性事件,简化日志输出
|
||||
if severity in ["high", "critical"]:
|
||||
logger.error(f"安全告警 [{severity.upper()}]: {description}")
|
||||
|
||||
@staticmethod
|
||||
def get_user_audit_logs(
|
||||
db: Session,
|
||||
user_id: str, # UUID
|
||||
event_types: Optional[List[AuditEventType]] = None,
|
||||
limit: int = 100,
|
||||
) -> List[AuditLog]:
|
||||
"""
|
||||
获取用户的审计日志
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
user_id: 用户ID
|
||||
event_types: 事件类型过滤
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
审计日志列表
|
||||
"""
|
||||
query = db.query(AuditLog).filter(AuditLog.user_id == user_id)
|
||||
|
||||
if event_types:
|
||||
event_type_values = [et.value for et in event_types]
|
||||
query = query.filter(AuditLog.event_type.in_(event_type_values))
|
||||
|
||||
return query.order_by(AuditLog.created_at.desc()).limit(limit).all()
|
||||
|
||||
@staticmethod
|
||||
def get_suspicious_activities(db: Session, hours: int = 24, limit: int = 100) -> List[AuditLog]:
|
||||
"""
|
||||
获取可疑活动
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
hours: 时间范围(小时)
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
可疑活动列表
|
||||
"""
|
||||
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours)
|
||||
|
||||
suspicious_types = [
|
||||
AuditEventType.SUSPICIOUS_ACTIVITY.value,
|
||||
AuditEventType.UNAUTHORIZED_ACCESS.value,
|
||||
AuditEventType.LOGIN_FAILED.value,
|
||||
AuditEventType.REQUEST_RATE_LIMITED.value,
|
||||
]
|
||||
|
||||
return (
|
||||
db.query(AuditLog)
|
||||
.filter(AuditLog.event_type.in_(suspicious_types), AuditLog.created_at >= cutoff_time)
|
||||
.order_by(AuditLog.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def analyze_user_behavior(db: Session, user_id: str, days: int = 30) -> Dict[str, Any]: # UUID
|
||||
"""
|
||||
分析用户行为
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
user_id: 用户ID
|
||||
days: 分析天数
|
||||
|
||||
Returns:
|
||||
行为分析结果
|
||||
"""
|
||||
from sqlalchemy import func
|
||||
|
||||
cutoff_time = datetime.now(timezone.utc) - timedelta(days=days)
|
||||
|
||||
# 统计各种事件类型
|
||||
event_counts = (
|
||||
db.query(AuditLog.event_type, func.count(AuditLog.id).label("count"))
|
||||
.filter(AuditLog.user_id == user_id, AuditLog.created_at >= cutoff_time)
|
||||
.group_by(AuditLog.event_type)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 统计失败请求
|
||||
failed_requests = (
|
||||
db.query(func.count(AuditLog.id))
|
||||
.filter(
|
||||
AuditLog.user_id == user_id,
|
||||
AuditLog.event_type == AuditEventType.REQUEST_FAILED.value,
|
||||
AuditLog.created_at >= cutoff_time,
|
||||
)
|
||||
.scalar()
|
||||
)
|
||||
|
||||
# 统计成功请求
|
||||
success_requests = (
|
||||
db.query(func.count(AuditLog.id))
|
||||
.filter(
|
||||
AuditLog.user_id == user_id,
|
||||
AuditLog.event_type == AuditEventType.REQUEST_SUCCESS.value,
|
||||
AuditLog.created_at >= cutoff_time,
|
||||
)
|
||||
.scalar()
|
||||
)
|
||||
|
||||
# 获取最近的可疑活动
|
||||
recent_suspicious = (
|
||||
db.query(AuditLog)
|
||||
.filter(
|
||||
AuditLog.user_id == user_id,
|
||||
AuditLog.event_type.in_(
|
||||
[
|
||||
AuditEventType.SUSPICIOUS_ACTIVITY.value,
|
||||
AuditEventType.UNAUTHORIZED_ACCESS.value,
|
||||
]
|
||||
),
|
||||
AuditLog.created_at >= cutoff_time,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"period_days": days,
|
||||
"event_counts": {event: count for event, count in event_counts},
|
||||
"failed_requests": failed_requests or 0,
|
||||
"success_requests": success_requests or 0,
|
||||
"success_rate": (
|
||||
success_requests / (success_requests + failed_requests)
|
||||
if (success_requests + failed_requests) > 0
|
||||
else 0
|
||||
),
|
||||
"suspicious_activities": recent_suspicious,
|
||||
"analysis_time": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def log_event_auto(
|
||||
event_type: AuditEventType,
|
||||
description: str,
|
||||
user_id: Optional[str] = None,
|
||||
api_key_id: Optional[str] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
request_id: Optional[str] = None,
|
||||
status_code: Optional[int] = None,
|
||||
error_message: Optional[str] = None,
|
||||
event_metadata: Optional[Dict[str, Any]] = None,
|
||||
db: Optional[Session] = None,
|
||||
) -> Optional[AuditLog]:
|
||||
"""
|
||||
自动管理数据库会话的审计日志记录方法
|
||||
适用于中间件等无法直接获取数据库会话的场景
|
||||
|
||||
Args:
|
||||
event_type: 事件类型
|
||||
description: 事件描述
|
||||
user_id: 用户ID
|
||||
api_key_id: API密钥ID
|
||||
ip_address: IP地址
|
||||
user_agent: 用户代理
|
||||
request_id: 请求ID
|
||||
status_code: 状态码
|
||||
error_message: 错误消息
|
||||
event_metadata: 额外元数据
|
||||
db: 数据库会话(可选,如不提供则自动创建)
|
||||
|
||||
Returns:
|
||||
审计日志记录
|
||||
"""
|
||||
# 如果提供了数据库会话,使用它(不自动提交)
|
||||
if db is not None:
|
||||
try:
|
||||
audit_log = AuditService.log_event(
|
||||
db=db,
|
||||
event_type=event_type,
|
||||
description=description,
|
||||
user_id=user_id,
|
||||
api_key_id=api_key_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
status_code=status_code,
|
||||
error_message=error_message,
|
||||
metadata=event_metadata,
|
||||
)
|
||||
# 注意:不在这里提交,让调用方决定何时提交
|
||||
return audit_log
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log audit event: {e}")
|
||||
return None
|
||||
|
||||
# 如果没有提供会话,自动创建并管理
|
||||
db_session = None
|
||||
try:
|
||||
db_session = next(get_db())
|
||||
|
||||
audit_log = AuditService.log_event(
|
||||
db=db_session,
|
||||
event_type=event_type,
|
||||
description=description,
|
||||
user_id=user_id,
|
||||
api_key_id=api_key_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
status_code=status_code,
|
||||
error_message=error_message,
|
||||
metadata=event_metadata,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
return audit_log
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log audit event with auto session: {e}")
|
||||
if db_session is not None:
|
||||
db_session.rollback()
|
||||
return None
|
||||
finally:
|
||||
if db_session is not None:
|
||||
db_session.close()
|
||||
|
||||
|
||||
# 全局审计服务实例
|
||||
audit_service = AuditService()
|
||||
597
src/services/system/cleanup_scheduler.py
Normal file
597
src/services/system/cleanup_scheduler.py
Normal file
@@ -0,0 +1,597 @@
|
||||
"""
|
||||
使用记录清理定时任务
|
||||
|
||||
分级清理策略:
|
||||
- detail_log_retention_days: 压缩 request_body 和 response_body 到压缩字段
|
||||
- header_retention_days: 清空 request_headers 和 response_headers
|
||||
- log_retention_days: 删除整条记录
|
||||
|
||||
统计聚合任务:
|
||||
- 每天凌晨聚合前一天的统计数据
|
||||
- 更新全局统计汇总
|
||||
|
||||
使用 APScheduler 进行任务调度,支持时区配置。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.database import create_session
|
||||
from src.models.database import Usage
|
||||
from src.services.system.config import SystemConfigService
|
||||
from src.services.system.scheduler import get_scheduler
|
||||
from src.services.system.stats_aggregator import StatsAggregatorService
|
||||
from src.services.user.apikey import ApiKeyService
|
||||
from src.utils.compression import compress_json
|
||||
|
||||
|
||||
class CleanupScheduler:
|
||||
"""使用记录清理调度器"""
|
||||
|
||||
def __init__(self):
|
||||
self.running = False
|
||||
self._interval_tasks = []
|
||||
|
||||
async def start(self):
|
||||
"""启动调度器"""
|
||||
if self.running:
|
||||
logger.warning("Cleanup scheduler already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
logger.info("使用记录清理调度器已启动")
|
||||
|
||||
scheduler = get_scheduler()
|
||||
|
||||
# 注册定时任务(使用业务时区)
|
||||
# 统计聚合任务 - 凌晨 1 点执行
|
||||
scheduler.add_cron_job(
|
||||
self._scheduled_stats_aggregation,
|
||||
hour=1,
|
||||
minute=0,
|
||||
job_id="stats_aggregation",
|
||||
name="统计数据聚合",
|
||||
)
|
||||
|
||||
# 清理任务 - 凌晨 3 点执行
|
||||
scheduler.add_cron_job(
|
||||
self._scheduled_cleanup,
|
||||
hour=3,
|
||||
minute=0,
|
||||
job_id="usage_cleanup",
|
||||
name="使用记录清理",
|
||||
)
|
||||
|
||||
# 连接池监控 - 每 5 分钟
|
||||
scheduler.add_interval_job(
|
||||
self._scheduled_monitor,
|
||||
minutes=5,
|
||||
job_id="pool_monitor",
|
||||
name="连接池监控",
|
||||
)
|
||||
|
||||
# Pending 状态清理 - 每 5 分钟
|
||||
scheduler.add_interval_job(
|
||||
self._scheduled_pending_cleanup,
|
||||
minutes=5,
|
||||
job_id="pending_cleanup",
|
||||
name="Pending状态清理",
|
||||
)
|
||||
|
||||
# 启动时执行一次初始化任务
|
||||
asyncio.create_task(self._run_startup_tasks())
|
||||
|
||||
async def _run_startup_tasks(self):
|
||||
"""启动时执行的初始化任务"""
|
||||
# 延迟一点执行,确保系统完全启动
|
||||
await asyncio.sleep(2)
|
||||
|
||||
try:
|
||||
logger.info("启动时执行首次清理任务...")
|
||||
await self._perform_cleanup()
|
||||
except Exception as e:
|
||||
logger.exception(f"启动时清理任务执行出错: {e}")
|
||||
|
||||
try:
|
||||
logger.info("启动时检查统计数据...")
|
||||
await self._perform_stats_aggregation(backfill=True)
|
||||
except Exception as e:
|
||||
logger.exception(f"启动时统计聚合任务出错: {e}")
|
||||
|
||||
async def stop(self):
|
||||
"""停止调度器"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
self.running = False
|
||||
scheduler = get_scheduler()
|
||||
scheduler.stop()
|
||||
|
||||
logger.info("使用记录清理调度器已停止")
|
||||
|
||||
# ========== 任务函数(APScheduler 直接调用异步函数) ==========
|
||||
|
||||
async def _scheduled_stats_aggregation(self):
|
||||
"""统计聚合任务(定时调用)"""
|
||||
await self._perform_stats_aggregation()
|
||||
|
||||
async def _scheduled_cleanup(self):
|
||||
"""清理任务(定时调用)"""
|
||||
await self._perform_cleanup()
|
||||
|
||||
async def _scheduled_monitor(self):
|
||||
"""监控任务(定时调用)"""
|
||||
try:
|
||||
from src.database import log_pool_status
|
||||
|
||||
log_pool_status()
|
||||
except Exception as e:
|
||||
logger.exception(f"连接池监控任务出错: {e}")
|
||||
|
||||
async def _scheduled_pending_cleanup(self):
|
||||
"""Pending 清理任务(定时调用)"""
|
||||
await self._perform_pending_cleanup()
|
||||
|
||||
# ========== 实际任务实现 ==========
|
||||
|
||||
async def _perform_stats_aggregation(self, backfill: bool = False):
|
||||
"""执行统计聚合任务
|
||||
|
||||
Args:
|
||||
backfill: 是否回填历史数据(首次启动时使用)
|
||||
"""
|
||||
db = create_session()
|
||||
try:
|
||||
# 检查是否启用统计聚合
|
||||
if not SystemConfigService.get_config(db, "enable_stats_aggregation", True):
|
||||
logger.info("统计聚合已禁用,跳过聚合任务")
|
||||
return
|
||||
|
||||
logger.info("开始执行统计数据聚合...")
|
||||
|
||||
if backfill:
|
||||
# 首次启动时回填历史数据
|
||||
from src.models.database import StatsSummary
|
||||
|
||||
summary = db.query(StatsSummary).first()
|
||||
if not summary:
|
||||
logger.info("检测到首次运行,开始回填历史统计数据...")
|
||||
days_to_backfill = SystemConfigService.get_config(
|
||||
db, "stats_backfill_days", 365
|
||||
)
|
||||
count = StatsAggregatorService.backfill_historical_data(
|
||||
db, days=days_to_backfill
|
||||
)
|
||||
logger.info(f"历史数据回填完成,共 {count} 天")
|
||||
return
|
||||
|
||||
# 聚合昨天的数据
|
||||
now = datetime.now(timezone.utc)
|
||||
yesterday = (now - timedelta(days=1)).replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
|
||||
StatsAggregatorService.aggregate_daily_stats(db, yesterday)
|
||||
|
||||
# 聚合所有用户的昨日数据
|
||||
from src.models.database import User as DBUser
|
||||
|
||||
users = db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
|
||||
for (user_id,) in users:
|
||||
try:
|
||||
StatsAggregatorService.aggregate_user_daily_stats(db, user_id, yesterday)
|
||||
except Exception as e:
|
||||
logger.warning(f"聚合用户 {user_id} 统计数据失败: {e}")
|
||||
# 回滚当前用户的失败操作,继续处理其他用户
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 更新全局汇总
|
||||
StatsAggregatorService.update_summary(db)
|
||||
|
||||
logger.info("统计数据聚合完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"统计聚合任务执行失败: {e}")
|
||||
db.rollback()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def _perform_pending_cleanup(self):
|
||||
"""执行 pending 状态清理"""
|
||||
db = create_session()
|
||||
try:
|
||||
from src.services.usage.service import UsageService
|
||||
|
||||
# 获取配置的超时时间(默认 10 分钟)
|
||||
timeout_minutes = SystemConfigService.get_config(
|
||||
db, "pending_request_timeout_minutes", 10
|
||||
)
|
||||
|
||||
# 执行清理
|
||||
cleaned_count = UsageService.cleanup_stale_pending_requests(
|
||||
db, timeout_minutes=timeout_minutes
|
||||
)
|
||||
|
||||
if cleaned_count > 0:
|
||||
logger.info(f"清理了 {cleaned_count} 条超时的 pending/streaming 请求")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"清理 pending 请求失败: {e}")
|
||||
db.rollback()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def _perform_cleanup(self):
|
||||
"""执行清理任务"""
|
||||
db = create_session()
|
||||
try:
|
||||
# 检查是否启用自动清理
|
||||
if not SystemConfigService.get_config(db, "enable_auto_cleanup", True):
|
||||
logger.info("自动清理已禁用,跳过清理任务")
|
||||
return
|
||||
|
||||
logger.info("开始执行使用记录分级清理...")
|
||||
|
||||
# 获取配置参数
|
||||
detail_retention = SystemConfigService.get_config(db, "detail_log_retention_days", 7)
|
||||
compressed_retention = SystemConfigService.get_config(
|
||||
db, "compressed_log_retention_days", 90
|
||||
)
|
||||
header_retention = SystemConfigService.get_config(db, "header_retention_days", 90)
|
||||
log_retention = SystemConfigService.get_config(db, "log_retention_days", 365)
|
||||
batch_size = SystemConfigService.get_config(db, "cleanup_batch_size", 1000)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# 1. 压缩详细日志 (body 字段 -> 压缩字段)
|
||||
detail_cutoff = now - timedelta(days=detail_retention)
|
||||
body_compressed = await self._cleanup_body_fields(db, detail_cutoff, batch_size)
|
||||
|
||||
# 2. 清理压缩字段(90天后)
|
||||
compressed_cutoff = now - timedelta(days=compressed_retention)
|
||||
compressed_cleaned = await self._cleanup_compressed_fields(
|
||||
db, compressed_cutoff, batch_size
|
||||
)
|
||||
|
||||
# 3. 清理请求头
|
||||
header_cutoff = now - timedelta(days=header_retention)
|
||||
header_cleaned = await self._cleanup_header_fields(db, header_cutoff, batch_size)
|
||||
|
||||
# 4. 删除过期记录
|
||||
log_cutoff = now - timedelta(days=log_retention)
|
||||
records_deleted = await self._delete_old_records(db, log_cutoff, batch_size)
|
||||
|
||||
# 5. 清理过期的API Keys
|
||||
auto_delete = SystemConfigService.get_config(db, "auto_delete_expired_keys", False)
|
||||
keys_cleaned = ApiKeyService.cleanup_expired_keys(db, auto_delete=auto_delete)
|
||||
|
||||
logger.info(
|
||||
f"清理完成: 压缩 {body_compressed} 条, "
|
||||
f"清理压缩字段 {compressed_cleaned} 条, "
|
||||
f"清理header {header_cleaned} 条, "
|
||||
f"删除记录 {records_deleted} 条, "
|
||||
f"清理过期Keys {keys_cleaned} 条"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"清理任务执行失败: {e}")
|
||||
db.rollback()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def _cleanup_body_fields(
|
||||
self, db: Session, cutoff_time: datetime, batch_size: int
|
||||
) -> int:
|
||||
"""压缩 request_body 和 response_body 字段到压缩字段
|
||||
|
||||
逐条处理,确保每条记录都正确更新
|
||||
"""
|
||||
from sqlalchemy import null, update
|
||||
|
||||
total_compressed = 0
|
||||
no_progress_count = 0 # 连续无进展计数
|
||||
processed_ids: set = set() # 记录已处理的 ID,防止重复处理
|
||||
|
||||
while True:
|
||||
batch_db = create_session()
|
||||
try:
|
||||
# 1. 查询需要压缩的记录
|
||||
# 注意:排除已经是 NULL 或 JSON null 的记录
|
||||
records = (
|
||||
batch_db.query(Usage.id, Usage.request_body, Usage.response_body)
|
||||
.filter(Usage.created_at < cutoff_time)
|
||||
.filter((Usage.request_body.isnot(None)) | (Usage.response_body.isnot(None)))
|
||||
.limit(batch_size)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not records:
|
||||
break
|
||||
|
||||
# 过滤掉实际值为 None 的记录(JSON null 被解析为 Python None)
|
||||
valid_records = [
|
||||
(rid, req, resp)
|
||||
for rid, req, resp in records
|
||||
if req is not None or resp is not None
|
||||
]
|
||||
|
||||
if not valid_records:
|
||||
# 所有记录都是 JSON null,需要清理它们
|
||||
logger.warning(
|
||||
f"检测到 {len(records)} 条记录的 body 字段为 JSON null,进行清理"
|
||||
)
|
||||
for record_id, _, _ in records:
|
||||
batch_db.execute(
|
||||
update(Usage)
|
||||
.where(Usage.id == record_id)
|
||||
.values(request_body=null(), response_body=null())
|
||||
)
|
||||
batch_db.commit()
|
||||
continue
|
||||
|
||||
# 检测是否有重复的 ID(说明更新未生效)
|
||||
current_ids = {r[0] for r in valid_records}
|
||||
repeated_ids = current_ids & processed_ids
|
||||
if repeated_ids:
|
||||
logger.error(
|
||||
f"检测到重复处理的记录 ID: {list(repeated_ids)[:5]}...,"
|
||||
"说明数据库更新未生效,终止循环"
|
||||
)
|
||||
break
|
||||
|
||||
batch_success = 0
|
||||
|
||||
# 2. 逐条更新(确保每条都正确处理)
|
||||
for record_id, req_body, resp_body in valid_records:
|
||||
try:
|
||||
# 使用 null() 确保设置的是 SQL NULL 而不是 JSON null
|
||||
result = batch_db.execute(
|
||||
update(Usage)
|
||||
.where(Usage.id == record_id)
|
||||
.values(
|
||||
request_body=null(),
|
||||
response_body=null(),
|
||||
request_body_compressed=compress_json(req_body)
|
||||
if req_body
|
||||
else None,
|
||||
response_body_compressed=compress_json(resp_body)
|
||||
if resp_body
|
||||
else None,
|
||||
)
|
||||
)
|
||||
if result.rowcount > 0:
|
||||
batch_success += 1
|
||||
processed_ids.add(record_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"压缩记录 {record_id} 失败: {e}")
|
||||
continue
|
||||
|
||||
batch_db.commit()
|
||||
|
||||
# 3. 检查是否有实际进展
|
||||
if batch_success == 0:
|
||||
no_progress_count += 1
|
||||
if no_progress_count >= 3:
|
||||
logger.error(
|
||||
f"压缩 body 字段连续 {no_progress_count} 批无进展,"
|
||||
"终止循环以避免死循环"
|
||||
)
|
||||
break
|
||||
else:
|
||||
no_progress_count = 0 # 重置计数
|
||||
|
||||
total_compressed += batch_success
|
||||
logger.debug(
|
||||
f"已压缩 {batch_success} 条记录的 body 字段,累计 {total_compressed} 条"
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"压缩 body 字段失败: {e}")
|
||||
try:
|
||||
batch_db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
finally:
|
||||
batch_db.close()
|
||||
|
||||
return total_compressed
|
||||
|
||||
async def _cleanup_compressed_fields(
|
||||
self, db: Session, cutoff_time: datetime, batch_size: int
|
||||
) -> int:
|
||||
"""清理压缩字段(90天后删除压缩的body)
|
||||
|
||||
每批使用短生命周期 session,避免 ORM 缓存问题
|
||||
"""
|
||||
from sqlalchemy import null, update
|
||||
|
||||
total_cleaned = 0
|
||||
|
||||
while True:
|
||||
batch_db = create_session()
|
||||
try:
|
||||
# 查询需要清理压缩字段的记录
|
||||
records_to_clean = (
|
||||
batch_db.query(Usage.id)
|
||||
.filter(Usage.created_at < cutoff_time)
|
||||
.filter(
|
||||
(Usage.request_body_compressed.isnot(None))
|
||||
| (Usage.response_body_compressed.isnot(None))
|
||||
)
|
||||
.limit(batch_size)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not records_to_clean:
|
||||
break
|
||||
|
||||
record_ids = [r.id for r in records_to_clean]
|
||||
|
||||
# 批量更新,使用 null() 确保设置 SQL NULL
|
||||
result = batch_db.execute(
|
||||
update(Usage)
|
||||
.where(Usage.id.in_(record_ids))
|
||||
.values(
|
||||
request_body_compressed=null(),
|
||||
response_body_compressed=null(),
|
||||
)
|
||||
)
|
||||
|
||||
rows_updated = result.rowcount
|
||||
batch_db.commit()
|
||||
|
||||
if rows_updated == 0:
|
||||
logger.warning("清理压缩字段: rowcount=0,可能存在问题")
|
||||
break
|
||||
|
||||
total_cleaned += rows_updated
|
||||
logger.debug(f"已清理 {rows_updated} 条记录的压缩字段,累计 {total_cleaned} 条")
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"清理压缩字段失败: {e}")
|
||||
try:
|
||||
batch_db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
finally:
|
||||
batch_db.close()
|
||||
|
||||
return total_cleaned
|
||||
|
||||
async def _cleanup_header_fields(
|
||||
self, db: Session, cutoff_time: datetime, batch_size: int
|
||||
) -> int:
|
||||
"""清理 request_headers, response_headers 和 provider_request_headers 字段
|
||||
|
||||
每批使用短生命周期 session,避免 ORM 缓存问题
|
||||
"""
|
||||
from sqlalchemy import null, update
|
||||
|
||||
total_cleaned = 0
|
||||
|
||||
while True:
|
||||
batch_db = create_session()
|
||||
try:
|
||||
# 先查询需要清理的记录ID(分批)
|
||||
records_to_clean = (
|
||||
batch_db.query(Usage.id)
|
||||
.filter(Usage.created_at < cutoff_time)
|
||||
.filter(
|
||||
(Usage.request_headers.isnot(None))
|
||||
| (Usage.response_headers.isnot(None))
|
||||
| (Usage.provider_request_headers.isnot(None))
|
||||
)
|
||||
.limit(batch_size)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not records_to_clean:
|
||||
break
|
||||
|
||||
record_ids = [r.id for r in records_to_clean]
|
||||
|
||||
# 批量更新,使用 null() 确保设置 SQL NULL
|
||||
result = batch_db.execute(
|
||||
update(Usage)
|
||||
.where(Usage.id.in_(record_ids))
|
||||
.values(
|
||||
request_headers=null(),
|
||||
response_headers=null(),
|
||||
provider_request_headers=null(),
|
||||
)
|
||||
)
|
||||
|
||||
rows_updated = result.rowcount
|
||||
batch_db.commit()
|
||||
|
||||
if rows_updated == 0:
|
||||
logger.warning("清理 header 字段: rowcount=0,可能存在问题")
|
||||
break
|
||||
|
||||
total_cleaned += rows_updated
|
||||
logger.debug(
|
||||
f"已清理 {rows_updated} 条记录的 header 字段,累计 {total_cleaned} 条"
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"清理 header 字段失败: {e}")
|
||||
try:
|
||||
batch_db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
finally:
|
||||
batch_db.close()
|
||||
|
||||
return total_cleaned
|
||||
|
||||
async def _delete_old_records(self, db: Session, cutoff_time: datetime, batch_size: int) -> int:
|
||||
"""删除过期的完整记录"""
|
||||
total_deleted = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
# 查询要删除的记录ID(分批)
|
||||
records_to_delete = (
|
||||
db.query(Usage.id)
|
||||
.filter(Usage.created_at < cutoff_time)
|
||||
.limit(batch_size)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not records_to_delete:
|
||||
break
|
||||
|
||||
record_ids = [r.id for r in records_to_delete]
|
||||
|
||||
# 执行删除
|
||||
result = db.execute(
|
||||
delete(Usage)
|
||||
.where(Usage.id.in_(record_ids))
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
rows_deleted = result.rowcount
|
||||
db.commit()
|
||||
|
||||
total_deleted += rows_deleted
|
||||
logger.debug(f"已删除 {rows_deleted} 条过期记录,累计 {total_deleted} 条")
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"删除过期记录失败: {e}")
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
|
||||
return total_deleted
|
||||
|
||||
|
||||
# 全局单例
|
||||
_cleanup_scheduler = None
|
||||
|
||||
|
||||
def get_cleanup_scheduler() -> CleanupScheduler:
|
||||
"""获取清理调度器单例"""
|
||||
global _cleanup_scheduler
|
||||
if _cleanup_scheduler is None:
|
||||
_cleanup_scheduler = CleanupScheduler()
|
||||
return _cleanup_scheduler
|
||||
257
src/services/system/config.py
Normal file
257
src/services/system/config.py
Normal file
@@ -0,0 +1,257 @@
|
||||
"""
|
||||
系统配置服务
|
||||
"""
|
||||
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.models.database import Provider, SystemConfig
|
||||
|
||||
|
||||
|
||||
class LogLevel(str, Enum):
|
||||
"""日志记录级别"""
|
||||
|
||||
BASIC = "basic" # 仅记录基本信息(tokens、成本等)
|
||||
HEADERS = "headers" # 记录基本信息+请求/响应头(敏感信息会脱敏)
|
||||
FULL = "full" # 记录完整请求和响应(包含body,敏感信息会脱敏)
|
||||
|
||||
|
||||
class SystemConfigService:
|
||||
"""系统配置服务类"""
|
||||
|
||||
# 默认配置
|
||||
DEFAULT_CONFIGS = {
|
||||
"request_log_level": {
|
||||
"value": LogLevel.BASIC.value,
|
||||
"description": "请求记录级别:basic(基本信息), headers(含请求头), full(完整请求响应)",
|
||||
},
|
||||
"max_request_body_size": {
|
||||
"value": 1048576, # 1MB
|
||||
"description": "最大请求体记录大小(字节),超过此大小的请求体将被截断(仅影响数据库记录,不影响真实API请求)",
|
||||
},
|
||||
"max_response_body_size": {
|
||||
"value": 1048576, # 1MB
|
||||
"description": "最大响应体记录大小(字节),超过此大小的响应体将被截断(仅影响数据库记录,不影响真实API响应)",
|
||||
},
|
||||
"sensitive_headers": {
|
||||
"value": ["authorization", "x-api-key", "api-key", "cookie", "set-cookie"],
|
||||
"description": "敏感请求头列表,这些请求头会被脱敏处理",
|
||||
},
|
||||
# 分级清理策略
|
||||
"detail_log_retention_days": {
|
||||
"value": 7,
|
||||
"description": "详细日志保留天数,超过此天数后压缩 request_body 和 response_body 到压缩字段",
|
||||
},
|
||||
"compressed_log_retention_days": {
|
||||
"value": 90,
|
||||
"description": "压缩日志保留天数,超过此天数后删除压缩的 body 字段(保留headers和统计)",
|
||||
},
|
||||
"header_retention_days": {
|
||||
"value": 90,
|
||||
"description": "请求头保留天数,超过此天数后清空 request_headers 和 response_headers 字段",
|
||||
},
|
||||
"log_retention_days": {
|
||||
"value": 365,
|
||||
"description": "完整日志保留天数,超过此天数后删除整条记录(保留核心统计)",
|
||||
},
|
||||
"enable_auto_cleanup": {
|
||||
"value": True,
|
||||
"description": "是否启用自动清理任务,每天凌晨执行分级清理",
|
||||
},
|
||||
"cleanup_batch_size": {
|
||||
"value": 1000,
|
||||
"description": "每批次清理的记录数,避免单次操作过大影响数据库性能",
|
||||
},
|
||||
"provider_priority_mode": {
|
||||
"value": "provider",
|
||||
"description": "优先级策略:provider(提供商优先模式) 或 global_key(全局Key优先模式)",
|
||||
},
|
||||
"auto_delete_expired_keys": {
|
||||
"value": False,
|
||||
"description": "是否自动删除过期的API Key(True=物理删除,False=仅禁用),仅管理员可配置",
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, db: Session, key: str, default: Any = None) -> Optional[Any]:
|
||||
"""获取系统配置值"""
|
||||
config = db.query(SystemConfig).filter(SystemConfig.key == key).first()
|
||||
if config:
|
||||
return config.value
|
||||
|
||||
# 如果配置不存在,检查默认值
|
||||
if key in cls.DEFAULT_CONFIGS:
|
||||
return cls.DEFAULT_CONFIGS[key]["value"]
|
||||
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def set_config(db: Session, key: str, value: Any, description: str = None) -> SystemConfig:
|
||||
"""设置系统配置值"""
|
||||
config = db.query(SystemConfig).filter(SystemConfig.key == key).first()
|
||||
|
||||
if config:
|
||||
# 更新现有配置
|
||||
config.value = value
|
||||
if description:
|
||||
config.description = description
|
||||
else:
|
||||
# 创建新配置
|
||||
config = SystemConfig(key=key, value=value, description=description)
|
||||
db.add(config)
|
||||
|
||||
db.commit()
|
||||
db.refresh(config)
|
||||
return config
|
||||
|
||||
@staticmethod
|
||||
def get_default_provider(db: Session) -> Optional[str]:
|
||||
"""
|
||||
获取系统默认提供商
|
||||
优先级:1. 管理员设置的默认提供商 2. 数据库中第一个可用提供商
|
||||
"""
|
||||
# 首先尝试获取管理员设置的默认提供商
|
||||
default_provider = SystemConfigService.get_config(db, "default_provider")
|
||||
if default_provider:
|
||||
return default_provider
|
||||
|
||||
# 如果没有设置,fallback到数据库中第一个可用提供商
|
||||
first_provider = db.query(Provider).filter(Provider.is_active == True).first()
|
||||
|
||||
if first_provider:
|
||||
return first_provider.name
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def set_default_provider(db: Session, provider_name: str) -> SystemConfig:
|
||||
"""设置系统默认提供商"""
|
||||
return SystemConfigService.set_config(
|
||||
db, "default_provider", provider_name, "系统默认提供商,当用户未设置个人提供商时使用"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_all_configs(db: Session) -> list:
|
||||
"""获取所有系统配置"""
|
||||
configs = db.query(SystemConfig).all()
|
||||
return [
|
||||
{
|
||||
"key": config.key,
|
||||
"value": config.value,
|
||||
"description": config.description,
|
||||
"updated_at": config.updated_at.isoformat(),
|
||||
}
|
||||
for config in configs
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def delete_config(db: Session, key: str) -> bool:
|
||||
"""删除系统配置"""
|
||||
config = db.query(SystemConfig).filter(SystemConfig.key == key).first()
|
||||
if config:
|
||||
db.delete(config)
|
||||
db.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def init_default_configs(cls, db: Session):
|
||||
"""初始化默认配置"""
|
||||
for key, default_config in cls.DEFAULT_CONFIGS.items():
|
||||
if not db.query(SystemConfig).filter(SystemConfig.key == key).first():
|
||||
config = SystemConfig(
|
||||
key=key,
|
||||
value=default_config["value"],
|
||||
description=default_config["description"],
|
||||
)
|
||||
db.add(config)
|
||||
|
||||
db.commit()
|
||||
logger.info("初始化默认系统配置完成")
|
||||
|
||||
@classmethod
|
||||
def get_log_level(cls, db: Session) -> LogLevel:
|
||||
"""获取日志记录级别"""
|
||||
level = cls.get_config(db, "request_log_level", LogLevel.BASIC.value)
|
||||
if isinstance(level, str):
|
||||
return LogLevel(level)
|
||||
return level
|
||||
|
||||
@classmethod
|
||||
def should_log_headers(cls, db: Session) -> bool:
|
||||
"""是否应该记录请求头"""
|
||||
log_level = cls.get_log_level(db)
|
||||
return log_level in [LogLevel.HEADERS, LogLevel.FULL]
|
||||
|
||||
@classmethod
|
||||
def should_log_body(cls, db: Session) -> bool:
|
||||
"""是否应该记录请求体和响应体"""
|
||||
log_level = cls.get_log_level(db)
|
||||
return log_level == LogLevel.FULL
|
||||
|
||||
@classmethod
|
||||
def should_mask_sensitive_data(cls, db: Session) -> bool:
|
||||
"""是否应该脱敏敏感数据(始终脱敏)"""
|
||||
_ = db # 保持接口一致性
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_sensitive_headers(cls, db: Session) -> list:
|
||||
"""获取敏感请求头列表"""
|
||||
return cls.get_config(db, "sensitive_headers", [])
|
||||
|
||||
@classmethod
|
||||
def mask_sensitive_headers(cls, db: Session, headers: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""脱敏敏感请求头"""
|
||||
if not cls.should_mask_sensitive_data(db):
|
||||
return headers
|
||||
|
||||
sensitive_headers = cls.get_sensitive_headers(db)
|
||||
masked_headers = {}
|
||||
|
||||
for key, value in headers.items():
|
||||
if key.lower() in [h.lower() for h in sensitive_headers]:
|
||||
# 保留前后各4个字符,中间用星号替换
|
||||
if len(str(value)) > 8:
|
||||
masked_value = str(value)[:4] + "****" + str(value)[-4:]
|
||||
else:
|
||||
masked_value = "****"
|
||||
masked_headers[key] = masked_value
|
||||
else:
|
||||
masked_headers[key] = value
|
||||
|
||||
return masked_headers
|
||||
|
||||
@classmethod
|
||||
def truncate_body(cls, db: Session, body: Any, is_request: bool = True) -> Any:
|
||||
"""截断过大的请求体或响应体"""
|
||||
max_size_key = "max_request_body_size" if is_request else "max_response_body_size"
|
||||
max_size = cls.get_config(db, max_size_key, 102400)
|
||||
|
||||
if not body:
|
||||
return body
|
||||
|
||||
# 转换为字符串以计算大小
|
||||
body_str = json.dumps(body) if isinstance(body, (dict, list)) else str(body)
|
||||
|
||||
if len(body_str) > max_size:
|
||||
# 截断并添加提示
|
||||
truncated_str = body_str[:max_size]
|
||||
if isinstance(body, (dict, list)):
|
||||
try:
|
||||
# 尝试保持JSON格式
|
||||
return {
|
||||
"_truncated": True,
|
||||
"_original_size": len(body_str),
|
||||
"_content": truncated_str,
|
||||
}
|
||||
except:
|
||||
pass
|
||||
return truncated_str + f"\n... (truncated, original size: {len(body_str)} bytes)"
|
||||
|
||||
return body
|
||||
187
src/services/system/scheduler.py
Normal file
187
src/services/system/scheduler.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
统一定时任务调度器
|
||||
|
||||
使用 APScheduler 管理所有定时任务,支持时区配置。
|
||||
所有定时任务使用应用时区(APP_TIMEZONE)配置执行时间,
|
||||
数据存储仍然使用 UTC。
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
|
||||
from src.core.logger import logger
|
||||
|
||||
# 应用时区配置,默认为 Asia/Shanghai
|
||||
APP_TIMEZONE = os.getenv("APP_TIMEZONE", "Asia/Shanghai")
|
||||
|
||||
|
||||
class TaskScheduler:
|
||||
"""统一定时任务调度器"""
|
||||
|
||||
_instance: Optional["TaskScheduler"] = None
|
||||
|
||||
def __init__(self):
|
||||
self.scheduler = AsyncIOScheduler(timezone=APP_TIMEZONE)
|
||||
self._started = False
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "TaskScheduler":
|
||||
"""获取调度器单例"""
|
||||
if cls._instance is None:
|
||||
cls._instance = TaskScheduler()
|
||||
return cls._instance
|
||||
|
||||
def add_cron_job(
|
||||
self,
|
||||
func,
|
||||
hour: int,
|
||||
minute: int = 0,
|
||||
job_id: str = None,
|
||||
name: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
添加 cron 定时任务
|
||||
|
||||
Args:
|
||||
func: 要执行的函数
|
||||
hour: 执行时间(小时),使用业务时区
|
||||
minute: 执行时间(分钟)
|
||||
job_id: 任务ID
|
||||
name: 任务名称(用于日志)
|
||||
**kwargs: 传递给任务函数的参数
|
||||
"""
|
||||
trigger = CronTrigger(hour=hour, minute=minute, timezone=APP_TIMEZONE)
|
||||
|
||||
job_id = job_id or func.__name__
|
||||
display_name = name or job_id
|
||||
|
||||
self.scheduler.add_job(
|
||||
func,
|
||||
trigger,
|
||||
id=job_id,
|
||||
name=display_name,
|
||||
replace_existing=True,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"已注册定时任务: {display_name}, "
|
||||
f"执行时间: {hour:02d}:{minute:02d} ({APP_TIMEZONE})"
|
||||
)
|
||||
|
||||
def add_interval_job(
|
||||
self,
|
||||
func,
|
||||
seconds: int = None,
|
||||
minutes: int = None,
|
||||
hours: int = None,
|
||||
job_id: str = None,
|
||||
name: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
添加间隔执行任务
|
||||
|
||||
Args:
|
||||
func: 要执行的函数
|
||||
seconds: 间隔秒数
|
||||
minutes: 间隔分钟数
|
||||
hours: 间隔小时数
|
||||
job_id: 任务ID
|
||||
name: 任务名称
|
||||
**kwargs: 传递给任务函数的参数
|
||||
"""
|
||||
# 构建 trigger 参数,过滤掉 None 值
|
||||
trigger_kwargs = {}
|
||||
if seconds is not None:
|
||||
trigger_kwargs["seconds"] = seconds
|
||||
if minutes is not None:
|
||||
trigger_kwargs["minutes"] = minutes
|
||||
if hours is not None:
|
||||
trigger_kwargs["hours"] = hours
|
||||
|
||||
trigger = IntervalTrigger(**trigger_kwargs)
|
||||
|
||||
job_id = job_id or func.__name__
|
||||
display_name = name or job_id
|
||||
|
||||
# 计算间隔描述
|
||||
interval_parts = []
|
||||
if hours:
|
||||
interval_parts.append(f"{hours}小时")
|
||||
if minutes:
|
||||
interval_parts.append(f"{minutes}分钟")
|
||||
if seconds:
|
||||
interval_parts.append(f"{seconds}秒")
|
||||
interval_desc = "".join(interval_parts) or "未知间隔"
|
||||
|
||||
self.scheduler.add_job(
|
||||
func,
|
||||
trigger,
|
||||
id=job_id,
|
||||
name=display_name,
|
||||
replace_existing=True,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
logger.info(f"已注册间隔任务: {display_name}, 执行间隔: {interval_desc}")
|
||||
|
||||
def start(self):
|
||||
"""启动调度器"""
|
||||
if self._started:
|
||||
logger.warning("调度器已在运行中")
|
||||
return
|
||||
|
||||
self.scheduler.start()
|
||||
self._started = True
|
||||
logger.info(f"定时任务调度器已启动,应用时区: {APP_TIMEZONE}")
|
||||
|
||||
# 打印下次执行时间
|
||||
self._log_next_run_times()
|
||||
|
||||
def stop(self):
|
||||
"""停止调度器"""
|
||||
if not self._started:
|
||||
return
|
||||
|
||||
self.scheduler.shutdown(wait=False)
|
||||
self._started = False
|
||||
logger.info("定时任务调度器已停止")
|
||||
|
||||
def _log_next_run_times(self):
|
||||
"""记录所有任务的下次执行时间"""
|
||||
jobs = self.scheduler.get_jobs()
|
||||
if not jobs:
|
||||
return
|
||||
|
||||
logger.info("已注册的定时任务:")
|
||||
for job in jobs:
|
||||
next_run = job.next_run_time
|
||||
if next_run:
|
||||
# 计算距离下次执行的时间
|
||||
now = datetime.now(next_run.tzinfo)
|
||||
delta = next_run - now
|
||||
hours, remainder = divmod(int(delta.total_seconds()), 3600)
|
||||
minutes = remainder // 60
|
||||
|
||||
logger.info(
|
||||
f" - {job.name}: 下次执行 {next_run.strftime('%Y-%m-%d %H:%M')} "
|
||||
f"({hours}小时{minutes}分钟后)"
|
||||
)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""调度器是否在运行"""
|
||||
return self._started
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def get_scheduler() -> TaskScheduler:
|
||||
"""获取调度器单例"""
|
||||
return TaskScheduler.get_instance()
|
||||
436
src/services/system/stats_aggregator.py
Normal file
436
src/services/system/stats_aggregator.py
Normal file
@@ -0,0 +1,436 @@
|
||||
"""统计数据聚合服务
|
||||
|
||||
实现预聚合统计,避免每次请求都全表扫描。
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import and_, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.models.database import (
|
||||
ApiKey,
|
||||
RequestCandidate,
|
||||
StatsDaily,
|
||||
StatsSummary,
|
||||
StatsUserDaily,
|
||||
Usage,
|
||||
)
|
||||
from src.models.database import User as DBUser
|
||||
|
||||
|
||||
class StatsAggregatorService:
|
||||
"""统计数据聚合服务"""
|
||||
|
||||
@staticmethod
|
||||
def aggregate_daily_stats(db: Session, date: datetime) -> StatsDaily:
|
||||
"""聚合指定日期的统计数据
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
date: 要聚合的日期 (会自动转为 UTC 当天开始)
|
||||
|
||||
Returns:
|
||||
StatsDaily 记录
|
||||
"""
|
||||
# 确保日期是 UTC 当天开始
|
||||
day_start = date.replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=timezone.utc)
|
||||
day_end = day_start + timedelta(days=1)
|
||||
|
||||
# 检查是否已存在该日期的记录
|
||||
existing = db.query(StatsDaily).filter(StatsDaily.date == day_start).first()
|
||||
if existing:
|
||||
stats = existing
|
||||
else:
|
||||
stats = StatsDaily(id=str(uuid.uuid4()), date=day_start)
|
||||
|
||||
# 基础请求统计
|
||||
base_query = db.query(Usage).filter(
|
||||
and_(Usage.created_at >= day_start, Usage.created_at < day_end)
|
||||
)
|
||||
|
||||
total_requests = base_query.count()
|
||||
|
||||
# 如果没有请求,直接返回空记录
|
||||
if total_requests == 0:
|
||||
stats.total_requests = 0
|
||||
stats.success_requests = 0
|
||||
stats.error_requests = 0
|
||||
stats.input_tokens = 0
|
||||
stats.output_tokens = 0
|
||||
stats.cache_creation_tokens = 0
|
||||
stats.cache_read_tokens = 0
|
||||
stats.total_cost = 0.0
|
||||
stats.actual_total_cost = 0.0
|
||||
stats.input_cost = 0.0
|
||||
stats.output_cost = 0.0
|
||||
stats.cache_creation_cost = 0.0
|
||||
stats.cache_read_cost = 0.0
|
||||
stats.avg_response_time_ms = 0.0
|
||||
stats.fallback_count = 0
|
||||
|
||||
if not existing:
|
||||
db.add(stats)
|
||||
db.commit()
|
||||
return stats
|
||||
|
||||
# 错误请求数
|
||||
error_requests = (
|
||||
base_query.filter(
|
||||
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
|
||||
).count()
|
||||
)
|
||||
|
||||
# Token 和成本聚合
|
||||
aggregated = (
|
||||
db.query(
|
||||
func.sum(Usage.input_tokens).label("input_tokens"),
|
||||
func.sum(Usage.output_tokens).label("output_tokens"),
|
||||
func.sum(Usage.cache_creation_input_tokens).label("cache_creation_tokens"),
|
||||
func.sum(Usage.cache_read_input_tokens).label("cache_read_tokens"),
|
||||
func.sum(Usage.total_cost_usd).label("total_cost"),
|
||||
func.sum(Usage.actual_total_cost_usd).label("actual_total_cost"),
|
||||
func.sum(Usage.input_cost_usd).label("input_cost"),
|
||||
func.sum(Usage.output_cost_usd).label("output_cost"),
|
||||
func.sum(Usage.cache_creation_cost_usd).label("cache_creation_cost"),
|
||||
func.sum(Usage.cache_read_cost_usd).label("cache_read_cost"),
|
||||
func.avg(Usage.response_time_ms).label("avg_response_time"),
|
||||
)
|
||||
.filter(and_(Usage.created_at >= day_start, Usage.created_at < day_end))
|
||||
.first()
|
||||
)
|
||||
|
||||
# Fallback 统计 (执行候选数 > 1 的请求数)
|
||||
fallback_subquery = (
|
||||
db.query(
|
||||
RequestCandidate.request_id,
|
||||
func.count(RequestCandidate.id).label("executed_count"),
|
||||
)
|
||||
.filter(
|
||||
and_(
|
||||
RequestCandidate.created_at >= day_start,
|
||||
RequestCandidate.created_at < day_end,
|
||||
RequestCandidate.status.in_(["success", "failed"]),
|
||||
)
|
||||
)
|
||||
.group_by(RequestCandidate.request_id)
|
||||
.subquery()
|
||||
)
|
||||
fallback_count = (
|
||||
db.query(func.count())
|
||||
.select_from(fallback_subquery)
|
||||
.filter(fallback_subquery.c.executed_count > 1)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
# 使用维度统计
|
||||
unique_models = (
|
||||
db.query(func.count(func.distinct(Usage.model)))
|
||||
.filter(and_(Usage.created_at >= day_start, Usage.created_at < day_end))
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
unique_providers = (
|
||||
db.query(func.count(func.distinct(Usage.provider)))
|
||||
.filter(and_(Usage.created_at >= day_start, Usage.created_at < day_end))
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
# 更新统计记录
|
||||
stats.total_requests = total_requests
|
||||
stats.success_requests = total_requests - error_requests
|
||||
stats.error_requests = error_requests
|
||||
stats.input_tokens = int(aggregated.input_tokens or 0)
|
||||
stats.output_tokens = int(aggregated.output_tokens or 0)
|
||||
stats.cache_creation_tokens = int(aggregated.cache_creation_tokens or 0)
|
||||
stats.cache_read_tokens = int(aggregated.cache_read_tokens or 0)
|
||||
stats.total_cost = float(aggregated.total_cost or 0)
|
||||
stats.actual_total_cost = float(aggregated.actual_total_cost or 0)
|
||||
stats.input_cost = float(aggregated.input_cost or 0)
|
||||
stats.output_cost = float(aggregated.output_cost or 0)
|
||||
stats.cache_creation_cost = float(aggregated.cache_creation_cost or 0)
|
||||
stats.cache_read_cost = float(aggregated.cache_read_cost or 0)
|
||||
stats.avg_response_time_ms = float(aggregated.avg_response_time or 0)
|
||||
stats.fallback_count = fallback_count
|
||||
stats.unique_models = unique_models
|
||||
stats.unique_providers = unique_providers
|
||||
|
||||
if not existing:
|
||||
db.add(stats)
|
||||
db.commit()
|
||||
|
||||
logger.info(f"[StatsAggregator] 聚合日期 {day_start.date()} 完成: {total_requests} 请求")
|
||||
return stats
|
||||
|
||||
@staticmethod
|
||||
def aggregate_user_daily_stats(
|
||||
db: Session, user_id: str, date: datetime
|
||||
) -> StatsUserDaily:
|
||||
"""聚合指定用户指定日期的统计数据"""
|
||||
day_start = date.replace(hour=0, minute=0, second=0, microsecond=0, tzinfo=timezone.utc)
|
||||
day_end = day_start + timedelta(days=1)
|
||||
|
||||
existing = (
|
||||
db.query(StatsUserDaily)
|
||||
.filter(and_(StatsUserDaily.user_id == user_id, StatsUserDaily.date == day_start))
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing:
|
||||
stats = existing
|
||||
else:
|
||||
stats = StatsUserDaily(id=str(uuid.uuid4()), user_id=user_id, date=day_start)
|
||||
|
||||
# 用户请求统计
|
||||
base_query = db.query(Usage).filter(
|
||||
and_(
|
||||
Usage.user_id == user_id,
|
||||
Usage.created_at >= day_start,
|
||||
Usage.created_at < day_end,
|
||||
)
|
||||
)
|
||||
|
||||
total_requests = base_query.count()
|
||||
|
||||
if total_requests == 0:
|
||||
stats.total_requests = 0
|
||||
stats.success_requests = 0
|
||||
stats.error_requests = 0
|
||||
stats.input_tokens = 0
|
||||
stats.output_tokens = 0
|
||||
stats.cache_creation_tokens = 0
|
||||
stats.cache_read_tokens = 0
|
||||
stats.total_cost = 0.0
|
||||
|
||||
if not existing:
|
||||
db.add(stats)
|
||||
db.commit()
|
||||
return stats
|
||||
|
||||
error_requests = (
|
||||
base_query.filter(
|
||||
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
|
||||
).count()
|
||||
)
|
||||
|
||||
aggregated = (
|
||||
db.query(
|
||||
func.sum(Usage.input_tokens).label("input_tokens"),
|
||||
func.sum(Usage.output_tokens).label("output_tokens"),
|
||||
func.sum(Usage.cache_creation_input_tokens).label("cache_creation_tokens"),
|
||||
func.sum(Usage.cache_read_input_tokens).label("cache_read_tokens"),
|
||||
func.sum(Usage.total_cost_usd).label("total_cost"),
|
||||
)
|
||||
.filter(
|
||||
and_(
|
||||
Usage.user_id == user_id,
|
||||
Usage.created_at >= day_start,
|
||||
Usage.created_at < day_end,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
stats.total_requests = total_requests
|
||||
stats.success_requests = total_requests - error_requests
|
||||
stats.error_requests = error_requests
|
||||
stats.input_tokens = int(aggregated.input_tokens or 0)
|
||||
stats.output_tokens = int(aggregated.output_tokens or 0)
|
||||
stats.cache_creation_tokens = int(aggregated.cache_creation_tokens or 0)
|
||||
stats.cache_read_tokens = int(aggregated.cache_read_tokens or 0)
|
||||
stats.total_cost = float(aggregated.total_cost or 0)
|
||||
|
||||
if not existing:
|
||||
db.add(stats)
|
||||
db.commit()
|
||||
return stats
|
||||
|
||||
@staticmethod
|
||||
def update_summary(db: Session) -> StatsSummary:
|
||||
"""更新全局统计汇总
|
||||
|
||||
汇总截止到昨天的所有数据。
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
today = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
cutoff_date = today # 不含今天
|
||||
|
||||
# 获取或创建 summary 记录
|
||||
summary = db.query(StatsSummary).first()
|
||||
if not summary:
|
||||
summary = StatsSummary(id=str(uuid.uuid4()), cutoff_date=cutoff_date)
|
||||
|
||||
# 从 stats_daily 聚合历史数据
|
||||
daily_aggregated = (
|
||||
db.query(
|
||||
func.sum(StatsDaily.total_requests).label("total_requests"),
|
||||
func.sum(StatsDaily.success_requests).label("success_requests"),
|
||||
func.sum(StatsDaily.error_requests).label("error_requests"),
|
||||
func.sum(StatsDaily.input_tokens).label("input_tokens"),
|
||||
func.sum(StatsDaily.output_tokens).label("output_tokens"),
|
||||
func.sum(StatsDaily.cache_creation_tokens).label("cache_creation_tokens"),
|
||||
func.sum(StatsDaily.cache_read_tokens).label("cache_read_tokens"),
|
||||
func.sum(StatsDaily.total_cost).label("total_cost"),
|
||||
func.sum(StatsDaily.actual_total_cost).label("actual_total_cost"),
|
||||
)
|
||||
.filter(StatsDaily.date < cutoff_date)
|
||||
.first()
|
||||
)
|
||||
|
||||
# 用户/API Key 统计
|
||||
total_users = db.query(func.count(DBUser.id)).scalar() or 0
|
||||
active_users = (
|
||||
db.query(func.count(DBUser.id)).filter(DBUser.is_active.is_(True)).scalar() or 0
|
||||
)
|
||||
total_api_keys = db.query(func.count(ApiKey.id)).scalar() or 0
|
||||
active_api_keys = (
|
||||
db.query(func.count(ApiKey.id)).filter(ApiKey.is_active.is_(True)).scalar() or 0
|
||||
)
|
||||
|
||||
# 更新 summary
|
||||
summary.cutoff_date = cutoff_date
|
||||
summary.all_time_requests = int(daily_aggregated.total_requests or 0)
|
||||
summary.all_time_success_requests = int(daily_aggregated.success_requests or 0)
|
||||
summary.all_time_error_requests = int(daily_aggregated.error_requests or 0)
|
||||
summary.all_time_input_tokens = int(daily_aggregated.input_tokens or 0)
|
||||
summary.all_time_output_tokens = int(daily_aggregated.output_tokens or 0)
|
||||
summary.all_time_cache_creation_tokens = int(daily_aggregated.cache_creation_tokens or 0)
|
||||
summary.all_time_cache_read_tokens = int(daily_aggregated.cache_read_tokens or 0)
|
||||
summary.all_time_cost = float(daily_aggregated.total_cost or 0)
|
||||
summary.all_time_actual_cost = float(daily_aggregated.actual_total_cost or 0)
|
||||
summary.total_users = total_users
|
||||
summary.active_users = active_users
|
||||
summary.total_api_keys = total_api_keys
|
||||
summary.active_api_keys = active_api_keys
|
||||
|
||||
db.add(summary)
|
||||
db.commit()
|
||||
|
||||
logger.info(f"[StatsAggregator] 更新全局汇总完成,截止日期: {cutoff_date.date()}")
|
||||
return summary
|
||||
|
||||
@staticmethod
|
||||
def get_today_realtime_stats(db: Session) -> dict:
|
||||
"""获取今日实时统计(用于与预聚合数据合并)"""
|
||||
now = datetime.now(timezone.utc)
|
||||
today = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
base_query = db.query(Usage).filter(Usage.created_at >= today)
|
||||
|
||||
total_requests = base_query.count()
|
||||
|
||||
if total_requests == 0:
|
||||
return {
|
||||
"total_requests": 0,
|
||||
"success_requests": 0,
|
||||
"error_requests": 0,
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"cache_creation_tokens": 0,
|
||||
"cache_read_tokens": 0,
|
||||
"total_cost": 0.0,
|
||||
"actual_total_cost": 0.0,
|
||||
}
|
||||
|
||||
error_requests = (
|
||||
base_query.filter(
|
||||
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
|
||||
).count()
|
||||
)
|
||||
|
||||
aggregated = (
|
||||
db.query(
|
||||
func.sum(Usage.input_tokens).label("input_tokens"),
|
||||
func.sum(Usage.output_tokens).label("output_tokens"),
|
||||
func.sum(Usage.cache_creation_input_tokens).label("cache_creation_tokens"),
|
||||
func.sum(Usage.cache_read_input_tokens).label("cache_read_tokens"),
|
||||
func.sum(Usage.total_cost_usd).label("total_cost"),
|
||||
func.sum(Usage.actual_total_cost_usd).label("actual_total_cost"),
|
||||
)
|
||||
.filter(Usage.created_at >= today)
|
||||
.first()
|
||||
)
|
||||
|
||||
return {
|
||||
"total_requests": total_requests,
|
||||
"success_requests": total_requests - error_requests,
|
||||
"error_requests": error_requests,
|
||||
"input_tokens": int(aggregated.input_tokens or 0),
|
||||
"output_tokens": int(aggregated.output_tokens or 0),
|
||||
"cache_creation_tokens": int(aggregated.cache_creation_tokens or 0),
|
||||
"cache_read_tokens": int(aggregated.cache_read_tokens or 0),
|
||||
"total_cost": float(aggregated.total_cost or 0),
|
||||
"actual_total_cost": float(aggregated.actual_total_cost or 0),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_combined_stats(db: Session) -> dict:
|
||||
"""获取合并后的统计数据(预聚合 + 今日实时)"""
|
||||
summary = db.query(StatsSummary).first()
|
||||
today_stats = StatsAggregatorService.get_today_realtime_stats(db)
|
||||
|
||||
if not summary:
|
||||
# 如果没有预聚合数据,返回今日数据
|
||||
return today_stats
|
||||
|
||||
return {
|
||||
"total_requests": summary.all_time_requests + today_stats["total_requests"],
|
||||
"success_requests": summary.all_time_success_requests
|
||||
+ today_stats["success_requests"],
|
||||
"error_requests": summary.all_time_error_requests + today_stats["error_requests"],
|
||||
"input_tokens": summary.all_time_input_tokens + today_stats["input_tokens"],
|
||||
"output_tokens": summary.all_time_output_tokens + today_stats["output_tokens"],
|
||||
"cache_creation_tokens": summary.all_time_cache_creation_tokens
|
||||
+ today_stats["cache_creation_tokens"],
|
||||
"cache_read_tokens": summary.all_time_cache_read_tokens
|
||||
+ today_stats["cache_read_tokens"],
|
||||
"total_cost": summary.all_time_cost + today_stats["total_cost"],
|
||||
"actual_total_cost": summary.all_time_actual_cost + today_stats["actual_total_cost"],
|
||||
"total_users": summary.total_users,
|
||||
"active_users": summary.active_users,
|
||||
"total_api_keys": summary.total_api_keys,
|
||||
"active_api_keys": summary.active_api_keys,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def backfill_historical_data(db: Session, days: int = 365) -> int:
|
||||
"""回填历史数据(首次部署时使用)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
days: 要回填的天数
|
||||
|
||||
Returns:
|
||||
回填的天数
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
today = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
# 找到最早的 Usage 记录
|
||||
earliest = db.query(func.min(Usage.created_at)).scalar()
|
||||
if not earliest:
|
||||
logger.info("[StatsAggregator] 没有历史数据需要回填")
|
||||
return 0
|
||||
|
||||
# 计算需要回填的日期范围
|
||||
earliest_date = earliest.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
start_date = max(earliest_date, today - timedelta(days=days))
|
||||
|
||||
count = 0
|
||||
current_date = start_date
|
||||
while current_date < today:
|
||||
StatsAggregatorService.aggregate_daily_stats(db, current_date)
|
||||
count += 1
|
||||
current_date += timedelta(days=1)
|
||||
|
||||
# 更新汇总
|
||||
if count > 0:
|
||||
StatsAggregatorService.update_summary(db)
|
||||
|
||||
logger.info(f"[StatsAggregator] 回填历史数据完成,共 {count} 天")
|
||||
return count
|
||||
142
src/services/system/sync_stats.py
Normal file
142
src/services/system/sync_stats.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
API密钥统计同步服务
|
||||
定期同步API密钥的统计数据,确保与实际使用记录一致
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.models.database import ApiKey, Usage
|
||||
|
||||
|
||||
|
||||
class SyncStatsService:
|
||||
"""API密钥统计同步服务"""
|
||||
|
||||
# 分页批量大小
|
||||
BATCH_SIZE = 100
|
||||
|
||||
@staticmethod
|
||||
def sync_api_key_stats(db: Session, api_key_id: Optional[str] = None) -> dict: # UUID
|
||||
"""
|
||||
同步API密钥的统计数据
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
api_key_id: 指定要同步的API密钥ID,如果不指定则同步所有
|
||||
|
||||
Returns:
|
||||
同步结果统计
|
||||
"""
|
||||
result = {"synced": 0, "updated": 0, "errors": 0}
|
||||
|
||||
try:
|
||||
# 获取要同步的API密钥(使用分页避免大数据量问题)
|
||||
if api_key_id:
|
||||
api_keys = db.query(ApiKey).filter(ApiKey.id == api_key_id).all()
|
||||
else:
|
||||
# 分页处理,避免一次加载所有数据
|
||||
offset = 0
|
||||
api_keys = []
|
||||
while True:
|
||||
batch = db.query(ApiKey).offset(offset).limit(SyncStatsService.BATCH_SIZE).all()
|
||||
if not batch:
|
||||
break
|
||||
api_keys.extend(batch)
|
||||
offset += SyncStatsService.BATCH_SIZE
|
||||
|
||||
for api_key in api_keys:
|
||||
try:
|
||||
# 计算实际的使用统计
|
||||
stats = (
|
||||
db.query(
|
||||
func.count(Usage.id).label("requests"),
|
||||
func.sum(Usage.total_cost_usd).label("cost"),
|
||||
)
|
||||
.filter(Usage.api_key_id == api_key.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
actual_requests = stats.requests or 0
|
||||
actual_cost = float(stats.cost or 0)
|
||||
|
||||
# 获取最后使用时间
|
||||
last_usage = (
|
||||
db.query(Usage.created_at)
|
||||
.filter(Usage.api_key_id == api_key.id)
|
||||
.order_by(Usage.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
# 检查是否需要更新
|
||||
needs_update = False
|
||||
if api_key.total_requests != actual_requests:
|
||||
logger.info(f"API密钥 {api_key.id} 请求数不一致: {api_key.total_requests} -> {actual_requests}")
|
||||
api_key.total_requests = actual_requests
|
||||
needs_update = True
|
||||
|
||||
if abs(api_key.total_cost_usd - actual_cost) > 0.0001:
|
||||
logger.info(f"API密钥 {api_key.id} 费用不一致: {api_key.total_cost_usd} -> {actual_cost}")
|
||||
api_key.total_cost_usd = actual_cost
|
||||
needs_update = True
|
||||
|
||||
if last_usage and api_key.last_used_at != last_usage[0]:
|
||||
api_key.last_used_at = last_usage[0]
|
||||
needs_update = True
|
||||
|
||||
result["synced"] += 1
|
||||
if needs_update:
|
||||
result["updated"] += 1
|
||||
logger.info(f"已更新API密钥 {api_key.id} 的统计数据")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"同步API密钥 {api_key.id} 统计时出错: {e}")
|
||||
result["errors"] += 1
|
||||
# 回滚当前失败的操作,继续处理其他密钥
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 提交所有更改
|
||||
db.commit()
|
||||
logger.info(f"同步完成: 处理 {result['synced']} 个密钥, 更新 {result['updated']} 个, 错误 {result['errors']} 个")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"同步统计数据时出错: {e}")
|
||||
db.rollback()
|
||||
raise
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_api_key_real_stats(db: Session, api_key_id: str) -> dict: # UUID
|
||||
"""
|
||||
获取API密钥的实际统计数据(直接从使用记录计算)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
api_key_id: API密钥ID
|
||||
|
||||
Returns:
|
||||
实际的统计数据
|
||||
"""
|
||||
# 计算实际的使用统计
|
||||
stats = (
|
||||
db.query(
|
||||
func.count(Usage.id).label("requests"),
|
||||
func.sum(Usage.total_cost_usd).label("cost"),
|
||||
func.max(Usage.created_at).label("last_used"),
|
||||
)
|
||||
.filter(Usage.api_key_id == api_key_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
return {
|
||||
"total_requests": stats.requests or 0,
|
||||
"total_cost_usd": float(stats.cost or 0),
|
||||
"last_used_at": stats.last_used,
|
||||
}
|
||||
Reference in New Issue
Block a user