refactor: optimize middleware with pure ASGI implementation and enhance security measures

- Replace BaseHTTPMiddleware with pure ASGI implementation in plugin middleware for better streaming response handling
- Add trusted proxy count configuration for client IP extraction in reverse proxy environments
- Implement audit log cleanup scheduler with configurable retention period
- Replace plaintext token logging with SHA256 hash fingerprints for security
- Fix database session lifecycle management in middleware
- Improve request tracing and error logging throughout the system
- Add comprehensive tests for pipeline architecture
This commit is contained in:
fawney19
2025-12-18 19:07:20 +08:00
parent c7b971cfe7
commit 7b932d7afb
24 changed files with 497 additions and 219 deletions

View File

@@ -28,7 +28,7 @@
from __future__ import annotations
import time
from typing import Any, Callable, Dict, Optional, Protocol, runtime_checkable
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Protocol, runtime_checkable
from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse
@@ -43,6 +43,9 @@ from src.services.provider.format import normalize_api_format
from src.services.system.audit import audit_service
from src.services.usage.service import UsageService
if TYPE_CHECKING:
from src.api.handlers.base.stream_context import StreamContext
class MessageTelemetry:
@@ -399,6 +402,41 @@ class BaseMessageHandler:
# 创建后台任务,不阻塞当前流
asyncio.create_task(_do_update())
def _update_usage_to_streaming_with_ctx(self, ctx: "StreamContext") -> None:
"""更新 Usage 状态为 streaming同时更新 provider 和 target_model
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
Args:
ctx: 流式上下文,包含 provider_name 和 mapped_model
"""
import asyncio
from src.database.database import get_db
target_request_id = self.request_id
provider = ctx.provider_name
target_model = ctx.mapped_model
async def _do_update() -> None:
try:
db_gen = get_db()
db = next(db_gen)
try:
UsageService.update_usage_status(
db=db,
request_id=target_request_id,
status="streaming",
provider=provider,
target_model=target_model,
)
finally:
db.close()
except Exception as e:
logger.warning(f"[{target_request_id}] 更新 Usage 状态为 streaming 失败: {e}")
# 创建后台任务,不阻塞当前流
asyncio.create_task(_do_update())
def _log_request_error(self, message: str, error: Exception) -> None:
"""记录请求错误日志,对业务异常不打印堆栈

View File

@@ -297,11 +297,15 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
# 创建类型安全的流式上下文
ctx = StreamContext(model=model, api_format=api_format)
# 创建更新状态的回调闭包(可以访问 ctx
def update_streaming_status() -> None:
self._update_usage_to_streaming_with_ctx(ctx)
# 创建流处理器
stream_processor = StreamProcessor(
request_id=self.request_id,
default_parser=self.parser,
on_streaming_start=self._update_usage_to_streaming,
on_streaming_start=update_streaming_status,
)
# 定义请求函数

View File

@@ -532,7 +532,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
async for chunk in stream_response.aiter_raw():
# 在第一次输出数据前更新状态为 streaming
if not streaming_status_updated:
self._update_usage_to_streaming(ctx.request_id)
self._update_usage_to_streaming_with_ctx(ctx)
streaming_status_updated = True
buffer += chunk
@@ -816,7 +816,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 在第一次输出数据前更新状态为 streaming
if prefetched_chunks:
self._update_usage_to_streaming(ctx.request_id)
self._update_usage_to_streaming_with_ctx(ctx)
# 先处理预读的字节块
for chunk in prefetched_chunks: