6 Commits

Author SHA1 Message Date
fawney19
2b1d197047 Merge remote-tracking branch 'gitcode/master' into htmambo/master 2025-12-25 22:47:08 +08:00
fawney19
71bc2e6aab fix: 增加参数校验防止除零错误 2025-12-25 22:44:17 +08:00
fawney19
afb329934a fix: 修复端点健康统计时间分段计算的除零错误 2025-12-25 19:54:16 +08:00
elky0401
1313af45a3 !4 merge htmambo/master into master
refactor: 重构模型测试错误解析逻辑并修复用量统计变量引用

Created-by: elky0401
Commit-by: fawney19;hoping
Merged-by: elky0401
Description: feat: 引入统一的端点检查器以重构适配器并改进错误处理和用量统计。
refactor: 重构模型测试错误解析逻辑并修复用量统计变量引用

See merge request: elky0401/Aether!4
2025-12-25 19:39:33 +08:00
fawney19
dddb327885 refactor: 重构模型测试错误解析逻辑并修复用量统计变量引用
- 将 ModelsTab 和 ModelAliasesTab 中重复的错误解析逻辑提取到 errorParser.ts
- 添加 parseTestModelError 函数统一处理测试响应错误
- 为 testModel API 添加 TypeScript 类型定义 (TestModelRequest/TestModelResponse)
- 修复 endpoint_checker.py 中 usage_data 变量引用错误
2025-12-25 19:36:29 +08:00
hoping
26b4a37323 feat: 引入统一的端点检查器以重构适配器并改进错误处理和用量统计。 2025-12-25 00:02:56 +08:00
25 changed files with 2299 additions and 121 deletions

View File

@@ -58,3 +58,38 @@ export async function deleteProvider(providerId: string): Promise<{ message: str
return response.data return response.data
} }
/**
* 测试模型连接性
*/
export interface TestModelRequest {
provider_id: string
model_name: string
api_key_id?: string
message?: string
api_format?: string
}
export interface TestModelResponse {
success: boolean
error?: string
data?: {
response?: {
status_code?: number
error?: string | { message?: string }
choices?: Array<{ message?: { content?: string } }>
}
content_preview?: string
}
provider?: {
id: string
name: string
display_name: string
}
model?: string
}
export async function testModel(data: TestModelRequest): Promise<TestModelResponse> {
const response = await client.post('/api/admin/provider-query/test-model', data)
return response.data
}

View File

