mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-05 17:22:28 +08:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5f0c1fb347 | ||
|
|
7b932d7afb | ||
|
|
c7b971cfe7 | ||
|
|
293bb592dc | ||
|
|
3e50c157be |
@@ -0,0 +1,57 @@
|
||||
"""add proxy field to provider_endpoints
|
||||
|
||||
Revision ID: f30f9936f6a2
|
||||
Revises: 1cc6942cf06f
|
||||
Create Date: 2025-12-18 06:31:58.451112+00:00
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy import inspect
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'f30f9936f6a2'
|
||||
down_revision = '1cc6942cf06f'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def column_exists(table_name: str, column_name: str) -> bool:
|
||||
"""检查列是否存在"""
|
||||
bind = op.get_bind()
|
||||
inspector = inspect(bind)
|
||||
columns = [col['name'] for col in inspector.get_columns(table_name)]
|
||||
return column_name in columns
|
||||
|
||||
|
||||
def get_column_type(table_name: str, column_name: str) -> str:
|
||||
"""获取列的类型"""
|
||||
bind = op.get_bind()
|
||||
inspector = inspect(bind)
|
||||
for col in inspector.get_columns(table_name):
|
||||
if col['name'] == column_name:
|
||||
return str(col['type']).upper()
|
||||
return ''
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""添加 proxy 字段到 provider_endpoints 表"""
|
||||
if not column_exists('provider_endpoints', 'proxy'):
|
||||
# 字段不存在,直接添加 JSONB 类型
|
||||
op.add_column('provider_endpoints', sa.Column('proxy', JSONB(), nullable=True))
|
||||
else:
|
||||
# 字段已存在,检查是否需要转换类型
|
||||
col_type = get_column_type('provider_endpoints', 'proxy')
|
||||
if 'JSONB' not in col_type:
|
||||
# 如果是 JSON 类型,转换为 JSONB
|
||||
op.execute(
|
||||
'ALTER TABLE provider_endpoints '
|
||||
'ALTER COLUMN proxy TYPE JSONB USING proxy::jsonb'
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""移除 proxy 字段"""
|
||||
if column_exists('provider_endpoints', 'proxy'):
|
||||
op.drop_column('provider_endpoints', 'proxy')
|
||||
@@ -1,5 +1,5 @@
|
||||
import client from '../client'
|
||||
import type { ProviderEndpoint } from './types'
|
||||
import type { ProviderEndpoint, ProxyConfig } from './types'
|
||||
|
||||
/**
|
||||
* 获取指定 Provider 的所有 Endpoints
|
||||
@@ -38,6 +38,7 @@ export async function createEndpoint(
|
||||
rate_limit?: number
|
||||
is_active?: boolean
|
||||
config?: Record<string, any>
|
||||
proxy?: ProxyConfig | null
|
||||
}
|
||||
): Promise<ProviderEndpoint> {
|
||||
const response = await client.post(`/api/admin/endpoints/providers/${providerId}/endpoints`, data)
|
||||
@@ -63,6 +64,7 @@ export async function updateEndpoint(
|
||||
rate_limit: number
|
||||
is_active: boolean
|
||||
config: Record<string, any>
|
||||
proxy: ProxyConfig | null
|
||||
}>
|
||||
): Promise<ProviderEndpoint> {
|
||||
const response = await client.put(`/api/admin/endpoints/${endpointId}`, data)
|
||||
|
||||
@@ -20,6 +20,16 @@ export const API_FORMAT_LABELS: Record<string, string> = {
|
||||
[API_FORMATS.GEMINI_CLI]: 'Gemini CLI',
|
||||
}
|
||||
|
||||
/**
|
||||
* 代理配置类型
|
||||
*/
|
||||
export interface ProxyConfig {
|
||||
url: string
|
||||
username?: string
|
||||
password?: string
|
||||
enabled?: boolean // 是否启用代理(false 时保留配置但不使用)
|
||||
}
|
||||
|
||||
export interface ProviderEndpoint {
|
||||
id: string
|
||||
provider_id: string
|
||||
@@ -41,6 +51,7 @@ export interface ProviderEndpoint {
|
||||
last_failure_at?: string
|
||||
is_active: boolean
|
||||
config?: Record<string, any>
|
||||
proxy?: ProxyConfig | null
|
||||
total_keys: number
|
||||
active_keys: number
|
||||
created_at: string
|
||||
|
||||
@@ -132,7 +132,7 @@
|
||||
type="number"
|
||||
min="1"
|
||||
max="10000"
|
||||
placeholder="100"
|
||||
placeholder="留空不限制"
|
||||
class="h-10"
|
||||
@update:model-value="(v) => form.rate_limit = parseNumberInput(v, { min: 1, max: 10000 })"
|
||||
/>
|
||||
@@ -376,7 +376,7 @@ const form = ref<StandaloneKeyFormData>({
|
||||
initial_balance_usd: 10,
|
||||
expire_days: undefined,
|
||||
never_expire: true,
|
||||
rate_limit: 100,
|
||||
rate_limit: undefined,
|
||||
auto_delete_on_expiry: false,
|
||||
allowed_providers: [],
|
||||
allowed_api_formats: [],
|
||||
@@ -389,7 +389,7 @@ function resetForm() {
|
||||
initial_balance_usd: 10,
|
||||
expire_days: undefined,
|
||||
never_expire: true,
|
||||
rate_limit: 100,
|
||||
rate_limit: undefined,
|
||||
auto_delete_on_expiry: false,
|
||||
allowed_providers: [],
|
||||
allowed_api_formats: [],
|
||||
@@ -408,7 +408,7 @@ function loadKeyData() {
|
||||
initial_balance_usd: props.apiKey.initial_balance_usd,
|
||||
expire_days: props.apiKey.expire_days,
|
||||
never_expire: props.apiKey.never_expire,
|
||||
rate_limit: props.apiKey.rate_limit || 100,
|
||||
rate_limit: props.apiKey.rate_limit,
|
||||
auto_delete_on_expiry: props.apiKey.auto_delete_on_expiry,
|
||||
allowed_providers: props.apiKey.allowed_providers || [],
|
||||
allowed_api_formats: props.apiKey.allowed_api_formats || [],
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
>
|
||||
<form
|
||||
class="space-y-6"
|
||||
@submit.prevent="handleSubmit"
|
||||
@submit.prevent="handleSubmit()"
|
||||
>
|
||||
<!-- API 配置 -->
|
||||
<div class="space-y-4">
|
||||
@@ -132,6 +132,79 @@
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 代理配置 -->
|
||||
<div class="space-y-4">
|
||||
<div class="flex items-center justify-between">
|
||||
<h3 class="text-sm font-medium">
|
||||
代理配置
|
||||
</h3>
|
||||
<div class="flex items-center gap-2">
|
||||
<Switch v-model="proxyEnabled" />
|
||||
<span class="text-sm text-muted-foreground">启用代理</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div
|
||||
v-if="proxyEnabled"
|
||||
class="space-y-4 rounded-lg border p-4"
|
||||
>
|
||||
<div class="space-y-2">
|
||||
<Label for="proxy_url">代理 URL *</Label>
|
||||
<Input
|
||||
id="proxy_url"
|
||||
v-model="form.proxy_url"
|
||||
placeholder="http://host:port 或 socks5://host:port"
|
||||
required
|
||||
:class="proxyUrlError ? 'border-red-500' : ''"
|
||||
/>
|
||||
<p
|
||||
v-if="proxyUrlError"
|
||||
class="text-xs text-red-500"
|
||||
>
|
||||
{{ proxyUrlError }}
|
||||
</p>
|
||||
<p
|
||||
v-else
|
||||
class="text-xs text-muted-foreground"
|
||||
>
|
||||
支持 HTTP、HTTPS、SOCKS5 代理
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div class="grid grid-cols-2 gap-4">
|
||||
<div class="space-y-2">
|
||||
<Label for="proxy_user">用户名(可选)</Label>
|
||||
<Input
|
||||
:id="`proxy_user_${formId}`"
|
||||
:name="`proxy_user_${formId}`"
|
||||
v-model="form.proxy_username"
|
||||
placeholder="代理认证用户名"
|
||||
autocomplete="off"
|
||||
data-form-type="other"
|
||||
data-lpignore="true"
|
||||
data-1p-ignore="true"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div class="space-y-2">
|
||||
<Label :for="`proxy_pass_${formId}`">密码(可选)</Label>
|
||||
<Input
|
||||
:id="`proxy_pass_${formId}`"
|
||||
:name="`proxy_pass_${formId}`"
|
||||
v-model="form.proxy_password"
|
||||
type="text"
|
||||
:placeholder="passwordPlaceholder"
|
||||
autocomplete="off"
|
||||
data-form-type="other"
|
||||
data-lpignore="true"
|
||||
data-1p-ignore="true"
|
||||
:style="{ '-webkit-text-security': 'disc', 'text-security': 'disc' }"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</form>
|
||||
|
||||
<template #footer>
|
||||
@@ -145,12 +218,24 @@
|
||||
</Button>
|
||||
<Button
|
||||
:disabled="loading || !form.base_url || (!isEditMode && !form.api_format)"
|
||||
@click="handleSubmit"
|
||||
@click="handleSubmit()"
|
||||
>
|
||||
{{ loading ? (isEditMode ? '保存中...' : '创建中...') : (isEditMode ? '保存修改' : '创建') }}
|
||||
</Button>
|
||||
</template>
|
||||
</Dialog>
|
||||
|
||||
<!-- 确认清空凭据对话框 -->
|
||||
<AlertDialog
|
||||
v-model="showClearCredentialsDialog"
|
||||
title="清空代理凭据"
|
||||
description="代理 URL 为空,但用户名和密码仍有值。是否清空这些凭据并继续保存?"
|
||||
type="warning"
|
||||
confirm-text="清空并保存"
|
||||
cancel-text="返回编辑"
|
||||
@confirm="confirmClearCredentials"
|
||||
@cancel="showClearCredentialsDialog = false"
|
||||
/>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
@@ -165,7 +250,9 @@ import {
|
||||
SelectValue,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
Switch,
|
||||
} from '@/components/ui'
|
||||
import AlertDialog from '@/components/common/AlertDialog.vue'
|
||||
import { Link, SquarePen } from 'lucide-vue-next'
|
||||
import { useToast } from '@/composables/useToast'
|
||||
import { useFormDialog } from '@/composables/useFormDialog'
|
||||
@@ -194,6 +281,11 @@ const emit = defineEmits<{
|
||||
const { success, error: showError } = useToast()
|
||||
const loading = ref(false)
|
||||
const selectOpen = ref(false)
|
||||
const proxyEnabled = ref(false)
|
||||
const showClearCredentialsDialog = ref(false) // 确认清空凭据对话框
|
||||
|
||||
// 生成随机 ID 防止浏览器自动填充
|
||||
const formId = Math.random().toString(36).substring(2, 10)
|
||||
|
||||
// 内部状态
|
||||
const internalOpen = computed(() => props.modelValue)
|
||||
@@ -207,7 +299,11 @@ const form = ref({
|
||||
max_retries: 3,
|
||||
max_concurrent: undefined as number | undefined,
|
||||
rate_limit: undefined as number | undefined,
|
||||
is_active: true
|
||||
is_active: true,
|
||||
// 代理配置
|
||||
proxy_url: '',
|
||||
proxy_username: '',
|
||||
proxy_password: '',
|
||||
})
|
||||
|
||||
// API 格式列表
|
||||
@@ -237,6 +333,53 @@ const defaultPathPlaceholder = computed(() => {
|
||||
return `留空使用默认路径:${defaultPath.value}`
|
||||
})
|
||||
|
||||
// 检查是否有已保存的密码(后端返回 *** 表示有密码)
|
||||
const hasExistingPassword = computed(() => {
|
||||
if (!props.endpoint?.proxy) return false
|
||||
const proxy = props.endpoint.proxy as { password?: string }
|
||||
return proxy?.password === MASKED_PASSWORD
|
||||
})
|
||||
|
||||
// 密码输入框的 placeholder
|
||||
const passwordPlaceholder = computed(() => {
|
||||
if (hasExistingPassword.value) {
|
||||
return '已保存密码,留空保持不变'
|
||||
}
|
||||
return '代理认证密码'
|
||||
})
|
||||
|
||||
// 代理 URL 验证
|
||||
const proxyUrlError = computed(() => {
|
||||
// 只有启用代理且填写了 URL 时才验证
|
||||
if (!proxyEnabled.value || !form.value.proxy_url) {
|
||||
return ''
|
||||
}
|
||||
const url = form.value.proxy_url.trim()
|
||||
|
||||
// 检查禁止的特殊字符
|
||||
if (/[\n\r]/.test(url)) {
|
||||
return '代理 URL 包含非法字符'
|
||||
}
|
||||
|
||||
// 验证协议(不支持 SOCKS4)
|
||||
if (!/^(http|https|socks5):\/\//i.test(url)) {
|
||||
return '代理 URL 必须以 http://, https:// 或 socks5:// 开头'
|
||||
}
|
||||
try {
|
||||
const parsed = new URL(url)
|
||||
if (!parsed.host) {
|
||||
return '代理 URL 必须包含有效的 host'
|
||||
}
|
||||
// 禁止 URL 中内嵌认证信息
|
||||
if (parsed.username || parsed.password) {
|
||||
return '请勿在 URL 中包含用户名和密码,请使用独立的认证字段'
|
||||
}
|
||||
} catch {
|
||||
return '代理 URL 格式无效'
|
||||
}
|
||||
return ''
|
||||
})
|
||||
|
||||
// 组件挂载时加载API格式
|
||||
onMounted(() => {
|
||||
loadApiFormats()
|
||||
@@ -252,14 +395,23 @@ function resetForm() {
|
||||
max_retries: 3,
|
||||
max_concurrent: undefined,
|
||||
rate_limit: undefined,
|
||||
is_active: true
|
||||
is_active: true,
|
||||
proxy_url: '',
|
||||
proxy_username: '',
|
||||
proxy_password: '',
|
||||
}
|
||||
proxyEnabled.value = false
|
||||
}
|
||||
|
||||
// 原始密码占位符(后端返回的脱敏标记)
|
||||
const MASKED_PASSWORD = '***'
|
||||
|
||||
// 加载端点数据(编辑模式)
|
||||
function loadEndpointData() {
|
||||
if (!props.endpoint) return
|
||||
|
||||
const proxy = props.endpoint.proxy as { url?: string; username?: string; password?: string; enabled?: boolean } | null
|
||||
|
||||
form.value = {
|
||||
api_format: props.endpoint.api_format,
|
||||
base_url: props.endpoint.base_url,
|
||||
@@ -268,8 +420,15 @@ function loadEndpointData() {
|
||||
max_retries: props.endpoint.max_retries,
|
||||
max_concurrent: props.endpoint.max_concurrent || undefined,
|
||||
rate_limit: props.endpoint.rate_limit || undefined,
|
||||
is_active: props.endpoint.is_active
|
||||
is_active: props.endpoint.is_active,
|
||||
proxy_url: proxy?.url || '',
|
||||
proxy_username: proxy?.username || '',
|
||||
// 如果密码是脱敏标记,显示为空(让用户知道有密码但看不到)
|
||||
proxy_password: proxy?.password === MASKED_PASSWORD ? '' : (proxy?.password || ''),
|
||||
}
|
||||
|
||||
// 根据 enabled 字段或 url 存在判断是否启用代理
|
||||
proxyEnabled.value = proxy?.enabled ?? !!proxy?.url
|
||||
}
|
||||
|
||||
// 使用 useFormDialog 统一处理对话框逻辑
|
||||
@@ -282,12 +441,47 @@ const { isEditMode, handleDialogUpdate, handleCancel } = useFormDialog({
|
||||
resetForm,
|
||||
})
|
||||
|
||||
// 构建代理配置
|
||||
// - 有 URL 时始终保存配置,通过 enabled 字段控制是否启用
|
||||
// - 无 URL 时返回 null
|
||||
function buildProxyConfig(): { url: string; username?: string; password?: string; enabled: boolean } | null {
|
||||
if (!form.value.proxy_url) {
|
||||
// 没填 URL,无代理配置
|
||||
return null
|
||||
}
|
||||
return {
|
||||
url: form.value.proxy_url,
|
||||
username: form.value.proxy_username || undefined,
|
||||
password: form.value.proxy_password || undefined,
|
||||
enabled: proxyEnabled.value, // 开关状态决定是否启用
|
||||
}
|
||||
}
|
||||
|
||||
// 提交表单
|
||||
const handleSubmit = async () => {
|
||||
const handleSubmit = async (skipCredentialCheck = false) => {
|
||||
if (!props.provider && !props.endpoint) return
|
||||
|
||||
// 只在开关开启且填写了 URL 时验证
|
||||
if (proxyEnabled.value && form.value.proxy_url && proxyUrlError.value) {
|
||||
showError(proxyUrlError.value, '代理配置错误')
|
||||
return
|
||||
}
|
||||
|
||||
// 检查:开关开启但没有 URL,却有用户名或密码
|
||||
const hasOrphanedCredentials = proxyEnabled.value
|
||||
&& !form.value.proxy_url
|
||||
&& (form.value.proxy_username || form.value.proxy_password)
|
||||
|
||||
if (hasOrphanedCredentials && !skipCredentialCheck) {
|
||||
// 弹出确认对话框
|
||||
showClearCredentialsDialog.value = true
|
||||
return
|
||||
}
|
||||
|
||||
loading.value = true
|
||||
try {
|
||||
const proxyConfig = buildProxyConfig()
|
||||
|
||||
if (isEditMode.value && props.endpoint) {
|
||||
// 更新端点
|
||||
await updateEndpoint(props.endpoint.id, {
|
||||
@@ -297,7 +491,8 @@ const handleSubmit = async () => {
|
||||
max_retries: form.value.max_retries,
|
||||
max_concurrent: form.value.max_concurrent,
|
||||
rate_limit: form.value.rate_limit,
|
||||
is_active: form.value.is_active
|
||||
is_active: form.value.is_active,
|
||||
proxy: proxyConfig,
|
||||
})
|
||||
|
||||
success('端点已更新', '保存成功')
|
||||
@@ -313,7 +508,8 @@ const handleSubmit = async () => {
|
||||
max_retries: form.value.max_retries,
|
||||
max_concurrent: form.value.max_concurrent,
|
||||
rate_limit: form.value.rate_limit,
|
||||
is_active: form.value.is_active
|
||||
is_active: form.value.is_active,
|
||||
proxy: proxyConfig,
|
||||
})
|
||||
|
||||
success('端点创建成功', '成功')
|
||||
@@ -329,4 +525,12 @@ const handleSubmit = async () => {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 确认清空凭据并继续保存
|
||||
const confirmClearCredentials = () => {
|
||||
form.value.proxy_username = ''
|
||||
form.value.proxy_password = ''
|
||||
showClearCredentialsDialog.value = false
|
||||
handleSubmit(true) // 跳过凭据检查,直接提交
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -25,7 +25,7 @@
|
||||
</h3>
|
||||
<div class="flex items-center gap-1 text-sm font-mono text-muted-foreground bg-muted px-2 py-0.5 rounded">
|
||||
<span>{{ detail?.model || '-' }}</span>
|
||||
<template v-if="detail?.target_model">
|
||||
<template v-if="detail?.target_model && detail.target_model !== detail.model">
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 20 20"
|
||||
|
||||
@@ -185,32 +185,13 @@
|
||||
</div>
|
||||
</CardSection>
|
||||
|
||||
<!-- API Key 管理配置 -->
|
||||
<!-- 独立余额 Key 过期管理 -->
|
||||
<CardSection
|
||||
title="API Key 管理"
|
||||
description="API Key 相关配置"
|
||||
title="独立余额 Key 过期管理"
|
||||
description="独立余额 Key 的过期处理策略(普通用户 Key 不会过期)"
|
||||
>
|
||||
<div class="grid grid-cols-1 md:grid-cols-2 gap-6">
|
||||
<div>
|
||||
<Label
|
||||
for="api-key-expire"
|
||||
class="block text-sm font-medium"
|
||||
>
|
||||
API密钥过期天数
|
||||
</Label>
|
||||
<Input
|
||||
id="api-key-expire"
|
||||
v-model.number="systemConfig.api_key_expire_days"
|
||||
type="number"
|
||||
placeholder="0"
|
||||
class="mt-1"
|
||||
/>
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
0 表示永不过期
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div class="flex items-center h-full pt-6">
|
||||
<div class="flex items-center h-full">
|
||||
<div class="flex items-center space-x-2">
|
||||
<Checkbox
|
||||
id="auto-delete-expired-keys"
|
||||
@@ -224,7 +205,7 @@
|
||||
自动删除过期 Key
|
||||
</Label>
|
||||
<p class="text-xs text-muted-foreground">
|
||||
关闭时仅禁用过期 Key
|
||||
关闭时仅禁用过期 Key,不会物理删除
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
@@ -448,6 +429,25 @@
|
||||
避免单次操作过大影响性能
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Label
|
||||
for="audit-log-retention-days"
|
||||
class="block text-sm font-medium"
|
||||
>
|
||||
审计日志保留天数
|
||||
</Label>
|
||||
<Input
|
||||
id="audit-log-retention-days"
|
||||
v-model.number="systemConfig.audit_log_retention_days"
|
||||
type="number"
|
||||
placeholder="30"
|
||||
class="mt-1"
|
||||
/>
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
超过后删除审计日志记录
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 清理策略说明 -->
|
||||
@@ -460,6 +460,7 @@
|
||||
<p>2. <strong>压缩日志阶段</strong>: body 字段被压缩存储,节省空间</p>
|
||||
<p>3. <strong>统计阶段</strong>: 仅保留 tokens、成本等统计信息</p>
|
||||
<p>4. <strong>归档删除</strong>: 超过保留期限后完全删除记录</p>
|
||||
<p>5. <strong>审计日志</strong>: 独立清理,记录用户登录、操作等安全事件</p>
|
||||
</div>
|
||||
</div>
|
||||
</CardSection>
|
||||
@@ -796,8 +797,7 @@ interface SystemConfig {
|
||||
// 用户注册
|
||||
enable_registration: boolean
|
||||
require_email_verification: boolean
|
||||
// API Key 管理
|
||||
api_key_expire_days: number
|
||||
// 独立余额 Key 过期管理
|
||||
auto_delete_expired_keys: boolean
|
||||
// 日志记录
|
||||
request_log_level: string
|
||||
@@ -811,6 +811,7 @@ interface SystemConfig {
|
||||
header_retention_days: number
|
||||
log_retention_days: number
|
||||
cleanup_batch_size: number
|
||||
audit_log_retention_days: number
|
||||
}
|
||||
|
||||
const loading = ref(false)
|
||||
@@ -845,8 +846,7 @@ const systemConfig = ref<SystemConfig>({
|
||||
// 用户注册
|
||||
enable_registration: false,
|
||||
require_email_verification: false,
|
||||
// API Key 管理
|
||||
api_key_expire_days: 0,
|
||||
// 独立余额 Key 过期管理
|
||||
auto_delete_expired_keys: false,
|
||||
// 日志记录
|
||||
request_log_level: 'basic',
|
||||
@@ -860,6 +860,7 @@ const systemConfig = ref<SystemConfig>({
|
||||
header_retention_days: 90,
|
||||
log_retention_days: 365,
|
||||
cleanup_batch_size: 1000,
|
||||
audit_log_retention_days: 30,
|
||||
})
|
||||
|
||||
// 计算属性:KB 和 字节 之间的转换
|
||||
@@ -901,8 +902,7 @@ async function loadSystemConfig() {
|
||||
// 用户注册
|
||||
'enable_registration',
|
||||
'require_email_verification',
|
||||
// API Key 管理
|
||||
'api_key_expire_days',
|
||||
// 独立余额 Key 过期管理
|
||||
'auto_delete_expired_keys',
|
||||
// 日志记录
|
||||
'request_log_level',
|
||||
@@ -916,6 +916,7 @@ async function loadSystemConfig() {
|
||||
'header_retention_days',
|
||||
'log_retention_days',
|
||||
'cleanup_batch_size',
|
||||
'audit_log_retention_days',
|
||||
]
|
||||
|
||||
for (const key of configs) {
|
||||
@@ -960,12 +961,7 @@ async function saveSystemConfig() {
|
||||
value: systemConfig.value.require_email_verification,
|
||||
description: '是否需要邮箱验证'
|
||||
},
|
||||
// API Key 管理
|
||||
{
|
||||
key: 'api_key_expire_days',
|
||||
value: systemConfig.value.api_key_expire_days,
|
||||
description: 'API密钥过期天数'
|
||||
},
|
||||
// 独立余额 Key 过期管理
|
||||
{
|
||||
key: 'auto_delete_expired_keys',
|
||||
value: systemConfig.value.auto_delete_expired_keys,
|
||||
@@ -1023,6 +1019,11 @@ async function saveSystemConfig() {
|
||||
value: systemConfig.value.cleanup_batch_size,
|
||||
description: '每批次清理的记录数'
|
||||
},
|
||||
{
|
||||
key: 'audit_log_retention_days',
|
||||
value: systemConfig.value.audit_log_retention_days,
|
||||
description: '审计日志保留天数'
|
||||
},
|
||||
]
|
||||
|
||||
const promises = configItems.map(item =>
|
||||
|
||||
@@ -223,7 +223,7 @@ class AdminCreateStandaloneKeyAdapter(AdminApiAdapter):
|
||||
allowed_providers=self.key_data.allowed_providers,
|
||||
allowed_api_formats=self.key_data.allowed_api_formats,
|
||||
allowed_models=self.key_data.allowed_models,
|
||||
rate_limit=self.key_data.rate_limit or 100,
|
||||
rate_limit=self.key_data.rate_limit, # None 表示不限制
|
||||
expire_days=self.key_data.expire_days,
|
||||
initial_balance_usd=self.key_data.initial_balance_usd,
|
||||
is_standalone=True, # 标记为独立Key
|
||||
|
||||
@@ -5,7 +5,7 @@ ProviderEndpoint CRUD 管理 API
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from sqlalchemy import and_, func
|
||||
@@ -27,6 +27,16 @@ router = APIRouter(tags=["Endpoint Management"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
|
||||
|
||||
def mask_proxy_password(proxy_config: Optional[dict]) -> Optional[dict]:
|
||||
"""对代理配置中的密码进行脱敏处理"""
|
||||
if not proxy_config:
|
||||
return None
|
||||
masked = dict(proxy_config)
|
||||
if masked.get("password"):
|
||||
masked["password"] = "***"
|
||||
return masked
|
||||
|
||||
|
||||
@router.get("/providers/{provider_id}/endpoints", response_model=List[ProviderEndpointResponse])
|
||||
async def list_provider_endpoints(
|
||||
provider_id: str,
|
||||
@@ -153,6 +163,7 @@ class AdminListProviderEndpointsAdapter(AdminApiAdapter):
|
||||
"api_format": endpoint.api_format,
|
||||
"total_keys": total_keys_map.get(endpoint.id, 0),
|
||||
"active_keys": active_keys_map.get(endpoint.id, 0),
|
||||
"proxy": mask_proxy_password(endpoint.proxy),
|
||||
}
|
||||
endpoint_dict.pop("_sa_instance_state", None)
|
||||
result.append(ProviderEndpointResponse(**endpoint_dict))
|
||||
@@ -202,6 +213,7 @@ class AdminCreateProviderEndpointAdapter(AdminApiAdapter):
|
||||
rate_limit=self.endpoint_data.rate_limit,
|
||||
is_active=True,
|
||||
config=self.endpoint_data.config,
|
||||
proxy=self.endpoint_data.proxy.model_dump() if self.endpoint_data.proxy else None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
@@ -215,12 +227,13 @@ class AdminCreateProviderEndpointAdapter(AdminApiAdapter):
|
||||
endpoint_dict = {
|
||||
k: v
|
||||
for k, v in new_endpoint.__dict__.items()
|
||||
if k not in {"api_format", "_sa_instance_state"}
|
||||
if k not in {"api_format", "_sa_instance_state", "proxy"}
|
||||
}
|
||||
return ProviderEndpointResponse(
|
||||
**endpoint_dict,
|
||||
provider_name=provider.name,
|
||||
api_format=new_endpoint.api_format,
|
||||
proxy=mask_proxy_password(new_endpoint.proxy),
|
||||
total_keys=0,
|
||||
active_keys=0,
|
||||
)
|
||||
@@ -259,12 +272,13 @@ class AdminGetProviderEndpointAdapter(AdminApiAdapter):
|
||||
endpoint_dict = {
|
||||
k: v
|
||||
for k, v in endpoint_obj.__dict__.items()
|
||||
if k not in {"api_format", "_sa_instance_state"}
|
||||
if k not in {"api_format", "_sa_instance_state", "proxy"}
|
||||
}
|
||||
return ProviderEndpointResponse(
|
||||
**endpoint_dict,
|
||||
provider_name=provider.name,
|
||||
api_format=endpoint_obj.api_format,
|
||||
proxy=mask_proxy_password(endpoint_obj.proxy),
|
||||
total_keys=total_keys,
|
||||
active_keys=active_keys,
|
||||
)
|
||||
@@ -284,6 +298,17 @@ class AdminUpdateProviderEndpointAdapter(AdminApiAdapter):
|
||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||
|
||||
update_data = self.endpoint_data.model_dump(exclude_unset=True)
|
||||
# 把 proxy 转换为 dict 存储,支持显式设置为 None 清除代理
|
||||
if "proxy" in update_data:
|
||||
if update_data["proxy"] is not None:
|
||||
new_proxy = dict(update_data["proxy"])
|
||||
# 只有当密码字段未提供时才保留原密码(空字符串视为显式清除)
|
||||
if "password" not in new_proxy and endpoint.proxy:
|
||||
old_password = endpoint.proxy.get("password")
|
||||
if old_password:
|
||||
new_proxy["password"] = old_password
|
||||
update_data["proxy"] = new_proxy
|
||||
# proxy 为 None 时保留,用于清除代理配置
|
||||
for field, value in update_data.items():
|
||||
setattr(endpoint, field, value)
|
||||
endpoint.updated_at = datetime.now(timezone.utc)
|
||||
@@ -311,12 +336,13 @@ class AdminUpdateProviderEndpointAdapter(AdminApiAdapter):
|
||||
endpoint_dict = {
|
||||
k: v
|
||||
for k, v in endpoint.__dict__.items()
|
||||
if k not in {"api_format", "_sa_instance_state"}
|
||||
if k not in {"api_format", "_sa_instance_state", "proxy"}
|
||||
}
|
||||
return ProviderEndpointResponse(
|
||||
**endpoint_dict,
|
||||
provider_name=provider.name if provider else "Unknown",
|
||||
api_format=endpoint.api_format,
|
||||
proxy=mask_proxy_password(endpoint.proxy),
|
||||
total_keys=total_keys,
|
||||
active_keys=active_keys,
|
||||
)
|
||||
|
||||
@@ -140,7 +140,7 @@ class AnnouncementOptionalAuthAdapter(ApiAdapter):
|
||||
if not authorization or not authorization.lower().startswith("bearer "):
|
||||
return None
|
||||
|
||||
token = authorization.replace("Bearer ", "").strip()
|
||||
token = authorization[7:].strip()
|
||||
try:
|
||||
payload = await AuthService.verify_token(token, token_type="access")
|
||||
user_id = payload.get("user_id")
|
||||
|
||||
@@ -177,7 +177,7 @@ class ApiRequestPipeline:
|
||||
if not authorization or not authorization.lower().startswith("bearer "):
|
||||
raise HTTPException(status_code=401, detail="缺少管理员凭证")
|
||||
|
||||
token = authorization.replace("Bearer ", "").strip()
|
||||
token = authorization[7:].strip()
|
||||
try:
|
||||
payload = await self.auth_service.verify_token(token, token_type="access")
|
||||
except HTTPException:
|
||||
@@ -204,7 +204,7 @@ class ApiRequestPipeline:
|
||||
if not authorization or not authorization.lower().startswith("bearer "):
|
||||
raise HTTPException(status_code=401, detail="缺少用户凭证")
|
||||
|
||||
token = authorization.replace("Bearer ", "").strip()
|
||||
token = authorization[7:].strip()
|
||||
try:
|
||||
payload = await self.auth_service.verify_token(token, token_type="access")
|
||||
except HTTPException:
|
||||
|
||||
@@ -28,7 +28,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any, Callable, Dict, Optional, Protocol, runtime_checkable
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Protocol, runtime_checkable
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
@@ -43,6 +43,9 @@ from src.services.provider.format import normalize_api_format
|
||||
from src.services.system.audit import audit_service
|
||||
from src.services.usage.service import UsageService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.api.handlers.base.stream_context import StreamContext
|
||||
|
||||
|
||||
|
||||
class MessageTelemetry:
|
||||
@@ -399,6 +402,41 @@ class BaseMessageHandler:
|
||||
# 创建后台任务,不阻塞当前流
|
||||
asyncio.create_task(_do_update())
|
||||
|
||||
def _update_usage_to_streaming_with_ctx(self, ctx: "StreamContext") -> None:
|
||||
"""更新 Usage 状态为 streaming,同时更新 provider 和 target_model
|
||||
|
||||
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
|
||||
|
||||
Args:
|
||||
ctx: 流式上下文,包含 provider_name 和 mapped_model
|
||||
"""
|
||||
import asyncio
|
||||
from src.database.database import get_db
|
||||
|
||||
target_request_id = self.request_id
|
||||
provider = ctx.provider_name
|
||||
target_model = ctx.mapped_model
|
||||
|
||||
async def _do_update() -> None:
|
||||
try:
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
try:
|
||||
UsageService.update_usage_status(
|
||||
db=db,
|
||||
request_id=target_request_id,
|
||||
status="streaming",
|
||||
provider=provider,
|
||||
target_model=target_model,
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"[{target_request_id}] 更新 Usage 状态为 streaming 失败: {e}")
|
||||
|
||||
# 创建后台任务,不阻塞当前流
|
||||
asyncio.create_task(_do_update())
|
||||
|
||||
def _log_request_error(self, message: str, error: Exception) -> None:
|
||||
"""记录请求错误日志,对业务异常不打印堆栈
|
||||
|
||||
|
||||
@@ -64,18 +64,6 @@ class ChatAdapterBase(ApiAdapter):
|
||||
|
||||
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
||||
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
|
||||
self.response_normalizer = None
|
||||
# 可选启用响应规范化
|
||||
self._init_response_normalizer()
|
||||
|
||||
def _init_response_normalizer(self):
|
||||
"""初始化响应规范化器 - 子类可覆盖"""
|
||||
try:
|
||||
from src.services.provider.response_normalizer import ResponseNormalizer
|
||||
|
||||
self.response_normalizer = ResponseNormalizer()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
async def handle(self, context: ApiRequestContext):
|
||||
"""处理 Chat API 请求"""
|
||||
@@ -228,8 +216,6 @@ class ChatAdapterBase(ApiAdapter):
|
||||
user_agent=user_agent,
|
||||
start_time=start_time,
|
||||
allowed_api_formats=self.allowed_api_formats,
|
||||
response_normalizer=self.response_normalizer,
|
||||
enable_response_normalization=self.response_normalizer is not None,
|
||||
adapter_detector=self.detect_capability_requirements,
|
||||
)
|
||||
|
||||
|
||||
@@ -88,8 +88,6 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
user_agent: str,
|
||||
start_time: float,
|
||||
allowed_api_formats: Optional[list] = None,
|
||||
response_normalizer: Optional[Any] = None,
|
||||
enable_response_normalization: bool = False,
|
||||
adapter_detector: Optional[Callable[[Dict[str, str], Optional[Dict[str, Any]]], Dict[str, bool]]] = None,
|
||||
):
|
||||
allowed = allowed_api_formats or [self.FORMAT_ID]
|
||||
@@ -106,8 +104,6 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
)
|
||||
self._parser: Optional[ResponseParser] = None
|
||||
self._request_builder = PassthroughRequestBuilder()
|
||||
self.response_normalizer = response_normalizer
|
||||
self.enable_response_normalization = enable_response_normalization
|
||||
|
||||
@property
|
||||
def parser(self) -> ResponseParser:
|
||||
@@ -297,11 +293,15 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
# 创建类型安全的流式上下文
|
||||
ctx = StreamContext(model=model, api_format=api_format)
|
||||
|
||||
# 创建更新状态的回调闭包(可以访问 ctx)
|
||||
def update_streaming_status() -> None:
|
||||
self._update_usage_to_streaming_with_ctx(ctx)
|
||||
|
||||
# 创建流处理器
|
||||
stream_processor = StreamProcessor(
|
||||
request_id=self.request_id,
|
||||
default_parser=self.parser,
|
||||
on_streaming_start=self._update_usage_to_streaming,
|
||||
on_streaming_start=update_streaming_status,
|
||||
)
|
||||
|
||||
# 定义请求函数
|
||||
@@ -466,7 +466,13 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
pool=config.http_pool_timeout,
|
||||
)
|
||||
|
||||
http_client = httpx.AsyncClient(timeout=timeout_config, follow_redirects=True)
|
||||
# 创建 HTTP 客户端(支持代理配置)
|
||||
from src.clients.http_client import HTTPClientPool
|
||||
|
||||
http_client = HTTPClientPool.create_client_with_proxy(
|
||||
proxy_config=endpoint.proxy,
|
||||
timeout=timeout_config,
|
||||
)
|
||||
try:
|
||||
response_ctx = http_client.stream(
|
||||
"POST", url, json=provider_payload, headers=provider_headers
|
||||
@@ -634,10 +640,14 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
logger.info(f" [{self.request_id}] 发送非流式请求: Provider={provider.name}, "
|
||||
f"模型={model} -> {mapped_model or '无映射'}")
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
timeout=float(endpoint.timeout),
|
||||
follow_redirects=True,
|
||||
) as http_client:
|
||||
# 创建 HTTP 客户端(支持代理配置)
|
||||
from src.clients.http_client import HTTPClientPool
|
||||
|
||||
http_client = HTTPClientPool.create_client_with_proxy(
|
||||
proxy_config=endpoint.proxy,
|
||||
timeout=httpx.Timeout(float(endpoint.timeout)),
|
||||
)
|
||||
async with http_client:
|
||||
resp = await http_client.post(url, json=provider_payload, headers=provider_hdrs)
|
||||
|
||||
status_code = resp.status_code
|
||||
|
||||
@@ -454,7 +454,13 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
f"Key=***{key.api_key[-4:]}, "
|
||||
f"原始模型={ctx.model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}")
|
||||
|
||||
http_client = httpx.AsyncClient(timeout=timeout_config, follow_redirects=True)
|
||||
# 创建 HTTP 客户端(支持代理配置)
|
||||
from src.clients.http_client import HTTPClientPool
|
||||
|
||||
http_client = HTTPClientPool.create_client_with_proxy(
|
||||
proxy_config=endpoint.proxy,
|
||||
timeout=timeout_config,
|
||||
)
|
||||
try:
|
||||
response_ctx = http_client.stream(
|
||||
"POST", url, json=provider_payload, headers=provider_headers
|
||||
@@ -526,7 +532,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
async for chunk in stream_response.aiter_raw():
|
||||
# 在第一次输出数据前更新状态为 streaming
|
||||
if not streaming_status_updated:
|
||||
self._update_usage_to_streaming(ctx.request_id)
|
||||
self._update_usage_to_streaming_with_ctx(ctx)
|
||||
streaming_status_updated = True
|
||||
|
||||
buffer += chunk
|
||||
@@ -810,7 +816,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
|
||||
# 在第一次输出数据前更新状态为 streaming
|
||||
if prefetched_chunks:
|
||||
self._update_usage_to_streaming(ctx.request_id)
|
||||
self._update_usage_to_streaming_with_ctx(ctx)
|
||||
|
||||
# 先处理预读的字节块
|
||||
for chunk in prefetched_chunks:
|
||||
@@ -1419,10 +1425,14 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
f"Key=***{key.api_key[-4:]}, "
|
||||
f"原始模型={model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}")
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
timeout=float(endpoint.timeout),
|
||||
follow_redirects=True,
|
||||
) as http_client:
|
||||
# 创建 HTTP 客户端(支持代理配置)
|
||||
from src.clients.http_client import HTTPClientPool
|
||||
|
||||
http_client = HTTPClientPool.create_client_with_proxy(
|
||||
proxy_config=endpoint.proxy,
|
||||
timeout=httpx.Timeout(float(endpoint.timeout)),
|
||||
)
|
||||
async with http_client:
|
||||
resp = await http_client.post(url, json=provider_payload, headers=provider_headers)
|
||||
|
||||
status_code = resp.status_code
|
||||
|
||||
@@ -131,10 +131,5 @@ class ClaudeChatHandler(ChatHandlerBase):
|
||||
Returns:
|
||||
规范化后的响应
|
||||
"""
|
||||
if self.response_normalizer and self.response_normalizer.should_normalize(response):
|
||||
result: Dict[str, Any] = self.response_normalizer.normalize_claude_response(
|
||||
response_data=response,
|
||||
request_id=self.request_id,
|
||||
)
|
||||
return result
|
||||
# 作为中转站,直接透传响应,不做标准化处理
|
||||
return response
|
||||
|
||||
@@ -148,17 +148,6 @@ class GeminiChatHandler(ChatHandlerBase):
|
||||
|
||||
Returns:
|
||||
规范化后的响应
|
||||
|
||||
TODO: 如果需要,实现响应规范化逻辑
|
||||
"""
|
||||
# 可选:使用 response_normalizer 进行规范化
|
||||
# if (
|
||||
# self.response_normalizer
|
||||
# and self.response_normalizer.should_normalize(response)
|
||||
# ):
|
||||
# return self.response_normalizer.normalize_gemini_response(
|
||||
# response_data=response,
|
||||
# request_id=self.request_id,
|
||||
# strict=False,
|
||||
# )
|
||||
# 作为中转站,直接透传响应,不做标准化处理
|
||||
return response
|
||||
|
||||
@@ -128,10 +128,5 @@ class OpenAIChatHandler(ChatHandlerBase):
|
||||
Returns:
|
||||
规范化后的响应
|
||||
"""
|
||||
if self.response_normalizer and self.response_normalizer.should_normalize(response):
|
||||
return self.response_normalizer.normalize_openai_response(
|
||||
response_data=response,
|
||||
request_id=self.request_id,
|
||||
strict=False,
|
||||
)
|
||||
# 作为中转站,直接透传响应,不做标准化处理
|
||||
return response
|
||||
|
||||
@@ -5,12 +5,55 @@
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, Dict, Optional
|
||||
from urllib.parse import quote, urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from src.core.logger import logger
|
||||
|
||||
|
||||
def build_proxy_url(proxy_config: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
根据代理配置构建完整的代理 URL
|
||||
|
||||
Args:
|
||||
proxy_config: 代理配置字典,包含 url, username, password, enabled
|
||||
|
||||
Returns:
|
||||
完整的代理 URL,如 socks5://user:pass@host:port
|
||||
如果 enabled=False 或无配置,返回 None
|
||||
"""
|
||||
if not proxy_config:
|
||||
return None
|
||||
|
||||
# 检查 enabled 字段,默认为 True(兼容旧数据)
|
||||
if not proxy_config.get("enabled", True):
|
||||
return None
|
||||
|
||||
proxy_url = proxy_config.get("url")
|
||||
if not proxy_url:
|
||||
return None
|
||||
|
||||
username = proxy_config.get("username")
|
||||
password = proxy_config.get("password")
|
||||
|
||||
# 只要有用户名就添加认证信息(密码可以为空)
|
||||
if username:
|
||||
parsed = urlparse(proxy_url)
|
||||
# URL 编码用户名和密码,处理特殊字符(如 @, :, /)
|
||||
encoded_username = quote(username, safe="")
|
||||
encoded_password = quote(password, safe="") if password else ""
|
||||
# 重新构建带认证的代理 URL
|
||||
if encoded_password:
|
||||
auth_proxy = f"{parsed.scheme}://{encoded_username}:{encoded_password}@{parsed.netloc}"
|
||||
else:
|
||||
auth_proxy = f"{parsed.scheme}://{encoded_username}@{parsed.netloc}"
|
||||
if parsed.path:
|
||||
auth_proxy += parsed.path
|
||||
return auth_proxy
|
||||
|
||||
return proxy_url
|
||||
|
||||
|
||||
class HTTPClientPool:
|
||||
"""
|
||||
@@ -121,6 +164,44 @@ class HTTPClientPool:
|
||||
finally:
|
||||
await client.aclose()
|
||||
|
||||
@classmethod
|
||||
def create_client_with_proxy(
|
||||
cls,
|
||||
proxy_config: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[httpx.Timeout] = None,
|
||||
**kwargs: Any,
|
||||
) -> httpx.AsyncClient:
|
||||
"""
|
||||
创建带代理配置的HTTP客户端
|
||||
|
||||
Args:
|
||||
proxy_config: 代理配置字典,包含 url, username, password
|
||||
timeout: 超时配置
|
||||
**kwargs: 其他 httpx.AsyncClient 配置参数
|
||||
|
||||
Returns:
|
||||
配置好的 httpx.AsyncClient 实例
|
||||
"""
|
||||
config: Dict[str, Any] = {
|
||||
"http2": False,
|
||||
"verify": True,
|
||||
"follow_redirects": True,
|
||||
}
|
||||
|
||||
if timeout:
|
||||
config["timeout"] = timeout
|
||||
else:
|
||||
config["timeout"] = httpx.Timeout(10.0, read=300.0)
|
||||
|
||||
# 添加代理配置
|
||||
proxy_url = build_proxy_url(proxy_config) if proxy_config else None
|
||||
if proxy_url:
|
||||
config["proxy"] = proxy_url
|
||||
logger.debug(f"创建带代理的HTTP客户端: {proxy_config.get('url', 'unknown')}")
|
||||
|
||||
config.update(kwargs)
|
||||
return httpx.AsyncClient(**config)
|
||||
|
||||
|
||||
# 便捷访问函数
|
||||
def get_http_client() -> httpx.AsyncClient:
|
||||
|
||||
@@ -120,7 +120,7 @@ class RedisClientManager:
|
||||
if self._circuit_open_until and time.time() < self._circuit_open_until:
|
||||
remaining = self._circuit_open_until - time.time()
|
||||
logger.warning(
|
||||
"Redis 客户端处于熔断状态,跳过初始化,剩余 %.1f 秒 (last_error: %s)",
|
||||
"Redis 客户端处于熔断状态,跳过初始化,剩余 {:.1f} 秒 (last_error: {})",
|
||||
remaining,
|
||||
self._last_error,
|
||||
)
|
||||
@@ -200,7 +200,7 @@ class RedisClientManager:
|
||||
if self._consecutive_failures >= self._circuit_threshold:
|
||||
self._circuit_open_until = time.time() + self._circuit_reset_seconds
|
||||
logger.warning(
|
||||
"Redis 初始化连续失败 %s 次,开启熔断 %s 秒。"
|
||||
"Redis 初始化连续失败 {} 次,开启熔断 {} 秒。"
|
||||
"熔断期间以下功能将降级: 缓存亲和性、分布式并发控制、RPM限流。"
|
||||
"可通过管理 API /api/admin/system/redis/reset-circuit 手动重置。",
|
||||
self._consecutive_failures,
|
||||
|
||||
@@ -105,6 +105,13 @@ class Config:
|
||||
self.llm_api_rate_limit = int(os.getenv("LLM_API_RATE_LIMIT", "100"))
|
||||
self.public_api_rate_limit = int(os.getenv("PUBLIC_API_RATE_LIMIT", "60"))
|
||||
|
||||
# 可信代理配置
|
||||
# TRUSTED_PROXY_COUNT: 信任的代理层数(默认 1,即信任最近一层代理)
|
||||
# 设置为 0 表示不信任任何代理头,直接使用连接 IP
|
||||
# 当服务部署在 Nginx/CloudFlare 等反向代理后面时,设置为对应的代理层数
|
||||
# 如果服务直接暴露公网,应设置为 0 以防止 IP 伪造
|
||||
self.trusted_proxy_count = int(os.getenv("TRUSTED_PROXY_COUNT", "1"))
|
||||
|
||||
# 异常处理配置
|
||||
# 设置为 True 时,ProxyException 会传播到路由层以便记录 provider_request_headers
|
||||
# 设置为 False 时,使用全局异常处理器统一处理
|
||||
|
||||
@@ -153,7 +153,7 @@ def _log_pool_capacity():
|
||||
total_estimated = theoretical * workers
|
||||
safe_limit = config.pg_max_connections - config.pg_reserved_connections
|
||||
logger.info(
|
||||
"数据库连接池配置: pool_size=%s, max_overflow=%s, workers=%s, total_estimated=%s, safe_limit=%s",
|
||||
"数据库连接池配置: pool_size={}, max_overflow={}, workers={}, total_estimated={}, safe_limit={}",
|
||||
config.db_pool_size,
|
||||
config.db_max_overflow,
|
||||
workers,
|
||||
@@ -162,7 +162,7 @@ def _log_pool_capacity():
|
||||
)
|
||||
if total_estimated > safe_limit:
|
||||
logger.warning(
|
||||
"数据库连接池总需求可能超过 PostgreSQL 限制: %s > %s (pg_max_connections - reserved),"
|
||||
"数据库连接池总需求可能超过 PostgreSQL 限制: {} > {} (pg_max_connections - reserved),"
|
||||
"建议调整 DB_POOL_SIZE/DB_MAX_OVERFLOW 或减少 worker 数",
|
||||
total_estimated,
|
||||
safe_limit,
|
||||
@@ -260,7 +260,8 @@ def get_db(request: Request = None) -> Generator[Session, None, None]: # type:
|
||||
|
||||
2. **管理后台 API**:
|
||||
- 路由层显式调用 db.commit()
|
||||
- 每个操作独立提交,不依赖中间件
|
||||
- 提交后设置 request.state.tx_committed_by_route = True
|
||||
- 中间件看到此标志后跳过 commit,只负责 close
|
||||
|
||||
3. **后台任务/调度器**:
|
||||
- 使用独立 Session(通过 create_session() 或 next(get_db()))
|
||||
@@ -271,6 +272,17 @@ def get_db(request: Request = None) -> Generator[Session, None, None]: # type:
|
||||
- FastAPI 请求:通过 Depends(get_db) 注入,支持中间件管理的 session 复用
|
||||
- 非请求上下文:直接调用 get_db(),退化为独立 session 模式
|
||||
|
||||
路由层提交事务示例
|
||||
==================
|
||||
```python
|
||||
@router.post("/example")
|
||||
async def example(request: Request, db: Session = Depends(get_db)):
|
||||
# ... 业务逻辑 ...
|
||||
db.commit()
|
||||
request.state.tx_committed_by_route = True # 告知中间件已提交
|
||||
return {"message": "success"}
|
||||
```
|
||||
|
||||
注意事项
|
||||
========
|
||||
- 本函数不自动提交事务
|
||||
|
||||
@@ -49,7 +49,7 @@ async def initialize_providers():
|
||||
# 从数据库加载所有活跃的提供商
|
||||
providers = (
|
||||
db.query(Provider)
|
||||
.filter(Provider.is_active == True)
|
||||
.filter(Provider.is_active.is_(True))
|
||||
.order_by(Provider.provider_priority.asc())
|
||||
.all()
|
||||
)
|
||||
@@ -122,6 +122,7 @@ async def lifespan(app: FastAPI):
|
||||
logger.info("初始化全局Redis客户端...")
|
||||
from src.clients.redis_client import get_redis_client
|
||||
|
||||
redis_client = None
|
||||
try:
|
||||
redis_client = await get_redis_client(require_redis=config.require_redis)
|
||||
if redis_client:
|
||||
@@ -133,6 +134,7 @@ async def lifespan(app: FastAPI):
|
||||
logger.exception("[ERROR] Redis连接失败,应用启动中止")
|
||||
raise
|
||||
logger.warning(f"Redis连接失败,但配置允许降级,将继续使用内存模式: {e}")
|
||||
redis_client = None
|
||||
|
||||
# 初始化并发管理器(内部会使用Redis)
|
||||
logger.info("初始化并发管理器...")
|
||||
@@ -312,7 +314,7 @@ if frontend_dist.exists():
|
||||
仅对非API路径生效
|
||||
"""
|
||||
# 如果是API路径,不处理
|
||||
if full_path.startswith("api/") or full_path.startswith("v1/"):
|
||||
if full_path in {"api", "v1"} or full_path.startswith(("api/", "v1/")):
|
||||
raise HTTPException(status_code=404, detail="Not Found")
|
||||
|
||||
# 返回index.html,让前端路由处理
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
"""
|
||||
统一的插件中间件
|
||||
统一的插件中间件(纯 ASGI 实现)
|
||||
负责协调所有插件的调用
|
||||
|
||||
注意:使用纯 ASGI middleware 而非 BaseHTTPMiddleware,
|
||||
以避免 Starlette 已知的流式响应兼容性问题。
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
from typing import Any, Awaitable, Callable, Optional
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response as StarletteResponse
|
||||
from starlette.requests import Request
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
from src.config import config
|
||||
from src.core.logger import logger
|
||||
@@ -18,20 +19,25 @@ from src.plugins.manager import get_plugin_manager
|
||||
from src.plugins.rate_limit.base import RateLimitResult
|
||||
|
||||
|
||||
|
||||
class PluginMiddleware(BaseHTTPMiddleware):
|
||||
class PluginMiddleware:
|
||||
"""
|
||||
统一的插件调用中间件
|
||||
统一的插件调用中间件(纯 ASGI 实现)
|
||||
|
||||
职责:
|
||||
- 性能监控
|
||||
- 限流控制 (可选)
|
||||
- 数据库会话生命周期管理
|
||||
|
||||
注意: 认证由各路由通过 Depends() 显式声明,不在中间件层处理
|
||||
|
||||
为什么使用纯 ASGI 而非 BaseHTTPMiddleware:
|
||||
- BaseHTTPMiddleware 会缓冲整个响应体,对流式响应不友好
|
||||
- BaseHTTPMiddleware 与 StreamingResponse 存在已知兼容性问题
|
||||
- 纯 ASGI 可以直接透传流式响应,无额外开销
|
||||
"""
|
||||
|
||||
def __init__(self, app: Any) -> None:
|
||||
super().__init__(app)
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
self.plugin_manager = get_plugin_manager()
|
||||
|
||||
# 从配置读取速率限制值
|
||||
@@ -61,152 +67,159 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
"/v1/completions",
|
||||
]
|
||||
|
||||
async def dispatch(
|
||||
self, request: Request, call_next: Callable[[Request], Awaitable[StarletteResponse]]
|
||||
) -> StarletteResponse:
|
||||
"""处理请求并调用相应插件"""
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
"""ASGI 入口点"""
|
||||
if scope["type"] != "http":
|
||||
# 非 HTTP 请求(如 WebSocket)直接透传
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
# 构建 Request 对象以便复用现有逻辑
|
||||
request = Request(scope, receive, send)
|
||||
|
||||
# 记录请求开始时间
|
||||
start_time = time.time()
|
||||
|
||||
# 设置 request.state 属性
|
||||
# 注意:Starlette 的 Request 对象总是有 state 属性(State 实例)
|
||||
request.state.request_id = request.headers.get("x-request-id", "")
|
||||
request.state.start_time = start_time
|
||||
# 标记:若请求过程中通过 Depends(get_db) 创建了会话,则由本中间件统一管理其生命周期
|
||||
request.state.db_managed_by_middleware = True
|
||||
|
||||
response = None
|
||||
exception_to_raise = None
|
||||
# 1. 限流检查(在调用下游之前)
|
||||
rate_limit_result = await self._call_rate_limit_plugins(request)
|
||||
if rate_limit_result and not rate_limit_result.allowed:
|
||||
# 限流触发,返回429
|
||||
await self._send_rate_limit_response(send, rate_limit_result)
|
||||
return
|
||||
|
||||
# 2. 预处理插件调用
|
||||
await self._call_pre_request_plugins(request)
|
||||
|
||||
# 用于捕获响应状态码
|
||||
response_status_code: int = 0
|
||||
|
||||
async def send_wrapper(message: Message) -> None:
|
||||
nonlocal response_status_code
|
||||
|
||||
if message["type"] == "http.response.start":
|
||||
response_status_code = message.get("status", 0)
|
||||
|
||||
await send(message)
|
||||
|
||||
# 3. 调用下游应用
|
||||
exception_occurred: Optional[Exception] = None
|
||||
try:
|
||||
# 1. 限流插件调用(可选功能)
|
||||
rate_limit_result = await self._call_rate_limit_plugins(request)
|
||||
if rate_limit_result and not rate_limit_result.allowed:
|
||||
# 限流触发,返回429
|
||||
headers = rate_limit_result.headers or {}
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=rate_limit_result.message or "Rate limit exceeded",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# 2. 预处理插件调用
|
||||
await self._call_pre_request_plugins(request)
|
||||
|
||||
# 处理请求
|
||||
response = await call_next(request)
|
||||
|
||||
# 3. 提交关键数据库事务(在返回响应前)
|
||||
# 这确保了 Usage 记录、配额扣减等关键数据在响应返回前持久化
|
||||
try:
|
||||
db = getattr(request.state, "db", None)
|
||||
if isinstance(db, Session):
|
||||
db.commit()
|
||||
except Exception as commit_error:
|
||||
logger.error(f"关键事务提交失败: {commit_error}")
|
||||
try:
|
||||
if isinstance(db, Session):
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
await self._call_error_plugins(request, commit_error, start_time)
|
||||
# 返回 500 错误,因为数据可能不一致
|
||||
response = JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"type": "error",
|
||||
"error": {
|
||||
"type": "database_error",
|
||||
"message": "数据保存失败,请重试",
|
||||
},
|
||||
},
|
||||
)
|
||||
# 跳过后处理插件,直接返回错误响应
|
||||
return response
|
||||
|
||||
# 4. 后处理插件调用(监控等,非关键操作)
|
||||
# 这些操作失败不应影响用户响应
|
||||
await self._call_post_request_plugins(request, response, start_time)
|
||||
|
||||
# 注意:不在此处添加限流响应头,因为在BaseHTTPMiddleware中
|
||||
# 响应返回后修改headers会导致Content-Length不匹配错误
|
||||
# 限流响应头已在返回429错误时正确包含(见上面的HTTPException)
|
||||
|
||||
except RuntimeError as e:
|
||||
if str(e) == "No response returned.":
|
||||
db = getattr(request.state, "db", None)
|
||||
if isinstance(db, Session):
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.error("Downstream handler completed without returning a response")
|
||||
|
||||
await self._call_error_plugins(request, e, start_time)
|
||||
|
||||
if isinstance(db, Session):
|
||||
try:
|
||||
db.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
response = JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"type": "error",
|
||||
"error": {
|
||||
"type": "internal_error",
|
||||
"message": "Internal server error: downstream handler returned no response.",
|
||||
},
|
||||
},
|
||||
)
|
||||
else:
|
||||
exception_to_raise = e
|
||||
|
||||
await self.app(scope, receive, send_wrapper)
|
||||
except Exception as e:
|
||||
# 回滚数据库事务
|
||||
db = getattr(request.state, "db", None)
|
||||
if isinstance(db, Session):
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
exception_occurred = e
|
||||
# 错误处理插件调用
|
||||
await self._call_error_plugins(request, e, start_time)
|
||||
raise
|
||||
finally:
|
||||
# 4. 数据库会话清理(无论成功与否)
|
||||
await self._cleanup_db_session(request, exception_occurred)
|
||||
|
||||
# 尝试提交错误日志
|
||||
if isinstance(db, Session):
|
||||
# 5. 后处理插件调用(仅在成功时)
|
||||
if not exception_occurred and response_status_code > 0:
|
||||
await self._call_post_request_plugins(request, response_status_code, start_time)
|
||||
|
||||
async def _send_rate_limit_response(
|
||||
self, send: Send, result: RateLimitResult
|
||||
) -> None:
|
||||
"""发送 429 限流响应"""
|
||||
import json
|
||||
|
||||
body = json.dumps({
|
||||
"type": "error",
|
||||
"error": {
|
||||
"type": "rate_limit_error",
|
||||
"message": result.message or "Rate limit exceeded",
|
||||
},
|
||||
}).encode("utf-8")
|
||||
|
||||
headers = [(b"content-type", b"application/json")]
|
||||
if result.headers:
|
||||
for key, value in result.headers.items():
|
||||
headers.append((key.lower().encode(), str(value).encode()))
|
||||
|
||||
await send({
|
||||
"type": "http.response.start",
|
||||
"status": 429,
|
||||
"headers": headers,
|
||||
})
|
||||
await send({
|
||||
"type": "http.response.body",
|
||||
"body": body,
|
||||
})
|
||||
|
||||
async def _cleanup_db_session(
|
||||
self, request: Request, exception: Optional[Exception]
|
||||
) -> None:
|
||||
"""清理数据库会话
|
||||
|
||||
事务策略:
|
||||
- 如果 request.state.tx_committed_by_route 为 True,说明路由已自行提交,中间件只负责 close
|
||||
- 否则由中间件统一 commit/rollback
|
||||
|
||||
这避免了双重提交的问题,同时保持向后兼容。
|
||||
"""
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
db = getattr(request.state, "db", None)
|
||||
if not isinstance(db, Session):
|
||||
return
|
||||
|
||||
# 检查是否由路由层已经提交
|
||||
tx_committed_by_route = getattr(request.state, "tx_committed_by_route", False)
|
||||
|
||||
try:
|
||||
if exception is not None:
|
||||
# 发生异常,回滚事务(无论谁负责提交)
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception as rollback_error:
|
||||
logger.debug(f"回滚事务时出错(可忽略): {rollback_error}")
|
||||
elif not tx_committed_by_route:
|
||||
# 正常完成且路由未自行提交,由中间件提交事务
|
||||
try:
|
||||
db.commit()
|
||||
except:
|
||||
pass
|
||||
|
||||
exception_to_raise = e
|
||||
|
||||
except Exception as commit_error:
|
||||
logger.error(f"关键事务提交失败: {commit_error}")
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
# 如果 tx_committed_by_route 为 True,跳过 commit(路由已提交)
|
||||
finally:
|
||||
db = getattr(request.state, "db", None)
|
||||
if isinstance(db, Session):
|
||||
try:
|
||||
db.close()
|
||||
except Exception as close_error:
|
||||
# 连接池会处理连接的回收,这里的异常不应影响响应
|
||||
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
|
||||
|
||||
# 在 finally 块之后处理异常和响应
|
||||
if exception_to_raise:
|
||||
raise exception_to_raise
|
||||
|
||||
return response
|
||||
# 关闭会话,归还连接到连接池
|
||||
try:
|
||||
db.close()
|
||||
except Exception as close_error:
|
||||
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""
|
||||
获取客户端 IP 地址,支持代理头
|
||||
|
||||
注意:此方法信任 X-Forwarded-For 和 X-Real-IP 头,
|
||||
仅当服务部署在可信代理(如 Nginx、CloudFlare)后面时才安全。
|
||||
如果服务直接暴露公网,攻击者可伪造这些头绕过限流。
|
||||
"""
|
||||
# 从配置获取可信代理层数(默认为 1,即信任最近一层代理)
|
||||
trusted_proxy_count = getattr(config, "trusted_proxy_count", 1)
|
||||
|
||||
# 优先从代理头获取真实 IP
|
||||
forwarded_for = request.headers.get("x-forwarded-for")
|
||||
if forwarded_for:
|
||||
# X-Forwarded-For 可能包含多个 IP,取第一个
|
||||
return forwarded_for.split(",")[0].strip()
|
||||
# X-Forwarded-For 格式: "client, proxy1, proxy2"
|
||||
# 从右往左数 trusted_proxy_count 个,取其左边的第一个
|
||||
ips = [ip.strip() for ip in forwarded_for.split(",")]
|
||||
if len(ips) > trusted_proxy_count:
|
||||
return ips[-(trusted_proxy_count + 1)]
|
||||
elif ips:
|
||||
return ips[0]
|
||||
|
||||
real_ip = request.headers.get("x-real-ip")
|
||||
if real_ip:
|
||||
@@ -248,13 +261,11 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
api_key = request.headers.get("x-api-key", "")
|
||||
|
||||
if auth_header.startswith("Bearer "):
|
||||
if auth_header.lower().startswith("bearer "):
|
||||
api_key = auth_header[7:]
|
||||
|
||||
if api_key:
|
||||
# 使用 API Key 的哈希作为限制 key(避免日志泄露完整 key)
|
||||
import hashlib
|
||||
|
||||
key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:16]
|
||||
key = f"llm_api_key:{key_hash}"
|
||||
request.state.rate_limit_key_type = "api_key"
|
||||
@@ -319,7 +330,10 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
)
|
||||
else:
|
||||
# 限流触发,记录日志
|
||||
logger.warning(f"速率限制触发: {getattr(request.state, 'rate_limit_key_type', 'unknown')}")
|
||||
logger.warning(
|
||||
"速率限制触发: {}",
|
||||
getattr(request.state, "rate_limit_key_type", "unknown"),
|
||||
)
|
||||
return result
|
||||
return None
|
||||
except Exception as e:
|
||||
@@ -332,7 +346,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
pass
|
||||
|
||||
async def _call_post_request_plugins(
|
||||
self, request: Request, response: StarletteResponse, start_time: float
|
||||
self, request: Request, status_code: int, start_time: float
|
||||
) -> None:
|
||||
"""调用请求后的插件"""
|
||||
|
||||
@@ -345,8 +359,8 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
monitor_labels = {
|
||||
"method": request.method,
|
||||
"endpoint": request.url.path,
|
||||
"status": str(response.status_code),
|
||||
"status_class": f"{response.status_code // 100}xx",
|
||||
"status": str(status_code),
|
||||
"status_class": f"{status_code // 100}xx",
|
||||
}
|
||||
|
||||
# 记录请求计数
|
||||
@@ -368,6 +382,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
self, request: Request, error: Exception, start_time: float
|
||||
) -> None:
|
||||
"""调用错误处理插件"""
|
||||
from fastapi import HTTPException
|
||||
|
||||
duration = time.time() - start_time
|
||||
|
||||
@@ -380,7 +395,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
error=error,
|
||||
context={
|
||||
"endpoint": f"{request.method} {request.url.path}",
|
||||
"request_id": request.state.request_id,
|
||||
"request_id": getattr(request.state, "request_id", ""),
|
||||
"duration": duration,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -13,6 +13,42 @@ from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from src.core.enums import APIFormat, ProviderBillingType
|
||||
|
||||
|
||||
class ProxyConfig(BaseModel):
|
||||
"""代理配置"""
|
||||
|
||||
url: str = Field(..., description="代理 URL (http://, https://, socks5://)")
|
||||
username: Optional[str] = Field(None, max_length=255, description="代理用户名")
|
||||
password: Optional[str] = Field(None, max_length=500, description="代理密码")
|
||||
enabled: bool = Field(True, description="是否启用代理(false 时保留配置但不使用)")
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def validate_proxy_url(cls, v: str) -> str:
|
||||
"""验证代理 URL 格式"""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
v = v.strip()
|
||||
|
||||
# 检查禁止的字符(防止注入)
|
||||
if "\n" in v or "\r" in v:
|
||||
raise ValueError("代理 URL 包含非法字符")
|
||||
|
||||
# 验证协议(不支持 SOCKS4)
|
||||
if not re.match(r"^(http|https|socks5)://", v, re.IGNORECASE):
|
||||
raise ValueError("代理 URL 必须以 http://, https:// 或 socks5:// 开头")
|
||||
|
||||
# 验证 URL 结构
|
||||
parsed = urlparse(v)
|
||||
if not parsed.netloc:
|
||||
raise ValueError("代理 URL 必须包含有效的 host")
|
||||
|
||||
# 禁止 URL 中内嵌认证信息,强制使用独立字段
|
||||
if parsed.username or parsed.password:
|
||||
raise ValueError("请勿在 URL 中包含用户名和密码,请使用独立的认证字段")
|
||||
|
||||
return v
|
||||
|
||||
|
||||
class CreateProviderRequest(BaseModel):
|
||||
"""创建 Provider 请求"""
|
||||
|
||||
@@ -165,6 +201,7 @@ class CreateEndpointRequest(BaseModel):
|
||||
rpm_limit: Optional[int] = Field(None, ge=0, description="RPM 限制")
|
||||
concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制")
|
||||
config: Optional[Dict[str, Any]] = Field(None, description="其他配置")
|
||||
proxy: Optional[ProxyConfig] = Field(None, description="代理配置")
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
@@ -220,6 +257,7 @@ class UpdateEndpointRequest(BaseModel):
|
||||
rpm_limit: Optional[int] = Field(None, ge=0)
|
||||
concurrent_limit: Optional[int] = Field(None, ge=0)
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
proxy: Optional[ProxyConfig] = Field(None, description="代理配置")
|
||||
|
||||
# 复用验证器
|
||||
_validate_name = field_validator("name")(CreateEndpointRequest.validate_name.__func__)
|
||||
|
||||
@@ -538,6 +538,9 @@ class ProviderEndpoint(Base):
|
||||
# 额外配置
|
||||
config = Column(JSON, nullable=True) # 端点特定配置(不推荐使用,优先使用专用字段)
|
||||
|
||||
# 代理配置
|
||||
proxy = Column(JSONB, nullable=True) # 代理配置: {url, username, password}
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False
|
||||
|
||||
@@ -8,6 +8,8 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from src.models.admin_requests import ProxyConfig
|
||||
|
||||
# ========== ProviderEndpoint CRUD ==========
|
||||
|
||||
|
||||
@@ -30,6 +32,9 @@ class ProviderEndpointCreate(BaseModel):
|
||||
# 额外配置
|
||||
config: Optional[Dict[str, Any]] = Field(default=None, description="额外配置(JSON)")
|
||||
|
||||
# 代理配置
|
||||
proxy: Optional[ProxyConfig] = Field(default=None, description="代理配置")
|
||||
|
||||
@field_validator("api_format")
|
||||
@classmethod
|
||||
def validate_api_format(cls, v: str) -> str:
|
||||
@@ -64,6 +69,7 @@ class ProviderEndpointUpdate(BaseModel):
|
||||
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制")
|
||||
is_active: Optional[bool] = Field(default=None, description="是否启用")
|
||||
config: Optional[Dict[str, Any]] = Field(default=None, description="额外配置")
|
||||
proxy: Optional[ProxyConfig] = Field(default=None, description="代理配置")
|
||||
|
||||
@field_validator("base_url")
|
||||
@classmethod
|
||||
@@ -104,6 +110,9 @@ class ProviderEndpointResponse(BaseModel):
|
||||
# 额外配置
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
|
||||
# 代理配置(响应中密码已脱敏)
|
||||
proxy: Optional[Dict[str, Any]] = Field(default=None, description="代理配置(密码已脱敏)")
|
||||
|
||||
# 统计(从 Keys 聚合)
|
||||
total_keys: int = Field(default=0, description="总 Key 数量")
|
||||
active_keys: int = Field(default=0, description="活跃 Key 数量")
|
||||
|
||||
@@ -3,6 +3,7 @@ JWT认证插件
|
||||
支持JWT Bearer token认证
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Request
|
||||
@@ -46,8 +47,8 @@ class JwtAuthPlugin(AuthPlugin):
|
||||
logger.debug("未找到JWT token")
|
||||
return None
|
||||
|
||||
# 记录认证尝试的详细信息
|
||||
logger.info(f"JWT认证尝试 - 路径: {request.url.path}, Token前20位: {token[:20]}...")
|
||||
token_fingerprint = hashlib.sha256(token.encode()).hexdigest()[:12]
|
||||
logger.info(f"JWT认证尝试 - 路径: {request.url.path}, token_fp={token_fingerprint}")
|
||||
|
||||
try:
|
||||
# 验证JWT token
|
||||
|
||||
@@ -63,14 +63,16 @@ class JWTBlacklistService:
|
||||
|
||||
if ttl_seconds <= 0:
|
||||
# Token 已经过期,不需要加入黑名单
|
||||
logger.debug(f"Token 已过期,无需加入黑名单: {token[:10]}...")
|
||||
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
|
||||
logger.debug("Token 已过期,无需加入黑名单: token_fp={}", token_fp)
|
||||
return True
|
||||
|
||||
# 存储到 Redis,设置 TTL 为 Token 过期时间
|
||||
# 值存储为原因字符串
|
||||
await redis_client.setex(redis_key, ttl_seconds, reason)
|
||||
|
||||
logger.info(f"Token 已加入黑名单: {token[:10]}... (原因: {reason}, TTL: {ttl_seconds}s)")
|
||||
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
|
||||
logger.info("Token 已加入黑名单: token_fp={} (原因: {}, TTL: {}s)", token_fp, reason, ttl_seconds)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
@@ -109,7 +111,8 @@ class JWTBlacklistService:
|
||||
if exists:
|
||||
# 获取黑名单原因(可选)
|
||||
reason = await redis_client.get(redis_key)
|
||||
logger.warning(f"检测到黑名单 Token: {token[:10]}... (原因: {reason})")
|
||||
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
|
||||
logger.warning("检测到黑名单 Token: token_fp={} (原因: {})", token_fp, reason)
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -148,9 +151,11 @@ class JWTBlacklistService:
|
||||
deleted = await redis_client.delete(redis_key)
|
||||
|
||||
if deleted:
|
||||
logger.info(f"Token 已从黑名单移除: {token[:10]}...")
|
||||
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
|
||||
logger.info("Token 已从黑名单移除: token_fp={}", token_fp)
|
||||
else:
|
||||
logger.debug(f"Token 不在黑名单中: {token[:10]}...")
|
||||
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
|
||||
logger.debug("Token 不在黑名单中: token_fp={}", token_fp)
|
||||
|
||||
return bool(deleted)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
"""
|
||||
|
||||
import os
|
||||
import hashlib
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
@@ -169,7 +170,8 @@ class AuthService:
|
||||
key_record.last_used_at = datetime.now(timezone.utc)
|
||||
db.commit() # 立即提交事务,释放数据库锁,避免阻塞后续请求
|
||||
|
||||
logger.debug(f"API认证成功: 用户 {user.email} (Key: {api_key[:10]}...)")
|
||||
api_key_fp = hashlib.sha256(api_key.encode()).hexdigest()[:12]
|
||||
logger.debug("API认证成功: 用户 {} (api_key_fp={})", user.email, api_key_fp)
|
||||
return user, key_record
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
"""响应标准化服务,用于 STANDARD 模式下的响应格式验证和补全"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.models.claude import ClaudeResponse
|
||||
|
||||
|
||||
|
||||
class ResponseNormalizer:
|
||||
"""响应标准化器 - 用于标准模式下验证和补全响应字段"""
|
||||
|
||||
@staticmethod
|
||||
def normalize_claude_response(
|
||||
response_data: Dict[str, Any], request_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
标准化 Claude API 响应
|
||||
|
||||
Args:
|
||||
response_data: 原始响应数据
|
||||
request_id: 请求ID(用于日志)
|
||||
|
||||
Returns:
|
||||
标准化后的响应数据(失败时返回原始数据)
|
||||
"""
|
||||
if "error" in response_data:
|
||||
logger.debug(f"[ResponseNormalizer] 检测到错误响应,跳过标准化 | ID:{request_id}")
|
||||
return response_data
|
||||
|
||||
try:
|
||||
validated = ClaudeResponse.model_validate(response_data)
|
||||
normalized = validated.model_dump(mode="json", exclude_none=False)
|
||||
|
||||
logger.debug(f"[ResponseNormalizer] 响应标准化成功 | ID:{request_id}")
|
||||
return normalized
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"[ResponseNormalizer] 响应验证失败,透传原始数据 | ID:{request_id}")
|
||||
return response_data
|
||||
|
||||
@staticmethod
|
||||
def should_normalize(response_data: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
判断是否需要进行标准化
|
||||
|
||||
Args:
|
||||
response_data: 响应数据
|
||||
|
||||
Returns:
|
||||
是否需要标准化
|
||||
"""
|
||||
# 错误响应不需要标准化
|
||||
if "error" in response_data:
|
||||
return False
|
||||
|
||||
# 已经包含新字段的响应不需要再次标准化
|
||||
if "context_management" in response_data and "container" in response_data:
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -21,7 +21,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.database import create_session
|
||||
from src.models.database import Usage
|
||||
from src.models.database import AuditLog, Usage
|
||||
from src.services.system.config import SystemConfigService
|
||||
from src.services.system.scheduler import get_scheduler
|
||||
from src.services.system.stats_aggregator import StatsAggregatorService
|
||||
@@ -91,6 +91,15 @@ class CleanupScheduler:
|
||||
name="Pending状态清理",
|
||||
)
|
||||
|
||||
# 审计日志清理 - 凌晨 4 点执行
|
||||
scheduler.add_cron_job(
|
||||
self._scheduled_audit_cleanup,
|
||||
hour=4,
|
||||
minute=0,
|
||||
job_id="audit_cleanup",
|
||||
name="审计日志清理",
|
||||
)
|
||||
|
||||
# 启动时执行一次初始化任务
|
||||
asyncio.create_task(self._run_startup_tasks())
|
||||
|
||||
@@ -145,6 +154,10 @@ class CleanupScheduler:
|
||||
"""Pending 清理任务(定时调用)"""
|
||||
await self._perform_pending_cleanup()
|
||||
|
||||
async def _scheduled_audit_cleanup(self):
|
||||
"""审计日志清理任务(定时调用)"""
|
||||
await self._perform_audit_cleanup()
|
||||
|
||||
# ========== 实际任务实现 ==========
|
||||
|
||||
async def _perform_stats_aggregation(self, backfill: bool = False):
|
||||
@@ -330,6 +343,70 @@ class CleanupScheduler:
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def _perform_audit_cleanup(self):
|
||||
"""执行审计日志清理任务"""
|
||||
db = create_session()
|
||||
try:
|
||||
# 检查是否启用自动清理
|
||||
if not SystemConfigService.get_config(db, "enable_auto_cleanup", True):
|
||||
logger.info("自动清理已禁用,跳过审计日志清理")
|
||||
return
|
||||
|
||||
# 获取审计日志保留天数(默认 30 天,最少 7 天)
|
||||
audit_retention_days = max(
|
||||
SystemConfigService.get_config(db, "audit_log_retention_days", 30),
|
||||
7, # 最少保留 7 天,防止误配置删除所有审计日志
|
||||
)
|
||||
batch_size = SystemConfigService.get_config(db, "cleanup_batch_size", 1000)
|
||||
|
||||
cutoff_time = datetime.now(timezone.utc) - timedelta(days=audit_retention_days)
|
||||
|
||||
logger.info(f"开始清理 {audit_retention_days} 天前的审计日志...")
|
||||
|
||||
total_deleted = 0
|
||||
while True:
|
||||
# 先查询要删除的记录 ID(分批)
|
||||
records_to_delete = (
|
||||
db.query(AuditLog.id)
|
||||
.filter(AuditLog.created_at < cutoff_time)
|
||||
.limit(batch_size)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not records_to_delete:
|
||||
break
|
||||
|
||||
record_ids = [r.id for r in records_to_delete]
|
||||
|
||||
# 执行删除
|
||||
result = db.execute(
|
||||
delete(AuditLog)
|
||||
.where(AuditLog.id.in_(record_ids))
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
|
||||
rows_deleted = result.rowcount
|
||||
db.commit()
|
||||
|
||||
total_deleted += rows_deleted
|
||||
logger.debug(f"已删除 {rows_deleted} 条审计日志,累计 {total_deleted} 条")
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
if total_deleted > 0:
|
||||
logger.info(f"审计日志清理完成,共删除 {total_deleted} 条记录")
|
||||
else:
|
||||
logger.info("无需清理的审计日志")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"审计日志清理失败: {e}")
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def _perform_cleanup(self):
|
||||
"""执行清理任务"""
|
||||
db = create_session()
|
||||
|
||||
@@ -1217,15 +1217,19 @@ class UsageService:
|
||||
request_id: str,
|
||||
status: str,
|
||||
error_message: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
target_model: Optional[str] = None,
|
||||
) -> Optional[Usage]:
|
||||
"""
|
||||
快速更新使用记录状态(不更新其他字段)
|
||||
快速更新使用记录状态
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
request_id: 请求ID
|
||||
status: 新状态 (pending, streaming, completed, failed)
|
||||
error_message: 错误消息(仅在 failed 状态时使用)
|
||||
provider: 提供商名称(可选,streaming 状态时更新)
|
||||
target_model: 映射后的目标模型名(可选)
|
||||
|
||||
Returns:
|
||||
更新后的 Usage 记录,如果未找到则返回 None
|
||||
@@ -1239,6 +1243,10 @@ class UsageService:
|
||||
usage.status = status
|
||||
if error_message:
|
||||
usage.error_message = error_message
|
||||
if provider:
|
||||
usage.provider = provider
|
||||
if target_model:
|
||||
usage.target_model = target_model
|
||||
|
||||
db.commit()
|
||||
|
||||
|
||||
@@ -457,7 +457,7 @@ class StreamUsageTracker:
|
||||
|
||||
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应 | 估算输入tokens:{self.input_tokens}")
|
||||
|
||||
# 更新状态为 streaming
|
||||
# 更新状态为 streaming,同时更新 provider
|
||||
if self.request_id:
|
||||
try:
|
||||
from src.services.usage.service import UsageService
|
||||
@@ -465,6 +465,7 @@ class StreamUsageTracker:
|
||||
db=self.db,
|
||||
request_id=self.request_id,
|
||||
status="streaming",
|
||||
provider=self.provider,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
|
||||
|
||||
@@ -210,7 +210,15 @@ class ApiKeyService:
|
||||
|
||||
@staticmethod
|
||||
def check_rate_limit(db: Session, api_key: ApiKey, window_minutes: int = 1) -> tuple[bool, int]:
|
||||
"""检查速率限制"""
|
||||
"""检查速率限制
|
||||
|
||||
Returns:
|
||||
(is_allowed, remaining): 是否允许请求,剩余可用次数
|
||||
当 rate_limit 为 None 时表示不限制,返回 (True, -1)
|
||||
"""
|
||||
# 如果 rate_limit 为 None,表示不限制
|
||||
if api_key.rate_limit is None:
|
||||
return True, -1 # -1 表示无限制
|
||||
|
||||
# 计算时间窗口
|
||||
window_start = datetime.now(timezone.utc) - timedelta(minutes=window_minutes)
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
提供统一的用户认证和授权功能
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends, Header, HTTPException, status
|
||||
@@ -44,10 +45,17 @@ async def get_current_user(
|
||||
payload = await AuthService.verify_token(token, token_type="access")
|
||||
except HTTPException as token_error:
|
||||
# 保持原始的HTTP状态码(如401 Unauthorized),不要转换为403
|
||||
logger.error(f"Token验证失败: {token_error.status_code}: {token_error.detail}, Token前10位: {token[:10]}...")
|
||||
token_fp = hashlib.sha256(token.encode()).hexdigest()[:12]
|
||||
logger.error(
|
||||
"Token验证失败: {}: {}, token_fp={}",
|
||||
token_error.status_code,
|
||||
token_error.detail,
|
||||
token_fp,
|
||||
)
|
||||
raise # 重新抛出原始异常,保持状态码
|
||||
except Exception as token_error:
|
||||
logger.error(f"Token验证失败: {token_error}, Token前10位: {token[:10]}...")
|
||||
token_fp = hashlib.sha256(token.encode()).hexdigest()[:12]
|
||||
logger.error("Token验证失败: {}, token_fp={}", token_error, token_fp)
|
||||
raise ForbiddenException("无效的Token")
|
||||
|
||||
user_id = payload.get("user_id")
|
||||
@@ -63,7 +71,8 @@ async def get_current_user(
|
||||
raise ForbiddenException("无效的认证凭据")
|
||||
|
||||
# 仅在DEBUG模式下记录详细信息
|
||||
logger.debug(f"尝试获取用户: user_id={user_id}, token前10位: {token[:10]}...")
|
||||
token_fp = hashlib.sha256(token.encode()).hexdigest()[:12]
|
||||
logger.debug("尝试获取用户: user_id={}, token_fp={}", user_id, token_fp)
|
||||
|
||||
# 确保user_id是字符串格式(UUID)
|
||||
if not isinstance(user_id, str):
|
||||
|
||||
@@ -7,29 +7,47 @@ from typing import Optional
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from src.config import config
|
||||
|
||||
|
||||
def get_client_ip(request: Request) -> str:
|
||||
"""
|
||||
获取客户端真实IP地址
|
||||
|
||||
按优先级检查:
|
||||
1. X-Forwarded-For 头(支持代理链)
|
||||
1. X-Forwarded-For 头(支持代理链,根据可信代理数量提取)
|
||||
2. X-Real-IP 头(Nginx 代理)
|
||||
3. 直接客户端IP
|
||||
|
||||
安全说明:
|
||||
- 此函数根据 TRUSTED_PROXY_COUNT 配置来决定信任的代理层数
|
||||
- 当 TRUSTED_PROXY_COUNT=0 时,不信任任何代理头,直接使用连接 IP
|
||||
- 当服务直接暴露公网时,应设置 TRUSTED_PROXY_COUNT=0 以防止 IP 伪造
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
|
||||
Returns:
|
||||
str: 客户端IP地址,如果无法获取则返回 "unknown"
|
||||
"""
|
||||
trusted_proxy_count = config.trusted_proxy_count
|
||||
|
||||
# 如果不信任任何代理,直接返回连接 IP
|
||||
if trusted_proxy_count == 0:
|
||||
if request.client and request.client.host:
|
||||
return request.client.host
|
||||
return "unknown"
|
||||
|
||||
# 优先检查 X-Forwarded-For 头(可能包含代理链)
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
# X-Forwarded-For 格式: "client, proxy1, proxy2",取第一个(真实客户端)
|
||||
client_ip = forwarded_for.split(",")[0].strip()
|
||||
if client_ip:
|
||||
return client_ip
|
||||
# X-Forwarded-For 格式: "client, proxy1, proxy2"
|
||||
# 从右往左数 trusted_proxy_count 个,取其左边的第一个
|
||||
ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()]
|
||||
if len(ips) > trusted_proxy_count:
|
||||
return ips[-(trusted_proxy_count + 1)]
|
||||
elif ips:
|
||||
return ips[0]
|
||||
|
||||
# 检查 X-Real-IP 头(通常由 Nginx 设置)
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
@@ -91,20 +109,32 @@ def get_request_metadata(request: Request) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def extract_ip_from_headers(headers: dict) -> str:
|
||||
def extract_ip_from_headers(headers: dict, trusted_proxy_count: Optional[int] = None) -> str:
|
||||
"""
|
||||
从HTTP头字典中提取IP地址(用于中间件等场景)
|
||||
|
||||
Args:
|
||||
headers: HTTP头字典
|
||||
trusted_proxy_count: 可信代理层数,None 时使用配置值
|
||||
|
||||
Returns:
|
||||
str: 客户端IP地址
|
||||
"""
|
||||
if trusted_proxy_count is None:
|
||||
trusted_proxy_count = config.trusted_proxy_count
|
||||
|
||||
# 如果不信任任何代理,返回 unknown(调用方需要用其他方式获取连接 IP)
|
||||
if trusted_proxy_count == 0:
|
||||
return "unknown"
|
||||
|
||||
# 检查 X-Forwarded-For
|
||||
forwarded_for = headers.get("x-forwarded-for", "")
|
||||
if forwarded_for:
|
||||
return forwarded_for.split(",")[0].strip()
|
||||
ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()]
|
||||
if len(ips) > trusted_proxy_count:
|
||||
return ips[-(trusted_proxy_count + 1)]
|
||||
elif ips:
|
||||
return ips[0]
|
||||
|
||||
# 检查 X-Real-IP
|
||||
real_ip = headers.get("x-real-ip", "")
|
||||
|
||||
@@ -361,3 +361,61 @@ class TestPipelineAdminAuth:
|
||||
|
||||
assert result == mock_user
|
||||
assert mock_request.state.user_id == "admin-123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_admin_lowercase_bearer(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试 bearer (小写) 前缀也能正确解析"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "admin-123"
|
||||
mock_user.is_active = True
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {"authorization": "bearer valid-token"}
|
||||
mock_request.state = MagicMock()
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
|
||||
with patch.object(
|
||||
pipeline.auth_service,
|
||||
"verify_token",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"user_id": "admin-123"},
|
||||
) as mock_verify:
|
||||
result = await pipeline._authenticate_admin(mock_request, mock_db)
|
||||
|
||||
mock_verify.assert_awaited_once_with("valid-token", token_type="access")
|
||||
assert result == mock_user
|
||||
|
||||
|
||||
class TestPipelineUserAuth:
|
||||
"""测试普通用户 JWT 认证"""
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline(self) -> ApiRequestPipeline:
|
||||
return ApiRequestPipeline()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_user_lowercase_bearer(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试 bearer (小写) 前缀也能正确解析"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "user-123"
|
||||
mock_user.is_active = True
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {"authorization": "bearer valid-token"}
|
||||
mock_request.state = MagicMock()
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
|
||||
with patch.object(
|
||||
pipeline.auth_service,
|
||||
"verify_token",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"user_id": "user-123"},
|
||||
) as mock_verify:
|
||||
result = await pipeline._authenticate_user(mock_request, mock_db)
|
||||
|
||||
mock_verify.assert_awaited_once_with("valid-token", token_type="access")
|
||||
assert result == mock_user
|
||||
|
||||
Reference in New Issue
Block a user