mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 00:02:28 +08:00
Compare commits
24 Commits
v0.1.22
...
394cc536a9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
394cc536a9 | ||
|
|
e20a09f15a | ||
|
|
b89a4af0cf | ||
|
|
a56854af43 | ||
|
|
4a35d78c8d | ||
|
|
26b281271e | ||
|
|
96094cfde2 | ||
|
|
7e26af5476 | ||
|
|
c8dfb784bc | ||
|
|
fd3a5a5afe | ||
|
|
599b3d4c95 | ||
|
|
41719a00e7 | ||
|
|
b5c0f85dca | ||
|
|
7d6d262ed3 | ||
|
|
e21acd73eb | ||
|
|
702f9bc5f1 | ||
|
|
d0ce798881 | ||
|
|
2b1d197047 | ||
|
|
71bc2e6aab | ||
|
|
afb329934a | ||
|
|
1313af45a3 | ||
|
|
dddb327885 | ||
|
|
26b4a37323 | ||
|
|
9dad194130 |
15
LICENSE
15
LICENSE
@@ -5,12 +5,17 @@ Aether 非商业开源许可证
|
|||||||
特此授予任何获得本软件及其相关文档文件(以下简称"软件")副本的人免费使用、
|
特此授予任何获得本软件及其相关文档文件(以下简称"软件")副本的人免费使用、
|
||||||
复制、修改、合并、发布和分发本软件的权限,但须遵守以下条件:
|
复制、修改、合并、发布和分发本软件的权限,但须遵守以下条件:
|
||||||
|
|
||||||
1. 仅限非商业用途
|
1. 仅限非盈利用途
|
||||||
本软件不得用于商业目的。商业目的包括但不限于:
|
本软件不得用于盈利目的。盈利目的包括但不限于:
|
||||||
- 出售本软件或任何衍生作品
|
- 出售本软件或任何衍生作品
|
||||||
- 使用本软件提供付费服务
|
- 使用本软件提供付费服务
|
||||||
- 将本软件用于商业产品或服务
|
- 将本软件用于以盈利为目的的商业产品或服务
|
||||||
- 将本软件用于任何旨在获取商业利益或金钱报酬的活动
|
|
||||||
|
以下用途被明确允许:
|
||||||
|
- 个人学习和研究
|
||||||
|
- 教育机构的教学和研究
|
||||||
|
- 非盈利组织的内部使用
|
||||||
|
- 企业内部非盈利性质的使用(如内部工具、测试环境等)
|
||||||
|
|
||||||
2. 署名要求
|
2. 署名要求
|
||||||
上述版权声明和本许可声明应包含在本软件的所有副本或主要部分中。
|
上述版权声明和本许可声明应包含在本软件的所有副本或主要部分中。
|
||||||
@@ -22,7 +27,7 @@ Aether 非商业开源许可证
|
|||||||
您不得以不同的条款将本软件再许可给他人。
|
您不得以不同的条款将本软件再许可给他人。
|
||||||
|
|
||||||
5. 商业许可
|
5. 商业许可
|
||||||
如需商业使用,请联系版权持有人以获取单独的商业许可。
|
如需将本软件用于盈利目的,请联系版权持有人以获取单独的商业许可。
|
||||||
|
|
||||||
本软件按"原样"提供,不提供任何明示或暗示的保证,包括但不限于对适销性、
|
本软件按"原样"提供,不提供任何明示或暗示的保证,包括但不限于对适销性、
|
||||||
特定用途适用性和非侵权性的保证。在任何情况下,作者或版权持有人均不对任何
|
特定用途适用性和非侵权性的保证。在任何情况下,作者或版权持有人均不对任何
|
||||||
|
|||||||
16
README.md
16
README.md
@@ -143,7 +143,7 @@ cd frontend && npm install && npm run dev
|
|||||||
- **模型级别**: 在模型管理中针对指定模型开启 1H缓存策略
|
- **模型级别**: 在模型管理中针对指定模型开启 1H缓存策略
|
||||||
- **密钥级别**: 在密钥管理中针对指定密钥使用 1H缓存策略
|
- **密钥级别**: 在密钥管理中针对指定密钥使用 1H缓存策略
|
||||||
|
|
||||||
> **注意**: 若对密钥设置强制 1H缓存, 则该密钥只能调用支持 1H缓存的模型
|
> **注意**: 若对密钥设置强制 1H缓存, 则该密钥只能使用支持 1H缓存的模型, 匹配提供商Key, 将会导致这个Key无法同时用于Claude Code、Codex、GeminiCLI, 因为更推荐使用模型开启1H缓存.
|
||||||
|
|
||||||
### Q: 如何配置负载均衡?
|
### Q: 如何配置负载均衡?
|
||||||
|
|
||||||
@@ -162,4 +162,16 @@ cd frontend && npm install && npm run dev
|
|||||||
|
|
||||||
## 许可证
|
## 许可证
|
||||||
|
|
||||||
本项目采用 [Aether 非商业开源许可证](LICENSE)。
|
本项目采用 [Aether 非商业开源许可证](LICENSE)。允许个人学习、教育研究、非盈利组织及企业内部非盈利性质的使用;禁止用于盈利目的。商业使用请联系获取商业许可。
|
||||||
|
|
||||||
|
## 联系作者
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="docs/author/qq_qrcode.jpg" width="200" alt="QQ二维码">
|
||||||
|
</p>
|
||||||
|
|
||||||
|
## Star History
|
||||||
|
|
||||||
|
[](https://star-history.com/#fawney19/Aether&Date)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
BIN
docs/author/qq_qrcode.jpg
Normal file
BIN
docs/author/qq_qrcode.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 266 KiB |
BIN
docs/author/wechat_payment.jpg
Normal file
BIN
docs/author/wechat_payment.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 113 KiB |
@@ -87,6 +87,8 @@ export interface DashboardStatsResponse {
|
|||||||
cache_stats?: CacheStats
|
cache_stats?: CacheStats
|
||||||
users?: UserStats
|
users?: UserStats
|
||||||
token_breakdown?: TokenBreakdown
|
token_breakdown?: TokenBreakdown
|
||||||
|
// 普通用户专用字段
|
||||||
|
monthly_cost?: number
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface RecentRequestsResponse {
|
export interface RecentRequestsResponse {
|
||||||
|
|||||||
@@ -4,7 +4,8 @@ import type {
|
|||||||
GlobalModelUpdate,
|
GlobalModelUpdate,
|
||||||
GlobalModelResponse,
|
GlobalModelResponse,
|
||||||
GlobalModelWithStats,
|
GlobalModelWithStats,
|
||||||
GlobalModelListResponse
|
GlobalModelListResponse,
|
||||||
|
ModelCatalogProviderDetail,
|
||||||
} from './types'
|
} from './types'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -83,3 +84,16 @@ export async function batchAssignToProviders(
|
|||||||
)
|
)
|
||||||
return response.data
|
return response.data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取 GlobalModel 的所有关联提供商(包括非活跃的)
|
||||||
|
*/
|
||||||
|
export async function getGlobalModelProviders(globalModelId: string): Promise<{
|
||||||
|
providers: ModelCatalogProviderDetail[]
|
||||||
|
total: number
|
||||||
|
}> {
|
||||||
|
const response = await client.get(
|
||||||
|
`/api/admin/models/global/${globalModelId}/providers`
|
||||||
|
)
|
||||||
|
return response.data
|
||||||
|
}
|
||||||
|
|||||||
@@ -110,6 +110,14 @@ export async function updateEndpointKey(
|
|||||||
return response.data
|
return response.data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取完整的 API Key(用于查看和复制)
|
||||||
|
*/
|
||||||
|
export async function revealEndpointKey(keyId: string): Promise<{ api_key: string }> {
|
||||||
|
const response = await client.get(`/api/admin/endpoints/keys/${keyId}/reveal`)
|
||||||
|
return response.data
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 删除 Endpoint Key
|
* 删除 Endpoint Key
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -58,3 +58,38 @@ export async function deleteProvider(providerId: string): Promise<{ message: str
|
|||||||
return response.data
|
return response.data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 测试模型连接性
|
||||||
|
*/
|
||||||
|
export interface TestModelRequest {
|
||||||
|
provider_id: string
|
||||||
|
model_name: string
|
||||||
|
api_key_id?: string
|
||||||
|
message?: string
|
||||||
|
api_format?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface TestModelResponse {
|
||||||
|
success: boolean
|
||||||
|
error?: string
|
||||||
|
data?: {
|
||||||
|
response?: {
|
||||||
|
status_code?: number
|
||||||
|
error?: string | { message?: string }
|
||||||
|
choices?: Array<{ message?: { content?: string } }>
|
||||||
|
}
|
||||||
|
content_preview?: string
|
||||||
|
}
|
||||||
|
provider?: {
|
||||||
|
id: string
|
||||||
|
name: string
|
||||||
|
display_name: string
|
||||||
|
}
|
||||||
|
model?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function testModel(data: TestModelRequest): Promise<TestModelResponse> {
|
||||||
|
const response = await client.post('/api/admin/provider-query/test-model', data)
|
||||||
|
return response.data
|
||||||
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -20,4 +20,5 @@ export {
|
|||||||
updateGlobalModel,
|
updateGlobalModel,
|
||||||
deleteGlobalModel,
|
deleteGlobalModel,
|
||||||
batchAssignToProviders,
|
batchAssignToProviders,
|
||||||
|
getGlobalModelProviders,
|
||||||
} from './endpoints/global-models'
|
} from './endpoints/global-models'
|
||||||
|
|||||||
@@ -163,7 +163,9 @@ const contentZIndex = computed(() => (props.zIndex || 60) + 10)
|
|||||||
useEscapeKey(() => {
|
useEscapeKey(() => {
|
||||||
if (isOpen.value) {
|
if (isOpen.value) {
|
||||||
handleClose()
|
handleClose()
|
||||||
|
return true // 阻止其他监听器(如父级抽屉的 ESC 监听器)
|
||||||
}
|
}
|
||||||
|
return false
|
||||||
}, {
|
}, {
|
||||||
disableOnInput: true,
|
disableOnInput: true,
|
||||||
once: false
|
once: false
|
||||||
|
|||||||
@@ -4,11 +4,11 @@ import { log } from '@/utils/logger'
|
|||||||
export function useClipboard() {
|
export function useClipboard() {
|
||||||
const { success, error: showError } = useToast()
|
const { success, error: showError } = useToast()
|
||||||
|
|
||||||
async function copyToClipboard(text: string): Promise<boolean> {
|
async function copyToClipboard(text: string, showToast = true): Promise<boolean> {
|
||||||
try {
|
try {
|
||||||
if (navigator.clipboard && window.isSecureContext) {
|
if (navigator.clipboard && window.isSecureContext) {
|
||||||
await navigator.clipboard.writeText(text)
|
await navigator.clipboard.writeText(text)
|
||||||
success('已复制到剪贴板')
|
if (showToast) success('已复制到剪贴板')
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -25,17 +25,17 @@ export function useClipboard() {
|
|||||||
try {
|
try {
|
||||||
const successful = document.execCommand('copy')
|
const successful = document.execCommand('copy')
|
||||||
if (successful) {
|
if (successful) {
|
||||||
success('已复制到剪贴板')
|
if (showToast) success('已复制到剪贴板')
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
showError('复制失败,请手动复制')
|
if (showToast) showError('复制失败,请手动复制')
|
||||||
return false
|
return false
|
||||||
} finally {
|
} finally {
|
||||||
document.body.removeChild(textArea)
|
document.body.removeChild(textArea)
|
||||||
}
|
}
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
log.error('复制失败:', err)
|
log.error('复制失败:', err)
|
||||||
showError('复制失败,请手动选择文本进行复制')
|
if (showToast) showError('复制失败,请手动选择文本进行复制')
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -47,11 +47,11 @@ export function useConfirm() {
|
|||||||
/**
|
/**
|
||||||
* 便捷方法:危险操作确认(红色主题)
|
* 便捷方法:危险操作确认(红色主题)
|
||||||
*/
|
*/
|
||||||
const confirmDanger = (message: string, title?: string): Promise<boolean> => {
|
const confirmDanger = (message: string, title?: string, confirmText?: string): Promise<boolean> => {
|
||||||
return confirm({
|
return confirm({
|
||||||
message,
|
message,
|
||||||
title: title || '危险操作',
|
title: title || '危险操作',
|
||||||
confirmText: '删除',
|
confirmText: confirmText || '删除',
|
||||||
variant: 'danger'
|
variant: 'danger'
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,11 +4,11 @@ import { onMounted, onUnmounted, ref } from 'vue'
|
|||||||
* ESC 键监听 Composable(简化版本,直接使用独立监听器)
|
* ESC 键监听 Composable(简化版本,直接使用独立监听器)
|
||||||
* 用于按 ESC 键关闭弹窗或其他可关闭的组件
|
* 用于按 ESC 键关闭弹窗或其他可关闭的组件
|
||||||
*
|
*
|
||||||
* @param callback - 按 ESC 键时执行的回调函数
|
* @param callback - 按 ESC 键时执行的回调函数,返回 true 表示已处理事件,阻止其他监听器执行
|
||||||
* @param options - 配置选项
|
* @param options - 配置选项
|
||||||
*/
|
*/
|
||||||
export function useEscapeKey(
|
export function useEscapeKey(
|
||||||
callback: () => void,
|
callback: () => void | boolean,
|
||||||
options: {
|
options: {
|
||||||
/** 是否在输入框获得焦点时禁用 ESC 键,默认 true */
|
/** 是否在输入框获得焦点时禁用 ESC 键,默认 true */
|
||||||
disableOnInput?: boolean
|
disableOnInput?: boolean
|
||||||
@@ -42,8 +42,11 @@ export function useEscapeKey(
|
|||||||
if (isInputElement) return
|
if (isInputElement) return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 执行回调
|
// 执行回调,如果返回 true 则阻止其他监听器
|
||||||
callback()
|
const handled = callback()
|
||||||
|
if (handled === true) {
|
||||||
|
event.stopImmediatePropagation()
|
||||||
|
}
|
||||||
|
|
||||||
// 移除当前元素的焦点,避免残留样式
|
// 移除当前元素的焦点,避免残留样式
|
||||||
if (document.activeElement instanceof HTMLElement) {
|
if (document.activeElement instanceof HTMLElement) {
|
||||||
|
|||||||
@@ -700,6 +700,7 @@ import {
|
|||||||
} from 'lucide-vue-next'
|
} from 'lucide-vue-next'
|
||||||
import { useEscapeKey } from '@/composables/useEscapeKey'
|
import { useEscapeKey } from '@/composables/useEscapeKey'
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
|
import { useClipboard } from '@/composables/useClipboard'
|
||||||
import Card from '@/components/ui/card.vue'
|
import Card from '@/components/ui/card.vue'
|
||||||
import Badge from '@/components/ui/badge.vue'
|
import Badge from '@/components/ui/badge.vue'
|
||||||
import Button from '@/components/ui/button.vue'
|
import Button from '@/components/ui/button.vue'
|
||||||
@@ -731,6 +732,7 @@ const emit = defineEmits<{
|
|||||||
'refreshProviders': []
|
'refreshProviders': []
|
||||||
}>()
|
}>()
|
||||||
const { success: showSuccess, error: showError } = useToast()
|
const { success: showSuccess, error: showError } = useToast()
|
||||||
|
const { copyToClipboard } = useClipboard()
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
model: GlobalModelResponse | null
|
model: GlobalModelResponse | null
|
||||||
@@ -763,16 +765,6 @@ function handleClose() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 复制到剪贴板
|
|
||||||
async function copyToClipboard(text: string) {
|
|
||||||
try {
|
|
||||||
await navigator.clipboard.writeText(text)
|
|
||||||
showSuccess('已复制')
|
|
||||||
} catch {
|
|
||||||
showError('复制失败')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 格式化日期
|
// 格式化日期
|
||||||
function formatDate(dateStr: string): string {
|
function formatDate(dateStr: string): string {
|
||||||
if (!dateStr) return '-'
|
if (!dateStr) return '-'
|
||||||
|
|||||||
@@ -433,11 +433,17 @@ const availableGlobalModels = computed(() => {
|
|||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
// 计算可添加的上游模型(排除已关联的)
|
// 计算可添加的上游模型(排除已关联的,包括主模型名和映射名称)
|
||||||
const availableUpstreamModelsBase = computed(() => {
|
const availableUpstreamModelsBase = computed(() => {
|
||||||
const existingModelNames = new Set(
|
const existingModelNames = new Set<string>()
|
||||||
existingModels.value.map(m => m.provider_model_name)
|
for (const m of existingModels.value) {
|
||||||
)
|
// 主模型名
|
||||||
|
existingModelNames.add(m.provider_model_name)
|
||||||
|
// 映射名称
|
||||||
|
for (const mapping of m.provider_model_mappings ?? []) {
|
||||||
|
if (mapping.name) existingModelNames.add(mapping.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
return upstreamModels.value.filter(m => !existingModelNames.has(m.id))
|
return upstreamModels.value.filter(m => !existingModelNames.has(m.id))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -116,6 +116,19 @@
|
|||||||
{{ model.global_model_name }}
|
{{ model.global_model_name }}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- 测试按钮 -->
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
size="icon"
|
||||||
|
class="h-7 w-7 shrink-0"
|
||||||
|
title="测试模型连接"
|
||||||
|
:disabled="testingModelName === model.global_model_name"
|
||||||
|
@click.stop="testModelConnection(model)"
|
||||||
|
>
|
||||||
|
<Loader2 v-if="testingModelName === model.global_model_name" class="w-3.5 h-3.5 animate-spin" />
|
||||||
|
<Play v-else class="w-3.5 h-3.5" />
|
||||||
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -148,16 +161,17 @@
|
|||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, computed, watch } from 'vue'
|
import { ref, computed, watch } from 'vue'
|
||||||
import { Box, Loader2, Settings2 } from 'lucide-vue-next'
|
import { Box, Loader2, Settings2, Play } from 'lucide-vue-next'
|
||||||
import { Dialog } from '@/components/ui'
|
import { Dialog } from '@/components/ui'
|
||||||
import Button from '@/components/ui/button.vue'
|
import Button from '@/components/ui/button.vue'
|
||||||
import Badge from '@/components/ui/badge.vue'
|
import Badge from '@/components/ui/badge.vue'
|
||||||
import Checkbox from '@/components/ui/checkbox.vue'
|
import Checkbox from '@/components/ui/checkbox.vue'
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
import { parseApiError } from '@/utils/errorParser'
|
import { parseApiError, parseTestModelError } from '@/utils/errorParser'
|
||||||
import {
|
import {
|
||||||
updateEndpointKey,
|
updateEndpointKey,
|
||||||
getProviderAvailableSourceModels,
|
getProviderAvailableSourceModels,
|
||||||
|
testModel,
|
||||||
type EndpointAPIKey,
|
type EndpointAPIKey,
|
||||||
type ProviderAvailableSourceModel
|
type ProviderAvailableSourceModel
|
||||||
} from '@/api/endpoints'
|
} from '@/api/endpoints'
|
||||||
@@ -181,6 +195,7 @@ const loadingModels = ref(false)
|
|||||||
const availableModels = ref<ProviderAvailableSourceModel[]>([])
|
const availableModels = ref<ProviderAvailableSourceModel[]>([])
|
||||||
const selectedModels = ref<string[]>([])
|
const selectedModels = ref<string[]>([])
|
||||||
const initialModels = ref<string[]>([])
|
const initialModels = ref<string[]>([])
|
||||||
|
const testingModelName = ref<string | null>(null)
|
||||||
|
|
||||||
// 监听对话框打开
|
// 监听对话框打开
|
||||||
watch(() => props.open, (open) => {
|
watch(() => props.open, (open) => {
|
||||||
@@ -268,6 +283,32 @@ function clearModels() {
|
|||||||
selectedModels.value = []
|
selectedModels.value = []
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 测试模型连接
|
||||||
|
async function testModelConnection(model: ProviderAvailableSourceModel) {
|
||||||
|
if (!props.providerId || !props.apiKey || testingModelName.value) return
|
||||||
|
|
||||||
|
testingModelName.value = model.global_model_name
|
||||||
|
try {
|
||||||
|
const result = await testModel({
|
||||||
|
provider_id: props.providerId,
|
||||||
|
model_name: model.provider_model_name,
|
||||||
|
api_key_id: props.apiKey.id,
|
||||||
|
message: "hello"
|
||||||
|
})
|
||||||
|
|
||||||
|
if (result.success) {
|
||||||
|
success(`模型 "${model.display_name}" 测试成功`)
|
||||||
|
} else {
|
||||||
|
showError(`模型测试失败: ${parseTestModelError(result)}`)
|
||||||
|
}
|
||||||
|
} catch (err: any) {
|
||||||
|
const errorMsg = err.response?.data?.detail || err.message || '测试请求失败'
|
||||||
|
showError(`模型测试失败: ${errorMsg}`)
|
||||||
|
} finally {
|
||||||
|
testingModelName.value = null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function areArraysEqual(a: string[], b: string[]): boolean {
|
function areArraysEqual(a: string[], b: string[]): boolean {
|
||||||
if (a.length !== b.length) return false
|
if (a.length !== b.length) return false
|
||||||
const sortedA = [...a].sort()
|
const sortedA = [...a].sort()
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
v-model:open="modelSelectOpen"
|
v-model:open="modelSelectOpen"
|
||||||
:model-value="formData.modelId"
|
:model-value="formData.modelId"
|
||||||
:disabled="!!editingGroup"
|
:disabled="!!editingGroup"
|
||||||
@update:model-value="formData.modelId = $event"
|
@update:model-value="handleModelChange"
|
||||||
>
|
>
|
||||||
<SelectTrigger class="h-9">
|
<SelectTrigger class="h-9">
|
||||||
<SelectValue placeholder="请选择模型" />
|
<SelectValue placeholder="请选择模型" />
|
||||||
@@ -449,7 +449,17 @@ interface UpstreamModelGroup {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const groupedAvailableUpstreamModels = computed<UpstreamModelGroup[]>(() => {
|
const groupedAvailableUpstreamModels = computed<UpstreamModelGroup[]>(() => {
|
||||||
|
// 收集当前表单已添加的名称
|
||||||
const addedNames = new Set(formData.value.aliases.map(a => a.name.trim()))
|
const addedNames = new Set(formData.value.aliases.map(a => a.name.trim()))
|
||||||
|
|
||||||
|
// 收集所有已存在的映射名称(包括主模型名和映射名称)
|
||||||
|
for (const m of props.models) {
|
||||||
|
addedNames.add(m.provider_model_name)
|
||||||
|
for (const mapping of m.provider_model_mappings ?? []) {
|
||||||
|
if (mapping.name) addedNames.add(mapping.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const availableModels = filteredUpstreamModels.value.filter(m => !addedNames.has(m.id))
|
const availableModels = filteredUpstreamModels.value.filter(m => !addedNames.has(m.id))
|
||||||
|
|
||||||
const groups = new Map<string, UpstreamModelGroup>()
|
const groups = new Map<string, UpstreamModelGroup>()
|
||||||
@@ -519,6 +529,15 @@ function initForm() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 处理模型选择变更
|
||||||
|
function handleModelChange(value: string) {
|
||||||
|
formData.value.modelId = value
|
||||||
|
const selectedModel = props.models.find(m => m.id === value)
|
||||||
|
if (selectedModel) {
|
||||||
|
upstreamModelSearch.value = selectedModel.provider_model_name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 切换 API 格式
|
// 切换 API 格式
|
||||||
function toggleApiFormat(format: string) {
|
function toggleApiFormat(format: string) {
|
||||||
const index = formData.value.apiFormats.indexOf(format)
|
const index = formData.value.apiFormats.indexOf(format)
|
||||||
|
|||||||
@@ -337,8 +337,40 @@
|
|||||||
{{ key.is_active ? '活跃' : '禁用' }}
|
{{ key.is_active ? '活跃' : '禁用' }}
|
||||||
</Badge>
|
</Badge>
|
||||||
</div>
|
</div>
|
||||||
<div class="text-[10px] font-mono text-muted-foreground truncate">
|
<div class="flex items-center gap-1">
|
||||||
{{ key.api_key_masked }}
|
<span class="text-[10px] font-mono text-muted-foreground truncate max-w-[180px]">
|
||||||
|
{{ revealedKeys.has(key.id) ? revealedKeys.get(key.id) : key.api_key_masked }}
|
||||||
|
</span>
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
size="icon"
|
||||||
|
class="h-5 w-5 shrink-0"
|
||||||
|
:title="revealedKeys.has(key.id) ? '隐藏密钥' : '显示密钥'"
|
||||||
|
:disabled="revealingKeyId === key.id"
|
||||||
|
@click.stop="toggleKeyReveal(key)"
|
||||||
|
>
|
||||||
|
<Loader2
|
||||||
|
v-if="revealingKeyId === key.id"
|
||||||
|
class="w-3 h-3 animate-spin"
|
||||||
|
/>
|
||||||
|
<EyeOff
|
||||||
|
v-else-if="revealedKeys.has(key.id)"
|
||||||
|
class="w-3 h-3"
|
||||||
|
/>
|
||||||
|
<Eye
|
||||||
|
v-else
|
||||||
|
class="w-3 h-3"
|
||||||
|
/>
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
size="icon"
|
||||||
|
class="h-5 w-5 shrink-0"
|
||||||
|
title="复制密钥"
|
||||||
|
@click.stop="copyFullKey(key)"
|
||||||
|
>
|
||||||
|
<Copy class="w-3 h-3" />
|
||||||
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class="flex items-center gap-1.5 ml-auto shrink-0">
|
<div class="flex items-center gap-1.5 ml-auto shrink-0">
|
||||||
@@ -531,6 +563,7 @@
|
|||||||
<!-- 模型名称映射 -->
|
<!-- 模型名称映射 -->
|
||||||
<ModelAliasesTab
|
<ModelAliasesTab
|
||||||
v-if="provider"
|
v-if="provider"
|
||||||
|
ref="modelAliasesTabRef"
|
||||||
:key="`aliases-${provider.id}`"
|
:key="`aliases-${provider.id}`"
|
||||||
:provider="provider"
|
:provider="provider"
|
||||||
@refresh="handleRelatedDataRefresh"
|
@refresh="handleRelatedDataRefresh"
|
||||||
@@ -653,13 +686,16 @@ import {
|
|||||||
Power,
|
Power,
|
||||||
Layers,
|
Layers,
|
||||||
GripVertical,
|
GripVertical,
|
||||||
Copy
|
Copy,
|
||||||
|
Eye,
|
||||||
|
EyeOff
|
||||||
} from 'lucide-vue-next'
|
} from 'lucide-vue-next'
|
||||||
import { useEscapeKey } from '@/composables/useEscapeKey'
|
import { useEscapeKey } from '@/composables/useEscapeKey'
|
||||||
import Button from '@/components/ui/button.vue'
|
import Button from '@/components/ui/button.vue'
|
||||||
import Badge from '@/components/ui/badge.vue'
|
import Badge from '@/components/ui/badge.vue'
|
||||||
import Card from '@/components/ui/card.vue'
|
import Card from '@/components/ui/card.vue'
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
|
import { useClipboard } from '@/composables/useClipboard'
|
||||||
import { getProvider, getProviderEndpoints } from '@/api/endpoints'
|
import { getProvider, getProviderEndpoints } from '@/api/endpoints'
|
||||||
import {
|
import {
|
||||||
KeyFormDialog,
|
KeyFormDialog,
|
||||||
@@ -679,6 +715,7 @@ import {
|
|||||||
updateEndpoint,
|
updateEndpoint,
|
||||||
updateEndpointKey,
|
updateEndpointKey,
|
||||||
batchUpdateKeyPriority,
|
batchUpdateKeyPriority,
|
||||||
|
revealEndpointKey,
|
||||||
type ProviderEndpoint,
|
type ProviderEndpoint,
|
||||||
type EndpointAPIKey,
|
type EndpointAPIKey,
|
||||||
type Model
|
type Model
|
||||||
@@ -705,6 +742,7 @@ const emit = defineEmits<{
|
|||||||
}>()
|
}>()
|
||||||
|
|
||||||
const { error: showError, success: showSuccess } = useToast()
|
const { error: showError, success: showSuccess } = useToast()
|
||||||
|
const { copyToClipboard } = useClipboard()
|
||||||
|
|
||||||
const loading = ref(false)
|
const loading = ref(false)
|
||||||
const provider = ref<any>(null)
|
const provider = ref<any>(null)
|
||||||
@@ -728,6 +766,10 @@ const recoveringEndpointId = ref<string | null>(null)
|
|||||||
const togglingEndpointId = ref<string | null>(null)
|
const togglingEndpointId = ref<string | null>(null)
|
||||||
const togglingKeyId = ref<string | null>(null)
|
const togglingKeyId = ref<string | null>(null)
|
||||||
|
|
||||||
|
// 密钥显示状态:key_id -> 完整密钥
|
||||||
|
const revealedKeys = ref<Map<string, string>>(new Map())
|
||||||
|
const revealingKeyId = ref<string | null>(null)
|
||||||
|
|
||||||
// 模型相关状态
|
// 模型相关状态
|
||||||
const modelFormDialogOpen = ref(false)
|
const modelFormDialogOpen = ref(false)
|
||||||
const editingModel = ref<Model | null>(null)
|
const editingModel = ref<Model | null>(null)
|
||||||
@@ -735,6 +777,9 @@ const deleteModelConfirmOpen = ref(false)
|
|||||||
const modelToDelete = ref<Model | null>(null)
|
const modelToDelete = ref<Model | null>(null)
|
||||||
const batchAssignDialogOpen = ref(false)
|
const batchAssignDialogOpen = ref(false)
|
||||||
|
|
||||||
|
// ModelAliasesTab 组件引用
|
||||||
|
const modelAliasesTabRef = ref<InstanceType<typeof ModelAliasesTab> | null>(null)
|
||||||
|
|
||||||
// 拖动排序相关状态
|
// 拖动排序相关状态
|
||||||
const dragState = ref({
|
const dragState = ref({
|
||||||
isDragging: false,
|
isDragging: false,
|
||||||
@@ -756,7 +801,9 @@ const hasBlockingDialogOpen = computed(() =>
|
|||||||
deleteKeyConfirmOpen.value ||
|
deleteKeyConfirmOpen.value ||
|
||||||
modelFormDialogOpen.value ||
|
modelFormDialogOpen.value ||
|
||||||
deleteModelConfirmOpen.value ||
|
deleteModelConfirmOpen.value ||
|
||||||
batchAssignDialogOpen.value
|
batchAssignDialogOpen.value ||
|
||||||
|
// 检测 ModelAliasesTab 子组件的 Dialog 是否打开
|
||||||
|
modelAliasesTabRef.value?.dialogOpen
|
||||||
)
|
)
|
||||||
|
|
||||||
// 监听 providerId 变化
|
// 监听 providerId 变化
|
||||||
@@ -792,6 +839,9 @@ watch(() => props.open, (newOpen) => {
|
|||||||
currentEndpoint.value = null
|
currentEndpoint.value = null
|
||||||
editingKey.value = null
|
editingKey.value = null
|
||||||
keyToDelete.value = null
|
keyToDelete.value = null
|
||||||
|
|
||||||
|
// 清除已显示的密钥(安全考虑)
|
||||||
|
revealedKeys.value.clear()
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -880,6 +930,43 @@ function handleConfigKeyModels(key: EndpointAPIKey) {
|
|||||||
keyAllowedModelsDialogOpen.value = true
|
keyAllowedModelsDialogOpen.value = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 切换密钥显示/隐藏
|
||||||
|
async function toggleKeyReveal(key: EndpointAPIKey) {
|
||||||
|
if (revealedKeys.value.has(key.id)) {
|
||||||
|
// 已显示,隐藏它
|
||||||
|
revealedKeys.value.delete(key.id)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 未显示,调用 API 获取完整密钥
|
||||||
|
revealingKeyId.value = key.id
|
||||||
|
try {
|
||||||
|
const result = await revealEndpointKey(key.id)
|
||||||
|
revealedKeys.value.set(key.id, result.api_key)
|
||||||
|
} catch (err: any) {
|
||||||
|
showError(err.response?.data?.detail || '获取密钥失败', '错误')
|
||||||
|
} finally {
|
||||||
|
revealingKeyId.value = null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 复制完整密钥
|
||||||
|
async function copyFullKey(key: EndpointAPIKey) {
|
||||||
|
// 如果已经显示了,直接复制
|
||||||
|
if (revealedKeys.value.has(key.id)) {
|
||||||
|
copyToClipboard(revealedKeys.value.get(key.id)!)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 否则先获取再复制
|
||||||
|
try {
|
||||||
|
const result = await revealEndpointKey(key.id)
|
||||||
|
copyToClipboard(result.api_key)
|
||||||
|
} catch (err: any) {
|
||||||
|
showError(err.response?.data?.detail || '获取密钥失败', '错误')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function handleDeleteKey(key: EndpointAPIKey) {
|
function handleDeleteKey(key: EndpointAPIKey) {
|
||||||
keyToDelete.value = key
|
keyToDelete.value = key
|
||||||
deleteKeyConfirmOpen.value = true
|
deleteKeyConfirmOpen.value = true
|
||||||
@@ -1244,16 +1331,6 @@ function getHealthScoreBarColor(score: number): string {
|
|||||||
return 'bg-red-500 dark:bg-red-400'
|
return 'bg-red-500 dark:bg-red-400'
|
||||||
}
|
}
|
||||||
|
|
||||||
// 复制到剪贴板
|
|
||||||
async function copyToClipboard(text: string) {
|
|
||||||
try {
|
|
||||||
await navigator.clipboard.writeText(text)
|
|
||||||
showSuccess('已复制到剪贴板')
|
|
||||||
} catch {
|
|
||||||
showError('复制失败', '错误')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 加载 Provider 信息
|
// 加载 Provider 信息
|
||||||
async function loadProvider() {
|
async function loadProvider() {
|
||||||
if (!props.providerId) return
|
if (!props.providerId) return
|
||||||
|
|||||||
@@ -110,8 +110,9 @@
|
|||||||
<div
|
<div
|
||||||
v-for="mapping in group.aliases"
|
v-for="mapping in group.aliases"
|
||||||
:key="mapping.name"
|
:key="mapping.name"
|
||||||
class="flex items-center gap-2 py-1"
|
class="flex items-center justify-between gap-2 py-1"
|
||||||
>
|
>
|
||||||
|
<div class="flex items-center gap-2 flex-1 min-w-0">
|
||||||
<!-- 优先级标签 -->
|
<!-- 优先级标签 -->
|
||||||
<span class="inline-flex items-center justify-center w-5 h-5 rounded bg-background border text-xs font-medium shrink-0">
|
<span class="inline-flex items-center justify-center w-5 h-5 rounded bg-background border text-xs font-medium shrink-0">
|
||||||
{{ mapping.priority }}
|
{{ mapping.priority }}
|
||||||
@@ -121,6 +122,19 @@
|
|||||||
{{ mapping.name }}
|
{{ mapping.name }}
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
|
<!-- 测试按钮 -->
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
size="icon"
|
||||||
|
class="h-7 w-7 shrink-0"
|
||||||
|
title="测试映射"
|
||||||
|
:disabled="testingMapping === `${group.model.id}-${group.apiFormatsKey}-${mapping.name}`"
|
||||||
|
@click="testMapping(group, mapping)"
|
||||||
|
>
|
||||||
|
<Loader2 v-if="testingMapping === `${group.model.id}-${group.apiFormatsKey}-${mapping.name}`" class="w-3 h-3 animate-spin" />
|
||||||
|
<Play v-else class="w-3 h-3" />
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -166,18 +180,20 @@
|
|||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, computed, onMounted, watch } from 'vue'
|
import { ref, computed, onMounted, watch } from 'vue'
|
||||||
import { Tag, Plus, Edit, Trash2, ChevronRight } from 'lucide-vue-next'
|
import { Tag, Plus, Edit, Trash2, ChevronRight, Loader2, Play } from 'lucide-vue-next'
|
||||||
import { Card, Button, Badge } from '@/components/ui'
|
import { Card, Button, Badge } from '@/components/ui'
|
||||||
import AlertDialog from '@/components/common/AlertDialog.vue'
|
import AlertDialog from '@/components/common/AlertDialog.vue'
|
||||||
import ModelMappingDialog, { type AliasGroup } from '../ModelMappingDialog.vue'
|
import ModelMappingDialog, { type AliasGroup } from '../ModelMappingDialog.vue'
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
import {
|
import {
|
||||||
getProviderModels,
|
getProviderModels,
|
||||||
|
testModel,
|
||||||
API_FORMAT_LABELS,
|
API_FORMAT_LABELS,
|
||||||
type Model,
|
type Model,
|
||||||
type ProviderModelAlias
|
type ProviderModelAlias
|
||||||
} from '@/api/endpoints'
|
} from '@/api/endpoints'
|
||||||
import { updateModel } from '@/api/endpoints/models'
|
import { updateModel } from '@/api/endpoints/models'
|
||||||
|
import { parseTestModelError } from '@/utils/errorParser'
|
||||||
|
|
||||||
const props = defineProps<{
|
const props = defineProps<{
|
||||||
provider: any
|
provider: any
|
||||||
@@ -196,6 +212,7 @@ const dialogOpen = ref(false)
|
|||||||
const deleteConfirmOpen = ref(false)
|
const deleteConfirmOpen = ref(false)
|
||||||
const editingGroup = ref<AliasGroup | null>(null)
|
const editingGroup = ref<AliasGroup | null>(null)
|
||||||
const deletingGroup = ref<AliasGroup | null>(null)
|
const deletingGroup = ref<AliasGroup | null>(null)
|
||||||
|
const testingMapping = ref<string | null>(null)
|
||||||
|
|
||||||
// 列表展开状态
|
// 列表展开状态
|
||||||
const expandedAliasGroups = ref<Set<string>>(new Set())
|
const expandedAliasGroups = ref<Set<string>>(new Set())
|
||||||
@@ -337,6 +354,49 @@ async function onDialogSaved() {
|
|||||||
emit('refresh')
|
emit('refresh')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 测试模型映射
|
||||||
|
async function testMapping(group: any, mapping: any) {
|
||||||
|
const testingKey = `${group.model.id}-${group.apiFormatsKey}-${mapping.name}`
|
||||||
|
testingMapping.value = testingKey
|
||||||
|
|
||||||
|
try {
|
||||||
|
// 根据分组的 API 格式来确定应该使用的格式
|
||||||
|
let apiFormat = null
|
||||||
|
if (group.apiFormats.length === 1) {
|
||||||
|
apiFormat = group.apiFormats[0]
|
||||||
|
} else if (group.apiFormats.length === 0) {
|
||||||
|
// 如果没有指定格式,但分组显示为"全部",则使用模型的默认格式
|
||||||
|
apiFormat = group.model.effective_api_format || group.model.api_format
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = await testModel({
|
||||||
|
provider_id: props.provider.id,
|
||||||
|
model_name: mapping.name, // 使用映射名称进行测试
|
||||||
|
message: "hello",
|
||||||
|
api_format: apiFormat
|
||||||
|
})
|
||||||
|
|
||||||
|
if (result.success) {
|
||||||
|
showSuccess(`映射 "${mapping.name}" 测试成功`)
|
||||||
|
|
||||||
|
// 如果有响应内容,可以显示更多信息
|
||||||
|
if (result.data?.response?.choices?.[0]?.message?.content) {
|
||||||
|
const content = result.data.response.choices[0].message.content
|
||||||
|
showSuccess(`测试成功,响应: ${content.substring(0, 100)}${content.length > 100 ? '...' : ''}`)
|
||||||
|
} else if (result.data?.content_preview) {
|
||||||
|
showSuccess(`流式测试成功,预览: ${result.data.content_preview}`)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
showError(`映射测试失败: ${parseTestModelError(result)}`)
|
||||||
|
}
|
||||||
|
} catch (err: any) {
|
||||||
|
const errorMsg = err.response?.data?.detail || err.message || '测试请求失败'
|
||||||
|
showError(`映射测试失败: ${errorMsg}`)
|
||||||
|
} finally {
|
||||||
|
testingMapping.value = null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 监听 provider 变化
|
// 监听 provider 变化
|
||||||
watch(() => props.provider?.id, (newId) => {
|
watch(() => props.provider?.id, (newId) => {
|
||||||
if (newId) {
|
if (newId) {
|
||||||
@@ -349,4 +409,9 @@ onMounted(() => {
|
|||||||
loadModels()
|
loadModels()
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// 暴露给父组件,用于检测是否有弹窗打开
|
||||||
|
defineExpose({
|
||||||
|
dialogOpen: computed(() => dialogOpen.value || deleteConfirmOpen.value)
|
||||||
|
})
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
@@ -213,6 +213,7 @@ import { Box, Edit, Trash2, Layers, Eye, Wrench, Zap, Brain, Power, Copy, Image
|
|||||||
import Card from '@/components/ui/card.vue'
|
import Card from '@/components/ui/card.vue'
|
||||||
import Button from '@/components/ui/button.vue'
|
import Button from '@/components/ui/button.vue'
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
|
import { useClipboard } from '@/composables/useClipboard'
|
||||||
import { getProviderModels, type Model } from '@/api/endpoints'
|
import { getProviderModels, type Model } from '@/api/endpoints'
|
||||||
import { updateModel } from '@/api/endpoints/models'
|
import { updateModel } from '@/api/endpoints/models'
|
||||||
|
|
||||||
@@ -227,6 +228,7 @@ const emit = defineEmits<{
|
|||||||
}>()
|
}>()
|
||||||
|
|
||||||
const { error: showError, success: showSuccess } = useToast()
|
const { error: showError, success: showSuccess } = useToast()
|
||||||
|
const { copyToClipboard } = useClipboard()
|
||||||
|
|
||||||
// 状态
|
// 状态
|
||||||
const loading = ref(false)
|
const loading = ref(false)
|
||||||
@@ -244,12 +246,7 @@ const sortedModels = computed(() => {
|
|||||||
|
|
||||||
// 复制模型 ID 到剪贴板
|
// 复制模型 ID 到剪贴板
|
||||||
async function copyModelId(modelId: string) {
|
async function copyModelId(modelId: string) {
|
||||||
try {
|
await copyToClipboard(modelId)
|
||||||
await navigator.clipboard.writeText(modelId)
|
|
||||||
showSuccess('已复制到剪贴板')
|
|
||||||
} catch {
|
|
||||||
showError('复制失败', '错误')
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 加载模型
|
// 加载模型
|
||||||
|
|||||||
@@ -473,6 +473,7 @@
|
|||||||
import { ref, watch, computed } from 'vue'
|
import { ref, watch, computed } from 'vue'
|
||||||
import Button from '@/components/ui/button.vue'
|
import Button from '@/components/ui/button.vue'
|
||||||
import { useEscapeKey } from '@/composables/useEscapeKey'
|
import { useEscapeKey } from '@/composables/useEscapeKey'
|
||||||
|
import { useClipboard } from '@/composables/useClipboard'
|
||||||
import Card from '@/components/ui/card.vue'
|
import Card from '@/components/ui/card.vue'
|
||||||
import Badge from '@/components/ui/badge.vue'
|
import Badge from '@/components/ui/badge.vue'
|
||||||
import Separator from '@/components/ui/separator.vue'
|
import Separator from '@/components/ui/separator.vue'
|
||||||
@@ -505,6 +506,7 @@ const copiedStates = ref<Record<string, boolean>>({})
|
|||||||
const viewMode = ref<'compare' | 'formatted' | 'raw'>('compare')
|
const viewMode = ref<'compare' | 'formatted' | 'raw'>('compare')
|
||||||
const currentExpandDepth = ref(1)
|
const currentExpandDepth = ref(1)
|
||||||
const dataSource = ref<'client' | 'provider'>('client')
|
const dataSource = ref<'client' | 'provider'>('client')
|
||||||
|
const { copyToClipboard } = useClipboard()
|
||||||
const historicalPricing = ref<{
|
const historicalPricing = ref<{
|
||||||
input_price: string
|
input_price: string
|
||||||
output_price: string
|
output_price: string
|
||||||
@@ -784,7 +786,7 @@ function copyJsonToClipboard(tabName: string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (data) {
|
if (data) {
|
||||||
navigator.clipboard.writeText(JSON.stringify(data, null, 2))
|
copyToClipboard(JSON.stringify(data, null, 2), false)
|
||||||
copiedStates.value[tabName] = true
|
copiedStates.value[tabName] = true
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
copiedStates.value[tabName] = false
|
copiedStates.value[tabName] = false
|
||||||
|
|||||||
@@ -86,6 +86,34 @@
|
|||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div
|
||||||
|
v-if="isEditMode && form.password.length > 0"
|
||||||
|
class="space-y-2"
|
||||||
|
>
|
||||||
|
<Label class="text-sm font-medium">
|
||||||
|
确认新密码 <span class="text-muted-foreground">*</span>
|
||||||
|
</Label>
|
||||||
|
<Input
|
||||||
|
:id="`pwd-confirm-${formNonce}`"
|
||||||
|
v-model="form.confirmPassword"
|
||||||
|
type="password"
|
||||||
|
autocomplete="new-password"
|
||||||
|
data-form-type="other"
|
||||||
|
data-lpignore="true"
|
||||||
|
:name="`confirm-${formNonce}`"
|
||||||
|
required
|
||||||
|
minlength="6"
|
||||||
|
placeholder="再次输入新密码"
|
||||||
|
class="h-10"
|
||||||
|
/>
|
||||||
|
<p
|
||||||
|
v-if="form.confirmPassword.length > 0 && form.password !== form.confirmPassword"
|
||||||
|
class="text-xs text-destructive"
|
||||||
|
>
|
||||||
|
两次输入的密码不一致
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
<div class="space-y-2">
|
<div class="space-y-2">
|
||||||
<Label
|
<Label
|
||||||
for="form-email"
|
for="form-email"
|
||||||
@@ -423,6 +451,7 @@ const apiFormats = ref<Array<{ value: string; label: string }>>([])
|
|||||||
const form = ref({
|
const form = ref({
|
||||||
username: '',
|
username: '',
|
||||||
password: '',
|
password: '',
|
||||||
|
confirmPassword: '',
|
||||||
email: '',
|
email: '',
|
||||||
quota: 10,
|
quota: 10,
|
||||||
role: 'user' as 'admin' | 'user',
|
role: 'user' as 'admin' | 'user',
|
||||||
@@ -443,6 +472,7 @@ function resetForm() {
|
|||||||
form.value = {
|
form.value = {
|
||||||
username: '',
|
username: '',
|
||||||
password: '',
|
password: '',
|
||||||
|
confirmPassword: '',
|
||||||
email: '',
|
email: '',
|
||||||
quota: 10,
|
quota: 10,
|
||||||
role: 'user',
|
role: 'user',
|
||||||
@@ -461,6 +491,7 @@ function loadUserData() {
|
|||||||
form.value = {
|
form.value = {
|
||||||
username: props.user.username,
|
username: props.user.username,
|
||||||
password: '',
|
password: '',
|
||||||
|
confirmPassword: '',
|
||||||
email: props.user.email || '',
|
email: props.user.email || '',
|
||||||
quota: props.user.quota_usd == null ? 10 : props.user.quota_usd,
|
quota: props.user.quota_usd == null ? 10 : props.user.quota_usd,
|
||||||
role: props.user.role,
|
role: props.user.role,
|
||||||
@@ -486,7 +517,9 @@ const isFormValid = computed(() => {
|
|||||||
const hasUsername = form.value.username.trim().length > 0
|
const hasUsername = form.value.username.trim().length > 0
|
||||||
const hasEmail = form.value.email.trim().length > 0
|
const hasEmail = form.value.email.trim().length > 0
|
||||||
const hasPassword = isEditMode.value || form.value.password.length >= 6
|
const hasPassword = isEditMode.value || form.value.password.length >= 6
|
||||||
return hasUsername && hasEmail && hasPassword
|
// 编辑模式下如果填写了密码,必须确认密码一致
|
||||||
|
const passwordConfirmed = !isEditMode.value || form.value.password.length === 0 || form.value.password === form.value.confirmPassword
|
||||||
|
return hasUsername && hasEmail && hasPassword && passwordConfirmed
|
||||||
})
|
})
|
||||||
|
|
||||||
// 加载访问控制选项
|
// 加载访问控制选项
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
|
|
||||||
import type { User, LoginResponse } from '@/api/auth'
|
import type { User, LoginResponse } from '@/api/auth'
|
||||||
import type { DashboardStatsResponse, RecentRequest, ProviderStatus, DailyStatsResponse } from '@/api/dashboard'
|
import type { DashboardStatsResponse, RecentRequest, ProviderStatus, DailyStatsResponse } from '@/api/dashboard'
|
||||||
import type { User as AdminUser, ApiKey } from '@/api/users'
|
import type { User as AdminUser } from '@/api/users'
|
||||||
import type { AdminApiKeysResponse } from '@/api/admin'
|
import type { AdminApiKeysResponse } from '@/api/admin'
|
||||||
import type { Profile, UsageResponse } from '@/api/me'
|
import type { Profile, UsageResponse } from '@/api/me'
|
||||||
import type { ProviderWithEndpointsSummary, GlobalModelResponse } from '@/api/endpoints/types'
|
import type { ProviderWithEndpointsSummary, GlobalModelResponse } from '@/api/endpoints/types'
|
||||||
@@ -185,18 +185,20 @@ export const MOCK_DASHBOARD_STATS: DashboardStatsResponse = {
|
|||||||
output: 700000,
|
output: 700000,
|
||||||
cache_creation: 50000,
|
cache_creation: 50000,
|
||||||
cache_read: 200000
|
cache_read: 200000
|
||||||
}
|
},
|
||||||
|
// 普通用户专用字段
|
||||||
|
monthly_cost: 45.67
|
||||||
}
|
}
|
||||||
|
|
||||||
export const MOCK_RECENT_REQUESTS: RecentRequest[] = [
|
export const MOCK_RECENT_REQUESTS: RecentRequest[] = [
|
||||||
{ id: 'req-001', user: 'alice', model: 'claude-sonnet-4-20250514', tokens: 15234, time: '2 分钟前' },
|
{ id: 'req-001', user: 'alice', model: 'claude-sonnet-4-5-20250929', tokens: 15234, time: '2 分钟前' },
|
||||||
{ id: 'req-002', user: 'bob', model: 'gpt-4o', tokens: 8765, time: '5 分钟前' },
|
{ id: 'req-002', user: 'bob', model: 'gpt-5.1', tokens: 8765, time: '5 分钟前' },
|
||||||
{ id: 'req-003', user: 'charlie', model: 'claude-opus-4-20250514', tokens: 32100, time: '8 分钟前' },
|
{ id: 'req-003', user: 'charlie', model: 'claude-opus-4-5-20251101', tokens: 32100, time: '8 分钟前' },
|
||||||
{ id: 'req-004', user: 'diana', model: 'gemini-2.0-flash', tokens: 4521, time: '12 分钟前' },
|
{ id: 'req-004', user: 'diana', model: 'gemini-3-pro-preview', tokens: 4521, time: '12 分钟前' },
|
||||||
{ id: 'req-005', user: 'eve', model: 'claude-sonnet-4-20250514', tokens: 9876, time: '15 分钟前' },
|
{ id: 'req-005', user: 'eve', model: 'claude-sonnet-4-5-20250929', tokens: 9876, time: '15 分钟前' },
|
||||||
{ id: 'req-006', user: 'frank', model: 'gpt-4o-mini', tokens: 2345, time: '18 分钟前' },
|
{ id: 'req-006', user: 'frank', model: 'gpt-5.1-codex-mini', tokens: 2345, time: '18 分钟前' },
|
||||||
{ id: 'req-007', user: 'grace', model: 'claude-haiku-3-5-20241022', tokens: 6789, time: '22 分钟前' },
|
{ id: 'req-007', user: 'grace', model: 'claude-haiku-4-5-20251001', tokens: 6789, time: '22 分钟前' },
|
||||||
{ id: 'req-008', user: 'henry', model: 'gemini-2.5-pro', tokens: 12345, time: '25 分钟前' }
|
{ id: 'req-008', user: 'henry', model: 'gemini-3-pro-preview', tokens: 12345, time: '25 分钟前' }
|
||||||
]
|
]
|
||||||
|
|
||||||
export const MOCK_PROVIDER_STATUS: ProviderStatus[] = [
|
export const MOCK_PROVIDER_STATUS: ProviderStatus[] = [
|
||||||
@@ -231,11 +233,11 @@ function generateDailyStats(): DailyStatsResponse {
|
|||||||
unique_models: 8 + Math.floor(Math.random() * 5),
|
unique_models: 8 + Math.floor(Math.random() * 5),
|
||||||
unique_providers: 4 + Math.floor(Math.random() * 3),
|
unique_providers: 4 + Math.floor(Math.random() * 3),
|
||||||
model_breakdown: [
|
model_breakdown: [
|
||||||
{ model: 'claude-sonnet-4-20250514', requests: Math.floor(baseRequests * 0.35), tokens: Math.floor(baseTokens * 0.35), cost: Number((baseCost * 0.35).toFixed(2)) },
|
{ model: 'claude-sonnet-4-5-20250929', requests: Math.floor(baseRequests * 0.35), tokens: Math.floor(baseTokens * 0.35), cost: Number((baseCost * 0.35).toFixed(2)) },
|
||||||
{ model: 'gpt-4o', requests: Math.floor(baseRequests * 0.25), tokens: Math.floor(baseTokens * 0.25), cost: Number((baseCost * 0.25).toFixed(2)) },
|
{ model: 'gpt-5.1', requests: Math.floor(baseRequests * 0.25), tokens: Math.floor(baseTokens * 0.25), cost: Number((baseCost * 0.25).toFixed(2)) },
|
||||||
{ model: 'claude-opus-4-20250514', requests: Math.floor(baseRequests * 0.15), tokens: Math.floor(baseTokens * 0.15), cost: Number((baseCost * 0.20).toFixed(2)) },
|
{ model: 'claude-opus-4-5-20251101', requests: Math.floor(baseRequests * 0.15), tokens: Math.floor(baseTokens * 0.15), cost: Number((baseCost * 0.20).toFixed(2)) },
|
||||||
{ model: 'gemini-2.0-flash', requests: Math.floor(baseRequests * 0.15), tokens: Math.floor(baseTokens * 0.15), cost: Number((baseCost * 0.10).toFixed(2)) },
|
{ model: 'gemini-3-pro-preview', requests: Math.floor(baseRequests * 0.15), tokens: Math.floor(baseTokens * 0.15), cost: Number((baseCost * 0.10).toFixed(2)) },
|
||||||
{ model: 'claude-haiku-3-5-20241022', requests: Math.floor(baseRequests * 0.10), tokens: Math.floor(baseTokens * 0.10), cost: Number((baseCost * 0.10).toFixed(2)) }
|
{ model: 'claude-haiku-4-5-20251001', requests: Math.floor(baseRequests * 0.10), tokens: Math.floor(baseTokens * 0.10), cost: Number((baseCost * 0.10).toFixed(2)) }
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -243,11 +245,11 @@ function generateDailyStats(): DailyStatsResponse {
|
|||||||
return {
|
return {
|
||||||
daily_stats: dailyStats,
|
daily_stats: dailyStats,
|
||||||
model_summary: [
|
model_summary: [
|
||||||
{ model: 'claude-sonnet-4-20250514', requests: 2456, tokens: 8500000, cost: 125.45, avg_response_time: 1.2, cost_per_request: 0.051, tokens_per_request: 3461 },
|
{ model: 'claude-sonnet-4-5-20250929', requests: 2456, tokens: 8500000, cost: 125.45, avg_response_time: 1.2, cost_per_request: 0.051, tokens_per_request: 3461 },
|
||||||
{ model: 'gpt-4o', requests: 1823, tokens: 6200000, cost: 98.32, avg_response_time: 0.9, cost_per_request: 0.054, tokens_per_request: 3401 },
|
{ model: 'gpt-5.1', requests: 1823, tokens: 6200000, cost: 98.32, avg_response_time: 0.9, cost_per_request: 0.054, tokens_per_request: 3401 },
|
||||||
{ model: 'claude-opus-4-20250514', requests: 987, tokens: 4100000, cost: 156.78, avg_response_time: 2.1, cost_per_request: 0.159, tokens_per_request: 4154 },
|
{ model: 'claude-opus-4-5-20251101', requests: 987, tokens: 4100000, cost: 156.78, avg_response_time: 2.1, cost_per_request: 0.159, tokens_per_request: 4154 },
|
||||||
{ model: 'gemini-2.0-flash', requests: 1234, tokens: 3800000, cost: 28.56, avg_response_time: 0.6, cost_per_request: 0.023, tokens_per_request: 3079 },
|
{ model: 'gemini-3-pro-preview', requests: 1234, tokens: 3800000, cost: 28.56, avg_response_time: 0.6, cost_per_request: 0.023, tokens_per_request: 3079 },
|
||||||
{ model: 'claude-haiku-3-5-20241022', requests: 2100, tokens: 5200000, cost: 32.10, avg_response_time: 0.5, cost_per_request: 0.015, tokens_per_request: 2476 }
|
{ model: 'claude-haiku-4-5-20251001', requests: 2100, tokens: 5200000, cost: 32.10, avg_response_time: 0.5, cost_per_request: 0.015, tokens_per_request: 2476 }
|
||||||
],
|
],
|
||||||
period: {
|
period: {
|
||||||
start_date: dailyStats[0].date,
|
start_date: dailyStats[0].date,
|
||||||
@@ -336,7 +338,7 @@ export const MOCK_ALL_USERS: AdminUser[] = [
|
|||||||
|
|
||||||
// ========== API Key 数据 ==========
|
// ========== API Key 数据 ==========
|
||||||
|
|
||||||
export const MOCK_USER_API_KEYS: ApiKey[] = [
|
export const MOCK_USER_API_KEYS = [
|
||||||
{
|
{
|
||||||
id: 'key-uuid-001',
|
id: 'key-uuid-001',
|
||||||
key_display: 'sk-ae...x7f9',
|
key_display: 'sk-ae...x7f9',
|
||||||
@@ -346,7 +348,8 @@ export const MOCK_USER_API_KEYS: ApiKey[] = [
|
|||||||
is_active: true,
|
is_active: true,
|
||||||
is_standalone: false,
|
is_standalone: false,
|
||||||
total_requests: 1234,
|
total_requests: 1234,
|
||||||
total_cost_usd: 45.67
|
total_cost_usd: 45.67,
|
||||||
|
force_capabilities: null
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'key-uuid-002',
|
id: 'key-uuid-002',
|
||||||
@@ -357,7 +360,8 @@ export const MOCK_USER_API_KEYS: ApiKey[] = [
|
|||||||
is_active: true,
|
is_active: true,
|
||||||
is_standalone: false,
|
is_standalone: false,
|
||||||
total_requests: 5678,
|
total_requests: 5678,
|
||||||
total_cost_usd: 123.45
|
total_cost_usd: 123.45,
|
||||||
|
force_capabilities: { cache_1h: true }
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'key-uuid-003',
|
id: 'key-uuid-003',
|
||||||
@@ -367,7 +371,8 @@ export const MOCK_USER_API_KEYS: ApiKey[] = [
|
|||||||
is_active: false,
|
is_active: false,
|
||||||
is_standalone: false,
|
is_standalone: false,
|
||||||
total_requests: 100,
|
total_requests: 100,
|
||||||
total_cost_usd: 2.34
|
total_cost_usd: 2.34,
|
||||||
|
force_capabilities: null
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -813,16 +818,16 @@ export const MOCK_USAGE_RESPONSE: UsageResponse = {
|
|||||||
quota_usd: 100,
|
quota_usd: 100,
|
||||||
used_usd: 45.32,
|
used_usd: 45.32,
|
||||||
summary_by_model: [
|
summary_by_model: [
|
||||||
{ model: 'claude-sonnet-4-20250514', requests: 456, input_tokens: 650000, output_tokens: 250000, total_tokens: 900000, total_cost_usd: 18.50, actual_total_cost_usd: 13.50 },
|
{ model: 'claude-sonnet-4-5-20250929', requests: 456, input_tokens: 650000, output_tokens: 250000, total_tokens: 900000, total_cost_usd: 18.50, actual_total_cost_usd: 13.50 },
|
||||||
{ model: 'gpt-4o', requests: 312, input_tokens: 480000, output_tokens: 180000, total_tokens: 660000, total_cost_usd: 12.30, actual_total_cost_usd: 9.20 },
|
{ model: 'gpt-5.1', requests: 312, input_tokens: 480000, output_tokens: 180000, total_tokens: 660000, total_cost_usd: 12.30, actual_total_cost_usd: 9.20 },
|
||||||
{ model: 'claude-haiku-3-5-20241022', requests: 289, input_tokens: 420000, output_tokens: 170000, total_tokens: 590000, total_cost_usd: 8.50, actual_total_cost_usd: 6.30 },
|
{ model: 'claude-haiku-4-5-20251001', requests: 289, input_tokens: 420000, output_tokens: 170000, total_tokens: 590000, total_cost_usd: 8.50, actual_total_cost_usd: 6.30 },
|
||||||
{ model: 'gemini-2.0-flash', requests: 177, input_tokens: 250000, output_tokens: 100000, total_tokens: 350000, total_cost_usd: 6.37, actual_total_cost_usd: 4.33 }
|
{ model: 'gemini-3-pro-preview', requests: 177, input_tokens: 250000, output_tokens: 100000, total_tokens: 350000, total_cost_usd: 6.37, actual_total_cost_usd: 4.33 }
|
||||||
],
|
],
|
||||||
records: [
|
records: [
|
||||||
{
|
{
|
||||||
id: 'usage-001',
|
id: 'usage-001',
|
||||||
provider: 'anthropic',
|
provider: 'anthropic',
|
||||||
model: 'claude-sonnet-4-20250514',
|
model: 'claude-sonnet-4-5-20250929',
|
||||||
input_tokens: 1500,
|
input_tokens: 1500,
|
||||||
output_tokens: 800,
|
output_tokens: 800,
|
||||||
total_tokens: 2300,
|
total_tokens: 2300,
|
||||||
@@ -837,7 +842,7 @@ export const MOCK_USAGE_RESPONSE: UsageResponse = {
|
|||||||
{
|
{
|
||||||
id: 'usage-002',
|
id: 'usage-002',
|
||||||
provider: 'openai',
|
provider: 'openai',
|
||||||
model: 'gpt-4o',
|
model: 'gpt-5.1',
|
||||||
input_tokens: 2000,
|
input_tokens: 2000,
|
||||||
output_tokens: 500,
|
output_tokens: 500,
|
||||||
total_tokens: 2500,
|
total_tokens: 2500,
|
||||||
|
|||||||
@@ -405,10 +405,10 @@ function getUsageRecords() {
|
|||||||
|
|
||||||
// Mock 映射数据
|
// Mock 映射数据
|
||||||
const MOCK_ALIASES = [
|
const MOCK_ALIASES = [
|
||||||
{ id: 'alias-001', source_model: 'claude-4-sonnet', target_global_model_id: 'gm-001', target_global_model_name: 'claude-sonnet-4-20250514', target_global_model_display_name: 'Claude Sonnet 4', provider_id: null, provider_name: null, scope: 'global', mapping_type: 'alias', is_active: true, created_at: '2024-01-01T00:00:00Z', updated_at: '2024-01-01T00:00:00Z' },
|
{ id: 'alias-001', source_model: 'claude-4-sonnet', target_global_model_id: 'gm-003', target_global_model_name: 'claude-sonnet-4-5-20250929', target_global_model_display_name: 'Claude Sonnet 4.5', provider_id: null, provider_name: null, scope: 'global', mapping_type: 'alias', is_active: true, created_at: '2024-01-01T00:00:00Z', updated_at: '2024-01-01T00:00:00Z' },
|
||||||
{ id: 'alias-002', source_model: 'claude-4-opus', target_global_model_id: 'gm-002', target_global_model_name: 'claude-opus-4-20250514', target_global_model_display_name: 'Claude Opus 4', provider_id: null, provider_name: null, scope: 'global', mapping_type: 'alias', is_active: true, created_at: '2024-01-01T00:00:00Z', updated_at: '2024-01-01T00:00:00Z' },
|
{ id: 'alias-002', source_model: 'claude-4-opus', target_global_model_id: 'gm-002', target_global_model_name: 'claude-opus-4-5-20251101', target_global_model_display_name: 'Claude Opus 4.5', provider_id: null, provider_name: null, scope: 'global', mapping_type: 'alias', is_active: true, created_at: '2024-01-01T00:00:00Z', updated_at: '2024-01-01T00:00:00Z' },
|
||||||
{ id: 'alias-003', source_model: 'gpt4o', target_global_model_id: 'gm-004', target_global_model_name: 'gpt-4o', target_global_model_display_name: 'GPT-4o', provider_id: null, provider_name: null, scope: 'global', mapping_type: 'alias', is_active: true, created_at: '2024-01-01T00:00:00Z', updated_at: '2024-01-01T00:00:00Z' },
|
{ id: 'alias-003', source_model: 'gpt5', target_global_model_id: 'gm-006', target_global_model_name: 'gpt-5.1', target_global_model_display_name: 'GPT-5.1', provider_id: null, provider_name: null, scope: 'global', mapping_type: 'alias', is_active: true, created_at: '2024-01-01T00:00:00Z', updated_at: '2024-01-01T00:00:00Z' },
|
||||||
{ id: 'alias-004', source_model: 'gemini-flash', target_global_model_id: 'gm-005', target_global_model_name: 'gemini-2.0-flash', target_global_model_display_name: 'Gemini 2.0 Flash', provider_id: null, provider_name: null, scope: 'global', mapping_type: 'alias', is_active: true, created_at: '2024-01-01T00:00:00Z', updated_at: '2024-01-01T00:00:00Z' }
|
{ id: 'alias-004', source_model: 'gemini-pro', target_global_model_id: 'gm-005', target_global_model_name: 'gemini-3-pro-preview', target_global_model_display_name: 'Gemini 3 Pro Preview', provider_id: null, provider_name: null, scope: 'global', mapping_type: 'alias', is_active: true, created_at: '2024-01-01T00:00:00Z', updated_at: '2024-01-01T00:00:00Z' }
|
||||||
]
|
]
|
||||||
|
|
||||||
// Mock Endpoint Keys
|
// Mock Endpoint Keys
|
||||||
@@ -2172,10 +2172,10 @@ function generateIntervalTimelineData(
|
|||||||
|
|
||||||
// 模型列表(用于按模型区分颜色)
|
// 模型列表(用于按模型区分颜色)
|
||||||
const models = [
|
const models = [
|
||||||
'claude-sonnet-4-20250514',
|
'claude-sonnet-4-5-20250929',
|
||||||
'claude-3-5-sonnet-20241022',
|
'claude-haiku-4-5-20251001',
|
||||||
'claude-3-5-haiku-20241022',
|
'claude-opus-4-5-20251101',
|
||||||
'claude-opus-4-20250514'
|
'gpt-5.1'
|
||||||
]
|
]
|
||||||
|
|
||||||
// 生成模拟的请求间隔数据
|
// 生成模拟的请求间隔数据
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ export const useUsersStore = defineStore('users', () => {
|
|||||||
try {
|
try {
|
||||||
users.value = await usersApi.getAllUsers()
|
users.value = await usersApi.getAllUsers()
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error.value = err.response?.data?.detail || '获取用户列表失败'
|
error.value = err.response?.data?.error?.message || err.response?.data?.detail || '获取用户列表失败'
|
||||||
} finally {
|
} finally {
|
||||||
loading.value = false
|
loading.value = false
|
||||||
}
|
}
|
||||||
@@ -29,7 +29,7 @@ export const useUsersStore = defineStore('users', () => {
|
|||||||
users.value.push(newUser)
|
users.value.push(newUser)
|
||||||
return newUser
|
return newUser
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error.value = err.response?.data?.detail || '创建用户失败'
|
error.value = err.response?.data?.error?.message || err.response?.data?.detail || '创建用户失败'
|
||||||
throw err
|
throw err
|
||||||
} finally {
|
} finally {
|
||||||
loading.value = false
|
loading.value = false
|
||||||
@@ -52,7 +52,7 @@ export const useUsersStore = defineStore('users', () => {
|
|||||||
}
|
}
|
||||||
return updatedUser
|
return updatedUser
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error.value = err.response?.data?.detail || '更新用户失败'
|
error.value = err.response?.data?.error?.message || err.response?.data?.detail || '更新用户失败'
|
||||||
throw err
|
throw err
|
||||||
} finally {
|
} finally {
|
||||||
loading.value = false
|
loading.value = false
|
||||||
@@ -67,7 +67,7 @@ export const useUsersStore = defineStore('users', () => {
|
|||||||
await usersApi.deleteUser(userId)
|
await usersApi.deleteUser(userId)
|
||||||
users.value = users.value.filter(u => u.id !== userId)
|
users.value = users.value.filter(u => u.id !== userId)
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error.value = err.response?.data?.detail || '删除用户失败'
|
error.value = err.response?.data?.error?.message || err.response?.data?.detail || '删除用户失败'
|
||||||
throw err
|
throw err
|
||||||
} finally {
|
} finally {
|
||||||
loading.value = false
|
loading.value = false
|
||||||
@@ -78,7 +78,7 @@ export const useUsersStore = defineStore('users', () => {
|
|||||||
try {
|
try {
|
||||||
return await usersApi.getUserApiKeys(userId)
|
return await usersApi.getUserApiKeys(userId)
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error.value = err.response?.data?.detail || '获取 API Keys 失败'
|
error.value = err.response?.data?.error?.message || err.response?.data?.detail || '获取 API Keys 失败'
|
||||||
throw err
|
throw err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -87,7 +87,7 @@ export const useUsersStore = defineStore('users', () => {
|
|||||||
try {
|
try {
|
||||||
return await usersApi.createApiKey(userId, name)
|
return await usersApi.createApiKey(userId, name)
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error.value = err.response?.data?.detail || '创建 API Key 失败'
|
error.value = err.response?.data?.error?.message || err.response?.data?.detail || '创建 API Key 失败'
|
||||||
throw err
|
throw err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -96,7 +96,7 @@ export const useUsersStore = defineStore('users', () => {
|
|||||||
try {
|
try {
|
||||||
await usersApi.deleteApiKey(userId, keyId)
|
await usersApi.deleteApiKey(userId, keyId)
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error.value = err.response?.data?.detail || '删除 API Key 失败'
|
error.value = err.response?.data?.error?.message || err.response?.data?.detail || '删除 API Key 失败'
|
||||||
throw err
|
throw err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -110,7 +110,7 @@ export const useUsersStore = defineStore('users', () => {
|
|||||||
// 刷新用户列表以获取最新数据
|
// 刷新用户列表以获取最新数据
|
||||||
await fetchUsers()
|
await fetchUsers()
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error.value = err.response?.data?.detail || '重置配额失败'
|
error.value = err.response?.data?.error?.message || err.response?.data?.detail || '重置配额失败'
|
||||||
throw err
|
throw err
|
||||||
} finally {
|
} finally {
|
||||||
loading.value = false
|
loading.value = false
|
||||||
|
|||||||
@@ -198,3 +198,49 @@ export function parseApiErrorShort(err: unknown, defaultMessage: string = '操
|
|||||||
const lines = fullError.split('\n')
|
const lines = fullError.split('\n')
|
||||||
return lines[0] || defaultMessage
|
return lines[0] || defaultMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 解析模型测试响应的错误信息
|
||||||
|
* @param result 测试响应结果
|
||||||
|
* @returns 格式化的错误信息
|
||||||
|
*/
|
||||||
|
export function parseTestModelError(result: {
|
||||||
|
error?: string
|
||||||
|
data?: {
|
||||||
|
response?: {
|
||||||
|
status_code?: number
|
||||||
|
error?: string | { message?: string }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}): string {
|
||||||
|
let errorMsg = result.error || '测试失败'
|
||||||
|
|
||||||
|
// 检查HTTP状态码错误
|
||||||
|
if (result.data?.response?.status_code) {
|
||||||
|
const status = result.data.response.status_code
|
||||||
|
if (status === 403) {
|
||||||
|
errorMsg = '认证失败: API密钥无效或客户端类型不被允许'
|
||||||
|
} else if (status === 401) {
|
||||||
|
errorMsg = '认证失败: API密钥无效或已过期'
|
||||||
|
} else if (status === 404) {
|
||||||
|
errorMsg = '模型不存在: 请检查模型名称是否正确'
|
||||||
|
} else if (status === 429) {
|
||||||
|
errorMsg = '请求频率过高: 请稍后重试'
|
||||||
|
} else if (status >= 500) {
|
||||||
|
errorMsg = `服务器错误: HTTP ${status}`
|
||||||
|
} else {
|
||||||
|
errorMsg = `请求失败: HTTP ${status}`
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试从错误响应中提取更多信息
|
||||||
|
if (result.data?.response?.error) {
|
||||||
|
if (typeof result.data.response.error === 'string') {
|
||||||
|
errorMsg = result.data.response.error
|
||||||
|
} else if (result.data.response.error?.message) {
|
||||||
|
errorMsg = result.data.response.error.message
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return errorMsg
|
||||||
|
}
|
||||||
|
|||||||
@@ -650,6 +650,7 @@
|
|||||||
import { ref, computed, onMounted } from 'vue'
|
import { ref, computed, onMounted } from 'vue'
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
import { useConfirm } from '@/composables/useConfirm'
|
import { useConfirm } from '@/composables/useConfirm'
|
||||||
|
import { useClipboard } from '@/composables/useClipboard'
|
||||||
import { adminApi, type AdminApiKey, type CreateStandaloneApiKeyRequest } from '@/api/admin'
|
import { adminApi, type AdminApiKey, type CreateStandaloneApiKeyRequest } from '@/api/admin'
|
||||||
|
|
||||||
import {
|
import {
|
||||||
@@ -693,6 +694,7 @@ import { log } from '@/utils/logger'
|
|||||||
|
|
||||||
const { success, error } = useToast()
|
const { success, error } = useToast()
|
||||||
const { confirmDanger } = useConfirm()
|
const { confirmDanger } = useConfirm()
|
||||||
|
const { copyToClipboard } = useClipboard()
|
||||||
|
|
||||||
const apiKeys = ref<AdminApiKey[]>([])
|
const apiKeys = ref<AdminApiKey[]>([])
|
||||||
const loading = ref(false)
|
const loading = ref(false)
|
||||||
@@ -927,20 +929,14 @@ function selectKey() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async function copyKey() {
|
async function copyKey() {
|
||||||
try {
|
await copyToClipboard(newKeyValue.value)
|
||||||
await navigator.clipboard.writeText(newKeyValue.value)
|
|
||||||
success('API Key 已复制到剪贴板')
|
|
||||||
} catch {
|
|
||||||
error('复制失败,请手动复制')
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async function copyKeyPrefix(apiKey: AdminApiKey) {
|
async function copyKeyPrefix(apiKey: AdminApiKey) {
|
||||||
try {
|
try {
|
||||||
// 调用后端 API 获取完整密钥
|
// 调用后端 API 获取完整密钥
|
||||||
const response = await adminApi.getFullApiKey(apiKey.id)
|
const response = await adminApi.getFullApiKey(apiKey.id)
|
||||||
await navigator.clipboard.writeText(response.key)
|
await copyToClipboard(response.key)
|
||||||
success('完整密钥已复制到剪贴板')
|
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
log.error('复制密钥失败:', err)
|
log.error('复制密钥失败:', err)
|
||||||
error('复制失败,请重试')
|
error('复制失败,请重试')
|
||||||
@@ -1046,9 +1042,10 @@ async function handleKeyFormSubmit(data: StandaloneKeyFormData) {
|
|||||||
rate_limit: data.rate_limit,
|
rate_limit: data.rate_limit,
|
||||||
expire_days: data.never_expire ? null : (data.expire_days || null),
|
expire_days: data.never_expire ? null : (data.expire_days || null),
|
||||||
auto_delete_on_expiry: data.auto_delete_on_expiry,
|
auto_delete_on_expiry: data.auto_delete_on_expiry,
|
||||||
allowed_providers: data.allowed_providers.length > 0 ? data.allowed_providers : undefined,
|
// 空数组表示清除限制(允许全部),后端会将空数组存为 NULL
|
||||||
allowed_api_formats: data.allowed_api_formats.length > 0 ? data.allowed_api_formats : undefined,
|
allowed_providers: data.allowed_providers,
|
||||||
allowed_models: data.allowed_models.length > 0 ? data.allowed_models : undefined
|
allowed_api_formats: data.allowed_api_formats,
|
||||||
|
allowed_models: data.allowed_models
|
||||||
}
|
}
|
||||||
await adminApi.updateApiKey(data.id, updateData)
|
await adminApi.updateApiKey(data.id, updateData)
|
||||||
success('API Key 更新成功')
|
success('API Key 更新成功')
|
||||||
@@ -1064,9 +1061,10 @@ async function handleKeyFormSubmit(data: StandaloneKeyFormData) {
|
|||||||
rate_limit: data.rate_limit,
|
rate_limit: data.rate_limit,
|
||||||
expire_days: data.never_expire ? null : (data.expire_days || null),
|
expire_days: data.never_expire ? null : (data.expire_days || null),
|
||||||
auto_delete_on_expiry: data.auto_delete_on_expiry,
|
auto_delete_on_expiry: data.auto_delete_on_expiry,
|
||||||
allowed_providers: data.allowed_providers.length > 0 ? data.allowed_providers : undefined,
|
// 空数组表示不设置限制(允许全部),后端会将空数组存为 NULL
|
||||||
allowed_api_formats: data.allowed_api_formats.length > 0 ? data.allowed_api_formats : undefined,
|
allowed_providers: data.allowed_providers,
|
||||||
allowed_models: data.allowed_models.length > 0 ? data.allowed_models : undefined
|
allowed_api_formats: data.allowed_api_formats,
|
||||||
|
allowed_models: data.allowed_models
|
||||||
}
|
}
|
||||||
const response = await adminApi.createStandaloneApiKey(createData)
|
const response = await adminApi.createStandaloneApiKey(createData)
|
||||||
newKeyValue.value = response.key
|
newKeyValue.value = response.key
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ const clearingRowAffinityKey = ref<string | null>(null)
|
|||||||
const currentPage = ref(1)
|
const currentPage = ref(1)
|
||||||
const pageSize = ref(20)
|
const pageSize = ref(20)
|
||||||
const currentTime = ref(Math.floor(Date.now() / 1000))
|
const currentTime = ref(Math.floor(Date.now() / 1000))
|
||||||
|
const analysisHoursSelectOpen = ref(false)
|
||||||
|
|
||||||
// ==================== 模型映射缓存 ====================
|
// ==================== 模型映射缓存 ====================
|
||||||
|
|
||||||
@@ -1056,7 +1057,7 @@ onBeforeUnmount(() => {
|
|||||||
<span class="text-xs text-muted-foreground hidden sm:inline">分析用户请求间隔,推荐合适的缓存 TTL</span>
|
<span class="text-xs text-muted-foreground hidden sm:inline">分析用户请求间隔,推荐合适的缓存 TTL</span>
|
||||||
</div>
|
</div>
|
||||||
<div class="flex flex-wrap items-center gap-2">
|
<div class="flex flex-wrap items-center gap-2">
|
||||||
<Select v-model="analysisHours">
|
<Select v-model="analysisHours" v-model:open="analysisHoursSelectOpen">
|
||||||
<SelectTrigger class="w-24 sm:w-28 h-8">
|
<SelectTrigger class="w-24 sm:w-28 h-8">
|
||||||
<SelectValue placeholder="时间段" />
|
<SelectValue placeholder="时间段" />
|
||||||
</SelectTrigger>
|
</SelectTrigger>
|
||||||
|
|||||||
@@ -713,6 +713,7 @@ import ProviderModelFormDialog from '@/features/providers/components/ProviderMod
|
|||||||
import type { Model } from '@/api/endpoints'
|
import type { Model } from '@/api/endpoints'
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
import { useConfirm } from '@/composables/useConfirm'
|
import { useConfirm } from '@/composables/useConfirm'
|
||||||
|
import { useClipboard } from '@/composables/useClipboard'
|
||||||
import { useRowClick } from '@/composables/useRowClick'
|
import { useRowClick } from '@/composables/useRowClick'
|
||||||
import { parseApiError } from '@/utils/errorParser'
|
import { parseApiError } from '@/utils/errorParser'
|
||||||
import {
|
import {
|
||||||
@@ -736,6 +737,7 @@ import {
|
|||||||
updateGlobalModel,
|
updateGlobalModel,
|
||||||
deleteGlobalModel,
|
deleteGlobalModel,
|
||||||
batchAssignToProviders,
|
batchAssignToProviders,
|
||||||
|
getGlobalModelProviders,
|
||||||
type GlobalModelResponse,
|
type GlobalModelResponse,
|
||||||
} from '@/api/global-models'
|
} from '@/api/global-models'
|
||||||
import { log } from '@/utils/logger'
|
import { log } from '@/utils/logger'
|
||||||
@@ -743,6 +745,7 @@ import { getProvidersSummary } from '@/api/endpoints/providers'
|
|||||||
import { getAllCapabilities, type CapabilityDefinition } from '@/api/endpoints'
|
import { getAllCapabilities, type CapabilityDefinition } from '@/api/endpoints'
|
||||||
|
|
||||||
const { success, error: showError } = useToast()
|
const { success, error: showError } = useToast()
|
||||||
|
const { copyToClipboard } = useClipboard()
|
||||||
|
|
||||||
// 状态
|
// 状态
|
||||||
const loading = ref(false)
|
const loading = ref(false)
|
||||||
@@ -1066,16 +1069,6 @@ function handleRowClick(event: MouseEvent, model: GlobalModelResponse) {
|
|||||||
selectModel(model)
|
selectModel(model)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 复制到剪贴板
|
|
||||||
async function copyToClipboard(text: string) {
|
|
||||||
try {
|
|
||||||
await navigator.clipboard.writeText(text)
|
|
||||||
success('已复制')
|
|
||||||
} catch {
|
|
||||||
showError('复制失败')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async function selectModel(model: GlobalModelResponse) {
|
async function selectModel(model: GlobalModelResponse) {
|
||||||
selectedModel.value = model
|
selectedModel.value = model
|
||||||
detailTab.value = 'basic'
|
detailTab.value = 'basic'
|
||||||
@@ -1088,18 +1081,11 @@ async function selectModel(model: GlobalModelResponse) {
|
|||||||
async function loadModelProviders(_globalModelId: string) {
|
async function loadModelProviders(_globalModelId: string) {
|
||||||
loadingModelProviders.value = true
|
loadingModelProviders.value = true
|
||||||
try {
|
try {
|
||||||
// 使用 ModelCatalog API 获取详细的关联提供商信息
|
// 使用新的 API 获取所有关联提供商(包括非活跃的)
|
||||||
const { getModelCatalog } = await import('@/api/endpoints')
|
const response = await getGlobalModelProviders(_globalModelId)
|
||||||
const catalogResponse = await getModelCatalog()
|
|
||||||
|
|
||||||
// 查找当前 GlobalModel 对应的 catalog item
|
// 转换为展示格式
|
||||||
const catalogItem = catalogResponse.models.find(
|
selectedModelProviders.value = response.providers.map(p => ({
|
||||||
m => m.global_model_name === selectedModel.value?.name
|
|
||||||
)
|
|
||||||
|
|
||||||
if (catalogItem) {
|
|
||||||
// 转换为展示格式,包含完整的模型实现信息
|
|
||||||
selectedModelProviders.value = catalogItem.providers.map(p => ({
|
|
||||||
id: p.provider_id,
|
id: p.provider_id,
|
||||||
model_id: p.model_id,
|
model_id: p.model_id,
|
||||||
display_name: p.provider_display_name || p.provider_name,
|
display_name: p.provider_display_name || p.provider_name,
|
||||||
@@ -1121,9 +1107,6 @@ async function loadModelProviders(_globalModelId: string) {
|
|||||||
supports_function_calling: p.supports_function_calling,
|
supports_function_calling: p.supports_function_calling,
|
||||||
supports_streaming: p.supports_streaming
|
supports_streaming: p.supports_streaming
|
||||||
}))
|
}))
|
||||||
} else {
|
|
||||||
selectedModelProviders.value = []
|
|
||||||
}
|
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
log.error('加载关联提供商失败:', err)
|
log.error('加载关联提供商失败:', err)
|
||||||
showError(parseApiError(err, '加载关联提供商失败'), '错误')
|
showError(parseApiError(err, '加载关联提供商失败'), '错误')
|
||||||
|
|||||||
@@ -723,9 +723,19 @@ async function handleDeleteProvider(provider: ProviderWithEndpointsSummary) {
|
|||||||
// 切换提供商状态
|
// 切换提供商状态
|
||||||
async function toggleProviderStatus(provider: ProviderWithEndpointsSummary) {
|
async function toggleProviderStatus(provider: ProviderWithEndpointsSummary) {
|
||||||
try {
|
try {
|
||||||
await updateProvider(provider.id, { is_active: !provider.is_active })
|
const newStatus = !provider.is_active
|
||||||
provider.is_active = !provider.is_active
|
await updateProvider(provider.id, { is_active: newStatus })
|
||||||
showSuccess(provider.is_active ? '提供商已启用' : '提供商已停用')
|
|
||||||
|
// 更新抽屉内部的 provider 对象
|
||||||
|
provider.is_active = newStatus
|
||||||
|
|
||||||
|
// 同时更新主页面 providers 数组中的对象,实现无感更新
|
||||||
|
const targetProvider = providers.value.find(p => p.id === provider.id)
|
||||||
|
if (targetProvider) {
|
||||||
|
targetProvider.is_active = newStatus
|
||||||
|
}
|
||||||
|
|
||||||
|
showSuccess(newStatus ? '提供商已启用' : '提供商已停用')
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
showError(err.response?.data?.detail || '操作失败', '错误')
|
showError(err.response?.data?.detail || '操作失败', '错误')
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -701,6 +701,7 @@ import { ref, computed, onMounted, watch } from 'vue'
|
|||||||
import { useUsersStore } from '@/stores/users'
|
import { useUsersStore } from '@/stores/users'
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
import { useConfirm } from '@/composables/useConfirm'
|
import { useConfirm } from '@/composables/useConfirm'
|
||||||
|
import { useClipboard } from '@/composables/useClipboard'
|
||||||
import { usageApi, type UsageByUser } from '@/api/usage'
|
import { usageApi, type UsageByUser } from '@/api/usage'
|
||||||
import { adminApi } from '@/api/admin'
|
import { adminApi } from '@/api/admin'
|
||||||
|
|
||||||
@@ -748,6 +749,7 @@ import { log } from '@/utils/logger'
|
|||||||
|
|
||||||
const { success, error } = useToast()
|
const { success, error } = useToast()
|
||||||
const { confirmDanger, confirmWarning } = useConfirm()
|
const { confirmDanger, confirmWarning } = useConfirm()
|
||||||
|
const { copyToClipboard } = useClipboard()
|
||||||
const usersStore = useUsersStore()
|
const usersStore = useUsersStore()
|
||||||
|
|
||||||
// 用户表单对话框状态
|
// 用户表单对话框状态
|
||||||
@@ -875,7 +877,8 @@ async function toggleUserStatus(user: any) {
|
|||||||
const action = user.is_active ? '禁用' : '启用'
|
const action = user.is_active ? '禁用' : '启用'
|
||||||
const confirmed = await confirmDanger(
|
const confirmed = await confirmDanger(
|
||||||
`确定要${action}用户 ${user.username} 吗?`,
|
`确定要${action}用户 ${user.username} 吗?`,
|
||||||
`${action}用户`
|
`${action}用户`,
|
||||||
|
action
|
||||||
)
|
)
|
||||||
|
|
||||||
if (!confirmed) return
|
if (!confirmed) return
|
||||||
@@ -884,7 +887,7 @@ async function toggleUserStatus(user: any) {
|
|||||||
await usersStore.updateUser(user.id, { is_active: !user.is_active })
|
await usersStore.updateUser(user.id, { is_active: !user.is_active })
|
||||||
success(`用户已${action}`)
|
success(`用户已${action}`)
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error(err.response?.data?.detail || '未知错误', `${action}用户失败`)
|
error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', `${action}用户失败`)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -955,7 +958,7 @@ async function handleUserFormSubmit(data: UserFormData & { password?: string })
|
|||||||
closeUserFormDialog()
|
closeUserFormDialog()
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
const title = data.id ? '更新用户失败' : '创建用户失败'
|
const title = data.id ? '更新用户失败' : '创建用户失败'
|
||||||
error(err.response?.data?.detail || '未知错误', title)
|
error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', title)
|
||||||
} finally {
|
} finally {
|
||||||
userFormDialogRef.value?.setSaving(false)
|
userFormDialogRef.value?.setSaving(false)
|
||||||
}
|
}
|
||||||
@@ -989,7 +992,7 @@ async function createApiKey() {
|
|||||||
showNewApiKeyDialog.value = true
|
showNewApiKeyDialog.value = true
|
||||||
await loadUserApiKeys(selectedUser.value.id)
|
await loadUserApiKeys(selectedUser.value.id)
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error(err.response?.data?.detail || '未知错误', '创建 API Key 失败')
|
error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', '创建 API Key 失败')
|
||||||
} finally {
|
} finally {
|
||||||
creatingApiKey.value = false
|
creatingApiKey.value = false
|
||||||
}
|
}
|
||||||
@@ -1000,12 +1003,7 @@ function selectApiKey() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async function copyApiKey() {
|
async function copyApiKey() {
|
||||||
try {
|
await copyToClipboard(newApiKey.value)
|
||||||
await navigator.clipboard.writeText(newApiKey.value)
|
|
||||||
success('API Key已复制到剪贴板')
|
|
||||||
} catch {
|
|
||||||
error('复制失败,请手动复制')
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async function closeNewApiKeyDialog() {
|
async function closeNewApiKeyDialog() {
|
||||||
@@ -1026,7 +1024,7 @@ async function deleteApiKey(apiKey: any) {
|
|||||||
await loadUserApiKeys(selectedUser.value.id)
|
await loadUserApiKeys(selectedUser.value.id)
|
||||||
success('API Key已删除')
|
success('API Key已删除')
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error(err.response?.data?.detail || '未知错误', '删除 API Key 失败')
|
error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', '删除 API Key 失败')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1034,11 +1032,10 @@ async function copyFullKey(apiKey: any) {
|
|||||||
try {
|
try {
|
||||||
// 调用后端 API 获取完整密钥
|
// 调用后端 API 获取完整密钥
|
||||||
const response = await adminApi.getFullApiKey(apiKey.id)
|
const response = await adminApi.getFullApiKey(apiKey.id)
|
||||||
await navigator.clipboard.writeText(response.key)
|
await copyToClipboard(response.key)
|
||||||
success('完整密钥已复制到剪贴板')
|
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
log.error('复制密钥失败:', err)
|
log.error('复制密钥失败:', err)
|
||||||
error(err.response?.data?.detail || '未知错误', '复制密钥失败')
|
error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', '复制密钥失败')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1054,7 +1051,7 @@ async function resetQuota(user: any) {
|
|||||||
await usersStore.resetUserQuota(user.id)
|
await usersStore.resetUserQuota(user.id)
|
||||||
success('配额已重置')
|
success('配额已重置')
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error(err.response?.data?.detail || '未知错误', '重置配额失败')
|
error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', '重置配额失败')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1070,7 +1067,7 @@ async function deleteUser(user: any) {
|
|||||||
await usersStore.deleteUser(user.id)
|
await usersStore.deleteUser(user.id)
|
||||||
success('用户已删除')
|
success('用户已删除')
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error(err.response?.data?.detail || '未知错误', '删除用户失败')
|
error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', '删除用户失败')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
@@ -102,9 +102,9 @@
|
|||||||
<!-- Main Content -->
|
<!-- Main Content -->
|
||||||
<main class="relative z-10">
|
<main class="relative z-10">
|
||||||
<!-- Fixed Logo Container -->
|
<!-- Fixed Logo Container -->
|
||||||
<div class="fixed inset-0 z-20 pointer-events-none flex items-center justify-center overflow-hidden">
|
<div class="mt-4 fixed inset-0 z-20 pointer-events-none flex items-center justify-center overflow-hidden">
|
||||||
<div
|
<div
|
||||||
class="transform-gpu logo-container"
|
class="mt-16 transform-gpu logo-container"
|
||||||
:class="[currentSection === SECTIONS.HOME ? 'home-section' : '', `logo-transition-${scrollDirection}`]"
|
:class="[currentSection === SECTIONS.HOME ? 'home-section' : '', `logo-transition-${scrollDirection}`]"
|
||||||
:style="fixedLogoStyle"
|
:style="fixedLogoStyle"
|
||||||
>
|
>
|
||||||
@@ -151,7 +151,7 @@
|
|||||||
class="min-h-screen snap-start flex items-center justify-center px-16 lg:px-20 py-20"
|
class="min-h-screen snap-start flex items-center justify-center px-16 lg:px-20 py-20"
|
||||||
>
|
>
|
||||||
<div class="max-w-4xl mx-auto text-center">
|
<div class="max-w-4xl mx-auto text-center">
|
||||||
<div class="h-80 w-full mb-16" />
|
<div class="h-80 w-full mb-16 mt-8" />
|
||||||
<h1
|
<h1
|
||||||
class="mb-6 text-5xl md:text-7xl font-bold text-[#191919] dark:text-white leading-tight transition-all duration-700"
|
class="mb-6 text-5xl md:text-7xl font-bold text-[#191919] dark:text-white leading-tight transition-all duration-700"
|
||||||
:style="getTitleStyle(SECTIONS.HOME)"
|
:style="getTitleStyle(SECTIONS.HOME)"
|
||||||
@@ -166,7 +166,7 @@
|
|||||||
整合 Claude Code、Codex CLI、Gemini CLI 等多个 AI 编程助手
|
整合 Claude Code、Codex CLI、Gemini CLI 等多个 AI 编程助手
|
||||||
</p>
|
</p>
|
||||||
<button
|
<button
|
||||||
class="mt-16 transition-all duration-700 cursor-pointer hover:scale-110"
|
class="mt-8 transition-all duration-700 cursor-pointer hover:scale-110"
|
||||||
:style="getScrollIndicatorStyle(SECTIONS.HOME)"
|
:style="getScrollIndicatorStyle(SECTIONS.HOME)"
|
||||||
@click="scrollToSection(SECTIONS.CLAUDE)"
|
@click="scrollToSection(SECTIONS.CLAUDE)"
|
||||||
>
|
>
|
||||||
|
|||||||
@@ -145,10 +145,10 @@
|
|||||||
<DollarSign class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
<DollarSign class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
||||||
<div class="pr-6">
|
<div class="pr-6">
|
||||||
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
||||||
实际成本
|
本月费用
|
||||||
</p>
|
</p>
|
||||||
<p class="mt-1.5 sm:mt-2 text-lg sm:text-xl font-semibold text-foreground">
|
<p class="mt-1.5 sm:mt-2 text-lg sm:text-xl font-semibold text-foreground">
|
||||||
{{ formatCurrency(costStats.total_actual_cost) }}
|
{{ formatCurrency(costStats.total_cost) }}
|
||||||
</p>
|
</p>
|
||||||
<Badge
|
<Badge
|
||||||
v-if="costStats.cost_savings > 0"
|
v-if="costStats.cost_savings > 0"
|
||||||
@@ -162,14 +162,14 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- 普通用户:缓存统计 -->
|
<!-- 普通用户:月度统计 -->
|
||||||
<div
|
<div
|
||||||
v-else-if="!isAdmin && cacheStats && cacheStats.total_cache_tokens > 0"
|
v-else-if="!isAdmin && (hasCacheData || (userMonthlyCost !== null && userMonthlyCost > 0))"
|
||||||
class="mt-6"
|
class="mt-6"
|
||||||
>
|
>
|
||||||
<div class="mb-3 flex items-center justify-between">
|
<div class="mb-3 flex items-center justify-between">
|
||||||
<h3 class="text-sm font-medium text-foreground">
|
<h3 class="text-sm font-medium text-foreground">
|
||||||
本月缓存使用
|
本月统计
|
||||||
</h3>
|
</h3>
|
||||||
<Badge
|
<Badge
|
||||||
variant="outline"
|
variant="outline"
|
||||||
@@ -178,8 +178,16 @@
|
|||||||
Monthly
|
Monthly
|
||||||
</Badge>
|
</Badge>
|
||||||
</div>
|
</div>
|
||||||
<div class="grid grid-cols-2 gap-2 sm:gap-3 xl:grid-cols-4">
|
<div
|
||||||
<Card class="relative p-3 sm:p-4 border-book-cloth/30">
|
:class="[
|
||||||
|
'grid gap-2 sm:gap-3',
|
||||||
|
hasCacheData ? 'grid-cols-2 xl:grid-cols-4' : 'grid-cols-1 max-w-xs'
|
||||||
|
]"
|
||||||
|
>
|
||||||
|
<Card
|
||||||
|
v-if="cacheStats"
|
||||||
|
class="relative p-3 sm:p-4 border-book-cloth/30"
|
||||||
|
>
|
||||||
<Database class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
<Database class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
||||||
<div class="pr-6">
|
<div class="pr-6">
|
||||||
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
||||||
@@ -190,7 +198,10 @@
|
|||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
</Card>
|
</Card>
|
||||||
<Card class="relative p-3 sm:p-4 border-kraft/30">
|
<Card
|
||||||
|
v-if="cacheStats"
|
||||||
|
class="relative p-3 sm:p-4 border-kraft/30"
|
||||||
|
>
|
||||||
<Hash class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
<Hash class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
||||||
<div class="pr-6">
|
<div class="pr-6">
|
||||||
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
||||||
@@ -201,7 +212,10 @@
|
|||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
</Card>
|
</Card>
|
||||||
<Card class="relative p-3 sm:p-4 border-book-cloth/25">
|
<Card
|
||||||
|
v-if="cacheStats"
|
||||||
|
class="relative p-3 sm:p-4 border-book-cloth/25"
|
||||||
|
>
|
||||||
<Database class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
<Database class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
||||||
<div class="pr-6">
|
<div class="pr-6">
|
||||||
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
||||||
@@ -213,19 +227,16 @@
|
|||||||
</div>
|
</div>
|
||||||
</Card>
|
</Card>
|
||||||
<Card
|
<Card
|
||||||
v-if="tokenBreakdown"
|
v-if="userMonthlyCost !== null"
|
||||||
class="relative p-3 sm:p-4 border-manilla/40"
|
class="relative p-3 sm:p-4 border-manilla/40"
|
||||||
>
|
>
|
||||||
<Hash class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
<DollarSign class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
||||||
<div class="pr-6">
|
<div class="pr-6">
|
||||||
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
||||||
总Token
|
本月费用
|
||||||
</p>
|
</p>
|
||||||
<p class="mt-1.5 sm:mt-2 text-lg sm:text-xl font-semibold text-foreground">
|
<p class="mt-1.5 sm:mt-2 text-lg sm:text-xl font-semibold text-foreground">
|
||||||
{{ formatTokens((tokenBreakdown.input || 0) + (tokenBreakdown.output || 0)) }}
|
{{ formatCurrency(userMonthlyCost) }}
|
||||||
</p>
|
|
||||||
<p class="mt-0.5 sm:mt-1 text-[9px] sm:text-[10px] text-muted-foreground">
|
|
||||||
输入 {{ formatTokens(tokenBreakdown.input || 0) }} / 输出 {{ formatTokens(tokenBreakdown.output || 0) }}
|
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
</Card>
|
</Card>
|
||||||
@@ -831,6 +842,12 @@ const cacheStats = ref<{
|
|||||||
total_cache_tokens: number
|
total_cache_tokens: number
|
||||||
} | null>(null)
|
} | null>(null)
|
||||||
|
|
||||||
|
const userMonthlyCost = ref<number | null>(null)
|
||||||
|
|
||||||
|
const hasCacheData = computed(() =>
|
||||||
|
cacheStats.value && cacheStats.value.total_cache_tokens > 0
|
||||||
|
)
|
||||||
|
|
||||||
const tokenBreakdown = ref<{
|
const tokenBreakdown = ref<{
|
||||||
input: number
|
input: number
|
||||||
output: number
|
output: number
|
||||||
@@ -1086,6 +1103,7 @@ async function loadDashboardData() {
|
|||||||
} else {
|
} else {
|
||||||
if (statsData.cache_stats) cacheStats.value = statsData.cache_stats
|
if (statsData.cache_stats) cacheStats.value = statsData.cache_stats
|
||||||
if (statsData.token_breakdown) tokenBreakdown.value = statsData.token_breakdown
|
if (statsData.token_breakdown) tokenBreakdown.value = statsData.token_breakdown
|
||||||
|
if (statsData.monthly_cost !== undefined) userMonthlyCost.value = statsData.monthly_cost
|
||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
loading.value = false
|
loading.value = false
|
||||||
|
|||||||
@@ -301,6 +301,7 @@ function stopGlobalAutoRefresh() {
|
|||||||
function handleAutoRefreshChange(value: boolean) {
|
function handleAutoRefreshChange(value: boolean) {
|
||||||
globalAutoRefresh.value = value
|
globalAutoRefresh.value = value
|
||||||
if (value) {
|
if (value) {
|
||||||
|
refreshData() // 立即刷新一次
|
||||||
startGlobalAutoRefresh()
|
startGlobalAutoRefresh()
|
||||||
} else {
|
} else {
|
||||||
stopGlobalAutoRefresh()
|
stopGlobalAutoRefresh()
|
||||||
|
|||||||
@@ -342,6 +342,7 @@ import {
|
|||||||
Plus,
|
Plus,
|
||||||
} from 'lucide-vue-next'
|
} from 'lucide-vue-next'
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
|
import { useClipboard } from '@/composables/useClipboard'
|
||||||
import {
|
import {
|
||||||
Card,
|
Card,
|
||||||
Table,
|
Table,
|
||||||
@@ -370,6 +371,7 @@ import { useRowClick } from '@/composables/useRowClick'
|
|||||||
import { log } from '@/utils/logger'
|
import { log } from '@/utils/logger'
|
||||||
|
|
||||||
const { success, error: showError } = useToast()
|
const { success, error: showError } = useToast()
|
||||||
|
const { copyToClipboard } = useClipboard()
|
||||||
|
|
||||||
// 状态
|
// 状态
|
||||||
const loading = ref(false)
|
const loading = ref(false)
|
||||||
@@ -565,16 +567,6 @@ function hasTieredPricing(model: PublicGlobalModel): boolean {
|
|||||||
return (tiered?.tiers?.length || 0) > 1
|
return (tiered?.tiers?.length || 0) > 1
|
||||||
}
|
}
|
||||||
|
|
||||||
async function copyToClipboard(text: string) {
|
|
||||||
try {
|
|
||||||
await navigator.clipboard.writeText(text)
|
|
||||||
success('已复制')
|
|
||||||
} catch (err) {
|
|
||||||
log.error('复制失败:', err)
|
|
||||||
showError('复制失败')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
onMounted(() => {
|
onMounted(() => {
|
||||||
refreshData()
|
refreshData()
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -352,6 +352,7 @@ import {
|
|||||||
} from 'lucide-vue-next'
|
} from 'lucide-vue-next'
|
||||||
import { useEscapeKey } from '@/composables/useEscapeKey'
|
import { useEscapeKey } from '@/composables/useEscapeKey'
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
|
import { useClipboard } from '@/composables/useClipboard'
|
||||||
import Card from '@/components/ui/card.vue'
|
import Card from '@/components/ui/card.vue'
|
||||||
import Badge from '@/components/ui/badge.vue'
|
import Badge from '@/components/ui/badge.vue'
|
||||||
import Button from '@/components/ui/button.vue'
|
import Button from '@/components/ui/button.vue'
|
||||||
@@ -375,6 +376,7 @@ const emit = defineEmits<{
|
|||||||
}>()
|
}>()
|
||||||
|
|
||||||
const { success: showSuccess, error: showError } = useToast()
|
const { success: showSuccess, error: showError } = useToast()
|
||||||
|
const { copyToClipboard } = useClipboard()
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
model: PublicGlobalModel | null
|
model: PublicGlobalModel | null
|
||||||
@@ -408,15 +410,6 @@ function handleClose() {
|
|||||||
emit('update:open', false)
|
emit('update:open', false)
|
||||||
}
|
}
|
||||||
|
|
||||||
async function copyToClipboard(text: string) {
|
|
||||||
try {
|
|
||||||
await navigator.clipboard.writeText(text)
|
|
||||||
showSuccess('已复制')
|
|
||||||
} catch {
|
|
||||||
showError('复制失败')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function getFirstTierPrice(
|
function getFirstTierPrice(
|
||||||
tieredPricing: TieredPricingConfig | undefined | null,
|
tieredPricing: TieredPricingConfig | undefined | null,
|
||||||
priceKey: 'input_price_per_1m' | 'output_price_per_1m' | 'cache_creation_price_per_1m' | 'cache_read_price_per_1m'
|
priceKey: 'input_price_per_1m' | 'output_price_per_1m' | 'cache_creation_price_per_1m' | 'cache_read_price_per_1m'
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ authors = [
|
|||||||
classifiers = [
|
classifiers = [
|
||||||
"Development Status :: 4 - Beta",
|
"Development Status :: 4 - Beta",
|
||||||
"Intended Audience :: Developers",
|
"Intended Audience :: Developers",
|
||||||
"License :: OSI Approved :: MIT License",
|
"License :: Other/Proprietary License",
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
"Programming Language :: Python :: 3.8",
|
"Programming Language :: Python :: 3.8",
|
||||||
"Programming Language :: Python :: 3.9",
|
"Programming Language :: Python :: 3.9",
|
||||||
|
|||||||
@@ -80,6 +80,17 @@ async def get_keys_grouped_by_format(
|
|||||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/keys/{key_id}/reveal")
|
||||||
|
async def reveal_endpoint_key(
|
||||||
|
key_id: str,
|
||||||
|
request: Request,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
) -> dict:
|
||||||
|
"""获取完整的 API Key(用于查看和复制)"""
|
||||||
|
adapter = AdminRevealEndpointKeyAdapter(key_id=key_id)
|
||||||
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/keys/{key_id}")
|
@router.delete("/keys/{key_id}")
|
||||||
async def delete_endpoint_key(
|
async def delete_endpoint_key(
|
||||||
key_id: str,
|
key_id: str,
|
||||||
@@ -293,6 +304,30 @@ class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
|
|||||||
return EndpointAPIKeyResponse(**response_dict)
|
return EndpointAPIKeyResponse(**response_dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AdminRevealEndpointKeyAdapter(AdminApiAdapter):
|
||||||
|
"""获取完整的 API Key(用于查看和复制)"""
|
||||||
|
|
||||||
|
key_id: str
|
||||||
|
|
||||||
|
async def handle(self, context): # type: ignore[override]
|
||||||
|
db = context.db
|
||||||
|
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
|
||||||
|
if not key:
|
||||||
|
raise NotFoundException(f"Key {self.key_id} 不存在")
|
||||||
|
|
||||||
|
try:
|
||||||
|
decrypted_key = crypto_service.decrypt(key.api_key)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"解密 Key 失败: ID={self.key_id}, Error={e}")
|
||||||
|
raise InvalidRequestException(
|
||||||
|
"无法解密 API Key,可能是加密密钥已更改。请重新添加该密钥。"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"[REVEAL] 查看完整 Key: ID={self.key_id}, Name={key.name}")
|
||||||
|
return {"api_key": decrypted_key}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AdminDeleteEndpointKeyAdapter(AdminApiAdapter):
|
class AdminDeleteEndpointKeyAdapter(AdminApiAdapter):
|
||||||
key_id: str
|
key_id: str
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ GlobalModel Admin API
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query, Request
|
from fastapi import APIRouter, Depends, Query, Request
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@@ -19,9 +19,11 @@ from src.models.pydantic_models import (
|
|||||||
BatchAssignToProvidersResponse,
|
BatchAssignToProvidersResponse,
|
||||||
GlobalModelCreate,
|
GlobalModelCreate,
|
||||||
GlobalModelListResponse,
|
GlobalModelListResponse,
|
||||||
|
GlobalModelProvidersResponse,
|
||||||
GlobalModelResponse,
|
GlobalModelResponse,
|
||||||
GlobalModelUpdate,
|
GlobalModelUpdate,
|
||||||
GlobalModelWithStats,
|
GlobalModelWithStats,
|
||||||
|
ModelCatalogProviderDetail,
|
||||||
)
|
)
|
||||||
from src.services.model.global_model import GlobalModelService
|
from src.services.model.global_model import GlobalModelService
|
||||||
|
|
||||||
@@ -108,6 +110,17 @@ async def batch_assign_to_providers(
|
|||||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{global_model_id}/providers", response_model=GlobalModelProvidersResponse)
|
||||||
|
async def get_global_model_providers(
|
||||||
|
request: Request,
|
||||||
|
global_model_id: str,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
) -> GlobalModelProvidersResponse:
|
||||||
|
"""获取 GlobalModel 的所有关联提供商(包括非活跃的)"""
|
||||||
|
adapter = AdminGetGlobalModelProvidersAdapter(global_model_id=global_model_id)
|
||||||
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||||
|
|
||||||
|
|
||||||
# ========== Adapters ==========
|
# ========== Adapters ==========
|
||||||
|
|
||||||
|
|
||||||
@@ -275,3 +288,61 @@ class AdminBatchAssignToProvidersAdapter(AdminApiAdapter):
|
|||||||
logger.info(f"批量为 Provider 添加 GlobalModel: global_model_id={self.global_model_id} success={len(result['success'])} errors={len(result['errors'])}")
|
logger.info(f"批量为 Provider 添加 GlobalModel: global_model_id={self.global_model_id} success={len(result['success'])} errors={len(result['errors'])}")
|
||||||
|
|
||||||
return BatchAssignToProvidersResponse(**result)
|
return BatchAssignToProvidersResponse(**result)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AdminGetGlobalModelProvidersAdapter(AdminApiAdapter):
|
||||||
|
"""获取 GlobalModel 的所有关联提供商(包括非活跃的)"""
|
||||||
|
|
||||||
|
global_model_id: str
|
||||||
|
|
||||||
|
async def handle(self, context): # type: ignore[override]
|
||||||
|
from sqlalchemy.orm import joinedload
|
||||||
|
|
||||||
|
from src.models.database import Model
|
||||||
|
|
||||||
|
global_model = GlobalModelService.get_global_model(context.db, self.global_model_id)
|
||||||
|
|
||||||
|
# 获取所有关联的 Model(包括非活跃的)
|
||||||
|
models = (
|
||||||
|
context.db.query(Model)
|
||||||
|
.options(joinedload(Model.provider), joinedload(Model.global_model))
|
||||||
|
.filter(Model.global_model_id == global_model.id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_entries = []
|
||||||
|
for model in models:
|
||||||
|
provider = model.provider
|
||||||
|
if not provider:
|
||||||
|
continue
|
||||||
|
|
||||||
|
effective_tiered = model.get_effective_tiered_pricing()
|
||||||
|
tier_count = len(effective_tiered.get("tiers", [])) if effective_tiered else 1
|
||||||
|
|
||||||
|
provider_entries.append(
|
||||||
|
ModelCatalogProviderDetail(
|
||||||
|
provider_id=provider.id,
|
||||||
|
provider_name=provider.name,
|
||||||
|
provider_display_name=provider.display_name,
|
||||||
|
model_id=model.id,
|
||||||
|
target_model=model.provider_model_name,
|
||||||
|
input_price_per_1m=model.get_effective_input_price(),
|
||||||
|
output_price_per_1m=model.get_effective_output_price(),
|
||||||
|
cache_creation_price_per_1m=model.get_effective_cache_creation_price(),
|
||||||
|
cache_read_price_per_1m=model.get_effective_cache_read_price(),
|
||||||
|
cache_1h_creation_price_per_1m=model.get_effective_1h_cache_creation_price(),
|
||||||
|
price_per_request=model.get_effective_price_per_request(),
|
||||||
|
effective_tiered_pricing=effective_tiered,
|
||||||
|
tier_count=tier_count,
|
||||||
|
supports_vision=model.get_effective_supports_vision(),
|
||||||
|
supports_function_calling=model.get_effective_supports_function_calling(),
|
||||||
|
supports_streaming=model.get_effective_supports_streaming(),
|
||||||
|
is_active=bool(model.is_active),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return GlobalModelProvidersResponse(
|
||||||
|
providers=provider_entries,
|
||||||
|
total=len(provider_entries),
|
||||||
|
)
|
||||||
|
|||||||
@@ -32,6 +32,17 @@ class ModelsQueryRequest(BaseModel):
|
|||||||
api_key_id: Optional[str] = None
|
api_key_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelRequest(BaseModel):
|
||||||
|
"""模型测试请求"""
|
||||||
|
|
||||||
|
provider_id: str
|
||||||
|
model_name: str
|
||||||
|
api_key_id: Optional[str] = None
|
||||||
|
stream: bool = False
|
||||||
|
message: Optional[str] = "你好"
|
||||||
|
api_format: Optional[str] = None # 指定使用的API格式,如果不指定则使用端点的默认格式
|
||||||
|
|
||||||
|
|
||||||
# ============ API Endpoints ============
|
# ============ API Endpoints ============
|
||||||
|
|
||||||
|
|
||||||
@@ -206,3 +217,228 @@ async def query_available_models(
|
|||||||
"display_name": provider.display_name,
|
"display_name": provider.display_name,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/test-model")
|
||||||
|
async def test_model(
|
||||||
|
request: TestModelRequest,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
测试模型连接性
|
||||||
|
|
||||||
|
向指定提供商的指定模型发送测试请求,验证模型是否可用
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: 测试请求
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
测试结果
|
||||||
|
"""
|
||||||
|
# 获取提供商及其端点
|
||||||
|
provider = (
|
||||||
|
db.query(Provider)
|
||||||
|
.options(joinedload(Provider.endpoints).joinedload(ProviderEndpoint.api_keys))
|
||||||
|
.filter(Provider.id == request.provider_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not provider:
|
||||||
|
raise HTTPException(status_code=404, detail="Provider not found")
|
||||||
|
|
||||||
|
# 找到合适的端点和API Key
|
||||||
|
endpoint_config = None
|
||||||
|
endpoint = None
|
||||||
|
api_key = None
|
||||||
|
|
||||||
|
if request.api_key_id:
|
||||||
|
# 使用指定的API Key
|
||||||
|
for ep in provider.endpoints:
|
||||||
|
for key in ep.api_keys:
|
||||||
|
if key.id == request.api_key_id and key.is_active and ep.is_active:
|
||||||
|
endpoint = ep
|
||||||
|
api_key = key
|
||||||
|
break
|
||||||
|
if endpoint:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# 使用第一个可用的端点和密钥
|
||||||
|
for ep in provider.endpoints:
|
||||||
|
if not ep.is_active or not ep.api_keys:
|
||||||
|
continue
|
||||||
|
for key in ep.api_keys:
|
||||||
|
if key.is_active:
|
||||||
|
endpoint = ep
|
||||||
|
api_key = key
|
||||||
|
break
|
||||||
|
if endpoint:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not endpoint or not api_key:
|
||||||
|
raise HTTPException(status_code=404, detail="No active endpoint or API key found")
|
||||||
|
|
||||||
|
try:
|
||||||
|
api_key_value = crypto_service.decrypt(api_key.api_key)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[test-model] Failed to decrypt API key: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to decrypt API key")
|
||||||
|
|
||||||
|
# 构建请求配置
|
||||||
|
endpoint_config = {
|
||||||
|
"api_key": api_key_value,
|
||||||
|
"api_key_id": api_key.id, # 添加API Key ID用于用量记录
|
||||||
|
"base_url": endpoint.base_url,
|
||||||
|
"api_format": endpoint.api_format,
|
||||||
|
"extra_headers": endpoint.headers,
|
||||||
|
"timeout": endpoint.timeout or 30.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 获取对应的 Adapter 类
|
||||||
|
adapter_class = _get_adapter_for_format(endpoint.api_format)
|
||||||
|
if not adapter_class:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"Unknown API format: {endpoint.api_format}",
|
||||||
|
"provider": {
|
||||||
|
"id": provider.id,
|
||||||
|
"name": provider.name,
|
||||||
|
"display_name": provider.display_name,
|
||||||
|
},
|
||||||
|
"model": request.model_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(f"[test-model] 使用 Adapter: {adapter_class.__name__}")
|
||||||
|
logger.debug(f"[test-model] 端点 API Format: {endpoint.api_format}")
|
||||||
|
|
||||||
|
# 如果请求指定了 api_format,优先使用它
|
||||||
|
target_api_format = request.api_format or endpoint.api_format
|
||||||
|
if request.api_format and request.api_format != endpoint.api_format:
|
||||||
|
logger.debug(f"[test-model] 请求指定 API Format: {request.api_format}")
|
||||||
|
# 重新获取适配器
|
||||||
|
adapter_class = _get_adapter_for_format(request.api_format)
|
||||||
|
if not adapter_class:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"Unknown API format: {request.api_format}",
|
||||||
|
"provider": {
|
||||||
|
"id": provider.id,
|
||||||
|
"name": provider.name,
|
||||||
|
"display_name": provider.display_name,
|
||||||
|
},
|
||||||
|
"model": request.model_name,
|
||||||
|
}
|
||||||
|
logger.debug(f"[test-model] 重新选择 Adapter: {adapter_class.__name__}")
|
||||||
|
|
||||||
|
# 准备测试请求数据
|
||||||
|
check_request = {
|
||||||
|
"model": request.model_name,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": request.message or "Hello! This is a test message."}
|
||||||
|
],
|
||||||
|
"max_tokens": 30,
|
||||||
|
"temperature": 0.7,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 发送测试请求
|
||||||
|
async with httpx.AsyncClient(timeout=endpoint_config["timeout"]) as client:
|
||||||
|
# 非流式测试
|
||||||
|
logger.debug(f"[test-model] 开始非流式测试...")
|
||||||
|
|
||||||
|
response = await adapter_class.check_endpoint(
|
||||||
|
client,
|
||||||
|
endpoint_config["base_url"],
|
||||||
|
endpoint_config["api_key"],
|
||||||
|
check_request,
|
||||||
|
endpoint_config.get("extra_headers"),
|
||||||
|
# 用量计算参数(现在强制记录)
|
||||||
|
db=db,
|
||||||
|
user=current_user,
|
||||||
|
provider_name=provider.name,
|
||||||
|
provider_id=provider.id,
|
||||||
|
api_key_id=endpoint_config.get("api_key_id"),
|
||||||
|
model_name=request.model_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 记录提供商返回信息
|
||||||
|
logger.debug(f"[test-model] 非流式测试结果:")
|
||||||
|
logger.debug(f"[test-model] Status Code: {response.get('status_code')}")
|
||||||
|
logger.debug(f"[test-model] Response Headers: {response.get('headers', {})}")
|
||||||
|
response_data = response.get('response', {})
|
||||||
|
response_body = response_data.get('response_body', {})
|
||||||
|
logger.debug(f"[test-model] Response Data: {response_data}")
|
||||||
|
logger.debug(f"[test-model] Response Body: {response_body}")
|
||||||
|
# 尝试解析 response_body (通常是 JSON 字符串)
|
||||||
|
parsed_body = response_body
|
||||||
|
import json
|
||||||
|
if isinstance(response_body, str):
|
||||||
|
try:
|
||||||
|
parsed_body = json.loads(response_body)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if isinstance(parsed_body, dict) and 'error' in parsed_body:
|
||||||
|
error_obj = parsed_body['error']
|
||||||
|
# 兼容 error 可能是字典或字符串的情况
|
||||||
|
if isinstance(error_obj, dict):
|
||||||
|
logger.debug(f"[test-model] Error Message: {error_obj.get('message')}")
|
||||||
|
raise HTTPException(status_code=500, detail=error_obj.get('message'))
|
||||||
|
else:
|
||||||
|
logger.debug(f"[test-model] Error: {error_obj}")
|
||||||
|
raise HTTPException(status_code=500, detail=error_obj)
|
||||||
|
elif 'error' in response:
|
||||||
|
logger.debug(f"[test-model] Error: {response['error']}")
|
||||||
|
raise HTTPException(status_code=500, detail=response['error'])
|
||||||
|
else:
|
||||||
|
# 如果有选择或消息,记录内容预览
|
||||||
|
if isinstance(response_data, dict):
|
||||||
|
if 'choices' in response_data and response_data['choices']:
|
||||||
|
choice = response_data['choices'][0]
|
||||||
|
if 'message' in choice:
|
||||||
|
content = choice['message'].get('content', '')
|
||||||
|
logger.debug(f"[test-model] Content Preview: {content[:200]}...")
|
||||||
|
elif 'content' in response_data and response_data['content']:
|
||||||
|
content = str(response_data['content'])
|
||||||
|
logger.debug(f"[test-model] Content Preview: {content[:200]}...")
|
||||||
|
|
||||||
|
# 检查测试是否成功(基于HTTP状态码)
|
||||||
|
status_code = response.get('status_code', 0)
|
||||||
|
is_success = status_code == 200 and 'error' not in response
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": is_success,
|
||||||
|
"data": {
|
||||||
|
"stream": False,
|
||||||
|
"response": response,
|
||||||
|
},
|
||||||
|
"provider": {
|
||||||
|
"id": provider.id,
|
||||||
|
"name": provider.name,
|
||||||
|
"display_name": provider.display_name,
|
||||||
|
},
|
||||||
|
"model": request.model_name,
|
||||||
|
"endpoint": {
|
||||||
|
"id": endpoint.id,
|
||||||
|
"api_format": endpoint.api_format,
|
||||||
|
"base_url": endpoint.base_url,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[test-model] Error testing model {request.model_name}: {e}")
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": str(e),
|
||||||
|
"provider": {
|
||||||
|
"id": provider.id,
|
||||||
|
"name": provider.name,
|
||||||
|
"display_name": provider.display_name,
|
||||||
|
},
|
||||||
|
"model": request.model_name,
|
||||||
|
"endpoint": {
|
||||||
|
"id": endpoint.id,
|
||||||
|
"api_format": endpoint.api_format,
|
||||||
|
"base_url": endpoint.base_url,
|
||||||
|
} if endpoint else None,
|
||||||
|
}
|
||||||
|
|||||||
@@ -18,7 +18,15 @@ from sqlalchemy.orm import Session, joinedload
|
|||||||
from src.config.constants import CacheTTL
|
from src.config.constants import CacheTTL
|
||||||
from src.core.cache_service import CacheService
|
from src.core.cache_service import CacheService
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
from src.models.database import GlobalModel, Model, Provider, ProviderAPIKey, ProviderEndpoint
|
from src.models.database import (
|
||||||
|
ApiKey,
|
||||||
|
GlobalModel,
|
||||||
|
Model,
|
||||||
|
Provider,
|
||||||
|
ProviderAPIKey,
|
||||||
|
ProviderEndpoint,
|
||||||
|
User,
|
||||||
|
)
|
||||||
|
|
||||||
# 缓存 key 前缀
|
# 缓存 key 前缀
|
||||||
_CACHE_KEY_PREFIX = "models:list"
|
_CACHE_KEY_PREFIX = "models:list"
|
||||||
@@ -82,6 +90,7 @@ class ModelInfo:
|
|||||||
created_at: Optional[str] # ISO 格式
|
created_at: Optional[str] # ISO 格式
|
||||||
created_timestamp: int # Unix 时间戳
|
created_timestamp: int # Unix 时间戳
|
||||||
provider_name: str
|
provider_name: str
|
||||||
|
provider_id: str = "" # Provider ID,用于权限过滤
|
||||||
# 能力配置
|
# 能力配置
|
||||||
streaming: bool = True
|
streaming: bool = True
|
||||||
vision: bool = False
|
vision: bool = False
|
||||||
@@ -99,6 +108,92 @@ class ModelInfo:
|
|||||||
output_modalities: Optional[list[str]] = None
|
output_modalities: Optional[list[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AccessRestrictions:
|
||||||
|
"""API Key 或 User 的访问限制"""
|
||||||
|
|
||||||
|
allowed_providers: Optional[list[str]] = None # 允许的 Provider ID 列表
|
||||||
|
allowed_models: Optional[list[str]] = None # 允许的模型名称列表
|
||||||
|
allowed_api_formats: Optional[list[str]] = None # 允许的 API 格式列表
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_api_key_and_user(
|
||||||
|
cls, api_key: Optional[ApiKey], user: Optional[User]
|
||||||
|
) -> "AccessRestrictions":
|
||||||
|
"""
|
||||||
|
从 API Key 和 User 合并访问限制
|
||||||
|
|
||||||
|
限制逻辑:
|
||||||
|
- API Key 的限制优先于 User 的限制
|
||||||
|
- 如果 API Key 有限制,使用 API Key 的限制
|
||||||
|
- 如果 API Key 无限制但 User 有限制,使用 User 的限制
|
||||||
|
- 两者都无限制则返回空限制
|
||||||
|
"""
|
||||||
|
allowed_providers: Optional[list[str]] = None
|
||||||
|
allowed_models: Optional[list[str]] = None
|
||||||
|
allowed_api_formats: Optional[list[str]] = None
|
||||||
|
|
||||||
|
# 优先使用 API Key 的限制
|
||||||
|
if api_key:
|
||||||
|
if api_key.allowed_providers is not None:
|
||||||
|
allowed_providers = api_key.allowed_providers
|
||||||
|
if api_key.allowed_models is not None:
|
||||||
|
allowed_models = api_key.allowed_models
|
||||||
|
if api_key.allowed_api_formats is not None:
|
||||||
|
allowed_api_formats = api_key.allowed_api_formats
|
||||||
|
|
||||||
|
# 如果 API Key 没有限制,检查 User 的限制
|
||||||
|
# 注意: User 没有 allowed_api_formats 字段
|
||||||
|
if user:
|
||||||
|
if allowed_providers is None and user.allowed_providers is not None:
|
||||||
|
allowed_providers = user.allowed_providers
|
||||||
|
if allowed_models is None and user.allowed_models is not None:
|
||||||
|
allowed_models = user.allowed_models
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
allowed_providers=allowed_providers,
|
||||||
|
allowed_models=allowed_models,
|
||||||
|
allowed_api_formats=allowed_api_formats,
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_api_format_allowed(self, api_format: str) -> bool:
|
||||||
|
"""
|
||||||
|
检查 API 格式是否被允许
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_format: API 格式 (如 "OPENAI", "CLAUDE", "GEMINI")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True 如果格式被允许,False 否则
|
||||||
|
"""
|
||||||
|
if self.allowed_api_formats is None:
|
||||||
|
return True
|
||||||
|
return api_format in self.allowed_api_formats
|
||||||
|
|
||||||
|
def is_model_allowed(self, model_id: str, provider_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
检查模型是否被允许访问
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: 模型 ID
|
||||||
|
provider_id: Provider ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True 如果模型被允许,False 否则
|
||||||
|
"""
|
||||||
|
# 检查 Provider 限制
|
||||||
|
if self.allowed_providers is not None:
|
||||||
|
if provider_id not in self.allowed_providers:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 检查模型限制
|
||||||
|
if self.allowed_models is not None:
|
||||||
|
if model_id not in self.allowed_models:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def get_available_provider_ids(db: Session, api_formats: list[str]) -> set[str]:
|
def get_available_provider_ids(db: Session, api_formats: list[str]) -> set[str]:
|
||||||
"""
|
"""
|
||||||
返回有可用端点的 Provider IDs
|
返回有可用端点的 Provider IDs
|
||||||
@@ -218,6 +313,7 @@ def _extract_model_info(model: Any) -> ModelInfo:
|
|||||||
)
|
)
|
||||||
created_timestamp: int = int(model.created_at.timestamp()) if model.created_at else 0
|
created_timestamp: int = int(model.created_at.timestamp()) if model.created_at else 0
|
||||||
provider_name: str = model.provider.name if model.provider else "unknown"
|
provider_name: str = model.provider.name if model.provider else "unknown"
|
||||||
|
provider_id: str = model.provider_id or ""
|
||||||
|
|
||||||
# 从 GlobalModel.config 提取配置信息
|
# 从 GlobalModel.config 提取配置信息
|
||||||
config: dict = {}
|
config: dict = {}
|
||||||
@@ -233,6 +329,7 @@ def _extract_model_info(model: Any) -> ModelInfo:
|
|||||||
created_at=created_at,
|
created_at=created_at,
|
||||||
created_timestamp=created_timestamp,
|
created_timestamp=created_timestamp,
|
||||||
provider_name=provider_name,
|
provider_name=provider_name,
|
||||||
|
provider_id=provider_id,
|
||||||
# 能力配置
|
# 能力配置
|
||||||
streaming=config.get("streaming", True),
|
streaming=config.get("streaming", True),
|
||||||
vision=config.get("vision", False),
|
vision=config.get("vision", False),
|
||||||
@@ -255,6 +352,7 @@ async def list_available_models(
|
|||||||
db: Session,
|
db: Session,
|
||||||
available_provider_ids: set[str],
|
available_provider_ids: set[str],
|
||||||
api_formats: Optional[list[str]] = None,
|
api_formats: Optional[list[str]] = None,
|
||||||
|
restrictions: Optional[AccessRestrictions] = None,
|
||||||
) -> list[ModelInfo]:
|
) -> list[ModelInfo]:
|
||||||
"""
|
"""
|
||||||
获取可用模型列表(已去重,带缓存)
|
获取可用模型列表(已去重,带缓存)
|
||||||
@@ -263,6 +361,7 @@ async def list_available_models(
|
|||||||
db: 数据库会话
|
db: 数据库会话
|
||||||
available_provider_ids: 有可用端点的 Provider ID 集合
|
available_provider_ids: 有可用端点的 Provider ID 集合
|
||||||
api_formats: API 格式列表,用于检查 Key 的 allowed_models
|
api_formats: API 格式列表,用于检查 Key 的 allowed_models
|
||||||
|
restrictions: API Key/User 的访问限制
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
去重后的 ModelInfo 列表,按创建时间倒序
|
去重后的 ModelInfo 列表,按创建时间倒序
|
||||||
@@ -270,8 +369,16 @@ async def list_available_models(
|
|||||||
if not available_provider_ids:
|
if not available_provider_ids:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
# 缓存策略:只有完全无访问限制时才使用缓存
|
||||||
|
# - restrictions is None: 未传入限制对象
|
||||||
|
# - restrictions 的两个字段都为 None: 传入了限制对象但无实际限制
|
||||||
|
# 以上两种情况返回的结果相同,可以共享全局缓存
|
||||||
|
use_cache = restrictions is None or (
|
||||||
|
restrictions.allowed_providers is None and restrictions.allowed_models is None
|
||||||
|
)
|
||||||
|
|
||||||
# 尝试从缓存获取
|
# 尝试从缓存获取
|
||||||
if api_formats:
|
if api_formats and use_cache:
|
||||||
cached = await _get_cached_models(api_formats)
|
cached = await _get_cached_models(api_formats)
|
||||||
if cached is not None:
|
if cached is not None:
|
||||||
return cached
|
return cached
|
||||||
@@ -306,14 +413,19 @@ async def list_available_models(
|
|||||||
if available_model_ids is not None and info.id not in available_model_ids:
|
if available_model_ids is not None and info.id not in available_model_ids:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# 检查 API Key/User 访问限制
|
||||||
|
if restrictions is not None:
|
||||||
|
if not restrictions.is_model_allowed(info.id, info.provider_id):
|
||||||
|
continue
|
||||||
|
|
||||||
if info.id in seen_model_ids:
|
if info.id in seen_model_ids:
|
||||||
continue
|
continue
|
||||||
seen_model_ids.add(info.id)
|
seen_model_ids.add(info.id)
|
||||||
|
|
||||||
result.append(info)
|
result.append(info)
|
||||||
|
|
||||||
# 写入缓存
|
# 只有无限制的情况才写入缓存
|
||||||
if api_formats:
|
if api_formats and use_cache:
|
||||||
await _set_cached_models(api_formats, result)
|
await _set_cached_models(api_formats, result)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@@ -324,6 +436,7 @@ def find_model_by_id(
|
|||||||
model_id: str,
|
model_id: str,
|
||||||
available_provider_ids: set[str],
|
available_provider_ids: set[str],
|
||||||
api_formats: Optional[list[str]] = None,
|
api_formats: Optional[list[str]] = None,
|
||||||
|
restrictions: Optional[AccessRestrictions] = None,
|
||||||
) -> Optional[ModelInfo]:
|
) -> Optional[ModelInfo]:
|
||||||
"""
|
"""
|
||||||
按 ID 查找模型
|
按 ID 查找模型
|
||||||
@@ -338,6 +451,7 @@ def find_model_by_id(
|
|||||||
model_id: 模型 ID
|
model_id: 模型 ID
|
||||||
available_provider_ids: 有可用端点的 Provider ID 集合
|
available_provider_ids: 有可用端点的 Provider ID 集合
|
||||||
api_formats: API 格式列表,用于检查 Key 的 allowed_models
|
api_formats: API 格式列表,用于检查 Key 的 allowed_models
|
||||||
|
restrictions: API Key/User 的访问限制
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ModelInfo 或 None
|
ModelInfo 或 None
|
||||||
@@ -353,6 +467,11 @@ def find_model_by_id(
|
|||||||
if available_model_ids is not None and model_id not in available_model_ids:
|
if available_model_ids is not None and model_id not in available_model_ids:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# 快速检查:如果 restrictions 明确限制了模型列表且目标模型不在其中,直接返回 None
|
||||||
|
if restrictions is not None and restrictions.allowed_models is not None:
|
||||||
|
if model_id not in restrictions.allowed_models:
|
||||||
|
return None
|
||||||
|
|
||||||
# 先按 GlobalModel.name 查找
|
# 先按 GlobalModel.name 查找
|
||||||
models_by_global = (
|
models_by_global = (
|
||||||
db.query(Model)
|
db.query(Model)
|
||||||
@@ -368,8 +487,19 @@ def find_model_by_id(
|
|||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def is_model_accessible(m: Model) -> bool:
|
||||||
|
"""检查模型是否可访问"""
|
||||||
|
if m.provider_id not in available_provider_ids:
|
||||||
|
return False
|
||||||
|
# 检查 API Key/User 访问限制
|
||||||
|
if restrictions is not None:
|
||||||
|
provider_id = m.provider_id or ""
|
||||||
|
if not restrictions.is_model_allowed(model_id, provider_id):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
model = next(
|
model = next(
|
||||||
(m for m in models_by_global if m.provider_id in available_provider_ids),
|
(m for m in models_by_global if is_model_accessible(m)),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -393,7 +523,7 @@ def find_model_by_id(
|
|||||||
)
|
)
|
||||||
|
|
||||||
model = next(
|
model = next(
|
||||||
(m for m in models_by_provider_name if m.provider_id in available_provider_ids),
|
(m for m in models_by_provider_name if is_model_accessible(m)),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -118,7 +118,9 @@ class AdminDashboardStatsAdapter(AdminApiAdapter):
|
|||||||
# 转换为 UTC 用于与 stats_daily.date 比较(存储的是业务日期对应的 UTC 开始时间)
|
# 转换为 UTC 用于与 stats_daily.date 比较(存储的是业务日期对应的 UTC 开始时间)
|
||||||
today = today_local.astimezone(timezone.utc)
|
today = today_local.astimezone(timezone.utc)
|
||||||
yesterday = (today_local - timedelta(days=1)).astimezone(timezone.utc)
|
yesterday = (today_local - timedelta(days=1)).astimezone(timezone.utc)
|
||||||
last_month = (today_local - timedelta(days=30)).astimezone(timezone.utc)
|
# 本月第一天(自然月)
|
||||||
|
month_start_local = today_local.replace(day=1)
|
||||||
|
month_start = month_start_local.astimezone(timezone.utc)
|
||||||
|
|
||||||
# ==================== 使用预聚合数据 ====================
|
# ==================== 使用预聚合数据 ====================
|
||||||
# 从 stats_summary + 今日实时数据获取全局统计
|
# 从 stats_summary + 今日实时数据获取全局统计
|
||||||
@@ -208,7 +210,7 @@ class AdminDashboardStatsAdapter(AdminApiAdapter):
|
|||||||
func.sum(StatsDaily.cache_read_cost).label("cache_read_cost"),
|
func.sum(StatsDaily.cache_read_cost).label("cache_read_cost"),
|
||||||
func.sum(StatsDaily.fallback_count).label("fallback_count"),
|
func.sum(StatsDaily.fallback_count).label("fallback_count"),
|
||||||
)
|
)
|
||||||
.filter(StatsDaily.date >= last_month, StatsDaily.date < today)
|
.filter(StatsDaily.date >= month_start, StatsDaily.date < today)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -227,24 +229,24 @@ class AdminDashboardStatsAdapter(AdminApiAdapter):
|
|||||||
else:
|
else:
|
||||||
# 回退到实时查询(没有预聚合数据时)
|
# 回退到实时查询(没有预聚合数据时)
|
||||||
total_requests = (
|
total_requests = (
|
||||||
db.query(func.count(Usage.id)).filter(Usage.created_at >= last_month).scalar() or 0
|
db.query(func.count(Usage.id)).filter(Usage.created_at >= month_start).scalar() or 0
|
||||||
)
|
)
|
||||||
total_cost = (
|
total_cost = (
|
||||||
db.query(func.sum(Usage.total_cost_usd)).filter(Usage.created_at >= last_month).scalar() or 0
|
db.query(func.sum(Usage.total_cost_usd)).filter(Usage.created_at >= month_start).scalar() or 0
|
||||||
)
|
)
|
||||||
total_actual_cost = (
|
total_actual_cost = (
|
||||||
db.query(func.sum(Usage.actual_total_cost_usd))
|
db.query(func.sum(Usage.actual_total_cost_usd))
|
||||||
.filter(Usage.created_at >= last_month).scalar() or 0
|
.filter(Usage.created_at >= month_start).scalar() or 0
|
||||||
)
|
)
|
||||||
error_requests = (
|
error_requests = (
|
||||||
db.query(func.count(Usage.id))
|
db.query(func.count(Usage.id))
|
||||||
.filter(
|
.filter(
|
||||||
Usage.created_at >= last_month,
|
Usage.created_at >= month_start,
|
||||||
(Usage.status_code >= 400) | (Usage.error_message.isnot(None)),
|
(Usage.status_code >= 400) | (Usage.error_message.isnot(None)),
|
||||||
).scalar() or 0
|
).scalar() or 0
|
||||||
)
|
)
|
||||||
total_tokens = (
|
total_tokens = (
|
||||||
db.query(func.sum(Usage.total_tokens)).filter(Usage.created_at >= last_month).scalar() or 0
|
db.query(func.sum(Usage.total_tokens)).filter(Usage.created_at >= month_start).scalar() or 0
|
||||||
)
|
)
|
||||||
cache_stats = (
|
cache_stats = (
|
||||||
db.query(
|
db.query(
|
||||||
@@ -253,7 +255,7 @@ class AdminDashboardStatsAdapter(AdminApiAdapter):
|
|||||||
func.sum(Usage.cache_creation_cost_usd).label("cache_creation_cost"),
|
func.sum(Usage.cache_creation_cost_usd).label("cache_creation_cost"),
|
||||||
func.sum(Usage.cache_read_cost_usd).label("cache_read_cost"),
|
func.sum(Usage.cache_read_cost_usd).label("cache_read_cost"),
|
||||||
)
|
)
|
||||||
.filter(Usage.created_at >= last_month)
|
.filter(Usage.created_at >= month_start)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
cache_creation_tokens = int(cache_stats.cache_creation_tokens or 0) if cache_stats else 0
|
cache_creation_tokens = int(cache_stats.cache_creation_tokens or 0) if cache_stats else 0
|
||||||
@@ -267,7 +269,7 @@ class AdminDashboardStatsAdapter(AdminApiAdapter):
|
|||||||
RequestCandidate.request_id, func.count(RequestCandidate.id).label("executed_count")
|
RequestCandidate.request_id, func.count(RequestCandidate.id).label("executed_count")
|
||||||
)
|
)
|
||||||
.filter(
|
.filter(
|
||||||
RequestCandidate.created_at >= last_month,
|
RequestCandidate.created_at >= month_start,
|
||||||
RequestCandidate.status.in_(["success", "failed"]),
|
RequestCandidate.status.in_(["success", "failed"]),
|
||||||
)
|
)
|
||||||
.group_by(RequestCandidate.request_id)
|
.group_by(RequestCandidate.request_id)
|
||||||
@@ -447,7 +449,9 @@ class UserDashboardStatsAdapter(DashboardAdapter):
|
|||||||
# 转换为 UTC 用于数据库查询
|
# 转换为 UTC 用于数据库查询
|
||||||
today = today_local.astimezone(timezone.utc)
|
today = today_local.astimezone(timezone.utc)
|
||||||
yesterday = (today_local - timedelta(days=1)).astimezone(timezone.utc)
|
yesterday = (today_local - timedelta(days=1)).astimezone(timezone.utc)
|
||||||
last_month = (today_local - timedelta(days=30)).astimezone(timezone.utc)
|
# 本月第一天(自然月)
|
||||||
|
month_start_local = today_local.replace(day=1)
|
||||||
|
month_start = month_start_local.astimezone(timezone.utc)
|
||||||
|
|
||||||
user_api_keys = db.query(func.count(ApiKey.id)).filter(ApiKey.user_id == user.id).scalar()
|
user_api_keys = db.query(func.count(ApiKey.id)).filter(ApiKey.user_id == user.id).scalar()
|
||||||
active_keys = (
|
active_keys = (
|
||||||
@@ -483,12 +487,12 @@ class UserDashboardStatsAdapter(DashboardAdapter):
|
|||||||
# 本月请求统计
|
# 本月请求统计
|
||||||
user_requests = (
|
user_requests = (
|
||||||
db.query(func.count(Usage.id))
|
db.query(func.count(Usage.id))
|
||||||
.filter(and_(Usage.user_id == user.id, Usage.created_at >= last_month))
|
.filter(and_(Usage.user_id == user.id, Usage.created_at >= month_start))
|
||||||
.scalar()
|
.scalar()
|
||||||
)
|
)
|
||||||
user_cost = (
|
user_cost = (
|
||||||
db.query(func.sum(Usage.total_cost_usd))
|
db.query(func.sum(Usage.total_cost_usd))
|
||||||
.filter(and_(Usage.user_id == user.id, Usage.created_at >= last_month))
|
.filter(and_(Usage.user_id == user.id, Usage.created_at >= month_start))
|
||||||
.scalar()
|
.scalar()
|
||||||
or 0
|
or 0
|
||||||
)
|
)
|
||||||
@@ -532,18 +536,19 @@ class UserDashboardStatsAdapter(DashboardAdapter):
|
|||||||
func.sum(Usage.cache_read_input_tokens).label("cache_read_tokens"),
|
func.sum(Usage.cache_read_input_tokens).label("cache_read_tokens"),
|
||||||
func.sum(Usage.input_tokens).label("total_input_tokens"),
|
func.sum(Usage.input_tokens).label("total_input_tokens"),
|
||||||
)
|
)
|
||||||
.filter(and_(Usage.user_id == user.id, Usage.created_at >= last_month))
|
.filter(and_(Usage.user_id == user.id, Usage.created_at >= month_start))
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
cache_creation_tokens = int(cache_stats.cache_creation_tokens or 0) if cache_stats else 0
|
cache_creation_tokens = int(cache_stats.cache_creation_tokens or 0) if cache_stats else 0
|
||||||
cache_read_tokens = int(cache_stats.cache_read_tokens or 0) if cache_stats else 0
|
cache_read_tokens = int(cache_stats.cache_read_tokens or 0) if cache_stats else 0
|
||||||
|
monthly_input_tokens = int(cache_stats.total_input_tokens or 0) if cache_stats else 0
|
||||||
|
|
||||||
# 计算缓存命中率:cache_read / (input_tokens + cache_read)
|
# 计算本月缓存命中率:cache_read / (input_tokens + cache_read)
|
||||||
# input_tokens 是实际发送给模型的输入(不含缓存读取),cache_read 是从缓存读取的
|
# input_tokens 是实际发送给模型的输入(不含缓存读取),cache_read 是从缓存读取的
|
||||||
# 总输入 = input_tokens + cache_read,缓存命中率 = cache_read / 总输入
|
# 总输入 = input_tokens + cache_read,缓存命中率 = cache_read / 总输入
|
||||||
total_input_with_cache = all_time_input_tokens + all_time_cache_read
|
total_input_with_cache = monthly_input_tokens + cache_read_tokens
|
||||||
cache_hit_rate = (
|
cache_hit_rate = (
|
||||||
round((all_time_cache_read / total_input_with_cache) * 100, 1)
|
round((cache_read_tokens / total_input_with_cache) * 100, 1)
|
||||||
if total_input_with_cache > 0
|
if total_input_with_cache > 0
|
||||||
else 0
|
else 0
|
||||||
)
|
)
|
||||||
@@ -569,15 +574,15 @@ class UserDashboardStatsAdapter(DashboardAdapter):
|
|||||||
quota_value = "无限制"
|
quota_value = "无限制"
|
||||||
quota_change = f"已用 ${user.used_usd:.2f}"
|
quota_change = f"已用 ${user.used_usd:.2f}"
|
||||||
quota_high = False
|
quota_high = False
|
||||||
elif user.quota_usd and user.quota_usd > 0:
|
elif user.quota_usd > 0:
|
||||||
percent = min(100, int((user.used_usd / user.quota_usd) * 100))
|
percent = min(100, int((user.used_usd / user.quota_usd) * 100))
|
||||||
quota_value = "无限制"
|
quota_value = f"${user.quota_usd:.0f}"
|
||||||
quota_change = f"已用 ${user.used_usd:.2f}"
|
quota_change = f"已用 ${user.used_usd:.2f}"
|
||||||
quota_high = percent > 80
|
quota_high = percent > 80
|
||||||
else:
|
else:
|
||||||
quota_value = "0%"
|
quota_value = "$0"
|
||||||
quota_change = f"已用 ${user.used_usd:.2f}"
|
quota_change = f"已用 ${user.used_usd:.2f}"
|
||||||
quota_high = False
|
quota_high = True
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"stats": [
|
"stats": [
|
||||||
@@ -605,9 +610,15 @@ class UserDashboardStatsAdapter(DashboardAdapter):
|
|||||||
"icon": "TrendingUp",
|
"icon": "TrendingUp",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "本月费用",
|
"name": "总Token",
|
||||||
"value": f"${user_cost:.2f}",
|
"value": format_tokens(
|
||||||
"icon": "DollarSign",
|
all_time_input_tokens
|
||||||
|
+ all_time_output_tokens
|
||||||
|
+ all_time_cache_creation
|
||||||
|
+ all_time_cache_read
|
||||||
|
),
|
||||||
|
"subValue": f"输入 {format_tokens(all_time_input_tokens)} / 输出 {format_tokens(all_time_output_tokens)}",
|
||||||
|
"icon": "Hash",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"today": {
|
"today": {
|
||||||
@@ -631,6 +642,8 @@ class UserDashboardStatsAdapter(DashboardAdapter):
|
|||||||
"cache_hit_rate": cache_hit_rate,
|
"cache_hit_rate": cache_hit_rate,
|
||||||
"total_cache_tokens": cache_creation_tokens + cache_read_tokens,
|
"total_cache_tokens": cache_creation_tokens + cache_read_tokens,
|
||||||
},
|
},
|
||||||
|
# 本月费用(用于下方缓存区域显示)
|
||||||
|
"monthly_cost": float(user_cost),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -63,6 +63,34 @@ class ChatAdapterBase(ApiAdapter):
|
|||||||
name: str = "chat.base"
|
name: str = "chat.base"
|
||||||
mode = ApiMode.STANDARD
|
mode = ApiMode.STANDARD
|
||||||
|
|
||||||
|
# 子类可以配置的特殊方法(用于check_endpoint)
|
||||||
|
@classmethod
|
||||||
|
def build_endpoint_url(cls, base_url: str) -> str:
|
||||||
|
"""构建端点URL,子类可以覆盖以自定义URL构建逻辑"""
|
||||||
|
# 默认实现:在base_url后添加特定路径
|
||||||
|
return base_url
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
|
||||||
|
"""构建基础请求头,子类可以覆盖以自定义认证头"""
|
||||||
|
# 默认实现:Bearer token认证
|
||||||
|
return {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_protected_header_keys(cls) -> tuple:
|
||||||
|
"""返回不应被extra_headers覆盖的头部key,子类可以覆盖"""
|
||||||
|
# 默认保护认证相关头部
|
||||||
|
return ("authorization", "content-type")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""构建请求体,子类可以覆盖以自定义请求格式转换"""
|
||||||
|
# 默认实现:直接使用请求数据
|
||||||
|
return request_data.copy()
|
||||||
|
|
||||||
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
||||||
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
|
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
|
||||||
|
|
||||||
@@ -654,6 +682,65 @@ class ChatAdapterBase(ApiAdapter):
|
|||||||
# 默认实现返回空列表,子类应覆盖
|
# 默认实现返回空列表,子类应覆盖
|
||||||
return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models"
|
return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def check_endpoint(
|
||||||
|
cls,
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
base_url: str,
|
||||||
|
api_key: str,
|
||||||
|
request_data: Dict[str, Any],
|
||||||
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
|
# 用量计算参数(现在强制记录)
|
||||||
|
db: Optional[Any] = None,
|
||||||
|
user: Optional[Any] = None,
|
||||||
|
provider_name: Optional[str] = None,
|
||||||
|
provider_id: Optional[str] = None,
|
||||||
|
api_key_id: Optional[str] = None,
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
测试模型连接性(非流式)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: httpx 异步客户端
|
||||||
|
base_url: API 基础 URL
|
||||||
|
api_key: API 密钥(已解密)
|
||||||
|
request_data: 请求数据
|
||||||
|
extra_headers: 端点配置的额外请求头
|
||||||
|
db: 数据库会话
|
||||||
|
user: 用户对象
|
||||||
|
provider_name: 提供商名称
|
||||||
|
provider_id: 提供商ID
|
||||||
|
api_key_id: API Key ID
|
||||||
|
model_name: 模型名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
测试响应数据
|
||||||
|
"""
|
||||||
|
from src.api.handlers.base.endpoint_checker import build_safe_headers, run_endpoint_check
|
||||||
|
|
||||||
|
# 使用子类配置方法构建请求组件
|
||||||
|
url = cls.build_endpoint_url(base_url)
|
||||||
|
base_headers = cls.build_base_headers(api_key)
|
||||||
|
protected_keys = cls.get_protected_header_keys()
|
||||||
|
headers = build_safe_headers(base_headers, extra_headers, protected_keys)
|
||||||
|
body = cls.build_request_body(request_data)
|
||||||
|
|
||||||
|
# 使用通用的endpoint checker执行请求
|
||||||
|
return await run_endpoint_check(
|
||||||
|
client=client,
|
||||||
|
url=url,
|
||||||
|
headers=headers,
|
||||||
|
json_body=body,
|
||||||
|
api_format=cls.name,
|
||||||
|
# 用量计算参数(现在强制记录)
|
||||||
|
db=db,
|
||||||
|
user=user,
|
||||||
|
provider_name=provider_name,
|
||||||
|
provider_id=provider_id,
|
||||||
|
api_key_id=api_key_id,
|
||||||
|
model_name=model_name or request_data.get("model"),
|
||||||
|
)
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# Adapter 注册表 - 用于根据 API format 获取 Adapter 实例
|
# Adapter 注册表 - 用于根据 API format 获取 Adapter 实例
|
||||||
|
|||||||
@@ -614,6 +614,146 @@ class CliAdapterBase(ApiAdapter):
|
|||||||
# 默认实现返回空列表,子类应覆盖
|
# 默认实现返回空列表,子类应覆盖
|
||||||
return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models"
|
return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def check_endpoint(
|
||||||
|
cls,
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
base_url: str,
|
||||||
|
api_key: str,
|
||||||
|
request_data: Dict[str, Any],
|
||||||
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
|
# 用量计算参数
|
||||||
|
db: Optional[Any] = None,
|
||||||
|
user: Optional[Any] = None,
|
||||||
|
provider_name: Optional[str] = None,
|
||||||
|
provider_id: Optional[str] = None,
|
||||||
|
api_key_id: Optional[str] = None,
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
测试模型连接性(非流式)
|
||||||
|
|
||||||
|
通用的CLI endpoint测试方法,使用配置方法模式:
|
||||||
|
- build_endpoint_url(): 构建请求URL
|
||||||
|
- build_base_headers(): 构建基础认证头
|
||||||
|
- get_protected_header_keys(): 获取受保护的头部key
|
||||||
|
- build_request_body(): 构建请求体
|
||||||
|
- get_cli_user_agent(): 获取CLI User-Agent(子类可覆盖)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: httpx 异步客户端
|
||||||
|
base_url: API 基础 URL
|
||||||
|
api_key: API 密钥(已解密)
|
||||||
|
request_data: 请求数据
|
||||||
|
extra_headers: 端点配置的额外请求头
|
||||||
|
db: 数据库会话
|
||||||
|
user: 用户对象
|
||||||
|
provider_name: 提供商名称
|
||||||
|
provider_id: 提供商ID
|
||||||
|
api_key_id: API密钥ID
|
||||||
|
model_name: 模型名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
测试响应数据
|
||||||
|
"""
|
||||||
|
from src.api.handlers.base.endpoint_checker import build_safe_headers, run_endpoint_check
|
||||||
|
|
||||||
|
# 构建请求组件
|
||||||
|
url = cls.build_endpoint_url(base_url, request_data, model_name)
|
||||||
|
base_headers = cls.build_base_headers(api_key)
|
||||||
|
protected_keys = cls.get_protected_header_keys()
|
||||||
|
|
||||||
|
# 添加CLI User-Agent
|
||||||
|
cli_user_agent = cls.get_cli_user_agent()
|
||||||
|
if cli_user_agent:
|
||||||
|
base_headers["User-Agent"] = cli_user_agent
|
||||||
|
protected_keys = tuple(list(protected_keys) + ["user-agent"])
|
||||||
|
|
||||||
|
headers = build_safe_headers(base_headers, extra_headers, protected_keys)
|
||||||
|
body = cls.build_request_body(request_data)
|
||||||
|
|
||||||
|
# 获取有效的模型名称
|
||||||
|
effective_model_name = model_name or request_data.get("model")
|
||||||
|
|
||||||
|
return await run_endpoint_check(
|
||||||
|
client=client,
|
||||||
|
url=url,
|
||||||
|
headers=headers,
|
||||||
|
json_body=body,
|
||||||
|
api_format=cls.name,
|
||||||
|
# 用量计算参数(现在强制记录)
|
||||||
|
db=db,
|
||||||
|
user=user,
|
||||||
|
provider_name=provider_name,
|
||||||
|
provider_id=provider_id,
|
||||||
|
api_key_id=api_key_id,
|
||||||
|
model_name=effective_model_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# CLI Adapter 配置方法 - 子类应覆盖这些方法而不是整个 check_endpoint
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_endpoint_url(cls, base_url: str, request_data: Dict[str, Any], model_name: Optional[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
构建CLI API端点URL - 子类应覆盖
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_url: API基础URL
|
||||||
|
request_data: 请求数据
|
||||||
|
model_name: 模型名称(某些API需要,如Gemini)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
完整的端点URL
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement build_endpoint_url")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
构建CLI API认证头 - 子类应覆盖
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API密钥
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
基础认证头部字典
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement build_base_headers")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_protected_header_keys(cls) -> tuple:
|
||||||
|
"""
|
||||||
|
返回CLI API的保护头部key - 子类应覆盖
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
保护头部key的元组
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement get_protected_header_keys")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
构建CLI API请求体 - 子类应覆盖
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_data: 请求数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
请求体字典
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement build_request_body")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_cli_user_agent(cls) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
获取CLI User-Agent - 子类可覆盖
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CLI User-Agent字符串,如果不需要则为None
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# CLI Adapter 注册表 - 用于根据 API format 获取 CLI Adapter 实例
|
# CLI Adapter 注册表 - 用于根据 API format 获取 CLI Adapter 实例
|
||||||
|
|||||||
1252
src/api/handlers/base/endpoint_checker.py
Normal file
1252
src/api/handlers/base/endpoint_checker.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -209,6 +209,38 @@ class ClaudeChatAdapter(ChatAdapterBase):
|
|||||||
logger.warning(f"Failed to fetch Claude models from {models_url}: {e}")
|
logger.warning(f"Failed to fetch Claude models from {models_url}: {e}")
|
||||||
return [], error_msg
|
return [], error_msg
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_endpoint_url(cls, base_url: str) -> str:
|
||||||
|
"""构建Claude API端点URL"""
|
||||||
|
base_url = base_url.rstrip("/")
|
||||||
|
if base_url.endswith("/v1"):
|
||||||
|
return f"{base_url}/messages"
|
||||||
|
else:
|
||||||
|
return f"{base_url}/v1/messages"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
|
||||||
|
"""构建Claude API认证头"""
|
||||||
|
return {
|
||||||
|
"x-api-key": api_key,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"anthropic-version": "2023-06-01",
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_protected_header_keys(cls) -> tuple:
|
||||||
|
"""返回Claude API的保护头部key"""
|
||||||
|
return ("x-api-key", "content-type", "anthropic-version")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""构建Claude API请求体"""
|
||||||
|
return {
|
||||||
|
"model": request_data.get("model"),
|
||||||
|
"max_tokens": request_data.get("max_tokens", 100),
|
||||||
|
"messages": request_data.get("messages", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_claude_adapter(x_app_header: Optional[str]):
|
def build_claude_adapter(x_app_header: Optional[str]):
|
||||||
"""根据 x-app 头部构造 Chat 或 Claude Code 适配器。"""
|
"""根据 x-app 头部构造 Chat 或 Claude Code 适配器。"""
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ Claude CLI Adapter - 基于通用 CLI Adapter 基类的简化实现
|
|||||||
继承 CliAdapterBase,只需配置 FORMAT_ID 和 HANDLER_CLASS。
|
继承 CliAdapterBase,只需配置 FORMAT_ID 和 HANDLER_CLASS。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Tuple, Type
|
from typing import Any, AsyncIterator, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
@@ -126,5 +126,41 @@ class ClaudeCliAdapter(CliAdapterBase):
|
|||||||
m["api_format"] = cls.FORMAT_ID
|
m["api_format"] = cls.FORMAT_ID
|
||||||
return models, error
|
return models, error
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_endpoint_url(cls, base_url: str, request_data: Dict[str, Any], model_name: Optional[str] = None) -> str:
|
||||||
|
"""构建Claude CLI API端点URL"""
|
||||||
|
base_url = base_url.rstrip("/")
|
||||||
|
if base_url.endswith("/v1"):
|
||||||
|
return f"{base_url}/messages"
|
||||||
|
else:
|
||||||
|
return f"{base_url}/v1/messages"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
|
||||||
|
"""构建Claude CLI API认证头"""
|
||||||
|
return {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_protected_header_keys(cls) -> tuple:
|
||||||
|
"""返回Claude CLI API的保护头部key"""
|
||||||
|
return ("authorization", "content-type")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""构建Claude CLI API请求体"""
|
||||||
|
return {
|
||||||
|
"model": request_data.get("model"),
|
||||||
|
"max_tokens": request_data.get("max_tokens", 100),
|
||||||
|
"messages": request_data.get("messages", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_cli_user_agent(cls) -> Optional[str]:
|
||||||
|
"""获取Claude CLI User-Agent"""
|
||||||
|
return config.internal_user_agent_claude_cli
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["ClaudeCliAdapter"]
|
__all__ = ["ClaudeCliAdapter"]
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ Gemini Chat Adapter
|
|||||||
处理 Gemini API 格式的请求适配
|
处理 Gemini API 格式的请求适配
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Tuple, Type
|
from typing import Any, AsyncIterator, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Request
|
||||||
@@ -12,6 +12,7 @@ from fastapi.responses import JSONResponse
|
|||||||
|
|
||||||
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
|
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
|
||||||
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
||||||
|
from src.api.handlers.base.endpoint_checker import build_safe_headers, run_endpoint_check
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
from src.models.gemini import GeminiRequest
|
from src.models.gemini import GeminiRequest
|
||||||
|
|
||||||
@@ -199,6 +200,94 @@ class GeminiChatAdapter(ChatAdapterBase):
|
|||||||
logger.warning(f"Failed to fetch Gemini models from {models_url}: {e}")
|
logger.warning(f"Failed to fetch Gemini models from {models_url}: {e}")
|
||||||
return [], error_msg
|
return [], error_msg
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_endpoint_url(cls, base_url: str) -> str:
|
||||||
|
"""构建Gemini API端点URL"""
|
||||||
|
base_url = base_url.rstrip("/")
|
||||||
|
if base_url.endswith("/v1beta"):
|
||||||
|
return base_url # 子类需要处理model参数
|
||||||
|
else:
|
||||||
|
return f"{base_url}/v1beta"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
|
||||||
|
"""构建Gemini API认证头"""
|
||||||
|
return {
|
||||||
|
"x-goog-api-key": api_key,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_protected_header_keys(cls) -> tuple:
|
||||||
|
"""返回Gemini API的保护头部key"""
|
||||||
|
return ("x-goog-api-key", "content-type")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""构建Gemini API请求体"""
|
||||||
|
return {
|
||||||
|
"contents": request_data.get("messages", []),
|
||||||
|
"generationConfig": {
|
||||||
|
"maxOutputTokens": request_data.get("max_tokens", 100),
|
||||||
|
"temperature": request_data.get("temperature", 0.7),
|
||||||
|
},
|
||||||
|
"safetySettings": [
|
||||||
|
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def check_endpoint(
|
||||||
|
cls,
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
base_url: str,
|
||||||
|
api_key: str,
|
||||||
|
request_data: Dict[str, Any],
|
||||||
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
|
# 用量计算参数
|
||||||
|
db: Optional[Any] = None,
|
||||||
|
user: Optional[Any] = None,
|
||||||
|
provider_name: Optional[str] = None,
|
||||||
|
provider_id: Optional[str] = None,
|
||||||
|
api_key_id: Optional[str] = None,
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""测试 Gemini API 模型连接性(非流式)"""
|
||||||
|
# Gemini需要从request_data或model_name参数获取model名称
|
||||||
|
effective_model_name = model_name or request_data.get("model", "")
|
||||||
|
if not effective_model_name:
|
||||||
|
return {
|
||||||
|
"error": "Model name is required for Gemini API",
|
||||||
|
"status_code": 400,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 使用基类配置方法,但重写URL构建逻辑
|
||||||
|
base_url = cls.build_endpoint_url(base_url)
|
||||||
|
url = f"{base_url}/models/{effective_model_name}:generateContent"
|
||||||
|
|
||||||
|
# 构建请求组件
|
||||||
|
base_headers = cls.build_base_headers(api_key)
|
||||||
|
protected_keys = cls.get_protected_header_keys()
|
||||||
|
headers = build_safe_headers(base_headers, extra_headers, protected_keys)
|
||||||
|
body = cls.build_request_body(request_data)
|
||||||
|
|
||||||
|
# 使用基类的通用endpoint checker
|
||||||
|
from src.api.handlers.base.endpoint_checker import run_endpoint_check
|
||||||
|
return await run_endpoint_check(
|
||||||
|
client=client,
|
||||||
|
url=url,
|
||||||
|
headers=headers,
|
||||||
|
json_body=body,
|
||||||
|
api_format=cls.name,
|
||||||
|
# 用量计算参数(现在强制记录)
|
||||||
|
db=db,
|
||||||
|
user=user,
|
||||||
|
provider_name=provider_name,
|
||||||
|
provider_id=provider_id,
|
||||||
|
api_key_id=api_key_id,
|
||||||
|
model_name=effective_model_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_gemini_adapter(x_app_header: str = "") -> GeminiChatAdapter:
|
def build_gemini_adapter(x_app_header: str = "") -> GeminiChatAdapter:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ Gemini CLI Adapter - 基于通用 CLI Adapter 基类的实现
|
|||||||
继承 CliAdapterBase,处理 Gemini CLI 格式的请求。
|
继承 CliAdapterBase,处理 Gemini CLI 格式的请求。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Tuple, Type
|
from typing import Any, AsyncIterator, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
@@ -123,6 +123,52 @@ class GeminiCliAdapter(CliAdapterBase):
|
|||||||
m["api_format"] = cls.FORMAT_ID
|
m["api_format"] = cls.FORMAT_ID
|
||||||
return models, error
|
return models, error
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_endpoint_url(cls, base_url: str, request_data: Dict[str, Any], model_name: Optional[str] = None) -> str:
|
||||||
|
"""构建Gemini CLI API端点URL"""
|
||||||
|
effective_model_name = model_name or request_data.get("model", "")
|
||||||
|
if not effective_model_name:
|
||||||
|
raise ValueError("Model name is required for Gemini API")
|
||||||
|
|
||||||
|
base_url = base_url.rstrip("/")
|
||||||
|
if base_url.endswith("/v1beta"):
|
||||||
|
prefix = base_url
|
||||||
|
else:
|
||||||
|
prefix = f"{base_url}/v1beta"
|
||||||
|
return f"{prefix}/models/{effective_model_name}:generateContent"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
|
||||||
|
"""构建Gemini CLI API认证头"""
|
||||||
|
return {
|
||||||
|
"x-goog-api-key": api_key,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_protected_header_keys(cls) -> tuple:
|
||||||
|
"""返回Gemini CLI API的保护头部key"""
|
||||||
|
return ("x-goog-api-key", "content-type")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""构建Gemini CLI API请求体"""
|
||||||
|
return {
|
||||||
|
"contents": request_data.get("messages", []),
|
||||||
|
"generationConfig": {
|
||||||
|
"maxOutputTokens": request_data.get("max_tokens", 100),
|
||||||
|
"temperature": request_data.get("temperature", 0.7),
|
||||||
|
},
|
||||||
|
"safetySettings": [
|
||||||
|
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_cli_user_agent(cls) -> Optional[str]:
|
||||||
|
"""获取Gemini CLI User-Agent"""
|
||||||
|
return config.internal_user_agent_gemini_cli
|
||||||
|
|
||||||
|
|
||||||
def build_gemini_cli_adapter(x_app_header: str = "") -> GeminiCliAdapter:
|
def build_gemini_cli_adapter(x_app_header: str = "") -> GeminiCliAdapter:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -4,13 +4,14 @@ OpenAI Chat Adapter - 基于 ChatAdapterBase 的 OpenAI Chat API 适配器
|
|||||||
处理 /v1/chat/completions 端点的 OpenAI Chat 格式请求。
|
处理 /v1/chat/completions 端点的 OpenAI Chat 格式请求。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Tuple, Type
|
from typing import Any, AsyncIterator, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
|
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
|
||||||
|
from src.api.handlers.base.endpoint_checker import build_safe_headers, run_endpoint_check
|
||||||
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
from src.models.openai import OpenAIRequest
|
from src.models.openai import OpenAIRequest
|
||||||
@@ -154,5 +155,32 @@ class OpenAIChatAdapter(ChatAdapterBase):
|
|||||||
logger.warning(f"Failed to fetch models from {models_url}: {e}")
|
logger.warning(f"Failed to fetch models from {models_url}: {e}")
|
||||||
return [], error_msg
|
return [], error_msg
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_endpoint_url(cls, base_url: str) -> str:
|
||||||
|
"""构建OpenAI API端点URL"""
|
||||||
|
base_url = base_url.rstrip("/")
|
||||||
|
if base_url.endswith("/v1"):
|
||||||
|
return f"{base_url}/chat/completions"
|
||||||
|
else:
|
||||||
|
return f"{base_url}/v1/chat/completions"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
|
||||||
|
"""构建OpenAI API认证头"""
|
||||||
|
return {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_protected_header_keys(cls) -> tuple:
|
||||||
|
"""返回OpenAI API的保护头部key"""
|
||||||
|
return ("authorization", "content-type")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""构建OpenAI API请求体"""
|
||||||
|
return request_data.copy()
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["OpenAIChatAdapter"]
|
__all__ = ["OpenAIChatAdapter"]
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ OpenAI CLI Adapter - 基于通用 CLI Adapter 基类的简化实现
|
|||||||
继承 CliAdapterBase,只需配置 FORMAT_ID 和 HANDLER_CLASS。
|
继承 CliAdapterBase,只需配置 FORMAT_ID 和 HANDLER_CLASS。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict, Optional, Tuple, Type
|
from typing import Any, AsyncIterator, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
@@ -68,5 +68,37 @@ class OpenAICliAdapter(CliAdapterBase):
|
|||||||
m["api_format"] = cls.FORMAT_ID
|
m["api_format"] = cls.FORMAT_ID
|
||||||
return models, error
|
return models, error
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_endpoint_url(cls, base_url: str, request_data: Dict[str, Any], model_name: Optional[str] = None) -> str:
|
||||||
|
"""构建OpenAI CLI API端点URL"""
|
||||||
|
base_url = base_url.rstrip("/")
|
||||||
|
if base_url.endswith("/v1"):
|
||||||
|
return f"{base_url}/chat/completions"
|
||||||
|
else:
|
||||||
|
return f"{base_url}/v1/chat/completions"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
|
||||||
|
"""构建OpenAI CLI API认证头"""
|
||||||
|
return {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_protected_header_keys(cls) -> tuple:
|
||||||
|
"""返回OpenAI CLI API的保护头部key"""
|
||||||
|
return ("authorization", "content-type")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""构建OpenAI CLI API请求体"""
|
||||||
|
return request_data.copy()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_cli_user_agent(cls) -> Optional[str]:
|
||||||
|
"""获取OpenAI CLI User-Agent"""
|
||||||
|
return config.internal_user_agent_openai_cli
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["OpenAICliAdapter"]
|
__all__ = ["OpenAICliAdapter"]
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from fastapi.responses import JSONResponse
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from src.api.base.models_service import (
|
from src.api.base.models_service import (
|
||||||
|
AccessRestrictions,
|
||||||
ModelInfo,
|
ModelInfo,
|
||||||
find_model_by_id,
|
find_model_by_id,
|
||||||
get_available_provider_ids,
|
get_available_provider_ids,
|
||||||
@@ -103,6 +104,35 @@ def _get_formats_for_api(api_format: str) -> list[str]:
|
|||||||
return _OPENAI_FORMATS
|
return _OPENAI_FORMATS
|
||||||
|
|
||||||
|
|
||||||
|
def _build_empty_list_response(api_format: str) -> dict:
|
||||||
|
"""根据 API 格式构建空列表响应"""
|
||||||
|
if api_format == "claude":
|
||||||
|
return {"data": [], "has_more": False, "first_id": None, "last_id": None}
|
||||||
|
elif api_format == "gemini":
|
||||||
|
return {"models": []}
|
||||||
|
else:
|
||||||
|
return {"object": "list", "data": []}
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_formats_by_restrictions(
|
||||||
|
formats: list[str], restrictions: AccessRestrictions, api_format: str
|
||||||
|
) -> Tuple[list[str], Optional[dict]]:
|
||||||
|
"""
|
||||||
|
根据访问限制过滤 API 格式
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(过滤后的格式列表, 空响应或None)
|
||||||
|
如果过滤后为空,返回对应格式的空响应
|
||||||
|
"""
|
||||||
|
if restrictions.allowed_api_formats is None:
|
||||||
|
return formats, None
|
||||||
|
filtered = [f for f in formats if f in restrictions.allowed_api_formats]
|
||||||
|
if not filtered:
|
||||||
|
logger.info(f"[Models] API Key 不允许访问格式 {api_format}")
|
||||||
|
return [], _build_empty_list_response(api_format)
|
||||||
|
return filtered, None
|
||||||
|
|
||||||
|
|
||||||
def _authenticate(db: Session, api_key: Optional[str]) -> Tuple[Optional[User], Optional[ApiKey]]:
|
def _authenticate(db: Session, api_key: Optional[str]) -> Tuple[Optional[User], Optional[ApiKey]]:
|
||||||
"""
|
"""
|
||||||
认证 API Key
|
认证 API Key
|
||||||
@@ -375,22 +405,24 @@ async def list_models(
|
|||||||
logger.info(f"[Models] GET /v1/models | format={api_format}")
|
logger.info(f"[Models] GET /v1/models | format={api_format}")
|
||||||
|
|
||||||
# 认证
|
# 认证
|
||||||
user, _ = _authenticate(db, api_key)
|
user, key_record = _authenticate(db, api_key)
|
||||||
if not user:
|
if not user:
|
||||||
return _build_auth_error_response(api_format)
|
return _build_auth_error_response(api_format)
|
||||||
|
|
||||||
|
# 构建访问限制
|
||||||
|
restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
|
||||||
|
|
||||||
|
# 检查 API 格式限制
|
||||||
formats = _get_formats_for_api(api_format)
|
formats = _get_formats_for_api(api_format)
|
||||||
|
formats, empty_response = _filter_formats_by_restrictions(formats, restrictions, api_format)
|
||||||
|
if empty_response is not None:
|
||||||
|
return empty_response
|
||||||
|
|
||||||
available_provider_ids = get_available_provider_ids(db, formats)
|
available_provider_ids = get_available_provider_ids(db, formats)
|
||||||
if not available_provider_ids:
|
if not available_provider_ids:
|
||||||
if api_format == "claude":
|
return _build_empty_list_response(api_format)
|
||||||
return {"data": [], "has_more": False, "first_id": None, "last_id": None}
|
|
||||||
elif api_format == "gemini":
|
|
||||||
return {"models": []}
|
|
||||||
else:
|
|
||||||
return {"object": "list", "data": []}
|
|
||||||
|
|
||||||
models = await list_available_models(db, available_provider_ids, formats)
|
models = await list_available_models(db, available_provider_ids, formats, restrictions)
|
||||||
logger.debug(f"[Models] 返回 {len(models)} 个模型")
|
logger.debug(f"[Models] 返回 {len(models)} 个模型")
|
||||||
|
|
||||||
if api_format == "claude":
|
if api_format == "claude":
|
||||||
@@ -419,14 +451,21 @@ async def retrieve_model(
|
|||||||
logger.info(f"[Models] GET /v1/models/{model_id} | format={api_format}")
|
logger.info(f"[Models] GET /v1/models/{model_id} | format={api_format}")
|
||||||
|
|
||||||
# 认证
|
# 认证
|
||||||
user, _ = _authenticate(db, api_key)
|
user, key_record = _authenticate(db, api_key)
|
||||||
if not user:
|
if not user:
|
||||||
return _build_auth_error_response(api_format)
|
return _build_auth_error_response(api_format)
|
||||||
|
|
||||||
|
# 构建访问限制
|
||||||
|
restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
|
||||||
|
|
||||||
|
# 检查 API 格式限制
|
||||||
formats = _get_formats_for_api(api_format)
|
formats = _get_formats_for_api(api_format)
|
||||||
|
formats, _ = _filter_formats_by_restrictions(formats, restrictions, api_format)
|
||||||
|
if not formats:
|
||||||
|
return _build_404_response(model_id, api_format)
|
||||||
|
|
||||||
available_provider_ids = get_available_provider_ids(db, formats)
|
available_provider_ids = get_available_provider_ids(db, formats)
|
||||||
model_info = find_model_by_id(db, model_id, available_provider_ids, formats)
|
model_info = find_model_by_id(db, model_id, available_provider_ids, formats, restrictions)
|
||||||
|
|
||||||
if not model_info:
|
if not model_info:
|
||||||
return _build_404_response(model_id, api_format)
|
return _build_404_response(model_id, api_format)
|
||||||
@@ -455,15 +494,25 @@ async def list_models_gemini(
|
|||||||
api_key = _extract_api_key_from_request(request, gemini_def)
|
api_key = _extract_api_key_from_request(request, gemini_def)
|
||||||
|
|
||||||
# 认证
|
# 认证
|
||||||
user, _ = _authenticate(db, api_key)
|
user, key_record = _authenticate(db, api_key)
|
||||||
if not user:
|
if not user:
|
||||||
return _build_auth_error_response("gemini")
|
return _build_auth_error_response("gemini")
|
||||||
|
|
||||||
available_provider_ids = get_available_provider_ids(db, _GEMINI_FORMATS)
|
# 构建访问限制
|
||||||
|
restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
|
||||||
|
|
||||||
|
# 检查 API 格式限制
|
||||||
|
formats, empty_response = _filter_formats_by_restrictions(
|
||||||
|
_GEMINI_FORMATS, restrictions, "gemini"
|
||||||
|
)
|
||||||
|
if empty_response is not None:
|
||||||
|
return empty_response
|
||||||
|
|
||||||
|
available_provider_ids = get_available_provider_ids(db, formats)
|
||||||
if not available_provider_ids:
|
if not available_provider_ids:
|
||||||
return {"models": []}
|
return {"models": []}
|
||||||
|
|
||||||
models = await list_available_models(db, available_provider_ids, _GEMINI_FORMATS)
|
models = await list_available_models(db, available_provider_ids, formats, restrictions)
|
||||||
logger.debug(f"[Models] 返回 {len(models)} 个模型")
|
logger.debug(f"[Models] 返回 {len(models)} 个模型")
|
||||||
response = _build_gemini_list_response(models, page_size, page_token)
|
response = _build_gemini_list_response(models, page_size, page_token)
|
||||||
logger.debug(f"[Models] Gemini 响应: {response}")
|
logger.debug(f"[Models] Gemini 响应: {response}")
|
||||||
@@ -486,12 +535,22 @@ async def get_model_gemini(
|
|||||||
api_key = _extract_api_key_from_request(request, gemini_def)
|
api_key = _extract_api_key_from_request(request, gemini_def)
|
||||||
|
|
||||||
# 认证
|
# 认证
|
||||||
user, _ = _authenticate(db, api_key)
|
user, key_record = _authenticate(db, api_key)
|
||||||
if not user:
|
if not user:
|
||||||
return _build_auth_error_response("gemini")
|
return _build_auth_error_response("gemini")
|
||||||
|
|
||||||
available_provider_ids = get_available_provider_ids(db, _GEMINI_FORMATS)
|
# 构建访问限制
|
||||||
model_info = find_model_by_id(db, model_id, available_provider_ids, _GEMINI_FORMATS)
|
restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
|
||||||
|
|
||||||
|
# 检查 API 格式限制
|
||||||
|
formats, _ = _filter_formats_by_restrictions(_GEMINI_FORMATS, restrictions, "gemini")
|
||||||
|
if not formats:
|
||||||
|
return _build_404_response(model_id, "gemini")
|
||||||
|
|
||||||
|
available_provider_ids = get_available_provider_ids(db, formats)
|
||||||
|
model_info = find_model_by_id(
|
||||||
|
db, model_id, available_provider_ids, formats, restrictions
|
||||||
|
)
|
||||||
|
|
||||||
if not model_info:
|
if not model_info:
|
||||||
return _build_404_response(model_id, "gemini")
|
return _build_404_response(model_id, "gemini")
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from urllib.parse import quote, urlparse
|
|||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
from src.config import config
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
|
|
||||||
|
|
||||||
@@ -83,10 +84,10 @@ class HTTPClientPool:
|
|||||||
http2=False, # 暂时禁用HTTP/2以提高兼容性
|
http2=False, # 暂时禁用HTTP/2以提高兼容性
|
||||||
verify=True, # 启用SSL验证
|
verify=True, # 启用SSL验证
|
||||||
timeout=httpx.Timeout(
|
timeout=httpx.Timeout(
|
||||||
connect=10.0, # 连接超时
|
connect=config.http_connect_timeout,
|
||||||
read=300.0, # 读取超时(5分钟,适合流式响应)
|
read=config.http_read_timeout,
|
||||||
write=60.0, # 写入超时(60秒,支持大请求体)
|
write=config.http_write_timeout,
|
||||||
pool=5.0, # 连接池超时
|
pool=config.http_pool_timeout,
|
||||||
),
|
),
|
||||||
limits=httpx.Limits(
|
limits=httpx.Limits(
|
||||||
max_connections=100, # 最大连接数
|
max_connections=100, # 最大连接数
|
||||||
@@ -111,15 +112,20 @@ class HTTPClientPool:
|
|||||||
"""
|
"""
|
||||||
if name not in cls._clients:
|
if name not in cls._clients:
|
||||||
# 合并默认配置和自定义配置
|
# 合并默认配置和自定义配置
|
||||||
config = {
|
default_config = {
|
||||||
"http2": False,
|
"http2": False,
|
||||||
"verify": True,
|
"verify": True,
|
||||||
"timeout": httpx.Timeout(10.0, read=300.0),
|
"timeout": httpx.Timeout(
|
||||||
|
connect=config.http_connect_timeout,
|
||||||
|
read=config.http_read_timeout,
|
||||||
|
write=config.http_write_timeout,
|
||||||
|
pool=config.http_pool_timeout,
|
||||||
|
),
|
||||||
"follow_redirects": True,
|
"follow_redirects": True,
|
||||||
}
|
}
|
||||||
config.update(kwargs)
|
default_config.update(kwargs)
|
||||||
|
|
||||||
cls._clients[name] = httpx.AsyncClient(**config)
|
cls._clients[name] = httpx.AsyncClient(**default_config)
|
||||||
logger.debug(f"创建命名HTTP客户端: {name}")
|
logger.debug(f"创建命名HTTP客户端: {name}")
|
||||||
|
|
||||||
return cls._clients[name]
|
return cls._clients[name]
|
||||||
@@ -151,14 +157,19 @@ class HTTPClientPool:
|
|||||||
async with HTTPClientPool.get_temp_client() as client:
|
async with HTTPClientPool.get_temp_client() as client:
|
||||||
response = await client.get('https://example.com')
|
response = await client.get('https://example.com')
|
||||||
"""
|
"""
|
||||||
config = {
|
default_config = {
|
||||||
"http2": False,
|
"http2": False,
|
||||||
"verify": True,
|
"verify": True,
|
||||||
"timeout": httpx.Timeout(10.0),
|
"timeout": httpx.Timeout(
|
||||||
|
connect=config.http_connect_timeout,
|
||||||
|
read=config.http_read_timeout,
|
||||||
|
write=config.http_write_timeout,
|
||||||
|
pool=config.http_pool_timeout,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
config.update(kwargs)
|
default_config.update(kwargs)
|
||||||
|
|
||||||
client = httpx.AsyncClient(**config)
|
client = httpx.AsyncClient(**default_config)
|
||||||
try:
|
try:
|
||||||
yield client
|
yield client
|
||||||
finally:
|
finally:
|
||||||
@@ -182,25 +193,30 @@ class HTTPClientPool:
|
|||||||
Returns:
|
Returns:
|
||||||
配置好的 httpx.AsyncClient 实例
|
配置好的 httpx.AsyncClient 实例
|
||||||
"""
|
"""
|
||||||
config: Dict[str, Any] = {
|
client_config: Dict[str, Any] = {
|
||||||
"http2": False,
|
"http2": False,
|
||||||
"verify": True,
|
"verify": True,
|
||||||
"follow_redirects": True,
|
"follow_redirects": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
if timeout:
|
if timeout:
|
||||||
config["timeout"] = timeout
|
client_config["timeout"] = timeout
|
||||||
else:
|
else:
|
||||||
config["timeout"] = httpx.Timeout(10.0, read=300.0)
|
client_config["timeout"] = httpx.Timeout(
|
||||||
|
connect=config.http_connect_timeout,
|
||||||
|
read=config.http_read_timeout,
|
||||||
|
write=config.http_write_timeout,
|
||||||
|
pool=config.http_pool_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
# 添加代理配置
|
# 添加代理配置
|
||||||
proxy_url = build_proxy_url(proxy_config) if proxy_config else None
|
proxy_url = build_proxy_url(proxy_config) if proxy_config else None
|
||||||
if proxy_url:
|
if proxy_url:
|
||||||
config["proxy"] = proxy_url
|
client_config["proxy"] = proxy_url
|
||||||
logger.debug(f"创建带代理的HTTP客户端: {proxy_config.get('url', 'unknown')}")
|
logger.debug(f"创建带代理的HTTP客户端: {proxy_config.get('url', 'unknown')}")
|
||||||
|
|
||||||
config.update(kwargs)
|
client_config.update(kwargs)
|
||||||
return httpx.AsyncClient(**config)
|
return httpx.AsyncClient(**client_config)
|
||||||
|
|
||||||
|
|
||||||
# 便捷访问函数
|
# 便捷访问函数
|
||||||
|
|||||||
@@ -148,6 +148,7 @@ class Config:
|
|||||||
|
|
||||||
# HTTP 请求超时配置(秒)
|
# HTTP 请求超时配置(秒)
|
||||||
self.http_connect_timeout = float(os.getenv("HTTP_CONNECT_TIMEOUT", "10.0"))
|
self.http_connect_timeout = float(os.getenv("HTTP_CONNECT_TIMEOUT", "10.0"))
|
||||||
|
self.http_read_timeout = float(os.getenv("HTTP_READ_TIMEOUT", "300.0"))
|
||||||
self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0"))
|
self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0"))
|
||||||
self.http_pool_timeout = float(os.getenv("HTTP_POOL_TIMEOUT", "10.0"))
|
self.http_pool_timeout = float(os.getenv("HTTP_POOL_TIMEOUT", "10.0"))
|
||||||
|
|
||||||
|
|||||||
@@ -96,13 +96,15 @@ if not DISABLE_FILE_LOG:
|
|||||||
log_dir.mkdir(exist_ok=True)
|
log_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
# 文件日志通用配置
|
# 文件日志通用配置
|
||||||
|
# 注意: enqueue=False 使用同步模式,避免 multiprocessing 信号量泄漏
|
||||||
|
# 在 macOS 上,进程异常退出时 POSIX 信号量不会自动释放,导致资源耗尽
|
||||||
file_log_config = {
|
file_log_config = {
|
||||||
"format": FILE_FORMAT,
|
"format": FILE_FORMAT,
|
||||||
"filter": _log_filter,
|
"filter": _log_filter,
|
||||||
"rotation": "100 MB",
|
"rotation": "100 MB",
|
||||||
"retention": "30 days",
|
"retention": "30 days",
|
||||||
"compression": "gz",
|
"compression": "gz",
|
||||||
"enqueue": True,
|
"enqueue": False,
|
||||||
"encoding": "utf-8",
|
"encoding": "utf-8",
|
||||||
"catch": True,
|
"catch": True,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -360,6 +360,9 @@ def init_db():
|
|||||||
|
|
||||||
注意:数据库表结构由 Alembic 管理,部署时请运行 ./migrate.sh
|
注意:数据库表结构由 Alembic 管理,部署时请运行 ./migrate.sh
|
||||||
"""
|
"""
|
||||||
|
import sys
|
||||||
|
from sqlalchemy.exc import OperationalError
|
||||||
|
|
||||||
logger.info("初始化数据库...")
|
logger.info("初始化数据库...")
|
||||||
|
|
||||||
# 确保引擎已创建
|
# 确保引擎已创建
|
||||||
@@ -382,6 +385,38 @@ def init_db():
|
|||||||
db.commit()
|
db.commit()
|
||||||
logger.info("数据库初始化完成")
|
logger.info("数据库初始化完成")
|
||||||
|
|
||||||
|
except OperationalError as e:
|
||||||
|
db.rollback()
|
||||||
|
# 提取数据库连接信息用于提示
|
||||||
|
db_url = config.database_url
|
||||||
|
# 隐藏密码,只显示 host:port/database
|
||||||
|
if "@" in db_url:
|
||||||
|
db_info = db_url.split("@")[-1]
|
||||||
|
else:
|
||||||
|
db_info = db_url
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 直接打印到 stderr,确保消息显示
|
||||||
|
print("", file=sys.stderr)
|
||||||
|
print("=" * 60, file=sys.stderr)
|
||||||
|
print("数据库连接失败", file=sys.stderr)
|
||||||
|
print("=" * 60, file=sys.stderr)
|
||||||
|
print("", file=sys.stderr)
|
||||||
|
print(f"无法连接到数据库: {db_info}", file=sys.stderr)
|
||||||
|
print("", file=sys.stderr)
|
||||||
|
print("请检查以下事项:", file=sys.stderr)
|
||||||
|
print(" 1. PostgreSQL 服务是否正在运行", file=sys.stderr)
|
||||||
|
print(" 2. 数据库连接配置是否正确 (DATABASE_URL)", file=sys.stderr)
|
||||||
|
print(" 3. 数据库用户名和密码是否正确", file=sys.stderr)
|
||||||
|
print("", file=sys.stderr)
|
||||||
|
print("如果使用 Docker,请先运行:", file=sys.stderr)
|
||||||
|
print(" docker-compose up -d postgres redis", file=sys.stderr)
|
||||||
|
print("", file=sys.stderr)
|
||||||
|
print("=" * 60, file=sys.stderr)
|
||||||
|
# 使用 os._exit 直接退出,避免 uvicorn 捕获并打印堆栈
|
||||||
|
os._exit(1)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"数据库初始化失败: {e}")
|
logger.error(f"数据库初始化失败: {e}")
|
||||||
db.rollback()
|
db.rollback()
|
||||||
|
|||||||
@@ -317,6 +317,7 @@ class UpdateUserRequest(BaseModel):
|
|||||||
|
|
||||||
username: Optional[str] = Field(None, min_length=1, max_length=50)
|
username: Optional[str] = Field(None, min_length=1, max_length=50)
|
||||||
email: Optional[str] = Field(None, max_length=100)
|
email: Optional[str] = Field(None, max_length=100)
|
||||||
|
password: Optional[str] = Field(None, min_length=6, max_length=128, description="新密码(留空保持不变)")
|
||||||
quota_usd: Optional[float] = Field(None, ge=0)
|
quota_usd: Optional[float] = Field(None, ge=0)
|
||||||
is_active: Optional[bool] = None
|
is_active: Optional[bool] = None
|
||||||
role: Optional[str] = None
|
role: Optional[str] = None
|
||||||
|
|||||||
@@ -274,6 +274,13 @@ class GlobalModelListResponse(BaseModel):
|
|||||||
total: int
|
total: int
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalModelProvidersResponse(BaseModel):
|
||||||
|
"""GlobalModel 关联提供商列表响应"""
|
||||||
|
|
||||||
|
providers: List[ModelCatalogProviderDetail]
|
||||||
|
total: int
|
||||||
|
|
||||||
|
|
||||||
class BatchAssignToProvidersRequest(BaseModel):
|
class BatchAssignToProvidersRequest(BaseModel):
|
||||||
"""批量为 Provider 添加 GlobalModel 实现"""
|
"""批量为 Provider 添加 GlobalModel 实现"""
|
||||||
|
|
||||||
|
|||||||
35
src/services/cache/aware_scheduler.py
vendored
35
src/services/cache/aware_scheduler.py
vendored
@@ -30,6 +30,8 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import random
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||||
@@ -956,7 +958,16 @@ class CacheAwareScheduler:
|
|||||||
|
|
||||||
# 获取活跃的 Key 并按 internal_priority + 负载均衡排序
|
# 获取活跃的 Key 并按 internal_priority + 负载均衡排序
|
||||||
active_keys = [key for key in endpoint.api_keys if key.is_active]
|
active_keys = [key for key in endpoint.api_keys if key.is_active]
|
||||||
keys = self._shuffle_keys_by_internal_priority(active_keys, affinity_key)
|
# 检查是否所有 Key 都是 TTL=0(轮换模式)
|
||||||
|
# 如果所有 Key 的 cache_ttl_minutes 都是 0 或 None,则使用随机排序
|
||||||
|
use_random = all(
|
||||||
|
(key.cache_ttl_minutes or 0) == 0 for key in active_keys
|
||||||
|
) if active_keys else False
|
||||||
|
if use_random and len(active_keys) > 1:
|
||||||
|
logger.debug(
|
||||||
|
f" Endpoint {endpoint.id[:8]}... 启用 Key 轮换模式 (TTL=0, {len(active_keys)} keys)"
|
||||||
|
)
|
||||||
|
keys = self._shuffle_keys_by_internal_priority(active_keys, affinity_key, use_random)
|
||||||
|
|
||||||
for key in keys:
|
for key in keys:
|
||||||
# Key 级别的能力检查(模型级别的能力检查已在上面完成)
|
# Key 级别的能力检查(模型级别的能力检查已在上面完成)
|
||||||
@@ -1170,6 +1181,7 @@ class CacheAwareScheduler:
|
|||||||
self,
|
self,
|
||||||
keys: List[ProviderAPIKey],
|
keys: List[ProviderAPIKey],
|
||||||
affinity_key: Optional[str] = None,
|
affinity_key: Optional[str] = None,
|
||||||
|
use_random: bool = False,
|
||||||
) -> List[ProviderAPIKey]:
|
) -> List[ProviderAPIKey]:
|
||||||
"""
|
"""
|
||||||
对 API Key 按 internal_priority 分组,同优先级内部基于 affinity_key 进行确定性打乱
|
对 API Key 按 internal_priority 分组,同优先级内部基于 affinity_key 进行确定性打乱
|
||||||
@@ -1178,10 +1190,12 @@ class CacheAwareScheduler:
|
|||||||
- 数字越小越优先使用
|
- 数字越小越优先使用
|
||||||
- 同优先级 Key 之间实现负载均衡
|
- 同优先级 Key 之间实现负载均衡
|
||||||
- 使用 affinity_key 哈希确保同一请求 Key 的请求稳定(避免破坏缓存亲和性)
|
- 使用 affinity_key 哈希确保同一请求 Key 的请求稳定(避免破坏缓存亲和性)
|
||||||
|
- 当 use_random=True 时,使用随机排序实现轮换(用于 TTL=0 的场景)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
keys: API Key 列表
|
keys: API Key 列表
|
||||||
affinity_key: 亲和性标识符(通常为 API Key ID,用于确定性打乱)
|
affinity_key: 亲和性标识符(通常为 API Key ID,用于确定性打乱)
|
||||||
|
use_random: 是否使用随机排序(TTL=0 时为 True)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
排序后的 Key 列表
|
排序后的 Key 列表
|
||||||
@@ -1198,15 +1212,19 @@ class CacheAwareScheduler:
|
|||||||
priority = key.internal_priority if key.internal_priority is not None else 999999
|
priority = key.internal_priority if key.internal_priority is not None else 999999
|
||||||
priority_groups[priority].append(key)
|
priority_groups[priority].append(key)
|
||||||
|
|
||||||
# 对每个优先级组内的 Key 进行确定性打乱
|
# 对每个优先级组内的 Key 进行打乱
|
||||||
result = []
|
result = []
|
||||||
for priority in sorted(priority_groups.keys()): # 数字小的优先级高,排前面
|
for priority in sorted(priority_groups.keys()): # 数字小的优先级高,排前面
|
||||||
group_keys = priority_groups[priority]
|
group_keys = priority_groups[priority]
|
||||||
|
|
||||||
if len(group_keys) > 1 and affinity_key:
|
if len(group_keys) > 1:
|
||||||
# 改进的哈希策略:为每个 key 计算独立的哈希值
|
if use_random:
|
||||||
import hashlib
|
# TTL=0 模式:使用随机排序实现 Key 轮换
|
||||||
|
shuffled = list(group_keys)
|
||||||
|
random.shuffle(shuffled)
|
||||||
|
result.extend(shuffled)
|
||||||
|
elif affinity_key:
|
||||||
|
# 正常模式:使用哈希确定性打乱(保持缓存亲和性)
|
||||||
key_scores = []
|
key_scores = []
|
||||||
for key in group_keys:
|
for key in group_keys:
|
||||||
# 使用 affinity_key + key.id 的组合哈希
|
# 使用 affinity_key + key.id 的组合哈希
|
||||||
@@ -1218,8 +1236,11 @@ class CacheAwareScheduler:
|
|||||||
sorted_group = [key for _, key in sorted(key_scores)]
|
sorted_group = [key for _, key in sorted(key_scores)]
|
||||||
result.extend(sorted_group)
|
result.extend(sorted_group)
|
||||||
else:
|
else:
|
||||||
# 单个 Key 或没有 affinity_key 时保持原顺序
|
# 没有 affinity_key 时按 ID 排序保持稳定性
|
||||||
result.extend(sorted(group_keys, key=lambda k: k.id))
|
result.extend(sorted(group_keys, key=lambda k: k.id))
|
||||||
|
else:
|
||||||
|
# 单个 Key 直接添加
|
||||||
|
result.extend(group_keys)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
@@ -234,8 +234,15 @@ class EndpointHealthService:
|
|||||||
for api_format in format_key_mapping.keys()
|
for api_format in format_key_mapping.keys()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 参数校验(API 层已通过 Query(ge=1) 保证,这里做防御性检查)
|
||||||
|
if lookback_hours <= 0 or segments <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"lookback_hours and segments must be positive, "
|
||||||
|
f"got lookback_hours={lookback_hours}, segments={segments}"
|
||||||
|
)
|
||||||
|
|
||||||
# 计算时间范围
|
# 计算时间范围
|
||||||
interval_minutes = (lookback_hours * 60) // segments
|
segment_seconds = (lookback_hours * 3600) / segments
|
||||||
start_time = now - timedelta(hours=lookback_hours)
|
start_time = now - timedelta(hours=lookback_hours)
|
||||||
|
|
||||||
# 使用 RequestCandidate 表查询所有尝试记录
|
# 使用 RequestCandidate 表查询所有尝试记录
|
||||||
@@ -243,7 +250,7 @@ class EndpointHealthService:
|
|||||||
final_statuses = ["success", "failed", "skipped"]
|
final_statuses = ["success", "failed", "skipped"]
|
||||||
|
|
||||||
segment_expr = func.floor(
|
segment_expr = func.floor(
|
||||||
func.extract('epoch', RequestCandidate.created_at - start_time) / (interval_minutes * 60)
|
func.extract('epoch', RequestCandidate.created_at - start_time) / segment_seconds
|
||||||
).label('segment_idx')
|
).label('segment_idx')
|
||||||
|
|
||||||
candidate_stats = (
|
candidate_stats = (
|
||||||
|
|||||||
@@ -59,14 +59,15 @@ class ApiKeyService:
|
|||||||
if expire_days:
|
if expire_days:
|
||||||
expires_at = datetime.now(timezone.utc) + timedelta(days=expire_days)
|
expires_at = datetime.now(timezone.utc) + timedelta(days=expire_days)
|
||||||
|
|
||||||
|
# 空数组转为 None(表示不限制)
|
||||||
api_key = ApiKey(
|
api_key = ApiKey(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
key_hash=key_hash,
|
key_hash=key_hash,
|
||||||
key_encrypted=key_encrypted,
|
key_encrypted=key_encrypted,
|
||||||
name=name or f"API Key {datetime.now(timezone.utc).strftime('%Y%m%d%H%M%S')}",
|
name=name or f"API Key {datetime.now(timezone.utc).strftime('%Y%m%d%H%M%S')}",
|
||||||
allowed_providers=allowed_providers,
|
allowed_providers=allowed_providers or None,
|
||||||
allowed_api_formats=allowed_api_formats,
|
allowed_api_formats=allowed_api_formats or None,
|
||||||
allowed_models=allowed_models,
|
allowed_models=allowed_models or None,
|
||||||
rate_limit=rate_limit,
|
rate_limit=rate_limit,
|
||||||
concurrent_limit=concurrent_limit,
|
concurrent_limit=concurrent_limit,
|
||||||
expires_at=expires_at,
|
expires_at=expires_at,
|
||||||
@@ -141,8 +142,18 @@ class ApiKeyService:
|
|||||||
"auto_delete_on_expiry",
|
"auto_delete_on_expiry",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# 允许显式设置为空数组/None 的字段(空数组会转为 None,表示"全部")
|
||||||
|
nullable_list_fields = {"allowed_providers", "allowed_api_formats", "allowed_models"}
|
||||||
|
|
||||||
for field, value in kwargs.items():
|
for field, value in kwargs.items():
|
||||||
if field in updatable_fields and value is not None:
|
if field not in updatable_fields:
|
||||||
|
continue
|
||||||
|
# 对于 nullable_list_fields,空数组应该转为 None(表示不限制)
|
||||||
|
if field in nullable_list_fields:
|
||||||
|
if value is not None:
|
||||||
|
# 空数组转为 None(表示允许全部)
|
||||||
|
setattr(api_key, field, value if value else None)
|
||||||
|
elif value is not None:
|
||||||
setattr(api_key, field, value)
|
setattr(api_key, field, value)
|
||||||
|
|
||||||
api_key.updated_at = datetime.now(timezone.utc)
|
api_key.updated_at = datetime.now(timezone.utc)
|
||||||
|
|||||||
@@ -1,8 +1,16 @@
|
|||||||
"""分布式任务协调器,确保仅有一个 worker 执行特定任务"""
|
"""分布式任务协调器,确保仅有一个 worker 执行特定任务
|
||||||
|
|
||||||
|
锁清理策略:
|
||||||
|
- 单实例模式(默认):启动时使用原子操作清理旧锁并获取新锁
|
||||||
|
- 多实例模式:使用 NX 选项竞争锁,依赖 TTL 处理异常退出
|
||||||
|
|
||||||
|
使用方式:
|
||||||
|
- 默认行为:启动时清理旧锁(适用于单机部署)
|
||||||
|
- 多实例部署:设置 SINGLE_INSTANCE_MODE=false 禁用启动清理
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import uuid
|
import uuid
|
||||||
@@ -19,6 +27,10 @@ except ImportError: # pragma: no cover - Windows 环境
|
|||||||
class StartupTaskCoordinator:
|
class StartupTaskCoordinator:
|
||||||
"""利用 Redis 或文件锁,保证任务只在单个进程/实例中运行"""
|
"""利用 Redis 或文件锁,保证任务只在单个进程/实例中运行"""
|
||||||
|
|
||||||
|
# 类级别标记:在当前进程中是否已尝试过启动清理
|
||||||
|
# 注意:这在 fork 模式下每个 worker 都是独立的
|
||||||
|
_startup_cleanup_attempted = False
|
||||||
|
|
||||||
def __init__(self, redis_client=None, lock_dir: Optional[str] = None):
|
def __init__(self, redis_client=None, lock_dir: Optional[str] = None):
|
||||||
self.redis = redis_client
|
self.redis = redis_client
|
||||||
self._tokens: Dict[str, str] = {}
|
self._tokens: Dict[str, str] = {}
|
||||||
@@ -26,6 +38,8 @@ class StartupTaskCoordinator:
|
|||||||
self._lock_dir = pathlib.Path(lock_dir or os.getenv("TASK_LOCK_DIR", "./.locks"))
|
self._lock_dir = pathlib.Path(lock_dir or os.getenv("TASK_LOCK_DIR", "./.locks"))
|
||||||
if not self._lock_dir.exists():
|
if not self._lock_dir.exists():
|
||||||
self._lock_dir.mkdir(parents=True, exist_ok=True)
|
self._lock_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
# 单实例模式:启动时清理旧锁(适用于单机部署,避免残留锁问题)
|
||||||
|
self._single_instance_mode = os.getenv("SINGLE_INSTANCE_MODE", "true").lower() == "true"
|
||||||
|
|
||||||
def _redis_key(self, name: str) -> str:
|
def _redis_key(self, name: str) -> str:
|
||||||
return f"task_lock:{name}"
|
return f"task_lock:{name}"
|
||||||
@@ -36,7 +50,46 @@ class StartupTaskCoordinator:
|
|||||||
if self.redis:
|
if self.redis:
|
||||||
token = str(uuid.uuid4())
|
token = str(uuid.uuid4())
|
||||||
try:
|
try:
|
||||||
acquired = await self.redis.set(self._redis_key(name), token, nx=True, ex=ttl)
|
if self._single_instance_mode:
|
||||||
|
# 单实例模式:使用 Lua 脚本原子性地"清理旧锁 + 竞争获取"
|
||||||
|
# 只有当锁不存在或成功获取时才返回 1
|
||||||
|
# 这样第一个执行的 worker 会清理旧锁并获取,后续 worker 会正常竞争
|
||||||
|
script = """
|
||||||
|
local key = KEYS[1]
|
||||||
|
local token = ARGV[1]
|
||||||
|
local ttl = tonumber(ARGV[2])
|
||||||
|
local startup_key = KEYS[1] .. ':startup'
|
||||||
|
|
||||||
|
-- 检查是否已有 worker 执行过启动清理
|
||||||
|
local cleaned = redis.call('GET', startup_key)
|
||||||
|
if not cleaned then
|
||||||
|
-- 第一个 worker:删除旧锁,标记已清理
|
||||||
|
redis.call('DEL', key)
|
||||||
|
redis.call('SET', startup_key, '1', 'EX', 60)
|
||||||
|
end
|
||||||
|
|
||||||
|
-- 尝试获取锁(NX 模式)
|
||||||
|
local result = redis.call('SET', key, token, 'NX', 'EX', ttl)
|
||||||
|
if result then
|
||||||
|
return 1
|
||||||
|
end
|
||||||
|
return 0
|
||||||
|
"""
|
||||||
|
result = await self.redis.eval(
|
||||||
|
script, 2,
|
||||||
|
self._redis_key(name), self._redis_key(name),
|
||||||
|
token, ttl
|
||||||
|
)
|
||||||
|
if result == 1:
|
||||||
|
self._tokens[name] = token
|
||||||
|
logger.info(f"任务 {name} 通过 Redis 锁独占执行")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
# 多实例模式:直接使用 NX 选项竞争锁
|
||||||
|
acquired = await self.redis.set(
|
||||||
|
self._redis_key(name), token, nx=True, ex=ttl
|
||||||
|
)
|
||||||
if acquired:
|
if acquired:
|
||||||
self._tokens[name] = token
|
self._tokens[name] = token
|
||||||
logger.info(f"任务 {name} 通过 Redis 锁独占执行")
|
logger.info(f"任务 {name} 通过 Redis 锁独占执行")
|
||||||
|
|||||||
@@ -8,116 +8,86 @@ from src.api.handlers.base.utils import build_sse_headers, extract_cache_creatio
|
|||||||
class TestExtractCacheCreationTokens:
|
class TestExtractCacheCreationTokens:
|
||||||
"""测试 extract_cache_creation_tokens 函数"""
|
"""测试 extract_cache_creation_tokens 函数"""
|
||||||
|
|
||||||
# === 嵌套格式测试(优先级最高)===
|
def test_new_format_only(self) -> None:
|
||||||
|
"""测试只有新格式字段"""
|
||||||
def test_nested_cache_creation_format(self) -> None:
|
|
||||||
"""测试嵌套格式正常情况"""
|
|
||||||
usage = {
|
|
||||||
"cache_creation": {
|
|
||||||
"ephemeral_5m_input_tokens": 456,
|
|
||||||
"ephemeral_1h_input_tokens": 100,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert extract_cache_creation_tokens(usage) == 556
|
|
||||||
|
|
||||||
def test_nested_cache_creation_with_old_format_fallback(self) -> None:
|
|
||||||
"""测试嵌套格式为 0 时回退到旧格式"""
|
|
||||||
usage = {
|
|
||||||
"cache_creation": {
|
|
||||||
"ephemeral_5m_input_tokens": 0,
|
|
||||||
"ephemeral_1h_input_tokens": 0,
|
|
||||||
},
|
|
||||||
"cache_creation_input_tokens": 549,
|
|
||||||
}
|
|
||||||
assert extract_cache_creation_tokens(usage) == 549
|
|
||||||
|
|
||||||
def test_nested_has_priority_over_flat(self) -> None:
|
|
||||||
"""测试嵌套格式优先于扁平格式"""
|
|
||||||
usage = {
|
|
||||||
"cache_creation": {
|
|
||||||
"ephemeral_5m_input_tokens": 100,
|
|
||||||
"ephemeral_1h_input_tokens": 200,
|
|
||||||
},
|
|
||||||
"claude_cache_creation_5_m_tokens": 999, # 应该被忽略
|
|
||||||
"claude_cache_creation_1_h_tokens": 888, # 应该被忽略
|
|
||||||
"cache_creation_input_tokens": 777, # 应该被忽略
|
|
||||||
}
|
|
||||||
assert extract_cache_creation_tokens(usage) == 300
|
|
||||||
|
|
||||||
# === 扁平格式测试(优先级第二)===
|
|
||||||
|
|
||||||
def test_flat_new_format_still_works(self) -> None:
|
|
||||||
"""测试扁平新格式兼容性"""
|
|
||||||
usage = {
|
usage = {
|
||||||
"claude_cache_creation_5_m_tokens": 100,
|
"claude_cache_creation_5_m_tokens": 100,
|
||||||
"claude_cache_creation_1_h_tokens": 200,
|
"claude_cache_creation_1_h_tokens": 200,
|
||||||
}
|
}
|
||||||
assert extract_cache_creation_tokens(usage) == 300
|
assert extract_cache_creation_tokens(usage) == 300
|
||||||
|
|
||||||
def test_flat_new_format_with_old_format_fallback(self) -> None:
|
def test_new_format_5m_only(self) -> None:
|
||||||
"""测试扁平格式为 0 时回退到旧格式"""
|
"""测试只有 5 分钟缓存"""
|
||||||
usage = {
|
|
||||||
"claude_cache_creation_5_m_tokens": 0,
|
|
||||||
"claude_cache_creation_1_h_tokens": 0,
|
|
||||||
"cache_creation_input_tokens": 549,
|
|
||||||
}
|
|
||||||
assert extract_cache_creation_tokens(usage) == 549
|
|
||||||
|
|
||||||
def test_flat_new_format_5m_only(self) -> None:
|
|
||||||
"""测试只有 5 分钟扁平缓存"""
|
|
||||||
usage = {
|
usage = {
|
||||||
"claude_cache_creation_5_m_tokens": 150,
|
"claude_cache_creation_5_m_tokens": 150,
|
||||||
"claude_cache_creation_1_h_tokens": 0,
|
"claude_cache_creation_1_h_tokens": 0,
|
||||||
}
|
}
|
||||||
assert extract_cache_creation_tokens(usage) == 150
|
assert extract_cache_creation_tokens(usage) == 150
|
||||||
|
|
||||||
def test_flat_new_format_1h_only(self) -> None:
|
def test_new_format_1h_only(self) -> None:
|
||||||
"""测试只有 1 小时扁平缓存"""
|
"""测试只有 1 小时缓存"""
|
||||||
usage = {
|
usage = {
|
||||||
"claude_cache_creation_5_m_tokens": 0,
|
"claude_cache_creation_5_m_tokens": 0,
|
||||||
"claude_cache_creation_1_h_tokens": 250,
|
"claude_cache_creation_1_h_tokens": 250,
|
||||||
}
|
}
|
||||||
assert extract_cache_creation_tokens(usage) == 250
|
assert extract_cache_creation_tokens(usage) == 250
|
||||||
|
|
||||||
# === 旧格式测试(优先级第三)===
|
|
||||||
|
|
||||||
def test_old_format_only(self) -> None:
|
def test_old_format_only(self) -> None:
|
||||||
"""测试只有旧格式"""
|
"""测试只有旧格式字段"""
|
||||||
usage = {
|
usage = {
|
||||||
"cache_creation_input_tokens": 549,
|
"cache_creation_input_tokens": 500,
|
||||||
}
|
}
|
||||||
assert extract_cache_creation_tokens(usage) == 549
|
assert extract_cache_creation_tokens(usage) == 500
|
||||||
|
|
||||||
# === 边界情况测试 ===
|
def test_both_formats_prefers_new(self) -> None:
|
||||||
|
"""测试同时存在时优先使用新格式"""
|
||||||
|
usage = {
|
||||||
|
"claude_cache_creation_5_m_tokens": 100,
|
||||||
|
"claude_cache_creation_1_h_tokens": 200,
|
||||||
|
"cache_creation_input_tokens": 999, # 应该被忽略
|
||||||
|
}
|
||||||
|
assert extract_cache_creation_tokens(usage) == 300
|
||||||
|
|
||||||
def test_no_cache_creation_tokens(self) -> None:
|
def test_empty_usage(self) -> None:
|
||||||
"""测试没有任何缓存字段"""
|
"""测试空字典"""
|
||||||
usage = {}
|
usage = {}
|
||||||
assert extract_cache_creation_tokens(usage) == 0
|
assert extract_cache_creation_tokens(usage) == 0
|
||||||
|
|
||||||
def test_all_formats_zero(self) -> None:
|
def test_all_zeros(self) -> None:
|
||||||
"""测试所有格式都为 0"""
|
"""测试所有字段都为 0"""
|
||||||
usage = {
|
usage = {
|
||||||
"cache_creation": {
|
|
||||||
"ephemeral_5m_input_tokens": 0,
|
|
||||||
"ephemeral_1h_input_tokens": 0,
|
|
||||||
},
|
|
||||||
"claude_cache_creation_5_m_tokens": 0,
|
"claude_cache_creation_5_m_tokens": 0,
|
||||||
"claude_cache_creation_1_h_tokens": 0,
|
"claude_cache_creation_1_h_tokens": 0,
|
||||||
"cache_creation_input_tokens": 0,
|
"cache_creation_input_tokens": 0,
|
||||||
}
|
}
|
||||||
assert extract_cache_creation_tokens(usage) == 0
|
assert extract_cache_creation_tokens(usage) == 0
|
||||||
|
|
||||||
|
def test_partial_new_format_with_old_format_fallback(self) -> None:
|
||||||
|
"""测试新格式字段不存在时回退到旧格式"""
|
||||||
|
usage = {
|
||||||
|
"cache_creation_input_tokens": 123,
|
||||||
|
}
|
||||||
|
assert extract_cache_creation_tokens(usage) == 123
|
||||||
|
|
||||||
|
def test_new_format_zero_should_not_fallback(self) -> None:
|
||||||
|
"""测试新格式字段存在但为 0 时,不应 fallback 到旧格式"""
|
||||||
|
usage = {
|
||||||
|
"claude_cache_creation_5_m_tokens": 0,
|
||||||
|
"claude_cache_creation_1_h_tokens": 0,
|
||||||
|
"cache_creation_input_tokens": 456,
|
||||||
|
}
|
||||||
|
# 新格式字段存在,即使值为 0 也应该使用新格式(返回 0)
|
||||||
|
# 而不是 fallback 到旧格式(返回 456)
|
||||||
|
assert extract_cache_creation_tokens(usage) == 0
|
||||||
|
|
||||||
def test_unrelated_fields_ignored(self) -> None:
|
def test_unrelated_fields_ignored(self) -> None:
|
||||||
"""测试忽略无关字段"""
|
"""测试忽略无关字段"""
|
||||||
usage = {
|
usage = {
|
||||||
"input_tokens": 1000,
|
"input_tokens": 1000,
|
||||||
"output_tokens": 2000,
|
"output_tokens": 2000,
|
||||||
"cache_read_input_tokens": 300,
|
"cache_read_input_tokens": 300,
|
||||||
"cache_creation": {
|
"claude_cache_creation_5_m_tokens": 50,
|
||||||
"ephemeral_5m_input_tokens": 50,
|
"claude_cache_creation_1_h_tokens": 75,
|
||||||
"ephemeral_1h_input_tokens": 75,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
assert extract_cache_creation_tokens(usage) == 125
|
assert extract_cache_creation_tokens(usage) == 125
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user