mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-02 15:52:26 +08:00
feat: 引入统一的端点检查器以重构适配器并改进错误处理和用量统计。
This commit is contained in:
@@ -58,3 +58,16 @@ export async function deleteProvider(providerId: string): Promise<{ message: str
|
|||||||
return response.data
|
return response.data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 测试模型连接性
|
||||||
|
*/
|
||||||
|
export async function testModel(data: {
|
||||||
|
provider_id: string
|
||||||
|
model_name: string
|
||||||
|
api_key_id?: string
|
||||||
|
message?: string
|
||||||
|
}): Promise<any> {
|
||||||
|
const response = await client.post('/api/admin/provider-query/test-model', data)
|
||||||
|
return response.data
|
||||||
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -163,7 +163,9 @@ const contentZIndex = computed(() => (props.zIndex || 60) + 10)
|
|||||||
useEscapeKey(() => {
|
useEscapeKey(() => {
|
||||||
if (isOpen.value) {
|
if (isOpen.value) {
|
||||||
handleClose()
|
handleClose()
|
||||||
|
return true // 阻止其他监听器(如父级抽屉的 ESC 监听器)
|
||||||
}
|
}
|
||||||
|
return false
|
||||||
}, {
|
}, {
|
||||||
disableOnInput: true,
|
disableOnInput: true,
|
||||||
once: false
|
once: false
|
||||||
|
|||||||
@@ -47,11 +47,11 @@ export function useConfirm() {
|
|||||||
/**
|
/**
|
||||||
* 便捷方法:危险操作确认(红色主题)
|
* 便捷方法:危险操作确认(红色主题)
|
||||||
*/
|
*/
|
||||||
const confirmDanger = (message: string, title?: string): Promise<boolean> => {
|
const confirmDanger = (message: string, title?: string, confirmText?: string): Promise<boolean> => {
|
||||||
return confirm({
|
return confirm({
|
||||||
message,
|
message,
|
||||||
title: title || '危险操作',
|
title: title || '危险操作',
|
||||||
confirmText: '删除',
|
confirmText: confirmText || '删除',
|
||||||
variant: 'danger'
|
variant: 'danger'
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,11 +4,11 @@ import { onMounted, onUnmounted, ref } from 'vue'
|
|||||||
* ESC 键监听 Composable(简化版本,直接使用独立监听器)
|
* ESC 键监听 Composable(简化版本,直接使用独立监听器)
|
||||||
* 用于按 ESC 键关闭弹窗或其他可关闭的组件
|
* 用于按 ESC 键关闭弹窗或其他可关闭的组件
|
||||||
*
|
*
|
||||||
* @param callback - 按 ESC 键时执行的回调函数
|
* @param callback - 按 ESC 键时执行的回调函数,返回 true 表示已处理事件,阻止其他监听器执行
|
||||||
* @param options - 配置选项
|
* @param options - 配置选项
|
||||||
*/
|
*/
|
||||||
export function useEscapeKey(
|
export function useEscapeKey(
|
||||||
callback: () => void,
|
callback: () => void | boolean,
|
||||||
options: {
|
options: {
|
||||||
/** 是否在输入框获得焦点时禁用 ESC 键,默认 true */
|
/** 是否在输入框获得焦点时禁用 ESC 键,默认 true */
|
||||||
disableOnInput?: boolean
|
disableOnInput?: boolean
|
||||||
@@ -42,8 +42,11 @@ export function useEscapeKey(
|
|||||||
if (isInputElement) return
|
if (isInputElement) return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 执行回调
|
// 执行回调,如果返回 true 则阻止其他监听器
|
||||||
callback()
|
const handled = callback()
|
||||||
|
if (handled === true) {
|
||||||
|
event.stopImmediatePropagation()
|
||||||
|
}
|
||||||
|
|
||||||
// 移除当前元素的焦点,避免残留样式
|
// 移除当前元素的焦点,避免残留样式
|
||||||
if (document.activeElement instanceof HTMLElement) {
|
if (document.activeElement instanceof HTMLElement) {
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
v-model:open="modelSelectOpen"
|
v-model:open="modelSelectOpen"
|
||||||
:model-value="formData.modelId"
|
:model-value="formData.modelId"
|
||||||
:disabled="!!editingGroup"
|
:disabled="!!editingGroup"
|
||||||
@update:model-value="formData.modelId = $event"
|
@update:model-value="handleModelChange"
|
||||||
>
|
>
|
||||||
<SelectTrigger class="h-9">
|
<SelectTrigger class="h-9">
|
||||||
<SelectValue placeholder="请选择模型" />
|
<SelectValue placeholder="请选择模型" />
|
||||||
@@ -518,6 +518,15 @@ function initForm() {
|
|||||||
upstreamModels.value = []
|
upstreamModels.value = []
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 处理模型选择变更
|
||||||
|
function handleModelChange(value: string) {
|
||||||
|
formData.value.modelId = value
|
||||||
|
const selectedModel = props.models.find(m => m.id === value)
|
||||||
|
if (selectedModel) {
|
||||||
|
upstreamModelSearch.value = selectedModel.provider_model_name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 切换 API 格式
|
// 切换 API 格式
|
||||||
function toggleApiFormat(format: string) {
|
function toggleApiFormat(format: string) {
|
||||||
|
|||||||
@@ -531,6 +531,7 @@
|
|||||||
<!-- 模型名称映射 -->
|
<!-- 模型名称映射 -->
|
||||||
<ModelAliasesTab
|
<ModelAliasesTab
|
||||||
v-if="provider"
|
v-if="provider"
|
||||||
|
ref="modelAliasesTabRef"
|
||||||
:key="`aliases-${provider.id}`"
|
:key="`aliases-${provider.id}`"
|
||||||
:provider="provider"
|
:provider="provider"
|
||||||
@refresh="handleRelatedDataRefresh"
|
@refresh="handleRelatedDataRefresh"
|
||||||
@@ -735,6 +736,9 @@ const deleteModelConfirmOpen = ref(false)
|
|||||||
const modelToDelete = ref<Model | null>(null)
|
const modelToDelete = ref<Model | null>(null)
|
||||||
const batchAssignDialogOpen = ref(false)
|
const batchAssignDialogOpen = ref(false)
|
||||||
|
|
||||||
|
// ModelAliasesTab 组件引用
|
||||||
|
const modelAliasesTabRef = ref<InstanceType<typeof ModelAliasesTab> | null>(null)
|
||||||
|
|
||||||
// 拖动排序相关状态
|
// 拖动排序相关状态
|
||||||
const dragState = ref({
|
const dragState = ref({
|
||||||
isDragging: false,
|
isDragging: false,
|
||||||
@@ -756,7 +760,9 @@ const hasBlockingDialogOpen = computed(() =>
|
|||||||
deleteKeyConfirmOpen.value ||
|
deleteKeyConfirmOpen.value ||
|
||||||
modelFormDialogOpen.value ||
|
modelFormDialogOpen.value ||
|
||||||
deleteModelConfirmOpen.value ||
|
deleteModelConfirmOpen.value ||
|
||||||
batchAssignDialogOpen.value
|
batchAssignDialogOpen.value ||
|
||||||
|
// 检测 ModelAliasesTab 子组件的 Dialog 是否打开
|
||||||
|
modelAliasesTabRef.value?.dialogOpen
|
||||||
)
|
)
|
||||||
|
|
||||||
// 监听 providerId 变化
|
// 监听 providerId 变化
|
||||||
|
|||||||
@@ -110,16 +110,30 @@
|
|||||||
<div
|
<div
|
||||||
v-for="mapping in group.aliases"
|
v-for="mapping in group.aliases"
|
||||||
:key="mapping.name"
|
:key="mapping.name"
|
||||||
class="flex items-center gap-2 py-1"
|
class="flex items-center justify-between gap-2 py-1"
|
||||||
>
|
>
|
||||||
<!-- 优先级标签 -->
|
<div class="flex items-center gap-2 flex-1 min-w-0">
|
||||||
<span class="inline-flex items-center justify-center w-5 h-5 rounded bg-background border text-xs font-medium shrink-0">
|
<!-- 优先级标签 -->
|
||||||
{{ mapping.priority }}
|
<span class="inline-flex items-center justify-center w-5 h-5 rounded bg-background border text-xs font-medium shrink-0">
|
||||||
</span>
|
{{ mapping.priority }}
|
||||||
<!-- 映射名称 -->
|
</span>
|
||||||
<span class="font-mono text-sm truncate">
|
<!-- 映射名称 -->
|
||||||
{{ mapping.name }}
|
<span class="font-mono text-sm truncate">
|
||||||
</span>
|
{{ mapping.name }}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<!-- 测试按钮 -->
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
size="icon"
|
||||||
|
class="h-7 w-7 shrink-0"
|
||||||
|
title="测试映射"
|
||||||
|
:disabled="testingMapping === `${group.model.id}-${group.apiFormatsKey}-${mapping.name}`"
|
||||||
|
@click="testMapping(group, mapping)"
|
||||||
|
>
|
||||||
|
<Loader2 v-if="testingMapping === `${group.model.id}-${group.apiFormatsKey}-${mapping.name}`" class="w-3 h-3 animate-spin" />
|
||||||
|
<Play v-else class="w-3 h-3" />
|
||||||
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -166,13 +180,14 @@
|
|||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, computed, onMounted, watch } from 'vue'
|
import { ref, computed, onMounted, watch } from 'vue'
|
||||||
import { Tag, Plus, Edit, Trash2, ChevronRight } from 'lucide-vue-next'
|
import { Tag, Plus, Edit, Trash2, ChevronRight, Loader2, Play } from 'lucide-vue-next'
|
||||||
import { Card, Button, Badge } from '@/components/ui'
|
import { Card, Button, Badge } from '@/components/ui'
|
||||||
import AlertDialog from '@/components/common/AlertDialog.vue'
|
import AlertDialog from '@/components/common/AlertDialog.vue'
|
||||||
import ModelMappingDialog, { type AliasGroup } from '../ModelMappingDialog.vue'
|
import ModelMappingDialog, { type AliasGroup } from '../ModelMappingDialog.vue'
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
import {
|
import {
|
||||||
getProviderModels,
|
getProviderModels,
|
||||||
|
testModel,
|
||||||
API_FORMAT_LABELS,
|
API_FORMAT_LABELS,
|
||||||
type Model,
|
type Model,
|
||||||
type ProviderModelAlias
|
type ProviderModelAlias
|
||||||
@@ -196,6 +211,7 @@ const dialogOpen = ref(false)
|
|||||||
const deleteConfirmOpen = ref(false)
|
const deleteConfirmOpen = ref(false)
|
||||||
const editingGroup = ref<AliasGroup | null>(null)
|
const editingGroup = ref<AliasGroup | null>(null)
|
||||||
const deletingGroup = ref<AliasGroup | null>(null)
|
const deletingGroup = ref<AliasGroup | null>(null)
|
||||||
|
const testingMapping = ref<string | null>(null)
|
||||||
|
|
||||||
// 列表展开状态
|
// 列表展开状态
|
||||||
const expandedAliasGroups = ref<Set<string>>(new Set())
|
const expandedAliasGroups = ref<Set<string>>(new Set())
|
||||||
@@ -337,6 +353,81 @@ async function onDialogSaved() {
|
|||||||
emit('refresh')
|
emit('refresh')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 测试模型映射
|
||||||
|
async function testMapping(group: any, mapping: any) {
|
||||||
|
const testingKey = `${group.model.id}-${group.apiFormatsKey}-${mapping.name}`
|
||||||
|
testingMapping.value = testingKey
|
||||||
|
|
||||||
|
try {
|
||||||
|
// 根据分组的 API 格式来确定应该使用的格式
|
||||||
|
let apiFormat = null
|
||||||
|
if (group.apiFormats.length === 1) {
|
||||||
|
apiFormat = group.apiFormats[0]
|
||||||
|
} else if (group.apiFormats.length === 0) {
|
||||||
|
// 如果没有指定格式,但分组显示为"全部",则使用模型的默认格式
|
||||||
|
apiFormat = group.model.effective_api_format || group.model.api_format
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log(`测试映射 ${mapping.name},使用 API Format: ${apiFormat}`)
|
||||||
|
|
||||||
|
const result = await testModel({
|
||||||
|
provider_id: props.provider.id,
|
||||||
|
model_name: mapping.name, // 使用映射名称进行测试
|
||||||
|
message: "hello",
|
||||||
|
api_format: apiFormat
|
||||||
|
})
|
||||||
|
|
||||||
|
if (result.success) {
|
||||||
|
showSuccess(`映射 "${mapping.name}" 测试成功`)
|
||||||
|
|
||||||
|
// 如果有响应内容,可以显示更多信息
|
||||||
|
if (result.data?.response?.choices?.[0]?.message?.content) {
|
||||||
|
const content = result.data.response.choices[0].message.content
|
||||||
|
showSuccess(`测试成功,响应: ${content.substring(0, 100)}${content.length > 100 ? '...' : ''}`)
|
||||||
|
} else if (result.data?.content_preview) {
|
||||||
|
showSuccess(`流式测试成功,预览: ${result.data.content_preview}`)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 根据不同的错误类型显示更详细的信息
|
||||||
|
let errorMsg = result.error || '测试失败'
|
||||||
|
|
||||||
|
// 检查HTTP状态码错误
|
||||||
|
if (result.data?.response?.status_code) {
|
||||||
|
const status = result.data.response.status_code
|
||||||
|
if (status === 403) {
|
||||||
|
errorMsg = '认证失败: API密钥无效或客户端类型不被允许'
|
||||||
|
} else if (status === 401) {
|
||||||
|
errorMsg = '认证失败: API密钥无效或已过期'
|
||||||
|
} else if (status === 404) {
|
||||||
|
errorMsg = '模型不存在: 请检查模型名称是否正确'
|
||||||
|
} else if (status === 429) {
|
||||||
|
errorMsg = '请求频率过高: 请稍后重试'
|
||||||
|
} else if (status >= 500) {
|
||||||
|
errorMsg = `服务器错误: HTTP ${status}`
|
||||||
|
} else {
|
||||||
|
errorMsg = `请求失败: HTTP ${status}`
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试从错误响应中提取更多信息
|
||||||
|
if (result.data?.response?.error) {
|
||||||
|
if (typeof result.data.response.error === 'string') {
|
||||||
|
errorMsg = result.data.response.error
|
||||||
|
} else if (result.data.response.error?.message) {
|
||||||
|
errorMsg = result.data.response.error.message
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
showError(`映射测试失败: ${errorMsg}`)
|
||||||
|
}
|
||||||
|
} catch (err: any) {
|
||||||
|
const errorMsg = err.response?.data?.detail || err.message || '测试请求失败'
|
||||||
|
showError(`映射测试失败: ${errorMsg}`)
|
||||||
|
} finally {
|
||||||
|
testingMapping.value = null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 监听 provider 变化
|
// 监听 provider 变化
|
||||||
watch(() => props.provider?.id, (newId) => {
|
watch(() => props.provider?.id, (newId) => {
|
||||||
if (newId) {
|
if (newId) {
|
||||||
@@ -349,4 +440,9 @@ onMounted(() => {
|
|||||||
loadModels()
|
loadModels()
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// 暴露给父组件,用于检测是否有弹窗打开
|
||||||
|
defineExpose({
|
||||||
|
dialogOpen: computed(() => dialogOpen.value || deleteConfirmOpen.value)
|
||||||
|
})
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
@@ -156,6 +156,17 @@
|
|||||||
</td>
|
</td>
|
||||||
<td class="align-top px-4 py-3">
|
<td class="align-top px-4 py-3">
|
||||||
<div class="flex justify-center gap-1.5">
|
<div class="flex justify-center gap-1.5">
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
size="icon"
|
||||||
|
class="h-8 w-8"
|
||||||
|
title="测试模型"
|
||||||
|
:disabled="testingModelId === model.id"
|
||||||
|
@click="testModelConnection(model)"
|
||||||
|
>
|
||||||
|
<Loader2 v-if="testingModelId === model.id" class="w-3.5 h-3.5 animate-spin" />
|
||||||
|
<Play v-else class="w-3.5 h-3.5" />
|
||||||
|
</Button>
|
||||||
<Button
|
<Button
|
||||||
variant="ghost"
|
variant="ghost"
|
||||||
size="icon"
|
size="icon"
|
||||||
@@ -209,11 +220,11 @@
|
|||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, computed, onMounted } from 'vue'
|
import { ref, computed, onMounted } from 'vue'
|
||||||
import { Box, Edit, Trash2, Layers, Eye, Wrench, Zap, Brain, Power, Copy, Image } from 'lucide-vue-next'
|
import { Box, Edit, Trash2, Layers, Eye, Wrench, Zap, Brain, Power, Copy, Image, Loader2, Play } from 'lucide-vue-next'
|
||||||
import Card from '@/components/ui/card.vue'
|
import Card from '@/components/ui/card.vue'
|
||||||
import Button from '@/components/ui/button.vue'
|
import Button from '@/components/ui/button.vue'
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
import { getProviderModels, type Model } from '@/api/endpoints'
|
import { getProviderModels, testModel, type Model } from '@/api/endpoints'
|
||||||
import { updateModel } from '@/api/endpoints/models'
|
import { updateModel } from '@/api/endpoints/models'
|
||||||
|
|
||||||
const props = defineProps<{
|
const props = defineProps<{
|
||||||
@@ -232,6 +243,7 @@ const { error: showError, success: showSuccess } = useToast()
|
|||||||
const loading = ref(false)
|
const loading = ref(false)
|
||||||
const models = ref<Model[]>([])
|
const models = ref<Model[]>([])
|
||||||
const togglingModelId = ref<string | null>(null)
|
const togglingModelId = ref<string | null>(null)
|
||||||
|
const testingModelId = ref<string | null>(null)
|
||||||
|
|
||||||
// 按名称排序的模型列表
|
// 按名称排序的模型列表
|
||||||
const sortedModels = computed(() => {
|
const sortedModels = computed(() => {
|
||||||
@@ -380,6 +392,69 @@ async function toggleModelActive(model: Model) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 测试模型连接性
|
||||||
|
async function testModelConnection(model: Model) {
|
||||||
|
if (testingModelId.value) return
|
||||||
|
|
||||||
|
testingModelId.value = model.id
|
||||||
|
try {
|
||||||
|
const result = await testModel({
|
||||||
|
provider_id: props.provider.id,
|
||||||
|
model_name: model.provider_model_name,
|
||||||
|
message: "hello"
|
||||||
|
})
|
||||||
|
|
||||||
|
if (result.success) {
|
||||||
|
showSuccess(`模型 "${model.provider_model_name}" 测试成功`)
|
||||||
|
|
||||||
|
// 如果有响应内容,可以显示更多信息
|
||||||
|
if (result.data?.response?.choices?.[0]?.message?.content) {
|
||||||
|
const content = result.data.response.choices[0].message.content
|
||||||
|
showSuccess(`测试成功,响应: ${content.substring(0, 100)}${content.length > 100 ? '...' : ''}`)
|
||||||
|
} else if (result.data?.content_preview) {
|
||||||
|
showSuccess(`流式测试成功,预览: ${result.data.content_preview}`)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 根据不同的错误类型显示更详细的信息
|
||||||
|
let errorMsg = result.error || '测试失败'
|
||||||
|
|
||||||
|
// 检查HTTP状态码错误
|
||||||
|
if (result.data?.response?.status_code) {
|
||||||
|
const status = result.data.response.status_code
|
||||||
|
if (status === 403) {
|
||||||
|
errorMsg = '认证失败: API密钥无效或客户端类型不被允许'
|
||||||
|
} else if (status === 401) {
|
||||||
|
errorMsg = '认证失败: API密钥无效或已过期'
|
||||||
|
} else if (status === 404) {
|
||||||
|
errorMsg = '模型不存在: 请检查模型名称是否正确'
|
||||||
|
} else if (status === 429) {
|
||||||
|
errorMsg = '请求频率过高: 请稍后重试'
|
||||||
|
} else if (status >= 500) {
|
||||||
|
errorMsg = `服务器错误: HTTP ${status}`
|
||||||
|
} else {
|
||||||
|
errorMsg = `请求失败: HTTP ${status}`
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试从错误响应中提取更多信息
|
||||||
|
if (result.data?.response?.error) {
|
||||||
|
if (typeof result.data.response.error === 'string') {
|
||||||
|
errorMsg = result.data.response.error
|
||||||
|
} else if (result.data.response.error?.message) {
|
||||||
|
errorMsg = result.data.response.error.message
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
showError(`模型测试失败: ${errorMsg}`)
|
||||||
|
}
|
||||||
|
} catch (err: any) {
|
||||||
|
const errorMsg = err.response?.data?.detail || err.message || '测试请求失败'
|
||||||
|
showError(`模型测试失败: ${errorMsg}`)
|
||||||
|
} finally {
|
||||||
|
testingModelId.value = null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
onMounted(() => {
|
onMounted(() => {
|
||||||
loadModels()
|
loadModels()
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ export const useUsersStore = defineStore('users', () => {
|
|||||||
try {
|
try {
|
||||||
users.value = await usersApi.getAllUsers()
|
users.value = await usersApi.getAllUsers()
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error.value = err.response?.data?.detail || '获取用户列表失败'
|
error.value = err.response?.data?.error?.message || err.response?.data?.detail || '获取用户列表失败'
|
||||||
} finally {
|
} finally {
|
||||||
loading.value = false
|
loading.value = false
|
||||||
}
|
}
|
||||||
@@ -29,7 +29,7 @@ export const useUsersStore = defineStore('users', () => {
|
|||||||
users.value.push(newUser)
|
users.value.push(newUser)
|
||||||
return newUser
|
return newUser
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error.value = err.response?.data?.detail || '创建用户失败'
|
error.value = err.response?.data?.error?.message || err.response?.data?.detail || '创建用户失败'
|
||||||
throw err
|
throw err
|
||||||
} finally {
|
} finally {
|
||||||
loading.value = false
|
loading.value = false
|
||||||
@@ -52,7 +52,7 @@ export const useUsersStore = defineStore('users', () => {
|
|||||||
}
|
}
|
||||||
return updatedUser
|
return updatedUser
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error.value = err.response?.data?.detail || '更新用户失败'
|
error.value = err.response?.data?.error?.message || err.response?.data?.detail || '更新用户失败'
|
||||||
throw err
|
throw err
|
||||||
} finally {
|
} finally {
|
||||||
loading.value = false
|
loading.value = false
|
||||||
@@ -67,7 +67,7 @@ export const useUsersStore = defineStore('users', () => {
|
|||||||
await usersApi.deleteUser(userId)
|
await usersApi.deleteUser(userId)
|
||||||
users.value = users.value.filter(u => u.id !== userId)
|
users.value = users.value.filter(u => u.id !== userId)
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error.value = err.response?.data?.detail || '删除用户失败'
|
error.value = err.response?.data?.error?.message || err.response?.data?.detail || '删除用户失败'
|
||||||
throw err
|
throw err
|
||||||
} finally {
|
} finally {
|
||||||
loading.value = false
|
loading.value = false
|
||||||
@@ -78,7 +78,7 @@ export const useUsersStore = defineStore('users', () => {
|
|||||||
try {
|
try {
|
||||||
return await usersApi.getUserApiKeys(userId)
|
return await usersApi.getUserApiKeys(userId)
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error.value = err.response?.data?.detail || '获取 API Keys 失败'
|
error.value = err.response?.data?.error?.message || err.response?.data?.detail || '获取 API Keys 失败'
|
||||||
throw err
|
throw err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -87,7 +87,7 @@ export const useUsersStore = defineStore('users', () => {
|
|||||||
try {
|
try {
|
||||||
return await usersApi.createApiKey(userId, name)
|
return await usersApi.createApiKey(userId, name)
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error.value = err.response?.data?.detail || '创建 API Key 失败'
|
error.value = err.response?.data?.error?.message || err.response?.data?.detail || '创建 API Key 失败'
|
||||||
throw err
|
throw err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -96,7 +96,7 @@ export const useUsersStore = defineStore('users', () => {
|
|||||||
try {
|
try {
|
||||||
await usersApi.deleteApiKey(userId, keyId)
|
await usersApi.deleteApiKey(userId, keyId)
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error.value = err.response?.data?.detail || '删除 API Key 失败'
|
error.value = err.response?.data?.error?.message || err.response?.data?.detail || '删除 API Key 失败'
|
||||||
throw err
|
throw err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -110,7 +110,7 @@ export const useUsersStore = defineStore('users', () => {
|
|||||||
// 刷新用户列表以获取最新数据
|
// 刷新用户列表以获取最新数据
|
||||||
await fetchUsers()
|
await fetchUsers()
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error.value = err.response?.data?.detail || '重置配额失败'
|
error.value = err.response?.data?.error?.message || err.response?.data?.detail || '重置配额失败'
|
||||||
throw err
|
throw err
|
||||||
} finally {
|
} finally {
|
||||||
loading.value = false
|
loading.value = false
|
||||||
|
|||||||
@@ -723,9 +723,19 @@ async function handleDeleteProvider(provider: ProviderWithEndpointsSummary) {
|
|||||||
// 切换提供商状态
|
// 切换提供商状态
|
||||||
async function toggleProviderStatus(provider: ProviderWithEndpointsSummary) {
|
async function toggleProviderStatus(provider: ProviderWithEndpointsSummary) {
|
||||||
try {
|
try {
|
||||||
await updateProvider(provider.id, { is_active: !provider.is_active })
|
const newStatus = !provider.is_active
|
||||||
provider.is_active = !provider.is_active
|
await updateProvider(provider.id, { is_active: newStatus })
|
||||||
showSuccess(provider.is_active ? '提供商已启用' : '提供商已停用')
|
|
||||||
|
// 更新抽屉内部的 provider 对象
|
||||||
|
provider.is_active = newStatus
|
||||||
|
|
||||||
|
// 同时更新主页面 providers 数组中的对象,实现无感更新
|
||||||
|
const targetProvider = providers.value.find(p => p.id === provider.id)
|
||||||
|
if (targetProvider) {
|
||||||
|
targetProvider.is_active = newStatus
|
||||||
|
}
|
||||||
|
|
||||||
|
showSuccess(newStatus ? '提供商已启用' : '提供商已停用')
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
showError(err.response?.data?.detail || '操作失败', '错误')
|
showError(err.response?.data?.detail || '操作失败', '错误')
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -875,7 +875,8 @@ async function toggleUserStatus(user: any) {
|
|||||||
const action = user.is_active ? '禁用' : '启用'
|
const action = user.is_active ? '禁用' : '启用'
|
||||||
const confirmed = await confirmDanger(
|
const confirmed = await confirmDanger(
|
||||||
`确定要${action}用户 ${user.username} 吗?`,
|
`确定要${action}用户 ${user.username} 吗?`,
|
||||||
`${action}用户`
|
`${action}用户`,
|
||||||
|
action
|
||||||
)
|
)
|
||||||
|
|
||||||
if (!confirmed) return
|
if (!confirmed) return
|
||||||
@@ -884,7 +885,7 @@ async function toggleUserStatus(user: any) {
|
|||||||
await usersStore.updateUser(user.id, { is_active: !user.is_active })
|
await usersStore.updateUser(user.id, { is_active: !user.is_active })
|
||||||
success(`用户已${action}`)
|
success(`用户已${action}`)
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error(err.response?.data?.detail || '未知错误', `${action}用户失败`)
|
error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', `${action}用户失败`)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -955,7 +956,7 @@ async function handleUserFormSubmit(data: UserFormData & { password?: string })
|
|||||||
closeUserFormDialog()
|
closeUserFormDialog()
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
const title = data.id ? '更新用户失败' : '创建用户失败'
|
const title = data.id ? '更新用户失败' : '创建用户失败'
|
||||||
error(err.response?.data?.detail || '未知错误', title)
|
error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', title)
|
||||||
} finally {
|
} finally {
|
||||||
userFormDialogRef.value?.setSaving(false)
|
userFormDialogRef.value?.setSaving(false)
|
||||||
}
|
}
|
||||||
@@ -989,7 +990,7 @@ async function createApiKey() {
|
|||||||
showNewApiKeyDialog.value = true
|
showNewApiKeyDialog.value = true
|
||||||
await loadUserApiKeys(selectedUser.value.id)
|
await loadUserApiKeys(selectedUser.value.id)
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error(err.response?.data?.detail || '未知错误', '创建 API Key 失败')
|
error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', '创建 API Key 失败')
|
||||||
} finally {
|
} finally {
|
||||||
creatingApiKey.value = false
|
creatingApiKey.value = false
|
||||||
}
|
}
|
||||||
@@ -1026,7 +1027,7 @@ async function deleteApiKey(apiKey: any) {
|
|||||||
await loadUserApiKeys(selectedUser.value.id)
|
await loadUserApiKeys(selectedUser.value.id)
|
||||||
success('API Key已删除')
|
success('API Key已删除')
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error(err.response?.data?.detail || '未知错误', '删除 API Key 失败')
|
error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', '删除 API Key 失败')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1038,7 +1039,7 @@ async function copyFullKey(apiKey: any) {
|
|||||||
success('完整密钥已复制到剪贴板')
|
success('完整密钥已复制到剪贴板')
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
log.error('复制密钥失败:', err)
|
log.error('复制密钥失败:', err)
|
||||||
error(err.response?.data?.detail || '未知错误', '复制密钥失败')
|
error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', '复制密钥失败')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1054,7 +1055,7 @@ async function resetQuota(user: any) {
|
|||||||
await usersStore.resetUserQuota(user.id)
|
await usersStore.resetUserQuota(user.id)
|
||||||
success('配额已重置')
|
success('配额已重置')
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error(err.response?.data?.detail || '未知错误', '重置配额失败')
|
error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', '重置配额失败')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1070,7 +1071,7 @@ async function deleteUser(user: any) {
|
|||||||
await usersStore.deleteUser(user.id)
|
await usersStore.deleteUser(user.id)
|
||||||
success('用户已删除')
|
success('用户已删除')
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error(err.response?.data?.detail || '未知错误', '删除用户失败')
|
error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', '删除用户失败')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
@@ -102,9 +102,9 @@
|
|||||||
<!-- Main Content -->
|
<!-- Main Content -->
|
||||||
<main class="relative z-10">
|
<main class="relative z-10">
|
||||||
<!-- Fixed Logo Container -->
|
<!-- Fixed Logo Container -->
|
||||||
<div class="fixed inset-0 z-20 pointer-events-none flex items-center justify-center overflow-hidden">
|
<div class="mt-4 fixed inset-0 z-20 pointer-events-none flex items-center justify-center overflow-hidden">
|
||||||
<div
|
<div
|
||||||
class="transform-gpu logo-container"
|
class="mt-16 transform-gpu logo-container"
|
||||||
:class="[currentSection === SECTIONS.HOME ? 'home-section' : '', `logo-transition-${scrollDirection}`]"
|
:class="[currentSection === SECTIONS.HOME ? 'home-section' : '', `logo-transition-${scrollDirection}`]"
|
||||||
:style="fixedLogoStyle"
|
:style="fixedLogoStyle"
|
||||||
>
|
>
|
||||||
@@ -151,7 +151,7 @@
|
|||||||
class="min-h-screen snap-start flex items-center justify-center px-16 lg:px-20 py-20"
|
class="min-h-screen snap-start flex items-center justify-center px-16 lg:px-20 py-20"
|
||||||
>
|
>
|
||||||
<div class="max-w-4xl mx-auto text-center">
|
<div class="max-w-4xl mx-auto text-center">
|
||||||
<div class="h-80 w-full mb-16" />
|
<div class="h-80 w-full mb-16 mt-8" />
|
||||||
<h1
|
<h1
|
||||||
class="mb-6 text-5xl md:text-7xl font-bold text-[#191919] dark:text-white leading-tight transition-all duration-700"
|
class="mb-6 text-5xl md:text-7xl font-bold text-[#191919] dark:text-white leading-tight transition-all duration-700"
|
||||||
:style="getTitleStyle(SECTIONS.HOME)"
|
:style="getTitleStyle(SECTIONS.HOME)"
|
||||||
@@ -166,7 +166,7 @@
|
|||||||
整合 Claude Code、Codex CLI、Gemini CLI 等多个 AI 编程助手
|
整合 Claude Code、Codex CLI、Gemini CLI 等多个 AI 编程助手
|
||||||
</p>
|
</p>
|
||||||
<button
|
<button
|
||||||
class="mt-16 transition-all duration-700 cursor-pointer hover:scale-110"
|
class="mt-8 transition-all duration-700 cursor-pointer hover:scale-110"
|
||||||
:style="getScrollIndicatorStyle(SECTIONS.HOME)"
|
:style="getScrollIndicatorStyle(SECTIONS.HOME)"
|
||||||
@click="scrollToSection(SECTIONS.CLAUDE)"
|
@click="scrollToSection(SECTIONS.CLAUDE)"
|
||||||
>
|
>
|
||||||
|
|||||||
@@ -32,6 +32,17 @@ class ModelsQueryRequest(BaseModel):
|
|||||||
api_key_id: Optional[str] = None
|
api_key_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelRequest(BaseModel):
|
||||||
|
"""模型测试请求"""
|
||||||
|
|
||||||
|
provider_id: str
|
||||||
|
model_name: str
|
||||||
|
api_key_id: Optional[str] = None
|
||||||
|
stream: bool = False
|
||||||
|
message: Optional[str] = "你好"
|
||||||
|
api_format: Optional[str] = None # 指定使用的API格式,如果不指定则使用端点的默认格式
|
||||||
|
|
||||||
|
|
||||||
# ============ API Endpoints ============
|
# ============ API Endpoints ============
|
||||||
|
|
||||||
|
|
||||||
@@ -206,3 +217,228 @@ async def query_available_models(
|
|||||||
"display_name": provider.display_name,
|
"display_name": provider.display_name,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/test-model")
|
||||||
|
async def test_model(
|
||||||
|
request: TestModelRequest,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
测试模型连接性
|
||||||
|
|
||||||
|
向指定提供商的指定模型发送测试请求,验证模型是否可用
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: 测试请求
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
测试结果
|
||||||
|
"""
|
||||||
|
# 获取提供商及其端点
|
||||||
|
provider = (
|
||||||
|
db.query(Provider)
|
||||||
|
.options(joinedload(Provider.endpoints).joinedload(ProviderEndpoint.api_keys))
|
||||||
|
.filter(Provider.id == request.provider_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not provider:
|
||||||
|
raise HTTPException(status_code=404, detail="Provider not found")
|
||||||
|
|
||||||
|
# 找到合适的端点和API Key
|
||||||
|
endpoint_config = None
|
||||||
|
endpoint = None
|
||||||
|
api_key = None
|
||||||
|
|
||||||
|
if request.api_key_id:
|
||||||
|
# 使用指定的API Key
|
||||||
|
for ep in provider.endpoints:
|
||||||
|
for key in ep.api_keys:
|
||||||
|
if key.id == request.api_key_id and key.is_active and ep.is_active:
|
||||||
|
endpoint = ep
|
||||||
|
api_key = key
|
||||||
|
break
|
||||||
|
if endpoint:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# 使用第一个可用的端点和密钥
|
||||||
|
for ep in provider.endpoints:
|
||||||
|
if not ep.is_active or not ep.api_keys:
|
||||||
|
continue
|
||||||
|
for key in ep.api_keys:
|
||||||
|
if key.is_active:
|
||||||
|
endpoint = ep
|
||||||
|
api_key = key
|
||||||
|
break
|
||||||
|
if endpoint:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not endpoint or not api_key:
|
||||||
|
raise HTTPException(status_code=404, detail="No active endpoint or API key found")
|
||||||
|
|
||||||
|
try:
|
||||||
|
api_key_value = crypto_service.decrypt(api_key.api_key)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[test-model] Failed to decrypt API key: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to decrypt API key")
|
||||||
|
|
||||||
|
# 构建请求配置
|
||||||
|
endpoint_config = {
|
||||||
|
"api_key": api_key_value,
|
||||||
|
"api_key_id": api_key.id, # 添加API Key ID用于用量记录
|
||||||
|
"base_url": endpoint.base_url,
|
||||||
|
"api_format": endpoint.api_format,
|
||||||
|
"extra_headers": endpoint.headers,
|
||||||
|
"timeout": endpoint.timeout or 30.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 获取对应的 Adapter 类
|
||||||
|
adapter_class = _get_adapter_for_format(endpoint.api_format)
|
||||||
|
if not adapter_class:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"Unknown API format: {endpoint.api_format}",
|
||||||
|
"provider": {
|
||||||
|
"id": provider.id,
|
||||||
|
"name": provider.name,
|
||||||
|
"display_name": provider.display_name,
|
||||||
|
},
|
||||||
|
"model": request.model_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(f"[test-model] 使用 Adapter: {adapter_class.__name__}")
|
||||||
|
logger.debug(f"[test-model] 端点 API Format: {endpoint.api_format}")
|
||||||
|
|
||||||
|
# 如果请求指定了 api_format,优先使用它
|
||||||
|
target_api_format = request.api_format or endpoint.api_format
|
||||||
|
if request.api_format and request.api_format != endpoint.api_format:
|
||||||
|
logger.debug(f"[test-model] 请求指定 API Format: {request.api_format}")
|
||||||
|
# 重新获取适配器
|
||||||
|
adapter_class = _get_adapter_for_format(request.api_format)
|
||||||
|
if not adapter_class:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"Unknown API format: {request.api_format}",
|
||||||
|
"provider": {
|
||||||
|
"id": provider.id,
|
||||||
|
"name": provider.name,
|
||||||
|
"display_name": provider.display_name,
|
||||||
|
},
|
||||||
|
"model": request.model_name,
|
||||||
|
}
|
||||||
|
logger.debug(f"[test-model] 重新选择 Adapter: {adapter_class.__name__}")
|
||||||
|
|
||||||
|
# 准备测试请求数据
|
||||||
|
check_request = {
|
||||||
|
"model": request.model_name,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": request.message or "Hello! This is a test message."}
|
||||||
|
],
|
||||||
|
"max_tokens": 30,
|
||||||
|
"temperature": 0.7,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 发送测试请求
|
||||||
|
async with httpx.AsyncClient(timeout=endpoint_config["timeout"]) as client:
|
||||||
|
# 非流式测试
|
||||||
|
logger.debug(f"[test-model] 开始非流式测试...")
|
||||||
|
|
||||||
|
response = await adapter_class.check_endpoint(
|
||||||
|
client,
|
||||||
|
endpoint_config["base_url"],
|
||||||
|
endpoint_config["api_key"],
|
||||||
|
check_request,
|
||||||
|
endpoint_config.get("extra_headers"),
|
||||||
|
# 用量计算参数(现在强制记录)
|
||||||
|
db=db,
|
||||||
|
user=current_user,
|
||||||
|
provider_name=provider.name,
|
||||||
|
provider_id=provider.id,
|
||||||
|
api_key_id=endpoint_config.get("api_key_id"),
|
||||||
|
model_name=request.model_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 记录提供商返回信息
|
||||||
|
logger.debug(f"[test-model] 非流式测试结果:")
|
||||||
|
logger.debug(f"[test-model] Status Code: {response.get('status_code')}")
|
||||||
|
logger.debug(f"[test-model] Response Headers: {response.get('headers', {})}")
|
||||||
|
response_data = response.get('response', {})
|
||||||
|
response_body = response_data.get('response_body', {})
|
||||||
|
logger.debug(f"[test-model] Response Data: {response_data}")
|
||||||
|
logger.debug(f"[test-model] Response Body: {response_body}")
|
||||||
|
# 尝试解析 response_body (通常是 JSON 字符串)
|
||||||
|
parsed_body = response_body
|
||||||
|
import json
|
||||||
|
if isinstance(response_body, str):
|
||||||
|
try:
|
||||||
|
parsed_body = json.loads(response_body)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if isinstance(parsed_body, dict) and 'error' in parsed_body:
|
||||||
|
error_obj = parsed_body['error']
|
||||||
|
# 兼容 error 可能是字典或字符串的情况
|
||||||
|
if isinstance(error_obj, dict):
|
||||||
|
logger.debug(f"[test-model] Error Message: {error_obj.get('message')}")
|
||||||
|
raise HTTPException(status_code=500, detail=error_obj.get('message'))
|
||||||
|
else:
|
||||||
|
logger.debug(f"[test-model] Error: {error_obj}")
|
||||||
|
raise HTTPException(status_code=500, detail=error_obj)
|
||||||
|
elif 'error' in response:
|
||||||
|
logger.debug(f"[test-model] Error: {response['error']}")
|
||||||
|
raise HTTPException(status_code=500, detail=response['error'])
|
||||||
|
else:
|
||||||
|
# 如果有选择或消息,记录内容预览
|
||||||
|
if isinstance(response_data, dict):
|
||||||
|
if 'choices' in response_data and response_data['choices']:
|
||||||
|
choice = response_data['choices'][0]
|
||||||
|
if 'message' in choice:
|
||||||
|
content = choice['message'].get('content', '')
|
||||||
|
logger.debug(f"[test-model] Content Preview: {content[:200]}...")
|
||||||
|
elif 'content' in response_data and response_data['content']:
|
||||||
|
content = str(response_data['content'])
|
||||||
|
logger.debug(f"[test-model] Content Preview: {content[:200]}...")
|
||||||
|
|
||||||
|
# 检查测试是否成功(基于HTTP状态码)
|
||||||
|
status_code = response.get('status_code', 0)
|
||||||
|
is_success = status_code == 200 and 'error' not in response
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": is_success,
|
||||||
|
"data": {
|
||||||
|
"stream": False,
|
||||||
|
"response": response,
|
||||||
|
},
|
||||||
|
"provider": {
|
||||||
|
"id": provider.id,
|
||||||
|
"name": provider.name,
|
||||||
|
"display_name": provider.display_name,
|
||||||
|
},
|
||||||
|
"model": request.model_name,
|
||||||
|
"endpoint": {
|
||||||
|
"id": endpoint.id,
|
||||||
|
"api_format": endpoint.api_format,
|
||||||
|
"base_url": endpoint.base_url,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[test-model] Error testing model {request.model_name}: {e}")
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": str(e),
|
||||||
|
"provider": {
|
||||||
|
"id": provider.id,
|
||||||
|
"name": provider.name,
|
||||||
|
"display_name": provider.display_name,
|
||||||
|
},
|
||||||
|
"model": request.model_name,
|
||||||
|
"endpoint": {
|
||||||
|
"id": endpoint.id,
|
||||||
|
"api_format": endpoint.api_format,
|
||||||
|
"base_url": endpoint.base_url,
|
||||||
|
} if endpoint else None,
|
||||||
|
}
|
||||||
|
|||||||
@@ -63,6 +63,34 @@ class ChatAdapterBase(ApiAdapter):
|
|||||||
name: str = "chat.base"
|
name: str = "chat.base"
|
||||||
mode = ApiMode.STANDARD
|
mode = ApiMode.STANDARD
|
||||||
|
|
||||||
|
# 子类可以配置的特殊方法(用于check_endpoint)
|
||||||
|
@classmethod
|
||||||
|
def build_endpoint_url(cls, base_url: str) -> str:
|
||||||
|
"""构建端点URL,子类可以覆盖以自定义URL构建逻辑"""
|
||||||
|
# 默认实现:在base_url后添加特定路径
|
||||||
|
return base_url
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
|
||||||
|
"""构建基础请求头,子类可以覆盖以自定义认证头"""
|
||||||
|
# 默认实现:Bearer token认证
|
||||||
|
return {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_protected_header_keys(cls) -> tuple:
|
||||||
|
"""返回不应被extra_headers覆盖的头部key,子类可以覆盖"""
|
||||||
|
# 默认保护认证相关头部
|
||||||
|
return ("authorization", "content-type")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""构建请求体,子类可以覆盖以自定义请求格式转换"""
|
||||||
|
# 默认实现:直接使用请求数据
|
||||||
|
return request_data.copy()
|
||||||
|
|
||||||
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
||||||
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
|
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
|
||||||
|
|
||||||
@@ -654,6 +682,65 @@ class ChatAdapterBase(ApiAdapter):
|
|||||||
# 默认实现返回空列表,子类应覆盖
|
# 默认实现返回空列表,子类应覆盖
|
||||||
return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models"
|
return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def check_endpoint(
|
||||||
|
cls,
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
base_url: str,
|
||||||
|
api_key: str,
|
||||||
|
request_data: Dict[str, Any],
|
||||||
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
|
# 用量计算参数(现在强制记录)
|
||||||
|
db: Optional[Any] = None,
|
||||||
|
user: Optional[Any] = None,
|
||||||
|
provider_name: Optional[str] = None,
|
||||||
|
provider_id: Optional[str] = None,
|
||||||
|
api_key_id: Optional[str] = None,
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
测试模型连接性(非流式)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: httpx 异步客户端
|
||||||
|
base_url: API 基础 URL
|
||||||
|
api_key: API 密钥(已解密)
|
||||||
|
request_data: 请求数据
|
||||||
|
extra_headers: 端点配置的额外请求头
|
||||||
|
db: 数据库会话
|
||||||
|
user: 用户对象
|
||||||
|
provider_name: 提供商名称
|
||||||
|
provider_id: 提供商ID
|
||||||
|
api_key_id: API Key ID
|
||||||
|
model_name: 模型名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
测试响应数据
|
||||||
|
"""
|
||||||
|
from src.api.handlers.base.endpoint_checker import build_safe_headers, run_endpoint_check
|
||||||
|
|
||||||
|
# 使用子类配置方法构建请求组件
|
||||||
|
url = cls.build_endpoint_url(base_url)
|
||||||
|
base_headers = cls.build_base_headers(api_key)
|
||||||
|
protected_keys = cls.get_protected_header_keys()
|
||||||
|
headers = build_safe_headers(base_headers, extra_headers, protected_keys)
|
||||||
|
body = cls.build_request_body(request_data)
|
||||||
|
|
||||||
|
# 使用通用的endpoint checker执行请求
|
||||||
|
return await run_endpoint_check(
|
||||||
|
client=client,
|
||||||
|
url=url,
|
||||||
|
headers=headers,
|
||||||
|
json_body=body,
|
||||||
|
api_format=cls.name,
|
||||||
|
# 用量计算参数(现在强制记录)
|
||||||
|
db=db,
|
||||||
|
user=user,
|
||||||
|
provider_name=provider_name,
|
||||||
|
provider_id=provider_id,
|
||||||
|
api_key_id=api_key_id,
|
||||||
|
model_name=model_name or request_data.get("model"),
|
||||||
|
)
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# Adapter 注册表 - 用于根据 API format 获取 Adapter 实例
|
# Adapter 注册表 - 用于根据 API format 获取 Adapter 实例
|
||||||
|
|||||||
@@ -614,6 +614,146 @@ class CliAdapterBase(ApiAdapter):
|
|||||||
# 默认实现返回空列表,子类应覆盖
|
# 默认实现返回空列表,子类应覆盖
|
||||||
return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models"
|
return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def check_endpoint(
|
||||||
|
cls,
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
base_url: str,
|
||||||
|
api_key: str,
|
||||||
|
request_data: Dict[str, Any],
|
||||||
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
|
# 用量计算参数
|
||||||
|
db: Optional[Any] = None,
|
||||||
|
user: Optional[Any] = None,
|
||||||
|
provider_name: Optional[str] = None,
|
||||||
|
provider_id: Optional[str] = None,
|
||||||
|
api_key_id: Optional[str] = None,
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
测试模型连接性(非流式)
|
||||||
|
|
||||||
|
通用的CLI endpoint测试方法,使用配置方法模式:
|
||||||
|
- build_endpoint_url(): 构建请求URL
|
||||||
|
- build_base_headers(): 构建基础认证头
|
||||||
|
- get_protected_header_keys(): 获取受保护的头部key
|
||||||
|
- build_request_body(): 构建请求体
|
||||||
|
- get_cli_user_agent(): 获取CLI User-Agent(子类可覆盖)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: httpx 异步客户端
|
||||||
|
base_url: API 基础 URL
|
||||||
|
api_key: API 密钥(已解密)
|
||||||
|
request_data: 请求数据
|
||||||
|
extra_headers: 端点配置的额外请求头
|
||||||
|
db: 数据库会话
|
||||||
|
user: 用户对象
|
||||||
|
provider_name: 提供商名称
|
||||||
|
provider_id: 提供商ID
|
||||||
|
api_key_id: API密钥ID
|
||||||
|
model_name: 模型名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
测试响应数据
|
||||||
|
"""
|
||||||
|
from src.api.handlers.base.endpoint_checker import build_safe_headers, run_endpoint_check
|
||||||
|
|
||||||
|
# 构建请求组件
|
||||||
|
url = cls.build_endpoint_url(base_url, request_data, model_name)
|
||||||
|
base_headers = cls.build_base_headers(api_key)
|
||||||
|
protected_keys = cls.get_protected_header_keys()
|
||||||
|
|
||||||
|
# 添加CLI User-Agent
|
||||||
|
cli_user_agent = cls.get_cli_user_agent()
|
||||||
|
if cli_user_agent:
|
||||||
|
base_headers["User-Agent"] = cli_user_agent
|
||||||
|
protected_keys = tuple(list(protected_keys) + ["user-agent"])
|
||||||
|
|
||||||
|
headers = build_safe_headers(base_headers, extra_headers, protected_keys)
|
||||||
|
body = cls.build_request_body(request_data)
|
||||||
|
|
||||||
|
# 获取有效的模型名称
|
||||||
|
effective_model_name = model_name or request_data.get("model")
|
||||||
|
|
||||||
|
return await run_endpoint_check(
|
||||||
|
client=client,
|
||||||
|
url=url,
|
||||||
|
headers=headers,
|
||||||
|
json_body=body,
|
||||||
|
api_format=cls.name,
|
||||||
|
# 用量计算参数(现在强制记录)
|
||||||
|
db=db,
|
||||||
|
user=user,
|
||||||
|
provider_name=provider_name,
|
||||||
|
provider_id=provider_id,
|
||||||
|
api_key_id=api_key_id,
|
||||||
|
model_name=effective_model_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# CLI Adapter 配置方法 - 子类应覆盖这些方法而不是整个 check_endpoint
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_endpoint_url(cls, base_url: str, request_data: Dict[str, Any], model_name: Optional[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
构建CLI API端点URL - 子类应覆盖
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_url: API基础URL
|
||||||
|
request_data: 请求数据
|
||||||
|
model_name: 模型名称(某些API需要,如Gemini)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
完整的端点URL
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement build_endpoint_url")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
构建CLI API认证头 - 子类应覆盖
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API密钥
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
基础认证头部字典
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement build_base_headers")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_protected_header_keys(cls) -> tuple:
|
||||||
|
"""
|
||||||
|
返回CLI API的保护头部key - 子类应覆盖
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
保护头部key的元组
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement get_protected_header_keys")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
构建CLI API请求体 - 子类应覆盖
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_data: 请求数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
请求体字典
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement build_request_body")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_cli_user_agent(cls) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
获取CLI User-Agent - 子类可覆盖
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CLI User-Agent字符串,如果不需要则为None
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# CLI Adapter 注册表 - 用于根据 API format 获取 CLI Adapter 实例
|
# CLI Adapter 注册表 - 用于根据 API format 获取 CLI Adapter 实例
|
||||||
|
|||||||
1252
src/api/handlers/base/endpoint_checker.py
Normal file
1252
src/api/handlers/base/endpoint_checker.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -209,6 +209,38 @@ class ClaudeChatAdapter(ChatAdapterBase):
|
|||||||
logger.warning(f"Failed to fetch Claude models from {models_url}: {e}")
|
logger.warning(f"Failed to fetch Claude models from {models_url}: {e}")
|
||||||
return [], error_msg
|
return [], error_msg
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_endpoint_url(cls, base_url: str) -> str:
|
||||||
|
"""构建Claude API端点URL"""
|
||||||
|
base_url = base_url.rstrip("/")
|
||||||
|
if base_url.endswith("/v1"):
|
||||||
|
return f"{base_url}/messages"
|
||||||
|
else:
|
||||||
|
return f"{base_url}/v1/messages"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
|
||||||
|
"""构建Claude API认证头"""
|
||||||
|
return {
|
||||||
|
"x-api-key": api_key,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"anthropic-version": "2023-06-01",
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_protected_header_keys(cls) -> tuple:
|
||||||
|
"""返回Claude API的保护头部key"""
|
||||||
|
return ("x-api-key", "content-type", "anthropic-version")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""构建Claude API请求体"""
|
||||||
|
return {
|
||||||
|
"model": request_data.get("model"),
|
||||||
|
"max_tokens": request_data.get("max_tokens", 100),
|
||||||
|
"messages": request_data.get("messages", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_claude_adapter(x_app_header: Optional[str]):
|
def build_claude_adapter(x_app_header: Optional[str]):
|
||||||
"""根据 x-app 头部构造 Chat 或 Claude Code 适配器。"""
|
"""根据 x-app 头部构造 Chat 或 Claude Code 适配器。"""
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ Claude CLI Adapter - 基于通用 CLI Adapter 基类的简化实现
|
|||||||
继承 CliAdapterBase,只需配置 FORMAT_ID 和 HANDLER_CLASS。
|
继承 CliAdapterBase,只需配置 FORMAT_ID 和 HANDLER_CLASS。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Tuple, Type
|
from typing import Any, AsyncIterator, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
@@ -126,5 +126,41 @@ class ClaudeCliAdapter(CliAdapterBase):
|
|||||||
m["api_format"] = cls.FORMAT_ID
|
m["api_format"] = cls.FORMAT_ID
|
||||||
return models, error
|
return models, error
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_endpoint_url(cls, base_url: str, request_data: Dict[str, Any], model_name: Optional[str] = None) -> str:
|
||||||
|
"""构建Claude CLI API端点URL"""
|
||||||
|
base_url = base_url.rstrip("/")
|
||||||
|
if base_url.endswith("/v1"):
|
||||||
|
return f"{base_url}/messages"
|
||||||
|
else:
|
||||||
|
return f"{base_url}/v1/messages"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
|
||||||
|
"""构建Claude CLI API认证头"""
|
||||||
|
return {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_protected_header_keys(cls) -> tuple:
|
||||||
|
"""返回Claude CLI API的保护头部key"""
|
||||||
|
return ("authorization", "content-type")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""构建Claude CLI API请求体"""
|
||||||
|
return {
|
||||||
|
"model": request_data.get("model"),
|
||||||
|
"max_tokens": request_data.get("max_tokens", 100),
|
||||||
|
"messages": request_data.get("messages", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_cli_user_agent(cls) -> Optional[str]:
|
||||||
|
"""获取Claude CLI User-Agent"""
|
||||||
|
return config.internal_user_agent_claude_cli
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["ClaudeCliAdapter"]
|
__all__ = ["ClaudeCliAdapter"]
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ Gemini Chat Adapter
|
|||||||
处理 Gemini API 格式的请求适配
|
处理 Gemini API 格式的请求适配
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Tuple, Type
|
from typing import Any, AsyncIterator, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Request
|
||||||
@@ -12,6 +12,7 @@ from fastapi.responses import JSONResponse
|
|||||||
|
|
||||||
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
|
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
|
||||||
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
||||||
|
from src.api.handlers.base.endpoint_checker import build_safe_headers, run_endpoint_check
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
from src.models.gemini import GeminiRequest
|
from src.models.gemini import GeminiRequest
|
||||||
|
|
||||||
@@ -199,6 +200,94 @@ class GeminiChatAdapter(ChatAdapterBase):
|
|||||||
logger.warning(f"Failed to fetch Gemini models from {models_url}: {e}")
|
logger.warning(f"Failed to fetch Gemini models from {models_url}: {e}")
|
||||||
return [], error_msg
|
return [], error_msg
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_endpoint_url(cls, base_url: str) -> str:
|
||||||
|
"""构建Gemini API端点URL"""
|
||||||
|
base_url = base_url.rstrip("/")
|
||||||
|
if base_url.endswith("/v1beta"):
|
||||||
|
return base_url # 子类需要处理model参数
|
||||||
|
else:
|
||||||
|
return f"{base_url}/v1beta"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
|
||||||
|
"""构建Gemini API认证头"""
|
||||||
|
return {
|
||||||
|
"x-goog-api-key": api_key,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_protected_header_keys(cls) -> tuple:
|
||||||
|
"""返回Gemini API的保护头部key"""
|
||||||
|
return ("x-goog-api-key", "content-type")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""构建Gemini API请求体"""
|
||||||
|
return {
|
||||||
|
"contents": request_data.get("messages", []),
|
||||||
|
"generationConfig": {
|
||||||
|
"maxOutputTokens": request_data.get("max_tokens", 100),
|
||||||
|
"temperature": request_data.get("temperature", 0.7),
|
||||||
|
},
|
||||||
|
"safetySettings": [
|
||||||
|
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def check_endpoint(
|
||||||
|
cls,
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
base_url: str,
|
||||||
|
api_key: str,
|
||||||
|
request_data: Dict[str, Any],
|
||||||
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
|
# 用量计算参数
|
||||||
|
db: Optional[Any] = None,
|
||||||
|
user: Optional[Any] = None,
|
||||||
|
provider_name: Optional[str] = None,
|
||||||
|
provider_id: Optional[str] = None,
|
||||||
|
api_key_id: Optional[str] = None,
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""测试 Gemini API 模型连接性(非流式)"""
|
||||||
|
# Gemini需要从request_data或model_name参数获取model名称
|
||||||
|
effective_model_name = model_name or request_data.get("model", "")
|
||||||
|
if not effective_model_name:
|
||||||
|
return {
|
||||||
|
"error": "Model name is required for Gemini API",
|
||||||
|
"status_code": 400,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 使用基类配置方法,但重写URL构建逻辑
|
||||||
|
base_url = cls.build_endpoint_url(base_url)
|
||||||
|
url = f"{base_url}/models/{effective_model_name}:generateContent"
|
||||||
|
|
||||||
|
# 构建请求组件
|
||||||
|
base_headers = cls.build_base_headers(api_key)
|
||||||
|
protected_keys = cls.get_protected_header_keys()
|
||||||
|
headers = build_safe_headers(base_headers, extra_headers, protected_keys)
|
||||||
|
body = cls.build_request_body(request_data)
|
||||||
|
|
||||||
|
# 使用基类的通用endpoint checker
|
||||||
|
from src.api.handlers.base.endpoint_checker import run_endpoint_check
|
||||||
|
return await run_endpoint_check(
|
||||||
|
client=client,
|
||||||
|
url=url,
|
||||||
|
headers=headers,
|
||||||
|
json_body=body,
|
||||||
|
api_format=cls.name,
|
||||||
|
# 用量计算参数(现在强制记录)
|
||||||
|
db=db,
|
||||||
|
user=user,
|
||||||
|
provider_name=provider_name,
|
||||||
|
provider_id=provider_id,
|
||||||
|
api_key_id=api_key_id,
|
||||||
|
model_name=effective_model_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_gemini_adapter(x_app_header: str = "") -> GeminiChatAdapter:
|
def build_gemini_adapter(x_app_header: str = "") -> GeminiChatAdapter:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ Gemini CLI Adapter - 基于通用 CLI Adapter 基类的实现
|
|||||||
继承 CliAdapterBase,处理 Gemini CLI 格式的请求。
|
继承 CliAdapterBase,处理 Gemini CLI 格式的请求。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Tuple, Type
|
from typing import Any, AsyncIterator, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
@@ -123,6 +123,52 @@ class GeminiCliAdapter(CliAdapterBase):
|
|||||||
m["api_format"] = cls.FORMAT_ID
|
m["api_format"] = cls.FORMAT_ID
|
||||||
return models, error
|
return models, error
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_endpoint_url(cls, base_url: str, request_data: Dict[str, Any], model_name: Optional[str] = None) -> str:
|
||||||
|
"""构建Gemini CLI API端点URL"""
|
||||||
|
effective_model_name = model_name or request_data.get("model", "")
|
||||||
|
if not effective_model_name:
|
||||||
|
raise ValueError("Model name is required for Gemini API")
|
||||||
|
|
||||||
|
base_url = base_url.rstrip("/")
|
||||||
|
if base_url.endswith("/v1beta"):
|
||||||
|
prefix = base_url
|
||||||
|
else:
|
||||||
|
prefix = f"{base_url}/v1beta"
|
||||||
|
return f"{prefix}/models/{effective_model_name}:generateContent"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
|
||||||
|
"""构建Gemini CLI API认证头"""
|
||||||
|
return {
|
||||||
|
"x-goog-api-key": api_key,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_protected_header_keys(cls) -> tuple:
|
||||||
|
"""返回Gemini CLI API的保护头部key"""
|
||||||
|
return ("x-goog-api-key", "content-type")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""构建Gemini CLI API请求体"""
|
||||||
|
return {
|
||||||
|
"contents": request_data.get("messages", []),
|
||||||
|
"generationConfig": {
|
||||||
|
"maxOutputTokens": request_data.get("max_tokens", 100),
|
||||||
|
"temperature": request_data.get("temperature", 0.7),
|
||||||
|
},
|
||||||
|
"safetySettings": [
|
||||||
|
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_cli_user_agent(cls) -> Optional[str]:
|
||||||
|
"""获取Gemini CLI User-Agent"""
|
||||||
|
return config.internal_user_agent_gemini_cli
|
||||||
|
|
||||||
|
|
||||||
def build_gemini_cli_adapter(x_app_header: str = "") -> GeminiCliAdapter:
|
def build_gemini_cli_adapter(x_app_header: str = "") -> GeminiCliAdapter:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -4,13 +4,14 @@ OpenAI Chat Adapter - 基于 ChatAdapterBase 的 OpenAI Chat API 适配器
|
|||||||
处理 /v1/chat/completions 端点的 OpenAI Chat 格式请求。
|
处理 /v1/chat/completions 端点的 OpenAI Chat 格式请求。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Tuple, Type
|
from typing import Any, AsyncIterator, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
|
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
|
||||||
|
from src.api.handlers.base.endpoint_checker import build_safe_headers, run_endpoint_check
|
||||||
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
from src.models.openai import OpenAIRequest
|
from src.models.openai import OpenAIRequest
|
||||||
@@ -154,5 +155,32 @@ class OpenAIChatAdapter(ChatAdapterBase):
|
|||||||
logger.warning(f"Failed to fetch models from {models_url}: {e}")
|
logger.warning(f"Failed to fetch models from {models_url}: {e}")
|
||||||
return [], error_msg
|
return [], error_msg
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_endpoint_url(cls, base_url: str) -> str:
|
||||||
|
"""构建OpenAI API端点URL"""
|
||||||
|
base_url = base_url.rstrip("/")
|
||||||
|
if base_url.endswith("/v1"):
|
||||||
|
return f"{base_url}/chat/completions"
|
||||||
|
else:
|
||||||
|
return f"{base_url}/v1/chat/completions"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
|
||||||
|
"""构建OpenAI API认证头"""
|
||||||
|
return {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_protected_header_keys(cls) -> tuple:
|
||||||
|
"""返回OpenAI API的保护头部key"""
|
||||||
|
return ("authorization", "content-type")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""构建OpenAI API请求体"""
|
||||||
|
return request_data.copy()
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["OpenAIChatAdapter"]
|
__all__ = ["OpenAIChatAdapter"]
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ OpenAI CLI Adapter - 基于通用 CLI Adapter 基类的简化实现
|
|||||||
继承 CliAdapterBase,只需配置 FORMAT_ID 和 HANDLER_CLASS。
|
继承 CliAdapterBase,只需配置 FORMAT_ID 和 HANDLER_CLASS。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict, Optional, Tuple, Type
|
from typing import Any, AsyncIterator, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
@@ -68,5 +68,37 @@ class OpenAICliAdapter(CliAdapterBase):
|
|||||||
m["api_format"] = cls.FORMAT_ID
|
m["api_format"] = cls.FORMAT_ID
|
||||||
return models, error
|
return models, error
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_endpoint_url(cls, base_url: str, request_data: Dict[str, Any], model_name: Optional[str] = None) -> str:
|
||||||
|
"""构建OpenAI CLI API端点URL"""
|
||||||
|
base_url = base_url.rstrip("/")
|
||||||
|
if base_url.endswith("/v1"):
|
||||||
|
return f"{base_url}/chat/completions"
|
||||||
|
else:
|
||||||
|
return f"{base_url}/v1/chat/completions"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
|
||||||
|
"""构建OpenAI CLI API认证头"""
|
||||||
|
return {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_protected_header_keys(cls) -> tuple:
|
||||||
|
"""返回OpenAI CLI API的保护头部key"""
|
||||||
|
return ("authorization", "content-type")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""构建OpenAI CLI API请求体"""
|
||||||
|
return request_data.copy()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_cli_user_agent(cls) -> Optional[str]:
|
||||||
|
"""获取OpenAI CLI User-Agent"""
|
||||||
|
return config.internal_user_agent_openai_cli
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["OpenAICliAdapter"]
|
__all__ = ["OpenAICliAdapter"]
|
||||||
|
|||||||
@@ -8,116 +8,86 @@ from src.api.handlers.base.utils import build_sse_headers, extract_cache_creatio
|
|||||||
class TestExtractCacheCreationTokens:
|
class TestExtractCacheCreationTokens:
|
||||||
"""测试 extract_cache_creation_tokens 函数"""
|
"""测试 extract_cache_creation_tokens 函数"""
|
||||||
|
|
||||||
# === 嵌套格式测试(优先级最高)===
|
def test_new_format_only(self) -> None:
|
||||||
|
"""测试只有新格式字段"""
|
||||||
def test_nested_cache_creation_format(self) -> None:
|
|
||||||
"""测试嵌套格式正常情况"""
|
|
||||||
usage = {
|
|
||||||
"cache_creation": {
|
|
||||||
"ephemeral_5m_input_tokens": 456,
|
|
||||||
"ephemeral_1h_input_tokens": 100,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert extract_cache_creation_tokens(usage) == 556
|
|
||||||
|
|
||||||
def test_nested_cache_creation_with_old_format_fallback(self) -> None:
|
|
||||||
"""测试嵌套格式为 0 时回退到旧格式"""
|
|
||||||
usage = {
|
|
||||||
"cache_creation": {
|
|
||||||
"ephemeral_5m_input_tokens": 0,
|
|
||||||
"ephemeral_1h_input_tokens": 0,
|
|
||||||
},
|
|
||||||
"cache_creation_input_tokens": 549,
|
|
||||||
}
|
|
||||||
assert extract_cache_creation_tokens(usage) == 549
|
|
||||||
|
|
||||||
def test_nested_has_priority_over_flat(self) -> None:
|
|
||||||
"""测试嵌套格式优先于扁平格式"""
|
|
||||||
usage = {
|
|
||||||
"cache_creation": {
|
|
||||||
"ephemeral_5m_input_tokens": 100,
|
|
||||||
"ephemeral_1h_input_tokens": 200,
|
|
||||||
},
|
|
||||||
"claude_cache_creation_5_m_tokens": 999, # 应该被忽略
|
|
||||||
"claude_cache_creation_1_h_tokens": 888, # 应该被忽略
|
|
||||||
"cache_creation_input_tokens": 777, # 应该被忽略
|
|
||||||
}
|
|
||||||
assert extract_cache_creation_tokens(usage) == 300
|
|
||||||
|
|
||||||
# === 扁平格式测试(优先级第二)===
|
|
||||||
|
|
||||||
def test_flat_new_format_still_works(self) -> None:
|
|
||||||
"""测试扁平新格式兼容性"""
|
|
||||||
usage = {
|
usage = {
|
||||||
"claude_cache_creation_5_m_tokens": 100,
|
"claude_cache_creation_5_m_tokens": 100,
|
||||||
"claude_cache_creation_1_h_tokens": 200,
|
"claude_cache_creation_1_h_tokens": 200,
|
||||||
}
|
}
|
||||||
assert extract_cache_creation_tokens(usage) == 300
|
assert extract_cache_creation_tokens(usage) == 300
|
||||||
|
|
||||||
def test_flat_new_format_with_old_format_fallback(self) -> None:
|
def test_new_format_5m_only(self) -> None:
|
||||||
"""测试扁平格式为 0 时回退到旧格式"""
|
"""测试只有 5 分钟缓存"""
|
||||||
usage = {
|
|
||||||
"claude_cache_creation_5_m_tokens": 0,
|
|
||||||
"claude_cache_creation_1_h_tokens": 0,
|
|
||||||
"cache_creation_input_tokens": 549,
|
|
||||||
}
|
|
||||||
assert extract_cache_creation_tokens(usage) == 549
|
|
||||||
|
|
||||||
def test_flat_new_format_5m_only(self) -> None:
|
|
||||||
"""测试只有 5 分钟扁平缓存"""
|
|
||||||
usage = {
|
usage = {
|
||||||
"claude_cache_creation_5_m_tokens": 150,
|
"claude_cache_creation_5_m_tokens": 150,
|
||||||
"claude_cache_creation_1_h_tokens": 0,
|
"claude_cache_creation_1_h_tokens": 0,
|
||||||
}
|
}
|
||||||
assert extract_cache_creation_tokens(usage) == 150
|
assert extract_cache_creation_tokens(usage) == 150
|
||||||
|
|
||||||
def test_flat_new_format_1h_only(self) -> None:
|
def test_new_format_1h_only(self) -> None:
|
||||||
"""测试只有 1 小时扁平缓存"""
|
"""测试只有 1 小时缓存"""
|
||||||
usage = {
|
usage = {
|
||||||
"claude_cache_creation_5_m_tokens": 0,
|
"claude_cache_creation_5_m_tokens": 0,
|
||||||
"claude_cache_creation_1_h_tokens": 250,
|
"claude_cache_creation_1_h_tokens": 250,
|
||||||
}
|
}
|
||||||
assert extract_cache_creation_tokens(usage) == 250
|
assert extract_cache_creation_tokens(usage) == 250
|
||||||
|
|
||||||
# === 旧格式测试(优先级第三)===
|
|
||||||
|
|
||||||
def test_old_format_only(self) -> None:
|
def test_old_format_only(self) -> None:
|
||||||
"""测试只有旧格式"""
|
"""测试只有旧格式字段"""
|
||||||
usage = {
|
usage = {
|
||||||
"cache_creation_input_tokens": 549,
|
"cache_creation_input_tokens": 500,
|
||||||
}
|
}
|
||||||
assert extract_cache_creation_tokens(usage) == 549
|
assert extract_cache_creation_tokens(usage) == 500
|
||||||
|
|
||||||
# === 边界情况测试 ===
|
def test_both_formats_prefers_new(self) -> None:
|
||||||
|
"""测试同时存在时优先使用新格式"""
|
||||||
|
usage = {
|
||||||
|
"claude_cache_creation_5_m_tokens": 100,
|
||||||
|
"claude_cache_creation_1_h_tokens": 200,
|
||||||
|
"cache_creation_input_tokens": 999, # 应该被忽略
|
||||||
|
}
|
||||||
|
assert extract_cache_creation_tokens(usage) == 300
|
||||||
|
|
||||||
def test_no_cache_creation_tokens(self) -> None:
|
def test_empty_usage(self) -> None:
|
||||||
"""测试没有任何缓存字段"""
|
"""测试空字典"""
|
||||||
usage = {}
|
usage = {}
|
||||||
assert extract_cache_creation_tokens(usage) == 0
|
assert extract_cache_creation_tokens(usage) == 0
|
||||||
|
|
||||||
def test_all_formats_zero(self) -> None:
|
def test_all_zeros(self) -> None:
|
||||||
"""测试所有格式都为 0"""
|
"""测试所有字段都为 0"""
|
||||||
usage = {
|
usage = {
|
||||||
"cache_creation": {
|
|
||||||
"ephemeral_5m_input_tokens": 0,
|
|
||||||
"ephemeral_1h_input_tokens": 0,
|
|
||||||
},
|
|
||||||
"claude_cache_creation_5_m_tokens": 0,
|
"claude_cache_creation_5_m_tokens": 0,
|
||||||
"claude_cache_creation_1_h_tokens": 0,
|
"claude_cache_creation_1_h_tokens": 0,
|
||||||
"cache_creation_input_tokens": 0,
|
"cache_creation_input_tokens": 0,
|
||||||
}
|
}
|
||||||
assert extract_cache_creation_tokens(usage) == 0
|
assert extract_cache_creation_tokens(usage) == 0
|
||||||
|
|
||||||
|
def test_partial_new_format_with_old_format_fallback(self) -> None:
|
||||||
|
"""测试新格式字段不存在时回退到旧格式"""
|
||||||
|
usage = {
|
||||||
|
"cache_creation_input_tokens": 123,
|
||||||
|
}
|
||||||
|
assert extract_cache_creation_tokens(usage) == 123
|
||||||
|
|
||||||
|
def test_new_format_zero_should_not_fallback(self) -> None:
|
||||||
|
"""测试新格式字段存在但为 0 时,不应 fallback 到旧格式"""
|
||||||
|
usage = {
|
||||||
|
"claude_cache_creation_5_m_tokens": 0,
|
||||||
|
"claude_cache_creation_1_h_tokens": 0,
|
||||||
|
"cache_creation_input_tokens": 456,
|
||||||
|
}
|
||||||
|
# 新格式字段存在,即使值为 0 也应该使用新格式(返回 0)
|
||||||
|
# 而不是 fallback 到旧格式(返回 456)
|
||||||
|
assert extract_cache_creation_tokens(usage) == 0
|
||||||
|
|
||||||
def test_unrelated_fields_ignored(self) -> None:
|
def test_unrelated_fields_ignored(self) -> None:
|
||||||
"""测试忽略无关字段"""
|
"""测试忽略无关字段"""
|
||||||
usage = {
|
usage = {
|
||||||
"input_tokens": 1000,
|
"input_tokens": 1000,
|
||||||
"output_tokens": 2000,
|
"output_tokens": 2000,
|
||||||
"cache_read_input_tokens": 300,
|
"cache_read_input_tokens": 300,
|
||||||
"cache_creation": {
|
"claude_cache_creation_5_m_tokens": 50,
|
||||||
"ephemeral_5m_input_tokens": 50,
|
"claude_cache_creation_1_h_tokens": 75,
|
||||||
"ephemeral_1h_input_tokens": 75,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
assert extract_cache_creation_tokens(usage) == 125
|
assert extract_cache_creation_tokens(usage) == 125
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user