mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-08 10:42:29 +08:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5f0c1fb347 | ||
|
|
7b932d7afb | ||
|
|
c7b971cfe7 | ||
|
|
293bb592dc | ||
|
|
3e50c157be |
@@ -0,0 +1,57 @@
|
|||||||
|
"""add proxy field to provider_endpoints
|
||||||
|
|
||||||
|
Revision ID: f30f9936f6a2
|
||||||
|
Revises: 1cc6942cf06f
|
||||||
|
Create Date: 2025-12-18 06:31:58.451112+00:00
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
from sqlalchemy import inspect
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'f30f9936f6a2'
|
||||||
|
down_revision = '1cc6942cf06f'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def column_exists(table_name: str, column_name: str) -> bool:
|
||||||
|
"""检查列是否存在"""
|
||||||
|
bind = op.get_bind()
|
||||||
|
inspector = inspect(bind)
|
||||||
|
columns = [col['name'] for col in inspector.get_columns(table_name)]
|
||||||
|
return column_name in columns
|
||||||
|
|
||||||
|
|
||||||
|
def get_column_type(table_name: str, column_name: str) -> str:
|
||||||
|
"""获取列的类型"""
|
||||||
|
bind = op.get_bind()
|
||||||
|
inspector = inspect(bind)
|
||||||
|
for col in inspector.get_columns(table_name):
|
||||||
|
if col['name'] == column_name:
|
||||||
|
return str(col['type']).upper()
|
||||||
|
return ''
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""添加 proxy 字段到 provider_endpoints 表"""
|
||||||
|
if not column_exists('provider_endpoints', 'proxy'):
|
||||||
|
# 字段不存在,直接添加 JSONB 类型
|
||||||
|
op.add_column('provider_endpoints', sa.Column('proxy', JSONB(), nullable=True))
|
||||||
|
else:
|
||||||
|
# 字段已存在,检查是否需要转换类型
|
||||||
|
col_type = get_column_type('provider_endpoints', 'proxy')
|
||||||
|
if 'JSONB' not in col_type:
|
||||||
|
# 如果是 JSON 类型,转换为 JSONB
|
||||||
|
op.execute(
|
||||||
|
'ALTER TABLE provider_endpoints '
|
||||||
|
'ALTER COLUMN proxy TYPE JSONB USING proxy::jsonb'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""移除 proxy 字段"""
|
||||||
|
if column_exists('provider_endpoints', 'proxy'):
|
||||||
|
op.drop_column('provider_endpoints', 'proxy')
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
import client from '../client'
|
import client from '../client'
|
||||||
import type { ProviderEndpoint } from './types'
|
import type { ProviderEndpoint, ProxyConfig } from './types'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取指定 Provider 的所有 Endpoints
|
* 获取指定 Provider 的所有 Endpoints
|
||||||
@@ -38,6 +38,7 @@ export async function createEndpoint(
|
|||||||
rate_limit?: number
|
rate_limit?: number
|
||||||
is_active?: boolean
|
is_active?: boolean
|
||||||
config?: Record<string, any>
|
config?: Record<string, any>
|
||||||
|
proxy?: ProxyConfig | null
|
||||||
}
|
}
|
||||||
): Promise<ProviderEndpoint> {
|
): Promise<ProviderEndpoint> {
|
||||||
const response = await client.post(`/api/admin/endpoints/providers/${providerId}/endpoints`, data)
|
const response = await client.post(`/api/admin/endpoints/providers/${providerId}/endpoints`, data)
|
||||||
@@ -63,6 +64,7 @@ export async function updateEndpoint(
|
|||||||
rate_limit: number
|
rate_limit: number
|
||||||
is_active: boolean
|
is_active: boolean
|
||||||
config: Record<string, any>
|
config: Record<string, any>
|
||||||
|
proxy: ProxyConfig | null
|
||||||
}>
|
}>
|
||||||
): Promise<ProviderEndpoint> {
|
): Promise<ProviderEndpoint> {
|
||||||
const response = await client.put(`/api/admin/endpoints/${endpointId}`, data)
|
const response = await client.put(`/api/admin/endpoints/${endpointId}`, data)
|
||||||
|
|||||||
@@ -20,6 +20,16 @@ export const API_FORMAT_LABELS: Record<string, string> = {
|
|||||||
[API_FORMATS.GEMINI_CLI]: 'Gemini CLI',
|
[API_FORMATS.GEMINI_CLI]: 'Gemini CLI',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 代理配置类型
|
||||||
|
*/
|
||||||
|
export interface ProxyConfig {
|
||||||
|
url: string
|
||||||
|
username?: string
|
||||||
|
password?: string
|
||||||
|
enabled?: boolean // 是否启用代理(false 时保留配置但不使用)
|
||||||
|
}
|
||||||
|
|
||||||
export interface ProviderEndpoint {
|
export interface ProviderEndpoint {
|
||||||
id: string
|
id: string
|
||||||
provider_id: string
|
provider_id: string
|
||||||
@@ -41,6 +51,7 @@ export interface ProviderEndpoint {
|
|||||||
last_failure_at?: string
|
last_failure_at?: string
|
||||||
is_active: boolean
|
is_active: boolean
|
||||||
config?: Record<string, any>
|
config?: Record<string, any>
|
||||||
|
proxy?: ProxyConfig | null
|
||||||
total_keys: number
|
total_keys: number
|
||||||
active_keys: number
|
active_keys: number
|
||||||
created_at: string
|
created_at: string
|
||||||
|
|||||||
@@ -132,7 +132,7 @@
|
|||||||
type="number"
|
type="number"
|
||||||
min="1"
|
min="1"
|
||||||
max="10000"
|
max="10000"
|
||||||
placeholder="100"
|
placeholder="留空不限制"
|
||||||
class="h-10"
|
class="h-10"
|
||||||
@update:model-value="(v) => form.rate_limit = parseNumberInput(v, { min: 1, max: 10000 })"
|
@update:model-value="(v) => form.rate_limit = parseNumberInput(v, { min: 1, max: 10000 })"
|
||||||
/>
|
/>
|
||||||
@@ -376,7 +376,7 @@ const form = ref<StandaloneKeyFormData>({
|
|||||||
initial_balance_usd: 10,
|
initial_balance_usd: 10,
|
||||||
expire_days: undefined,
|
expire_days: undefined,
|
||||||
never_expire: true,
|
never_expire: true,
|
||||||
rate_limit: 100,
|
rate_limit: undefined,
|
||||||
auto_delete_on_expiry: false,
|
auto_delete_on_expiry: false,
|
||||||
allowed_providers: [],
|
allowed_providers: [],
|
||||||
allowed_api_formats: [],
|
allowed_api_formats: [],
|
||||||
@@ -389,7 +389,7 @@ function resetForm() {
|
|||||||
initial_balance_usd: 10,
|
initial_balance_usd: 10,
|
||||||
expire_days: undefined,
|
expire_days: undefined,
|
||||||
never_expire: true,
|
never_expire: true,
|
||||||
rate_limit: 100,
|
rate_limit: undefined,
|
||||||
auto_delete_on_expiry: false,
|
auto_delete_on_expiry: false,
|
||||||
allowed_providers: [],
|
allowed_providers: [],
|
||||||
allowed_api_formats: [],
|
allowed_api_formats: [],
|
||||||
@@ -408,7 +408,7 @@ function loadKeyData() {
|
|||||||
initial_balance_usd: props.apiKey.initial_balance_usd,
|
initial_balance_usd: props.apiKey.initial_balance_usd,
|
||||||
expire_days: props.apiKey.expire_days,
|
expire_days: props.apiKey.expire_days,
|
||||||
never_expire: props.apiKey.never_expire,
|
never_expire: props.apiKey.never_expire,
|
||||||
rate_limit: props.apiKey.rate_limit || 100,
|
rate_limit: props.apiKey.rate_limit,
|
||||||
auto_delete_on_expiry: props.apiKey.auto_delete_on_expiry,
|
auto_delete_on_expiry: props.apiKey.auto_delete_on_expiry,
|
||||||
allowed_providers: props.apiKey.allowed_providers || [],
|
allowed_providers: props.apiKey.allowed_providers || [],
|
||||||
allowed_api_formats: props.apiKey.allowed_api_formats || [],
|
allowed_api_formats: props.apiKey.allowed_api_formats || [],
|
||||||
|
|||||||
@@ -9,7 +9,7 @@
|
|||||||
>
|
>
|
||||||
<form
|
<form
|
||||||
class="space-y-6"
|
class="space-y-6"
|
||||||
@submit.prevent="handleSubmit"
|
@submit.prevent="handleSubmit()"
|
||||||
>
|
>
|
||||||
<!-- API 配置 -->
|
<!-- API 配置 -->
|
||||||
<div class="space-y-4">
|
<div class="space-y-4">
|
||||||
@@ -132,6 +132,79 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- 代理配置 -->
|
||||||
|
<div class="space-y-4">
|
||||||
|
<div class="flex items-center justify-between">
|
||||||
|
<h3 class="text-sm font-medium">
|
||||||
|
代理配置
|
||||||
|
</h3>
|
||||||
|
<div class="flex items-center gap-2">
|
||||||
|
<Switch v-model="proxyEnabled" />
|
||||||
|
<span class="text-sm text-muted-foreground">启用代理</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div
|
||||||
|
v-if="proxyEnabled"
|
||||||
|
class="space-y-4 rounded-lg border p-4"
|
||||||
|
>
|
||||||
|
<div class="space-y-2">
|
||||||
|
<Label for="proxy_url">代理 URL *</Label>
|
||||||
|
<Input
|
||||||
|
id="proxy_url"
|
||||||
|
v-model="form.proxy_url"
|
||||||
|
placeholder="http://host:port 或 socks5://host:port"
|
||||||
|
required
|
||||||
|
:class="proxyUrlError ? 'border-red-500' : ''"
|
||||||
|
/>
|
||||||
|
<p
|
||||||
|
v-if="proxyUrlError"
|
||||||
|
class="text-xs text-red-500"
|
||||||
|
>
|
||||||
|
{{ proxyUrlError }}
|
||||||
|
</p>
|
||||||
|
<p
|
||||||
|
v-else
|
||||||
|
class="text-xs text-muted-foreground"
|
||||||
|
>
|
||||||
|
支持 HTTP、HTTPS、SOCKS5 代理
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="grid grid-cols-2 gap-4">
|
||||||
|
<div class="space-y-2">
|
||||||
|
<Label for="proxy_user">用户名(可选)</Label>
|
||||||
|
<Input
|
||||||
|
:id="`proxy_user_${formId}`"
|
||||||
|
:name="`proxy_user_${formId}`"
|
||||||
|
v-model="form.proxy_username"
|
||||||
|
placeholder="代理认证用户名"
|
||||||
|
autocomplete="off"
|
||||||
|
data-form-type="other"
|
||||||
|
data-lpignore="true"
|
||||||
|
data-1p-ignore="true"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="space-y-2">
|
||||||
|
<Label :for="`proxy_pass_${formId}`">密码(可选)</Label>
|
||||||
|
<Input
|
||||||
|
:id="`proxy_pass_${formId}`"
|
||||||
|
:name="`proxy_pass_${formId}`"
|
||||||
|
v-model="form.proxy_password"
|
||||||
|
type="text"
|
||||||
|
:placeholder="passwordPlaceholder"
|
||||||
|
autocomplete="off"
|
||||||
|
data-form-type="other"
|
||||||
|
data-lpignore="true"
|
||||||
|
data-1p-ignore="true"
|
||||||
|
:style="{ '-webkit-text-security': 'disc', 'text-security': 'disc' }"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
</form>
|
</form>
|
||||||
|
|
||||||
<template #footer>
|
<template #footer>
|
||||||
@@ -145,12 +218,24 @@
|
|||||||
</Button>
|
</Button>
|
||||||
<Button
|
<Button
|
||||||
:disabled="loading || !form.base_url || (!isEditMode && !form.api_format)"
|
:disabled="loading || !form.base_url || (!isEditMode && !form.api_format)"
|
||||||
@click="handleSubmit"
|
@click="handleSubmit()"
|
||||||
>
|
>
|
||||||
{{ loading ? (isEditMode ? '保存中...' : '创建中...') : (isEditMode ? '保存修改' : '创建') }}
|
{{ loading ? (isEditMode ? '保存中...' : '创建中...') : (isEditMode ? '保存修改' : '创建') }}
|
||||||
</Button>
|
</Button>
|
||||||
</template>
|
</template>
|
||||||
</Dialog>
|
</Dialog>
|
||||||
|
|
||||||
|
<!-- 确认清空凭据对话框 -->
|
||||||
|
<AlertDialog
|
||||||
|
v-model="showClearCredentialsDialog"
|
||||||
|
title="清空代理凭据"
|
||||||
|
description="代理 URL 为空,但用户名和密码仍有值。是否清空这些凭据并继续保存?"
|
||||||
|
type="warning"
|
||||||
|
confirm-text="清空并保存"
|
||||||
|
cancel-text="返回编辑"
|
||||||
|
@confirm="confirmClearCredentials"
|
||||||
|
@cancel="showClearCredentialsDialog = false"
|
||||||
|
/>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
@@ -165,7 +250,9 @@ import {
|
|||||||
SelectValue,
|
SelectValue,
|
||||||
SelectContent,
|
SelectContent,
|
||||||
SelectItem,
|
SelectItem,
|
||||||
|
Switch,
|
||||||
} from '@/components/ui'
|
} from '@/components/ui'
|
||||||
|
import AlertDialog from '@/components/common/AlertDialog.vue'
|
||||||
import { Link, SquarePen } from 'lucide-vue-next'
|
import { Link, SquarePen } from 'lucide-vue-next'
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
import { useFormDialog } from '@/composables/useFormDialog'
|
import { useFormDialog } from '@/composables/useFormDialog'
|
||||||
@@ -194,6 +281,11 @@ const emit = defineEmits<{
|
|||||||
const { success, error: showError } = useToast()
|
const { success, error: showError } = useToast()
|
||||||
const loading = ref(false)
|
const loading = ref(false)
|
||||||
const selectOpen = ref(false)
|
const selectOpen = ref(false)
|
||||||
|
const proxyEnabled = ref(false)
|
||||||
|
const showClearCredentialsDialog = ref(false) // 确认清空凭据对话框
|
||||||
|
|
||||||
|
// 生成随机 ID 防止浏览器自动填充
|
||||||
|
const formId = Math.random().toString(36).substring(2, 10)
|
||||||
|
|
||||||
// 内部状态
|
// 内部状态
|
||||||
const internalOpen = computed(() => props.modelValue)
|
const internalOpen = computed(() => props.modelValue)
|
||||||
@@ -207,7 +299,11 @@ const form = ref({
|
|||||||
max_retries: 3,
|
max_retries: 3,
|
||||||
max_concurrent: undefined as number | undefined,
|
max_concurrent: undefined as number | undefined,
|
||||||
rate_limit: undefined as number | undefined,
|
rate_limit: undefined as number | undefined,
|
||||||
is_active: true
|
is_active: true,
|
||||||
|
// 代理配置
|
||||||
|
proxy_url: '',
|
||||||
|
proxy_username: '',
|
||||||
|
proxy_password: '',
|
||||||
})
|
})
|
||||||
|
|
||||||
// API 格式列表
|
// API 格式列表
|
||||||
@@ -237,6 +333,53 @@ const defaultPathPlaceholder = computed(() => {
|
|||||||
return `留空使用默认路径:${defaultPath.value}`
|
return `留空使用默认路径:${defaultPath.value}`
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// 检查是否有已保存的密码(后端返回 *** 表示有密码)
|
||||||
|
const hasExistingPassword = computed(() => {
|
||||||
|
if (!props.endpoint?.proxy) return false
|
||||||
|
const proxy = props.endpoint.proxy as { password?: string }
|
||||||
|
return proxy?.password === MASKED_PASSWORD
|
||||||
|
})
|
||||||
|
|
||||||
|
// 密码输入框的 placeholder
|
||||||
|
const passwordPlaceholder = computed(() => {
|
||||||
|
if (hasExistingPassword.value) {
|
||||||
|
return '已保存密码,留空保持不变'
|
||||||
|
}
|
||||||
|
return '代理认证密码'
|
||||||
|
})
|
||||||
|
|
||||||
|
// 代理 URL 验证
|
||||||
|
const proxyUrlError = computed(() => {
|
||||||
|
// 只有启用代理且填写了 URL 时才验证
|
||||||
|
if (!proxyEnabled.value || !form.value.proxy_url) {
|
||||||
|
return ''
|
||||||
|
}
|
||||||
|
const url = form.value.proxy_url.trim()
|
||||||
|
|
||||||
|
// 检查禁止的特殊字符
|
||||||
|
if (/[\n\r]/.test(url)) {
|
||||||
|
return '代理 URL 包含非法字符'
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证协议(不支持 SOCKS4)
|
||||||
|
if (!/^(http|https|socks5):\/\//i.test(url)) {
|
||||||
|
return '代理 URL 必须以 http://, https:// 或 socks5:// 开头'
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
const parsed = new URL(url)
|
||||||
|
if (!parsed.host) {
|
||||||
|
return '代理 URL 必须包含有效的 host'
|
||||||
|
}
|
||||||
|
// 禁止 URL 中内嵌认证信息
|
||||||
|
if (parsed.username || parsed.password) {
|
||||||
|
return '请勿在 URL 中包含用户名和密码,请使用独立的认证字段'
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
return '代理 URL 格式无效'
|
||||||
|
}
|
||||||
|
return ''
|
||||||
|
})
|
||||||
|
|
||||||
// 组件挂载时加载API格式
|
// 组件挂载时加载API格式
|
||||||
onMounted(() => {
|
onMounted(() => {
|
||||||
loadApiFormats()
|
loadApiFormats()
|
||||||
@@ -252,14 +395,23 @@ function resetForm() {
|
|||||||
max_retries: 3,
|
max_retries: 3,
|
||||||
max_concurrent: undefined,
|
max_concurrent: undefined,
|
||||||
rate_limit: undefined,
|
rate_limit: undefined,
|
||||||
is_active: true
|
is_active: true,
|
||||||
|
proxy_url: '',
|
||||||
|
proxy_username: '',
|
||||||
|
proxy_password: '',
|
||||||
}
|
}
|
||||||
|
proxyEnabled.value = false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 原始密码占位符(后端返回的脱敏标记)
|
||||||
|
const MASKED_PASSWORD = '***'
|
||||||
|
|
||||||
// 加载端点数据(编辑模式)
|
// 加载端点数据(编辑模式)
|
||||||
function loadEndpointData() {
|
function loadEndpointData() {
|
||||||
if (!props.endpoint) return
|
if (!props.endpoint) return
|
||||||
|
|
||||||
|
const proxy = props.endpoint.proxy as { url?: string; username?: string; password?: string; enabled?: boolean } | null
|
||||||
|
|
||||||
form.value = {
|
form.value = {
|
||||||
api_format: props.endpoint.api_format,
|
api_format: props.endpoint.api_format,
|
||||||
base_url: props.endpoint.base_url,
|
base_url: props.endpoint.base_url,
|
||||||
@@ -268,8 +420,15 @@ function loadEndpointData() {
|
|||||||
max_retries: props.endpoint.max_retries,
|
max_retries: props.endpoint.max_retries,
|
||||||
max_concurrent: props.endpoint.max_concurrent || undefined,
|
max_concurrent: props.endpoint.max_concurrent || undefined,
|
||||||
rate_limit: props.endpoint.rate_limit || undefined,
|
rate_limit: props.endpoint.rate_limit || undefined,
|
||||||
is_active: props.endpoint.is_active
|
is_active: props.endpoint.is_active,
|
||||||
|
proxy_url: proxy?.url || '',
|
||||||
|
proxy_username: proxy?.username || '',
|
||||||
|
// 如果密码是脱敏标记,显示为空(让用户知道有密码但看不到)
|
||||||
|
proxy_password: proxy?.password === MASKED_PASSWORD ? '' : (proxy?.password || ''),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 根据 enabled 字段或 url 存在判断是否启用代理
|
||||||
|
proxyEnabled.value = proxy?.enabled ?? !!proxy?.url
|
||||||
}
|
}
|
||||||
|
|
||||||
// 使用 useFormDialog 统一处理对话框逻辑
|
// 使用 useFormDialog 统一处理对话框逻辑
|
||||||
@@ -282,12 +441,47 @@ const { isEditMode, handleDialogUpdate, handleCancel } = useFormDialog({
|
|||||||
resetForm,
|
resetForm,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// 构建代理配置
|
||||||
|
// - 有 URL 时始终保存配置,通过 enabled 字段控制是否启用
|
||||||
|
// - 无 URL 时返回 null
|
||||||
|
function buildProxyConfig(): { url: string; username?: string; password?: string; enabled: boolean } | null {
|
||||||
|
if (!form.value.proxy_url) {
|
||||||
|
// 没填 URL,无代理配置
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
url: form.value.proxy_url,
|
||||||
|
username: form.value.proxy_username || undefined,
|
||||||
|
password: form.value.proxy_password || undefined,
|
||||||
|
enabled: proxyEnabled.value, // 开关状态决定是否启用
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 提交表单
|
// 提交表单
|
||||||
const handleSubmit = async () => {
|
const handleSubmit = async (skipCredentialCheck = false) => {
|
||||||
if (!props.provider && !props.endpoint) return
|
if (!props.provider && !props.endpoint) return
|
||||||
|
|
||||||
|
// 只在开关开启且填写了 URL 时验证
|
||||||
|
if (proxyEnabled.value && form.value.proxy_url && proxyUrlError.value) {
|
||||||
|
showError(proxyUrlError.value, '代理配置错误')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查:开关开启但没有 URL,却有用户名或密码
|
||||||
|
const hasOrphanedCredentials = proxyEnabled.value
|
||||||
|
&& !form.value.proxy_url
|
||||||
|
&& (form.value.proxy_username || form.value.proxy_password)
|
||||||
|
|
||||||
|
if (hasOrphanedCredentials && !skipCredentialCheck) {
|
||||||
|
// 弹出确认对话框
|
||||||
|
showClearCredentialsDialog.value = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
loading.value = true
|
loading.value = true
|
||||||
try {
|
try {
|
||||||
|
const proxyConfig = buildProxyConfig()
|
||||||
|
|
||||||
if (isEditMode.value && props.endpoint) {
|
if (isEditMode.value && props.endpoint) {
|
||||||
// 更新端点
|
// 更新端点
|
||||||
await updateEndpoint(props.endpoint.id, {
|
await updateEndpoint(props.endpoint.id, {
|
||||||
@@ -297,7 +491,8 @@ const handleSubmit = async () => {
|
|||||||
max_retries: form.value.max_retries,
|
max_retries: form.value.max_retries,
|
||||||
max_concurrent: form.value.max_concurrent,
|
max_concurrent: form.value.max_concurrent,
|
||||||
rate_limit: form.value.rate_limit,
|
rate_limit: form.value.rate_limit,
|
||||||
is_active: form.value.is_active
|
is_active: form.value.is_active,
|
||||||
|
proxy: proxyConfig,
|
||||||
})
|
})
|
||||||
|
|
||||||
success('端点已更新', '保存成功')
|
success('端点已更新', '保存成功')
|
||||||
@@ -313,7 +508,8 @@ const handleSubmit = async () => {
|
|||||||
max_retries: form.value.max_retries,
|
max_retries: form.value.max_retries,
|
||||||
max_concurrent: form.value.max_concurrent,
|
max_concurrent: form.value.max_concurrent,
|
||||||
rate_limit: form.value.rate_limit,
|
rate_limit: form.value.rate_limit,
|
||||||
is_active: form.value.is_active
|
is_active: form.value.is_active,
|
||||||
|
proxy: proxyConfig,
|
||||||
})
|
})
|
||||||
|
|
||||||
success('端点创建成功', '成功')
|
success('端点创建成功', '成功')
|
||||||
@@ -329,4 +525,12 @@ const handleSubmit = async () => {
|
|||||||
loading.value = false
|
loading.value = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 确认清空凭据并继续保存
|
||||||
|
const confirmClearCredentials = () => {
|
||||||
|
form.value.proxy_username = ''
|
||||||
|
form.value.proxy_password = ''
|
||||||
|
showClearCredentialsDialog.value = false
|
||||||
|
handleSubmit(true) // 跳过凭据检查,直接提交
|
||||||
|
}
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
@@ -25,7 +25,7 @@
|
|||||||
</h3>
|
</h3>
|
||||||
<div class="flex items-center gap-1 text-sm font-mono text-muted-foreground bg-muted px-2 py-0.5 rounded">
|
<div class="flex items-center gap-1 text-sm font-mono text-muted-foreground bg-muted px-2 py-0.5 rounded">
|
||||||
<span>{{ detail?.model || '-' }}</span>
|
<span>{{ detail?.model || '-' }}</span>
|
||||||
<template v-if="detail?.target_model">
|
<template v-if="detail?.target_model && detail.target_model !== detail.model">
|
||||||
<svg
|
<svg
|
||||||
xmlns="http://www.w3.org/2000/svg"
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
viewBox="0 0 20 20"
|
viewBox="0 0 20 20"
|
||||||
|
|||||||
@@ -185,32 +185,13 @@
|
|||||||
</div>
|
</div>
|
||||||
</CardSection>
|
</CardSection>
|
||||||
|
|
||||||
<!-- API Key 管理配置 -->
|
<!-- 独立余额 Key 过期管理 -->
|
||||||
<CardSection
|
<CardSection
|
||||||
title="API Key 管理"
|
title="独立余额 Key 过期管理"
|
||||||
description="API Key 相关配置"
|
description="独立余额 Key 的过期处理策略(普通用户 Key 不会过期)"
|
||||||
>
|
>
|
||||||
<div class="grid grid-cols-1 md:grid-cols-2 gap-6">
|
<div class="grid grid-cols-1 md:grid-cols-2 gap-6">
|
||||||
<div>
|
<div class="flex items-center h-full">
|
||||||
<Label
|
|
||||||
for="api-key-expire"
|
|
||||||
class="block text-sm font-medium"
|
|
||||||
>
|
|
||||||
API密钥过期天数
|
|
||||||
</Label>
|
|
||||||
<Input
|
|
||||||
id="api-key-expire"
|
|
||||||
v-model.number="systemConfig.api_key_expire_days"
|
|
||||||
type="number"
|
|
||||||
placeholder="0"
|
|
||||||
class="mt-1"
|
|
||||||
/>
|
|
||||||
<p class="mt-1 text-xs text-muted-foreground">
|
|
||||||
0 表示永不过期
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="flex items-center h-full pt-6">
|
|
||||||
<div class="flex items-center space-x-2">
|
<div class="flex items-center space-x-2">
|
||||||
<Checkbox
|
<Checkbox
|
||||||
id="auto-delete-expired-keys"
|
id="auto-delete-expired-keys"
|
||||||
@@ -224,7 +205,7 @@
|
|||||||
自动删除过期 Key
|
自动删除过期 Key
|
||||||
</Label>
|
</Label>
|
||||||
<p class="text-xs text-muted-foreground">
|
<p class="text-xs text-muted-foreground">
|
||||||
关闭时仅禁用过期 Key
|
关闭时仅禁用过期 Key,不会物理删除
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -448,6 +429,25 @@
|
|||||||
避免单次操作过大影响性能
|
避免单次操作过大影响性能
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<Label
|
||||||
|
for="audit-log-retention-days"
|
||||||
|
class="block text-sm font-medium"
|
||||||
|
>
|
||||||
|
审计日志保留天数
|
||||||
|
</Label>
|
||||||
|
<Input
|
||||||
|
id="audit-log-retention-days"
|
||||||
|
v-model.number="systemConfig.audit_log_retention_days"
|
||||||
|
type="number"
|
||||||
|
placeholder="30"
|
||||||
|
class="mt-1"
|
||||||
|
/>
|
||||||
|
<p class="mt-1 text-xs text-muted-foreground">
|
||||||
|
超过后删除审计日志记录
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- 清理策略说明 -->
|
<!-- 清理策略说明 -->
|
||||||
@@ -460,6 +460,7 @@
|
|||||||
<p>2. <strong>压缩日志阶段</strong>: body 字段被压缩存储,节省空间</p>
|
<p>2. <strong>压缩日志阶段</strong>: body 字段被压缩存储,节省空间</p>
|
||||||
<p>3. <strong>统计阶段</strong>: 仅保留 tokens、成本等统计信息</p>
|
<p>3. <strong>统计阶段</strong>: 仅保留 tokens、成本等统计信息</p>
|
||||||
<p>4. <strong>归档删除</strong>: 超过保留期限后完全删除记录</p>
|
<p>4. <strong>归档删除</strong>: 超过保留期限后完全删除记录</p>
|
||||||
|
<p>5. <strong>审计日志</strong>: 独立清理,记录用户登录、操作等安全事件</p>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</CardSection>
|
</CardSection>
|
||||||
@@ -796,8 +797,7 @@ interface SystemConfig {
|
|||||||
// 用户注册
|
// 用户注册
|
||||||
enable_registration: boolean
|
enable_registration: boolean
|
||||||
require_email_verification: boolean
|
require_email_verification: boolean
|
||||||
// API Key 管理
|
// 独立余额 Key 过期管理
|
||||||
api_key_expire_days: number
|
|
||||||
auto_delete_expired_keys: boolean
|
auto_delete_expired_keys: boolean
|
||||||
// 日志记录
|
// 日志记录
|
||||||
request_log_level: string
|
request_log_level: string
|
||||||
@@ -811,6 +811,7 @@ interface SystemConfig {
|
|||||||
header_retention_days: number
|
header_retention_days: number
|
||||||
log_retention_days: number
|
log_retention_days: number
|
||||||
cleanup_batch_size: number
|
cleanup_batch_size: number
|
||||||
|
audit_log_retention_days: number
|
||||||
}
|
}
|
||||||
|
|
||||||
const loading = ref(false)
|
const loading = ref(false)
|
||||||
@@ -845,8 +846,7 @@ const systemConfig = ref<SystemConfig>({
|
|||||||
// 用户注册
|
// 用户注册
|
||||||
enable_registration: false,
|
enable_registration: false,
|
||||||
require_email_verification: false,
|
require_email_verification: false,
|
||||||
// API Key 管理
|
// 独立余额 Key 过期管理
|
||||||
api_key_expire_days: 0,
|
|
||||||
auto_delete_expired_keys: false,
|
auto_delete_expired_keys: false,
|
||||||
// 日志记录
|
// 日志记录
|
||||||
request_log_level: 'basic',
|
request_log_level: 'basic',
|
||||||
@@ -860,6 +860,7 @@ const systemConfig = ref<SystemConfig>({
|
|||||||
header_retention_days: 90,
|
header_retention_days: 90,
|
||||||
log_retention_days: 365,
|
log_retention_days: 365,
|
||||||
cleanup_batch_size: 1000,
|
cleanup_batch_size: 1000,
|
||||||
|
audit_log_retention_days: 30,
|
||||||
})
|
})
|
||||||
|
|
||||||
// 计算属性:KB 和 字节 之间的转换
|
// 计算属性:KB 和 字节 之间的转换
|
||||||
@@ -901,8 +902,7 @@ async function loadSystemConfig() {
|
|||||||
// 用户注册
|
// 用户注册
|
||||||
'enable_registration',
|
'enable_registration',
|
||||||
'require_email_verification',
|
'require_email_verification',
|
||||||
// API Key 管理
|
// 独立余额 Key 过期管理
|
||||||
'api_key_expire_days',
|
|
||||||
'auto_delete_expired_keys',
|
'auto_delete_expired_keys',
|
||||||
// 日志记录
|
// 日志记录
|
||||||
'request_log_level',
|
'request_log_level',
|
||||||
@@ -916,6 +916,7 @@ async function loadSystemConfig() {
|
|||||||
'header_retention_days',
|
'header_retention_days',
|
||||||
'log_retention_days',
|
'log_retention_days',
|
||||||
'cleanup_batch_size',
|
'cleanup_batch_size',
|
||||||
|
'audit_log_retention_days',
|
||||||
]
|
]
|
||||||
|
|
||||||
for (const key of configs) {
|
for (const key of configs) {
|
||||||
@@ -960,12 +961,7 @@ async function saveSystemConfig() {
|
|||||||
value: systemConfig.value.require_email_verification,
|
value: systemConfig.value.require_email_verification,
|
||||||
description: '是否需要邮箱验证'
|
description: '是否需要邮箱验证'
|
||||||
},
|
},
|
||||||
// API Key 管理
|
// 独立余额 Key 过期管理
|
||||||
{
|
|
||||||
key: 'api_key_expire_days',
|
|
||||||
value: systemConfig.value.api_key_expire_days,
|
|
||||||
description: 'API密钥过期天数'
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
key: 'auto_delete_expired_keys',
|
key: 'auto_delete_expired_keys',
|
||||||
value: systemConfig.value.auto_delete_expired_keys,
|
value: systemConfig.value.auto_delete_expired_keys,
|
||||||
@@ -1023,6 +1019,11 @@ async function saveSystemConfig() {
|
|||||||
value: systemConfig.value.cleanup_batch_size,
|
value: systemConfig.value.cleanup_batch_size,
|
||||||
description: '每批次清理的记录数'
|
description: '每批次清理的记录数'
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
key: 'audit_log_retention_days',
|
||||||
|
value: systemConfig.value.audit_log_retention_days,
|
||||||
|
description: '审计日志保留天数'
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
const promises = configItems.map(item =>
|
const promises = configItems.map(item =>
|
||||||
|
|||||||
@@ -223,7 +223,7 @@ class AdminCreateStandaloneKeyAdapter(AdminApiAdapter):
|
|||||||
allowed_providers=self.key_data.allowed_providers,
|
allowed_providers=self.key_data.allowed_providers,
|
||||||
allowed_api_formats=self.key_data.allowed_api_formats,
|
allowed_api_formats=self.key_data.allowed_api_formats,
|
||||||
allowed_models=self.key_data.allowed_models,
|
allowed_models=self.key_data.allowed_models,
|
||||||
rate_limit=self.key_data.rate_limit or 100,
|
rate_limit=self.key_data.rate_limit, # None 表示不限制
|
||||||
expire_days=self.key_data.expire_days,
|
expire_days=self.key_data.expire_days,
|
||||||
initial_balance_usd=self.key_data.initial_balance_usd,
|
initial_balance_usd=self.key_data.initial_balance_usd,
|
||||||
is_standalone=True, # 标记为独立Key
|
is_standalone=True, # 标记为独立Key
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ ProviderEndpoint CRUD 管理 API
|
|||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query, Request
|
from fastapi import APIRouter, Depends, Query, Request
|
||||||
from sqlalchemy import and_, func
|
from sqlalchemy import and_, func
|
||||||
@@ -27,6 +27,16 @@ router = APIRouter(tags=["Endpoint Management"])
|
|||||||
pipeline = ApiRequestPipeline()
|
pipeline = ApiRequestPipeline()
|
||||||
|
|
||||||
|
|
||||||
|
def mask_proxy_password(proxy_config: Optional[dict]) -> Optional[dict]:
|
||||||
|
"""对代理配置中的密码进行脱敏处理"""
|
||||||
|
if not proxy_config:
|
||||||
|
return None
|
||||||
|
masked = dict(proxy_config)
|
||||||
|
if masked.get("password"):
|
||||||
|
masked["password"] = "***"
|
||||||
|
return masked
|
||||||
|
|
||||||
|
|
||||||
@router.get("/providers/{provider_id}/endpoints", response_model=List[ProviderEndpointResponse])
|
@router.get("/providers/{provider_id}/endpoints", response_model=List[ProviderEndpointResponse])
|
||||||
async def list_provider_endpoints(
|
async def list_provider_endpoints(
|
||||||
provider_id: str,
|
provider_id: str,
|
||||||
@@ -153,6 +163,7 @@ class AdminListProviderEndpointsAdapter(AdminApiAdapter):
|
|||||||
"api_format": endpoint.api_format,
|
"api_format": endpoint.api_format,
|
||||||
"total_keys": total_keys_map.get(endpoint.id, 0),
|
"total_keys": total_keys_map.get(endpoint.id, 0),
|
||||||
"active_keys": active_keys_map.get(endpoint.id, 0),
|
"active_keys": active_keys_map.get(endpoint.id, 0),
|
||||||
|
"proxy": mask_proxy_password(endpoint.proxy),
|
||||||
}
|
}
|
||||||
endpoint_dict.pop("_sa_instance_state", None)
|
endpoint_dict.pop("_sa_instance_state", None)
|
||||||
result.append(ProviderEndpointResponse(**endpoint_dict))
|
result.append(ProviderEndpointResponse(**endpoint_dict))
|
||||||
@@ -202,6 +213,7 @@ class AdminCreateProviderEndpointAdapter(AdminApiAdapter):
|
|||||||
rate_limit=self.endpoint_data.rate_limit,
|
rate_limit=self.endpoint_data.rate_limit,
|
||||||
is_active=True,
|
is_active=True,
|
||||||
config=self.endpoint_data.config,
|
config=self.endpoint_data.config,
|
||||||
|
proxy=self.endpoint_data.proxy.model_dump() if self.endpoint_data.proxy else None,
|
||||||
created_at=now,
|
created_at=now,
|
||||||
updated_at=now,
|
updated_at=now,
|
||||||
)
|
)
|
||||||
@@ -215,12 +227,13 @@ class AdminCreateProviderEndpointAdapter(AdminApiAdapter):
|
|||||||
endpoint_dict = {
|
endpoint_dict = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in new_endpoint.__dict__.items()
|
for k, v in new_endpoint.__dict__.items()
|
||||||
if k not in {"api_format", "_sa_instance_state"}
|
if k not in {"api_format", "_sa_instance_state", "proxy"}
|
||||||
}
|
}
|
||||||
return ProviderEndpointResponse(
|
return ProviderEndpointResponse(
|
||||||
**endpoint_dict,
|
**endpoint_dict,
|
||||||
provider_name=provider.name,
|
provider_name=provider.name,
|
||||||
api_format=new_endpoint.api_format,
|
api_format=new_endpoint.api_format,
|
||||||
|
proxy=mask_proxy_password(new_endpoint.proxy),
|
||||||
total_keys=0,
|
total_keys=0,
|
||||||
active_keys=0,
|
active_keys=0,
|
||||||
)
|
)
|
||||||
@@ -259,12 +272,13 @@ class AdminGetProviderEndpointAdapter(AdminApiAdapter):
|
|||||||
endpoint_dict = {
|
endpoint_dict = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in endpoint_obj.__dict__.items()
|
for k, v in endpoint_obj.__dict__.items()
|
||||||
if k not in {"api_format", "_sa_instance_state"}
|
if k not in {"api_format", "_sa_instance_state", "proxy"}
|
||||||
}
|
}
|
||||||
return ProviderEndpointResponse(
|
return ProviderEndpointResponse(
|
||||||
**endpoint_dict,
|
**endpoint_dict,
|
||||||
provider_name=provider.name,
|
provider_name=provider.name,
|
||||||
api_format=endpoint_obj.api_format,
|
api_format=endpoint_obj.api_format,
|
||||||
|
proxy=mask_proxy_password(endpoint_obj.proxy),
|
||||||
total_keys=total_keys,
|
total_keys=total_keys,
|
||||||
active_keys=active_keys,
|
active_keys=active_keys,
|
||||||
)
|
)
|
||||||
@@ -284,6 +298,17 @@ class AdminUpdateProviderEndpointAdapter(AdminApiAdapter):
|
|||||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||||
|
|
||||||
update_data = self.endpoint_data.model_dump(exclude_unset=True)
|
update_data = self.endpoint_data.model_dump(exclude_unset=True)
|
||||||
|
# 把 proxy 转换为 dict 存储,支持显式设置为 None 清除代理
|
||||||
|
if "proxy" in update_data:
|
||||||
|
if update_data["proxy"] is not None:
|
||||||
|
new_proxy = dict(update_data["proxy"])
|
||||||
|
# 只有当密码字段未提供时才保留原密码(空字符串视为显式清除)
|
||||||
|
if "password" not in new_proxy and endpoint.proxy:
|
||||||
|
old_password = endpoint.proxy.get("password")
|
||||||
|
if old_password:
|
||||||
|
new_proxy["password"] = old_password
|
||||||
|
update_data["proxy"] = new_proxy
|
||||||
|
# proxy 为 None 时保留,用于清除代理配置
|
||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
setattr(endpoint, field, value)
|
setattr(endpoint, field, value)
|
||||||
endpoint.updated_at = datetime.now(timezone.utc)
|
endpoint.updated_at = datetime.now(timezone.utc)
|
||||||
@@ -311,12 +336,13 @@ class AdminUpdateProviderEndpointAdapter(AdminApiAdapter):
|
|||||||
endpoint_dict = {
|
endpoint_dict = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in endpoint.__dict__.items()
|
for k, v in endpoint.__dict__.items()
|
||||||
if k not in {"api_format", "_sa_instance_state"}
|
if k not in {"api_format", "_sa_instance_state", "proxy"}
|
||||||
}
|
}
|
||||||
return ProviderEndpointResponse(
|
return ProviderEndpointResponse(
|
||||||
**endpoint_dict,
|
**endpoint_dict,
|
||||||
provider_name=provider.name if provider else "Unknown",
|
provider_name=provider.name if provider else "Unknown",
|
||||||
api_format=endpoint.api_format,
|
api_format=endpoint.api_format,
|
||||||
|
proxy=mask_proxy_password(endpoint.proxy),
|
||||||
total_keys=total_keys,
|
total_keys=total_keys,
|
||||||
active_keys=active_keys,
|
active_keys=active_keys,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -140,7 +140,7 @@ class AnnouncementOptionalAuthAdapter(ApiAdapter):
|
|||||||
if not authorization or not authorization.lower().startswith("bearer "):
|
if not authorization or not authorization.lower().startswith("bearer "):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
token = authorization.replace("Bearer ", "").strip()
|
token = authorization[7:].strip()
|
||||||
try:
|
try:
|
||||||
payload = await AuthService.verify_token(token, token_type="access")
|
payload = await AuthService.verify_token(token, token_type="access")
|
||||||
user_id = payload.get("user_id")
|
user_id = payload.get("user_id")
|
||||||
|
|||||||
@@ -177,7 +177,7 @@ class ApiRequestPipeline:
|
|||||||
if not authorization or not authorization.lower().startswith("bearer "):
|
if not authorization or not authorization.lower().startswith("bearer "):
|
||||||
raise HTTPException(status_code=401, detail="缺少管理员凭证")
|
raise HTTPException(status_code=401, detail="缺少管理员凭证")
|
||||||
|
|
||||||
token = authorization.replace("Bearer ", "").strip()
|
token = authorization[7:].strip()
|
||||||
try:
|
try:
|
||||||
payload = await self.auth_service.verify_token(token, token_type="access")
|
payload = await self.auth_service.verify_token(token, token_type="access")
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
@@ -204,7 +204,7 @@ class ApiRequestPipeline:
|
|||||||
if not authorization or not authorization.lower().startswith("bearer "):
|
if not authorization or not authorization.lower().startswith("bearer "):
|
||||||
raise HTTPException(status_code=401, detail="缺少用户凭证")
|
raise HTTPException(status_code=401, detail="缺少用户凭证")
|
||||||
|
|
||||||
token = authorization.replace("Bearer ", "").strip()
|
token = authorization[7:].strip()
|
||||||
try:
|
try:
|
||||||
payload = await self.auth_service.verify_token(token, token_type="access")
|
payload = await self.auth_service.verify_token(token, token_type="access")
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|||||||
@@ -28,7 +28,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from typing import Any, Callable, Dict, Optional, Protocol, runtime_checkable
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
@@ -43,6 +43,9 @@ from src.services.provider.format import normalize_api_format
|
|||||||
from src.services.system.audit import audit_service
|
from src.services.system.audit import audit_service
|
||||||
from src.services.usage.service import UsageService
|
from src.services.usage.service import UsageService
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.api.handlers.base.stream_context import StreamContext
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MessageTelemetry:
|
class MessageTelemetry:
|
||||||
@@ -399,6 +402,41 @@ class BaseMessageHandler:
|
|||||||
# 创建后台任务,不阻塞当前流
|
# 创建后台任务,不阻塞当前流
|
||||||
asyncio.create_task(_do_update())
|
asyncio.create_task(_do_update())
|
||||||
|
|
||||||
|
def _update_usage_to_streaming_with_ctx(self, ctx: "StreamContext") -> None:
|
||||||
|
"""更新 Usage 状态为 streaming,同时更新 provider 和 target_model
|
||||||
|
|
||||||
|
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: 流式上下文,包含 provider_name 和 mapped_model
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
from src.database.database import get_db
|
||||||
|
|
||||||
|
target_request_id = self.request_id
|
||||||
|
provider = ctx.provider_name
|
||||||
|
target_model = ctx.mapped_model
|
||||||
|
|
||||||
|
async def _do_update() -> None:
|
||||||
|
try:
|
||||||
|
db_gen = get_db()
|
||||||
|
db = next(db_gen)
|
||||||
|
try:
|
||||||
|
UsageService.update_usage_status(
|
||||||
|
db=db,
|
||||||
|
request_id=target_request_id,
|
||||||
|
status="streaming",
|
||||||
|
provider=provider,
|
||||||
|
target_model=target_model,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[{target_request_id}] 更新 Usage 状态为 streaming 失败: {e}")
|
||||||
|
|
||||||
|
# 创建后台任务,不阻塞当前流
|
||||||
|
asyncio.create_task(_do_update())
|
||||||
|
|
||||||
def _log_request_error(self, message: str, error: Exception) -> None:
|
def _log_request_error(self, message: str, error: Exception) -> None:
|
||||||
"""记录请求错误日志,对业务异常不打印堆栈
|
"""记录请求错误日志,对业务异常不打印堆栈
|
||||||
|
|
||||||
|
|||||||
@@ -64,18 +64,6 @@ class ChatAdapterBase(ApiAdapter):
|
|||||||
|
|
||||||
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
||||||
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
|
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
|
||||||
self.response_normalizer = None
|
|
||||||
# 可选启用响应规范化
|
|
||||||
self._init_response_normalizer()
|
|
||||||
|
|
||||||
def _init_response_normalizer(self):
|
|
||||||
"""初始化响应规范化器 - 子类可覆盖"""
|
|
||||||
try:
|
|
||||||
from src.services.provider.response_normalizer import ResponseNormalizer
|
|
||||||
|
|
||||||
self.response_normalizer = ResponseNormalizer()
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def handle(self, context: ApiRequestContext):
|
async def handle(self, context: ApiRequestContext):
|
||||||
"""处理 Chat API 请求"""
|
"""处理 Chat API 请求"""
|
||||||
@@ -228,8 +216,6 @@ class ChatAdapterBase(ApiAdapter):
|
|||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
allowed_api_formats=self.allowed_api_formats,
|
allowed_api_formats=self.allowed_api_formats,
|
||||||
response_normalizer=self.response_normalizer,
|
|
||||||
enable_response_normalization=self.response_normalizer is not None,
|
|
||||||
adapter_detector=self.detect_capability_requirements,
|
adapter_detector=self.detect_capability_requirements,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -88,8 +88,6 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
user_agent: str,
|
user_agent: str,
|
||||||
start_time: float,
|
start_time: float,
|
||||||
allowed_api_formats: Optional[list] = None,
|
allowed_api_formats: Optional[list] = None,
|
||||||
response_normalizer: Optional[Any] = None,
|
|
||||||
enable_response_normalization: bool = False,
|
|
||||||
adapter_detector: Optional[Callable[[Dict[str, str], Optional[Dict[str, Any]]], Dict[str, bool]]] = None,
|
adapter_detector: Optional[Callable[[Dict[str, str], Optional[Dict[str, Any]]], Dict[str, bool]]] = None,
|
||||||
):
|
):
|
||||||
allowed = allowed_api_formats or [self.FORMAT_ID]
|
allowed = allowed_api_formats or [self.FORMAT_ID]
|
||||||
@@ -106,8 +104,6 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
)
|
)
|
||||||
self._parser: Optional[ResponseParser] = None
|
self._parser: Optional[ResponseParser] = None
|
||||||
self._request_builder = PassthroughRequestBuilder()
|
self._request_builder = PassthroughRequestBuilder()
|
||||||
self.response_normalizer = response_normalizer
|
|
||||||
self.enable_response_normalization = enable_response_normalization
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parser(self) -> ResponseParser:
|
def parser(self) -> ResponseParser:
|
||||||
@@ -297,11 +293,15 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
# 创建类型安全的流式上下文
|
# 创建类型安全的流式上下文
|
||||||
ctx = StreamContext(model=model, api_format=api_format)
|
ctx = StreamContext(model=model, api_format=api_format)
|
||||||
|
|
||||||
|
# 创建更新状态的回调闭包(可以访问 ctx)
|
||||||
|
def update_streaming_status() -> None:
|
||||||
|
self._update_usage_to_streaming_with_ctx(ctx)
|
||||||
|
|
||||||
# 创建流处理器
|
# 创建流处理器
|
||||||
stream_processor = StreamProcessor(
|
stream_processor = StreamProcessor(
|
||||||
request_id=self.request_id,
|
request_id=self.request_id,
|
||||||
default_parser=self.parser,
|
default_parser=self.parser,
|
||||||
on_streaming_start=self._update_usage_to_streaming,
|
on_streaming_start=update_streaming_status,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 定义请求函数
|
# 定义请求函数
|
||||||
@@ -466,7 +466,13 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
pool=config.http_pool_timeout,
|
pool=config.http_pool_timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
http_client = httpx.AsyncClient(timeout=timeout_config, follow_redirects=True)
|
# 创建 HTTP 客户端(支持代理配置)
|
||||||
|
from src.clients.http_client import HTTPClientPool
|
||||||
|
|
||||||
|
http_client = HTTPClientPool.create_client_with_proxy(
|
||||||
|
proxy_config=endpoint.proxy,
|
||||||
|
timeout=timeout_config,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
response_ctx = http_client.stream(
|
response_ctx = http_client.stream(
|
||||||
"POST", url, json=provider_payload, headers=provider_headers
|
"POST", url, json=provider_payload, headers=provider_headers
|
||||||
@@ -634,10 +640,14 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
logger.info(f" [{self.request_id}] 发送非流式请求: Provider={provider.name}, "
|
logger.info(f" [{self.request_id}] 发送非流式请求: Provider={provider.name}, "
|
||||||
f"模型={model} -> {mapped_model or '无映射'}")
|
f"模型={model} -> {mapped_model or '无映射'}")
|
||||||
|
|
||||||
async with httpx.AsyncClient(
|
# 创建 HTTP 客户端(支持代理配置)
|
||||||
timeout=float(endpoint.timeout),
|
from src.clients.http_client import HTTPClientPool
|
||||||
follow_redirects=True,
|
|
||||||
) as http_client:
|
http_client = HTTPClientPool.create_client_with_proxy(
|
||||||
|
proxy_config=endpoint.proxy,
|
||||||
|
timeout=httpx.Timeout(float(endpoint.timeout)),
|
||||||
|
)
|
||||||
|
async with http_client:
|
||||||
resp = await http_client.post(url, json=provider_payload, headers=provider_hdrs)
|
resp = await http_client.post(url, json=provider_payload, headers=provider_hdrs)
|
||||||
|
|
||||||
status_code = resp.status_code
|
status_code = resp.status_code
|
||||||
|
|||||||
@@ -454,7 +454,13 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
f"Key=***{key.api_key[-4:]}, "
|
f"Key=***{key.api_key[-4:]}, "
|
||||||
f"原始模型={ctx.model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}")
|
f"原始模型={ctx.model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}")
|
||||||
|
|
||||||
http_client = httpx.AsyncClient(timeout=timeout_config, follow_redirects=True)
|
# 创建 HTTP 客户端(支持代理配置)
|
||||||
|
from src.clients.http_client import HTTPClientPool
|
||||||
|
|
||||||
|
http_client = HTTPClientPool.create_client_with_proxy(
|
||||||
|
proxy_config=endpoint.proxy,
|
||||||
|
timeout=timeout_config,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
response_ctx = http_client.stream(
|
response_ctx = http_client.stream(
|
||||||
"POST", url, json=provider_payload, headers=provider_headers
|
"POST", url, json=provider_payload, headers=provider_headers
|
||||||
@@ -526,7 +532,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
async for chunk in stream_response.aiter_raw():
|
async for chunk in stream_response.aiter_raw():
|
||||||
# 在第一次输出数据前更新状态为 streaming
|
# 在第一次输出数据前更新状态为 streaming
|
||||||
if not streaming_status_updated:
|
if not streaming_status_updated:
|
||||||
self._update_usage_to_streaming(ctx.request_id)
|
self._update_usage_to_streaming_with_ctx(ctx)
|
||||||
streaming_status_updated = True
|
streaming_status_updated = True
|
||||||
|
|
||||||
buffer += chunk
|
buffer += chunk
|
||||||
@@ -810,7 +816,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
|
|
||||||
# 在第一次输出数据前更新状态为 streaming
|
# 在第一次输出数据前更新状态为 streaming
|
||||||
if prefetched_chunks:
|
if prefetched_chunks:
|
||||||
self._update_usage_to_streaming(ctx.request_id)
|
self._update_usage_to_streaming_with_ctx(ctx)
|
||||||
|
|
||||||
# 先处理预读的字节块
|
# 先处理预读的字节块
|
||||||
for chunk in prefetched_chunks:
|
for chunk in prefetched_chunks:
|
||||||
@@ -1419,10 +1425,14 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
f"Key=***{key.api_key[-4:]}, "
|
f"Key=***{key.api_key[-4:]}, "
|
||||||
f"原始模型={model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}")
|
f"原始模型={model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}")
|
||||||
|
|
||||||
async with httpx.AsyncClient(
|
# 创建 HTTP 客户端(支持代理配置)
|
||||||
timeout=float(endpoint.timeout),
|
from src.clients.http_client import HTTPClientPool
|
||||||
follow_redirects=True,
|
|
||||||
) as http_client:
|
http_client = HTTPClientPool.create_client_with_proxy(
|
||||||
|
proxy_config=endpoint.proxy,
|
||||||
|
timeout=httpx.Timeout(float(endpoint.timeout)),
|
||||||
|
)
|
||||||
|
async with http_client:
|
||||||
resp = await http_client.post(url, json=provider_payload, headers=provider_headers)
|
resp = await http_client.post(url, json=provider_payload, headers=provider_headers)
|
||||||
|
|
||||||
status_code = resp.status_code
|
status_code = resp.status_code
|
||||||
|
|||||||
@@ -131,10 +131,5 @@ class ClaudeChatHandler(ChatHandlerBase):
|
|||||||
Returns:
|
Returns:
|
||||||
规范化后的响应
|
规范化后的响应
|
||||||
"""
|
"""
|
||||||
if self.response_normalizer and self.response_normalizer.should_normalize(response):
|
# 作为中转站,直接透传响应,不做标准化处理
|
||||||
result: Dict[str, Any] = self.response_normalizer.normalize_claude_response(
|
|
||||||
response_data=response,
|
|
||||||
request_id=self.request_id,
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
return response
|
return response
|
||||||
|
|||||||
@@ -148,17 +148,6 @@ class GeminiChatHandler(ChatHandlerBase):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
规范化后的响应
|
规范化后的响应
|
||||||
|
|
||||||
TODO: 如果需要,实现响应规范化逻辑
|
|
||||||
"""
|
"""
|
||||||
# 可选:使用 response_normalizer 进行规范化
|
# 作为中转站,直接透传响应,不做标准化处理
|
||||||
# if (
|
|
||||||
# self.response_normalizer
|
|
||||||
# and self.response_normalizer.should_normalize(response)
|
|
||||||
# ):
|
|
||||||
# return self.response_normalizer.normalize_gemini_response(
|
|
||||||
# response_data=response,
|
|
||||||
# request_id=self.request_id,
|
|
||||||
# strict=False,
|
|
||||||
# )
|
|
||||||
return response
|
return response
|
||||||
|
|||||||
@@ -128,10 +128,5 @@ class OpenAIChatHandler(ChatHandlerBase):
|
|||||||
Returns:
|
Returns:
|
||||||
规范化后的响应
|
规范化后的响应
|
||||||
"""
|
"""
|
||||||
if self.response_normalizer and self.response_normalizer.should_normalize(response):
|
# 作为中转站,直接透传响应,不做标准化处理
|
||||||
return self.response_normalizer.normalize_openai_response(
|
|
||||||
response_data=response,
|
|
||||||
request_id=self.request_id,
|
|
||||||
strict=False,
|
|
||||||
)
|
|
||||||
return response
|
return response
|
||||||
|
|||||||
@@ -5,12 +5,55 @@
|
|||||||
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
from urllib.parse import quote, urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
|
|
||||||
|
|
||||||
|
def build_proxy_url(proxy_config: Dict[str, Any]) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
根据代理配置构建完整的代理 URL
|
||||||
|
|
||||||
|
Args:
|
||||||
|
proxy_config: 代理配置字典,包含 url, username, password, enabled
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
完整的代理 URL,如 socks5://user:pass@host:port
|
||||||
|
如果 enabled=False 或无配置,返回 None
|
||||||
|
"""
|
||||||
|
if not proxy_config:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 检查 enabled 字段,默认为 True(兼容旧数据)
|
||||||
|
if not proxy_config.get("enabled", True):
|
||||||
|
return None
|
||||||
|
|
||||||
|
proxy_url = proxy_config.get("url")
|
||||||
|
if not proxy_url:
|
||||||
|
return None
|
||||||
|
|
||||||
|
username = proxy_config.get("username")
|
||||||
|
password = proxy_config.get("password")
|
||||||
|
|
||||||
|
# 只要有用户名就添加认证信息(密码可以为空)
|
||||||
|
if username:
|
||||||
|
parsed = urlparse(proxy_url)
|
||||||
|
# URL 编码用户名和密码,处理特殊字符(如 @, :, /)
|
||||||
|
encoded_username = quote(username, safe="")
|
||||||
|
encoded_password = quote(password, safe="") if password else ""
|
||||||
|
# 重新构建带认证的代理 URL
|
||||||
|
if encoded_password:
|
||||||
|
auth_proxy = f"{parsed.scheme}://{encoded_username}:{encoded_password}@{parsed.netloc}"
|
||||||
|
else:
|
||||||
|
auth_proxy = f"{parsed.scheme}://{encoded_username}@{parsed.netloc}"
|
||||||
|
if parsed.path:
|
||||||
|
auth_proxy += parsed.path
|
||||||
|
return auth_proxy
|
||||||
|
|
||||||
|
return proxy_url
|
||||||
|
|
||||||
|
|
||||||
class HTTPClientPool:
|
class HTTPClientPool:
|
||||||
"""
|
"""
|
||||||
@@ -121,6 +164,44 @@ class HTTPClientPool:
|
|||||||
finally:
|
finally:
|
||||||
await client.aclose()
|
await client.aclose()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_client_with_proxy(
|
||||||
|
cls,
|
||||||
|
proxy_config: Optional[Dict[str, Any]] = None,
|
||||||
|
timeout: Optional[httpx.Timeout] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> httpx.AsyncClient:
|
||||||
|
"""
|
||||||
|
创建带代理配置的HTTP客户端
|
||||||
|
|
||||||
|
Args:
|
||||||
|
proxy_config: 代理配置字典,包含 url, username, password
|
||||||
|
timeout: 超时配置
|
||||||
|
**kwargs: 其他 httpx.AsyncClient 配置参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
配置好的 httpx.AsyncClient 实例
|
||||||
|
"""
|
||||||
|
config: Dict[str, Any] = {
|
||||||
|
"http2": False,
|
||||||
|
"verify": True,
|
||||||
|
"follow_redirects": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
if timeout:
|
||||||
|
config["timeout"] = timeout
|
||||||
|
else:
|
||||||
|
config["timeout"] = httpx.Timeout(10.0, read=300.0)
|
||||||
|
|
||||||
|
# 添加代理配置
|
||||||
|
proxy_url = build_proxy_url(proxy_config) if proxy_config else None
|
||||||
|
if proxy_url:
|
||||||
|
config["proxy"] = proxy_url
|
||||||
|
logger.debug(f"创建带代理的HTTP客户端: {proxy_config.get('url', 'unknown')}")
|
||||||
|
|
||||||
|
config.update(kwargs)
|
||||||
|
return httpx.AsyncClient(**config)
|
||||||
|
|
||||||
|
|
||||||
# 便捷访问函数
|
# 便捷访问函数
|
||||||
def get_http_client() -> httpx.AsyncClient:
|
def get_http_client() -> httpx.AsyncClient:
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ class RedisClientManager:
|
|||||||
if self._circuit_open_until and time.time() < self._circuit_open_until:
|
if self._circuit_open_until and time.time() < self._circuit_open_until:
|
||||||
remaining = self._circuit_open_until - time.time()
|
remaining = self._circuit_open_until - time.time()
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Redis 客户端处于熔断状态,跳过初始化,剩余 %.1f 秒 (last_error: %s)",
|
"Redis 客户端处于熔断状态,跳过初始化,剩余 {:.1f} 秒 (last_error: {})",
|
||||||
remaining,
|
remaining,
|
||||||
self._last_error,
|
self._last_error,
|
||||||
)
|
)
|
||||||
@@ -200,7 +200,7 @@ class RedisClientManager:
|
|||||||
if self._consecutive_failures >= self._circuit_threshold:
|
if self._consecutive_failures >= self._circuit_threshold:
|
||||||
self._circuit_open_until = time.time() + self._circuit_reset_seconds
|
self._circuit_open_until = time.time() + self._circuit_reset_seconds
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Redis 初始化连续失败 %s 次,开启熔断 %s 秒。"
|
"Redis 初始化连续失败 {} 次,开启熔断 {} 秒。"
|
||||||
"熔断期间以下功能将降级: 缓存亲和性、分布式并发控制、RPM限流。"
|
"熔断期间以下功能将降级: 缓存亲和性、分布式并发控制、RPM限流。"
|
||||||
"可通过管理 API /api/admin/system/redis/reset-circuit 手动重置。",
|
"可通过管理 API /api/admin/system/redis/reset-circuit 手动重置。",
|
||||||
self._consecutive_failures,
|
self._consecutive_failures,
|
||||||
|
|||||||
@@ -105,6 +105,13 @@ class Config:
|
|||||||
self.llm_api_rate_limit = int(os.getenv("LLM_API_RATE_LIMIT", "100"))
|
self.llm_api_rate_limit = int(os.getenv("LLM_API_RATE_LIMIT", "100"))
|
||||||
self.public_api_rate_limit = int(os.getenv("PUBLIC_API_RATE_LIMIT", "60"))
|
self.public_api_rate_limit = int(os.getenv("PUBLIC_API_RATE_LIMIT", "60"))
|
||||||
|
|
||||||
|
# 可信代理配置
|
||||||
|
# TRUSTED_PROXY_COUNT: 信任的代理层数(默认 1,即信任最近一层代理)
|
||||||
|
# 设置为 0 表示不信任任何代理头,直接使用连接 IP
|
||||||
|
# 当服务部署在 Nginx/CloudFlare 等反向代理后面时,设置为对应的代理层数
|
||||||
|
# 如果服务直接暴露公网,应设置为 0 以防止 IP 伪造
|
||||||
|
self.trusted_proxy_count = int(os.getenv("TRUSTED_PROXY_COUNT", "1"))
|
||||||
|
|
||||||
# 异常处理配置
|
# 异常处理配置
|
||||||
# 设置为 True 时,ProxyException 会传播到路由层以便记录 provider_request_headers
|
# 设置为 True 时,ProxyException 会传播到路由层以便记录 provider_request_headers
|
||||||
# 设置为 False 时,使用全局异常处理器统一处理
|
# 设置为 False 时,使用全局异常处理器统一处理
|
||||||
|
|||||||
@@ -153,7 +153,7 @@ def _log_pool_capacity():
|
|||||||
total_estimated = theoretical * workers
|
total_estimated = theoretical * workers
|
||||||
safe_limit = config.pg_max_connections - config.pg_reserved_connections
|
safe_limit = config.pg_max_connections - config.pg_reserved_connections
|
||||||
logger.info(
|
logger.info(
|
||||||
"数据库连接池配置: pool_size=%s, max_overflow=%s, workers=%s, total_estimated=%s, safe_limit=%s",
|
"数据库连接池配置: pool_size={}, max_overflow={}, workers={}, total_estimated={}, safe_limit={}",
|
||||||
config.db_pool_size,
|
config.db_pool_size,
|
||||||
config.db_max_overflow,
|
config.db_max_overflow,
|
||||||
workers,
|
workers,
|
||||||
@@ -162,7 +162,7 @@ def _log_pool_capacity():
|
|||||||
)
|
)
|
||||||
if total_estimated > safe_limit:
|
if total_estimated > safe_limit:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"数据库连接池总需求可能超过 PostgreSQL 限制: %s > %s (pg_max_connections - reserved),"
|
"数据库连接池总需求可能超过 PostgreSQL 限制: {} > {} (pg_max_connections - reserved),"
|
||||||
"建议调整 DB_POOL_SIZE/DB_MAX_OVERFLOW 或减少 worker 数",
|
"建议调整 DB_POOL_SIZE/DB_MAX_OVERFLOW 或减少 worker 数",
|
||||||
total_estimated,
|
total_estimated,
|
||||||
safe_limit,
|
safe_limit,
|
||||||
@@ -260,7 +260,8 @@ def get_db(request: Request = None) -> Generator[Session, None, None]: # type:
|
|||||||
|
|
||||||
2. **管理后台 API**:
|
2. **管理后台 API**:
|
||||||
- 路由层显式调用 db.commit()
|
- 路由层显式调用 db.commit()
|
||||||
- 每个操作独立提交,不依赖中间件
|
- 提交后设置 request.state.tx_committed_by_route = True
|
||||||
|
- 中间件看到此标志后跳过 commit,只负责 close
|
||||||
|
|
||||||
3. **后台任务/调度器**:
|
3. **后台任务/调度器**:
|
||||||
- 使用独立 Session(通过 create_session() 或 next(get_db()))
|
- 使用独立 Session(通过 create_session() 或 next(get_db()))
|
||||||
@@ -271,6 +272,17 @@ def get_db(request: Request = None) -> Generator[Session, None, None]: # type:
|
|||||||
- FastAPI 请求:通过 Depends(get_db) 注入,支持中间件管理的 session 复用
|
- FastAPI 请求:通过 Depends(get_db) 注入,支持中间件管理的 session 复用
|
||||||
- 非请求上下文:直接调用 get_db(),退化为独立 session 模式
|
- 非请求上下文:直接调用 get_db(),退化为独立 session 模式
|
||||||
|
|
||||||
|
路由层提交事务示例
|
||||||
|
==================
|
||||||
|
```python
|
||||||
|
@router.post("/example")
|
||||||
|
async def example(request: Request, db: Session = Depends(get_db)):
|
||||||
|
# ... 业务逻辑 ...
|
||||||
|
db.commit()
|
||||||
|
request.state.tx_committed_by_route = True # 告知中间件已提交
|
||||||
|
return {"message": "success"}
|
||||||
|
```
|
||||||
|
|
||||||
注意事项
|
注意事项
|
||||||
========
|
========
|
||||||
- 本函数不自动提交事务
|
- 本函数不自动提交事务
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ async def initialize_providers():
|
|||||||
# 从数据库加载所有活跃的提供商
|
# 从数据库加载所有活跃的提供商
|
||||||
providers = (
|
providers = (
|
||||||
db.query(Provider)
|
db.query(Provider)
|
||||||
.filter(Provider.is_active == True)
|
.filter(Provider.is_active.is_(True))
|
||||||
.order_by(Provider.provider_priority.asc())
|
.order_by(Provider.provider_priority.asc())
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
@@ -122,6 +122,7 @@ async def lifespan(app: FastAPI):
|
|||||||
logger.info("初始化全局Redis客户端...")
|
logger.info("初始化全局Redis客户端...")
|
||||||
from src.clients.redis_client import get_redis_client
|
from src.clients.redis_client import get_redis_client
|
||||||
|
|
||||||
|
redis_client = None
|
||||||
try:
|
try:
|
||||||
redis_client = await get_redis_client(require_redis=config.require_redis)
|
redis_client = await get_redis_client(require_redis=config.require_redis)
|
||||||
if redis_client:
|
if redis_client:
|
||||||
@@ -133,6 +134,7 @@ async def lifespan(app: FastAPI):
|
|||||||
logger.exception("[ERROR] Redis连接失败,应用启动中止")
|
logger.exception("[ERROR] Redis连接失败,应用启动中止")
|
||||||
raise
|
raise
|
||||||
logger.warning(f"Redis连接失败,但配置允许降级,将继续使用内存模式: {e}")
|
logger.warning(f"Redis连接失败,但配置允许降级,将继续使用内存模式: {e}")
|
||||||
|
redis_client = None
|
||||||
|
|
||||||
# 初始化并发管理器(内部会使用Redis)
|
# 初始化并发管理器(内部会使用Redis)
|
||||||
logger.info("初始化并发管理器...")
|
logger.info("初始化并发管理器...")
|
||||||
@@ -312,7 +314,7 @@ if frontend_dist.exists():
|
|||||||
仅对非API路径生效
|
仅对非API路径生效
|
||||||
"""
|
"""
|
||||||
# 如果是API路径,不处理
|
# 如果是API路径,不处理
|
||||||
if full_path.startswith("api/") or full_path.startswith("v1/"):
|
if full_path in {"api", "v1"} or full_path.startswith(("api/", "v1/")):
|
||||||
raise HTTPException(status_code=404, detail="Not Found")
|
raise HTTPException(status_code=404, detail="Not Found")
|
||||||
|
|
||||||
# 返回index.html,让前端路由处理
|
# 返回index.html,让前端路由处理
|
||||||
|
|||||||
@@ -1,16 +1,17 @@
|
|||||||
"""
|
"""
|
||||||
统一的插件中间件
|
统一的插件中间件(纯 ASGI 实现)
|
||||||
负责协调所有插件的调用
|
负责协调所有插件的调用
|
||||||
|
|
||||||
|
注意:使用纯 ASGI middleware 而非 BaseHTTPMiddleware,
|
||||||
|
以避免 Starlette 已知的流式响应兼容性问题。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
import time
|
import time
|
||||||
from typing import Any, Awaitable, Callable, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import HTTPException, Request
|
from starlette.requests import Request
|
||||||
from fastapi.responses import JSONResponse
|
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
|
||||||
from starlette.responses import Response as StarletteResponse
|
|
||||||
|
|
||||||
from src.config import config
|
from src.config import config
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
@@ -18,20 +19,25 @@ from src.plugins.manager import get_plugin_manager
|
|||||||
from src.plugins.rate_limit.base import RateLimitResult
|
from src.plugins.rate_limit.base import RateLimitResult
|
||||||
|
|
||||||
|
|
||||||
|
class PluginMiddleware:
|
||||||
class PluginMiddleware(BaseHTTPMiddleware):
|
|
||||||
"""
|
"""
|
||||||
统一的插件调用中间件
|
统一的插件调用中间件(纯 ASGI 实现)
|
||||||
|
|
||||||
职责:
|
职责:
|
||||||
- 性能监控
|
- 性能监控
|
||||||
- 限流控制 (可选)
|
- 限流控制 (可选)
|
||||||
|
- 数据库会话生命周期管理
|
||||||
|
|
||||||
注意: 认证由各路由通过 Depends() 显式声明,不在中间件层处理
|
注意: 认证由各路由通过 Depends() 显式声明,不在中间件层处理
|
||||||
|
|
||||||
|
为什么使用纯 ASGI 而非 BaseHTTPMiddleware:
|
||||||
|
- BaseHTTPMiddleware 会缓冲整个响应体,对流式响应不友好
|
||||||
|
- BaseHTTPMiddleware 与 StreamingResponse 存在已知兼容性问题
|
||||||
|
- 纯 ASGI 可以直接透传流式响应,无额外开销
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, app: Any) -> None:
|
def __init__(self, app: ASGIApp) -> None:
|
||||||
super().__init__(app)
|
self.app = app
|
||||||
self.plugin_manager = get_plugin_manager()
|
self.plugin_manager = get_plugin_manager()
|
||||||
|
|
||||||
# 从配置读取速率限制值
|
# 从配置读取速率限制值
|
||||||
@@ -61,152 +67,159 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
|||||||
"/v1/completions",
|
"/v1/completions",
|
||||||
]
|
]
|
||||||
|
|
||||||
async def dispatch(
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
self, request: Request, call_next: Callable[[Request], Awaitable[StarletteResponse]]
|
"""ASGI 入口点"""
|
||||||
) -> StarletteResponse:
|
if scope["type"] != "http":
|
||||||
"""处理请求并调用相应插件"""
|
# 非 HTTP 请求(如 WebSocket)直接透传
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 构建 Request 对象以便复用现有逻辑
|
||||||
|
request = Request(scope, receive, send)
|
||||||
|
|
||||||
# 记录请求开始时间
|
# 记录请求开始时间
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 设置 request.state 属性
|
||||||
|
# 注意:Starlette 的 Request 对象总是有 state 属性(State 实例)
|
||||||
request.state.request_id = request.headers.get("x-request-id", "")
|
request.state.request_id = request.headers.get("x-request-id", "")
|
||||||
request.state.start_time = start_time
|
request.state.start_time = start_time
|
||||||
# 标记:若请求过程中通过 Depends(get_db) 创建了会话,则由本中间件统一管理其生命周期
|
# 标记:若请求过程中通过 Depends(get_db) 创建了会话,则由本中间件统一管理其生命周期
|
||||||
request.state.db_managed_by_middleware = True
|
request.state.db_managed_by_middleware = True
|
||||||
|
|
||||||
response = None
|
# 1. 限流检查(在调用下游之前)
|
||||||
exception_to_raise = None
|
rate_limit_result = await self._call_rate_limit_plugins(request)
|
||||||
|
if rate_limit_result and not rate_limit_result.allowed:
|
||||||
|
# 限流触发,返回429
|
||||||
|
await self._send_rate_limit_response(send, rate_limit_result)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 2. 预处理插件调用
|
||||||
|
await self._call_pre_request_plugins(request)
|
||||||
|
|
||||||
|
# 用于捕获响应状态码
|
||||||
|
response_status_code: int = 0
|
||||||
|
|
||||||
|
async def send_wrapper(message: Message) -> None:
|
||||||
|
nonlocal response_status_code
|
||||||
|
|
||||||
|
if message["type"] == "http.response.start":
|
||||||
|
response_status_code = message.get("status", 0)
|
||||||
|
|
||||||
|
await send(message)
|
||||||
|
|
||||||
|
# 3. 调用下游应用
|
||||||
|
exception_occurred: Optional[Exception] = None
|
||||||
try:
|
try:
|
||||||
# 1. 限流插件调用(可选功能)
|
await self.app(scope, receive, send_wrapper)
|
||||||
rate_limit_result = await self._call_rate_limit_plugins(request)
|
|
||||||
if rate_limit_result and not rate_limit_result.allowed:
|
|
||||||
# 限流触发,返回429
|
|
||||||
headers = rate_limit_result.headers or {}
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=429,
|
|
||||||
detail=rate_limit_result.message or "Rate limit exceeded",
|
|
||||||
headers=headers,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. 预处理插件调用
|
|
||||||
await self._call_pre_request_plugins(request)
|
|
||||||
|
|
||||||
# 处理请求
|
|
||||||
response = await call_next(request)
|
|
||||||
|
|
||||||
# 3. 提交关键数据库事务(在返回响应前)
|
|
||||||
# 这确保了 Usage 记录、配额扣减等关键数据在响应返回前持久化
|
|
||||||
try:
|
|
||||||
db = getattr(request.state, "db", None)
|
|
||||||
if isinstance(db, Session):
|
|
||||||
db.commit()
|
|
||||||
except Exception as commit_error:
|
|
||||||
logger.error(f"关键事务提交失败: {commit_error}")
|
|
||||||
try:
|
|
||||||
if isinstance(db, Session):
|
|
||||||
db.rollback()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
await self._call_error_plugins(request, commit_error, start_time)
|
|
||||||
# 返回 500 错误,因为数据可能不一致
|
|
||||||
response = JSONResponse(
|
|
||||||
status_code=500,
|
|
||||||
content={
|
|
||||||
"type": "error",
|
|
||||||
"error": {
|
|
||||||
"type": "database_error",
|
|
||||||
"message": "数据保存失败,请重试",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# 跳过后处理插件,直接返回错误响应
|
|
||||||
return response
|
|
||||||
|
|
||||||
# 4. 后处理插件调用(监控等,非关键操作)
|
|
||||||
# 这些操作失败不应影响用户响应
|
|
||||||
await self._call_post_request_plugins(request, response, start_time)
|
|
||||||
|
|
||||||
# 注意:不在此处添加限流响应头,因为在BaseHTTPMiddleware中
|
|
||||||
# 响应返回后修改headers会导致Content-Length不匹配错误
|
|
||||||
# 限流响应头已在返回429错误时正确包含(见上面的HTTPException)
|
|
||||||
|
|
||||||
except RuntimeError as e:
|
|
||||||
if str(e) == "No response returned.":
|
|
||||||
db = getattr(request.state, "db", None)
|
|
||||||
if isinstance(db, Session):
|
|
||||||
try:
|
|
||||||
db.rollback()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
logger.error("Downstream handler completed without returning a response")
|
|
||||||
|
|
||||||
await self._call_error_plugins(request, e, start_time)
|
|
||||||
|
|
||||||
if isinstance(db, Session):
|
|
||||||
try:
|
|
||||||
db.commit()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
response = JSONResponse(
|
|
||||||
status_code=500,
|
|
||||||
content={
|
|
||||||
"type": "error",
|
|
||||||
"error": {
|
|
||||||
"type": "internal_error",
|
|
||||||
"message": "Internal server error: downstream handler returned no response.",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
exception_to_raise = e
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 回滚数据库事务
|
exception_occurred = e
|
||||||
db = getattr(request.state, "db", None)
|
|
||||||
if isinstance(db, Session):
|
|
||||||
try:
|
|
||||||
db.rollback()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 错误处理插件调用
|
# 错误处理插件调用
|
||||||
await self._call_error_plugins(request, e, start_time)
|
await self._call_error_plugins(request, e, start_time)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# 4. 数据库会话清理(无论成功与否)
|
||||||
|
await self._cleanup_db_session(request, exception_occurred)
|
||||||
|
|
||||||
# 尝试提交错误日志
|
# 5. 后处理插件调用(仅在成功时)
|
||||||
if isinstance(db, Session):
|
if not exception_occurred and response_status_code > 0:
|
||||||
|
await self._call_post_request_plugins(request, response_status_code, start_time)
|
||||||
|
|
||||||
|
async def _send_rate_limit_response(
|
||||||
|
self, send: Send, result: RateLimitResult
|
||||||
|
) -> None:
|
||||||
|
"""发送 429 限流响应"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
body = json.dumps({
|
||||||
|
"type": "error",
|
||||||
|
"error": {
|
||||||
|
"type": "rate_limit_error",
|
||||||
|
"message": result.message or "Rate limit exceeded",
|
||||||
|
},
|
||||||
|
}).encode("utf-8")
|
||||||
|
|
||||||
|
headers = [(b"content-type", b"application/json")]
|
||||||
|
if result.headers:
|
||||||
|
for key, value in result.headers.items():
|
||||||
|
headers.append((key.lower().encode(), str(value).encode()))
|
||||||
|
|
||||||
|
await send({
|
||||||
|
"type": "http.response.start",
|
||||||
|
"status": 429,
|
||||||
|
"headers": headers,
|
||||||
|
})
|
||||||
|
await send({
|
||||||
|
"type": "http.response.body",
|
||||||
|
"body": body,
|
||||||
|
})
|
||||||
|
|
||||||
|
async def _cleanup_db_session(
|
||||||
|
self, request: Request, exception: Optional[Exception]
|
||||||
|
) -> None:
|
||||||
|
"""清理数据库会话
|
||||||
|
|
||||||
|
事务策略:
|
||||||
|
- 如果 request.state.tx_committed_by_route 为 True,说明路由已自行提交,中间件只负责 close
|
||||||
|
- 否则由中间件统一 commit/rollback
|
||||||
|
|
||||||
|
这避免了双重提交的问题,同时保持向后兼容。
|
||||||
|
"""
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
db = getattr(request.state, "db", None)
|
||||||
|
if not isinstance(db, Session):
|
||||||
|
return
|
||||||
|
|
||||||
|
# 检查是否由路由层已经提交
|
||||||
|
tx_committed_by_route = getattr(request.state, "tx_committed_by_route", False)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if exception is not None:
|
||||||
|
# 发生异常,回滚事务(无论谁负责提交)
|
||||||
|
try:
|
||||||
|
db.rollback()
|
||||||
|
except Exception as rollback_error:
|
||||||
|
logger.debug(f"回滚事务时出错(可忽略): {rollback_error}")
|
||||||
|
elif not tx_committed_by_route:
|
||||||
|
# 正常完成且路由未自行提交,由中间件提交事务
|
||||||
try:
|
try:
|
||||||
db.commit()
|
db.commit()
|
||||||
except:
|
except Exception as commit_error:
|
||||||
pass
|
logger.error(f"关键事务提交失败: {commit_error}")
|
||||||
|
try:
|
||||||
exception_to_raise = e
|
db.rollback()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# 如果 tx_committed_by_route 为 True,跳过 commit(路由已提交)
|
||||||
finally:
|
finally:
|
||||||
db = getattr(request.state, "db", None)
|
# 关闭会话,归还连接到连接池
|
||||||
if isinstance(db, Session):
|
try:
|
||||||
try:
|
db.close()
|
||||||
db.close()
|
except Exception as close_error:
|
||||||
except Exception as close_error:
|
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
|
||||||
# 连接池会处理连接的回收,这里的异常不应影响响应
|
|
||||||
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
|
|
||||||
|
|
||||||
# 在 finally 块之后处理异常和响应
|
|
||||||
if exception_to_raise:
|
|
||||||
raise exception_to_raise
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
def _get_client_ip(self, request: Request) -> str:
|
def _get_client_ip(self, request: Request) -> str:
|
||||||
"""
|
"""
|
||||||
获取客户端 IP 地址,支持代理头
|
获取客户端 IP 地址,支持代理头
|
||||||
|
|
||||||
|
注意:此方法信任 X-Forwarded-For 和 X-Real-IP 头,
|
||||||
|
仅当服务部署在可信代理(如 Nginx、CloudFlare)后面时才安全。
|
||||||
|
如果服务直接暴露公网,攻击者可伪造这些头绕过限流。
|
||||||
"""
|
"""
|
||||||
|
# 从配置获取可信代理层数(默认为 1,即信任最近一层代理)
|
||||||
|
trusted_proxy_count = getattr(config, "trusted_proxy_count", 1)
|
||||||
|
|
||||||
# 优先从代理头获取真实 IP
|
# 优先从代理头获取真实 IP
|
||||||
forwarded_for = request.headers.get("x-forwarded-for")
|
forwarded_for = request.headers.get("x-forwarded-for")
|
||||||
if forwarded_for:
|
if forwarded_for:
|
||||||
# X-Forwarded-For 可能包含多个 IP,取第一个
|
# X-Forwarded-For 格式: "client, proxy1, proxy2"
|
||||||
return forwarded_for.split(",")[0].strip()
|
# 从右往左数 trusted_proxy_count 个,取其左边的第一个
|
||||||
|
ips = [ip.strip() for ip in forwarded_for.split(",")]
|
||||||
|
if len(ips) > trusted_proxy_count:
|
||||||
|
return ips[-(trusted_proxy_count + 1)]
|
||||||
|
elif ips:
|
||||||
|
return ips[0]
|
||||||
|
|
||||||
real_ip = request.headers.get("x-real-ip")
|
real_ip = request.headers.get("x-real-ip")
|
||||||
if real_ip:
|
if real_ip:
|
||||||
@@ -248,13 +261,11 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
|||||||
auth_header = request.headers.get("authorization", "")
|
auth_header = request.headers.get("authorization", "")
|
||||||
api_key = request.headers.get("x-api-key", "")
|
api_key = request.headers.get("x-api-key", "")
|
||||||
|
|
||||||
if auth_header.startswith("Bearer "):
|
if auth_header.lower().startswith("bearer "):
|
||||||
api_key = auth_header[7:]
|
api_key = auth_header[7:]
|
||||||
|
|
||||||
if api_key:
|
if api_key:
|
||||||
# 使用 API Key 的哈希作为限制 key(避免日志泄露完整 key)
|
# 使用 API Key 的哈希作为限制 key(避免日志泄露完整 key)
|
||||||
import hashlib
|
|
||||||
|
|
||||||
key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:16]
|
key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:16]
|
||||||
key = f"llm_api_key:{key_hash}"
|
key = f"llm_api_key:{key_hash}"
|
||||||
request.state.rate_limit_key_type = "api_key"
|
request.state.rate_limit_key_type = "api_key"
|
||||||
@@ -319,7 +330,10 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# 限流触发,记录日志
|
# 限流触发,记录日志
|
||||||
logger.warning(f"速率限制触发: {getattr(request.state, 'rate_limit_key_type', 'unknown')}")
|
logger.warning(
|
||||||
|
"速率限制触发: {}",
|
||||||
|
getattr(request.state, "rate_limit_key_type", "unknown"),
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -332,7 +346,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
async def _call_post_request_plugins(
|
async def _call_post_request_plugins(
|
||||||
self, request: Request, response: StarletteResponse, start_time: float
|
self, request: Request, status_code: int, start_time: float
|
||||||
) -> None:
|
) -> None:
|
||||||
"""调用请求后的插件"""
|
"""调用请求后的插件"""
|
||||||
|
|
||||||
@@ -345,8 +359,8 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
|||||||
monitor_labels = {
|
monitor_labels = {
|
||||||
"method": request.method,
|
"method": request.method,
|
||||||
"endpoint": request.url.path,
|
"endpoint": request.url.path,
|
||||||
"status": str(response.status_code),
|
"status": str(status_code),
|
||||||
"status_class": f"{response.status_code // 100}xx",
|
"status_class": f"{status_code // 100}xx",
|
||||||
}
|
}
|
||||||
|
|
||||||
# 记录请求计数
|
# 记录请求计数
|
||||||
@@ -368,6 +382,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
|||||||
self, request: Request, error: Exception, start_time: float
|
self, request: Request, error: Exception, start_time: float
|
||||||
) -> None:
|
) -> None:
|
||||||
"""调用错误处理插件"""
|
"""调用错误处理插件"""
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
|
|
||||||
@@ -380,7 +395,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
|||||||
error=error,
|
error=error,
|
||||||
context={
|
context={
|
||||||
"endpoint": f"{request.method} {request.url.path}",
|
"endpoint": f"{request.method} {request.url.path}",
|
||||||
"request_id": request.state.request_id,
|
"request_id": getattr(request.state, "request_id", ""),
|
||||||
"duration": duration,
|
"duration": duration,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -13,6 +13,42 @@ from pydantic import BaseModel, Field, field_validator, model_validator
|
|||||||
from src.core.enums import APIFormat, ProviderBillingType
|
from src.core.enums import APIFormat, ProviderBillingType
|
||||||
|
|
||||||
|
|
||||||
|
class ProxyConfig(BaseModel):
|
||||||
|
"""代理配置"""
|
||||||
|
|
||||||
|
url: str = Field(..., description="代理 URL (http://, https://, socks5://)")
|
||||||
|
username: Optional[str] = Field(None, max_length=255, description="代理用户名")
|
||||||
|
password: Optional[str] = Field(None, max_length=500, description="代理密码")
|
||||||
|
enabled: bool = Field(True, description="是否启用代理(false 时保留配置但不使用)")
|
||||||
|
|
||||||
|
@field_validator("url")
|
||||||
|
@classmethod
|
||||||
|
def validate_proxy_url(cls, v: str) -> str:
|
||||||
|
"""验证代理 URL 格式"""
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
v = v.strip()
|
||||||
|
|
||||||
|
# 检查禁止的字符(防止注入)
|
||||||
|
if "\n" in v or "\r" in v:
|
||||||
|
raise ValueError("代理 URL 包含非法字符")
|
||||||
|
|
||||||
|
# 验证协议(不支持 SOCKS4)
|
||||||
|
if not re.match(r"^(http|https|socks5)://", v, re.IGNORECASE):
|
||||||
|
raise ValueError("代理 URL 必须以 http://, https:// 或 socks5:// 开头")
|
||||||
|
|
||||||
|
# 验证 URL 结构
|
||||||
|
parsed = urlparse(v)
|
||||||
|
if not parsed.netloc:
|
||||||
|
raise ValueError("代理 URL 必须包含有效的 host")
|
||||||
|
|
||||||
|
# 禁止 URL 中内嵌认证信息,强制使用独立字段
|
||||||
|
if parsed.username or parsed.password:
|
||||||
|
raise ValueError("请勿在 URL 中包含用户名和密码,请使用独立的认证字段")
|
||||||
|
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
class CreateProviderRequest(BaseModel):
|
class CreateProviderRequest(BaseModel):
|
||||||
"""创建 Provider 请求"""
|
"""创建 Provider 请求"""
|
||||||
|
|
||||||
@@ -165,6 +201,7 @@ class CreateEndpointRequest(BaseModel):
|
|||||||
rpm_limit: Optional[int] = Field(None, ge=0, description="RPM 限制")
|
rpm_limit: Optional[int] = Field(None, ge=0, description="RPM 限制")
|
||||||
concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制")
|
concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制")
|
||||||
config: Optional[Dict[str, Any]] = Field(None, description="其他配置")
|
config: Optional[Dict[str, Any]] = Field(None, description="其他配置")
|
||||||
|
proxy: Optional[ProxyConfig] = Field(None, description="代理配置")
|
||||||
|
|
||||||
@field_validator("name")
|
@field_validator("name")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -220,6 +257,7 @@ class UpdateEndpointRequest(BaseModel):
|
|||||||
rpm_limit: Optional[int] = Field(None, ge=0)
|
rpm_limit: Optional[int] = Field(None, ge=0)
|
||||||
concurrent_limit: Optional[int] = Field(None, ge=0)
|
concurrent_limit: Optional[int] = Field(None, ge=0)
|
||||||
config: Optional[Dict[str, Any]] = None
|
config: Optional[Dict[str, Any]] = None
|
||||||
|
proxy: Optional[ProxyConfig] = Field(None, description="代理配置")
|
||||||
|
|
||||||
# 复用验证器
|
# 复用验证器
|
||||||
_validate_name = field_validator("name")(CreateEndpointRequest.validate_name.__func__)
|
_validate_name = field_validator("name")(CreateEndpointRequest.validate_name.__func__)
|
||||||
|
|||||||
@@ -538,6 +538,9 @@ class ProviderEndpoint(Base):
|
|||||||
# 额外配置
|
# 额外配置
|
||||||
config = Column(JSON, nullable=True) # 端点特定配置(不推荐使用,优先使用专用字段)
|
config = Column(JSON, nullable=True) # 端点特定配置(不推荐使用,优先使用专用字段)
|
||||||
|
|
||||||
|
# 代理配置
|
||||||
|
proxy = Column(JSONB, nullable=True) # 代理配置: {url, username, password}
|
||||||
|
|
||||||
# 时间戳
|
# 时间戳
|
||||||
created_at = Column(
|
created_at = Column(
|
||||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False
|
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
|
|
||||||
|
from src.models.admin_requests import ProxyConfig
|
||||||
|
|
||||||
# ========== ProviderEndpoint CRUD ==========
|
# ========== ProviderEndpoint CRUD ==========
|
||||||
|
|
||||||
|
|
||||||
@@ -30,6 +32,9 @@ class ProviderEndpointCreate(BaseModel):
|
|||||||
# 额外配置
|
# 额外配置
|
||||||
config: Optional[Dict[str, Any]] = Field(default=None, description="额外配置(JSON)")
|
config: Optional[Dict[str, Any]] = Field(default=None, description="额外配置(JSON)")
|
||||||
|
|
||||||
|
# 代理配置
|
||||||
|
proxy: Optional[ProxyConfig] = Field(default=None, description="代理配置")
|
||||||
|
|
||||||
@field_validator("api_format")
|
@field_validator("api_format")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_api_format(cls, v: str) -> str:
|
def validate_api_format(cls, v: str) -> str:
|
||||||
@@ -64,6 +69,7 @@ class ProviderEndpointUpdate(BaseModel):
|
|||||||
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制")
|
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制")
|
||||||
is_active: Optional[bool] = Field(default=None, description="是否启用")
|
is_active: Optional[bool] = Field(default=None, description="是否启用")
|
||||||
config: Optional[Dict[str, Any]] = Field(default=None, description="额外配置")
|
config: Optional[Dict[str, Any]] = Field(default=None, description="额外配置")
|
||||||
|
proxy: Optional[ProxyConfig] = Field(default=None, description="代理配置")
|
||||||
|
|
||||||
@field_validator("base_url")
|
@field_validator("base_url")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -104,6 +110,9 @@ class ProviderEndpointResponse(BaseModel):
|
|||||||
# 额外配置
|
# 额外配置
|
||||||
config: Optional[Dict[str, Any]] = None
|
config: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
# 代理配置(响应中密码已脱敏)
|
||||||
|
proxy: Optional[Dict[str, Any]] = Field(default=None, description="代理配置(密码已脱敏)")
|
||||||
|
|
||||||
# 统计(从 Keys 聚合)
|
# 统计(从 Keys 聚合)
|
||||||
total_keys: int = Field(default=0, description="总 Key 数量")
|
total_keys: int = Field(default=0, description="总 Key 数量")
|
||||||
active_keys: int = Field(default=0, description="活跃 Key 数量")
|
active_keys: int = Field(default=0, description="活跃 Key 数量")
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ JWT认证插件
|
|||||||
支持JWT Bearer token认证
|
支持JWT Bearer token认证
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
@@ -46,8 +47,8 @@ class JwtAuthPlugin(AuthPlugin):
|
|||||||
logger.debug("未找到JWT token")
|
logger.debug("未找到JWT token")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 记录认证尝试的详细信息
|
token_fingerprint = hashlib.sha256(token.encode()).hexdigest()[:12]
|
||||||
logger.info(f"JWT认证尝试 - 路径: {request.url.path}, Token前20位: {token[:20]}...")
|
logger.info(f"JWT认证尝试 - 路径: {request.url.path}, token_fp={token_fingerprint}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 验证JWT token
|
# 验证JWT token
|
||||||
|
|||||||
@@ -63,14 +63,16 @@ class JWTBlacklistService:
|
|||||||
|
|
||||||
if ttl_seconds <= 0:
|
if ttl_seconds <= 0:
|
||||||
# Token 已经过期,不需要加入黑名单
|
# Token 已经过期,不需要加入黑名单
|
||||||
logger.debug(f"Token 已过期,无需加入黑名单: {token[:10]}...")
|
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
|
||||||
|
logger.debug("Token 已过期,无需加入黑名单: token_fp={}", token_fp)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# 存储到 Redis,设置 TTL 为 Token 过期时间
|
# 存储到 Redis,设置 TTL 为 Token 过期时间
|
||||||
# 值存储为原因字符串
|
# 值存储为原因字符串
|
||||||
await redis_client.setex(redis_key, ttl_seconds, reason)
|
await redis_client.setex(redis_key, ttl_seconds, reason)
|
||||||
|
|
||||||
logger.info(f"Token 已加入黑名单: {token[:10]}... (原因: {reason}, TTL: {ttl_seconds}s)")
|
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
|
||||||
|
logger.info("Token 已加入黑名单: token_fp={} (原因: {}, TTL: {}s)", token_fp, reason, ttl_seconds)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -109,7 +111,8 @@ class JWTBlacklistService:
|
|||||||
if exists:
|
if exists:
|
||||||
# 获取黑名单原因(可选)
|
# 获取黑名单原因(可选)
|
||||||
reason = await redis_client.get(redis_key)
|
reason = await redis_client.get(redis_key)
|
||||||
logger.warning(f"检测到黑名单 Token: {token[:10]}... (原因: {reason})")
|
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
|
||||||
|
logger.warning("检测到黑名单 Token: token_fp={} (原因: {})", token_fp, reason)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
@@ -148,9 +151,11 @@ class JWTBlacklistService:
|
|||||||
deleted = await redis_client.delete(redis_key)
|
deleted = await redis_client.delete(redis_key)
|
||||||
|
|
||||||
if deleted:
|
if deleted:
|
||||||
logger.info(f"Token 已从黑名单移除: {token[:10]}...")
|
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
|
||||||
|
logger.info("Token 已从黑名单移除: token_fp={}", token_fp)
|
||||||
else:
|
else:
|
||||||
logger.debug(f"Token 不在黑名单中: {token[:10]}...")
|
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
|
||||||
|
logger.debug("Token 不在黑名单中: token_fp={}", token_fp)
|
||||||
|
|
||||||
return bool(deleted)
|
return bool(deleted)
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import hashlib
|
||||||
import secrets
|
import secrets
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
@@ -169,7 +170,8 @@ class AuthService:
|
|||||||
key_record.last_used_at = datetime.now(timezone.utc)
|
key_record.last_used_at = datetime.now(timezone.utc)
|
||||||
db.commit() # 立即提交事务,释放数据库锁,避免阻塞后续请求
|
db.commit() # 立即提交事务,释放数据库锁,避免阻塞后续请求
|
||||||
|
|
||||||
logger.debug(f"API认证成功: 用户 {user.email} (Key: {api_key[:10]}...)")
|
api_key_fp = hashlib.sha256(api_key.encode()).hexdigest()[:12]
|
||||||
|
logger.debug("API认证成功: 用户 {} (api_key_fp={})", user.email, api_key_fp)
|
||||||
return user, key_record
|
return user, key_record
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -1,61 +0,0 @@
|
|||||||
"""响应标准化服务,用于 STANDARD 模式下的响应格式验证和补全"""
|
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
from src.core.logger import logger
|
|
||||||
from src.models.claude import ClaudeResponse
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseNormalizer:
|
|
||||||
"""响应标准化器 - 用于标准模式下验证和补全响应字段"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def normalize_claude_response(
|
|
||||||
response_data: Dict[str, Any], request_id: Optional[str] = None
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
标准化 Claude API 响应
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response_data: 原始响应数据
|
|
||||||
request_id: 请求ID(用于日志)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
标准化后的响应数据(失败时返回原始数据)
|
|
||||||
"""
|
|
||||||
if "error" in response_data:
|
|
||||||
logger.debug(f"[ResponseNormalizer] 检测到错误响应,跳过标准化 | ID:{request_id}")
|
|
||||||
return response_data
|
|
||||||
|
|
||||||
try:
|
|
||||||
validated = ClaudeResponse.model_validate(response_data)
|
|
||||||
normalized = validated.model_dump(mode="json", exclude_none=False)
|
|
||||||
|
|
||||||
logger.debug(f"[ResponseNormalizer] 响应标准化成功 | ID:{request_id}")
|
|
||||||
return normalized
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"[ResponseNormalizer] 响应验证失败,透传原始数据 | ID:{request_id}")
|
|
||||||
return response_data
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def should_normalize(response_data: Dict[str, Any]) -> bool:
|
|
||||||
"""
|
|
||||||
判断是否需要进行标准化
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response_data: 响应数据
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
是否需要标准化
|
|
||||||
"""
|
|
||||||
# 错误响应不需要标准化
|
|
||||||
if "error" in response_data:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 已经包含新字段的响应不需要再次标准化
|
|
||||||
if "context_management" in response_data and "container" in response_data:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
@@ -21,7 +21,7 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
from src.database import create_session
|
from src.database import create_session
|
||||||
from src.models.database import Usage
|
from src.models.database import AuditLog, Usage
|
||||||
from src.services.system.config import SystemConfigService
|
from src.services.system.config import SystemConfigService
|
||||||
from src.services.system.scheduler import get_scheduler
|
from src.services.system.scheduler import get_scheduler
|
||||||
from src.services.system.stats_aggregator import StatsAggregatorService
|
from src.services.system.stats_aggregator import StatsAggregatorService
|
||||||
@@ -91,6 +91,15 @@ class CleanupScheduler:
|
|||||||
name="Pending状态清理",
|
name="Pending状态清理",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 审计日志清理 - 凌晨 4 点执行
|
||||||
|
scheduler.add_cron_job(
|
||||||
|
self._scheduled_audit_cleanup,
|
||||||
|
hour=4,
|
||||||
|
minute=0,
|
||||||
|
job_id="audit_cleanup",
|
||||||
|
name="审计日志清理",
|
||||||
|
)
|
||||||
|
|
||||||
# 启动时执行一次初始化任务
|
# 启动时执行一次初始化任务
|
||||||
asyncio.create_task(self._run_startup_tasks())
|
asyncio.create_task(self._run_startup_tasks())
|
||||||
|
|
||||||
@@ -145,6 +154,10 @@ class CleanupScheduler:
|
|||||||
"""Pending 清理任务(定时调用)"""
|
"""Pending 清理任务(定时调用)"""
|
||||||
await self._perform_pending_cleanup()
|
await self._perform_pending_cleanup()
|
||||||
|
|
||||||
|
async def _scheduled_audit_cleanup(self):
|
||||||
|
"""审计日志清理任务(定时调用)"""
|
||||||
|
await self._perform_audit_cleanup()
|
||||||
|
|
||||||
# ========== 实际任务实现 ==========
|
# ========== 实际任务实现 ==========
|
||||||
|
|
||||||
async def _perform_stats_aggregation(self, backfill: bool = False):
|
async def _perform_stats_aggregation(self, backfill: bool = False):
|
||||||
@@ -330,6 +343,70 @@ class CleanupScheduler:
|
|||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
async def _perform_audit_cleanup(self):
|
||||||
|
"""执行审计日志清理任务"""
|
||||||
|
db = create_session()
|
||||||
|
try:
|
||||||
|
# 检查是否启用自动清理
|
||||||
|
if not SystemConfigService.get_config(db, "enable_auto_cleanup", True):
|
||||||
|
logger.info("自动清理已禁用,跳过审计日志清理")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 获取审计日志保留天数(默认 30 天,最少 7 天)
|
||||||
|
audit_retention_days = max(
|
||||||
|
SystemConfigService.get_config(db, "audit_log_retention_days", 30),
|
||||||
|
7, # 最少保留 7 天,防止误配置删除所有审计日志
|
||||||
|
)
|
||||||
|
batch_size = SystemConfigService.get_config(db, "cleanup_batch_size", 1000)
|
||||||
|
|
||||||
|
cutoff_time = datetime.now(timezone.utc) - timedelta(days=audit_retention_days)
|
||||||
|
|
||||||
|
logger.info(f"开始清理 {audit_retention_days} 天前的审计日志...")
|
||||||
|
|
||||||
|
total_deleted = 0
|
||||||
|
while True:
|
||||||
|
# 先查询要删除的记录 ID(分批)
|
||||||
|
records_to_delete = (
|
||||||
|
db.query(AuditLog.id)
|
||||||
|
.filter(AuditLog.created_at < cutoff_time)
|
||||||
|
.limit(batch_size)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not records_to_delete:
|
||||||
|
break
|
||||||
|
|
||||||
|
record_ids = [r.id for r in records_to_delete]
|
||||||
|
|
||||||
|
# 执行删除
|
||||||
|
result = db.execute(
|
||||||
|
delete(AuditLog)
|
||||||
|
.where(AuditLog.id.in_(record_ids))
|
||||||
|
.execution_options(synchronize_session=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
rows_deleted = result.rowcount
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
total_deleted += rows_deleted
|
||||||
|
logger.debug(f"已删除 {rows_deleted} 条审计日志,累计 {total_deleted} 条")
|
||||||
|
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
if total_deleted > 0:
|
||||||
|
logger.info(f"审计日志清理完成,共删除 {total_deleted} 条记录")
|
||||||
|
else:
|
||||||
|
logger.info("无需清理的审计日志")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"审计日志清理失败: {e}")
|
||||||
|
try:
|
||||||
|
db.rollback()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
async def _perform_cleanup(self):
|
async def _perform_cleanup(self):
|
||||||
"""执行清理任务"""
|
"""执行清理任务"""
|
||||||
db = create_session()
|
db = create_session()
|
||||||
|
|||||||
@@ -1217,15 +1217,19 @@ class UsageService:
|
|||||||
request_id: str,
|
request_id: str,
|
||||||
status: str,
|
status: str,
|
||||||
error_message: Optional[str] = None,
|
error_message: Optional[str] = None,
|
||||||
|
provider: Optional[str] = None,
|
||||||
|
target_model: Optional[str] = None,
|
||||||
) -> Optional[Usage]:
|
) -> Optional[Usage]:
|
||||||
"""
|
"""
|
||||||
快速更新使用记录状态(不更新其他字段)
|
快速更新使用记录状态
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: 数据库会话
|
db: 数据库会话
|
||||||
request_id: 请求ID
|
request_id: 请求ID
|
||||||
status: 新状态 (pending, streaming, completed, failed)
|
status: 新状态 (pending, streaming, completed, failed)
|
||||||
error_message: 错误消息(仅在 failed 状态时使用)
|
error_message: 错误消息(仅在 failed 状态时使用)
|
||||||
|
provider: 提供商名称(可选,streaming 状态时更新)
|
||||||
|
target_model: 映射后的目标模型名(可选)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
更新后的 Usage 记录,如果未找到则返回 None
|
更新后的 Usage 记录,如果未找到则返回 None
|
||||||
@@ -1239,6 +1243,10 @@ class UsageService:
|
|||||||
usage.status = status
|
usage.status = status
|
||||||
if error_message:
|
if error_message:
|
||||||
usage.error_message = error_message
|
usage.error_message = error_message
|
||||||
|
if provider:
|
||||||
|
usage.provider = provider
|
||||||
|
if target_model:
|
||||||
|
usage.target_model = target_model
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
|
|||||||
@@ -457,7 +457,7 @@ class StreamUsageTracker:
|
|||||||
|
|
||||||
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应 | 估算输入tokens:{self.input_tokens}")
|
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应 | 估算输入tokens:{self.input_tokens}")
|
||||||
|
|
||||||
# 更新状态为 streaming
|
# 更新状态为 streaming,同时更新 provider
|
||||||
if self.request_id:
|
if self.request_id:
|
||||||
try:
|
try:
|
||||||
from src.services.usage.service import UsageService
|
from src.services.usage.service import UsageService
|
||||||
@@ -465,6 +465,7 @@ class StreamUsageTracker:
|
|||||||
db=self.db,
|
db=self.db,
|
||||||
request_id=self.request_id,
|
request_id=self.request_id,
|
||||||
status="streaming",
|
status="streaming",
|
||||||
|
provider=self.provider,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
|
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
|
||||||
|
|||||||
@@ -210,7 +210,15 @@ class ApiKeyService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_rate_limit(db: Session, api_key: ApiKey, window_minutes: int = 1) -> tuple[bool, int]:
|
def check_rate_limit(db: Session, api_key: ApiKey, window_minutes: int = 1) -> tuple[bool, int]:
|
||||||
"""检查速率限制"""
|
"""检查速率限制
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(is_allowed, remaining): 是否允许请求,剩余可用次数
|
||||||
|
当 rate_limit 为 None 时表示不限制,返回 (True, -1)
|
||||||
|
"""
|
||||||
|
# 如果 rate_limit 为 None,表示不限制
|
||||||
|
if api_key.rate_limit is None:
|
||||||
|
return True, -1 # -1 表示无限制
|
||||||
|
|
||||||
# 计算时间窗口
|
# 计算时间窗口
|
||||||
window_start = datetime.now(timezone.utc) - timedelta(minutes=window_minutes)
|
window_start = datetime.now(timezone.utc) - timedelta(minutes=window_minutes)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
提供统一的用户认证和授权功能
|
提供统一的用户认证和授权功能
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import Depends, Header, HTTPException, status
|
from fastapi import Depends, Header, HTTPException, status
|
||||||
@@ -44,10 +45,17 @@ async def get_current_user(
|
|||||||
payload = await AuthService.verify_token(token, token_type="access")
|
payload = await AuthService.verify_token(token, token_type="access")
|
||||||
except HTTPException as token_error:
|
except HTTPException as token_error:
|
||||||
# 保持原始的HTTP状态码(如401 Unauthorized),不要转换为403
|
# 保持原始的HTTP状态码(如401 Unauthorized),不要转换为403
|
||||||
logger.error(f"Token验证失败: {token_error.status_code}: {token_error.detail}, Token前10位: {token[:10]}...")
|
token_fp = hashlib.sha256(token.encode()).hexdigest()[:12]
|
||||||
|
logger.error(
|
||||||
|
"Token验证失败: {}: {}, token_fp={}",
|
||||||
|
token_error.status_code,
|
||||||
|
token_error.detail,
|
||||||
|
token_fp,
|
||||||
|
)
|
||||||
raise # 重新抛出原始异常,保持状态码
|
raise # 重新抛出原始异常,保持状态码
|
||||||
except Exception as token_error:
|
except Exception as token_error:
|
||||||
logger.error(f"Token验证失败: {token_error}, Token前10位: {token[:10]}...")
|
token_fp = hashlib.sha256(token.encode()).hexdigest()[:12]
|
||||||
|
logger.error("Token验证失败: {}, token_fp={}", token_error, token_fp)
|
||||||
raise ForbiddenException("无效的Token")
|
raise ForbiddenException("无效的Token")
|
||||||
|
|
||||||
user_id = payload.get("user_id")
|
user_id = payload.get("user_id")
|
||||||
@@ -63,7 +71,8 @@ async def get_current_user(
|
|||||||
raise ForbiddenException("无效的认证凭据")
|
raise ForbiddenException("无效的认证凭据")
|
||||||
|
|
||||||
# 仅在DEBUG模式下记录详细信息
|
# 仅在DEBUG模式下记录详细信息
|
||||||
logger.debug(f"尝试获取用户: user_id={user_id}, token前10位: {token[:10]}...")
|
token_fp = hashlib.sha256(token.encode()).hexdigest()[:12]
|
||||||
|
logger.debug("尝试获取用户: user_id={}, token_fp={}", user_id, token_fp)
|
||||||
|
|
||||||
# 确保user_id是字符串格式(UUID)
|
# 确保user_id是字符串格式(UUID)
|
||||||
if not isinstance(user_id, str):
|
if not isinstance(user_id, str):
|
||||||
|
|||||||
@@ -7,29 +7,47 @@ from typing import Optional
|
|||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
|
from src.config import config
|
||||||
|
|
||||||
|
|
||||||
def get_client_ip(request: Request) -> str:
|
def get_client_ip(request: Request) -> str:
|
||||||
"""
|
"""
|
||||||
获取客户端真实IP地址
|
获取客户端真实IP地址
|
||||||
|
|
||||||
按优先级检查:
|
按优先级检查:
|
||||||
1. X-Forwarded-For 头(支持代理链)
|
1. X-Forwarded-For 头(支持代理链,根据可信代理数量提取)
|
||||||
2. X-Real-IP 头(Nginx 代理)
|
2. X-Real-IP 头(Nginx 代理)
|
||||||
3. 直接客户端IP
|
3. 直接客户端IP
|
||||||
|
|
||||||
|
安全说明:
|
||||||
|
- 此函数根据 TRUSTED_PROXY_COUNT 配置来决定信任的代理层数
|
||||||
|
- 当 TRUSTED_PROXY_COUNT=0 时,不信任任何代理头,直接使用连接 IP
|
||||||
|
- 当服务直接暴露公网时,应设置 TRUSTED_PROXY_COUNT=0 以防止 IP 伪造
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: FastAPI Request 对象
|
request: FastAPI Request 对象
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 客户端IP地址,如果无法获取则返回 "unknown"
|
str: 客户端IP地址,如果无法获取则返回 "unknown"
|
||||||
"""
|
"""
|
||||||
|
trusted_proxy_count = config.trusted_proxy_count
|
||||||
|
|
||||||
|
# 如果不信任任何代理,直接返回连接 IP
|
||||||
|
if trusted_proxy_count == 0:
|
||||||
|
if request.client and request.client.host:
|
||||||
|
return request.client.host
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
# 优先检查 X-Forwarded-For 头(可能包含代理链)
|
# 优先检查 X-Forwarded-For 头(可能包含代理链)
|
||||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||||
if forwarded_for:
|
if forwarded_for:
|
||||||
# X-Forwarded-For 格式: "client, proxy1, proxy2",取第一个(真实客户端)
|
# X-Forwarded-For 格式: "client, proxy1, proxy2"
|
||||||
client_ip = forwarded_for.split(",")[0].strip()
|
# 从右往左数 trusted_proxy_count 个,取其左边的第一个
|
||||||
if client_ip:
|
ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()]
|
||||||
return client_ip
|
if len(ips) > trusted_proxy_count:
|
||||||
|
return ips[-(trusted_proxy_count + 1)]
|
||||||
|
elif ips:
|
||||||
|
return ips[0]
|
||||||
|
|
||||||
# 检查 X-Real-IP 头(通常由 Nginx 设置)
|
# 检查 X-Real-IP 头(通常由 Nginx 设置)
|
||||||
real_ip = request.headers.get("X-Real-IP")
|
real_ip = request.headers.get("X-Real-IP")
|
||||||
@@ -91,20 +109,32 @@ def get_request_metadata(request: Request) -> dict:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def extract_ip_from_headers(headers: dict) -> str:
|
def extract_ip_from_headers(headers: dict, trusted_proxy_count: Optional[int] = None) -> str:
|
||||||
"""
|
"""
|
||||||
从HTTP头字典中提取IP地址(用于中间件等场景)
|
从HTTP头字典中提取IP地址(用于中间件等场景)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
headers: HTTP头字典
|
headers: HTTP头字典
|
||||||
|
trusted_proxy_count: 可信代理层数,None 时使用配置值
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 客户端IP地址
|
str: 客户端IP地址
|
||||||
"""
|
"""
|
||||||
|
if trusted_proxy_count is None:
|
||||||
|
trusted_proxy_count = config.trusted_proxy_count
|
||||||
|
|
||||||
|
# 如果不信任任何代理,返回 unknown(调用方需要用其他方式获取连接 IP)
|
||||||
|
if trusted_proxy_count == 0:
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
# 检查 X-Forwarded-For
|
# 检查 X-Forwarded-For
|
||||||
forwarded_for = headers.get("x-forwarded-for", "")
|
forwarded_for = headers.get("x-forwarded-for", "")
|
||||||
if forwarded_for:
|
if forwarded_for:
|
||||||
return forwarded_for.split(",")[0].strip()
|
ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()]
|
||||||
|
if len(ips) > trusted_proxy_count:
|
||||||
|
return ips[-(trusted_proxy_count + 1)]
|
||||||
|
elif ips:
|
||||||
|
return ips[0]
|
||||||
|
|
||||||
# 检查 X-Real-IP
|
# 检查 X-Real-IP
|
||||||
real_ip = headers.get("x-real-ip", "")
|
real_ip = headers.get("x-real-ip", "")
|
||||||
|
|||||||
@@ -361,3 +361,61 @@ class TestPipelineAdminAuth:
|
|||||||
|
|
||||||
assert result == mock_user
|
assert result == mock_user
|
||||||
assert mock_request.state.user_id == "admin-123"
|
assert mock_request.state.user_id == "admin-123"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_admin_lowercase_bearer(self, pipeline: ApiRequestPipeline) -> None:
|
||||||
|
"""测试 bearer (小写) 前缀也能正确解析"""
|
||||||
|
mock_user = MagicMock()
|
||||||
|
mock_user.id = "admin-123"
|
||||||
|
mock_user.is_active = True
|
||||||
|
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.headers = {"authorization": "bearer valid-token"}
|
||||||
|
mock_request.state = MagicMock()
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
pipeline.auth_service,
|
||||||
|
"verify_token",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value={"user_id": "admin-123"},
|
||||||
|
) as mock_verify:
|
||||||
|
result = await pipeline._authenticate_admin(mock_request, mock_db)
|
||||||
|
|
||||||
|
mock_verify.assert_awaited_once_with("valid-token", token_type="access")
|
||||||
|
assert result == mock_user
|
||||||
|
|
||||||
|
|
||||||
|
class TestPipelineUserAuth:
|
||||||
|
"""测试普通用户 JWT 认证"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def pipeline(self) -> ApiRequestPipeline:
|
||||||
|
return ApiRequestPipeline()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_user_lowercase_bearer(self, pipeline: ApiRequestPipeline) -> None:
|
||||||
|
"""测试 bearer (小写) 前缀也能正确解析"""
|
||||||
|
mock_user = MagicMock()
|
||||||
|
mock_user.id = "user-123"
|
||||||
|
mock_user.is_active = True
|
||||||
|
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.headers = {"authorization": "bearer valid-token"}
|
||||||
|
mock_request.state = MagicMock()
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
pipeline.auth_service,
|
||||||
|
"verify_token",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value={"user_id": "user-123"},
|
||||||
|
) as mock_verify:
|
||||||
|
result = await pipeline._authenticate_user(mock_request, mock_db)
|
||||||
|
|
||||||
|
mock_verify.assert_awaited_once_with("valid-token", token_type="access")
|
||||||
|
assert result == mock_user
|
||||||
|
|||||||
Reference in New Issue
Block a user