5 Commits

Author SHA1 Message Date
fawney19
5f0c1fb347 refactor: remove unused response normalizer module
- Delete unused ResponseNormalizer class and its initialization logic
- Remove response_normalizer and enable_response_normalization parameters from handlers
- Simplify chat adapter base initialization by removing normalizer setup
- Clean up unused imports in handler modules
2025-12-19 01:20:30 +08:00
fawney19
7b932d7afb refactor: optimize middleware with pure ASGI implementation and enhance security measures
- Replace BaseHTTPMiddleware with pure ASGI implementation in plugin middleware for better streaming response handling
- Add trusted proxy count configuration for client IP extraction in reverse proxy environments
- Implement audit log cleanup scheduler with configurable retention period
- Replace plaintext token logging with SHA256 hash fingerprints for security
- Fix database session lifecycle management in middleware
- Improve request tracing and error logging throughout the system
- Add comprehensive tests for pipeline architecture
2025-12-18 19:07:20 +08:00
fawney19
c7b971cfe7 Merge pull request #29 from fawney19/feature/proxy-support
feat: add HTTP/SOCKS5 proxy support for API endpoints
2025-12-18 16:18:25 +08:00
fawney19
293bb592dc fix: enhance proxy configuration with password preservation and UI improvements
- Add 'enabled' field to ProxyConfig for preserving config when disabled
- Mask proxy password in API responses (return '***' instead of actual password)
- Preserve existing password on update when new password not provided
- Add URL encoding for proxy credentials (handle special chars like @, :, /)
- Enhanced URL validation: block SOCKS4, require valid host, forbid embedded auth
- UI improvements: use Switch component, dynamic password placeholder
- Add confirmation dialog for orphaned credentials (URL empty but has username/password)
- Prevent browser password autofill with randomized IDs and CSS text-security
- Unify ProxyConfig type definition in types.ts
2025-12-18 16:14:37 +08:00
fawney19
3e50c157be feat: add HTTP/SOCKS5 proxy support for API endpoints
- Add proxy field to ProviderEndpoint database model with migration
- Add ProxyConfig Pydantic model for proxy URL validation
- Extend HTTP client pool with create_client_with_proxy method
- Integrate proxy configuration in chat_handler_base.py and cli_handler_base.py
- Update admin API endpoints to support proxy configuration CRUD
- Add proxy configuration UI in frontend EndpointFormDialog

Fixes #28
2025-12-18 14:46:47 +08:00
38 changed files with 974 additions and 345 deletions

View File

