diff --git a/frontend/src/api/endpoints/models.ts b/frontend/src/api/endpoints/models.ts index 5a00f79..3620ac2 100644 --- a/frontend/src/api/endpoints/models.ts +++ b/frontend/src/api/endpoints/models.ts @@ -5,6 +5,8 @@ import type { ModelUpdate, ModelCatalogResponse, ProviderAvailableSourceModelsResponse, + UpstreamModel, + ImportFromUpstreamResponse, } from './types' /** @@ -119,3 +121,40 @@ export async function batchAssignModelsToProvider( ) return response.data } + +/** + * 查询提供商的上游模型列表 + */ +export async function queryProviderUpstreamModels( + providerId: string +): Promise<{ + success: boolean + data: { + models: UpstreamModel[] + error: string | null + } + provider: { + id: string + name: string + display_name: string + } +}> { + const response = await client.post('/api/admin/provider-query/models', { + provider_id: providerId, + }) + return response.data +} + +/** + * 从上游提供商导入模型 + */ +export async function importModelsFromUpstream( + providerId: string, + modelIds: string[] +): Promise { + const response = await client.post( + `/api/admin/providers/${providerId}/import-from-upstream`, + { model_ids: modelIds } + ) + return response.data +} diff --git a/frontend/src/api/endpoints/types.ts b/frontend/src/api/endpoints/types.ts index ef13cb7..6d373ac 100644 --- a/frontend/src/api/endpoints/types.ts +++ b/frontend/src/api/endpoints/types.ts @@ -495,3 +495,42 @@ export interface GlobalModelListResponse { models: GlobalModelResponse[] total: number } + +// ==================== 上游模型导入相关 ==================== + +/** + * 上游模型(从提供商 API 获取的原始模型) + */ +export interface UpstreamModel { + id: string + owned_by?: string + display_name?: string + api_format?: string +} + +/** + * 导入成功的模型信息 + */ +export interface ImportFromUpstreamSuccessItem { + model_id: string + global_model_id: string + global_model_name: string + provider_model_id: string + created_global_model: boolean +} + +/** + * 导入失败的模型信息 + */ +export interface ImportFromUpstreamErrorItem { + model_id: string + error: string +} + +/** + * 从上游提供商导入模型响应 + */ +export interface ImportFromUpstreamResponse { + success: ImportFromUpstreamSuccessItem[] + errors: ImportFromUpstreamErrorItem[] +} diff --git a/frontend/src/features/providers/components/BatchAssignModelsDialog.vue b/frontend/src/features/providers/components/BatchAssignModelsDialog.vue index d706616..97243d3 100644 --- a/frontend/src/features/providers/components/BatchAssignModelsDialog.vue +++ b/frontend/src/features/providers/components/BatchAssignModelsDialog.vue @@ -31,29 +31,46 @@
- +
-
-
-

- 可添加 -

- +
+

+ 可添加 +

+
+ +
- - {{ availableModels.length }} 个 - + + + +
@@ -73,37 +90,142 @@
+
- -
-

- {{ model.display_name }} -

-

- {{ model.name }} -

+
+ +
- - {{ model.is_active ? '活跃' : '停用' }} - +
+ 所有全局模型均已关联 +
+
+ +
+

+ {{ model.display_name }} +

+

+ {{ model.name }} +

+
+ + {{ model.is_active ? '活跃' : '停用' }} + +
+
+
+ + +
+
+ + +
+
+
+ +
+

+ {{ model.id }} +

+

+ {{ model.owned_by || model.id }} +

+
+
+
@@ -115,8 +237,8 @@ variant="outline" size="sm" class="w-9 h-8" - :class="selectedLeftIds.length > 0 && !submittingAdd ? 'border-primary' : ''" - :disabled="selectedLeftIds.length === 0 || submittingAdd" + :class="totalSelectedCount > 0 && !submittingAdd ? 'border-primary' : ''" + :disabled="totalSelectedCount === 0 || submittingAdd" title="添加选中" @click="batchAddSelected" > @@ -127,7 +249,7 @@ -
- + 已添加 +

