4 Commits

Author SHA1 Message Date
fawney19
1dac4cb156 refactor: optimize provider query and stats aggregation logic 2025-12-17 16:41:10 +08:00
fawney19
50abb55c94 fix(models): clear form state when loading model data for edit
Reset model selection, search query, and expanded provider state
when switching to edit mode to prevent stale UI state carrying over
from previous operations. Also ensure tieredPricing is properly set
or reset based on model data.
2025-12-16 18:42:58 +08:00
fawney19
73d3c9d3e4 ui(models): display model ID in global model form dialog
Show model ID below model name in the dropdown list for better clarity
when selecting models, with appropriate text styling for selected state.
2025-12-16 18:36:23 +08:00
fawney19
d24c3885ab feat(admin): add config and user data import/export functionality
Add comprehensive import/export endpoints for:
- Provider and model configuration (with key decryption for export)
- User data and API keys (preserving encrypted data)

Includes merge modes (skip/overwrite/error) for conflict handling,
10MB size limit for imports, and automatic cache invalidation.

Also fix optional field in GlobalModelResponse tiered_pricing.
2025-12-16 18:33:14 +08:00
24 changed files with 3341 additions and 601 deletions

View File

@@ -20,10 +20,10 @@ depends_on = None
def upgrade() -> None:
# Create ENUM types
op.execute("CREATE TYPE userrole AS ENUM ('admin', 'user')")
# Create ENUM types (with IF NOT EXISTS for idempotency)
op.execute("DO $$ BEGIN CREATE TYPE userrole AS ENUM ('admin', 'user'); EXCEPTION WHEN duplicate_object THEN NULL; END $$")
op.execute(
"CREATE TYPE providerbillingtype AS ENUM ('monthly_quota', 'pay_as_you_go', 'free_tier')"
"DO $$ BEGIN CREATE TYPE providerbillingtype AS ENUM ('monthly_quota', 'pay_as_you_go', 'free_tier'); EXCEPTION WHEN duplicate_object THEN NULL; END $$"
)
# ==================== users ====================
@@ -35,7 +35,7 @@ def upgrade() -> None:
sa.Column("password_hash", sa.String(255), nullable=False),
sa.Column(
"role",
sa.Enum("admin", "user", name="userrole", create_type=False),
postgresql.ENUM("admin", "user", name="userrole", create_type=False),
nullable=False,
server_default="user",
),
@@ -67,7 +67,7 @@ def upgrade() -> None:
sa.Column("website", sa.String(500), nullable=True),
sa.Column(
"billing_type",
sa.Enum(
postgresql.ENUM(
"monthly_quota", "pay_as_you_go", "free_tier", name="providerbillingtype", create_type=False
),
nullable=False,

View File

@@ -1,5 +1,179 @@
import apiClient from './client'
// 配置导出数据结构
export interface ConfigExportData {
version: string
exported_at: string
global_models: GlobalModelExport[]
providers: ProviderExport[]
}
// 用户导出数据结构
export interface UsersExportData {
version: string
exported_at: string
users: UserExport[]
}
export interface UserExport {
email: string
username: string
password_hash: string
role: string
allowed_providers?: string[] | null
allowed_endpoints?: string[] | null
allowed_models?: string[] | null
model_capability_settings?: any
quota_usd?: number | null
used_usd?: number
total_usd?: number
is_active: boolean
api_keys: UserApiKeyExport[]
}
export interface UserApiKeyExport {
key_hash: string
key_encrypted?: string | null
name?: string | null
is_standalone: boolean
balance_used_usd?: number
current_balance_usd?: number | null
allowed_providers?: string[] | null
allowed_endpoints?: string[] | null
allowed_api_formats?: string[] | null
allowed_models?: string[] | null
rate_limit?: number
concurrent_limit?: number | null
force_capabilities?: any
is_active: boolean
auto_delete_on_expiry?: boolean
total_requests?: number
total_cost_usd?: number
}
export interface GlobalModelExport {
name: string
display_name: string
default_price_per_request?: number | null
default_tiered_pricing: any
supported_capabilities?: string[] | null
config?: any
is_active: boolean
}
export interface ProviderExport {
name: string
display_name: string
description?: string | null
website?: string | null
billing_type?: string | null
monthly_quota_usd?: number | null
quota_reset_day?: number
rpm_limit?: number | null
provider_priority?: number
is_active: boolean
rate_limit?: number | null
concurrent_limit?: number | null
config?: any
endpoints: EndpointExport[]
models: ModelExport[]
}
export interface EndpointExport {
api_format: string
base_url: string
headers?: any
timeout?: number
max_retries?: number
max_concurrent?: number | null
rate_limit?: number | null
is_active: boolean
custom_path?: string | null
config?: any
keys: KeyExport[]
}
export interface KeyExport {
api_key: string
name?: string | null
note?: string | null
rate_multiplier?: number
internal_priority?: number
global_priority?: number | null
max_concurrent?: number | null
rate_limit?: number | null
daily_limit?: number | null
monthly_limit?: number | null
allowed_models?: string[] | null
capabilities?: any
is_active: boolean
}
export interface ModelExport {
global_model_name: string | null
provider_model_name: string
provider_model_aliases?: any
price_per_request?: number | null
tiered_pricing?: any
supports_vision?: boolean | null
supports_function_calling?: boolean | null
supports_streaming?: boolean | null
supports_extended_thinking?: boolean | null
supports_image_generation?: boolean | null
is_active: boolean
config?: any
}
// Provider 模型查询响应
export interface ProviderModelsQueryResponse {
success: boolean
data: {
models: Array<{
id: string
object?: string
created?: number
owned_by?: string
display_name?: string
api_format?: string
}>
error?: string
}
provider: {
id: string
name: string
display_name: string
}
}
export interface ConfigImportRequest extends ConfigExportData {
merge_mode: 'skip' | 'overwrite' | 'error'
}
export interface UsersImportRequest extends UsersExportData {
merge_mode: 'skip' | 'overwrite' | 'error'
}
export interface UsersImportResponse {
message: string
stats: {
users: { created: number; updated: number; skipped: number }
api_keys: { created: number; skipped: number }
errors: string[]
}
}
export interface ConfigImportResponse {
message: string
stats: {
global_models: { created: number; updated: number; skipped: number }
providers: { created: number; updated: number; skipped: number }
endpoints: { created: number; updated: number; skipped: number }
keys: { created: number; updated: number; skipped: number }
models: { created: number; updated: number; skipped: number }
errors: string[]
}
}
// API密钥管理相关接口定义
export interface AdminApiKey {
id: string // UUID
@@ -173,5 +347,44 @@ export const adminApi = {
'/api/admin/system/api-formats'
)
return response.data
},
// 导出配置
async exportConfig(): Promise<ConfigExportData> {
const response = await apiClient.get<ConfigExportData>('/api/admin/system/config/export')
return response.data
},
// 导入配置
async importConfig(data: ConfigImportRequest): Promise<ConfigImportResponse> {
const response = await apiClient.post<ConfigImportResponse>(
'/api/admin/system/config/import',
data
)
return response.data
},
// 导出用户数据
async exportUsers(): Promise<UsersExportData> {
const response = await apiClient.get<UsersExportData>('/api/admin/system/users/export')
return response.data
},
// 导入用户数据
async importUsers(data: UsersImportRequest): Promise<UsersImportResponse> {
const response = await apiClient.post<UsersImportResponse>(
'/api/admin/system/users/import',
data
)
return response.data
},
// 查询 Provider 可用模型(从上游 API 获取)
async queryProviderModels(providerId: string, apiKeyId?: string): Promise<ProviderModelsQueryResponse> {
const response = await apiClient.post<ProviderModelsQueryResponse>(
'/api/admin/provider-query/models',
{ provider_id: providerId, api_key_id: apiKeyId }
)
return response.data
}
}

View File

@@ -1,3 +1,25 @@
// API 格式常量
export const API_FORMATS = {
CLAUDE: 'CLAUDE',
CLAUDE_CLI: 'CLAUDE_CLI',
OPENAI: 'OPENAI',
OPENAI_CLI: 'OPENAI_CLI',
GEMINI: 'GEMINI',
GEMINI_CLI: 'GEMINI_CLI',
} as const
export type APIFormat = typeof API_FORMATS[keyof typeof API_FORMATS]
// API 格式显示名称映射按品牌分组API 在前CLI 在后)
export const API_FORMAT_LABELS: Record<string, string> = {
[API_FORMATS.CLAUDE]: 'Claude',
[API_FORMATS.CLAUDE_CLI]: 'Claude CLI',
[API_FORMATS.OPENAI]: 'OpenAI',
[API_FORMATS.OPENAI_CLI]: 'OpenAI CLI',
[API_FORMATS.GEMINI]: 'Gemini',
[API_FORMATS.GEMINI_CLI]: 'Gemini CLI',
}
export interface ProviderEndpoint {
id: string
provider_id: string
@@ -214,6 +236,7 @@ export interface ConcurrencyStatus {
export interface ProviderModelAlias {
name: string
priority: number // 优先级(数字越小优先级越高)
api_formats?: string[] // 作用域(适用的 API 格式),为空表示对所有格式生效
}
export interface Model {

View File

@@ -68,13 +68,19 @@
<div
v-for="model in group.models"
:key="model.modelId"
class="flex items-center gap-2 pl-7 pr-2.5 py-1.5 cursor-pointer text-xs border-t"
class="flex flex-col gap-0.5 pl-7 pr-2.5 py-1.5 cursor-pointer text-xs border-t"
:class="selectedModel?.modelId === model.modelId && selectedModel?.providerId === model.providerId
? 'bg-primary text-primary-foreground'
: 'hover:bg-muted'"
@click="selectModel(model)"
>
<span class="truncate">{{ model.modelName }}</span>
<span class="truncate font-medium">{{ model.modelName }}</span>
<span
class="truncate text-[10px]"
:class="selectedModel?.modelId === model.modelId && selectedModel?.providerId === model.providerId
? 'text-primary-foreground/70'
: 'text-muted-foreground'"
>{{ model.modelId }}</span>
</div>
</div>
</div>
@@ -390,15 +396,13 @@ interface ProviderGroup {
const groupedModels = computed(() => {
let models = allModels.value.filter(m => !m.deprecated)
// 搜索(支持空格分隔的多关键词 AND 搜索)
if (searchQuery.value) {
const query = searchQuery.value.toLowerCase()
models = models.filter(model =>
model.providerId.toLowerCase().includes(query) ||
model.providerName.toLowerCase().includes(query) ||
model.modelId.toLowerCase().includes(query) ||
model.modelName.toLowerCase().includes(query) ||
model.family?.toLowerCase().includes(query)
)
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
models = models.filter(model => {
const searchableText = `${model.providerId} ${model.providerName} ${model.modelId} ${model.modelName} ${model.family || ''}`.toLowerCase()
return keywords.every(keyword => searchableText.includes(keyword))
})
}
// 按提供商分组
@@ -415,14 +419,16 @@ const groupedModels = computed(() => {
}
// 转换为数组并排序
let result = Array.from(groups.values())
const result = Array.from(groups.values())
// 如果有搜索词,把提供商名称/ID匹配的排在前面
if (searchQuery.value) {
const query = searchQuery.value.toLowerCase()
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
result.sort((a, b) => {
const aProviderMatch = a.providerId.toLowerCase().includes(query) || a.providerName.toLowerCase().includes(query)
const bProviderMatch = b.providerId.toLowerCase().includes(query) || b.providerName.toLowerCase().includes(query)
const aText = `${a.providerId} ${a.providerName}`.toLowerCase()
const bText = `${b.providerId} ${b.providerName}`.toLowerCase()
const aProviderMatch = keywords.some(k => aText.includes(k))
const bProviderMatch = keywords.some(k => bText.includes(k))
if (aProviderMatch && !bProviderMatch) return -1
if (!aProviderMatch && bProviderMatch) return 1
return a.providerName.localeCompare(b.providerName)
@@ -598,6 +604,11 @@ function resetForm() {
// 加载模型数据(编辑模式)
function loadModelData() {
if (!props.model) return
// 先重置创建模式的残留状态
selectedModel.value = null
searchQuery.value = ''
expandedProvider.value = null
form.value = {
name: props.model.name,
display_name: props.model.display_name,
@@ -606,9 +617,10 @@ function loadModelData() {
config: props.model.config ? { ...props.model.config } : { streaming: true },
is_active: props.model.is_active,
}
if (props.model.default_tiered_pricing) {
tieredPricing.value = JSON.parse(JSON.stringify(props.model.default_tiered_pricing))
}
// 确保 tieredPricing 也被正确设置或重置
tieredPricing.value = props.model.default_tiered_pricing
? JSON.parse(JSON.stringify(props.model.default_tiered_pricing))
: null
}
// 使用 useFormDialog 统一处理对话框逻辑

View File

@@ -526,7 +526,14 @@
@edit-model="handleEditModel"
@delete-model="handleDeleteModel"
@batch-assign="handleBatchAssign"
@manage-alias="handleManageAlias"
/>
<!-- 模型名称映射 -->
<ModelAliasesTab
v-if="provider"
:key="`aliases-${provider.id}`"
:provider="provider"
@refresh="handleRelatedDataRefresh"
/>
</div>
</template>
@@ -629,16 +636,6 @@
@update:open="batchAssignDialogOpen = $event"
@changed="handleBatchAssignChanged"
/>
<!-- 模型别名管理对话框 -->
<ModelAliasDialog
v-if="open && provider"
:open="aliasDialogOpen"
:provider-id="provider.id"
:model="aliasEditingModel"
@update:open="aliasDialogOpen = $event"
@saved="handleAliasSaved"
/>
</template>
<script setup lang="ts">
@@ -667,8 +664,8 @@ import {
KeyFormDialog,
KeyAllowedModelsDialog,
ModelsTab,
BatchAssignModelsDialog,
ModelAliasDialog
ModelAliasesTab,
BatchAssignModelsDialog
} from '@/features/providers/components'
import EndpointFormDialog from '@/features/providers/components/EndpointFormDialog.vue'
import ProviderModelFormDialog from '@/features/providers/components/ProviderModelFormDialog.vue'
@@ -737,10 +734,6 @@ const deleteModelConfirmOpen = ref(false)
const modelToDelete = ref<Model | null>(null)
const batchAssignDialogOpen = ref(false)
// 别名管理相关状态
const aliasDialogOpen = ref(false)
const aliasEditingModel = ref<Model | null>(null)
// 拖动排序相关状态
const dragState = ref({
isDragging: false,
@@ -762,8 +755,7 @@ const hasBlockingDialogOpen = computed(() =>
deleteKeyConfirmOpen.value ||
modelFormDialogOpen.value ||
deleteModelConfirmOpen.value ||
batchAssignDialogOpen.value ||
aliasDialogOpen.value
batchAssignDialogOpen.value
)
// 监听 providerId 变化
@@ -792,7 +784,6 @@ watch(() => props.open, (newOpen) => {
keyAllowedModelsDialogOpen.value = false
deleteKeyConfirmOpen.value = false
batchAssignDialogOpen.value = false
aliasDialogOpen.value = false
// 重置临时数据
endpointToEdit.value = null
@@ -1030,19 +1021,6 @@ async function handleBatchAssignChanged() {
emit('refresh')
}
// 处理管理映射 - 打开别名对话框
function handleManageAlias(model: Model) {
aliasEditingModel.value = model
aliasDialogOpen.value = true
}
// 处理别名保存完成
async function handleAliasSaved() {
aliasEditingModel.value = null
await loadProvider()
emit('refresh')
}
// 处理模型保存完成
async function handleModelSaved() {
editingModel.value = null

View File

@@ -10,3 +10,4 @@ export { default as BatchAssignModelsDialog } from './BatchAssignModelsDialog.vu
export { default as ModelAliasDialog } from './ModelAliasDialog.vue'
export { default as ModelsTab } from './provider-tabs/ModelsTab.vue'
export { default as ModelAliasesTab } from './provider-tabs/ModelAliasesTab.vue'

File diff suppressed because it is too large Load Diff

View File

@@ -165,15 +165,6 @@
>
<Edit class="w-3.5 h-3.5" />
</Button>
<Button
variant="ghost"
size="icon"
class="h-8 w-8"
title="管理映射"
@click="openAliasDialog(model)"
>
<Tag class="w-3.5 h-3.5" />
</Button>
<Button
variant="ghost"
size="icon"
@@ -218,7 +209,7 @@
<script setup lang="ts">
import { ref, computed, onMounted } from 'vue'
import { Box, Edit, Trash2, Layers, Eye, Wrench, Zap, Brain, Power, Copy, Image, Tag } from 'lucide-vue-next'
import { Box, Edit, Trash2, Layers, Eye, Wrench, Zap, Brain, Power, Copy, Image } from 'lucide-vue-next'
import Card from '@/components/ui/card.vue'
import Button from '@/components/ui/button.vue'
import { useToast } from '@/composables/useToast'
@@ -233,7 +224,6 @@ const emit = defineEmits<{
'editModel': [model: Model]
'deleteModel': [model: Model]
'batchAssign': []
'manageAlias': [model: Model]
}>()
const { error: showError, success: showSuccess } = useToast()
@@ -373,11 +363,6 @@ function openBatchAssignDialog() {
emit('batchAssign')
}
// 打开别名管理对话框
function openAliasDialog(model: Model) {
emit('manageAlias', model)
}
// 切换模型启用状态
async function toggleModelActive(model: Model) {
if (togglingModelId.value) return

View File

@@ -751,15 +751,13 @@ const expiringSoonCount = computed(() => apiKeys.value.filter(key => isExpiringS
const filteredApiKeys = computed(() => {
let result = apiKeys.value
// 搜索筛选
// 搜索筛选(支持空格分隔的多关键词 AND 搜索)
if (searchQuery.value) {
const query = searchQuery.value.toLowerCase()
result = result.filter(key =>
(key.name && key.name.toLowerCase().includes(query)) ||
(key.key_display && key.key_display.toLowerCase().includes(query)) ||
(key.username && key.username.toLowerCase().includes(query)) ||
(key.user_email && key.user_email.toLowerCase().includes(query))
)
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
result = result.filter(key => {
const searchableText = `${key.name || ''} ${key.key_display || ''} ${key.username || ''} ${key.user_email || ''}`.toLowerCase()
return keywords.every(keyword => searchableText.includes(keyword))
})
}
// 状态筛选

View File

@@ -1002,13 +1002,13 @@ async function batchRemoveSelectedProviders() {
const filteredGlobalModels = computed(() => {
let result = globalModels.value
// 搜索
// 搜索(支持空格分隔的多关键词 AND 搜索)
if (searchQuery.value) {
const query = searchQuery.value.toLowerCase()
result = result.filter(m =>
m.name.toLowerCase().includes(query) ||
m.display_name?.toLowerCase().includes(query)
)
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
result = result.filter(m => {
const searchableText = `${m.name} ${m.display_name || ''}`.toLowerCase()
return keywords.every(keyword => searchableText.includes(keyword))
})
}
// 能力筛选

View File

@@ -505,13 +505,13 @@ const priorityModeConfig = computed(() => {
const filteredProviders = computed(() => {
let result = [...providers.value]
// 搜索筛选
// 搜索筛选(支持空格分隔的多关键词 AND 搜索)
if (searchQuery.value.trim()) {
const query = searchQuery.value.toLowerCase()
result = result.filter(p =>
p.display_name.toLowerCase().includes(query) ||
p.name.toLowerCase().includes(query)
)
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
result = result.filter(p => {
const searchableText = `${p.display_name} ${p.name}`.toLowerCase()
return keywords.every(keyword => searchableText.includes(keyword))
})
}
// 排序

View File

@@ -15,6 +15,94 @@
</PageHeader>
<div class="mt-6 space-y-6">
<!-- 配置导出/导入 -->
<CardSection
title="配置管理"
description="导出或导入提供商和模型配置,便于备份或迁移"
>
<div class="flex flex-wrap gap-4">
<div class="flex-1 min-w-[200px]">
<p class="text-sm text-muted-foreground mb-3">
导出当前所有提供商端点API Key 和模型配置到 JSON 文件
</p>
<Button
variant="outline"
:disabled="exportLoading"
@click="handleExportConfig"
>
<Download class="w-4 h-4 mr-2" />
{{ exportLoading ? '导出中...' : '导出配置' }}
</Button>
</div>
<div class="flex-1 min-w-[200px]">
<p class="text-sm text-muted-foreground mb-3">
JSON 文件导入配置支持跳过覆盖或报错三种冲突处理模式
</p>
<div class="flex items-center gap-2">
<input
ref="configFileInput"
type="file"
accept=".json"
class="hidden"
@change="handleConfigFileSelect"
>
<Button
variant="outline"
:disabled="importLoading"
@click="triggerConfigFileSelect"
>
<Upload class="w-4 h-4 mr-2" />
{{ importLoading ? '导入中...' : '导入配置' }}
</Button>
</div>
</div>
</div>
</CardSection>
<!-- 用户数据导出/导入 -->
<CardSection
title="用户数据管理"
description="导出或导入用户及其 API Keys 数据(不含管理员)"
>
<div class="flex flex-wrap gap-4">
<div class="flex-1 min-w-[200px]">
<p class="text-sm text-muted-foreground mb-3">
导出所有普通用户及其 API Keys JSON 文件
</p>
<Button
variant="outline"
:disabled="exportUsersLoading"
@click="handleExportUsers"
>
<Download class="w-4 h-4 mr-2" />
{{ exportUsersLoading ? '导出中...' : '导出用户数据' }}
</Button>
</div>
<div class="flex-1 min-w-[200px]">
<p class="text-sm text-muted-foreground mb-3">
JSON 文件导入用户数据需相同 ENCRYPTION_KEY
</p>
<div class="flex items-center gap-2">
<input
ref="usersFileInput"
type="file"
accept=".json"
class="hidden"
@change="handleUsersFileSelect"
>
<Button
variant="outline"
:disabled="importUsersLoading"
@click="triggerUsersFileSelect"
>
<Upload class="w-4 h-4 mr-2" />
{{ importUsersLoading ? '导入中...' : '导入用户数据' }}
</Button>
</div>
</div>
</div>
</CardSection>
<!-- 基础配置 -->
<CardSection
title="基础配置"
@@ -375,11 +463,326 @@
</div>
</CardSection>
</div>
<!-- 导入配置对话框 -->
<Dialog v-model:open="importDialogOpen">
<DialogContent class="max-w-lg">
<DialogHeader>
<DialogTitle>导入配置</DialogTitle>
<DialogDescription>
选择冲突处理模式并确认导入
</DialogDescription>
</DialogHeader>
<div class="space-y-4 py-4">
<div
v-if="importPreview"
class="p-3 bg-muted rounded-lg text-sm"
>
<p class="font-medium mb-2">
配置预览
</p>
<ul class="space-y-1 text-muted-foreground">
<li>全局模型: {{ importPreview.global_models?.length || 0 }} </li>
<li>提供商: {{ importPreview.providers?.length || 0 }} </li>
<li>
端点: {{ importPreview.providers?.reduce((sum: number, p: any) => sum + (p.endpoints?.length || 0), 0) }}
</li>
<li>
API Keys: {{ importPreview.providers?.reduce((sum: number, p: any) => sum + p.endpoints?.reduce((s: number, e: any) => s + (e.keys?.length || 0), 0), 0) }}
</li>
</ul>
</div>
<div>
<Label class="block text-sm font-medium mb-2">冲突处理模式</Label>
<Select v-model="mergeMode">
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
<SelectItem value="skip">
跳过 - 保留现有配置
</SelectItem>
<SelectItem value="overwrite">
覆盖 - 用导入配置替换
</SelectItem>
<SelectItem value="error">
报错 - 遇到冲突时中止
</SelectItem>
</SelectContent>
</Select>
<p class="mt-1 text-xs text-muted-foreground">
<template v-if="mergeMode === 'skip'">
已存在的配置将被保留仅导入新配置
</template>
<template v-else-if="mergeMode === 'overwrite'">
已存在的配置将被导入的配置覆盖
</template>
<template v-else>
如果发现任何冲突导入将中止并回滚
</template>
</p>
</div>
<div class="p-3 bg-yellow-500/10 border border-yellow-500/20 rounded-lg">
<p class="text-sm text-yellow-600 dark:text-yellow-400">
注意相同的 API Keys 会自动跳过不会创建重复记录
</p>
</div>
</div>
<DialogFooter>
<Button
variant="outline"
@click="importDialogOpen = false"
>
取消
</Button>
<Button
:disabled="importLoading"
@click="confirmImport"
>
{{ importLoading ? '导入中...' : '确认导入' }}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
<!-- 导入结果对话框 -->
<Dialog v-model:open="importResultDialogOpen">
<DialogContent class="max-w-lg">
<DialogHeader>
<DialogTitle>导入完成</DialogTitle>
</DialogHeader>
<div
v-if="importResult"
class="space-y-4 py-4"
>
<div class="grid grid-cols-2 gap-4 text-sm">
<div class="p-3 bg-muted rounded-lg">
<p class="font-medium">
全局模型
</p>
<p class="text-muted-foreground">
创建: {{ importResult.stats.global_models.created }},
更新: {{ importResult.stats.global_models.updated }},
跳过: {{ importResult.stats.global_models.skipped }}
</p>
</div>
<div class="p-3 bg-muted rounded-lg">
<p class="font-medium">
提供商
</p>
<p class="text-muted-foreground">
创建: {{ importResult.stats.providers.created }},
更新: {{ importResult.stats.providers.updated }},
跳过: {{ importResult.stats.providers.skipped }}
</p>
</div>
<div class="p-3 bg-muted rounded-lg">
<p class="font-medium">
端点
</p>
<p class="text-muted-foreground">
创建: {{ importResult.stats.endpoints.created }},
更新: {{ importResult.stats.endpoints.updated }},
跳过: {{ importResult.stats.endpoints.skipped }}
</p>
</div>
<div class="p-3 bg-muted rounded-lg">
<p class="font-medium">
API Keys
</p>
<p class="text-muted-foreground">
创建: {{ importResult.stats.keys.created }},
跳过: {{ importResult.stats.keys.skipped }}
</p>
</div>
<div class="p-3 bg-muted rounded-lg col-span-2">
<p class="font-medium">
模型配置
</p>
<p class="text-muted-foreground">
创建: {{ importResult.stats.models.created }},
更新: {{ importResult.stats.models.updated }},
跳过: {{ importResult.stats.models.skipped }}
</p>
</div>
</div>
<div
v-if="importResult.stats.errors.length > 0"
class="p-3 bg-red-500/10 border border-red-500/20 rounded-lg"
>
<p class="font-medium text-red-600 dark:text-red-400 mb-2">
警告信息
</p>
<ul class="text-sm text-red-600 dark:text-red-400 space-y-1">
<li
v-for="(err, index) in importResult.stats.errors"
:key="index"
>
{{ err }}
</li>
</ul>
</div>
</div>
<DialogFooter>
<Button @click="importResultDialogOpen = false">
确定
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
<!-- 用户数据导入对话框 -->
<Dialog v-model:open="importUsersDialogOpen">
<DialogContent class="max-w-lg">
<DialogHeader>
<DialogTitle>导入用户数据</DialogTitle>
<DialogDescription>
选择冲突处理模式并确认导入
</DialogDescription>
</DialogHeader>
<div class="space-y-4 py-4">
<div
v-if="importUsersPreview"
class="p-3 bg-muted rounded-lg text-sm"
>
<p class="font-medium mb-2">
数据预览
</p>
<ul class="space-y-1 text-muted-foreground">
<li>用户: {{ importUsersPreview.users?.length || 0 }} </li>
<li>
API Keys: {{ importUsersPreview.users?.reduce((sum: number, u: any) => sum + (u.api_keys?.length || 0), 0) }}
</li>
</ul>
</div>
<div>
<Label class="block text-sm font-medium mb-2">冲突处理模式</Label>
<Select v-model="usersMergeMode">
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
<SelectItem value="skip">
跳过 - 保留现有用户
</SelectItem>
<SelectItem value="overwrite">
覆盖 - 用导入数据替换
</SelectItem>
<SelectItem value="error">
报错 - 遇到冲突时中止
</SelectItem>
</SelectContent>
</Select>
<p class="mt-1 text-xs text-muted-foreground">
<template v-if="usersMergeMode === 'skip'">
已存在的用户将被保留仅导入新用户
</template>
<template v-else-if="usersMergeMode === 'overwrite'">
已存在的用户将被导入的数据覆盖
</template>
<template v-else>
如果发现任何冲突导入将中止并回滚
</template>
</p>
</div>
<div class="p-3 bg-yellow-500/10 border border-yellow-500/20 rounded-lg">
<p class="text-sm text-yellow-600 dark:text-yellow-400">
注意用户 API Keys 需要目标系统使用相同的 ENCRYPTION_KEY 环境变量才能正常工作
</p>
</div>
</div>
<DialogFooter>
<Button
variant="outline"
@click="importUsersDialogOpen = false"
>
取消
</Button>
<Button
:disabled="importUsersLoading"
@click="confirmImportUsers"
>
{{ importUsersLoading ? '导入中...' : '确认导入' }}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
<!-- 用户数据导入结果对话框 -->
<Dialog v-model:open="importUsersResultDialogOpen">
<DialogContent class="max-w-lg">
<DialogHeader>
<DialogTitle>用户数据导入完成</DialogTitle>
</DialogHeader>
<div
v-if="importUsersResult"
class="space-y-4 py-4"
>
<div class="grid grid-cols-2 gap-4 text-sm">
<div class="p-3 bg-muted rounded-lg">
<p class="font-medium">
用户
</p>
<p class="text-muted-foreground">
创建: {{ importUsersResult.stats.users.created }},
更新: {{ importUsersResult.stats.users.updated }},
跳过: {{ importUsersResult.stats.users.skipped }}
</p>
</div>
<div class="p-3 bg-muted rounded-lg">
<p class="font-medium">
API Keys
</p>
<p class="text-muted-foreground">
创建: {{ importUsersResult.stats.api_keys.created }},
跳过: {{ importUsersResult.stats.api_keys.skipped }}
</p>
</div>
</div>
<div
v-if="importUsersResult.stats.errors.length > 0"
class="p-3 bg-red-500/10 border border-red-500/20 rounded-lg"
>
<p class="font-medium text-red-600 dark:text-red-400 mb-2">
警告信息
</p>
<ul class="text-sm text-red-600 dark:text-red-400 space-y-1">
<li
v-for="(err, index) in importUsersResult.stats.errors"
:key="index"
>
{{ err }}
</li>
</ul>
</div>
</div>
<DialogFooter>
<Button @click="importUsersResultDialogOpen = false">
确定
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
</PageContainer>
</template>
<script setup lang="ts">
import { ref, computed, onMounted } from 'vue'
import { Download, Upload } from 'lucide-vue-next'
import Button from '@/components/ui/button.vue'
import Input from '@/components/ui/input.vue'
import Label from '@/components/ui/label.vue'
@@ -389,9 +792,17 @@ import SelectTrigger from '@/components/ui/select-trigger.vue'
import SelectValue from '@/components/ui/select-value.vue'
import SelectContent from '@/components/ui/select-content.vue'
import SelectItem from '@/components/ui/select-item.vue'
import {
Dialog,
DialogContent,
DialogHeader,
DialogTitle,
DialogDescription,
DialogFooter
} from '@/components/ui'
import { PageHeader, PageContainer, CardSection } from '@/components/layout'
import { useToast } from '@/composables/useToast'
import { adminApi } from '@/api/admin'
import { adminApi, type ConfigExportData, type ConfigImportResponse, type UsersExportData, type UsersImportResponse } from '@/api/admin'
import { log } from '@/utils/logger'
const { success, error } = useToast()
@@ -423,6 +834,26 @@ interface SystemConfig {
const loading = ref(false)
const logLevelSelectOpen = ref(false)
// 导出/导入相关
const exportLoading = ref(false)
const importLoading = ref(false)
const importDialogOpen = ref(false)
const importResultDialogOpen = ref(false)
const configFileInput = ref<HTMLInputElement | null>(null)
const importPreview = ref<ConfigExportData | null>(null)
const importResult = ref<ConfigImportResponse | null>(null)
const mergeMode = ref<'skip' | 'overwrite' | 'error'>('skip')
// 用户数据导出/导入相关
const exportUsersLoading = ref(false)
const importUsersLoading = ref(false)
const importUsersDialogOpen = ref(false)
const importUsersResultDialogOpen = ref(false)
const usersFileInput = ref<HTMLInputElement | null>(null)
const importUsersPreview = ref<UsersExportData | null>(null)
const importUsersResult = ref<UsersImportResponse | null>(null)
const usersMergeMode = ref<'skip' | 'overwrite' | 'error'>('skip')
const systemConfig = ref<SystemConfig>({
// 基础配置
default_user_quota_usd: 10.0,
@@ -623,4 +1054,183 @@ async function saveSystemConfig() {
loading.value = false
}
}
// 导出配置
async function handleExportConfig() {
exportLoading.value = true
try {
const data = await adminApi.exportConfig()
const blob = new Blob([JSON.stringify(data, null, 2)], { type: 'application/json' })
const url = URL.createObjectURL(blob)
const a = document.createElement('a')
a.href = url
a.download = `aether-config-${new Date().toISOString().slice(0, 10)}.json`
document.body.appendChild(a)
a.click()
document.body.removeChild(a)
URL.revokeObjectURL(url)
success('配置已导出')
} catch (err) {
error('导出配置失败')
log.error('导出配置失败:', err)
} finally {
exportLoading.value = false
}
}
// 触发文件选择
function triggerConfigFileSelect() {
configFileInput.value?.click()
}
// 文件大小限制 (10MB)
const MAX_FILE_SIZE = 10 * 1024 * 1024
// 处理文件选择
function handleConfigFileSelect(event: Event) {
const input = event.target as HTMLInputElement
const file = input.files?.[0]
if (!file) return
if (file.size > MAX_FILE_SIZE) {
error('文件大小不能超过 10MB')
input.value = ''
return
}
const reader = new FileReader()
reader.onload = (e) => {
try {
const content = e.target?.result as string
const data = JSON.parse(content) as ConfigExportData
// 验证版本
if (data.version !== '1.0') {
error(`不支持的配置版本: ${data.version}`)
return
}
importPreview.value = data
mergeMode.value = 'skip'
importDialogOpen.value = true
} catch (err) {
error('解析配置文件失败,请确保是有效的 JSON 文件')
log.error('解析配置文件失败:', err)
}
}
reader.readAsText(file)
// 重置 input 以便能再次选择同一文件
input.value = ''
}
// 确认导入
async function confirmImport() {
if (!importPreview.value) return
importLoading.value = true
try {
const result = await adminApi.importConfig({
...importPreview.value,
merge_mode: mergeMode.value
})
importResult.value = result
importDialogOpen.value = false
importResultDialogOpen.value = true
success('配置导入成功')
} catch (err: any) {
error(err.response?.data?.detail || '导入配置失败')
log.error('导入配置失败:', err)
} finally {
importLoading.value = false
}
}
// 导出用户数据
async function handleExportUsers() {
exportUsersLoading.value = true
try {
const data = await adminApi.exportUsers()
const blob = new Blob([JSON.stringify(data, null, 2)], { type: 'application/json' })
const url = URL.createObjectURL(blob)
const a = document.createElement('a')
a.href = url
a.download = `aether-users-${new Date().toISOString().slice(0, 10)}.json`
document.body.appendChild(a)
a.click()
document.body.removeChild(a)
URL.revokeObjectURL(url)
success('用户数据已导出')
} catch (err) {
error('导出用户数据失败')
log.error('导出用户数据失败:', err)
} finally {
exportUsersLoading.value = false
}
}
// 触发用户数据文件选择
function triggerUsersFileSelect() {
usersFileInput.value?.click()
}
// 处理用户数据文件选择
function handleUsersFileSelect(event: Event) {
const input = event.target as HTMLInputElement
const file = input.files?.[0]
if (!file) return
if (file.size > MAX_FILE_SIZE) {
error('文件大小不能超过 10MB')
input.value = ''
return
}
const reader = new FileReader()
reader.onload = (e) => {
try {
const content = e.target?.result as string
const data = JSON.parse(content) as UsersExportData
// 验证版本
if (data.version !== '1.0') {
error(`不支持的配置版本: ${data.version}`)
return
}
importUsersPreview.value = data
usersMergeMode.value = 'skip'
importUsersDialogOpen.value = true
} catch (err) {
error('解析用户数据文件失败,请确保是有效的 JSON 文件')
log.error('解析用户数据文件失败:', err)
}
}
reader.readAsText(file)
// 重置 input 以便能再次选择同一文件
input.value = ''
}
// 确认导入用户数据
async function confirmImportUsers() {
if (!importUsersPreview.value) return
importUsersLoading.value = true
try {
const result = await adminApi.importUsers({
...importUsersPreview.value,
merge_mode: usersMergeMode.value
})
importUsersResult.value = result
importUsersDialogOpen.value = false
importUsersResultDialogOpen.value = true
success('用户数据导入成功')
} catch (err: any) {
error(err.response?.data?.detail || '导入用户数据失败')
log.error('导入用户数据失败:', err)
} finally {
importUsersLoading.value = false
}
}
</script>

View File

@@ -791,11 +791,13 @@ const filteredUsers = computed(() => {
return new Date(b.created_at).getTime() - new Date(a.created_at).getTime()
})
// 搜索(支持空格分隔的多关键词 AND 搜索)
if (searchQuery.value) {
const query = searchQuery.value.toLowerCase()
filtered = filtered.filter(
u => u.username.toLowerCase().includes(query) || u.email?.toLowerCase().includes(query)
)
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
filtered = filtered.filter(u => {
const searchableText = `${u.username} ${u.email || ''}`.toLowerCase()
return keywords.every(keyword => searchableText.includes(keyword))
})
}
if (filterRole.value !== 'all') {

View File

@@ -474,13 +474,13 @@ async function toggleCapability(modelName: string, capName: string) {
const filteredModels = computed(() => {
let result = models.value
// 搜索
// 搜索(支持空格分隔的多关键词 AND 搜索)
if (searchQuery.value) {
const query = searchQuery.value.toLowerCase()
result = result.filter(m =>
m.name.toLowerCase().includes(query) ||
m.display_name?.toLowerCase().includes(query)
)
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
result = result.filter(m => {
const searchableText = `${m.name} ${m.display_name || ''}`.toLowerCase()
return keywords.every(keyword => searchableText.includes(keyword))
})
}
// 能力筛选

View File

@@ -7,6 +7,7 @@ from .api_keys import router as api_keys_router
from .endpoints import router as endpoints_router
from .models import router as models_router
from .monitoring import router as monitoring_router
from .provider_query import router as provider_query_router
from .provider_strategy import router as provider_strategy_router
from .providers import router as providers_router
from .security import router as security_router
@@ -26,5 +27,6 @@ router.include_router(provider_strategy_router)
router.include_router(adaptive_router)
router.include_router(models_router)
router.include_router(security_router)
router.include_router(provider_query_router)
__all__ = ["router"]

View File

@@ -1,46 +1,28 @@
"""
Provider Query API 端点
用于查询提供商的余额、使用记录等信息
用于查询提供商的模型列表等信息
"""
from datetime import datetime
import asyncio
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
import httpx
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session, joinedload
from src.core.crypto import crypto_service
from src.core.logger import logger
from src.database.database import get_db
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint, User
# 初始化适配器注册
from src.plugins.provider_query import init # noqa
from src.plugins.provider_query import get_query_registry
from src.plugins.provider_query.base import QueryCapability
from src.models.database import Provider, ProviderEndpoint, User
from src.utils.auth_utils import get_current_user
router = APIRouter(prefix="/provider-query", tags=["Provider Query"])
router = APIRouter(prefix="/api/admin/provider-query", tags=["Provider Query"])
# ============ Request/Response Models ============
class BalanceQueryRequest(BaseModel):
"""余额查询请求"""
provider_id: str
api_key_id: Optional[str] = None # 如果不指定,使用提供商的第一个可用 API Key
class UsageSummaryQueryRequest(BaseModel):
"""使用汇总查询请求"""
provider_id: str
api_key_id: Optional[str] = None
period: str = "month" # day, week, month, year
class ModelsQueryRequest(BaseModel):
"""模型列表查询请求"""
@@ -51,360 +33,281 @@ class ModelsQueryRequest(BaseModel):
# ============ API Endpoints ============
@router.get("/adapters")
async def list_adapters(
current_user: User = Depends(get_current_user),
):
"""
获取所有可用的查询适配器
async def _fetch_openai_models(
client: httpx.AsyncClient,
base_url: str,
api_key: str,
api_format: str,
extra_headers: Optional[dict] = None,
) -> tuple[list, Optional[str]]:
"""获取 OpenAI 格式的模型列表
Returns:
适配器列表
tuple[list, Optional[str]]: (模型列表, 错误信息)
"""
registry = get_query_registry()
adapters = registry.list_adapters()
headers = {"Authorization": f"Bearer {api_key}"}
if extra_headers:
# 防止 extra_headers 覆盖 Authorization
safe_headers = {k: v for k, v in extra_headers.items() if k.lower() != "authorization"}
headers.update(safe_headers)
return {"success": True, "data": adapters}
# 构建 /v1/models URL
if base_url.endswith("/v1"):
models_url = f"{base_url}/models"
else:
models_url = f"{base_url}/v1/models"
try:
response = await client.get(models_url, headers=headers)
logger.debug(f"OpenAI models request to {models_url}: status={response.status_code}")
if response.status_code == 200:
data = response.json()
models = []
if "data" in data:
models = data["data"]
elif isinstance(data, list):
models = data
# 为每个模型添加 api_format 字段
for m in models:
m["api_format"] = api_format
return models, None
else:
# 记录详细的错误信息
error_body = response.text[:500] if response.text else "(empty)"
error_msg = f"HTTP {response.status_code}: {error_body}"
logger.warning(f"OpenAI models request to {models_url} failed: {error_msg}")
return [], error_msg
except Exception as e:
error_msg = f"Request error: {str(e)}"
logger.warning(f"Failed to fetch models from {models_url}: {e}")
return [], error_msg
@router.get("/capabilities/{provider_id}")
async def get_provider_capabilities(
provider_id: str,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
获取提供商支持的查询能力
Args:
provider_id: 提供商 ID
async def _fetch_claude_models(
client: httpx.AsyncClient, base_url: str, api_key: str, api_format: str
) -> tuple[list, Optional[str]]:
"""获取 Claude 格式的模型列表
Returns:
支持的查询能力列表
tuple[list, Optional[str]]: (模型列表, 错误信息)
"""
# 获取提供商
from sqlalchemy import select
result = await db.execute(select(Provider).where(Provider.id == provider_id))
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
registry = get_query_registry()
capabilities = registry.get_capabilities_for_provider(provider.name)
if capabilities is None:
return {
"success": True,
"data": {
"provider_id": provider_id,
"provider_name": provider.name,
"capabilities": [],
"has_adapter": False,
"message": "No query adapter available for this provider",
},
}
return {
"success": True,
"data": {
"provider_id": provider_id,
"provider_name": provider.name,
"capabilities": [c.name for c in capabilities],
"has_adapter": True,
},
headers = {
"x-api-key": api_key,
"Authorization": f"Bearer {api_key}",
"anthropic-version": "2023-06-01",
}
# 构建 /v1/models URL
if base_url.endswith("/v1"):
models_url = f"{base_url}/models"
else:
models_url = f"{base_url}/v1/models"
@router.post("/balance")
async def query_balance(
request: BalanceQueryRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
查询提供商余额
try:
response = await client.get(models_url, headers=headers)
logger.debug(f"Claude models request to {models_url}: status={response.status_code}")
if response.status_code == 200:
data = response.json()
models = []
if "data" in data:
models = data["data"]
elif isinstance(data, list):
models = data
# 为每个模型添加 api_format 字段
for m in models:
m["api_format"] = api_format
return models, None
else:
error_body = response.text[:500] if response.text else "(empty)"
error_msg = f"HTTP {response.status_code}: {error_body}"
logger.warning(f"Claude models request to {models_url} failed: {error_msg}")
return [], error_msg
except Exception as e:
error_msg = f"Request error: {str(e)}"
logger.warning(f"Failed to fetch Claude models from {models_url}: {e}")
return [], error_msg
Args:
request: 查询请求
async def _fetch_gemini_models(
client: httpx.AsyncClient, base_url: str, api_key: str, api_format: str
) -> tuple[list, Optional[str]]:
"""获取 Gemini 格式的模型列表
Returns:
余额信息
tuple[list, Optional[str]]: (模型列表, 错误信息)
"""
from sqlalchemy import select
from sqlalchemy.orm import selectinload
# 兼容 base_url 已包含 /v1beta 的情况
base_url_clean = base_url.rstrip("/")
if base_url_clean.endswith("/v1beta"):
models_url = f"{base_url_clean}/models?key={api_key}"
else:
models_url = f"{base_url_clean}/v1beta/models?key={api_key}"
# 获取提供商及其端点
result = await db.execute(
select(Provider)
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
.where(Provider.id == request.provider_id)
)
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
# 获取 API Key
api_key_value = None
endpoint_config = None
if request.api_key_id:
# 查找指定的 API Key
for endpoint in provider.endpoints:
for api_key in endpoint.api_keys:
if api_key.id == request.api_key_id:
api_key_value = api_key.api_key
endpoint_config = {
"base_url": endpoint.base_url,
"api_format": endpoint.api_format if endpoint.api_format else None,
try:
response = await client.get(models_url)
logger.debug(f"Gemini models request to {models_url}: status={response.status_code}")
if response.status_code == 200:
data = response.json()
if "models" in data:
# 转换为统一格式
return [
{
"id": m.get("name", "").replace("models/", ""),
"owned_by": "google",
"display_name": m.get("displayName", ""),
"api_format": api_format,
}
break
if api_key_value:
break
if not api_key_value:
raise HTTPException(status_code=404, detail="API Key not found")
else:
# 使用第一个可用的 API Key
for endpoint in provider.endpoints:
if endpoint.is_active and endpoint.api_keys:
for api_key in endpoint.api_keys:
if api_key.is_active:
api_key_value = api_key.api_key
endpoint_config = {
"base_url": endpoint.base_url,
"api_format": endpoint.api_format if endpoint.api_format else None,
}
break
if api_key_value:
break
if not api_key_value:
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
# 查询余额
registry = get_query_registry()
query_result = await registry.query_provider_balance(
provider_type=provider.name, api_key=api_key_value, endpoint_config=endpoint_config
)
if not query_result.success:
logger.warning(f"Balance query failed for provider {provider.name}: {query_result.error}")
return {
"success": query_result.success,
"data": query_result.to_dict(),
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
}
@router.post("/usage-summary")
async def query_usage_summary(
request: UsageSummaryQueryRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
查询提供商使用汇总
Args:
request: 查询请求
Returns:
使用汇总信息
"""
from sqlalchemy import select
from sqlalchemy.orm import selectinload
# 获取提供商及其端点
result = await db.execute(
select(Provider)
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
.where(Provider.id == request.provider_id)
)
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
# 获取 API Key逻辑同上
api_key_value = None
endpoint_config = None
if request.api_key_id:
for endpoint in provider.endpoints:
for api_key in endpoint.api_keys:
if api_key.id == request.api_key_id:
api_key_value = api_key.api_key
endpoint_config = {"base_url": endpoint.base_url}
break
if api_key_value:
break
if not api_key_value:
raise HTTPException(status_code=404, detail="API Key not found")
else:
for endpoint in provider.endpoints:
if endpoint.is_active and endpoint.api_keys:
for api_key in endpoint.api_keys:
if api_key.is_active:
api_key_value = api_key.api_key
endpoint_config = {"base_url": endpoint.base_url}
break
if api_key_value:
break
if not api_key_value:
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
# 查询使用汇总
registry = get_query_registry()
query_result = await registry.query_provider_usage(
provider_type=provider.name,
api_key=api_key_value,
period=request.period,
endpoint_config=endpoint_config,
)
return {
"success": query_result.success,
"data": query_result.to_dict(),
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
}
for m in data["models"]
], None
return [], None
else:
error_body = response.text[:500] if response.text else "(empty)"
error_msg = f"HTTP {response.status_code}: {error_body}"
logger.warning(f"Gemini models request to {models_url} failed: {error_msg}")
return [], error_msg
except Exception as e:
error_msg = f"Request error: {str(e)}"
logger.warning(f"Failed to fetch Gemini models from {models_url}: {e}")
return [], error_msg
@router.post("/models")
async def query_available_models(
request: ModelsQueryRequest,
db: AsyncSession = Depends(get_db),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
查询提供商可用模型
遍历所有活跃端点,根据端点的 API 格式选择正确的请求方式:
- OPENAI/OPENAI_CLI: /v1/models (Bearer token)
- CLAUDE/CLAUDE_CLI: /v1/models (x-api-key)
- GEMINI/GEMINI_CLI: /v1beta/models (URL key parameter)
Args:
request: 查询请求
Returns:
模型列表
所有端点的模型列表(合并)
"""
from sqlalchemy import select
from sqlalchemy.orm import selectinload
# 获取提供商及其端点
result = await db.execute(
select(Provider)
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
.where(Provider.id == request.provider_id)
provider = (
db.query(Provider)
.options(joinedload(Provider.endpoints).joinedload(ProviderEndpoint.api_keys))
.filter(Provider.id == request.provider_id)
.first()
)
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
# 获取 API Key
api_key_value = None
endpoint_config = None
# 收集所有活跃端点的配置
endpoint_configs: list[dict] = []
if request.api_key_id:
# 指定了特定的 API Key只使用该 Key 对应的端点
for endpoint in provider.endpoints:
for api_key in endpoint.api_keys:
if api_key.id == request.api_key_id:
api_key_value = api_key.api_key
endpoint_config = {"base_url": endpoint.base_url}
try:
api_key_value = crypto_service.decrypt(api_key.api_key)
except Exception as e:
logger.error(f"Failed to decrypt API key: {e}")
raise HTTPException(status_code=500, detail="Failed to decrypt API key")
endpoint_configs.append({
"api_key": api_key_value,
"base_url": endpoint.base_url,
"api_format": endpoint.api_format,
"extra_headers": endpoint.headers,
})
break
if api_key_value:
if endpoint_configs:
break
if not api_key_value:
if not endpoint_configs:
raise HTTPException(status_code=404, detail="API Key not found")
else:
# 遍历所有活跃端点,每个端点取第一个可用的 Key
for endpoint in provider.endpoints:
if endpoint.is_active and endpoint.api_keys:
for api_key in endpoint.api_keys:
if api_key.is_active:
api_key_value = api_key.api_key
endpoint_config = {"base_url": endpoint.base_url}
break
if api_key_value:
break
if not endpoint.is_active or not endpoint.api_keys:
continue
if not api_key_value:
# 找第一个可用的 Key
for api_key in endpoint.api_keys:
if api_key.is_active:
try:
api_key_value = crypto_service.decrypt(api_key.api_key)
except Exception as e:
logger.error(f"Failed to decrypt API key: {e}")
continue # 尝试下一个 Key
endpoint_configs.append({
"api_key": api_key_value,
"base_url": endpoint.base_url,
"api_format": endpoint.api_format,
"extra_headers": endpoint.headers,
})
break # 只取第一个可用的 Key
if not endpoint_configs:
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
# 查询模型
registry = get_query_registry()
adapter = registry.get_adapter_for_provider(provider.name)
# 并发请求所有端点的模型列表
all_models: list = []
errors: list[str] = []
if not adapter:
raise HTTPException(
status_code=400, detail=f"No query adapter available for provider: {provider.name}"
async def fetch_endpoint_models(
client: httpx.AsyncClient, config: dict
) -> tuple[list, Optional[str]]:
base_url = config["base_url"]
if not base_url:
return [], None
base_url = base_url.rstrip("/")
api_format = config["api_format"]
api_key_value = config["api_key"]
extra_headers = config["extra_headers"]
try:
if api_format in ["CLAUDE", "CLAUDE_CLI"]:
return await _fetch_claude_models(client, base_url, api_key_value, api_format)
elif api_format in ["GEMINI", "GEMINI_CLI"]:
return await _fetch_gemini_models(client, base_url, api_key_value, api_format)
else:
return await _fetch_openai_models(
client, base_url, api_key_value, api_format, extra_headers
)
except Exception as e:
logger.error(f"Error fetching models from {api_format} endpoint: {e}")
return [], f"{api_format}: {str(e)}"
async with httpx.AsyncClient(timeout=30.0) as client:
results = await asyncio.gather(
*[fetch_endpoint_models(client, c) for c in endpoint_configs]
)
for models, error in results:
all_models.extend(models)
if error:
errors.append(error)
query_result = await adapter.query_available_models(
api_key=api_key_value, endpoint_config=endpoint_config
)
# 按 model id 去重(保留第一个)
seen_ids: set[str] = set()
unique_models: list = []
for model in all_models:
model_id = model.get("id")
if model_id and model_id not in seen_ids:
seen_ids.add(model_id)
unique_models.append(model)
error = "; ".join(errors) if errors else None
if not unique_models and not error:
error = "No models returned from any endpoint"
return {
"success": query_result.success,
"data": query_result.to_dict(),
"success": len(unique_models) > 0,
"data": {"models": unique_models, "error": error},
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
}
@router.delete("/cache/{provider_id}")
async def clear_query_cache(
provider_id: str,
api_key_id: Optional[str] = None,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
清除查询缓存
Args:
provider_id: 提供商 ID
api_key_id: 可选,指定清除某个 API Key 的缓存
Returns:
清除结果
"""
from sqlalchemy import select
# 获取提供商
result = await db.execute(select(Provider).where(Provider.id == provider_id))
provider = result.scalar_one_or_none()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
registry = get_query_registry()
adapter = registry.get_adapter_for_provider(provider.name)
if adapter:
if api_key_id:
# 获取 API Key 值来清除缓存
from sqlalchemy.orm import selectinload
result = await db.execute(select(ProviderAPIKey).where(ProviderAPIKey.id == api_key_id))
api_key = result.scalar_one_or_none()
if api_key:
adapter.clear_cache(api_key.api_key)
else:
adapter.clear_cache()
return {"success": True, "message": "Cache cleared successfully"}

View File

@@ -91,6 +91,34 @@ async def get_api_formats(request: Request, db: Session = Depends(get_db)):
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/config/export")
async def export_config(request: Request, db: Session = Depends(get_db)):
"""导出提供商和模型配置(管理员)"""
adapter = AdminExportConfigAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/config/import")
async def import_config(request: Request, db: Session = Depends(get_db)):
"""导入提供商和模型配置(管理员)"""
adapter = AdminImportConfigAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/users/export")
async def export_users(request: Request, db: Session = Depends(get_db)):
"""导出用户数据(管理员)"""
adapter = AdminExportUsersAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/users/import")
async def import_users(request: Request, db: Session = Depends(get_db)):
"""导入用户数据(管理员)"""
adapter = AdminImportUsersAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# -------- 系统设置适配器 --------
@@ -310,3 +338,749 @@ class AdminGetApiFormatsAdapter(AdminApiAdapter):
)
return {"formats": formats}
class AdminExportConfigAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
"""导出提供商和模型配置(解密数据)"""
from datetime import datetime, timezone
from src.core.crypto import crypto_service
from src.models.database import GlobalModel, Model, ProviderAPIKey, ProviderEndpoint
db = context.db
# 导出 GlobalModels
global_models = db.query(GlobalModel).all()
global_models_data = []
for gm in global_models:
global_models_data.append(
{
"name": gm.name,
"display_name": gm.display_name,
"default_price_per_request": gm.default_price_per_request,
"default_tiered_pricing": gm.default_tiered_pricing,
"supported_capabilities": gm.supported_capabilities,
"config": gm.config,
"is_active": gm.is_active,
}
)
# 导出 Providers 及其关联数据
providers = db.query(Provider).all()
providers_data = []
for provider in providers:
# 导出 Endpoints
endpoints = (
db.query(ProviderEndpoint)
.filter(ProviderEndpoint.provider_id == provider.id)
.all()
)
endpoints_data = []
for ep in endpoints:
# 导出 Endpoint Keys
keys = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == ep.id).all()
)
keys_data = []
for key in keys:
# 解密 API Key
try:
decrypted_key = crypto_service.decrypt(key.api_key)
except Exception:
decrypted_key = ""
keys_data.append(
{
"api_key": decrypted_key,
"name": key.name,
"note": key.note,
"rate_multiplier": key.rate_multiplier,
"internal_priority": key.internal_priority,
"global_priority": key.global_priority,
"max_concurrent": key.max_concurrent,
"rate_limit": key.rate_limit,
"daily_limit": key.daily_limit,
"monthly_limit": key.monthly_limit,
"allowed_models": key.allowed_models,
"capabilities": key.capabilities,
"is_active": key.is_active,
}
)
endpoints_data.append(
{
"api_format": ep.api_format,
"base_url": ep.base_url,
"headers": ep.headers,
"timeout": ep.timeout,
"max_retries": ep.max_retries,
"max_concurrent": ep.max_concurrent,
"rate_limit": ep.rate_limit,
"is_active": ep.is_active,
"custom_path": ep.custom_path,
"config": ep.config,
"keys": keys_data,
}
)
# 导出 Provider Models
models = db.query(Model).filter(Model.provider_id == provider.id).all()
models_data = []
for model in models:
# 获取关联的 GlobalModel 名称
global_model = (
db.query(GlobalModel).filter(GlobalModel.id == model.global_model_id).first()
)
models_data.append(
{
"global_model_name": global_model.name if global_model else None,
"provider_model_name": model.provider_model_name,
"provider_model_aliases": model.provider_model_aliases,
"price_per_request": model.price_per_request,
"tiered_pricing": model.tiered_pricing,
"supports_vision": model.supports_vision,
"supports_function_calling": model.supports_function_calling,
"supports_streaming": model.supports_streaming,
"supports_extended_thinking": model.supports_extended_thinking,
"supports_image_generation": model.supports_image_generation,
"is_active": model.is_active,
"config": model.config,
}
)
providers_data.append(
{
"name": provider.name,
"display_name": provider.display_name,
"description": provider.description,
"website": provider.website,
"billing_type": provider.billing_type.value if provider.billing_type else None,
"monthly_quota_usd": provider.monthly_quota_usd,
"quota_reset_day": provider.quota_reset_day,
"rpm_limit": provider.rpm_limit,
"provider_priority": provider.provider_priority,
"is_active": provider.is_active,
"rate_limit": provider.rate_limit,
"concurrent_limit": provider.concurrent_limit,
"config": provider.config,
"endpoints": endpoints_data,
"models": models_data,
}
)
return {
"version": "1.0",
"exported_at": datetime.now(timezone.utc).isoformat(),
"global_models": global_models_data,
"providers": providers_data,
}
MAX_IMPORT_SIZE = 10 * 1024 * 1024 # 10MB
class AdminImportConfigAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
"""导入提供商和模型配置"""
import uuid
from datetime import datetime, timezone
from src.core.crypto import crypto_service
from src.core.enums import ProviderBillingType
from src.models.database import GlobalModel, Model, ProviderAPIKey, ProviderEndpoint
# 检查请求体大小
if context.raw_body and len(context.raw_body) > MAX_IMPORT_SIZE:
raise InvalidRequestException("请求体大小不能超过 10MB")
db = context.db
payload = context.ensure_json_body()
# 验证配置版本
version = payload.get("version")
if version != "1.0":
raise InvalidRequestException(f"不支持的配置版本: {version}")
# 获取导入选项
merge_mode = payload.get("merge_mode", "skip") # skip, overwrite, error
global_models_data = payload.get("global_models", [])
providers_data = payload.get("providers", [])
stats = {
"global_models": {"created": 0, "updated": 0, "skipped": 0},
"providers": {"created": 0, "updated": 0, "skipped": 0},
"endpoints": {"created": 0, "updated": 0, "skipped": 0},
"keys": {"created": 0, "updated": 0, "skipped": 0},
"models": {"created": 0, "updated": 0, "skipped": 0},
"errors": [],
}
try:
# 导入 GlobalModels
global_model_map = {} # name -> id 映射
for gm_data in global_models_data:
existing = (
db.query(GlobalModel).filter(GlobalModel.name == gm_data["name"]).first()
)
if existing:
global_model_map[gm_data["name"]] = existing.id
if merge_mode == "skip":
stats["global_models"]["skipped"] += 1
continue
elif merge_mode == "error":
raise InvalidRequestException(
f"GlobalModel '{gm_data['name']}' 已存在"
)
elif merge_mode == "overwrite":
# 更新现有记录
existing.display_name = gm_data.get(
"display_name", existing.display_name
)
existing.default_price_per_request = gm_data.get(
"default_price_per_request"
)
existing.default_tiered_pricing = gm_data.get(
"default_tiered_pricing", existing.default_tiered_pricing
)
existing.supported_capabilities = gm_data.get(
"supported_capabilities"
)
existing.config = gm_data.get("config")
existing.is_active = gm_data.get("is_active", True)
existing.updated_at = datetime.now(timezone.utc)
stats["global_models"]["updated"] += 1
else:
# 创建新记录
new_gm = GlobalModel(
id=str(uuid.uuid4()),
name=gm_data["name"],
display_name=gm_data.get("display_name", gm_data["name"]),
default_price_per_request=gm_data.get("default_price_per_request"),
default_tiered_pricing=gm_data.get(
"default_tiered_pricing",
{"tiers": [{"up_to": None, "input_price_per_1m": 0, "output_price_per_1m": 0}]},
),
supported_capabilities=gm_data.get("supported_capabilities"),
config=gm_data.get("config"),
is_active=gm_data.get("is_active", True),
)
db.add(new_gm)
db.flush()
global_model_map[gm_data["name"]] = new_gm.id
stats["global_models"]["created"] += 1
# 导入 Providers
for prov_data in providers_data:
existing_provider = (
db.query(Provider).filter(Provider.name == prov_data["name"]).first()
)
if existing_provider:
provider_id = existing_provider.id
if merge_mode == "skip":
stats["providers"]["skipped"] += 1
# 仍然需要处理 endpoints 和 models如果存在
elif merge_mode == "error":
raise InvalidRequestException(
f"Provider '{prov_data['name']}' 已存在"
)
elif merge_mode == "overwrite":
# 更新现有记录
existing_provider.display_name = prov_data.get(
"display_name", existing_provider.display_name
)
existing_provider.description = prov_data.get("description")
existing_provider.website = prov_data.get("website")
if prov_data.get("billing_type"):
existing_provider.billing_type = ProviderBillingType(
prov_data["billing_type"]
)
existing_provider.monthly_quota_usd = prov_data.get(
"monthly_quota_usd"
)
existing_provider.quota_reset_day = prov_data.get(
"quota_reset_day", 30
)
existing_provider.rpm_limit = prov_data.get("rpm_limit")
existing_provider.provider_priority = prov_data.get(
"provider_priority", 100
)
existing_provider.is_active = prov_data.get("is_active", True)
existing_provider.rate_limit = prov_data.get("rate_limit")
existing_provider.concurrent_limit = prov_data.get(
"concurrent_limit"
)
existing_provider.config = prov_data.get("config")
existing_provider.updated_at = datetime.now(timezone.utc)
stats["providers"]["updated"] += 1
else:
# 创建新 Provider
billing_type = ProviderBillingType.PAY_AS_YOU_GO
if prov_data.get("billing_type"):
billing_type = ProviderBillingType(prov_data["billing_type"])
new_provider = Provider(
id=str(uuid.uuid4()),
name=prov_data["name"],
display_name=prov_data.get("display_name", prov_data["name"]),
description=prov_data.get("description"),
website=prov_data.get("website"),
billing_type=billing_type,
monthly_quota_usd=prov_data.get("monthly_quota_usd"),
quota_reset_day=prov_data.get("quota_reset_day", 30),
rpm_limit=prov_data.get("rpm_limit"),
provider_priority=prov_data.get("provider_priority", 100),
is_active=prov_data.get("is_active", True),
rate_limit=prov_data.get("rate_limit"),
concurrent_limit=prov_data.get("concurrent_limit"),
config=prov_data.get("config"),
)
db.add(new_provider)
db.flush()
provider_id = new_provider.id
stats["providers"]["created"] += 1
# 导入 Endpoints
for ep_data in prov_data.get("endpoints", []):
existing_ep = (
db.query(ProviderEndpoint)
.filter(
ProviderEndpoint.provider_id == provider_id,
ProviderEndpoint.api_format == ep_data["api_format"],
)
.first()
)
if existing_ep:
endpoint_id = existing_ep.id
if merge_mode == "skip":
stats["endpoints"]["skipped"] += 1
elif merge_mode == "error":
raise InvalidRequestException(
f"Endpoint '{ep_data['api_format']}' 已存在于 Provider '{prov_data['name']}'"
)
elif merge_mode == "overwrite":
existing_ep.base_url = ep_data.get(
"base_url", existing_ep.base_url
)
existing_ep.headers = ep_data.get("headers")
existing_ep.timeout = ep_data.get("timeout", 300)
existing_ep.max_retries = ep_data.get("max_retries", 3)
existing_ep.max_concurrent = ep_data.get("max_concurrent")
existing_ep.rate_limit = ep_data.get("rate_limit")
existing_ep.is_active = ep_data.get("is_active", True)
existing_ep.custom_path = ep_data.get("custom_path")
existing_ep.config = ep_data.get("config")
existing_ep.updated_at = datetime.now(timezone.utc)
stats["endpoints"]["updated"] += 1
else:
new_ep = ProviderEndpoint(
id=str(uuid.uuid4()),
provider_id=provider_id,
api_format=ep_data["api_format"],
base_url=ep_data["base_url"],
headers=ep_data.get("headers"),
timeout=ep_data.get("timeout", 300),
max_retries=ep_data.get("max_retries", 3),
max_concurrent=ep_data.get("max_concurrent"),
rate_limit=ep_data.get("rate_limit"),
is_active=ep_data.get("is_active", True),
custom_path=ep_data.get("custom_path"),
config=ep_data.get("config"),
)
db.add(new_ep)
db.flush()
endpoint_id = new_ep.id
stats["endpoints"]["created"] += 1
# 导入 Keys
# 获取当前 endpoint 下所有已有的 keys用于去重
existing_keys = (
db.query(ProviderAPIKey)
.filter(ProviderAPIKey.endpoint_id == endpoint_id)
.all()
)
# 解密已有 keys 用于比对
existing_key_values = set()
for ek in existing_keys:
try:
decrypted = crypto_service.decrypt(ek.api_key)
existing_key_values.add(decrypted)
except Exception:
pass
for key_data in ep_data.get("keys", []):
if not key_data.get("api_key"):
stats["errors"].append(
f"跳过空 API Key (Endpoint: {ep_data['api_format']})"
)
continue
# 检查是否已存在相同的 Key通过明文比对
if key_data["api_key"] in existing_key_values:
stats["keys"]["skipped"] += 1
continue
encrypted_key = crypto_service.encrypt(key_data["api_key"])
new_key = ProviderAPIKey(
id=str(uuid.uuid4()),
endpoint_id=endpoint_id,
api_key=encrypted_key,
name=key_data.get("name"),
note=key_data.get("note"),
rate_multiplier=key_data.get("rate_multiplier", 1.0),
internal_priority=key_data.get("internal_priority", 100),
global_priority=key_data.get("global_priority"),
max_concurrent=key_data.get("max_concurrent"),
rate_limit=key_data.get("rate_limit"),
daily_limit=key_data.get("daily_limit"),
monthly_limit=key_data.get("monthly_limit"),
allowed_models=key_data.get("allowed_models"),
capabilities=key_data.get("capabilities"),
is_active=key_data.get("is_active", True),
)
db.add(new_key)
# 添加到已有集合,防止同一批导入中重复
existing_key_values.add(key_data["api_key"])
stats["keys"]["created"] += 1
# 导入 Models
for model_data in prov_data.get("models", []):
global_model_name = model_data.get("global_model_name")
if not global_model_name:
stats["errors"].append(
f"跳过无 global_model_name 的模型 (Provider: {prov_data['name']})"
)
continue
global_model_id = global_model_map.get(global_model_name)
if not global_model_id:
# 尝试从数据库查找
existing_gm = (
db.query(GlobalModel)
.filter(GlobalModel.name == global_model_name)
.first()
)
if existing_gm:
global_model_id = existing_gm.id
else:
stats["errors"].append(
f"GlobalModel '{global_model_name}' 不存在,跳过模型"
)
continue
existing_model = (
db.query(Model)
.filter(
Model.provider_id == provider_id,
Model.provider_model_name == model_data["provider_model_name"],
)
.first()
)
if existing_model:
if merge_mode == "skip":
stats["models"]["skipped"] += 1
elif merge_mode == "error":
raise InvalidRequestException(
f"Model '{model_data['provider_model_name']}' 已存在于 Provider '{prov_data['name']}'"
)
elif merge_mode == "overwrite":
existing_model.global_model_id = global_model_id
existing_model.provider_model_aliases = model_data.get(
"provider_model_aliases"
)
existing_model.price_per_request = model_data.get(
"price_per_request"
)
existing_model.tiered_pricing = model_data.get(
"tiered_pricing"
)
existing_model.supports_vision = model_data.get(
"supports_vision"
)
existing_model.supports_function_calling = model_data.get(
"supports_function_calling"
)
existing_model.supports_streaming = model_data.get(
"supports_streaming"
)
existing_model.supports_extended_thinking = model_data.get(
"supports_extended_thinking"
)
existing_model.supports_image_generation = model_data.get(
"supports_image_generation"
)
existing_model.is_active = model_data.get("is_active", True)
existing_model.config = model_data.get("config")
existing_model.updated_at = datetime.now(timezone.utc)
stats["models"]["updated"] += 1
else:
new_model = Model(
id=str(uuid.uuid4()),
provider_id=provider_id,
global_model_id=global_model_id,
provider_model_name=model_data["provider_model_name"],
provider_model_aliases=model_data.get(
"provider_model_aliases"
),
price_per_request=model_data.get("price_per_request"),
tiered_pricing=model_data.get("tiered_pricing"),
supports_vision=model_data.get("supports_vision"),
supports_function_calling=model_data.get(
"supports_function_calling"
),
supports_streaming=model_data.get("supports_streaming"),
supports_extended_thinking=model_data.get(
"supports_extended_thinking"
),
supports_image_generation=model_data.get(
"supports_image_generation"
),
is_active=model_data.get("is_active", True),
config=model_data.get("config"),
)
db.add(new_model)
stats["models"]["created"] += 1
db.commit()
# 失效缓存
from src.services.cache.invalidation import get_cache_invalidation_service
cache_service = get_cache_invalidation_service()
cache_service.invalidate_all()
return {
"message": "配置导入成功",
"stats": stats,
}
except InvalidRequestException:
db.rollback()
raise
except Exception as e:
db.rollback()
raise InvalidRequestException(f"导入失败: {str(e)}")
class AdminExportUsersAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
"""导出用户数据(保留加密数据,排除管理员)"""
from datetime import datetime, timezone
from src.core.enums import UserRole
from src.models.database import ApiKey, User
db = context.db
# 导出 Users排除管理员
users = db.query(User).filter(
User.is_deleted.is_(False),
User.role != UserRole.ADMIN
).all()
users_data = []
for user in users:
# 导出用户的 API Keys保留加密数据
api_keys = db.query(ApiKey).filter(ApiKey.user_id == user.id).all()
api_keys_data = []
for key in api_keys:
api_keys_data.append(
{
"key_hash": key.key_hash,
"key_encrypted": key.key_encrypted,
"name": key.name,
"is_standalone": key.is_standalone,
"balance_used_usd": key.balance_used_usd,
"current_balance_usd": key.current_balance_usd,
"allowed_providers": key.allowed_providers,
"allowed_endpoints": key.allowed_endpoints,
"allowed_api_formats": key.allowed_api_formats,
"allowed_models": key.allowed_models,
"rate_limit": key.rate_limit,
"concurrent_limit": key.concurrent_limit,
"force_capabilities": key.force_capabilities,
"is_active": key.is_active,
"auto_delete_on_expiry": key.auto_delete_on_expiry,
"total_requests": key.total_requests,
"total_cost_usd": key.total_cost_usd,
}
)
users_data.append(
{
"email": user.email,
"username": user.username,
"password_hash": user.password_hash,
"role": user.role.value if user.role else "user",
"allowed_providers": user.allowed_providers,
"allowed_endpoints": user.allowed_endpoints,
"allowed_models": user.allowed_models,
"model_capability_settings": user.model_capability_settings,
"quota_usd": user.quota_usd,
"used_usd": user.used_usd,
"total_usd": user.total_usd,
"is_active": user.is_active,
"api_keys": api_keys_data,
}
)
return {
"version": "1.0",
"exported_at": datetime.now(timezone.utc).isoformat(),
"users": users_data,
}
class AdminImportUsersAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
"""导入用户数据"""
import uuid
from datetime import datetime, timezone
from src.core.enums import UserRole
from src.models.database import ApiKey, User
# 检查请求体大小
if context.raw_body and len(context.raw_body) > MAX_IMPORT_SIZE:
raise InvalidRequestException("请求体大小不能超过 10MB")
db = context.db
payload = context.ensure_json_body()
# 验证配置版本
version = payload.get("version")
if version != "1.0":
raise InvalidRequestException(f"不支持的配置版本: {version}")
# 获取导入选项
merge_mode = payload.get("merge_mode", "skip") # skip, overwrite, error
users_data = payload.get("users", [])
stats = {
"users": {"created": 0, "updated": 0, "skipped": 0},
"api_keys": {"created": 0, "skipped": 0},
"errors": [],
}
try:
for user_data in users_data:
# 跳过管理员角色的导入(不区分大小写)
role_str = str(user_data.get("role", "")).lower()
if role_str == "admin":
stats["errors"].append(f"跳过管理员用户: {user_data.get('email')}")
stats["users"]["skipped"] += 1
continue
existing_user = (
db.query(User).filter(User.email == user_data["email"]).first()
)
if existing_user:
user_id = existing_user.id
if merge_mode == "skip":
stats["users"]["skipped"] += 1
elif merge_mode == "error":
raise InvalidRequestException(
f"用户 '{user_data['email']}' 已存在"
)
elif merge_mode == "overwrite":
# 更新现有用户
existing_user.username = user_data.get(
"username", existing_user.username
)
if user_data.get("password_hash"):
existing_user.password_hash = user_data["password_hash"]
if user_data.get("role"):
existing_user.role = UserRole(user_data["role"])
existing_user.allowed_providers = user_data.get("allowed_providers")
existing_user.allowed_endpoints = user_data.get("allowed_endpoints")
existing_user.allowed_models = user_data.get("allowed_models")
existing_user.model_capability_settings = user_data.get(
"model_capability_settings"
)
existing_user.quota_usd = user_data.get("quota_usd")
existing_user.used_usd = user_data.get("used_usd", 0.0)
existing_user.total_usd = user_data.get("total_usd", 0.0)
existing_user.is_active = user_data.get("is_active", True)
existing_user.updated_at = datetime.now(timezone.utc)
stats["users"]["updated"] += 1
else:
# 创建新用户
role = UserRole.USER
if user_data.get("role"):
role = UserRole(user_data["role"])
new_user = User(
id=str(uuid.uuid4()),
email=user_data["email"],
username=user_data.get("username", user_data["email"].split("@")[0]),
password_hash=user_data.get("password_hash", ""),
role=role,
allowed_providers=user_data.get("allowed_providers"),
allowed_endpoints=user_data.get("allowed_endpoints"),
allowed_models=user_data.get("allowed_models"),
model_capability_settings=user_data.get("model_capability_settings"),
quota_usd=user_data.get("quota_usd"),
used_usd=user_data.get("used_usd", 0.0),
total_usd=user_data.get("total_usd", 0.0),
is_active=user_data.get("is_active", True),
)
db.add(new_user)
db.flush()
user_id = new_user.id
stats["users"]["created"] += 1
# 导入 API Keys
for key_data in user_data.get("api_keys", []):
# 检查是否已存在相同的 key_hash
if key_data.get("key_hash"):
existing_key = (
db.query(ApiKey)
.filter(ApiKey.key_hash == key_data["key_hash"])
.first()
)
if existing_key:
stats["api_keys"]["skipped"] += 1
continue
new_key = ApiKey(
id=str(uuid.uuid4()),
user_id=user_id,
key_hash=key_data.get("key_hash", ""),
key_encrypted=key_data.get("key_encrypted"),
name=key_data.get("name"),
is_standalone=key_data.get("is_standalone", False),
balance_used_usd=key_data.get("balance_used_usd", 0.0),
current_balance_usd=key_data.get("current_balance_usd"),
allowed_providers=key_data.get("allowed_providers"),
allowed_endpoints=key_data.get("allowed_endpoints"),
allowed_api_formats=key_data.get("allowed_api_formats"),
allowed_models=key_data.get("allowed_models"),
rate_limit=key_data.get("rate_limit", 100),
concurrent_limit=key_data.get("concurrent_limit", 5),
force_capabilities=key_data.get("force_capabilities"),
is_active=key_data.get("is_active", True),
auto_delete_on_expiry=key_data.get("auto_delete_on_expiry", False),
total_requests=key_data.get("total_requests", 0),
total_cost_usd=key_data.get("total_cost_usd", 0.0),
)
db.add(new_key)
stats["api_keys"]["created"] += 1
db.commit()
return {
"message": "用户数据导入成功",
"stats": stats,
}
except InvalidRequestException:
db.rollback()
raise
except Exception as e:
db.rollback()
raise InvalidRequestException(f"导入失败: {str(e)}")

View File

@@ -731,8 +731,15 @@ class DashboardDailyStatsAdapter(DashboardAdapter):
)
# stats_daily.date 存储的是业务日期对应的 UTC 开始时间
# 需要转回业务时区再取日期,才能与日期序列匹配
def _to_business_date_str(value: datetime) -> str:
if value.tzinfo is None:
value_utc = value.replace(tzinfo=timezone.utc)
else:
value_utc = value.astimezone(timezone.utc)
return value_utc.astimezone(app_tz).date().isoformat()
stats_map = {
stat.date.replace(tzinfo=timezone.utc).astimezone(app_tz).date().isoformat(): {
_to_business_date_str(stat.date): {
"requests": stat.total_requests,
"tokens": stat.input_tokens + stat.output_tokens + stat.cache_creation_tokens + stat.cache_read_tokens,
"cost": stat.total_cost,
@@ -790,6 +797,38 @@ class DashboardDailyStatsAdapter(DashboardAdapter):
"unique_providers": today_unique_providers,
"fallback_count": today_fallback_count,
}
# 历史预聚合缺失时兜底:按业务日范围实时计算(仅补最近少量缺失,避免全表扫描)
yesterday_date = today_local.date() - timedelta(days=1)
historical_end = min(end_date_local.date(), yesterday_date)
missing_dates: list[str] = []
cursor = start_date_local.date()
while cursor <= historical_end:
date_str = cursor.isoformat()
if date_str not in stats_map:
missing_dates.append(date_str)
cursor += timedelta(days=1)
if missing_dates:
for date_str in missing_dates[-7:]:
target_local = datetime.fromisoformat(date_str).replace(tzinfo=app_tz)
computed = StatsAggregatorService.compute_daily_stats(db, target_local)
stats_map[date_str] = {
"requests": computed["total_requests"],
"tokens": (
computed["input_tokens"]
+ computed["output_tokens"]
+ computed["cache_creation_tokens"]
+ computed["cache_read_tokens"]
),
"cost": computed["total_cost"],
"avg_response_time": computed["avg_response_time_ms"] / 1000.0
if computed["avg_response_time_ms"]
else 0,
"unique_models": computed["unique_models"],
"unique_providers": computed["unique_providers"],
"fallback_count": computed["fallback_count"],
}
else:
# 普通用户:仍需实时查询(用户级预聚合可选)
query = db.query(Usage).filter(

View File

@@ -266,8 +266,11 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
if mapping and mapping.model:
# 使用 select_provider_model_name 支持别名功能
# 传入 api_key.id 作为 affinity_key实现相同用户稳定选择同一别名
# 传入 api_format 用于过滤适用的别名作用域
affinity_key = self.api_key.id if self.api_key else None
mapped_name = mapping.model.select_provider_model_name(affinity_key)
mapped_name = mapping.model.select_provider_model_name(
affinity_key, api_format=self.FORMAT_ID
)
logger.debug(f"[Chat] 模型映射: {source_model} -> {mapped_name}")
return mapped_name

View File

@@ -155,8 +155,11 @@ class CliMessageHandlerBase(BaseMessageHandler):
if mapping and mapping.model:
# 使用 select_provider_model_name 支持别名功能
# 传入 api_key.id 作为 affinity_key实现相同用户稳定选择同一别名
# 传入 api_format 用于过滤适用的别名作用域
affinity_key = self.api_key.id if self.api_key else None
mapped_name = mapping.model.select_provider_model_name(affinity_key)
mapped_name = mapping.model.select_provider_model_name(
affinity_key, api_format=self.FORMAT_ID
)
logger.debug(f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)")
return mapped_name

View File

@@ -813,7 +813,9 @@ class Model(Base):
def get_effective_supports_image_generation(self) -> bool:
return self._get_effective_capability("supports_image_generation", False)
def select_provider_model_name(self, affinity_key: Optional[str] = None) -> str:
def select_provider_model_name(
self, affinity_key: Optional[str] = None, api_format: Optional[str] = None
) -> str:
"""按优先级选择要使用的 Provider 模型名称
如果配置了 provider_model_aliases按优先级选择数字越小越优先
@@ -822,6 +824,7 @@ class Model(Base):
Args:
affinity_key: 用于哈希分散的亲和键(如用户 API Key 哈希),确保同一用户稳定选择同一别名
api_format: 当前请求的 API 格式(如 CLAUDE、OPENAI 等),用于过滤适用的别名
"""
import hashlib
@@ -840,6 +843,13 @@ class Model(Base):
if not isinstance(name, str) or not name.strip():
continue
# 检查 api_formats 作用域(如果配置了且当前有 api_format
alias_api_formats = raw.get("api_formats")
if api_format and alias_api_formats:
# 如果配置了作用域,只有匹配时才生效
if isinstance(alias_api_formats, list) and api_format not in alias_api_formats:
continue
raw_priority = raw.get("priority", 1)
try:
priority = int(raw_priority)

View File

@@ -238,8 +238,8 @@ class GlobalModelResponse(BaseModel):
# 按次计费配置
default_price_per_request: Optional[float] = Field(None, description="每次请求固定费用")
# 阶梯计费配置
default_tiered_pricing: TieredPricingConfig = Field(
..., description="阶梯计费配置"
default_tiered_pricing: Optional[TieredPricingConfig] = Field(
default=None, description="阶梯计费配置"
)
# Key 能力配置 - 模型支持的能力列表
supported_capabilities: Optional[List[str]] = Field(

View File

@@ -35,6 +35,7 @@ class CleanupScheduler:
def __init__(self):
self.running = False
self._interval_tasks = []
self._stats_aggregation_lock = asyncio.Lock()
async def start(self):
"""启动调度器"""
@@ -56,6 +57,14 @@ class CleanupScheduler:
job_id="stats_aggregation",
name="统计数据聚合",
)
# 统计聚合补偿任务 - 每 30 分钟检查缺失并回填
scheduler.add_interval_job(
self._scheduled_stats_aggregation,
minutes=30,
job_id="stats_aggregation_backfill",
name="统计数据聚合补偿",
backfill=True,
)
# 清理任务 - 凌晨 3 点执行
scheduler.add_cron_job(
@@ -115,9 +124,9 @@ class CleanupScheduler:
# ========== 任务函数APScheduler 直接调用异步函数) ==========
async def _scheduled_stats_aggregation(self):
async def _scheduled_stats_aggregation(self, backfill: bool = False):
"""统计聚合任务(定时调用)"""
await self._perform_stats_aggregation()
await self._perform_stats_aggregation(backfill=backfill)
async def _scheduled_cleanup(self):
"""清理任务(定时调用)"""
@@ -144,136 +153,157 @@ class CleanupScheduler:
Args:
backfill: 是否回填历史数据(启动时检查缺失的日期)
"""
db = create_session()
try:
# 检查是否启用统计聚合
if not SystemConfigService.get_config(db, "enable_stats_aggregation", True):
logger.info("统计聚合已禁用,跳过聚合任务")
return
if self._stats_aggregation_lock.locked():
logger.info("统计聚合任务正在运行,跳过本次触发")
return
logger.info("开始执行统计数据聚合...")
from src.models.database import StatsDaily, User as DBUser
from src.services.system.scheduler import APP_TIMEZONE
from zoneinfo import ZoneInfo
# 使用业务时区计算日期,确保与定时任务触发时间一致
# 定时任务在 Asia/Shanghai 凌晨 1 点触发,此时应聚合 Asia/Shanghai 的"昨天"
app_tz = ZoneInfo(APP_TIMEZONE)
now_local = datetime.now(app_tz)
today_local = now_local.replace(hour=0, minute=0, second=0, microsecond=0)
if backfill:
# 启动时检查并回填缺失的日期
from src.models.database import StatsSummary
summary = db.query(StatsSummary).first()
if not summary:
# 首次运行,回填所有历史数据
logger.info("检测到首次运行,开始回填历史统计数据...")
days_to_backfill = SystemConfigService.get_config(
db, "stats_backfill_days", 365
)
count = StatsAggregatorService.backfill_historical_data(
db, days=days_to_backfill
)
logger.info(f"历史数据回填完成,共 {count}")
async with self._stats_aggregation_lock:
db = create_session()
try:
# 检查是否启用统计聚合
if not SystemConfigService.get_config(db, "enable_stats_aggregation", True):
logger.info("统计聚合已禁用,跳过聚合任务")
return
# 非首次运行,检查最近是否有缺失的日期需要回填
latest_stat = (
db.query(StatsDaily)
.order_by(StatsDaily.date.desc())
.first()
)
logger.info("开始执行统计数据聚合...")
if latest_stat:
latest_date_utc = latest_stat.date
if latest_date_utc.tzinfo is None:
latest_date_utc = latest_date_utc.replace(tzinfo=timezone.utc)
else:
latest_date_utc = latest_date_utc.astimezone(timezone.utc)
from src.models.database import StatsDaily, User as DBUser
from src.services.system.scheduler import APP_TIMEZONE
from zoneinfo import ZoneInfo
# 使用业务日期计算缺失区间(避免用 UTC 年月日导致日期偏移,且对 DST 更安全)
latest_business_date = latest_date_utc.astimezone(app_tz).date()
yesterday_business_date = (today_local.date() - timedelta(days=1))
missing_start_date = latest_business_date + timedelta(days=1)
# 使用业务时区计算日期,确保与定时任务触发时间一致
# 定时任务在 Asia/Shanghai 凌晨 1 点触发,此时应聚合 Asia/Shanghai 的"昨天"
app_tz = ZoneInfo(APP_TIMEZONE)
now_local = datetime.now(app_tz)
today_local = now_local.replace(hour=0, minute=0, second=0, microsecond=0)
if missing_start_date <= yesterday_business_date:
missing_days = (yesterday_business_date - missing_start_date).days + 1
logger.info(
f"检测到缺失 {missing_days} 天的统计数据 "
f"({missing_start_date} ~ {yesterday_business_date}),开始回填..."
if backfill:
# 启动时检查并回填缺失的日期
from src.models.database import StatsSummary
summary = db.query(StatsSummary).first()
if not summary:
# 首次运行,回填所有历史数据
logger.info("检测到首次运行,开始回填历史统计数据...")
days_to_backfill = SystemConfigService.get_config(
db, "stats_backfill_days", 365
)
count = StatsAggregatorService.backfill_historical_data(
db, days=days_to_backfill
)
logger.info(f"历史数据回填完成,共 {count}")
return
current_date = missing_start_date
users = db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
# 非首次运行,检查最近是否有缺失的日期需要回填
latest_stat = db.query(StatsDaily).order_by(StatsDaily.date.desc()).first()
while current_date <= yesterday_business_date:
try:
current_date_local = datetime.combine(
current_date, datetime.min.time(), tzinfo=app_tz
if latest_stat:
latest_date_utc = latest_stat.date
if latest_date_utc.tzinfo is None:
latest_date_utc = latest_date_utc.replace(tzinfo=timezone.utc)
else:
latest_date_utc = latest_date_utc.astimezone(timezone.utc)
# 使用业务日期计算缺失区间(避免用 UTC 年月日导致日期偏移,且对 DST 更安全)
latest_business_date = latest_date_utc.astimezone(app_tz).date()
yesterday_business_date = today_local.date() - timedelta(days=1)
missing_start_date = latest_business_date + timedelta(days=1)
if missing_start_date <= yesterday_business_date:
missing_days = (
yesterday_business_date - missing_start_date
).days + 1
# 限制最大回填天数,防止停机很久后一次性回填太多
max_backfill_days: int = SystemConfigService.get_config(
db, "max_stats_backfill_days", 30
) or 30
if missing_days > max_backfill_days:
logger.warning(
f"缺失 {missing_days} 天数据超过最大回填限制 "
f"{max_backfill_days} 天,只回填最近 {max_backfill_days}"
)
StatsAggregatorService.aggregate_daily_stats(db, current_date_local)
# 聚合用户数据
for (user_id,) in users:
try:
StatsAggregatorService.aggregate_user_daily_stats(
db, user_id, current_date_local
)
except Exception as e:
logger.warning(
f"回填用户 {user_id} 日期 {current_date} 失败: {e}"
)
try:
db.rollback()
except Exception:
pass
except Exception as e:
logger.warning(f"回填日期 {current_date} 失败: {e}")
missing_start_date = yesterday_business_date - timedelta(
days=max_backfill_days - 1
)
missing_days = max_backfill_days
logger.info(
f"检测到缺失 {missing_days} 天的统计数据 "
f"({missing_start_date} ~ {yesterday_business_date}),开始回填..."
)
current_date = missing_start_date
users = (
db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
)
while current_date <= yesterday_business_date:
try:
db.rollback()
except Exception:
pass
current_date_local = datetime.combine(
current_date, datetime.min.time(), tzinfo=app_tz
)
StatsAggregatorService.aggregate_daily_stats(
db, current_date_local
)
for (user_id,) in users:
try:
StatsAggregatorService.aggregate_user_daily_stats(
db, user_id, current_date_local
)
except Exception as e:
logger.warning(
f"回填用户 {user_id} 日期 {current_date} 失败: {e}"
)
try:
db.rollback()
except Exception:
pass
except Exception as e:
logger.warning(f"回填日期 {current_date} 失败: {e}")
try:
db.rollback()
except Exception:
pass
current_date += timedelta(days=1)
current_date += timedelta(days=1)
# 更新全局汇总
StatsAggregatorService.update_summary(db)
logger.info(f"缺失数据回填完成,共 {missing_days}")
else:
logger.info("统计数据已是最新,无需回填")
return
StatsAggregatorService.update_summary(db)
logger.info(f"缺失数据回填完成,共 {missing_days}")
else:
logger.info("统计数据已是最新,无需回填")
return
# 定时任务:聚合昨天的数据
# 注意aggregate_daily_stats 期望业务时区的日期,不是 UTC
yesterday_local = today_local - timedelta(days=1)
# 定时任务:聚合昨天的数据
yesterday_local = today_local - timedelta(days=1)
StatsAggregatorService.aggregate_daily_stats(db, yesterday_local)
StatsAggregatorService.aggregate_daily_stats(db, yesterday_local)
# 聚合所有用户的昨日数据
users = db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
for (user_id,) in users:
try:
StatsAggregatorService.aggregate_user_daily_stats(db, user_id, yesterday_local)
except Exception as e:
logger.warning(f"聚合用户 {user_id} 统计数据失败: {e}")
# 回滚当前用户的失败操作,继续处理其他用户
users = db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
for (user_id,) in users:
try:
db.rollback()
except Exception:
pass
StatsAggregatorService.aggregate_user_daily_stats(
db, user_id, yesterday_local
)
except Exception as e:
logger.warning(f"聚合用户 {user_id} 统计数据失败: {e}")
try:
db.rollback()
except Exception:
pass
# 更新全局汇总
StatsAggregatorService.update_summary(db)
StatsAggregatorService.update_summary(db)
logger.info("统计数据聚合完成")
logger.info("统计数据聚合完成")
except Exception as e:
logger.exception(f"统计聚合任务执行失败: {e}")
db.rollback()
finally:
db.close()
except Exception as e:
logger.exception(f"统计聚合任务执行失败: {e}")
try:
db.rollback()
except Exception:
pass
finally:
db.close()
async def _perform_pending_cleanup(self):
"""执行 pending 状态清理"""

View File

@@ -56,65 +56,44 @@ class StatsAggregatorService:
"""统计数据聚合服务"""
@staticmethod
def aggregate_daily_stats(db: Session, date: datetime) -> StatsDaily:
"""聚合指定日期的统计数据
Args:
db: 数据库会话
date: 要聚合的业务日期(使用 APP_TIMEZONE 时区)
Returns:
StatsDaily 记录
"""
# 将业务日期转换为 UTC 时间范围
def compute_daily_stats(db: Session, date: datetime) -> dict:
"""计算指定业务日期的统计数据(不写入数据库)"""
day_start, day_end = _get_business_day_range(date)
# stats_daily.date 存储的是业务日期对应的 UTC 开始时间
# 检查是否已存在该日期的记录
existing = db.query(StatsDaily).filter(StatsDaily.date == day_start).first()
if existing:
stats = existing
else:
stats = StatsDaily(id=str(uuid.uuid4()), date=day_start)
# 基础请求统计
base_query = db.query(Usage).filter(
and_(Usage.created_at >= day_start, Usage.created_at < day_end)
)
total_requests = base_query.count()
# 如果没有请求,直接返回空记录
if total_requests == 0:
stats.total_requests = 0
stats.success_requests = 0
stats.error_requests = 0
stats.input_tokens = 0
stats.output_tokens = 0
stats.cache_creation_tokens = 0
stats.cache_read_tokens = 0
stats.total_cost = 0.0
stats.actual_total_cost = 0.0
stats.input_cost = 0.0
stats.output_cost = 0.0
stats.cache_creation_cost = 0.0
stats.cache_read_cost = 0.0
stats.avg_response_time_ms = 0.0
stats.fallback_count = 0
return {
"day_start": day_start,
"total_requests": 0,
"success_requests": 0,
"error_requests": 0,
"input_tokens": 0,
"output_tokens": 0,
"cache_creation_tokens": 0,
"cache_read_tokens": 0,
"total_cost": 0.0,
"actual_total_cost": 0.0,
"input_cost": 0.0,
"output_cost": 0.0,
"cache_creation_cost": 0.0,
"cache_read_cost": 0.0,
"avg_response_time_ms": 0.0,
"fallback_count": 0,
"unique_models": 0,
"unique_providers": 0,
}
if not existing:
db.add(stats)
db.commit()
return stats
# 错误请求数
error_requests = (
base_query.filter(
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
).count()
)
# Token 和成本聚合
aggregated = (
db.query(
func.sum(Usage.input_tokens).label("input_tokens"),
@@ -157,7 +136,6 @@ class StatsAggregatorService:
or 0
)
# 使用维度统计
unique_models = (
db.query(func.count(func.distinct(Usage.model)))
.filter(and_(Usage.created_at >= day_start, Usage.created_at < day_end))
@@ -171,31 +149,74 @@ class StatsAggregatorService:
or 0
)
return {
"day_start": day_start,
"total_requests": total_requests,
"success_requests": total_requests - error_requests,
"error_requests": error_requests,
"input_tokens": int(aggregated.input_tokens or 0) if aggregated else 0,
"output_tokens": int(aggregated.output_tokens or 0) if aggregated else 0,
"cache_creation_tokens": int(aggregated.cache_creation_tokens or 0) if aggregated else 0,
"cache_read_tokens": int(aggregated.cache_read_tokens or 0) if aggregated else 0,
"total_cost": float(aggregated.total_cost or 0) if aggregated else 0.0,
"actual_total_cost": float(aggregated.actual_total_cost or 0) if aggregated else 0.0,
"input_cost": float(aggregated.input_cost or 0) if aggregated else 0.0,
"output_cost": float(aggregated.output_cost or 0) if aggregated else 0.0,
"cache_creation_cost": float(aggregated.cache_creation_cost or 0) if aggregated else 0.0,
"cache_read_cost": float(aggregated.cache_read_cost or 0) if aggregated else 0.0,
"avg_response_time_ms": float(aggregated.avg_response_time or 0) if aggregated else 0.0,
"fallback_count": fallback_count,
"unique_models": unique_models,
"unique_providers": unique_providers,
}
@staticmethod
def aggregate_daily_stats(db: Session, date: datetime) -> StatsDaily:
"""聚合指定日期的统计数据
Args:
db: 数据库会话
date: 要聚合的业务日期(使用 APP_TIMEZONE 时区)
Returns:
StatsDaily 记录
"""
computed = StatsAggregatorService.compute_daily_stats(db, date)
day_start = computed["day_start"]
# stats_daily.date 存储的是业务日期对应的 UTC 开始时间
# 检查是否已存在该日期的记录
existing = db.query(StatsDaily).filter(StatsDaily.date == day_start).first()
if existing:
stats = existing
else:
stats = StatsDaily(id=str(uuid.uuid4()), date=day_start)
# 更新统计记录
stats.total_requests = total_requests
stats.success_requests = total_requests - error_requests
stats.error_requests = error_requests
stats.input_tokens = int(aggregated.input_tokens or 0)
stats.output_tokens = int(aggregated.output_tokens or 0)
stats.cache_creation_tokens = int(aggregated.cache_creation_tokens or 0)
stats.cache_read_tokens = int(aggregated.cache_read_tokens or 0)
stats.total_cost = float(aggregated.total_cost or 0)
stats.actual_total_cost = float(aggregated.actual_total_cost or 0)
stats.input_cost = float(aggregated.input_cost or 0)
stats.output_cost = float(aggregated.output_cost or 0)
stats.cache_creation_cost = float(aggregated.cache_creation_cost or 0)
stats.cache_read_cost = float(aggregated.cache_read_cost or 0)
stats.avg_response_time_ms = float(aggregated.avg_response_time or 0)
stats.fallback_count = fallback_count
stats.unique_models = unique_models
stats.unique_providers = unique_providers
stats.total_requests = computed["total_requests"]
stats.success_requests = computed["success_requests"]
stats.error_requests = computed["error_requests"]
stats.input_tokens = computed["input_tokens"]
stats.output_tokens = computed["output_tokens"]
stats.cache_creation_tokens = computed["cache_creation_tokens"]
stats.cache_read_tokens = computed["cache_read_tokens"]
stats.total_cost = computed["total_cost"]
stats.actual_total_cost = computed["actual_total_cost"]
stats.input_cost = computed["input_cost"]
stats.output_cost = computed["output_cost"]
stats.cache_creation_cost = computed["cache_creation_cost"]
stats.cache_read_cost = computed["cache_read_cost"]
stats.avg_response_time_ms = computed["avg_response_time_ms"]
stats.fallback_count = computed["fallback_count"]
stats.unique_models = computed["unique_models"]
stats.unique_providers = computed["unique_providers"]
if not existing:
db.add(stats)
db.commit()
# 日志使用业务日期(输入参数),而不是 UTC 日期
logger.info(f"[StatsAggregator] 聚合日期 {date.date()} 完成: {total_requests} 请求")
logger.info(f"[StatsAggregator] 聚合日期 {date.date()} 完成: {computed['total_requests']} 请求")
return stats
@staticmethod