@@ -0,0 +1,57 @@
"""add proxy field to provider_endpoints
Revision ID: f30f9936f6a2
Revises: 1cc6942cf06f
Create Date: 2025-12-18 06:31:58.451112+00:00
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy import inspect
# revision identifiers, used by Alembic.
revision = 'f30f9936f6a2'
down_revision = '1cc6942cf06f'
branch_labels = None
depends_on = None
def column_exists(table_name: str, column_name: str) -> bool:
"""检查列是否存在"""
bind = op.get_bind()
inspector = inspect(bind)
columns = [col['name'] for col in inspector.get_columns(table_name)]
return column_name in columns
def get_column_type(table_name: str, column_name: str) -> str:
"""获取列的类型"""
bind = op.get_bind()
inspector = inspect(bind)
for col in inspector.get_columns(table_name):
if col['name'] == column_name:
return str(col['type']).upper()
return ''
def upgrade() -> None:
"""添加 proxy 字段到 provider_endpoints 表"""
if not column_exists('provider_endpoints', 'proxy'):
# 字段不存在,直接添加 JSONB 类型
op.add_column('provider_endpoints', sa.Column('proxy', JSONB(), nullable=True))
else:
# 字段已存在,检查是否需要转换类型
col_type = get_column_type('provider_endpoints', 'proxy')
if 'JSONB' not in col_type:
# 如果是 JSON 类型,转换为 JSONB
op.execute(
'ALTER TABLE provider_endpoints '
'ALTER COLUMN proxy TYPE JSONB USING proxy::jsonb'
)
def downgrade() -> None:
"""移除 proxy 字段"""
if column_exists('provider_endpoints', 'proxy'):
op.drop_column('provider_endpoints', 'proxy')

View File

@@ -1,5 +1,5 @@
import client from '../client'
import type { ProviderEndpoint } from './types'
import type { ProviderEndpoint, ProxyConfig } from './types'
/**
* 获取指定 Provider 的所有 Endpoints
@@ -38,6 +38,7 @@ export async function createEndpoint(
rate_limit?: number
is_active?: boolean
config?: Record<string, any>
proxy?: ProxyConfig | null
}
): Promise<ProviderEndpoint> {
const response = await client.post(`/api/admin/endpoints/providers/${providerId}/endpoints`, data)
@@ -63,6 +64,7 @@ export async function updateEndpoint(
rate_limit: number
is_active: boolean
config: Record<string, any>
proxy: ProxyConfig | null
}>
): Promise<ProviderEndpoint> {
const response = await client.put(`/api/admin/endpoints/${endpointId}`, data)

View File

@@ -20,6 +20,16 @@ export const API_FORMAT_LABELS: Record<string, string> = {
[API_FORMATS.GEMINI_CLI]: 'Gemini CLI',
}
/**
* 代理配置类型
*/
export interface ProxyConfig {
url: string
username?: string
password?: string
enabled?: boolean // 是否启用代理false 时保留配置但不使用)
}
export interface ProviderEndpoint {
id: string
provider_id: string
@@ -41,6 +51,7 @@ export interface ProviderEndpoint {
last_failure_at?: string
is_active: boolean
config?: Record<string, any>
proxy?: ProxyConfig | null
total_keys: number
active_keys: number
created_at: string

View File

@@ -132,7 +132,7 @@
type="number"
min="1"
max="10000"
placeholder="100"
placeholder="留空不限制"
class="h-10"
@update:model-value="(v) => form.rate_limit = parseNumberInput(v, { min: 1, max: 10000 })"
/>
@@ -376,7 +376,7 @@ const form = ref<StandaloneKeyFormData>({
initial_balance_usd: 10,
expire_days: undefined,
never_expire: true,
rate_limit: 100,
rate_limit: undefined,
auto_delete_on_expiry: false,
allowed_providers: [],
allowed_api_formats: [],
@@ -389,7 +389,7 @@ function resetForm() {
initial_balance_usd: 10,
expire_days: undefined,
never_expire: true,
rate_limit: 100,
rate_limit: undefined,
auto_delete_on_expiry: false,
allowed_providers: [],
allowed_api_formats: [],
@@ -408,7 +408,7 @@ function loadKeyData() {
initial_balance_usd: props.apiKey.initial_balance_usd,
expire_days: props.apiKey.expire_days,
never_expire: props.apiKey.never_expire,
rate_limit: props.apiKey.rate_limit || 100,
rate_limit: props.apiKey.rate_limit,
auto_delete_on_expiry: props.apiKey.auto_delete_on_expiry,
allowed_providers: props.apiKey.allowed_providers || [],
allowed_api_formats: props.apiKey.allowed_api_formats || [],

View File

@@ -9,7 +9,7 @@
>
<form
class="space-y-6"
@submit.prevent="handleSubmit"
@submit.prevent="handleSubmit()"
>
<!-- API 配置 -->
<div class="space-y-4">
@@ -132,6 +132,79 @@
</div>
</div>
</div>
<!-- 代理配置 -->
<div class="space-y-4">
<div class="flex items-center justify-between">
<h3 class="text-sm font-medium">
代理配置
</h3>
<div class="flex items-center gap-2">
<Switch v-model="proxyEnabled" />
<span class="text-sm text-muted-foreground">启用代理</span>
</div>
</div>
<div
v-if="proxyEnabled"
class="space-y-4 rounded-lg border p-4"
>
<div class="space-y-2">
<Label for="proxy_url">代理 URL *</Label>
<Input
id="proxy_url"
v-model="form.proxy_url"
placeholder="http://host:port 或 socks5://host:port"
required
:class="proxyUrlError ? 'border-red-500' : ''"
/>
<p
v-if="proxyUrlError"
class="text-xs text-red-500"
>
{{ proxyUrlError }}
</p>
<p
v-else
class="text-xs text-muted-foreground"
>
支持 HTTPHTTPSSOCKS5 代理
</p>
</div>
<div class="grid grid-cols-2 gap-4">
<div class="space-y-2">
<Label for="proxy_user">用户名可选</Label>
<Input
:id="`proxy_user_${formId}`"
:name="`proxy_user_${formId}`"
v-model="form.proxy_username"
placeholder="代理认证用户名"
autocomplete="off"
data-form-type="other"
data-lpignore="true"
data-1p-ignore="true"
/>
</div>
<div class="space-y-2">
<Label :for="`proxy_pass_${formId}`">密码可选</Label>
<Input
:id="`proxy_pass_${formId}`"
:name="`proxy_pass_${formId}`"
v-model="form.proxy_password"
type="text"
:placeholder="passwordPlaceholder"
autocomplete="off"
data-form-type="other"
data-lpignore="true"
data-1p-ignore="true"
:style="{ '-webkit-text-security': 'disc', 'text-security': 'disc' }"
/>
</div>
</div>
</div>
</div>
</form>
<template #footer>
@@ -145,12 +218,24 @@
</Button>
<Button
:disabled="loading || !form.base_url || (!isEditMode && !form.api_format)"
@click="handleSubmit"
@click="handleSubmit()"
>
{{ loading ? (isEditMode ? '保存中...' : '创建中...') : (isEditMode ? '保存修改' : '创建') }}
</Button>
</template>
</Dialog>
<!-- 确认清空凭据对话框 -->
<AlertDialog
v-model="showClearCredentialsDialog"
title="清空代理凭据"
description="代理 URL 为空,但用户名和密码仍有值。是否清空这些凭据并继续保存?"
type="warning"
confirm-text="清空并保存"
cancel-text="返回编辑"
@confirm="confirmClearCredentials"
@cancel="showClearCredentialsDialog = false"
/>
</template>
<script setup lang="ts">
@@ -165,7 +250,9 @@ import {
SelectValue,
SelectContent,
SelectItem,
Switch,
} from '@/components/ui'
import AlertDialog from '@/components/common/AlertDialog.vue'
import { Link, SquarePen } from 'lucide-vue-next'
import { useToast } from '@/composables/useToast'
import { useFormDialog } from '@/composables/useFormDialog'
@@ -194,6 +281,11 @@ const emit = defineEmits<{
const { success, error: showError } = useToast()
const loading = ref(false)
const selectOpen = ref(false)
const proxyEnabled = ref(false)
const showClearCredentialsDialog = ref(false) // 确认清空凭据对话框
// 生成随机 ID 防止浏览器自动填充
const formId = Math.random().toString(36).substring(2, 10)
// 内部状态
const internalOpen = computed(() => props.modelValue)
@@ -207,7 +299,11 @@ const form = ref({
max_retries: 3,
max_concurrent: undefined as number | undefined,
rate_limit: undefined as number | undefined,
is_active: true
is_active: true,
// 代理配置
proxy_url: '',
proxy_username: '',
proxy_password: '',
})
// API 格式列表
@@ -237,6 +333,53 @@ const defaultPathPlaceholder = computed(() => {
return `留空使用默认路径:${defaultPath.value}`
})
// 检查是否有已保存的密码(后端返回 *** 表示有密码)
const hasExistingPassword = computed(() => {
if (!props.endpoint?.proxy) return false
const proxy = props.endpoint.proxy as { password?: string }
return proxy?.password === MASKED_PASSWORD
})
// 密码输入框的 placeholder
const passwordPlaceholder = computed(() => {
if (hasExistingPassword.value) {
return '已保存密码,留空保持不变'
}
return '代理认证密码'
})
// 代理 URL 验证
const proxyUrlError = computed(() => {
// 只有启用代理且填写了 URL 时才验证
if (!proxyEnabled.value || !form.value.proxy_url) {
return ''
}
const url = form.value.proxy_url.trim()
// 检查禁止的特殊字符
if (/[\n\r]/.test(url)) {
return '代理 URL 包含非法字符'
}
// 验证协议(不支持 SOCKS4
if (!/^(http|https|socks5):\/\//i.test(url)) {
return '代理 URL 必须以 http://, https:// 或 socks5:// 开头'
}
try {
const parsed = new URL(url)
if (!parsed.host) {
return '代理 URL 必须包含有效的 host'
}
// 禁止 URL 中内嵌认证信息
if (parsed.username || parsed.password) {
return '请勿在 URL 中包含用户名和密码,请使用独立的认证字段'
}
} catch {
return '代理 URL 格式无效'
}
return ''
})
// 组件挂载时加载API格式
onMounted(() => {
loadApiFormats()
@@ -252,14 +395,23 @@ function resetForm() {
max_retries: 3,
max_concurrent: undefined,
rate_limit: undefined,
is_active: true
is_active: true,
proxy_url: '',
proxy_username: '',
proxy_password: '',
}
proxyEnabled.value = false
}
// 原始密码占位符(后端返回的脱敏标记)
const MASKED_PASSWORD = '***'
// 加载端点数据(编辑模式)
function loadEndpointData() {
if (!props.endpoint) return
const proxy = props.endpoint.proxy as { url?: string; username?: string; password?: string; enabled?: boolean } | null
form.value = {
api_format: props.endpoint.api_format,
base_url: props.endpoint.base_url,
@@ -268,8 +420,15 @@ function loadEndpointData() {
max_retries: props.endpoint.max_retries,
max_concurrent: props.endpoint.max_concurrent || undefined,
rate_limit: props.endpoint.rate_limit || undefined,
is_active: props.endpoint.is_active
is_active: props.endpoint.is_active,
proxy_url: proxy?.url || '',
proxy_username: proxy?.username || '',
// 如果密码是脱敏标记,显示为空(让用户知道有密码但看不到)
proxy_password: proxy?.password === MASKED_PASSWORD ? '' : (proxy?.password || ''),
}
// 根据 enabled 字段或 url 存在判断是否启用代理
proxyEnabled.value = proxy?.enabled ?? !!proxy?.url
}
// 使用 useFormDialog 统一处理对话框逻辑
@@ -282,12 +441,47 @@ const { isEditMode, handleDialogUpdate, handleCancel } = useFormDialog({
resetForm,
})
// 构建代理配置
// - 有 URL 时始终保存配置,通过 enabled 字段控制是否启用
// - 无 URL 时返回 null
function buildProxyConfig(): { url: string; username?: string; password?: string; enabled: boolean } | null {
if (!form.value.proxy_url) {
// 没填 URL无代理配置
return null
}
return {
url: form.value.proxy_url,
username: form.value.proxy_username || undefined,
password: form.value.proxy_password || undefined,
enabled: proxyEnabled.value, // 开关状态决定是否启用
}
}
// 提交表单
const handleSubmit = async () => {
const handleSubmit = async (skipCredentialCheck = false) => {
if (!props.provider && !props.endpoint) return
// 只在开关开启且填写了 URL 时验证
if (proxyEnabled.value && form.value.proxy_url && proxyUrlError.value) {
showError(proxyUrlError.value, '代理配置错误')
return
}
// 检查:开关开启但没有 URL却有用户名或密码
const hasOrphanedCredentials = proxyEnabled.value
&& !form.value.proxy_url
&& (form.value.proxy_username || form.value.proxy_password)
if (hasOrphanedCredentials && !skipCredentialCheck) {
// 弹出确认对话框
showClearCredentialsDialog.value = true
return
}
loading.value = true
try {
const proxyConfig = buildProxyConfig()
if (isEditMode.value && props.endpoint) {
// 更新端点
await updateEndpoint(props.endpoint.id, {
@@ -297,7 +491,8 @@ const handleSubmit = async () => {
max_retries: form.value.max_retries,
max_concurrent: form.value.max_concurrent,
rate_limit: form.value.rate_limit,
is_active: form.value.is_active
is_active: form.value.is_active,
proxy: proxyConfig,
})
success('端点已更新', '保存成功')
@@ -313,7 +508,8 @@ const handleSubmit = async () => {
max_retries: form.value.max_retries,
max_concurrent: form.value.max_concurrent,
rate_limit: form.value.rate_limit,
is_active: form.value.is_active
is_active: form.value.is_active,
proxy: proxyConfig,
})
success('端点创建成功', '成功')
@@ -329,4 +525,12 @@ const handleSubmit = async () => {
loading.value = false
}
}
// 确认清空凭据并继续保存
const confirmClearCredentials = () => {
form.value.proxy_username = ''
form.value.proxy_password = ''
showClearCredentialsDialog.value = false
handleSubmit(true) // 跳过凭据检查,直接提交
}
</script>

View File

@@ -25,7 +25,7 @@
</h3>
<div class="flex items-center gap-1 text-sm font-mono text-muted-foreground bg-muted px-2 py-0.5 rounded">
<span>{{ detail?.model || '-' }}</span>
<template v-if="detail?.target_model">
<template v-if="detail?.target_model && detail.target_model !== detail.model">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 20 20"

View File

@@ -185,32 +185,13 @@
</div>
</CardSection>
<!-- API Key 管理配置 -->
<!-- 独立余额 Key 过期管理 -->
<CardSection
title="API Key 管理"
description="API Key 相关配置"
title="独立余额 Key 过期管理"
description="独立余额 Key 的过期处理策略(普通用户 Key 不会过期)"
>
<div class="grid grid-cols-1 md:grid-cols-2 gap-6">
<div>
<Label
for="api-key-expire"
class="block text-sm font-medium"
>
API密钥过期天数
</Label>
<Input
id="api-key-expire"
v-model.number="systemConfig.api_key_expire_days"
type="number"
placeholder="0"
class="mt-1"
/>
<p class="mt-1 text-xs text-muted-foreground">
0 表示永不过期
</p>
</div>
<div class="flex items-center h-full pt-6">
<div class="flex items-center h-full">
<div class="flex items-center space-x-2">
<Checkbox
id="auto-delete-expired-keys"
@@ -224,7 +205,7 @@
自动删除过期 Key
</Label>
<p class="text-xs text-muted-foreground">
关闭时仅禁用过期 Key
关闭时仅禁用过期 Key不会物理删除
</p>
</div>
</div>
@@ -448,6 +429,25 @@
避免单次操作过大影响性能
</p>
</div>
<div>
<Label
for="audit-log-retention-days"
class="block text-sm font-medium"
>
审计日志保留天数
</Label>
<Input
id="audit-log-retention-days"
v-model.number="systemConfig.audit_log_retention_days"
type="number"
placeholder="30"
class="mt-1"
/>
<p class="mt-1 text-xs text-muted-foreground">
超过后删除审计日志记录
</p>
</div>
</div>
<!-- 清理策略说明 -->
@@ -460,6 +460,7 @@
<p>2. <strong>压缩日志阶段</strong>: body 字段被压缩存储节省空间</p>
<p>3. <strong>统计阶段</strong>: 仅保留 tokens成本等统计信息</p>
<p>4. <strong>归档删除</strong>: 超过保留期限后完全删除记录</p>
<p>5. <strong>审计日志</strong>: 独立清理记录用户登录操作等安全事件</p>
</div>
</div>
</CardSection>
@@ -796,8 +797,7 @@ interface SystemConfig {
// 用户注册
enable_registration: boolean
require_email_verification: boolean
// API Key 管理
api_key_expire_days: number
// 独立余额 Key 过期管理
auto_delete_expired_keys: boolean
// 日志记录
request_log_level: string
@@ -811,6 +811,7 @@ interface SystemConfig {
header_retention_days: number
log_retention_days: number
cleanup_batch_size: number
audit_log_retention_days: number
}
const loading = ref(false)
@@ -845,8 +846,7 @@ const systemConfig = ref<SystemConfig>({
// 用户注册
enable_registration: false,
require_email_verification: false,
// API Key 管理
api_key_expire_days: 0,
// 独立余额 Key 过期管理
auto_delete_expired_keys: false,
// 日志记录
request_log_level: 'basic',
@@ -860,6 +860,7 @@ const systemConfig = ref<SystemConfig>({
header_retention_days: 90,
log_retention_days: 365,
cleanup_batch_size: 1000,
audit_log_retention_days: 30,
})
// 计算属性KB 和 字节 之间的转换
@@ -901,8 +902,7 @@ async function loadSystemConfig() {
// 用户注册
'enable_registration',
'require_email_verification',
// API Key 管理
'api_key_expire_days',
// 独立余额 Key 过期管理
'auto_delete_expired_keys',
// 日志记录
'request_log_level',
@@ -916,6 +916,7 @@ async function loadSystemConfig() {
'header_retention_days',
'log_retention_days',
'cleanup_batch_size',
'audit_log_retention_days',
]
for (const key of configs) {
@@ -960,12 +961,7 @@ async function saveSystemConfig() {
value: systemConfig.value.require_email_verification,
description: '是否需要邮箱验证'
},
// API Key 管理
{
key: 'api_key_expire_days',
value: systemConfig.value.api_key_expire_days,
description: 'API密钥过期天数'
},
// 独立余额 Key 过期管理
{
key: 'auto_delete_expired_keys',
value: systemConfig.value.auto_delete_expired_keys,
@@ -1023,6 +1019,11 @@ async function saveSystemConfig() {
value: systemConfig.value.cleanup_batch_size,
description: '每批次清理的记录数'
},
{
key: 'audit_log_retention_days',
value: systemConfig.value.audit_log_retention_days,
description: '审计日志保留天数'
},
]
const promises = configItems.map(item =>

View File

@@ -223,7 +223,7 @@ class AdminCreateStandaloneKeyAdapter(AdminApiAdapter):
allowed_providers=self.key_data.allowed_providers,
allowed_api_formats=self.key_data.allowed_api_formats,
allowed_models=self.key_data.allowed_models,
rate_limit=self.key_data.rate_limit or 100,
rate_limit=self.key_data.rate_limit, # None 表示不限制
expire_days=self.key_data.expire_days,
initial_balance_usd=self.key_data.initial_balance_usd,
is_standalone=True, # 标记为独立Key

View File

@@ -5,7 +5,7 @@ ProviderEndpoint CRUD 管理 API
import uuid
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import List
from typing import List, Optional
from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy import and_, func
@@ -27,6 +27,16 @@ router = APIRouter(tags=["Endpoint Management"])
pipeline = ApiRequestPipeline()
def mask_proxy_password(proxy_config: Optional[dict]) -> Optional[dict]:
"""对代理配置中的密码进行脱敏处理"""
if not proxy_config:
return None
masked = dict(proxy_config)
if masked.get("password"):
masked["password"] = "***"
return masked
@router.get("/providers/{provider_id}/endpoints", response_model=List[ProviderEndpointResponse])
async def list_provider_endpoints(
provider_id: str,
@@ -153,6 +163,7 @@ class AdminListProviderEndpointsAdapter(AdminApiAdapter):
"api_format": endpoint.api_format,
"total_keys": total_keys_map.get(endpoint.id, 0),
"active_keys": active_keys_map.get(endpoint.id, 0),
"proxy": mask_proxy_password(endpoint.proxy),
}
endpoint_dict.pop("_sa_instance_state", None)
result.append(ProviderEndpointResponse(**endpoint_dict))
@@ -202,6 +213,7 @@ class AdminCreateProviderEndpointAdapter(AdminApiAdapter):
rate_limit=self.endpoint_data.rate_limit,
is_active=True,
config=self.endpoint_data.config,
proxy=self.endpoint_data.proxy.model_dump() if self.endpoint_data.proxy else None,
created_at=now,
updated_at=now,
)
@@ -215,12 +227,13 @@ class AdminCreateProviderEndpointAdapter(AdminApiAdapter):
endpoint_dict = {
k: v
for k, v in new_endpoint.__dict__.items()
if k not in {"api_format", "_sa_instance_state"}
if k not in {"api_format", "_sa_instance_state", "proxy"}
}
return ProviderEndpointResponse(
**endpoint_dict,
provider_name=provider.name,
api_format=new_endpoint.api_format,
proxy=mask_proxy_password(new_endpoint.proxy),
total_keys=0,
active_keys=0,
)
@@ -259,12 +272,13 @@ class AdminGetProviderEndpointAdapter(AdminApiAdapter):
endpoint_dict = {
k: v
for k, v in endpoint_obj.__dict__.items()
if k not in {"api_format", "_sa_instance_state"}
if k not in {"api_format", "_sa_instance_state", "proxy"}
}
return ProviderEndpointResponse(
**endpoint_dict,
provider_name=provider.name,
api_format=endpoint_obj.api_format,
proxy=mask_proxy_password(endpoint_obj.proxy),
total_keys=total_keys,
active_keys=active_keys,
)
@@ -284,6 +298,17 @@ class AdminUpdateProviderEndpointAdapter(AdminApiAdapter):
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
update_data = self.endpoint_data.model_dump(exclude_unset=True)
# 把 proxy 转换为 dict 存储,支持显式设置为 None 清除代理
if "proxy" in update_data:
if update_data["proxy"] is not None:
new_proxy = dict(update_data["proxy"])
# 只有当密码字段未提供时才保留原密码(空字符串视为显式清除)
if "password" not in new_proxy and endpoint.proxy:
old_password = endpoint.proxy.get("password")
if old_password:
new_proxy["password"] = old_password
update_data["proxy"] = new_proxy
# proxy 为 None 时保留,用于清除代理配置
for field, value in update_data.items():
setattr(endpoint, field, value)
endpoint.updated_at = datetime.now(timezone.utc)
@@ -311,12 +336,13 @@ class AdminUpdateProviderEndpointAdapter(AdminApiAdapter):
endpoint_dict = {
k: v
for k, v in endpoint.__dict__.items()
if k not in {"api_format", "_sa_instance_state"}
if k not in {"api_format", "_sa_instance_state", "proxy"}
}
return ProviderEndpointResponse(
**endpoint_dict,
provider_name=provider.name if provider else "Unknown",
api_format=endpoint.api_format,
proxy=mask_proxy_password(endpoint.proxy),
total_keys=total_keys,
active_keys=active_keys,
)

View File

@@ -140,7 +140,7 @@ class AnnouncementOptionalAuthAdapter(ApiAdapter):
if not authorization or not authorization.lower().startswith("bearer "):
return None
token = authorization.replace("Bearer ", "").strip()
token = authorization[7:].strip()
try:
payload = await AuthService.verify_token(token, token_type="access")
user_id = payload.get("user_id")

View File

@@ -177,7 +177,7 @@ class ApiRequestPipeline:
if not authorization or not authorization.lower().startswith("bearer "):
raise HTTPException(status_code=401, detail="缺少管理员凭证")
token = authorization.replace("Bearer ", "").strip()
token = authorization[7:].strip()
try:
payload = await self.auth_service.verify_token(token, token_type="access")
except HTTPException:
@@ -204,7 +204,7 @@ class ApiRequestPipeline:
if not authorization or not authorization.lower().startswith("bearer "):
raise HTTPException(status_code=401, detail="缺少用户凭证")
token = authorization.replace("Bearer ", "").strip()
token = authorization[7:].strip()
try:
payload = await self.auth_service.verify_token(token, token_type="access")
except HTTPException:

View File

@@ -28,7 +28,7 @@
from __future__ import annotations
import time
from typing import Any, Callable, Dict, Optional, Protocol, runtime_checkable
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Protocol, runtime_checkable
from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse
@@ -43,6 +43,9 @@ from src.services.provider.format import normalize_api_format
from src.services.system.audit import audit_service
from src.services.usage.service import UsageService
if TYPE_CHECKING:
from src.api.handlers.base.stream_context import StreamContext
class MessageTelemetry:
@@ -399,6 +402,41 @@ class BaseMessageHandler:
# 创建后台任务,不阻塞当前流
asyncio.create_task(_do_update())
def _update_usage_to_streaming_with_ctx(self, ctx: "StreamContext") -> None:
"""更新 Usage 状态为 streaming同时更新 provider 和 target_model
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
Args:
ctx: 流式上下文,包含 provider_name 和 mapped_model
"""
import asyncio
from src.database.database import get_db
target_request_id = self.request_id
provider = ctx.provider_name
target_model = ctx.mapped_model
async def _do_update() -> None:
try:
db_gen = get_db()
db = next(db_gen)
try:
UsageService.update_usage_status(
db=db,
request_id=target_request_id,
status="streaming",
provider=provider,
target_model=target_model,
)
finally:
db.close()
except Exception as e:
logger.warning(f"[{target_request_id}] 更新 Usage 状态为 streaming 失败: {e}")
# 创建后台任务,不阻塞当前流
asyncio.create_task(_do_update())
def _log_request_error(self, message: str, error: Exception) -> None:
"""记录请求错误日志,对业务异常不打印堆栈