+
import { ref, computed, watch } from 'vue' -import { Layers, Loader2, ChevronRight, ChevronLeft } from 'lucide-vue-next' +import { Layers, Loader2, ChevronRight, ChevronLeft, ChevronDown, Zap, RefreshCw, Search } from 'lucide-vue-next' import Dialog from '@/components/ui/dialog/Dialog.vue' import Button from '@/components/ui/button.vue' import Badge from '@/components/ui/badge.vue' import Checkbox from '@/components/ui/checkbox.vue' +import Input from '@/components/ui/input.vue' import { useToast } from '@/composables/useToast' import { parseApiError } from '@/utils/errorParser' import { @@ -253,8 +368,13 @@ import { getProviderModels, batchAssignModelsToProvider, deleteModel, + importModelsFromUpstream, + API_FORMAT_LABELS, type Model } from '@/api/endpoints' +import { useUpstreamModelsCache, type UpstreamModel } from '../composables/useUpstreamModelsCache' + +const { fetchModels: fetchCachedModels, clearCache, getCachedModels } = useUpstreamModelsCache() const props = defineProps<{ open: boolean @@ -274,17 +394,27 @@ const { error: showError, success } = useToast() const loadingGlobalModels = ref(false) const submittingAdd = ref(false) const submittingRemove = ref(false) +const fetchingUpstreamModels = ref(false) +const upstreamModelsLoaded = ref(false) // 数据 const allGlobalModels = ref([]) const existingModels = ref([]) +const upstreamModels = ref([]) // 选择状态 -const selectedLeftIds = ref([]) +const selectedGlobalModelIds = ref([]) +const selectedUpstreamModelIds = ref([]) const selectedRightIds = ref([]) -// 计算可添加的模型(排除已关联的) -const availableModels = computed(() => { +// 折叠状态 +const collapsedGroups = ref>(new Set()) + +// 搜索状态 +const searchQuery = ref('') + +// 计算可添加的全局模型(排除已关联的) +const availableGlobalModelsBase = computed(() => { const existingGlobalModelIds = new Set( existingModels.value .filter(m => m.global_model_id) @@ -293,31 +423,123 @@ const availableModels = computed(() => { return allGlobalModels.value.filter(m => !existingGlobalModelIds.has(m.id)) }) -// 全选状态 -const isAllLeftSelected = computed(() => - availableModels.value.length > 0 && - selectedLeftIds.value.length === availableModels.value.length -) +// 搜索过滤后的全局模型 +const availableGlobalModels = computed(() => { + if (!searchQuery.value.trim()) return availableGlobalModelsBase.value + const query = searchQuery.value.toLowerCase() + return availableGlobalModelsBase.value.filter(m => + m.name.toLowerCase().includes(query) || + m.display_name.toLowerCase().includes(query) + ) +}) +// 计算可添加的上游模型(排除已关联的) +const availableUpstreamModelsBase = computed(() => { + const existingModelNames = new Set( + existingModels.value.map(m => m.provider_model_name) + ) + return upstreamModels.value.filter(m => !existingModelNames.has(m.id)) +}) + +// 搜索过滤后的上游模型 +const availableUpstreamModels = computed(() => { + if (!searchQuery.value.trim()) return availableUpstreamModelsBase.value + const query = searchQuery.value.toLowerCase() + return availableUpstreamModelsBase.value.filter(m => + m.id.toLowerCase().includes(query) || + (m.owned_by && m.owned_by.toLowerCase().includes(query)) + ) +}) + +// 按 API 格式分组的上游模型 +const upstreamModelGroups = computed(() => { + const groups: Record = {} + + for (const model of availableUpstreamModels.value) { + const format = model.api_format || 'unknown' + if (!groups[format]) { + groups[format] = [] + } + groups[format].push(model) + } + + // 按 API_FORMAT_LABELS 的顺序排序 + const order = Object.keys(API_FORMAT_LABELS) + return Object.entries(groups) + .map(([api_format, models]) => ({ api_format, models })) + .sort((a, b) => { + const aIndex = order.indexOf(a.api_format) + const bIndex = order.indexOf(b.api_format) + if (aIndex === -1 && bIndex === -1) return a.api_format.localeCompare(b.api_format) + if (aIndex === -1) return 1 + if (bIndex === -1) return -1 + return aIndex - bIndex + }) +}) + +// 总可添加数量 +const totalAvailableCount = computed(() => { + return availableGlobalModels.value.length + availableUpstreamModels.value.length +}) + +// 总选中数量 +const totalSelectedCount = computed(() => { + return selectedGlobalModelIds.value.length + selectedUpstreamModelIds.value.length +}) + +// 全选状态 const isAllRightSelected = computed(() => existingModels.value.length > 0 && selectedRightIds.value.length === existingModels.value.length ) +// 全局模型是否全选 +const isAllGlobalModelsSelected = computed(() => { + if (availableGlobalModels.value.length === 0) return false + return availableGlobalModels.value.every(m => selectedGlobalModelIds.value.includes(m.id)) +}) + +// 检查某个上游组是否全选 +function isUpstreamGroupAllSelected(apiFormat: string): boolean { + const group = upstreamModelGroups.value.find(g => g.api_format === apiFormat) + if (!group || group.models.length === 0) return false + return group.models.every(m => selectedUpstreamModelIds.value.includes(m.id)) +} + // 监听打开状态 watch(() => props.open, async (isOpen) => { if (isOpen && props.providerId) { await loadData() } else { // 重置状态 - selectedLeftIds.value = [] + selectedGlobalModelIds.value = [] + selectedUpstreamModelIds.value = [] selectedRightIds.value = [] + upstreamModels.value = [] + upstreamModelsLoaded.value = false + collapsedGroups.value = new Set() + searchQuery.value = '' } }) // 加载数据 async function loadData() { await Promise.all([loadGlobalModels(), loadExistingModels()]) + // 默认折叠全局模型组 + collapsedGroups.value = new Set(['global']) + + // 检查缓存,如果有缓存数据则直接使用 + const cachedModels = getCachedModels(props.providerId) + if (cachedModels) { + upstreamModels.value = cachedModels + upstreamModelsLoaded.value = true + // 折叠所有上游模型组 + for (const model of cachedModels) { + if (model.api_format) { + collapsedGroups.value.add(model.api_format) + } + } + } } // 加载全局模型列表 @@ -342,13 +564,91 @@ async function loadExistingModels() { } } -// 切换左侧选择 -function toggleLeftSelection(id: string) { - const index = selectedLeftIds.value.indexOf(id) - if (index === -1) { - selectedLeftIds.value.push(id) +// 从提供商获取模型 +async function fetchUpstreamModels(forceRefresh = false) { + if (forceRefresh) { + clearCache(props.providerId) + } + + try { + fetchingUpstreamModels.value = true + const result = await fetchCachedModels(props.providerId, forceRefresh) + if (result) { + if (result.error) { + showError(result.error, '错误') + } else { + upstreamModels.value = result.models + upstreamModelsLoaded.value = true + // 折叠所有上游模型组 + const allGroups = new Set(collapsedGroups.value) + for (const model of result.models) { + if (model.api_format) { + allGroups.add(model.api_format) + } + } + collapsedGroups.value = allGroups + } + } + } finally { + fetchingUpstreamModels.value = false + } +} + +// 切换折叠状态 +function toggleGroupCollapse(group: string) { + if (collapsedGroups.value.has(group)) { + collapsedGroups.value.delete(group) } else { - selectedLeftIds.value.splice(index, 1) + collapsedGroups.value.add(group) + } + // 触发响应式更新 + collapsedGroups.value = new Set(collapsedGroups.value) +} + +// 切换全局模型选择 +function toggleGlobalModelSelection(id: string) { + const index = selectedGlobalModelIds.value.indexOf(id) + if (index === -1) { + selectedGlobalModelIds.value.push(id) + } else { + selectedGlobalModelIds.value.splice(index, 1) + } +} + +// 切换上游模型选择 +function toggleUpstreamModelSelection(id: string) { + const index = selectedUpstreamModelIds.value.indexOf(id) + if (index === -1) { + selectedUpstreamModelIds.value.push(id) + } else { + selectedUpstreamModelIds.value.splice(index, 1) + } +} + +// 全选全局模型 +function selectAllGlobalModels() { + const allIds = availableGlobalModels.value.map(m => m.id) + const allSelected = allIds.every(id => selectedGlobalModelIds.value.includes(id)) + if (allSelected) { + selectedGlobalModelIds.value = selectedGlobalModelIds.value.filter(id => !allIds.includes(id)) + } else { + const newIds = allIds.filter(id => !selectedGlobalModelIds.value.includes(id)) + selectedGlobalModelIds.value.push(...newIds) + } +} + +// 全选某个 API 格式的上游模型 +function selectAllUpstreamModels(apiFormat: string) { + const group = upstreamModelGroups.value.find(g => g.api_format === apiFormat) + if (!group) return + + const allIds = group.models.map(m => m.id) + const allSelected = allIds.every(id => selectedUpstreamModelIds.value.includes(id)) + if (allSelected) { + selectedUpstreamModelIds.value = selectedUpstreamModelIds.value.filter(id => !allIds.includes(id)) + } else { + const newIds = allIds.filter(id => !selectedUpstreamModelIds.value.includes(id)) + selectedUpstreamModelIds.value.push(...newIds) } } @@ -362,15 +662,6 @@ function toggleRightSelection(id: string) { } } -// 全选/取消全选左侧 -function toggleSelectAllLeft() { - if (isAllLeftSelected.value) { - selectedLeftIds.value = [] - } else { - selectedLeftIds.value = availableModels.value.map(m => m.id) - } -} - // 全选/取消全选右侧 function toggleSelectAllRight() { if (isAllRightSelected.value) { @@ -382,22 +673,41 @@ function toggleSelectAllRight() { // 批量添加选中的模型 async function batchAddSelected() { - if (selectedLeftIds.value.length === 0) return + if (totalSelectedCount.value === 0) return try { submittingAdd.value = true - const result = await batchAssignModelsToProvider(props.providerId, selectedLeftIds.value) + let totalSuccess = 0 + const allErrors: string[] = [] - if (result.success.length > 0) { - success(`成功添加 ${result.success.length} 个模型`) + // 处理全局模型 + if (selectedGlobalModelIds.value.length > 0) { + const result = await batchAssignModelsToProvider(props.providerId, selectedGlobalModelIds.value) + totalSuccess += result.success.length + if (result.errors.length > 0) { + allErrors.push(...result.errors.map(e => e.error)) + } } - if (result.errors.length > 0) { - const errorMessages = result.errors.map(e => e.error).join(', ') - showError(`部分模型添加失败: ${errorMessages}`, '警告') + // 处理上游模型(调用 import-from-upstream API) + if (selectedUpstreamModelIds.value.length > 0) { + const result = await importModelsFromUpstream(props.providerId, selectedUpstreamModelIds.value) + totalSuccess += result.success.length + if (result.errors.length > 0) { + allErrors.push(...result.errors.map(e => e.error)) + } } - selectedLeftIds.value = [] + if (totalSuccess > 0) { + success(`成功添加 ${totalSuccess} 个模型`) + } + + if (allErrors.length > 0) { + showError(`部分模型添加失败: ${allErrors.slice(0, 3).join(', ')}${allErrors.length > 3 ? '...' : ''}`, '警告') + } + + selectedGlobalModelIds.value = [] + selectedUpstreamModelIds.value = [] await loadExistingModels() emit('changed') } catch (err: any) { diff --git a/frontend/src/features/providers/components/ModelMappingDialog.vue b/frontend/src/features/providers/components/ModelMappingDialog.vue new file mode 100644 index 0000000..fba527a --- /dev/null +++ b/frontend/src/features/providers/components/ModelMappingDialog.vue @@ -0,0 +1,777 @@ + + + diff --git a/frontend/src/features/providers/components/provider-tabs/ModelAliasesTab.vue b/frontend/src/features/providers/components/provider-tabs/ModelAliasesTab.vue index 93c0c6e..77206f9 100644 --- a/frontend/src/features/providers/components/provider-tabs/ModelAliasesTab.vue +++ b/frontend/src/features/providers/components/provider-tabs/ModelAliasesTab.vue @@ -142,330 +142,14 @@ - -
- -
- -
- - -
- - -
- -
- -
-
- 无可用格式 -
-
-
- - -
- -
- -
- 上游模型 - -
- - -
- - - - -
- - -
- - - -
- -

- 点击右上角按钮
从上游获取可用模型 -

-
-
-
- - -
-
-
- 映射名称 - - {{ formData.aliases.length }} - -
-
- - -
-
- - -
-
-
- -
- -
- - -
- -
- {{ alias.priority }} -
-
- - - - - - -
-
- - -
- -

- 从左侧选择模型
或手动添加映射 -

-
-
- - -
- 拖拽调整优先级顺序 -
-
-
-
- - -
+ import { ref, computed, onMounted, watch } from 'vue' -import { Tag, Plus, Edit, Trash2, Loader2, GripVertical, X, Zap, Search, RefreshCw, ChevronRight, Eraser } from 'lucide-vue-next' -import { - Card, - Button, - Badge, - Input, - Label, - Dialog, - Select, - SelectContent, - SelectItem, - SelectTrigger, - SelectValue, -} from '@/components/ui' +import { Tag, Plus, Edit, Trash2, ChevronRight } from 'lucide-vue-next' +import { Card, Button, Badge } from '@/components/ui' import AlertDialog from '@/components/common/AlertDialog.vue' +import ModelMappingDialog, { type AliasGroup } from '../ModelMappingDialog.vue' import { useToast } from '@/composables/useToast' import { getProviderModels, @@ -505,17 +178,6 @@ import { type ProviderModelAlias } from '@/api/endpoints' import { updateModel } from '@/api/endpoints/models' -import { adminApi } from '@/api/admin' - -interface AliasItem { - model: Model - alias: ProviderModelAlias -} - -interface FormAlias { - name: string - priority: number -} const props = defineProps<{ provider: any @@ -532,131 +194,22 @@ const loading = ref(false) const models = ref([]) const dialogOpen = ref(false) const deleteConfirmOpen = ref(false) -const submitting = ref(false) -const editingItem = ref(null) +const editingGroup = ref(null) const deletingGroup = ref(null) -const modelSelectOpen = ref(false) -// 拖拽状态 -const draggedIndex = ref(null) -const dragOverIndex = ref(null) - -// 优先级编辑状态 -const editingPriorityIndex = ref(null) - -// 快速添加(上游模型)状态 -const fetchingUpstreamModels = ref(false) -const refreshingUpstreamModels = ref(false) -const upstreamModelsLoaded = ref(false) -const upstreamModels = ref>([]) -const upstreamModelSearch = ref('') - -// 分组折叠状态(上游模型列表) -const collapsedGroups = ref>(new Set()) - -// 列表展开状态(映射组列表) +// 列表展开状态 const expandedAliasGroups = ref>(new Set()) -// 上游模型缓存(按 Provider ID) -const upstreamModelsCache = ref - timestamp: number -}>>(new Map()) -const CACHE_TTL = 5 * 60 * 1000 // 5 分钟缓存 - -// 过滤和排序后的上游模型列表 -const filteredUpstreamModels = computed(() => { - const searchText = upstreamModelSearch.value.toLowerCase().trim() - let result = [...upstreamModels.value] - - // 按名称排序 - result.sort((a, b) => a.id.localeCompare(b.id)) - - // 搜索过滤(支持空格分隔的多关键词 AND 搜索) - if (searchText) { - const keywords = searchText.split(/\s+/).filter(k => k.length > 0) - result = result.filter(m => { - const searchableText = `${m.id} ${m.owned_by || ''} ${m.api_format || ''}`.toLowerCase() - return keywords.every(keyword => searchableText.includes(keyword)) - }) - } - - return result -}) - -// 按 API 格式分组的上游模型列表 -interface UpstreamModelGroup { - api_format: string - models: Array<{ id: string; owned_by?: string; api_format?: string }> -} - -// 可添加的上游模型(排除已添加的)按分组显示 -const groupedAvailableUpstreamModels = computed(() => { - // 获取已添加的映射名称集合 - const addedNames = new Set(formData.value.aliases.map(a => a.name.trim())) - - // 过滤掉已添加的模型 - const availableModels = filteredUpstreamModels.value.filter(m => !addedNames.has(m.id)) - - // 按 API 格式分组 - const groups = new Map() - - for (const model of availableModels) { - const format = model.api_format || 'UNKNOWN' - if (!groups.has(format)) { - groups.set(format, { api_format: format, models: [] }) - } - groups.get(format)!.models.push(model) - } - - // 按 API_FORMAT_LABELS 的键顺序排序 - const order = Object.keys(API_FORMAT_LABELS) - return Array.from(groups.values()).sort((a, b) => { - const aIndex = order.indexOf(a.api_format) - const bIndex = order.indexOf(b.api_format) - // 未知格式排最后 - if (aIndex === -1 && bIndex === -1) return a.api_format.localeCompare(b.api_format) - if (aIndex === -1) return 1 - if (bIndex === -1) return -1 - return aIndex - bIndex - }) -}) - -// 表单数据 -const formData = ref<{ - modelId: string - apiFormats: string[] - aliases: FormAlias[] -}>({ - modelId: '', - apiFormats: [], - aliases: [] -}) - -// 检查是否有有效的别名 -const hasValidAliases = computed(() => { - return formData.value.aliases.some(a => a.name.trim()) -}) - -// 获取 Provider 支持的 API 格式(按 API_FORMATS 定义的顺序排序) +// 获取 Provider 支持的 API 格式 const providerApiFormats = computed(() => { const formats = props.provider?.api_formats if (Array.isArray(formats) && formats.length > 0) { - // 按 API_FORMAT_LABELS 中的键顺序排序 const order = Object.keys(API_FORMAT_LABELS) return [...formats].sort((a, b) => order.indexOf(a) - order.indexOf(b)) } return [] }) -// 分组数据结构 -interface AliasGroup { - model: Model - apiFormatsKey: string // 作用域的唯一标识(排序后的格式数组 JSON) - apiFormats: string[] // 作用域 - aliases: ProviderModelAlias[] // 该组的所有映射 -} - // 生成作用域唯一键 function getApiFormatsKey(formats: string[] | undefined): string { if (!formats || formats.length === 0) return '' @@ -689,12 +242,10 @@ const aliasGroups = computed(() => { } } - // 对每个组内的别名按优先级排序 for (const group of groups) { group.aliases.sort((a, b) => a.priority - b.priority) } - // 按模型名排序,同模型内按作用域排序 return groups.sort((a, b) => { const nameA = (a.model.global_model_display_name || a.model.provider_model_name || '').toLowerCase() const nameB = (b.model.global_model_display_name || b.model.provider_model_name || '').toLowerCase() @@ -703,9 +254,6 @@ const aliasGroups = computed(() => { }) }) -// 当前编辑的分组 -const editingGroup = ref(null) - // 加载模型 async function loadModels() { try { @@ -728,25 +276,6 @@ const deleteConfirmDescription = computed(() => { return `确定要删除模型「${modelName}」在作用域「${scopeText}」下的 ${aliases.length} 个映射吗?\n\n映射名称:${aliasNames}` }) -// 切换 API 格式 -function toggleApiFormat(format: string) { - const index = formData.value.apiFormats.indexOf(format) - if (index >= 0) { - formData.value.apiFormats.splice(index, 1) - } else { - formData.value.apiFormats.push(format) - } -} - -// 切换分组折叠状态(上游模型列表) -function toggleGroupCollapse(apiFormat: string) { - if (collapsedGroups.value.has(apiFormat)) { - collapsedGroups.value.delete(apiFormat) - } else { - collapsedGroups.value.add(apiFormat) - } -} - // 切换映射组展开状态 function toggleAliasGroupExpand(groupKey: string) { if (expandedAliasGroups.value.has(groupKey)) { @@ -756,147 +285,15 @@ function toggleAliasGroupExpand(groupKey: string) { } } -// 添加别名项 -function addAliasItem() { - const maxPriority = formData.value.aliases.length > 0 - ? Math.max(...formData.value.aliases.map(a => a.priority)) - : 0 - formData.value.aliases.push({ name: '', priority: maxPriority + 1 }) -} - -// 删除别名项 -function removeAliasItem(index: number) { - formData.value.aliases.splice(index, 1) -} - -// ===== 拖拽排序 ===== -function handleDragStart(index: number, event: DragEvent) { - draggedIndex.value = index - if (event.dataTransfer) { - event.dataTransfer.effectAllowed = 'move' - } -} - -function handleDragEnd() { - draggedIndex.value = null - dragOverIndex.value = null -} - -function handleDragOver(index: number) { - if (draggedIndex.value !== null && draggedIndex.value !== index) { - dragOverIndex.value = index - } -} - -function handleDragLeave() { - dragOverIndex.value = null -} - -function handleDrop(targetIndex: number) { - const dragIndex = draggedIndex.value - if (dragIndex === null || dragIndex === targetIndex) { - dragOverIndex.value = null - return - } - - const items = [...formData.value.aliases] - const draggedItem = items[dragIndex] - - // 记录每个别名的原始优先级(在修改前) - const originalPriorityMap = new Map() - items.forEach((alias, idx) => { - originalPriorityMap.set(idx, alias.priority) - }) - - // 重排数组 - items.splice(dragIndex, 1) - items.splice(targetIndex, 0, draggedItem) - - // 按新顺序为每个组分配新的优先级 - // 同组的别名保持相同的优先级(被拖动的别名单独成组) - const groupNewPriority = new Map() // 原优先级 -> 新优先级 - let currentPriority = 1 - - items.forEach((alias) => { - // 找到这个别名在原数组中的索引 - const originalIdx = formData.value.aliases.findIndex(a => a === alias) - const originalPriority = originalIdx >= 0 ? originalPriorityMap.get(originalIdx)! : alias.priority - - if (alias === draggedItem) { - // 被拖动的别名是独立的新组,获得当前优先级 - alias.priority = currentPriority - currentPriority++ - } else { - if (groupNewPriority.has(originalPriority)) { - // 这个组已经分配过优先级,使用相同的值 - alias.priority = groupNewPriority.get(originalPriority)! - } else { - // 这个组第一次出现,分配新优先级 - groupNewPriority.set(originalPriority, currentPriority) - alias.priority = currentPriority - currentPriority++ - } - } - }) - - formData.value.aliases = items - draggedIndex.value = null - dragOverIndex.value = null -} - -// ===== 优先级编辑 ===== -function startEditPriority(index: number) { - editingPriorityIndex.value = index -} - -function finishEditPriority(index: number, event: FocusEvent) { - const input = event.target as HTMLInputElement - const newPriority = parseInt(input.value) || 1 - formData.value.aliases[index].priority = Math.max(1, newPriority) - editingPriorityIndex.value = null -} - -function cancelEditPriority() { - editingPriorityIndex.value = null -} - // 打开添加对话框 function openAddDialog() { - editingItem.value = null editingGroup.value = null - formData.value = { - modelId: '', - apiFormats: [], - aliases: [] - } - // 重置状态 - editingPriorityIndex.value = null - draggedIndex.value = null - dragOverIndex.value = null - // 重置上游模型状态 - upstreamModelsLoaded.value = false - upstreamModels.value = [] - upstreamModelSearch.value = '' dialogOpen.value = true } // 编辑分组 function editGroup(group: AliasGroup) { editingGroup.value = group - editingItem.value = { model: group.model, alias: group.aliases[0] } // 保持兼容 - formData.value = { - modelId: group.model.id, - apiFormats: [...group.apiFormats], - aliases: group.aliases.map(a => ({ name: a.name, priority: a.priority })) - } - // 重置状态 - editingPriorityIndex.value = null - draggedIndex.value = null - dragOverIndex.value = null - // 重置上游模型状态 - upstreamModelsLoaded.value = false - upstreamModels.value = [] - upstreamModelSearch.value = '' dialogOpen.value = true } @@ -913,11 +310,9 @@ async function confirmDelete() { const { model, aliases, apiFormatsKey } = deletingGroup.value try { - // 从模型的别名列表中移除该分组的所有别名 const currentAliases = model.provider_model_aliases || [] const aliasNamesToRemove = new Set(aliases.map(a => a.name)) const newAliases = currentAliases.filter((a: ProviderModelAlias) => { - // 只移除同一作用域的别名 const currentKey = getApiFormatsKey(a.api_formats) return !(currentKey === apiFormatsKey && aliasNamesToRemove.has(a.name)) }) @@ -936,89 +331,10 @@ async function confirmDelete() { } } -// 提交表单 -async function handleSubmit() { - if (submitting.value) return - if (!formData.value.modelId || formData.value.aliases.length === 0) return - - // 过滤有效的别名 - const validAliases = formData.value.aliases.filter(a => a.name.trim()) - if (validAliases.length === 0) { - showError('请至少添加一个有效的映射名称', '错误') - return - } - - submitting.value = true - try { - const targetModel = models.value.find(m => m.id === formData.value.modelId) - if (!targetModel) { - showError('模型不存在', '错误') - return - } - - const currentAliases = targetModel.provider_model_aliases || [] - let newAliases: ProviderModelAlias[] - - // 构建新的别名对象(带作用域) - const buildAlias = (a: FormAlias): ProviderModelAlias => ({ - name: a.name.trim(), - priority: a.priority, - ...(formData.value.apiFormats.length > 0 ? { api_formats: formData.value.apiFormats } : {}) - }) - - if (editingGroup.value) { - // 编辑分组模式:替换该分组的所有别名 - const oldApiFormatsKey = editingGroup.value.apiFormatsKey - const oldAliasNames = new Set(editingGroup.value.aliases.map(a => a.name)) - - // 移除旧分组的所有别名 - const filteredAliases = currentAliases.filter((a: ProviderModelAlias) => { - const currentKey = getApiFormatsKey(a.api_formats) - return !(currentKey === oldApiFormatsKey && oldAliasNames.has(a.name)) - }) - - // 检查新别名是否与其他分组的别名重复 - const existingNames = new Set(filteredAliases.map((a: ProviderModelAlias) => a.name)) - const duplicates = validAliases.filter(a => existingNames.has(a.name.trim())) - if (duplicates.length > 0) { - showError(`以下映射名称已存在:${duplicates.map(d => d.name).join(', ')}`, '错误') - return - } - - // 添加新的别名 - newAliases = [ - ...filteredAliases, - ...validAliases.map(buildAlias) - ] - } else { - // 添加模式:检查是否重复并批量添加 - const existingNames = new Set(currentAliases.map((a: ProviderModelAlias) => a.name)) - const duplicates = validAliases.filter(a => existingNames.has(a.name.trim())) - if (duplicates.length > 0) { - showError(`以下映射名称已存在:${duplicates.map(d => d.name).join(', ')}`, '错误') - return - } - newAliases = [ - ...currentAliases, - ...validAliases.map(buildAlias) - ] - } - - await updateModel(props.provider.id, targetModel.id, { - provider_model_aliases: newAliases - }) - - showSuccess(editingGroup.value ? '映射组已更新' : '映射已添加') - dialogOpen.value = false - editingGroup.value = null - editingItem.value = null - await loadModels() - emit('refresh') - } catch (err: any) { - showError(err.response?.data?.detail || '操作失败', '错误') - } finally { - submitting.value = false - } +// 对话框保存后回调 +async function onDialogSaved() { + await loadModels() + emit('refresh') } // 监听 provider 变化 @@ -1033,103 +349,4 @@ onMounted(() => { loadModels() } }) - -// ===== 快速添加(上游模型)===== -async function fetchUpstreamModels() { - if (!props.provider?.id) return - - const providerId = props.provider.id - upstreamModelSearch.value = '' - - // 检查缓存 - const cached = upstreamModelsCache.value.get(providerId) - if (cached && Date.now() - cached.timestamp < CACHE_TTL) { - upstreamModels.value = cached.models - upstreamModelsLoaded.value = true - return - } - - fetchingUpstreamModels.value = true - upstreamModels.value = [] - - try { - const response = await adminApi.queryProviderModels(providerId) - if (response.success && response.data?.models) { - upstreamModels.value = response.data.models - // 写入缓存 - upstreamModelsCache.value.set(providerId, { - models: response.data.models, - timestamp: Date.now() - }) - upstreamModelsLoaded.value = true - } else { - showError(response.data?.error || '获取模型列表失败', '错误') - } - } catch (err: any) { - showError(err.response?.data?.detail || '获取模型列表失败', '错误') - } finally { - fetchingUpstreamModels.value = false - } -} - -// 添加单个上游模型 -function addUpstreamModel(modelId: string) { - // 检查是否已存在 - if (formData.value.aliases.some(a => a.name === modelId)) { - return - } - - const maxPriority = formData.value.aliases.length > 0 - ? Math.max(...formData.value.aliases.map(a => a.priority)) - : 0 - - formData.value.aliases.push({ name: modelId, priority: maxPriority + 1 }) -} - -// 添加某个分组的所有模型 -function addAllFromGroup(apiFormat: string) { - const group = groupedAvailableUpstreamModels.value.find(g => g.api_format === apiFormat) - if (!group) return - - let maxPriority = formData.value.aliases.length > 0 - ? Math.max(...formData.value.aliases.map(a => a.priority)) - : 0 - - for (const model of group.models) { - // 检查是否已存在 - if (!formData.value.aliases.some(a => a.name === model.id)) { - maxPriority++ - formData.value.aliases.push({ name: model.id, priority: maxPriority }) - } - } -} - -// 刷新上游模型列表(清除缓存并重新获取) -async function refreshUpstreamModels() { - if (!props.provider?.id || refreshingUpstreamModels.value) return - - const providerId = props.provider.id - refreshingUpstreamModels.value = true - - // 清除缓存 - upstreamModelsCache.value.delete(providerId) - - try { - const response = await adminApi.queryProviderModels(providerId) - if (response.success && response.data?.models) { - upstreamModels.value = response.data.models - // 写入缓存 - upstreamModelsCache.value.set(providerId, { - models: response.data.models, - timestamp: Date.now() - }) - } else { - showError(response.data?.error || '刷新失败', '错误') - } - } catch (err: any) { - showError(err.response?.data?.detail || '刷新失败', '错误') - } finally { - refreshingUpstreamModels.value = false - } -} diff --git a/frontend/src/features/providers/composables/useUpstreamModelsCache.ts b/frontend/src/features/providers/composables/useUpstreamModelsCache.ts new file mode 100644 index 0000000..2c4d5fe --- /dev/null +++ b/frontend/src/features/providers/composables/useUpstreamModelsCache.ts @@ -0,0 +1,112 @@ +/** + * 上游模型缓存 - 共享缓存,避免重复请求 + */ +import { ref } from 'vue' +import { adminApi } from '@/api/admin' +import type { UpstreamModel } from '@/api/endpoints/types' + +// 扩展类型,包含可能的额外字段 +export type { UpstreamModel } + +interface CacheEntry { + models: UpstreamModel[] + timestamp: number +} + +type FetchResult = { models: UpstreamModel[]; error?: string } + +// 全局缓存(模块级别,所有组件共享) +const cache = new Map() +const CACHE_TTL = 5 * 60 * 1000 // 5分钟 + +// 进行中的请求(用于去重并发请求) +const pendingRequests = new Map>() + +// 请求状态 +const loadingMap = ref>(new Map()) + +export function useUpstreamModelsCache() { + /** + * 获取上游模型列表 + * @param providerId 提供商ID + * @param forceRefresh 是否强制刷新 + * @returns 模型列表或 null(如果请求失败) + */ + async function fetchModels( + providerId: string, + forceRefresh = false + ): Promise { + // 检查缓存 + if (!forceRefresh) { + const cached = cache.get(providerId) + if (cached && Date.now() - cached.timestamp < CACHE_TTL) { + return { models: cached.models } + } + } + + // 检查是否有进行中的请求(非强制刷新时复用) + if (!forceRefresh && pendingRequests.has(providerId)) { + return pendingRequests.get(providerId)! + } + + // 创建新请求 + const requestPromise = (async (): Promise => { + try { + loadingMap.value.set(providerId, true) + const response = await adminApi.queryProviderModels(providerId) + + if (response.success && response.data?.models) { + // 存入缓存 + cache.set(providerId, { + models: response.data.models, + timestamp: Date.now() + }) + return { models: response.data.models } + } else { + return { models: [], error: response.data?.error || '获取上游模型失败' } + } + } catch (err: any) { + return { models: [], error: err.response?.data?.detail || '获取上游模型失败' } + } finally { + loadingMap.value.set(providerId, false) + pendingRequests.delete(providerId) + } + })() + + pendingRequests.set(providerId, requestPromise) + return requestPromise + } + + /** + * 获取缓存的模型(不发起请求) + */ + function getCachedModels(providerId: string): UpstreamModel[] | null { + const cached = cache.get(providerId) + if (cached && Date.now() - cached.timestamp < CACHE_TTL) { + return cached.models + } + return null + } + + /** + * 清除指定提供商的缓存 + */ + function clearCache(providerId: string) { + cache.delete(providerId) + } + + /** + * 检查是否正在加载 + */ + function isLoading(providerId: string): boolean { + return loadingMap.value.get(providerId) || false + } + + return { + fetchModels, + getCachedModels, + clearCache, + isLoading, + loadingMap + } +} diff --git a/src/api/admin/provider_query.py b/src/api/admin/provider_query.py index d84bce3..b91d68a 100644 --- a/src/api/admin/provider_query.py +++ b/src/api/admin/provider_query.py @@ -151,29 +151,46 @@ async def query_available_models( adapter_class = _get_adapter_for_format(api_format) if not adapter_class: return [], f"Unknown API format: {api_format}" - return await adapter_class.fetch_models( + models, error = await adapter_class.fetch_models( client, base_url, api_key_value, extra_headers ) + # 确保所有模型都有 api_format 字段 + for m in models: + if "api_format" not in m: + m["api_format"] = api_format + return models, error except Exception as e: logger.error(f"Error fetching models from {api_format} endpoint: {e}") return [], f"{api_format}: {str(e)}" + # 限制并发请求数量,避免触发上游速率限制 + MAX_CONCURRENT_REQUESTS = 5 + semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS) + + async def fetch_with_semaphore( + client: httpx.AsyncClient, config: dict + ) -> tuple[list, Optional[str]]: + async with semaphore: + return await fetch_endpoint_models(client, config) + async with httpx.AsyncClient(timeout=30.0) as client: results = await asyncio.gather( - *[fetch_endpoint_models(client, c) for c in endpoint_configs] + *[fetch_with_semaphore(client, c) for c in endpoint_configs] ) for models, error in results: all_models.extend(models) if error: errors.append(error) - # 按 model id 去重(保留第一个) - seen_ids: set[str] = set() + # 按 model id + api_format 去重(保留第一个) + seen_keys: set[str] = set() unique_models: list = [] for model in all_models: model_id = model.get("id") - if model_id and model_id not in seen_ids: - seen_ids.add(model_id) + api_format = model.get("api_format", "") + unique_key = f"{model_id}:{api_format}" + if model_id and unique_key not in seen_keys: + seen_keys.add(unique_key) unique_models.append(model) error = "; ".join(errors) if errors else None diff --git a/src/api/admin/providers/models.py b/src/api/admin/providers/models.py index 9b56366..42483d2 100644 --- a/src/api/admin/providers/models.py +++ b/src/api/admin/providers/models.py @@ -22,16 +22,18 @@ from src.models.api import ( from src.models.pydantic_models import ( BatchAssignModelsToProviderRequest, BatchAssignModelsToProviderResponse, + ImportFromUpstreamRequest, + ImportFromUpstreamResponse, + ImportFromUpstreamSuccessItem, + ImportFromUpstreamErrorItem, + ProviderAvailableSourceModel, + ProviderAvailableSourceModelsResponse, ) from src.models.database import ( GlobalModel, Model, Provider, ) -from src.models.pydantic_models import ( - ProviderAvailableSourceModel, - ProviderAvailableSourceModelsResponse, -) from src.services.model.service import ModelService router = APIRouter(tags=["Model Management"]) @@ -158,6 +160,28 @@ async def batch_assign_global_models_to_provider( return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) +@router.post( + "/{provider_id}/import-from-upstream", + response_model=ImportFromUpstreamResponse, +) +async def import_models_from_upstream( + provider_id: str, + payload: ImportFromUpstreamRequest, + request: Request, + db: Session = Depends(get_db), +) -> ImportFromUpstreamResponse: + """ + 从上游提供商导入模型 + + 流程: + 1. 根据 model_ids 检查全局模型是否存在(按 name 匹配) + 2. 如不存在,自动创建新的 GlobalModel(使用默认配置) + 3. 创建 Model 关联到当前 Provider + """ + adapter = AdminImportFromUpstreamAdapter(provider_id=provider_id, payload=payload) + return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) + + # -------- Adapters -------- @@ -425,3 +449,130 @@ class AdminBatchAssignModelsToProviderAdapter(AdminApiAdapter): await invalidate_models_list_cache() return BatchAssignModelsToProviderResponse(success=success, errors=errors) + + +@dataclass +class AdminImportFromUpstreamAdapter(AdminApiAdapter): + """从上游提供商导入模型""" + + provider_id: str + payload: ImportFromUpstreamRequest + + async def handle(self, context): # type: ignore[override] + db = context.db + provider = db.query(Provider).filter(Provider.id == self.provider_id).first() + if not provider: + raise NotFoundException("Provider not found", "provider") + + success: list[ImportFromUpstreamSuccessItem] = [] + errors: list[ImportFromUpstreamErrorItem] = [] + + # 默认阶梯计费配置(免费) + default_tiered_pricing = { + "tiers": [ + { + "up_to": None, + "input_price_per_1m": 0.0, + "output_price_per_1m": 0.0, + } + ] + } + + for model_id in self.payload.model_ids: + # 输入验证:检查 model_id 长度 + if not model_id or len(model_id) > 100: + errors.append( + ImportFromUpstreamErrorItem( + model_id=model_id[:50] + "..." if model_id and len(model_id) > 50 else model_id or "", + error="Invalid model_id: must be 1-100 characters", + ) + ) + continue + + try: + # 使用 savepoint 确保单个模型导入的原子性 + savepoint = db.begin_nested() + try: + # 1. 检查是否已存在同名的 GlobalModel + global_model = ( + db.query(GlobalModel).filter(GlobalModel.name == model_id).first() + ) + created_global_model = False + + if not global_model: + # 2. 创建新的 GlobalModel + global_model = GlobalModel( + name=model_id, + display_name=model_id, + default_tiered_pricing=default_tiered_pricing, + is_active=True, + ) + db.add(global_model) + db.flush() + created_global_model = True + logger.info( + f"Created new GlobalModel: {model_id} during upstream import" + ) + + # 3. 检查是否已存在关联 + existing = ( + db.query(Model) + .filter( + Model.provider_id == self.provider_id, + Model.global_model_id == global_model.id, + ) + .first() + ) + if existing: + # 已存在关联,提交 savepoint 并记录成功 + savepoint.commit() + success.append( + ImportFromUpstreamSuccessItem( + model_id=model_id, + global_model_id=global_model.id, + global_model_name=global_model.name, + provider_model_id=existing.id, + created_global_model=created_global_model, + ) + ) + continue + + # 4. 创建新的 Model 记录 + new_model = Model( + provider_id=self.provider_id, + global_model_id=global_model.id, + provider_model_name=global_model.name, + is_active=True, + ) + db.add(new_model) + db.flush() + + # 提交 savepoint + savepoint.commit() + success.append( + ImportFromUpstreamSuccessItem( + model_id=model_id, + global_model_id=global_model.id, + global_model_name=global_model.name, + provider_model_id=new_model.id, + created_global_model=created_global_model, + ) + ) + except Exception as e: + # 回滚到 savepoint + savepoint.rollback() + raise e + except Exception as e: + logger.error(f"Error importing model {model_id}: {e}") + errors.append(ImportFromUpstreamErrorItem(model_id=model_id, error=str(e))) + + db.commit() + logger.info( + f"Imported {len(success)} models from upstream to provider {provider.name} by {context.user.username}" + ) + + # 清除 /v1/models 列表缓存 + if success: + await invalidate_models_list_cache() + + return ImportFromUpstreamResponse(success=success, errors=errors) diff --git a/src/models/pydantic_models.py b/src/models/pydantic_models.py index 3625a0a..779950e 100644 --- a/src/models/pydantic_models.py +++ b/src/models/pydantic_models.py @@ -301,6 +301,36 @@ class BatchAssignModelsToProviderResponse(BaseModel): errors: List[dict] +class ImportFromUpstreamRequest(BaseModel): + """从上游提供商导入模型请求""" + + model_ids: List[str] = Field(..., min_length=1, description="上游模型 ID 列表") + + +class ImportFromUpstreamSuccessItem(BaseModel): + """导入成功的模型信息""" + + model_id: str = Field(..., description="上游模型 ID") + global_model_id: str = Field(..., description="GlobalModel ID") + global_model_name: str = Field(..., description="GlobalModel 名称") + provider_model_id: str = Field(..., description="Provider Model ID") + created_global_model: bool = Field(..., description="是否新创建了 GlobalModel") + + +class ImportFromUpstreamErrorItem(BaseModel): + """导入失败的模型信息""" + + model_id: str = Field(..., description="上游模型 ID") + error: str = Field(..., description="错误信息") + + +class ImportFromUpstreamResponse(BaseModel): + """从上游提供商导入模型响应""" + + success: List[ImportFromUpstreamSuccessItem] + errors: List[ImportFromUpstreamErrorItem] + + __all__ = [ "BatchAssignModelsToProviderRequest", "BatchAssignModelsToProviderResponse", @@ -311,6 +341,10 @@ __all__ = [ "GlobalModelResponse", "GlobalModelUpdate", "GlobalModelWithStats", + "ImportFromUpstreamErrorItem", + "ImportFromUpstreamRequest", + "ImportFromUpstreamResponse", + "ImportFromUpstreamSuccessItem", "ModelCapabilities", "ModelCatalogItem", "ModelCatalogProviderDetail",