mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 00:02:28 +08:00
refactor: optimize middleware with pure ASGI implementation and enhance security measures
- Replace BaseHTTPMiddleware with pure ASGI implementation in plugin middleware for better streaming response handling - Add trusted proxy count configuration for client IP extraction in reverse proxy environments - Implement audit log cleanup scheduler with configurable retention period - Replace plaintext token logging with SHA256 hash fingerprints for security - Fix database session lifecycle management in middleware - Improve request tracing and error logging throughout the system - Add comprehensive tests for pipeline architecture
This commit is contained in:
@@ -132,7 +132,7 @@
|
|||||||
type="number"
|
type="number"
|
||||||
min="1"
|
min="1"
|
||||||
max="10000"
|
max="10000"
|
||||||
placeholder="100"
|
placeholder="留空不限制"
|
||||||
class="h-10"
|
class="h-10"
|
||||||
@update:model-value="(v) => form.rate_limit = parseNumberInput(v, { min: 1, max: 10000 })"
|
@update:model-value="(v) => form.rate_limit = parseNumberInput(v, { min: 1, max: 10000 })"
|
||||||
/>
|
/>
|
||||||
@@ -376,7 +376,7 @@ const form = ref<StandaloneKeyFormData>({
|
|||||||
initial_balance_usd: 10,
|
initial_balance_usd: 10,
|
||||||
expire_days: undefined,
|
expire_days: undefined,
|
||||||
never_expire: true,
|
never_expire: true,
|
||||||
rate_limit: 100,
|
rate_limit: undefined,
|
||||||
auto_delete_on_expiry: false,
|
auto_delete_on_expiry: false,
|
||||||
allowed_providers: [],
|
allowed_providers: [],
|
||||||
allowed_api_formats: [],
|
allowed_api_formats: [],
|
||||||
@@ -389,7 +389,7 @@ function resetForm() {
|
|||||||
initial_balance_usd: 10,
|
initial_balance_usd: 10,
|
||||||
expire_days: undefined,
|
expire_days: undefined,
|
||||||
never_expire: true,
|
never_expire: true,
|
||||||
rate_limit: 100,
|
rate_limit: undefined,
|
||||||
auto_delete_on_expiry: false,
|
auto_delete_on_expiry: false,
|
||||||
allowed_providers: [],
|
allowed_providers: [],
|
||||||
allowed_api_formats: [],
|
allowed_api_formats: [],
|
||||||
@@ -408,7 +408,7 @@ function loadKeyData() {
|
|||||||
initial_balance_usd: props.apiKey.initial_balance_usd,
|
initial_balance_usd: props.apiKey.initial_balance_usd,
|
||||||
expire_days: props.apiKey.expire_days,
|
expire_days: props.apiKey.expire_days,
|
||||||
never_expire: props.apiKey.never_expire,
|
never_expire: props.apiKey.never_expire,
|
||||||
rate_limit: props.apiKey.rate_limit || 100,
|
rate_limit: props.apiKey.rate_limit,
|
||||||
auto_delete_on_expiry: props.apiKey.auto_delete_on_expiry,
|
auto_delete_on_expiry: props.apiKey.auto_delete_on_expiry,
|
||||||
allowed_providers: props.apiKey.allowed_providers || [],
|
allowed_providers: props.apiKey.allowed_providers || [],
|
||||||
allowed_api_formats: props.apiKey.allowed_api_formats || [],
|
allowed_api_formats: props.apiKey.allowed_api_formats || [],
|
||||||
|
|||||||
@@ -25,7 +25,7 @@
|
|||||||
</h3>
|
</h3>
|
||||||
<div class="flex items-center gap-1 text-sm font-mono text-muted-foreground bg-muted px-2 py-0.5 rounded">
|
<div class="flex items-center gap-1 text-sm font-mono text-muted-foreground bg-muted px-2 py-0.5 rounded">
|
||||||
<span>{{ detail?.model || '-' }}</span>
|
<span>{{ detail?.model || '-' }}</span>
|
||||||
<template v-if="detail?.target_model">
|
<template v-if="detail?.target_model && detail.target_model !== detail.model">
|
||||||
<svg
|
<svg
|
||||||
xmlns="http://www.w3.org/2000/svg"
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
viewBox="0 0 20 20"
|
viewBox="0 0 20 20"
|
||||||
|
|||||||
@@ -185,32 +185,13 @@
|
|||||||
</div>
|
</div>
|
||||||
</CardSection>
|
</CardSection>
|
||||||
|
|
||||||
<!-- API Key 管理配置 -->
|
<!-- 独立余额 Key 过期管理 -->
|
||||||
<CardSection
|
<CardSection
|
||||||
title="API Key 管理"
|
title="独立余额 Key 过期管理"
|
||||||
description="API Key 相关配置"
|
description="独立余额 Key 的过期处理策略(普通用户 Key 不会过期)"
|
||||||
>
|
>
|
||||||
<div class="grid grid-cols-1 md:grid-cols-2 gap-6">
|
<div class="grid grid-cols-1 md:grid-cols-2 gap-6">
|
||||||
<div>
|
<div class="flex items-center h-full">
|
||||||
<Label
|
|
||||||
for="api-key-expire"
|
|
||||||
class="block text-sm font-medium"
|
|
||||||
>
|
|
||||||
API密钥过期天数
|
|
||||||
</Label>
|
|
||||||
<Input
|
|
||||||
id="api-key-expire"
|
|
||||||
v-model.number="systemConfig.api_key_expire_days"
|
|
||||||
type="number"
|
|
||||||
placeholder="0"
|
|
||||||
class="mt-1"
|
|
||||||
/>
|
|
||||||
<p class="mt-1 text-xs text-muted-foreground">
|
|
||||||
0 表示永不过期
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="flex items-center h-full pt-6">
|
|
||||||
<div class="flex items-center space-x-2">
|
<div class="flex items-center space-x-2">
|
||||||
<Checkbox
|
<Checkbox
|
||||||
id="auto-delete-expired-keys"
|
id="auto-delete-expired-keys"
|
||||||
@@ -224,7 +205,7 @@
|
|||||||
自动删除过期 Key
|
自动删除过期 Key
|
||||||
</Label>
|
</Label>
|
||||||
<p class="text-xs text-muted-foreground">
|
<p class="text-xs text-muted-foreground">
|
||||||
关闭时仅禁用过期 Key
|
关闭时仅禁用过期 Key,不会物理删除
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -448,6 +429,25 @@
|
|||||||
避免单次操作过大影响性能
|
避免单次操作过大影响性能
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<Label
|
||||||
|
for="audit-log-retention-days"
|
||||||
|
class="block text-sm font-medium"
|
||||||
|
>
|
||||||
|
审计日志保留天数
|
||||||
|
</Label>
|
||||||
|
<Input
|
||||||
|
id="audit-log-retention-days"
|
||||||
|
v-model.number="systemConfig.audit_log_retention_days"
|
||||||
|
type="number"
|
||||||
|
placeholder="30"
|
||||||
|
class="mt-1"
|
||||||
|
/>
|
||||||
|
<p class="mt-1 text-xs text-muted-foreground">
|
||||||
|
超过后删除审计日志记录
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- 清理策略说明 -->
|
<!-- 清理策略说明 -->
|
||||||
@@ -460,6 +460,7 @@
|
|||||||
<p>2. <strong>压缩日志阶段</strong>: body 字段被压缩存储,节省空间</p>
|
<p>2. <strong>压缩日志阶段</strong>: body 字段被压缩存储,节省空间</p>
|
||||||
<p>3. <strong>统计阶段</strong>: 仅保留 tokens、成本等统计信息</p>
|
<p>3. <strong>统计阶段</strong>: 仅保留 tokens、成本等统计信息</p>
|
||||||
<p>4. <strong>归档删除</strong>: 超过保留期限后完全删除记录</p>
|
<p>4. <strong>归档删除</strong>: 超过保留期限后完全删除记录</p>
|
||||||
|
<p>5. <strong>审计日志</strong>: 独立清理,记录用户登录、操作等安全事件</p>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</CardSection>
|
</CardSection>
|
||||||
@@ -796,8 +797,7 @@ interface SystemConfig {
|
|||||||
// 用户注册
|
// 用户注册
|
||||||
enable_registration: boolean
|
enable_registration: boolean
|
||||||
require_email_verification: boolean
|
require_email_verification: boolean
|
||||||
// API Key 管理
|
// 独立余额 Key 过期管理
|
||||||
api_key_expire_days: number
|
|
||||||
auto_delete_expired_keys: boolean
|
auto_delete_expired_keys: boolean
|
||||||
// 日志记录
|
// 日志记录
|
||||||
request_log_level: string
|
request_log_level: string
|
||||||
@@ -811,6 +811,7 @@ interface SystemConfig {
|
|||||||
header_retention_days: number
|
header_retention_days: number
|
||||||
log_retention_days: number
|
log_retention_days: number
|
||||||
cleanup_batch_size: number
|
cleanup_batch_size: number
|
||||||
|
audit_log_retention_days: number
|
||||||
}
|
}
|
||||||
|
|
||||||
const loading = ref(false)
|
const loading = ref(false)
|
||||||
@@ -845,8 +846,7 @@ const systemConfig = ref<SystemConfig>({
|
|||||||
// 用户注册
|
// 用户注册
|
||||||
enable_registration: false,
|
enable_registration: false,
|
||||||
require_email_verification: false,
|
require_email_verification: false,
|
||||||
// API Key 管理
|
// 独立余额 Key 过期管理
|
||||||
api_key_expire_days: 0,
|
|
||||||
auto_delete_expired_keys: false,
|
auto_delete_expired_keys: false,
|
||||||
// 日志记录
|
// 日志记录
|
||||||
request_log_level: 'basic',
|
request_log_level: 'basic',
|
||||||
@@ -860,6 +860,7 @@ const systemConfig = ref<SystemConfig>({
|
|||||||
header_retention_days: 90,
|
header_retention_days: 90,
|
||||||
log_retention_days: 365,
|
log_retention_days: 365,
|
||||||
cleanup_batch_size: 1000,
|
cleanup_batch_size: 1000,
|
||||||
|
audit_log_retention_days: 30,
|
||||||
})
|
})
|
||||||
|
|
||||||
// 计算属性:KB 和 字节 之间的转换
|
// 计算属性:KB 和 字节 之间的转换
|
||||||
@@ -901,8 +902,7 @@ async function loadSystemConfig() {
|
|||||||
// 用户注册
|
// 用户注册
|
||||||
'enable_registration',
|
'enable_registration',
|
||||||
'require_email_verification',
|
'require_email_verification',
|
||||||
// API Key 管理
|
// 独立余额 Key 过期管理
|
||||||
'api_key_expire_days',
|
|
||||||
'auto_delete_expired_keys',
|
'auto_delete_expired_keys',
|
||||||
// 日志记录
|
// 日志记录
|
||||||
'request_log_level',
|
'request_log_level',
|
||||||
@@ -916,6 +916,7 @@ async function loadSystemConfig() {
|
|||||||
'header_retention_days',
|
'header_retention_days',
|
||||||
'log_retention_days',
|
'log_retention_days',
|
||||||
'cleanup_batch_size',
|
'cleanup_batch_size',
|
||||||
|
'audit_log_retention_days',
|
||||||
]
|
]
|
||||||
|
|
||||||
for (const key of configs) {
|
for (const key of configs) {
|
||||||
@@ -960,12 +961,7 @@ async function saveSystemConfig() {
|
|||||||
value: systemConfig.value.require_email_verification,
|
value: systemConfig.value.require_email_verification,
|
||||||
description: '是否需要邮箱验证'
|
description: '是否需要邮箱验证'
|
||||||
},
|
},
|
||||||
// API Key 管理
|
// 独立余额 Key 过期管理
|
||||||
{
|
|
||||||
key: 'api_key_expire_days',
|
|
||||||
value: systemConfig.value.api_key_expire_days,
|
|
||||||
description: 'API密钥过期天数'
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
key: 'auto_delete_expired_keys',
|
key: 'auto_delete_expired_keys',
|
||||||
value: systemConfig.value.auto_delete_expired_keys,
|
value: systemConfig.value.auto_delete_expired_keys,
|
||||||
@@ -1023,6 +1019,11 @@ async function saveSystemConfig() {
|
|||||||
value: systemConfig.value.cleanup_batch_size,
|
value: systemConfig.value.cleanup_batch_size,
|
||||||
description: '每批次清理的记录数'
|
description: '每批次清理的记录数'
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
key: 'audit_log_retention_days',
|
||||||
|
value: systemConfig.value.audit_log_retention_days,
|
||||||
|
description: '审计日志保留天数'
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
const promises = configItems.map(item =>
|
const promises = configItems.map(item =>
|
||||||
|
|||||||
@@ -223,7 +223,7 @@ class AdminCreateStandaloneKeyAdapter(AdminApiAdapter):
|
|||||||
allowed_providers=self.key_data.allowed_providers,
|
allowed_providers=self.key_data.allowed_providers,
|
||||||
allowed_api_formats=self.key_data.allowed_api_formats,
|
allowed_api_formats=self.key_data.allowed_api_formats,
|
||||||
allowed_models=self.key_data.allowed_models,
|
allowed_models=self.key_data.allowed_models,
|
||||||
rate_limit=self.key_data.rate_limit or 100,
|
rate_limit=self.key_data.rate_limit, # None 表示不限制
|
||||||
expire_days=self.key_data.expire_days,
|
expire_days=self.key_data.expire_days,
|
||||||
initial_balance_usd=self.key_data.initial_balance_usd,
|
initial_balance_usd=self.key_data.initial_balance_usd,
|
||||||
is_standalone=True, # 标记为独立Key
|
is_standalone=True, # 标记为独立Key
|
||||||
|
|||||||
@@ -140,7 +140,7 @@ class AnnouncementOptionalAuthAdapter(ApiAdapter):
|
|||||||
if not authorization or not authorization.lower().startswith("bearer "):
|
if not authorization or not authorization.lower().startswith("bearer "):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
token = authorization.replace("Bearer ", "").strip()
|
token = authorization[7:].strip()
|
||||||
try:
|
try:
|
||||||
payload = await AuthService.verify_token(token, token_type="access")
|
payload = await AuthService.verify_token(token, token_type="access")
|
||||||
user_id = payload.get("user_id")
|
user_id = payload.get("user_id")
|
||||||
|
|||||||
@@ -177,7 +177,7 @@ class ApiRequestPipeline:
|
|||||||
if not authorization or not authorization.lower().startswith("bearer "):
|
if not authorization or not authorization.lower().startswith("bearer "):
|
||||||
raise HTTPException(status_code=401, detail="缺少管理员凭证")
|
raise HTTPException(status_code=401, detail="缺少管理员凭证")
|
||||||
|
|
||||||
token = authorization.replace("Bearer ", "").strip()
|
token = authorization[7:].strip()
|
||||||
try:
|
try:
|
||||||
payload = await self.auth_service.verify_token(token, token_type="access")
|
payload = await self.auth_service.verify_token(token, token_type="access")
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
@@ -204,7 +204,7 @@ class ApiRequestPipeline:
|
|||||||
if not authorization or not authorization.lower().startswith("bearer "):
|
if not authorization or not authorization.lower().startswith("bearer "):
|
||||||
raise HTTPException(status_code=401, detail="缺少用户凭证")
|
raise HTTPException(status_code=401, detail="缺少用户凭证")
|
||||||
|
|
||||||
token = authorization.replace("Bearer ", "").strip()
|
token = authorization[7:].strip()
|
||||||
try:
|
try:
|
||||||
payload = await self.auth_service.verify_token(token, token_type="access")
|
payload = await self.auth_service.verify_token(token, token_type="access")
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|||||||
@@ -28,7 +28,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from typing import Any, Callable, Dict, Optional, Protocol, runtime_checkable
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
@@ -43,6 +43,9 @@ from src.services.provider.format import normalize_api_format
|
|||||||
from src.services.system.audit import audit_service
|
from src.services.system.audit import audit_service
|
||||||
from src.services.usage.service import UsageService
|
from src.services.usage.service import UsageService
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.api.handlers.base.stream_context import StreamContext
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MessageTelemetry:
|
class MessageTelemetry:
|
||||||
@@ -399,6 +402,41 @@ class BaseMessageHandler:
|
|||||||
# 创建后台任务,不阻塞当前流
|
# 创建后台任务,不阻塞当前流
|
||||||
asyncio.create_task(_do_update())
|
asyncio.create_task(_do_update())
|
||||||
|
|
||||||
|
def _update_usage_to_streaming_with_ctx(self, ctx: "StreamContext") -> None:
|
||||||
|
"""更新 Usage 状态为 streaming,同时更新 provider 和 target_model
|
||||||
|
|
||||||
|
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: 流式上下文,包含 provider_name 和 mapped_model
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
from src.database.database import get_db
|
||||||
|
|
||||||
|
target_request_id = self.request_id
|
||||||
|
provider = ctx.provider_name
|
||||||
|
target_model = ctx.mapped_model
|
||||||
|
|
||||||
|
async def _do_update() -> None:
|
||||||
|
try:
|
||||||
|
db_gen = get_db()
|
||||||
|
db = next(db_gen)
|
||||||
|
try:
|
||||||
|
UsageService.update_usage_status(
|
||||||
|
db=db,
|
||||||
|
request_id=target_request_id,
|
||||||
|
status="streaming",
|
||||||
|
provider=provider,
|
||||||
|
target_model=target_model,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[{target_request_id}] 更新 Usage 状态为 streaming 失败: {e}")
|
||||||
|
|
||||||
|
# 创建后台任务,不阻塞当前流
|
||||||
|
asyncio.create_task(_do_update())
|
||||||
|
|
||||||
def _log_request_error(self, message: str, error: Exception) -> None:
|
def _log_request_error(self, message: str, error: Exception) -> None:
|
||||||
"""记录请求错误日志,对业务异常不打印堆栈
|
"""记录请求错误日志,对业务异常不打印堆栈
|
||||||
|
|
||||||
|
|||||||
@@ -297,11 +297,15 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
# 创建类型安全的流式上下文
|
# 创建类型安全的流式上下文
|
||||||
ctx = StreamContext(model=model, api_format=api_format)
|
ctx = StreamContext(model=model, api_format=api_format)
|
||||||
|
|
||||||
|
# 创建更新状态的回调闭包(可以访问 ctx)
|
||||||
|
def update_streaming_status() -> None:
|
||||||
|
self._update_usage_to_streaming_with_ctx(ctx)
|
||||||
|
|
||||||
# 创建流处理器
|
# 创建流处理器
|
||||||
stream_processor = StreamProcessor(
|
stream_processor = StreamProcessor(
|
||||||
request_id=self.request_id,
|
request_id=self.request_id,
|
||||||
default_parser=self.parser,
|
default_parser=self.parser,
|
||||||
on_streaming_start=self._update_usage_to_streaming,
|
on_streaming_start=update_streaming_status,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 定义请求函数
|
# 定义请求函数
|
||||||
|
|||||||
@@ -532,7 +532,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
async for chunk in stream_response.aiter_raw():
|
async for chunk in stream_response.aiter_raw():
|
||||||
# 在第一次输出数据前更新状态为 streaming
|
# 在第一次输出数据前更新状态为 streaming
|
||||||
if not streaming_status_updated:
|
if not streaming_status_updated:
|
||||||
self._update_usage_to_streaming(ctx.request_id)
|
self._update_usage_to_streaming_with_ctx(ctx)
|
||||||
streaming_status_updated = True
|
streaming_status_updated = True
|
||||||
|
|
||||||
buffer += chunk
|
buffer += chunk
|
||||||
@@ -816,7 +816,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
|
|
||||||
# 在第一次输出数据前更新状态为 streaming
|
# 在第一次输出数据前更新状态为 streaming
|
||||||
if prefetched_chunks:
|
if prefetched_chunks:
|
||||||
self._update_usage_to_streaming(ctx.request_id)
|
self._update_usage_to_streaming_with_ctx(ctx)
|
||||||
|
|
||||||
# 先处理预读的字节块
|
# 先处理预读的字节块
|
||||||
for chunk in prefetched_chunks:
|
for chunk in prefetched_chunks:
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ class RedisClientManager:
|
|||||||
if self._circuit_open_until and time.time() < self._circuit_open_until:
|
if self._circuit_open_until and time.time() < self._circuit_open_until:
|
||||||
remaining = self._circuit_open_until - time.time()
|
remaining = self._circuit_open_until - time.time()
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Redis 客户端处于熔断状态,跳过初始化,剩余 %.1f 秒 (last_error: %s)",
|
"Redis 客户端处于熔断状态,跳过初始化,剩余 {:.1f} 秒 (last_error: {})",
|
||||||
remaining,
|
remaining,
|
||||||
self._last_error,
|
self._last_error,
|
||||||
)
|
)
|
||||||
@@ -200,7 +200,7 @@ class RedisClientManager:
|
|||||||
if self._consecutive_failures >= self._circuit_threshold:
|
if self._consecutive_failures >= self._circuit_threshold:
|
||||||
self._circuit_open_until = time.time() + self._circuit_reset_seconds
|
self._circuit_open_until = time.time() + self._circuit_reset_seconds
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Redis 初始化连续失败 %s 次,开启熔断 %s 秒。"
|
"Redis 初始化连续失败 {} 次,开启熔断 {} 秒。"
|
||||||
"熔断期间以下功能将降级: 缓存亲和性、分布式并发控制、RPM限流。"
|
"熔断期间以下功能将降级: 缓存亲和性、分布式并发控制、RPM限流。"
|
||||||
"可通过管理 API /api/admin/system/redis/reset-circuit 手动重置。",
|
"可通过管理 API /api/admin/system/redis/reset-circuit 手动重置。",
|
||||||
self._consecutive_failures,
|
self._consecutive_failures,
|
||||||
|
|||||||
@@ -105,6 +105,13 @@ class Config:
|
|||||||
self.llm_api_rate_limit = int(os.getenv("LLM_API_RATE_LIMIT", "100"))
|
self.llm_api_rate_limit = int(os.getenv("LLM_API_RATE_LIMIT", "100"))
|
||||||
self.public_api_rate_limit = int(os.getenv("PUBLIC_API_RATE_LIMIT", "60"))
|
self.public_api_rate_limit = int(os.getenv("PUBLIC_API_RATE_LIMIT", "60"))
|
||||||
|
|
||||||
|
# 可信代理配置
|
||||||
|
# TRUSTED_PROXY_COUNT: 信任的代理层数(默认 1,即信任最近一层代理)
|
||||||
|
# 设置为 0 表示不信任任何代理头,直接使用连接 IP
|
||||||
|
# 当服务部署在 Nginx/CloudFlare 等反向代理后面时,设置为对应的代理层数
|
||||||
|
# 如果服务直接暴露公网,应设置为 0 以防止 IP 伪造
|
||||||
|
self.trusted_proxy_count = int(os.getenv("TRUSTED_PROXY_COUNT", "1"))
|
||||||
|
|
||||||
# 异常处理配置
|
# 异常处理配置
|
||||||
# 设置为 True 时,ProxyException 会传播到路由层以便记录 provider_request_headers
|
# 设置为 True 时,ProxyException 会传播到路由层以便记录 provider_request_headers
|
||||||
# 设置为 False 时,使用全局异常处理器统一处理
|
# 设置为 False 时,使用全局异常处理器统一处理
|
||||||
|
|||||||
@@ -153,7 +153,7 @@ def _log_pool_capacity():
|
|||||||
total_estimated = theoretical * workers
|
total_estimated = theoretical * workers
|
||||||
safe_limit = config.pg_max_connections - config.pg_reserved_connections
|
safe_limit = config.pg_max_connections - config.pg_reserved_connections
|
||||||
logger.info(
|
logger.info(
|
||||||
"数据库连接池配置: pool_size=%s, max_overflow=%s, workers=%s, total_estimated=%s, safe_limit=%s",
|
"数据库连接池配置: pool_size={}, max_overflow={}, workers={}, total_estimated={}, safe_limit={}",
|
||||||
config.db_pool_size,
|
config.db_pool_size,
|
||||||
config.db_max_overflow,
|
config.db_max_overflow,
|
||||||
workers,
|
workers,
|
||||||
@@ -162,7 +162,7 @@ def _log_pool_capacity():
|
|||||||
)
|
)
|
||||||
if total_estimated > safe_limit:
|
if total_estimated > safe_limit:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"数据库连接池总需求可能超过 PostgreSQL 限制: %s > %s (pg_max_connections - reserved),"
|
"数据库连接池总需求可能超过 PostgreSQL 限制: {} > {} (pg_max_connections - reserved),"
|
||||||
"建议调整 DB_POOL_SIZE/DB_MAX_OVERFLOW 或减少 worker 数",
|
"建议调整 DB_POOL_SIZE/DB_MAX_OVERFLOW 或减少 worker 数",
|
||||||
total_estimated,
|
total_estimated,
|
||||||
safe_limit,
|
safe_limit,
|
||||||
@@ -260,7 +260,8 @@ def get_db(request: Request = None) -> Generator[Session, None, None]: # type:
|
|||||||
|
|
||||||
2. **管理后台 API**:
|
2. **管理后台 API**:
|
||||||
- 路由层显式调用 db.commit()
|
- 路由层显式调用 db.commit()
|
||||||
- 每个操作独立提交,不依赖中间件
|
- 提交后设置 request.state.tx_committed_by_route = True
|
||||||
|
- 中间件看到此标志后跳过 commit,只负责 close
|
||||||
|
|
||||||
3. **后台任务/调度器**:
|
3. **后台任务/调度器**:
|
||||||
- 使用独立 Session(通过 create_session() 或 next(get_db()))
|
- 使用独立 Session(通过 create_session() 或 next(get_db()))
|
||||||
@@ -271,6 +272,17 @@ def get_db(request: Request = None) -> Generator[Session, None, None]: # type:
|
|||||||
- FastAPI 请求:通过 Depends(get_db) 注入,支持中间件管理的 session 复用
|
- FastAPI 请求:通过 Depends(get_db) 注入,支持中间件管理的 session 复用
|
||||||
- 非请求上下文:直接调用 get_db(),退化为独立 session 模式
|
- 非请求上下文:直接调用 get_db(),退化为独立 session 模式
|
||||||
|
|
||||||
|
路由层提交事务示例
|
||||||
|
==================
|
||||||
|
```python
|
||||||
|
@router.post("/example")
|
||||||
|
async def example(request: Request, db: Session = Depends(get_db)):
|
||||||
|
# ... 业务逻辑 ...
|
||||||
|
db.commit()
|
||||||
|
request.state.tx_committed_by_route = True # 告知中间件已提交
|
||||||
|
return {"message": "success"}
|
||||||
|
```
|
||||||
|
|
||||||
注意事项
|
注意事项
|
||||||
========
|
========
|
||||||
- 本函数不自动提交事务
|
- 本函数不自动提交事务
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ async def initialize_providers():
|
|||||||
# 从数据库加载所有活跃的提供商
|
# 从数据库加载所有活跃的提供商
|
||||||
providers = (
|
providers = (
|
||||||
db.query(Provider)
|
db.query(Provider)
|
||||||
.filter(Provider.is_active == True)
|
.filter(Provider.is_active.is_(True))
|
||||||
.order_by(Provider.provider_priority.asc())
|
.order_by(Provider.provider_priority.asc())
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
@@ -122,6 +122,7 @@ async def lifespan(app: FastAPI):
|
|||||||
logger.info("初始化全局Redis客户端...")
|
logger.info("初始化全局Redis客户端...")
|
||||||
from src.clients.redis_client import get_redis_client
|
from src.clients.redis_client import get_redis_client
|
||||||
|
|
||||||
|
redis_client = None
|
||||||
try:
|
try:
|
||||||
redis_client = await get_redis_client(require_redis=config.require_redis)
|
redis_client = await get_redis_client(require_redis=config.require_redis)
|
||||||
if redis_client:
|
if redis_client:
|
||||||
@@ -133,6 +134,7 @@ async def lifespan(app: FastAPI):
|
|||||||
logger.exception("[ERROR] Redis连接失败,应用启动中止")
|
logger.exception("[ERROR] Redis连接失败,应用启动中止")
|
||||||
raise
|
raise
|
||||||
logger.warning(f"Redis连接失败,但配置允许降级,将继续使用内存模式: {e}")
|
logger.warning(f"Redis连接失败,但配置允许降级,将继续使用内存模式: {e}")
|
||||||
|
redis_client = None
|
||||||
|
|
||||||
# 初始化并发管理器(内部会使用Redis)
|
# 初始化并发管理器(内部会使用Redis)
|
||||||
logger.info("初始化并发管理器...")
|
logger.info("初始化并发管理器...")
|
||||||
@@ -312,7 +314,7 @@ if frontend_dist.exists():
|
|||||||
仅对非API路径生效
|
仅对非API路径生效
|
||||||
"""
|
"""
|
||||||
# 如果是API路径,不处理
|
# 如果是API路径,不处理
|
||||||
if full_path.startswith("api/") or full_path.startswith("v1/"):
|
if full_path in {"api", "v1"} or full_path.startswith(("api/", "v1/")):
|
||||||
raise HTTPException(status_code=404, detail="Not Found")
|
raise HTTPException(status_code=404, detail="Not Found")
|
||||||
|
|
||||||
# 返回index.html,让前端路由处理
|
# 返回index.html,让前端路由处理
|
||||||
|
|||||||
@@ -1,16 +1,17 @@
|
|||||||
"""
|
"""
|
||||||
统一的插件中间件
|
统一的插件中间件(纯 ASGI 实现)
|
||||||
负责协调所有插件的调用
|
负责协调所有插件的调用
|
||||||
|
|
||||||
|
注意:使用纯 ASGI middleware 而非 BaseHTTPMiddleware,
|
||||||
|
以避免 Starlette 已知的流式响应兼容性问题。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
import time
|
import time
|
||||||
from typing import Any, Awaitable, Callable, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import HTTPException, Request
|
from starlette.requests import Request
|
||||||
from fastapi.responses import JSONResponse
|
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
|
||||||
from starlette.responses import Response as StarletteResponse
|
|
||||||
|
|
||||||
from src.config import config
|
from src.config import config
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
@@ -18,20 +19,25 @@ from src.plugins.manager import get_plugin_manager
|
|||||||
from src.plugins.rate_limit.base import RateLimitResult
|
from src.plugins.rate_limit.base import RateLimitResult
|
||||||
|
|
||||||
|
|
||||||
|
class PluginMiddleware:
|
||||||
class PluginMiddleware(BaseHTTPMiddleware):
|
|
||||||
"""
|
"""
|
||||||
统一的插件调用中间件
|
统一的插件调用中间件(纯 ASGI 实现)
|
||||||
|
|
||||||
职责:
|
职责:
|
||||||
- 性能监控
|
- 性能监控
|
||||||
- 限流控制 (可选)
|
- 限流控制 (可选)
|
||||||
|
- 数据库会话生命周期管理
|
||||||
|
|
||||||
注意: 认证由各路由通过 Depends() 显式声明,不在中间件层处理
|
注意: 认证由各路由通过 Depends() 显式声明,不在中间件层处理
|
||||||
|
|
||||||
|
为什么使用纯 ASGI 而非 BaseHTTPMiddleware:
|
||||||
|
- BaseHTTPMiddleware 会缓冲整个响应体,对流式响应不友好
|
||||||
|
- BaseHTTPMiddleware 与 StreamingResponse 存在已知兼容性问题
|
||||||
|
- 纯 ASGI 可以直接透传流式响应,无额外开销
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, app: Any) -> None:
|
def __init__(self, app: ASGIApp) -> None:
|
||||||
super().__init__(app)
|
self.app = app
|
||||||
self.plugin_manager = get_plugin_manager()
|
self.plugin_manager = get_plugin_manager()
|
||||||
|
|
||||||
# 从配置读取速率限制值
|
# 从配置读取速率限制值
|
||||||
@@ -61,152 +67,159 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
|||||||
"/v1/completions",
|
"/v1/completions",
|
||||||
]
|
]
|
||||||
|
|
||||||
async def dispatch(
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
self, request: Request, call_next: Callable[[Request], Awaitable[StarletteResponse]]
|
"""ASGI 入口点"""
|
||||||
) -> StarletteResponse:
|
if scope["type"] != "http":
|
||||||
"""处理请求并调用相应插件"""
|
# 非 HTTP 请求(如 WebSocket)直接透传
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 构建 Request 对象以便复用现有逻辑
|
||||||
|
request = Request(scope, receive, send)
|
||||||
|
|
||||||
# 记录请求开始时间
|
# 记录请求开始时间
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 设置 request.state 属性
|
||||||
|
# 注意:Starlette 的 Request 对象总是有 state 属性(State 实例)
|
||||||
request.state.request_id = request.headers.get("x-request-id", "")
|
request.state.request_id = request.headers.get("x-request-id", "")
|
||||||
request.state.start_time = start_time
|
request.state.start_time = start_time
|
||||||
# 标记:若请求过程中通过 Depends(get_db) 创建了会话,则由本中间件统一管理其生命周期
|
# 标记:若请求过程中通过 Depends(get_db) 创建了会话,则由本中间件统一管理其生命周期
|
||||||
request.state.db_managed_by_middleware = True
|
request.state.db_managed_by_middleware = True
|
||||||
|
|
||||||
response = None
|
# 1. 限流检查(在调用下游之前)
|
||||||
exception_to_raise = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 1. 限流插件调用(可选功能)
|
|
||||||
rate_limit_result = await self._call_rate_limit_plugins(request)
|
rate_limit_result = await self._call_rate_limit_plugins(request)
|
||||||
if rate_limit_result and not rate_limit_result.allowed:
|
if rate_limit_result and not rate_limit_result.allowed:
|
||||||
# 限流触发,返回429
|
# 限流触发,返回429
|
||||||
headers = rate_limit_result.headers or {}
|
await self._send_rate_limit_response(send, rate_limit_result)
|
||||||
raise HTTPException(
|
return
|
||||||
status_code=429,
|
|
||||||
detail=rate_limit_result.message or "Rate limit exceeded",
|
|
||||||
headers=headers,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. 预处理插件调用
|
# 2. 预处理插件调用
|
||||||
await self._call_pre_request_plugins(request)
|
await self._call_pre_request_plugins(request)
|
||||||
|
|
||||||
# 处理请求
|
# 用于捕获响应状态码
|
||||||
response = await call_next(request)
|
response_status_code: int = 0
|
||||||
|
|
||||||
# 3. 提交关键数据库事务(在返回响应前)
|
async def send_wrapper(message: Message) -> None:
|
||||||
# 这确保了 Usage 记录、配额扣减等关键数据在响应返回前持久化
|
nonlocal response_status_code
|
||||||
|
|
||||||
|
if message["type"] == "http.response.start":
|
||||||
|
response_status_code = message.get("status", 0)
|
||||||
|
|
||||||
|
await send(message)
|
||||||
|
|
||||||
|
# 3. 调用下游应用
|
||||||
|
exception_occurred: Optional[Exception] = None
|
||||||
try:
|
try:
|
||||||
|
await self.app(scope, receive, send_wrapper)
|
||||||
|
except Exception as e:
|
||||||
|
exception_occurred = e
|
||||||
|
# 错误处理插件调用
|
||||||
|
await self._call_error_plugins(request, e, start_time)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# 4. 数据库会话清理(无论成功与否)
|
||||||
|
await self._cleanup_db_session(request, exception_occurred)
|
||||||
|
|
||||||
|
# 5. 后处理插件调用(仅在成功时)
|
||||||
|
if not exception_occurred and response_status_code > 0:
|
||||||
|
await self._call_post_request_plugins(request, response_status_code, start_time)
|
||||||
|
|
||||||
|
async def _send_rate_limit_response(
|
||||||
|
self, send: Send, result: RateLimitResult
|
||||||
|
) -> None:
|
||||||
|
"""发送 429 限流响应"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
body = json.dumps({
|
||||||
|
"type": "error",
|
||||||
|
"error": {
|
||||||
|
"type": "rate_limit_error",
|
||||||
|
"message": result.message or "Rate limit exceeded",
|
||||||
|
},
|
||||||
|
}).encode("utf-8")
|
||||||
|
|
||||||
|
headers = [(b"content-type", b"application/json")]
|
||||||
|
if result.headers:
|
||||||
|
for key, value in result.headers.items():
|
||||||
|
headers.append((key.lower().encode(), str(value).encode()))
|
||||||
|
|
||||||
|
await send({
|
||||||
|
"type": "http.response.start",
|
||||||
|
"status": 429,
|
||||||
|
"headers": headers,
|
||||||
|
})
|
||||||
|
await send({
|
||||||
|
"type": "http.response.body",
|
||||||
|
"body": body,
|
||||||
|
})
|
||||||
|
|
||||||
|
async def _cleanup_db_session(
|
||||||
|
self, request: Request, exception: Optional[Exception]
|
||||||
|
) -> None:
|
||||||
|
"""清理数据库会话
|
||||||
|
|
||||||
|
事务策略:
|
||||||
|
- 如果 request.state.tx_committed_by_route 为 True,说明路由已自行提交,中间件只负责 close
|
||||||
|
- 否则由中间件统一 commit/rollback
|
||||||
|
|
||||||
|
这避免了双重提交的问题,同时保持向后兼容。
|
||||||
|
"""
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
db = getattr(request.state, "db", None)
|
db = getattr(request.state, "db", None)
|
||||||
if isinstance(db, Session):
|
if not isinstance(db, Session):
|
||||||
|
return
|
||||||
|
|
||||||
|
# 检查是否由路由层已经提交
|
||||||
|
tx_committed_by_route = getattr(request.state, "tx_committed_by_route", False)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if exception is not None:
|
||||||
|
# 发生异常,回滚事务(无论谁负责提交)
|
||||||
|
try:
|
||||||
|
db.rollback()
|
||||||
|
except Exception as rollback_error:
|
||||||
|
logger.debug(f"回滚事务时出错(可忽略): {rollback_error}")
|
||||||
|
elif not tx_committed_by_route:
|
||||||
|
# 正常完成且路由未自行提交,由中间件提交事务
|
||||||
|
try:
|
||||||
db.commit()
|
db.commit()
|
||||||
except Exception as commit_error:
|
except Exception as commit_error:
|
||||||
logger.error(f"关键事务提交失败: {commit_error}")
|
logger.error(f"关键事务提交失败: {commit_error}")
|
||||||
try:
|
try:
|
||||||
if isinstance(db, Session):
|
|
||||||
db.rollback()
|
db.rollback()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
await self._call_error_plugins(request, commit_error, start_time)
|
# 如果 tx_committed_by_route 为 True,跳过 commit(路由已提交)
|
||||||
# 返回 500 错误,因为数据可能不一致
|
|
||||||
response = JSONResponse(
|
|
||||||
status_code=500,
|
|
||||||
content={
|
|
||||||
"type": "error",
|
|
||||||
"error": {
|
|
||||||
"type": "database_error",
|
|
||||||
"message": "数据保存失败,请重试",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# 跳过后处理插件,直接返回错误响应
|
|
||||||
return response
|
|
||||||
|
|
||||||
# 4. 后处理插件调用(监控等,非关键操作)
|
|
||||||
# 这些操作失败不应影响用户响应
|
|
||||||
await self._call_post_request_plugins(request, response, start_time)
|
|
||||||
|
|
||||||
# 注意:不在此处添加限流响应头,因为在BaseHTTPMiddleware中
|
|
||||||
# 响应返回后修改headers会导致Content-Length不匹配错误
|
|
||||||
# 限流响应头已在返回429错误时正确包含(见上面的HTTPException)
|
|
||||||
|
|
||||||
except RuntimeError as e:
|
|
||||||
if str(e) == "No response returned.":
|
|
||||||
db = getattr(request.state, "db", None)
|
|
||||||
if isinstance(db, Session):
|
|
||||||
try:
|
|
||||||
db.rollback()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
logger.error("Downstream handler completed without returning a response")
|
|
||||||
|
|
||||||
await self._call_error_plugins(request, e, start_time)
|
|
||||||
|
|
||||||
if isinstance(db, Session):
|
|
||||||
try:
|
|
||||||
db.commit()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
response = JSONResponse(
|
|
||||||
status_code=500,
|
|
||||||
content={
|
|
||||||
"type": "error",
|
|
||||||
"error": {
|
|
||||||
"type": "internal_error",
|
|
||||||
"message": "Internal server error: downstream handler returned no response.",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
exception_to_raise = e
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# 回滚数据库事务
|
|
||||||
db = getattr(request.state, "db", None)
|
|
||||||
if isinstance(db, Session):
|
|
||||||
try:
|
|
||||||
db.rollback()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 错误处理插件调用
|
|
||||||
await self._call_error_plugins(request, e, start_time)
|
|
||||||
|
|
||||||
# 尝试提交错误日志
|
|
||||||
if isinstance(db, Session):
|
|
||||||
try:
|
|
||||||
db.commit()
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
exception_to_raise = e
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
db = getattr(request.state, "db", None)
|
# 关闭会话,归还连接到连接池
|
||||||
if isinstance(db, Session):
|
|
||||||
try:
|
try:
|
||||||
db.close()
|
db.close()
|
||||||
except Exception as close_error:
|
except Exception as close_error:
|
||||||
# 连接池会处理连接的回收,这里的异常不应影响响应
|
|
||||||
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
|
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
|
||||||
|
|
||||||
# 在 finally 块之后处理异常和响应
|
|
||||||
if exception_to_raise:
|
|
||||||
raise exception_to_raise
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
def _get_client_ip(self, request: Request) -> str:
|
def _get_client_ip(self, request: Request) -> str:
|
||||||
"""
|
"""
|
||||||
获取客户端 IP 地址,支持代理头
|
获取客户端 IP 地址,支持代理头
|
||||||
|
|
||||||
|
注意:此方法信任 X-Forwarded-For 和 X-Real-IP 头,
|
||||||
|
仅当服务部署在可信代理(如 Nginx、CloudFlare)后面时才安全。
|
||||||
|
如果服务直接暴露公网,攻击者可伪造这些头绕过限流。
|
||||||
"""
|
"""
|
||||||
|
# 从配置获取可信代理层数(默认为 1,即信任最近一层代理)
|
||||||
|
trusted_proxy_count = getattr(config, "trusted_proxy_count", 1)
|
||||||
|
|
||||||
# 优先从代理头获取真实 IP
|
# 优先从代理头获取真实 IP
|
||||||
forwarded_for = request.headers.get("x-forwarded-for")
|
forwarded_for = request.headers.get("x-forwarded-for")
|
||||||
if forwarded_for:
|
if forwarded_for:
|
||||||
# X-Forwarded-For 可能包含多个 IP,取第一个
|
# X-Forwarded-For 格式: "client, proxy1, proxy2"
|
||||||
return forwarded_for.split(",")[0].strip()
|
# 从右往左数 trusted_proxy_count 个,取其左边的第一个
|
||||||
|
ips = [ip.strip() for ip in forwarded_for.split(",")]
|
||||||
|
if len(ips) > trusted_proxy_count:
|
||||||
|
return ips[-(trusted_proxy_count + 1)]
|
||||||
|
elif ips:
|
||||||
|
return ips[0]
|
||||||
|
|
||||||
real_ip = request.headers.get("x-real-ip")
|
real_ip = request.headers.get("x-real-ip")
|
||||||
if real_ip:
|
if real_ip:
|
||||||
@@ -248,13 +261,11 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
|||||||
auth_header = request.headers.get("authorization", "")
|
auth_header = request.headers.get("authorization", "")
|
||||||
api_key = request.headers.get("x-api-key", "")
|
api_key = request.headers.get("x-api-key", "")
|
||||||
|
|
||||||
if auth_header.startswith("Bearer "):
|
if auth_header.lower().startswith("bearer "):
|
||||||
api_key = auth_header[7:]
|
api_key = auth_header[7:]
|
||||||
|
|
||||||
if api_key:
|
if api_key:
|
||||||
# 使用 API Key 的哈希作为限制 key(避免日志泄露完整 key)
|
# 使用 API Key 的哈希作为限制 key(避免日志泄露完整 key)
|
||||||
import hashlib
|
|
||||||
|
|
||||||
key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:16]
|
key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:16]
|
||||||
key = f"llm_api_key:{key_hash}"
|
key = f"llm_api_key:{key_hash}"
|
||||||
request.state.rate_limit_key_type = "api_key"
|
request.state.rate_limit_key_type = "api_key"
|
||||||
@@ -319,7 +330,10 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# 限流触发,记录日志
|
# 限流触发,记录日志
|
||||||
logger.warning(f"速率限制触发: {getattr(request.state, 'rate_limit_key_type', 'unknown')}")
|
logger.warning(
|
||||||
|
"速率限制触发: {}",
|
||||||
|
getattr(request.state, "rate_limit_key_type", "unknown"),
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -332,7 +346,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
async def _call_post_request_plugins(
|
async def _call_post_request_plugins(
|
||||||
self, request: Request, response: StarletteResponse, start_time: float
|
self, request: Request, status_code: int, start_time: float
|
||||||
) -> None:
|
) -> None:
|
||||||
"""调用请求后的插件"""
|
"""调用请求后的插件"""
|
||||||
|
|
||||||
@@ -345,8 +359,8 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
|||||||
monitor_labels = {
|
monitor_labels = {
|
||||||
"method": request.method,
|
"method": request.method,
|
||||||
"endpoint": request.url.path,
|
"endpoint": request.url.path,
|
||||||
"status": str(response.status_code),
|
"status": str(status_code),
|
||||||
"status_class": f"{response.status_code // 100}xx",
|
"status_class": f"{status_code // 100}xx",
|
||||||
}
|
}
|
||||||
|
|
||||||
# 记录请求计数
|
# 记录请求计数
|
||||||
@@ -368,6 +382,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
|||||||
self, request: Request, error: Exception, start_time: float
|
self, request: Request, error: Exception, start_time: float
|
||||||
) -> None:
|
) -> None:
|
||||||
"""调用错误处理插件"""
|
"""调用错误处理插件"""
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
|
|
||||||
@@ -380,7 +395,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
|||||||
error=error,
|
error=error,
|
||||||
context={
|
context={
|
||||||
"endpoint": f"{request.method} {request.url.path}",
|
"endpoint": f"{request.method} {request.url.path}",
|
||||||
"request_id": request.state.request_id,
|
"request_id": getattr(request.state, "request_id", ""),
|
||||||
"duration": duration,
|
"duration": duration,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ JWT认证插件
|
|||||||
支持JWT Bearer token认证
|
支持JWT Bearer token认证
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
@@ -46,8 +47,8 @@ class JwtAuthPlugin(AuthPlugin):
|
|||||||
logger.debug("未找到JWT token")
|
logger.debug("未找到JWT token")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 记录认证尝试的详细信息
|
token_fingerprint = hashlib.sha256(token.encode()).hexdigest()[:12]
|
||||||
logger.info(f"JWT认证尝试 - 路径: {request.url.path}, Token前20位: {token[:20]}...")
|
logger.info(f"JWT认证尝试 - 路径: {request.url.path}, token_fp={token_fingerprint}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 验证JWT token
|
# 验证JWT token
|
||||||
|
|||||||
@@ -63,14 +63,16 @@ class JWTBlacklistService:
|
|||||||
|
|
||||||
if ttl_seconds <= 0:
|
if ttl_seconds <= 0:
|
||||||
# Token 已经过期,不需要加入黑名单
|
# Token 已经过期,不需要加入黑名单
|
||||||
logger.debug(f"Token 已过期,无需加入黑名单: {token[:10]}...")
|
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
|
||||||
|
logger.debug("Token 已过期,无需加入黑名单: token_fp={}", token_fp)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# 存储到 Redis,设置 TTL 为 Token 过期时间
|
# 存储到 Redis,设置 TTL 为 Token 过期时间
|
||||||
# 值存储为原因字符串
|
# 值存储为原因字符串
|
||||||
await redis_client.setex(redis_key, ttl_seconds, reason)
|
await redis_client.setex(redis_key, ttl_seconds, reason)
|
||||||
|
|
||||||
logger.info(f"Token 已加入黑名单: {token[:10]}... (原因: {reason}, TTL: {ttl_seconds}s)")
|
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
|
||||||
|
logger.info("Token 已加入黑名单: token_fp={} (原因: {}, TTL: {}s)", token_fp, reason, ttl_seconds)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -109,7 +111,8 @@ class JWTBlacklistService:
|
|||||||
if exists:
|
if exists:
|
||||||
# 获取黑名单原因(可选)
|
# 获取黑名单原因(可选)
|
||||||
reason = await redis_client.get(redis_key)
|
reason = await redis_client.get(redis_key)
|
||||||
logger.warning(f"检测到黑名单 Token: {token[:10]}... (原因: {reason})")
|
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
|
||||||
|
logger.warning("检测到黑名单 Token: token_fp={} (原因: {})", token_fp, reason)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
@@ -148,9 +151,11 @@ class JWTBlacklistService:
|
|||||||
deleted = await redis_client.delete(redis_key)
|
deleted = await redis_client.delete(redis_key)
|
||||||
|
|
||||||
if deleted:
|
if deleted:
|
||||||
logger.info(f"Token 已从黑名单移除: {token[:10]}...")
|
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
|
||||||
|
logger.info("Token 已从黑名单移除: token_fp={}", token_fp)
|
||||||
else:
|
else:
|
||||||
logger.debug(f"Token 不在黑名单中: {token[:10]}...")
|
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
|
||||||
|
logger.debug("Token 不在黑名单中: token_fp={}", token_fp)
|
||||||
|
|
||||||
return bool(deleted)
|
return bool(deleted)
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import hashlib
|
||||||
import secrets
|
import secrets
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
@@ -169,7 +170,8 @@ class AuthService:
|
|||||||
key_record.last_used_at = datetime.now(timezone.utc)
|
key_record.last_used_at = datetime.now(timezone.utc)
|
||||||
db.commit() # 立即提交事务,释放数据库锁,避免阻塞后续请求
|
db.commit() # 立即提交事务,释放数据库锁,避免阻塞后续请求
|
||||||
|
|
||||||
logger.debug(f"API认证成功: 用户 {user.email} (Key: {api_key[:10]}...)")
|
api_key_fp = hashlib.sha256(api_key.encode()).hexdigest()[:12]
|
||||||
|
logger.debug("API认证成功: 用户 {} (api_key_fp={})", user.email, api_key_fp)
|
||||||
return user, key_record
|
return user, key_record
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
from src.database import create_session
|
from src.database import create_session
|
||||||
from src.models.database import Usage
|
from src.models.database import AuditLog, Usage
|
||||||
from src.services.system.config import SystemConfigService
|
from src.services.system.config import SystemConfigService
|
||||||
from src.services.system.scheduler import get_scheduler
|
from src.services.system.scheduler import get_scheduler
|
||||||
from src.services.system.stats_aggregator import StatsAggregatorService
|
from src.services.system.stats_aggregator import StatsAggregatorService
|
||||||
@@ -91,6 +91,15 @@ class CleanupScheduler:
|
|||||||
name="Pending状态清理",
|
name="Pending状态清理",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 审计日志清理 - 凌晨 4 点执行
|
||||||
|
scheduler.add_cron_job(
|
||||||
|
self._scheduled_audit_cleanup,
|
||||||
|
hour=4,
|
||||||
|
minute=0,
|
||||||
|
job_id="audit_cleanup",
|
||||||
|
name="审计日志清理",
|
||||||
|
)
|
||||||
|
|
||||||
# 启动时执行一次初始化任务
|
# 启动时执行一次初始化任务
|
||||||
asyncio.create_task(self._run_startup_tasks())
|
asyncio.create_task(self._run_startup_tasks())
|
||||||
|
|
||||||
@@ -145,6 +154,10 @@ class CleanupScheduler:
|
|||||||
"""Pending 清理任务(定时调用)"""
|
"""Pending 清理任务(定时调用)"""
|
||||||
await self._perform_pending_cleanup()
|
await self._perform_pending_cleanup()
|
||||||
|
|
||||||
|
async def _scheduled_audit_cleanup(self):
|
||||||
|
"""审计日志清理任务(定时调用)"""
|
||||||
|
await self._perform_audit_cleanup()
|
||||||
|
|
||||||
# ========== 实际任务实现 ==========
|
# ========== 实际任务实现 ==========
|
||||||
|
|
||||||
async def _perform_stats_aggregation(self, backfill: bool = False):
|
async def _perform_stats_aggregation(self, backfill: bool = False):
|
||||||
@@ -330,6 +343,70 @@ class CleanupScheduler:
|
|||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
async def _perform_audit_cleanup(self):
|
||||||
|
"""执行审计日志清理任务"""
|
||||||
|
db = create_session()
|
||||||
|
try:
|
||||||
|
# 检查是否启用自动清理
|
||||||
|
if not SystemConfigService.get_config(db, "enable_auto_cleanup", True):
|
||||||
|
logger.info("自动清理已禁用,跳过审计日志清理")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 获取审计日志保留天数(默认 30 天,最少 7 天)
|
||||||
|
audit_retention_days = max(
|
||||||
|
SystemConfigService.get_config(db, "audit_log_retention_days", 30),
|
||||||
|
7, # 最少保留 7 天,防止误配置删除所有审计日志
|
||||||
|
)
|
||||||
|
batch_size = SystemConfigService.get_config(db, "cleanup_batch_size", 1000)
|
||||||
|
|
||||||
|
cutoff_time = datetime.now(timezone.utc) - timedelta(days=audit_retention_days)
|
||||||
|
|
||||||
|
logger.info(f"开始清理 {audit_retention_days} 天前的审计日志...")
|
||||||
|
|
||||||
|
total_deleted = 0
|
||||||
|
while True:
|
||||||
|
# 先查询要删除的记录 ID(分批)
|
||||||
|
records_to_delete = (
|
||||||
|
db.query(AuditLog.id)
|
||||||
|
.filter(AuditLog.created_at < cutoff_time)
|
||||||
|
.limit(batch_size)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not records_to_delete:
|
||||||
|
break
|
||||||
|
|
||||||
|
record_ids = [r.id for r in records_to_delete]
|
||||||
|
|
||||||
|
# 执行删除
|
||||||
|
result = db.execute(
|
||||||
|
delete(AuditLog)
|
||||||
|
.where(AuditLog.id.in_(record_ids))
|
||||||
|
.execution_options(synchronize_session=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
rows_deleted = result.rowcount
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
total_deleted += rows_deleted
|
||||||
|
logger.debug(f"已删除 {rows_deleted} 条审计日志,累计 {total_deleted} 条")
|
||||||
|
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
if total_deleted > 0:
|
||||||
|
logger.info(f"审计日志清理完成,共删除 {total_deleted} 条记录")
|
||||||
|
else:
|
||||||
|
logger.info("无需清理的审计日志")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"审计日志清理失败: {e}")
|
||||||
|
try:
|
||||||
|
db.rollback()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
async def _perform_cleanup(self):
|
async def _perform_cleanup(self):
|
||||||
"""执行清理任务"""
|
"""执行清理任务"""
|
||||||
db = create_session()
|
db = create_session()
|
||||||
|
|||||||
@@ -1217,15 +1217,19 @@ class UsageService:
|
|||||||
request_id: str,
|
request_id: str,
|
||||||
status: str,
|
status: str,
|
||||||
error_message: Optional[str] = None,
|
error_message: Optional[str] = None,
|
||||||
|
provider: Optional[str] = None,
|
||||||
|
target_model: Optional[str] = None,
|
||||||
) -> Optional[Usage]:
|
) -> Optional[Usage]:
|
||||||
"""
|
"""
|
||||||
快速更新使用记录状态(不更新其他字段)
|
快速更新使用记录状态
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: 数据库会话
|
db: 数据库会话
|
||||||
request_id: 请求ID
|
request_id: 请求ID
|
||||||
status: 新状态 (pending, streaming, completed, failed)
|
status: 新状态 (pending, streaming, completed, failed)
|
||||||
error_message: 错误消息(仅在 failed 状态时使用)
|
error_message: 错误消息(仅在 failed 状态时使用)
|
||||||
|
provider: 提供商名称(可选,streaming 状态时更新)
|
||||||
|
target_model: 映射后的目标模型名(可选)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
更新后的 Usage 记录,如果未找到则返回 None
|
更新后的 Usage 记录,如果未找到则返回 None
|
||||||
@@ -1239,6 +1243,10 @@ class UsageService:
|
|||||||
usage.status = status
|
usage.status = status
|
||||||
if error_message:
|
if error_message:
|
||||||
usage.error_message = error_message
|
usage.error_message = error_message
|
||||||
|
if provider:
|
||||||
|
usage.provider = provider
|
||||||
|
if target_model:
|
||||||
|
usage.target_model = target_model
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
|
|||||||
@@ -457,7 +457,7 @@ class StreamUsageTracker:
|
|||||||
|
|
||||||
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应 | 估算输入tokens:{self.input_tokens}")
|
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应 | 估算输入tokens:{self.input_tokens}")
|
||||||
|
|
||||||
# 更新状态为 streaming
|
# 更新状态为 streaming,同时更新 provider
|
||||||
if self.request_id:
|
if self.request_id:
|
||||||
try:
|
try:
|
||||||
from src.services.usage.service import UsageService
|
from src.services.usage.service import UsageService
|
||||||
@@ -465,6 +465,7 @@ class StreamUsageTracker:
|
|||||||
db=self.db,
|
db=self.db,
|
||||||
request_id=self.request_id,
|
request_id=self.request_id,
|
||||||
status="streaming",
|
status="streaming",
|
||||||
|
provider=self.provider,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
|
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
|
||||||
|
|||||||
@@ -210,7 +210,15 @@ class ApiKeyService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_rate_limit(db: Session, api_key: ApiKey, window_minutes: int = 1) -> tuple[bool, int]:
|
def check_rate_limit(db: Session, api_key: ApiKey, window_minutes: int = 1) -> tuple[bool, int]:
|
||||||
"""检查速率限制"""
|
"""检查速率限制
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(is_allowed, remaining): 是否允许请求,剩余可用次数
|
||||||
|
当 rate_limit 为 None 时表示不限制,返回 (True, -1)
|
||||||
|
"""
|
||||||
|
# 如果 rate_limit 为 None,表示不限制
|
||||||
|
if api_key.rate_limit is None:
|
||||||
|
return True, -1 # -1 表示无限制
|
||||||
|
|
||||||
# 计算时间窗口
|
# 计算时间窗口
|
||||||
window_start = datetime.now(timezone.utc) - timedelta(minutes=window_minutes)
|
window_start = datetime.now(timezone.utc) - timedelta(minutes=window_minutes)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
提供统一的用户认证和授权功能
|
提供统一的用户认证和授权功能
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import Depends, Header, HTTPException, status
|
from fastapi import Depends, Header, HTTPException, status
|
||||||
@@ -44,10 +45,17 @@ async def get_current_user(
|
|||||||
payload = await AuthService.verify_token(token, token_type="access")
|
payload = await AuthService.verify_token(token, token_type="access")
|
||||||
except HTTPException as token_error:
|
except HTTPException as token_error:
|
||||||
# 保持原始的HTTP状态码(如401 Unauthorized),不要转换为403
|
# 保持原始的HTTP状态码(如401 Unauthorized),不要转换为403
|
||||||
logger.error(f"Token验证失败: {token_error.status_code}: {token_error.detail}, Token前10位: {token[:10]}...")
|
token_fp = hashlib.sha256(token.encode()).hexdigest()[:12]
|
||||||
|
logger.error(
|
||||||
|
"Token验证失败: {}: {}, token_fp={}",
|
||||||
|
token_error.status_code,
|
||||||
|
token_error.detail,
|
||||||
|
token_fp,
|
||||||
|
)
|
||||||
raise # 重新抛出原始异常,保持状态码
|
raise # 重新抛出原始异常,保持状态码
|
||||||
except Exception as token_error:
|
except Exception as token_error:
|
||||||
logger.error(f"Token验证失败: {token_error}, Token前10位: {token[:10]}...")
|
token_fp = hashlib.sha256(token.encode()).hexdigest()[:12]
|
||||||
|
logger.error("Token验证失败: {}, token_fp={}", token_error, token_fp)
|
||||||
raise ForbiddenException("无效的Token")
|
raise ForbiddenException("无效的Token")
|
||||||
|
|
||||||
user_id = payload.get("user_id")
|
user_id = payload.get("user_id")
|
||||||
@@ -63,7 +71,8 @@ async def get_current_user(
|
|||||||
raise ForbiddenException("无效的认证凭据")
|
raise ForbiddenException("无效的认证凭据")
|
||||||
|
|
||||||
# 仅在DEBUG模式下记录详细信息
|
# 仅在DEBUG模式下记录详细信息
|
||||||
logger.debug(f"尝试获取用户: user_id={user_id}, token前10位: {token[:10]}...")
|
token_fp = hashlib.sha256(token.encode()).hexdigest()[:12]
|
||||||
|
logger.debug("尝试获取用户: user_id={}, token_fp={}", user_id, token_fp)
|
||||||
|
|
||||||
# 确保user_id是字符串格式(UUID)
|
# 确保user_id是字符串格式(UUID)
|
||||||
if not isinstance(user_id, str):
|
if not isinstance(user_id, str):
|
||||||
|
|||||||
@@ -7,29 +7,47 @@ from typing import Optional
|
|||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
|
from src.config import config
|
||||||
|
|
||||||
|
|
||||||
def get_client_ip(request: Request) -> str:
|
def get_client_ip(request: Request) -> str:
|
||||||
"""
|
"""
|
||||||
获取客户端真实IP地址
|
获取客户端真实IP地址
|
||||||
|
|
||||||
按优先级检查:
|
按优先级检查:
|
||||||
1. X-Forwarded-For 头(支持代理链)
|
1. X-Forwarded-For 头(支持代理链,根据可信代理数量提取)
|
||||||
2. X-Real-IP 头(Nginx 代理)
|
2. X-Real-IP 头(Nginx 代理)
|
||||||
3. 直接客户端IP
|
3. 直接客户端IP
|
||||||
|
|
||||||
|
安全说明:
|
||||||
|
- 此函数根据 TRUSTED_PROXY_COUNT 配置来决定信任的代理层数
|
||||||
|
- 当 TRUSTED_PROXY_COUNT=0 时,不信任任何代理头,直接使用连接 IP
|
||||||
|
- 当服务直接暴露公网时,应设置 TRUSTED_PROXY_COUNT=0 以防止 IP 伪造
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: FastAPI Request 对象
|
request: FastAPI Request 对象
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 客户端IP地址,如果无法获取则返回 "unknown"
|
str: 客户端IP地址,如果无法获取则返回 "unknown"
|
||||||
"""
|
"""
|
||||||
|
trusted_proxy_count = config.trusted_proxy_count
|
||||||
|
|
||||||
|
# 如果不信任任何代理,直接返回连接 IP
|
||||||
|
if trusted_proxy_count == 0:
|
||||||
|
if request.client and request.client.host:
|
||||||
|
return request.client.host
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
# 优先检查 X-Forwarded-For 头(可能包含代理链)
|
# 优先检查 X-Forwarded-For 头(可能包含代理链)
|
||||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||||
if forwarded_for:
|
if forwarded_for:
|
||||||
# X-Forwarded-For 格式: "client, proxy1, proxy2",取第一个(真实客户端)
|
# X-Forwarded-For 格式: "client, proxy1, proxy2"
|
||||||
client_ip = forwarded_for.split(",")[0].strip()
|
# 从右往左数 trusted_proxy_count 个,取其左边的第一个
|
||||||
if client_ip:
|
ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()]
|
||||||
return client_ip
|
if len(ips) > trusted_proxy_count:
|
||||||
|
return ips[-(trusted_proxy_count + 1)]
|
||||||
|
elif ips:
|
||||||
|
return ips[0]
|
||||||
|
|
||||||
# 检查 X-Real-IP 头(通常由 Nginx 设置)
|
# 检查 X-Real-IP 头(通常由 Nginx 设置)
|
||||||
real_ip = request.headers.get("X-Real-IP")
|
real_ip = request.headers.get("X-Real-IP")
|
||||||
@@ -91,20 +109,32 @@ def get_request_metadata(request: Request) -> dict:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def extract_ip_from_headers(headers: dict) -> str:
|
def extract_ip_from_headers(headers: dict, trusted_proxy_count: Optional[int] = None) -> str:
|
||||||
"""
|
"""
|
||||||
从HTTP头字典中提取IP地址(用于中间件等场景)
|
从HTTP头字典中提取IP地址(用于中间件等场景)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
headers: HTTP头字典
|
headers: HTTP头字典
|
||||||
|
trusted_proxy_count: 可信代理层数,None 时使用配置值
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 客户端IP地址
|
str: 客户端IP地址
|
||||||
"""
|
"""
|
||||||
|
if trusted_proxy_count is None:
|
||||||
|
trusted_proxy_count = config.trusted_proxy_count
|
||||||
|
|
||||||
|
# 如果不信任任何代理,返回 unknown(调用方需要用其他方式获取连接 IP)
|
||||||
|
if trusted_proxy_count == 0:
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
# 检查 X-Forwarded-For
|
# 检查 X-Forwarded-For
|
||||||
forwarded_for = headers.get("x-forwarded-for", "")
|
forwarded_for = headers.get("x-forwarded-for", "")
|
||||||
if forwarded_for:
|
if forwarded_for:
|
||||||
return forwarded_for.split(",")[0].strip()
|
ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()]
|
||||||
|
if len(ips) > trusted_proxy_count:
|
||||||
|
return ips[-(trusted_proxy_count + 1)]
|
||||||
|
elif ips:
|
||||||
|
return ips[0]
|
||||||
|
|
||||||
# 检查 X-Real-IP
|
# 检查 X-Real-IP
|
||||||
real_ip = headers.get("x-real-ip", "")
|
real_ip = headers.get("x-real-ip", "")
|
||||||
|
|||||||
@@ -361,3 +361,61 @@ class TestPipelineAdminAuth:
|
|||||||
|
|
||||||
assert result == mock_user
|
assert result == mock_user
|
||||||
assert mock_request.state.user_id == "admin-123"
|
assert mock_request.state.user_id == "admin-123"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_admin_lowercase_bearer(self, pipeline: ApiRequestPipeline) -> None:
|
||||||
|
"""测试 bearer (小写) 前缀也能正确解析"""
|
||||||
|
mock_user = MagicMock()
|
||||||
|
mock_user.id = "admin-123"
|
||||||
|
mock_user.is_active = True
|
||||||
|
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.headers = {"authorization": "bearer valid-token"}
|
||||||
|
mock_request.state = MagicMock()
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
pipeline.auth_service,
|
||||||
|
"verify_token",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value={"user_id": "admin-123"},
|
||||||
|
) as mock_verify:
|
||||||
|
result = await pipeline._authenticate_admin(mock_request, mock_db)
|
||||||
|
|
||||||
|
mock_verify.assert_awaited_once_with("valid-token", token_type="access")
|
||||||
|
assert result == mock_user
|
||||||
|
|
||||||
|
|
||||||
|
class TestPipelineUserAuth:
|
||||||
|
"""测试普通用户 JWT 认证"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def pipeline(self) -> ApiRequestPipeline:
|
||||||
|
return ApiRequestPipeline()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_user_lowercase_bearer(self, pipeline: ApiRequestPipeline) -> None:
|
||||||
|
"""测试 bearer (小写) 前缀也能正确解析"""
|
||||||
|
mock_user = MagicMock()
|
||||||
|
mock_user.id = "user-123"
|
||||||
|
mock_user.is_active = True
|
||||||
|
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.headers = {"authorization": "bearer valid-token"}
|
||||||
|
mock_request.state = MagicMock()
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
pipeline.auth_service,
|
||||||
|
"verify_token",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value={"user_id": "user-123"},
|
||||||
|
) as mock_verify:
|
||||||
|
result = await pipeline._authenticate_user(mock_request, mock_db)
|
||||||
|
|
||||||
|
mock_verify.assert_awaited_once_with("valid-token", token_type="access")
|
||||||
|
assert result == mock_user
|
||||||
|
|||||||
Reference in New Issue
Block a user