View File

@@ -64,18 +64,6 @@ class ChatAdapterBase(ApiAdapter):
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
self.response_normalizer = None
# 可选启用响应规范化
self._init_response_normalizer()
def _init_response_normalizer(self):
"""初始化响应规范化器 - 子类可覆盖"""
try:
from src.services.provider.response_normalizer import ResponseNormalizer
self.response_normalizer = ResponseNormalizer()
except ImportError:
pass
async def handle(self, context: ApiRequestContext):
"""处理 Chat API 请求"""
@@ -228,8 +216,6 @@ class ChatAdapterBase(ApiAdapter):
user_agent=user_agent,
start_time=start_time,
allowed_api_formats=self.allowed_api_formats,
response_normalizer=self.response_normalizer,
enable_response_normalization=self.response_normalizer is not None,
adapter_detector=self.detect_capability_requirements,
)

View File

@@ -88,8 +88,6 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
user_agent: str,
start_time: float,
allowed_api_formats: Optional[list] = None,
response_normalizer: Optional[Any] = None,
enable_response_normalization: bool = False,
adapter_detector: Optional[Callable[[Dict[str, str], Optional[Dict[str, Any]]], Dict[str, bool]]] = None,
):
allowed = allowed_api_formats or [self.FORMAT_ID]
@@ -106,8 +104,6 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
)
self._parser: Optional[ResponseParser] = None
self._request_builder = PassthroughRequestBuilder()
self.response_normalizer = response_normalizer
self.enable_response_normalization = enable_response_normalization
@property
def parser(self) -> ResponseParser:
@@ -297,11 +293,15 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
# 创建类型安全的流式上下文
ctx = StreamContext(model=model, api_format=api_format)
# 创建更新状态的回调闭包(可以访问 ctx
def update_streaming_status() -> None:
self._update_usage_to_streaming_with_ctx(ctx)
# 创建流处理器
stream_processor = StreamProcessor(
request_id=self.request_id,
default_parser=self.parser,
on_streaming_start=self._update_usage_to_streaming,
on_streaming_start=update_streaming_status,
)
# 定义请求函数
@@ -466,7 +466,13 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
pool=config.http_pool_timeout,
)
http_client = httpx.AsyncClient(timeout=timeout_config, follow_redirects=True)
# 创建 HTTP 客户端(支持代理配置)
from src.clients.http_client import HTTPClientPool
http_client = HTTPClientPool.create_client_with_proxy(
proxy_config=endpoint.proxy,
timeout=timeout_config,
)
try:
response_ctx = http_client.stream(
"POST", url, json=provider_payload, headers=provider_headers
@@ -634,10 +640,14 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
logger.info(f" [{self.request_id}] 发送非流式请求: Provider={provider.name}, "
f"模型={model} -> {mapped_model or '无映射'}")
async with httpx.AsyncClient(
timeout=float(endpoint.timeout),
follow_redirects=True,
) as http_client:
# 创建 HTTP 客户端(支持代理配置)
from src.clients.http_client import HTTPClientPool
http_client = HTTPClientPool.create_client_with_proxy(
proxy_config=endpoint.proxy,
timeout=httpx.Timeout(float(endpoint.timeout)),
)
async with http_client:
resp = await http_client.post(url, json=provider_payload, headers=provider_hdrs)
status_code = resp.status_code

View File

@@ -454,7 +454,13 @@ class CliMessageHandlerBase(BaseMessageHandler):
f"Key=***{key.api_key[-4:]}, "
f"原始模型={ctx.model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}")
http_client = httpx.AsyncClient(timeout=timeout_config, follow_redirects=True)
# 创建 HTTP 客户端(支持代理配置)
from src.clients.http_client import HTTPClientPool
http_client = HTTPClientPool.create_client_with_proxy(
proxy_config=endpoint.proxy,
timeout=timeout_config,
)
try:
response_ctx = http_client.stream(
"POST", url, json=provider_payload, headers=provider_headers
@@ -526,7 +532,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
async for chunk in stream_response.aiter_raw():
# 在第一次输出数据前更新状态为 streaming
if not streaming_status_updated:
self._update_usage_to_streaming(ctx.request_id)
self._update_usage_to_streaming_with_ctx(ctx)
streaming_status_updated = True
buffer += chunk
@@ -810,7 +816,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 在第一次输出数据前更新状态为 streaming
if prefetched_chunks:
self._update_usage_to_streaming(ctx.request_id)
self._update_usage_to_streaming_with_ctx(ctx)
# 先处理预读的字节块
for chunk in prefetched_chunks:
@@ -1419,10 +1425,14 @@ class CliMessageHandlerBase(BaseMessageHandler):
f"Key=***{key.api_key[-4:]}, "
f"原始模型={model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}")
async with httpx.AsyncClient(
timeout=float(endpoint.timeout),
follow_redirects=True,
) as http_client:
# 创建 HTTP 客户端(支持代理配置)
from src.clients.http_client import HTTPClientPool
http_client = HTTPClientPool.create_client_with_proxy(
proxy_config=endpoint.proxy,
timeout=httpx.Timeout(float(endpoint.timeout)),
)
async with http_client:
resp = await http_client.post(url, json=provider_payload, headers=provider_headers)
status_code = resp.status_code