@@ -163,7 +163,9 @@ const contentZIndex = computed(() => (props.zIndex || 60) + 10)
useEscapeKey(() => { useEscapeKey(() => {
if (isOpen.value) { if (isOpen.value) {
handleClose() handleClose()
return true // 阻止其他监听器(如父级抽屉的 ESC 监听器)
} }
return false
}, { }, {
disableOnInput: true, disableOnInput: true,
once: false once: false

View File

@@ -47,11 +47,11 @@ export function useConfirm() {
/** /**
* 便捷方法:危险操作确认(红色主题) * 便捷方法:危险操作确认(红色主题)
*/ */
const confirmDanger = (message: string, title?: string): Promise<boolean> => { const confirmDanger = (message: string, title?: string, confirmText?: string): Promise<boolean> => {
return confirm({ return confirm({
message, message,
title: title || '危险操作', title: title || '危险操作',
confirmText: '删除', confirmText: confirmText || '删除',
variant: 'danger' variant: 'danger'
}) })
} }

View File

@@ -4,11 +4,11 @@ import { onMounted, onUnmounted, ref } from 'vue'
* ESC 键监听 Composable简化版本直接使用独立监听器 * ESC 键监听 Composable简化版本直接使用独立监听器
* 用于按 ESC 键关闭弹窗或其他可关闭的组件 * 用于按 ESC 键关闭弹窗或其他可关闭的组件
* *
* @param callback - 按 ESC 键时执行的回调函数 * @param callback - 按 ESC 键时执行的回调函数,返回 true 表示已处理事件,阻止其他监听器执行
* @param options - 配置选项 * @param options - 配置选项
*/ */
export function useEscapeKey( export function useEscapeKey(
callback: () => void, callback: () => void | boolean,
options: { options: {
/** 是否在输入框获得焦点时禁用 ESC 键,默认 true */ /** 是否在输入框获得焦点时禁用 ESC 键,默认 true */
disableOnInput?: boolean disableOnInput?: boolean
@@ -42,8 +42,11 @@ export function useEscapeKey(
if (isInputElement) return if (isInputElement) return
} }
// 执行回调 // 执行回调,如果返回 true 则阻止其他监听器
callback() const handled = callback()
if (handled === true) {
event.stopImmediatePropagation()
}
// 移除当前元素的焦点,避免残留样式 // 移除当前元素的焦点,避免残留样式
if (document.activeElement instanceof HTMLElement) { if (document.activeElement instanceof HTMLElement) {

View File

@@ -17,7 +17,7 @@
v-model:open="modelSelectOpen" v-model:open="modelSelectOpen"
:model-value="formData.modelId" :model-value="formData.modelId"
:disabled="!!editingGroup" :disabled="!!editingGroup"
@update:model-value="formData.modelId = $event" @update:model-value="handleModelChange"
> >
<SelectTrigger class="h-9"> <SelectTrigger class="h-9">
<SelectValue placeholder="请选择模型" /> <SelectValue placeholder="请选择模型" />
@@ -519,6 +519,15 @@ function initForm() {
} }
} }
// 处理模型选择变更
function handleModelChange(value: string) {
formData.value.modelId = value
const selectedModel = props.models.find(m => m.id === value)
if (selectedModel) {
upstreamModelSearch.value = selectedModel.provider_model_name
}
}
// 切换 API 格式 // 切换 API 格式
function toggleApiFormat(format: string) { function toggleApiFormat(format: string) {
const index = formData.value.apiFormats.indexOf(format) const index = formData.value.apiFormats.indexOf(format)

View File

@@ -531,6 +531,7 @@
<!-- 模型名称映射 --> <!-- 模型名称映射 -->
<ModelAliasesTab <ModelAliasesTab
v-if="provider" v-if="provider"
ref="modelAliasesTabRef"
:key="`aliases-${provider.id}`" :key="`aliases-${provider.id}`"
:provider="provider" :provider="provider"
@refresh="handleRelatedDataRefresh" @refresh="handleRelatedDataRefresh"
@@ -735,6 +736,9 @@ const deleteModelConfirmOpen = ref(false)
const modelToDelete = ref<Model | null>(null) const modelToDelete = ref<Model | null>(null)
const batchAssignDialogOpen = ref(false) const batchAssignDialogOpen = ref(false)
// ModelAliasesTab 组件引用
const modelAliasesTabRef = ref<InstanceType<typeof ModelAliasesTab> | null>(null)
// 拖动排序相关状态 // 拖动排序相关状态
const dragState = ref({ const dragState = ref({
isDragging: false, isDragging: false,
@@ -756,7 +760,9 @@ const hasBlockingDialogOpen = computed(() =>
deleteKeyConfirmOpen.value || deleteKeyConfirmOpen.value ||
modelFormDialogOpen.value || modelFormDialogOpen.value ||
deleteModelConfirmOpen.value || deleteModelConfirmOpen.value ||
batchAssignDialogOpen.value batchAssignDialogOpen.value ||
// 检测 ModelAliasesTab 子组件的 Dialog 是否打开
modelAliasesTabRef.value?.dialogOpen
) )
// 监听 providerId 变化 // 监听 providerId 变化

View File

@@ -110,16 +110,30 @@
<div <div
v-for="mapping in group.aliases" v-for="mapping in group.aliases"
:key="mapping.name" :key="mapping.name"
class="flex items-center gap-2 py-1" class="flex items-center justify-between gap-2 py-1"
> >
<!-- 优先级标签 --> <div class="flex items-center gap-2 flex-1 min-w-0">
<span class="inline-flex items-center justify-center w-5 h-5 rounded bg-background border text-xs font-medium shrink-0"> <!-- 优先级标签 -->
{{ mapping.priority }} <span class="inline-flex items-center justify-center w-5 h-5 rounded bg-background border text-xs font-medium shrink-0">
</span> {{ mapping.priority }}
<!-- 映射名称 --> </span>
<span class="font-mono text-sm truncate"> <!-- 映射名称 -->
{{ mapping.name }} <span class="font-mono text-sm truncate">
</span> {{ mapping.name }}
</span>
</div>
<!-- 测试按钮 -->
<Button
variant="ghost"
size="icon"
class="h-7 w-7 shrink-0"
title="测试映射"
:disabled="testingMapping === `${group.model.id}-${group.apiFormatsKey}-${mapping.name}`"
@click="testMapping(group, mapping)"
>
<Loader2 v-if="testingMapping === `${group.model.id}-${group.apiFormatsKey}-${mapping.name}`" class="w-3 h-3 animate-spin" />
<Play v-else class="w-3 h-3" />
</Button>
</div> </div>
</div> </div>
</div> </div>
@@ -166,18 +180,20 @@
<script setup lang="ts"> <script setup lang="ts">
import { ref, computed, onMounted, watch } from 'vue' import { ref, computed, onMounted, watch } from 'vue'
import { Tag, Plus, Edit, Trash2, ChevronRight } from 'lucide-vue-next' import { Tag, Plus, Edit, Trash2, ChevronRight, Loader2, Play } from 'lucide-vue-next'
import { Card, Button, Badge } from '@/components/ui' import { Card, Button, Badge } from '@/components/ui'
import AlertDialog from '@/components/common/AlertDialog.vue' import AlertDialog from '@/components/common/AlertDialog.vue'
import ModelMappingDialog, { type AliasGroup } from '../ModelMappingDialog.vue' import ModelMappingDialog, { type AliasGroup } from '../ModelMappingDialog.vue'
import { useToast } from '@/composables/useToast' import { useToast } from '@/composables/useToast'
import { import {
getProviderModels, getProviderModels,
testModel,
API_FORMAT_LABELS, API_FORMAT_LABELS,
type Model, type Model,
type ProviderModelAlias type ProviderModelAlias
} from '@/api/endpoints' } from '@/api/endpoints'
import { updateModel } from '@/api/endpoints/models' import { updateModel } from '@/api/endpoints/models'
import { parseTestModelError } from '@/utils/errorParser'
const props = defineProps<{ const props = defineProps<{
provider: any provider: any
@@ -196,6 +212,7 @@ const dialogOpen = ref(false)
const deleteConfirmOpen = ref(false) const deleteConfirmOpen = ref(false)
const editingGroup = ref<AliasGroup | null>(null) const editingGroup = ref<AliasGroup | null>(null)
const deletingGroup = ref<AliasGroup | null>(null) const deletingGroup = ref<AliasGroup | null>(null)
const testingMapping = ref<string | null>(null)
// 列表展开状态 // 列表展开状态
const expandedAliasGroups = ref<Set<string>>(new Set()) const expandedAliasGroups = ref<Set<string>>(new Set())
@@ -337,6 +354,49 @@ async function onDialogSaved() {
emit('refresh') emit('refresh')
} }
// 测试模型映射
async function testMapping(group: any, mapping: any) {
const testingKey = `${group.model.id}-${group.apiFormatsKey}-${mapping.name}`
testingMapping.value = testingKey
try {
// 根据分组的 API 格式来确定应该使用的格式
let apiFormat = null
if (group.apiFormats.length === 1) {
apiFormat = group.apiFormats[0]
} else if (group.apiFormats.length === 0) {
// 如果没有指定格式,但分组显示为"全部",则使用模型的默认格式
apiFormat = group.model.effective_api_format || group.model.api_format
}
const result = await testModel({
provider_id: props.provider.id,
model_name: mapping.name, // 使用映射名称进行测试
message: "hello",
api_format: apiFormat
})
if (result.success) {
showSuccess(`映射 "${mapping.name}" 测试成功`)
// 如果有响应内容,可以显示更多信息
if (result.data?.response?.choices?.[0]?.message?.content) {
const content = result.data.response.choices[0].message.content
showSuccess(`测试成功,响应: ${content.substring(0, 100)}${content.length > 100 ? '...' : ''}`)
} else if (result.data?.content_preview) {
showSuccess(`流式测试成功,预览: ${result.data.content_preview}`)
}
} else {
showError(`映射测试失败: ${parseTestModelError(result)}`)
}
} catch (err: any) {
const errorMsg = err.response?.data?.detail || err.message || '测试请求失败'
showError(`映射测试失败: ${errorMsg}`)
} finally {
testingMapping.value = null
}
}
// 监听 provider 变化 // 监听 provider 变化
watch(() => props.provider?.id, (newId) => { watch(() => props.provider?.id, (newId) => {
if (newId) { if (newId) {
@@ -349,4 +409,9 @@ onMounted(() => {
loadModels() loadModels()
} }
}) })
// 暴露给父组件,用于检测是否有弹窗打开
defineExpose({
dialogOpen: computed(() => dialogOpen.value || deleteConfirmOpen.value)
})
</script> </script>

View File

@@ -156,6 +156,17 @@
</td> </td>
<td class="align-top px-4 py-3"> <td class="align-top px-4 py-3">
<div class="flex justify-center gap-1.5"> <div class="flex justify-center gap-1.5">
<Button
variant="ghost"
size="icon"
class="h-8 w-8"
title="测试模型"
:disabled="testingModelId === model.id"
@click="testModelConnection(model)"
>
<Loader2 v-if="testingModelId === model.id" class="w-3.5 h-3.5 animate-spin" />
<Play v-else class="w-3.5 h-3.5" />
</Button>
<Button <Button
variant="ghost" variant="ghost"
size="icon" size="icon"
@@ -209,12 +220,13 @@
<script setup lang="ts"> <script setup lang="ts">
import { ref, computed, onMounted } from 'vue' import { ref, computed, onMounted } from 'vue'
import { Box, Edit, Trash2, Layers, Eye, Wrench, Zap, Brain, Power, Copy, Image } from 'lucide-vue-next' import { Box, Edit, Trash2, Layers, Eye, Wrench, Zap, Brain, Power, Copy, Image, Loader2, Play } from 'lucide-vue-next'
import Card from '@/components/ui/card.vue' import Card from '@/components/ui/card.vue'
import Button from '@/components/ui/button.vue' import Button from '@/components/ui/button.vue'
import { useToast } from '@/composables/useToast' import { useToast } from '@/composables/useToast'
import { getProviderModels, type Model } from '@/api/endpoints' import { getProviderModels, testModel, type Model } from '@/api/endpoints'
import { updateModel } from '@/api/endpoints/models' import { updateModel } from '@/api/endpoints/models'
import { parseTestModelError } from '@/utils/errorParser'
const props = defineProps<{ const props = defineProps<{
provider: any provider: any
@@ -232,6 +244,7 @@ const { error: showError, success: showSuccess } = useToast()
const loading = ref(false) const loading = ref(false)
const models = ref<Model[]>([]) const models = ref<Model[]>([])
const togglingModelId = ref<string | null>(null) const togglingModelId = ref<string | null>(null)
const testingModelId = ref<string | null>(null)
// 按名称排序的模型列表 // 按名称排序的模型列表
const sortedModels = computed(() => { const sortedModels = computed(() => {
@@ -380,6 +393,39 @@ async function toggleModelActive(model: Model) {
} }
} }
// 测试模型连接性
async function testModelConnection(model: Model) {
if (testingModelId.value) return
testingModelId.value = model.id
try {
const result = await testModel({
provider_id: props.provider.id,
model_name: model.provider_model_name,
message: "hello"
})
if (result.success) {
showSuccess(`模型 "${model.provider_model_name}" 测试成功`)
// 如果有响应内容,可以显示更多信息
if (result.data?.response?.choices?.[0]?.message?.content) {
const content = result.data.response.choices[0].message.content
showSuccess(`测试成功,响应: ${content.substring(0, 100)}${content.length > 100 ? '...' : ''}`)
} else if (result.data?.content_preview) {
showSuccess(`流式测试成功,预览: ${result.data.content_preview}`)
}
} else {
showError(`模型测试失败: ${parseTestModelError(result)}`)
}
} catch (err: any) {
const errorMsg = err.response?.data?.detail || err.message || '测试请求失败'
showError(`模型测试失败: ${errorMsg}`)
} finally {
testingModelId.value = null
}
}
onMounted(() => { onMounted(() => {
loadModels() loadModels()
}) })

View File

@@ -14,7 +14,7 @@ export const useUsersStore = defineStore('users', () => {
try { try {
users.value = await usersApi.getAllUsers() users.value = await usersApi.getAllUsers()
} catch (err: any) { } catch (err: any) {
error.value = err.response?.data?.detail || '获取用户列表失败' error.value = err.response?.data?.error?.message || err.response?.data?.detail || '获取用户列表失败'
} finally { } finally {
loading.value = false loading.value = false
} }
@@ -29,7 +29,7 @@ export const useUsersStore = defineStore('users', () => {
users.value.push(newUser) users.value.push(newUser)
return newUser return newUser
} catch (err: any) { } catch (err: any) {
error.value = err.response?.data?.detail || '创建用户失败' error.value = err.response?.data?.error?.message || err.response?.data?.detail || '创建用户失败'
throw err throw err
} finally { } finally {
loading.value = false loading.value = false
@@ -52,7 +52,7 @@ export const useUsersStore = defineStore('users', () => {
} }
return updatedUser return updatedUser
} catch (err: any) { } catch (err: any) {
error.value = err.response?.data?.detail || '更新用户失败' error.value = err.response?.data?.error?.message || err.response?.data?.detail || '更新用户失败'
throw err throw err
} finally { } finally {
loading.value = false loading.value = false
@@ -67,7 +67,7 @@ export const useUsersStore = defineStore('users', () => {
await usersApi.deleteUser(userId) await usersApi.deleteUser(userId)
users.value = users.value.filter(u => u.id !== userId) users.value = users.value.filter(u => u.id !== userId)
} catch (err: any) { } catch (err: any) {
error.value = err.response?.data?.detail || '删除用户失败' error.value = err.response?.data?.error?.message || err.response?.data?.detail || '删除用户失败'
throw err throw err
} finally { } finally {
loading.value = false loading.value = false
@@ -78,7 +78,7 @@ export const useUsersStore = defineStore('users', () => {
try { try {
return await usersApi.getUserApiKeys(userId) return await usersApi.getUserApiKeys(userId)
} catch (err: any) { } catch (err: any) {
error.value = err.response?.data?.detail || '获取 API Keys 失败' error.value = err.response?.data?.error?.message || err.response?.data?.detail || '获取 API Keys 失败'
throw err throw err
} }
} }
@@ -87,7 +87,7 @@ export const useUsersStore = defineStore('users', () => {
try { try {
return await usersApi.createApiKey(userId, name) return await usersApi.createApiKey(userId, name)
} catch (err: any) { } catch (err: any) {
error.value = err.response?.data?.detail || '创建 API Key 失败' error.value = err.response?.data?.error?.message || err.response?.data?.detail || '创建 API Key 失败'
throw err throw err
} }
} }
@@ -96,7 +96,7 @@ export const useUsersStore = defineStore('users', () => {
try { try {
await usersApi.deleteApiKey(userId, keyId) await usersApi.deleteApiKey(userId, keyId)
} catch (err: any) { } catch (err: any) {
error.value = err.response?.data?.detail || '删除 API Key 失败' error.value = err.response?.data?.error?.message || err.response?.data?.detail || '删除 API Key 失败'
throw err throw err
} }
} }
@@ -110,7 +110,7 @@ export const useUsersStore = defineStore('users', () => {
// 刷新用户列表以获取最新数据 // 刷新用户列表以获取最新数据
await fetchUsers() await fetchUsers()
} catch (err: any) { } catch (err: any) {
error.value = err.response?.data?.detail || '重置配额失败' error.value = err.response?.data?.error?.message || err.response?.data?.detail || '重置配额失败'
throw err throw err
} finally { } finally {
loading.value = false loading.value = false

View File

@@ -198,3 +198,49 @@ export function parseApiErrorShort(err: unknown, defaultMessage: string = '操
const lines = fullError.split('\n') const lines = fullError.split('\n')
return lines[0] || defaultMessage return lines[0] || defaultMessage
} }
/**
* 解析模型测试响应的错误信息
* @param result 测试响应结果
* @returns 格式化的错误信息
*/
export function parseTestModelError(result: {
error?: string
data?: {
response?: {
status_code?: number
error?: string | { message?: string }
}
}
}): string {
let errorMsg = result.error || '测试失败'
// 检查HTTP状态码错误
if (result.data?.response?.status_code) {
const status = result.data.response.status_code
if (status === 403) {
errorMsg = '认证失败: API密钥无效或客户端类型不被允许'
} else if (status === 401) {
errorMsg = '认证失败: API密钥无效或已过期'
} else if (status === 404) {
errorMsg = '模型不存在: 请检查模型名称是否正确'
} else if (status === 429) {
errorMsg = '请求频率过高: 请稍后重试'
} else if (status >= 500) {
errorMsg = `服务器错误: HTTP ${status}`
} else {
errorMsg = `请求失败: HTTP ${status}`
}
}
// 尝试从错误响应中提取更多信息
if (result.data?.response?.error) {
if (typeof result.data.response.error === 'string') {
errorMsg = result.data.response.error
} else if (result.data.response.error?.message) {
errorMsg = result.data.response.error.message
}
}
return errorMsg
}

View File

@@ -723,9 +723,19 @@ async function handleDeleteProvider(provider: ProviderWithEndpointsSummary) {
// 切换提供商状态 // 切换提供商状态
async function toggleProviderStatus(provider: ProviderWithEndpointsSummary) { async function toggleProviderStatus(provider: ProviderWithEndpointsSummary) {
try { try {
await updateProvider(provider.id, { is_active: !provider.is_active }) const newStatus = !provider.is_active
provider.is_active = !provider.is_active await updateProvider(provider.id, { is_active: newStatus })
showSuccess(provider.is_active ? '提供商已启用' : '提供商已停用')
// 更新抽屉内部的 provider 对象
provider.is_active = newStatus
// 同时更新主页面 providers 数组中的对象,实现无感更新
const targetProvider = providers.value.find(p => p.id === provider.id)
if (targetProvider) {
targetProvider.is_active = newStatus
}
showSuccess(newStatus ? '提供商已启用' : '提供商已停用')
} catch (err: any) { } catch (err: any) {
showError(err.response?.data?.detail || '操作失败', '错误') showError(err.response?.data?.detail || '操作失败', '错误')
} }

View File

@@ -875,7 +875,8 @@ async function toggleUserStatus(user: any) {
const action = user.is_active ? '禁用' : '启用' const action = user.is_active ? '禁用' : '启用'
const confirmed = await confirmDanger( const confirmed = await confirmDanger(
`确定要${action}用户 ${user.username} 吗?`, `确定要${action}用户 ${user.username} 吗?`,
`${action}用户` `${action}用户`,
action
) )
if (!confirmed) return if (!confirmed) return
@@ -884,7 +885,7 @@ async function toggleUserStatus(user: any) {
await usersStore.updateUser(user.id, { is_active: !user.is_active }) await usersStore.updateUser(user.id, { is_active: !user.is_active })
success(`用户已${action}`) success(`用户已${action}`)
} catch (err: any) { } catch (err: any) {
error(err.response?.data?.detail || '未知错误', `${action}用户失败`) error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', `${action}用户失败`)
} }
} }
@@ -955,7 +956,7 @@ async function handleUserFormSubmit(data: UserFormData & { password?: string })
closeUserFormDialog() closeUserFormDialog()
} catch (err: any) { } catch (err: any) {
const title = data.id ? '更新用户失败' : '创建用户失败' const title = data.id ? '更新用户失败' : '创建用户失败'
error(err.response?.data?.detail || '未知错误', title) error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', title)
} finally { } finally {
userFormDialogRef.value?.setSaving(false) userFormDialogRef.value?.setSaving(false)
} }
@@ -989,7 +990,7 @@ async function createApiKey() {
showNewApiKeyDialog.value = true showNewApiKeyDialog.value = true
await loadUserApiKeys(selectedUser.value.id) await loadUserApiKeys(selectedUser.value.id)
} catch (err: any) { } catch (err: any) {
error(err.response?.data?.detail || '未知错误', '创建 API Key 失败') error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', '创建 API Key 失败')
} finally { } finally {
creatingApiKey.value = false creatingApiKey.value = false
} }
@@ -1026,7 +1027,7 @@ async function deleteApiKey(apiKey: any) {
await loadUserApiKeys(selectedUser.value.id) await loadUserApiKeys(selectedUser.value.id)
success('API Key已删除') success('API Key已删除')
} catch (err: any) { } catch (err: any) {
error(err.response?.data?.detail || '未知错误', '删除 API Key 失败') error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', '删除 API Key 失败')
} }
} }
@@ -1038,7 +1039,7 @@ async function copyFullKey(apiKey: any) {
success('完整密钥已复制到剪贴板') success('完整密钥已复制到剪贴板')
} catch (err: any) { } catch (err: any) {
log.error('复制密钥失败:', err) log.error('复制密钥失败:', err)
error(err.response?.data?.detail || '未知错误', '复制密钥失败') error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', '复制密钥失败')
} }
} }
@@ -1054,7 +1055,7 @@ async function resetQuota(user: any) {
await usersStore.resetUserQuota(user.id) await usersStore.resetUserQuota(user.id)
success('配额已重置') success('配额已重置')
} catch (err: any) { } catch (err: any) {
error(err.response?.data?.detail || '未知错误', '重置配额失败') error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', '重置配额失败')
} }
} }
@@ -1070,7 +1071,7 @@ async function deleteUser(user: any) {
await usersStore.deleteUser(user.id) await usersStore.deleteUser(user.id)
success('用户已删除') success('用户已删除')
} catch (err: any) { } catch (err: any) {
error(err.response?.data?.detail || '未知错误', '删除用户失败') error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', '删除用户失败')
} }
} }
</script> </script>

View File

@@ -102,9 +102,9 @@
<!-- Main Content --> <!-- Main Content -->
<main class="relative z-10"> <main class="relative z-10">
<!-- Fixed Logo Container --> <!-- Fixed Logo Container -->
<div class="fixed inset-0 z-20 pointer-events-none flex items-center justify-center overflow-hidden"> <div class="mt-4 fixed inset-0 z-20 pointer-events-none flex items-center justify-center overflow-hidden">
<div <div
class="transform-gpu logo-container" class="mt-16 transform-gpu logo-container"
:class="[currentSection === SECTIONS.HOME ? 'home-section' : '', `logo-transition-${scrollDirection}`]" :class="[currentSection === SECTIONS.HOME ? 'home-section' : '', `logo-transition-${scrollDirection}`]"
:style="fixedLogoStyle" :style="fixedLogoStyle"
> >
@@ -151,7 +151,7 @@
class="min-h-screen snap-start flex items-center justify-center px-16 lg:px-20 py-20" class="min-h-screen snap-start flex items-center justify-center px-16 lg:px-20 py-20"
> >
<div class="max-w-4xl mx-auto text-center"> <div class="max-w-4xl mx-auto text-center">
<div class="h-80 w-full mb-16" /> <div class="h-80 w-full mb-16 mt-8" />
<h1 <h1
class="mb-6 text-5xl md:text-7xl font-bold text-[#191919] dark:text-white leading-tight transition-all duration-700" class="mb-6 text-5xl md:text-7xl font-bold text-[#191919] dark:text-white leading-tight transition-all duration-700"
:style="getTitleStyle(SECTIONS.HOME)" :style="getTitleStyle(SECTIONS.HOME)"
@@ -166,7 +166,7 @@
整合 Claude CodeCodex CLIGemini CLI 等多个 AI 编程助手 整合 Claude CodeCodex CLIGemini CLI 等多个 AI 编程助手
</p> </p>
<button <button
class="mt-16 transition-all duration-700 cursor-pointer hover:scale-110" class="mt-8 transition-all duration-700 cursor-pointer hover:scale-110"
:style="getScrollIndicatorStyle(SECTIONS.HOME)" :style="getScrollIndicatorStyle(SECTIONS.HOME)"
@click="scrollToSection(SECTIONS.CLAUDE)" @click="scrollToSection(SECTIONS.CLAUDE)"
> >

View File

@@ -32,6 +32,17 @@ class ModelsQueryRequest(BaseModel):
api_key_id: Optional[str] = None api_key_id: Optional[str] = None
class TestModelRequest(BaseModel):
"""模型测试请求"""
provider_id: str
model_name: str
api_key_id: Optional[str] = None
stream: bool = False
message: Optional[str] = "你好"
api_format: Optional[str] = None # 指定使用的API格式如果不指定则使用端点的默认格式
# ============ API Endpoints ============ # ============ API Endpoints ============
@@ -206,3 +217,228 @@ async def query_available_models(
"display_name": provider.display_name, "display_name": provider.display_name,
}, },
} }
@router.post("/test-model")
async def test_model(
request: TestModelRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
测试模型连接性
向指定提供商的指定模型发送测试请求,验证模型是否可用
Args:
request: 测试请求
Returns:
测试结果
"""
# 获取提供商及其端点
provider = (
db.query(Provider)
.options(joinedload(Provider.endpoints).joinedload(ProviderEndpoint.api_keys))
.filter(Provider.id == request.provider_id)
.first()
)
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
# 找到合适的端点和API Key
endpoint_config = None
endpoint = None
api_key = None
if request.api_key_id:
# 使用指定的API Key
for ep in provider.endpoints:
for key in ep.api_keys:
if key.id == request.api_key_id and key.is_active and ep.is_active:
endpoint = ep
api_key = key
break
if endpoint:
break
else:
# 使用第一个可用的端点和密钥
for ep in provider.endpoints:
if not ep.is_active or not ep.api_keys:
continue
for key in ep.api_keys:
if key.is_active:
endpoint = ep
api_key = key
break
if endpoint:
break
if not endpoint or not api_key:
raise HTTPException(status_code=404, detail="No active endpoint or API key found")
try:
api_key_value = crypto_service.decrypt(api_key.api_key)
except Exception as e:
logger.error(f"[test-model] Failed to decrypt API key: {e}")
raise HTTPException(status_code=500, detail="Failed to decrypt API key")
# 构建请求配置
endpoint_config = {
"api_key": api_key_value,
"api_key_id": api_key.id, # 添加API Key ID用于用量记录
"base_url": endpoint.base_url,
"api_format": endpoint.api_format,
"extra_headers": endpoint.headers,
"timeout": endpoint.timeout or 30.0,
}
try:
# 获取对应的 Adapter 类
adapter_class = _get_adapter_for_format(endpoint.api_format)
if not adapter_class:
return {
"success": False,
"error": f"Unknown API format: {endpoint.api_format}",
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
"model": request.model_name,
}
logger.debug(f"[test-model] 使用 Adapter: {adapter_class.__name__}")
logger.debug(f"[test-model] 端点 API Format: {endpoint.api_format}")
# 如果请求指定了 api_format优先使用它
target_api_format = request.api_format or endpoint.api_format
if request.api_format and request.api_format != endpoint.api_format:
logger.debug(f"[test-model] 请求指定 API Format: {request.api_format}")
# 重新获取适配器
adapter_class = _get_adapter_for_format(request.api_format)
if not adapter_class:
return {
"success": False,
"error": f"Unknown API format: {request.api_format}",
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
"model": request.model_name,
}
logger.debug(f"[test-model] 重新选择 Adapter: {adapter_class.__name__}")
# 准备测试请求数据
check_request = {
"model": request.model_name,
"messages": [
{"role": "user", "content": request.message or "Hello! This is a test message."}
],
"max_tokens": 30,
"temperature": 0.7,
}
# 发送测试请求
async with httpx.AsyncClient(timeout=endpoint_config["timeout"]) as client:
# 非流式测试
logger.debug(f"[test-model] 开始非流式测试...")
response = await adapter_class.check_endpoint(
client,
endpoint_config["base_url"],
endpoint_config["api_key"],
check_request,
endpoint_config.get("extra_headers"),
# 用量计算参数(现在强制记录)
db=db,
user=current_user,
provider_name=provider.name,
provider_id=provider.id,
api_key_id=endpoint_config.get("api_key_id"),
model_name=request.model_name,
)
# 记录提供商返回信息
logger.debug(f"[test-model] 非流式测试结果:")
logger.debug(f"[test-model] Status Code: {response.get('status_code')}")
logger.debug(f"[test-model] Response Headers: {response.get('headers', {})}")
response_data = response.get('response', {})
response_body = response_data.get('response_body', {})
logger.debug(f"[test-model] Response Data: {response_data}")
logger.debug(f"[test-model] Response Body: {response_body}")
# 尝试解析 response_body (通常是 JSON 字符串)
parsed_body = response_body
import json
if isinstance(response_body, str):
try:
parsed_body = json.loads(response_body)
except json.JSONDecodeError:
pass
if isinstance(parsed_body, dict) and 'error' in parsed_body:
error_obj = parsed_body['error']
# 兼容 error 可能是字典或字符串的情况
if isinstance(error_obj, dict):
logger.debug(f"[test-model] Error Message: {error_obj.get('message')}")
raise HTTPException(status_code=500, detail=error_obj.get('message'))
else:
logger.debug(f"[test-model] Error: {error_obj}")
raise HTTPException(status_code=500, detail=error_obj)
elif 'error' in response:
logger.debug(f"[test-model] Error: {response['error']}")
raise HTTPException(status_code=500, detail=response['error'])
else:
# 如果有选择或消息,记录内容预览
if isinstance(response_data, dict):
if 'choices' in response_data and response_data['choices']:
choice = response_data['choices'][0]
if 'message' in choice:
content = choice['message'].get('content', '')
logger.debug(f"[test-model] Content Preview: {content[:200]}...")
elif 'content' in response_data and response_data['content']:
content = str(response_data['content'])
logger.debug(f"[test-model] Content Preview: {content[:200]}...")
# 检查测试是否成功基于HTTP状态码
status_code = response.get('status_code', 0)
is_success = status_code == 200 and 'error' not in response
return {
"success": is_success,
"data": {
"stream": False,
"response": response,
},
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
"model": request.model_name,
"endpoint": {
"id": endpoint.id,
"api_format": endpoint.api_format,
"base_url": endpoint.base_url,
},
}
except Exception as e:
logger.error(f"[test-model] Error testing model {request.model_name}: {e}")
return {
"success": False,
"error": str(e),
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
"model": request.model_name,
"endpoint": {
"id": endpoint.id,
"api_format": endpoint.api_format,
"base_url": endpoint.base_url,
} if endpoint else None,
}

View File

@@ -63,6 +63,34 @@ class ChatAdapterBase(ApiAdapter):
name: str = "chat.base" name: str = "chat.base"
mode = ApiMode.STANDARD mode = ApiMode.STANDARD
# 子类可以配置的特殊方法用于check_endpoint
@classmethod
def build_endpoint_url(cls, base_url: str) -> str:
"""构建端点URL子类可以覆盖以自定义URL构建逻辑"""
# 默认实现在base_url后添加特定路径
return base_url
@classmethod
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
"""构建基础请求头,子类可以覆盖以自定义认证头"""
# 默认实现Bearer token认证
return {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
@classmethod
def get_protected_header_keys(cls) -> tuple:
"""返回不应被extra_headers覆盖的头部key子类可以覆盖"""
# 默认保护认证相关头部
return ("authorization", "content-type")
@classmethod
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""构建请求体,子类可以覆盖以自定义请求格式转换"""
# 默认实现:直接使用请求数据
return request_data.copy()
def __init__(self, allowed_api_formats: Optional[list[str]] = None): def __init__(self, allowed_api_formats: Optional[list[str]] = None):
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID] self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
@@ -654,6 +682,65 @@ class ChatAdapterBase(ApiAdapter):
# 默认实现返回空列表,子类应覆盖 # 默认实现返回空列表,子类应覆盖
return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models" return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models"
@classmethod
async def check_endpoint(
cls,
client: httpx.AsyncClient,
base_url: str,
api_key: str,
request_data: Dict[str, Any],
extra_headers: Optional[Dict[str, str]] = None,
# 用量计算参数(现在强制记录)
db: Optional[Any] = None,
user: Optional[Any] = None,
provider_name: Optional[str] = None,
provider_id: Optional[str] = None,
api_key_id: Optional[str] = None,
model_name: Optional[str] = None,
) -> Dict[str, Any]:
"""
测试模型连接性(非流式)
Args:
client: httpx 异步客户端
base_url: API 基础 URL
api_key: API 密钥(已解密)
request_data: 请求数据
extra_headers: 端点配置的额外请求头
db: 数据库会话
user: 用户对象
provider_name: 提供商名称
provider_id: 提供商ID
api_key_id: API Key ID
model_name: 模型名称
Returns:
测试响应数据
"""
from src.api.handlers.base.endpoint_checker import build_safe_headers, run_endpoint_check
# 使用子类配置方法构建请求组件
url = cls.build_endpoint_url(base_url)
base_headers = cls.build_base_headers(api_key)
protected_keys = cls.get_protected_header_keys()
headers = build_safe_headers(base_headers, extra_headers, protected_keys)
body = cls.build_request_body(request_data)
# 使用通用的endpoint checker执行请求
return await run_endpoint_check(
client=client,
url=url,
headers=headers,
json_body=body,
api_format=cls.name,
# 用量计算参数(现在强制记录)
db=db,
user=user,
provider_name=provider_name,
provider_id=provider_id,
api_key_id=api_key_id,
model_name=model_name or request_data.get("model"),
)
# ========================================================================= # =========================================================================
# Adapter 注册表 - 用于根据 API format 获取 Adapter 实例 # Adapter 注册表 - 用于根据 API format 获取 Adapter 实例

View File

@@ -614,6 +614,146 @@ class CliAdapterBase(ApiAdapter):
# 默认实现返回空列表,子类应覆盖 # 默认实现返回空列表,子类应覆盖
return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models" return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models"
@classmethod
async def check_endpoint(
cls,
client: httpx.AsyncClient,
base_url: str,
api_key: str,
request_data: Dict[str, Any],
extra_headers: Optional[Dict[str, str]] = None,
# 用量计算参数
db: Optional[Any] = None,
user: Optional[Any] = None,
provider_name: Optional[str] = None,
provider_id: Optional[str] = None,
api_key_id: Optional[str] = None,
model_name: Optional[str] = None,
) -> Dict[str, Any]:
"""
测试模型连接性(非流式)
通用的CLI endpoint测试方法使用配置方法模式
- build_endpoint_url(): 构建请求URL
- build_base_headers(): 构建基础认证头
- get_protected_header_keys(): 获取受保护的头部key
- build_request_body(): 构建请求体
- get_cli_user_agent(): 获取CLI User-Agent子类可覆盖
Args:
client: httpx 异步客户端
base_url: API 基础 URL
api_key: API 密钥(已解密)
request_data: 请求数据
extra_headers: 端点配置的额外请求头
db: 数据库会话
user: 用户对象
provider_name: 提供商名称
provider_id: 提供商ID
api_key_id: API密钥ID
model_name: 模型名称
Returns:
测试响应数据
"""
from src.api.handlers.base.endpoint_checker import build_safe_headers, run_endpoint_check
# 构建请求组件
url = cls.build_endpoint_url(base_url, request_data, model_name)
base_headers = cls.build_base_headers(api_key)
protected_keys = cls.get_protected_header_keys()
# 添加CLI User-Agent
cli_user_agent = cls.get_cli_user_agent()
if cli_user_agent:
base_headers["User-Agent"] = cli_user_agent
protected_keys = tuple(list(protected_keys) + ["user-agent"])
headers = build_safe_headers(base_headers, extra_headers, protected_keys)
body = cls.build_request_body(request_data)
# 获取有效的模型名称
effective_model_name = model_name or request_data.get("model")
return await run_endpoint_check(
client=client,
url=url,
headers=headers,
json_body=body,
api_format=cls.name,
# 用量计算参数(现在强制记录)
db=db,
user=user,
provider_name=provider_name,
provider_id=provider_id,
api_key_id=api_key_id,
model_name=effective_model_name,
)
# =========================================================================
# CLI Adapter 配置方法 - 子类应覆盖这些方法而不是整个 check_endpoint
# =========================================================================
@classmethod
def build_endpoint_url(cls, base_url: str, request_data: Dict[str, Any], model_name: Optional[str] = None) -> str:
"""
构建CLI API端点URL - 子类应覆盖
Args:
base_url: API基础URL
request_data: 请求数据
model_name: 模型名称某些API需要如Gemini
Returns:
完整的端点URL
"""
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement build_endpoint_url")
@classmethod
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
"""
构建CLI API认证头 - 子类应覆盖
Args:
api_key: API密钥
Returns:
基础认证头部字典
"""
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement build_base_headers")
@classmethod
def get_protected_header_keys(cls) -> tuple:
"""
返回CLI API的保护头部key - 子类应覆盖
Returns:
保护头部key的元组
"""
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement get_protected_header_keys")
@classmethod
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""
构建CLI API请求体 - 子类应覆盖
Args:
request_data: 请求数据
Returns:
请求体字典
"""
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement build_request_body")
@classmethod
def get_cli_user_agent(cls) -> Optional[str]:
"""
获取CLI User-Agent - 子类可覆盖
Returns:
CLI User-Agent字符串如果不需要则为None
"""
return None
# ========================================================================= # =========================================================================
# CLI Adapter 注册表 - 用于根据 API format 获取 CLI Adapter 实例 # CLI Adapter 注册表 - 用于根据 API format 获取 CLI Adapter 实例

File diff suppressed because it is too large Load Diff

View File

@@ -209,6 +209,38 @@ class ClaudeChatAdapter(ChatAdapterBase):
logger.warning(f"Failed to fetch Claude models from {models_url}: {e}") logger.warning(f"Failed to fetch Claude models from {models_url}: {e}")
return [], error_msg return [], error_msg
@classmethod
def build_endpoint_url(cls, base_url: str) -> str:
"""构建Claude API端点URL"""
base_url = base_url.rstrip("/")
if base_url.endswith("/v1"):
return f"{base_url}/messages"
else:
return f"{base_url}/v1/messages"
@classmethod
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
"""构建Claude API认证头"""
return {
"x-api-key": api_key,
"Content-Type": "application/json",
"anthropic-version": "2023-06-01",
}
@classmethod
def get_protected_header_keys(cls) -> tuple:
"""返回Claude API的保护头部key"""
return ("x-api-key", "content-type", "anthropic-version")
@classmethod
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""构建Claude API请求体"""
return {
"model": request_data.get("model"),
"max_tokens": request_data.get("max_tokens", 100),
"messages": request_data.get("messages", []),
}
def build_claude_adapter(x_app_header: Optional[str]): def build_claude_adapter(x_app_header: Optional[str]):
"""根据 x-app 头部构造 Chat 或 Claude Code 适配器。""" """根据 x-app 头部构造 Chat 或 Claude Code 适配器。"""

View File

@@ -4,7 +4,7 @@ Claude CLI Adapter - 基于通用 CLI Adapter 基类的简化实现
继承 CliAdapterBase只需配置 FORMAT_ID 和 HANDLER_CLASS。 继承 CliAdapterBase只需配置 FORMAT_ID 和 HANDLER_CLASS。
""" """
from typing import Any, Dict, Optional, Tuple, Type from typing import Any, AsyncIterator, Dict, Optional, Tuple, Type, Union
import httpx import httpx
from fastapi import Request from fastapi import Request
@@ -126,5 +126,41 @@ class ClaudeCliAdapter(CliAdapterBase):
m["api_format"] = cls.FORMAT_ID m["api_format"] = cls.FORMAT_ID
return models, error return models, error
@classmethod
def build_endpoint_url(cls, base_url: str, request_data: Dict[str, Any], model_name: Optional[str] = None) -> str:
"""构建Claude CLI API端点URL"""
base_url = base_url.rstrip("/")
if base_url.endswith("/v1"):
return f"{base_url}/messages"
else:
return f"{base_url}/v1/messages"
@classmethod
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
"""构建Claude CLI API认证头"""
return {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
@classmethod
def get_protected_header_keys(cls) -> tuple:
"""返回Claude CLI API的保护头部key"""
return ("authorization", "content-type")
@classmethod
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""构建Claude CLI API请求体"""
return {
"model": request_data.get("model"),
"max_tokens": request_data.get("max_tokens", 100),
"messages": request_data.get("messages", []),
}
@classmethod
def get_cli_user_agent(cls) -> Optional[str]:
"""获取Claude CLI User-Agent"""
return config.internal_user_agent_claude_cli
__all__ = ["ClaudeCliAdapter"] __all__ = ["ClaudeCliAdapter"]

View File

@@ -4,7 +4,7 @@ Gemini Chat Adapter
处理 Gemini API 格式的请求适配 处理 Gemini API 格式的请求适配
""" """
from typing import Any, Dict, Optional, Tuple, Type from typing import Any, AsyncIterator, Dict, Optional, Tuple, Type, Union
import httpx import httpx
from fastapi import HTTPException, Request from fastapi import HTTPException, Request
@@ -12,6 +12,7 @@ from fastapi.responses import JSONResponse
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
from src.api.handlers.base.chat_handler_base import ChatHandlerBase from src.api.handlers.base.chat_handler_base import ChatHandlerBase
from src.api.handlers.base.endpoint_checker import build_safe_headers, run_endpoint_check
from src.core.logger import logger from src.core.logger import logger
from src.models.gemini import GeminiRequest from src.models.gemini import GeminiRequest
@@ -199,6 +200,94 @@ class GeminiChatAdapter(ChatAdapterBase):
logger.warning(f"Failed to fetch Gemini models from {models_url}: {e}") logger.warning(f"Failed to fetch Gemini models from {models_url}: {e}")
return [], error_msg return [], error_msg
@classmethod
def build_endpoint_url(cls, base_url: str) -> str:
"""构建Gemini API端点URL"""
base_url = base_url.rstrip("/")
if base_url.endswith("/v1beta"):
return base_url # 子类需要处理model参数
else:
return f"{base_url}/v1beta"
@classmethod
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
"""构建Gemini API认证头"""
return {
"x-goog-api-key": api_key,
"Content-Type": "application/json",
}
@classmethod
def get_protected_header_keys(cls) -> tuple:
"""返回Gemini API的保护头部key"""
return ("x-goog-api-key", "content-type")
@classmethod
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""构建Gemini API请求体"""
return {
"contents": request_data.get("messages", []),
"generationConfig": {
"maxOutputTokens": request_data.get("max_tokens", 100),
"temperature": request_data.get("temperature", 0.7),
},
"safetySettings": [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}
],
}
@classmethod
async def check_endpoint(
cls,
client: httpx.AsyncClient,
base_url: str,
api_key: str,
request_data: Dict[str, Any],
extra_headers: Optional[Dict[str, str]] = None,
# 用量计算参数
db: Optional[Any] = None,
user: Optional[Any] = None,
provider_name: Optional[str] = None,
provider_id: Optional[str] = None,
api_key_id: Optional[str] = None,
model_name: Optional[str] = None,
) -> Dict[str, Any]:
"""测试 Gemini API 模型连接性(非流式)"""
# Gemini需要从request_data或model_name参数获取model名称
effective_model_name = model_name or request_data.get("model", "")
if not effective_model_name:
return {
"error": "Model name is required for Gemini API",
"status_code": 400,
}
# 使用基类配置方法但重写URL构建逻辑
base_url = cls.build_endpoint_url(base_url)
url = f"{base_url}/models/{effective_model_name}:generateContent"
# 构建请求组件
base_headers = cls.build_base_headers(api_key)
protected_keys = cls.get_protected_header_keys()
headers = build_safe_headers(base_headers, extra_headers, protected_keys)
body = cls.build_request_body(request_data)
# 使用基类的通用endpoint checker
from src.api.handlers.base.endpoint_checker import run_endpoint_check
return await run_endpoint_check(
client=client,
url=url,
headers=headers,
json_body=body,
api_format=cls.name,
# 用量计算参数(现在强制记录)
db=db,
user=user,
provider_name=provider_name,
provider_id=provider_id,
api_key_id=api_key_id,
model_name=effective_model_name,
)
def build_gemini_adapter(x_app_header: str = "") -> GeminiChatAdapter: def build_gemini_adapter(x_app_header: str = "") -> GeminiChatAdapter:
""" """

View File

@@ -4,7 +4,7 @@ Gemini CLI Adapter - 基于通用 CLI Adapter 基类的实现
继承 CliAdapterBase处理 Gemini CLI 格式的请求。 继承 CliAdapterBase处理 Gemini CLI 格式的请求。
""" """
from typing import Any, Dict, Optional, Tuple, Type from typing import Any, AsyncIterator, Dict, Optional, Tuple, Type, Union
import httpx import httpx
from fastapi import Request from fastapi import Request
@@ -123,6 +123,52 @@ class GeminiCliAdapter(CliAdapterBase):
m["api_format"] = cls.FORMAT_ID m["api_format"] = cls.FORMAT_ID
return models, error return models, error
@classmethod
def build_endpoint_url(cls, base_url: str, request_data: Dict[str, Any], model_name: Optional[str] = None) -> str:
"""构建Gemini CLI API端点URL"""
effective_model_name = model_name or request_data.get("model", "")
if not effective_model_name:
raise ValueError("Model name is required for Gemini API")
base_url = base_url.rstrip("/")
if base_url.endswith("/v1beta"):
prefix = base_url
else:
prefix = f"{base_url}/v1beta"
return f"{prefix}/models/{effective_model_name}:generateContent"
@classmethod
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
"""构建Gemini CLI API认证头"""
return {
"x-goog-api-key": api_key,
"Content-Type": "application/json",
}
@classmethod
def get_protected_header_keys(cls) -> tuple:
"""返回Gemini CLI API的保护头部key"""
return ("x-goog-api-key", "content-type")
@classmethod
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""构建Gemini CLI API请求体"""
return {
"contents": request_data.get("messages", []),
"generationConfig": {
"maxOutputTokens": request_data.get("max_tokens", 100),
"temperature": request_data.get("temperature", 0.7),
},
"safetySettings": [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}
],
}
@classmethod
def get_cli_user_agent(cls) -> Optional[str]:
"""获取Gemini CLI User-Agent"""
return config.internal_user_agent_gemini_cli
def build_gemini_cli_adapter(x_app_header: str = "") -> GeminiCliAdapter: def build_gemini_cli_adapter(x_app_header: str = "") -> GeminiCliAdapter:
""" """

View File

@@ -4,13 +4,14 @@ OpenAI Chat Adapter - 基于 ChatAdapterBase 的 OpenAI Chat API 适配器
处理 /v1/chat/completions 端点的 OpenAI Chat 格式请求。 处理 /v1/chat/completions 端点的 OpenAI Chat 格式请求。
""" """
from typing import Any, Dict, Optional, Tuple, Type from typing import Any, AsyncIterator, Dict, Optional, Tuple, Type, Union
import httpx import httpx
from fastapi import Request from fastapi import Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
from src.api.handlers.base.endpoint_checker import build_safe_headers, run_endpoint_check
from src.api.handlers.base.chat_handler_base import ChatHandlerBase from src.api.handlers.base.chat_handler_base import ChatHandlerBase
from src.core.logger import logger from src.core.logger import logger
from src.models.openai import OpenAIRequest from src.models.openai import OpenAIRequest
@@ -154,5 +155,32 @@ class OpenAIChatAdapter(ChatAdapterBase):
logger.warning(f"Failed to fetch models from {models_url}: {e}") logger.warning(f"Failed to fetch models from {models_url}: {e}")
return [], error_msg return [], error_msg
@classmethod
def build_endpoint_url(cls, base_url: str) -> str:
"""构建OpenAI API端点URL"""
base_url = base_url.rstrip("/")
if base_url.endswith("/v1"):
return f"{base_url}/chat/completions"
else:
return f"{base_url}/v1/chat/completions"
@classmethod
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
"""构建OpenAI API认证头"""
return {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
@classmethod
def get_protected_header_keys(cls) -> tuple:
"""返回OpenAI API的保护头部key"""
return ("authorization", "content-type")
@classmethod
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""构建OpenAI API请求体"""
return request_data.copy()
__all__ = ["OpenAIChatAdapter"] __all__ = ["OpenAIChatAdapter"]

View File

@@ -4,7 +4,7 @@ OpenAI CLI Adapter - 基于通用 CLI Adapter 基类的简化实现
继承 CliAdapterBase只需配置 FORMAT_ID 和 HANDLER_CLASS。 继承 CliAdapterBase只需配置 FORMAT_ID 和 HANDLER_CLASS。
""" """
from typing import Dict, Optional, Tuple, Type from typing import Any, AsyncIterator, Dict, Optional, Tuple, Type, Union
import httpx import httpx
from fastapi import Request from fastapi import Request
@@ -68,5 +68,37 @@ class OpenAICliAdapter(CliAdapterBase):
m["api_format"] = cls.FORMAT_ID m["api_format"] = cls.FORMAT_ID
return models, error return models, error
@classmethod
def build_endpoint_url(cls, base_url: str, request_data: Dict[str, Any], model_name: Optional[str] = None) -> str:
"""构建OpenAI CLI API端点URL"""
base_url = base_url.rstrip("/")
if base_url.endswith("/v1"):
return f"{base_url}/chat/completions"
else:
return f"{base_url}/v1/chat/completions"
@classmethod
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
"""构建OpenAI CLI API认证头"""
return {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
@classmethod
def get_protected_header_keys(cls) -> tuple:
"""返回OpenAI CLI API的保护头部key"""
return ("authorization", "content-type")
@classmethod
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""构建OpenAI CLI API请求体"""
return request_data.copy()
@classmethod
def get_cli_user_agent(cls) -> Optional[str]:
"""获取OpenAI CLI User-Agent"""
return config.internal_user_agent_openai_cli
__all__ = ["OpenAICliAdapter"] __all__ = ["OpenAICliAdapter"]

View File

@@ -234,8 +234,15 @@ class EndpointHealthService:
for api_format in format_key_mapping.keys() for api_format in format_key_mapping.keys()
} }
# 参数校验API 层已通过 Query(ge=1) 保证,这里做防御性检查)
if lookback_hours <= 0 or segments <= 0:
raise ValueError(
f"lookback_hours and segments must be positive, "
f"got lookback_hours={lookback_hours}, segments={segments}"
)
# 计算时间范围 # 计算时间范围
interval_minutes = (lookback_hours * 60) // segments segment_seconds = (lookback_hours * 3600) / segments
start_time = now - timedelta(hours=lookback_hours) start_time = now - timedelta(hours=lookback_hours)
# 使用 RequestCandidate 表查询所有尝试记录 # 使用 RequestCandidate 表查询所有尝试记录
@@ -243,7 +250,7 @@ class EndpointHealthService:
final_statuses = ["success", "failed", "skipped"] final_statuses = ["success", "failed", "skipped"]
segment_expr = func.floor( segment_expr = func.floor(
func.extract('epoch', RequestCandidate.created_at - start_time) / (interval_minutes * 60) func.extract('epoch', RequestCandidate.created_at - start_time) / segment_seconds
).label('segment_idx') ).label('segment_idx')
candidate_stats = ( candidate_stats = (

View File

@@ -8,116 +8,86 @@ from src.api.handlers.base.utils import build_sse_headers, extract_cache_creatio
class TestExtractCacheCreationTokens: class TestExtractCacheCreationTokens:
"""测试 extract_cache_creation_tokens 函数""" """测试 extract_cache_creation_tokens 函数"""
# === 嵌套格式测试(优先级最高)=== def test_new_format_only(self) -> None:
"""测试只有新格式字段"""
def test_nested_cache_creation_format(self) -> None:
"""测试嵌套格式正常情况"""
usage = {
"cache_creation": {
"ephemeral_5m_input_tokens": 456,
"ephemeral_1h_input_tokens": 100,
}
}
assert extract_cache_creation_tokens(usage) == 556
def test_nested_cache_creation_with_old_format_fallback(self) -> None:
"""测试嵌套格式为 0 时回退到旧格式"""
usage = {
"cache_creation": {
"ephemeral_5m_input_tokens": 0,
"ephemeral_1h_input_tokens": 0,
},
"cache_creation_input_tokens": 549,
}
assert extract_cache_creation_tokens(usage) == 549
def test_nested_has_priority_over_flat(self) -> None:
"""测试嵌套格式优先于扁平格式"""
usage = {
"cache_creation": {
"ephemeral_5m_input_tokens": 100,
"ephemeral_1h_input_tokens": 200,
},
"claude_cache_creation_5_m_tokens": 999, # 应该被忽略
"claude_cache_creation_1_h_tokens": 888, # 应该被忽略
"cache_creation_input_tokens": 777, # 应该被忽略
}
assert extract_cache_creation_tokens(usage) == 300
# === 扁平格式测试(优先级第二)===
def test_flat_new_format_still_works(self) -> None:
"""测试扁平新格式兼容性"""
usage = { usage = {
"claude_cache_creation_5_m_tokens": 100, "claude_cache_creation_5_m_tokens": 100,
"claude_cache_creation_1_h_tokens": 200, "claude_cache_creation_1_h_tokens": 200,
} }
assert extract_cache_creation_tokens(usage) == 300 assert extract_cache_creation_tokens(usage) == 300
def test_flat_new_format_with_old_format_fallback(self) -> None: def test_new_format_5m_only(self) -> None:
"""测试扁平格式为 0 时回退到旧格式""" """测试只有 5 分钟缓存"""
usage = {
"claude_cache_creation_5_m_tokens": 0,
"claude_cache_creation_1_h_tokens": 0,
"cache_creation_input_tokens": 549,
}
assert extract_cache_creation_tokens(usage) == 549
def test_flat_new_format_5m_only(self) -> None:
"""测试只有 5 分钟扁平缓存"""
usage = { usage = {
"claude_cache_creation_5_m_tokens": 150, "claude_cache_creation_5_m_tokens": 150,
"claude_cache_creation_1_h_tokens": 0, "claude_cache_creation_1_h_tokens": 0,
} }
assert extract_cache_creation_tokens(usage) == 150 assert extract_cache_creation_tokens(usage) == 150
def test_flat_new_format_1h_only(self) -> None: def test_new_format_1h_only(self) -> None:
"""测试只有 1 小时扁平缓存""" """测试只有 1 小时缓存"""
usage = { usage = {
"claude_cache_creation_5_m_tokens": 0, "claude_cache_creation_5_m_tokens": 0,
"claude_cache_creation_1_h_tokens": 250, "claude_cache_creation_1_h_tokens": 250,
} }
assert extract_cache_creation_tokens(usage) == 250 assert extract_cache_creation_tokens(usage) == 250
# === 旧格式测试(优先级第三)===
def test_old_format_only(self) -> None: def test_old_format_only(self) -> None:
"""测试只有旧格式""" """测试只有旧格式字段"""
usage = { usage = {
"cache_creation_input_tokens": 549, "cache_creation_input_tokens": 500,
} }
assert extract_cache_creation_tokens(usage) == 549 assert extract_cache_creation_tokens(usage) == 500
# === 边界情况测试 === def test_both_formats_prefers_new(self) -> None:
"""测试同时存在时优先使用新格式"""
usage = {
"claude_cache_creation_5_m_tokens": 100,
"claude_cache_creation_1_h_tokens": 200,
"cache_creation_input_tokens": 999, # 应该被忽略
}
assert extract_cache_creation_tokens(usage) == 300
def test_no_cache_creation_tokens(self) -> None: def test_empty_usage(self) -> None:
"""测试没有任何缓存字段""" """测试空字典"""
usage = {} usage = {}
assert extract_cache_creation_tokens(usage) == 0 assert extract_cache_creation_tokens(usage) == 0
def test_all_formats_zero(self) -> None: def test_all_zeros(self) -> None:
"""测试所有格式都为 0""" """测试所有字段都为 0"""
usage = { usage = {
"cache_creation": {
"ephemeral_5m_input_tokens": 0,
"ephemeral_1h_input_tokens": 0,
},
"claude_cache_creation_5_m_tokens": 0, "claude_cache_creation_5_m_tokens": 0,
"claude_cache_creation_1_h_tokens": 0, "claude_cache_creation_1_h_tokens": 0,
"cache_creation_input_tokens": 0, "cache_creation_input_tokens": 0,
} }
assert extract_cache_creation_tokens(usage) == 0 assert extract_cache_creation_tokens(usage) == 0
def test_partial_new_format_with_old_format_fallback(self) -> None:
"""测试新格式字段不存在时回退到旧格式"""
usage = {
"cache_creation_input_tokens": 123,
}
assert extract_cache_creation_tokens(usage) == 123
def test_new_format_zero_should_not_fallback(self) -> None:
"""测试新格式字段存在但为 0 时,不应 fallback 到旧格式"""
usage = {
"claude_cache_creation_5_m_tokens": 0,
"claude_cache_creation_1_h_tokens": 0,
"cache_creation_input_tokens": 456,
}
# 新格式字段存在,即使值为 0 也应该使用新格式(返回 0
# 而不是 fallback 到旧格式(返回 456
assert extract_cache_creation_tokens(usage) == 0
def test_unrelated_fields_ignored(self) -> None: def test_unrelated_fields_ignored(self) -> None:
"""测试忽略无关字段""" """测试忽略无关字段"""
usage = { usage = {
"input_tokens": 1000, "input_tokens": 1000,
"output_tokens": 2000, "output_tokens": 2000,
"cache_read_input_tokens": 300, "cache_read_input_tokens": 300,
"cache_creation": { "claude_cache_creation_5_m_tokens": 50,
"ephemeral_5m_input_tokens": 50, "claude_cache_creation_1_h_tokens": 75,
"ephemeral_1h_input_tokens": 75,
},
} }
assert extract_cache_creation_tokens(usage) == 125 assert extract_cache_creation_tokens(usage) == 125