mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-09 11:12:28 +08:00
refactor: optimize middleware with pure ASGI implementation and enhance security measures
- Replace BaseHTTPMiddleware with pure ASGI implementation in plugin middleware for better streaming response handling - Add trusted proxy count configuration for client IP extraction in reverse proxy environments - Implement audit log cleanup scheduler with configurable retention period - Replace plaintext token logging with SHA256 hash fingerprints for security - Fix database session lifecycle management in middleware - Improve request tracing and error logging throughout the system - Add comprehensive tests for pipeline architecture
This commit is contained in:
@@ -63,14 +63,16 @@ class JWTBlacklistService:
|
||||
|
||||
if ttl_seconds <= 0:
|
||||
# Token 已经过期,不需要加入黑名单
|
||||
logger.debug(f"Token 已过期,无需加入黑名单: {token[:10]}...")
|
||||
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
|
||||
logger.debug("Token 已过期,无需加入黑名单: token_fp={}", token_fp)
|
||||
return True
|
||||
|
||||
# 存储到 Redis,设置 TTL 为 Token 过期时间
|
||||
# 值存储为原因字符串
|
||||
await redis_client.setex(redis_key, ttl_seconds, reason)
|
||||
|
||||
logger.info(f"Token 已加入黑名单: {token[:10]}... (原因: {reason}, TTL: {ttl_seconds}s)")
|
||||
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
|
||||
logger.info("Token 已加入黑名单: token_fp={} (原因: {}, TTL: {}s)", token_fp, reason, ttl_seconds)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
@@ -109,7 +111,8 @@ class JWTBlacklistService:
|
||||
if exists:
|
||||
# 获取黑名单原因(可选)
|
||||
reason = await redis_client.get(redis_key)
|
||||
logger.warning(f"检测到黑名单 Token: {token[:10]}... (原因: {reason})")
|
||||
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
|
||||
logger.warning("检测到黑名单 Token: token_fp={} (原因: {})", token_fp, reason)
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -148,9 +151,11 @@ class JWTBlacklistService:
|
||||
deleted = await redis_client.delete(redis_key)
|
||||
|
||||
if deleted:
|
||||
logger.info(f"Token 已从黑名单移除: {token[:10]}...")
|
||||
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
|
||||
logger.info("Token 已从黑名单移除: token_fp={}", token_fp)
|
||||
else:
|
||||
logger.debug(f"Token 不在黑名单中: {token[:10]}...")
|
||||
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
|
||||
logger.debug("Token 不在黑名单中: token_fp={}", token_fp)
|
||||
|
||||
return bool(deleted)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
"""
|
||||
|
||||
import os
|
||||
import hashlib
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
@@ -169,7 +170,8 @@ class AuthService:
|
||||
key_record.last_used_at = datetime.now(timezone.utc)
|
||||
db.commit() # 立即提交事务,释放数据库锁,避免阻塞后续请求
|
||||
|
||||
logger.debug(f"API认证成功: 用户 {user.email} (Key: {api_key[:10]}...)")
|
||||
api_key_fp = hashlib.sha256(api_key.encode()).hexdigest()[:12]
|
||||
logger.debug("API认证成功: 用户 {} (api_key_fp={})", user.email, api_key_fp)
|
||||
return user, key_record
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -21,7 +21,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.database import create_session
|
||||
from src.models.database import Usage
|
||||
from src.models.database import AuditLog, Usage
|
||||
from src.services.system.config import SystemConfigService
|
||||
from src.services.system.scheduler import get_scheduler
|
||||
from src.services.system.stats_aggregator import StatsAggregatorService
|
||||
@@ -91,6 +91,15 @@ class CleanupScheduler:
|
||||
name="Pending状态清理",
|
||||
)
|
||||
|
||||
# 审计日志清理 - 凌晨 4 点执行
|
||||
scheduler.add_cron_job(
|
||||
self._scheduled_audit_cleanup,
|
||||
hour=4,
|
||||
minute=0,
|
||||
job_id="audit_cleanup",
|
||||
name="审计日志清理",
|
||||
)
|
||||
|
||||
# 启动时执行一次初始化任务
|
||||
asyncio.create_task(self._run_startup_tasks())
|
||||
|
||||
@@ -145,6 +154,10 @@ class CleanupScheduler:
|
||||
"""Pending 清理任务(定时调用)"""
|
||||
await self._perform_pending_cleanup()
|
||||
|
||||
async def _scheduled_audit_cleanup(self):
|
||||
"""审计日志清理任务(定时调用)"""
|
||||
await self._perform_audit_cleanup()
|
||||
|
||||
# ========== 实际任务实现 ==========
|
||||
|
||||
async def _perform_stats_aggregation(self, backfill: bool = False):
|
||||
@@ -330,6 +343,70 @@ class CleanupScheduler:
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def _perform_audit_cleanup(self):
|
||||
"""执行审计日志清理任务"""
|
||||
db = create_session()
|
||||
try:
|
||||
# 检查是否启用自动清理
|
||||
if not SystemConfigService.get_config(db, "enable_auto_cleanup", True):
|
||||
logger.info("自动清理已禁用,跳过审计日志清理")
|
||||
return
|
||||
|
||||
# 获取审计日志保留天数(默认 30 天,最少 7 天)
|
||||
audit_retention_days = max(
|
||||
SystemConfigService.get_config(db, "audit_log_retention_days", 30),
|
||||
7, # 最少保留 7 天,防止误配置删除所有审计日志
|
||||
)
|
||||
batch_size = SystemConfigService.get_config(db, "cleanup_batch_size", 1000)
|
||||
|
||||
cutoff_time = datetime.now(timezone.utc) - timedelta(days=audit_retention_days)
|
||||
|
||||
logger.info(f"开始清理 {audit_retention_days} 天前的审计日志...")
|
||||
|
||||
total_deleted = 0
|
||||
while True:
|
||||
# 先查询要删除的记录 ID(分批)
|
||||
records_to_delete = (
|
||||
db.query(AuditLog.id)
|
||||
.filter(AuditLog.created_at < cutoff_time)
|
||||
.limit(batch_size)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not records_to_delete:
|
||||
break
|
||||
|
||||
record_ids = [r.id for r in records_to_delete]
|
||||
|
||||
# 执行删除
|
||||
result = db.execute(
|
||||
delete(AuditLog)
|
||||
.where(AuditLog.id.in_(record_ids))
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
rows_deleted = result.rowcount
|
||||
db.commit()
|
||||
|
||||
total_deleted += rows_deleted
|
||||
logger.debug(f"已删除 {rows_deleted} 条审计日志,累计 {total_deleted} 条")
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
if total_deleted > 0:
|
||||
logger.info(f"审计日志清理完成,共删除 {total_deleted} 条记录")
|
||||
else:
|
||||
logger.info("无需清理的审计日志")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"审计日志清理失败: {e}")
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def _perform_cleanup(self):
|
||||
"""执行清理任务"""
|
||||
db = create_session()
|
||||
|
||||
@@ -1217,15 +1217,19 @@ class UsageService:
|
||||
request_id: str,
|
||||
status: str,
|
||||
error_message: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
target_model: Optional[str] = None,
|
||||
) -> Optional[Usage]:
|
||||
"""
|
||||
快速更新使用记录状态(不更新其他字段)
|
||||
快速更新使用记录状态
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
request_id: 请求ID
|
||||
status: 新状态 (pending, streaming, completed, failed)
|
||||
error_message: 错误消息(仅在 failed 状态时使用)
|
||||
provider: 提供商名称(可选,streaming 状态时更新)
|
||||
target_model: 映射后的目标模型名(可选)
|
||||
|
||||
Returns:
|
||||
更新后的 Usage 记录,如果未找到则返回 None
|
||||
@@ -1239,6 +1243,10 @@ class UsageService:
|
||||
usage.status = status
|
||||
if error_message:
|
||||
usage.error_message = error_message
|
||||
if provider:
|
||||
usage.provider = provider
|
||||
if target_model:
|
||||
usage.target_model = target_model
|
||||
|
||||
db.commit()
|
||||
|
||||
|
||||
@@ -457,7 +457,7 @@ class StreamUsageTracker:
|
||||
|
||||
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应 | 估算输入tokens:{self.input_tokens}")
|
||||
|
||||
# 更新状态为 streaming
|
||||
# 更新状态为 streaming,同时更新 provider
|
||||
if self.request_id:
|
||||
try:
|
||||
from src.services.usage.service import UsageService
|
||||
@@ -465,6 +465,7 @@ class StreamUsageTracker:
|
||||
db=self.db,
|
||||
request_id=self.request_id,
|
||||
status="streaming",
|
||||
provider=self.provider,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
|
||||
|
||||
@@ -210,7 +210,15 @@ class ApiKeyService:
|
||||
|
||||
@staticmethod
|
||||
def check_rate_limit(db: Session, api_key: ApiKey, window_minutes: int = 1) -> tuple[bool, int]:
|
||||
"""检查速率限制"""
|
||||
"""检查速率限制
|
||||
|
||||
Returns:
|
||||
(is_allowed, remaining): 是否允许请求,剩余可用次数
|
||||
当 rate_limit 为 None 时表示不限制,返回 (True, -1)
|
||||
"""
|
||||
# 如果 rate_limit 为 None,表示不限制
|
||||
if api_key.rate_limit is None:
|
||||
return True, -1 # -1 表示无限制
|
||||
|
||||
# 计算时间窗口
|
||||
window_start = datetime.now(timezone.utc) - timedelta(minutes=window_minutes)
|
||||
|
||||
Reference in New Issue
Block a user