View File

@@ -131,10 +131,5 @@ class ClaudeChatHandler(ChatHandlerBase):
Returns:
规范化后的响应
"""
if self.response_normalizer and self.response_normalizer.should_normalize(response):
result: Dict[str, Any] = self.response_normalizer.normalize_claude_response(
response_data=response,
request_id=self.request_id,
)
return result
# 作为中转站,直接透传响应,不做标准化处理
return response

View File

@@ -148,17 +148,6 @@ class GeminiChatHandler(ChatHandlerBase):
Returns:
规范化后的响应
TODO: 如果需要,实现响应规范化逻辑
"""
# 可选:使用 response_normalizer 进行规范化
# if (
# self.response_normalizer
# and self.response_normalizer.should_normalize(response)
# ):
# return self.response_normalizer.normalize_gemini_response(
# response_data=response,
# request_id=self.request_id,
# strict=False,
# )
# 作为中转站,直接透传响应,不做标准化处理
return response

View File

@@ -128,10 +128,5 @@ class OpenAIChatHandler(ChatHandlerBase):
Returns:
规范化后的响应
"""
if self.response_normalizer and self.response_normalizer.should_normalize(response):
return self.response_normalizer.normalize_openai_response(
response_data=response,
request_id=self.request_id,
strict=False,
)
# 作为中转站,直接透传响应,不做标准化处理
return response

View File

@@ -5,12 +5,55 @@
from contextlib import asynccontextmanager
from typing import Any, Dict, Optional
from urllib.parse import quote, urlparse
import httpx
from src.core.logger import logger
def build_proxy_url(proxy_config: Dict[str, Any]) -> Optional[str]:
"""
根据代理配置构建完整的代理 URL
Args:
proxy_config: 代理配置字典,包含 url, username, password, enabled
Returns:
完整的代理 URL如 socks5://user:pass@host:port
如果 enabled=False 或无配置,返回 None
"""
if not proxy_config:
return None
# 检查 enabled 字段,默认为 True兼容旧数据
if not proxy_config.get("enabled", True):
return None
proxy_url = proxy_config.get("url")
if not proxy_url:
return None
username = proxy_config.get("username")
password = proxy_config.get("password")
# 只要有用户名就添加认证信息(密码可以为空)
if username:
parsed = urlparse(proxy_url)
# URL 编码用户名和密码,处理特殊字符(如 @, :, /
encoded_username = quote(username, safe="")
encoded_password = quote(password, safe="") if password else ""
# 重新构建带认证的代理 URL
if encoded_password:
auth_proxy = f"{parsed.scheme}://{encoded_username}:{encoded_password}@{parsed.netloc}"
else:
auth_proxy = f"{parsed.scheme}://{encoded_username}@{parsed.netloc}"
if parsed.path:
auth_proxy += parsed.path
return auth_proxy
return proxy_url
class HTTPClientPool:
"""
@@ -121,6 +164,44 @@ class HTTPClientPool:
finally:
await client.aclose()
@classmethod
def create_client_with_proxy(
cls,
proxy_config: Optional[Dict[str, Any]] = None,
timeout: Optional[httpx.Timeout] = None,
**kwargs: Any,
) -> httpx.AsyncClient:
"""
创建带代理配置的HTTP客户端
Args:
proxy_config: 代理配置字典,包含 url, username, password
timeout: 超时配置
**kwargs: 其他 httpx.AsyncClient 配置参数
Returns:
配置好的 httpx.AsyncClient 实例
"""
config: Dict[str, Any] = {
"http2": False,
"verify": True,
"follow_redirects": True,
}
if timeout:
config["timeout"] = timeout
else:
config["timeout"] = httpx.Timeout(10.0, read=300.0)
# 添加代理配置
proxy_url = build_proxy_url(proxy_config) if proxy_config else None
if proxy_url:
config["proxy"] = proxy_url
logger.debug(f"创建带代理的HTTP客户端: {proxy_config.get('url', 'unknown')}")
config.update(kwargs)
return httpx.AsyncClient(**config)
# 便捷访问函数
def get_http_client() -> httpx.AsyncClient:

