mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-07 18:22:28 +08:00
refactor: optimize middleware with pure ASGI implementation and enhance security measures
- Replace BaseHTTPMiddleware with pure ASGI implementation in plugin middleware for better streaming response handling - Add trusted proxy count configuration for client IP extraction in reverse proxy environments - Implement audit log cleanup scheduler with configurable retention period - Replace plaintext token logging with SHA256 hash fingerprints for security - Fix database session lifecycle management in middleware - Improve request tracing and error logging throughout the system - Add comprehensive tests for pipeline architecture
This commit is contained in:
@@ -223,7 +223,7 @@ class AdminCreateStandaloneKeyAdapter(AdminApiAdapter):
|
||||
allowed_providers=self.key_data.allowed_providers,
|
||||
allowed_api_formats=self.key_data.allowed_api_formats,
|
||||
allowed_models=self.key_data.allowed_models,
|
||||
rate_limit=self.key_data.rate_limit or 100,
|
||||
rate_limit=self.key_data.rate_limit, # None 表示不限制
|
||||
expire_days=self.key_data.expire_days,
|
||||
initial_balance_usd=self.key_data.initial_balance_usd,
|
||||
is_standalone=True, # 标记为独立Key
|
||||
|
||||
@@ -140,7 +140,7 @@ class AnnouncementOptionalAuthAdapter(ApiAdapter):
|
||||
if not authorization or not authorization.lower().startswith("bearer "):
|
||||
return None
|
||||
|
||||
token = authorization.replace("Bearer ", "").strip()
|
||||
token = authorization[7:].strip()
|
||||
try:
|
||||
payload = await AuthService.verify_token(token, token_type="access")
|
||||
user_id = payload.get("user_id")
|
||||
|
||||
@@ -177,7 +177,7 @@ class ApiRequestPipeline:
|
||||
if not authorization or not authorization.lower().startswith("bearer "):
|
||||
raise HTTPException(status_code=401, detail="缺少管理员凭证")
|
||||
|
||||
token = authorization.replace("Bearer ", "").strip()
|
||||
token = authorization[7:].strip()
|
||||
try:
|
||||
payload = await self.auth_service.verify_token(token, token_type="access")
|
||||
except HTTPException:
|
||||
@@ -204,7 +204,7 @@ class ApiRequestPipeline:
|
||||
if not authorization or not authorization.lower().startswith("bearer "):
|
||||
raise HTTPException(status_code=401, detail="缺少用户凭证")
|
||||
|
||||
token = authorization.replace("Bearer ", "").strip()
|
||||
token = authorization[7:].strip()
|
||||
try:
|
||||
payload = await self.auth_service.verify_token(token, token_type="access")
|
||||
except HTTPException:
|
||||
|
||||
@@ -28,7 +28,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any, Callable, Dict, Optional, Protocol, runtime_checkable
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Protocol, runtime_checkable
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
@@ -43,6 +43,9 @@ from src.services.provider.format import normalize_api_format
|
||||
from src.services.system.audit import audit_service
|
||||
from src.services.usage.service import UsageService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.api.handlers.base.stream_context import StreamContext
|
||||
|
||||
|
||||
|
||||
class MessageTelemetry:
|
||||
@@ -399,6 +402,41 @@ class BaseMessageHandler:
|
||||
# 创建后台任务,不阻塞当前流
|
||||
asyncio.create_task(_do_update())
|
||||
|
||||
def _update_usage_to_streaming_with_ctx(self, ctx: "StreamContext") -> None:
|
||||
"""更新 Usage 状态为 streaming,同时更新 provider 和 target_model
|
||||
|
||||
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
|
||||
|
||||
Args:
|
||||
ctx: 流式上下文,包含 provider_name 和 mapped_model
|
||||
"""
|
||||
import asyncio
|
||||
from src.database.database import get_db
|
||||
|
||||
target_request_id = self.request_id
|
||||
provider = ctx.provider_name
|
||||
target_model = ctx.mapped_model
|
||||
|
||||
async def _do_update() -> None:
|
||||
try:
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
try:
|
||||
UsageService.update_usage_status(
|
||||
db=db,
|
||||
request_id=target_request_id,
|
||||
status="streaming",
|
||||
provider=provider,
|
||||
target_model=target_model,
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"[{target_request_id}] 更新 Usage 状态为 streaming 失败: {e}")
|
||||
|
||||
# 创建后台任务,不阻塞当前流
|
||||
asyncio.create_task(_do_update())
|
||||
|
||||
def _log_request_error(self, message: str, error: Exception) -> None:
|
||||
"""记录请求错误日志,对业务异常不打印堆栈
|
||||
|
||||
|
||||
@@ -297,11 +297,15 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
# 创建类型安全的流式上下文
|
||||
ctx = StreamContext(model=model, api_format=api_format)
|
||||
|
||||
# 创建更新状态的回调闭包(可以访问 ctx)
|
||||
def update_streaming_status() -> None:
|
||||
self._update_usage_to_streaming_with_ctx(ctx)
|
||||
|
||||
# 创建流处理器
|
||||
stream_processor = StreamProcessor(
|
||||
request_id=self.request_id,
|
||||
default_parser=self.parser,
|
||||
on_streaming_start=self._update_usage_to_streaming,
|
||||
on_streaming_start=update_streaming_status,
|
||||
)
|
||||
|
||||
# 定义请求函数
|
||||
|
||||
@@ -532,7 +532,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
async for chunk in stream_response.aiter_raw():
|
||||
# 在第一次输出数据前更新状态为 streaming
|
||||
if not streaming_status_updated:
|
||||
self._update_usage_to_streaming(ctx.request_id)
|
||||
self._update_usage_to_streaming_with_ctx(ctx)
|
||||
streaming_status_updated = True
|
||||
|
||||
buffer += chunk
|
||||
@@ -816,7 +816,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
|
||||
# 在第一次输出数据前更新状态为 streaming
|
||||
if prefetched_chunks:
|
||||
self._update_usage_to_streaming(ctx.request_id)
|
||||
self._update_usage_to_streaming_with_ctx(ctx)
|
||||
|
||||
# 先处理预读的字节块
|
||||
for chunk in prefetched_chunks:
|
||||
|
||||
Reference in New Issue
Block a user