mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-05 09:12:27 +08:00
Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5f0c1fb347 | ||
|
|
7b932d7afb | ||
|
|
c7b971cfe7 | ||
|
|
293bb592dc | ||
|
|
3e50c157be | ||
|
|
21587449c8 | ||
|
|
3d0ab353d3 | ||
|
|
b2a857c164 | ||
|
|
4d1d863916 | ||
|
|
b579420690 | ||
|
|
9d5c84f9d3 | ||
|
|
53e6a82480 | ||
|
|
bd11ebdbd5 | ||
|
|
1dac4cb156 |
@@ -20,10 +20,10 @@ depends_on = None
|
|||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
# Create ENUM types
|
# Create ENUM types (with IF NOT EXISTS for idempotency)
|
||||||
op.execute("CREATE TYPE userrole AS ENUM ('admin', 'user')")
|
op.execute("DO $$ BEGIN CREATE TYPE userrole AS ENUM ('admin', 'user'); EXCEPTION WHEN duplicate_object THEN NULL; END $$")
|
||||||
op.execute(
|
op.execute(
|
||||||
"CREATE TYPE providerbillingtype AS ENUM ('monthly_quota', 'pay_as_you_go', 'free_tier')"
|
"DO $$ BEGIN CREATE TYPE providerbillingtype AS ENUM ('monthly_quota', 'pay_as_you_go', 'free_tier'); EXCEPTION WHEN duplicate_object THEN NULL; END $$"
|
||||||
)
|
)
|
||||||
|
|
||||||
# ==================== users ====================
|
# ==================== users ====================
|
||||||
@@ -35,7 +35,7 @@ def upgrade() -> None:
|
|||||||
sa.Column("password_hash", sa.String(255), nullable=False),
|
sa.Column("password_hash", sa.String(255), nullable=False),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"role",
|
"role",
|
||||||
sa.Enum("admin", "user", name="userrole", create_type=False),
|
postgresql.ENUM("admin", "user", name="userrole", create_type=False),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
server_default="user",
|
server_default="user",
|
||||||
),
|
),
|
||||||
@@ -67,7 +67,7 @@ def upgrade() -> None:
|
|||||||
sa.Column("website", sa.String(500), nullable=True),
|
sa.Column("website", sa.String(500), nullable=True),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"billing_type",
|
"billing_type",
|
||||||
sa.Enum(
|
postgresql.ENUM(
|
||||||
"monthly_quota", "pay_as_you_go", "free_tier", name="providerbillingtype", create_type=False
|
"monthly_quota", "pay_as_you_go", "free_tier", name="providerbillingtype", create_type=False
|
||||||
),
|
),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
|
|||||||
@@ -0,0 +1,57 @@
|
|||||||
|
"""add proxy field to provider_endpoints
|
||||||
|
|
||||||
|
Revision ID: f30f9936f6a2
|
||||||
|
Revises: 1cc6942cf06f
|
||||||
|
Create Date: 2025-12-18 06:31:58.451112+00:00
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
from sqlalchemy import inspect
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'f30f9936f6a2'
|
||||||
|
down_revision = '1cc6942cf06f'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def column_exists(table_name: str, column_name: str) -> bool:
|
||||||
|
"""检查列是否存在"""
|
||||||
|
bind = op.get_bind()
|
||||||
|
inspector = inspect(bind)
|
||||||
|
columns = [col['name'] for col in inspector.get_columns(table_name)]
|
||||||
|
return column_name in columns
|
||||||
|
|
||||||
|
|
||||||
|
def get_column_type(table_name: str, column_name: str) -> str:
|
||||||
|
"""获取列的类型"""
|
||||||
|
bind = op.get_bind()
|
||||||
|
inspector = inspect(bind)
|
||||||
|
for col in inspector.get_columns(table_name):
|
||||||
|
if col['name'] == column_name:
|
||||||
|
return str(col['type']).upper()
|
||||||
|
return ''
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""添加 proxy 字段到 provider_endpoints 表"""
|
||||||
|
if not column_exists('provider_endpoints', 'proxy'):
|
||||||
|
# 字段不存在,直接添加 JSONB 类型
|
||||||
|
op.add_column('provider_endpoints', sa.Column('proxy', JSONB(), nullable=True))
|
||||||
|
else:
|
||||||
|
# 字段已存在,检查是否需要转换类型
|
||||||
|
col_type = get_column_type('provider_endpoints', 'proxy')
|
||||||
|
if 'JSONB' not in col_type:
|
||||||
|
# 如果是 JSON 类型,转换为 JSONB
|
||||||
|
op.execute(
|
||||||
|
'ALTER TABLE provider_endpoints '
|
||||||
|
'ALTER COLUMN proxy TYPE JSONB USING proxy::jsonb'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""移除 proxy 字段"""
|
||||||
|
if column_exists('provider_endpoints', 'proxy'):
|
||||||
|
op.drop_column('provider_endpoints', 'proxy')
|
||||||
@@ -124,6 +124,27 @@ export interface ModelExport {
|
|||||||
config?: any
|
config?: any
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Provider 模型查询响应
|
||||||
|
export interface ProviderModelsQueryResponse {
|
||||||
|
success: boolean
|
||||||
|
data: {
|
||||||
|
models: Array<{
|
||||||
|
id: string
|
||||||
|
object?: string
|
||||||
|
created?: number
|
||||||
|
owned_by?: string
|
||||||
|
display_name?: string
|
||||||
|
api_format?: string
|
||||||
|
}>
|
||||||
|
error?: string
|
||||||
|
}
|
||||||
|
provider: {
|
||||||
|
id: string
|
||||||
|
name: string
|
||||||
|
display_name: string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export interface ConfigImportRequest extends ConfigExportData {
|
export interface ConfigImportRequest extends ConfigExportData {
|
||||||
merge_mode: 'skip' | 'overwrite' | 'error'
|
merge_mode: 'skip' | 'overwrite' | 'error'
|
||||||
}
|
}
|
||||||
@@ -356,5 +377,14 @@ export const adminApi = {
|
|||||||
data
|
data
|
||||||
)
|
)
|
||||||
return response.data
|
return response.data
|
||||||
|
},
|
||||||
|
|
||||||
|
// 查询 Provider 可用模型(从上游 API 获取)
|
||||||
|
async queryProviderModels(providerId: string, apiKeyId?: string): Promise<ProviderModelsQueryResponse> {
|
||||||
|
const response = await apiClient.post<ProviderModelsQueryResponse>(
|
||||||
|
'/api/admin/provider-query/models',
|
||||||
|
{ provider_id: providerId, api_key_id: apiKeyId }
|
||||||
|
)
|
||||||
|
return response.data
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import client from '../client'
|
import client from '../client'
|
||||||
import type { ProviderEndpoint } from './types'
|
import type { ProviderEndpoint, ProxyConfig } from './types'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取指定 Provider 的所有 Endpoints
|
* 获取指定 Provider 的所有 Endpoints
|
||||||
@@ -38,6 +38,7 @@ export async function createEndpoint(
|
|||||||
rate_limit?: number
|
rate_limit?: number
|
||||||
is_active?: boolean
|
is_active?: boolean
|
||||||
config?: Record<string, any>
|
config?: Record<string, any>
|
||||||
|
proxy?: ProxyConfig | null
|
||||||
}
|
}
|
||||||
): Promise<ProviderEndpoint> {
|
): Promise<ProviderEndpoint> {
|
||||||
const response = await client.post(`/api/admin/endpoints/providers/${providerId}/endpoints`, data)
|
const response = await client.post(`/api/admin/endpoints/providers/${providerId}/endpoints`, data)
|
||||||
@@ -63,6 +64,7 @@ export async function updateEndpoint(
|
|||||||
rate_limit: number
|
rate_limit: number
|
||||||
is_active: boolean
|
is_active: boolean
|
||||||
config: Record<string, any>
|
config: Record<string, any>
|
||||||
|
proxy: ProxyConfig | null
|
||||||
}>
|
}>
|
||||||
): Promise<ProviderEndpoint> {
|
): Promise<ProviderEndpoint> {
|
||||||
const response = await client.put(`/api/admin/endpoints/${endpointId}`, data)
|
const response = await client.put(`/api/admin/endpoints/${endpointId}`, data)
|
||||||
|
|||||||
@@ -1,3 +1,35 @@
|
|||||||
|
// API 格式常量
|
||||||
|
export const API_FORMATS = {
|
||||||
|
CLAUDE: 'CLAUDE',
|
||||||
|
CLAUDE_CLI: 'CLAUDE_CLI',
|
||||||
|
OPENAI: 'OPENAI',
|
||||||
|
OPENAI_CLI: 'OPENAI_CLI',
|
||||||
|
GEMINI: 'GEMINI',
|
||||||
|
GEMINI_CLI: 'GEMINI_CLI',
|
||||||
|
} as const
|
||||||
|
|
||||||
|
export type APIFormat = typeof API_FORMATS[keyof typeof API_FORMATS]
|
||||||
|
|
||||||
|
// API 格式显示名称映射(按品牌分组:API 在前,CLI 在后)
|
||||||
|
export const API_FORMAT_LABELS: Record<string, string> = {
|
||||||
|
[API_FORMATS.CLAUDE]: 'Claude',
|
||||||
|
[API_FORMATS.CLAUDE_CLI]: 'Claude CLI',
|
||||||
|
[API_FORMATS.OPENAI]: 'OpenAI',
|
||||||
|
[API_FORMATS.OPENAI_CLI]: 'OpenAI CLI',
|
||||||
|
[API_FORMATS.GEMINI]: 'Gemini',
|
||||||
|
[API_FORMATS.GEMINI_CLI]: 'Gemini CLI',
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 代理配置类型
|
||||||
|
*/
|
||||||
|
export interface ProxyConfig {
|
||||||
|
url: string
|
||||||
|
username?: string
|
||||||
|
password?: string
|
||||||
|
enabled?: boolean // 是否启用代理(false 时保留配置但不使用)
|
||||||
|
}
|
||||||
|
|
||||||
export interface ProviderEndpoint {
|
export interface ProviderEndpoint {
|
||||||
id: string
|
id: string
|
||||||
provider_id: string
|
provider_id: string
|
||||||
@@ -19,6 +51,7 @@ export interface ProviderEndpoint {
|
|||||||
last_failure_at?: string
|
last_failure_at?: string
|
||||||
is_active: boolean
|
is_active: boolean
|
||||||
config?: Record<string, any>
|
config?: Record<string, any>
|
||||||
|
proxy?: ProxyConfig | null
|
||||||
total_keys: number
|
total_keys: number
|
||||||
active_keys: number
|
active_keys: number
|
||||||
created_at: string
|
created_at: string
|
||||||
@@ -214,6 +247,7 @@ export interface ConcurrencyStatus {
|
|||||||
export interface ProviderModelAlias {
|
export interface ProviderModelAlias {
|
||||||
name: string
|
name: string
|
||||||
priority: number // 优先级(数字越小优先级越高)
|
priority: number // 优先级(数字越小优先级越高)
|
||||||
|
api_formats?: string[] // 作用域(适用的 API 格式),为空表示对所有格式生效
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface Model {
|
export interface Model {
|
||||||
|
|||||||
@@ -34,11 +34,10 @@ const buttonClass = computed(() => {
|
|||||||
'inline-flex items-center justify-center rounded-xl text-sm font-semibold transition-all duration-200 ring-offset-background focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50 active:scale-[0.98]'
|
'inline-flex items-center justify-center rounded-xl text-sm font-semibold transition-all duration-200 ring-offset-background focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50 active:scale-[0.98]'
|
||||||
|
|
||||||
const variantClasses = {
|
const variantClasses = {
|
||||||
default:
|
default: 'bg-primary text-white hover:bg-primary/90',
|
||||||
'bg-primary text-white shadow-[0_20px_35px_rgba(204,120,92,0.35)] hover:bg-primary/90 hover:shadow-[0_25px_45px_rgba(204,120,92,0.45)]',
|
destructive: 'bg-destructive text-destructive-foreground hover:bg-destructive/85',
|
||||||
destructive: 'bg-destructive text-destructive-foreground hover:bg-destructive/85 shadow-sm',
|
|
||||||
outline:
|
outline:
|
||||||
'border border-border/60 bg-card/60 text-foreground hover:border-primary/60 hover:text-primary hover:bg-primary/10 shadow-sm backdrop-blur transition-all',
|
'border border-border/60 bg-card/60 text-foreground hover:border-primary/60 hover:text-primary hover:bg-primary/10 backdrop-blur transition-all',
|
||||||
secondary:
|
secondary:
|
||||||
'bg-secondary text-secondary-foreground shadow-inner hover:bg-secondary/80',
|
'bg-secondary text-secondary-foreground shadow-inner hover:bg-secondary/80',
|
||||||
ghost: 'hover:bg-accent hover:text-accent-foreground',
|
ghost: 'hover:bg-accent hover:text-accent-foreground',
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
<Teleport to="body">
|
<Teleport to="body">
|
||||||
<div
|
<div
|
||||||
v-if="isOpen"
|
v-if="isOpen"
|
||||||
class="fixed inset-0 overflow-y-auto"
|
class="fixed inset-0 overflow-y-auto pointer-events-none"
|
||||||
:style="{ zIndex: containerZIndex }"
|
:style="{ zIndex: containerZIndex }"
|
||||||
>
|
>
|
||||||
<!-- 背景遮罩 -->
|
<!-- 背景遮罩 -->
|
||||||
@@ -16,13 +16,13 @@
|
|||||||
>
|
>
|
||||||
<div
|
<div
|
||||||
v-if="isOpen"
|
v-if="isOpen"
|
||||||
class="fixed inset-0 bg-black/40 backdrop-blur-sm transition-opacity"
|
class="fixed inset-0 bg-black/40 backdrop-blur-sm transition-opacity pointer-events-auto"
|
||||||
:style="{ zIndex: backdropZIndex }"
|
:style="{ zIndex: backdropZIndex }"
|
||||||
@click="handleClose"
|
@click="handleClose"
|
||||||
/>
|
/>
|
||||||
</Transition>
|
</Transition>
|
||||||
|
|
||||||
<div class="relative flex min-h-full items-end justify-center p-4 text-center sm:items-center sm:p-0">
|
<div class="relative flex min-h-full items-end justify-center p-4 text-center sm:items-center sm:p-0 pointer-events-none">
|
||||||
<!-- 对话框内容 -->
|
<!-- 对话框内容 -->
|
||||||
<Transition
|
<Transition
|
||||||
enter-active-class="duration-300 ease-out"
|
enter-active-class="duration-300 ease-out"
|
||||||
@@ -34,7 +34,7 @@
|
|||||||
>
|
>
|
||||||
<div
|
<div
|
||||||
v-if="isOpen"
|
v-if="isOpen"
|
||||||
class="relative transform rounded-lg bg-background text-left shadow-2xl transition-all sm:my-8 sm:w-full border border-border"
|
class="relative transform rounded-lg bg-background text-left shadow-2xl transition-all sm:my-8 sm:w-full border border-border pointer-events-auto"
|
||||||
:style="{ zIndex: contentZIndex }"
|
:style="{ zIndex: contentZIndex }"
|
||||||
:class="maxWidthClass"
|
:class="maxWidthClass"
|
||||||
@click.stop
|
@click.stop
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ const props = withDefaults(defineProps<Props>(), {
|
|||||||
|
|
||||||
const contentClass = computed(() =>
|
const contentClass = computed(() =>
|
||||||
cn(
|
cn(
|
||||||
'z-[100] max-h-96 min-w-[8rem] overflow-hidden rounded-2xl border border-border bg-card text-foreground shadow-2xl backdrop-blur-xl pointer-events-auto',
|
'z-[200] max-h-96 min-w-[8rem] overflow-hidden rounded-2xl border border-border bg-card text-foreground shadow-2xl backdrop-blur-xl pointer-events-auto',
|
||||||
'data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95',
|
'data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95',
|
||||||
'data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2',
|
'data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2',
|
||||||
props.class
|
props.class
|
||||||
|
|||||||
@@ -132,7 +132,7 @@
|
|||||||
type="number"
|
type="number"
|
||||||
min="1"
|
min="1"
|
||||||
max="10000"
|
max="10000"
|
||||||
placeholder="100"
|
placeholder="留空不限制"
|
||||||
class="h-10"
|
class="h-10"
|
||||||
@update:model-value="(v) => form.rate_limit = parseNumberInput(v, { min: 1, max: 10000 })"
|
@update:model-value="(v) => form.rate_limit = parseNumberInput(v, { min: 1, max: 10000 })"
|
||||||
/>
|
/>
|
||||||
@@ -376,7 +376,7 @@ const form = ref<StandaloneKeyFormData>({
|
|||||||
initial_balance_usd: 10,
|
initial_balance_usd: 10,
|
||||||
expire_days: undefined,
|
expire_days: undefined,
|
||||||
never_expire: true,
|
never_expire: true,
|
||||||
rate_limit: 100,
|
rate_limit: undefined,
|
||||||
auto_delete_on_expiry: false,
|
auto_delete_on_expiry: false,
|
||||||
allowed_providers: [],
|
allowed_providers: [],
|
||||||
allowed_api_formats: [],
|
allowed_api_formats: [],
|
||||||
@@ -389,7 +389,7 @@ function resetForm() {
|
|||||||
initial_balance_usd: 10,
|
initial_balance_usd: 10,
|
||||||
expire_days: undefined,
|
expire_days: undefined,
|
||||||
never_expire: true,
|
never_expire: true,
|
||||||
rate_limit: 100,
|
rate_limit: undefined,
|
||||||
auto_delete_on_expiry: false,
|
auto_delete_on_expiry: false,
|
||||||
allowed_providers: [],
|
allowed_providers: [],
|
||||||
allowed_api_formats: [],
|
allowed_api_formats: [],
|
||||||
@@ -408,7 +408,7 @@ function loadKeyData() {
|
|||||||
initial_balance_usd: props.apiKey.initial_balance_usd,
|
initial_balance_usd: props.apiKey.initial_balance_usd,
|
||||||
expire_days: props.apiKey.expire_days,
|
expire_days: props.apiKey.expire_days,
|
||||||
never_expire: props.apiKey.never_expire,
|
never_expire: props.apiKey.never_expire,
|
||||||
rate_limit: props.apiKey.rate_limit || 100,
|
rate_limit: props.apiKey.rate_limit,
|
||||||
auto_delete_on_expiry: props.apiKey.auto_delete_on_expiry,
|
auto_delete_on_expiry: props.apiKey.auto_delete_on_expiry,
|
||||||
allowed_providers: props.apiKey.allowed_providers || [],
|
allowed_providers: props.apiKey.allowed_providers || [],
|
||||||
allowed_api_formats: props.apiKey.allowed_api_formats || [],
|
allowed_api_formats: props.apiKey.allowed_api_formats || [],
|
||||||
|
|||||||
@@ -396,15 +396,13 @@ interface ProviderGroup {
|
|||||||
|
|
||||||
const groupedModels = computed(() => {
|
const groupedModels = computed(() => {
|
||||||
let models = allModels.value.filter(m => !m.deprecated)
|
let models = allModels.value.filter(m => !m.deprecated)
|
||||||
|
// 搜索(支持空格分隔的多关键词 AND 搜索)
|
||||||
if (searchQuery.value) {
|
if (searchQuery.value) {
|
||||||
const query = searchQuery.value.toLowerCase()
|
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||||
models = models.filter(model =>
|
models = models.filter(model => {
|
||||||
model.providerId.toLowerCase().includes(query) ||
|
const searchableText = `${model.providerId} ${model.providerName} ${model.modelId} ${model.modelName} ${model.family || ''}`.toLowerCase()
|
||||||
model.providerName.toLowerCase().includes(query) ||
|
return keywords.every(keyword => searchableText.includes(keyword))
|
||||||
model.modelId.toLowerCase().includes(query) ||
|
})
|
||||||
model.modelName.toLowerCase().includes(query) ||
|
|
||||||
model.family?.toLowerCase().includes(query)
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 按提供商分组
|
// 按提供商分组
|
||||||
@@ -425,10 +423,12 @@ const groupedModels = computed(() => {
|
|||||||
|
|
||||||
// 如果有搜索词,把提供商名称/ID匹配的排在前面
|
// 如果有搜索词,把提供商名称/ID匹配的排在前面
|
||||||
if (searchQuery.value) {
|
if (searchQuery.value) {
|
||||||
const query = searchQuery.value.toLowerCase()
|
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||||
result.sort((a, b) => {
|
result.sort((a, b) => {
|
||||||
const aProviderMatch = a.providerId.toLowerCase().includes(query) || a.providerName.toLowerCase().includes(query)
|
const aText = `${a.providerId} ${a.providerName}`.toLowerCase()
|
||||||
const bProviderMatch = b.providerId.toLowerCase().includes(query) || b.providerName.toLowerCase().includes(query)
|
const bText = `${b.providerId} ${b.providerName}`.toLowerCase()
|
||||||
|
const aProviderMatch = keywords.some(k => aText.includes(k))
|
||||||
|
const bProviderMatch = keywords.some(k => bText.includes(k))
|
||||||
if (aProviderMatch && !bProviderMatch) return -1
|
if (aProviderMatch && !bProviderMatch) return -1
|
||||||
if (!aProviderMatch && bProviderMatch) return 1
|
if (!aProviderMatch && bProviderMatch) return 1
|
||||||
return a.providerName.localeCompare(b.providerName)
|
return a.providerName.localeCompare(b.providerName)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@
|
|||||||
>
|
>
|
||||||
<form
|
<form
|
||||||
class="space-y-6"
|
class="space-y-6"
|
||||||
@submit.prevent="handleSubmit"
|
@submit.prevent="handleSubmit()"
|
||||||
>
|
>
|
||||||
<!-- API 配置 -->
|
<!-- API 配置 -->
|
||||||
<div class="space-y-4">
|
<div class="space-y-4">
|
||||||
@@ -132,6 +132,79 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- 代理配置 -->
|
||||||
|
<div class="space-y-4">
|
||||||
|
<div class="flex items-center justify-between">
|
||||||
|
<h3 class="text-sm font-medium">
|
||||||
|
代理配置
|
||||||
|
</h3>
|
||||||
|
<div class="flex items-center gap-2">
|
||||||
|
<Switch v-model="proxyEnabled" />
|
||||||
|
<span class="text-sm text-muted-foreground">启用代理</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div
|
||||||
|
v-if="proxyEnabled"
|
||||||
|
class="space-y-4 rounded-lg border p-4"
|
||||||
|
>
|
||||||
|
<div class="space-y-2">
|
||||||
|
<Label for="proxy_url">代理 URL *</Label>
|
||||||
|
<Input
|
||||||
|
id="proxy_url"
|
||||||
|
v-model="form.proxy_url"
|
||||||
|
placeholder="http://host:port 或 socks5://host:port"
|
||||||
|
required
|
||||||
|
:class="proxyUrlError ? 'border-red-500' : ''"
|
||||||
|
/>
|
||||||
|
<p
|
||||||
|
v-if="proxyUrlError"
|
||||||
|
class="text-xs text-red-500"
|
||||||
|
>
|
||||||
|
{{ proxyUrlError }}
|
||||||
|
</p>
|
||||||
|
<p
|
||||||
|
v-else
|
||||||
|
class="text-xs text-muted-foreground"
|
||||||
|
>
|
||||||
|
支持 HTTP、HTTPS、SOCKS5 代理
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="grid grid-cols-2 gap-4">
|
||||||
|
<div class="space-y-2">
|
||||||
|
<Label for="proxy_user">用户名(可选)</Label>
|
||||||
|
<Input
|
||||||
|
:id="`proxy_user_${formId}`"
|
||||||
|
:name="`proxy_user_${formId}`"
|
||||||
|
v-model="form.proxy_username"
|
||||||
|
placeholder="代理认证用户名"
|
||||||
|
autocomplete="off"
|
||||||
|
data-form-type="other"
|
||||||
|
data-lpignore="true"
|
||||||
|
data-1p-ignore="true"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="space-y-2">
|
||||||
|
<Label :for="`proxy_pass_${formId}`">密码(可选)</Label>
|
||||||
|
<Input
|
||||||
|
:id="`proxy_pass_${formId}`"
|
||||||
|
:name="`proxy_pass_${formId}`"
|
||||||
|
v-model="form.proxy_password"
|
||||||
|
type="text"
|
||||||
|
:placeholder="passwordPlaceholder"
|
||||||
|
autocomplete="off"
|
||||||
|
data-form-type="other"
|
||||||
|
data-lpignore="true"
|
||||||
|
data-1p-ignore="true"
|
||||||
|
:style="{ '-webkit-text-security': 'disc', 'text-security': 'disc' }"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
</form>
|
</form>
|
||||||
|
|
||||||
<template #footer>
|
<template #footer>
|
||||||
@@ -145,12 +218,24 @@
|
|||||||
</Button>
|
</Button>
|
||||||
<Button
|
<Button
|
||||||
:disabled="loading || !form.base_url || (!isEditMode && !form.api_format)"
|
:disabled="loading || !form.base_url || (!isEditMode && !form.api_format)"
|
||||||
@click="handleSubmit"
|
@click="handleSubmit()"
|
||||||
>
|
>
|
||||||
{{ loading ? (isEditMode ? '保存中...' : '创建中...') : (isEditMode ? '保存修改' : '创建') }}
|
{{ loading ? (isEditMode ? '保存中...' : '创建中...') : (isEditMode ? '保存修改' : '创建') }}
|
||||||
</Button>
|
</Button>
|
||||||
</template>
|
</template>
|
||||||
</Dialog>
|
</Dialog>
|
||||||
|
|
||||||
|
<!-- 确认清空凭据对话框 -->
|
||||||
|
<AlertDialog
|
||||||
|
v-model="showClearCredentialsDialog"
|
||||||
|
title="清空代理凭据"
|
||||||
|
description="代理 URL 为空,但用户名和密码仍有值。是否清空这些凭据并继续保存?"
|
||||||
|
type="warning"
|
||||||
|
confirm-text="清空并保存"
|
||||||
|
cancel-text="返回编辑"
|
||||||
|
@confirm="confirmClearCredentials"
|
||||||
|
@cancel="showClearCredentialsDialog = false"
|
||||||
|
/>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
@@ -165,7 +250,9 @@ import {
|
|||||||
SelectValue,
|
SelectValue,
|
||||||
SelectContent,
|
SelectContent,
|
||||||
SelectItem,
|
SelectItem,
|
||||||
|
Switch,
|
||||||
} from '@/components/ui'
|
} from '@/components/ui'
|
||||||
|
import AlertDialog from '@/components/common/AlertDialog.vue'
|
||||||
import { Link, SquarePen } from 'lucide-vue-next'
|
import { Link, SquarePen } from 'lucide-vue-next'
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
import { useFormDialog } from '@/composables/useFormDialog'
|
import { useFormDialog } from '@/composables/useFormDialog'
|
||||||
@@ -194,6 +281,11 @@ const emit = defineEmits<{
|
|||||||
const { success, error: showError } = useToast()
|
const { success, error: showError } = useToast()
|
||||||
const loading = ref(false)
|
const loading = ref(false)
|
||||||
const selectOpen = ref(false)
|
const selectOpen = ref(false)
|
||||||
|
const proxyEnabled = ref(false)
|
||||||
|
const showClearCredentialsDialog = ref(false) // 确认清空凭据对话框
|
||||||
|
|
||||||
|
// 生成随机 ID 防止浏览器自动填充
|
||||||
|
const formId = Math.random().toString(36).substring(2, 10)
|
||||||
|
|
||||||
// 内部状态
|
// 内部状态
|
||||||
const internalOpen = computed(() => props.modelValue)
|
const internalOpen = computed(() => props.modelValue)
|
||||||
@@ -207,7 +299,11 @@ const form = ref({
|
|||||||
max_retries: 3,
|
max_retries: 3,
|
||||||
max_concurrent: undefined as number | undefined,
|
max_concurrent: undefined as number | undefined,
|
||||||
rate_limit: undefined as number | undefined,
|
rate_limit: undefined as number | undefined,
|
||||||
is_active: true
|
is_active: true,
|
||||||
|
// 代理配置
|
||||||
|
proxy_url: '',
|
||||||
|
proxy_username: '',
|
||||||
|
proxy_password: '',
|
||||||
})
|
})
|
||||||
|
|
||||||
// API 格式列表
|
// API 格式列表
|
||||||
@@ -237,6 +333,53 @@ const defaultPathPlaceholder = computed(() => {
|
|||||||
return `留空使用默认路径:${defaultPath.value}`
|
return `留空使用默认路径:${defaultPath.value}`
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// 检查是否有已保存的密码(后端返回 *** 表示有密码)
|
||||||
|
const hasExistingPassword = computed(() => {
|
||||||
|
if (!props.endpoint?.proxy) return false
|
||||||
|
const proxy = props.endpoint.proxy as { password?: string }
|
||||||
|
return proxy?.password === MASKED_PASSWORD
|
||||||
|
})
|
||||||
|
|
||||||
|
// 密码输入框的 placeholder
|
||||||
|
const passwordPlaceholder = computed(() => {
|
||||||
|
if (hasExistingPassword.value) {
|
||||||
|
return '已保存密码,留空保持不变'
|
||||||
|
}
|
||||||
|
return '代理认证密码'
|
||||||
|
})
|
||||||
|
|
||||||
|
// 代理 URL 验证
|
||||||
|
const proxyUrlError = computed(() => {
|
||||||
|
// 只有启用代理且填写了 URL 时才验证
|
||||||
|
if (!proxyEnabled.value || !form.value.proxy_url) {
|
||||||
|
return ''
|
||||||
|
}
|
||||||
|
const url = form.value.proxy_url.trim()
|
||||||
|
|
||||||
|
// 检查禁止的特殊字符
|
||||||
|
if (/[\n\r]/.test(url)) {
|
||||||
|
return '代理 URL 包含非法字符'
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证协议(不支持 SOCKS4)
|
||||||
|
if (!/^(http|https|socks5):\/\//i.test(url)) {
|
||||||
|
return '代理 URL 必须以 http://, https:// 或 socks5:// 开头'
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
const parsed = new URL(url)
|
||||||
|
if (!parsed.host) {
|
||||||
|
return '代理 URL 必须包含有效的 host'
|
||||||
|
}
|
||||||
|
// 禁止 URL 中内嵌认证信息
|
||||||
|
if (parsed.username || parsed.password) {
|
||||||
|
return '请勿在 URL 中包含用户名和密码,请使用独立的认证字段'
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
return '代理 URL 格式无效'
|
||||||
|
}
|
||||||
|
return ''
|
||||||
|
})
|
||||||
|
|
||||||
// 组件挂载时加载API格式
|
// 组件挂载时加载API格式
|
||||||
onMounted(() => {
|
onMounted(() => {
|
||||||
loadApiFormats()
|
loadApiFormats()
|
||||||
@@ -252,14 +395,23 @@ function resetForm() {
|
|||||||
max_retries: 3,
|
max_retries: 3,
|
||||||
max_concurrent: undefined,
|
max_concurrent: undefined,
|
||||||
rate_limit: undefined,
|
rate_limit: undefined,
|
||||||
is_active: true
|
is_active: true,
|
||||||
|
proxy_url: '',
|
||||||
|
proxy_username: '',
|
||||||
|
proxy_password: '',
|
||||||
}
|
}
|
||||||
|
proxyEnabled.value = false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 原始密码占位符(后端返回的脱敏标记)
|
||||||
|
const MASKED_PASSWORD = '***'
|
||||||
|
|
||||||
// 加载端点数据(编辑模式)
|
// 加载端点数据(编辑模式)
|
||||||
function loadEndpointData() {
|
function loadEndpointData() {
|
||||||
if (!props.endpoint) return
|
if (!props.endpoint) return
|
||||||
|
|
||||||
|
const proxy = props.endpoint.proxy as { url?: string; username?: string; password?: string; enabled?: boolean } | null
|
||||||
|
|
||||||
form.value = {
|
form.value = {
|
||||||
api_format: props.endpoint.api_format,
|
api_format: props.endpoint.api_format,
|
||||||
base_url: props.endpoint.base_url,
|
base_url: props.endpoint.base_url,
|
||||||
@@ -268,8 +420,15 @@ function loadEndpointData() {
|
|||||||
max_retries: props.endpoint.max_retries,
|
max_retries: props.endpoint.max_retries,
|
||||||
max_concurrent: props.endpoint.max_concurrent || undefined,
|
max_concurrent: props.endpoint.max_concurrent || undefined,
|
||||||
rate_limit: props.endpoint.rate_limit || undefined,
|
rate_limit: props.endpoint.rate_limit || undefined,
|
||||||
is_active: props.endpoint.is_active
|
is_active: props.endpoint.is_active,
|
||||||
|
proxy_url: proxy?.url || '',
|
||||||
|
proxy_username: proxy?.username || '',
|
||||||
|
// 如果密码是脱敏标记,显示为空(让用户知道有密码但看不到)
|
||||||
|
proxy_password: proxy?.password === MASKED_PASSWORD ? '' : (proxy?.password || ''),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 根据 enabled 字段或 url 存在判断是否启用代理
|
||||||
|
proxyEnabled.value = proxy?.enabled ?? !!proxy?.url
|
||||||
}
|
}
|
||||||
|
|
||||||
// 使用 useFormDialog 统一处理对话框逻辑
|
// 使用 useFormDialog 统一处理对话框逻辑
|
||||||
@@ -282,12 +441,47 @@ const { isEditMode, handleDialogUpdate, handleCancel } = useFormDialog({
|
|||||||
resetForm,
|
resetForm,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// 构建代理配置
|
||||||
|
// - 有 URL 时始终保存配置,通过 enabled 字段控制是否启用
|
||||||
|
// - 无 URL 时返回 null
|
||||||
|
function buildProxyConfig(): { url: string; username?: string; password?: string; enabled: boolean } | null {
|
||||||
|
if (!form.value.proxy_url) {
|
||||||
|
// 没填 URL,无代理配置
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
url: form.value.proxy_url,
|
||||||
|
username: form.value.proxy_username || undefined,
|
||||||
|
password: form.value.proxy_password || undefined,
|
||||||
|
enabled: proxyEnabled.value, // 开关状态决定是否启用
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 提交表单
|
// 提交表单
|
||||||
const handleSubmit = async () => {
|
const handleSubmit = async (skipCredentialCheck = false) => {
|
||||||
if (!props.provider && !props.endpoint) return
|
if (!props.provider && !props.endpoint) return
|
||||||
|
|
||||||
|
// 只在开关开启且填写了 URL 时验证
|
||||||
|
if (proxyEnabled.value && form.value.proxy_url && proxyUrlError.value) {
|
||||||
|
showError(proxyUrlError.value, '代理配置错误')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查:开关开启但没有 URL,却有用户名或密码
|
||||||
|
const hasOrphanedCredentials = proxyEnabled.value
|
||||||
|
&& !form.value.proxy_url
|
||||||
|
&& (form.value.proxy_username || form.value.proxy_password)
|
||||||
|
|
||||||
|
if (hasOrphanedCredentials && !skipCredentialCheck) {
|
||||||
|
// 弹出确认对话框
|
||||||
|
showClearCredentialsDialog.value = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
loading.value = true
|
loading.value = true
|
||||||
try {
|
try {
|
||||||
|
const proxyConfig = buildProxyConfig()
|
||||||
|
|
||||||
if (isEditMode.value && props.endpoint) {
|
if (isEditMode.value && props.endpoint) {
|
||||||
// 更新端点
|
// 更新端点
|
||||||
await updateEndpoint(props.endpoint.id, {
|
await updateEndpoint(props.endpoint.id, {
|
||||||
@@ -297,7 +491,8 @@ const handleSubmit = async () => {
|
|||||||
max_retries: form.value.max_retries,
|
max_retries: form.value.max_retries,
|
||||||
max_concurrent: form.value.max_concurrent,
|
max_concurrent: form.value.max_concurrent,
|
||||||
rate_limit: form.value.rate_limit,
|
rate_limit: form.value.rate_limit,
|
||||||
is_active: form.value.is_active
|
is_active: form.value.is_active,
|
||||||
|
proxy: proxyConfig,
|
||||||
})
|
})
|
||||||
|
|
||||||
success('端点已更新', '保存成功')
|
success('端点已更新', '保存成功')
|
||||||
@@ -313,7 +508,8 @@ const handleSubmit = async () => {
|
|||||||
max_retries: form.value.max_retries,
|
max_retries: form.value.max_retries,
|
||||||
max_concurrent: form.value.max_concurrent,
|
max_concurrent: form.value.max_concurrent,
|
||||||
rate_limit: form.value.rate_limit,
|
rate_limit: form.value.rate_limit,
|
||||||
is_active: form.value.is_active
|
is_active: form.value.is_active,
|
||||||
|
proxy: proxyConfig,
|
||||||
})
|
})
|
||||||
|
|
||||||
success('端点创建成功', '成功')
|
success('端点创建成功', '成功')
|
||||||
@@ -329,4 +525,12 @@ const handleSubmit = async () => {
|
|||||||
loading.value = false
|
loading.value = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 确认清空凭据并继续保存
|
||||||
|
const confirmClearCredentials = () => {
|
||||||
|
form.value.proxy_username = ''
|
||||||
|
form.value.proxy_password = ''
|
||||||
|
showClearCredentialsDialog.value = false
|
||||||
|
handleSubmit(true) // 跳过凭据检查,直接提交
|
||||||
|
}
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
@@ -312,8 +312,41 @@
|
|||||||
|
|
||||||
<template #footer>
|
<template #footer>
|
||||||
<div class="flex items-center justify-between w-full">
|
<div class="flex items-center justify-between w-full">
|
||||||
<div class="text-xs text-muted-foreground">
|
<div class="flex items-center gap-4">
|
||||||
当前模式: <span class="font-medium">{{ activeMainTab === 'provider' ? '提供商优先' : 'Key 优先' }}</span>
|
<div class="text-xs text-muted-foreground">
|
||||||
|
当前模式: <span class="font-medium">{{ activeMainTab === 'provider' ? '提供商优先' : 'Key 优先' }}</span>
|
||||||
|
</div>
|
||||||
|
<div class="flex items-center gap-2 pl-4 border-l border-border">
|
||||||
|
<span class="text-xs text-muted-foreground">调度:</span>
|
||||||
|
<div class="flex gap-0.5 p-0.5 bg-muted/40 rounded-md">
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
class="px-2 py-1 text-xs font-medium rounded transition-all"
|
||||||
|
:class="[
|
||||||
|
schedulingMode === 'fixed_order'
|
||||||
|
? 'bg-primary text-primary-foreground shadow-sm'
|
||||||
|
: 'text-muted-foreground hover:text-foreground hover:bg-muted/50'
|
||||||
|
]"
|
||||||
|
title="严格按优先级顺序,不考虑缓存"
|
||||||
|
@click="schedulingMode = 'fixed_order'"
|
||||||
|
>
|
||||||
|
固定顺序
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
class="px-2 py-1 text-xs font-medium rounded transition-all"
|
||||||
|
:class="[
|
||||||
|
schedulingMode === 'cache_affinity'
|
||||||
|
? 'bg-primary text-primary-foreground shadow-sm'
|
||||||
|
: 'text-muted-foreground hover:text-foreground hover:bg-muted/50'
|
||||||
|
]"
|
||||||
|
title="优先使用已缓存的Provider,利用Prompt Cache"
|
||||||
|
@click="schedulingMode = 'cache_affinity'"
|
||||||
|
>
|
||||||
|
缓存亲和
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class="flex gap-2">
|
<div class="flex gap-2">
|
||||||
<Button
|
<Button
|
||||||
@@ -410,6 +443,9 @@ const saving = ref(false)
|
|||||||
// Key 优先级编辑状态
|
// Key 优先级编辑状态
|
||||||
const editingKeyPriority = ref<Record<string, string | null>>({}) // format -> keyId
|
const editingKeyPriority = ref<Record<string, string | null>>({}) // format -> keyId
|
||||||
|
|
||||||
|
// 调度模式状态
|
||||||
|
const schedulingMode = ref<'fixed_order' | 'cache_affinity'>('cache_affinity')
|
||||||
|
|
||||||
// 可用的 API 格式
|
// 可用的 API 格式
|
||||||
const availableFormats = computed(() => {
|
const availableFormats = computed(() => {
|
||||||
return Object.keys(keysByFormat.value).sort()
|
return Object.keys(keysByFormat.value).sort()
|
||||||
@@ -433,11 +469,18 @@ watch(internalOpen, async (open) => {
|
|||||||
// 加载当前的优先级模式配置
|
// 加载当前的优先级模式配置
|
||||||
async function loadCurrentPriorityMode() {
|
async function loadCurrentPriorityMode() {
|
||||||
try {
|
try {
|
||||||
const response = await adminApi.getSystemConfig('provider_priority_mode')
|
const [priorityResponse, schedulingResponse] = await Promise.all([
|
||||||
const currentMode = response.value || 'provider'
|
adminApi.getSystemConfig('provider_priority_mode'),
|
||||||
|
adminApi.getSystemConfig('scheduling_mode')
|
||||||
|
])
|
||||||
|
const currentMode = priorityResponse.value || 'provider'
|
||||||
activeMainTab.value = currentMode === 'global_key' ? 'key' : 'provider'
|
activeMainTab.value = currentMode === 'global_key' ? 'key' : 'provider'
|
||||||
|
|
||||||
|
const currentSchedulingMode = schedulingResponse.value || 'cache_affinity'
|
||||||
|
schedulingMode.value = currentSchedulingMode === 'fixed_order' ? 'fixed_order' : 'cache_affinity'
|
||||||
} catch {
|
} catch {
|
||||||
activeMainTab.value = 'provider'
|
activeMainTab.value = 'provider'
|
||||||
|
schedulingMode.value = 'cache_affinity'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -611,11 +654,19 @@ async function save() {
|
|||||||
|
|
||||||
const newMode = activeMainTab.value === 'key' ? 'global_key' : 'provider'
|
const newMode = activeMainTab.value === 'key' ? 'global_key' : 'provider'
|
||||||
|
|
||||||
await adminApi.updateSystemConfig(
|
// 保存优先级模式和调度模式
|
||||||
'provider_priority_mode',
|
await Promise.all([
|
||||||
newMode,
|
adminApi.updateSystemConfig(
|
||||||
'Provider/Key 优先级策略:provider(提供商优先模式) 或 global_key(全局Key优先模式)'
|
'provider_priority_mode',
|
||||||
)
|
newMode,
|
||||||
|
'Provider/Key 优先级策略:provider(提供商优先模式) 或 global_key(全局Key优先模式)'
|
||||||
|
),
|
||||||
|
adminApi.updateSystemConfig(
|
||||||
|
'scheduling_mode',
|
||||||
|
schedulingMode.value,
|
||||||
|
'调度模式:fixed_order(固定顺序模式) 或 cache_affinity(缓存亲和模式)'
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
const providerUpdates = sortedProviders.value.map((provider, index) =>
|
const providerUpdates = sortedProviders.value.map((provider, index) =>
|
||||||
updateProvider(provider.id, { provider_priority: index + 1 })
|
updateProvider(provider.id, { provider_priority: index + 1 })
|
||||||
|
|||||||
@@ -526,7 +526,14 @@
|
|||||||
@edit-model="handleEditModel"
|
@edit-model="handleEditModel"
|
||||||
@delete-model="handleDeleteModel"
|
@delete-model="handleDeleteModel"
|
||||||
@batch-assign="handleBatchAssign"
|
@batch-assign="handleBatchAssign"
|
||||||
@manage-alias="handleManageAlias"
|
/>
|
||||||
|
|
||||||
|
<!-- 模型名称映射 -->
|
||||||
|
<ModelAliasesTab
|
||||||
|
v-if="provider"
|
||||||
|
:key="`aliases-${provider.id}`"
|
||||||
|
:provider="provider"
|
||||||
|
@refresh="handleRelatedDataRefresh"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
@@ -629,16 +636,6 @@
|
|||||||
@update:open="batchAssignDialogOpen = $event"
|
@update:open="batchAssignDialogOpen = $event"
|
||||||
@changed="handleBatchAssignChanged"
|
@changed="handleBatchAssignChanged"
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<!-- 模型别名管理对话框 -->
|
|
||||||
<ModelAliasDialog
|
|
||||||
v-if="open && provider"
|
|
||||||
:open="aliasDialogOpen"
|
|
||||||
:provider-id="provider.id"
|
|
||||||
:model="aliasEditingModel"
|
|
||||||
@update:open="aliasDialogOpen = $event"
|
|
||||||
@saved="handleAliasSaved"
|
|
||||||
/>
|
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
@@ -667,8 +664,8 @@ import {
|
|||||||
KeyFormDialog,
|
KeyFormDialog,
|
||||||
KeyAllowedModelsDialog,
|
KeyAllowedModelsDialog,
|
||||||
ModelsTab,
|
ModelsTab,
|
||||||
BatchAssignModelsDialog,
|
ModelAliasesTab,
|
||||||
ModelAliasDialog
|
BatchAssignModelsDialog
|
||||||
} from '@/features/providers/components'
|
} from '@/features/providers/components'
|
||||||
import EndpointFormDialog from '@/features/providers/components/EndpointFormDialog.vue'
|
import EndpointFormDialog from '@/features/providers/components/EndpointFormDialog.vue'
|
||||||
import ProviderModelFormDialog from '@/features/providers/components/ProviderModelFormDialog.vue'
|
import ProviderModelFormDialog from '@/features/providers/components/ProviderModelFormDialog.vue'
|
||||||
@@ -737,10 +734,6 @@ const deleteModelConfirmOpen = ref(false)
|
|||||||
const modelToDelete = ref<Model | null>(null)
|
const modelToDelete = ref<Model | null>(null)
|
||||||
const batchAssignDialogOpen = ref(false)
|
const batchAssignDialogOpen = ref(false)
|
||||||
|
|
||||||
// 别名管理相关状态
|
|
||||||
const aliasDialogOpen = ref(false)
|
|
||||||
const aliasEditingModel = ref<Model | null>(null)
|
|
||||||
|
|
||||||
// 拖动排序相关状态
|
// 拖动排序相关状态
|
||||||
const dragState = ref({
|
const dragState = ref({
|
||||||
isDragging: false,
|
isDragging: false,
|
||||||
@@ -762,8 +755,7 @@ const hasBlockingDialogOpen = computed(() =>
|
|||||||
deleteKeyConfirmOpen.value ||
|
deleteKeyConfirmOpen.value ||
|
||||||
modelFormDialogOpen.value ||
|
modelFormDialogOpen.value ||
|
||||||
deleteModelConfirmOpen.value ||
|
deleteModelConfirmOpen.value ||
|
||||||
batchAssignDialogOpen.value ||
|
batchAssignDialogOpen.value
|
||||||
aliasDialogOpen.value
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// 监听 providerId 变化
|
// 监听 providerId 变化
|
||||||
@@ -792,7 +784,6 @@ watch(() => props.open, (newOpen) => {
|
|||||||
keyAllowedModelsDialogOpen.value = false
|
keyAllowedModelsDialogOpen.value = false
|
||||||
deleteKeyConfirmOpen.value = false
|
deleteKeyConfirmOpen.value = false
|
||||||
batchAssignDialogOpen.value = false
|
batchAssignDialogOpen.value = false
|
||||||
aliasDialogOpen.value = false
|
|
||||||
|
|
||||||
// 重置临时数据
|
// 重置临时数据
|
||||||
endpointToEdit.value = null
|
endpointToEdit.value = null
|
||||||
@@ -1030,19 +1021,6 @@ async function handleBatchAssignChanged() {
|
|||||||
emit('refresh')
|
emit('refresh')
|
||||||
}
|
}
|
||||||
|
|
||||||
// 处理管理映射 - 打开别名对话框
|
|
||||||
function handleManageAlias(model: Model) {
|
|
||||||
aliasEditingModel.value = model
|
|
||||||
aliasDialogOpen.value = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// 处理别名保存完成
|
|
||||||
async function handleAliasSaved() {
|
|
||||||
aliasEditingModel.value = null
|
|
||||||
await loadProvider()
|
|
||||||
emit('refresh')
|
|
||||||
}
|
|
||||||
|
|
||||||
// 处理模型保存完成
|
// 处理模型保存完成
|
||||||
async function handleModelSaved() {
|
async function handleModelSaved() {
|
||||||
editingModel.value = null
|
editingModel.value = null
|
||||||
|
|||||||
@@ -10,3 +10,4 @@ export { default as BatchAssignModelsDialog } from './BatchAssignModelsDialog.vu
|
|||||||
export { default as ModelAliasDialog } from './ModelAliasDialog.vue'
|
export { default as ModelAliasDialog } from './ModelAliasDialog.vue'
|
||||||
|
|
||||||
export { default as ModelsTab } from './provider-tabs/ModelsTab.vue'
|
export { default as ModelsTab } from './provider-tabs/ModelsTab.vue'
|
||||||
|
export { default as ModelAliasesTab } from './provider-tabs/ModelAliasesTab.vue'
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -165,15 +165,6 @@
|
|||||||
>
|
>
|
||||||
<Edit class="w-3.5 h-3.5" />
|
<Edit class="w-3.5 h-3.5" />
|
||||||
</Button>
|
</Button>
|
||||||
<Button
|
|
||||||
variant="ghost"
|
|
||||||
size="icon"
|
|
||||||
class="h-8 w-8"
|
|
||||||
title="管理映射"
|
|
||||||
@click="openAliasDialog(model)"
|
|
||||||
>
|
|
||||||
<Tag class="w-3.5 h-3.5" />
|
|
||||||
</Button>
|
|
||||||
<Button
|
<Button
|
||||||
variant="ghost"
|
variant="ghost"
|
||||||
size="icon"
|
size="icon"
|
||||||
@@ -218,7 +209,7 @@
|
|||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, computed, onMounted } from 'vue'
|
import { ref, computed, onMounted } from 'vue'
|
||||||
import { Box, Edit, Trash2, Layers, Eye, Wrench, Zap, Brain, Power, Copy, Image, Tag } from 'lucide-vue-next'
|
import { Box, Edit, Trash2, Layers, Eye, Wrench, Zap, Brain, Power, Copy, Image } from 'lucide-vue-next'
|
||||||
import Card from '@/components/ui/card.vue'
|
import Card from '@/components/ui/card.vue'
|
||||||
import Button from '@/components/ui/button.vue'
|
import Button from '@/components/ui/button.vue'
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
@@ -233,7 +224,6 @@ const emit = defineEmits<{
|
|||||||
'editModel': [model: Model]
|
'editModel': [model: Model]
|
||||||
'deleteModel': [model: Model]
|
'deleteModel': [model: Model]
|
||||||
'batchAssign': []
|
'batchAssign': []
|
||||||
'manageAlias': [model: Model]
|
|
||||||
}>()
|
}>()
|
||||||
|
|
||||||
const { error: showError, success: showSuccess } = useToast()
|
const { error: showError, success: showSuccess } = useToast()
|
||||||
@@ -373,11 +363,6 @@ function openBatchAssignDialog() {
|
|||||||
emit('batchAssign')
|
emit('batchAssign')
|
||||||
}
|
}
|
||||||
|
|
||||||
// 打开别名管理对话框
|
|
||||||
function openAliasDialog(model: Model) {
|
|
||||||
emit('manageAlias', model)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 切换模型启用状态
|
// 切换模型启用状态
|
||||||
async function toggleModelActive(model: Model) {
|
async function toggleModelActive(model: Model) {
|
||||||
if (togglingModelId.value) return
|
if (togglingModelId.value) return
|
||||||
|
|||||||
@@ -25,7 +25,7 @@
|
|||||||
</h3>
|
</h3>
|
||||||
<div class="flex items-center gap-1 text-sm font-mono text-muted-foreground bg-muted px-2 py-0.5 rounded">
|
<div class="flex items-center gap-1 text-sm font-mono text-muted-foreground bg-muted px-2 py-0.5 rounded">
|
||||||
<span>{{ detail?.model || '-' }}</span>
|
<span>{{ detail?.model || '-' }}</span>
|
||||||
<template v-if="detail?.target_model">
|
<template v-if="detail?.target_model && detail.target_model !== detail.model">
|
||||||
<svg
|
<svg
|
||||||
xmlns="http://www.w3.org/2000/svg"
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
viewBox="0 0 20 20"
|
viewBox="0 0 20 20"
|
||||||
|
|||||||
@@ -751,15 +751,13 @@ const expiringSoonCount = computed(() => apiKeys.value.filter(key => isExpiringS
|
|||||||
const filteredApiKeys = computed(() => {
|
const filteredApiKeys = computed(() => {
|
||||||
let result = apiKeys.value
|
let result = apiKeys.value
|
||||||
|
|
||||||
// 搜索筛选
|
// 搜索筛选(支持空格分隔的多关键词 AND 搜索)
|
||||||
if (searchQuery.value) {
|
if (searchQuery.value) {
|
||||||
const query = searchQuery.value.toLowerCase()
|
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||||
result = result.filter(key =>
|
result = result.filter(key => {
|
||||||
(key.name && key.name.toLowerCase().includes(query)) ||
|
const searchableText = `${key.name || ''} ${key.key_display || ''} ${key.username || ''} ${key.user_email || ''}`.toLowerCase()
|
||||||
(key.key_display && key.key_display.toLowerCase().includes(query)) ||
|
return keywords.every(keyword => searchableText.includes(keyword))
|
||||||
(key.username && key.username.toLowerCase().includes(query)) ||
|
})
|
||||||
(key.user_email && key.user_email.toLowerCase().includes(query))
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 状态筛选
|
// 状态筛选
|
||||||
|
|||||||
@@ -1002,13 +1002,13 @@ async function batchRemoveSelectedProviders() {
|
|||||||
const filteredGlobalModels = computed(() => {
|
const filteredGlobalModels = computed(() => {
|
||||||
let result = globalModels.value
|
let result = globalModels.value
|
||||||
|
|
||||||
// 搜索
|
// 搜索(支持空格分隔的多关键词 AND 搜索)
|
||||||
if (searchQuery.value) {
|
if (searchQuery.value) {
|
||||||
const query = searchQuery.value.toLowerCase()
|
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||||
result = result.filter(m =>
|
result = result.filter(m => {
|
||||||
m.name.toLowerCase().includes(query) ||
|
const searchableText = `${m.name} ${m.display_name || ''}`.toLowerCase()
|
||||||
m.display_name?.toLowerCase().includes(query)
|
return keywords.every(keyword => searchableText.includes(keyword))
|
||||||
)
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 能力筛选
|
// 能力筛选
|
||||||
|
|||||||
@@ -505,13 +505,13 @@ const priorityModeConfig = computed(() => {
|
|||||||
const filteredProviders = computed(() => {
|
const filteredProviders = computed(() => {
|
||||||
let result = [...providers.value]
|
let result = [...providers.value]
|
||||||
|
|
||||||
// 搜索筛选
|
// 搜索筛选(支持空格分隔的多关键词 AND 搜索)
|
||||||
if (searchQuery.value.trim()) {
|
if (searchQuery.value.trim()) {
|
||||||
const query = searchQuery.value.toLowerCase()
|
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||||
result = result.filter(p =>
|
result = result.filter(p => {
|
||||||
p.display_name.toLowerCase().includes(query) ||
|
const searchableText = `${p.display_name} ${p.name}`.toLowerCase()
|
||||||
p.name.toLowerCase().includes(query)
|
return keywords.every(keyword => searchableText.includes(keyword))
|
||||||
)
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 排序
|
// 排序
|
||||||
|
|||||||
@@ -7,6 +7,7 @@
|
|||||||
<template #actions>
|
<template #actions>
|
||||||
<Button
|
<Button
|
||||||
:disabled="loading"
|
:disabled="loading"
|
||||||
|
class="shadow-none hover:shadow-none"
|
||||||
@click="saveSystemConfig"
|
@click="saveSystemConfig"
|
||||||
>
|
>
|
||||||
{{ loading ? '保存中...' : '保存所有配置' }}
|
{{ loading ? '保存中...' : '保存所有配置' }}
|
||||||
@@ -184,32 +185,13 @@
|
|||||||
</div>
|
</div>
|
||||||
</CardSection>
|
</CardSection>
|
||||||
|
|
||||||
<!-- API Key 管理配置 -->
|
<!-- 独立余额 Key 过期管理 -->
|
||||||
<CardSection
|
<CardSection
|
||||||
title="API Key 管理"
|
title="独立余额 Key 过期管理"
|
||||||
description="API Key 相关配置"
|
description="独立余额 Key 的过期处理策略(普通用户 Key 不会过期)"
|
||||||
>
|
>
|
||||||
<div class="grid grid-cols-1 md:grid-cols-2 gap-6">
|
<div class="grid grid-cols-1 md:grid-cols-2 gap-6">
|
||||||
<div>
|
<div class="flex items-center h-full">
|
||||||
<Label
|
|
||||||
for="api-key-expire"
|
|
||||||
class="block text-sm font-medium"
|
|
||||||
>
|
|
||||||
API密钥过期天数
|
|
||||||
</Label>
|
|
||||||
<Input
|
|
||||||
id="api-key-expire"
|
|
||||||
v-model.number="systemConfig.api_key_expire_days"
|
|
||||||
type="number"
|
|
||||||
placeholder="0"
|
|
||||||
class="mt-1"
|
|
||||||
/>
|
|
||||||
<p class="mt-1 text-xs text-muted-foreground">
|
|
||||||
0 表示永不过期
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="flex items-center h-full pt-6">
|
|
||||||
<div class="flex items-center space-x-2">
|
<div class="flex items-center space-x-2">
|
||||||
<Checkbox
|
<Checkbox
|
||||||
id="auto-delete-expired-keys"
|
id="auto-delete-expired-keys"
|
||||||
@@ -223,7 +205,7 @@
|
|||||||
自动删除过期 Key
|
自动删除过期 Key
|
||||||
</Label>
|
</Label>
|
||||||
<p class="text-xs text-muted-foreground">
|
<p class="text-xs text-muted-foreground">
|
||||||
关闭时仅禁用过期 Key
|
关闭时仅禁用过期 Key,不会物理删除
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -447,6 +429,25 @@
|
|||||||
避免单次操作过大影响性能
|
避免单次操作过大影响性能
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<Label
|
||||||
|
for="audit-log-retention-days"
|
||||||
|
class="block text-sm font-medium"
|
||||||
|
>
|
||||||
|
审计日志保留天数
|
||||||
|
</Label>
|
||||||
|
<Input
|
||||||
|
id="audit-log-retention-days"
|
||||||
|
v-model.number="systemConfig.audit_log_retention_days"
|
||||||
|
type="number"
|
||||||
|
placeholder="30"
|
||||||
|
class="mt-1"
|
||||||
|
/>
|
||||||
|
<p class="mt-1 text-xs text-muted-foreground">
|
||||||
|
超过后删除审计日志记录
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- 清理策略说明 -->
|
<!-- 清理策略说明 -->
|
||||||
@@ -459,323 +460,310 @@
|
|||||||
<p>2. <strong>压缩日志阶段</strong>: body 字段被压缩存储,节省空间</p>
|
<p>2. <strong>压缩日志阶段</strong>: body 字段被压缩存储,节省空间</p>
|
||||||
<p>3. <strong>统计阶段</strong>: 仅保留 tokens、成本等统计信息</p>
|
<p>3. <strong>统计阶段</strong>: 仅保留 tokens、成本等统计信息</p>
|
||||||
<p>4. <strong>归档删除</strong>: 超过保留期限后完全删除记录</p>
|
<p>4. <strong>归档删除</strong>: 超过保留期限后完全删除记录</p>
|
||||||
|
<p>5. <strong>审计日志</strong>: 独立清理,记录用户登录、操作等安全事件</p>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</CardSection>
|
</CardSection>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- 导入配置对话框 -->
|
<!-- 导入配置对话框 -->
|
||||||
<Dialog v-model:open="importDialogOpen">
|
<Dialog
|
||||||
<DialogContent class="max-w-lg">
|
v-model:open="importDialogOpen"
|
||||||
<DialogHeader>
|
title="导入配置"
|
||||||
<DialogTitle>导入配置</DialogTitle>
|
description="选择冲突处理模式并确认导入"
|
||||||
<DialogDescription>
|
>
|
||||||
选择冲突处理模式并确认导入
|
<div class="space-y-4">
|
||||||
</DialogDescription>
|
<div
|
||||||
</DialogHeader>
|
v-if="importPreview"
|
||||||
|
class="p-3 bg-muted rounded-lg text-sm"
|
||||||
<div class="space-y-4 py-4">
|
>
|
||||||
<div
|
<p class="font-medium mb-2">
|
||||||
v-if="importPreview"
|
配置预览
|
||||||
class="p-3 bg-muted rounded-lg text-sm"
|
</p>
|
||||||
>
|
<ul class="space-y-1 text-muted-foreground">
|
||||||
<p class="font-medium mb-2">
|
<li>全局模型: {{ importPreview.global_models?.length || 0 }} 个</li>
|
||||||
配置预览
|
<li>提供商: {{ importPreview.providers?.length || 0 }} 个</li>
|
||||||
</p>
|
<li>
|
||||||
<ul class="space-y-1 text-muted-foreground">
|
端点: {{ importPreview.providers?.reduce((sum: number, p: any) => sum + (p.endpoints?.length || 0), 0) }} 个
|
||||||
<li>全局模型: {{ importPreview.global_models?.length || 0 }} 个</li>
|
</li>
|
||||||
<li>提供商: {{ importPreview.providers?.length || 0 }} 个</li>
|
<li>
|
||||||
<li>
|
API Keys: {{ importPreview.providers?.reduce((sum: number, p: any) => sum + p.endpoints?.reduce((s: number, e: any) => s + (e.keys?.length || 0), 0), 0) }} 个
|
||||||
端点: {{ importPreview.providers?.reduce((sum: number, p: any) => sum + (p.endpoints?.length || 0), 0) }} 个
|
</li>
|
||||||
</li>
|
</ul>
|
||||||
<li>
|
|
||||||
API Keys: {{ importPreview.providers?.reduce((sum: number, p: any) => sum + p.endpoints?.reduce((s: number, e: any) => s + (e.keys?.length || 0), 0), 0) }} 个
|
|
||||||
</li>
|
|
||||||
</ul>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div>
|
|
||||||
<Label class="block text-sm font-medium mb-2">冲突处理模式</Label>
|
|
||||||
<Select v-model="mergeMode">
|
|
||||||
<SelectTrigger>
|
|
||||||
<SelectValue />
|
|
||||||
</SelectTrigger>
|
|
||||||
<SelectContent>
|
|
||||||
<SelectItem value="skip">
|
|
||||||
跳过 - 保留现有配置
|
|
||||||
</SelectItem>
|
|
||||||
<SelectItem value="overwrite">
|
|
||||||
覆盖 - 用导入配置替换
|
|
||||||
</SelectItem>
|
|
||||||
<SelectItem value="error">
|
|
||||||
报错 - 遇到冲突时中止
|
|
||||||
</SelectItem>
|
|
||||||
</SelectContent>
|
|
||||||
</Select>
|
|
||||||
<p class="mt-1 text-xs text-muted-foreground">
|
|
||||||
<template v-if="mergeMode === 'skip'">
|
|
||||||
已存在的配置将被保留,仅导入新配置
|
|
||||||
</template>
|
|
||||||
<template v-else-if="mergeMode === 'overwrite'">
|
|
||||||
已存在的配置将被导入的配置覆盖
|
|
||||||
</template>
|
|
||||||
<template v-else>
|
|
||||||
如果发现任何冲突,导入将中止并回滚
|
|
||||||
</template>
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="p-3 bg-yellow-500/10 border border-yellow-500/20 rounded-lg">
|
|
||||||
<p class="text-sm text-yellow-600 dark:text-yellow-400">
|
|
||||||
注意:相同的 API Keys 会自动跳过,不会创建重复记录。
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<DialogFooter>
|
<div>
|
||||||
<Button
|
<Label class="block text-sm font-medium mb-2">冲突处理模式</Label>
|
||||||
variant="outline"
|
<Select
|
||||||
@click="importDialogOpen = false"
|
v-model="mergeMode"
|
||||||
|
v-model:open="mergeModeSelectOpen"
|
||||||
>
|
>
|
||||||
取消
|
<SelectTrigger>
|
||||||
</Button>
|
<SelectValue />
|
||||||
<Button
|
</SelectTrigger>
|
||||||
:disabled="importLoading"
|
<SelectContent>
|
||||||
@click="confirmImport"
|
<SelectItem value="skip">
|
||||||
>
|
跳过 - 保留现有配置
|
||||||
{{ importLoading ? '导入中...' : '确认导入' }}
|
</SelectItem>
|
||||||
</Button>
|
<SelectItem value="overwrite">
|
||||||
</DialogFooter>
|
覆盖 - 用导入配置替换
|
||||||
</DialogContent>
|
</SelectItem>
|
||||||
|
<SelectItem value="error">
|
||||||
|
报错 - 遇到冲突时中止
|
||||||
|
</SelectItem>
|
||||||
|
</SelectContent>
|
||||||
|
</Select>
|
||||||
|
<p class="mt-1 text-xs text-muted-foreground">
|
||||||
|
<template v-if="mergeMode === 'skip'">
|
||||||
|
已存在的配置将被保留,仅导入新配置
|
||||||
|
</template>
|
||||||
|
<template v-else-if="mergeMode === 'overwrite'">
|
||||||
|
已存在的配置将被导入的配置覆盖
|
||||||
|
</template>
|
||||||
|
<template v-else>
|
||||||
|
如果发现任何冲突,导入将中止并回滚
|
||||||
|
</template>
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<p class="text-xs text-muted-foreground">
|
||||||
|
注意:相同的 API Keys 会自动跳过,不会创建重复记录。
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<template #footer>
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
@click="importDialogOpen = false; mergeModeSelectOpen = false"
|
||||||
|
>
|
||||||
|
取消
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
:disabled="importLoading"
|
||||||
|
@click="confirmImport"
|
||||||
|
>
|
||||||
|
{{ importLoading ? '导入中...' : '确认导入' }}
|
||||||
|
</Button>
|
||||||
|
</template>
|
||||||
</Dialog>
|
</Dialog>
|
||||||
|
|
||||||
<!-- 导入结果对话框 -->
|
<!-- 导入结果对话框 -->
|
||||||
<Dialog v-model:open="importResultDialogOpen">
|
<Dialog
|
||||||
<DialogContent class="max-w-lg">
|
v-model:open="importResultDialogOpen"
|
||||||
<DialogHeader>
|
title="导入完成"
|
||||||
<DialogTitle>导入完成</DialogTitle>
|
>
|
||||||
</DialogHeader>
|
<div
|
||||||
|
v-if="importResult"
|
||||||
<div
|
class="space-y-4"
|
||||||
v-if="importResult"
|
>
|
||||||
class="space-y-4 py-4"
|
<div class="grid grid-cols-2 gap-4 text-sm">
|
||||||
>
|
<div class="p-3 bg-muted rounded-lg">
|
||||||
<div class="grid grid-cols-2 gap-4 text-sm">
|
<p class="font-medium">
|
||||||
<div class="p-3 bg-muted rounded-lg">
|
全局模型
|
||||||
<p class="font-medium">
|
</p>
|
||||||
全局模型
|
<p class="text-muted-foreground">
|
||||||
</p>
|
创建: {{ importResult.stats.global_models.created }},
|
||||||
<p class="text-muted-foreground">
|
更新: {{ importResult.stats.global_models.updated }},
|
||||||
创建: {{ importResult.stats.global_models.created }},
|
跳过: {{ importResult.stats.global_models.skipped }}
|
||||||
更新: {{ importResult.stats.global_models.updated }},
|
</p>
|
||||||
跳过: {{ importResult.stats.global_models.skipped }}
|
</div>
|
||||||
</p>
|
<div class="p-3 bg-muted rounded-lg">
|
||||||
</div>
|
<p class="font-medium">
|
||||||
<div class="p-3 bg-muted rounded-lg">
|
提供商
|
||||||
<p class="font-medium">
|
</p>
|
||||||
提供商
|
<p class="text-muted-foreground">
|
||||||
</p>
|
创建: {{ importResult.stats.providers.created }},
|
||||||
<p class="text-muted-foreground">
|
更新: {{ importResult.stats.providers.updated }},
|
||||||
创建: {{ importResult.stats.providers.created }},
|
跳过: {{ importResult.stats.providers.skipped }}
|
||||||
更新: {{ importResult.stats.providers.updated }},
|
</p>
|
||||||
跳过: {{ importResult.stats.providers.skipped }}
|
</div>
|
||||||
</p>
|
<div class="p-3 bg-muted rounded-lg">
|
||||||
</div>
|
<p class="font-medium">
|
||||||
<div class="p-3 bg-muted rounded-lg">
|
端点
|
||||||
<p class="font-medium">
|
</p>
|
||||||
端点
|
<p class="text-muted-foreground">
|
||||||
</p>
|
创建: {{ importResult.stats.endpoints.created }},
|
||||||
<p class="text-muted-foreground">
|
更新: {{ importResult.stats.endpoints.updated }},
|
||||||
创建: {{ importResult.stats.endpoints.created }},
|
跳过: {{ importResult.stats.endpoints.skipped }}
|
||||||
更新: {{ importResult.stats.endpoints.updated }},
|
</p>
|
||||||
跳过: {{ importResult.stats.endpoints.skipped }}
|
</div>
|
||||||
</p>
|
<div class="p-3 bg-muted rounded-lg">
|
||||||
</div>
|
<p class="font-medium">
|
||||||
<div class="p-3 bg-muted rounded-lg">
|
API Keys
|
||||||
<p class="font-medium">
|
</p>
|
||||||
API Keys
|
<p class="text-muted-foreground">
|
||||||
</p>
|
创建: {{ importResult.stats.keys.created }},
|
||||||
<p class="text-muted-foreground">
|
跳过: {{ importResult.stats.keys.skipped }}
|
||||||
创建: {{ importResult.stats.keys.created }},
|
</p>
|
||||||
跳过: {{ importResult.stats.keys.skipped }}
|
</div>
|
||||||
</p>
|
<div class="p-3 bg-muted rounded-lg col-span-2">
|
||||||
</div>
|
<p class="font-medium">
|
||||||
<div class="p-3 bg-muted rounded-lg col-span-2">
|
模型配置
|
||||||
<p class="font-medium">
|
</p>
|
||||||
模型配置
|
<p class="text-muted-foreground">
|
||||||
</p>
|
创建: {{ importResult.stats.models.created }},
|
||||||
<p class="text-muted-foreground">
|
更新: {{ importResult.stats.models.updated }},
|
||||||
创建: {{ importResult.stats.models.created }},
|
跳过: {{ importResult.stats.models.skipped }}
|
||||||
更新: {{ importResult.stats.models.updated }},
|
|
||||||
跳过: {{ importResult.stats.models.skipped }}
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div
|
|
||||||
v-if="importResult.stats.errors.length > 0"
|
|
||||||
class="p-3 bg-red-500/10 border border-red-500/20 rounded-lg"
|
|
||||||
>
|
|
||||||
<p class="font-medium text-red-600 dark:text-red-400 mb-2">
|
|
||||||
警告信息
|
|
||||||
</p>
|
</p>
|
||||||
<ul class="text-sm text-red-600 dark:text-red-400 space-y-1">
|
|
||||||
<li
|
|
||||||
v-for="(err, index) in importResult.stats.errors"
|
|
||||||
:key="index"
|
|
||||||
>
|
|
||||||
{{ err }}
|
|
||||||
</li>
|
|
||||||
</ul>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<DialogFooter>
|
<div
|
||||||
<Button @click="importResultDialogOpen = false">
|
v-if="importResult.stats.errors.length > 0"
|
||||||
确定
|
class="p-3 bg-destructive/10 rounded-lg"
|
||||||
</Button>
|
>
|
||||||
</DialogFooter>
|
<p class="font-medium text-destructive mb-2">
|
||||||
</DialogContent>
|
警告信息
|
||||||
|
</p>
|
||||||
|
<ul class="text-sm text-destructive space-y-1">
|
||||||
|
<li
|
||||||
|
v-for="(err, index) in importResult.stats.errors"
|
||||||
|
:key="index"
|
||||||
|
>
|
||||||
|
{{ err }}
|
||||||
|
</li>
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<template #footer>
|
||||||
|
<Button @click="importResultDialogOpen = false">
|
||||||
|
确定
|
||||||
|
</Button>
|
||||||
|
</template>
|
||||||
</Dialog>
|
</Dialog>
|
||||||
|
|
||||||
<!-- 用户数据导入对话框 -->
|
<!-- 用户数据导入对话框 -->
|
||||||
<Dialog v-model:open="importUsersDialogOpen">
|
<Dialog
|
||||||
<DialogContent class="max-w-lg">
|
v-model:open="importUsersDialogOpen"
|
||||||
<DialogHeader>
|
title="导入用户数据"
|
||||||
<DialogTitle>导入用户数据</DialogTitle>
|
description="选择冲突处理模式并确认导入"
|
||||||
<DialogDescription>
|
>
|
||||||
选择冲突处理模式并确认导入
|
<div class="space-y-4">
|
||||||
</DialogDescription>
|
<div
|
||||||
</DialogHeader>
|
v-if="importUsersPreview"
|
||||||
|
class="p-3 bg-muted rounded-lg text-sm"
|
||||||
<div class="space-y-4 py-4">
|
>
|
||||||
<div
|
<p class="font-medium mb-2">
|
||||||
v-if="importUsersPreview"
|
数据预览
|
||||||
class="p-3 bg-muted rounded-lg text-sm"
|
</p>
|
||||||
>
|
<ul class="space-y-1 text-muted-foreground">
|
||||||
<p class="font-medium mb-2">
|
<li>用户: {{ importUsersPreview.users?.length || 0 }} 个</li>
|
||||||
数据预览
|
<li>
|
||||||
</p>
|
API Keys: {{ importUsersPreview.users?.reduce((sum: number, u: any) => sum + (u.api_keys?.length || 0), 0) }} 个
|
||||||
<ul class="space-y-1 text-muted-foreground">
|
</li>
|
||||||
<li>用户: {{ importUsersPreview.users?.length || 0 }} 个</li>
|
</ul>
|
||||||
<li>
|
|
||||||
API Keys: {{ importUsersPreview.users?.reduce((sum: number, u: any) => sum + (u.api_keys?.length || 0), 0) }} 个
|
|
||||||
</li>
|
|
||||||
</ul>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div>
|
|
||||||
<Label class="block text-sm font-medium mb-2">冲突处理模式</Label>
|
|
||||||
<Select v-model="usersMergeMode">
|
|
||||||
<SelectTrigger>
|
|
||||||
<SelectValue />
|
|
||||||
</SelectTrigger>
|
|
||||||
<SelectContent>
|
|
||||||
<SelectItem value="skip">
|
|
||||||
跳过 - 保留现有用户
|
|
||||||
</SelectItem>
|
|
||||||
<SelectItem value="overwrite">
|
|
||||||
覆盖 - 用导入数据替换
|
|
||||||
</SelectItem>
|
|
||||||
<SelectItem value="error">
|
|
||||||
报错 - 遇到冲突时中止
|
|
||||||
</SelectItem>
|
|
||||||
</SelectContent>
|
|
||||||
</Select>
|
|
||||||
<p class="mt-1 text-xs text-muted-foreground">
|
|
||||||
<template v-if="usersMergeMode === 'skip'">
|
|
||||||
已存在的用户将被保留,仅导入新用户
|
|
||||||
</template>
|
|
||||||
<template v-else-if="usersMergeMode === 'overwrite'">
|
|
||||||
已存在的用户将被导入的数据覆盖
|
|
||||||
</template>
|
|
||||||
<template v-else>
|
|
||||||
如果发现任何冲突,导入将中止并回滚
|
|
||||||
</template>
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="p-3 bg-yellow-500/10 border border-yellow-500/20 rounded-lg">
|
|
||||||
<p class="text-sm text-yellow-600 dark:text-yellow-400">
|
|
||||||
注意:用户 API Keys 需要目标系统使用相同的 ENCRYPTION_KEY 环境变量才能正常工作。
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<DialogFooter>
|
<div>
|
||||||
<Button
|
<Label class="block text-sm font-medium mb-2">冲突处理模式</Label>
|
||||||
variant="outline"
|
<Select
|
||||||
@click="importUsersDialogOpen = false"
|
v-model="usersMergeMode"
|
||||||
|
v-model:open="usersMergeModeSelectOpen"
|
||||||
>
|
>
|
||||||
取消
|
<SelectTrigger>
|
||||||
</Button>
|
<SelectValue />
|
||||||
<Button
|
</SelectTrigger>
|
||||||
:disabled="importUsersLoading"
|
<SelectContent>
|
||||||
@click="confirmImportUsers"
|
<SelectItem value="skip">
|
||||||
>
|
跳过 - 保留现有用户
|
||||||
{{ importUsersLoading ? '导入中...' : '确认导入' }}
|
</SelectItem>
|
||||||
</Button>
|
<SelectItem value="overwrite">
|
||||||
</DialogFooter>
|
覆盖 - 用导入数据替换
|
||||||
</DialogContent>
|
</SelectItem>
|
||||||
|
<SelectItem value="error">
|
||||||
|
报错 - 遇到冲突时中止
|
||||||
|
</SelectItem>
|
||||||
|
</SelectContent>
|
||||||
|
</Select>
|
||||||
|
<p class="mt-1 text-xs text-muted-foreground">
|
||||||
|
<template v-if="usersMergeMode === 'skip'">
|
||||||
|
已存在的用户将被保留,仅导入新用户
|
||||||
|
</template>
|
||||||
|
<template v-else-if="usersMergeMode === 'overwrite'">
|
||||||
|
已存在的用户将被导入的数据覆盖
|
||||||
|
</template>
|
||||||
|
<template v-else>
|
||||||
|
如果发现任何冲突,导入将中止并回滚
|
||||||
|
</template>
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<p class="text-xs text-muted-foreground">
|
||||||
|
注意:用户 API Keys 需要目标系统使用相同的 ENCRYPTION_KEY 环境变量才能正常工作。
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<template #footer>
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
@click="importUsersDialogOpen = false; usersMergeModeSelectOpen = false"
|
||||||
|
>
|
||||||
|
取消
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
:disabled="importUsersLoading"
|
||||||
|
@click="confirmImportUsers"
|
||||||
|
>
|
||||||
|
{{ importUsersLoading ? '导入中...' : '确认导入' }}
|
||||||
|
</Button>
|
||||||
|
</template>
|
||||||
</Dialog>
|
</Dialog>
|
||||||
|
|
||||||
<!-- 用户数据导入结果对话框 -->
|
<!-- 用户数据导入结果对话框 -->
|
||||||
<Dialog v-model:open="importUsersResultDialogOpen">
|
<Dialog
|
||||||
<DialogContent class="max-w-lg">
|
v-model:open="importUsersResultDialogOpen"
|
||||||
<DialogHeader>
|
title="用户数据导入完成"
|
||||||
<DialogTitle>用户数据导入完成</DialogTitle>
|
>
|
||||||
</DialogHeader>
|
<div
|
||||||
|
v-if="importUsersResult"
|
||||||
<div
|
class="space-y-4"
|
||||||
v-if="importUsersResult"
|
>
|
||||||
class="space-y-4 py-4"
|
<div class="grid grid-cols-2 gap-4 text-sm">
|
||||||
>
|
<div class="p-3 bg-muted rounded-lg">
|
||||||
<div class="grid grid-cols-2 gap-4 text-sm">
|
<p class="font-medium">
|
||||||
<div class="p-3 bg-muted rounded-lg">
|
用户
|
||||||
<p class="font-medium">
|
</p>
|
||||||
用户
|
<p class="text-muted-foreground">
|
||||||
</p>
|
创建: {{ importUsersResult.stats.users.created }},
|
||||||
<p class="text-muted-foreground">
|
更新: {{ importUsersResult.stats.users.updated }},
|
||||||
创建: {{ importUsersResult.stats.users.created }},
|
跳过: {{ importUsersResult.stats.users.skipped }}
|
||||||
更新: {{ importUsersResult.stats.users.updated }},
|
</p>
|
||||||
跳过: {{ importUsersResult.stats.users.skipped }}
|
</div>
|
||||||
</p>
|
<div class="p-3 bg-muted rounded-lg">
|
||||||
</div>
|
<p class="font-medium">
|
||||||
<div class="p-3 bg-muted rounded-lg">
|
API Keys
|
||||||
<p class="font-medium">
|
</p>
|
||||||
API Keys
|
<p class="text-muted-foreground">
|
||||||
</p>
|
创建: {{ importUsersResult.stats.api_keys.created }},
|
||||||
<p class="text-muted-foreground">
|
跳过: {{ importUsersResult.stats.api_keys.skipped }}
|
||||||
创建: {{ importUsersResult.stats.api_keys.created }},
|
|
||||||
跳过: {{ importUsersResult.stats.api_keys.skipped }}
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div
|
|
||||||
v-if="importUsersResult.stats.errors.length > 0"
|
|
||||||
class="p-3 bg-red-500/10 border border-red-500/20 rounded-lg"
|
|
||||||
>
|
|
||||||
<p class="font-medium text-red-600 dark:text-red-400 mb-2">
|
|
||||||
警告信息
|
|
||||||
</p>
|
</p>
|
||||||
<ul class="text-sm text-red-600 dark:text-red-400 space-y-1">
|
|
||||||
<li
|
|
||||||
v-for="(err, index) in importUsersResult.stats.errors"
|
|
||||||
:key="index"
|
|
||||||
>
|
|
||||||
{{ err }}
|
|
||||||
</li>
|
|
||||||
</ul>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<DialogFooter>
|
<div
|
||||||
<Button @click="importUsersResultDialogOpen = false">
|
v-if="importUsersResult.stats.errors.length > 0"
|
||||||
确定
|
class="p-3 bg-destructive/10 rounded-lg"
|
||||||
</Button>
|
>
|
||||||
</DialogFooter>
|
<p class="font-medium text-destructive mb-2">
|
||||||
</DialogContent>
|
警告信息
|
||||||
|
</p>
|
||||||
|
<ul class="text-sm text-destructive space-y-1">
|
||||||
|
<li
|
||||||
|
v-for="(err, index) in importUsersResult.stats.errors"
|
||||||
|
:key="index"
|
||||||
|
>
|
||||||
|
{{ err }}
|
||||||
|
</li>
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<template #footer>
|
||||||
|
<Button @click="importUsersResultDialogOpen = false">
|
||||||
|
确定
|
||||||
|
</Button>
|
||||||
|
</template>
|
||||||
</Dialog>
|
</Dialog>
|
||||||
</PageContainer>
|
</PageContainer>
|
||||||
</template>
|
</template>
|
||||||
@@ -794,11 +782,6 @@ import SelectContent from '@/components/ui/select-content.vue'
|
|||||||
import SelectItem from '@/components/ui/select-item.vue'
|
import SelectItem from '@/components/ui/select-item.vue'
|
||||||
import {
|
import {
|
||||||
Dialog,
|
Dialog,
|
||||||
DialogContent,
|
|
||||||
DialogHeader,
|
|
||||||
DialogTitle,
|
|
||||||
DialogDescription,
|
|
||||||
DialogFooter
|
|
||||||
} from '@/components/ui'
|
} from '@/components/ui'
|
||||||
import { PageHeader, PageContainer, CardSection } from '@/components/layout'
|
import { PageHeader, PageContainer, CardSection } from '@/components/layout'
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
@@ -814,8 +797,7 @@ interface SystemConfig {
|
|||||||
// 用户注册
|
// 用户注册
|
||||||
enable_registration: boolean
|
enable_registration: boolean
|
||||||
require_email_verification: boolean
|
require_email_verification: boolean
|
||||||
// API Key 管理
|
// 独立余额 Key 过期管理
|
||||||
api_key_expire_days: number
|
|
||||||
auto_delete_expired_keys: boolean
|
auto_delete_expired_keys: boolean
|
||||||
// 日志记录
|
// 日志记录
|
||||||
request_log_level: string
|
request_log_level: string
|
||||||
@@ -829,6 +811,7 @@ interface SystemConfig {
|
|||||||
header_retention_days: number
|
header_retention_days: number
|
||||||
log_retention_days: number
|
log_retention_days: number
|
||||||
cleanup_batch_size: number
|
cleanup_batch_size: number
|
||||||
|
audit_log_retention_days: number
|
||||||
}
|
}
|
||||||
|
|
||||||
const loading = ref(false)
|
const loading = ref(false)
|
||||||
@@ -843,6 +826,7 @@ const configFileInput = ref<HTMLInputElement | null>(null)
|
|||||||
const importPreview = ref<ConfigExportData | null>(null)
|
const importPreview = ref<ConfigExportData | null>(null)
|
||||||
const importResult = ref<ConfigImportResponse | null>(null)
|
const importResult = ref<ConfigImportResponse | null>(null)
|
||||||
const mergeMode = ref<'skip' | 'overwrite' | 'error'>('skip')
|
const mergeMode = ref<'skip' | 'overwrite' | 'error'>('skip')
|
||||||
|
const mergeModeSelectOpen = ref(false)
|
||||||
|
|
||||||
// 用户数据导出/导入相关
|
// 用户数据导出/导入相关
|
||||||
const exportUsersLoading = ref(false)
|
const exportUsersLoading = ref(false)
|
||||||
@@ -853,6 +837,7 @@ const usersFileInput = ref<HTMLInputElement | null>(null)
|
|||||||
const importUsersPreview = ref<UsersExportData | null>(null)
|
const importUsersPreview = ref<UsersExportData | null>(null)
|
||||||
const importUsersResult = ref<UsersImportResponse | null>(null)
|
const importUsersResult = ref<UsersImportResponse | null>(null)
|
||||||
const usersMergeMode = ref<'skip' | 'overwrite' | 'error'>('skip')
|
const usersMergeMode = ref<'skip' | 'overwrite' | 'error'>('skip')
|
||||||
|
const usersMergeModeSelectOpen = ref(false)
|
||||||
|
|
||||||
const systemConfig = ref<SystemConfig>({
|
const systemConfig = ref<SystemConfig>({
|
||||||
// 基础配置
|
// 基础配置
|
||||||
@@ -861,8 +846,7 @@ const systemConfig = ref<SystemConfig>({
|
|||||||
// 用户注册
|
// 用户注册
|
||||||
enable_registration: false,
|
enable_registration: false,
|
||||||
require_email_verification: false,
|
require_email_verification: false,
|
||||||
// API Key 管理
|
// 独立余额 Key 过期管理
|
||||||
api_key_expire_days: 0,
|
|
||||||
auto_delete_expired_keys: false,
|
auto_delete_expired_keys: false,
|
||||||
// 日志记录
|
// 日志记录
|
||||||
request_log_level: 'basic',
|
request_log_level: 'basic',
|
||||||
@@ -876,6 +860,7 @@ const systemConfig = ref<SystemConfig>({
|
|||||||
header_retention_days: 90,
|
header_retention_days: 90,
|
||||||
log_retention_days: 365,
|
log_retention_days: 365,
|
||||||
cleanup_batch_size: 1000,
|
cleanup_batch_size: 1000,
|
||||||
|
audit_log_retention_days: 30,
|
||||||
})
|
})
|
||||||
|
|
||||||
// 计算属性:KB 和 字节 之间的转换
|
// 计算属性:KB 和 字节 之间的转换
|
||||||
@@ -917,8 +902,7 @@ async function loadSystemConfig() {
|
|||||||
// 用户注册
|
// 用户注册
|
||||||
'enable_registration',
|
'enable_registration',
|
||||||
'require_email_verification',
|
'require_email_verification',
|
||||||
// API Key 管理
|
// 独立余额 Key 过期管理
|
||||||
'api_key_expire_days',
|
|
||||||
'auto_delete_expired_keys',
|
'auto_delete_expired_keys',
|
||||||
// 日志记录
|
// 日志记录
|
||||||
'request_log_level',
|
'request_log_level',
|
||||||
@@ -932,6 +916,7 @@ async function loadSystemConfig() {
|
|||||||
'header_retention_days',
|
'header_retention_days',
|
||||||
'log_retention_days',
|
'log_retention_days',
|
||||||
'cleanup_batch_size',
|
'cleanup_batch_size',
|
||||||
|
'audit_log_retention_days',
|
||||||
]
|
]
|
||||||
|
|
||||||
for (const key of configs) {
|
for (const key of configs) {
|
||||||
@@ -976,12 +961,7 @@ async function saveSystemConfig() {
|
|||||||
value: systemConfig.value.require_email_verification,
|
value: systemConfig.value.require_email_verification,
|
||||||
description: '是否需要邮箱验证'
|
description: '是否需要邮箱验证'
|
||||||
},
|
},
|
||||||
// API Key 管理
|
// 独立余额 Key 过期管理
|
||||||
{
|
|
||||||
key: 'api_key_expire_days',
|
|
||||||
value: systemConfig.value.api_key_expire_days,
|
|
||||||
description: 'API密钥过期天数'
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
key: 'auto_delete_expired_keys',
|
key: 'auto_delete_expired_keys',
|
||||||
value: systemConfig.value.auto_delete_expired_keys,
|
value: systemConfig.value.auto_delete_expired_keys,
|
||||||
@@ -1039,6 +1019,11 @@ async function saveSystemConfig() {
|
|||||||
value: systemConfig.value.cleanup_batch_size,
|
value: systemConfig.value.cleanup_batch_size,
|
||||||
description: '每批次清理的记录数'
|
description: '每批次清理的记录数'
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
key: 'audit_log_retention_days',
|
||||||
|
value: systemConfig.value.audit_log_retention_days,
|
||||||
|
description: '审计日志保留天数'
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
const promises = configItems.map(item =>
|
const promises = configItems.map(item =>
|
||||||
@@ -1136,6 +1121,7 @@ async function confirmImport() {
|
|||||||
})
|
})
|
||||||
importResult.value = result
|
importResult.value = result
|
||||||
importDialogOpen.value = false
|
importDialogOpen.value = false
|
||||||
|
mergeModeSelectOpen.value = false
|
||||||
importResultDialogOpen.value = true
|
importResultDialogOpen.value = true
|
||||||
success('配置导入成功')
|
success('配置导入成功')
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
@@ -1224,6 +1210,7 @@ async function confirmImportUsers() {
|
|||||||
})
|
})
|
||||||
importUsersResult.value = result
|
importUsersResult.value = result
|
||||||
importUsersDialogOpen.value = false
|
importUsersDialogOpen.value = false
|
||||||
|
usersMergeModeSelectOpen.value = false
|
||||||
importUsersResultDialogOpen.value = true
|
importUsersResultDialogOpen.value = true
|
||||||
success('用户数据导入成功')
|
success('用户数据导入成功')
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
|
|||||||
@@ -791,11 +791,13 @@ const filteredUsers = computed(() => {
|
|||||||
return new Date(b.created_at).getTime() - new Date(a.created_at).getTime()
|
return new Date(b.created_at).getTime() - new Date(a.created_at).getTime()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// 搜索(支持空格分隔的多关键词 AND 搜索)
|
||||||
if (searchQuery.value) {
|
if (searchQuery.value) {
|
||||||
const query = searchQuery.value.toLowerCase()
|
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||||
filtered = filtered.filter(
|
filtered = filtered.filter(u => {
|
||||||
u => u.username.toLowerCase().includes(query) || u.email?.toLowerCase().includes(query)
|
const searchableText = `${u.username} ${u.email || ''}`.toLowerCase()
|
||||||
)
|
return keywords.every(keyword => searchableText.includes(keyword))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if (filterRole.value !== 'all') {
|
if (filterRole.value !== 'all') {
|
||||||
|
|||||||
@@ -103,7 +103,7 @@
|
|||||||
</div>
|
</div>
|
||||||
<div class="grid grid-cols-2 gap-2 sm:gap-3 xl:grid-cols-4">
|
<div class="grid grid-cols-2 gap-2 sm:gap-3 xl:grid-cols-4">
|
||||||
<Card class="relative p-3 sm:p-4 border-book-cloth/30">
|
<Card class="relative p-3 sm:p-4 border-book-cloth/30">
|
||||||
<Clock class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-book-cloth" />
|
<Clock class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
||||||
<div class="pr-6">
|
<div class="pr-6">
|
||||||
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
||||||
平均响应
|
平均响应
|
||||||
@@ -114,7 +114,7 @@
|
|||||||
</div>
|
</div>
|
||||||
</Card>
|
</Card>
|
||||||
<Card class="relative p-3 sm:p-4 border-kraft/30">
|
<Card class="relative p-3 sm:p-4 border-kraft/30">
|
||||||
<AlertTriangle class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-kraft" />
|
<AlertTriangle class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
||||||
<div class="pr-6">
|
<div class="pr-6">
|
||||||
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
||||||
错误率
|
错误率
|
||||||
@@ -128,7 +128,7 @@
|
|||||||
</div>
|
</div>
|
||||||
</Card>
|
</Card>
|
||||||
<Card class="relative p-3 sm:p-4 border-book-cloth/25">
|
<Card class="relative p-3 sm:p-4 border-book-cloth/25">
|
||||||
<Shuffle class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-kraft" />
|
<Shuffle class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
||||||
<div class="pr-6">
|
<div class="pr-6">
|
||||||
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
||||||
转移次数
|
转移次数
|
||||||
@@ -142,7 +142,7 @@
|
|||||||
v-if="costStats"
|
v-if="costStats"
|
||||||
class="relative p-3 sm:p-4 border-manilla/40"
|
class="relative p-3 sm:p-4 border-manilla/40"
|
||||||
>
|
>
|
||||||
<DollarSign class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-book-cloth" />
|
<DollarSign class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
||||||
<div class="pr-6">
|
<div class="pr-6">
|
||||||
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
||||||
实际成本
|
实际成本
|
||||||
@@ -180,7 +180,7 @@
|
|||||||
</div>
|
</div>
|
||||||
<div class="grid grid-cols-2 gap-2 sm:gap-3 xl:grid-cols-4">
|
<div class="grid grid-cols-2 gap-2 sm:gap-3 xl:grid-cols-4">
|
||||||
<Card class="relative p-3 sm:p-4 border-book-cloth/30">
|
<Card class="relative p-3 sm:p-4 border-book-cloth/30">
|
||||||
<Database class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-book-cloth" />
|
<Database class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
||||||
<div class="pr-6">
|
<div class="pr-6">
|
||||||
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
||||||
缓存命中率
|
缓存命中率
|
||||||
@@ -191,7 +191,7 @@
|
|||||||
</div>
|
</div>
|
||||||
</Card>
|
</Card>
|
||||||
<Card class="relative p-3 sm:p-4 border-kraft/30">
|
<Card class="relative p-3 sm:p-4 border-kraft/30">
|
||||||
<Hash class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-kraft" />
|
<Hash class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
||||||
<div class="pr-6">
|
<div class="pr-6">
|
||||||
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
||||||
缓存读取
|
缓存读取
|
||||||
@@ -202,7 +202,7 @@
|
|||||||
</div>
|
</div>
|
||||||
</Card>
|
</Card>
|
||||||
<Card class="relative p-3 sm:p-4 border-book-cloth/25">
|
<Card class="relative p-3 sm:p-4 border-book-cloth/25">
|
||||||
<Database class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-kraft" />
|
<Database class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
||||||
<div class="pr-6">
|
<div class="pr-6">
|
||||||
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
||||||
缓存创建
|
缓存创建
|
||||||
@@ -216,7 +216,7 @@
|
|||||||
v-if="tokenBreakdown"
|
v-if="tokenBreakdown"
|
||||||
class="relative p-3 sm:p-4 border-manilla/40"
|
class="relative p-3 sm:p-4 border-manilla/40"
|
||||||
>
|
>
|
||||||
<Hash class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-book-cloth" />
|
<Hash class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
||||||
<div class="pr-6">
|
<div class="pr-6">
|
||||||
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
||||||
总Token
|
总Token
|
||||||
@@ -254,16 +254,16 @@
|
|||||||
<Card class="overflow-hidden p-4 flex flex-col flex-1 min-h-0 h-full max-h-[280px] sm:max-h-none">
|
<Card class="overflow-hidden p-4 flex flex-col flex-1 min-h-0 h-full max-h-[280px] sm:max-h-none">
|
||||||
<div
|
<div
|
||||||
v-if="loadingAnnouncements"
|
v-if="loadingAnnouncements"
|
||||||
class="py-8 text-center"
|
class="flex-1 flex items-center justify-center"
|
||||||
>
|
>
|
||||||
<Loader2 class="h-5 w-5 animate-spin mx-auto text-muted-foreground" />
|
<Loader2 class="h-5 w-5 animate-spin text-muted-foreground" />
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div
|
<div
|
||||||
v-else-if="announcements.length === 0"
|
v-else-if="announcements.length === 0"
|
||||||
class="py-8 text-center"
|
class="flex-1 flex flex-col items-center justify-center"
|
||||||
>
|
>
|
||||||
<Bell class="h-8 w-8 mx-auto text-muted-foreground/40" />
|
<Bell class="h-8 w-8 text-muted-foreground/40" />
|
||||||
<p class="mt-2 text-xs text-muted-foreground">
|
<p class="mt-2 text-xs text-muted-foreground">
|
||||||
暂无公告
|
暂无公告
|
||||||
</p>
|
</p>
|
||||||
@@ -793,9 +793,8 @@ const statCardGlows = [
|
|||||||
'bg-kraft/30'
|
'bg-kraft/30'
|
||||||
]
|
]
|
||||||
|
|
||||||
const getStatIconColor = (index: number): string => {
|
const getStatIconColor = (_index: number): string => {
|
||||||
const colors = ['text-book-cloth', 'text-kraft', 'text-book-cloth', 'text-kraft']
|
return 'text-muted-foreground'
|
||||||
return colors[index % colors.length]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 统计数据
|
// 统计数据
|
||||||
|
|||||||
@@ -474,13 +474,13 @@ async function toggleCapability(modelName: string, capName: string) {
|
|||||||
const filteredModels = computed(() => {
|
const filteredModels = computed(() => {
|
||||||
let result = models.value
|
let result = models.value
|
||||||
|
|
||||||
// 搜索
|
// 搜索(支持空格分隔的多关键词 AND 搜索)
|
||||||
if (searchQuery.value) {
|
if (searchQuery.value) {
|
||||||
const query = searchQuery.value.toLowerCase()
|
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||||
result = result.filter(m =>
|
result = result.filter(m => {
|
||||||
m.name.toLowerCase().includes(query) ||
|
const searchableText = `${m.name} ${m.display_name || ''}`.toLowerCase()
|
||||||
m.display_name?.toLowerCase().includes(query)
|
return keywords.every(keyword => searchableText.includes(keyword))
|
||||||
)
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 能力筛选
|
// 能力筛选
|
||||||
|
|||||||
@@ -62,6 +62,7 @@
|
|||||||
<Button
|
<Button
|
||||||
type="submit"
|
type="submit"
|
||||||
:disabled="savingProfile"
|
:disabled="savingProfile"
|
||||||
|
class="shadow-none hover:shadow-none"
|
||||||
>
|
>
|
||||||
{{ savingProfile ? '保存中...' : '保存修改' }}
|
{{ savingProfile ? '保存中...' : '保存修改' }}
|
||||||
</Button>
|
</Button>
|
||||||
@@ -107,6 +108,7 @@
|
|||||||
<Button
|
<Button
|
||||||
type="submit"
|
type="submit"
|
||||||
:disabled="changingPassword"
|
:disabled="changingPassword"
|
||||||
|
class="shadow-none hover:shadow-none"
|
||||||
>
|
>
|
||||||
{{ changingPassword ? '修改中...' : '修改密码' }}
|
{{ changingPassword ? '修改中...' : '修改密码' }}
|
||||||
</Button>
|
</Button>
|
||||||
@@ -320,6 +322,7 @@
|
|||||||
import { ref, onMounted } from 'vue'
|
import { ref, onMounted } from 'vue'
|
||||||
import { useAuthStore } from '@/stores/auth'
|
import { useAuthStore } from '@/stores/auth'
|
||||||
import { meApi, type Profile } from '@/api/me'
|
import { meApi, type Profile } from '@/api/me'
|
||||||
|
import { useDarkMode, type ThemeMode } from '@/composables/useDarkMode'
|
||||||
import Card from '@/components/ui/card.vue'
|
import Card from '@/components/ui/card.vue'
|
||||||
import Button from '@/components/ui/button.vue'
|
import Button from '@/components/ui/button.vue'
|
||||||
import Badge from '@/components/ui/badge.vue'
|
import Badge from '@/components/ui/badge.vue'
|
||||||
@@ -338,6 +341,7 @@ import { log } from '@/utils/logger'
|
|||||||
|
|
||||||
const authStore = useAuthStore()
|
const authStore = useAuthStore()
|
||||||
const { success, error: showError } = useToast()
|
const { success, error: showError } = useToast()
|
||||||
|
const { setThemeMode } = useDarkMode()
|
||||||
|
|
||||||
const profile = ref<Profile | null>(null)
|
const profile = ref<Profile | null>(null)
|
||||||
|
|
||||||
@@ -375,20 +379,8 @@ function handleThemeChange(value: string) {
|
|||||||
themeSelectOpen.value = false
|
themeSelectOpen.value = false
|
||||||
updatePreferences()
|
updatePreferences()
|
||||||
|
|
||||||
// 应用主题
|
// 使用 useDarkMode 统一切换主题
|
||||||
if (value === 'dark') {
|
setThemeMode(value as ThemeMode)
|
||||||
document.documentElement.classList.add('dark')
|
|
||||||
} else if (value === 'light') {
|
|
||||||
document.documentElement.classList.remove('dark')
|
|
||||||
} else {
|
|
||||||
// system: 跟随系统
|
|
||||||
const prefersDark = window.matchMedia('(prefers-color-scheme: dark)').matches
|
|
||||||
if (prefersDark) {
|
|
||||||
document.documentElement.classList.add('dark')
|
|
||||||
} else {
|
|
||||||
document.documentElement.classList.remove('dark')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleLanguageChange(value: string) {
|
function handleLanguageChange(value: string) {
|
||||||
@@ -418,10 +410,16 @@ async function loadProfile() {
|
|||||||
async function loadPreferences() {
|
async function loadPreferences() {
|
||||||
try {
|
try {
|
||||||
const prefs = await meApi.getPreferences()
|
const prefs = await meApi.getPreferences()
|
||||||
|
|
||||||
|
// 主题以本地 localStorage 为准(useDarkMode 在应用启动时已初始化)
|
||||||
|
// 这样可以避免刷新页面时主题被服务端旧值覆盖
|
||||||
|
const { themeMode: currentThemeMode } = useDarkMode()
|
||||||
|
const localTheme = currentThemeMode.value
|
||||||
|
|
||||||
preferencesForm.value = {
|
preferencesForm.value = {
|
||||||
avatar_url: prefs.avatar_url || '',
|
avatar_url: prefs.avatar_url || '',
|
||||||
bio: prefs.bio || '',
|
bio: prefs.bio || '',
|
||||||
theme: prefs.theme || 'light',
|
theme: localTheme, // 使用本地主题,而非服务端返回值
|
||||||
language: prefs.language || 'zh-CN',
|
language: prefs.language || 'zh-CN',
|
||||||
timezone: prefs.timezone || 'Asia/Shanghai',
|
timezone: prefs.timezone || 'Asia/Shanghai',
|
||||||
notifications: {
|
notifications: {
|
||||||
@@ -431,11 +429,12 @@ async function loadPreferences() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 应用主题
|
// 如果本地主题和服务端不一致,同步到服务端(静默更新,不提示用户)
|
||||||
if (preferencesForm.value.theme === 'dark') {
|
const serverTheme = prefs.theme || 'light'
|
||||||
document.documentElement.classList.add('dark')
|
if (localTheme !== serverTheme) {
|
||||||
} else if (preferencesForm.value.theme === 'light') {
|
meApi.updatePreferences({ theme: localTheme }).catch(() => {
|
||||||
document.documentElement.classList.remove('dark')
|
// 静默失败,不影响用户体验
|
||||||
|
})
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
log.error('加载偏好设置失败:', error)
|
log.error('加载偏好设置失败:', error)
|
||||||
|
|||||||
@@ -3,10 +3,8 @@
|
|||||||
A proxy server that enables AI models to work with multiple API providers.
|
A proxy server that enables AI models to work with multiple API providers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
# 注意: dotenv 加载已统一移至 src/config/settings.py
|
||||||
|
# 不要在此处重复加载
|
||||||
# Load environment variables from .env file
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from src._version import __version__
|
from src._version import __version__
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from .api_keys import router as api_keys_router
|
|||||||
from .endpoints import router as endpoints_router
|
from .endpoints import router as endpoints_router
|
||||||
from .models import router as models_router
|
from .models import router as models_router
|
||||||
from .monitoring import router as monitoring_router
|
from .monitoring import router as monitoring_router
|
||||||
|
from .provider_query import router as provider_query_router
|
||||||
from .provider_strategy import router as provider_strategy_router
|
from .provider_strategy import router as provider_strategy_router
|
||||||
from .providers import router as providers_router
|
from .providers import router as providers_router
|
||||||
from .security import router as security_router
|
from .security import router as security_router
|
||||||
@@ -26,5 +27,6 @@ router.include_router(provider_strategy_router)
|
|||||||
router.include_router(adaptive_router)
|
router.include_router(adaptive_router)
|
||||||
router.include_router(models_router)
|
router.include_router(models_router)
|
||||||
router.include_router(security_router)
|
router.include_router(security_router)
|
||||||
|
router.include_router(provider_query_router)
|
||||||
|
|
||||||
__all__ = ["router"]
|
__all__ = ["router"]
|
||||||
|
|||||||
@@ -223,7 +223,7 @@ class AdminCreateStandaloneKeyAdapter(AdminApiAdapter):
|
|||||||
allowed_providers=self.key_data.allowed_providers,
|
allowed_providers=self.key_data.allowed_providers,
|
||||||
allowed_api_formats=self.key_data.allowed_api_formats,
|
allowed_api_formats=self.key_data.allowed_api_formats,
|
||||||
allowed_models=self.key_data.allowed_models,
|
allowed_models=self.key_data.allowed_models,
|
||||||
rate_limit=self.key_data.rate_limit or 100,
|
rate_limit=self.key_data.rate_limit, # None 表示不限制
|
||||||
expire_days=self.key_data.expire_days,
|
expire_days=self.key_data.expire_days,
|
||||||
initial_balance_usd=self.key_data.initial_balance_usd,
|
initial_balance_usd=self.key_data.initial_balance_usd,
|
||||||
is_standalone=True, # 标记为独立Key
|
is_standalone=True, # 标记为独立Key
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ ProviderEndpoint CRUD 管理 API
|
|||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query, Request
|
from fastapi import APIRouter, Depends, Query, Request
|
||||||
from sqlalchemy import and_, func
|
from sqlalchemy import and_, func
|
||||||
@@ -27,6 +27,16 @@ router = APIRouter(tags=["Endpoint Management"])
|
|||||||
pipeline = ApiRequestPipeline()
|
pipeline = ApiRequestPipeline()
|
||||||
|
|
||||||
|
|
||||||
|
def mask_proxy_password(proxy_config: Optional[dict]) -> Optional[dict]:
|
||||||
|
"""对代理配置中的密码进行脱敏处理"""
|
||||||
|
if not proxy_config:
|
||||||
|
return None
|
||||||
|
masked = dict(proxy_config)
|
||||||
|
if masked.get("password"):
|
||||||
|
masked["password"] = "***"
|
||||||
|
return masked
|
||||||
|
|
||||||
|
|
||||||
@router.get("/providers/{provider_id}/endpoints", response_model=List[ProviderEndpointResponse])
|
@router.get("/providers/{provider_id}/endpoints", response_model=List[ProviderEndpointResponse])
|
||||||
async def list_provider_endpoints(
|
async def list_provider_endpoints(
|
||||||
provider_id: str,
|
provider_id: str,
|
||||||
@@ -153,6 +163,7 @@ class AdminListProviderEndpointsAdapter(AdminApiAdapter):
|
|||||||
"api_format": endpoint.api_format,
|
"api_format": endpoint.api_format,
|
||||||
"total_keys": total_keys_map.get(endpoint.id, 0),
|
"total_keys": total_keys_map.get(endpoint.id, 0),
|
||||||
"active_keys": active_keys_map.get(endpoint.id, 0),
|
"active_keys": active_keys_map.get(endpoint.id, 0),
|
||||||
|
"proxy": mask_proxy_password(endpoint.proxy),
|
||||||
}
|
}
|
||||||
endpoint_dict.pop("_sa_instance_state", None)
|
endpoint_dict.pop("_sa_instance_state", None)
|
||||||
result.append(ProviderEndpointResponse(**endpoint_dict))
|
result.append(ProviderEndpointResponse(**endpoint_dict))
|
||||||
@@ -202,6 +213,7 @@ class AdminCreateProviderEndpointAdapter(AdminApiAdapter):
|
|||||||
rate_limit=self.endpoint_data.rate_limit,
|
rate_limit=self.endpoint_data.rate_limit,
|
||||||
is_active=True,
|
is_active=True,
|
||||||
config=self.endpoint_data.config,
|
config=self.endpoint_data.config,
|
||||||
|
proxy=self.endpoint_data.proxy.model_dump() if self.endpoint_data.proxy else None,
|
||||||
created_at=now,
|
created_at=now,
|
||||||
updated_at=now,
|
updated_at=now,
|
||||||
)
|
)
|
||||||
@@ -215,12 +227,13 @@ class AdminCreateProviderEndpointAdapter(AdminApiAdapter):
|
|||||||
endpoint_dict = {
|
endpoint_dict = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in new_endpoint.__dict__.items()
|
for k, v in new_endpoint.__dict__.items()
|
||||||
if k not in {"api_format", "_sa_instance_state"}
|
if k not in {"api_format", "_sa_instance_state", "proxy"}
|
||||||
}
|
}
|
||||||
return ProviderEndpointResponse(
|
return ProviderEndpointResponse(
|
||||||
**endpoint_dict,
|
**endpoint_dict,
|
||||||
provider_name=provider.name,
|
provider_name=provider.name,
|
||||||
api_format=new_endpoint.api_format,
|
api_format=new_endpoint.api_format,
|
||||||
|
proxy=mask_proxy_password(new_endpoint.proxy),
|
||||||
total_keys=0,
|
total_keys=0,
|
||||||
active_keys=0,
|
active_keys=0,
|
||||||
)
|
)
|
||||||
@@ -259,12 +272,13 @@ class AdminGetProviderEndpointAdapter(AdminApiAdapter):
|
|||||||
endpoint_dict = {
|
endpoint_dict = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in endpoint_obj.__dict__.items()
|
for k, v in endpoint_obj.__dict__.items()
|
||||||
if k not in {"api_format", "_sa_instance_state"}
|
if k not in {"api_format", "_sa_instance_state", "proxy"}
|
||||||
}
|
}
|
||||||
return ProviderEndpointResponse(
|
return ProviderEndpointResponse(
|
||||||
**endpoint_dict,
|
**endpoint_dict,
|
||||||
provider_name=provider.name,
|
provider_name=provider.name,
|
||||||
api_format=endpoint_obj.api_format,
|
api_format=endpoint_obj.api_format,
|
||||||
|
proxy=mask_proxy_password(endpoint_obj.proxy),
|
||||||
total_keys=total_keys,
|
total_keys=total_keys,
|
||||||
active_keys=active_keys,
|
active_keys=active_keys,
|
||||||
)
|
)
|
||||||
@@ -284,6 +298,17 @@ class AdminUpdateProviderEndpointAdapter(AdminApiAdapter):
|
|||||||
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
|
||||||
|
|
||||||
update_data = self.endpoint_data.model_dump(exclude_unset=True)
|
update_data = self.endpoint_data.model_dump(exclude_unset=True)
|
||||||
|
# 把 proxy 转换为 dict 存储,支持显式设置为 None 清除代理
|
||||||
|
if "proxy" in update_data:
|
||||||
|
if update_data["proxy"] is not None:
|
||||||
|
new_proxy = dict(update_data["proxy"])
|
||||||
|
# 只有当密码字段未提供时才保留原密码(空字符串视为显式清除)
|
||||||
|
if "password" not in new_proxy and endpoint.proxy:
|
||||||
|
old_password = endpoint.proxy.get("password")
|
||||||
|
if old_password:
|
||||||
|
new_proxy["password"] = old_password
|
||||||
|
update_data["proxy"] = new_proxy
|
||||||
|
# proxy 为 None 时保留,用于清除代理配置
|
||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
setattr(endpoint, field, value)
|
setattr(endpoint, field, value)
|
||||||
endpoint.updated_at = datetime.now(timezone.utc)
|
endpoint.updated_at = datetime.now(timezone.utc)
|
||||||
@@ -311,12 +336,13 @@ class AdminUpdateProviderEndpointAdapter(AdminApiAdapter):
|
|||||||
endpoint_dict = {
|
endpoint_dict = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in endpoint.__dict__.items()
|
for k, v in endpoint.__dict__.items()
|
||||||
if k not in {"api_format", "_sa_instance_state"}
|
if k not in {"api_format", "_sa_instance_state", "proxy"}
|
||||||
}
|
}
|
||||||
return ProviderEndpointResponse(
|
return ProviderEndpointResponse(
|
||||||
**endpoint_dict,
|
**endpoint_dict,
|
||||||
provider_name=provider.name if provider else "Unknown",
|
provider_name=provider.name if provider else "Unknown",
|
||||||
api_format=endpoint.api_format,
|
api_format=endpoint.api_format,
|
||||||
|
proxy=mask_proxy_password(endpoint.proxy),
|
||||||
total_keys=total_keys,
|
total_keys=total_keys,
|
||||||
active_keys=active_keys,
|
active_keys=active_keys,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -21,7 +21,8 @@ from src.core.logger import logger
|
|||||||
from src.database import get_db
|
from src.database import get_db
|
||||||
from src.models.database import ApiKey, User
|
from src.models.database import ApiKey, User
|
||||||
from src.services.cache.affinity_manager import get_affinity_manager
|
from src.services.cache.affinity_manager import get_affinity_manager
|
||||||
from src.services.cache.aware_scheduler import get_cache_aware_scheduler
|
from src.services.cache.aware_scheduler import CacheAwareScheduler, get_cache_aware_scheduler
|
||||||
|
from src.services.system.config import SystemConfigService
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/admin/monitoring/cache", tags=["Admin - Monitoring: Cache"])
|
router = APIRouter(prefix="/api/admin/monitoring/cache", tags=["Admin - Monitoring: Cache"])
|
||||||
pipeline = ApiRequestPipeline()
|
pipeline = ApiRequestPipeline()
|
||||||
@@ -250,7 +251,22 @@ class AdminCacheStatsAdapter(AdminApiAdapter):
|
|||||||
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
|
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
|
||||||
try:
|
try:
|
||||||
redis_client = get_redis_client_sync()
|
redis_client = get_redis_client_sync()
|
||||||
scheduler = await get_cache_aware_scheduler(redis_client)
|
# 读取系统配置,确保监控接口与编排器使用一致的模式
|
||||||
|
priority_mode = SystemConfigService.get_config(
|
||||||
|
context.db,
|
||||||
|
"provider_priority_mode",
|
||||||
|
CacheAwareScheduler.PRIORITY_MODE_PROVIDER,
|
||||||
|
)
|
||||||
|
scheduling_mode = SystemConfigService.get_config(
|
||||||
|
context.db,
|
||||||
|
"scheduling_mode",
|
||||||
|
CacheAwareScheduler.SCHEDULING_MODE_CACHE_AFFINITY,
|
||||||
|
)
|
||||||
|
scheduler = await get_cache_aware_scheduler(
|
||||||
|
redis_client,
|
||||||
|
priority_mode=priority_mode,
|
||||||
|
scheduling_mode=scheduling_mode,
|
||||||
|
)
|
||||||
stats = await scheduler.get_stats()
|
stats = await scheduler.get_stats()
|
||||||
logger.info("缓存统计信息查询成功")
|
logger.info("缓存统计信息查询成功")
|
||||||
context.add_audit_metadata(
|
context.add_audit_metadata(
|
||||||
@@ -270,7 +286,22 @@ class AdminCacheMetricsAdapter(AdminApiAdapter):
|
|||||||
async def handle(self, context: ApiRequestContext) -> PlainTextResponse:
|
async def handle(self, context: ApiRequestContext) -> PlainTextResponse:
|
||||||
try:
|
try:
|
||||||
redis_client = get_redis_client_sync()
|
redis_client = get_redis_client_sync()
|
||||||
scheduler = await get_cache_aware_scheduler(redis_client)
|
# 读取系统配置,确保监控接口与编排器使用一致的模式
|
||||||
|
priority_mode = SystemConfigService.get_config(
|
||||||
|
context.db,
|
||||||
|
"provider_priority_mode",
|
||||||
|
CacheAwareScheduler.PRIORITY_MODE_PROVIDER,
|
||||||
|
)
|
||||||
|
scheduling_mode = SystemConfigService.get_config(
|
||||||
|
context.db,
|
||||||
|
"scheduling_mode",
|
||||||
|
CacheAwareScheduler.SCHEDULING_MODE_CACHE_AFFINITY,
|
||||||
|
)
|
||||||
|
scheduler = await get_cache_aware_scheduler(
|
||||||
|
redis_client,
|
||||||
|
priority_mode=priority_mode,
|
||||||
|
scheduling_mode=scheduling_mode,
|
||||||
|
)
|
||||||
stats = await scheduler.get_stats()
|
stats = await scheduler.get_stats()
|
||||||
payload = self._format_prometheus(stats)
|
payload = self._format_prometheus(stats)
|
||||||
context.add_audit_metadata(
|
context.add_audit_metadata(
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from datetime import datetime
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, ConfigDict
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from src.api.base.admin_adapter import AdminApiAdapter
|
from src.api.base.admin_adapter import AdminApiAdapter
|
||||||
@@ -52,8 +52,7 @@ class CandidateResponse(BaseModel):
|
|||||||
started_at: Optional[datetime] = None
|
started_at: Optional[datetime] = None
|
||||||
finished_at: Optional[datetime] = None
|
finished_at: Optional[datetime] = None
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(from_attributes=True)
|
||||||
from_attributes = True
|
|
||||||
|
|
||||||
|
|
||||||
class RequestTraceResponse(BaseModel):
|
class RequestTraceResponse(BaseModel):
|
||||||
|
|||||||
@@ -1,46 +1,28 @@
|
|||||||
"""
|
"""
|
||||||
Provider Query API 端点
|
Provider Query API 端点
|
||||||
用于查询提供商的余额、使用记录等信息
|
用于查询提供商的模型列表等信息
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from datetime import datetime
|
import asyncio
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
import httpx
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.orm import Session, joinedload
|
||||||
|
|
||||||
|
from src.core.crypto import crypto_service
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
from src.database.database import get_db
|
from src.database.database import get_db
|
||||||
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint, User
|
from src.models.database import Provider, ProviderEndpoint, User
|
||||||
|
|
||||||
# 初始化适配器注册
|
|
||||||
from src.plugins.provider_query import init # noqa
|
|
||||||
from src.plugins.provider_query import get_query_registry
|
|
||||||
from src.plugins.provider_query.base import QueryCapability
|
|
||||||
from src.utils.auth_utils import get_current_user
|
from src.utils.auth_utils import get_current_user
|
||||||
|
|
||||||
router = APIRouter(prefix="/provider-query", tags=["Provider Query"])
|
router = APIRouter(prefix="/api/admin/provider-query", tags=["Provider Query"])
|
||||||
|
|
||||||
|
|
||||||
# ============ Request/Response Models ============
|
# ============ Request/Response Models ============
|
||||||
|
|
||||||
|
|
||||||
class BalanceQueryRequest(BaseModel):
|
|
||||||
"""余额查询请求"""
|
|
||||||
|
|
||||||
provider_id: str
|
|
||||||
api_key_id: Optional[str] = None # 如果不指定,使用提供商的第一个可用 API Key
|
|
||||||
|
|
||||||
|
|
||||||
class UsageSummaryQueryRequest(BaseModel):
|
|
||||||
"""使用汇总查询请求"""
|
|
||||||
|
|
||||||
provider_id: str
|
|
||||||
api_key_id: Optional[str] = None
|
|
||||||
period: str = "month" # day, week, month, year
|
|
||||||
|
|
||||||
|
|
||||||
class ModelsQueryRequest(BaseModel):
|
class ModelsQueryRequest(BaseModel):
|
||||||
"""模型列表查询请求"""
|
"""模型列表查询请求"""
|
||||||
|
|
||||||
@@ -51,360 +33,281 @@ class ModelsQueryRequest(BaseModel):
|
|||||||
# ============ API Endpoints ============
|
# ============ API Endpoints ============
|
||||||
|
|
||||||
|
|
||||||
@router.get("/adapters")
|
async def _fetch_openai_models(
|
||||||
async def list_adapters(
|
client: httpx.AsyncClient,
|
||||||
current_user: User = Depends(get_current_user),
|
base_url: str,
|
||||||
):
|
api_key: str,
|
||||||
"""
|
api_format: str,
|
||||||
获取所有可用的查询适配器
|
extra_headers: Optional[dict] = None,
|
||||||
|
) -> tuple[list, Optional[str]]:
|
||||||
|
"""获取 OpenAI 格式的模型列表
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
适配器列表
|
tuple[list, Optional[str]]: (模型列表, 错误信息)
|
||||||
"""
|
"""
|
||||||
registry = get_query_registry()
|
headers = {"Authorization": f"Bearer {api_key}"}
|
||||||
adapters = registry.list_adapters()
|
if extra_headers:
|
||||||
|
# 防止 extra_headers 覆盖 Authorization
|
||||||
|
safe_headers = {k: v for k, v in extra_headers.items() if k.lower() != "authorization"}
|
||||||
|
headers.update(safe_headers)
|
||||||
|
|
||||||
return {"success": True, "data": adapters}
|
# 构建 /v1/models URL
|
||||||
|
if base_url.endswith("/v1"):
|
||||||
|
models_url = f"{base_url}/models"
|
||||||
|
else:
|
||||||
|
models_url = f"{base_url}/v1/models"
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.get(models_url, headers=headers)
|
||||||
|
logger.debug(f"OpenAI models request to {models_url}: status={response.status_code}")
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
models = []
|
||||||
|
if "data" in data:
|
||||||
|
models = data["data"]
|
||||||
|
elif isinstance(data, list):
|
||||||
|
models = data
|
||||||
|
# 为每个模型添加 api_format 字段
|
||||||
|
for m in models:
|
||||||
|
m["api_format"] = api_format
|
||||||
|
return models, None
|
||||||
|
else:
|
||||||
|
# 记录详细的错误信息
|
||||||
|
error_body = response.text[:500] if response.text else "(empty)"
|
||||||
|
error_msg = f"HTTP {response.status_code}: {error_body}"
|
||||||
|
logger.warning(f"OpenAI models request to {models_url} failed: {error_msg}")
|
||||||
|
return [], error_msg
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Request error: {str(e)}"
|
||||||
|
logger.warning(f"Failed to fetch models from {models_url}: {e}")
|
||||||
|
return [], error_msg
|
||||||
|
|
||||||
|
|
||||||
@router.get("/capabilities/{provider_id}")
|
async def _fetch_claude_models(
|
||||||
async def get_provider_capabilities(
|
client: httpx.AsyncClient, base_url: str, api_key: str, api_format: str
|
||||||
provider_id: str,
|
) -> tuple[list, Optional[str]]:
|
||||||
db: AsyncSession = Depends(get_db),
|
"""获取 Claude 格式的模型列表
|
||||||
current_user: User = Depends(get_current_user),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
获取提供商支持的查询能力
|
|
||||||
|
|
||||||
Args:
|
|
||||||
provider_id: 提供商 ID
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
支持的查询能力列表
|
tuple[list, Optional[str]]: (模型列表, 错误信息)
|
||||||
"""
|
"""
|
||||||
# 获取提供商
|
headers = {
|
||||||
from sqlalchemy import select
|
"x-api-key": api_key,
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
result = await db.execute(select(Provider).where(Provider.id == provider_id))
|
"anthropic-version": "2023-06-01",
|
||||||
provider = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if not provider:
|
|
||||||
raise HTTPException(status_code=404, detail="Provider not found")
|
|
||||||
|
|
||||||
registry = get_query_registry()
|
|
||||||
capabilities = registry.get_capabilities_for_provider(provider.name)
|
|
||||||
|
|
||||||
if capabilities is None:
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"data": {
|
|
||||||
"provider_id": provider_id,
|
|
||||||
"provider_name": provider.name,
|
|
||||||
"capabilities": [],
|
|
||||||
"has_adapter": False,
|
|
||||||
"message": "No query adapter available for this provider",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"data": {
|
|
||||||
"provider_id": provider_id,
|
|
||||||
"provider_name": provider.name,
|
|
||||||
"capabilities": [c.name for c in capabilities],
|
|
||||||
"has_adapter": True,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 构建 /v1/models URL
|
||||||
|
if base_url.endswith("/v1"):
|
||||||
|
models_url = f"{base_url}/models"
|
||||||
|
else:
|
||||||
|
models_url = f"{base_url}/v1/models"
|
||||||
|
|
||||||
@router.post("/balance")
|
try:
|
||||||
async def query_balance(
|
response = await client.get(models_url, headers=headers)
|
||||||
request: BalanceQueryRequest,
|
logger.debug(f"Claude models request to {models_url}: status={response.status_code}")
|
||||||
db: AsyncSession = Depends(get_db),
|
if response.status_code == 200:
|
||||||
current_user: User = Depends(get_current_user),
|
data = response.json()
|
||||||
):
|
models = []
|
||||||
"""
|
if "data" in data:
|
||||||
查询提供商余额
|
models = data["data"]
|
||||||
|
elif isinstance(data, list):
|
||||||
|
models = data
|
||||||
|
# 为每个模型添加 api_format 字段
|
||||||
|
for m in models:
|
||||||
|
m["api_format"] = api_format
|
||||||
|
return models, None
|
||||||
|
else:
|
||||||
|
error_body = response.text[:500] if response.text else "(empty)"
|
||||||
|
error_msg = f"HTTP {response.status_code}: {error_body}"
|
||||||
|
logger.warning(f"Claude models request to {models_url} failed: {error_msg}")
|
||||||
|
return [], error_msg
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Request error: {str(e)}"
|
||||||
|
logger.warning(f"Failed to fetch Claude models from {models_url}: {e}")
|
||||||
|
return [], error_msg
|
||||||
|
|
||||||
Args:
|
|
||||||
request: 查询请求
|
async def _fetch_gemini_models(
|
||||||
|
client: httpx.AsyncClient, base_url: str, api_key: str, api_format: str
|
||||||
|
) -> tuple[list, Optional[str]]:
|
||||||
|
"""获取 Gemini 格式的模型列表
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
余额信息
|
tuple[list, Optional[str]]: (模型列表, 错误信息)
|
||||||
"""
|
"""
|
||||||
from sqlalchemy import select
|
# 兼容 base_url 已包含 /v1beta 的情况
|
||||||
from sqlalchemy.orm import selectinload
|
base_url_clean = base_url.rstrip("/")
|
||||||
|
if base_url_clean.endswith("/v1beta"):
|
||||||
|
models_url = f"{base_url_clean}/models?key={api_key}"
|
||||||
|
else:
|
||||||
|
models_url = f"{base_url_clean}/v1beta/models?key={api_key}"
|
||||||
|
|
||||||
# 获取提供商及其端点
|
try:
|
||||||
result = await db.execute(
|
response = await client.get(models_url)
|
||||||
select(Provider)
|
logger.debug(f"Gemini models request to {models_url}: status={response.status_code}")
|
||||||
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
|
if response.status_code == 200:
|
||||||
.where(Provider.id == request.provider_id)
|
data = response.json()
|
||||||
)
|
if "models" in data:
|
||||||
provider = result.scalar_one_or_none()
|
# 转换为统一格式
|
||||||
|
return [
|
||||||
if not provider:
|
{
|
||||||
raise HTTPException(status_code=404, detail="Provider not found")
|
"id": m.get("name", "").replace("models/", ""),
|
||||||
|
"owned_by": "google",
|
||||||
# 获取 API Key
|
"display_name": m.get("displayName", ""),
|
||||||
api_key_value = None
|
"api_format": api_format,
|
||||||
endpoint_config = None
|
|
||||||
|
|
||||||
if request.api_key_id:
|
|
||||||
# 查找指定的 API Key
|
|
||||||
for endpoint in provider.endpoints:
|
|
||||||
for api_key in endpoint.api_keys:
|
|
||||||
if api_key.id == request.api_key_id:
|
|
||||||
api_key_value = api_key.api_key
|
|
||||||
endpoint_config = {
|
|
||||||
"base_url": endpoint.base_url,
|
|
||||||
"api_format": endpoint.api_format if endpoint.api_format else None,
|
|
||||||
}
|
}
|
||||||
break
|
for m in data["models"]
|
||||||
if api_key_value:
|
], None
|
||||||
break
|
return [], None
|
||||||
|
else:
|
||||||
if not api_key_value:
|
error_body = response.text[:500] if response.text else "(empty)"
|
||||||
raise HTTPException(status_code=404, detail="API Key not found")
|
error_msg = f"HTTP {response.status_code}: {error_body}"
|
||||||
else:
|
logger.warning(f"Gemini models request to {models_url} failed: {error_msg}")
|
||||||
# 使用第一个可用的 API Key
|
return [], error_msg
|
||||||
for endpoint in provider.endpoints:
|
except Exception as e:
|
||||||
if endpoint.is_active and endpoint.api_keys:
|
error_msg = f"Request error: {str(e)}"
|
||||||
for api_key in endpoint.api_keys:
|
logger.warning(f"Failed to fetch Gemini models from {models_url}: {e}")
|
||||||
if api_key.is_active:
|
return [], error_msg
|
||||||
api_key_value = api_key.api_key
|
|
||||||
endpoint_config = {
|
|
||||||
"base_url": endpoint.base_url,
|
|
||||||
"api_format": endpoint.api_format if endpoint.api_format else None,
|
|
||||||
}
|
|
||||||
break
|
|
||||||
if api_key_value:
|
|
||||||
break
|
|
||||||
|
|
||||||
if not api_key_value:
|
|
||||||
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
|
|
||||||
|
|
||||||
# 查询余额
|
|
||||||
registry = get_query_registry()
|
|
||||||
query_result = await registry.query_provider_balance(
|
|
||||||
provider_type=provider.name, api_key=api_key_value, endpoint_config=endpoint_config
|
|
||||||
)
|
|
||||||
|
|
||||||
if not query_result.success:
|
|
||||||
logger.warning(f"Balance query failed for provider {provider.name}: {query_result.error}")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": query_result.success,
|
|
||||||
"data": query_result.to_dict(),
|
|
||||||
"provider": {
|
|
||||||
"id": provider.id,
|
|
||||||
"name": provider.name,
|
|
||||||
"display_name": provider.display_name,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/usage-summary")
|
|
||||||
async def query_usage_summary(
|
|
||||||
request: UsageSummaryQueryRequest,
|
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
查询提供商使用汇总
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: 查询请求
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
使用汇总信息
|
|
||||||
"""
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.orm import selectinload
|
|
||||||
|
|
||||||
# 获取提供商及其端点
|
|
||||||
result = await db.execute(
|
|
||||||
select(Provider)
|
|
||||||
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
|
|
||||||
.where(Provider.id == request.provider_id)
|
|
||||||
)
|
|
||||||
provider = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if not provider:
|
|
||||||
raise HTTPException(status_code=404, detail="Provider not found")
|
|
||||||
|
|
||||||
# 获取 API Key(逻辑同上)
|
|
||||||
api_key_value = None
|
|
||||||
endpoint_config = None
|
|
||||||
|
|
||||||
if request.api_key_id:
|
|
||||||
for endpoint in provider.endpoints:
|
|
||||||
for api_key in endpoint.api_keys:
|
|
||||||
if api_key.id == request.api_key_id:
|
|
||||||
api_key_value = api_key.api_key
|
|
||||||
endpoint_config = {"base_url": endpoint.base_url}
|
|
||||||
break
|
|
||||||
if api_key_value:
|
|
||||||
break
|
|
||||||
|
|
||||||
if not api_key_value:
|
|
||||||
raise HTTPException(status_code=404, detail="API Key not found")
|
|
||||||
else:
|
|
||||||
for endpoint in provider.endpoints:
|
|
||||||
if endpoint.is_active and endpoint.api_keys:
|
|
||||||
for api_key in endpoint.api_keys:
|
|
||||||
if api_key.is_active:
|
|
||||||
api_key_value = api_key.api_key
|
|
||||||
endpoint_config = {"base_url": endpoint.base_url}
|
|
||||||
break
|
|
||||||
if api_key_value:
|
|
||||||
break
|
|
||||||
|
|
||||||
if not api_key_value:
|
|
||||||
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
|
|
||||||
|
|
||||||
# 查询使用汇总
|
|
||||||
registry = get_query_registry()
|
|
||||||
query_result = await registry.query_provider_usage(
|
|
||||||
provider_type=provider.name,
|
|
||||||
api_key=api_key_value,
|
|
||||||
period=request.period,
|
|
||||||
endpoint_config=endpoint_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": query_result.success,
|
|
||||||
"data": query_result.to_dict(),
|
|
||||||
"provider": {
|
|
||||||
"id": provider.id,
|
|
||||||
"name": provider.name,
|
|
||||||
"display_name": provider.display_name,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/models")
|
@router.post("/models")
|
||||||
async def query_available_models(
|
async def query_available_models(
|
||||||
request: ModelsQueryRequest,
|
request: ModelsQueryRequest,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
查询提供商可用模型
|
查询提供商可用模型
|
||||||
|
|
||||||
|
遍历所有活跃端点,根据端点的 API 格式选择正确的请求方式:
|
||||||
|
- OPENAI/OPENAI_CLI: /v1/models (Bearer token)
|
||||||
|
- CLAUDE/CLAUDE_CLI: /v1/models (x-api-key)
|
||||||
|
- GEMINI/GEMINI_CLI: /v1beta/models (URL key parameter)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: 查询请求
|
request: 查询请求
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
模型列表
|
所有端点的模型列表(合并)
|
||||||
"""
|
"""
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.orm import selectinload
|
|
||||||
|
|
||||||
# 获取提供商及其端点
|
# 获取提供商及其端点
|
||||||
result = await db.execute(
|
provider = (
|
||||||
select(Provider)
|
db.query(Provider)
|
||||||
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
|
.options(joinedload(Provider.endpoints).joinedload(ProviderEndpoint.api_keys))
|
||||||
.where(Provider.id == request.provider_id)
|
.filter(Provider.id == request.provider_id)
|
||||||
|
.first()
|
||||||
)
|
)
|
||||||
provider = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if not provider:
|
if not provider:
|
||||||
raise HTTPException(status_code=404, detail="Provider not found")
|
raise HTTPException(status_code=404, detail="Provider not found")
|
||||||
|
|
||||||
# 获取 API Key
|
# 收集所有活跃端点的配置
|
||||||
api_key_value = None
|
endpoint_configs: list[dict] = []
|
||||||
endpoint_config = None
|
|
||||||
|
|
||||||
if request.api_key_id:
|
if request.api_key_id:
|
||||||
|
# 指定了特定的 API Key,只使用该 Key 对应的端点
|
||||||
for endpoint in provider.endpoints:
|
for endpoint in provider.endpoints:
|
||||||
for api_key in endpoint.api_keys:
|
for api_key in endpoint.api_keys:
|
||||||
if api_key.id == request.api_key_id:
|
if api_key.id == request.api_key_id:
|
||||||
api_key_value = api_key.api_key
|
try:
|
||||||
endpoint_config = {"base_url": endpoint.base_url}
|
api_key_value = crypto_service.decrypt(api_key.api_key)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to decrypt API key: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to decrypt API key")
|
||||||
|
endpoint_configs.append({
|
||||||
|
"api_key": api_key_value,
|
||||||
|
"base_url": endpoint.base_url,
|
||||||
|
"api_format": endpoint.api_format,
|
||||||
|
"extra_headers": endpoint.headers,
|
||||||
|
})
|
||||||
break
|
break
|
||||||
if api_key_value:
|
if endpoint_configs:
|
||||||
break
|
break
|
||||||
|
|
||||||
if not api_key_value:
|
if not endpoint_configs:
|
||||||
raise HTTPException(status_code=404, detail="API Key not found")
|
raise HTTPException(status_code=404, detail="API Key not found")
|
||||||
else:
|
else:
|
||||||
|
# 遍历所有活跃端点,每个端点取第一个可用的 Key
|
||||||
for endpoint in provider.endpoints:
|
for endpoint in provider.endpoints:
|
||||||
if endpoint.is_active and endpoint.api_keys:
|
if not endpoint.is_active or not endpoint.api_keys:
|
||||||
for api_key in endpoint.api_keys:
|
continue
|
||||||
if api_key.is_active:
|
|
||||||
api_key_value = api_key.api_key
|
|
||||||
endpoint_config = {"base_url": endpoint.base_url}
|
|
||||||
break
|
|
||||||
if api_key_value:
|
|
||||||
break
|
|
||||||
|
|
||||||
if not api_key_value:
|
# 找第一个可用的 Key
|
||||||
|
for api_key in endpoint.api_keys:
|
||||||
|
if api_key.is_active:
|
||||||
|
try:
|
||||||
|
api_key_value = crypto_service.decrypt(api_key.api_key)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to decrypt API key: {e}")
|
||||||
|
continue # 尝试下一个 Key
|
||||||
|
endpoint_configs.append({
|
||||||
|
"api_key": api_key_value,
|
||||||
|
"base_url": endpoint.base_url,
|
||||||
|
"api_format": endpoint.api_format,
|
||||||
|
"extra_headers": endpoint.headers,
|
||||||
|
})
|
||||||
|
break # 只取第一个可用的 Key
|
||||||
|
|
||||||
|
if not endpoint_configs:
|
||||||
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
|
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
|
||||||
|
|
||||||
# 查询模型
|
# 并发请求所有端点的模型列表
|
||||||
registry = get_query_registry()
|
all_models: list = []
|
||||||
adapter = registry.get_adapter_for_provider(provider.name)
|
errors: list[str] = []
|
||||||
|
|
||||||
if not adapter:
|
async def fetch_endpoint_models(
|
||||||
raise HTTPException(
|
client: httpx.AsyncClient, config: dict
|
||||||
status_code=400, detail=f"No query adapter available for provider: {provider.name}"
|
) -> tuple[list, Optional[str]]:
|
||||||
|
base_url = config["base_url"]
|
||||||
|
if not base_url:
|
||||||
|
return [], None
|
||||||
|
base_url = base_url.rstrip("/")
|
||||||
|
api_format = config["api_format"]
|
||||||
|
api_key_value = config["api_key"]
|
||||||
|
extra_headers = config["extra_headers"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
if api_format in ["CLAUDE", "CLAUDE_CLI"]:
|
||||||
|
return await _fetch_claude_models(client, base_url, api_key_value, api_format)
|
||||||
|
elif api_format in ["GEMINI", "GEMINI_CLI"]:
|
||||||
|
return await _fetch_gemini_models(client, base_url, api_key_value, api_format)
|
||||||
|
else:
|
||||||
|
return await _fetch_openai_models(
|
||||||
|
client, base_url, api_key_value, api_format, extra_headers
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching models from {api_format} endpoint: {e}")
|
||||||
|
return [], f"{api_format}: {str(e)}"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*[fetch_endpoint_models(client, c) for c in endpoint_configs]
|
||||||
)
|
)
|
||||||
|
for models, error in results:
|
||||||
|
all_models.extend(models)
|
||||||
|
if error:
|
||||||
|
errors.append(error)
|
||||||
|
|
||||||
query_result = await adapter.query_available_models(
|
# 按 model id 去重(保留第一个)
|
||||||
api_key=api_key_value, endpoint_config=endpoint_config
|
seen_ids: set[str] = set()
|
||||||
)
|
unique_models: list = []
|
||||||
|
for model in all_models:
|
||||||
|
model_id = model.get("id")
|
||||||
|
if model_id and model_id not in seen_ids:
|
||||||
|
seen_ids.add(model_id)
|
||||||
|
unique_models.append(model)
|
||||||
|
|
||||||
|
error = "; ".join(errors) if errors else None
|
||||||
|
if not unique_models and not error:
|
||||||
|
error = "No models returned from any endpoint"
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": query_result.success,
|
"success": len(unique_models) > 0,
|
||||||
"data": query_result.to_dict(),
|
"data": {"models": unique_models, "error": error},
|
||||||
"provider": {
|
"provider": {
|
||||||
"id": provider.id,
|
"id": provider.id,
|
||||||
"name": provider.name,
|
"name": provider.name,
|
||||||
"display_name": provider.display_name,
|
"display_name": provider.display_name,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/cache/{provider_id}")
|
|
||||||
async def clear_query_cache(
|
|
||||||
provider_id: str,
|
|
||||||
api_key_id: Optional[str] = None,
|
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
清除查询缓存
|
|
||||||
|
|
||||||
Args:
|
|
||||||
provider_id: 提供商 ID
|
|
||||||
api_key_id: 可选,指定清除某个 API Key 的缓存
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
清除结果
|
|
||||||
"""
|
|
||||||
from sqlalchemy import select
|
|
||||||
|
|
||||||
# 获取提供商
|
|
||||||
result = await db.execute(select(Provider).where(Provider.id == provider_id))
|
|
||||||
provider = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if not provider:
|
|
||||||
raise HTTPException(status_code=404, detail="Provider not found")
|
|
||||||
|
|
||||||
registry = get_query_registry()
|
|
||||||
adapter = registry.get_adapter_for_provider(provider.name)
|
|
||||||
|
|
||||||
if adapter:
|
|
||||||
if api_key_id:
|
|
||||||
# 获取 API Key 值来清除缓存
|
|
||||||
from sqlalchemy.orm import selectinload
|
|
||||||
|
|
||||||
result = await db.execute(select(ProviderAPIKey).where(ProviderAPIKey.id == api_key_id))
|
|
||||||
api_key = result.scalar_one_or_none()
|
|
||||||
if api_key:
|
|
||||||
adapter.clear_cache(api_key.api_key)
|
|
||||||
else:
|
|
||||||
adapter.clear_cache()
|
|
||||||
|
|
||||||
return {"success": True, "message": "Cache cleared successfully"}
|
|
||||||
|
|||||||
@@ -852,7 +852,7 @@ class AdminImportConfigAdapter(AdminApiAdapter):
|
|||||||
from src.services.cache.invalidation import get_cache_invalidation_service
|
from src.services.cache.invalidation import get_cache_invalidation_service
|
||||||
|
|
||||||
cache_service = get_cache_invalidation_service()
|
cache_service = get_cache_invalidation_service()
|
||||||
cache_service.invalidate_all()
|
cache_service.clear_all_caches()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"message": "配置导入成功",
|
"message": "配置导入成功",
|
||||||
|
|||||||
@@ -140,9 +140,9 @@ class AnnouncementOptionalAuthAdapter(ApiAdapter):
|
|||||||
if not authorization or not authorization.lower().startswith("bearer "):
|
if not authorization or not authorization.lower().startswith("bearer "):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
token = authorization.replace("Bearer ", "").strip()
|
token = authorization[7:].strip()
|
||||||
try:
|
try:
|
||||||
payload = await AuthService.verify_token(token)
|
payload = await AuthService.verify_token(token, token_type="access")
|
||||||
user_id = payload.get("user_id")
|
user_id = payload.get("user_id")
|
||||||
if not user_id:
|
if not user_id:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -211,7 +211,7 @@ class AuthRefreshAdapter(AuthPublicAdapter):
|
|||||||
|
|
||||||
class AuthRegisterAdapter(AuthPublicAdapter):
|
class AuthRegisterAdapter(AuthPublicAdapter):
|
||||||
async def handle(self, context): # type: ignore[override]
|
async def handle(self, context): # type: ignore[override]
|
||||||
from ..models.database import SystemConfig
|
from src.models.database import SystemConfig
|
||||||
|
|
||||||
db = context.db
|
db = context.db
|
||||||
payload = context.ensure_json_body()
|
payload = context.ensure_json_body()
|
||||||
|
|||||||
@@ -5,13 +5,12 @@ from enum import Enum
|
|||||||
from typing import Any, Optional, Tuple
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Request
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from src.core.exceptions import QuotaExceededException
|
from src.core.exceptions import QuotaExceededException
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
from src.models.database import ApiKey, AuditEventType, User, UserRole
|
from src.models.database import ApiKey, AuditEventType, User, UserRole
|
||||||
from src.services.auth.service import AuthService
|
from src.services.auth.service import AuthService
|
||||||
from src.services.cache.user_cache import UserCacheService
|
|
||||||
from src.services.system.audit import AuditService
|
from src.services.system.audit import AuditService
|
||||||
from src.services.usage.service import UsageService
|
from src.services.usage.service import UsageService
|
||||||
|
|
||||||
@@ -178,9 +177,9 @@ class ApiRequestPipeline:
|
|||||||
if not authorization or not authorization.lower().startswith("bearer "):
|
if not authorization or not authorization.lower().startswith("bearer "):
|
||||||
raise HTTPException(status_code=401, detail="缺少管理员凭证")
|
raise HTTPException(status_code=401, detail="缺少管理员凭证")
|
||||||
|
|
||||||
token = authorization.replace("Bearer ", "").strip()
|
token = authorization[7:].strip()
|
||||||
try:
|
try:
|
||||||
payload = await self.auth_service.verify_token(token)
|
payload = await self.auth_service.verify_token(token, token_type="access")
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
@@ -191,8 +190,8 @@ class ApiRequestPipeline:
|
|||||||
if not user_id:
|
if not user_id:
|
||||||
raise HTTPException(status_code=401, detail="无效的管理员令牌")
|
raise HTTPException(status_code=401, detail="无效的管理员令牌")
|
||||||
|
|
||||||
# 使用缓存查询用户
|
# 直接查询数据库,确保返回的是当前 Session 绑定的对象
|
||||||
user = await UserCacheService.get_user_by_id(db, user_id)
|
user = db.query(User).filter(User.id == user_id).first()
|
||||||
if not user or not user.is_active:
|
if not user or not user.is_active:
|
||||||
raise HTTPException(status_code=403, detail="用户不存在或已禁用")
|
raise HTTPException(status_code=403, detail="用户不存在或已禁用")
|
||||||
|
|
||||||
@@ -205,9 +204,9 @@ class ApiRequestPipeline:
|
|||||||
if not authorization or not authorization.lower().startswith("bearer "):
|
if not authorization or not authorization.lower().startswith("bearer "):
|
||||||
raise HTTPException(status_code=401, detail="缺少用户凭证")
|
raise HTTPException(status_code=401, detail="缺少用户凭证")
|
||||||
|
|
||||||
token = authorization.replace("Bearer ", "").strip()
|
token = authorization[7:].strip()
|
||||||
try:
|
try:
|
||||||
payload = await self.auth_service.verify_token(token)
|
payload = await self.auth_service.verify_token(token, token_type="access")
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
@@ -218,8 +217,8 @@ class ApiRequestPipeline:
|
|||||||
if not user_id:
|
if not user_id:
|
||||||
raise HTTPException(status_code=401, detail="无效的用户令牌")
|
raise HTTPException(status_code=401, detail="无效的用户令牌")
|
||||||
|
|
||||||
# 使用缓存查询用户
|
# 直接查询数据库,确保返回的是当前 Session 绑定的对象
|
||||||
user = await UserCacheService.get_user_by_id(db, user_id)
|
user = db.query(User).filter(User.id == user_id).first()
|
||||||
if not user or not user.is_active:
|
if not user or not user.is_active:
|
||||||
raise HTTPException(status_code=403, detail="用户不存在或已禁用")
|
raise HTTPException(status_code=403, detail="用户不存在或已禁用")
|
||||||
|
|
||||||
@@ -242,11 +241,15 @@ class ApiRequestPipeline:
|
|||||||
status_code: Optional[int] = None,
|
status_code: Optional[int] = None,
|
||||||
error: Optional[str] = None,
|
error: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""记录审计事件
|
||||||
|
|
||||||
|
事务策略:复用请求级 Session,不单独提交。
|
||||||
|
审计记录随主事务一起提交,由中间件统一管理。
|
||||||
|
"""
|
||||||
if not getattr(adapter, "audit_log_enabled", True):
|
if not getattr(adapter, "audit_log_enabled", True):
|
||||||
return
|
return
|
||||||
|
|
||||||
bind = context.db.get_bind()
|
if context.db is None:
|
||||||
if bind is None:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
event_type = adapter.audit_success_event if success else adapter.audit_failure_event
|
event_type = adapter.audit_success_event if success else adapter.audit_failure_event
|
||||||
@@ -266,11 +269,11 @@ class ApiRequestPipeline:
|
|||||||
error=error,
|
error=error,
|
||||||
)
|
)
|
||||||
|
|
||||||
SessionMaker = sessionmaker(bind=bind)
|
|
||||||
audit_session = SessionMaker()
|
|
||||||
try:
|
try:
|
||||||
|
# 复用请求级 Session,不创建新的连接
|
||||||
|
# 审计记录随主事务一起提交,由中间件统一管理
|
||||||
self.audit_service.log_event(
|
self.audit_service.log_event(
|
||||||
db=audit_session,
|
db=context.db,
|
||||||
event_type=event_type,
|
event_type=event_type,
|
||||||
description=f"{context.request.method} {context.request.url.path} via {adapter.name}",
|
description=f"{context.request.method} {context.request.url.path} via {adapter.name}",
|
||||||
user_id=context.user.id if context.user else None,
|
user_id=context.user.id if context.user else None,
|
||||||
@@ -282,12 +285,9 @@ class ApiRequestPipeline:
|
|||||||
error_message=error,
|
error_message=error,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
audit_session.commit()
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
audit_session.rollback()
|
# 审计失败不应影响主请求,仅记录警告
|
||||||
logger.warning(f"[Audit] Failed to record event for adapter={adapter.name}: {exc}")
|
logger.warning(f"[Audit] Failed to record event for adapter={adapter.name}: {exc}")
|
||||||
finally:
|
|
||||||
audit_session.close()
|
|
||||||
|
|
||||||
def _build_audit_metadata(
|
def _build_audit_metadata(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -731,8 +731,15 @@ class DashboardDailyStatsAdapter(DashboardAdapter):
|
|||||||
)
|
)
|
||||||
# stats_daily.date 存储的是业务日期对应的 UTC 开始时间
|
# stats_daily.date 存储的是业务日期对应的 UTC 开始时间
|
||||||
# 需要转回业务时区再取日期,才能与日期序列匹配
|
# 需要转回业务时区再取日期,才能与日期序列匹配
|
||||||
|
def _to_business_date_str(value: datetime) -> str:
|
||||||
|
if value.tzinfo is None:
|
||||||
|
value_utc = value.replace(tzinfo=timezone.utc)
|
||||||
|
else:
|
||||||
|
value_utc = value.astimezone(timezone.utc)
|
||||||
|
return value_utc.astimezone(app_tz).date().isoformat()
|
||||||
|
|
||||||
stats_map = {
|
stats_map = {
|
||||||
stat.date.replace(tzinfo=timezone.utc).astimezone(app_tz).date().isoformat(): {
|
_to_business_date_str(stat.date): {
|
||||||
"requests": stat.total_requests,
|
"requests": stat.total_requests,
|
||||||
"tokens": stat.input_tokens + stat.output_tokens + stat.cache_creation_tokens + stat.cache_read_tokens,
|
"tokens": stat.input_tokens + stat.output_tokens + stat.cache_creation_tokens + stat.cache_read_tokens,
|
||||||
"cost": stat.total_cost,
|
"cost": stat.total_cost,
|
||||||
@@ -790,6 +797,38 @@ class DashboardDailyStatsAdapter(DashboardAdapter):
|
|||||||
"unique_providers": today_unique_providers,
|
"unique_providers": today_unique_providers,
|
||||||
"fallback_count": today_fallback_count,
|
"fallback_count": today_fallback_count,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 历史预聚合缺失时兜底:按业务日范围实时计算(仅补最近少量缺失,避免全表扫描)
|
||||||
|
yesterday_date = today_local.date() - timedelta(days=1)
|
||||||
|
historical_end = min(end_date_local.date(), yesterday_date)
|
||||||
|
missing_dates: list[str] = []
|
||||||
|
cursor = start_date_local.date()
|
||||||
|
while cursor <= historical_end:
|
||||||
|
date_str = cursor.isoformat()
|
||||||
|
if date_str not in stats_map:
|
||||||
|
missing_dates.append(date_str)
|
||||||
|
cursor += timedelta(days=1)
|
||||||
|
|
||||||
|
if missing_dates:
|
||||||
|
for date_str in missing_dates[-7:]:
|
||||||
|
target_local = datetime.fromisoformat(date_str).replace(tzinfo=app_tz)
|
||||||
|
computed = StatsAggregatorService.compute_daily_stats(db, target_local)
|
||||||
|
stats_map[date_str] = {
|
||||||
|
"requests": computed["total_requests"],
|
||||||
|
"tokens": (
|
||||||
|
computed["input_tokens"]
|
||||||
|
+ computed["output_tokens"]
|
||||||
|
+ computed["cache_creation_tokens"]
|
||||||
|
+ computed["cache_read_tokens"]
|
||||||
|
),
|
||||||
|
"cost": computed["total_cost"],
|
||||||
|
"avg_response_time": computed["avg_response_time_ms"] / 1000.0
|
||||||
|
if computed["avg_response_time_ms"]
|
||||||
|
else 0,
|
||||||
|
"unique_models": computed["unique_models"],
|
||||||
|
"unique_providers": computed["unique_providers"],
|
||||||
|
"fallback_count": computed["fallback_count"],
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
# 普通用户:仍需实时查询(用户级预聚合可选)
|
# 普通用户:仍需实时查询(用户级预聚合可选)
|
||||||
query = db.query(Usage).filter(
|
query = db.query(Usage).filter(
|
||||||
|
|||||||
@@ -28,7 +28,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from typing import Any, Callable, Dict, Optional, Protocol, runtime_checkable
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
@@ -43,6 +43,9 @@ from src.services.provider.format import normalize_api_format
|
|||||||
from src.services.system.audit import audit_service
|
from src.services.system.audit import audit_service
|
||||||
from src.services.usage.service import UsageService
|
from src.services.usage.service import UsageService
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.api.handlers.base.stream_context import StreamContext
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MessageTelemetry:
|
class MessageTelemetry:
|
||||||
@@ -399,6 +402,41 @@ class BaseMessageHandler:
|
|||||||
# 创建后台任务,不阻塞当前流
|
# 创建后台任务,不阻塞当前流
|
||||||
asyncio.create_task(_do_update())
|
asyncio.create_task(_do_update())
|
||||||
|
|
||||||
|
def _update_usage_to_streaming_with_ctx(self, ctx: "StreamContext") -> None:
|
||||||
|
"""更新 Usage 状态为 streaming,同时更新 provider 和 target_model
|
||||||
|
|
||||||
|
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: 流式上下文,包含 provider_name 和 mapped_model
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
from src.database.database import get_db
|
||||||
|
|
||||||
|
target_request_id = self.request_id
|
||||||
|
provider = ctx.provider_name
|
||||||
|
target_model = ctx.mapped_model
|
||||||
|
|
||||||
|
async def _do_update() -> None:
|
||||||
|
try:
|
||||||
|
db_gen = get_db()
|
||||||
|
db = next(db_gen)
|
||||||
|
try:
|
||||||
|
UsageService.update_usage_status(
|
||||||
|
db=db,
|
||||||
|
request_id=target_request_id,
|
||||||
|
status="streaming",
|
||||||
|
provider=provider,
|
||||||
|
target_model=target_model,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[{target_request_id}] 更新 Usage 状态为 streaming 失败: {e}")
|
||||||
|
|
||||||
|
# 创建后台任务,不阻塞当前流
|
||||||
|
asyncio.create_task(_do_update())
|
||||||
|
|
||||||
def _log_request_error(self, message: str, error: Exception) -> None:
|
def _log_request_error(self, message: str, error: Exception) -> None:
|
||||||
"""记录请求错误日志,对业务异常不打印堆栈
|
"""记录请求错误日志,对业务异常不打印堆栈
|
||||||
|
|
||||||
@@ -411,9 +449,10 @@ class BaseMessageHandler:
|
|||||||
QuotaExceededException,
|
QuotaExceededException,
|
||||||
RateLimitException,
|
RateLimitException,
|
||||||
ModelNotSupportedException,
|
ModelNotSupportedException,
|
||||||
|
UpstreamClientException,
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(error, (ProviderException, QuotaExceededException, RateLimitException, ModelNotSupportedException)):
|
if isinstance(error, (ProviderException, QuotaExceededException, RateLimitException, ModelNotSupportedException, UpstreamClientException)):
|
||||||
# 业务异常:简洁日志,不打印堆栈
|
# 业务异常:简洁日志,不打印堆栈
|
||||||
logger.error(f"{message}: [{type(error).__name__}] {error}")
|
logger.error(f"{message}: [{type(error).__name__}] {error}")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -64,18 +64,6 @@ class ChatAdapterBase(ApiAdapter):
|
|||||||
|
|
||||||
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
||||||
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
|
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
|
||||||
self.response_normalizer = None
|
|
||||||
# 可选启用响应规范化
|
|
||||||
self._init_response_normalizer()
|
|
||||||
|
|
||||||
def _init_response_normalizer(self):
|
|
||||||
"""初始化响应规范化器 - 子类可覆盖"""
|
|
||||||
try:
|
|
||||||
from src.services.provider.response_normalizer import ResponseNormalizer
|
|
||||||
|
|
||||||
self.response_normalizer = ResponseNormalizer()
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def handle(self, context: ApiRequestContext):
|
async def handle(self, context: ApiRequestContext):
|
||||||
"""处理 Chat API 请求"""
|
"""处理 Chat API 请求"""
|
||||||
@@ -228,8 +216,6 @@ class ChatAdapterBase(ApiAdapter):
|
|||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
allowed_api_formats=self.allowed_api_formats,
|
allowed_api_formats=self.allowed_api_formats,
|
||||||
response_normalizer=self.response_normalizer,
|
|
||||||
enable_response_normalization=self.response_normalizer is not None,
|
|
||||||
adapter_detector=self.detect_capability_requirements,
|
adapter_detector=self.detect_capability_requirements,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -88,8 +88,6 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
user_agent: str,
|
user_agent: str,
|
||||||
start_time: float,
|
start_time: float,
|
||||||
allowed_api_formats: Optional[list] = None,
|
allowed_api_formats: Optional[list] = None,
|
||||||
response_normalizer: Optional[Any] = None,
|
|
||||||
enable_response_normalization: bool = False,
|
|
||||||
adapter_detector: Optional[Callable[[Dict[str, str], Optional[Dict[str, Any]]], Dict[str, bool]]] = None,
|
adapter_detector: Optional[Callable[[Dict[str, str], Optional[Dict[str, Any]]], Dict[str, bool]]] = None,
|
||||||
):
|
):
|
||||||
allowed = allowed_api_formats or [self.FORMAT_ID]
|
allowed = allowed_api_formats or [self.FORMAT_ID]
|
||||||
@@ -106,8 +104,6 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
)
|
)
|
||||||
self._parser: Optional[ResponseParser] = None
|
self._parser: Optional[ResponseParser] = None
|
||||||
self._request_builder = PassthroughRequestBuilder()
|
self._request_builder = PassthroughRequestBuilder()
|
||||||
self.response_normalizer = response_normalizer
|
|
||||||
self.enable_response_normalization = enable_response_normalization
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parser(self) -> ResponseParser:
|
def parser(self) -> ResponseParser:
|
||||||
@@ -266,8 +262,11 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
if mapping and mapping.model:
|
if mapping and mapping.model:
|
||||||
# 使用 select_provider_model_name 支持别名功能
|
# 使用 select_provider_model_name 支持别名功能
|
||||||
# 传入 api_key.id 作为 affinity_key,实现相同用户稳定选择同一别名
|
# 传入 api_key.id 作为 affinity_key,实现相同用户稳定选择同一别名
|
||||||
|
# 传入 api_format 用于过滤适用的别名作用域
|
||||||
affinity_key = self.api_key.id if self.api_key else None
|
affinity_key = self.api_key.id if self.api_key else None
|
||||||
mapped_name = mapping.model.select_provider_model_name(affinity_key)
|
mapped_name = mapping.model.select_provider_model_name(
|
||||||
|
affinity_key, api_format=self.FORMAT_ID
|
||||||
|
)
|
||||||
logger.debug(f"[Chat] 模型映射: {source_model} -> {mapped_name}")
|
logger.debug(f"[Chat] 模型映射: {source_model} -> {mapped_name}")
|
||||||
return mapped_name
|
return mapped_name
|
||||||
|
|
||||||
@@ -294,11 +293,15 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
# 创建类型安全的流式上下文
|
# 创建类型安全的流式上下文
|
||||||
ctx = StreamContext(model=model, api_format=api_format)
|
ctx = StreamContext(model=model, api_format=api_format)
|
||||||
|
|
||||||
|
# 创建更新状态的回调闭包(可以访问 ctx)
|
||||||
|
def update_streaming_status() -> None:
|
||||||
|
self._update_usage_to_streaming_with_ctx(ctx)
|
||||||
|
|
||||||
# 创建流处理器
|
# 创建流处理器
|
||||||
stream_processor = StreamProcessor(
|
stream_processor = StreamProcessor(
|
||||||
request_id=self.request_id,
|
request_id=self.request_id,
|
||||||
default_parser=self.parser,
|
default_parser=self.parser,
|
||||||
on_streaming_start=self._update_usage_to_streaming,
|
on_streaming_start=update_streaming_status,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 定义请求函数
|
# 定义请求函数
|
||||||
@@ -463,7 +466,13 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
pool=config.http_pool_timeout,
|
pool=config.http_pool_timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
http_client = httpx.AsyncClient(timeout=timeout_config, follow_redirects=True)
|
# 创建 HTTP 客户端(支持代理配置)
|
||||||
|
from src.clients.http_client import HTTPClientPool
|
||||||
|
|
||||||
|
http_client = HTTPClientPool.create_client_with_proxy(
|
||||||
|
proxy_config=endpoint.proxy,
|
||||||
|
timeout=timeout_config,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
response_ctx = http_client.stream(
|
response_ctx = http_client.stream(
|
||||||
"POST", url, json=provider_payload, headers=provider_headers
|
"POST", url, json=provider_payload, headers=provider_headers
|
||||||
@@ -631,10 +640,14 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
logger.info(f" [{self.request_id}] 发送非流式请求: Provider={provider.name}, "
|
logger.info(f" [{self.request_id}] 发送非流式请求: Provider={provider.name}, "
|
||||||
f"模型={model} -> {mapped_model or '无映射'}")
|
f"模型={model} -> {mapped_model or '无映射'}")
|
||||||
|
|
||||||
async with httpx.AsyncClient(
|
# 创建 HTTP 客户端(支持代理配置)
|
||||||
timeout=float(endpoint.timeout),
|
from src.clients.http_client import HTTPClientPool
|
||||||
follow_redirects=True,
|
|
||||||
) as http_client:
|
http_client = HTTPClientPool.create_client_with_proxy(
|
||||||
|
proxy_config=endpoint.proxy,
|
||||||
|
timeout=httpx.Timeout(float(endpoint.timeout)),
|
||||||
|
)
|
||||||
|
async with http_client:
|
||||||
resp = await http_client.post(url, json=provider_payload, headers=provider_hdrs)
|
resp = await http_client.post(url, json=provider_payload, headers=provider_hdrs)
|
||||||
|
|
||||||
status_code = resp.status_code
|
status_code = resp.status_code
|
||||||
|
|||||||
@@ -155,8 +155,11 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
if mapping and mapping.model:
|
if mapping and mapping.model:
|
||||||
# 使用 select_provider_model_name 支持别名功能
|
# 使用 select_provider_model_name 支持别名功能
|
||||||
# 传入 api_key.id 作为 affinity_key,实现相同用户稳定选择同一别名
|
# 传入 api_key.id 作为 affinity_key,实现相同用户稳定选择同一别名
|
||||||
|
# 传入 api_format 用于过滤适用的别名作用域
|
||||||
affinity_key = self.api_key.id if self.api_key else None
|
affinity_key = self.api_key.id if self.api_key else None
|
||||||
mapped_name = mapping.model.select_provider_model_name(affinity_key)
|
mapped_name = mapping.model.select_provider_model_name(
|
||||||
|
affinity_key, api_format=self.FORMAT_ID
|
||||||
|
)
|
||||||
logger.debug(f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)")
|
logger.debug(f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)")
|
||||||
return mapped_name
|
return mapped_name
|
||||||
|
|
||||||
@@ -451,7 +454,13 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
f"Key=***{key.api_key[-4:]}, "
|
f"Key=***{key.api_key[-4:]}, "
|
||||||
f"原始模型={ctx.model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}")
|
f"原始模型={ctx.model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}")
|
||||||
|
|
||||||
http_client = httpx.AsyncClient(timeout=timeout_config, follow_redirects=True)
|
# 创建 HTTP 客户端(支持代理配置)
|
||||||
|
from src.clients.http_client import HTTPClientPool
|
||||||
|
|
||||||
|
http_client = HTTPClientPool.create_client_with_proxy(
|
||||||
|
proxy_config=endpoint.proxy,
|
||||||
|
timeout=timeout_config,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
response_ctx = http_client.stream(
|
response_ctx = http_client.stream(
|
||||||
"POST", url, json=provider_payload, headers=provider_headers
|
"POST", url, json=provider_payload, headers=provider_headers
|
||||||
@@ -523,7 +532,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
async for chunk in stream_response.aiter_raw():
|
async for chunk in stream_response.aiter_raw():
|
||||||
# 在第一次输出数据前更新状态为 streaming
|
# 在第一次输出数据前更新状态为 streaming
|
||||||
if not streaming_status_updated:
|
if not streaming_status_updated:
|
||||||
self._update_usage_to_streaming(ctx.request_id)
|
self._update_usage_to_streaming_with_ctx(ctx)
|
||||||
streaming_status_updated = True
|
streaming_status_updated = True
|
||||||
|
|
||||||
buffer += chunk
|
buffer += chunk
|
||||||
@@ -807,7 +816,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
|
|
||||||
# 在第一次输出数据前更新状态为 streaming
|
# 在第一次输出数据前更新状态为 streaming
|
||||||
if prefetched_chunks:
|
if prefetched_chunks:
|
||||||
self._update_usage_to_streaming(ctx.request_id)
|
self._update_usage_to_streaming_with_ctx(ctx)
|
||||||
|
|
||||||
# 先处理预读的字节块
|
# 先处理预读的字节块
|
||||||
for chunk in prefetched_chunks:
|
for chunk in prefetched_chunks:
|
||||||
@@ -1416,10 +1425,14 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
f"Key=***{key.api_key[-4:]}, "
|
f"Key=***{key.api_key[-4:]}, "
|
||||||
f"原始模型={model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}")
|
f"原始模型={model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}")
|
||||||
|
|
||||||
async with httpx.AsyncClient(
|
# 创建 HTTP 客户端(支持代理配置)
|
||||||
timeout=float(endpoint.timeout),
|
from src.clients.http_client import HTTPClientPool
|
||||||
follow_redirects=True,
|
|
||||||
) as http_client:
|
http_client = HTTPClientPool.create_client_with_proxy(
|
||||||
|
proxy_config=endpoint.proxy,
|
||||||
|
timeout=httpx.Timeout(float(endpoint.timeout)),
|
||||||
|
)
|
||||||
|
async with http_client:
|
||||||
resp = await http_client.post(url, json=provider_payload, headers=provider_headers)
|
resp = await http_client.post(url, json=provider_payload, headers=provider_headers)
|
||||||
|
|
||||||
status_code = resp.status_code
|
status_code = resp.status_code
|
||||||
|
|||||||
@@ -131,10 +131,5 @@ class ClaudeChatHandler(ChatHandlerBase):
|
|||||||
Returns:
|
Returns:
|
||||||
规范化后的响应
|
规范化后的响应
|
||||||
"""
|
"""
|
||||||
if self.response_normalizer and self.response_normalizer.should_normalize(response):
|
# 作为中转站,直接透传响应,不做标准化处理
|
||||||
result: Dict[str, Any] = self.response_normalizer.normalize_claude_response(
|
|
||||||
response_data=response,
|
|
||||||
request_id=self.request_id,
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
return response
|
return response
|
||||||
|
|||||||
@@ -148,17 +148,6 @@ class GeminiChatHandler(ChatHandlerBase):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
规范化后的响应
|
规范化后的响应
|
||||||
|
|
||||||
TODO: 如果需要,实现响应规范化逻辑
|
|
||||||
"""
|
"""
|
||||||
# 可选:使用 response_normalizer 进行规范化
|
# 作为中转站,直接透传响应,不做标准化处理
|
||||||
# if (
|
|
||||||
# self.response_normalizer
|
|
||||||
# and self.response_normalizer.should_normalize(response)
|
|
||||||
# ):
|
|
||||||
# return self.response_normalizer.normalize_gemini_response(
|
|
||||||
# response_data=response,
|
|
||||||
# request_id=self.request_id,
|
|
||||||
# strict=False,
|
|
||||||
# )
|
|
||||||
return response
|
return response
|
||||||
|
|||||||
@@ -128,10 +128,5 @@ class OpenAIChatHandler(ChatHandlerBase):
|
|||||||
Returns:
|
Returns:
|
||||||
规范化后的响应
|
规范化后的响应
|
||||||
"""
|
"""
|
||||||
if self.response_normalizer and self.response_normalizer.should_normalize(response):
|
# 作为中转站,直接透传响应,不做标准化处理
|
||||||
return self.response_normalizer.normalize_openai_response(
|
|
||||||
response_data=response,
|
|
||||||
request_id=self.request_id,
|
|
||||||
strict=False,
|
|
||||||
)
|
|
||||||
return response
|
return response
|
||||||
|
|||||||
@@ -5,12 +5,55 @@
|
|||||||
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
from urllib.parse import quote, urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
|
|
||||||
|
|
||||||
|
def build_proxy_url(proxy_config: Dict[str, Any]) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
根据代理配置构建完整的代理 URL
|
||||||
|
|
||||||
|
Args:
|
||||||
|
proxy_config: 代理配置字典,包含 url, username, password, enabled
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
完整的代理 URL,如 socks5://user:pass@host:port
|
||||||
|
如果 enabled=False 或无配置,返回 None
|
||||||
|
"""
|
||||||
|
if not proxy_config:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 检查 enabled 字段,默认为 True(兼容旧数据)
|
||||||
|
if not proxy_config.get("enabled", True):
|
||||||
|
return None
|
||||||
|
|
||||||
|
proxy_url = proxy_config.get("url")
|
||||||
|
if not proxy_url:
|
||||||
|
return None
|
||||||
|
|
||||||
|
username = proxy_config.get("username")
|
||||||
|
password = proxy_config.get("password")
|
||||||
|
|
||||||
|
# 只要有用户名就添加认证信息(密码可以为空)
|
||||||
|
if username:
|
||||||
|
parsed = urlparse(proxy_url)
|
||||||
|
# URL 编码用户名和密码,处理特殊字符(如 @, :, /)
|
||||||
|
encoded_username = quote(username, safe="")
|
||||||
|
encoded_password = quote(password, safe="") if password else ""
|
||||||
|
# 重新构建带认证的代理 URL
|
||||||
|
if encoded_password:
|
||||||
|
auth_proxy = f"{parsed.scheme}://{encoded_username}:{encoded_password}@{parsed.netloc}"
|
||||||
|
else:
|
||||||
|
auth_proxy = f"{parsed.scheme}://{encoded_username}@{parsed.netloc}"
|
||||||
|
if parsed.path:
|
||||||
|
auth_proxy += parsed.path
|
||||||
|
return auth_proxy
|
||||||
|
|
||||||
|
return proxy_url
|
||||||
|
|
||||||
|
|
||||||
class HTTPClientPool:
|
class HTTPClientPool:
|
||||||
"""
|
"""
|
||||||
@@ -121,6 +164,44 @@ class HTTPClientPool:
|
|||||||
finally:
|
finally:
|
||||||
await client.aclose()
|
await client.aclose()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_client_with_proxy(
|
||||||
|
cls,
|
||||||
|
proxy_config: Optional[Dict[str, Any]] = None,
|
||||||
|
timeout: Optional[httpx.Timeout] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> httpx.AsyncClient:
|
||||||
|
"""
|
||||||
|
创建带代理配置的HTTP客户端
|
||||||
|
|
||||||
|
Args:
|
||||||
|
proxy_config: 代理配置字典,包含 url, username, password
|
||||||
|
timeout: 超时配置
|
||||||
|
**kwargs: 其他 httpx.AsyncClient 配置参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
配置好的 httpx.AsyncClient 实例
|
||||||
|
"""
|
||||||
|
config: Dict[str, Any] = {
|
||||||
|
"http2": False,
|
||||||
|
"verify": True,
|
||||||
|
"follow_redirects": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
if timeout:
|
||||||
|
config["timeout"] = timeout
|
||||||
|
else:
|
||||||
|
config["timeout"] = httpx.Timeout(10.0, read=300.0)
|
||||||
|
|
||||||
|
# 添加代理配置
|
||||||
|
proxy_url = build_proxy_url(proxy_config) if proxy_config else None
|
||||||
|
if proxy_url:
|
||||||
|
config["proxy"] = proxy_url
|
||||||
|
logger.debug(f"创建带代理的HTTP客户端: {proxy_config.get('url', 'unknown')}")
|
||||||
|
|
||||||
|
config.update(kwargs)
|
||||||
|
return httpx.AsyncClient(**config)
|
||||||
|
|
||||||
|
|
||||||
# 便捷访问函数
|
# 便捷访问函数
|
||||||
def get_http_client() -> httpx.AsyncClient:
|
def get_http_client() -> httpx.AsyncClient:
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ class RedisClientManager:
|
|||||||
if self._circuit_open_until and time.time() < self._circuit_open_until:
|
if self._circuit_open_until and time.time() < self._circuit_open_until:
|
||||||
remaining = self._circuit_open_until - time.time()
|
remaining = self._circuit_open_until - time.time()
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Redis 客户端处于熔断状态,跳过初始化,剩余 %.1f 秒 (last_error: %s)",
|
"Redis 客户端处于熔断状态,跳过初始化,剩余 {:.1f} 秒 (last_error: {})",
|
||||||
remaining,
|
remaining,
|
||||||
self._last_error,
|
self._last_error,
|
||||||
)
|
)
|
||||||
@@ -200,7 +200,7 @@ class RedisClientManager:
|
|||||||
if self._consecutive_failures >= self._circuit_threshold:
|
if self._consecutive_failures >= self._circuit_threshold:
|
||||||
self._circuit_open_until = time.time() + self._circuit_reset_seconds
|
self._circuit_open_until = time.time() + self._circuit_reset_seconds
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Redis 初始化连续失败 %s 次,开启熔断 %s 秒。"
|
"Redis 初始化连续失败 {} 次,开启熔断 {} 秒。"
|
||||||
"熔断期间以下功能将降级: 缓存亲和性、分布式并发控制、RPM限流。"
|
"熔断期间以下功能将降级: 缓存亲和性、分布式并发控制、RPM限流。"
|
||||||
"可通过管理 API /api/admin/system/redis/reset-circuit 手动重置。",
|
"可通过管理 API /api/admin/system/redis/reset-circuit 手动重置。",
|
||||||
self._consecutive_failures,
|
self._consecutive_failures,
|
||||||
@@ -267,6 +267,9 @@ async def get_redis_client(require_redis: bool = False) -> Optional[aioredis.Red
|
|||||||
|
|
||||||
if _redis_manager is None:
|
if _redis_manager is None:
|
||||||
_redis_manager = RedisClientManager()
|
_redis_manager = RedisClientManager()
|
||||||
|
# 如果尚未连接(例如启动时降级、或 close() 后),尝试重新初始化。
|
||||||
|
# initialize() 内部包含熔断器逻辑,避免频繁重试导致抖动。
|
||||||
|
if _redis_manager.get_client() is None:
|
||||||
await _redis_manager.initialize(require_redis=require_redis)
|
await _redis_manager.initialize(require_redis=require_redis)
|
||||||
|
|
||||||
return _redis_manager.get_client()
|
return _redis_manager.get_client()
|
||||||
|
|||||||
@@ -41,8 +41,8 @@ class CacheSize:
|
|||||||
class ConcurrencyDefaults:
|
class ConcurrencyDefaults:
|
||||||
"""并发控制默认值"""
|
"""并发控制默认值"""
|
||||||
|
|
||||||
# 自适应并发初始限制(保守值)
|
# 自适应并发初始限制(宽松起步,遇到 429 再降低)
|
||||||
INITIAL_LIMIT = 3
|
INITIAL_LIMIT = 50
|
||||||
|
|
||||||
# 429错误后的冷却时间(分钟)- 在此期间不会增加并发限制
|
# 429错误后的冷却时间(分钟)- 在此期间不会增加并发限制
|
||||||
COOLDOWN_AFTER_429_MINUTES = 5
|
COOLDOWN_AFTER_429_MINUTES = 5
|
||||||
@@ -67,13 +67,14 @@ class ConcurrencyDefaults:
|
|||||||
MIN_SAMPLES_FOR_DECISION = 5
|
MIN_SAMPLES_FOR_DECISION = 5
|
||||||
|
|
||||||
# 扩容步长 - 每次扩容增加的并发数
|
# 扩容步长 - 每次扩容增加的并发数
|
||||||
INCREASE_STEP = 1
|
INCREASE_STEP = 2
|
||||||
|
|
||||||
# 缩容乘数 - 遇到 429 时的缩容比例
|
# 缩容乘数 - 遇到 429 时基于当前并发数的缩容比例
|
||||||
DECREASE_MULTIPLIER = 0.7
|
# 0.85 表示降到触发 429 时并发数的 85%
|
||||||
|
DECREASE_MULTIPLIER = 0.85
|
||||||
|
|
||||||
# 最大并发限制上限
|
# 最大并发限制上限
|
||||||
MAX_CONCURRENT_LIMIT = 100
|
MAX_CONCURRENT_LIMIT = 200
|
||||||
|
|
||||||
# 最小并发限制下限
|
# 最小并发限制下限
|
||||||
MIN_CONCURRENT_LIMIT = 1
|
MIN_CONCURRENT_LIMIT = 1
|
||||||
@@ -85,6 +86,11 @@ class ConcurrencyDefaults:
|
|||||||
# 探测性扩容最小请求数 - 在探测间隔内至少需要这么多请求
|
# 探测性扩容最小请求数 - 在探测间隔内至少需要这么多请求
|
||||||
PROBE_INCREASE_MIN_REQUESTS = 10
|
PROBE_INCREASE_MIN_REQUESTS = 10
|
||||||
|
|
||||||
|
# === 缓存用户预留比例 ===
|
||||||
|
# 缓存用户槽位预留比例(新用户可用 1 - 此值)
|
||||||
|
# 0.1 表示缓存用户预留 10%,新用户可用 90%
|
||||||
|
CACHE_RESERVATION_RATIO = 0.1
|
||||||
|
|
||||||
|
|
||||||
class CircuitBreakerDefaults:
|
class CircuitBreakerDefaults:
|
||||||
"""熔断器配置默认值(滑动窗口 + 半开状态模式)
|
"""熔断器配置默认值(滑动窗口 + 半开状态模式)
|
||||||
|
|||||||
@@ -105,6 +105,13 @@ class Config:
|
|||||||
self.llm_api_rate_limit = int(os.getenv("LLM_API_RATE_LIMIT", "100"))
|
self.llm_api_rate_limit = int(os.getenv("LLM_API_RATE_LIMIT", "100"))
|
||||||
self.public_api_rate_limit = int(os.getenv("PUBLIC_API_RATE_LIMIT", "60"))
|
self.public_api_rate_limit = int(os.getenv("PUBLIC_API_RATE_LIMIT", "60"))
|
||||||
|
|
||||||
|
# 可信代理配置
|
||||||
|
# TRUSTED_PROXY_COUNT: 信任的代理层数(默认 1,即信任最近一层代理)
|
||||||
|
# 设置为 0 表示不信任任何代理头,直接使用连接 IP
|
||||||
|
# 当服务部署在 Nginx/CloudFlare 等反向代理后面时,设置为对应的代理层数
|
||||||
|
# 如果服务直接暴露公网,应设置为 0 以防止 IP 伪造
|
||||||
|
self.trusted_proxy_count = int(os.getenv("TRUSTED_PROXY_COUNT", "1"))
|
||||||
|
|
||||||
# 异常处理配置
|
# 异常处理配置
|
||||||
# 设置为 True 时,ProxyException 会传播到路由层以便记录 provider_request_headers
|
# 设置为 True 时,ProxyException 会传播到路由层以便记录 provider_request_headers
|
||||||
# 设置为 False 时,使用全局异常处理器统一处理
|
# 设置为 False 时,使用全局异常处理器统一处理
|
||||||
@@ -122,9 +129,9 @@ class Config:
|
|||||||
|
|
||||||
# 并发控制配置
|
# 并发控制配置
|
||||||
# CONCURRENCY_SLOT_TTL: 并发槽位 TTL(秒),防止死锁
|
# CONCURRENCY_SLOT_TTL: 并发槽位 TTL(秒),防止死锁
|
||||||
# CACHE_RESERVATION_RATIO: 缓存用户预留比例(默认 30%)
|
# CACHE_RESERVATION_RATIO: 缓存用户预留比例(默认 10%,新用户可用 90%)
|
||||||
self.concurrency_slot_ttl = int(os.getenv("CONCURRENCY_SLOT_TTL", "600"))
|
self.concurrency_slot_ttl = int(os.getenv("CONCURRENCY_SLOT_TTL", "600"))
|
||||||
self.cache_reservation_ratio = float(os.getenv("CACHE_RESERVATION_RATIO", "0.3"))
|
self.cache_reservation_ratio = float(os.getenv("CACHE_RESERVATION_RATIO", "0.1"))
|
||||||
|
|
||||||
# HTTP 请求超时配置(秒)
|
# HTTP 请求超时配置(秒)
|
||||||
self.http_connect_timeout = float(os.getenv("HTTP_CONNECT_TIMEOUT", "10.0"))
|
self.http_connect_timeout = float(os.getenv("HTTP_CONNECT_TIMEOUT", "10.0"))
|
||||||
|
|||||||
@@ -46,6 +46,11 @@ class BatchCommitter:
|
|||||||
|
|
||||||
def mark_dirty(self, session: Session):
|
def mark_dirty(self, session: Session):
|
||||||
"""标记 Session 有待提交的更改"""
|
"""标记 Session 有待提交的更改"""
|
||||||
|
# 请求级事务由中间件统一 commit/rollback;避免后台任务在请求中途误提交。
|
||||||
|
if session is None:
|
||||||
|
return
|
||||||
|
if session.info.get("managed_by_middleware"):
|
||||||
|
return
|
||||||
self._pending_sessions.add(session)
|
self._pending_sessions.add(session)
|
||||||
|
|
||||||
async def _batch_commit_loop(self):
|
async def _batch_commit_loop(self):
|
||||||
|
|||||||
@@ -1,168 +0,0 @@
|
|||||||
"""
|
|
||||||
统一的请求上下文
|
|
||||||
|
|
||||||
RequestContext 贯穿整个请求生命周期,包含所有请求相关信息。
|
|
||||||
这确保了数据在各层之间传递时不会丢失。
|
|
||||||
|
|
||||||
使用方式:
|
|
||||||
1. Pipeline 层创建 RequestContext
|
|
||||||
2. 各层通过 context 访问和更新信息
|
|
||||||
3. Adapter 层使用 context 记录 Usage
|
|
||||||
"""
|
|
||||||
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RequestContext:
|
|
||||||
"""
|
|
||||||
请求上下文 - 贯穿整个请求生命周期
|
|
||||||
|
|
||||||
设计原则:
|
|
||||||
1. 在请求开始时创建,包含所有已知信息
|
|
||||||
2. 在请求执行过程中逐步填充 Provider 信息
|
|
||||||
3. 在请求结束时用于记录 Usage
|
|
||||||
"""
|
|
||||||
|
|
||||||
# ==================== 请求标识 ====================
|
|
||||||
request_id: str
|
|
||||||
|
|
||||||
# ==================== 认证信息 ====================
|
|
||||||
user: Any # User model
|
|
||||||
api_key: Any # ApiKey model
|
|
||||||
db: Any # Database session
|
|
||||||
|
|
||||||
# ==================== 请求信息 ====================
|
|
||||||
api_format: str # CLAUDE, OPENAI, GEMINI, etc.
|
|
||||||
model: str # 用户请求的模型名
|
|
||||||
is_stream: bool = False
|
|
||||||
|
|
||||||
# ==================== 原始请求 ====================
|
|
||||||
original_headers: Dict[str, str] = field(default_factory=dict)
|
|
||||||
original_body: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|
||||||
# ==================== 客户端信息 ====================
|
|
||||||
client_ip: str = "unknown"
|
|
||||||
user_agent: str = ""
|
|
||||||
|
|
||||||
# ==================== 计时 ====================
|
|
||||||
start_time: float = field(default_factory=time.time)
|
|
||||||
|
|
||||||
# ==================== Provider 信息(请求执行后填充)====================
|
|
||||||
provider_name: Optional[str] = None
|
|
||||||
provider_id: Optional[str] = None
|
|
||||||
endpoint_id: Optional[str] = None
|
|
||||||
provider_api_key_id: Optional[str] = None
|
|
||||||
|
|
||||||
# ==================== 模型映射信息 ====================
|
|
||||||
resolved_model: Optional[str] = None # 映射后的模型名
|
|
||||||
original_model: Optional[str] = None # 原始模型名(用于价格计算)
|
|
||||||
|
|
||||||
# ==================== 请求/响应头 ====================
|
|
||||||
provider_request_headers: Dict[str, str] = field(default_factory=dict)
|
|
||||||
provider_response_headers: Dict[str, str] = field(default_factory=dict)
|
|
||||||
|
|
||||||
# ==================== 追踪信息 ====================
|
|
||||||
attempt_id: Optional[str] = None
|
|
||||||
|
|
||||||
# ==================== 能力需求 ====================
|
|
||||||
capability_requirements: Dict[str, bool] = field(default_factory=dict)
|
|
||||||
# 运行时计算的能力需求,来源于:
|
|
||||||
# 1. 用户 model_capability_settings
|
|
||||||
# 2. 用户 ApiKey.force_capabilities
|
|
||||||
# 3. 请求头 X-Require-Capability
|
|
||||||
# 4. 失败重试时动态添加
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create(
|
|
||||||
cls,
|
|
||||||
*,
|
|
||||||
db: Any,
|
|
||||||
user: Any,
|
|
||||||
api_key: Any,
|
|
||||||
api_format: str,
|
|
||||||
model: str,
|
|
||||||
is_stream: bool = False,
|
|
||||||
original_headers: Optional[Dict[str, str]] = None,
|
|
||||||
original_body: Optional[Dict[str, Any]] = None,
|
|
||||||
client_ip: str = "unknown",
|
|
||||||
user_agent: str = "",
|
|
||||||
request_id: Optional[str] = None,
|
|
||||||
) -> "RequestContext":
|
|
||||||
"""创建请求上下文"""
|
|
||||||
return cls(
|
|
||||||
request_id=request_id or str(uuid.uuid4()),
|
|
||||||
db=db,
|
|
||||||
user=user,
|
|
||||||
api_key=api_key,
|
|
||||||
api_format=api_format,
|
|
||||||
model=model,
|
|
||||||
is_stream=is_stream,
|
|
||||||
original_headers=original_headers or {},
|
|
||||||
original_body=original_body or {},
|
|
||||||
client_ip=client_ip,
|
|
||||||
user_agent=user_agent,
|
|
||||||
original_model=model, # 初始时原始模型等于请求模型
|
|
||||||
)
|
|
||||||
|
|
||||||
def update_provider_info(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
provider_name: str,
|
|
||||||
provider_id: str,
|
|
||||||
endpoint_id: str,
|
|
||||||
provider_api_key_id: str,
|
|
||||||
resolved_model: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
"""更新 Provider 信息(请求执行后调用)"""
|
|
||||||
self.provider_name = provider_name
|
|
||||||
self.provider_id = provider_id
|
|
||||||
self.endpoint_id = endpoint_id
|
|
||||||
self.provider_api_key_id = provider_api_key_id
|
|
||||||
if resolved_model:
|
|
||||||
self.resolved_model = resolved_model
|
|
||||||
|
|
||||||
def update_headers(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
request_headers: Optional[Dict[str, str]] = None,
|
|
||||||
response_headers: Optional[Dict[str, str]] = None,
|
|
||||||
) -> None:
|
|
||||||
"""更新请求/响应头"""
|
|
||||||
if request_headers:
|
|
||||||
self.provider_request_headers = request_headers
|
|
||||||
if response_headers:
|
|
||||||
self.provider_response_headers = response_headers
|
|
||||||
|
|
||||||
@property
|
|
||||||
def elapsed_ms(self) -> int:
|
|
||||||
"""计算已经过的时间(毫秒)"""
|
|
||||||
return int((time.time() - self.start_time) * 1000)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def effective_model(self) -> str:
|
|
||||||
"""获取有效的模型名(映射后优先)"""
|
|
||||||
return self.resolved_model or self.model
|
|
||||||
|
|
||||||
@property
|
|
||||||
def billing_model(self) -> str:
|
|
||||||
"""获取计费模型名(原始模型优先)"""
|
|
||||||
return self.original_model or self.model
|
|
||||||
|
|
||||||
def to_metadata_dict(self) -> Dict[str, Any]:
|
|
||||||
"""转换为元数据字典(用于 Usage 记录)"""
|
|
||||||
return {
|
|
||||||
"api_format": self.api_format,
|
|
||||||
"provider": self.provider_name or "unknown",
|
|
||||||
"model": self.effective_model,
|
|
||||||
"original_model": self.billing_model,
|
|
||||||
"provider_id": self.provider_id,
|
|
||||||
"provider_endpoint_id": self.endpoint_id,
|
|
||||||
"provider_api_key_id": self.provider_api_key_id,
|
|
||||||
"provider_request_headers": self.provider_request_headers,
|
|
||||||
"provider_response_headers": self.provider_response_headers,
|
|
||||||
"attempt_id": self.attempt_id,
|
|
||||||
}
|
|
||||||
@@ -9,7 +9,7 @@
|
|||||||
|
|
||||||
输出策略:
|
输出策略:
|
||||||
- 控制台: 开发环境=DEBUG, 生产环境=INFO (通过 LOG_LEVEL 控制)
|
- 控制台: 开发环境=DEBUG, 生产环境=INFO (通过 LOG_LEVEL 控制)
|
||||||
- 文件: 始终保存 DEBUG 级别,保留30天,每日轮转
|
- 文件: 始终保存 DEBUG 级别,保留30天,按大小轮转 (100MB)
|
||||||
|
|
||||||
使用方式:
|
使用方式:
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
@@ -72,12 +72,15 @@ def _log_filter(record: dict) -> bool: # type: ignore[type-arg]
|
|||||||
|
|
||||||
|
|
||||||
if IS_DOCKER:
|
if IS_DOCKER:
|
||||||
|
# 生产环境:禁用 backtrace 和 diagnose,减少日志噪音
|
||||||
logger.add(
|
logger.add(
|
||||||
sys.stdout,
|
sys.stdout,
|
||||||
format=CONSOLE_FORMAT_PROD,
|
format=CONSOLE_FORMAT_PROD,
|
||||||
level=LOG_LEVEL,
|
level=LOG_LEVEL,
|
||||||
filter=_log_filter, # type: ignore[arg-type]
|
filter=_log_filter, # type: ignore[arg-type]
|
||||||
colorize=False,
|
colorize=False,
|
||||||
|
backtrace=False,
|
||||||
|
diagnose=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.add(
|
logger.add(
|
||||||
@@ -92,30 +95,37 @@ if not DISABLE_FILE_LOG:
|
|||||||
log_dir = PROJECT_ROOT / "logs"
|
log_dir = PROJECT_ROOT / "logs"
|
||||||
log_dir.mkdir(exist_ok=True)
|
log_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# 文件日志通用配置
|
||||||
|
file_log_config = {
|
||||||
|
"format": FILE_FORMAT,
|
||||||
|
"filter": _log_filter,
|
||||||
|
"rotation": "100 MB",
|
||||||
|
"retention": "30 days",
|
||||||
|
"compression": "gz",
|
||||||
|
"enqueue": True,
|
||||||
|
"encoding": "utf-8",
|
||||||
|
"catch": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 生产环境禁用详细堆栈
|
||||||
|
if IS_DOCKER:
|
||||||
|
file_log_config["backtrace"] = False
|
||||||
|
file_log_config["diagnose"] = False
|
||||||
|
|
||||||
# 主日志文件 - 所有级别
|
# 主日志文件 - 所有级别
|
||||||
logger.add(
|
logger.add(
|
||||||
log_dir / "app.log",
|
log_dir / "app.log",
|
||||||
format=FILE_FORMAT,
|
|
||||||
level="DEBUG",
|
level="DEBUG",
|
||||||
filter=_log_filter, # type: ignore[arg-type]
|
**file_log_config, # type: ignore[arg-type]
|
||||||
rotation="00:00",
|
|
||||||
retention="30 days",
|
|
||||||
compression="gz",
|
|
||||||
enqueue=True,
|
|
||||||
encoding="utf-8",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 错误日志文件 - 仅 ERROR 及以上
|
# 错误日志文件 - 仅 ERROR 及以上
|
||||||
|
error_log_config = file_log_config.copy()
|
||||||
|
error_log_config["rotation"] = "50 MB"
|
||||||
logger.add(
|
logger.add(
|
||||||
log_dir / "error.log",
|
log_dir / "error.log",
|
||||||
format=FILE_FORMAT,
|
|
||||||
level="ERROR",
|
level="ERROR",
|
||||||
filter=_log_filter, # type: ignore[arg-type]
|
**error_log_config, # type: ignore[arg-type]
|
||||||
rotation="00:00",
|
|
||||||
retention="30 days",
|
|
||||||
compression="gz",
|
|
||||||
enqueue=True,
|
|
||||||
encoding="utf-8",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
import time
|
import time
|
||||||
from typing import AsyncGenerator, Generator, Optional
|
from typing import AsyncGenerator, Generator, Optional
|
||||||
|
|
||||||
|
from starlette.requests import Request
|
||||||
from sqlalchemy import create_engine, event
|
from sqlalchemy import create_engine, event
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
from sqlalchemy.ext.asyncio import (
|
from sqlalchemy.ext.asyncio import (
|
||||||
@@ -150,9 +151,22 @@ def _log_pool_capacity():
|
|||||||
theoretical = config.db_pool_size + config.db_max_overflow
|
theoretical = config.db_pool_size + config.db_max_overflow
|
||||||
workers = max(1, config.worker_processes)
|
workers = max(1, config.worker_processes)
|
||||||
total_estimated = theoretical * workers
|
total_estimated = theoretical * workers
|
||||||
logger.info("数据库连接池配置")
|
safe_limit = config.pg_max_connections - config.pg_reserved_connections
|
||||||
if total_estimated > config.db_pool_warn_threshold:
|
logger.info(
|
||||||
logger.warning("数据库连接需求可能超过阈值,请调小池大小或减少 worker 数")
|
"数据库连接池配置: pool_size={}, max_overflow={}, workers={}, total_estimated={}, safe_limit={}",
|
||||||
|
config.db_pool_size,
|
||||||
|
config.db_max_overflow,
|
||||||
|
workers,
|
||||||
|
total_estimated,
|
||||||
|
safe_limit,
|
||||||
|
)
|
||||||
|
if total_estimated > safe_limit:
|
||||||
|
logger.warning(
|
||||||
|
"数据库连接池总需求可能超过 PostgreSQL 限制: {} > {} (pg_max_connections - reserved),"
|
||||||
|
"建议调整 DB_POOL_SIZE/DB_MAX_OVERFLOW 或减少 worker 数",
|
||||||
|
total_estimated,
|
||||||
|
safe_limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _ensure_async_engine() -> AsyncEngine:
|
def _ensure_async_engine() -> AsyncEngine:
|
||||||
@@ -185,7 +199,7 @@ def _ensure_async_engine() -> AsyncEngine:
|
|||||||
# 创建异步引擎
|
# 创建异步引擎
|
||||||
_async_engine = create_async_engine(
|
_async_engine = create_async_engine(
|
||||||
ASYNC_DATABASE_URL,
|
ASYNC_DATABASE_URL,
|
||||||
poolclass=QueuePool, # 使用队列连接池
|
# AsyncEngine 不能使用 QueuePool;默认使用 AsyncAdaptedQueuePool
|
||||||
pool_size=config.db_pool_size,
|
pool_size=config.db_pool_size,
|
||||||
max_overflow=config.db_max_overflow,
|
max_overflow=config.db_max_overflow,
|
||||||
pool_timeout=config.db_pool_timeout,
|
pool_timeout=config.db_pool_timeout,
|
||||||
@@ -209,7 +223,18 @@ def _ensure_async_engine() -> AsyncEngine:
|
|||||||
|
|
||||||
|
|
||||||
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
|
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
|
||||||
"""获取异步数据库会话"""
|
"""获取异步数据库会话
|
||||||
|
|
||||||
|
.. deprecated::
|
||||||
|
此方法已废弃,项目统一使用同步 Session。
|
||||||
|
未来版本可能移除此方法。请使用 get_db() 代替。
|
||||||
|
"""
|
||||||
|
import warnings
|
||||||
|
warnings.warn(
|
||||||
|
"get_async_db() 已废弃,项目统一使用同步 Session。请使用 get_db() 代替。",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
# 确保异步引擎已初始化
|
# 确保异步引擎已初始化
|
||||||
_ensure_async_engine()
|
_ensure_async_engine()
|
||||||
|
|
||||||
@@ -220,16 +245,73 @@ async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
|
|||||||
await session.close()
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
def get_db() -> Generator[Session, None, None]:
|
def get_db(request: Request = None) -> Generator[Session, None, None]: # type: ignore[assignment]
|
||||||
"""获取数据库会话
|
"""获取数据库会话
|
||||||
|
|
||||||
注意:事务管理由业务逻辑层显式控制(手动调用 commit/rollback)
|
事务策略说明
|
||||||
这里只负责会话的创建和关闭,不自动提交
|
============
|
||||||
|
本项目采用**混合事务管理**策略:
|
||||||
|
|
||||||
|
1. **LLM 请求路径**:
|
||||||
|
- 由 PluginMiddleware 统一管理事务
|
||||||
|
- Service 层使用 db.flush() 使更改可见,但不提交
|
||||||
|
- 请求结束时由中间件统一 commit 或 rollback
|
||||||
|
- 例外:UsageService.record_usage() 会显式 commit,因为使用记录需要立即持久化
|
||||||
|
|
||||||
|
2. **管理后台 API**:
|
||||||
|
- 路由层显式调用 db.commit()
|
||||||
|
- 提交后设置 request.state.tx_committed_by_route = True
|
||||||
|
- 中间件看到此标志后跳过 commit,只负责 close
|
||||||
|
|
||||||
|
3. **后台任务/调度器**:
|
||||||
|
- 使用独立 Session(通过 create_session() 或 next(get_db()))
|
||||||
|
- 自行管理事务生命周期
|
||||||
|
|
||||||
|
使用方式
|
||||||
|
========
|
||||||
|
- FastAPI 请求:通过 Depends(get_db) 注入,支持中间件管理的 session 复用
|
||||||
|
- 非请求上下文:直接调用 get_db(),退化为独立 session 模式
|
||||||
|
|
||||||
|
路由层提交事务示例
|
||||||
|
==================
|
||||||
|
```python
|
||||||
|
@router.post("/example")
|
||||||
|
async def example(request: Request, db: Session = Depends(get_db)):
|
||||||
|
# ... 业务逻辑 ...
|
||||||
|
db.commit()
|
||||||
|
request.state.tx_committed_by_route = True # 告知中间件已提交
|
||||||
|
return {"message": "success"}
|
||||||
|
```
|
||||||
|
|
||||||
|
注意事项
|
||||||
|
========
|
||||||
|
- 本函数不自动提交事务
|
||||||
|
- 异常时会自动回滚
|
||||||
|
- 中间件管理模式下,session 关闭由中间件负责
|
||||||
"""
|
"""
|
||||||
|
# FastAPI 请求上下文:优先复用中间件绑定的 request.state.db
|
||||||
|
if request is not None:
|
||||||
|
existing_db = getattr(getattr(request, "state", None), "db", None)
|
||||||
|
if isinstance(existing_db, Session):
|
||||||
|
yield existing_db
|
||||||
|
return
|
||||||
|
|
||||||
# 确保引擎已初始化
|
# 确保引擎已初始化
|
||||||
_ensure_engine()
|
_ensure_engine()
|
||||||
|
|
||||||
db = _SessionLocal()
|
db = _SessionLocal()
|
||||||
|
|
||||||
|
# 如果中间件声明会统一管理会话生命周期,则把 session 绑定到 request.state,
|
||||||
|
# 并由中间件负责 commit/rollback/close(这里不关闭,避免流式响应提前释放会话)。
|
||||||
|
managed_by_middleware = bool(
|
||||||
|
request is not None
|
||||||
|
and hasattr(request, "state")
|
||||||
|
and getattr(request.state, "db_managed_by_middleware", False)
|
||||||
|
)
|
||||||
|
if managed_by_middleware:
|
||||||
|
request.state.db = db
|
||||||
|
db.info["managed_by_middleware"] = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield db
|
yield db
|
||||||
# 不再自动 commit,由业务代码显式管理事务
|
# 不再自动 commit,由业务代码显式管理事务
|
||||||
@@ -241,12 +323,13 @@ def get_db() -> Generator[Session, None, None]:
|
|||||||
logger.debug(f"回滚事务时出错(可忽略): {rollback_error}")
|
logger.debug(f"回滚事务时出错(可忽略): {rollback_error}")
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
try:
|
if not managed_by_middleware:
|
||||||
db.close() # 确保连接返回池
|
try:
|
||||||
except Exception as close_error:
|
db.close() # 确保连接返回池
|
||||||
# 记录关闭错误(如 IllegalStateChangeError)
|
except Exception as close_error:
|
||||||
# 连接池会处理连接的回收
|
# 记录关闭错误(如 IllegalStateChangeError)
|
||||||
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
|
# 连接池会处理连接的回收
|
||||||
|
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
|
||||||
|
|
||||||
|
|
||||||
def create_session() -> Session:
|
def create_session() -> Session:
|
||||||
@@ -336,7 +419,7 @@ def init_admin_user(db: Session):
|
|||||||
admin.set_password(config.admin_password)
|
admin.set_password(config.admin_password)
|
||||||
|
|
||||||
db.add(admin)
|
db.add(admin)
|
||||||
db.commit() # 刷新以获取ID,但不提交
|
db.flush() # 分配ID,但不提交事务(由外层 init_db 统一 commit)
|
||||||
|
|
||||||
logger.info(f"创建管理员账户成功: {admin.email} ({admin.username})")
|
logger.info(f"创建管理员账户成功: {admin.email} ({admin.username})")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
15
src/main.py
15
src/main.py
@@ -3,7 +3,6 @@
|
|||||||
采用模块化架构设计
|
采用模块化架构设计
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -39,20 +38,18 @@ async def initialize_providers():
|
|||||||
"""从数据库初始化提供商(仅用于日志记录)"""
|
"""从数据库初始化提供商(仅用于日志记录)"""
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from src.core.enums import APIFormat
|
from src.database.database import create_session
|
||||||
from src.database import get_db
|
|
||||||
from src.models.database import Provider
|
from src.models.database import Provider
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 创建数据库会话
|
# 创建数据库会话
|
||||||
db_gen = get_db()
|
db: Session = create_session()
|
||||||
db: Session = next(db_gen)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 从数据库加载所有活跃的提供商
|
# 从数据库加载所有活跃的提供商
|
||||||
providers = (
|
providers = (
|
||||||
db.query(Provider)
|
db.query(Provider)
|
||||||
.filter(Provider.is_active == True)
|
.filter(Provider.is_active.is_(True))
|
||||||
.order_by(Provider.provider_priority.asc())
|
.order_by(Provider.provider_priority.asc())
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
@@ -75,7 +72,7 @@ async def initialize_providers():
|
|||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.exception("从数据库初始化提供商失败")
|
logger.exception("从数据库初始化提供商失败")
|
||||||
|
|
||||||
|
|
||||||
@@ -125,6 +122,7 @@ async def lifespan(app: FastAPI):
|
|||||||
logger.info("初始化全局Redis客户端...")
|
logger.info("初始化全局Redis客户端...")
|
||||||
from src.clients.redis_client import get_redis_client
|
from src.clients.redis_client import get_redis_client
|
||||||
|
|
||||||
|
redis_client = None
|
||||||
try:
|
try:
|
||||||
redis_client = await get_redis_client(require_redis=config.require_redis)
|
redis_client = await get_redis_client(require_redis=config.require_redis)
|
||||||
if redis_client:
|
if redis_client:
|
||||||
@@ -136,6 +134,7 @@ async def lifespan(app: FastAPI):
|
|||||||
logger.exception("[ERROR] Redis连接失败,应用启动中止")
|
logger.exception("[ERROR] Redis连接失败,应用启动中止")
|
||||||
raise
|
raise
|
||||||
logger.warning(f"Redis连接失败,但配置允许降级,将继续使用内存模式: {e}")
|
logger.warning(f"Redis连接失败,但配置允许降级,将继续使用内存模式: {e}")
|
||||||
|
redis_client = None
|
||||||
|
|
||||||
# 初始化并发管理器(内部会使用Redis)
|
# 初始化并发管理器(内部会使用Redis)
|
||||||
logger.info("初始化并发管理器...")
|
logger.info("初始化并发管理器...")
|
||||||
@@ -315,7 +314,7 @@ if frontend_dist.exists():
|
|||||||
仅对非API路径生效
|
仅对非API路径生效
|
||||||
"""
|
"""
|
||||||
# 如果是API路径,不处理
|
# 如果是API路径,不处理
|
||||||
if full_path.startswith("api/") or full_path.startswith("v1/"):
|
if full_path in {"api", "v1"} or full_path.startswith(("api/", "v1/")):
|
||||||
raise HTTPException(status_code=404, detail="Not Found")
|
raise HTTPException(status_code=404, detail="Not Found")
|
||||||
|
|
||||||
# 返回index.html,让前端路由处理
|
# 返回index.html,让前端路由处理
|
||||||
|
|||||||
@@ -1,38 +1,43 @@
|
|||||||
"""
|
"""
|
||||||
统一的插件中间件
|
统一的插件中间件(纯 ASGI 实现)
|
||||||
负责协调所有插件的调用
|
负责协调所有插件的调用
|
||||||
|
|
||||||
|
注意:使用纯 ASGI middleware 而非 BaseHTTPMiddleware,
|
||||||
|
以避免 Starlette 已知的流式响应兼容性问题。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
import time
|
import time
|
||||||
from typing import Any, Awaitable, Callable, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import HTTPException, Request
|
from starlette.requests import Request
|
||||||
from fastapi.responses import JSONResponse
|
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
|
||||||
from starlette.responses import Response as StarletteResponse
|
|
||||||
|
|
||||||
from src.config import config
|
from src.config import config
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
from src.database import get_db
|
|
||||||
from src.plugins.manager import get_plugin_manager
|
from src.plugins.manager import get_plugin_manager
|
||||||
from src.plugins.rate_limit.base import RateLimitResult
|
from src.plugins.rate_limit.base import RateLimitResult
|
||||||
|
|
||||||
|
|
||||||
|
class PluginMiddleware:
|
||||||
class PluginMiddleware(BaseHTTPMiddleware):
|
|
||||||
"""
|
"""
|
||||||
统一的插件调用中间件
|
统一的插件调用中间件(纯 ASGI 实现)
|
||||||
|
|
||||||
职责:
|
职责:
|
||||||
- 性能监控
|
- 性能监控
|
||||||
- 限流控制 (可选)
|
- 限流控制 (可选)
|
||||||
|
- 数据库会话生命周期管理
|
||||||
|
|
||||||
注意: 认证由各路由通过 Depends() 显式声明,不在中间件层处理
|
注意: 认证由各路由通过 Depends() 显式声明,不在中间件层处理
|
||||||
|
|
||||||
|
为什么使用纯 ASGI 而非 BaseHTTPMiddleware:
|
||||||
|
- BaseHTTPMiddleware 会缓冲整个响应体,对流式响应不友好
|
||||||
|
- BaseHTTPMiddleware 与 StreamingResponse 存在已知兼容性问题
|
||||||
|
- 纯 ASGI 可以直接透传流式响应,无额外开销
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, app: Any) -> None:
|
def __init__(self, app: ASGIApp) -> None:
|
||||||
super().__init__(app)
|
self.app = app
|
||||||
self.plugin_manager = get_plugin_manager()
|
self.plugin_manager = get_plugin_manager()
|
||||||
|
|
||||||
# 从配置读取速率限制值
|
# 从配置读取速率限制值
|
||||||
@@ -62,175 +67,159 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
|||||||
"/v1/completions",
|
"/v1/completions",
|
||||||
]
|
]
|
||||||
|
|
||||||
async def dispatch(
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
self, request: Request, call_next: Callable[[Request], Awaitable[StarletteResponse]]
|
"""ASGI 入口点"""
|
||||||
) -> StarletteResponse:
|
if scope["type"] != "http":
|
||||||
"""处理请求并调用相应插件"""
|
# 非 HTTP 请求(如 WebSocket)直接透传
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 构建 Request 对象以便复用现有逻辑
|
||||||
|
request = Request(scope, receive, send)
|
||||||
|
|
||||||
# 记录请求开始时间
|
# 记录请求开始时间
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 设置 request.state 属性
|
||||||
|
# 注意:Starlette 的 Request 对象总是有 state 属性(State 实例)
|
||||||
request.state.request_id = request.headers.get("x-request-id", "")
|
request.state.request_id = request.headers.get("x-request-id", "")
|
||||||
request.state.start_time = start_time
|
request.state.start_time = start_time
|
||||||
|
# 标记:若请求过程中通过 Depends(get_db) 创建了会话,则由本中间件统一管理其生命周期
|
||||||
|
request.state.db_managed_by_middleware = True
|
||||||
|
|
||||||
# 从 request.app 获取 FastAPI 应用实例(而不是从 __init__ 的 app 参数)
|
# 1. 限流检查(在调用下游之前)
|
||||||
# 这样才能访问到真正的 FastAPI 实例和其 dependency_overrides
|
rate_limit_result = await self._call_rate_limit_plugins(request)
|
||||||
db_func = get_db
|
if rate_limit_result and not rate_limit_result.allowed:
|
||||||
if hasattr(request, "app") and hasattr(request.app, "dependency_overrides"):
|
# 限流触发,返回429
|
||||||
if get_db in request.app.dependency_overrides:
|
await self._send_rate_limit_response(send, rate_limit_result)
|
||||||
db_func = request.app.dependency_overrides[get_db]
|
return
|
||||||
logger.debug("Using overridden get_db from app.dependency_overrides")
|
|
||||||
|
|
||||||
# 创建数据库会话供需要的插件或后续处理使用
|
# 2. 预处理插件调用
|
||||||
db_gen = db_func()
|
await self._call_pre_request_plugins(request)
|
||||||
db = None
|
|
||||||
response = None
|
|
||||||
exception_to_raise = None
|
|
||||||
|
|
||||||
|
# 用于捕获响应状态码
|
||||||
|
response_status_code: int = 0
|
||||||
|
|
||||||
|
async def send_wrapper(message: Message) -> None:
|
||||||
|
nonlocal response_status_code
|
||||||
|
|
||||||
|
if message["type"] == "http.response.start":
|
||||||
|
response_status_code = message.get("status", 0)
|
||||||
|
|
||||||
|
await send(message)
|
||||||
|
|
||||||
|
# 3. 调用下游应用
|
||||||
|
exception_occurred: Optional[Exception] = None
|
||||||
try:
|
try:
|
||||||
# 获取数据库会话
|
await self.app(scope, receive, send_wrapper)
|
||||||
db = next(db_gen)
|
|
||||||
request.state.db = db
|
|
||||||
|
|
||||||
# 1. 限流插件调用(可选功能)
|
|
||||||
rate_limit_result = await self._call_rate_limit_plugins(request)
|
|
||||||
if rate_limit_result and not rate_limit_result.allowed:
|
|
||||||
# 限流触发,返回429
|
|
||||||
headers = rate_limit_result.headers or {}
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=429,
|
|
||||||
detail=rate_limit_result.message or "Rate limit exceeded",
|
|
||||||
headers=headers,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. 预处理插件调用
|
|
||||||
await self._call_pre_request_plugins(request)
|
|
||||||
|
|
||||||
# 处理请求
|
|
||||||
response = await call_next(request)
|
|
||||||
|
|
||||||
# 3. 提交关键数据库事务(在返回响应前)
|
|
||||||
# 这确保了 Usage 记录、配额扣减等关键数据在响应返回前持久化
|
|
||||||
try:
|
|
||||||
db.commit()
|
|
||||||
except Exception as commit_error:
|
|
||||||
logger.error(f"关键事务提交失败: {commit_error}")
|
|
||||||
db.rollback()
|
|
||||||
# 返回 500 错误,因为数据可能不一致
|
|
||||||
response = JSONResponse(
|
|
||||||
status_code=500,
|
|
||||||
content={
|
|
||||||
"type": "error",
|
|
||||||
"error": {
|
|
||||||
"type": "database_error",
|
|
||||||
"message": "数据保存失败,请重试",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# 跳过后处理插件,直接返回错误响应
|
|
||||||
return response
|
|
||||||
|
|
||||||
# 4. 后处理插件调用(监控等,非关键操作)
|
|
||||||
# 这些操作失败不应影响用户响应
|
|
||||||
await self._call_post_request_plugins(request, response, start_time)
|
|
||||||
|
|
||||||
# 注意:不在此处添加限流响应头,因为在BaseHTTPMiddleware中
|
|
||||||
# 响应返回后修改headers会导致Content-Length不匹配错误
|
|
||||||
# 限流响应头已在返回429错误时正确包含(见上面的HTTPException)
|
|
||||||
|
|
||||||
except RuntimeError as e:
|
|
||||||
if str(e) == "No response returned.":
|
|
||||||
if db:
|
|
||||||
db.rollback()
|
|
||||||
|
|
||||||
logger.error("Downstream handler completed without returning a response")
|
|
||||||
|
|
||||||
await self._call_error_plugins(request, e, start_time)
|
|
||||||
|
|
||||||
if db:
|
|
||||||
try:
|
|
||||||
db.commit()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
response = JSONResponse(
|
|
||||||
status_code=500,
|
|
||||||
content={
|
|
||||||
"type": "error",
|
|
||||||
"error": {
|
|
||||||
"type": "internal_error",
|
|
||||||
"message": "Internal server error: downstream handler returned no response.",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
exception_to_raise = e
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 回滚数据库事务
|
exception_occurred = e
|
||||||
if db:
|
|
||||||
db.rollback()
|
|
||||||
|
|
||||||
# 错误处理插件调用
|
# 错误处理插件调用
|
||||||
await self._call_error_plugins(request, e, start_time)
|
await self._call_error_plugins(request, e, start_time)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# 4. 数据库会话清理(无论成功与否)
|
||||||
|
await self._cleanup_db_session(request, exception_occurred)
|
||||||
|
|
||||||
# 尝试提交错误日志
|
# 5. 后处理插件调用(仅在成功时)
|
||||||
if db:
|
if not exception_occurred and response_status_code > 0:
|
||||||
|
await self._call_post_request_plugins(request, response_status_code, start_time)
|
||||||
|
|
||||||
|
async def _send_rate_limit_response(
|
||||||
|
self, send: Send, result: RateLimitResult
|
||||||
|
) -> None:
|
||||||
|
"""发送 429 限流响应"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
body = json.dumps({
|
||||||
|
"type": "error",
|
||||||
|
"error": {
|
||||||
|
"type": "rate_limit_error",
|
||||||
|
"message": result.message or "Rate limit exceeded",
|
||||||
|
},
|
||||||
|
}).encode("utf-8")
|
||||||
|
|
||||||
|
headers = [(b"content-type", b"application/json")]
|
||||||
|
if result.headers:
|
||||||
|
for key, value in result.headers.items():
|
||||||
|
headers.append((key.lower().encode(), str(value).encode()))
|
||||||
|
|
||||||
|
await send({
|
||||||
|
"type": "http.response.start",
|
||||||
|
"status": 429,
|
||||||
|
"headers": headers,
|
||||||
|
})
|
||||||
|
await send({
|
||||||
|
"type": "http.response.body",
|
||||||
|
"body": body,
|
||||||
|
})
|
||||||
|
|
||||||
|
async def _cleanup_db_session(
|
||||||
|
self, request: Request, exception: Optional[Exception]
|
||||||
|
) -> None:
|
||||||
|
"""清理数据库会话
|
||||||
|
|
||||||
|
事务策略:
|
||||||
|
- 如果 request.state.tx_committed_by_route 为 True,说明路由已自行提交,中间件只负责 close
|
||||||
|
- 否则由中间件统一 commit/rollback
|
||||||
|
|
||||||
|
这避免了双重提交的问题,同时保持向后兼容。
|
||||||
|
"""
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
db = getattr(request.state, "db", None)
|
||||||
|
if not isinstance(db, Session):
|
||||||
|
return
|
||||||
|
|
||||||
|
# 检查是否由路由层已经提交
|
||||||
|
tx_committed_by_route = getattr(request.state, "tx_committed_by_route", False)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if exception is not None:
|
||||||
|
# 发生异常,回滚事务(无论谁负责提交)
|
||||||
|
try:
|
||||||
|
db.rollback()
|
||||||
|
except Exception as rollback_error:
|
||||||
|
logger.debug(f"回滚事务时出错(可忽略): {rollback_error}")
|
||||||
|
elif not tx_committed_by_route:
|
||||||
|
# 正常完成且路由未自行提交,由中间件提交事务
|
||||||
try:
|
try:
|
||||||
db.commit()
|
db.commit()
|
||||||
except:
|
except Exception as commit_error:
|
||||||
pass
|
logger.error(f"关键事务提交失败: {commit_error}")
|
||||||
|
try:
|
||||||
exception_to_raise = e
|
db.rollback()
|
||||||
|
except Exception:
|
||||||
finally:
|
|
||||||
# 确保数据库会话被正确关闭
|
|
||||||
# 注意:需要安全地处理各种状态,避免 IllegalStateChangeError
|
|
||||||
if db is not None:
|
|
||||||
try:
|
|
||||||
# 检查会话是否可以安全地进行回滚
|
|
||||||
# 只有当没有进行中的事务操作时才尝试回滚
|
|
||||||
if db.is_active and not db.get_transaction().is_active:
|
|
||||||
# 事务不在活跃状态,可以安全回滚
|
|
||||||
pass
|
pass
|
||||||
elif db.is_active:
|
# 如果 tx_committed_by_route 为 True,跳过 commit(路由已提交)
|
||||||
# 事务在活跃状态,尝试回滚
|
finally:
|
||||||
try:
|
# 关闭会话,归还连接到连接池
|
||||||
db.rollback()
|
|
||||||
except Exception as rollback_error:
|
|
||||||
# 回滚失败(可能是 commit 正在进行中),忽略错误
|
|
||||||
logger.debug(f"Rollback skipped: {rollback_error}")
|
|
||||||
except Exception:
|
|
||||||
# 检查状态时出错,忽略
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 通过触发生成器的 finally 块来关闭会话(标准模式)
|
|
||||||
# 这会调用 get_db() 的 finally 块,执行 db.close()
|
|
||||||
try:
|
try:
|
||||||
next(db_gen, None)
|
db.close()
|
||||||
except StopIteration:
|
except Exception as close_error:
|
||||||
# 正常情况:生成器已耗尽
|
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
|
||||||
pass
|
|
||||||
except Exception as cleanup_error:
|
|
||||||
# 忽略 IllegalStateChangeError 等清理错误
|
|
||||||
# 这些错误通常是由于事务状态不一致导致的,不影响业务逻辑
|
|
||||||
if "IllegalStateChangeError" not in str(type(cleanup_error).__name__):
|
|
||||||
logger.warning(f"Database cleanup warning: {cleanup_error}")
|
|
||||||
|
|
||||||
# 在 finally 块之后处理异常和响应
|
|
||||||
if exception_to_raise:
|
|
||||||
raise exception_to_raise
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
def _get_client_ip(self, request: Request) -> str:
|
def _get_client_ip(self, request: Request) -> str:
|
||||||
"""
|
"""
|
||||||
获取客户端 IP 地址,支持代理头
|
获取客户端 IP 地址,支持代理头
|
||||||
|
|
||||||
|
注意:此方法信任 X-Forwarded-For 和 X-Real-IP 头,
|
||||||
|
仅当服务部署在可信代理(如 Nginx、CloudFlare)后面时才安全。
|
||||||
|
如果服务直接暴露公网,攻击者可伪造这些头绕过限流。
|
||||||
"""
|
"""
|
||||||
|
# 从配置获取可信代理层数(默认为 1,即信任最近一层代理)
|
||||||
|
trusted_proxy_count = getattr(config, "trusted_proxy_count", 1)
|
||||||
|
|
||||||
# 优先从代理头获取真实 IP
|
# 优先从代理头获取真实 IP
|
||||||
forwarded_for = request.headers.get("x-forwarded-for")
|
forwarded_for = request.headers.get("x-forwarded-for")
|
||||||
if forwarded_for:
|
if forwarded_for:
|
||||||
# X-Forwarded-For 可能包含多个 IP,取第一个
|
# X-Forwarded-For 格式: "client, proxy1, proxy2"
|
||||||
return forwarded_for.split(",")[0].strip()
|
# 从右往左数 trusted_proxy_count 个,取其左边的第一个
|
||||||
|
ips = [ip.strip() for ip in forwarded_for.split(",")]
|
||||||
|
if len(ips) > trusted_proxy_count:
|
||||||
|
return ips[-(trusted_proxy_count + 1)]
|
||||||
|
elif ips:
|
||||||
|
return ips[0]
|
||||||
|
|
||||||
real_ip = request.headers.get("x-real-ip")
|
real_ip = request.headers.get("x-real-ip")
|
||||||
if real_ip:
|
if real_ip:
|
||||||
@@ -250,7 +239,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
async def _get_rate_limit_key_and_config(
|
async def _get_rate_limit_key_and_config(
|
||||||
self, request: Request, db: Session
|
self, request: Request
|
||||||
) -> tuple[Optional[str], Optional[int]]:
|
) -> tuple[Optional[str], Optional[int]]:
|
||||||
"""
|
"""
|
||||||
获取速率限制的key和配置
|
获取速率限制的key和配置
|
||||||
@@ -272,13 +261,11 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
|||||||
auth_header = request.headers.get("authorization", "")
|
auth_header = request.headers.get("authorization", "")
|
||||||
api_key = request.headers.get("x-api-key", "")
|
api_key = request.headers.get("x-api-key", "")
|
||||||
|
|
||||||
if auth_header.startswith("Bearer "):
|
if auth_header.lower().startswith("bearer "):
|
||||||
api_key = auth_header[7:]
|
api_key = auth_header[7:]
|
||||||
|
|
||||||
if api_key:
|
if api_key:
|
||||||
# 使用 API Key 的哈希作为限制 key(避免日志泄露完整 key)
|
# 使用 API Key 的哈希作为限制 key(避免日志泄露完整 key)
|
||||||
import hashlib
|
|
||||||
|
|
||||||
key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:16]
|
key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:16]
|
||||||
key = f"llm_api_key:{key_hash}"
|
key = f"llm_api_key:{key_hash}"
|
||||||
request.state.rate_limit_key_type = "api_key"
|
request.state.rate_limit_key_type = "api_key"
|
||||||
@@ -318,14 +305,8 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
|||||||
# 如果没有限流插件,允许通过
|
# 如果没有限流插件,允许通过
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 获取数据库会话
|
# 获取速率限制的 key 和配置
|
||||||
db = getattr(request.state, "db", None)
|
key, rate_limit_value = await self._get_rate_limit_key_and_config(request)
|
||||||
if not db:
|
|
||||||
logger.warning("速率限制检查:无法获取数据库会话")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 获取速率限制的key和配置(从数据库)
|
|
||||||
key, rate_limit_value = await self._get_rate_limit_key_and_config(request, db)
|
|
||||||
if not key:
|
if not key:
|
||||||
# 不需要限流的端点(如未分类路径),静默跳过
|
# 不需要限流的端点(如未分类路径),静默跳过
|
||||||
return None
|
return None
|
||||||
@@ -336,7 +317,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
|||||||
key=key,
|
key=key,
|
||||||
endpoint=request.url.path,
|
endpoint=request.url.path,
|
||||||
method=request.method,
|
method=request.method,
|
||||||
rate_limit=rate_limit_value, # 传入数据库配置的限制值
|
rate_limit=rate_limit_value, # 传入配置的限制值
|
||||||
)
|
)
|
||||||
# 类型检查:确保返回的是RateLimitResult类型
|
# 类型检查:确保返回的是RateLimitResult类型
|
||||||
if isinstance(result, RateLimitResult):
|
if isinstance(result, RateLimitResult):
|
||||||
@@ -349,7 +330,10 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# 限流触发,记录日志
|
# 限流触发,记录日志
|
||||||
logger.warning(f"速率限制触发: {getattr(request.state, 'rate_limit_key_type', 'unknown')}")
|
logger.warning(
|
||||||
|
"速率限制触发: {}",
|
||||||
|
getattr(request.state, "rate_limit_key_type", "unknown"),
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -362,7 +346,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
async def _call_post_request_plugins(
|
async def _call_post_request_plugins(
|
||||||
self, request: Request, response: StarletteResponse, start_time: float
|
self, request: Request, status_code: int, start_time: float
|
||||||
) -> None:
|
) -> None:
|
||||||
"""调用请求后的插件"""
|
"""调用请求后的插件"""
|
||||||
|
|
||||||
@@ -375,8 +359,8 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
|||||||
monitor_labels = {
|
monitor_labels = {
|
||||||
"method": request.method,
|
"method": request.method,
|
||||||
"endpoint": request.url.path,
|
"endpoint": request.url.path,
|
||||||
"status": str(response.status_code),
|
"status": str(status_code),
|
||||||
"status_class": f"{response.status_code // 100}xx",
|
"status_class": f"{status_code // 100}xx",
|
||||||
}
|
}
|
||||||
|
|
||||||
# 记录请求计数
|
# 记录请求计数
|
||||||
@@ -398,6 +382,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
|||||||
self, request: Request, error: Exception, start_time: float
|
self, request: Request, error: Exception, start_time: float
|
||||||
) -> None:
|
) -> None:
|
||||||
"""调用错误处理插件"""
|
"""调用错误处理插件"""
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
|
|
||||||
@@ -410,7 +395,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
|||||||
error=error,
|
error=error,
|
||||||
context={
|
context={
|
||||||
"endpoint": f"{request.method} {request.url.path}",
|
"endpoint": f"{request.method} {request.url.path}",
|
||||||
"request_id": request.state.request_id,
|
"request_id": getattr(request.state, "request_id", ""),
|
||||||
"duration": duration,
|
"duration": duration,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -13,6 +13,42 @@ from pydantic import BaseModel, Field, field_validator, model_validator
|
|||||||
from src.core.enums import APIFormat, ProviderBillingType
|
from src.core.enums import APIFormat, ProviderBillingType
|
||||||
|
|
||||||
|
|
||||||
|
class ProxyConfig(BaseModel):
|
||||||
|
"""代理配置"""
|
||||||
|
|
||||||
|
url: str = Field(..., description="代理 URL (http://, https://, socks5://)")
|
||||||
|
username: Optional[str] = Field(None, max_length=255, description="代理用户名")
|
||||||
|
password: Optional[str] = Field(None, max_length=500, description="代理密码")
|
||||||
|
enabled: bool = Field(True, description="是否启用代理(false 时保留配置但不使用)")
|
||||||
|
|
||||||
|
@field_validator("url")
|
||||||
|
@classmethod
|
||||||
|
def validate_proxy_url(cls, v: str) -> str:
|
||||||
|
"""验证代理 URL 格式"""
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
v = v.strip()
|
||||||
|
|
||||||
|
# 检查禁止的字符(防止注入)
|
||||||
|
if "\n" in v or "\r" in v:
|
||||||
|
raise ValueError("代理 URL 包含非法字符")
|
||||||
|
|
||||||
|
# 验证协议(不支持 SOCKS4)
|
||||||
|
if not re.match(r"^(http|https|socks5)://", v, re.IGNORECASE):
|
||||||
|
raise ValueError("代理 URL 必须以 http://, https:// 或 socks5:// 开头")
|
||||||
|
|
||||||
|
# 验证 URL 结构
|
||||||
|
parsed = urlparse(v)
|
||||||
|
if not parsed.netloc:
|
||||||
|
raise ValueError("代理 URL 必须包含有效的 host")
|
||||||
|
|
||||||
|
# 禁止 URL 中内嵌认证信息,强制使用独立字段
|
||||||
|
if parsed.username or parsed.password:
|
||||||
|
raise ValueError("请勿在 URL 中包含用户名和密码,请使用独立的认证字段")
|
||||||
|
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
class CreateProviderRequest(BaseModel):
|
class CreateProviderRequest(BaseModel):
|
||||||
"""创建 Provider 请求"""
|
"""创建 Provider 请求"""
|
||||||
|
|
||||||
@@ -107,20 +143,6 @@ class CreateProviderRequest(BaseModel):
|
|||||||
if not re.match(r"^https?://", v, re.IGNORECASE):
|
if not re.match(r"^https?://", v, re.IGNORECASE):
|
||||||
v = f"https://{v}"
|
v = f"https://{v}"
|
||||||
|
|
||||||
# 防止 SSRF 攻击:禁止内网地址
|
|
||||||
forbidden_patterns = [
|
|
||||||
r"localhost",
|
|
||||||
r"127\.0\.0\.1",
|
|
||||||
r"0\.0\.0\.0",
|
|
||||||
r"192\.168\.",
|
|
||||||
r"10\.",
|
|
||||||
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
|
|
||||||
r"169\.254\.",
|
|
||||||
]
|
|
||||||
for pattern in forbidden_patterns:
|
|
||||||
if re.search(pattern, v, re.IGNORECASE):
|
|
||||||
raise ValueError("不允许使用内网地址")
|
|
||||||
|
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator("billing_type")
|
@field_validator("billing_type")
|
||||||
@@ -179,6 +201,7 @@ class CreateEndpointRequest(BaseModel):
|
|||||||
rpm_limit: Optional[int] = Field(None, ge=0, description="RPM 限制")
|
rpm_limit: Optional[int] = Field(None, ge=0, description="RPM 限制")
|
||||||
concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制")
|
concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制")
|
||||||
config: Optional[Dict[str, Any]] = Field(None, description="其他配置")
|
config: Optional[Dict[str, Any]] = Field(None, description="其他配置")
|
||||||
|
proxy: Optional[ProxyConfig] = Field(None, description="代理配置")
|
||||||
|
|
||||||
@field_validator("name")
|
@field_validator("name")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -195,19 +218,6 @@ class CreateEndpointRequest(BaseModel):
|
|||||||
if not re.match(r"^https?://", v, re.IGNORECASE):
|
if not re.match(r"^https?://", v, re.IGNORECASE):
|
||||||
raise ValueError("URL 必须以 http:// 或 https:// 开头")
|
raise ValueError("URL 必须以 http:// 或 https:// 开头")
|
||||||
|
|
||||||
# 防止 SSRF
|
|
||||||
forbidden_patterns = [
|
|
||||||
r"localhost",
|
|
||||||
r"127\.0\.0\.1",
|
|
||||||
r"0\.0\.0\.0",
|
|
||||||
r"192\.168\.",
|
|
||||||
r"10\.",
|
|
||||||
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
|
|
||||||
]
|
|
||||||
for pattern in forbidden_patterns:
|
|
||||||
if re.search(pattern, v, re.IGNORECASE):
|
|
||||||
raise ValueError("不允许使用内网地址")
|
|
||||||
|
|
||||||
return v.rstrip("/") # 移除末尾斜杠
|
return v.rstrip("/") # 移除末尾斜杠
|
||||||
|
|
||||||
@field_validator("api_format")
|
@field_validator("api_format")
|
||||||
@@ -247,6 +257,7 @@ class UpdateEndpointRequest(BaseModel):
|
|||||||
rpm_limit: Optional[int] = Field(None, ge=0)
|
rpm_limit: Optional[int] = Field(None, ge=0)
|
||||||
concurrent_limit: Optional[int] = Field(None, ge=0)
|
concurrent_limit: Optional[int] = Field(None, ge=0)
|
||||||
config: Optional[Dict[str, Any]] = None
|
config: Optional[Dict[str, Any]] = None
|
||||||
|
proxy: Optional[ProxyConfig] = Field(None, description="代理配置")
|
||||||
|
|
||||||
# 复用验证器
|
# 复用验证器
|
||||||
_validate_name = field_validator("name")(CreateEndpointRequest.validate_name.__func__)
|
_validate_name = field_validator("name")(CreateEndpointRequest.validate_name.__func__)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import re
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
|
|
||||||
from ..core.enums import UserRole
|
from ..core.enums import UserRole
|
||||||
|
|
||||||
@@ -336,8 +336,7 @@ class ProviderResponse(BaseModel):
|
|||||||
active_models_count: int = 0
|
active_models_count: int = 0
|
||||||
api_keys_count: int = 0
|
api_keys_count: int = 0
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(from_attributes=True)
|
||||||
from_attributes = True
|
|
||||||
|
|
||||||
|
|
||||||
# ========== 模型管理 ==========
|
# ========== 模型管理 ==========
|
||||||
@@ -442,8 +441,7 @@ class ModelResponse(BaseModel):
|
|||||||
global_model_name: Optional[str] = None
|
global_model_name: Optional[str] = None
|
||||||
global_model_display_name: Optional[str] = None
|
global_model_display_name: Optional[str] = None
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(from_attributes=True)
|
||||||
from_attributes = True
|
|
||||||
|
|
||||||
|
|
||||||
class ModelDetailResponse(BaseModel):
|
class ModelDetailResponse(BaseModel):
|
||||||
@@ -469,8 +467,7 @@ class ModelDetailResponse(BaseModel):
|
|||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(from_attributes=True)
|
||||||
from_attributes = True
|
|
||||||
|
|
||||||
|
|
||||||
# ========== 系统设置 ==========
|
# ========== 系统设置 ==========
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ Provider API Key相关的API模型
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
|
||||||
class ProviderAPIKeyBase(BaseModel):
|
class ProviderAPIKeyBase(BaseModel):
|
||||||
@@ -53,8 +53,7 @@ class ProviderAPIKeyResponse(ProviderAPIKeyBase):
|
|||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(from_attributes=True)
|
||||||
from_attributes = True
|
|
||||||
|
|
||||||
|
|
||||||
class ProviderAPIKeyStats(BaseModel):
|
class ProviderAPIKeyStats(BaseModel):
|
||||||
|
|||||||
@@ -27,8 +27,7 @@ from sqlalchemy import (
|
|||||||
UniqueConstraint,
|
UniqueConstraint,
|
||||||
)
|
)
|
||||||
from sqlalchemy.dialects.postgresql import JSONB
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.orm import declarative_base, relationship
|
||||||
from sqlalchemy.orm import relationship
|
|
||||||
|
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from ..core.enums import ProviderBillingType, UserRole
|
from ..core.enums import ProviderBillingType, UserRole
|
||||||
@@ -539,6 +538,9 @@ class ProviderEndpoint(Base):
|
|||||||
# 额外配置
|
# 额外配置
|
||||||
config = Column(JSON, nullable=True) # 端点特定配置(不推荐使用,优先使用专用字段)
|
config = Column(JSON, nullable=True) # 端点特定配置(不推荐使用,优先使用专用字段)
|
||||||
|
|
||||||
|
# 代理配置
|
||||||
|
proxy = Column(JSONB, nullable=True) # 代理配置: {url, username, password}
|
||||||
|
|
||||||
# 时间戳
|
# 时间戳
|
||||||
created_at = Column(
|
created_at = Column(
|
||||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False
|
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False
|
||||||
@@ -813,7 +815,9 @@ class Model(Base):
|
|||||||
def get_effective_supports_image_generation(self) -> bool:
|
def get_effective_supports_image_generation(self) -> bool:
|
||||||
return self._get_effective_capability("supports_image_generation", False)
|
return self._get_effective_capability("supports_image_generation", False)
|
||||||
|
|
||||||
def select_provider_model_name(self, affinity_key: Optional[str] = None) -> str:
|
def select_provider_model_name(
|
||||||
|
self, affinity_key: Optional[str] = None, api_format: Optional[str] = None
|
||||||
|
) -> str:
|
||||||
"""按优先级选择要使用的 Provider 模型名称
|
"""按优先级选择要使用的 Provider 模型名称
|
||||||
|
|
||||||
如果配置了 provider_model_aliases,按优先级选择(数字越小越优先);
|
如果配置了 provider_model_aliases,按优先级选择(数字越小越优先);
|
||||||
@@ -822,6 +826,7 @@ class Model(Base):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
affinity_key: 用于哈希分散的亲和键(如用户 API Key 哈希),确保同一用户稳定选择同一别名
|
affinity_key: 用于哈希分散的亲和键(如用户 API Key 哈希),确保同一用户稳定选择同一别名
|
||||||
|
api_format: 当前请求的 API 格式(如 CLAUDE、OPENAI 等),用于过滤适用的别名
|
||||||
"""
|
"""
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
@@ -840,6 +845,13 @@ class Model(Base):
|
|||||||
if not isinstance(name, str) or not name.strip():
|
if not isinstance(name, str) or not name.strip():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# 检查 api_formats 作用域(如果配置了且当前有 api_format)
|
||||||
|
alias_api_formats = raw.get("api_formats")
|
||||||
|
if api_format and alias_api_formats:
|
||||||
|
# 如果配置了作用域,只有匹配时才生效
|
||||||
|
if isinstance(alias_api_formats, list) and api_format not in alias_api_formats:
|
||||||
|
continue
|
||||||
|
|
||||||
raw_priority = raw.get("priority", 1)
|
raw_priority = raw.get("priority", 1)
|
||||||
try:
|
try:
|
||||||
priority = int(raw_priority)
|
priority = int(raw_priority)
|
||||||
|
|||||||
@@ -6,7 +6,9 @@ import re
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
|
|
||||||
|
from src.models.admin_requests import ProxyConfig
|
||||||
|
|
||||||
# ========== ProviderEndpoint CRUD ==========
|
# ========== ProviderEndpoint CRUD ==========
|
||||||
|
|
||||||
@@ -30,6 +32,9 @@ class ProviderEndpointCreate(BaseModel):
|
|||||||
# 额外配置
|
# 额外配置
|
||||||
config: Optional[Dict[str, Any]] = Field(default=None, description="额外配置(JSON)")
|
config: Optional[Dict[str, Any]] = Field(default=None, description="额外配置(JSON)")
|
||||||
|
|
||||||
|
# 代理配置
|
||||||
|
proxy: Optional[ProxyConfig] = Field(default=None, description="代理配置")
|
||||||
|
|
||||||
@field_validator("api_format")
|
@field_validator("api_format")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_api_format(cls, v: str) -> str:
|
def validate_api_format(cls, v: str) -> str:
|
||||||
@@ -45,24 +50,9 @@ class ProviderEndpointCreate(BaseModel):
|
|||||||
@field_validator("base_url")
|
@field_validator("base_url")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_base_url(cls, v: str) -> str:
|
def validate_base_url(cls, v: str) -> str:
|
||||||
"""验证 API URL(SSRF 防护)"""
|
|
||||||
if not re.match(r"^https?://", v, re.IGNORECASE):
|
if not re.match(r"^https?://", v, re.IGNORECASE):
|
||||||
raise ValueError("URL 必须以 http:// 或 https:// 开头")
|
raise ValueError("URL 必须以 http:// 或 https:// 开头")
|
||||||
|
|
||||||
# 防止 SSRF 攻击:禁止内网地址
|
|
||||||
forbidden_patterns = [
|
|
||||||
r"localhost",
|
|
||||||
r"127\.0\.0\.1",
|
|
||||||
r"0\.0\.0\.0",
|
|
||||||
r"192\.168\.",
|
|
||||||
r"10\.",
|
|
||||||
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
|
|
||||||
r"169\.254\.",
|
|
||||||
]
|
|
||||||
for pattern in forbidden_patterns:
|
|
||||||
if re.search(pattern, v, re.IGNORECASE):
|
|
||||||
raise ValueError("不允许使用内网地址")
|
|
||||||
|
|
||||||
return v.rstrip("/") # 移除末尾斜杠
|
return v.rstrip("/") # 移除末尾斜杠
|
||||||
|
|
||||||
|
|
||||||
@@ -79,31 +69,18 @@ class ProviderEndpointUpdate(BaseModel):
|
|||||||
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制")
|
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制")
|
||||||
is_active: Optional[bool] = Field(default=None, description="是否启用")
|
is_active: Optional[bool] = Field(default=None, description="是否启用")
|
||||||
config: Optional[Dict[str, Any]] = Field(default=None, description="额外配置")
|
config: Optional[Dict[str, Any]] = Field(default=None, description="额外配置")
|
||||||
|
proxy: Optional[ProxyConfig] = Field(default=None, description="代理配置")
|
||||||
|
|
||||||
@field_validator("base_url")
|
@field_validator("base_url")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_base_url(cls, v: Optional[str]) -> Optional[str]:
|
def validate_base_url(cls, v: Optional[str]) -> Optional[str]:
|
||||||
"""验证 API URL(SSRF 防护)"""
|
"""验证 API URL"""
|
||||||
if v is None:
|
if v is None:
|
||||||
return v
|
return v
|
||||||
|
|
||||||
if not re.match(r"^https?://", v, re.IGNORECASE):
|
if not re.match(r"^https?://", v, re.IGNORECASE):
|
||||||
raise ValueError("URL 必须以 http:// 或 https:// 开头")
|
raise ValueError("URL 必须以 http:// 或 https:// 开头")
|
||||||
|
|
||||||
# 防止 SSRF 攻击:禁止内网地址
|
|
||||||
forbidden_patterns = [
|
|
||||||
r"localhost",
|
|
||||||
r"127\.0\.0\.1",
|
|
||||||
r"0\.0\.0\.0",
|
|
||||||
r"192\.168\.",
|
|
||||||
r"10\.",
|
|
||||||
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
|
|
||||||
r"169\.254\.",
|
|
||||||
]
|
|
||||||
for pattern in forbidden_patterns:
|
|
||||||
if re.search(pattern, v, re.IGNORECASE):
|
|
||||||
raise ValueError("不允许使用内网地址")
|
|
||||||
|
|
||||||
return v.rstrip("/") # 移除末尾斜杠
|
return v.rstrip("/") # 移除末尾斜杠
|
||||||
|
|
||||||
|
|
||||||
@@ -133,6 +110,9 @@ class ProviderEndpointResponse(BaseModel):
|
|||||||
# 额外配置
|
# 额外配置
|
||||||
config: Optional[Dict[str, Any]] = None
|
config: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
# 代理配置(响应中密码已脱敏)
|
||||||
|
proxy: Optional[Dict[str, Any]] = Field(default=None, description="代理配置(密码已脱敏)")
|
||||||
|
|
||||||
# 统计(从 Keys 聚合)
|
# 统计(从 Keys 聚合)
|
||||||
total_keys: int = Field(default=0, description="总 Key 数量")
|
total_keys: int = Field(default=0, description="总 Key 数量")
|
||||||
active_keys: int = Field(default=0, description="活跃 Key 数量")
|
active_keys: int = Field(default=0, description="活跃 Key 数量")
|
||||||
@@ -141,8 +121,7 @@ class ProviderEndpointResponse(BaseModel):
|
|||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(from_attributes=True)
|
||||||
from_attributes = True
|
|
||||||
|
|
||||||
|
|
||||||
# ========== ProviderAPIKey 相关(新架构) ==========
|
# ========== ProviderAPIKey 相关(新架构) ==========
|
||||||
@@ -384,8 +363,7 @@ class EndpointAPIKeyResponse(BaseModel):
|
|||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(from_attributes=True)
|
||||||
from_attributes = True
|
|
||||||
|
|
||||||
|
|
||||||
# ========== 健康监控相关 ==========
|
# ========== 健康监控相关 ==========
|
||||||
@@ -535,8 +513,7 @@ class ProviderWithEndpointsSummary(BaseModel):
|
|||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(from_attributes=True)
|
||||||
from_attributes = True
|
|
||||||
|
|
||||||
|
|
||||||
# ========== 健康监控可视化模型 ==========
|
# ========== 健康监控可视化模型 ==========
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ Pydantic 数据模型(阶段一统一模型管理)
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||||
|
|
||||||
|
|
||||||
# ========== 阶梯计费相关模型 ==========
|
# ========== 阶梯计费相关模型 ==========
|
||||||
@@ -256,8 +256,7 @@ class GlobalModelResponse(BaseModel):
|
|||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: Optional[datetime]
|
updated_at: Optional[datetime]
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(from_attributes=True)
|
||||||
from_attributes = True
|
|
||||||
|
|
||||||
|
|
||||||
class GlobalModelWithStats(GlobalModelResponse):
|
class GlobalModelWithStats(GlobalModelResponse):
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ JWT认证插件
|
|||||||
支持JWT Bearer token认证
|
支持JWT Bearer token认证
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
@@ -46,12 +47,12 @@ class JwtAuthPlugin(AuthPlugin):
|
|||||||
logger.debug("未找到JWT token")
|
logger.debug("未找到JWT token")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 记录认证尝试的详细信息
|
token_fingerprint = hashlib.sha256(token.encode()).hexdigest()[:12]
|
||||||
logger.info(f"JWT认证尝试 - 路径: {request.url.path}, Token前20位: {token[:20]}...")
|
logger.info(f"JWT认证尝试 - 路径: {request.url.path}, token_fp={token_fingerprint}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 验证JWT token
|
# 验证JWT token
|
||||||
payload = AuthService.verify_token(token)
|
payload = await AuthService.verify_token(token, token_type="access")
|
||||||
logger.debug(f"JWT token验证成功, payload: {payload}")
|
logger.debug(f"JWT token验证成功, payload: {payload}")
|
||||||
|
|
||||||
# 从payload中提取用户信息
|
# 从payload中提取用户信息
|
||||||
|
|||||||
@@ -63,14 +63,16 @@ class JWTBlacklistService:
|
|||||||
|
|
||||||
if ttl_seconds <= 0:
|
if ttl_seconds <= 0:
|
||||||
# Token 已经过期,不需要加入黑名单
|
# Token 已经过期,不需要加入黑名单
|
||||||
logger.debug(f"Token 已过期,无需加入黑名单: {token[:10]}...")
|
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
|
||||||
|
logger.debug("Token 已过期,无需加入黑名单: token_fp={}", token_fp)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# 存储到 Redis,设置 TTL 为 Token 过期时间
|
# 存储到 Redis,设置 TTL 为 Token 过期时间
|
||||||
# 值存储为原因字符串
|
# 值存储为原因字符串
|
||||||
await redis_client.setex(redis_key, ttl_seconds, reason)
|
await redis_client.setex(redis_key, ttl_seconds, reason)
|
||||||
|
|
||||||
logger.info(f"Token 已加入黑名单: {token[:10]}... (原因: {reason}, TTL: {ttl_seconds}s)")
|
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
|
||||||
|
logger.info("Token 已加入黑名单: token_fp={} (原因: {}, TTL: {}s)", token_fp, reason, ttl_seconds)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -109,7 +111,8 @@ class JWTBlacklistService:
|
|||||||
if exists:
|
if exists:
|
||||||
# 获取黑名单原因(可选)
|
# 获取黑名单原因(可选)
|
||||||
reason = await redis_client.get(redis_key)
|
reason = await redis_client.get(redis_key)
|
||||||
logger.warning(f"检测到黑名单 Token: {token[:10]}... (原因: {reason})")
|
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
|
||||||
|
logger.warning("检测到黑名单 Token: token_fp={} (原因: {})", token_fp, reason)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
@@ -148,9 +151,11 @@ class JWTBlacklistService:
|
|||||||
deleted = await redis_client.delete(redis_key)
|
deleted = await redis_client.delete(redis_key)
|
||||||
|
|
||||||
if deleted:
|
if deleted:
|
||||||
logger.info(f"Token 已从黑名单移除: {token[:10]}...")
|
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
|
||||||
|
logger.info("Token 已从黑名单移除: token_fp={}", token_fp)
|
||||||
else:
|
else:
|
||||||
logger.debug(f"Token 不在黑名单中: {token[:10]}...")
|
token_fp = JWTBlacklistService._get_token_hash(token)[:12]
|
||||||
|
logger.debug("Token 不在黑名单中: token_fp={}", token_fp)
|
||||||
|
|
||||||
return bool(deleted)
|
return bool(deleted)
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import hashlib
|
||||||
import secrets
|
import secrets
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
@@ -93,8 +94,8 @@ class AuthService:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
async def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
|
async def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
|
||||||
"""用户登录认证"""
|
"""用户登录认证"""
|
||||||
# 使用缓存查询用户
|
# 登录校验必须读取密码哈希,不能使用不包含 password_hash 的缓存对象
|
||||||
user = await UserCacheService.get_user_by_email(db, email)
|
user = db.query(User).filter(User.email == email).first()
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
logger.warning(f"登录失败 - 用户不存在: {email}")
|
logger.warning(f"登录失败 - 用户不存在: {email}")
|
||||||
@@ -109,13 +110,10 @@ class AuthService:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# 更新最后登录时间
|
# 更新最后登录时间
|
||||||
# 需要重新从数据库获取以便更新(缓存的对象是分离的)
|
user.last_login_at = datetime.now(timezone.utc)
|
||||||
db_user = db.query(User).filter(User.id == user.id).first()
|
db.commit() # 立即提交事务,释放数据库锁
|
||||||
if db_user:
|
# 清除缓存,因为用户信息已更新
|
||||||
db_user.last_login_at = datetime.now(timezone.utc)
|
await UserCacheService.invalidate_user_cache(user.id, user.email)
|
||||||
db.commit() # 立即提交事务,释放数据库锁
|
|
||||||
# 清除缓存,因为用户信息已更新
|
|
||||||
await UserCacheService.invalidate_user_cache(user.id, user.email)
|
|
||||||
|
|
||||||
logger.info(f"用户登录成功: {email} (ID: {user.id})")
|
logger.info(f"用户登录成功: {email} (ID: {user.id})")
|
||||||
return user
|
return user
|
||||||
@@ -172,7 +170,8 @@ class AuthService:
|
|||||||
key_record.last_used_at = datetime.now(timezone.utc)
|
key_record.last_used_at = datetime.now(timezone.utc)
|
||||||
db.commit() # 立即提交事务,释放数据库锁,避免阻塞后续请求
|
db.commit() # 立即提交事务,释放数据库锁,避免阻塞后续请求
|
||||||
|
|
||||||
logger.debug(f"API认证成功: 用户 {user.email} (Key: {api_key[:10]}...)")
|
api_key_fp = hashlib.sha256(api_key.encode()).hexdigest()[:12]
|
||||||
|
logger.debug("API认证成功: 用户 {} (api_key_fp={})", user.email, api_key_fp)
|
||||||
return user, key_record
|
return user, key_record
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -198,7 +197,10 @@ class AuthService:
|
|||||||
if user.role == UserRole.ADMIN:
|
if user.role == UserRole.ADMIN:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if user.role.value >= required_role.value:
|
# 避免使用字符串比较导致权限判断错误(例如 'user' >= 'admin')
|
||||||
|
role_rank = {UserRole.USER: 0, UserRole.ADMIN: 1}
|
||||||
|
# 未知用户角色默认 -1(拒绝),未知要求角色默认 999(拒绝)
|
||||||
|
if role_rank.get(user.role, -1) >= role_rank.get(required_role, 999):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
logger.warning(f"权限不足: 用户 {user.email} 角色 {user.role.value} < 需要 {required_role.value}")
|
logger.warning(f"权限不足: 用户 {user.email} 角色 {user.role.value} < 需要 {required_role.value}")
|
||||||
@@ -230,7 +232,7 @@ class AuthService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
user_id = payload.get("sub")
|
user_id = payload.get("user_id")
|
||||||
logger.info(f"用户登出成功: user_id={user_id}")
|
logger.info(f"用户登出成功: user_id={user_id}")
|
||||||
|
|
||||||
return success
|
return success
|
||||||
|
|||||||
71
src/services/cache/aware_scheduler.py
vendored
71
src/services/cache/aware_scheduler.py
vendored
@@ -59,7 +59,6 @@ from src.services.health.monitor import health_monitor
|
|||||||
from src.services.provider.format import normalize_api_format
|
from src.services.provider.format import normalize_api_format
|
||||||
from src.services.rate_limit.adaptive_reservation import (
|
from src.services.rate_limit.adaptive_reservation import (
|
||||||
AdaptiveReservationManager,
|
AdaptiveReservationManager,
|
||||||
ReservationResult,
|
|
||||||
get_adaptive_reservation_manager,
|
get_adaptive_reservation_manager,
|
||||||
)
|
)
|
||||||
from src.services.rate_limit.concurrency_manager import get_concurrency_manager
|
from src.services.rate_limit.concurrency_manager import get_concurrency_manager
|
||||||
@@ -112,8 +111,6 @@ class CacheAwareScheduler:
|
|||||||
- 健康度监控
|
- 健康度监控
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 静态常量作为默认值(实际由 AdaptiveReservationManager 动态计算)
|
|
||||||
CACHE_RESERVATION_RATIO = 0.3
|
|
||||||
# 优先级模式常量
|
# 优先级模式常量
|
||||||
PRIORITY_MODE_PROVIDER = "provider" # 提供商优先模式
|
PRIORITY_MODE_PROVIDER = "provider" # 提供商优先模式
|
||||||
PRIORITY_MODE_GLOBAL_KEY = "global_key" # 全局 Key 优先模式
|
PRIORITY_MODE_GLOBAL_KEY = "global_key" # 全局 Key 优先模式
|
||||||
@@ -121,8 +118,17 @@ class CacheAwareScheduler:
|
|||||||
PRIORITY_MODE_PROVIDER,
|
PRIORITY_MODE_PROVIDER,
|
||||||
PRIORITY_MODE_GLOBAL_KEY,
|
PRIORITY_MODE_GLOBAL_KEY,
|
||||||
}
|
}
|
||||||
|
# 调度模式常量
|
||||||
|
SCHEDULING_MODE_FIXED_ORDER = "fixed_order" # 固定顺序模式
|
||||||
|
SCHEDULING_MODE_CACHE_AFFINITY = "cache_affinity" # 缓存亲和模式
|
||||||
|
ALLOWED_SCHEDULING_MODES = {
|
||||||
|
SCHEDULING_MODE_FIXED_ORDER,
|
||||||
|
SCHEDULING_MODE_CACHE_AFFINITY,
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self, redis_client=None, priority_mode: Optional[str] = None):
|
def __init__(
|
||||||
|
self, redis_client=None, priority_mode: Optional[str] = None, scheduling_mode: Optional[str] = None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
初始化调度器
|
初始化调度器
|
||||||
|
|
||||||
@@ -132,12 +138,16 @@ class CacheAwareScheduler:
|
|||||||
Args:
|
Args:
|
||||||
redis_client: Redis客户端(可选)
|
redis_client: Redis客户端(可选)
|
||||||
priority_mode: 候选排序策略(provider | global_key)
|
priority_mode: 候选排序策略(provider | global_key)
|
||||||
|
scheduling_mode: 调度模式(fixed_order | cache_affinity)
|
||||||
"""
|
"""
|
||||||
self.redis = redis_client
|
self.redis = redis_client
|
||||||
self.priority_mode = self._normalize_priority_mode(
|
self.priority_mode = self._normalize_priority_mode(
|
||||||
priority_mode or self.PRIORITY_MODE_PROVIDER
|
priority_mode or self.PRIORITY_MODE_PROVIDER
|
||||||
)
|
)
|
||||||
logger.debug(f"[CacheAwareScheduler] 初始化优先级模式: {self.priority_mode}")
|
self.scheduling_mode = self._normalize_scheduling_mode(
|
||||||
|
scheduling_mode or self.SCHEDULING_MODE_CACHE_AFFINITY
|
||||||
|
)
|
||||||
|
logger.debug(f"[CacheAwareScheduler] 初始化优先级模式: {self.priority_mode}, 调度模式: {self.scheduling_mode}")
|
||||||
|
|
||||||
# 初始化子组件(将在第一次使用时异步初始化)
|
# 初始化子组件(将在第一次使用时异步初始化)
|
||||||
self._affinity_manager: Optional[CacheAffinityManager] = None
|
self._affinity_manager: Optional[CacheAffinityManager] = None
|
||||||
@@ -673,14 +683,19 @@ class CacheAwareScheduler:
|
|||||||
f"(api_format={target_format.value}, model={model_name})"
|
f"(api_format={target_format.value}, model={model_name})"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. 应用缓存亲和性排序(使用 global_model_id 作为模型标识)
|
# 4. 应用缓存亲和性排序(仅在缓存亲和模式下启用)
|
||||||
if affinity_key and candidates:
|
if self.scheduling_mode == self.SCHEDULING_MODE_CACHE_AFFINITY:
|
||||||
candidates = await self._apply_cache_affinity(
|
if affinity_key and candidates:
|
||||||
candidates=candidates,
|
candidates = await self._apply_cache_affinity(
|
||||||
affinity_key=affinity_key,
|
candidates=candidates,
|
||||||
api_format=target_format,
|
affinity_key=affinity_key,
|
||||||
global_model_id=global_model_id,
|
api_format=target_format,
|
||||||
)
|
global_model_id=global_model_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 固定顺序模式:标记所有候选为非缓存
|
||||||
|
for candidate in candidates:
|
||||||
|
candidate.is_cached = False
|
||||||
|
|
||||||
return candidates, global_model_id
|
return candidates, global_model_id
|
||||||
|
|
||||||
@@ -1060,6 +1075,22 @@ class CacheAwareScheduler:
|
|||||||
self.priority_mode = normalized
|
self.priority_mode = normalized
|
||||||
logger.debug(f"[CacheAwareScheduler] 切换优先级模式为: {self.priority_mode}")
|
logger.debug(f"[CacheAwareScheduler] 切换优先级模式为: {self.priority_mode}")
|
||||||
|
|
||||||
|
def _normalize_scheduling_mode(self, mode: Optional[str]) -> str:
|
||||||
|
normalized = (mode or "").strip().lower()
|
||||||
|
if normalized not in self.ALLOWED_SCHEDULING_MODES:
|
||||||
|
if normalized:
|
||||||
|
logger.warning(f"[CacheAwareScheduler] 无效的调度模式 '{mode}',回退为 cache_affinity")
|
||||||
|
return self.SCHEDULING_MODE_CACHE_AFFINITY
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
def set_scheduling_mode(self, mode: Optional[str]) -> None:
|
||||||
|
"""运行时更新调度模式"""
|
||||||
|
normalized = self._normalize_scheduling_mode(mode)
|
||||||
|
if normalized == self.scheduling_mode:
|
||||||
|
return
|
||||||
|
self.scheduling_mode = normalized
|
||||||
|
logger.debug(f"[CacheAwareScheduler] 切换调度模式为: {self.scheduling_mode}")
|
||||||
|
|
||||||
def _apply_priority_mode_sort(
|
def _apply_priority_mode_sort(
|
||||||
self, candidates: List[ProviderCandidate], affinity_key: Optional[str] = None
|
self, candidates: List[ProviderCandidate], affinity_key: Optional[str] = None
|
||||||
) -> List[ProviderCandidate]:
|
) -> List[ProviderCandidate]:
|
||||||
@@ -1286,7 +1317,6 @@ class CacheAwareScheduler:
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
"scheduler": "cache_aware",
|
"scheduler": "cache_aware",
|
||||||
"cache_reservation_ratio": self.CACHE_RESERVATION_RATIO,
|
|
||||||
"dynamic_reservation": {
|
"dynamic_reservation": {
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
"config": reservation_stats["config"],
|
"config": reservation_stats["config"],
|
||||||
@@ -1307,6 +1337,7 @@ _scheduler: Optional[CacheAwareScheduler] = None
|
|||||||
async def get_cache_aware_scheduler(
|
async def get_cache_aware_scheduler(
|
||||||
redis_client=None,
|
redis_client=None,
|
||||||
priority_mode: Optional[str] = None,
|
priority_mode: Optional[str] = None,
|
||||||
|
scheduling_mode: Optional[str] = None,
|
||||||
) -> CacheAwareScheduler:
|
) -> CacheAwareScheduler:
|
||||||
"""
|
"""
|
||||||
获取全局CacheAwareScheduler实例
|
获取全局CacheAwareScheduler实例
|
||||||
@@ -1317,6 +1348,7 @@ async def get_cache_aware_scheduler(
|
|||||||
Args:
|
Args:
|
||||||
redis_client: Redis客户端(可选)
|
redis_client: Redis客户端(可选)
|
||||||
priority_mode: 外部覆盖的优先级模式(provider | global_key)
|
priority_mode: 外部覆盖的优先级模式(provider | global_key)
|
||||||
|
scheduling_mode: 外部覆盖的调度模式(fixed_order | cache_affinity)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
CacheAwareScheduler实例
|
CacheAwareScheduler实例
|
||||||
@@ -1324,8 +1356,13 @@ async def get_cache_aware_scheduler(
|
|||||||
global _scheduler
|
global _scheduler
|
||||||
|
|
||||||
if _scheduler is None:
|
if _scheduler is None:
|
||||||
_scheduler = CacheAwareScheduler(redis_client, priority_mode=priority_mode)
|
_scheduler = CacheAwareScheduler(
|
||||||
elif priority_mode:
|
redis_client, priority_mode=priority_mode, scheduling_mode=scheduling_mode
|
||||||
_scheduler.set_priority_mode(priority_mode)
|
)
|
||||||
|
else:
|
||||||
|
if priority_mode:
|
||||||
|
_scheduler.set_priority_mode(priority_mode)
|
||||||
|
if scheduling_mode:
|
||||||
|
_scheduler.set_scheduling_mode(scheduling_mode)
|
||||||
|
|
||||||
return _scheduler
|
return _scheduler
|
||||||
|
|||||||
22
src/services/cache/model_cache.py
vendored
22
src/services/cache/model_cache.py
vendored
@@ -1,5 +1,21 @@
|
|||||||
"""
|
"""
|
||||||
Model 映射缓存服务 - 减少模型查询
|
Model 映射缓存服务 - 减少模型查询
|
||||||
|
|
||||||
|
架构说明
|
||||||
|
========
|
||||||
|
本服务采用混合 async/sync 模式:
|
||||||
|
- 缓存操作(CacheService):真正的 async,使用 aioredis
|
||||||
|
- 数据库查询(db.query):同步的 SQLAlchemy Session
|
||||||
|
|
||||||
|
设计决策
|
||||||
|
--------
|
||||||
|
1. 保持 async 方法签名:因为缓存命中时完全异步,性能最优
|
||||||
|
2. 缓存未命中时的同步查询:FastAPI 会在线程池中执行,不会阻塞事件循环
|
||||||
|
3. 调用方必须在 async 上下文中使用 await
|
||||||
|
|
||||||
|
使用示例
|
||||||
|
--------
|
||||||
|
global_model = await ModelCacheService.resolve_global_model_by_name_or_alias(db, "gpt-4")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import time
|
import time
|
||||||
@@ -19,7 +35,11 @@ from src.models.database import GlobalModel, Model
|
|||||||
|
|
||||||
|
|
||||||
class ModelCacheService:
|
class ModelCacheService:
|
||||||
"""Model 映射缓存服务"""
|
"""Model 映射缓存服务
|
||||||
|
|
||||||
|
提供 GlobalModel 和 Model 的缓存查询功能,减少数据库访问。
|
||||||
|
所有公开方法均为 async,需要在 async 上下文中调用。
|
||||||
|
"""
|
||||||
|
|
||||||
# 缓存 TTL(秒)- 使用统一常量
|
# 缓存 TTL(秒)- 使用统一常量
|
||||||
CACHE_TTL = CacheTTL.MODEL
|
CACHE_TTL = CacheTTL.MODEL
|
||||||
|
|||||||
254
src/services/cache/provider_cache.py
vendored
254
src/services/cache/provider_cache.py
vendored
@@ -1,254 +0,0 @@
|
|||||||
"""
|
|
||||||
Provider 配置缓存服务 - 减少 Provider/Endpoint/APIKey 查询
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from src.config.constants import CacheTTL
|
|
||||||
from src.core.cache_service import CacheKeys, CacheService
|
|
||||||
from src.core.logger import logger
|
|
||||||
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ProviderCacheService:
|
|
||||||
"""Provider 配置缓存服务"""
|
|
||||||
|
|
||||||
# 缓存 TTL(秒)- 使用统一常量
|
|
||||||
CACHE_TTL = CacheTTL.PROVIDER
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def get_provider_by_id(db: Session, provider_id: str) -> Optional[Provider]:
|
|
||||||
"""
|
|
||||||
获取 Provider(带缓存)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: 数据库会话
|
|
||||||
provider_id: Provider ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Provider 对象或 None
|
|
||||||
"""
|
|
||||||
cache_key = CacheKeys.provider_by_id(provider_id)
|
|
||||||
|
|
||||||
# 1. 尝试从缓存获取
|
|
||||||
cached_data = await CacheService.get(cache_key)
|
|
||||||
if cached_data:
|
|
||||||
logger.debug(f"Provider 缓存命中: {provider_id}")
|
|
||||||
return ProviderCacheService._dict_to_provider(cached_data)
|
|
||||||
|
|
||||||
# 2. 缓存未命中,查询数据库
|
|
||||||
provider = db.query(Provider).filter(Provider.id == provider_id).first()
|
|
||||||
|
|
||||||
# 3. 写入缓存
|
|
||||||
if provider:
|
|
||||||
provider_dict = ProviderCacheService._provider_to_dict(provider)
|
|
||||||
await CacheService.set(
|
|
||||||
cache_key, provider_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
|
|
||||||
)
|
|
||||||
logger.debug(f"Provider 已缓存: {provider_id}")
|
|
||||||
|
|
||||||
return provider
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def get_endpoint_by_id(db: Session, endpoint_id: str) -> Optional[ProviderEndpoint]:
|
|
||||||
"""
|
|
||||||
获取 Endpoint(带缓存)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: 数据库会话
|
|
||||||
endpoint_id: Endpoint ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ProviderEndpoint 对象或 None
|
|
||||||
"""
|
|
||||||
cache_key = CacheKeys.endpoint_by_id(endpoint_id)
|
|
||||||
|
|
||||||
# 1. 尝试从缓存获取
|
|
||||||
cached_data = await CacheService.get(cache_key)
|
|
||||||
if cached_data:
|
|
||||||
logger.debug(f"Endpoint 缓存命中: {endpoint_id}")
|
|
||||||
return ProviderCacheService._dict_to_endpoint(cached_data)
|
|
||||||
|
|
||||||
# 2. 缓存未命中,查询数据库
|
|
||||||
endpoint = db.query(ProviderEndpoint).filter(ProviderEndpoint.id == endpoint_id).first()
|
|
||||||
|
|
||||||
# 3. 写入缓存
|
|
||||||
if endpoint:
|
|
||||||
endpoint_dict = ProviderCacheService._endpoint_to_dict(endpoint)
|
|
||||||
await CacheService.set(
|
|
||||||
cache_key, endpoint_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
|
|
||||||
)
|
|
||||||
logger.debug(f"Endpoint 已缓存: {endpoint_id}")
|
|
||||||
|
|
||||||
return endpoint
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def get_api_key_by_id(db: Session, api_key_id: str) -> Optional[ProviderAPIKey]:
|
|
||||||
"""
|
|
||||||
获取 API Key(带缓存)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: 数据库会话
|
|
||||||
api_key_id: API Key ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ProviderAPIKey 对象或 None
|
|
||||||
"""
|
|
||||||
cache_key = CacheKeys.api_key_by_id(api_key_id)
|
|
||||||
|
|
||||||
# 1. 尝试从缓存获取
|
|
||||||
cached_data = await CacheService.get(cache_key)
|
|
||||||
if cached_data:
|
|
||||||
logger.debug(f"API Key 缓存命中: {api_key_id}")
|
|
||||||
return ProviderCacheService._dict_to_api_key(cached_data)
|
|
||||||
|
|
||||||
# 2. 缓存未命中,查询数据库
|
|
||||||
api_key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == api_key_id).first()
|
|
||||||
|
|
||||||
# 3. 写入缓存
|
|
||||||
if api_key:
|
|
||||||
api_key_dict = ProviderCacheService._api_key_to_dict(api_key)
|
|
||||||
await CacheService.set(
|
|
||||||
cache_key, api_key_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
|
|
||||||
)
|
|
||||||
logger.debug(f"API Key 已缓存: {api_key_id}")
|
|
||||||
|
|
||||||
return api_key
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def invalidate_provider_cache(provider_id: str):
|
|
||||||
"""
|
|
||||||
清除 Provider 缓存
|
|
||||||
|
|
||||||
Args:
|
|
||||||
provider_id: Provider ID
|
|
||||||
"""
|
|
||||||
await CacheService.delete(CacheKeys.provider_by_id(provider_id))
|
|
||||||
logger.debug(f"Provider 缓存已清除: {provider_id}")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def invalidate_endpoint_cache(endpoint_id: str):
|
|
||||||
"""
|
|
||||||
清除 Endpoint 缓存
|
|
||||||
|
|
||||||
Args:
|
|
||||||
endpoint_id: Endpoint ID
|
|
||||||
"""
|
|
||||||
await CacheService.delete(CacheKeys.endpoint_by_id(endpoint_id))
|
|
||||||
logger.debug(f"Endpoint 缓存已清除: {endpoint_id}")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def invalidate_api_key_cache(api_key_id: str):
|
|
||||||
"""
|
|
||||||
清除 API Key 缓存
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_key_id: API Key ID
|
|
||||||
"""
|
|
||||||
await CacheService.delete(CacheKeys.api_key_by_id(api_key_id))
|
|
||||||
logger.debug(f"API Key 缓存已清除: {api_key_id}")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _provider_to_dict(provider: Provider) -> dict:
|
|
||||||
"""将 Provider 对象转换为字典(用于缓存)"""
|
|
||||||
return {
|
|
||||||
"id": provider.id,
|
|
||||||
"name": provider.name,
|
|
||||||
"api_format": provider.api_format,
|
|
||||||
"base_url": provider.base_url,
|
|
||||||
"is_active": provider.is_active,
|
|
||||||
"priority": provider.priority,
|
|
||||||
"rpm_limit": provider.rpm_limit,
|
|
||||||
"rpm_used": provider.rpm_used,
|
|
||||||
"rpm_reset_at": provider.rpm_reset_at.isoformat() if provider.rpm_reset_at else None,
|
|
||||||
"config": provider.config,
|
|
||||||
"description": provider.description,
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _dict_to_provider(provider_dict: dict) -> Provider:
|
|
||||||
"""从字典重建 Provider 对象(分离的对象,不在 Session 中)"""
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
provider = Provider(
|
|
||||||
id=provider_dict["id"],
|
|
||||||
name=provider_dict["name"],
|
|
||||||
api_format=provider_dict["api_format"],
|
|
||||||
base_url=provider_dict.get("base_url"),
|
|
||||||
is_active=provider_dict["is_active"],
|
|
||||||
priority=provider_dict.get("priority", 0),
|
|
||||||
rpm_limit=provider_dict.get("rpm_limit"),
|
|
||||||
rpm_used=provider_dict.get("rpm_used", 0),
|
|
||||||
config=provider_dict.get("config"),
|
|
||||||
description=provider_dict.get("description"),
|
|
||||||
)
|
|
||||||
|
|
||||||
if provider_dict.get("rpm_reset_at"):
|
|
||||||
provider.rpm_reset_at = datetime.fromisoformat(provider_dict["rpm_reset_at"])
|
|
||||||
|
|
||||||
return provider
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _endpoint_to_dict(endpoint: ProviderEndpoint) -> dict:
|
|
||||||
"""将 Endpoint 对象转换为字典"""
|
|
||||||
return {
|
|
||||||
"id": endpoint.id,
|
|
||||||
"provider_id": endpoint.provider_id,
|
|
||||||
"name": endpoint.name,
|
|
||||||
"base_url": endpoint.base_url,
|
|
||||||
"is_active": endpoint.is_active,
|
|
||||||
"priority": endpoint.priority,
|
|
||||||
"weight": endpoint.weight,
|
|
||||||
"custom_path": endpoint.custom_path,
|
|
||||||
"config": endpoint.config,
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _dict_to_endpoint(endpoint_dict: dict) -> ProviderEndpoint:
|
|
||||||
"""从字典重建 Endpoint 对象"""
|
|
||||||
endpoint = ProviderEndpoint(
|
|
||||||
id=endpoint_dict["id"],
|
|
||||||
provider_id=endpoint_dict["provider_id"],
|
|
||||||
name=endpoint_dict["name"],
|
|
||||||
base_url=endpoint_dict["base_url"],
|
|
||||||
is_active=endpoint_dict["is_active"],
|
|
||||||
priority=endpoint_dict.get("priority", 0),
|
|
||||||
weight=endpoint_dict.get("weight", 1.0),
|
|
||||||
custom_path=endpoint_dict.get("custom_path"),
|
|
||||||
config=endpoint_dict.get("config"),
|
|
||||||
)
|
|
||||||
return endpoint
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _api_key_to_dict(api_key: ProviderAPIKey) -> dict:
|
|
||||||
"""将 API Key 对象转换为字典"""
|
|
||||||
return {
|
|
||||||
"id": api_key.id,
|
|
||||||
"endpoint_id": api_key.endpoint_id,
|
|
||||||
"key_value": api_key.key_value,
|
|
||||||
"is_active": api_key.is_active,
|
|
||||||
"max_rpm": api_key.max_rpm,
|
|
||||||
"current_rpm": api_key.current_rpm,
|
|
||||||
"health_score": api_key.health_score,
|
|
||||||
"circuit_breaker_state": api_key.circuit_breaker_state,
|
|
||||||
"adaptive_concurrency_limit": api_key.adaptive_concurrency_limit,
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _dict_to_api_key(api_key_dict: dict) -> ProviderAPIKey:
|
|
||||||
"""从字典重建 API Key 对象"""
|
|
||||||
api_key = ProviderAPIKey(
|
|
||||||
id=api_key_dict["id"],
|
|
||||||
endpoint_id=api_key_dict["endpoint_id"],
|
|
||||||
key_value=api_key_dict["key_value"],
|
|
||||||
is_active=api_key_dict["is_active"],
|
|
||||||
max_rpm=api_key_dict.get("max_rpm"),
|
|
||||||
current_rpm=api_key_dict.get("current_rpm", 0),
|
|
||||||
health_score=api_key_dict.get("health_score", 1.0),
|
|
||||||
circuit_breaker_state=api_key_dict.get("circuit_breaker_state"),
|
|
||||||
adaptive_concurrency_limit=api_key_dict.get("adaptive_concurrency_limit"),
|
|
||||||
)
|
|
||||||
return api_key
|
|
||||||
24
src/services/cache/user_cache.py
vendored
24
src/services/cache/user_cache.py
vendored
@@ -1,5 +1,22 @@
|
|||||||
"""
|
"""
|
||||||
用户缓存服务 - 减少数据库查询
|
用户缓存服务 - 减少数据库查询
|
||||||
|
|
||||||
|
架构说明
|
||||||
|
========
|
||||||
|
本服务采用混合 async/sync 模式:
|
||||||
|
- 缓存操作(CacheService):真正的 async,使用 aioredis
|
||||||
|
- 数据库查询(db.query):同步的 SQLAlchemy Session
|
||||||
|
|
||||||
|
设计决策
|
||||||
|
--------
|
||||||
|
1. 保持 async 方法签名:因为缓存命中时完全异步,性能最优
|
||||||
|
2. 缓存未命中时的同步查询:FastAPI 会在线程池中执行,不会阻塞事件循环
|
||||||
|
3. 调用方必须在 async 上下文中使用 await
|
||||||
|
|
||||||
|
使用示例
|
||||||
|
--------
|
||||||
|
user = await UserCacheService.get_user_by_id(db, user_id)
|
||||||
|
await UserCacheService.invalidate_user_cache(user_id, email)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -12,9 +29,12 @@ from src.core.logger import logger
|
|||||||
from src.models.database import User
|
from src.models.database import User
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class UserCacheService:
|
class UserCacheService:
|
||||||
"""用户缓存服务"""
|
"""用户缓存服务
|
||||||
|
|
||||||
|
提供 User 的缓存查询功能,减少数据库访问。
|
||||||
|
所有公开方法均为 async,需要在 async 上下文中调用。
|
||||||
|
"""
|
||||||
|
|
||||||
# 缓存 TTL(秒)- 使用统一常量
|
# 缓存 TTL(秒)- 使用统一常量
|
||||||
CACHE_TTL = CacheTTL.USER
|
CACHE_TTL = CacheTTL.USER
|
||||||
|
|||||||
@@ -69,24 +69,29 @@ class ErrorClassifier:
|
|||||||
# 这些错误是由用户请求本身导致的,换 Provider 也无济于事
|
# 这些错误是由用户请求本身导致的,换 Provider 也无济于事
|
||||||
# 注意:标准 API 返回的 error.type 已在 CLIENT_ERROR_TYPES 中处理
|
# 注意:标准 API 返回的 error.type 已在 CLIENT_ERROR_TYPES 中处理
|
||||||
# 这里主要用于匹配非标准格式或第三方代理的错误消息
|
# 这里主要用于匹配非标准格式或第三方代理的错误消息
|
||||||
|
#
|
||||||
|
# 重要:不要在此列表中包含 Provider Key 配置问题(如 invalid_api_key)
|
||||||
|
# 这类错误应该触发故障转移,而不是直接返回给用户
|
||||||
CLIENT_ERROR_PATTERNS: Tuple[str, ...] = (
|
CLIENT_ERROR_PATTERNS: Tuple[str, ...] = (
|
||||||
"could not process image", # 图片处理失败
|
"could not process image", # 图片处理失败
|
||||||
"image too large", # 图片过大
|
"image too large", # 图片过大
|
||||||
"invalid image", # 无效图片
|
"invalid image", # 无效图片
|
||||||
"unsupported image", # 不支持的图片格式
|
"unsupported image", # 不支持的图片格式
|
||||||
"content_policy_violation", # 内容违规
|
"content_policy_violation", # 内容违规
|
||||||
"invalid_api_key", # 无效的 API Key(不同于认证失败)
|
|
||||||
"context_length_exceeded", # 上下文长度超限
|
"context_length_exceeded", # 上下文长度超限
|
||||||
"content_length_limit", # 请求内容长度超限 (Claude API)
|
"content_length_limit", # 请求内容长度超限 (Claude API)
|
||||||
|
"content_length_exceeds", # 内容长度超限变体 (AWS CodeWhisperer)
|
||||||
"max_tokens", # token 数超限
|
"max_tokens", # token 数超限
|
||||||
"invalid_prompt", # 无效的提示词
|
"invalid_prompt", # 无效的提示词
|
||||||
"content too long", # 内容过长
|
"content too long", # 内容过长
|
||||||
|
"input is too long", # 输入过长 (AWS)
|
||||||
"message is too long", # 消息过长
|
"message is too long", # 消息过长
|
||||||
"prompt is too long", # Prompt 超长(第三方代理常见格式)
|
"prompt is too long", # Prompt 超长(第三方代理常见格式)
|
||||||
"image exceeds", # 图片超出限制
|
"image exceeds", # 图片超出限制
|
||||||
"pdf too large", # PDF 过大
|
"pdf too large", # PDF 过大
|
||||||
"file too large", # 文件过大
|
"file too large", # 文件过大
|
||||||
"tool_use_id", # tool_result 引用了不存在的 tool_use(兼容非标准代理)
|
"tool_use_id", # tool_result 引用了不存在的 tool_use(兼容非标准代理)
|
||||||
|
"validationexception", # AWS 验证异常
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -110,18 +115,124 @@ class ErrorClassifier:
|
|||||||
# 表示客户端错误的 error type(不区分大小写)
|
# 表示客户端错误的 error type(不区分大小写)
|
||||||
# 这些 type 表明是请求本身的问题,不应重试
|
# 这些 type 表明是请求本身的问题,不应重试
|
||||||
CLIENT_ERROR_TYPES: Tuple[str, ...] = (
|
CLIENT_ERROR_TYPES: Tuple[str, ...] = (
|
||||||
"invalid_request_error", # Claude/OpenAI 标准客户端错误类型
|
# Claude/OpenAI 标准
|
||||||
"invalid_argument", # Gemini 参数错误
|
"invalid_request_error",
|
||||||
"failed_precondition", # Gemini 前置条件错误
|
# Gemini
|
||||||
|
"invalid_argument",
|
||||||
|
"failed_precondition",
|
||||||
|
# AWS
|
||||||
|
"validationexception",
|
||||||
|
# 通用
|
||||||
|
"validation_error",
|
||||||
|
"bad_request",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 表示客户端错误的 reason/code 字段值
|
||||||
|
CLIENT_ERROR_REASONS: Tuple[str, ...] = (
|
||||||
|
"CONTENT_LENGTH_EXCEEDS_THRESHOLD",
|
||||||
|
"CONTEXT_LENGTH_EXCEEDED",
|
||||||
|
"MAX_TOKENS_EXCEEDED",
|
||||||
|
"INVALID_CONTENT",
|
||||||
|
"CONTENT_POLICY_VIOLATION",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_error_response(self, error_text: Optional[str]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
解析错误响应为结构化数据
|
||||||
|
|
||||||
|
支持多种格式:
|
||||||
|
- {"error": {"type": "...", "message": "..."}} (Claude/OpenAI)
|
||||||
|
- {"error": {"message": "...", "__type": "..."}} (AWS)
|
||||||
|
- {"errorMessage": "..."} (Lambda)
|
||||||
|
- {"error": "..."}
|
||||||
|
- {"message": "...", "reason": "..."}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
结构化的错误信息: {
|
||||||
|
"type": str, # 错误类型
|
||||||
|
"message": str, # 错误消息
|
||||||
|
"reason": str, # 错误原因/代码
|
||||||
|
"raw": str, # 原始文本
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
result = {"type": "", "message": "", "reason": "", "raw": error_text or ""}
|
||||||
|
|
||||||
|
if not error_text:
|
||||||
|
return result
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(error_text)
|
||||||
|
|
||||||
|
# 格式 1: {"error": {"type": "...", "message": "..."}}
|
||||||
|
if isinstance(data.get("error"), dict):
|
||||||
|
error_obj = data["error"]
|
||||||
|
result["type"] = str(error_obj.get("type", ""))
|
||||||
|
result["message"] = str(error_obj.get("message", ""))
|
||||||
|
|
||||||
|
# AWS 格式: {"error": {"__type": "...", "message": "...", "reason": "..."}}
|
||||||
|
# __type 直接在 error 对象中,而不是嵌套在 message 里
|
||||||
|
if "__type" in error_obj:
|
||||||
|
result["type"] = result["type"] or str(error_obj.get("__type", ""))
|
||||||
|
if "reason" in error_obj:
|
||||||
|
result["reason"] = str(error_obj.get("reason", ""))
|
||||||
|
if "code" in error_obj:
|
||||||
|
result["reason"] = result["reason"] or str(error_obj.get("code", ""))
|
||||||
|
|
||||||
|
# 嵌套 JSON 格式: message 字段本身是 JSON 字符串
|
||||||
|
# 支持多种嵌套格式:
|
||||||
|
# - AWS: {"__type": "...", "message": "...", "reason": "..."}
|
||||||
|
# - 第三方代理: {"error": {"type": "...", "message": "..."}}
|
||||||
|
if result["message"].startswith("{"):
|
||||||
|
try:
|
||||||
|
nested = json.loads(result["message"])
|
||||||
|
if isinstance(nested, dict):
|
||||||
|
# AWS 格式
|
||||||
|
if "__type" in nested:
|
||||||
|
result["type"] = result["type"] or str(nested.get("__type", ""))
|
||||||
|
result["message"] = str(nested.get("message", result["message"]))
|
||||||
|
result["reason"] = str(nested.get("reason", ""))
|
||||||
|
# 第三方代理格式: {"error": {"message": "..."}}
|
||||||
|
elif isinstance(nested.get("error"), dict):
|
||||||
|
inner_error = nested["error"]
|
||||||
|
inner_msg = str(inner_error.get("message", ""))
|
||||||
|
if inner_msg:
|
||||||
|
result["message"] = inner_msg
|
||||||
|
# 简单格式: {"message": "..."}
|
||||||
|
elif "message" in nested:
|
||||||
|
result["message"] = str(nested["message"])
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 格式 2: {"error": "..."}
|
||||||
|
elif isinstance(data.get("error"), str):
|
||||||
|
result["message"] = str(data["error"])
|
||||||
|
|
||||||
|
# 格式 3: {"errorMessage": "..."} (Lambda)
|
||||||
|
elif "errorMessage" in data:
|
||||||
|
result["message"] = str(data["errorMessage"])
|
||||||
|
|
||||||
|
# 格式 4: {"message": "...", "reason": "..."}
|
||||||
|
elif "message" in data:
|
||||||
|
result["message"] = str(data["message"])
|
||||||
|
result["reason"] = str(data.get("reason", ""))
|
||||||
|
|
||||||
|
# 提取顶层的 reason/code
|
||||||
|
if not result["reason"]:
|
||||||
|
result["reason"] = str(data.get("reason", data.get("code", "")))
|
||||||
|
|
||||||
|
except (json.JSONDecodeError, TypeError, KeyError):
|
||||||
|
result["message"] = error_text[:500] if len(error_text) > 500 else error_text
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def _is_client_error(self, error_text: Optional[str]) -> bool:
|
def _is_client_error(self, error_text: Optional[str]) -> bool:
|
||||||
"""
|
"""
|
||||||
检测错误响应是否为客户端错误(不应重试)
|
检测错误响应是否为客户端错误(不应重试)
|
||||||
|
|
||||||
判断逻辑:
|
判断逻辑(按优先级):
|
||||||
1. 检查 error.type 是否为已知的客户端错误类型
|
1. 检查 error.type 是否为已知的客户端错误类型
|
||||||
2. 检查错误文本是否包含已知的客户端错误模式
|
2. 检查 reason/code 是否为已知的客户端错误原因
|
||||||
|
3. 回退到关键词匹配
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
error_text: 错误响应文本
|
error_text: 错误响应文本
|
||||||
@@ -132,67 +243,53 @@ class ErrorClassifier:
|
|||||||
if not error_text:
|
if not error_text:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 尝试解析 JSON 并检查 error type
|
parsed = self._parse_error_response(error_text)
|
||||||
try:
|
|
||||||
data = json.loads(error_text)
|
|
||||||
if isinstance(data.get("error"), dict):
|
|
||||||
error_type = data["error"].get("type", "")
|
|
||||||
if error_type and any(
|
|
||||||
t.lower() in error_type.lower() for t in self.CLIENT_ERROR_TYPES
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
except (json.JSONDecodeError, TypeError, KeyError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 回退到关键词匹配
|
# 1. 检查 error type
|
||||||
error_lower = error_text.lower()
|
if parsed["type"]:
|
||||||
return any(pattern.lower() in error_lower for pattern in self.CLIENT_ERROR_PATTERNS)
|
error_type_lower = parsed["type"].lower()
|
||||||
|
if any(t.lower() in error_type_lower for t in self.CLIENT_ERROR_TYPES):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 2. 检查 reason/code
|
||||||
|
if parsed["reason"]:
|
||||||
|
reason_upper = parsed["reason"].upper()
|
||||||
|
if any(r in reason_upper for r in self.CLIENT_ERROR_REASONS):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 3. 回退到关键词匹配(合并 message 和 raw)
|
||||||
|
search_text = f"{parsed['message']} {parsed['raw']}".lower()
|
||||||
|
return any(pattern.lower() in search_text for pattern in self.CLIENT_ERROR_PATTERNS)
|
||||||
|
|
||||||
def _extract_error_message(self, error_text: Optional[str]) -> Optional[str]:
|
def _extract_error_message(self, error_text: Optional[str]) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
从错误响应中提取错误消息
|
从错误响应中提取错误消息
|
||||||
|
|
||||||
支持格式:
|
|
||||||
- {"error": {"message": "..."}} (OpenAI/Claude)
|
|
||||||
- {"error": {"type": "...", "message": "..."}}
|
|
||||||
- {"error": "..."}
|
|
||||||
- {"message": "..."}
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
error_text: 错误响应文本
|
error_text: 错误响应文本
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
提取的错误消息,如果无法解析则返回原始文本
|
提取的错误消息
|
||||||
"""
|
"""
|
||||||
if not error_text:
|
if not error_text:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
parsed = self._parse_error_response(error_text)
|
||||||
data = json.loads(error_text)
|
|
||||||
|
|
||||||
# {"error": {"message": "..."}} 或 {"error": {"type": "...", "message": "..."}}
|
# 构建可读的错误消息
|
||||||
if isinstance(data.get("error"), dict):
|
parts = []
|
||||||
error_obj = data["error"]
|
if parsed["type"]:
|
||||||
message = error_obj.get("message", "")
|
parts.append(parsed["type"])
|
||||||
error_type = error_obj.get("type", "")
|
if parsed["reason"]:
|
||||||
if message:
|
parts.append(f"[{parsed['reason']}]")
|
||||||
if error_type:
|
if parsed["message"]:
|
||||||
return f"{error_type}: {message}"
|
parts.append(parsed["message"])
|
||||||
return str(message)
|
|
||||||
|
|
||||||
# {"error": "..."}
|
if parts:
|
||||||
if isinstance(data.get("error"), str):
|
return ": ".join(parts) if len(parts) > 1 else parts[0]
|
||||||
return str(data["error"])
|
|
||||||
|
|
||||||
# {"message": "..."}
|
|
||||||
if isinstance(data.get("message"), str):
|
|
||||||
return str(data["message"])
|
|
||||||
|
|
||||||
except (json.JSONDecodeError, TypeError, KeyError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 无法解析,返回原始文本(截断)
|
# 无法解析,返回原始文本(截断)
|
||||||
return error_text[:500] if len(error_text) > 500 else error_text
|
return parsed["raw"][:500] if len(parsed["raw"]) > 500 else parsed["raw"]
|
||||||
|
|
||||||
def classify(
|
def classify(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -102,9 +102,15 @@ class FallbackOrchestrator:
|
|||||||
"provider_priority_mode",
|
"provider_priority_mode",
|
||||||
CacheAwareScheduler.PRIORITY_MODE_PROVIDER,
|
CacheAwareScheduler.PRIORITY_MODE_PROVIDER,
|
||||||
)
|
)
|
||||||
|
scheduling_mode = SystemConfigService.get_config(
|
||||||
|
self.db,
|
||||||
|
"scheduling_mode",
|
||||||
|
CacheAwareScheduler.SCHEDULING_MODE_CACHE_AFFINITY,
|
||||||
|
)
|
||||||
self.cache_scheduler = await get_cache_aware_scheduler(
|
self.cache_scheduler = await get_cache_aware_scheduler(
|
||||||
self.redis,
|
self.redis,
|
||||||
priority_mode=priority_mode,
|
priority_mode=priority_mode,
|
||||||
|
scheduling_mode=scheduling_mode,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# 确保运行时配置变更能生效
|
# 确保运行时配置变更能生效
|
||||||
@@ -113,7 +119,13 @@ class FallbackOrchestrator:
|
|||||||
"provider_priority_mode",
|
"provider_priority_mode",
|
||||||
CacheAwareScheduler.PRIORITY_MODE_PROVIDER,
|
CacheAwareScheduler.PRIORITY_MODE_PROVIDER,
|
||||||
)
|
)
|
||||||
|
scheduling_mode = SystemConfigService.get_config(
|
||||||
|
self.db,
|
||||||
|
"scheduling_mode",
|
||||||
|
CacheAwareScheduler.SCHEDULING_MODE_CACHE_AFFINITY,
|
||||||
|
)
|
||||||
self.cache_scheduler.set_priority_mode(priority_mode)
|
self.cache_scheduler.set_priority_mode(priority_mode)
|
||||||
|
self.cache_scheduler.set_scheduling_mode(scheduling_mode)
|
||||||
|
|
||||||
# 确保 cache_scheduler 内部组件也已初始化
|
# 确保 cache_scheduler 内部组件也已初始化
|
||||||
await self.cache_scheduler._ensure_initialized()
|
await self.cache_scheduler._ensure_initialized()
|
||||||
|
|||||||
@@ -1,61 +0,0 @@
|
|||||||
"""响应标准化服务,用于 STANDARD 模式下的响应格式验证和补全"""
|
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
from src.core.logger import logger
|
|
||||||
from src.models.claude import ClaudeResponse
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseNormalizer:
|
|
||||||
"""响应标准化器 - 用于标准模式下验证和补全响应字段"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def normalize_claude_response(
|
|
||||||
response_data: Dict[str, Any], request_id: Optional[str] = None
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
标准化 Claude API 响应
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response_data: 原始响应数据
|
|
||||||
request_id: 请求ID(用于日志)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
标准化后的响应数据(失败时返回原始数据)
|
|
||||||
"""
|
|
||||||
if "error" in response_data:
|
|
||||||
logger.debug(f"[ResponseNormalizer] 检测到错误响应,跳过标准化 | ID:{request_id}")
|
|
||||||
return response_data
|
|
||||||
|
|
||||||
try:
|
|
||||||
validated = ClaudeResponse.model_validate(response_data)
|
|
||||||
normalized = validated.model_dump(mode="json", exclude_none=False)
|
|
||||||
|
|
||||||
logger.debug(f"[ResponseNormalizer] 响应标准化成功 | ID:{request_id}")
|
|
||||||
return normalized
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"[ResponseNormalizer] 响应验证失败,透传原始数据 | ID:{request_id}")
|
|
||||||
return response_data
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def should_normalize(response_data: Dict[str, Any]) -> bool:
|
|
||||||
"""
|
|
||||||
判断是否需要进行标准化
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response_data: 响应数据
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
是否需要标准化
|
|
||||||
"""
|
|
||||||
# 错误响应不需要标准化
|
|
||||||
if "error" in response_data:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 已经包含新字段的响应不需要再次标准化
|
|
||||||
if "context_management" in response_data and "container" in response_data:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
@@ -5,6 +5,10 @@
|
|||||||
- 使用滑动窗口采样,容忍并发波动
|
- 使用滑动窗口采样,容忍并发波动
|
||||||
- 基于窗口内高利用率采样比例决策,而非要求连续高利用率
|
- 基于窗口内高利用率采样比例决策,而非要求连续高利用率
|
||||||
- 增加探测性扩容机制,长时间稳定时主动尝试扩容
|
- 增加探测性扩容机制,长时间稳定时主动尝试扩容
|
||||||
|
|
||||||
|
AIMD 参数说明:
|
||||||
|
- 扩容:加性增加 (+INCREASE_STEP)
|
||||||
|
- 缩容:乘性减少 (*DECREASE_MULTIPLIER,默认 0.85)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
@@ -34,7 +38,7 @@ class AdaptiveConcurrencyManager:
|
|||||||
核心算法:基于滑动窗口利用率的 AIMD
|
核心算法:基于滑动窗口利用率的 AIMD
|
||||||
- 滑动窗口记录最近 N 次请求的利用率
|
- 滑动窗口记录最近 N 次请求的利用率
|
||||||
- 当窗口内高利用率采样比例 >= 60% 时触发扩容
|
- 当窗口内高利用率采样比例 >= 60% 时触发扩容
|
||||||
- 遇到 429 错误时乘性减少 (*0.7)
|
- 遇到 429 错误时乘性减少 (*0.85)
|
||||||
- 长时间无 429 且有流量时触发探测性扩容
|
- 长时间无 429 且有流量时触发探测性扩容
|
||||||
|
|
||||||
扩容条件(满足任一即可):
|
扩容条件(满足任一即可):
|
||||||
|
|||||||
@@ -11,7 +11,6 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import math
|
import math
|
||||||
import os
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from datetime import timedelta # noqa: F401 - kept for potential future use
|
from datetime import timedelta # noqa: F401 - kept for potential future use
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
@@ -40,6 +39,7 @@ class ConcurrencyManager:
|
|||||||
self._memory_lock: asyncio.Lock = asyncio.Lock()
|
self._memory_lock: asyncio.Lock = asyncio.Lock()
|
||||||
self._memory_endpoint_counts: dict[str, int] = {}
|
self._memory_endpoint_counts: dict[str, int] = {}
|
||||||
self._memory_key_counts: dict[str, int] = {}
|
self._memory_key_counts: dict[str, int] = {}
|
||||||
|
self._owns_redis: bool = False
|
||||||
self._memory_initialized = True
|
self._memory_initialized = True
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
@@ -47,41 +47,29 @@ class ConcurrencyManager:
|
|||||||
if self._redis is not None:
|
if self._redis is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
# 优先使用 REDIS_URL,如果没有则根据密码构建 URL
|
|
||||||
redis_url = os.getenv("REDIS_URL")
|
|
||||||
|
|
||||||
if not redis_url:
|
|
||||||
# 本地开发模式:从 REDIS_PASSWORD 构建 URL
|
|
||||||
redis_password = os.getenv("REDIS_PASSWORD")
|
|
||||||
if redis_password:
|
|
||||||
redis_url = f"redis://:{redis_password}@localhost:6379/0"
|
|
||||||
else:
|
|
||||||
redis_url = "redis://localhost:6379/0"
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._redis = await aioredis.from_url(
|
# 复用全局 Redis 客户端(带熔断/降级),避免重复创建连接池
|
||||||
redis_url,
|
from src.clients.redis_client import get_redis_client
|
||||||
encoding="utf-8",
|
|
||||||
decode_responses=True,
|
self._redis = await get_redis_client(require_redis=False)
|
||||||
socket_timeout=5.0,
|
self._owns_redis = False
|
||||||
socket_connect_timeout=5.0,
|
if self._redis:
|
||||||
)
|
logger.info("[OK] ConcurrencyManager 已复用全局 Redis 客户端")
|
||||||
# 测试连接
|
else:
|
||||||
await self._redis.ping()
|
logger.warning("[WARN] Redis 不可用,并发控制降级为内存模式(仅在单实例环境下安全)")
|
||||||
# 脱敏显示(隐藏密码)
|
|
||||||
safe_url = redis_url.split("@")[-1] if "@" in redis_url else redis_url
|
|
||||||
logger.info(f"[OK] Redis 连接成功: {safe_url}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[ERROR] Redis 连接失败: {e}")
|
logger.error(f"[ERROR] 获取全局 Redis 客户端失败: {e}")
|
||||||
logger.warning("[WARN] 并发控制将被禁用(仅在单实例环境下安全)")
|
logger.warning("[WARN] 并发控制将降级为内存模式(仅在单实例环境下安全)")
|
||||||
self._redis = None
|
self._redis = None
|
||||||
|
self._owns_redis = False
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""关闭 Redis 连接"""
|
"""关闭 Redis 连接"""
|
||||||
if self._redis:
|
if self._redis and self._owns_redis:
|
||||||
await self._redis.close()
|
await self._redis.close()
|
||||||
self._redis = None
|
logger.info("ConcurrencyManager Redis 连接已关闭")
|
||||||
logger.info("Redis 连接已关闭")
|
self._redis = None
|
||||||
|
self._owns_redis = False
|
||||||
|
|
||||||
def _get_endpoint_key(self, endpoint_id: str) -> str:
|
def _get_endpoint_key(self, endpoint_id: str) -> str:
|
||||||
"""获取 Endpoint 并发计数的 Redis Key"""
|
"""获取 Endpoint 并发计数的 Redis Key"""
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ RPM (Requests Per Minute) 限流服务
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Dict, Tuple
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@@ -72,11 +72,7 @@ class RPMLimiter:
|
|||||||
# 获取当前分钟窗口
|
# 获取当前分钟窗口
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
window_start = now.replace(second=0, microsecond=0)
|
window_start = now.replace(second=0, microsecond=0)
|
||||||
window_end = (
|
window_end = window_start + timedelta(minutes=1)
|
||||||
window_start.replace(minute=window_start.minute + 1)
|
|
||||||
if window_start.minute < 59
|
|
||||||
else window_start.replace(hour=window_start.hour + 1, minute=0)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 查找或创建追踪记录
|
# 查找或创建追踪记录
|
||||||
tracking = (
|
tracking = (
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from sqlalchemy.orm import Session
|
|||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
from src.database import get_db
|
from src.database import get_db
|
||||||
from src.models.database import AuditEventType, AuditLog
|
from src.models.database import AuditEventType, AuditLog
|
||||||
from src.utils.transaction_manager import transactional
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -19,10 +18,13 @@ from src.utils.transaction_manager import transactional
|
|||||||
|
|
||||||
|
|
||||||
class AuditService:
|
class AuditService:
|
||||||
"""审计服务"""
|
"""审计服务
|
||||||
|
|
||||||
|
事务策略:本服务不负责事务提交,由中间件统一管理。
|
||||||
|
所有方法只做 db.add/flush,提交由请求结束时的中间件处理。
|
||||||
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@transactional(commit=False) # 不自动提交,让调用方决定
|
|
||||||
def log_event(
|
def log_event(
|
||||||
db: Session,
|
db: Session,
|
||||||
event_type: AuditEventType,
|
event_type: AuditEventType,
|
||||||
@@ -54,47 +56,44 @@ class AuditService:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
审计日志记录
|
审计日志记录
|
||||||
|
|
||||||
|
Note:
|
||||||
|
不在此方法内提交事务,由调用方或中间件统一管理。
|
||||||
"""
|
"""
|
||||||
try:
|
audit_log = AuditLog(
|
||||||
audit_log = AuditLog(
|
event_type=event_type.value,
|
||||||
event_type=event_type.value,
|
description=description,
|
||||||
description=description,
|
user_id=user_id,
|
||||||
user_id=user_id,
|
api_key_id=api_key_id,
|
||||||
api_key_id=api_key_id,
|
ip_address=ip_address,
|
||||||
ip_address=ip_address,
|
user_agent=user_agent,
|
||||||
user_agent=user_agent,
|
request_id=request_id,
|
||||||
request_id=request_id,
|
status_code=status_code,
|
||||||
status_code=status_code,
|
error_message=error_message,
|
||||||
error_message=error_message,
|
event_metadata=metadata,
|
||||||
event_metadata=metadata,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
db.add(audit_log)
|
db.add(audit_log)
|
||||||
db.commit() # 立即提交事务,释放数据库锁
|
# 使用 flush 使记录可见但不提交事务,事务由中间件统一管理
|
||||||
db.refresh(audit_log)
|
db.flush()
|
||||||
|
|
||||||
# 同时记录到系统日志
|
# 同时记录到系统日志
|
||||||
log_message = (
|
log_message = (
|
||||||
f"AUDIT [{event_type.value}] - {description} | "
|
f"AUDIT [{event_type.value}] - {description} | "
|
||||||
f"user_id={user_id}, ip={ip_address}"
|
f"user_id={user_id}, ip={ip_address}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if event_type in [
|
if event_type in [
|
||||||
AuditEventType.UNAUTHORIZED_ACCESS,
|
AuditEventType.UNAUTHORIZED_ACCESS,
|
||||||
AuditEventType.SUSPICIOUS_ACTIVITY,
|
AuditEventType.SUSPICIOUS_ACTIVITY,
|
||||||
]:
|
]:
|
||||||
logger.warning(log_message)
|
logger.warning(log_message)
|
||||||
elif event_type in [AuditEventType.LOGIN_FAILED, AuditEventType.REQUEST_FAILED]:
|
elif event_type in [AuditEventType.LOGIN_FAILED, AuditEventType.REQUEST_FAILED]:
|
||||||
logger.info(log_message)
|
logger.info(log_message)
|
||||||
else:
|
else:
|
||||||
logger.debug(log_message)
|
logger.debug(log_message)
|
||||||
|
|
||||||
return audit_log
|
return audit_log
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to log audit event: {e}")
|
|
||||||
db.rollback()
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def log_login_attempt(
|
def log_login_attempt(
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
from src.database import create_session
|
from src.database import create_session
|
||||||
from src.models.database import Usage
|
from src.models.database import AuditLog, Usage
|
||||||
from src.services.system.config import SystemConfigService
|
from src.services.system.config import SystemConfigService
|
||||||
from src.services.system.scheduler import get_scheduler
|
from src.services.system.scheduler import get_scheduler
|
||||||
from src.services.system.stats_aggregator import StatsAggregatorService
|
from src.services.system.stats_aggregator import StatsAggregatorService
|
||||||
@@ -35,6 +35,7 @@ class CleanupScheduler:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.running = False
|
self.running = False
|
||||||
self._interval_tasks = []
|
self._interval_tasks = []
|
||||||
|
self._stats_aggregation_lock = asyncio.Lock()
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
"""启动调度器"""
|
"""启动调度器"""
|
||||||
@@ -56,6 +57,14 @@ class CleanupScheduler:
|
|||||||
job_id="stats_aggregation",
|
job_id="stats_aggregation",
|
||||||
name="统计数据聚合",
|
name="统计数据聚合",
|
||||||
)
|
)
|
||||||
|
# 统计聚合补偿任务 - 每 30 分钟检查缺失并回填
|
||||||
|
scheduler.add_interval_job(
|
||||||
|
self._scheduled_stats_aggregation,
|
||||||
|
minutes=30,
|
||||||
|
job_id="stats_aggregation_backfill",
|
||||||
|
name="统计数据聚合补偿",
|
||||||
|
backfill=True,
|
||||||
|
)
|
||||||
|
|
||||||
# 清理任务 - 凌晨 3 点执行
|
# 清理任务 - 凌晨 3 点执行
|
||||||
scheduler.add_cron_job(
|
scheduler.add_cron_job(
|
||||||
@@ -82,6 +91,15 @@ class CleanupScheduler:
|
|||||||
name="Pending状态清理",
|
name="Pending状态清理",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 审计日志清理 - 凌晨 4 点执行
|
||||||
|
scheduler.add_cron_job(
|
||||||
|
self._scheduled_audit_cleanup,
|
||||||
|
hour=4,
|
||||||
|
minute=0,
|
||||||
|
job_id="audit_cleanup",
|
||||||
|
name="审计日志清理",
|
||||||
|
)
|
||||||
|
|
||||||
# 启动时执行一次初始化任务
|
# 启动时执行一次初始化任务
|
||||||
asyncio.create_task(self._run_startup_tasks())
|
asyncio.create_task(self._run_startup_tasks())
|
||||||
|
|
||||||
@@ -115,9 +133,9 @@ class CleanupScheduler:
|
|||||||
|
|
||||||
# ========== 任务函数(APScheduler 直接调用异步函数) ==========
|
# ========== 任务函数(APScheduler 直接调用异步函数) ==========
|
||||||
|
|
||||||
async def _scheduled_stats_aggregation(self):
|
async def _scheduled_stats_aggregation(self, backfill: bool = False):
|
||||||
"""统计聚合任务(定时调用)"""
|
"""统计聚合任务(定时调用)"""
|
||||||
await self._perform_stats_aggregation()
|
await self._perform_stats_aggregation(backfill=backfill)
|
||||||
|
|
||||||
async def _scheduled_cleanup(self):
|
async def _scheduled_cleanup(self):
|
||||||
"""清理任务(定时调用)"""
|
"""清理任务(定时调用)"""
|
||||||
@@ -136,6 +154,10 @@ class CleanupScheduler:
|
|||||||
"""Pending 清理任务(定时调用)"""
|
"""Pending 清理任务(定时调用)"""
|
||||||
await self._perform_pending_cleanup()
|
await self._perform_pending_cleanup()
|
||||||
|
|
||||||
|
async def _scheduled_audit_cleanup(self):
|
||||||
|
"""审计日志清理任务(定时调用)"""
|
||||||
|
await self._perform_audit_cleanup()
|
||||||
|
|
||||||
# ========== 实际任务实现 ==========
|
# ========== 实际任务实现 ==========
|
||||||
|
|
||||||
async def _perform_stats_aggregation(self, backfill: bool = False):
|
async def _perform_stats_aggregation(self, backfill: bool = False):
|
||||||
@@ -144,136 +166,157 @@ class CleanupScheduler:
|
|||||||
Args:
|
Args:
|
||||||
backfill: 是否回填历史数据(启动时检查缺失的日期)
|
backfill: 是否回填历史数据(启动时检查缺失的日期)
|
||||||
"""
|
"""
|
||||||
db = create_session()
|
if self._stats_aggregation_lock.locked():
|
||||||
try:
|
logger.info("统计聚合任务正在运行,跳过本次触发")
|
||||||
# 检查是否启用统计聚合
|
return
|
||||||
if not SystemConfigService.get_config(db, "enable_stats_aggregation", True):
|
|
||||||
logger.info("统计聚合已禁用,跳过聚合任务")
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info("开始执行统计数据聚合...")
|
async with self._stats_aggregation_lock:
|
||||||
|
db = create_session()
|
||||||
from src.models.database import StatsDaily, User as DBUser
|
try:
|
||||||
from src.services.system.scheduler import APP_TIMEZONE
|
# 检查是否启用统计聚合
|
||||||
from zoneinfo import ZoneInfo
|
if not SystemConfigService.get_config(db, "enable_stats_aggregation", True):
|
||||||
|
logger.info("统计聚合已禁用,跳过聚合任务")
|
||||||
# 使用业务时区计算日期,确保与定时任务触发时间一致
|
|
||||||
# 定时任务在 Asia/Shanghai 凌晨 1 点触发,此时应聚合 Asia/Shanghai 的"昨天"
|
|
||||||
app_tz = ZoneInfo(APP_TIMEZONE)
|
|
||||||
now_local = datetime.now(app_tz)
|
|
||||||
today_local = now_local.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
||||||
|
|
||||||
if backfill:
|
|
||||||
# 启动时检查并回填缺失的日期
|
|
||||||
from src.models.database import StatsSummary
|
|
||||||
|
|
||||||
summary = db.query(StatsSummary).first()
|
|
||||||
if not summary:
|
|
||||||
# 首次运行,回填所有历史数据
|
|
||||||
logger.info("检测到首次运行,开始回填历史统计数据...")
|
|
||||||
days_to_backfill = SystemConfigService.get_config(
|
|
||||||
db, "stats_backfill_days", 365
|
|
||||||
)
|
|
||||||
count = StatsAggregatorService.backfill_historical_data(
|
|
||||||
db, days=days_to_backfill
|
|
||||||
)
|
|
||||||
logger.info(f"历史数据回填完成,共 {count} 天")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# 非首次运行,检查最近是否有缺失的日期需要回填
|
logger.info("开始执行统计数据聚合...")
|
||||||
latest_stat = (
|
|
||||||
db.query(StatsDaily)
|
|
||||||
.order_by(StatsDaily.date.desc())
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if latest_stat:
|
from src.models.database import StatsDaily, User as DBUser
|
||||||
latest_date_utc = latest_stat.date
|
from src.services.system.scheduler import APP_TIMEZONE
|
||||||
if latest_date_utc.tzinfo is None:
|
from zoneinfo import ZoneInfo
|
||||||
latest_date_utc = latest_date_utc.replace(tzinfo=timezone.utc)
|
|
||||||
else:
|
|
||||||
latest_date_utc = latest_date_utc.astimezone(timezone.utc)
|
|
||||||
|
|
||||||
# 使用业务日期计算缺失区间(避免用 UTC 年月日导致日期偏移,且对 DST 更安全)
|
# 使用业务时区计算日期,确保与定时任务触发时间一致
|
||||||
latest_business_date = latest_date_utc.astimezone(app_tz).date()
|
# 定时任务在 Asia/Shanghai 凌晨 1 点触发,此时应聚合 Asia/Shanghai 的"昨天"
|
||||||
yesterday_business_date = (today_local.date() - timedelta(days=1))
|
app_tz = ZoneInfo(APP_TIMEZONE)
|
||||||
missing_start_date = latest_business_date + timedelta(days=1)
|
now_local = datetime.now(app_tz)
|
||||||
|
today_local = now_local.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
|
||||||
if missing_start_date <= yesterday_business_date:
|
if backfill:
|
||||||
missing_days = (yesterday_business_date - missing_start_date).days + 1
|
# 启动时检查并回填缺失的日期
|
||||||
logger.info(
|
from src.models.database import StatsSummary
|
||||||
f"检测到缺失 {missing_days} 天的统计数据 "
|
|
||||||
f"({missing_start_date} ~ {yesterday_business_date}),开始回填..."
|
summary = db.query(StatsSummary).first()
|
||||||
|
if not summary:
|
||||||
|
# 首次运行,回填所有历史数据
|
||||||
|
logger.info("检测到首次运行,开始回填历史统计数据...")
|
||||||
|
days_to_backfill = SystemConfigService.get_config(
|
||||||
|
db, "stats_backfill_days", 365
|
||||||
)
|
)
|
||||||
|
count = StatsAggregatorService.backfill_historical_data(
|
||||||
|
db, days=days_to_backfill
|
||||||
|
)
|
||||||
|
logger.info(f"历史数据回填完成,共 {count} 天")
|
||||||
|
return
|
||||||
|
|
||||||
current_date = missing_start_date
|
# 非首次运行,检查最近是否有缺失的日期需要回填
|
||||||
users = db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
|
latest_stat = db.query(StatsDaily).order_by(StatsDaily.date.desc()).first()
|
||||||
|
|
||||||
while current_date <= yesterday_business_date:
|
if latest_stat:
|
||||||
try:
|
latest_date_utc = latest_stat.date
|
||||||
current_date_local = datetime.combine(
|
if latest_date_utc.tzinfo is None:
|
||||||
current_date, datetime.min.time(), tzinfo=app_tz
|
latest_date_utc = latest_date_utc.replace(tzinfo=timezone.utc)
|
||||||
|
else:
|
||||||
|
latest_date_utc = latest_date_utc.astimezone(timezone.utc)
|
||||||
|
|
||||||
|
# 使用业务日期计算缺失区间(避免用 UTC 年月日导致日期偏移,且对 DST 更安全)
|
||||||
|
latest_business_date = latest_date_utc.astimezone(app_tz).date()
|
||||||
|
yesterday_business_date = today_local.date() - timedelta(days=1)
|
||||||
|
missing_start_date = latest_business_date + timedelta(days=1)
|
||||||
|
|
||||||
|
if missing_start_date <= yesterday_business_date:
|
||||||
|
missing_days = (
|
||||||
|
yesterday_business_date - missing_start_date
|
||||||
|
).days + 1
|
||||||
|
|
||||||
|
# 限制最大回填天数,防止停机很久后一次性回填太多
|
||||||
|
max_backfill_days: int = SystemConfigService.get_config(
|
||||||
|
db, "max_stats_backfill_days", 30
|
||||||
|
) or 30
|
||||||
|
if missing_days > max_backfill_days:
|
||||||
|
logger.warning(
|
||||||
|
f"缺失 {missing_days} 天数据超过最大回填限制 "
|
||||||
|
f"{max_backfill_days} 天,只回填最近 {max_backfill_days} 天"
|
||||||
)
|
)
|
||||||
StatsAggregatorService.aggregate_daily_stats(db, current_date_local)
|
missing_start_date = yesterday_business_date - timedelta(
|
||||||
# 聚合用户数据
|
days=max_backfill_days - 1
|
||||||
for (user_id,) in users:
|
)
|
||||||
try:
|
missing_days = max_backfill_days
|
||||||
StatsAggregatorService.aggregate_user_daily_stats(
|
|
||||||
db, user_id, current_date_local
|
logger.info(
|
||||||
)
|
f"检测到缺失 {missing_days} 天的统计数据 "
|
||||||
except Exception as e:
|
f"({missing_start_date} ~ {yesterday_business_date}),开始回填..."
|
||||||
logger.warning(
|
)
|
||||||
f"回填用户 {user_id} 日期 {current_date} 失败: {e}"
|
|
||||||
)
|
current_date = missing_start_date
|
||||||
try:
|
users = (
|
||||||
db.rollback()
|
db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
|
||||||
except Exception:
|
)
|
||||||
pass
|
|
||||||
except Exception as e:
|
while current_date <= yesterday_business_date:
|
||||||
logger.warning(f"回填日期 {current_date} 失败: {e}")
|
|
||||||
try:
|
try:
|
||||||
db.rollback()
|
current_date_local = datetime.combine(
|
||||||
except Exception:
|
current_date, datetime.min.time(), tzinfo=app_tz
|
||||||
pass
|
)
|
||||||
|
StatsAggregatorService.aggregate_daily_stats(
|
||||||
|
db, current_date_local
|
||||||
|
)
|
||||||
|
for (user_id,) in users:
|
||||||
|
try:
|
||||||
|
StatsAggregatorService.aggregate_user_daily_stats(
|
||||||
|
db, user_id, current_date_local
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"回填用户 {user_id} 日期 {current_date} 失败: {e}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
db.rollback()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"回填日期 {current_date} 失败: {e}")
|
||||||
|
try:
|
||||||
|
db.rollback()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
current_date += timedelta(days=1)
|
current_date += timedelta(days=1)
|
||||||
|
|
||||||
# 更新全局汇总
|
StatsAggregatorService.update_summary(db)
|
||||||
StatsAggregatorService.update_summary(db)
|
logger.info(f"缺失数据回填完成,共 {missing_days} 天")
|
||||||
logger.info(f"缺失数据回填完成,共 {missing_days} 天")
|
else:
|
||||||
else:
|
logger.info("统计数据已是最新,无需回填")
|
||||||
logger.info("统计数据已是最新,无需回填")
|
return
|
||||||
return
|
|
||||||
|
|
||||||
# 定时任务:聚合昨天的数据
|
# 定时任务:聚合昨天的数据
|
||||||
# 注意:aggregate_daily_stats 期望业务时区的日期,不是 UTC
|
yesterday_local = today_local - timedelta(days=1)
|
||||||
yesterday_local = today_local - timedelta(days=1)
|
|
||||||
|
|
||||||
StatsAggregatorService.aggregate_daily_stats(db, yesterday_local)
|
StatsAggregatorService.aggregate_daily_stats(db, yesterday_local)
|
||||||
|
|
||||||
# 聚合所有用户的昨日数据
|
users = db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
|
||||||
users = db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
|
for (user_id,) in users:
|
||||||
for (user_id,) in users:
|
|
||||||
try:
|
|
||||||
StatsAggregatorService.aggregate_user_daily_stats(db, user_id, yesterday_local)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"聚合用户 {user_id} 统计数据失败: {e}")
|
|
||||||
# 回滚当前用户的失败操作,继续处理其他用户
|
|
||||||
try:
|
try:
|
||||||
db.rollback()
|
StatsAggregatorService.aggregate_user_daily_stats(
|
||||||
except Exception:
|
db, user_id, yesterday_local
|
||||||
pass
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"聚合用户 {user_id} 统计数据失败: {e}")
|
||||||
|
try:
|
||||||
|
db.rollback()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# 更新全局汇总
|
StatsAggregatorService.update_summary(db)
|
||||||
StatsAggregatorService.update_summary(db)
|
|
||||||
|
|
||||||
logger.info("统计数据聚合完成")
|
logger.info("统计数据聚合完成")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"统计聚合任务执行失败: {e}")
|
logger.exception(f"统计聚合任务执行失败: {e}")
|
||||||
db.rollback()
|
try:
|
||||||
finally:
|
db.rollback()
|
||||||
db.close()
|
except Exception:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
async def _perform_pending_cleanup(self):
|
async def _perform_pending_cleanup(self):
|
||||||
"""执行 pending 状态清理"""
|
"""执行 pending 状态清理"""
|
||||||
@@ -300,6 +343,70 @@ class CleanupScheduler:
|
|||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
async def _perform_audit_cleanup(self):
|
||||||
|
"""执行审计日志清理任务"""
|
||||||
|
db = create_session()
|
||||||
|
try:
|
||||||
|
# 检查是否启用自动清理
|
||||||
|
if not SystemConfigService.get_config(db, "enable_auto_cleanup", True):
|
||||||
|
logger.info("自动清理已禁用,跳过审计日志清理")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 获取审计日志保留天数(默认 30 天,最少 7 天)
|
||||||
|
audit_retention_days = max(
|
||||||
|
SystemConfigService.get_config(db, "audit_log_retention_days", 30),
|
||||||
|
7, # 最少保留 7 天,防止误配置删除所有审计日志
|
||||||
|
)
|
||||||
|
batch_size = SystemConfigService.get_config(db, "cleanup_batch_size", 1000)
|
||||||
|
|
||||||
|
cutoff_time = datetime.now(timezone.utc) - timedelta(days=audit_retention_days)
|
||||||
|
|
||||||
|
logger.info(f"开始清理 {audit_retention_days} 天前的审计日志...")
|
||||||
|
|
||||||
|
total_deleted = 0
|
||||||
|
while True:
|
||||||
|
# 先查询要删除的记录 ID(分批)
|
||||||
|
records_to_delete = (
|
||||||
|
db.query(AuditLog.id)
|
||||||
|
.filter(AuditLog.created_at < cutoff_time)
|
||||||
|
.limit(batch_size)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not records_to_delete:
|
||||||
|
break
|
||||||
|
|
||||||
|
record_ids = [r.id for r in records_to_delete]
|
||||||
|
|
||||||
|
# 执行删除
|
||||||
|
result = db.execute(
|
||||||
|
delete(AuditLog)
|
||||||
|
.where(AuditLog.id.in_(record_ids))
|
||||||
|
.execution_options(synchronize_session=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
rows_deleted = result.rowcount
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
total_deleted += rows_deleted
|
||||||
|
logger.debug(f"已删除 {rows_deleted} 条审计日志,累计 {total_deleted} 条")
|
||||||
|
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
if total_deleted > 0:
|
||||||
|
logger.info(f"审计日志清理完成,共删除 {total_deleted} 条记录")
|
||||||
|
else:
|
||||||
|
logger.info("无需清理的审计日志")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"审计日志清理失败: {e}")
|
||||||
|
try:
|
||||||
|
db.rollback()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
async def _perform_cleanup(self):
|
async def _perform_cleanup(self):
|
||||||
"""执行清理任务"""
|
"""执行清理任务"""
|
||||||
db = create_session()
|
db = create_session()
|
||||||
|
|||||||
@@ -71,6 +71,10 @@ class SystemConfigService:
|
|||||||
"value": "provider",
|
"value": "provider",
|
||||||
"description": "优先级策略:provider(提供商优先模式) 或 global_key(全局Key优先模式)",
|
"description": "优先级策略:provider(提供商优先模式) 或 global_key(全局Key优先模式)",
|
||||||
},
|
},
|
||||||
|
"scheduling_mode": {
|
||||||
|
"value": "cache_affinity",
|
||||||
|
"description": "调度模式:fixed_order(固定顺序模式,严格按优先级顺序) 或 cache_affinity(缓存亲和模式,优先使用已缓存的Provider)",
|
||||||
|
},
|
||||||
"auto_delete_expired_keys": {
|
"auto_delete_expired_keys": {
|
||||||
"value": False,
|
"value": False,
|
||||||
"description": "是否自动删除过期的API Key(True=物理删除,False=仅禁用),仅管理员可配置",
|
"description": "是否自动删除过期的API Key(True=物理删除,False=仅禁用),仅管理员可配置",
|
||||||
|
|||||||
@@ -56,65 +56,44 @@ class StatsAggregatorService:
|
|||||||
"""统计数据聚合服务"""
|
"""统计数据聚合服务"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def aggregate_daily_stats(db: Session, date: datetime) -> StatsDaily:
|
def compute_daily_stats(db: Session, date: datetime) -> dict:
|
||||||
"""聚合指定日期的统计数据
|
"""计算指定业务日期的统计数据(不写入数据库)"""
|
||||||
|
|
||||||
Args:
|
|
||||||
db: 数据库会话
|
|
||||||
date: 要聚合的业务日期(使用 APP_TIMEZONE 时区)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
StatsDaily 记录
|
|
||||||
"""
|
|
||||||
# 将业务日期转换为 UTC 时间范围
|
|
||||||
day_start, day_end = _get_business_day_range(date)
|
day_start, day_end = _get_business_day_range(date)
|
||||||
|
|
||||||
# stats_daily.date 存储的是业务日期对应的 UTC 开始时间
|
|
||||||
# 检查是否已存在该日期的记录
|
|
||||||
existing = db.query(StatsDaily).filter(StatsDaily.date == day_start).first()
|
|
||||||
if existing:
|
|
||||||
stats = existing
|
|
||||||
else:
|
|
||||||
stats = StatsDaily(id=str(uuid.uuid4()), date=day_start)
|
|
||||||
|
|
||||||
# 基础请求统计
|
|
||||||
base_query = db.query(Usage).filter(
|
base_query = db.query(Usage).filter(
|
||||||
and_(Usage.created_at >= day_start, Usage.created_at < day_end)
|
and_(Usage.created_at >= day_start, Usage.created_at < day_end)
|
||||||
)
|
)
|
||||||
|
|
||||||
total_requests = base_query.count()
|
total_requests = base_query.count()
|
||||||
|
|
||||||
# 如果没有请求,直接返回空记录
|
|
||||||
if total_requests == 0:
|
if total_requests == 0:
|
||||||
stats.total_requests = 0
|
return {
|
||||||
stats.success_requests = 0
|
"day_start": day_start,
|
||||||
stats.error_requests = 0
|
"total_requests": 0,
|
||||||
stats.input_tokens = 0
|
"success_requests": 0,
|
||||||
stats.output_tokens = 0
|
"error_requests": 0,
|
||||||
stats.cache_creation_tokens = 0
|
"input_tokens": 0,
|
||||||
stats.cache_read_tokens = 0
|
"output_tokens": 0,
|
||||||
stats.total_cost = 0.0
|
"cache_creation_tokens": 0,
|
||||||
stats.actual_total_cost = 0.0
|
"cache_read_tokens": 0,
|
||||||
stats.input_cost = 0.0
|
"total_cost": 0.0,
|
||||||
stats.output_cost = 0.0
|
"actual_total_cost": 0.0,
|
||||||
stats.cache_creation_cost = 0.0
|
"input_cost": 0.0,
|
||||||
stats.cache_read_cost = 0.0
|
"output_cost": 0.0,
|
||||||
stats.avg_response_time_ms = 0.0
|
"cache_creation_cost": 0.0,
|
||||||
stats.fallback_count = 0
|
"cache_read_cost": 0.0,
|
||||||
|
"avg_response_time_ms": 0.0,
|
||||||
|
"fallback_count": 0,
|
||||||
|
"unique_models": 0,
|
||||||
|
"unique_providers": 0,
|
||||||
|
}
|
||||||
|
|
||||||
if not existing:
|
|
||||||
db.add(stats)
|
|
||||||
db.commit()
|
|
||||||
return stats
|
|
||||||
|
|
||||||
# 错误请求数
|
|
||||||
error_requests = (
|
error_requests = (
|
||||||
base_query.filter(
|
base_query.filter(
|
||||||
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
|
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
|
||||||
).count()
|
).count()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Token 和成本聚合
|
|
||||||
aggregated = (
|
aggregated = (
|
||||||
db.query(
|
db.query(
|
||||||
func.sum(Usage.input_tokens).label("input_tokens"),
|
func.sum(Usage.input_tokens).label("input_tokens"),
|
||||||
@@ -157,7 +136,6 @@ class StatsAggregatorService:
|
|||||||
or 0
|
or 0
|
||||||
)
|
)
|
||||||
|
|
||||||
# 使用维度统计
|
|
||||||
unique_models = (
|
unique_models = (
|
||||||
db.query(func.count(func.distinct(Usage.model)))
|
db.query(func.count(func.distinct(Usage.model)))
|
||||||
.filter(and_(Usage.created_at >= day_start, Usage.created_at < day_end))
|
.filter(and_(Usage.created_at >= day_start, Usage.created_at < day_end))
|
||||||
@@ -171,31 +149,74 @@ class StatsAggregatorService:
|
|||||||
or 0
|
or 0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"day_start": day_start,
|
||||||
|
"total_requests": total_requests,
|
||||||
|
"success_requests": total_requests - error_requests,
|
||||||
|
"error_requests": error_requests,
|
||||||
|
"input_tokens": int(aggregated.input_tokens or 0) if aggregated else 0,
|
||||||
|
"output_tokens": int(aggregated.output_tokens or 0) if aggregated else 0,
|
||||||
|
"cache_creation_tokens": int(aggregated.cache_creation_tokens or 0) if aggregated else 0,
|
||||||
|
"cache_read_tokens": int(aggregated.cache_read_tokens or 0) if aggregated else 0,
|
||||||
|
"total_cost": float(aggregated.total_cost or 0) if aggregated else 0.0,
|
||||||
|
"actual_total_cost": float(aggregated.actual_total_cost or 0) if aggregated else 0.0,
|
||||||
|
"input_cost": float(aggregated.input_cost or 0) if aggregated else 0.0,
|
||||||
|
"output_cost": float(aggregated.output_cost or 0) if aggregated else 0.0,
|
||||||
|
"cache_creation_cost": float(aggregated.cache_creation_cost or 0) if aggregated else 0.0,
|
||||||
|
"cache_read_cost": float(aggregated.cache_read_cost or 0) if aggregated else 0.0,
|
||||||
|
"avg_response_time_ms": float(aggregated.avg_response_time or 0) if aggregated else 0.0,
|
||||||
|
"fallback_count": fallback_count,
|
||||||
|
"unique_models": unique_models,
|
||||||
|
"unique_providers": unique_providers,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def aggregate_daily_stats(db: Session, date: datetime) -> StatsDaily:
|
||||||
|
"""聚合指定日期的统计数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
date: 要聚合的业务日期(使用 APP_TIMEZONE 时区)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StatsDaily 记录
|
||||||
|
"""
|
||||||
|
computed = StatsAggregatorService.compute_daily_stats(db, date)
|
||||||
|
day_start = computed["day_start"]
|
||||||
|
|
||||||
|
# stats_daily.date 存储的是业务日期对应的 UTC 开始时间
|
||||||
|
# 检查是否已存在该日期的记录
|
||||||
|
existing = db.query(StatsDaily).filter(StatsDaily.date == day_start).first()
|
||||||
|
if existing:
|
||||||
|
stats = existing
|
||||||
|
else:
|
||||||
|
stats = StatsDaily(id=str(uuid.uuid4()), date=day_start)
|
||||||
|
|
||||||
# 更新统计记录
|
# 更新统计记录
|
||||||
stats.total_requests = total_requests
|
stats.total_requests = computed["total_requests"]
|
||||||
stats.success_requests = total_requests - error_requests
|
stats.success_requests = computed["success_requests"]
|
||||||
stats.error_requests = error_requests
|
stats.error_requests = computed["error_requests"]
|
||||||
stats.input_tokens = int(aggregated.input_tokens or 0)
|
stats.input_tokens = computed["input_tokens"]
|
||||||
stats.output_tokens = int(aggregated.output_tokens or 0)
|
stats.output_tokens = computed["output_tokens"]
|
||||||
stats.cache_creation_tokens = int(aggregated.cache_creation_tokens or 0)
|
stats.cache_creation_tokens = computed["cache_creation_tokens"]
|
||||||
stats.cache_read_tokens = int(aggregated.cache_read_tokens or 0)
|
stats.cache_read_tokens = computed["cache_read_tokens"]
|
||||||
stats.total_cost = float(aggregated.total_cost or 0)
|
stats.total_cost = computed["total_cost"]
|
||||||
stats.actual_total_cost = float(aggregated.actual_total_cost or 0)
|
stats.actual_total_cost = computed["actual_total_cost"]
|
||||||
stats.input_cost = float(aggregated.input_cost or 0)
|
stats.input_cost = computed["input_cost"]
|
||||||
stats.output_cost = float(aggregated.output_cost or 0)
|
stats.output_cost = computed["output_cost"]
|
||||||
stats.cache_creation_cost = float(aggregated.cache_creation_cost or 0)
|
stats.cache_creation_cost = computed["cache_creation_cost"]
|
||||||
stats.cache_read_cost = float(aggregated.cache_read_cost or 0)
|
stats.cache_read_cost = computed["cache_read_cost"]
|
||||||
stats.avg_response_time_ms = float(aggregated.avg_response_time or 0)
|
stats.avg_response_time_ms = computed["avg_response_time_ms"]
|
||||||
stats.fallback_count = fallback_count
|
stats.fallback_count = computed["fallback_count"]
|
||||||
stats.unique_models = unique_models
|
stats.unique_models = computed["unique_models"]
|
||||||
stats.unique_providers = unique_providers
|
stats.unique_providers = computed["unique_providers"]
|
||||||
|
|
||||||
if not existing:
|
if not existing:
|
||||||
db.add(stats)
|
db.add(stats)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
# 日志使用业务日期(输入参数),而不是 UTC 日期
|
# 日志使用业务日期(输入参数),而不是 UTC 日期
|
||||||
logger.info(f"[StatsAggregator] 聚合日期 {date.date()} 完成: {total_requests} 请求")
|
logger.info(f"[StatsAggregator] 聚合日期 {date.date()} 完成: {computed['total_requests']} 请求")
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -457,7 +457,7 @@ class StreamUsageTracker:
|
|||||||
|
|
||||||
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应 | 估算输入tokens:{self.input_tokens}")
|
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应 | 估算输入tokens:{self.input_tokens}")
|
||||||
|
|
||||||
# 更新状态为 streaming
|
# 更新状态为 streaming,同时更新 provider
|
||||||
if self.request_id:
|
if self.request_id:
|
||||||
try:
|
try:
|
||||||
from src.services.usage.service import UsageService
|
from src.services.usage.service import UsageService
|
||||||
@@ -465,6 +465,7 @@ class StreamUsageTracker:
|
|||||||
db=self.db,
|
db=self.db,
|
||||||
request_id=self.request_id,
|
request_id=self.request_id,
|
||||||
status="streaming",
|
status="streaming",
|
||||||
|
provider=self.provider,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
|
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
|
||||||
|
|||||||
@@ -210,7 +210,15 @@ class ApiKeyService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_rate_limit(db: Session, api_key: ApiKey, window_minutes: int = 1) -> tuple[bool, int]:
|
def check_rate_limit(db: Session, api_key: ApiKey, window_minutes: int = 1) -> tuple[bool, int]:
|
||||||
"""检查速率限制"""
|
"""检查速率限制
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(is_allowed, remaining): 是否允许请求,剩余可用次数
|
||||||
|
当 rate_limit 为 None 时表示不限制,返回 (True, -1)
|
||||||
|
"""
|
||||||
|
# 如果 rate_limit 为 None,表示不限制
|
||||||
|
if api_key.rate_limit is None:
|
||||||
|
return True, -1 # -1 表示无限制
|
||||||
|
|
||||||
# 计算时间窗口
|
# 计算时间窗口
|
||||||
window_start = datetime.now(timezone.utc) - timedelta(minutes=window_minutes)
|
window_start = datetime.now(timezone.utc) - timedelta(minutes=window_minutes)
|
||||||
|
|||||||
@@ -71,8 +71,8 @@ class PreferenceService:
|
|||||||
raise NotFoundException("Provider not found or inactive")
|
raise NotFoundException("Provider not found or inactive")
|
||||||
preferences.default_provider_id = default_provider_id
|
preferences.default_provider_id = default_provider_id
|
||||||
if theme is not None:
|
if theme is not None:
|
||||||
if theme not in ["light", "dark", "auto"]:
|
if theme not in ["light", "dark", "auto", "system"]:
|
||||||
raise ValueError("Invalid theme. Must be 'light', 'dark', or 'auto'")
|
raise ValueError("Invalid theme. Must be 'light', 'dark', 'auto', or 'system'")
|
||||||
preferences.theme = theme
|
preferences.theme = theme
|
||||||
if language is not None:
|
if language is not None:
|
||||||
preferences.language = language
|
preferences.language = language
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
提供统一的用户认证和授权功能
|
提供统一的用户认证和授权功能
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import Depends, Header, HTTPException, status
|
from fastapi import Depends, Header, HTTPException, status
|
||||||
@@ -41,13 +42,20 @@ async def get_current_user(
|
|||||||
try:
|
try:
|
||||||
# 验证Token格式和签名
|
# 验证Token格式和签名
|
||||||
try:
|
try:
|
||||||
payload = await AuthService.verify_token(token)
|
payload = await AuthService.verify_token(token, token_type="access")
|
||||||
except HTTPException as token_error:
|
except HTTPException as token_error:
|
||||||
# 保持原始的HTTP状态码(如401 Unauthorized),不要转换为403
|
# 保持原始的HTTP状态码(如401 Unauthorized),不要转换为403
|
||||||
logger.error(f"Token验证失败: {token_error.status_code}: {token_error.detail}, Token前10位: {token[:10]}...")
|
token_fp = hashlib.sha256(token.encode()).hexdigest()[:12]
|
||||||
|
logger.error(
|
||||||
|
"Token验证失败: {}: {}, token_fp={}",
|
||||||
|
token_error.status_code,
|
||||||
|
token_error.detail,
|
||||||
|
token_fp,
|
||||||
|
)
|
||||||
raise # 重新抛出原始异常,保持状态码
|
raise # 重新抛出原始异常,保持状态码
|
||||||
except Exception as token_error:
|
except Exception as token_error:
|
||||||
logger.error(f"Token验证失败: {token_error}, Token前10位: {token[:10]}...")
|
token_fp = hashlib.sha256(token.encode()).hexdigest()[:12]
|
||||||
|
logger.error("Token验证失败: {}, token_fp={}", token_error, token_fp)
|
||||||
raise ForbiddenException("无效的Token")
|
raise ForbiddenException("无效的Token")
|
||||||
|
|
||||||
user_id = payload.get("user_id")
|
user_id = payload.get("user_id")
|
||||||
@@ -63,7 +71,8 @@ async def get_current_user(
|
|||||||
raise ForbiddenException("无效的认证凭据")
|
raise ForbiddenException("无效的认证凭据")
|
||||||
|
|
||||||
# 仅在DEBUG模式下记录详细信息
|
# 仅在DEBUG模式下记录详细信息
|
||||||
logger.debug(f"尝试获取用户: user_id={user_id}, token前10位: {token[:10]}...")
|
token_fp = hashlib.sha256(token.encode()).hexdigest()[:12]
|
||||||
|
logger.debug("尝试获取用户: user_id={}, token_fp={}", user_id, token_fp)
|
||||||
|
|
||||||
# 确保user_id是字符串格式(UUID)
|
# 确保user_id是字符串格式(UUID)
|
||||||
if not isinstance(user_id, str):
|
if not isinstance(user_id, str):
|
||||||
@@ -144,7 +153,7 @@ async def get_current_user_from_header(
|
|||||||
token = authorization.replace("Bearer ", "")
|
token = authorization.replace("Bearer ", "")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
payload = await AuthService.verify_token(token)
|
payload = await AuthService.verify_token(token, token_type="access")
|
||||||
user_id = payload.get("user_id")
|
user_id = payload.get("user_id")
|
||||||
|
|
||||||
if not user_id:
|
if not user_id:
|
||||||
|
|||||||
@@ -7,29 +7,47 @@ from typing import Optional
|
|||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
|
from src.config import config
|
||||||
|
|
||||||
|
|
||||||
def get_client_ip(request: Request) -> str:
|
def get_client_ip(request: Request) -> str:
|
||||||
"""
|
"""
|
||||||
获取客户端真实IP地址
|
获取客户端真实IP地址
|
||||||
|
|
||||||
按优先级检查:
|
按优先级检查:
|
||||||
1. X-Forwarded-For 头(支持代理链)
|
1. X-Forwarded-For 头(支持代理链,根据可信代理数量提取)
|
||||||
2. X-Real-IP 头(Nginx 代理)
|
2. X-Real-IP 头(Nginx 代理)
|
||||||
3. 直接客户端IP
|
3. 直接客户端IP
|
||||||
|
|
||||||
|
安全说明:
|
||||||
|
- 此函数根据 TRUSTED_PROXY_COUNT 配置来决定信任的代理层数
|
||||||
|
- 当 TRUSTED_PROXY_COUNT=0 时,不信任任何代理头,直接使用连接 IP
|
||||||
|
- 当服务直接暴露公网时,应设置 TRUSTED_PROXY_COUNT=0 以防止 IP 伪造
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: FastAPI Request 对象
|
request: FastAPI Request 对象
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 客户端IP地址,如果无法获取则返回 "unknown"
|
str: 客户端IP地址,如果无法获取则返回 "unknown"
|
||||||
"""
|
"""
|
||||||
|
trusted_proxy_count = config.trusted_proxy_count
|
||||||
|
|
||||||
|
# 如果不信任任何代理,直接返回连接 IP
|
||||||
|
if trusted_proxy_count == 0:
|
||||||
|
if request.client and request.client.host:
|
||||||
|
return request.client.host
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
# 优先检查 X-Forwarded-For 头(可能包含代理链)
|
# 优先检查 X-Forwarded-For 头(可能包含代理链)
|
||||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||||
if forwarded_for:
|
if forwarded_for:
|
||||||
# X-Forwarded-For 格式: "client, proxy1, proxy2",取第一个(真实客户端)
|
# X-Forwarded-For 格式: "client, proxy1, proxy2"
|
||||||
client_ip = forwarded_for.split(",")[0].strip()
|
# 从右往左数 trusted_proxy_count 个,取其左边的第一个
|
||||||
if client_ip:
|
ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()]
|
||||||
return client_ip
|
if len(ips) > trusted_proxy_count:
|
||||||
|
return ips[-(trusted_proxy_count + 1)]
|
||||||
|
elif ips:
|
||||||
|
return ips[0]
|
||||||
|
|
||||||
# 检查 X-Real-IP 头(通常由 Nginx 设置)
|
# 检查 X-Real-IP 头(通常由 Nginx 设置)
|
||||||
real_ip = request.headers.get("X-Real-IP")
|
real_ip = request.headers.get("X-Real-IP")
|
||||||
@@ -91,20 +109,32 @@ def get_request_metadata(request: Request) -> dict:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def extract_ip_from_headers(headers: dict) -> str:
|
def extract_ip_from_headers(headers: dict, trusted_proxy_count: Optional[int] = None) -> str:
|
||||||
"""
|
"""
|
||||||
从HTTP头字典中提取IP地址(用于中间件等场景)
|
从HTTP头字典中提取IP地址(用于中间件等场景)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
headers: HTTP头字典
|
headers: HTTP头字典
|
||||||
|
trusted_proxy_count: 可信代理层数,None 时使用配置值
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 客户端IP地址
|
str: 客户端IP地址
|
||||||
"""
|
"""
|
||||||
|
if trusted_proxy_count is None:
|
||||||
|
trusted_proxy_count = config.trusted_proxy_count
|
||||||
|
|
||||||
|
# 如果不信任任何代理,返回 unknown(调用方需要用其他方式获取连接 IP)
|
||||||
|
if trusted_proxy_count == 0:
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
# 检查 X-Forwarded-For
|
# 检查 X-Forwarded-For
|
||||||
forwarded_for = headers.get("x-forwarded-for", "")
|
forwarded_for = headers.get("x-forwarded-for", "")
|
||||||
if forwarded_for:
|
if forwarded_for:
|
||||||
return forwarded_for.split(",")[0].strip()
|
ips = [ip.strip() for ip in forwarded_for.split(",") if ip.strip()]
|
||||||
|
if len(ips) > trusted_proxy_count:
|
||||||
|
return ips[-(trusted_proxy_count + 1)]
|
||||||
|
elif ips:
|
||||||
|
return ips[0]
|
||||||
|
|
||||||
# 检查 X-Real-IP
|
# 检查 X-Real-IP
|
||||||
real_ip = headers.get("x-real-ip", "")
|
real_ip = headers.get("x-real-ip", "")
|
||||||
|
|||||||
421
tests/api/test_pipeline.py
Normal file
421
tests/api/test_pipeline.py
Normal file
@@ -0,0 +1,421 @@
|
|||||||
|
"""
|
||||||
|
API Pipeline 测试
|
||||||
|
|
||||||
|
测试 ApiRequestPipeline 的核心功能:
|
||||||
|
- 认证流程(API Key、JWT Token)
|
||||||
|
- 配额计算
|
||||||
|
- 审计日志记录
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from src.api.base.pipeline import ApiRequestPipeline
|
||||||
|
|
||||||
|
|
||||||
|
class TestPipelineQuotaCalculation:
|
||||||
|
"""测试 Pipeline 配额计算"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def pipeline(self) -> ApiRequestPipeline:
|
||||||
|
return ApiRequestPipeline()
|
||||||
|
|
||||||
|
def test_calculate_quota_remaining_with_quota(self, pipeline: ApiRequestPipeline) -> None:
|
||||||
|
"""测试有配额限制时计算剩余配额"""
|
||||||
|
mock_user = MagicMock()
|
||||||
|
mock_user.quota_usd = 100.0
|
||||||
|
mock_user.used_usd = 30.0
|
||||||
|
|
||||||
|
remaining = pipeline._calculate_quota_remaining(mock_user)
|
||||||
|
|
||||||
|
assert remaining == 70.0
|
||||||
|
|
||||||
|
def test_calculate_quota_remaining_no_quota(self, pipeline: ApiRequestPipeline) -> None:
|
||||||
|
"""测试无配额限制时返回 None"""
|
||||||
|
mock_user = MagicMock()
|
||||||
|
mock_user.quota_usd = None
|
||||||
|
mock_user.used_usd = 30.0
|
||||||
|
|
||||||
|
remaining = pipeline._calculate_quota_remaining(mock_user)
|
||||||
|
|
||||||
|
assert remaining is None
|
||||||
|
|
||||||
|
def test_calculate_quota_remaining_negative_quota(self, pipeline: ApiRequestPipeline) -> None:
|
||||||
|
"""测试负配额时返回 None"""
|
||||||
|
mock_user = MagicMock()
|
||||||
|
mock_user.quota_usd = -1
|
||||||
|
mock_user.used_usd = 0.0
|
||||||
|
|
||||||
|
remaining = pipeline._calculate_quota_remaining(mock_user)
|
||||||
|
|
||||||
|
assert remaining is None
|
||||||
|
|
||||||
|
def test_calculate_quota_remaining_exceeded(self, pipeline: ApiRequestPipeline) -> None:
|
||||||
|
"""测试配额已超时返回 0"""
|
||||||
|
mock_user = MagicMock()
|
||||||
|
mock_user.quota_usd = 100.0
|
||||||
|
mock_user.used_usd = 150.0
|
||||||
|
|
||||||
|
remaining = pipeline._calculate_quota_remaining(mock_user)
|
||||||
|
|
||||||
|
assert remaining == 0.0
|
||||||
|
|
||||||
|
def test_calculate_quota_remaining_none_user(self, pipeline: ApiRequestPipeline) -> None:
|
||||||
|
"""测试用户为 None 时返回 None"""
|
||||||
|
remaining = pipeline._calculate_quota_remaining(None)
|
||||||
|
|
||||||
|
assert remaining is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestPipelineAuditLogging:
|
||||||
|
"""测试 Pipeline 审计日志"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def pipeline(self) -> ApiRequestPipeline:
|
||||||
|
return ApiRequestPipeline()
|
||||||
|
|
||||||
|
def test_record_audit_event_success(self, pipeline: ApiRequestPipeline) -> None:
|
||||||
|
"""测试记录成功的审计事件"""
|
||||||
|
mock_context = MagicMock()
|
||||||
|
mock_context.db = MagicMock()
|
||||||
|
mock_context.user = MagicMock()
|
||||||
|
mock_context.user.id = "user-123"
|
||||||
|
mock_context.api_key = MagicMock()
|
||||||
|
mock_context.api_key.id = "key-123"
|
||||||
|
mock_context.request_id = "req-123"
|
||||||
|
mock_context.client_ip = "127.0.0.1"
|
||||||
|
mock_context.user_agent = "test-agent"
|
||||||
|
mock_context.request = MagicMock()
|
||||||
|
mock_context.request.method = "POST"
|
||||||
|
mock_context.request.url.path = "/v1/messages"
|
||||||
|
mock_context.start_time = 1000.0
|
||||||
|
|
||||||
|
mock_adapter = MagicMock()
|
||||||
|
mock_adapter.name = "test-adapter"
|
||||||
|
mock_adapter.audit_log_enabled = True
|
||||||
|
mock_adapter.audit_success_event = None
|
||||||
|
mock_adapter.audit_failure_event = None
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
pipeline.audit_service,
|
||||||
|
"log_event",
|
||||||
|
) as mock_log:
|
||||||
|
with patch("time.time", return_value=1001.0):
|
||||||
|
pipeline._record_audit_event(
|
||||||
|
mock_context, mock_adapter, success=True, status_code=200
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_log.assert_called_once()
|
||||||
|
call_kwargs = mock_log.call_args[1]
|
||||||
|
assert call_kwargs["user_id"] == "user-123"
|
||||||
|
assert call_kwargs["status_code"] == 200
|
||||||
|
|
||||||
|
def test_record_audit_event_failure(self, pipeline: ApiRequestPipeline) -> None:
|
||||||
|
"""测试记录失败的审计事件"""
|
||||||
|
mock_context = MagicMock()
|
||||||
|
mock_context.db = MagicMock()
|
||||||
|
mock_context.user = MagicMock()
|
||||||
|
mock_context.user.id = "user-123"
|
||||||
|
mock_context.api_key = MagicMock()
|
||||||
|
mock_context.api_key.id = "key-123"
|
||||||
|
mock_context.request_id = "req-123"
|
||||||
|
mock_context.client_ip = "127.0.0.1"
|
||||||
|
mock_context.user_agent = "test-agent"
|
||||||
|
mock_context.request = MagicMock()
|
||||||
|
mock_context.request.method = "POST"
|
||||||
|
mock_context.request.url.path = "/v1/messages"
|
||||||
|
mock_context.start_time = 1000.0
|
||||||
|
|
||||||
|
mock_adapter = MagicMock()
|
||||||
|
mock_adapter.name = "test-adapter"
|
||||||
|
mock_adapter.audit_log_enabled = True
|
||||||
|
mock_adapter.audit_success_event = None
|
||||||
|
mock_adapter.audit_failure_event = None
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
pipeline.audit_service,
|
||||||
|
"log_event",
|
||||||
|
) as mock_log:
|
||||||
|
with patch("time.time", return_value=1001.0):
|
||||||
|
pipeline._record_audit_event(
|
||||||
|
mock_context, mock_adapter, success=False, status_code=500, error="Internal error"
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_log.assert_called_once()
|
||||||
|
call_kwargs = mock_log.call_args[1]
|
||||||
|
assert call_kwargs["status_code"] == 500
|
||||||
|
assert call_kwargs["error_message"] == "Internal error"
|
||||||
|
|
||||||
|
def test_record_audit_event_no_db(self, pipeline: ApiRequestPipeline) -> None:
|
||||||
|
"""测试没有数据库会话时跳过审计"""
|
||||||
|
mock_context = MagicMock()
|
||||||
|
mock_context.db = None
|
||||||
|
|
||||||
|
mock_adapter = MagicMock()
|
||||||
|
mock_adapter.audit_log_enabled = True
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
pipeline.audit_service,
|
||||||
|
"log_event",
|
||||||
|
) as mock_log:
|
||||||
|
# 不应该抛出异常
|
||||||
|
pipeline._record_audit_event(mock_context, mock_adapter, success=True)
|
||||||
|
|
||||||
|
# 不应该调用 log_event
|
||||||
|
mock_log.assert_not_called()
|
||||||
|
|
||||||
|
def test_record_audit_event_disabled(self, pipeline: ApiRequestPipeline) -> None:
|
||||||
|
"""测试审计日志被禁用时跳过"""
|
||||||
|
mock_context = MagicMock()
|
||||||
|
mock_context.db = MagicMock()
|
||||||
|
|
||||||
|
mock_adapter = MagicMock()
|
||||||
|
mock_adapter.audit_log_enabled = False
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
pipeline.audit_service,
|
||||||
|
"log_event",
|
||||||
|
) as mock_log:
|
||||||
|
pipeline._record_audit_event(mock_context, mock_adapter, success=True)
|
||||||
|
|
||||||
|
mock_log.assert_not_called()
|
||||||
|
|
||||||
|
def test_record_audit_event_exception_handling(self, pipeline: ApiRequestPipeline) -> None:
|
||||||
|
"""测试审计日志异常不影响主流程"""
|
||||||
|
mock_context = MagicMock()
|
||||||
|
mock_context.db = MagicMock()
|
||||||
|
mock_context.user = MagicMock()
|
||||||
|
mock_context.user.id = "user-123"
|
||||||
|
mock_context.api_key = MagicMock()
|
||||||
|
mock_context.api_key.id = "key-123"
|
||||||
|
mock_context.request_id = "req-123"
|
||||||
|
mock_context.client_ip = "127.0.0.1"
|
||||||
|
mock_context.user_agent = "test-agent"
|
||||||
|
mock_context.request = MagicMock()
|
||||||
|
mock_context.request.method = "POST"
|
||||||
|
mock_context.request.url.path = "/v1/messages"
|
||||||
|
mock_context.start_time = 1000.0
|
||||||
|
|
||||||
|
mock_adapter = MagicMock()
|
||||||
|
mock_adapter.name = "test-adapter"
|
||||||
|
mock_adapter.audit_log_enabled = True
|
||||||
|
mock_adapter.audit_success_event = None
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
pipeline.audit_service,
|
||||||
|
"log_event",
|
||||||
|
side_effect=Exception("DB error"),
|
||||||
|
):
|
||||||
|
with patch("time.time", return_value=1001.0):
|
||||||
|
# 不应该抛出异常
|
||||||
|
pipeline._record_audit_event(mock_context, mock_adapter, success=True)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPipelineAuthentication:
|
||||||
|
"""测试 Pipeline 认证相关逻辑"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def pipeline(self) -> ApiRequestPipeline:
|
||||||
|
return ApiRequestPipeline()
|
||||||
|
|
||||||
|
def test_authenticate_client_missing_key(self, pipeline: ApiRequestPipeline) -> None:
|
||||||
|
"""测试缺少 API Key 时抛出异常"""
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.headers = {}
|
||||||
|
mock_request.url.path = "/v1/messages"
|
||||||
|
mock_request.state = MagicMock()
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
|
||||||
|
mock_adapter = MagicMock()
|
||||||
|
mock_adapter.extract_api_key = MagicMock(return_value=None)
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
pipeline._authenticate_client(mock_request, mock_db, mock_adapter)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 401
|
||||||
|
assert "API密钥" in exc_info.value.detail
|
||||||
|
|
||||||
|
def test_authenticate_client_invalid_key(self, pipeline: ApiRequestPipeline) -> None:
|
||||||
|
"""测试无效的 API Key"""
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.headers = {"Authorization": "Bearer sk-invalid"}
|
||||||
|
mock_request.url.path = "/v1/messages"
|
||||||
|
mock_request.state = MagicMock()
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
|
||||||
|
mock_adapter = MagicMock()
|
||||||
|
mock_adapter.extract_api_key = MagicMock(return_value="sk-invalid")
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
pipeline.auth_service,
|
||||||
|
"authenticate_api_key",
|
||||||
|
return_value=None,
|
||||||
|
):
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
pipeline._authenticate_client(mock_request, mock_db, mock_adapter)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 401
|
||||||
|
|
||||||
|
def test_authenticate_client_quota_exceeded(self, pipeline: ApiRequestPipeline) -> None:
|
||||||
|
"""测试配额超限时抛出异常"""
|
||||||
|
mock_user = MagicMock()
|
||||||
|
mock_user.id = "user-123"
|
||||||
|
mock_user.quota_usd = 100.0
|
||||||
|
mock_user.used_usd = 100.0
|
||||||
|
|
||||||
|
mock_api_key = MagicMock()
|
||||||
|
mock_api_key.id = "key-123"
|
||||||
|
mock_api_key.is_standalone = False
|
||||||
|
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.headers = {"Authorization": "Bearer sk-test"}
|
||||||
|
mock_request.url.path = "/v1/messages"
|
||||||
|
mock_request.state = MagicMock()
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
|
||||||
|
mock_adapter = MagicMock()
|
||||||
|
mock_adapter.extract_api_key = MagicMock(return_value="sk-test")
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
pipeline.auth_service,
|
||||||
|
"authenticate_api_key",
|
||||||
|
return_value=(mock_user, mock_api_key),
|
||||||
|
):
|
||||||
|
with patch.object(
|
||||||
|
pipeline.usage_service,
|
||||||
|
"check_user_quota",
|
||||||
|
return_value=(False, "配额不足"),
|
||||||
|
):
|
||||||
|
from src.core.exceptions import QuotaExceededException
|
||||||
|
|
||||||
|
with pytest.raises(QuotaExceededException):
|
||||||
|
pipeline._authenticate_client(mock_request, mock_db, mock_adapter)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPipelineAdminAuth:
|
||||||
|
"""测试管理员认证"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def pipeline(self) -> ApiRequestPipeline:
|
||||||
|
return ApiRequestPipeline()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_admin_missing_token(self, pipeline: ApiRequestPipeline) -> None:
|
||||||
|
"""测试缺少管理员令牌"""
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.headers = {}
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await pipeline._authenticate_admin(mock_request, mock_db)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 401
|
||||||
|
assert "管理员凭证" in exc_info.value.detail
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_admin_invalid_token(self, pipeline: ApiRequestPipeline) -> None:
|
||||||
|
"""测试无效的管理员令牌"""
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.headers = {"authorization": "Bearer invalid-token"}
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
pipeline.auth_service,
|
||||||
|
"verify_token",
|
||||||
|
side_effect=HTTPException(status_code=401, detail="Invalid token"),
|
||||||
|
):
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await pipeline._authenticate_admin(mock_request, mock_db)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 401
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_admin_success(self, pipeline: ApiRequestPipeline) -> None:
|
||||||
|
"""测试管理员认证成功"""
|
||||||
|
mock_user = MagicMock()
|
||||||
|
mock_user.id = "admin-123"
|
||||||
|
mock_user.is_active = True
|
||||||
|
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.headers = {"authorization": "Bearer valid-token"}
|
||||||
|
mock_request.state = MagicMock()
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
pipeline.auth_service,
|
||||||
|
"verify_token",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value={"user_id": "admin-123"},
|
||||||
|
):
|
||||||
|
result = await pipeline._authenticate_admin(mock_request, mock_db)
|
||||||
|
|
||||||
|
assert result == mock_user
|
||||||
|
assert mock_request.state.user_id == "admin-123"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_admin_lowercase_bearer(self, pipeline: ApiRequestPipeline) -> None:
|
||||||
|
"""测试 bearer (小写) 前缀也能正确解析"""
|
||||||
|
mock_user = MagicMock()
|
||||||
|
mock_user.id = "admin-123"
|
||||||
|
mock_user.is_active = True
|
||||||
|
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.headers = {"authorization": "bearer valid-token"}
|
||||||
|
mock_request.state = MagicMock()
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
pipeline.auth_service,
|
||||||
|
"verify_token",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value={"user_id": "admin-123"},
|
||||||
|
) as mock_verify:
|
||||||
|
result = await pipeline._authenticate_admin(mock_request, mock_db)
|
||||||
|
|
||||||
|
mock_verify.assert_awaited_once_with("valid-token", token_type="access")
|
||||||
|
assert result == mock_user
|
||||||
|
|
||||||
|
|
||||||
|
class TestPipelineUserAuth:
|
||||||
|
"""测试普通用户 JWT 认证"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def pipeline(self) -> ApiRequestPipeline:
|
||||||
|
return ApiRequestPipeline()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_user_lowercase_bearer(self, pipeline: ApiRequestPipeline) -> None:
|
||||||
|
"""测试 bearer (小写) 前缀也能正确解析"""
|
||||||
|
mock_user = MagicMock()
|
||||||
|
mock_user.id = "user-123"
|
||||||
|
mock_user.is_active = True
|
||||||
|
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.headers = {"authorization": "bearer valid-token"}
|
||||||
|
mock_request.state = MagicMock()
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
pipeline.auth_service,
|
||||||
|
"verify_token",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value={"user_id": "user-123"},
|
||||||
|
) as mock_verify:
|
||||||
|
result = await pipeline._authenticate_user(mock_request, mock_db)
|
||||||
|
|
||||||
|
mock_verify.assert_awaited_once_with("valid-token", token_type="access")
|
||||||
|
assert result == mock_user
|
||||||
1
tests/services/__init__.py
Normal file
1
tests/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""服务层测试"""
|
||||||
299
tests/services/test_auth.py
Normal file
299
tests/services/test_auth.py
Normal file
@@ -0,0 +1,299 @@
|
|||||||
|
"""
|
||||||
|
认证服务测试
|
||||||
|
|
||||||
|
测试 AuthService 的核心功能:
|
||||||
|
- JWT Token 创建和验证
|
||||||
|
- 用户登录认证
|
||||||
|
- API Key 认证
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
|
||||||
|
from src.services.auth.service import (
|
||||||
|
AuthService,
|
||||||
|
JWT_SECRET_KEY,
|
||||||
|
JWT_ALGORITHM,
|
||||||
|
JWT_EXPIRATION_HOURS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestJWTTokenCreation:
|
||||||
|
"""测试 JWT Token 创建"""
|
||||||
|
|
||||||
|
def test_create_access_token_contains_required_fields(self) -> None:
|
||||||
|
"""测试访问令牌包含必要字段"""
|
||||||
|
data = {"sub": "user123", "email": "test@example.com"}
|
||||||
|
token = AuthService.create_access_token(data)
|
||||||
|
|
||||||
|
# 解码验证
|
||||||
|
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||||
|
|
||||||
|
assert payload["sub"] == "user123"
|
||||||
|
assert payload["email"] == "test@example.com"
|
||||||
|
assert payload["type"] == "access"
|
||||||
|
assert "exp" in payload
|
||||||
|
|
||||||
|
def test_create_access_token_expiration(self) -> None:
|
||||||
|
"""测试访问令牌过期时间正确"""
|
||||||
|
data = {"sub": "user123"}
|
||||||
|
token = AuthService.create_access_token(data)
|
||||||
|
|
||||||
|
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||||
|
|
||||||
|
# 验证过期时间在预期范围内(允许1分钟误差)
|
||||||
|
exp_time = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
|
||||||
|
expected_exp = datetime.now(timezone.utc) + timedelta(hours=JWT_EXPIRATION_HOURS)
|
||||||
|
|
||||||
|
assert abs((exp_time - expected_exp).total_seconds()) < 60
|
||||||
|
|
||||||
|
def test_create_refresh_token_type(self) -> None:
|
||||||
|
"""测试刷新令牌类型正确"""
|
||||||
|
data = {"sub": "user123"}
|
||||||
|
token = AuthService.create_refresh_token(data)
|
||||||
|
|
||||||
|
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||||
|
|
||||||
|
assert payload["type"] == "refresh"
|
||||||
|
|
||||||
|
def test_create_refresh_token_longer_expiration(self) -> None:
|
||||||
|
"""测试刷新令牌过期时间更长"""
|
||||||
|
data = {"sub": "user123"}
|
||||||
|
access_token = AuthService.create_access_token(data)
|
||||||
|
refresh_token = AuthService.create_refresh_token(data)
|
||||||
|
|
||||||
|
access_payload = jwt.decode(access_token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||||
|
refresh_payload = jwt.decode(refresh_token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||||
|
|
||||||
|
# 刷新令牌应该比访问令牌过期时间更长
|
||||||
|
assert refresh_payload["exp"] > access_payload["exp"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestJWTTokenVerification:
|
||||||
|
"""测试 JWT Token 验证"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_verify_valid_access_token(self) -> None:
|
||||||
|
"""测试验证有效的访问令牌"""
|
||||||
|
data = {"sub": "user123", "email": "test@example.com"}
|
||||||
|
token = AuthService.create_access_token(data)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"src.services.auth.service.JWTBlacklistService.is_blacklisted",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=False,
|
||||||
|
):
|
||||||
|
payload = await AuthService.verify_token(token, token_type="access")
|
||||||
|
|
||||||
|
assert payload["sub"] == "user123"
|
||||||
|
assert payload["type"] == "access"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_verify_expired_token_raises_error(self) -> None:
|
||||||
|
"""测试验证过期令牌抛出异常"""
|
||||||
|
# 创建一个已过期的 token
|
||||||
|
data = {"sub": "user123", "type": "access"}
|
||||||
|
expire = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||||
|
data["exp"] = expire
|
||||||
|
expired_token = jwt.encode(data, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await AuthService.verify_token(expired_token)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 401
|
||||||
|
assert "过期" in exc_info.value.detail
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_verify_invalid_token_raises_error(self) -> None:
|
||||||
|
"""测试验证无效令牌抛出异常"""
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await AuthService.verify_token("invalid.token.here")
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 401
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_verify_wrong_token_type_raises_error(self) -> None:
|
||||||
|
"""测试令牌类型不匹配抛出异常"""
|
||||||
|
data = {"sub": "user123"}
|
||||||
|
refresh_token = AuthService.create_refresh_token(data)
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"src.services.auth.service.JWTBlacklistService.is_blacklisted",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=False,
|
||||||
|
):
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await AuthService.verify_token(refresh_token, token_type="access")
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 401
|
||||||
|
assert "类型错误" in exc_info.value.detail
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_verify_blacklisted_token_raises_error(self) -> None:
|
||||||
|
"""测试已撤销的令牌抛出异常"""
|
||||||
|
data = {"sub": "user123"}
|
||||||
|
token = AuthService.create_access_token(data)
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"src.services.auth.service.JWTBlacklistService.is_blacklisted",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=True,
|
||||||
|
):
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
await AuthService.verify_token(token)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 401
|
||||||
|
assert "撤销" in exc_info.value.detail
|
||||||
|
|
||||||
|
|
||||||
|
class TestUserAuthentication:
|
||||||
|
"""测试用户登录认证"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_user_success(self) -> None:
|
||||||
|
"""测试用户登录成功"""
|
||||||
|
# Mock 数据库和用户对象
|
||||||
|
mock_user = MagicMock()
|
||||||
|
mock_user.id = "user-123"
|
||||||
|
mock_user.email = "test@example.com"
|
||||||
|
mock_user.is_active = True
|
||||||
|
mock_user.verify_password.return_value = True
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"src.services.auth.service.UserCacheService.invalidate_user_cache",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
):
|
||||||
|
result = await AuthService.authenticate_user(mock_db, "test@example.com", "password123")
|
||||||
|
|
||||||
|
assert result == mock_user
|
||||||
|
mock_user.verify_password.assert_called_once_with("password123")
|
||||||
|
mock_db.commit.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_user_not_found(self) -> None:
|
||||||
|
"""测试用户不存在"""
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.query.return_value.filter.return_value.first.return_value = None
|
||||||
|
|
||||||
|
result = await AuthService.authenticate_user(mock_db, "nonexistent@example.com", "password")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_user_wrong_password(self) -> None:
|
||||||
|
"""测试密码错误"""
|
||||||
|
mock_user = MagicMock()
|
||||||
|
mock_user.email = "test@example.com"
|
||||||
|
mock_user.verify_password.return_value = False
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||||
|
|
||||||
|
result = await AuthService.authenticate_user(mock_db, "test@example.com", "wrongpassword")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authenticate_user_inactive(self) -> None:
|
||||||
|
"""测试用户已禁用"""
|
||||||
|
mock_user = MagicMock()
|
||||||
|
mock_user.email = "test@example.com"
|
||||||
|
mock_user.is_active = False
|
||||||
|
mock_user.verify_password.return_value = True
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||||
|
|
||||||
|
result = await AuthService.authenticate_user(mock_db, "test@example.com", "password123")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestAPIKeyAuthentication:
|
||||||
|
"""测试 API Key 认证"""
|
||||||
|
|
||||||
|
def test_authenticate_api_key_success(self) -> None:
|
||||||
|
"""测试 API Key 认证成功"""
|
||||||
|
mock_user = MagicMock()
|
||||||
|
mock_user.id = "user-123"
|
||||||
|
mock_user.email = "test@example.com"
|
||||||
|
mock_user.is_active = True
|
||||||
|
|
||||||
|
mock_api_key = MagicMock()
|
||||||
|
mock_api_key.is_active = True
|
||||||
|
mock_api_key.expires_at = None
|
||||||
|
mock_api_key.user = mock_user
|
||||||
|
mock_api_key.balance_used_usd = 0.0
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = (
|
||||||
|
mock_api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"):
|
||||||
|
with patch(
|
||||||
|
"src.services.auth.service.ApiKeyService.check_balance",
|
||||||
|
return_value=(True, 100.0),
|
||||||
|
):
|
||||||
|
result = AuthService.authenticate_api_key(mock_db, "sk-test-key")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result[0] == mock_user
|
||||||
|
assert result[1] == mock_api_key
|
||||||
|
|
||||||
|
def test_authenticate_api_key_not_found(self) -> None:
|
||||||
|
"""测试 API Key 不存在"""
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = (
|
||||||
|
None
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"):
|
||||||
|
result = AuthService.authenticate_api_key(mock_db, "sk-invalid-key")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_authenticate_api_key_inactive(self) -> None:
|
||||||
|
"""测试 API Key 已禁用"""
|
||||||
|
mock_api_key = MagicMock()
|
||||||
|
mock_api_key.is_active = False
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = (
|
||||||
|
mock_api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"):
|
||||||
|
result = AuthService.authenticate_api_key(mock_db, "sk-inactive-key")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_authenticate_api_key_expired(self) -> None:
|
||||||
|
"""测试 API Key 已过期"""
|
||||||
|
mock_api_key = MagicMock()
|
||||||
|
mock_api_key.is_active = True
|
||||||
|
mock_api_key.expires_at = datetime.now(timezone.utc) - timedelta(days=1)
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = (
|
||||||
|
mock_api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"):
|
||||||
|
result = AuthService.authenticate_api_key(mock_db, "sk-expired-key")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
292
tests/services/test_usage_service.py
Normal file
292
tests/services/test_usage_service.py
Normal file
@@ -0,0 +1,292 @@
|
|||||||
|
"""
|
||||||
|
UsageService 测试
|
||||||
|
|
||||||
|
测试用量统计服务的核心功能:
|
||||||
|
- 成本计算
|
||||||
|
- 配额检查
|
||||||
|
- 用量统计查询
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch, AsyncMock
|
||||||
|
|
||||||
|
from src.services.usage.service import UsageService
|
||||||
|
|
||||||
|
|
||||||
|
class TestCostCalculation:
|
||||||
|
"""测试成本计算"""
|
||||||
|
|
||||||
|
def test_calculate_cost_basic(self) -> None:
|
||||||
|
"""测试基础成本计算"""
|
||||||
|
# 价格:输入 $3/1M, 输出 $15/1M
|
||||||
|
result = UsageService.calculate_cost(
|
||||||
|
input_tokens=1000,
|
||||||
|
output_tokens=500,
|
||||||
|
input_price_per_1m=3.0,
|
||||||
|
output_price_per_1m=15.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost, request_cost, total_cost = result
|
||||||
|
|
||||||
|
# 1000 tokens * $3 / 1M = $0.003
|
||||||
|
assert abs(input_cost - 0.003) < 0.0001
|
||||||
|
# 500 tokens * $15 / 1M = $0.0075
|
||||||
|
assert abs(output_cost - 0.0075) < 0.0001
|
||||||
|
# Total = $0.003 + $0.0075 = $0.0105
|
||||||
|
assert abs(total_cost - 0.0105) < 0.0001
|
||||||
|
|
||||||
|
def test_calculate_cost_with_cache(self) -> None:
|
||||||
|
"""测试带缓存的成本计算"""
|
||||||
|
result = UsageService.calculate_cost(
|
||||||
|
input_tokens=1000,
|
||||||
|
output_tokens=500,
|
||||||
|
input_price_per_1m=3.0,
|
||||||
|
output_price_per_1m=15.0,
|
||||||
|
cache_creation_input_tokens=200,
|
||||||
|
cache_read_input_tokens=300,
|
||||||
|
cache_creation_price_per_1m=3.75, # 1.25x input price
|
||||||
|
cache_read_price_per_1m=0.3, # 0.1x input price
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
input_cost,
|
||||||
|
output_cost,
|
||||||
|
cache_creation_cost,
|
||||||
|
cache_read_cost,
|
||||||
|
cache_cost,
|
||||||
|
request_cost,
|
||||||
|
total_cost,
|
||||||
|
) = result
|
||||||
|
|
||||||
|
# 验证缓存成本被计算
|
||||||
|
assert cache_creation_cost > 0
|
||||||
|
assert cache_read_cost > 0
|
||||||
|
assert cache_cost == cache_creation_cost + cache_read_cost
|
||||||
|
|
||||||
|
def test_calculate_cost_with_request_price(self) -> None:
|
||||||
|
"""测试按次计费"""
|
||||||
|
result = UsageService.calculate_cost(
|
||||||
|
input_tokens=1000,
|
||||||
|
output_tokens=500,
|
||||||
|
input_price_per_1m=3.0,
|
||||||
|
output_price_per_1m=15.0,
|
||||||
|
price_per_request=0.01,
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
input_cost,
|
||||||
|
output_cost,
|
||||||
|
cache_creation_cost,
|
||||||
|
cache_read_cost,
|
||||||
|
cache_cost,
|
||||||
|
request_cost,
|
||||||
|
total_cost,
|
||||||
|
) = result
|
||||||
|
|
||||||
|
assert request_cost == 0.01
|
||||||
|
# Total 包含 request_cost
|
||||||
|
assert total_cost == input_cost + output_cost + request_cost
|
||||||
|
|
||||||
|
def test_calculate_cost_zero_tokens(self) -> None:
|
||||||
|
"""测试零 token 的成本计算"""
|
||||||
|
result = UsageService.calculate_cost(
|
||||||
|
input_tokens=0,
|
||||||
|
output_tokens=0,
|
||||||
|
input_price_per_1m=3.0,
|
||||||
|
output_price_per_1m=15.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
input_cost,
|
||||||
|
output_cost,
|
||||||
|
cache_creation_cost,
|
||||||
|
cache_read_cost,
|
||||||
|
cache_cost,
|
||||||
|
request_cost,
|
||||||
|
total_cost,
|
||||||
|
) = result
|
||||||
|
|
||||||
|
assert input_cost == 0
|
||||||
|
assert output_cost == 0
|
||||||
|
assert total_cost == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestQuotaCheck:
|
||||||
|
"""测试配额检查"""
|
||||||
|
|
||||||
|
def test_check_user_quota_sufficient(self) -> None:
|
||||||
|
"""测试配额充足"""
|
||||||
|
mock_user = MagicMock()
|
||||||
|
mock_user.quota_usd = 100.0
|
||||||
|
mock_user.used_usd = 30.0
|
||||||
|
mock_user.role = MagicMock()
|
||||||
|
mock_user.role.value = "user"
|
||||||
|
|
||||||
|
mock_api_key = MagicMock()
|
||||||
|
mock_api_key.is_standalone = False
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
|
||||||
|
is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key)
|
||||||
|
|
||||||
|
assert is_ok is True
|
||||||
|
|
||||||
|
def test_check_user_quota_exceeded(self) -> None:
|
||||||
|
"""测试配额超限(当有预估成本时)"""
|
||||||
|
mock_user = MagicMock()
|
||||||
|
mock_user.quota_usd = 100.0
|
||||||
|
mock_user.used_usd = 99.0 # 接近配额上限
|
||||||
|
mock_user.role = MagicMock()
|
||||||
|
mock_user.role.value = "user"
|
||||||
|
|
||||||
|
mock_api_key = MagicMock()
|
||||||
|
mock_api_key.is_standalone = False
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
|
||||||
|
# 当预估成本超过剩余配额时应该返回 False
|
||||||
|
is_ok, message = UsageService.check_user_quota(
|
||||||
|
mock_db, mock_user, estimated_cost=5.0, api_key=mock_api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
assert is_ok is False
|
||||||
|
assert "配额" in message
|
||||||
|
|
||||||
|
def test_check_user_quota_no_limit(self) -> None:
|
||||||
|
"""测试无配额限制(None)"""
|
||||||
|
mock_user = MagicMock()
|
||||||
|
mock_user.quota_usd = None
|
||||||
|
mock_user.used_usd = 1000.0
|
||||||
|
mock_user.role = MagicMock()
|
||||||
|
mock_user.role.value = "user"
|
||||||
|
|
||||||
|
mock_api_key = MagicMock()
|
||||||
|
mock_api_key.is_standalone = False
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
|
||||||
|
is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key)
|
||||||
|
|
||||||
|
assert is_ok is True
|
||||||
|
|
||||||
|
def test_check_user_quota_admin_bypass(self) -> None:
|
||||||
|
"""测试管理员绕过配额检查"""
|
||||||
|
from src.models.database import UserRole
|
||||||
|
|
||||||
|
mock_user = MagicMock()
|
||||||
|
mock_user.quota_usd = 0.0
|
||||||
|
mock_user.used_usd = 1000.0
|
||||||
|
mock_user.role = UserRole.ADMIN
|
||||||
|
|
||||||
|
mock_api_key = MagicMock()
|
||||||
|
mock_api_key.is_standalone = False
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
|
||||||
|
is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key)
|
||||||
|
|
||||||
|
assert is_ok is True
|
||||||
|
|
||||||
|
def test_check_standalone_api_key_balance(self) -> None:
|
||||||
|
"""测试独立 API Key 余额检查"""
|
||||||
|
mock_user = MagicMock()
|
||||||
|
mock_user.quota_usd = 0.0
|
||||||
|
mock_user.used_usd = 0.0
|
||||||
|
mock_user.role = MagicMock()
|
||||||
|
mock_user.role.value = "user"
|
||||||
|
|
||||||
|
mock_api_key = MagicMock()
|
||||||
|
mock_api_key.is_standalone = True
|
||||||
|
mock_api_key.current_balance_usd = 50.0
|
||||||
|
mock_api_key.balance_used_usd = 10.0
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
|
||||||
|
is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key)
|
||||||
|
|
||||||
|
assert is_ok is True
|
||||||
|
|
||||||
|
def test_check_standalone_api_key_insufficient_balance(self) -> None:
|
||||||
|
"""测试独立 API Key 余额不足"""
|
||||||
|
mock_user = MagicMock()
|
||||||
|
mock_user.quota_usd = 100.0
|
||||||
|
mock_user.used_usd = 0.0
|
||||||
|
mock_user.role = MagicMock()
|
||||||
|
mock_user.role.value = "user"
|
||||||
|
|
||||||
|
mock_api_key = MagicMock()
|
||||||
|
mock_api_key.is_standalone = True
|
||||||
|
mock_api_key.current_balance_usd = 10.0
|
||||||
|
mock_api_key.balance_used_usd = 9.0 # 剩余 $1
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
|
||||||
|
# 需要 mock ApiKeyService.get_remaining_balance
|
||||||
|
with patch(
|
||||||
|
"src.services.user.apikey.ApiKeyService.get_remaining_balance",
|
||||||
|
return_value=1.0,
|
||||||
|
):
|
||||||
|
# 预估成本 $5 超过剩余余额 $1
|
||||||
|
is_ok, message = UsageService.check_user_quota(
|
||||||
|
mock_db, mock_user, estimated_cost=5.0, api_key=mock_api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
assert is_ok is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestUsageStatistics:
|
||||||
|
"""测试用量统计查询
|
||||||
|
|
||||||
|
注意:get_usage_summary 方法内部使用了数据库方言特定的日期函数,
|
||||||
|
需要真实数据库或更复杂的 mock。这里只测试方法存在性。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_get_usage_summary_exists(self) -> None:
|
||||||
|
"""测试 get_usage_summary 方法存在"""
|
||||||
|
assert hasattr(UsageService, "get_usage_summary")
|
||||||
|
assert callable(getattr(UsageService, "get_usage_summary"))
|
||||||
|
|
||||||
|
|
||||||
|
class TestHelperMethods:
|
||||||
|
"""测试辅助方法"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_rate_multiplier_and_free_tier_default(self) -> None:
|
||||||
|
"""测试默认费率倍数"""
|
||||||
|
mock_db = MagicMock()
|
||||||
|
# 模拟未找到 provider_api_key
|
||||||
|
mock_db.query.return_value.filter.return_value.first.return_value = None
|
||||||
|
|
||||||
|
rate_multiplier, is_free_tier = await UsageService._get_rate_multiplier_and_free_tier(
|
||||||
|
mock_db, provider_api_key_id=None, provider_id=None
|
||||||
|
)
|
||||||
|
|
||||||
|
assert rate_multiplier == 1.0
|
||||||
|
assert is_free_tier is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_rate_multiplier_from_provider_api_key(self) -> None:
|
||||||
|
"""测试从 ProviderAPIKey 获取费率倍数"""
|
||||||
|
mock_provider_api_key = MagicMock()
|
||||||
|
mock_provider_api_key.rate_multiplier = 0.8
|
||||||
|
|
||||||
|
mock_endpoint = MagicMock()
|
||||||
|
mock_endpoint.provider_id = "provider-123"
|
||||||
|
|
||||||
|
mock_provider = MagicMock()
|
||||||
|
mock_provider.billing_type = "standard"
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
# 第一次查询返回 provider_api_key
|
||||||
|
mock_db.query.return_value.filter.return_value.first.side_effect = [
|
||||||
|
mock_provider_api_key,
|
||||||
|
mock_endpoint,
|
||||||
|
mock_provider,
|
||||||
|
]
|
||||||
|
|
||||||
|
rate_multiplier, is_free_tier = await UsageService._get_rate_multiplier_and_free_tier(
|
||||||
|
mock_db, provider_api_key_id="pak-123", provider_id=None
|
||||||
|
)
|
||||||
|
|
||||||
|
assert rate_multiplier == 0.8
|
||||||
|
assert is_free_tier is False
|
||||||
Reference in New Issue
Block a user