View File

@@ -120,7 +120,7 @@ class RedisClientManager:
if self._circuit_open_until and time.time() < self._circuit_open_until:
remaining = self._circuit_open_until - time.time()
logger.warning(
"Redis 客户端处于熔断状态,跳过初始化,剩余 %.1f 秒 (last_error: %s)",
"Redis 客户端处于熔断状态,跳过初始化,剩余 {:.1f} 秒 (last_error: {})",
remaining,
self._last_error,
)
@@ -200,7 +200,7 @@ class RedisClientManager:
if self._consecutive_failures >= self._circuit_threshold:
self._circuit_open_until = time.time() + self._circuit_reset_seconds
logger.warning(
"Redis 初始化连续失败 %s 次,开启熔断 %s 秒。"
"Redis 初始化连续失败 {} 次,开启熔断 {} 秒。"
"熔断期间以下功能将降级: 缓存亲和性、分布式并发控制、RPM限流。"
"可通过管理 API /api/admin/system/redis/reset-circuit 手动重置。",
self._consecutive_failures,

View File

@@ -105,6 +105,13 @@ class Config:
self.llm_api_rate_limit = int(os.getenv("LLM_API_RATE_LIMIT", "100"))
self.public_api_rate_limit = int(os.getenv("PUBLIC_API_RATE_LIMIT", "60"))
# 可信代理配置
# TRUSTED_PROXY_COUNT: 信任的代理层数(默认 1即信任最近一层代理
# 设置为 0 表示不信任任何代理头,直接使用连接 IP
# 当服务部署在 Nginx/CloudFlare 等反向代理后面时,设置为对应的代理层数
# 如果服务直接暴露公网,应设置为 0 以防止 IP 伪造
self.trusted_proxy_count = int(os.getenv("TRUSTED_PROXY_COUNT", "1"))
# 异常处理配置
# 设置为 True 时ProxyException 会传播到路由层以便记录 provider_request_headers
# 设置为 False 时,使用全局异常处理器统一处理

View File

@@ -153,7 +153,7 @@ def _log_pool_capacity():
total_estimated = theoretical * workers
safe_limit = config.pg_max_connections - config.pg_reserved_connections
logger.info(
"数据库连接池配置: pool_size=%s, max_overflow=%s, workers=%s, total_estimated=%s, safe_limit=%s",
"数据库连接池配置: pool_size={}, max_overflow={}, workers={}, total_estimated={}, safe_limit={}",
config.db_pool_size,
config.db_max_overflow,
workers,
@@ -162,7 +162,7 @@ def _log_pool_capacity():
)
if total_estimated > safe_limit:
logger.warning(
"数据库连接池总需求可能超过 PostgreSQL 限制: %s > %s (pg_max_connections - reserved)"
"数据库连接池总需求可能超过 PostgreSQL 限制: {} > {} (pg_max_connections - reserved)"
"建议调整 DB_POOL_SIZE/DB_MAX_OVERFLOW 或减少 worker 数",
total_estimated,
safe_limit,
@@ -260,7 +260,8 @@ def get_db(request: Request = None) -> Generator[Session, None, None]: # type:
2. **管理后台 API**
- 路由层显式调用 db.commit()
- 每个操作独立提交,不依赖中间件
- 提交后设置 request.state.tx_committed_by_route = True
- 中间件看到此标志后跳过 commit只负责 close
3. **后台任务/调度器**
- 使用独立 Session通过 create_session() 或 next(get_db())
@@ -271,6 +272,17 @@ def get_db(request: Request = None) -> Generator[Session, None, None]: # type:
- FastAPI 请求:通过 Depends(get_db) 注入,支持中间件管理的 session 复用
- 非请求上下文:直接调用 get_db(),退化为独立 session 模式
路由层提交事务示例
==================
```python
@router.post("/example")
async def example(request: Request, db: Session = Depends(get_db)):
# ... 业务逻辑 ...
db.commit()
request.state.tx_committed_by_route = True # 告知中间件已提交
return {"message": "success"}
```
注意事项
========
- 本函数不自动提交事务

View File

@@ -49,7 +49,7 @@ async def initialize_providers():
# 从数据库加载所有活跃的提供商
providers = (
db.query(Provider)
.filter(Provider.is_active == True)
.filter(Provider.is_active.is_(True))
.order_by(Provider.provider_priority.asc())
.all()
)
@@ -122,6 +122,7 @@ async def lifespan(app: FastAPI):
logger.info("初始化全局Redis客户端...")
from src.clients.redis_client import get_redis_client
redis_client = None
try:
redis_client = await get_redis_client(require_redis=config.require_redis)
if redis_client:
@@ -133,6 +134,7 @@ async def lifespan(app: FastAPI):
logger.exception("[ERROR] Redis连接失败应用启动中止")
raise
logger.warning(f"Redis连接失败但配置允许降级将继续使用内存模式: {e}")
redis_client = None
# 初始化并发管理器内部会使用Redis
logger.info("初始化并发管理器...")
@@ -312,7 +314,7 @@ if frontend_dist.exists():
仅对非API路径生效
"""
# 如果是API路径不处理
if full_path.startswith("api/") or full_path.startswith("v1/"):
if full_path in {"api", "v1"} or full_path.startswith(("api/", "v1/")):
raise HTTPException(status_code=404, detail="Not Found")
# 返回index.html让前端路由处理

View File

@@ -1,16 +1,17 @@
"""
统一的插件中间件
统一的插件中间件(纯 ASGI 实现)
负责协调所有插件的调用
注意:使用纯 ASGI middleware 而非 BaseHTTPMiddleware
以避免 Starlette 已知的流式响应兼容性问题。
"""
import hashlib
import time
from typing import Any, Awaitable, Callable, Optional
from typing import Optional
from fastapi import HTTPException, Request
from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response as StarletteResponse
from starlette.requests import Request
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from src.config import config
from src.core.logger import logger
@@ -18,20 +19,25 @@ from src.plugins.manager import get_plugin_manager
from src.plugins.rate_limit.base import RateLimitResult
class PluginMiddleware(BaseHTTPMiddleware):
class PluginMiddleware:
"""
统一的插件调用中间件
统一的插件调用中间件(纯 ASGI 实现)
职责:
- 性能监控
- 限流控制 (可选)
- 数据库会话生命周期管理
注意: 认证由各路由通过 Depends() 显式声明,不在中间件层处理
为什么使用纯 ASGI 而非 BaseHTTPMiddleware:
- BaseHTTPMiddleware 会缓冲整个响应体,对流式响应不友好
- BaseHTTPMiddleware 与 StreamingResponse 存在已知兼容性问题
- 纯 ASGI 可以直接透传流式响应,无额外开销
"""
def __init__(self, app: Any) -> None:
super().__init__(app)
def __init__(self, app: ASGIApp) -> None:
self.app = app
self.plugin_manager = get_plugin_manager()
# 从配置读取速率限制值
@@ -61,152 +67,159 @@ class PluginMiddleware(BaseHTTPMiddleware):
"/v1/completions",
]
async def dispatch(
self, request: Request, call_next: Callable[[Request], Awaitable[StarletteResponse]]
) -> StarletteResponse:
"""处理请求并调用相应插件"""
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""ASGI 入口点"""
if scope["type"] != "http":
# 非 HTTP 请求(如 WebSocket直接透传
await self.app(scope, receive, send)
return
# 构建 Request 对象以便复用现有逻辑
request = Request(scope, receive, send)
# 记录请求开始时间
start_time = time.time()
# 设置 request.state 属性
# 注意Starlette 的 Request 对象总是有 state 属性State 实例)
request.state.request_id = request.headers.get("x-request-id", "")
request.state.start_time = start_time
# 标记:若请求过程中通过 Depends(get_db) 创建了会话,则由本中间件统一管理其生命周期
request.state.db_managed_by_middleware = True
response = None
exception_to_raise = None
# 1. 限流检查(在调用下游之前)
rate_limit_result = await self._call_rate_limit_plugins(request)
if rate_limit_result and not rate_limit_result.allowed:
# 限流触发返回429
await self._send_rate_limit_response(send, rate_limit_result)
return
# 2. 预处理插件调用
await self._call_pre_request_plugins(request)
# 用于捕获响应状态码
response_status_code: int = 0
async def send_wrapper(message: Message) -> None:
nonlocal response_status_code
if message["type"] == "http.response.start":
response_status_code = message.get("status", 0)
await send(message)
# 3. 调用下游应用
exception_occurred: Optional[Exception] = None
try:
# 1. 限流插件调用(可选功能)
rate_limit_result = await self._call_rate_limit_plugins(request)
if rate_limit_result and not rate_limit_result.allowed:
# 限流触发返回429
headers = rate_limit_result.headers or {}
raise HTTPException(
status_code=429,
detail=rate_limit_result.message or "Rate limit exceeded",
headers=headers,
)
# 2. 预处理插件调用
await self._call_pre_request_plugins(request)
# 处理请求
response = await call_next(request)
# 3. 提交关键数据库事务(在返回响应前)
# 这确保了 Usage 记录、配额扣减等关键数据在响应返回前持久化
try:
db = getattr(request.state, "db", None)
if isinstance(db, Session):
db.commit()
except Exception as commit_error:
logger.error(f"关键事务提交失败: {commit_error}")
try:
if isinstance(db, Session):
db.rollback()
except Exception:
pass
await self._call_error_plugins(request, commit_error, start_time)
# 返回 500 错误,因为数据可能不一致
response = JSONResponse(
status_code=500,
content={
"type": "error",
"error": {
"type": "database_error",
"message": "数据保存失败,请重试",
},
},
)
# 跳过后处理插件,直接返回错误响应
return response
# 4. 后处理插件调用(监控等,非关键操作)
# 这些操作失败不应影响用户响应
await self._call_post_request_plugins(request, response, start_time)
# 注意不在此处添加限流响应头因为在BaseHTTPMiddleware中
# 响应返回后修改headers会导致Content-Length不匹配错误
# 限流响应头已在返回429错误时正确包含见上面的HTTPException
except RuntimeError as e:
if str(e) == "No response returned.":
db = getattr(request.state, "db", None)
if isinstance(db, Session):
try:
db.rollback()
except Exception:
pass
logger.error("Downstream handler completed without returning a response")
await self._call_error_plugins(request, e, start_time)
if isinstance(db, Session):
try:
db.commit()
except Exception:
pass
response = JSONResponse(
status_code=500,
content={
"type": "error",
"error": {
"type": "internal_error",
"message": "Internal server error: downstream handler returned no response.",
},
},
)
else:
exception_to_raise = e
await self.app(scope, receive, send_wrapper)
except Exception as e:
# 回滚数据库事务
db = getattr(request.state, "db", None)
if isinstance(db, Session):
try:
db.rollback()
except Exception:
pass
exception_occurred = e
# 错误处理插件调用
await self._call_error_plugins(request, e, start_time)
raise
finally:
# 4. 数据库会话清理(无论成功与否)
await self._cleanup_db_session(request, exception_occurred)
# 尝试提交错误日志
if isinstance(db, Session):
# 5. 后处理插件调用(仅在成功时)
if not exception_occurred and response_status_code > 0:
await self._call_post_request_plugins(request, response_status_code, start_time)
async def _send_rate_limit_response(
self, send: Send, result: RateLimitResult
) -> None:
"""发送 429 限流响应"""
import json
body = json.dumps({
"type": "error",
"error": {
"type": "rate_limit_error",
"message": result.message or "Rate limit exceeded",
},
}).encode("utf-8")
headers = [(b"content-type", b"application/json")]
if result.headers:
for key, value in result.headers.items():
headers.append((key.lower().encode(), str(value).encode()))
await send({
"type": "http.response.start",
"status": 429,
"headers": headers,
})
await send({
"type": "http.response.body",
"body": body,
})
async def _cleanup_db_session(
self, request: Request, exception: Optional[Exception]
) -> None:
"""清理数据库会话
事务策略:
- 如果 request.state.tx_committed_by_route 为 True说明路由已自行提交中间件只负责 close
- 否则由中间件统一 commit/rollback
这避免了双重提交的问题,同时保持向后兼容。
"""
from sqlalchemy.orm import Session
db = getattr(request.state, "db", None)
if not isinstance(db, Session):
return
# 检查是否由路由层已经提交
tx_committed_by_route = getattr(request.state, "tx_committed_by_route", False)
try:
if exception is not None:
# 发生异常,回滚事务(无论谁负责提交)
try:
db.rollback()
except Exception as rollback_error:
logger.debug(f"回滚事务时出错(可忽略): {rollback_error}")
elif not tx_committed_by_route:
# 正常完成且路由未自行提交,由中间件提交事务
try:
db.commit()
except:
pass
exception_to_raise = e
except Exception as commit_error:
logger.error(f"关键事务提交失败: {commit_error}")
try:
db.rollback()
except Exception:
pass
# 如果 tx_committed_by_route 为 True跳过 commit路由已提交
finally:
db = getattr(request.state, "db", None)
if isinstance(db, Session):
try:
db.close()
except Exception as close_error:
# 连接池会处理连接的回收,这里的异常不应影响响应
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
# 在 finally 块之后处理异常和响应
if exception_to_raise:
raise exception_to_raise
return response
# 关闭会话,归还连接到连接池
try:
db.close()
except Exception as close_error:
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
def _get_client_ip(self, request: Request) -> str:
"""
获取客户端 IP 地址,支持代理头
注意:此方法信任 X-Forwarded-For 和 X-Real-IP 头,
仅当服务部署在可信代理(如 Nginx、CloudFlare后面时才安全。
如果服务直接暴露公网,攻击者可伪造这些头绕过限流。
"""
# 从配置获取可信代理层数(默认为 1即信任最近一层代理
trusted_proxy_count = getattr(config, "trusted_proxy_count", 1)
# 优先从代理头获取真实 IP
forwarded_for = request.headers.get("x-forwarded-for")
if forwarded_for:
# X-Forwarded-For 可能包含多个 IP取第一个
return forwarded_for.split(",")[0].strip()
# X-Forwarded-For 格式: "client, proxy1, proxy2"
# 从右往左数 trusted_proxy_count 个,取其左边的第一个
ips = [ip.strip() for ip in forwarded_for.split(",")]
if len(ips) > trusted_proxy_count:
return ips[-(trusted_proxy_count + 1)]
elif ips:
return ips[0]
real_ip = request.headers.get("x-real-ip")
if real_ip:
@@ -248,13 +261,11 @@ class PluginMiddleware(BaseHTTPMiddleware):
auth_header = request.headers.get("authorization", "")
api_key = request.headers.get("x-api-key", "")
if auth_header.startswith("Bearer "):
if auth_header.lower().startswith("bearer "):
api_key = auth_header[7:]
if api_key:
# 使用 API Key 的哈希作为限制 key避免日志泄露完整 key
import hashlib
key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:16]
key = f"llm_api_key:{key_hash}"
request.state.rate_limit_key_type = "api_key"
@@ -319,7 +330,10 @@ class PluginMiddleware(BaseHTTPMiddleware):
)
else:
# 限流触发,记录日志
logger.warning(f"速率限制触发: {getattr(request.state, 'rate_limit_key_type', 'unknown')}")
logger.warning(
"速率限制触发: {}",
getattr(request.state, "rate_limit_key_type", "unknown"),
)
return result
return None
except Exception as e:
@@ -332,7 +346,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
pass
async def _call_post_request_plugins(
self, request: Request, response: StarletteResponse, start_time: float
self, request: Request, status_code: int, start_time: float
) -> None:
"""调用请求后的插件"""
@@ -345,8 +359,8 @@ class PluginMiddleware(BaseHTTPMiddleware):
monitor_labels = {
"method": request.method,
"endpoint": request.url.path,
"status": str(response.status_code),
"status_class": f"{response.status_code // 100}xx",
"status": str(status_code),
"status_class": f"{status_code // 100}xx",
}
# 记录请求计数
@@ -368,6 +382,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
self, request: Request, error: Exception, start_time: float
) -> None:
"""调用错误处理插件"""
from fastapi import HTTPException
duration = time.time() - start_time
@@ -380,7 +395,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
error=error,
context={
"endpoint": f"{request.method} {request.url.path}",
"request_id": request.state.request_id,
"request_id": getattr(request.state, "request_id", ""),
"duration": duration,
},
)

