mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-07 02:02:27 +08:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1dac4cb156 | ||
|
|
50abb55c94 | ||
|
|
73d3c9d3e4 | ||
|
|
d24c3885ab |
@@ -20,10 +20,10 @@ depends_on = None
|
|||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
# Create ENUM types
|
# Create ENUM types (with IF NOT EXISTS for idempotency)
|
||||||
op.execute("CREATE TYPE userrole AS ENUM ('admin', 'user')")
|
op.execute("DO $$ BEGIN CREATE TYPE userrole AS ENUM ('admin', 'user'); EXCEPTION WHEN duplicate_object THEN NULL; END $$")
|
||||||
op.execute(
|
op.execute(
|
||||||
"CREATE TYPE providerbillingtype AS ENUM ('monthly_quota', 'pay_as_you_go', 'free_tier')"
|
"DO $$ BEGIN CREATE TYPE providerbillingtype AS ENUM ('monthly_quota', 'pay_as_you_go', 'free_tier'); EXCEPTION WHEN duplicate_object THEN NULL; END $$"
|
||||||
)
|
)
|
||||||
|
|
||||||
# ==================== users ====================
|
# ==================== users ====================
|
||||||
@@ -35,7 +35,7 @@ def upgrade() -> None:
|
|||||||
sa.Column("password_hash", sa.String(255), nullable=False),
|
sa.Column("password_hash", sa.String(255), nullable=False),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"role",
|
"role",
|
||||||
sa.Enum("admin", "user", name="userrole", create_type=False),
|
postgresql.ENUM("admin", "user", name="userrole", create_type=False),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
server_default="user",
|
server_default="user",
|
||||||
),
|
),
|
||||||
@@ -67,7 +67,7 @@ def upgrade() -> None:
|
|||||||
sa.Column("website", sa.String(500), nullable=True),
|
sa.Column("website", sa.String(500), nullable=True),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"billing_type",
|
"billing_type",
|
||||||
sa.Enum(
|
postgresql.ENUM(
|
||||||
"monthly_quota", "pay_as_you_go", "free_tier", name="providerbillingtype", create_type=False
|
"monthly_quota", "pay_as_you_go", "free_tier", name="providerbillingtype", create_type=False
|
||||||
),
|
),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
|
|||||||
@@ -1,5 +1,179 @@
|
|||||||
import apiClient from './client'
|
import apiClient from './client'
|
||||||
|
|
||||||
|
// 配置导出数据结构
|
||||||
|
export interface ConfigExportData {
|
||||||
|
version: string
|
||||||
|
exported_at: string
|
||||||
|
global_models: GlobalModelExport[]
|
||||||
|
providers: ProviderExport[]
|
||||||
|
}
|
||||||
|
|
||||||
|
// 用户导出数据结构
|
||||||
|
export interface UsersExportData {
|
||||||
|
version: string
|
||||||
|
exported_at: string
|
||||||
|
users: UserExport[]
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface UserExport {
|
||||||
|
email: string
|
||||||
|
username: string
|
||||||
|
password_hash: string
|
||||||
|
role: string
|
||||||
|
allowed_providers?: string[] | null
|
||||||
|
allowed_endpoints?: string[] | null
|
||||||
|
allowed_models?: string[] | null
|
||||||
|
model_capability_settings?: any
|
||||||
|
quota_usd?: number | null
|
||||||
|
used_usd?: number
|
||||||
|
total_usd?: number
|
||||||
|
is_active: boolean
|
||||||
|
api_keys: UserApiKeyExport[]
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface UserApiKeyExport {
|
||||||
|
key_hash: string
|
||||||
|
key_encrypted?: string | null
|
||||||
|
name?: string | null
|
||||||
|
is_standalone: boolean
|
||||||
|
balance_used_usd?: number
|
||||||
|
current_balance_usd?: number | null
|
||||||
|
allowed_providers?: string[] | null
|
||||||
|
allowed_endpoints?: string[] | null
|
||||||
|
allowed_api_formats?: string[] | null
|
||||||
|
allowed_models?: string[] | null
|
||||||
|
rate_limit?: number
|
||||||
|
concurrent_limit?: number | null
|
||||||
|
force_capabilities?: any
|
||||||
|
is_active: boolean
|
||||||
|
auto_delete_on_expiry?: boolean
|
||||||
|
total_requests?: number
|
||||||
|
total_cost_usd?: number
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface GlobalModelExport {
|
||||||
|
name: string
|
||||||
|
display_name: string
|
||||||
|
default_price_per_request?: number | null
|
||||||
|
default_tiered_pricing: any
|
||||||
|
supported_capabilities?: string[] | null
|
||||||
|
config?: any
|
||||||
|
is_active: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ProviderExport {
|
||||||
|
name: string
|
||||||
|
display_name: string
|
||||||
|
description?: string | null
|
||||||
|
website?: string | null
|
||||||
|
billing_type?: string | null
|
||||||
|
monthly_quota_usd?: number | null
|
||||||
|
quota_reset_day?: number
|
||||||
|
rpm_limit?: number | null
|
||||||
|
provider_priority?: number
|
||||||
|
is_active: boolean
|
||||||
|
rate_limit?: number | null
|
||||||
|
concurrent_limit?: number | null
|
||||||
|
config?: any
|
||||||
|
endpoints: EndpointExport[]
|
||||||
|
models: ModelExport[]
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface EndpointExport {
|
||||||
|
api_format: string
|
||||||
|
base_url: string
|
||||||
|
headers?: any
|
||||||
|
timeout?: number
|
||||||
|
max_retries?: number
|
||||||
|
max_concurrent?: number | null
|
||||||
|
rate_limit?: number | null
|
||||||
|
is_active: boolean
|
||||||
|
custom_path?: string | null
|
||||||
|
config?: any
|
||||||
|
keys: KeyExport[]
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface KeyExport {
|
||||||
|
api_key: string
|
||||||
|
name?: string | null
|
||||||
|
note?: string | null
|
||||||
|
rate_multiplier?: number
|
||||||
|
internal_priority?: number
|
||||||
|
global_priority?: number | null
|
||||||
|
max_concurrent?: number | null
|
||||||
|
rate_limit?: number | null
|
||||||
|
daily_limit?: number | null
|
||||||
|
monthly_limit?: number | null
|
||||||
|
allowed_models?: string[] | null
|
||||||
|
capabilities?: any
|
||||||
|
is_active: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ModelExport {
|
||||||
|
global_model_name: string | null
|
||||||
|
provider_model_name: string
|
||||||
|
provider_model_aliases?: any
|
||||||
|
price_per_request?: number | null
|
||||||
|
tiered_pricing?: any
|
||||||
|
supports_vision?: boolean | null
|
||||||
|
supports_function_calling?: boolean | null
|
||||||
|
supports_streaming?: boolean | null
|
||||||
|
supports_extended_thinking?: boolean | null
|
||||||
|
supports_image_generation?: boolean | null
|
||||||
|
is_active: boolean
|
||||||
|
config?: any
|
||||||
|
}
|
||||||
|
|
||||||
|
// Provider 模型查询响应
|
||||||
|
export interface ProviderModelsQueryResponse {
|
||||||
|
success: boolean
|
||||||
|
data: {
|
||||||
|
models: Array<{
|
||||||
|
id: string
|
||||||
|
object?: string
|
||||||
|
created?: number
|
||||||
|
owned_by?: string
|
||||||
|
display_name?: string
|
||||||
|
api_format?: string
|
||||||
|
}>
|
||||||
|
error?: string
|
||||||
|
}
|
||||||
|
provider: {
|
||||||
|
id: string
|
||||||
|
name: string
|
||||||
|
display_name: string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ConfigImportRequest extends ConfigExportData {
|
||||||
|
merge_mode: 'skip' | 'overwrite' | 'error'
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface UsersImportRequest extends UsersExportData {
|
||||||
|
merge_mode: 'skip' | 'overwrite' | 'error'
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface UsersImportResponse {
|
||||||
|
message: string
|
||||||
|
stats: {
|
||||||
|
users: { created: number; updated: number; skipped: number }
|
||||||
|
api_keys: { created: number; skipped: number }
|
||||||
|
errors: string[]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ConfigImportResponse {
|
||||||
|
message: string
|
||||||
|
stats: {
|
||||||
|
global_models: { created: number; updated: number; skipped: number }
|
||||||
|
providers: { created: number; updated: number; skipped: number }
|
||||||
|
endpoints: { created: number; updated: number; skipped: number }
|
||||||
|
keys: { created: number; updated: number; skipped: number }
|
||||||
|
models: { created: number; updated: number; skipped: number }
|
||||||
|
errors: string[]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// API密钥管理相关接口定义
|
// API密钥管理相关接口定义
|
||||||
export interface AdminApiKey {
|
export interface AdminApiKey {
|
||||||
id: string // UUID
|
id: string // UUID
|
||||||
@@ -173,5 +347,44 @@ export const adminApi = {
|
|||||||
'/api/admin/system/api-formats'
|
'/api/admin/system/api-formats'
|
||||||
)
|
)
|
||||||
return response.data
|
return response.data
|
||||||
|
},
|
||||||
|
|
||||||
|
// 导出配置
|
||||||
|
async exportConfig(): Promise<ConfigExportData> {
|
||||||
|
const response = await apiClient.get<ConfigExportData>('/api/admin/system/config/export')
|
||||||
|
return response.data
|
||||||
|
},
|
||||||
|
|
||||||
|
// 导入配置
|
||||||
|
async importConfig(data: ConfigImportRequest): Promise<ConfigImportResponse> {
|
||||||
|
const response = await apiClient.post<ConfigImportResponse>(
|
||||||
|
'/api/admin/system/config/import',
|
||||||
|
data
|
||||||
|
)
|
||||||
|
return response.data
|
||||||
|
},
|
||||||
|
|
||||||
|
// 导出用户数据
|
||||||
|
async exportUsers(): Promise<UsersExportData> {
|
||||||
|
const response = await apiClient.get<UsersExportData>('/api/admin/system/users/export')
|
||||||
|
return response.data
|
||||||
|
},
|
||||||
|
|
||||||
|
// 导入用户数据
|
||||||
|
async importUsers(data: UsersImportRequest): Promise<UsersImportResponse> {
|
||||||
|
const response = await apiClient.post<UsersImportResponse>(
|
||||||
|
'/api/admin/system/users/import',
|
||||||
|
data
|
||||||
|
)
|
||||||
|
return response.data
|
||||||
|
},
|
||||||
|
|
||||||
|
// 查询 Provider 可用模型(从上游 API 获取)
|
||||||
|
async queryProviderModels(providerId: string, apiKeyId?: string): Promise<ProviderModelsQueryResponse> {
|
||||||
|
const response = await apiClient.post<ProviderModelsQueryResponse>(
|
||||||
|
'/api/admin/provider-query/models',
|
||||||
|
{ provider_id: providerId, api_key_id: apiKeyId }
|
||||||
|
)
|
||||||
|
return response.data
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +1,25 @@
|
|||||||
|
// API 格式常量
|
||||||
|
export const API_FORMATS = {
|
||||||
|
CLAUDE: 'CLAUDE',
|
||||||
|
CLAUDE_CLI: 'CLAUDE_CLI',
|
||||||
|
OPENAI: 'OPENAI',
|
||||||
|
OPENAI_CLI: 'OPENAI_CLI',
|
||||||
|
GEMINI: 'GEMINI',
|
||||||
|
GEMINI_CLI: 'GEMINI_CLI',
|
||||||
|
} as const
|
||||||
|
|
||||||
|
export type APIFormat = typeof API_FORMATS[keyof typeof API_FORMATS]
|
||||||
|
|
||||||
|
// API 格式显示名称映射(按品牌分组:API 在前,CLI 在后)
|
||||||
|
export const API_FORMAT_LABELS: Record<string, string> = {
|
||||||
|
[API_FORMATS.CLAUDE]: 'Claude',
|
||||||
|
[API_FORMATS.CLAUDE_CLI]: 'Claude CLI',
|
||||||
|
[API_FORMATS.OPENAI]: 'OpenAI',
|
||||||
|
[API_FORMATS.OPENAI_CLI]: 'OpenAI CLI',
|
||||||
|
[API_FORMATS.GEMINI]: 'Gemini',
|
||||||
|
[API_FORMATS.GEMINI_CLI]: 'Gemini CLI',
|
||||||
|
}
|
||||||
|
|
||||||
export interface ProviderEndpoint {
|
export interface ProviderEndpoint {
|
||||||
id: string
|
id: string
|
||||||
provider_id: string
|
provider_id: string
|
||||||
@@ -214,6 +236,7 @@ export interface ConcurrencyStatus {
|
|||||||
export interface ProviderModelAlias {
|
export interface ProviderModelAlias {
|
||||||
name: string
|
name: string
|
||||||
priority: number // 优先级(数字越小优先级越高)
|
priority: number // 优先级(数字越小优先级越高)
|
||||||
|
api_formats?: string[] // 作用域(适用的 API 格式),为空表示对所有格式生效
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface Model {
|
export interface Model {
|
||||||
|
|||||||
@@ -68,13 +68,19 @@
|
|||||||
<div
|
<div
|
||||||
v-for="model in group.models"
|
v-for="model in group.models"
|
||||||
:key="model.modelId"
|
:key="model.modelId"
|
||||||
class="flex items-center gap-2 pl-7 pr-2.5 py-1.5 cursor-pointer text-xs border-t"
|
class="flex flex-col gap-0.5 pl-7 pr-2.5 py-1.5 cursor-pointer text-xs border-t"
|
||||||
:class="selectedModel?.modelId === model.modelId && selectedModel?.providerId === model.providerId
|
:class="selectedModel?.modelId === model.modelId && selectedModel?.providerId === model.providerId
|
||||||
? 'bg-primary text-primary-foreground'
|
? 'bg-primary text-primary-foreground'
|
||||||
: 'hover:bg-muted'"
|
: 'hover:bg-muted'"
|
||||||
@click="selectModel(model)"
|
@click="selectModel(model)"
|
||||||
>
|
>
|
||||||
<span class="truncate">{{ model.modelName }}</span>
|
<span class="truncate font-medium">{{ model.modelName }}</span>
|
||||||
|
<span
|
||||||
|
class="truncate text-[10px]"
|
||||||
|
:class="selectedModel?.modelId === model.modelId && selectedModel?.providerId === model.providerId
|
||||||
|
? 'text-primary-foreground/70'
|
||||||
|
: 'text-muted-foreground'"
|
||||||
|
>{{ model.modelId }}</span>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -390,15 +396,13 @@ interface ProviderGroup {
|
|||||||
|
|
||||||
const groupedModels = computed(() => {
|
const groupedModels = computed(() => {
|
||||||
let models = allModels.value.filter(m => !m.deprecated)
|
let models = allModels.value.filter(m => !m.deprecated)
|
||||||
|
// 搜索(支持空格分隔的多关键词 AND 搜索)
|
||||||
if (searchQuery.value) {
|
if (searchQuery.value) {
|
||||||
const query = searchQuery.value.toLowerCase()
|
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||||
models = models.filter(model =>
|
models = models.filter(model => {
|
||||||
model.providerId.toLowerCase().includes(query) ||
|
const searchableText = `${model.providerId} ${model.providerName} ${model.modelId} ${model.modelName} ${model.family || ''}`.toLowerCase()
|
||||||
model.providerName.toLowerCase().includes(query) ||
|
return keywords.every(keyword => searchableText.includes(keyword))
|
||||||
model.modelId.toLowerCase().includes(query) ||
|
})
|
||||||
model.modelName.toLowerCase().includes(query) ||
|
|
||||||
model.family?.toLowerCase().includes(query)
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 按提供商分组
|
// 按提供商分组
|
||||||
@@ -415,14 +419,16 @@ const groupedModels = computed(() => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 转换为数组并排序
|
// 转换为数组并排序
|
||||||
let result = Array.from(groups.values())
|
const result = Array.from(groups.values())
|
||||||
|
|
||||||
// 如果有搜索词,把提供商名称/ID匹配的排在前面
|
// 如果有搜索词,把提供商名称/ID匹配的排在前面
|
||||||
if (searchQuery.value) {
|
if (searchQuery.value) {
|
||||||
const query = searchQuery.value.toLowerCase()
|
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||||
result.sort((a, b) => {
|
result.sort((a, b) => {
|
||||||
const aProviderMatch = a.providerId.toLowerCase().includes(query) || a.providerName.toLowerCase().includes(query)
|
const aText = `${a.providerId} ${a.providerName}`.toLowerCase()
|
||||||
const bProviderMatch = b.providerId.toLowerCase().includes(query) || b.providerName.toLowerCase().includes(query)
|
const bText = `${b.providerId} ${b.providerName}`.toLowerCase()
|
||||||
|
const aProviderMatch = keywords.some(k => aText.includes(k))
|
||||||
|
const bProviderMatch = keywords.some(k => bText.includes(k))
|
||||||
if (aProviderMatch && !bProviderMatch) return -1
|
if (aProviderMatch && !bProviderMatch) return -1
|
||||||
if (!aProviderMatch && bProviderMatch) return 1
|
if (!aProviderMatch && bProviderMatch) return 1
|
||||||
return a.providerName.localeCompare(b.providerName)
|
return a.providerName.localeCompare(b.providerName)
|
||||||
@@ -598,6 +604,11 @@ function resetForm() {
|
|||||||
// 加载模型数据(编辑模式)
|
// 加载模型数据(编辑模式)
|
||||||
function loadModelData() {
|
function loadModelData() {
|
||||||
if (!props.model) return
|
if (!props.model) return
|
||||||
|
// 先重置创建模式的残留状态
|
||||||
|
selectedModel.value = null
|
||||||
|
searchQuery.value = ''
|
||||||
|
expandedProvider.value = null
|
||||||
|
|
||||||
form.value = {
|
form.value = {
|
||||||
name: props.model.name,
|
name: props.model.name,
|
||||||
display_name: props.model.display_name,
|
display_name: props.model.display_name,
|
||||||
@@ -606,9 +617,10 @@ function loadModelData() {
|
|||||||
config: props.model.config ? { ...props.model.config } : { streaming: true },
|
config: props.model.config ? { ...props.model.config } : { streaming: true },
|
||||||
is_active: props.model.is_active,
|
is_active: props.model.is_active,
|
||||||
}
|
}
|
||||||
if (props.model.default_tiered_pricing) {
|
// 确保 tieredPricing 也被正确设置或重置
|
||||||
tieredPricing.value = JSON.parse(JSON.stringify(props.model.default_tiered_pricing))
|
tieredPricing.value = props.model.default_tiered_pricing
|
||||||
}
|
? JSON.parse(JSON.stringify(props.model.default_tiered_pricing))
|
||||||
|
: null
|
||||||
}
|
}
|
||||||
|
|
||||||
// 使用 useFormDialog 统一处理对话框逻辑
|
// 使用 useFormDialog 统一处理对话框逻辑
|
||||||
|
|||||||
@@ -526,7 +526,14 @@
|
|||||||
@edit-model="handleEditModel"
|
@edit-model="handleEditModel"
|
||||||
@delete-model="handleDeleteModel"
|
@delete-model="handleDeleteModel"
|
||||||
@batch-assign="handleBatchAssign"
|
@batch-assign="handleBatchAssign"
|
||||||
@manage-alias="handleManageAlias"
|
/>
|
||||||
|
|
||||||
|
<!-- 模型名称映射 -->
|
||||||
|
<ModelAliasesTab
|
||||||
|
v-if="provider"
|
||||||
|
:key="`aliases-${provider.id}`"
|
||||||
|
:provider="provider"
|
||||||
|
@refresh="handleRelatedDataRefresh"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</template>
|
</template>
|
||||||
@@ -629,16 +636,6 @@
|
|||||||
@update:open="batchAssignDialogOpen = $event"
|
@update:open="batchAssignDialogOpen = $event"
|
||||||
@changed="handleBatchAssignChanged"
|
@changed="handleBatchAssignChanged"
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<!-- 模型别名管理对话框 -->
|
|
||||||
<ModelAliasDialog
|
|
||||||
v-if="open && provider"
|
|
||||||
:open="aliasDialogOpen"
|
|
||||||
:provider-id="provider.id"
|
|
||||||
:model="aliasEditingModel"
|
|
||||||
@update:open="aliasDialogOpen = $event"
|
|
||||||
@saved="handleAliasSaved"
|
|
||||||
/>
|
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
@@ -667,8 +664,8 @@ import {
|
|||||||
KeyFormDialog,
|
KeyFormDialog,
|
||||||
KeyAllowedModelsDialog,
|
KeyAllowedModelsDialog,
|
||||||
ModelsTab,
|
ModelsTab,
|
||||||
BatchAssignModelsDialog,
|
ModelAliasesTab,
|
||||||
ModelAliasDialog
|
BatchAssignModelsDialog
|
||||||
} from '@/features/providers/components'
|
} from '@/features/providers/components'
|
||||||
import EndpointFormDialog from '@/features/providers/components/EndpointFormDialog.vue'
|
import EndpointFormDialog from '@/features/providers/components/EndpointFormDialog.vue'
|
||||||
import ProviderModelFormDialog from '@/features/providers/components/ProviderModelFormDialog.vue'
|
import ProviderModelFormDialog from '@/features/providers/components/ProviderModelFormDialog.vue'
|
||||||
@@ -737,10 +734,6 @@ const deleteModelConfirmOpen = ref(false)
|
|||||||
const modelToDelete = ref<Model | null>(null)
|
const modelToDelete = ref<Model | null>(null)
|
||||||
const batchAssignDialogOpen = ref(false)
|
const batchAssignDialogOpen = ref(false)
|
||||||
|
|
||||||
// 别名管理相关状态
|
|
||||||
const aliasDialogOpen = ref(false)
|
|
||||||
const aliasEditingModel = ref<Model | null>(null)
|
|
||||||
|
|
||||||
// 拖动排序相关状态
|
// 拖动排序相关状态
|
||||||
const dragState = ref({
|
const dragState = ref({
|
||||||
isDragging: false,
|
isDragging: false,
|
||||||
@@ -762,8 +755,7 @@ const hasBlockingDialogOpen = computed(() =>
|
|||||||
deleteKeyConfirmOpen.value ||
|
deleteKeyConfirmOpen.value ||
|
||||||
modelFormDialogOpen.value ||
|
modelFormDialogOpen.value ||
|
||||||
deleteModelConfirmOpen.value ||
|
deleteModelConfirmOpen.value ||
|
||||||
batchAssignDialogOpen.value ||
|
batchAssignDialogOpen.value
|
||||||
aliasDialogOpen.value
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// 监听 providerId 变化
|
// 监听 providerId 变化
|
||||||
@@ -792,7 +784,6 @@ watch(() => props.open, (newOpen) => {
|
|||||||
keyAllowedModelsDialogOpen.value = false
|
keyAllowedModelsDialogOpen.value = false
|
||||||
deleteKeyConfirmOpen.value = false
|
deleteKeyConfirmOpen.value = false
|
||||||
batchAssignDialogOpen.value = false
|
batchAssignDialogOpen.value = false
|
||||||
aliasDialogOpen.value = false
|
|
||||||
|
|
||||||
// 重置临时数据
|
// 重置临时数据
|
||||||
endpointToEdit.value = null
|
endpointToEdit.value = null
|
||||||
@@ -1030,19 +1021,6 @@ async function handleBatchAssignChanged() {
|
|||||||
emit('refresh')
|
emit('refresh')
|
||||||
}
|
}
|
||||||
|
|
||||||
// 处理管理映射 - 打开别名对话框
|
|
||||||
function handleManageAlias(model: Model) {
|
|
||||||
aliasEditingModel.value = model
|
|
||||||
aliasDialogOpen.value = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// 处理别名保存完成
|
|
||||||
async function handleAliasSaved() {
|
|
||||||
aliasEditingModel.value = null
|
|
||||||
await loadProvider()
|
|
||||||
emit('refresh')
|
|
||||||
}
|
|
||||||
|
|
||||||
// 处理模型保存完成
|
// 处理模型保存完成
|
||||||
async function handleModelSaved() {
|
async function handleModelSaved() {
|
||||||
editingModel.value = null
|
editingModel.value = null
|
||||||
|
|||||||
@@ -10,3 +10,4 @@ export { default as BatchAssignModelsDialog } from './BatchAssignModelsDialog.vu
|
|||||||
export { default as ModelAliasDialog } from './ModelAliasDialog.vue'
|
export { default as ModelAliasDialog } from './ModelAliasDialog.vue'
|
||||||
|
|
||||||
export { default as ModelsTab } from './provider-tabs/ModelsTab.vue'
|
export { default as ModelsTab } from './provider-tabs/ModelsTab.vue'
|
||||||
|
export { default as ModelAliasesTab } from './provider-tabs/ModelAliasesTab.vue'
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -165,15 +165,6 @@
|
|||||||
>
|
>
|
||||||
<Edit class="w-3.5 h-3.5" />
|
<Edit class="w-3.5 h-3.5" />
|
||||||
</Button>
|
</Button>
|
||||||
<Button
|
|
||||||
variant="ghost"
|
|
||||||
size="icon"
|
|
||||||
class="h-8 w-8"
|
|
||||||
title="管理映射"
|
|
||||||
@click="openAliasDialog(model)"
|
|
||||||
>
|
|
||||||
<Tag class="w-3.5 h-3.5" />
|
|
||||||
</Button>
|
|
||||||
<Button
|
<Button
|
||||||
variant="ghost"
|
variant="ghost"
|
||||||
size="icon"
|
size="icon"
|
||||||
@@ -218,7 +209,7 @@
|
|||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, computed, onMounted } from 'vue'
|
import { ref, computed, onMounted } from 'vue'
|
||||||
import { Box, Edit, Trash2, Layers, Eye, Wrench, Zap, Brain, Power, Copy, Image, Tag } from 'lucide-vue-next'
|
import { Box, Edit, Trash2, Layers, Eye, Wrench, Zap, Brain, Power, Copy, Image } from 'lucide-vue-next'
|
||||||
import Card from '@/components/ui/card.vue'
|
import Card from '@/components/ui/card.vue'
|
||||||
import Button from '@/components/ui/button.vue'
|
import Button from '@/components/ui/button.vue'
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
@@ -233,7 +224,6 @@ const emit = defineEmits<{
|
|||||||
'editModel': [model: Model]
|
'editModel': [model: Model]
|
||||||
'deleteModel': [model: Model]
|
'deleteModel': [model: Model]
|
||||||
'batchAssign': []
|
'batchAssign': []
|
||||||
'manageAlias': [model: Model]
|
|
||||||
}>()
|
}>()
|
||||||
|
|
||||||
const { error: showError, success: showSuccess } = useToast()
|
const { error: showError, success: showSuccess } = useToast()
|
||||||
@@ -373,11 +363,6 @@ function openBatchAssignDialog() {
|
|||||||
emit('batchAssign')
|
emit('batchAssign')
|
||||||
}
|
}
|
||||||
|
|
||||||
// 打开别名管理对话框
|
|
||||||
function openAliasDialog(model: Model) {
|
|
||||||
emit('manageAlias', model)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 切换模型启用状态
|
// 切换模型启用状态
|
||||||
async function toggleModelActive(model: Model) {
|
async function toggleModelActive(model: Model) {
|
||||||
if (togglingModelId.value) return
|
if (togglingModelId.value) return
|
||||||
|
|||||||
@@ -751,15 +751,13 @@ const expiringSoonCount = computed(() => apiKeys.value.filter(key => isExpiringS
|
|||||||
const filteredApiKeys = computed(() => {
|
const filteredApiKeys = computed(() => {
|
||||||
let result = apiKeys.value
|
let result = apiKeys.value
|
||||||
|
|
||||||
// 搜索筛选
|
// 搜索筛选(支持空格分隔的多关键词 AND 搜索)
|
||||||
if (searchQuery.value) {
|
if (searchQuery.value) {
|
||||||
const query = searchQuery.value.toLowerCase()
|
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||||
result = result.filter(key =>
|
result = result.filter(key => {
|
||||||
(key.name && key.name.toLowerCase().includes(query)) ||
|
const searchableText = `${key.name || ''} ${key.key_display || ''} ${key.username || ''} ${key.user_email || ''}`.toLowerCase()
|
||||||
(key.key_display && key.key_display.toLowerCase().includes(query)) ||
|
return keywords.every(keyword => searchableText.includes(keyword))
|
||||||
(key.username && key.username.toLowerCase().includes(query)) ||
|
})
|
||||||
(key.user_email && key.user_email.toLowerCase().includes(query))
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 状态筛选
|
// 状态筛选
|
||||||
|
|||||||
@@ -1002,13 +1002,13 @@ async function batchRemoveSelectedProviders() {
|
|||||||
const filteredGlobalModels = computed(() => {
|
const filteredGlobalModels = computed(() => {
|
||||||
let result = globalModels.value
|
let result = globalModels.value
|
||||||
|
|
||||||
// 搜索
|
// 搜索(支持空格分隔的多关键词 AND 搜索)
|
||||||
if (searchQuery.value) {
|
if (searchQuery.value) {
|
||||||
const query = searchQuery.value.toLowerCase()
|
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||||
result = result.filter(m =>
|
result = result.filter(m => {
|
||||||
m.name.toLowerCase().includes(query) ||
|
const searchableText = `${m.name} ${m.display_name || ''}`.toLowerCase()
|
||||||
m.display_name?.toLowerCase().includes(query)
|
return keywords.every(keyword => searchableText.includes(keyword))
|
||||||
)
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 能力筛选
|
// 能力筛选
|
||||||
|
|||||||
@@ -505,13 +505,13 @@ const priorityModeConfig = computed(() => {
|
|||||||
const filteredProviders = computed(() => {
|
const filteredProviders = computed(() => {
|
||||||
let result = [...providers.value]
|
let result = [...providers.value]
|
||||||
|
|
||||||
// 搜索筛选
|
// 搜索筛选(支持空格分隔的多关键词 AND 搜索)
|
||||||
if (searchQuery.value.trim()) {
|
if (searchQuery.value.trim()) {
|
||||||
const query = searchQuery.value.toLowerCase()
|
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||||
result = result.filter(p =>
|
result = result.filter(p => {
|
||||||
p.display_name.toLowerCase().includes(query) ||
|
const searchableText = `${p.display_name} ${p.name}`.toLowerCase()
|
||||||
p.name.toLowerCase().includes(query)
|
return keywords.every(keyword => searchableText.includes(keyword))
|
||||||
)
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 排序
|
// 排序
|
||||||
|
|||||||
@@ -15,6 +15,94 @@
|
|||||||
</PageHeader>
|
</PageHeader>
|
||||||
|
|
||||||
<div class="mt-6 space-y-6">
|
<div class="mt-6 space-y-6">
|
||||||
|
<!-- 配置导出/导入 -->
|
||||||
|
<CardSection
|
||||||
|
title="配置管理"
|
||||||
|
description="导出或导入提供商和模型配置,便于备份或迁移"
|
||||||
|
>
|
||||||
|
<div class="flex flex-wrap gap-4">
|
||||||
|
<div class="flex-1 min-w-[200px]">
|
||||||
|
<p class="text-sm text-muted-foreground mb-3">
|
||||||
|
导出当前所有提供商、端点、API Key 和模型配置到 JSON 文件
|
||||||
|
</p>
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
:disabled="exportLoading"
|
||||||
|
@click="handleExportConfig"
|
||||||
|
>
|
||||||
|
<Download class="w-4 h-4 mr-2" />
|
||||||
|
{{ exportLoading ? '导出中...' : '导出配置' }}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
<div class="flex-1 min-w-[200px]">
|
||||||
|
<p class="text-sm text-muted-foreground mb-3">
|
||||||
|
从 JSON 文件导入配置,支持跳过、覆盖或报错三种冲突处理模式
|
||||||
|
</p>
|
||||||
|
<div class="flex items-center gap-2">
|
||||||
|
<input
|
||||||
|
ref="configFileInput"
|
||||||
|
type="file"
|
||||||
|
accept=".json"
|
||||||
|
class="hidden"
|
||||||
|
@change="handleConfigFileSelect"
|
||||||
|
>
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
:disabled="importLoading"
|
||||||
|
@click="triggerConfigFileSelect"
|
||||||
|
>
|
||||||
|
<Upload class="w-4 h-4 mr-2" />
|
||||||
|
{{ importLoading ? '导入中...' : '导入配置' }}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</CardSection>
|
||||||
|
|
||||||
|
<!-- 用户数据导出/导入 -->
|
||||||
|
<CardSection
|
||||||
|
title="用户数据管理"
|
||||||
|
description="导出或导入用户及其 API Keys 数据(不含管理员)"
|
||||||
|
>
|
||||||
|
<div class="flex flex-wrap gap-4">
|
||||||
|
<div class="flex-1 min-w-[200px]">
|
||||||
|
<p class="text-sm text-muted-foreground mb-3">
|
||||||
|
导出所有普通用户及其 API Keys 到 JSON 文件
|
||||||
|
</p>
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
:disabled="exportUsersLoading"
|
||||||
|
@click="handleExportUsers"
|
||||||
|
>
|
||||||
|
<Download class="w-4 h-4 mr-2" />
|
||||||
|
{{ exportUsersLoading ? '导出中...' : '导出用户数据' }}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
<div class="flex-1 min-w-[200px]">
|
||||||
|
<p class="text-sm text-muted-foreground mb-3">
|
||||||
|
从 JSON 文件导入用户数据(需相同 ENCRYPTION_KEY)
|
||||||
|
</p>
|
||||||
|
<div class="flex items-center gap-2">
|
||||||
|
<input
|
||||||
|
ref="usersFileInput"
|
||||||
|
type="file"
|
||||||
|
accept=".json"
|
||||||
|
class="hidden"
|
||||||
|
@change="handleUsersFileSelect"
|
||||||
|
>
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
:disabled="importUsersLoading"
|
||||||
|
@click="triggerUsersFileSelect"
|
||||||
|
>
|
||||||
|
<Upload class="w-4 h-4 mr-2" />
|
||||||
|
{{ importUsersLoading ? '导入中...' : '导入用户数据' }}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</CardSection>
|
||||||
|
|
||||||
<!-- 基础配置 -->
|
<!-- 基础配置 -->
|
||||||
<CardSection
|
<CardSection
|
||||||
title="基础配置"
|
title="基础配置"
|
||||||
@@ -375,11 +463,326 @@
|
|||||||
</div>
|
</div>
|
||||||
</CardSection>
|
</CardSection>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- 导入配置对话框 -->
|
||||||
|
<Dialog v-model:open="importDialogOpen">
|
||||||
|
<DialogContent class="max-w-lg">
|
||||||
|
<DialogHeader>
|
||||||
|
<DialogTitle>导入配置</DialogTitle>
|
||||||
|
<DialogDescription>
|
||||||
|
选择冲突处理模式并确认导入
|
||||||
|
</DialogDescription>
|
||||||
|
</DialogHeader>
|
||||||
|
|
||||||
|
<div class="space-y-4 py-4">
|
||||||
|
<div
|
||||||
|
v-if="importPreview"
|
||||||
|
class="p-3 bg-muted rounded-lg text-sm"
|
||||||
|
>
|
||||||
|
<p class="font-medium mb-2">
|
||||||
|
配置预览
|
||||||
|
</p>
|
||||||
|
<ul class="space-y-1 text-muted-foreground">
|
||||||
|
<li>全局模型: {{ importPreview.global_models?.length || 0 }} 个</li>
|
||||||
|
<li>提供商: {{ importPreview.providers?.length || 0 }} 个</li>
|
||||||
|
<li>
|
||||||
|
端点: {{ importPreview.providers?.reduce((sum: number, p: any) => sum + (p.endpoints?.length || 0), 0) }} 个
|
||||||
|
</li>
|
||||||
|
<li>
|
||||||
|
API Keys: {{ importPreview.providers?.reduce((sum: number, p: any) => sum + p.endpoints?.reduce((s: number, e: any) => s + (e.keys?.length || 0), 0), 0) }} 个
|
||||||
|
</li>
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<Label class="block text-sm font-medium mb-2">冲突处理模式</Label>
|
||||||
|
<Select v-model="mergeMode">
|
||||||
|
<SelectTrigger>
|
||||||
|
<SelectValue />
|
||||||
|
</SelectTrigger>
|
||||||
|
<SelectContent>
|
||||||
|
<SelectItem value="skip">
|
||||||
|
跳过 - 保留现有配置
|
||||||
|
</SelectItem>
|
||||||
|
<SelectItem value="overwrite">
|
||||||
|
覆盖 - 用导入配置替换
|
||||||
|
</SelectItem>
|
||||||
|
<SelectItem value="error">
|
||||||
|
报错 - 遇到冲突时中止
|
||||||
|
</SelectItem>
|
||||||
|
</SelectContent>
|
||||||
|
</Select>
|
||||||
|
<p class="mt-1 text-xs text-muted-foreground">
|
||||||
|
<template v-if="mergeMode === 'skip'">
|
||||||
|
已存在的配置将被保留,仅导入新配置
|
||||||
|
</template>
|
||||||
|
<template v-else-if="mergeMode === 'overwrite'">
|
||||||
|
已存在的配置将被导入的配置覆盖
|
||||||
|
</template>
|
||||||
|
<template v-else>
|
||||||
|
如果发现任何冲突,导入将中止并回滚
|
||||||
|
</template>
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="p-3 bg-yellow-500/10 border border-yellow-500/20 rounded-lg">
|
||||||
|
<p class="text-sm text-yellow-600 dark:text-yellow-400">
|
||||||
|
注意:相同的 API Keys 会自动跳过,不会创建重复记录。
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<DialogFooter>
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
@click="importDialogOpen = false"
|
||||||
|
>
|
||||||
|
取消
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
:disabled="importLoading"
|
||||||
|
@click="confirmImport"
|
||||||
|
>
|
||||||
|
{{ importLoading ? '导入中...' : '确认导入' }}
|
||||||
|
</Button>
|
||||||
|
</DialogFooter>
|
||||||
|
</DialogContent>
|
||||||
|
</Dialog>
|
||||||
|
|
||||||
|
<!-- 导入结果对话框 -->
|
||||||
|
<Dialog v-model:open="importResultDialogOpen">
|
||||||
|
<DialogContent class="max-w-lg">
|
||||||
|
<DialogHeader>
|
||||||
|
<DialogTitle>导入完成</DialogTitle>
|
||||||
|
</DialogHeader>
|
||||||
|
|
||||||
|
<div
|
||||||
|
v-if="importResult"
|
||||||
|
class="space-y-4 py-4"
|
||||||
|
>
|
||||||
|
<div class="grid grid-cols-2 gap-4 text-sm">
|
||||||
|
<div class="p-3 bg-muted rounded-lg">
|
||||||
|
<p class="font-medium">
|
||||||
|
全局模型
|
||||||
|
</p>
|
||||||
|
<p class="text-muted-foreground">
|
||||||
|
创建: {{ importResult.stats.global_models.created }},
|
||||||
|
更新: {{ importResult.stats.global_models.updated }},
|
||||||
|
跳过: {{ importResult.stats.global_models.skipped }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<div class="p-3 bg-muted rounded-lg">
|
||||||
|
<p class="font-medium">
|
||||||
|
提供商
|
||||||
|
</p>
|
||||||
|
<p class="text-muted-foreground">
|
||||||
|
创建: {{ importResult.stats.providers.created }},
|
||||||
|
更新: {{ importResult.stats.providers.updated }},
|
||||||
|
跳过: {{ importResult.stats.providers.skipped }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<div class="p-3 bg-muted rounded-lg">
|
||||||
|
<p class="font-medium">
|
||||||
|
端点
|
||||||
|
</p>
|
||||||
|
<p class="text-muted-foreground">
|
||||||
|
创建: {{ importResult.stats.endpoints.created }},
|
||||||
|
更新: {{ importResult.stats.endpoints.updated }},
|
||||||
|
跳过: {{ importResult.stats.endpoints.skipped }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<div class="p-3 bg-muted rounded-lg">
|
||||||
|
<p class="font-medium">
|
||||||
|
API Keys
|
||||||
|
</p>
|
||||||
|
<p class="text-muted-foreground">
|
||||||
|
创建: {{ importResult.stats.keys.created }},
|
||||||
|
跳过: {{ importResult.stats.keys.skipped }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<div class="p-3 bg-muted rounded-lg col-span-2">
|
||||||
|
<p class="font-medium">
|
||||||
|
模型配置
|
||||||
|
</p>
|
||||||
|
<p class="text-muted-foreground">
|
||||||
|
创建: {{ importResult.stats.models.created }},
|
||||||
|
更新: {{ importResult.stats.models.updated }},
|
||||||
|
跳过: {{ importResult.stats.models.skipped }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div
|
||||||
|
v-if="importResult.stats.errors.length > 0"
|
||||||
|
class="p-3 bg-red-500/10 border border-red-500/20 rounded-lg"
|
||||||
|
>
|
||||||
|
<p class="font-medium text-red-600 dark:text-red-400 mb-2">
|
||||||
|
警告信息
|
||||||
|
</p>
|
||||||
|
<ul class="text-sm text-red-600 dark:text-red-400 space-y-1">
|
||||||
|
<li
|
||||||
|
v-for="(err, index) in importResult.stats.errors"
|
||||||
|
:key="index"
|
||||||
|
>
|
||||||
|
{{ err }}
|
||||||
|
</li>
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<DialogFooter>
|
||||||
|
<Button @click="importResultDialogOpen = false">
|
||||||
|
确定
|
||||||
|
</Button>
|
||||||
|
</DialogFooter>
|
||||||
|
</DialogContent>
|
||||||
|
</Dialog>
|
||||||
|
|
||||||
|
<!-- 用户数据导入对话框 -->
|
||||||
|
<Dialog v-model:open="importUsersDialogOpen">
|
||||||
|
<DialogContent class="max-w-lg">
|
||||||
|
<DialogHeader>
|
||||||
|
<DialogTitle>导入用户数据</DialogTitle>
|
||||||
|
<DialogDescription>
|
||||||
|
选择冲突处理模式并确认导入
|
||||||
|
</DialogDescription>
|
||||||
|
</DialogHeader>
|
||||||
|
|
||||||
|
<div class="space-y-4 py-4">
|
||||||
|
<div
|
||||||
|
v-if="importUsersPreview"
|
||||||
|
class="p-3 bg-muted rounded-lg text-sm"
|
||||||
|
>
|
||||||
|
<p class="font-medium mb-2">
|
||||||
|
数据预览
|
||||||
|
</p>
|
||||||
|
<ul class="space-y-1 text-muted-foreground">
|
||||||
|
<li>用户: {{ importUsersPreview.users?.length || 0 }} 个</li>
|
||||||
|
<li>
|
||||||
|
API Keys: {{ importUsersPreview.users?.reduce((sum: number, u: any) => sum + (u.api_keys?.length || 0), 0) }} 个
|
||||||
|
</li>
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<Label class="block text-sm font-medium mb-2">冲突处理模式</Label>
|
||||||
|
<Select v-model="usersMergeMode">
|
||||||
|
<SelectTrigger>
|
||||||
|
<SelectValue />
|
||||||
|
</SelectTrigger>
|
||||||
|
<SelectContent>
|
||||||
|
<SelectItem value="skip">
|
||||||
|
跳过 - 保留现有用户
|
||||||
|
</SelectItem>
|
||||||
|
<SelectItem value="overwrite">
|
||||||
|
覆盖 - 用导入数据替换
|
||||||
|
</SelectItem>
|
||||||
|
<SelectItem value="error">
|
||||||
|
报错 - 遇到冲突时中止
|
||||||
|
</SelectItem>
|
||||||
|
</SelectContent>
|
||||||
|
</Select>
|
||||||
|
<p class="mt-1 text-xs text-muted-foreground">
|
||||||
|
<template v-if="usersMergeMode === 'skip'">
|
||||||
|
已存在的用户将被保留,仅导入新用户
|
||||||
|
</template>
|
||||||
|
<template v-else-if="usersMergeMode === 'overwrite'">
|
||||||
|
已存在的用户将被导入的数据覆盖
|
||||||
|
</template>
|
||||||
|
<template v-else>
|
||||||
|
如果发现任何冲突,导入将中止并回滚
|
||||||
|
</template>
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="p-3 bg-yellow-500/10 border border-yellow-500/20 rounded-lg">
|
||||||
|
<p class="text-sm text-yellow-600 dark:text-yellow-400">
|
||||||
|
注意:用户 API Keys 需要目标系统使用相同的 ENCRYPTION_KEY 环境变量才能正常工作。
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<DialogFooter>
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
@click="importUsersDialogOpen = false"
|
||||||
|
>
|
||||||
|
取消
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
:disabled="importUsersLoading"
|
||||||
|
@click="confirmImportUsers"
|
||||||
|
>
|
||||||
|
{{ importUsersLoading ? '导入中...' : '确认导入' }}
|
||||||
|
</Button>
|
||||||
|
</DialogFooter>
|
||||||
|
</DialogContent>
|
||||||
|
</Dialog>
|
||||||
|
|
||||||
|
<!-- 用户数据导入结果对话框 -->
|
||||||
|
<Dialog v-model:open="importUsersResultDialogOpen">
|
||||||
|
<DialogContent class="max-w-lg">
|
||||||
|
<DialogHeader>
|
||||||
|
<DialogTitle>用户数据导入完成</DialogTitle>
|
||||||
|
</DialogHeader>
|
||||||
|
|
||||||
|
<div
|
||||||
|
v-if="importUsersResult"
|
||||||
|
class="space-y-4 py-4"
|
||||||
|
>
|
||||||
|
<div class="grid grid-cols-2 gap-4 text-sm">
|
||||||
|
<div class="p-3 bg-muted rounded-lg">
|
||||||
|
<p class="font-medium">
|
||||||
|
用户
|
||||||
|
</p>
|
||||||
|
<p class="text-muted-foreground">
|
||||||
|
创建: {{ importUsersResult.stats.users.created }},
|
||||||
|
更新: {{ importUsersResult.stats.users.updated }},
|
||||||
|
跳过: {{ importUsersResult.stats.users.skipped }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<div class="p-3 bg-muted rounded-lg">
|
||||||
|
<p class="font-medium">
|
||||||
|
API Keys
|
||||||
|
</p>
|
||||||
|
<p class="text-muted-foreground">
|
||||||
|
创建: {{ importUsersResult.stats.api_keys.created }},
|
||||||
|
跳过: {{ importUsersResult.stats.api_keys.skipped }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div
|
||||||
|
v-if="importUsersResult.stats.errors.length > 0"
|
||||||
|
class="p-3 bg-red-500/10 border border-red-500/20 rounded-lg"
|
||||||
|
>
|
||||||
|
<p class="font-medium text-red-600 dark:text-red-400 mb-2">
|
||||||
|
警告信息
|
||||||
|
</p>
|
||||||
|
<ul class="text-sm text-red-600 dark:text-red-400 space-y-1">
|
||||||
|
<li
|
||||||
|
v-for="(err, index) in importUsersResult.stats.errors"
|
||||||
|
:key="index"
|
||||||
|
>
|
||||||
|
{{ err }}
|
||||||
|
</li>
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<DialogFooter>
|
||||||
|
<Button @click="importUsersResultDialogOpen = false">
|
||||||
|
确定
|
||||||
|
</Button>
|
||||||
|
</DialogFooter>
|
||||||
|
</DialogContent>
|
||||||
|
</Dialog>
|
||||||
</PageContainer>
|
</PageContainer>
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, computed, onMounted } from 'vue'
|
import { ref, computed, onMounted } from 'vue'
|
||||||
|
import { Download, Upload } from 'lucide-vue-next'
|
||||||
import Button from '@/components/ui/button.vue'
|
import Button from '@/components/ui/button.vue'
|
||||||
import Input from '@/components/ui/input.vue'
|
import Input from '@/components/ui/input.vue'
|
||||||
import Label from '@/components/ui/label.vue'
|
import Label from '@/components/ui/label.vue'
|
||||||
@@ -389,9 +792,17 @@ import SelectTrigger from '@/components/ui/select-trigger.vue'
|
|||||||
import SelectValue from '@/components/ui/select-value.vue'
|
import SelectValue from '@/components/ui/select-value.vue'
|
||||||
import SelectContent from '@/components/ui/select-content.vue'
|
import SelectContent from '@/components/ui/select-content.vue'
|
||||||
import SelectItem from '@/components/ui/select-item.vue'
|
import SelectItem from '@/components/ui/select-item.vue'
|
||||||
|
import {
|
||||||
|
Dialog,
|
||||||
|
DialogContent,
|
||||||
|
DialogHeader,
|
||||||
|
DialogTitle,
|
||||||
|
DialogDescription,
|
||||||
|
DialogFooter
|
||||||
|
} from '@/components/ui'
|
||||||
import { PageHeader, PageContainer, CardSection } from '@/components/layout'
|
import { PageHeader, PageContainer, CardSection } from '@/components/layout'
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
import { adminApi } from '@/api/admin'
|
import { adminApi, type ConfigExportData, type ConfigImportResponse, type UsersExportData, type UsersImportResponse } from '@/api/admin'
|
||||||
import { log } from '@/utils/logger'
|
import { log } from '@/utils/logger'
|
||||||
|
|
||||||
const { success, error } = useToast()
|
const { success, error } = useToast()
|
||||||
@@ -423,6 +834,26 @@ interface SystemConfig {
|
|||||||
const loading = ref(false)
|
const loading = ref(false)
|
||||||
const logLevelSelectOpen = ref(false)
|
const logLevelSelectOpen = ref(false)
|
||||||
|
|
||||||
|
// 导出/导入相关
|
||||||
|
const exportLoading = ref(false)
|
||||||
|
const importLoading = ref(false)
|
||||||
|
const importDialogOpen = ref(false)
|
||||||
|
const importResultDialogOpen = ref(false)
|
||||||
|
const configFileInput = ref<HTMLInputElement | null>(null)
|
||||||
|
const importPreview = ref<ConfigExportData | null>(null)
|
||||||
|
const importResult = ref<ConfigImportResponse | null>(null)
|
||||||
|
const mergeMode = ref<'skip' | 'overwrite' | 'error'>('skip')
|
||||||
|
|
||||||
|
// 用户数据导出/导入相关
|
||||||
|
const exportUsersLoading = ref(false)
|
||||||
|
const importUsersLoading = ref(false)
|
||||||
|
const importUsersDialogOpen = ref(false)
|
||||||
|
const importUsersResultDialogOpen = ref(false)
|
||||||
|
const usersFileInput = ref<HTMLInputElement | null>(null)
|
||||||
|
const importUsersPreview = ref<UsersExportData | null>(null)
|
||||||
|
const importUsersResult = ref<UsersImportResponse | null>(null)
|
||||||
|
const usersMergeMode = ref<'skip' | 'overwrite' | 'error'>('skip')
|
||||||
|
|
||||||
const systemConfig = ref<SystemConfig>({
|
const systemConfig = ref<SystemConfig>({
|
||||||
// 基础配置
|
// 基础配置
|
||||||
default_user_quota_usd: 10.0,
|
default_user_quota_usd: 10.0,
|
||||||
@@ -623,4 +1054,183 @@ async function saveSystemConfig() {
|
|||||||
loading.value = false
|
loading.value = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 导出配置
|
||||||
|
async function handleExportConfig() {
|
||||||
|
exportLoading.value = true
|
||||||
|
try {
|
||||||
|
const data = await adminApi.exportConfig()
|
||||||
|
const blob = new Blob([JSON.stringify(data, null, 2)], { type: 'application/json' })
|
||||||
|
const url = URL.createObjectURL(blob)
|
||||||
|
const a = document.createElement('a')
|
||||||
|
a.href = url
|
||||||
|
a.download = `aether-config-${new Date().toISOString().slice(0, 10)}.json`
|
||||||
|
document.body.appendChild(a)
|
||||||
|
a.click()
|
||||||
|
document.body.removeChild(a)
|
||||||
|
URL.revokeObjectURL(url)
|
||||||
|
success('配置已导出')
|
||||||
|
} catch (err) {
|
||||||
|
error('导出配置失败')
|
||||||
|
log.error('导出配置失败:', err)
|
||||||
|
} finally {
|
||||||
|
exportLoading.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 触发文件选择
|
||||||
|
function triggerConfigFileSelect() {
|
||||||
|
configFileInput.value?.click()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 文件大小限制 (10MB)
|
||||||
|
const MAX_FILE_SIZE = 10 * 1024 * 1024
|
||||||
|
|
||||||
|
// 处理文件选择
|
||||||
|
function handleConfigFileSelect(event: Event) {
|
||||||
|
const input = event.target as HTMLInputElement
|
||||||
|
const file = input.files?.[0]
|
||||||
|
if (!file) return
|
||||||
|
|
||||||
|
if (file.size > MAX_FILE_SIZE) {
|
||||||
|
error('文件大小不能超过 10MB')
|
||||||
|
input.value = ''
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const reader = new FileReader()
|
||||||
|
reader.onload = (e) => {
|
||||||
|
try {
|
||||||
|
const content = e.target?.result as string
|
||||||
|
const data = JSON.parse(content) as ConfigExportData
|
||||||
|
|
||||||
|
// 验证版本
|
||||||
|
if (data.version !== '1.0') {
|
||||||
|
error(`不支持的配置版本: ${data.version}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
importPreview.value = data
|
||||||
|
mergeMode.value = 'skip'
|
||||||
|
importDialogOpen.value = true
|
||||||
|
} catch (err) {
|
||||||
|
error('解析配置文件失败,请确保是有效的 JSON 文件')
|
||||||
|
log.error('解析配置文件失败:', err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reader.readAsText(file)
|
||||||
|
|
||||||
|
// 重置 input 以便能再次选择同一文件
|
||||||
|
input.value = ''
|
||||||
|
}
|
||||||
|
|
||||||
|
// 确认导入
|
||||||
|
async function confirmImport() {
|
||||||
|
if (!importPreview.value) return
|
||||||
|
|
||||||
|
importLoading.value = true
|
||||||
|
try {
|
||||||
|
const result = await adminApi.importConfig({
|
||||||
|
...importPreview.value,
|
||||||
|
merge_mode: mergeMode.value
|
||||||
|
})
|
||||||
|
importResult.value = result
|
||||||
|
importDialogOpen.value = false
|
||||||
|
importResultDialogOpen.value = true
|
||||||
|
success('配置导入成功')
|
||||||
|
} catch (err: any) {
|
||||||
|
error(err.response?.data?.detail || '导入配置失败')
|
||||||
|
log.error('导入配置失败:', err)
|
||||||
|
} finally {
|
||||||
|
importLoading.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 导出用户数据
|
||||||
|
async function handleExportUsers() {
|
||||||
|
exportUsersLoading.value = true
|
||||||
|
try {
|
||||||
|
const data = await adminApi.exportUsers()
|
||||||
|
const blob = new Blob([JSON.stringify(data, null, 2)], { type: 'application/json' })
|
||||||
|
const url = URL.createObjectURL(blob)
|
||||||
|
const a = document.createElement('a')
|
||||||
|
a.href = url
|
||||||
|
a.download = `aether-users-${new Date().toISOString().slice(0, 10)}.json`
|
||||||
|
document.body.appendChild(a)
|
||||||
|
a.click()
|
||||||
|
document.body.removeChild(a)
|
||||||
|
URL.revokeObjectURL(url)
|
||||||
|
success('用户数据已导出')
|
||||||
|
} catch (err) {
|
||||||
|
error('导出用户数据失败')
|
||||||
|
log.error('导出用户数据失败:', err)
|
||||||
|
} finally {
|
||||||
|
exportUsersLoading.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 触发用户数据文件选择
|
||||||
|
function triggerUsersFileSelect() {
|
||||||
|
usersFileInput.value?.click()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理用户数据文件选择
|
||||||
|
function handleUsersFileSelect(event: Event) {
|
||||||
|
const input = event.target as HTMLInputElement
|
||||||
|
const file = input.files?.[0]
|
||||||
|
if (!file) return
|
||||||
|
|
||||||
|
if (file.size > MAX_FILE_SIZE) {
|
||||||
|
error('文件大小不能超过 10MB')
|
||||||
|
input.value = ''
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const reader = new FileReader()
|
||||||
|
reader.onload = (e) => {
|
||||||
|
try {
|
||||||
|
const content = e.target?.result as string
|
||||||
|
const data = JSON.parse(content) as UsersExportData
|
||||||
|
|
||||||
|
// 验证版本
|
||||||
|
if (data.version !== '1.0') {
|
||||||
|
error(`不支持的配置版本: ${data.version}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
importUsersPreview.value = data
|
||||||
|
usersMergeMode.value = 'skip'
|
||||||
|
importUsersDialogOpen.value = true
|
||||||
|
} catch (err) {
|
||||||
|
error('解析用户数据文件失败,请确保是有效的 JSON 文件')
|
||||||
|
log.error('解析用户数据文件失败:', err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reader.readAsText(file)
|
||||||
|
|
||||||
|
// 重置 input 以便能再次选择同一文件
|
||||||
|
input.value = ''
|
||||||
|
}
|
||||||
|
|
||||||
|
// 确认导入用户数据
|
||||||
|
async function confirmImportUsers() {
|
||||||
|
if (!importUsersPreview.value) return
|
||||||
|
|
||||||
|
importUsersLoading.value = true
|
||||||
|
try {
|
||||||
|
const result = await adminApi.importUsers({
|
||||||
|
...importUsersPreview.value,
|
||||||
|
merge_mode: usersMergeMode.value
|
||||||
|
})
|
||||||
|
importUsersResult.value = result
|
||||||
|
importUsersDialogOpen.value = false
|
||||||
|
importUsersResultDialogOpen.value = true
|
||||||
|
success('用户数据导入成功')
|
||||||
|
} catch (err: any) {
|
||||||
|
error(err.response?.data?.detail || '导入用户数据失败')
|
||||||
|
log.error('导入用户数据失败:', err)
|
||||||
|
} finally {
|
||||||
|
importUsersLoading.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
@@ -791,11 +791,13 @@ const filteredUsers = computed(() => {
|
|||||||
return new Date(b.created_at).getTime() - new Date(a.created_at).getTime()
|
return new Date(b.created_at).getTime() - new Date(a.created_at).getTime()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// 搜索(支持空格分隔的多关键词 AND 搜索)
|
||||||
if (searchQuery.value) {
|
if (searchQuery.value) {
|
||||||
const query = searchQuery.value.toLowerCase()
|
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||||
filtered = filtered.filter(
|
filtered = filtered.filter(u => {
|
||||||
u => u.username.toLowerCase().includes(query) || u.email?.toLowerCase().includes(query)
|
const searchableText = `${u.username} ${u.email || ''}`.toLowerCase()
|
||||||
)
|
return keywords.every(keyword => searchableText.includes(keyword))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if (filterRole.value !== 'all') {
|
if (filterRole.value !== 'all') {
|
||||||
|
|||||||
@@ -474,13 +474,13 @@ async function toggleCapability(modelName: string, capName: string) {
|
|||||||
const filteredModels = computed(() => {
|
const filteredModels = computed(() => {
|
||||||
let result = models.value
|
let result = models.value
|
||||||
|
|
||||||
// 搜索
|
// 搜索(支持空格分隔的多关键词 AND 搜索)
|
||||||
if (searchQuery.value) {
|
if (searchQuery.value) {
|
||||||
const query = searchQuery.value.toLowerCase()
|
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||||
result = result.filter(m =>
|
result = result.filter(m => {
|
||||||
m.name.toLowerCase().includes(query) ||
|
const searchableText = `${m.name} ${m.display_name || ''}`.toLowerCase()
|
||||||
m.display_name?.toLowerCase().includes(query)
|
return keywords.every(keyword => searchableText.includes(keyword))
|
||||||
)
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 能力筛选
|
// 能力筛选
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from .api_keys import router as api_keys_router
|
|||||||
from .endpoints import router as endpoints_router
|
from .endpoints import router as endpoints_router
|
||||||
from .models import router as models_router
|
from .models import router as models_router
|
||||||
from .monitoring import router as monitoring_router
|
from .monitoring import router as monitoring_router
|
||||||
|
from .provider_query import router as provider_query_router
|
||||||
from .provider_strategy import router as provider_strategy_router
|
from .provider_strategy import router as provider_strategy_router
|
||||||
from .providers import router as providers_router
|
from .providers import router as providers_router
|
||||||
from .security import router as security_router
|
from .security import router as security_router
|
||||||
@@ -26,5 +27,6 @@ router.include_router(provider_strategy_router)
|
|||||||
router.include_router(adaptive_router)
|
router.include_router(adaptive_router)
|
||||||
router.include_router(models_router)
|
router.include_router(models_router)
|
||||||
router.include_router(security_router)
|
router.include_router(security_router)
|
||||||
|
router.include_router(provider_query_router)
|
||||||
|
|
||||||
__all__ = ["router"]
|
__all__ = ["router"]
|
||||||
|
|||||||
@@ -1,46 +1,28 @@
|
|||||||
"""
|
"""
|
||||||
Provider Query API 端点
|
Provider Query API 端点
|
||||||
用于查询提供商的余额、使用记录等信息
|
用于查询提供商的模型列表等信息
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from datetime import datetime
|
import asyncio
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
import httpx
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.orm import Session, joinedload
|
||||||
|
|
||||||
|
from src.core.crypto import crypto_service
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
from src.database.database import get_db
|
from src.database.database import get_db
|
||||||
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint, User
|
from src.models.database import Provider, ProviderEndpoint, User
|
||||||
|
|
||||||
# 初始化适配器注册
|
|
||||||
from src.plugins.provider_query import init # noqa
|
|
||||||
from src.plugins.provider_query import get_query_registry
|
|
||||||
from src.plugins.provider_query.base import QueryCapability
|
|
||||||
from src.utils.auth_utils import get_current_user
|
from src.utils.auth_utils import get_current_user
|
||||||
|
|
||||||
router = APIRouter(prefix="/provider-query", tags=["Provider Query"])
|
router = APIRouter(prefix="/api/admin/provider-query", tags=["Provider Query"])
|
||||||
|
|
||||||
|
|
||||||
# ============ Request/Response Models ============
|
# ============ Request/Response Models ============
|
||||||
|
|
||||||
|
|
||||||
class BalanceQueryRequest(BaseModel):
|
|
||||||
"""余额查询请求"""
|
|
||||||
|
|
||||||
provider_id: str
|
|
||||||
api_key_id: Optional[str] = None # 如果不指定,使用提供商的第一个可用 API Key
|
|
||||||
|
|
||||||
|
|
||||||
class UsageSummaryQueryRequest(BaseModel):
|
|
||||||
"""使用汇总查询请求"""
|
|
||||||
|
|
||||||
provider_id: str
|
|
||||||
api_key_id: Optional[str] = None
|
|
||||||
period: str = "month" # day, week, month, year
|
|
||||||
|
|
||||||
|
|
||||||
class ModelsQueryRequest(BaseModel):
|
class ModelsQueryRequest(BaseModel):
|
||||||
"""模型列表查询请求"""
|
"""模型列表查询请求"""
|
||||||
|
|
||||||
@@ -51,360 +33,281 @@ class ModelsQueryRequest(BaseModel):
|
|||||||
# ============ API Endpoints ============
|
# ============ API Endpoints ============
|
||||||
|
|
||||||
|
|
||||||
@router.get("/adapters")
|
async def _fetch_openai_models(
|
||||||
async def list_adapters(
|
client: httpx.AsyncClient,
|
||||||
current_user: User = Depends(get_current_user),
|
base_url: str,
|
||||||
):
|
api_key: str,
|
||||||
"""
|
api_format: str,
|
||||||
获取所有可用的查询适配器
|
extra_headers: Optional[dict] = None,
|
||||||
|
) -> tuple[list, Optional[str]]:
|
||||||
|
"""获取 OpenAI 格式的模型列表
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
适配器列表
|
tuple[list, Optional[str]]: (模型列表, 错误信息)
|
||||||
"""
|
"""
|
||||||
registry = get_query_registry()
|
headers = {"Authorization": f"Bearer {api_key}"}
|
||||||
adapters = registry.list_adapters()
|
if extra_headers:
|
||||||
|
# 防止 extra_headers 覆盖 Authorization
|
||||||
|
safe_headers = {k: v for k, v in extra_headers.items() if k.lower() != "authorization"}
|
||||||
|
headers.update(safe_headers)
|
||||||
|
|
||||||
return {"success": True, "data": adapters}
|
# 构建 /v1/models URL
|
||||||
|
if base_url.endswith("/v1"):
|
||||||
|
models_url = f"{base_url}/models"
|
||||||
@router.get("/capabilities/{provider_id}")
|
|
||||||
async def get_provider_capabilities(
|
|
||||||
provider_id: str,
|
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
获取提供商支持的查询能力
|
|
||||||
|
|
||||||
Args:
|
|
||||||
provider_id: 提供商 ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
支持的查询能力列表
|
|
||||||
"""
|
|
||||||
# 获取提供商
|
|
||||||
from sqlalchemy import select
|
|
||||||
|
|
||||||
result = await db.execute(select(Provider).where(Provider.id == provider_id))
|
|
||||||
provider = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if not provider:
|
|
||||||
raise HTTPException(status_code=404, detail="Provider not found")
|
|
||||||
|
|
||||||
registry = get_query_registry()
|
|
||||||
capabilities = registry.get_capabilities_for_provider(provider.name)
|
|
||||||
|
|
||||||
if capabilities is None:
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"data": {
|
|
||||||
"provider_id": provider_id,
|
|
||||||
"provider_name": provider.name,
|
|
||||||
"capabilities": [],
|
|
||||||
"has_adapter": False,
|
|
||||||
"message": "No query adapter available for this provider",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"data": {
|
|
||||||
"provider_id": provider_id,
|
|
||||||
"provider_name": provider.name,
|
|
||||||
"capabilities": [c.name for c in capabilities],
|
|
||||||
"has_adapter": True,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/balance")
|
|
||||||
async def query_balance(
|
|
||||||
request: BalanceQueryRequest,
|
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
查询提供商余额
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: 查询请求
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
余额信息
|
|
||||||
"""
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.orm import selectinload
|
|
||||||
|
|
||||||
# 获取提供商及其端点
|
|
||||||
result = await db.execute(
|
|
||||||
select(Provider)
|
|
||||||
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
|
|
||||||
.where(Provider.id == request.provider_id)
|
|
||||||
)
|
|
||||||
provider = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if not provider:
|
|
||||||
raise HTTPException(status_code=404, detail="Provider not found")
|
|
||||||
|
|
||||||
# 获取 API Key
|
|
||||||
api_key_value = None
|
|
||||||
endpoint_config = None
|
|
||||||
|
|
||||||
if request.api_key_id:
|
|
||||||
# 查找指定的 API Key
|
|
||||||
for endpoint in provider.endpoints:
|
|
||||||
for api_key in endpoint.api_keys:
|
|
||||||
if api_key.id == request.api_key_id:
|
|
||||||
api_key_value = api_key.api_key
|
|
||||||
endpoint_config = {
|
|
||||||
"base_url": endpoint.base_url,
|
|
||||||
"api_format": endpoint.api_format if endpoint.api_format else None,
|
|
||||||
}
|
|
||||||
break
|
|
||||||
if api_key_value:
|
|
||||||
break
|
|
||||||
|
|
||||||
if not api_key_value:
|
|
||||||
raise HTTPException(status_code=404, detail="API Key not found")
|
|
||||||
else:
|
else:
|
||||||
# 使用第一个可用的 API Key
|
models_url = f"{base_url}/v1/models"
|
||||||
for endpoint in provider.endpoints:
|
|
||||||
if endpoint.is_active and endpoint.api_keys:
|
|
||||||
for api_key in endpoint.api_keys:
|
|
||||||
if api_key.is_active:
|
|
||||||
api_key_value = api_key.api_key
|
|
||||||
endpoint_config = {
|
|
||||||
"base_url": endpoint.base_url,
|
|
||||||
"api_format": endpoint.api_format if endpoint.api_format else None,
|
|
||||||
}
|
|
||||||
break
|
|
||||||
if api_key_value:
|
|
||||||
break
|
|
||||||
|
|
||||||
if not api_key_value:
|
try:
|
||||||
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
|
response = await client.get(models_url, headers=headers)
|
||||||
|
logger.debug(f"OpenAI models request to {models_url}: status={response.status_code}")
|
||||||
# 查询余额
|
if response.status_code == 200:
|
||||||
registry = get_query_registry()
|
data = response.json()
|
||||||
query_result = await registry.query_provider_balance(
|
models = []
|
||||||
provider_type=provider.name, api_key=api_key_value, endpoint_config=endpoint_config
|
if "data" in data:
|
||||||
)
|
models = data["data"]
|
||||||
|
elif isinstance(data, list):
|
||||||
if not query_result.success:
|
models = data
|
||||||
logger.warning(f"Balance query failed for provider {provider.name}: {query_result.error}")
|
# 为每个模型添加 api_format 字段
|
||||||
|
for m in models:
|
||||||
return {
|
m["api_format"] = api_format
|
||||||
"success": query_result.success,
|
return models, None
|
||||||
"data": query_result.to_dict(),
|
else:
|
||||||
"provider": {
|
# 记录详细的错误信息
|
||||||
"id": provider.id,
|
error_body = response.text[:500] if response.text else "(empty)"
|
||||||
"name": provider.name,
|
error_msg = f"HTTP {response.status_code}: {error_body}"
|
||||||
"display_name": provider.display_name,
|
logger.warning(f"OpenAI models request to {models_url} failed: {error_msg}")
|
||||||
},
|
return [], error_msg
|
||||||
}
|
except Exception as e:
|
||||||
|
error_msg = f"Request error: {str(e)}"
|
||||||
|
logger.warning(f"Failed to fetch models from {models_url}: {e}")
|
||||||
|
return [], error_msg
|
||||||
|
|
||||||
|
|
||||||
@router.post("/usage-summary")
|
async def _fetch_claude_models(
|
||||||
async def query_usage_summary(
|
client: httpx.AsyncClient, base_url: str, api_key: str, api_format: str
|
||||||
request: UsageSummaryQueryRequest,
|
) -> tuple[list, Optional[str]]:
|
||||||
db: AsyncSession = Depends(get_db),
|
"""获取 Claude 格式的模型列表
|
||||||
current_user: User = Depends(get_current_user),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
查询提供商使用汇总
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: 查询请求
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
使用汇总信息
|
tuple[list, Optional[str]]: (模型列表, 错误信息)
|
||||||
"""
|
"""
|
||||||
from sqlalchemy import select
|
headers = {
|
||||||
from sqlalchemy.orm import selectinload
|
"x-api-key": api_key,
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
# 获取提供商及其端点
|
"anthropic-version": "2023-06-01",
|
||||||
result = await db.execute(
|
|
||||||
select(Provider)
|
|
||||||
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
|
|
||||||
.where(Provider.id == request.provider_id)
|
|
||||||
)
|
|
||||||
provider = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if not provider:
|
|
||||||
raise HTTPException(status_code=404, detail="Provider not found")
|
|
||||||
|
|
||||||
# 获取 API Key(逻辑同上)
|
|
||||||
api_key_value = None
|
|
||||||
endpoint_config = None
|
|
||||||
|
|
||||||
if request.api_key_id:
|
|
||||||
for endpoint in provider.endpoints:
|
|
||||||
for api_key in endpoint.api_keys:
|
|
||||||
if api_key.id == request.api_key_id:
|
|
||||||
api_key_value = api_key.api_key
|
|
||||||
endpoint_config = {"base_url": endpoint.base_url}
|
|
||||||
break
|
|
||||||
if api_key_value:
|
|
||||||
break
|
|
||||||
|
|
||||||
if not api_key_value:
|
|
||||||
raise HTTPException(status_code=404, detail="API Key not found")
|
|
||||||
else:
|
|
||||||
for endpoint in provider.endpoints:
|
|
||||||
if endpoint.is_active and endpoint.api_keys:
|
|
||||||
for api_key in endpoint.api_keys:
|
|
||||||
if api_key.is_active:
|
|
||||||
api_key_value = api_key.api_key
|
|
||||||
endpoint_config = {"base_url": endpoint.base_url}
|
|
||||||
break
|
|
||||||
if api_key_value:
|
|
||||||
break
|
|
||||||
|
|
||||||
if not api_key_value:
|
|
||||||
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
|
|
||||||
|
|
||||||
# 查询使用汇总
|
|
||||||
registry = get_query_registry()
|
|
||||||
query_result = await registry.query_provider_usage(
|
|
||||||
provider_type=provider.name,
|
|
||||||
api_key=api_key_value,
|
|
||||||
period=request.period,
|
|
||||||
endpoint_config=endpoint_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": query_result.success,
|
|
||||||
"data": query_result.to_dict(),
|
|
||||||
"provider": {
|
|
||||||
"id": provider.id,
|
|
||||||
"name": provider.name,
|
|
||||||
"display_name": provider.display_name,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 构建 /v1/models URL
|
||||||
|
if base_url.endswith("/v1"):
|
||||||
|
models_url = f"{base_url}/models"
|
||||||
|
else:
|
||||||
|
models_url = f"{base_url}/v1/models"
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.get(models_url, headers=headers)
|
||||||
|
logger.debug(f"Claude models request to {models_url}: status={response.status_code}")
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
models = []
|
||||||
|
if "data" in data:
|
||||||
|
models = data["data"]
|
||||||
|
elif isinstance(data, list):
|
||||||
|
models = data
|
||||||
|
# 为每个模型添加 api_format 字段
|
||||||
|
for m in models:
|
||||||
|
m["api_format"] = api_format
|
||||||
|
return models, None
|
||||||
|
else:
|
||||||
|
error_body = response.text[:500] if response.text else "(empty)"
|
||||||
|
error_msg = f"HTTP {response.status_code}: {error_body}"
|
||||||
|
logger.warning(f"Claude models request to {models_url} failed: {error_msg}")
|
||||||
|
return [], error_msg
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Request error: {str(e)}"
|
||||||
|
logger.warning(f"Failed to fetch Claude models from {models_url}: {e}")
|
||||||
|
return [], error_msg
|
||||||
|
|
||||||
|
|
||||||
|
async def _fetch_gemini_models(
|
||||||
|
client: httpx.AsyncClient, base_url: str, api_key: str, api_format: str
|
||||||
|
) -> tuple[list, Optional[str]]:
|
||||||
|
"""获取 Gemini 格式的模型列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[list, Optional[str]]: (模型列表, 错误信息)
|
||||||
|
"""
|
||||||
|
# 兼容 base_url 已包含 /v1beta 的情况
|
||||||
|
base_url_clean = base_url.rstrip("/")
|
||||||
|
if base_url_clean.endswith("/v1beta"):
|
||||||
|
models_url = f"{base_url_clean}/models?key={api_key}"
|
||||||
|
else:
|
||||||
|
models_url = f"{base_url_clean}/v1beta/models?key={api_key}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.get(models_url)
|
||||||
|
logger.debug(f"Gemini models request to {models_url}: status={response.status_code}")
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
if "models" in data:
|
||||||
|
# 转换为统一格式
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": m.get("name", "").replace("models/", ""),
|
||||||
|
"owned_by": "google",
|
||||||
|
"display_name": m.get("displayName", ""),
|
||||||
|
"api_format": api_format,
|
||||||
|
}
|
||||||
|
for m in data["models"]
|
||||||
|
], None
|
||||||
|
return [], None
|
||||||
|
else:
|
||||||
|
error_body = response.text[:500] if response.text else "(empty)"
|
||||||
|
error_msg = f"HTTP {response.status_code}: {error_body}"
|
||||||
|
logger.warning(f"Gemini models request to {models_url} failed: {error_msg}")
|
||||||
|
return [], error_msg
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Request error: {str(e)}"
|
||||||
|
logger.warning(f"Failed to fetch Gemini models from {models_url}: {e}")
|
||||||
|
return [], error_msg
|
||||||
|
|
||||||
|
|
||||||
@router.post("/models")
|
@router.post("/models")
|
||||||
async def query_available_models(
|
async def query_available_models(
|
||||||
request: ModelsQueryRequest,
|
request: ModelsQueryRequest,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
查询提供商可用模型
|
查询提供商可用模型
|
||||||
|
|
||||||
|
遍历所有活跃端点,根据端点的 API 格式选择正确的请求方式:
|
||||||
|
- OPENAI/OPENAI_CLI: /v1/models (Bearer token)
|
||||||
|
- CLAUDE/CLAUDE_CLI: /v1/models (x-api-key)
|
||||||
|
- GEMINI/GEMINI_CLI: /v1beta/models (URL key parameter)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: 查询请求
|
request: 查询请求
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
模型列表
|
所有端点的模型列表(合并)
|
||||||
"""
|
"""
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.orm import selectinload
|
|
||||||
|
|
||||||
# 获取提供商及其端点
|
# 获取提供商及其端点
|
||||||
result = await db.execute(
|
provider = (
|
||||||
select(Provider)
|
db.query(Provider)
|
||||||
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
|
.options(joinedload(Provider.endpoints).joinedload(ProviderEndpoint.api_keys))
|
||||||
.where(Provider.id == request.provider_id)
|
.filter(Provider.id == request.provider_id)
|
||||||
|
.first()
|
||||||
)
|
)
|
||||||
provider = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if not provider:
|
if not provider:
|
||||||
raise HTTPException(status_code=404, detail="Provider not found")
|
raise HTTPException(status_code=404, detail="Provider not found")
|
||||||
|
|
||||||
# 获取 API Key
|
# 收集所有活跃端点的配置
|
||||||
api_key_value = None
|
endpoint_configs: list[dict] = []
|
||||||
endpoint_config = None
|
|
||||||
|
|
||||||
if request.api_key_id:
|
if request.api_key_id:
|
||||||
|
# 指定了特定的 API Key,只使用该 Key 对应的端点
|
||||||
for endpoint in provider.endpoints:
|
for endpoint in provider.endpoints:
|
||||||
for api_key in endpoint.api_keys:
|
for api_key in endpoint.api_keys:
|
||||||
if api_key.id == request.api_key_id:
|
if api_key.id == request.api_key_id:
|
||||||
api_key_value = api_key.api_key
|
try:
|
||||||
endpoint_config = {"base_url": endpoint.base_url}
|
api_key_value = crypto_service.decrypt(api_key.api_key)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to decrypt API key: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to decrypt API key")
|
||||||
|
endpoint_configs.append({
|
||||||
|
"api_key": api_key_value,
|
||||||
|
"base_url": endpoint.base_url,
|
||||||
|
"api_format": endpoint.api_format,
|
||||||
|
"extra_headers": endpoint.headers,
|
||||||
|
})
|
||||||
break
|
break
|
||||||
if api_key_value:
|
if endpoint_configs:
|
||||||
break
|
break
|
||||||
|
|
||||||
if not api_key_value:
|
if not endpoint_configs:
|
||||||
raise HTTPException(status_code=404, detail="API Key not found")
|
raise HTTPException(status_code=404, detail="API Key not found")
|
||||||
else:
|
else:
|
||||||
|
# 遍历所有活跃端点,每个端点取第一个可用的 Key
|
||||||
for endpoint in provider.endpoints:
|
for endpoint in provider.endpoints:
|
||||||
if endpoint.is_active and endpoint.api_keys:
|
if not endpoint.is_active or not endpoint.api_keys:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 找第一个可用的 Key
|
||||||
for api_key in endpoint.api_keys:
|
for api_key in endpoint.api_keys:
|
||||||
if api_key.is_active:
|
if api_key.is_active:
|
||||||
api_key_value = api_key.api_key
|
try:
|
||||||
endpoint_config = {"base_url": endpoint.base_url}
|
api_key_value = crypto_service.decrypt(api_key.api_key)
|
||||||
break
|
except Exception as e:
|
||||||
if api_key_value:
|
logger.error(f"Failed to decrypt API key: {e}")
|
||||||
break
|
continue # 尝试下一个 Key
|
||||||
|
endpoint_configs.append({
|
||||||
|
"api_key": api_key_value,
|
||||||
|
"base_url": endpoint.base_url,
|
||||||
|
"api_format": endpoint.api_format,
|
||||||
|
"extra_headers": endpoint.headers,
|
||||||
|
})
|
||||||
|
break # 只取第一个可用的 Key
|
||||||
|
|
||||||
if not api_key_value:
|
if not endpoint_configs:
|
||||||
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
|
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
|
||||||
|
|
||||||
# 查询模型
|
# 并发请求所有端点的模型列表
|
||||||
registry = get_query_registry()
|
all_models: list = []
|
||||||
adapter = registry.get_adapter_for_provider(provider.name)
|
errors: list[str] = []
|
||||||
|
|
||||||
if not adapter:
|
async def fetch_endpoint_models(
|
||||||
raise HTTPException(
|
client: httpx.AsyncClient, config: dict
|
||||||
status_code=400, detail=f"No query adapter available for provider: {provider.name}"
|
) -> tuple[list, Optional[str]]:
|
||||||
)
|
base_url = config["base_url"]
|
||||||
|
if not base_url:
|
||||||
|
return [], None
|
||||||
|
base_url = base_url.rstrip("/")
|
||||||
|
api_format = config["api_format"]
|
||||||
|
api_key_value = config["api_key"]
|
||||||
|
extra_headers = config["extra_headers"]
|
||||||
|
|
||||||
query_result = await adapter.query_available_models(
|
try:
|
||||||
api_key=api_key_value, endpoint_config=endpoint_config
|
if api_format in ["CLAUDE", "CLAUDE_CLI"]:
|
||||||
|
return await _fetch_claude_models(client, base_url, api_key_value, api_format)
|
||||||
|
elif api_format in ["GEMINI", "GEMINI_CLI"]:
|
||||||
|
return await _fetch_gemini_models(client, base_url, api_key_value, api_format)
|
||||||
|
else:
|
||||||
|
return await _fetch_openai_models(
|
||||||
|
client, base_url, api_key_value, api_format, extra_headers
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching models from {api_format} endpoint: {e}")
|
||||||
|
return [], f"{api_format}: {str(e)}"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*[fetch_endpoint_models(client, c) for c in endpoint_configs]
|
||||||
|
)
|
||||||
|
for models, error in results:
|
||||||
|
all_models.extend(models)
|
||||||
|
if error:
|
||||||
|
errors.append(error)
|
||||||
|
|
||||||
|
# 按 model id 去重(保留第一个)
|
||||||
|
seen_ids: set[str] = set()
|
||||||
|
unique_models: list = []
|
||||||
|
for model in all_models:
|
||||||
|
model_id = model.get("id")
|
||||||
|
if model_id and model_id not in seen_ids:
|
||||||
|
seen_ids.add(model_id)
|
||||||
|
unique_models.append(model)
|
||||||
|
|
||||||
|
error = "; ".join(errors) if errors else None
|
||||||
|
if not unique_models and not error:
|
||||||
|
error = "No models returned from any endpoint"
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": query_result.success,
|
"success": len(unique_models) > 0,
|
||||||
"data": query_result.to_dict(),
|
"data": {"models": unique_models, "error": error},
|
||||||
"provider": {
|
"provider": {
|
||||||
"id": provider.id,
|
"id": provider.id,
|
||||||
"name": provider.name,
|
"name": provider.name,
|
||||||
"display_name": provider.display_name,
|
"display_name": provider.display_name,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/cache/{provider_id}")
|
|
||||||
async def clear_query_cache(
|
|
||||||
provider_id: str,
|
|
||||||
api_key_id: Optional[str] = None,
|
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
清除查询缓存
|
|
||||||
|
|
||||||
Args:
|
|
||||||
provider_id: 提供商 ID
|
|
||||||
api_key_id: 可选,指定清除某个 API Key 的缓存
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
清除结果
|
|
||||||
"""
|
|
||||||
from sqlalchemy import select
|
|
||||||
|
|
||||||
# 获取提供商
|
|
||||||
result = await db.execute(select(Provider).where(Provider.id == provider_id))
|
|
||||||
provider = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if not provider:
|
|
||||||
raise HTTPException(status_code=404, detail="Provider not found")
|
|
||||||
|
|
||||||
registry = get_query_registry()
|
|
||||||
adapter = registry.get_adapter_for_provider(provider.name)
|
|
||||||
|
|
||||||
if adapter:
|
|
||||||
if api_key_id:
|
|
||||||
# 获取 API Key 值来清除缓存
|
|
||||||
from sqlalchemy.orm import selectinload
|
|
||||||
|
|
||||||
result = await db.execute(select(ProviderAPIKey).where(ProviderAPIKey.id == api_key_id))
|
|
||||||
api_key = result.scalar_one_or_none()
|
|
||||||
if api_key:
|
|
||||||
adapter.clear_cache(api_key.api_key)
|
|
||||||
else:
|
|
||||||
adapter.clear_cache()
|
|
||||||
|
|
||||||
return {"success": True, "message": "Cache cleared successfully"}
|
|
||||||
|
|||||||
@@ -91,6 +91,34 @@ async def get_api_formats(request: Request, db: Session = Depends(get_db)):
|
|||||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/config/export")
|
||||||
|
async def export_config(request: Request, db: Session = Depends(get_db)):
|
||||||
|
"""导出提供商和模型配置(管理员)"""
|
||||||
|
adapter = AdminExportConfigAdapter()
|
||||||
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/config/import")
|
||||||
|
async def import_config(request: Request, db: Session = Depends(get_db)):
|
||||||
|
"""导入提供商和模型配置(管理员)"""
|
||||||
|
adapter = AdminImportConfigAdapter()
|
||||||
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/users/export")
|
||||||
|
async def export_users(request: Request, db: Session = Depends(get_db)):
|
||||||
|
"""导出用户数据(管理员)"""
|
||||||
|
adapter = AdminExportUsersAdapter()
|
||||||
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/users/import")
|
||||||
|
async def import_users(request: Request, db: Session = Depends(get_db)):
|
||||||
|
"""导入用户数据(管理员)"""
|
||||||
|
adapter = AdminImportUsersAdapter()
|
||||||
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||||
|
|
||||||
|
|
||||||
# -------- 系统设置适配器 --------
|
# -------- 系统设置适配器 --------
|
||||||
|
|
||||||
|
|
||||||
@@ -310,3 +338,749 @@ class AdminGetApiFormatsAdapter(AdminApiAdapter):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return {"formats": formats}
|
return {"formats": formats}
|
||||||
|
|
||||||
|
|
||||||
|
class AdminExportConfigAdapter(AdminApiAdapter):
|
||||||
|
async def handle(self, context): # type: ignore[override]
|
||||||
|
"""导出提供商和模型配置(解密数据)"""
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from src.core.crypto import crypto_service
|
||||||
|
from src.models.database import GlobalModel, Model, ProviderAPIKey, ProviderEndpoint
|
||||||
|
|
||||||
|
db = context.db
|
||||||
|
|
||||||
|
# 导出 GlobalModels
|
||||||
|
global_models = db.query(GlobalModel).all()
|
||||||
|
global_models_data = []
|
||||||
|
for gm in global_models:
|
||||||
|
global_models_data.append(
|
||||||
|
{
|
||||||
|
"name": gm.name,
|
||||||
|
"display_name": gm.display_name,
|
||||||
|
"default_price_per_request": gm.default_price_per_request,
|
||||||
|
"default_tiered_pricing": gm.default_tiered_pricing,
|
||||||
|
"supported_capabilities": gm.supported_capabilities,
|
||||||
|
"config": gm.config,
|
||||||
|
"is_active": gm.is_active,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 导出 Providers 及其关联数据
|
||||||
|
providers = db.query(Provider).all()
|
||||||
|
providers_data = []
|
||||||
|
for provider in providers:
|
||||||
|
# 导出 Endpoints
|
||||||
|
endpoints = (
|
||||||
|
db.query(ProviderEndpoint)
|
||||||
|
.filter(ProviderEndpoint.provider_id == provider.id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
endpoints_data = []
|
||||||
|
for ep in endpoints:
|
||||||
|
# 导出 Endpoint Keys
|
||||||
|
keys = (
|
||||||
|
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == ep.id).all()
|
||||||
|
)
|
||||||
|
keys_data = []
|
||||||
|
for key in keys:
|
||||||
|
# 解密 API Key
|
||||||
|
try:
|
||||||
|
decrypted_key = crypto_service.decrypt(key.api_key)
|
||||||
|
except Exception:
|
||||||
|
decrypted_key = ""
|
||||||
|
|
||||||
|
keys_data.append(
|
||||||
|
{
|
||||||
|
"api_key": decrypted_key,
|
||||||
|
"name": key.name,
|
||||||
|
"note": key.note,
|
||||||
|
"rate_multiplier": key.rate_multiplier,
|
||||||
|
"internal_priority": key.internal_priority,
|
||||||
|
"global_priority": key.global_priority,
|
||||||
|
"max_concurrent": key.max_concurrent,
|
||||||
|
"rate_limit": key.rate_limit,
|
||||||
|
"daily_limit": key.daily_limit,
|
||||||
|
"monthly_limit": key.monthly_limit,
|
||||||
|
"allowed_models": key.allowed_models,
|
||||||
|
"capabilities": key.capabilities,
|
||||||
|
"is_active": key.is_active,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
endpoints_data.append(
|
||||||
|
{
|
||||||
|
"api_format": ep.api_format,
|
||||||
|
"base_url": ep.base_url,
|
||||||
|
"headers": ep.headers,
|
||||||
|
"timeout": ep.timeout,
|
||||||
|
"max_retries": ep.max_retries,
|
||||||
|
"max_concurrent": ep.max_concurrent,
|
||||||
|
"rate_limit": ep.rate_limit,
|
||||||
|
"is_active": ep.is_active,
|
||||||
|
"custom_path": ep.custom_path,
|
||||||
|
"config": ep.config,
|
||||||
|
"keys": keys_data,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 导出 Provider Models
|
||||||
|
models = db.query(Model).filter(Model.provider_id == provider.id).all()
|
||||||
|
models_data = []
|
||||||
|
for model in models:
|
||||||
|
# 获取关联的 GlobalModel 名称
|
||||||
|
global_model = (
|
||||||
|
db.query(GlobalModel).filter(GlobalModel.id == model.global_model_id).first()
|
||||||
|
)
|
||||||
|
models_data.append(
|
||||||
|
{
|
||||||
|
"global_model_name": global_model.name if global_model else None,
|
||||||
|
"provider_model_name": model.provider_model_name,
|
||||||
|
"provider_model_aliases": model.provider_model_aliases,
|
||||||
|
"price_per_request": model.price_per_request,
|
||||||
|
"tiered_pricing": model.tiered_pricing,
|
||||||
|
"supports_vision": model.supports_vision,
|
||||||
|
"supports_function_calling": model.supports_function_calling,
|
||||||
|
"supports_streaming": model.supports_streaming,
|
||||||
|
"supports_extended_thinking": model.supports_extended_thinking,
|
||||||
|
"supports_image_generation": model.supports_image_generation,
|
||||||
|
"is_active": model.is_active,
|
||||||
|
"config": model.config,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
providers_data.append(
|
||||||
|
{
|
||||||
|
"name": provider.name,
|
||||||
|
"display_name": provider.display_name,
|
||||||
|
"description": provider.description,
|
||||||
|
"website": provider.website,
|
||||||
|
"billing_type": provider.billing_type.value if provider.billing_type else None,
|
||||||
|
"monthly_quota_usd": provider.monthly_quota_usd,
|
||||||
|
"quota_reset_day": provider.quota_reset_day,
|
||||||
|
"rpm_limit": provider.rpm_limit,
|
||||||
|
"provider_priority": provider.provider_priority,
|
||||||
|
"is_active": provider.is_active,
|
||||||
|
"rate_limit": provider.rate_limit,
|
||||||
|
"concurrent_limit": provider.concurrent_limit,
|
||||||
|
"config": provider.config,
|
||||||
|
"endpoints": endpoints_data,
|
||||||
|
"models": models_data,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"version": "1.0",
|
||||||
|
"exported_at": datetime.now(timezone.utc).isoformat(),
|
||||||
|
"global_models": global_models_data,
|
||||||
|
"providers": providers_data,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
MAX_IMPORT_SIZE = 10 * 1024 * 1024 # 10MB
|
||||||
|
|
||||||
|
|
||||||
|
class AdminImportConfigAdapter(AdminApiAdapter):
|
||||||
|
async def handle(self, context): # type: ignore[override]
|
||||||
|
"""导入提供商和模型配置"""
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from src.core.crypto import crypto_service
|
||||||
|
from src.core.enums import ProviderBillingType
|
||||||
|
from src.models.database import GlobalModel, Model, ProviderAPIKey, ProviderEndpoint
|
||||||
|
|
||||||
|
# 检查请求体大小
|
||||||
|
if context.raw_body and len(context.raw_body) > MAX_IMPORT_SIZE:
|
||||||
|
raise InvalidRequestException("请求体大小不能超过 10MB")
|
||||||
|
|
||||||
|
db = context.db
|
||||||
|
payload = context.ensure_json_body()
|
||||||
|
|
||||||
|
# 验证配置版本
|
||||||
|
version = payload.get("version")
|
||||||
|
if version != "1.0":
|
||||||
|
raise InvalidRequestException(f"不支持的配置版本: {version}")
|
||||||
|
|
||||||
|
# 获取导入选项
|
||||||
|
merge_mode = payload.get("merge_mode", "skip") # skip, overwrite, error
|
||||||
|
global_models_data = payload.get("global_models", [])
|
||||||
|
providers_data = payload.get("providers", [])
|
||||||
|
|
||||||
|
stats = {
|
||||||
|
"global_models": {"created": 0, "updated": 0, "skipped": 0},
|
||||||
|
"providers": {"created": 0, "updated": 0, "skipped": 0},
|
||||||
|
"endpoints": {"created": 0, "updated": 0, "skipped": 0},
|
||||||
|
"keys": {"created": 0, "updated": 0, "skipped": 0},
|
||||||
|
"models": {"created": 0, "updated": 0, "skipped": 0},
|
||||||
|
"errors": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 导入 GlobalModels
|
||||||
|
global_model_map = {} # name -> id 映射
|
||||||
|
for gm_data in global_models_data:
|
||||||
|
existing = (
|
||||||
|
db.query(GlobalModel).filter(GlobalModel.name == gm_data["name"]).first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
global_model_map[gm_data["name"]] = existing.id
|
||||||
|
if merge_mode == "skip":
|
||||||
|
stats["global_models"]["skipped"] += 1
|
||||||
|
continue
|
||||||
|
elif merge_mode == "error":
|
||||||
|
raise InvalidRequestException(
|
||||||
|
f"GlobalModel '{gm_data['name']}' 已存在"
|
||||||
|
)
|
||||||
|
elif merge_mode == "overwrite":
|
||||||
|
# 更新现有记录
|
||||||
|
existing.display_name = gm_data.get(
|
||||||
|
"display_name", existing.display_name
|
||||||
|
)
|
||||||
|
existing.default_price_per_request = gm_data.get(
|
||||||
|
"default_price_per_request"
|
||||||
|
)
|
||||||
|
existing.default_tiered_pricing = gm_data.get(
|
||||||
|
"default_tiered_pricing", existing.default_tiered_pricing
|
||||||
|
)
|
||||||
|
existing.supported_capabilities = gm_data.get(
|
||||||
|
"supported_capabilities"
|
||||||
|
)
|
||||||
|
existing.config = gm_data.get("config")
|
||||||
|
existing.is_active = gm_data.get("is_active", True)
|
||||||
|
existing.updated_at = datetime.now(timezone.utc)
|
||||||
|
stats["global_models"]["updated"] += 1
|
||||||
|
else:
|
||||||
|
# 创建新记录
|
||||||
|
new_gm = GlobalModel(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
name=gm_data["name"],
|
||||||
|
display_name=gm_data.get("display_name", gm_data["name"]),
|
||||||
|
default_price_per_request=gm_data.get("default_price_per_request"),
|
||||||
|
default_tiered_pricing=gm_data.get(
|
||||||
|
"default_tiered_pricing",
|
||||||
|
{"tiers": [{"up_to": None, "input_price_per_1m": 0, "output_price_per_1m": 0}]},
|
||||||
|
),
|
||||||
|
supported_capabilities=gm_data.get("supported_capabilities"),
|
||||||
|
config=gm_data.get("config"),
|
||||||
|
is_active=gm_data.get("is_active", True),
|
||||||
|
)
|
||||||
|
db.add(new_gm)
|
||||||
|
db.flush()
|
||||||
|
global_model_map[gm_data["name"]] = new_gm.id
|
||||||
|
stats["global_models"]["created"] += 1
|
||||||
|
|
||||||
|
# 导入 Providers
|
||||||
|
for prov_data in providers_data:
|
||||||
|
existing_provider = (
|
||||||
|
db.query(Provider).filter(Provider.name == prov_data["name"]).first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if existing_provider:
|
||||||
|
provider_id = existing_provider.id
|
||||||
|
if merge_mode == "skip":
|
||||||
|
stats["providers"]["skipped"] += 1
|
||||||
|
# 仍然需要处理 endpoints 和 models(如果存在)
|
||||||
|
elif merge_mode == "error":
|
||||||
|
raise InvalidRequestException(
|
||||||
|
f"Provider '{prov_data['name']}' 已存在"
|
||||||
|
)
|
||||||
|
elif merge_mode == "overwrite":
|
||||||
|
# 更新现有记录
|
||||||
|
existing_provider.display_name = prov_data.get(
|
||||||
|
"display_name", existing_provider.display_name
|
||||||
|
)
|
||||||
|
existing_provider.description = prov_data.get("description")
|
||||||
|
existing_provider.website = prov_data.get("website")
|
||||||
|
if prov_data.get("billing_type"):
|
||||||
|
existing_provider.billing_type = ProviderBillingType(
|
||||||
|
prov_data["billing_type"]
|
||||||
|
)
|
||||||
|
existing_provider.monthly_quota_usd = prov_data.get(
|
||||||
|
"monthly_quota_usd"
|
||||||
|
)
|
||||||
|
existing_provider.quota_reset_day = prov_data.get(
|
||||||
|
"quota_reset_day", 30
|
||||||
|
)
|
||||||
|
existing_provider.rpm_limit = prov_data.get("rpm_limit")
|
||||||
|
existing_provider.provider_priority = prov_data.get(
|
||||||
|
"provider_priority", 100
|
||||||
|
)
|
||||||
|
existing_provider.is_active = prov_data.get("is_active", True)
|
||||||
|
existing_provider.rate_limit = prov_data.get("rate_limit")
|
||||||
|
existing_provider.concurrent_limit = prov_data.get(
|
||||||
|
"concurrent_limit"
|
||||||
|
)
|
||||||
|
existing_provider.config = prov_data.get("config")
|
||||||
|
existing_provider.updated_at = datetime.now(timezone.utc)
|
||||||
|
stats["providers"]["updated"] += 1
|
||||||
|
else:
|
||||||
|
# 创建新 Provider
|
||||||
|
billing_type = ProviderBillingType.PAY_AS_YOU_GO
|
||||||
|
if prov_data.get("billing_type"):
|
||||||
|
billing_type = ProviderBillingType(prov_data["billing_type"])
|
||||||
|
|
||||||
|
new_provider = Provider(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
name=prov_data["name"],
|
||||||
|
display_name=prov_data.get("display_name", prov_data["name"]),
|
||||||
|
description=prov_data.get("description"),
|
||||||
|
website=prov_data.get("website"),
|
||||||
|
billing_type=billing_type,
|
||||||
|
monthly_quota_usd=prov_data.get("monthly_quota_usd"),
|
||||||
|
quota_reset_day=prov_data.get("quota_reset_day", 30),
|
||||||
|
rpm_limit=prov_data.get("rpm_limit"),
|
||||||
|
provider_priority=prov_data.get("provider_priority", 100),
|
||||||
|
is_active=prov_data.get("is_active", True),
|
||||||
|
rate_limit=prov_data.get("rate_limit"),
|
||||||
|
concurrent_limit=prov_data.get("concurrent_limit"),
|
||||||
|
config=prov_data.get("config"),
|
||||||
|
)
|
||||||
|
db.add(new_provider)
|
||||||
|
db.flush()
|
||||||
|
provider_id = new_provider.id
|
||||||
|
stats["providers"]["created"] += 1
|
||||||
|
|
||||||
|
# 导入 Endpoints
|
||||||
|
for ep_data in prov_data.get("endpoints", []):
|
||||||
|
existing_ep = (
|
||||||
|
db.query(ProviderEndpoint)
|
||||||
|
.filter(
|
||||||
|
ProviderEndpoint.provider_id == provider_id,
|
||||||
|
ProviderEndpoint.api_format == ep_data["api_format"],
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if existing_ep:
|
||||||
|
endpoint_id = existing_ep.id
|
||||||
|
if merge_mode == "skip":
|
||||||
|
stats["endpoints"]["skipped"] += 1
|
||||||
|
elif merge_mode == "error":
|
||||||
|
raise InvalidRequestException(
|
||||||
|
f"Endpoint '{ep_data['api_format']}' 已存在于 Provider '{prov_data['name']}'"
|
||||||
|
)
|
||||||
|
elif merge_mode == "overwrite":
|
||||||
|
existing_ep.base_url = ep_data.get(
|
||||||
|
"base_url", existing_ep.base_url
|
||||||
|
)
|
||||||
|
existing_ep.headers = ep_data.get("headers")
|
||||||
|
existing_ep.timeout = ep_data.get("timeout", 300)
|
||||||
|
existing_ep.max_retries = ep_data.get("max_retries", 3)
|
||||||
|
existing_ep.max_concurrent = ep_data.get("max_concurrent")
|
||||||
|
existing_ep.rate_limit = ep_data.get("rate_limit")
|
||||||
|
existing_ep.is_active = ep_data.get("is_active", True)
|
||||||
|
existing_ep.custom_path = ep_data.get("custom_path")
|
||||||
|
existing_ep.config = ep_data.get("config")
|
||||||
|
existing_ep.updated_at = datetime.now(timezone.utc)
|
||||||
|
stats["endpoints"]["updated"] += 1
|
||||||
|
else:
|
||||||
|
new_ep = ProviderEndpoint(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
provider_id=provider_id,
|
||||||
|
api_format=ep_data["api_format"],
|
||||||
|
base_url=ep_data["base_url"],
|
||||||
|
headers=ep_data.get("headers"),
|
||||||
|
timeout=ep_data.get("timeout", 300),
|
||||||
|
max_retries=ep_data.get("max_retries", 3),
|
||||||
|
max_concurrent=ep_data.get("max_concurrent"),
|
||||||
|
rate_limit=ep_data.get("rate_limit"),
|
||||||
|
is_active=ep_data.get("is_active", True),
|
||||||
|
custom_path=ep_data.get("custom_path"),
|
||||||
|
config=ep_data.get("config"),
|
||||||
|
)
|
||||||
|
db.add(new_ep)
|
||||||
|
db.flush()
|
||||||
|
endpoint_id = new_ep.id
|
||||||
|
stats["endpoints"]["created"] += 1
|
||||||
|
|
||||||
|
# 导入 Keys
|
||||||
|
# 获取当前 endpoint 下所有已有的 keys,用于去重
|
||||||
|
existing_keys = (
|
||||||
|
db.query(ProviderAPIKey)
|
||||||
|
.filter(ProviderAPIKey.endpoint_id == endpoint_id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
# 解密已有 keys 用于比对
|
||||||
|
existing_key_values = set()
|
||||||
|
for ek in existing_keys:
|
||||||
|
try:
|
||||||
|
decrypted = crypto_service.decrypt(ek.api_key)
|
||||||
|
existing_key_values.add(decrypted)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
for key_data in ep_data.get("keys", []):
|
||||||
|
if not key_data.get("api_key"):
|
||||||
|
stats["errors"].append(
|
||||||
|
f"跳过空 API Key (Endpoint: {ep_data['api_format']})"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查是否已存在相同的 Key(通过明文比对)
|
||||||
|
if key_data["api_key"] in existing_key_values:
|
||||||
|
stats["keys"]["skipped"] += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
encrypted_key = crypto_service.encrypt(key_data["api_key"])
|
||||||
|
|
||||||
|
new_key = ProviderAPIKey(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
endpoint_id=endpoint_id,
|
||||||
|
api_key=encrypted_key,
|
||||||
|
name=key_data.get("name"),
|
||||||
|
note=key_data.get("note"),
|
||||||
|
rate_multiplier=key_data.get("rate_multiplier", 1.0),
|
||||||
|
internal_priority=key_data.get("internal_priority", 100),
|
||||||
|
global_priority=key_data.get("global_priority"),
|
||||||
|
max_concurrent=key_data.get("max_concurrent"),
|
||||||
|
rate_limit=key_data.get("rate_limit"),
|
||||||
|
daily_limit=key_data.get("daily_limit"),
|
||||||
|
monthly_limit=key_data.get("monthly_limit"),
|
||||||
|
allowed_models=key_data.get("allowed_models"),
|
||||||
|
capabilities=key_data.get("capabilities"),
|
||||||
|
is_active=key_data.get("is_active", True),
|
||||||
|
)
|
||||||
|
db.add(new_key)
|
||||||
|
# 添加到已有集合,防止同一批导入中重复
|
||||||
|
existing_key_values.add(key_data["api_key"])
|
||||||
|
stats["keys"]["created"] += 1
|
||||||
|
|
||||||
|
# 导入 Models
|
||||||
|
for model_data in prov_data.get("models", []):
|
||||||
|
global_model_name = model_data.get("global_model_name")
|
||||||
|
if not global_model_name:
|
||||||
|
stats["errors"].append(
|
||||||
|
f"跳过无 global_model_name 的模型 (Provider: {prov_data['name']})"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
global_model_id = global_model_map.get(global_model_name)
|
||||||
|
if not global_model_id:
|
||||||
|
# 尝试从数据库查找
|
||||||
|
existing_gm = (
|
||||||
|
db.query(GlobalModel)
|
||||||
|
.filter(GlobalModel.name == global_model_name)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if existing_gm:
|
||||||
|
global_model_id = existing_gm.id
|
||||||
|
else:
|
||||||
|
stats["errors"].append(
|
||||||
|
f"GlobalModel '{global_model_name}' 不存在,跳过模型"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
existing_model = (
|
||||||
|
db.query(Model)
|
||||||
|
.filter(
|
||||||
|
Model.provider_id == provider_id,
|
||||||
|
Model.provider_model_name == model_data["provider_model_name"],
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if existing_model:
|
||||||
|
if merge_mode == "skip":
|
||||||
|
stats["models"]["skipped"] += 1
|
||||||
|
elif merge_mode == "error":
|
||||||
|
raise InvalidRequestException(
|
||||||
|
f"Model '{model_data['provider_model_name']}' 已存在于 Provider '{prov_data['name']}'"
|
||||||
|
)
|
||||||
|
elif merge_mode == "overwrite":
|
||||||
|
existing_model.global_model_id = global_model_id
|
||||||
|
existing_model.provider_model_aliases = model_data.get(
|
||||||
|
"provider_model_aliases"
|
||||||
|
)
|
||||||
|
existing_model.price_per_request = model_data.get(
|
||||||
|
"price_per_request"
|
||||||
|
)
|
||||||
|
existing_model.tiered_pricing = model_data.get(
|
||||||
|
"tiered_pricing"
|
||||||
|
)
|
||||||
|
existing_model.supports_vision = model_data.get(
|
||||||
|
"supports_vision"
|
||||||
|
)
|
||||||
|
existing_model.supports_function_calling = model_data.get(
|
||||||
|
"supports_function_calling"
|
||||||
|
)
|
||||||
|
existing_model.supports_streaming = model_data.get(
|
||||||
|
"supports_streaming"
|
||||||
|
)
|
||||||
|
existing_model.supports_extended_thinking = model_data.get(
|
||||||
|
"supports_extended_thinking"
|
||||||
|
)
|
||||||
|
existing_model.supports_image_generation = model_data.get(
|
||||||
|
"supports_image_generation"
|
||||||
|
)
|
||||||
|
existing_model.is_active = model_data.get("is_active", True)
|
||||||
|
existing_model.config = model_data.get("config")
|
||||||
|
existing_model.updated_at = datetime.now(timezone.utc)
|
||||||
|
stats["models"]["updated"] += 1
|
||||||
|
else:
|
||||||
|
new_model = Model(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
provider_id=provider_id,
|
||||||
|
global_model_id=global_model_id,
|
||||||
|
provider_model_name=model_data["provider_model_name"],
|
||||||
|
provider_model_aliases=model_data.get(
|
||||||
|
"provider_model_aliases"
|
||||||
|
),
|
||||||
|
price_per_request=model_data.get("price_per_request"),
|
||||||
|
tiered_pricing=model_data.get("tiered_pricing"),
|
||||||
|
supports_vision=model_data.get("supports_vision"),
|
||||||
|
supports_function_calling=model_data.get(
|
||||||
|
"supports_function_calling"
|
||||||
|
),
|
||||||
|
supports_streaming=model_data.get("supports_streaming"),
|
||||||
|
supports_extended_thinking=model_data.get(
|
||||||
|
"supports_extended_thinking"
|
||||||
|
),
|
||||||
|
supports_image_generation=model_data.get(
|
||||||
|
"supports_image_generation"
|
||||||
|
),
|
||||||
|
is_active=model_data.get("is_active", True),
|
||||||
|
config=model_data.get("config"),
|
||||||
|
)
|
||||||
|
db.add(new_model)
|
||||||
|
stats["models"]["created"] += 1
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
# 失效缓存
|
||||||
|
from src.services.cache.invalidation import get_cache_invalidation_service
|
||||||
|
|
||||||
|
cache_service = get_cache_invalidation_service()
|
||||||
|
cache_service.invalidate_all()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": "配置导入成功",
|
||||||
|
"stats": stats,
|
||||||
|
}
|
||||||
|
|
||||||
|
except InvalidRequestException:
|
||||||
|
db.rollback()
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
raise InvalidRequestException(f"导入失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
class AdminExportUsersAdapter(AdminApiAdapter):
|
||||||
|
async def handle(self, context): # type: ignore[override]
|
||||||
|
"""导出用户数据(保留加密数据,排除管理员)"""
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from src.core.enums import UserRole
|
||||||
|
from src.models.database import ApiKey, User
|
||||||
|
|
||||||
|
db = context.db
|
||||||
|
|
||||||
|
# 导出 Users(排除管理员)
|
||||||
|
users = db.query(User).filter(
|
||||||
|
User.is_deleted.is_(False),
|
||||||
|
User.role != UserRole.ADMIN
|
||||||
|
).all()
|
||||||
|
users_data = []
|
||||||
|
for user in users:
|
||||||
|
# 导出用户的 API Keys(保留加密数据)
|
||||||
|
api_keys = db.query(ApiKey).filter(ApiKey.user_id == user.id).all()
|
||||||
|
api_keys_data = []
|
||||||
|
for key in api_keys:
|
||||||
|
api_keys_data.append(
|
||||||
|
{
|
||||||
|
"key_hash": key.key_hash,
|
||||||
|
"key_encrypted": key.key_encrypted,
|
||||||
|
"name": key.name,
|
||||||
|
"is_standalone": key.is_standalone,
|
||||||
|
"balance_used_usd": key.balance_used_usd,
|
||||||
|
"current_balance_usd": key.current_balance_usd,
|
||||||
|
"allowed_providers": key.allowed_providers,
|
||||||
|
"allowed_endpoints": key.allowed_endpoints,
|
||||||
|
"allowed_api_formats": key.allowed_api_formats,
|
||||||
|
"allowed_models": key.allowed_models,
|
||||||
|
"rate_limit": key.rate_limit,
|
||||||
|
"concurrent_limit": key.concurrent_limit,
|
||||||
|
"force_capabilities": key.force_capabilities,
|
||||||
|
"is_active": key.is_active,
|
||||||
|
"auto_delete_on_expiry": key.auto_delete_on_expiry,
|
||||||
|
"total_requests": key.total_requests,
|
||||||
|
"total_cost_usd": key.total_cost_usd,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
users_data.append(
|
||||||
|
{
|
||||||
|
"email": user.email,
|
||||||
|
"username": user.username,
|
||||||
|
"password_hash": user.password_hash,
|
||||||
|
"role": user.role.value if user.role else "user",
|
||||||
|
"allowed_providers": user.allowed_providers,
|
||||||
|
"allowed_endpoints": user.allowed_endpoints,
|
||||||
|
"allowed_models": user.allowed_models,
|
||||||
|
"model_capability_settings": user.model_capability_settings,
|
||||||
|
"quota_usd": user.quota_usd,
|
||||||
|
"used_usd": user.used_usd,
|
||||||
|
"total_usd": user.total_usd,
|
||||||
|
"is_active": user.is_active,
|
||||||
|
"api_keys": api_keys_data,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"version": "1.0",
|
||||||
|
"exported_at": datetime.now(timezone.utc).isoformat(),
|
||||||
|
"users": users_data,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AdminImportUsersAdapter(AdminApiAdapter):
|
||||||
|
async def handle(self, context): # type: ignore[override]
|
||||||
|
"""导入用户数据"""
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from src.core.enums import UserRole
|
||||||
|
from src.models.database import ApiKey, User
|
||||||
|
|
||||||
|
# 检查请求体大小
|
||||||
|
if context.raw_body and len(context.raw_body) > MAX_IMPORT_SIZE:
|
||||||
|
raise InvalidRequestException("请求体大小不能超过 10MB")
|
||||||
|
|
||||||
|
db = context.db
|
||||||
|
payload = context.ensure_json_body()
|
||||||
|
|
||||||
|
# 验证配置版本
|
||||||
|
version = payload.get("version")
|
||||||
|
if version != "1.0":
|
||||||
|
raise InvalidRequestException(f"不支持的配置版本: {version}")
|
||||||
|
|
||||||
|
# 获取导入选项
|
||||||
|
merge_mode = payload.get("merge_mode", "skip") # skip, overwrite, error
|
||||||
|
users_data = payload.get("users", [])
|
||||||
|
|
||||||
|
stats = {
|
||||||
|
"users": {"created": 0, "updated": 0, "skipped": 0},
|
||||||
|
"api_keys": {"created": 0, "skipped": 0},
|
||||||
|
"errors": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
for user_data in users_data:
|
||||||
|
# 跳过管理员角色的导入(不区分大小写)
|
||||||
|
role_str = str(user_data.get("role", "")).lower()
|
||||||
|
if role_str == "admin":
|
||||||
|
stats["errors"].append(f"跳过管理员用户: {user_data.get('email')}")
|
||||||
|
stats["users"]["skipped"] += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
existing_user = (
|
||||||
|
db.query(User).filter(User.email == user_data["email"]).first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if existing_user:
|
||||||
|
user_id = existing_user.id
|
||||||
|
if merge_mode == "skip":
|
||||||
|
stats["users"]["skipped"] += 1
|
||||||
|
elif merge_mode == "error":
|
||||||
|
raise InvalidRequestException(
|
||||||
|
f"用户 '{user_data['email']}' 已存在"
|
||||||
|
)
|
||||||
|
elif merge_mode == "overwrite":
|
||||||
|
# 更新现有用户
|
||||||
|
existing_user.username = user_data.get(
|
||||||
|
"username", existing_user.username
|
||||||
|
)
|
||||||
|
if user_data.get("password_hash"):
|
||||||
|
existing_user.password_hash = user_data["password_hash"]
|
||||||
|
if user_data.get("role"):
|
||||||
|
existing_user.role = UserRole(user_data["role"])
|
||||||
|
existing_user.allowed_providers = user_data.get("allowed_providers")
|
||||||
|
existing_user.allowed_endpoints = user_data.get("allowed_endpoints")
|
||||||
|
existing_user.allowed_models = user_data.get("allowed_models")
|
||||||
|
existing_user.model_capability_settings = user_data.get(
|
||||||
|
"model_capability_settings"
|
||||||
|
)
|
||||||
|
existing_user.quota_usd = user_data.get("quota_usd")
|
||||||
|
existing_user.used_usd = user_data.get("used_usd", 0.0)
|
||||||
|
existing_user.total_usd = user_data.get("total_usd", 0.0)
|
||||||
|
existing_user.is_active = user_data.get("is_active", True)
|
||||||
|
existing_user.updated_at = datetime.now(timezone.utc)
|
||||||
|
stats["users"]["updated"] += 1
|
||||||
|
else:
|
||||||
|
# 创建新用户
|
||||||
|
role = UserRole.USER
|
||||||
|
if user_data.get("role"):
|
||||||
|
role = UserRole(user_data["role"])
|
||||||
|
|
||||||
|
new_user = User(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
email=user_data["email"],
|
||||||
|
username=user_data.get("username", user_data["email"].split("@")[0]),
|
||||||
|
password_hash=user_data.get("password_hash", ""),
|
||||||
|
role=role,
|
||||||
|
allowed_providers=user_data.get("allowed_providers"),
|
||||||
|
allowed_endpoints=user_data.get("allowed_endpoints"),
|
||||||
|
allowed_models=user_data.get("allowed_models"),
|
||||||
|
model_capability_settings=user_data.get("model_capability_settings"),
|
||||||
|
quota_usd=user_data.get("quota_usd"),
|
||||||
|
used_usd=user_data.get("used_usd", 0.0),
|
||||||
|
total_usd=user_data.get("total_usd", 0.0),
|
||||||
|
is_active=user_data.get("is_active", True),
|
||||||
|
)
|
||||||
|
db.add(new_user)
|
||||||
|
db.flush()
|
||||||
|
user_id = new_user.id
|
||||||
|
stats["users"]["created"] += 1
|
||||||
|
|
||||||
|
# 导入 API Keys
|
||||||
|
for key_data in user_data.get("api_keys", []):
|
||||||
|
# 检查是否已存在相同的 key_hash
|
||||||
|
if key_data.get("key_hash"):
|
||||||
|
existing_key = (
|
||||||
|
db.query(ApiKey)
|
||||||
|
.filter(ApiKey.key_hash == key_data["key_hash"])
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if existing_key:
|
||||||
|
stats["api_keys"]["skipped"] += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_key = ApiKey(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
key_hash=key_data.get("key_hash", ""),
|
||||||
|
key_encrypted=key_data.get("key_encrypted"),
|
||||||
|
name=key_data.get("name"),
|
||||||
|
is_standalone=key_data.get("is_standalone", False),
|
||||||
|
balance_used_usd=key_data.get("balance_used_usd", 0.0),
|
||||||
|
current_balance_usd=key_data.get("current_balance_usd"),
|
||||||
|
allowed_providers=key_data.get("allowed_providers"),
|
||||||
|
allowed_endpoints=key_data.get("allowed_endpoints"),
|
||||||
|
allowed_api_formats=key_data.get("allowed_api_formats"),
|
||||||
|
allowed_models=key_data.get("allowed_models"),
|
||||||
|
rate_limit=key_data.get("rate_limit", 100),
|
||||||
|
concurrent_limit=key_data.get("concurrent_limit", 5),
|
||||||
|
force_capabilities=key_data.get("force_capabilities"),
|
||||||
|
is_active=key_data.get("is_active", True),
|
||||||
|
auto_delete_on_expiry=key_data.get("auto_delete_on_expiry", False),
|
||||||
|
total_requests=key_data.get("total_requests", 0),
|
||||||
|
total_cost_usd=key_data.get("total_cost_usd", 0.0),
|
||||||
|
)
|
||||||
|
db.add(new_key)
|
||||||
|
stats["api_keys"]["created"] += 1
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": "用户数据导入成功",
|
||||||
|
"stats": stats,
|
||||||
|
}
|
||||||
|
|
||||||
|
except InvalidRequestException:
|
||||||
|
db.rollback()
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
raise InvalidRequestException(f"导入失败: {str(e)}")
|
||||||
|
|||||||
@@ -731,8 +731,15 @@ class DashboardDailyStatsAdapter(DashboardAdapter):
|
|||||||
)
|
)
|
||||||
# stats_daily.date 存储的是业务日期对应的 UTC 开始时间
|
# stats_daily.date 存储的是业务日期对应的 UTC 开始时间
|
||||||
# 需要转回业务时区再取日期,才能与日期序列匹配
|
# 需要转回业务时区再取日期,才能与日期序列匹配
|
||||||
|
def _to_business_date_str(value: datetime) -> str:
|
||||||
|
if value.tzinfo is None:
|
||||||
|
value_utc = value.replace(tzinfo=timezone.utc)
|
||||||
|
else:
|
||||||
|
value_utc = value.astimezone(timezone.utc)
|
||||||
|
return value_utc.astimezone(app_tz).date().isoformat()
|
||||||
|
|
||||||
stats_map = {
|
stats_map = {
|
||||||
stat.date.replace(tzinfo=timezone.utc).astimezone(app_tz).date().isoformat(): {
|
_to_business_date_str(stat.date): {
|
||||||
"requests": stat.total_requests,
|
"requests": stat.total_requests,
|
||||||
"tokens": stat.input_tokens + stat.output_tokens + stat.cache_creation_tokens + stat.cache_read_tokens,
|
"tokens": stat.input_tokens + stat.output_tokens + stat.cache_creation_tokens + stat.cache_read_tokens,
|
||||||
"cost": stat.total_cost,
|
"cost": stat.total_cost,
|
||||||
@@ -790,6 +797,38 @@ class DashboardDailyStatsAdapter(DashboardAdapter):
|
|||||||
"unique_providers": today_unique_providers,
|
"unique_providers": today_unique_providers,
|
||||||
"fallback_count": today_fallback_count,
|
"fallback_count": today_fallback_count,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 历史预聚合缺失时兜底:按业务日范围实时计算(仅补最近少量缺失,避免全表扫描)
|
||||||
|
yesterday_date = today_local.date() - timedelta(days=1)
|
||||||
|
historical_end = min(end_date_local.date(), yesterday_date)
|
||||||
|
missing_dates: list[str] = []
|
||||||
|
cursor = start_date_local.date()
|
||||||
|
while cursor <= historical_end:
|
||||||
|
date_str = cursor.isoformat()
|
||||||
|
if date_str not in stats_map:
|
||||||
|
missing_dates.append(date_str)
|
||||||
|
cursor += timedelta(days=1)
|
||||||
|
|
||||||
|
if missing_dates:
|
||||||
|
for date_str in missing_dates[-7:]:
|
||||||
|
target_local = datetime.fromisoformat(date_str).replace(tzinfo=app_tz)
|
||||||
|
computed = StatsAggregatorService.compute_daily_stats(db, target_local)
|
||||||
|
stats_map[date_str] = {
|
||||||
|
"requests": computed["total_requests"],
|
||||||
|
"tokens": (
|
||||||
|
computed["input_tokens"]
|
||||||
|
+ computed["output_tokens"]
|
||||||
|
+ computed["cache_creation_tokens"]
|
||||||
|
+ computed["cache_read_tokens"]
|
||||||
|
),
|
||||||
|
"cost": computed["total_cost"],
|
||||||
|
"avg_response_time": computed["avg_response_time_ms"] / 1000.0
|
||||||
|
if computed["avg_response_time_ms"]
|
||||||
|
else 0,
|
||||||
|
"unique_models": computed["unique_models"],
|
||||||
|
"unique_providers": computed["unique_providers"],
|
||||||
|
"fallback_count": computed["fallback_count"],
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
# 普通用户:仍需实时查询(用户级预聚合可选)
|
# 普通用户:仍需实时查询(用户级预聚合可选)
|
||||||
query = db.query(Usage).filter(
|
query = db.query(Usage).filter(
|
||||||
|
|||||||
@@ -266,8 +266,11 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
if mapping and mapping.model:
|
if mapping and mapping.model:
|
||||||
# 使用 select_provider_model_name 支持别名功能
|
# 使用 select_provider_model_name 支持别名功能
|
||||||
# 传入 api_key.id 作为 affinity_key,实现相同用户稳定选择同一别名
|
# 传入 api_key.id 作为 affinity_key,实现相同用户稳定选择同一别名
|
||||||
|
# 传入 api_format 用于过滤适用的别名作用域
|
||||||
affinity_key = self.api_key.id if self.api_key else None
|
affinity_key = self.api_key.id if self.api_key else None
|
||||||
mapped_name = mapping.model.select_provider_model_name(affinity_key)
|
mapped_name = mapping.model.select_provider_model_name(
|
||||||
|
affinity_key, api_format=self.FORMAT_ID
|
||||||
|
)
|
||||||
logger.debug(f"[Chat] 模型映射: {source_model} -> {mapped_name}")
|
logger.debug(f"[Chat] 模型映射: {source_model} -> {mapped_name}")
|
||||||
return mapped_name
|
return mapped_name
|
||||||
|
|
||||||
|
|||||||
@@ -155,8 +155,11 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
if mapping and mapping.model:
|
if mapping and mapping.model:
|
||||||
# 使用 select_provider_model_name 支持别名功能
|
# 使用 select_provider_model_name 支持别名功能
|
||||||
# 传入 api_key.id 作为 affinity_key,实现相同用户稳定选择同一别名
|
# 传入 api_key.id 作为 affinity_key,实现相同用户稳定选择同一别名
|
||||||
|
# 传入 api_format 用于过滤适用的别名作用域
|
||||||
affinity_key = self.api_key.id if self.api_key else None
|
affinity_key = self.api_key.id if self.api_key else None
|
||||||
mapped_name = mapping.model.select_provider_model_name(affinity_key)
|
mapped_name = mapping.model.select_provider_model_name(
|
||||||
|
affinity_key, api_format=self.FORMAT_ID
|
||||||
|
)
|
||||||
logger.debug(f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)")
|
logger.debug(f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)")
|
||||||
return mapped_name
|
return mapped_name
|
||||||
|
|
||||||
|
|||||||
@@ -813,7 +813,9 @@ class Model(Base):
|
|||||||
def get_effective_supports_image_generation(self) -> bool:
|
def get_effective_supports_image_generation(self) -> bool:
|
||||||
return self._get_effective_capability("supports_image_generation", False)
|
return self._get_effective_capability("supports_image_generation", False)
|
||||||
|
|
||||||
def select_provider_model_name(self, affinity_key: Optional[str] = None) -> str:
|
def select_provider_model_name(
|
||||||
|
self, affinity_key: Optional[str] = None, api_format: Optional[str] = None
|
||||||
|
) -> str:
|
||||||
"""按优先级选择要使用的 Provider 模型名称
|
"""按优先级选择要使用的 Provider 模型名称
|
||||||
|
|
||||||
如果配置了 provider_model_aliases,按优先级选择(数字越小越优先);
|
如果配置了 provider_model_aliases,按优先级选择(数字越小越优先);
|
||||||
@@ -822,6 +824,7 @@ class Model(Base):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
affinity_key: 用于哈希分散的亲和键(如用户 API Key 哈希),确保同一用户稳定选择同一别名
|
affinity_key: 用于哈希分散的亲和键(如用户 API Key 哈希),确保同一用户稳定选择同一别名
|
||||||
|
api_format: 当前请求的 API 格式(如 CLAUDE、OPENAI 等),用于过滤适用的别名
|
||||||
"""
|
"""
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
@@ -840,6 +843,13 @@ class Model(Base):
|
|||||||
if not isinstance(name, str) or not name.strip():
|
if not isinstance(name, str) or not name.strip():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# 检查 api_formats 作用域(如果配置了且当前有 api_format)
|
||||||
|
alias_api_formats = raw.get("api_formats")
|
||||||
|
if api_format and alias_api_formats:
|
||||||
|
# 如果配置了作用域,只有匹配时才生效
|
||||||
|
if isinstance(alias_api_formats, list) and api_format not in alias_api_formats:
|
||||||
|
continue
|
||||||
|
|
||||||
raw_priority = raw.get("priority", 1)
|
raw_priority = raw.get("priority", 1)
|
||||||
try:
|
try:
|
||||||
priority = int(raw_priority)
|
priority = int(raw_priority)
|
||||||
|
|||||||
@@ -238,8 +238,8 @@ class GlobalModelResponse(BaseModel):
|
|||||||
# 按次计费配置
|
# 按次计费配置
|
||||||
default_price_per_request: Optional[float] = Field(None, description="每次请求固定费用")
|
default_price_per_request: Optional[float] = Field(None, description="每次请求固定费用")
|
||||||
# 阶梯计费配置
|
# 阶梯计费配置
|
||||||
default_tiered_pricing: TieredPricingConfig = Field(
|
default_tiered_pricing: Optional[TieredPricingConfig] = Field(
|
||||||
..., description="阶梯计费配置"
|
default=None, description="阶梯计费配置"
|
||||||
)
|
)
|
||||||
# Key 能力配置 - 模型支持的能力列表
|
# Key 能力配置 - 模型支持的能力列表
|
||||||
supported_capabilities: Optional[List[str]] = Field(
|
supported_capabilities: Optional[List[str]] = Field(
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ class CleanupScheduler:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.running = False
|
self.running = False
|
||||||
self._interval_tasks = []
|
self._interval_tasks = []
|
||||||
|
self._stats_aggregation_lock = asyncio.Lock()
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
"""启动调度器"""
|
"""启动调度器"""
|
||||||
@@ -56,6 +57,14 @@ class CleanupScheduler:
|
|||||||
job_id="stats_aggregation",
|
job_id="stats_aggregation",
|
||||||
name="统计数据聚合",
|
name="统计数据聚合",
|
||||||
)
|
)
|
||||||
|
# 统计聚合补偿任务 - 每 30 分钟检查缺失并回填
|
||||||
|
scheduler.add_interval_job(
|
||||||
|
self._scheduled_stats_aggregation,
|
||||||
|
minutes=30,
|
||||||
|
job_id="stats_aggregation_backfill",
|
||||||
|
name="统计数据聚合补偿",
|
||||||
|
backfill=True,
|
||||||
|
)
|
||||||
|
|
||||||
# 清理任务 - 凌晨 3 点执行
|
# 清理任务 - 凌晨 3 点执行
|
||||||
scheduler.add_cron_job(
|
scheduler.add_cron_job(
|
||||||
@@ -115,9 +124,9 @@ class CleanupScheduler:
|
|||||||
|
|
||||||
# ========== 任务函数(APScheduler 直接调用异步函数) ==========
|
# ========== 任务函数(APScheduler 直接调用异步函数) ==========
|
||||||
|
|
||||||
async def _scheduled_stats_aggregation(self):
|
async def _scheduled_stats_aggregation(self, backfill: bool = False):
|
||||||
"""统计聚合任务(定时调用)"""
|
"""统计聚合任务(定时调用)"""
|
||||||
await self._perform_stats_aggregation()
|
await self._perform_stats_aggregation(backfill=backfill)
|
||||||
|
|
||||||
async def _scheduled_cleanup(self):
|
async def _scheduled_cleanup(self):
|
||||||
"""清理任务(定时调用)"""
|
"""清理任务(定时调用)"""
|
||||||
@@ -144,6 +153,11 @@ class CleanupScheduler:
|
|||||||
Args:
|
Args:
|
||||||
backfill: 是否回填历史数据(启动时检查缺失的日期)
|
backfill: 是否回填历史数据(启动时检查缺失的日期)
|
||||||
"""
|
"""
|
||||||
|
if self._stats_aggregation_lock.locked():
|
||||||
|
logger.info("统计聚合任务正在运行,跳过本次触发")
|
||||||
|
return
|
||||||
|
|
||||||
|
async with self._stats_aggregation_lock:
|
||||||
db = create_session()
|
db = create_session()
|
||||||
try:
|
try:
|
||||||
# 检查是否启用统计聚合
|
# 检查是否启用统计聚合
|
||||||
@@ -181,11 +195,7 @@ class CleanupScheduler:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# 非首次运行,检查最近是否有缺失的日期需要回填
|
# 非首次运行,检查最近是否有缺失的日期需要回填
|
||||||
latest_stat = (
|
latest_stat = db.query(StatsDaily).order_by(StatsDaily.date.desc()).first()
|
||||||
db.query(StatsDaily)
|
|
||||||
.order_by(StatsDaily.date.desc())
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if latest_stat:
|
if latest_stat:
|
||||||
latest_date_utc = latest_stat.date
|
latest_date_utc = latest_stat.date
|
||||||
@@ -196,26 +206,46 @@ class CleanupScheduler:
|
|||||||
|
|
||||||
# 使用业务日期计算缺失区间(避免用 UTC 年月日导致日期偏移,且对 DST 更安全)
|
# 使用业务日期计算缺失区间(避免用 UTC 年月日导致日期偏移,且对 DST 更安全)
|
||||||
latest_business_date = latest_date_utc.astimezone(app_tz).date()
|
latest_business_date = latest_date_utc.astimezone(app_tz).date()
|
||||||
yesterday_business_date = (today_local.date() - timedelta(days=1))
|
yesterday_business_date = today_local.date() - timedelta(days=1)
|
||||||
missing_start_date = latest_business_date + timedelta(days=1)
|
missing_start_date = latest_business_date + timedelta(days=1)
|
||||||
|
|
||||||
if missing_start_date <= yesterday_business_date:
|
if missing_start_date <= yesterday_business_date:
|
||||||
missing_days = (yesterday_business_date - missing_start_date).days + 1
|
missing_days = (
|
||||||
|
yesterday_business_date - missing_start_date
|
||||||
|
).days + 1
|
||||||
|
|
||||||
|
# 限制最大回填天数,防止停机很久后一次性回填太多
|
||||||
|
max_backfill_days: int = SystemConfigService.get_config(
|
||||||
|
db, "max_stats_backfill_days", 30
|
||||||
|
) or 30
|
||||||
|
if missing_days > max_backfill_days:
|
||||||
|
logger.warning(
|
||||||
|
f"缺失 {missing_days} 天数据超过最大回填限制 "
|
||||||
|
f"{max_backfill_days} 天,只回填最近 {max_backfill_days} 天"
|
||||||
|
)
|
||||||
|
missing_start_date = yesterday_business_date - timedelta(
|
||||||
|
days=max_backfill_days - 1
|
||||||
|
)
|
||||||
|
missing_days = max_backfill_days
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"检测到缺失 {missing_days} 天的统计数据 "
|
f"检测到缺失 {missing_days} 天的统计数据 "
|
||||||
f"({missing_start_date} ~ {yesterday_business_date}),开始回填..."
|
f"({missing_start_date} ~ {yesterday_business_date}),开始回填..."
|
||||||
)
|
)
|
||||||
|
|
||||||
current_date = missing_start_date
|
current_date = missing_start_date
|
||||||
users = db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
|
users = (
|
||||||
|
db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
|
||||||
|
)
|
||||||
|
|
||||||
while current_date <= yesterday_business_date:
|
while current_date <= yesterday_business_date:
|
||||||
try:
|
try:
|
||||||
current_date_local = datetime.combine(
|
current_date_local = datetime.combine(
|
||||||
current_date, datetime.min.time(), tzinfo=app_tz
|
current_date, datetime.min.time(), tzinfo=app_tz
|
||||||
)
|
)
|
||||||
StatsAggregatorService.aggregate_daily_stats(db, current_date_local)
|
StatsAggregatorService.aggregate_daily_stats(
|
||||||
# 聚合用户数据
|
db, current_date_local
|
||||||
|
)
|
||||||
for (user_id,) in users:
|
for (user_id,) in users:
|
||||||
try:
|
try:
|
||||||
StatsAggregatorService.aggregate_user_daily_stats(
|
StatsAggregatorService.aggregate_user_daily_stats(
|
||||||
@@ -238,7 +268,6 @@ class CleanupScheduler:
|
|||||||
|
|
||||||
current_date += timedelta(days=1)
|
current_date += timedelta(days=1)
|
||||||
|
|
||||||
# 更新全局汇总
|
|
||||||
StatsAggregatorService.update_summary(db)
|
StatsAggregatorService.update_summary(db)
|
||||||
logger.info(f"缺失数据回填完成,共 {missing_days} 天")
|
logger.info(f"缺失数据回填完成,共 {missing_days} 天")
|
||||||
else:
|
else:
|
||||||
@@ -246,32 +275,33 @@ class CleanupScheduler:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# 定时任务:聚合昨天的数据
|
# 定时任务:聚合昨天的数据
|
||||||
# 注意:aggregate_daily_stats 期望业务时区的日期,不是 UTC
|
|
||||||
yesterday_local = today_local - timedelta(days=1)
|
yesterday_local = today_local - timedelta(days=1)
|
||||||
|
|
||||||
StatsAggregatorService.aggregate_daily_stats(db, yesterday_local)
|
StatsAggregatorService.aggregate_daily_stats(db, yesterday_local)
|
||||||
|
|
||||||
# 聚合所有用户的昨日数据
|
|
||||||
users = db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
|
users = db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
|
||||||
for (user_id,) in users:
|
for (user_id,) in users:
|
||||||
try:
|
try:
|
||||||
StatsAggregatorService.aggregate_user_daily_stats(db, user_id, yesterday_local)
|
StatsAggregatorService.aggregate_user_daily_stats(
|
||||||
|
db, user_id, yesterday_local
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"聚合用户 {user_id} 统计数据失败: {e}")
|
logger.warning(f"聚合用户 {user_id} 统计数据失败: {e}")
|
||||||
# 回滚当前用户的失败操作,继续处理其他用户
|
|
||||||
try:
|
try:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# 更新全局汇总
|
|
||||||
StatsAggregatorService.update_summary(db)
|
StatsAggregatorService.update_summary(db)
|
||||||
|
|
||||||
logger.info("统计数据聚合完成")
|
logger.info("统计数据聚合完成")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"统计聚合任务执行失败: {e}")
|
logger.exception(f"统计聚合任务执行失败: {e}")
|
||||||
|
try:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|||||||
@@ -56,65 +56,44 @@ class StatsAggregatorService:
|
|||||||
"""统计数据聚合服务"""
|
"""统计数据聚合服务"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def aggregate_daily_stats(db: Session, date: datetime) -> StatsDaily:
|
def compute_daily_stats(db: Session, date: datetime) -> dict:
|
||||||
"""聚合指定日期的统计数据
|
"""计算指定业务日期的统计数据(不写入数据库)"""
|
||||||
|
|
||||||
Args:
|
|
||||||
db: 数据库会话
|
|
||||||
date: 要聚合的业务日期(使用 APP_TIMEZONE 时区)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
StatsDaily 记录
|
|
||||||
"""
|
|
||||||
# 将业务日期转换为 UTC 时间范围
|
|
||||||
day_start, day_end = _get_business_day_range(date)
|
day_start, day_end = _get_business_day_range(date)
|
||||||
|
|
||||||
# stats_daily.date 存储的是业务日期对应的 UTC 开始时间
|
|
||||||
# 检查是否已存在该日期的记录
|
|
||||||
existing = db.query(StatsDaily).filter(StatsDaily.date == day_start).first()
|
|
||||||
if existing:
|
|
||||||
stats = existing
|
|
||||||
else:
|
|
||||||
stats = StatsDaily(id=str(uuid.uuid4()), date=day_start)
|
|
||||||
|
|
||||||
# 基础请求统计
|
|
||||||
base_query = db.query(Usage).filter(
|
base_query = db.query(Usage).filter(
|
||||||
and_(Usage.created_at >= day_start, Usage.created_at < day_end)
|
and_(Usage.created_at >= day_start, Usage.created_at < day_end)
|
||||||
)
|
)
|
||||||
|
|
||||||
total_requests = base_query.count()
|
total_requests = base_query.count()
|
||||||
|
|
||||||
# 如果没有请求,直接返回空记录
|
|
||||||
if total_requests == 0:
|
if total_requests == 0:
|
||||||
stats.total_requests = 0
|
return {
|
||||||
stats.success_requests = 0
|
"day_start": day_start,
|
||||||
stats.error_requests = 0
|
"total_requests": 0,
|
||||||
stats.input_tokens = 0
|
"success_requests": 0,
|
||||||
stats.output_tokens = 0
|
"error_requests": 0,
|
||||||
stats.cache_creation_tokens = 0
|
"input_tokens": 0,
|
||||||
stats.cache_read_tokens = 0
|
"output_tokens": 0,
|
||||||
stats.total_cost = 0.0
|
"cache_creation_tokens": 0,
|
||||||
stats.actual_total_cost = 0.0
|
"cache_read_tokens": 0,
|
||||||
stats.input_cost = 0.0
|
"total_cost": 0.0,
|
||||||
stats.output_cost = 0.0
|
"actual_total_cost": 0.0,
|
||||||
stats.cache_creation_cost = 0.0
|
"input_cost": 0.0,
|
||||||
stats.cache_read_cost = 0.0
|
"output_cost": 0.0,
|
||||||
stats.avg_response_time_ms = 0.0
|
"cache_creation_cost": 0.0,
|
||||||
stats.fallback_count = 0
|
"cache_read_cost": 0.0,
|
||||||
|
"avg_response_time_ms": 0.0,
|
||||||
|
"fallback_count": 0,
|
||||||
|
"unique_models": 0,
|
||||||
|
"unique_providers": 0,
|
||||||
|
}
|
||||||
|
|
||||||
if not existing:
|
|
||||||
db.add(stats)
|
|
||||||
db.commit()
|
|
||||||
return stats
|
|
||||||
|
|
||||||
# 错误请求数
|
|
||||||
error_requests = (
|
error_requests = (
|
||||||
base_query.filter(
|
base_query.filter(
|
||||||
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
|
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
|
||||||
).count()
|
).count()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Token 和成本聚合
|
|
||||||
aggregated = (
|
aggregated = (
|
||||||
db.query(
|
db.query(
|
||||||
func.sum(Usage.input_tokens).label("input_tokens"),
|
func.sum(Usage.input_tokens).label("input_tokens"),
|
||||||
@@ -157,7 +136,6 @@ class StatsAggregatorService:
|
|||||||
or 0
|
or 0
|
||||||
)
|
)
|
||||||
|
|
||||||
# 使用维度统计
|
|
||||||
unique_models = (
|
unique_models = (
|
||||||
db.query(func.count(func.distinct(Usage.model)))
|
db.query(func.count(func.distinct(Usage.model)))
|
||||||
.filter(and_(Usage.created_at >= day_start, Usage.created_at < day_end))
|
.filter(and_(Usage.created_at >= day_start, Usage.created_at < day_end))
|
||||||
@@ -171,31 +149,74 @@ class StatsAggregatorService:
|
|||||||
or 0
|
or 0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"day_start": day_start,
|
||||||
|
"total_requests": total_requests,
|
||||||
|
"success_requests": total_requests - error_requests,
|
||||||
|
"error_requests": error_requests,
|
||||||
|
"input_tokens": int(aggregated.input_tokens or 0) if aggregated else 0,
|
||||||
|
"output_tokens": int(aggregated.output_tokens or 0) if aggregated else 0,
|
||||||
|
"cache_creation_tokens": int(aggregated.cache_creation_tokens or 0) if aggregated else 0,
|
||||||
|
"cache_read_tokens": int(aggregated.cache_read_tokens or 0) if aggregated else 0,
|
||||||
|
"total_cost": float(aggregated.total_cost or 0) if aggregated else 0.0,
|
||||||
|
"actual_total_cost": float(aggregated.actual_total_cost or 0) if aggregated else 0.0,
|
||||||
|
"input_cost": float(aggregated.input_cost or 0) if aggregated else 0.0,
|
||||||
|
"output_cost": float(aggregated.output_cost or 0) if aggregated else 0.0,
|
||||||
|
"cache_creation_cost": float(aggregated.cache_creation_cost or 0) if aggregated else 0.0,
|
||||||
|
"cache_read_cost": float(aggregated.cache_read_cost or 0) if aggregated else 0.0,
|
||||||
|
"avg_response_time_ms": float(aggregated.avg_response_time or 0) if aggregated else 0.0,
|
||||||
|
"fallback_count": fallback_count,
|
||||||
|
"unique_models": unique_models,
|
||||||
|
"unique_providers": unique_providers,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def aggregate_daily_stats(db: Session, date: datetime) -> StatsDaily:
|
||||||
|
"""聚合指定日期的统计数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
date: 要聚合的业务日期(使用 APP_TIMEZONE 时区)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StatsDaily 记录
|
||||||
|
"""
|
||||||
|
computed = StatsAggregatorService.compute_daily_stats(db, date)
|
||||||
|
day_start = computed["day_start"]
|
||||||
|
|
||||||
|
# stats_daily.date 存储的是业务日期对应的 UTC 开始时间
|
||||||
|
# 检查是否已存在该日期的记录
|
||||||
|
existing = db.query(StatsDaily).filter(StatsDaily.date == day_start).first()
|
||||||
|
if existing:
|
||||||
|
stats = existing
|
||||||
|
else:
|
||||||
|
stats = StatsDaily(id=str(uuid.uuid4()), date=day_start)
|
||||||
|
|
||||||
# 更新统计记录
|
# 更新统计记录
|
||||||
stats.total_requests = total_requests
|
stats.total_requests = computed["total_requests"]
|
||||||
stats.success_requests = total_requests - error_requests
|
stats.success_requests = computed["success_requests"]
|
||||||
stats.error_requests = error_requests
|
stats.error_requests = computed["error_requests"]
|
||||||
stats.input_tokens = int(aggregated.input_tokens or 0)
|
stats.input_tokens = computed["input_tokens"]
|
||||||
stats.output_tokens = int(aggregated.output_tokens or 0)
|
stats.output_tokens = computed["output_tokens"]
|
||||||
stats.cache_creation_tokens = int(aggregated.cache_creation_tokens or 0)
|
stats.cache_creation_tokens = computed["cache_creation_tokens"]
|
||||||
stats.cache_read_tokens = int(aggregated.cache_read_tokens or 0)
|
stats.cache_read_tokens = computed["cache_read_tokens"]
|
||||||
stats.total_cost = float(aggregated.total_cost or 0)
|
stats.total_cost = computed["total_cost"]
|
||||||
stats.actual_total_cost = float(aggregated.actual_total_cost or 0)
|
stats.actual_total_cost = computed["actual_total_cost"]
|
||||||
stats.input_cost = float(aggregated.input_cost or 0)
|
stats.input_cost = computed["input_cost"]
|
||||||
stats.output_cost = float(aggregated.output_cost or 0)
|
stats.output_cost = computed["output_cost"]
|
||||||
stats.cache_creation_cost = float(aggregated.cache_creation_cost or 0)
|
stats.cache_creation_cost = computed["cache_creation_cost"]
|
||||||
stats.cache_read_cost = float(aggregated.cache_read_cost or 0)
|
stats.cache_read_cost = computed["cache_read_cost"]
|
||||||
stats.avg_response_time_ms = float(aggregated.avg_response_time or 0)
|
stats.avg_response_time_ms = computed["avg_response_time_ms"]
|
||||||
stats.fallback_count = fallback_count
|
stats.fallback_count = computed["fallback_count"]
|
||||||
stats.unique_models = unique_models
|
stats.unique_models = computed["unique_models"]
|
||||||
stats.unique_providers = unique_providers
|
stats.unique_providers = computed["unique_providers"]
|
||||||
|
|
||||||
if not existing:
|
if not existing:
|
||||||
db.add(stats)
|
db.add(stats)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
# 日志使用业务日期(输入参数),而不是 UTC 日期
|
# 日志使用业务日期(输入参数),而不是 UTC 日期
|
||||||
logger.info(f"[StatsAggregator] 聚合日期 {date.date()} 完成: {total_requests} 请求")
|
logger.info(f"[StatsAggregator] 聚合日期 {date.date()} 完成: {computed['total_requests']} 请求")
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user