View File

@@ -13,6 +13,42 @@ from pydantic import BaseModel, Field, field_validator, model_validator
from src.core.enums import APIFormat, ProviderBillingType
class ProxyConfig(BaseModel):
"""代理配置"""
url: str = Field(..., description="代理 URL (http://, https://, socks5://)")
username: Optional[str] = Field(None, max_length=255, description="代理用户名")
password: Optional[str] = Field(None, max_length=500, description="代理密码")
enabled: bool = Field(True, description="是否启用代理false 时保留配置但不使用)")
@field_validator("url")
@classmethod
def validate_proxy_url(cls, v: str) -> str:
"""验证代理 URL 格式"""
from urllib.parse import urlparse
v = v.strip()
# 检查禁止的字符(防止注入)
if "\n" in v or "\r" in v:
raise ValueError("代理 URL 包含非法字符")
# 验证协议(不支持 SOCKS4
if not re.match(r"^(http|https|socks5)://", v, re.IGNORECASE):
raise ValueError("代理 URL 必须以 http://, https:// 或 socks5:// 开头")
# 验证 URL 结构
parsed = urlparse(v)
if not parsed.netloc:
raise ValueError("代理 URL 必须包含有效的 host")
# 禁止 URL 中内嵌认证信息,强制使用独立字段
if parsed.username or parsed.password:
raise ValueError("请勿在 URL 中包含用户名和密码,请使用独立的认证字段")
return v
class CreateProviderRequest(BaseModel):
"""创建 Provider 请求"""
@@ -165,6 +201,7 @@ class CreateEndpointRequest(BaseModel):
rpm_limit: Optional[int] = Field(None, ge=0, description="RPM 限制")
concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制")
config: Optional[Dict[str, Any]] = Field(None, description="其他配置")
proxy: Optional[ProxyConfig] = Field(None, description="代理配置")
@field_validator("name")
@classmethod
@@ -220,6 +257,7 @@ class UpdateEndpointRequest(BaseModel):
rpm_limit: Optional[int] = Field(None, ge=0)
concurrent_limit: Optional[int] = Field(None, ge=0)
config: Optional[Dict[str, Any]] = None
proxy: Optional[ProxyConfig] = Field(None, description="代理配置")
# 复用验证器
_validate_name = field_validator("name")(CreateEndpointRequest.validate_name.__func__)

View File

@@ -538,6 +538,9 @@ class ProviderEndpoint(Base):
# 额外配置
config = Column(JSON, nullable=True) # 端点特定配置(不推荐使用,优先使用专用字段)
# 代理配置
proxy = Column(JSONB, nullable=True) # 代理配置: {url, username, password}
# 时间戳
created_at = Column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False

View File

@@ -8,6 +8,8 @@ from typing import Any, Dict, List, Optional
from pydantic import BaseModel, ConfigDict, Field, field_validator
from src.models.admin_requests import ProxyConfig
# ========== ProviderEndpoint CRUD ==========
@@ -30,6 +32,9 @@ class ProviderEndpointCreate(BaseModel):
# 额外配置
config: Optional[Dict[str, Any]] = Field(default=None, description="额外配置JSON")
# 代理配置
proxy: Optional[ProxyConfig] = Field(default=None, description="代理配置")
@field_validator("api_format")
@classmethod
def validate_api_format(cls, v: str) -> str:
@@ -64,6 +69,7 @@ class ProviderEndpointUpdate(BaseModel):
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制")
is_active: Optional[bool] = Field(default=None, description="是否启用")
config: Optional[Dict[str, Any]] = Field(default=None, description="额外配置")
proxy: Optional[ProxyConfig] = Field(default=None, description="代理配置")
@field_validator("base_url")
@classmethod
@@ -104,6 +110,9 @@ class ProviderEndpointResponse(BaseModel):
# 额外配置
config: Optional[Dict[str, Any]] = None
# 代理配置(响应中密码已脱敏)
proxy: Optional[Dict[str, Any]] = Field(default=None, description="代理配置(密码已脱敏)")
# 统计(从 Keys 聚合)
total_keys: int = Field(default=0, description="总 Key 数量")
active_keys: int = Field(default=0, description="活跃 Key 数量")

View File

@@ -3,6 +3,7 @@ JWT认证插件
支持JWT Bearer token认证
"""
import hashlib
from typing import Optional
from fastapi import Request
@@ -46,8 +47,8 @@ class JwtAuthPlugin(AuthPlugin):
logger.debug("未找到JWT token")
return None
# 记录认证尝试的详细信息
logger.info(f"JWT认证尝试 - 路径: {request.url.path}, Token前20位: {token[:20]}...")
token_fingerprint = hashlib.sha256(token.encode()).hexdigest()[:12]
logger.info(f"JWT认证尝试 - 路径: {request.url.path}, token_fp={token_fingerprint}")
try:
# 验证JWT token

View File

@@ -63,14 +63,16 @@ class JWTBlacklistService:
if ttl_seconds <= 0:
# Token 已经过期,不需要加入黑名单
logger.debug(f"Token 已过期,无需加入黑名单: {token[:10]}...")
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
logger.debug("Token 已过期,无需加入黑名单: token_fp={}", token_fp)
return True
# 存储到 Redis设置 TTL 为 Token 过期时间
# 值存储为原因字符串
await redis_client.setex(redis_key, ttl_seconds, reason)
logger.info(f"Token 已加入黑名单: {token[:10]}... (原因: {reason}, TTL: {ttl_seconds}s)")
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
logger.info("Token 已加入黑名单: token_fp={} (原因: {}, TTL: {}s)", token_fp, reason, ttl_seconds)
return True
except Exception as e:
@@ -109,7 +111,8 @@ class JWTBlacklistService:
if exists:
# 获取黑名单原因(可选)
reason = await redis_client.get(redis_key)
logger.warning(f"检测到黑名单 Token: {token[:10]}... (原因: {reason})")
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
logger.warning("检测到黑名单 Token: token_fp={} (原因: {})", token_fp, reason)
return True
return False
@@ -148,9 +151,11 @@ class JWTBlacklistService:
deleted = await redis_client.delete(redis_key)
if deleted:
logger.info(f"Token 已从黑名单移除: {token[:10]}...")
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
logger.info("Token 已从黑名单移除: token_fp={}", token_fp)
else:
logger.debug(f"Token 不在黑名单中: {token[:10]}...")
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
logger.debug("Token 不在黑名单中: token_fp={}", token_fp)
return bool(deleted)

View File

@@ -3,6 +3,7 @@
"""
import os
import hashlib
import secrets
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional
@@ -169,7 +170,8 @@ class AuthService:
key_record.last_used_at = datetime.now(timezone.utc)
db.commit() # 立即提交事务,释放数据库锁,避免阻塞后续请求
logger.debug(f"API认证成功: 用户 {user.email} (Key: {api_key[:10]}...)")
api_key_fp = hashlib.sha256(api_key.encode()).hexdigest()[:12]
logger.debug("API认证成功: 用户 {} (api_key_fp={})", user.email, api_key_fp)
return user, key_record
@staticmethod

View File

@@ -1,61 +0,0 @@
"""响应标准化服务,用于 STANDARD 模式下的响应格式验证和补全"""
from typing import Any, Dict, Optional
from src.core.logger import logger
from src.models.claude import ClaudeResponse
class ResponseNormalizer:
"""响应标准化器 - 用于标准模式下验证和补全响应字段"""
@staticmethod
def normalize_claude_response(
response_data: Dict[str, Any], request_id: Optional[str] = None
) -> Dict[str, Any]:
"""
标准化 Claude API 响应
Args:
response_data: 原始响应数据
request_id: 请求ID用于日志
Returns:
标准化后的响应数据(失败时返回原始数据)
"""
if "error" in response_data:
logger.debug(f"[ResponseNormalizer] 检测到错误响应,跳过标准化 | ID:{request_id}")
return response_data
try:
validated = ClaudeResponse.model_validate(response_data)
normalized = validated.model_dump(mode="json", exclude_none=False)
logger.debug(f"[ResponseNormalizer] 响应标准化成功 | ID:{request_id}")
return normalized
except Exception as e:
logger.debug(f"[ResponseNormalizer] 响应验证失败,透传原始数据 | ID:{request_id}")
return response_data
@staticmethod
def should_normalize(response_data: Dict[str, Any]) -> bool:
"""
判断是否需要进行标准化
Args:
response_data: 响应数据
Returns:
是否需要标准化
"""
# 错误响应不需要标准化
if "error" in response_data:
return False
# 已经包含新字段的响应不需要再次标准化
if "context_management" in response_data and "container" in response_data:
return False
return True

View File

@@ -21,7 +21,7 @@ from sqlalchemy.orm import Session
from src.core.logger import logger
from src.database import create_session
from src.models.database import Usage
from src.models.database import AuditLog, Usage
from src.services.system.config import SystemConfigService
from src.services.system.scheduler import get_scheduler
from src.services.system.stats_aggregator import StatsAggregatorService
@@ -91,6 +91,15 @@ class CleanupScheduler:
name="Pending状态清理",
)
# 审计日志清理 - 凌晨 4 点执行
scheduler.add_cron_job(
self._scheduled_audit_cleanup,
hour=4,
minute=0,
job_id="audit_cleanup",
name="审计日志清理",
)
# 启动时执行一次初始化任务
asyncio.create_task(self._run_startup_tasks())
@@ -145,6 +154,10 @@ class CleanupScheduler:
"""Pending 清理任务(定时调用)"""
await self._perform_pending_cleanup()
async def _scheduled_audit_cleanup(self):
"""审计日志清理任务(定时调用)"""
await self._perform_audit_cleanup()
# ========== 实际任务实现 ==========
async def _perform_stats_aggregation(self, backfill: bool = False):
@@ -330,6 +343,70 @@ class CleanupScheduler:
finally:
db.close()
async def _perform_audit_cleanup(self):
"""执行审计日志清理任务"""
db = create_session()
try:
# 检查是否启用自动清理
if not SystemConfigService.get_config(db, "enable_auto_cleanup", True):
logger.info("自动清理已禁用,跳过审计日志清理")
return
# 获取审计日志保留天数(默认 30 天,最少 7 天)
audit_retention_days = max(
SystemConfigService.get_config(db, "audit_log_retention_days", 30),
7, # 最少保留 7 天,防止误配置删除所有审计日志
)
batch_size = SystemConfigService.get_config(db, "cleanup_batch_size", 1000)
cutoff_time = datetime.now(timezone.utc) - timedelta(days=audit_retention_days)
logger.info(f"开始清理 {audit_retention_days} 天前的审计日志...")
total_deleted = 0
while True:
# 先查询要删除的记录 ID分批
records_to_delete = (
db.query(AuditLog.id)
.filter(AuditLog.created_at < cutoff_time)
.limit(batch_size)
.all()
)
if not records_to_delete:
break
record_ids = [r.id for r in records_to_delete]
# 执行删除
result = db.execute(
delete(AuditLog)
.where(AuditLog.id.in_(record_ids))
.execution_options(synchronize_session=False)
)
rows_deleted = result.rowcount
db.commit()
total_deleted += rows_deleted
logger.debug(f"已删除 {rows_deleted} 条审计日志,累计 {total_deleted}")
await asyncio.sleep(0.1)
if total_deleted > 0:
logger.info(f"审计日志清理完成,共删除 {total_deleted} 条记录")
else:
logger.info("无需清理的审计日志")
except Exception as e:
logger.exception(f"审计日志清理失败: {e}")
try:
db.rollback()
except Exception:
pass
finally:
db.close()
async def _perform_cleanup(self):
"""执行清理任务"""
db = create_session()

View File

@@ -1217,15 +1217,19 @@ class UsageService:
request_id: str,
status: str,
error_message: Optional[str] = None,
provider: Optional[str] = None,
target_model: Optional[str] = None,
) -> Optional[Usage]:
"""
快速更新使用记录状态(不更新其他字段)
快速更新使用记录状态
Args:
db: 数据库会话
request_id: 请求ID
status: 新状态 (pending, streaming, completed, failed)
error_message: 错误消息(仅在 failed 状态时使用)
provider: 提供商名称可选streaming 状态时更新)
target_model: 映射后的目标模型名(可选)
Returns:
更新后的 Usage 记录,如果未找到则返回 None
@@ -1239,6 +1243,10 @@ class UsageService:
usage.status = status
if error_message:
usage.error_message = error_message
if provider:
usage.provider = provider
if target_model:
usage.target_model = target_model
db.commit()

View File

@@ -457,7 +457,7 @@ class StreamUsageTracker:
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应 | 估算输入tokens:{self.input_tokens}")
# 更新状态为 streaming
# 更新状态为 streaming,同时更新 provider
if self.request_id:
try:
from src.services.usage.service import UsageService
@@ -465,6 +465,7 @@ class StreamUsageTracker:
db=self.db,
request_id=self.request_id,
status="streaming",
provider=self.provider,
)
except Exception as e:
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")

View File

@@ -210,7 +210,15 @@ class ApiKeyService:
@staticmethod
def check_rate_limit(db: Session, api_key: ApiKey, window_minutes: int = 1) -> tuple[bool, int]:
"""检查速率限制"""
"""检查速率限制
Returns:
(is_allowed, remaining): 是否允许请求,剩余可用次数
当 rate_limit 为 None 时表示不限制,返回 (True, -1)
"""
# 如果 rate_limit 为 None表示不限制
if api_key.rate_limit is None:
return True, -1 # -1 表示无限制
# 计算时间窗口
window_start = datetime.now(timezone.utc) - timedelta(minutes=window_minutes)

View File

@@ -3,6 +3,7 @@
提供统一的用户认证和授权功能
"""
import hashlib
from typing import Optional
from fastapi import Depends, Header, HTTPException, status
@@ -44,10 +45,17 @@ async def get_current_user(
payload = await AuthService.verify_token(token, token_type="access")
except HTTPException as token_error:
# 保持原始的HTTP状态码如401 Unauthorized不要转换为403
logger.error(f"Token验证失败: {token_error.status_code}: {token_error.detail}, Token前10位: {token[:10]}...")
token_fp = hashlib.sha256(token.encode()).hexdigest()[:12]
logger.error(
"Token验证失败: {}: {}, token_fp={}",
token_error.status_code,
token_error.detail,
token_fp,
)
raise # 重新抛出原始异常,保持状态码
except Exception as token_error:
logger.error(f"Token验证失败: {token_error}, Token前10位: {token[:10]}...")
token_fp = hashlib.sha256(token.encode()).hexdigest()[:12]
logger.error("Token验证失败: {}, token_fp={}", token_error, token_fp)
raise ForbiddenException("无效的Token")
user_id = payload.get("user_id")
@@ -63,7 +71,8 @@ async def get_current_user(
raise ForbiddenException("无效的认证凭据")
# 仅在DEBUG模式下记录详细信息
logger.debug(f"尝试获取用户: user_id={user_id}, token前10位: {token[:10]}...")
token_fp = hashlib.sha256(token.encode()).hexdigest()[:12]
logger.debug("尝试获取用户: user_id={}, token_fp={}", user_id, token_fp)
# 确保user_id是字符串格式UUID
if not isinstance(user_id, str):

View File

@@ -7,29 +7,47 @@ from typing import Optional
from fastapi import Request
from src.config import config
def get_client_ip(request: Request) -> str:
"""
获取客户端真实IP地址
按优先级检查:
1. X-Forwarded-For 头(支持代理链)
1. X-Forwarded-For 头(支持代理链,根据可信代理数量提取
2. X-Real-IP 头Nginx 代理)
3. 直接客户端IP
安全说明:
- 此函数根据 TRUSTED_PROXY_COUNT 配置来决定信任的代理层数
- 当 TRUSTED_PROXY_COUNT=0 时,不信任任何代理头,直接使用连接 IP
- 当服务直接暴露公网时,应设置 TRUSTED_PROXY_COUNT=0 以防止 IP 伪造
Args:
request: FastAPI Request 对象
Returns:
str: 客户端IP地址如果无法获取则返回 "unknown"
"""
trusted_proxy_count = config.trusted_proxy_count
# 如果不信任任何代理,直接返回连接 IP
if trusted_proxy_count == 0:
if request.client and request.client.host:
return request.client.host
return "unknown"
# 优先检查 X-Forwarded-For 头(可能包含代理链)
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
# X-Forwarded-For 格式: "client, proxy1, proxy2",取第一个(真实客户端)
client_ip = forwarded_for.split(",")[0].strip()
if client_ip:
return client_ip
# X-Forwarded-For 格式: "client, proxy1, proxy2"
# 从右往左数 trusted_proxy_count 个,取其左边的第一个
ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()]
if len(ips) > trusted_proxy_count:
return ips[-(trusted_proxy_count + 1)]
elif ips:
return ips[0]
# 检查 X-Real-IP 头(通常由 Nginx 设置)
real_ip = request.headers.get("X-Real-IP")
@@ -91,20 +109,32 @@ def get_request_metadata(request: Request) -> dict:
}
def extract_ip_from_headers(headers: dict) -> str:
def extract_ip_from_headers(headers: dict, trusted_proxy_count: Optional[int] = None) -> str:
"""
从HTTP头字典中提取IP地址用于中间件等场景
Args:
headers: HTTP头字典
trusted_proxy_count: 可信代理层数None 时使用配置值
Returns:
str: 客户端IP地址
"""
if trusted_proxy_count is None:
trusted_proxy_count = config.trusted_proxy_count
# 如果不信任任何代理,返回 unknown调用方需要用其他方式获取连接 IP
if trusted_proxy_count == 0:
return "unknown"
# 检查 X-Forwarded-For
forwarded_for = headers.get("x-forwarded-for", "")
if forwarded_for:
return forwarded_for.split(",")[0].strip()
ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()]
if len(ips) > trusted_proxy_count:
return ips[-(trusted_proxy_count + 1)]
elif ips:
return ips[0]
# 检查 X-Real-IP
real_ip = headers.get("x-real-ip", "")

View File

@@ -361,3 +361,61 @@ class TestPipelineAdminAuth:
assert result == mock_user
assert mock_request.state.user_id == "admin-123"
@pytest.mark.asyncio
async def test_authenticate_admin_lowercase_bearer(self, pipeline: ApiRequestPipeline) -> None:
"""测试 bearer (小写) 前缀也能正确解析"""
mock_user = MagicMock()
mock_user.id = "admin-123"
mock_user.is_active = True
mock_request = MagicMock()
mock_request.headers = {"authorization": "bearer valid-token"}
mock_request.state = MagicMock()
mock_db = MagicMock()
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
with patch.object(
pipeline.auth_service,
"verify_token",
new_callable=AsyncMock,
return_value={"user_id": "admin-123"},
) as mock_verify:
result = await pipeline._authenticate_admin(mock_request, mock_db)
mock_verify.assert_awaited_once_with("valid-token", token_type="access")
assert result == mock_user
class TestPipelineUserAuth:
"""测试普通用户 JWT 认证"""
@pytest.fixture
def pipeline(self) -> ApiRequestPipeline:
return ApiRequestPipeline()
@pytest.mark.asyncio
async def test_authenticate_user_lowercase_bearer(self, pipeline: ApiRequestPipeline) -> None:
"""测试 bearer (小写) 前缀也能正确解析"""
mock_user = MagicMock()
mock_user.id = "user-123"
mock_user.is_active = True
mock_request = MagicMock()
mock_request.headers = {"authorization": "bearer valid-token"}
mock_request.state = MagicMock()
mock_db = MagicMock()
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
with patch.object(
pipeline.auth_service,
"verify_token",
new_callable=AsyncMock,
return_value={"user_id": "user-123"},
) as mock_verify:
result = await pipeline._authenticate_user(mock_request, mock_db)
mock_verify.assert_awaited_once_with("valid-token", token_type="access")
assert result == mock_user