24 Commits

Author SHA1 Message Date
fawney19
394cc536a9 feat: 添加 API 格式访问限制
扩展访问限制功能,支持 API Key 级别的 API 格式限制(OPENAI、CLAUDE、GEMINI)。

- AccessRestrictions 新增 allowed_api_formats 字段
- 新增 is_api_format_allowed() 方法检查格式权限
- models.py 添加 _filter_formats_by_restrictions() 函数过滤 API 格式
- 在所有模型列表和查询端点应用格式限制检查
- 添加 _build_empty_list_response() 统一空响应构建逻辑
2025-12-30 17:50:39 +08:00
fawney19
e20a09f15a feat: 添加模型列表访问限制功能
实现 API Key 和 User 级别的模型访问权限控制,支持按 Provider 和模型名称限制。

- 新增 AccessRestrictions 类处理访问限制合并逻辑(API Key 优先于 User)
- models_service 支持根据限制过滤模型列表
- models.py 在列表查询时构建并应用访问限制
- 优化缓存策略:仅无限制请求使用缓存,有限制的请求旁路缓存
- 修复 logger 配置:enqueue 改为 False 避免 macOS 信号量泄漏
2025-12-30 16:57:59 +08:00
fawney19
b89a4af0cf refactor: 统一 HTTP 客户端超时配置
将 HTTPClientPool 中硬编码的超时参数改为使用可配置的环境变量,提高系统的灵活性和可维护性。

- 添加 HTTP_READ_TIMEOUT 环境变量配置(默认 300 秒)
- 统一所有 HTTP 客户端创建逻辑使用配置化超时
- 改进变量命名清晰性(config -> default_config 或 client_config)
2025-12-30 15:06:55 +08:00
fawney19
a56854af43 feat: 为 GlobalModel 添加关联提供商查询 API
添加新的 API 端点 GET /api/admin/models/global/{global_model_id}/providers,用于获取 GlobalModel 的所有关联提供商(包括非活跃的)。

- 后端:实现 AdminGetGlobalModelProvidersAdapter 适配器
- 前端:使用新 API 替换原有的 ModelCatalog 获取方式
- 数据库:改进初始化时的错误提示和连接异常处理
2025-12-30 14:47:35 +08:00
fawney19
4a35d78c8d fix: 修正 README 中的图片文件扩展名 2025-12-30 09:44:45 +08:00
fawney19
26b281271e docs: 更新许可证说明、README 文档和模拟数据;调整模型版本至最新 2025-12-30 09:41:03 +08:00
fawney19
96094cfde2 fix: 调整仪表盘普通用户月度统计显示,添加月度费用字段 2025-12-29 18:28:37 +08:00
fawney19
7e26af5476 fix: 修复仪表盘缓存命中率和配额显示逻辑
- 本月缓存命中率改为使用本月数据计算(monthly_input_tokens + cache_read_tokens),而非全时数据
- 修复配额显示:有配额时显示实际金额,无配额时显示为 /bin/zsh 并标记为高警告状态
2025-12-29 18:12:33 +08:00
fawney19
c8dfb784bc fix: 修复仪表盘月份统计逻辑,改为自然月而非过去30天 2025-12-29 17:55:42 +08:00
fawney19
fd3a5a5afe refactor: 将模型测试功能从 ModelsTab 移到 KeyAllowedModelsDialog 2025-12-29 17:44:02 +08:00
fawney19
599b3d4c95 feat: 添加 Provider API Key 查看和复制功能
- 后端添加 GET /api/admin/endpoints/keys/{key_id}/reveal 接口
- 前端密钥列表添加眼睛按钮(显示/隐藏完整密钥)和复制按钮
- 关闭抽屉时自动清除已显示的密钥(安全考虑)

Fixes #53
2025-12-29 17:14:26 +08:00
fawney19
41719a00e7 refactor: 改进分布式任务锁的清理策略
实现两种锁清理模式:
- 单实例模式(默认):启动时使用 Lua 脚本原子性清理旧锁,解决 worker 重启时���锁残留问题
- 多实例模式:使用 NX 选项竞争锁,依赖 TTL 处理异常退出

可通过 SINGLE_INSTANCE_MODE 环境变量控制模式选择。
2025-12-28 21:34:43 +08:00
fawney19
b5c0f85dca refactor: 统一剪贴板复制功能到 useClipboard 组合式函数
将各个组件和视图中重复的剪贴板复制逻辑提取到 useClipboard 组合式函数。
增加 showToast 参数支持静默复制,减少代码重复,提高维护性。
2025-12-28 20:41:52 +08:00
fawney19
7d6d262ed3 feat: 增加用户密码修改时的确认验证
在编辑用户时,如果填写了新密码,需要进行密码确认,确保两次输入一致。
同时更新后端请求模型以支持密码字段。
2025-12-28 20:00:25 +08:00
fawney19
e21acd73eb fix: 修复模型映射中重复关联的问题
在批量分配模型和编辑模型映射时,需要检查不仅是主模型名是否已关联,
还要检查其映射名称是否已存在,防止同一个上游模型被重复关联。
2025-12-28 19:40:07 +08:00
fawney19
702f9bc5f1 fix: 修复缓存监控页面TTL分析时间段选择器点击无响应
为 Select 组件添加 v-model:open 绑定,解决 radix-vue Select 组件
在某些情况下点击无响应的问题。

Fixes #55
2025-12-28 19:14:49 +08:00
fawney19
d0ce798881 fix: TTL=0时启用Key随机轮换模式
- 当所有Key的cache_ttl_minutes都为0时,使用随机排序代替确定性哈希
- 将hashlib和random的import移到文件顶部
- 简化单Key场景的处理逻辑

Closes #57
2025-12-28 19:07:25 +08:00
fawney19
2b1d197047 Merge remote-tracking branch 'gitcode/master' into htmambo/master 2025-12-25 22:47:08 +08:00
fawney19
71bc2e6aab fix: 增加参数校验防止除零错误 2025-12-25 22:44:17 +08:00
fawney19
afb329934a fix: 修复端点健康统计时间分段计算的除零错误 2025-12-25 19:54:16 +08:00
elky0401
1313af45a3 !4 merge htmambo/master into master
refactor: 重构模型测试错误解析逻辑并修复用量统计变量引用

Created-by: elky0401
Commit-by: fawney19;hoping
Merged-by: elky0401
Description: feat: 引入统一的端点检查器以重构适配器并改进错误处理和用量统计。
refactor: 重构模型测试错误解析逻辑并修复用量统计变量引用

See merge request: elky0401/Aether!4
2025-12-25 19:39:33 +08:00
fawney19
dddb327885 refactor: 重构模型测试错误解析逻辑并修复用量统计变量引用
- 将 ModelsTab 和 ModelAliasesTab 中重复的错误解析逻辑提取到 errorParser.ts
- 添加 parseTestModelError 函数统一处理测试响应错误
- 为 testModel API 添加 TypeScript 类型定义 (TestModelRequest/TestModelResponse)
- 修复 endpoint_checker.py 中 usage_data 变量引用错误
2025-12-25 19:36:29 +08:00
hoping
26b4a37323 feat: 引入统一的端点检查器以重构适配器并改进错误处理和用量统计。 2025-12-25 00:02:56 +08:00
fawney19
9dad194130 fix: 修复 API Key 访问限制字段无法清除的问题
- 统一前端创建和更新 API Key 时的空数组处理逻辑
- 后端创建和更新接口都支持空数组转 NULL(表示不限制)
- 开启自动刷新时立即刷新一次数据
2025-12-24 22:35:30 +08:00
63 changed files with 3173 additions and 405 deletions

15
LICENSE
View File

@@ -5,12 +5,17 @@ Aether 非商业开源许可证
特此授予任何获得本软件及其相关文档文件(以下简称"软件")副本的人免费使用、 特此授予任何获得本软件及其相关文档文件(以下简称"软件")副本的人免费使用、
复制、修改、合并、发布和分发本软件的权限,但须遵守以下条件: 复制、修改、合并、发布和分发本软件的权限,但须遵守以下条件:
1. 仅限非商业用途 1. 仅限非盈利用途
本软件不得用于商业目的。商业目的包括但不限于: 本软件不得用于盈利目的。盈利目的包括但不限于:
- 出售本软件或任何衍生作品 - 出售本软件或任何衍生作品
- 使用本软件提供付费服务 - 使用本软件提供付费服务
- 将本软件用于商业产品或服务 - 将本软件用于以盈利为目的的商业产品或服务
- 将本软件用于任何旨在获取商业利益或金钱报酬的活动
以下用途被明确允许:
- 个人学习和研究
- 教育机构的教学和研究
- 非盈利组织的内部使用
- 企业内部非盈利性质的使用(如内部工具、测试环境等)
2. 署名要求 2. 署名要求
上述版权声明和本许可声明应包含在本软件的所有副本或主要部分中。 上述版权声明和本许可声明应包含在本软件的所有副本或主要部分中。
@@ -22,7 +27,7 @@ Aether 非商业开源许可证
您不得以不同的条款将本软件再许可给他人。 您不得以不同的条款将本软件再许可给他人。
5. 商业许可 5. 商业许可
如需商业使用,请联系版权持有人以获取单独的商业许可。 如需将本软件用于盈利目的,请联系版权持有人以获取单独的商业许可。
本软件按"原样"提供,不提供任何明示或暗示的保证,包括但不限于对适销性、 本软件按"原样"提供,不提供任何明示或暗示的保证,包括但不限于对适销性、
特定用途适用性和非侵权性的保证。在任何情况下,作者或版权持有人均不对任何 特定用途适用性和非侵权性的保证。在任何情况下,作者或版权持有人均不对任何

View File

@@ -143,7 +143,7 @@ cd frontend && npm install && npm run dev
- **模型级别**: 在模型管理中针对指定模型开启 1H缓存策略 - **模型级别**: 在模型管理中针对指定模型开启 1H缓存策略
- **密钥级别**: 在密钥管理中针对指定密钥使用 1H缓存策略 - **密钥级别**: 在密钥管理中针对指定密钥使用 1H缓存策略
> **注意**: 若对密钥设置强制 1H缓存, 则该密钥只能用支持 1H缓存的模型 > **注意**: 若对密钥设置强制 1H缓存, 则该密钥只能使用支持 1H缓存的模型, 匹配提供商Key, 将会导致这个Key无法同时用于Claude Code、Codex、GeminiCLI, 因为更推荐使用模型开启1H缓存.
### Q: 如何配置负载均衡? ### Q: 如何配置负载均衡?
@@ -162,4 +162,16 @@ cd frontend && npm install && npm run dev
## 许可证 ## 许可证
本项目采用 [Aether 非商业开源许可证](LICENSE)。 本项目采用 [Aether 非商业开源许可证](LICENSE)。允许个人学习、教育研究、非盈利组织及企业内部非盈利性质的使用;禁止用于盈利目的。商业使用请联系获取商业许可。
## 联系作者
<p align="center">
<img src="docs/author/qq_qrcode.jpg" width="200" alt="QQ二维码">
</p>
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=fawney19/Aether&type=Date)](https://star-history.com/#fawney19/Aether&Date)

BIN
docs/author/qq_qrcode.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 266 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 113 KiB

View File

@@ -87,6 +87,8 @@ export interface DashboardStatsResponse {
cache_stats?: CacheStats cache_stats?: CacheStats
users?: UserStats users?: UserStats
token_breakdown?: TokenBreakdown token_breakdown?: TokenBreakdown
// 普通用户专用字段
monthly_cost?: number
} }
export interface RecentRequestsResponse { export interface RecentRequestsResponse {

View File

@@ -4,7 +4,8 @@ import type {
GlobalModelUpdate, GlobalModelUpdate,
GlobalModelResponse, GlobalModelResponse,
GlobalModelWithStats, GlobalModelWithStats,
GlobalModelListResponse GlobalModelListResponse,
ModelCatalogProviderDetail,
} from './types' } from './types'
/** /**
@@ -83,3 +84,16 @@ export async function batchAssignToProviders(
) )
return response.data return response.data
} }
/**
* 获取 GlobalModel 的所有关联提供商(包括非活跃的)
*/
export async function getGlobalModelProviders(globalModelId: string): Promise<{
providers: ModelCatalogProviderDetail[]
total: number
}> {
const response = await client.get(
`/api/admin/models/global/${globalModelId}/providers`
)
return response.data
}

View File

@@ -110,6 +110,14 @@ export async function updateEndpointKey(
return response.data return response.data
} }
/**
* 获取完整的 API Key用于查看和复制
*/
export async function revealEndpointKey(keyId: string): Promise<{ api_key: string }> {
const response = await client.get(`/api/admin/endpoints/keys/${keyId}/reveal`)
return response.data
}
/** /**
* 删除 Endpoint Key * 删除 Endpoint Key
*/ */

View File

@@ -58,3 +58,38 @@ export async function deleteProvider(providerId: string): Promise<{ message: str
return response.data return response.data
} }
/**
* 测试模型连接性
*/
export interface TestModelRequest {
provider_id: string
model_name: string
api_key_id?: string
message?: string
api_format?: string
}
export interface TestModelResponse {
success: boolean
error?: string
data?: {
response?: {
status_code?: number
error?: string | { message?: string }
choices?: Array<{ message?: { content?: string } }>
}
content_preview?: string
}
provider?: {
id: string
name: string
display_name: string
}
model?: string
}
export async function testModel(data: TestModelRequest): Promise<TestModelResponse> {
const response = await client.post('/api/admin/provider-query/test-model', data)
return response.data
}

View File

@@ -20,4 +20,5 @@ export {
updateGlobalModel, updateGlobalModel,
deleteGlobalModel, deleteGlobalModel,
batchAssignToProviders, batchAssignToProviders,
getGlobalModelProviders,
} from './endpoints/global-models' } from './endpoints/global-models'

View File

@@ -163,7 +163,9 @@ const contentZIndex = computed(() => (props.zIndex || 60) + 10)
useEscapeKey(() => { useEscapeKey(() => {
if (isOpen.value) { if (isOpen.value) {
handleClose() handleClose()
return true // 阻止其他监听器(如父级抽屉的 ESC 监听器)
} }
return false
}, { }, {
disableOnInput: true, disableOnInput: true,
once: false once: false

View File

@@ -4,11 +4,11 @@ import { log } from '@/utils/logger'
export function useClipboard() { export function useClipboard() {
const { success, error: showError } = useToast() const { success, error: showError } = useToast()
async function copyToClipboard(text: string): Promise<boolean> { async function copyToClipboard(text: string, showToast = true): Promise<boolean> {
try { try {
if (navigator.clipboard && window.isSecureContext) { if (navigator.clipboard && window.isSecureContext) {
await navigator.clipboard.writeText(text) await navigator.clipboard.writeText(text)
success('已复制到剪贴板') if (showToast) success('已复制到剪贴板')
return true return true
} }
@@ -25,17 +25,17 @@ export function useClipboard() {
try { try {
const successful = document.execCommand('copy') const successful = document.execCommand('copy')
if (successful) { if (successful) {
success('已复制到剪贴板') if (showToast) success('已复制到剪贴板')
return true return true
} }
showError('复制失败,请手动复制') if (showToast) showError('复制失败,请手动复制')
return false return false
} finally { } finally {
document.body.removeChild(textArea) document.body.removeChild(textArea)
} }
} catch (err) { } catch (err) {
log.error('复制失败:', err) log.error('复制失败:', err)
showError('复制失败,请手动选择文本进行复制') if (showToast) showError('复制失败,请手动选择文本进行复制')
return false return false
} }
} }

View File

@@ -47,11 +47,11 @@ export function useConfirm() {
/** /**
* 便捷方法:危险操作确认(红色主题) * 便捷方法:危险操作确认(红色主题)
*/ */
const confirmDanger = (message: string, title?: string): Promise<boolean> => { const confirmDanger = (message: string, title?: string, confirmText?: string): Promise<boolean> => {
return confirm({ return confirm({
message, message,
title: title || '危险操作', title: title || '危险操作',
confirmText: '删除', confirmText: confirmText || '删除',
variant: 'danger' variant: 'danger'
}) })
} }

View File

@@ -4,11 +4,11 @@ import { onMounted, onUnmounted, ref } from 'vue'
* ESC 键监听 Composable简化版本直接使用独立监听器 * ESC 键监听 Composable简化版本直接使用独立监听器
* 用于按 ESC 键关闭弹窗或其他可关闭的组件 * 用于按 ESC 键关闭弹窗或其他可关闭的组件
* *
* @param callback - 按 ESC 键时执行的回调函数 * @param callback - 按 ESC 键时执行的回调函数,返回 true 表示已处理事件,阻止其他监听器执行
* @param options - 配置选项 * @param options - 配置选项
*/ */
export function useEscapeKey( export function useEscapeKey(
callback: () => void, callback: () => void | boolean,
options: { options: {
/** 是否在输入框获得焦点时禁用 ESC 键,默认 true */ /** 是否在输入框获得焦点时禁用 ESC 键,默认 true */
disableOnInput?: boolean disableOnInput?: boolean
@@ -42,8 +42,11 @@ export function useEscapeKey(
if (isInputElement) return if (isInputElement) return
} }
// 执行回调 // 执行回调,如果返回 true 则阻止其他监听器
callback() const handled = callback()
if (handled === true) {
event.stopImmediatePropagation()
}
// 移除当前元素的焦点,避免残留样式 // 移除当前元素的焦点,避免残留样式
if (document.activeElement instanceof HTMLElement) { if (document.activeElement instanceof HTMLElement) {

View File

@@ -700,6 +700,7 @@ import {
} from 'lucide-vue-next' } from 'lucide-vue-next'
import { useEscapeKey } from '@/composables/useEscapeKey' import { useEscapeKey } from '@/composables/useEscapeKey'
import { useToast } from '@/composables/useToast' import { useToast } from '@/composables/useToast'
import { useClipboard } from '@/composables/useClipboard'
import Card from '@/components/ui/card.vue' import Card from '@/components/ui/card.vue'
import Badge from '@/components/ui/badge.vue' import Badge from '@/components/ui/badge.vue'
import Button from '@/components/ui/button.vue' import Button from '@/components/ui/button.vue'
@@ -731,6 +732,7 @@ const emit = defineEmits<{
'refreshProviders': [] 'refreshProviders': []
}>() }>()
const { success: showSuccess, error: showError } = useToast() const { success: showSuccess, error: showError } = useToast()
const { copyToClipboard } = useClipboard()
interface Props { interface Props {
model: GlobalModelResponse | null model: GlobalModelResponse | null
@@ -763,16 +765,6 @@ function handleClose() {
} }
} }
// 复制到剪贴板
async function copyToClipboard(text: string) {
try {
await navigator.clipboard.writeText(text)
showSuccess('已复制')
} catch {
showError('复制失败')
}
}
// 格式化日期 // 格式化日期
function formatDate(dateStr: string): string { function formatDate(dateStr: string): string {
if (!dateStr) return '-' if (!dateStr) return '-'

View File

@@ -433,11 +433,17 @@ const availableGlobalModels = computed(() => {
) )
}) })
// 计算可添加的上游模型(排除已关联的) // 计算可添加的上游模型(排除已关联的,包括主模型名和映射名称
const availableUpstreamModelsBase = computed(() => { const availableUpstreamModelsBase = computed(() => {
const existingModelNames = new Set( const existingModelNames = new Set<string>()
existingModels.value.map(m => m.provider_model_name) for (const m of existingModels.value) {
) // 主模型名
existingModelNames.add(m.provider_model_name)
// 映射名称
for (const mapping of m.provider_model_mappings ?? []) {
if (mapping.name) existingModelNames.add(mapping.name)
}
}
return upstreamModels.value.filter(m => !existingModelNames.has(m.id)) return upstreamModels.value.filter(m => !existingModelNames.has(m.id))
}) })

View File

@@ -116,6 +116,19 @@
{{ model.global_model_name }} {{ model.global_model_name }}
</div> </div>
</div> </div>
<!-- 测试按钮 -->
<Button
variant="ghost"
size="icon"
class="h-7 w-7 shrink-0"
title="测试模型连接"
:disabled="testingModelName === model.global_model_name"
@click.stop="testModelConnection(model)"
>
<Loader2 v-if="testingModelName === model.global_model_name" class="w-3.5 h-3.5 animate-spin" />
<Play v-else class="w-3.5 h-3.5" />
</Button>
</div> </div>
</div> </div>
</div> </div>
@@ -148,16 +161,17 @@
<script setup lang="ts"> <script setup lang="ts">
import { ref, computed, watch } from 'vue' import { ref, computed, watch } from 'vue'
import { Box, Loader2, Settings2 } from 'lucide-vue-next' import { Box, Loader2, Settings2, Play } from 'lucide-vue-next'
import { Dialog } from '@/components/ui' import { Dialog } from '@/components/ui'
import Button from '@/components/ui/button.vue' import Button from '@/components/ui/button.vue'
import Badge from '@/components/ui/badge.vue' import Badge from '@/components/ui/badge.vue'
import Checkbox from '@/components/ui/checkbox.vue' import Checkbox from '@/components/ui/checkbox.vue'
import { useToast } from '@/composables/useToast' import { useToast } from '@/composables/useToast'
import { parseApiError } from '@/utils/errorParser' import { parseApiError, parseTestModelError } from '@/utils/errorParser'
import { import {
updateEndpointKey, updateEndpointKey,
getProviderAvailableSourceModels, getProviderAvailableSourceModels,
testModel,
type EndpointAPIKey, type EndpointAPIKey,
type ProviderAvailableSourceModel type ProviderAvailableSourceModel
} from '@/api/endpoints' } from '@/api/endpoints'
@@ -181,6 +195,7 @@ const loadingModels = ref(false)
const availableModels = ref<ProviderAvailableSourceModel[]>([]) const availableModels = ref<ProviderAvailableSourceModel[]>([])
const selectedModels = ref<string[]>([]) const selectedModels = ref<string[]>([])
const initialModels = ref<string[]>([]) const initialModels = ref<string[]>([])
const testingModelName = ref<string | null>(null)
// 监听对话框打开 // 监听对话框打开
watch(() => props.open, (open) => { watch(() => props.open, (open) => {
@@ -268,6 +283,32 @@ function clearModels() {
selectedModels.value = [] selectedModels.value = []
} }
// 测试模型连接
async function testModelConnection(model: ProviderAvailableSourceModel) {
if (!props.providerId || !props.apiKey || testingModelName.value) return
testingModelName.value = model.global_model_name
try {
const result = await testModel({
provider_id: props.providerId,
model_name: model.provider_model_name,
api_key_id: props.apiKey.id,
message: "hello"
})
if (result.success) {
success(`模型 "${model.display_name}" 测试成功`)
} else {
showError(`模型测试失败: ${parseTestModelError(result)}`)
}
} catch (err: any) {
const errorMsg = err.response?.data?.detail || err.message || '测试请求失败'
showError(`模型测试失败: ${errorMsg}`)
} finally {
testingModelName.value = null
}
}
function areArraysEqual(a: string[], b: string[]): boolean { function areArraysEqual(a: string[], b: string[]): boolean {
if (a.length !== b.length) return false if (a.length !== b.length) return false
const sortedA = [...a].sort() const sortedA = [...a].sort()

View File

@@ -17,7 +17,7 @@
v-model:open="modelSelectOpen" v-model:open="modelSelectOpen"
:model-value="formData.modelId" :model-value="formData.modelId"
:disabled="!!editingGroup" :disabled="!!editingGroup"
@update:model-value="formData.modelId = $event" @update:model-value="handleModelChange"
> >
<SelectTrigger class="h-9"> <SelectTrigger class="h-9">
<SelectValue placeholder="请选择模型" /> <SelectValue placeholder="请选择模型" />
@@ -449,7 +449,17 @@ interface UpstreamModelGroup {
} }
const groupedAvailableUpstreamModels = computed<UpstreamModelGroup[]>(() => { const groupedAvailableUpstreamModels = computed<UpstreamModelGroup[]>(() => {
// 收集当前表单已添加的名称
const addedNames = new Set(formData.value.aliases.map(a => a.name.trim())) const addedNames = new Set(formData.value.aliases.map(a => a.name.trim()))
// 收集所有已存在的映射名称(包括主模型名和映射名称)
for (const m of props.models) {
addedNames.add(m.provider_model_name)
for (const mapping of m.provider_model_mappings ?? []) {
if (mapping.name) addedNames.add(mapping.name)
}
}
const availableModels = filteredUpstreamModels.value.filter(m => !addedNames.has(m.id)) const availableModels = filteredUpstreamModels.value.filter(m => !addedNames.has(m.id))
const groups = new Map<string, UpstreamModelGroup>() const groups = new Map<string, UpstreamModelGroup>()
@@ -519,6 +529,15 @@ function initForm() {
} }
} }
// 处理模型选择变更
function handleModelChange(value: string) {
formData.value.modelId = value
const selectedModel = props.models.find(m => m.id === value)
if (selectedModel) {
upstreamModelSearch.value = selectedModel.provider_model_name
}
}
// 切换 API 格式 // 切换 API 格式
function toggleApiFormat(format: string) { function toggleApiFormat(format: string) {
const index = formData.value.apiFormats.indexOf(format) const index = formData.value.apiFormats.indexOf(format)

View File

@@ -337,8 +337,40 @@
{{ key.is_active ? '活跃' : '禁用' }} {{ key.is_active ? '活跃' : '禁用' }}
</Badge> </Badge>
</div> </div>
<div class="text-[10px] font-mono text-muted-foreground truncate"> <div class="flex items-center gap-1">
{{ key.api_key_masked }} <span class="text-[10px] font-mono text-muted-foreground truncate max-w-[180px]">
{{ revealedKeys.has(key.id) ? revealedKeys.get(key.id) : key.api_key_masked }}
</span>
<Button
variant="ghost"
size="icon"
class="h-5 w-5 shrink-0"
:title="revealedKeys.has(key.id) ? '隐藏密钥' : '显示密钥'"
:disabled="revealingKeyId === key.id"
@click.stop="toggleKeyReveal(key)"
>
<Loader2
v-if="revealingKeyId === key.id"
class="w-3 h-3 animate-spin"
/>
<EyeOff
v-else-if="revealedKeys.has(key.id)"
class="w-3 h-3"
/>
<Eye
v-else
class="w-3 h-3"
/>
</Button>
<Button
variant="ghost"
size="icon"
class="h-5 w-5 shrink-0"
title="复制密钥"
@click.stop="copyFullKey(key)"
>
<Copy class="w-3 h-3" />
</Button>
</div> </div>
</div> </div>
<div class="flex items-center gap-1.5 ml-auto shrink-0"> <div class="flex items-center gap-1.5 ml-auto shrink-0">
@@ -531,6 +563,7 @@
<!-- 模型名称映射 --> <!-- 模型名称映射 -->
<ModelAliasesTab <ModelAliasesTab
v-if="provider" v-if="provider"
ref="modelAliasesTabRef"
:key="`aliases-${provider.id}`" :key="`aliases-${provider.id}`"
:provider="provider" :provider="provider"
@refresh="handleRelatedDataRefresh" @refresh="handleRelatedDataRefresh"
@@ -653,13 +686,16 @@ import {
Power, Power,
Layers, Layers,
GripVertical, GripVertical,
Copy Copy,
Eye,
EyeOff
} from 'lucide-vue-next' } from 'lucide-vue-next'
import { useEscapeKey } from '@/composables/useEscapeKey' import { useEscapeKey } from '@/composables/useEscapeKey'
import Button from '@/components/ui/button.vue' import Button from '@/components/ui/button.vue'
import Badge from '@/components/ui/badge.vue' import Badge from '@/components/ui/badge.vue'
import Card from '@/components/ui/card.vue' import Card from '@/components/ui/card.vue'
import { useToast } from '@/composables/useToast' import { useToast } from '@/composables/useToast'
import { useClipboard } from '@/composables/useClipboard'
import { getProvider, getProviderEndpoints } from '@/api/endpoints' import { getProvider, getProviderEndpoints } from '@/api/endpoints'
import { import {
KeyFormDialog, KeyFormDialog,
@@ -679,6 +715,7 @@ import {
updateEndpoint, updateEndpoint,
updateEndpointKey, updateEndpointKey,
batchUpdateKeyPriority, batchUpdateKeyPriority,
revealEndpointKey,
type ProviderEndpoint, type ProviderEndpoint,
type EndpointAPIKey, type EndpointAPIKey,
type Model type Model
@@ -705,6 +742,7 @@ const emit = defineEmits<{
}>() }>()
const { error: showError, success: showSuccess } = useToast() const { error: showError, success: showSuccess } = useToast()
const { copyToClipboard } = useClipboard()
const loading = ref(false) const loading = ref(false)
const provider = ref<any>(null) const provider = ref<any>(null)
@@ -728,6 +766,10 @@ const recoveringEndpointId = ref<string | null>(null)
const togglingEndpointId = ref<string | null>(null) const togglingEndpointId = ref<string | null>(null)
const togglingKeyId = ref<string | null>(null) const togglingKeyId = ref<string | null>(null)
// 密钥显示状态key_id -> 完整密钥
const revealedKeys = ref<Map<string, string>>(new Map())
const revealingKeyId = ref<string | null>(null)
// 模型相关状态 // 模型相关状态
const modelFormDialogOpen = ref(false) const modelFormDialogOpen = ref(false)
const editingModel = ref<Model | null>(null) const editingModel = ref<Model | null>(null)
@@ -735,6 +777,9 @@ const deleteModelConfirmOpen = ref(false)
const modelToDelete = ref<Model | null>(null) const modelToDelete = ref<Model | null>(null)
const batchAssignDialogOpen = ref(false) const batchAssignDialogOpen = ref(false)
// ModelAliasesTab 组件引用
const modelAliasesTabRef = ref<InstanceType<typeof ModelAliasesTab> | null>(null)
// 拖动排序相关状态 // 拖动排序相关状态
const dragState = ref({ const dragState = ref({
isDragging: false, isDragging: false,
@@ -756,7 +801,9 @@ const hasBlockingDialogOpen = computed(() =>
deleteKeyConfirmOpen.value || deleteKeyConfirmOpen.value ||
modelFormDialogOpen.value || modelFormDialogOpen.value ||
deleteModelConfirmOpen.value || deleteModelConfirmOpen.value ||
batchAssignDialogOpen.value batchAssignDialogOpen.value ||
// 检测 ModelAliasesTab 子组件的 Dialog 是否打开
modelAliasesTabRef.value?.dialogOpen
) )
// 监听 providerId 变化 // 监听 providerId 变化
@@ -792,6 +839,9 @@ watch(() => props.open, (newOpen) => {
currentEndpoint.value = null currentEndpoint.value = null
editingKey.value = null editingKey.value = null
keyToDelete.value = null keyToDelete.value = null
// 清除已显示的密钥(安全考虑)
revealedKeys.value.clear()
} }
}) })
@@ -880,6 +930,43 @@ function handleConfigKeyModels(key: EndpointAPIKey) {
keyAllowedModelsDialogOpen.value = true keyAllowedModelsDialogOpen.value = true
} }
// 切换密钥显示/隐藏
async function toggleKeyReveal(key: EndpointAPIKey) {
if (revealedKeys.value.has(key.id)) {
// 已显示,隐藏它
revealedKeys.value.delete(key.id)
return
}
// 未显示,调用 API 获取完整密钥
revealingKeyId.value = key.id
try {
const result = await revealEndpointKey(key.id)
revealedKeys.value.set(key.id, result.api_key)
} catch (err: any) {
showError(err.response?.data?.detail || '获取密钥失败', '错误')
} finally {
revealingKeyId.value = null
}
}
// 复制完整密钥
async function copyFullKey(key: EndpointAPIKey) {
// 如果已经显示了,直接复制
if (revealedKeys.value.has(key.id)) {
copyToClipboard(revealedKeys.value.get(key.id)!)
return
}
// 否则先获取再复制
try {
const result = await revealEndpointKey(key.id)
copyToClipboard(result.api_key)
} catch (err: any) {
showError(err.response?.data?.detail || '获取密钥失败', '错误')
}
}
function handleDeleteKey(key: EndpointAPIKey) { function handleDeleteKey(key: EndpointAPIKey) {
keyToDelete.value = key keyToDelete.value = key
deleteKeyConfirmOpen.value = true deleteKeyConfirmOpen.value = true
@@ -1244,16 +1331,6 @@ function getHealthScoreBarColor(score: number): string {
return 'bg-red-500 dark:bg-red-400' return 'bg-red-500 dark:bg-red-400'
} }
// 复制到剪贴板
async function copyToClipboard(text: string) {
try {
await navigator.clipboard.writeText(text)
showSuccess('已复制到剪贴板')
} catch {
showError('复制失败', '错误')
}
}
// 加载 Provider 信息 // 加载 Provider 信息
async function loadProvider() { async function loadProvider() {
if (!props.providerId) return if (!props.providerId) return

View File

@@ -110,8 +110,9 @@
<div <div
v-for="mapping in group.aliases" v-for="mapping in group.aliases"
:key="mapping.name" :key="mapping.name"
class="flex items-center gap-2 py-1" class="flex items-center justify-between gap-2 py-1"
> >
<div class="flex items-center gap-2 flex-1 min-w-0">
<!-- 优先级标签 --> <!-- 优先级标签 -->
<span class="inline-flex items-center justify-center w-5 h-5 rounded bg-background border text-xs font-medium shrink-0"> <span class="inline-flex items-center justify-center w-5 h-5 rounded bg-background border text-xs font-medium shrink-0">
{{ mapping.priority }} {{ mapping.priority }}
@@ -121,6 +122,19 @@
{{ mapping.name }} {{ mapping.name }}
</span> </span>
</div> </div>
<!-- 测试按钮 -->
<Button
variant="ghost"
size="icon"
class="h-7 w-7 shrink-0"
title="测试映射"
:disabled="testingMapping === `${group.model.id}-${group.apiFormatsKey}-${mapping.name}`"
@click="testMapping(group, mapping)"
>
<Loader2 v-if="testingMapping === `${group.model.id}-${group.apiFormatsKey}-${mapping.name}`" class="w-3 h-3 animate-spin" />
<Play v-else class="w-3 h-3" />
</Button>
</div>
</div> </div>
</div> </div>
</div> </div>
@@ -166,18 +180,20 @@
<script setup lang="ts"> <script setup lang="ts">
import { ref, computed, onMounted, watch } from 'vue' import { ref, computed, onMounted, watch } from 'vue'
import { Tag, Plus, Edit, Trash2, ChevronRight } from 'lucide-vue-next' import { Tag, Plus, Edit, Trash2, ChevronRight, Loader2, Play } from 'lucide-vue-next'
import { Card, Button, Badge } from '@/components/ui' import { Card, Button, Badge } from '@/components/ui'
import AlertDialog from '@/components/common/AlertDialog.vue' import AlertDialog from '@/components/common/AlertDialog.vue'
import ModelMappingDialog, { type AliasGroup } from '../ModelMappingDialog.vue' import ModelMappingDialog, { type AliasGroup } from '../ModelMappingDialog.vue'
import { useToast } from '@/composables/useToast' import { useToast } from '@/composables/useToast'
import { import {
getProviderModels, getProviderModels,
testModel,
API_FORMAT_LABELS, API_FORMAT_LABELS,
type Model, type Model,
type ProviderModelAlias type ProviderModelAlias
} from '@/api/endpoints' } from '@/api/endpoints'
import { updateModel } from '@/api/endpoints/models' import { updateModel } from '@/api/endpoints/models'
import { parseTestModelError } from '@/utils/errorParser'
const props = defineProps<{ const props = defineProps<{
provider: any provider: any
@@ -196,6 +212,7 @@ const dialogOpen = ref(false)
const deleteConfirmOpen = ref(false) const deleteConfirmOpen = ref(false)
const editingGroup = ref<AliasGroup | null>(null) const editingGroup = ref<AliasGroup | null>(null)
const deletingGroup = ref<AliasGroup | null>(null) const deletingGroup = ref<AliasGroup | null>(null)
const testingMapping = ref<string | null>(null)
// 列表展开状态 // 列表展开状态
const expandedAliasGroups = ref<Set<string>>(new Set()) const expandedAliasGroups = ref<Set<string>>(new Set())
@@ -337,6 +354,49 @@ async function onDialogSaved() {
emit('refresh') emit('refresh')
} }
// 测试模型映射
async function testMapping(group: any, mapping: any) {
const testingKey = `${group.model.id}-${group.apiFormatsKey}-${mapping.name}`
testingMapping.value = testingKey
try {
// 根据分组的 API 格式来确定应该使用的格式
let apiFormat = null
if (group.apiFormats.length === 1) {
apiFormat = group.apiFormats[0]
} else if (group.apiFormats.length === 0) {
// 如果没有指定格式,但分组显示为"全部",则使用模型的默认格式
apiFormat = group.model.effective_api_format || group.model.api_format
}
const result = await testModel({
provider_id: props.provider.id,
model_name: mapping.name, // 使用映射名称进行测试
message: "hello",
api_format: apiFormat
})
if (result.success) {
showSuccess(`映射 "${mapping.name}" 测试成功`)
// 如果有响应内容,可以显示更多信息
if (result.data?.response?.choices?.[0]?.message?.content) {
const content = result.data.response.choices[0].message.content
showSuccess(`测试成功,响应: ${content.substring(0, 100)}${content.length > 100 ? '...' : ''}`)
} else if (result.data?.content_preview) {
showSuccess(`流式测试成功,预览: ${result.data.content_preview}`)
}
} else {
showError(`映射测试失败: ${parseTestModelError(result)}`)
}
} catch (err: any) {
const errorMsg = err.response?.data?.detail || err.message || '测试请求失败'
showError(`映射测试失败: ${errorMsg}`)
} finally {
testingMapping.value = null
}
}
// 监听 provider 变化 // 监听 provider 变化
watch(() => props.provider?.id, (newId) => { watch(() => props.provider?.id, (newId) => {
if (newId) { if (newId) {
@@ -349,4 +409,9 @@ onMounted(() => {
loadModels() loadModels()
} }
}) })
// 暴露给父组件,用于检测是否有弹窗打开
defineExpose({
dialogOpen: computed(() => dialogOpen.value || deleteConfirmOpen.value)
})
</script> </script>

View File

@@ -213,6 +213,7 @@ import { Box, Edit, Trash2, Layers, Eye, Wrench, Zap, Brain, Power, Copy, Image
import Card from '@/components/ui/card.vue' import Card from '@/components/ui/card.vue'
import Button from '@/components/ui/button.vue' import Button from '@/components/ui/button.vue'
import { useToast } from '@/composables/useToast' import { useToast } from '@/composables/useToast'
import { useClipboard } from '@/composables/useClipboard'
import { getProviderModels, type Model } from '@/api/endpoints' import { getProviderModels, type Model } from '@/api/endpoints'
import { updateModel } from '@/api/endpoints/models' import { updateModel } from '@/api/endpoints/models'
@@ -227,6 +228,7 @@ const emit = defineEmits<{
}>() }>()
const { error: showError, success: showSuccess } = useToast() const { error: showError, success: showSuccess } = useToast()
const { copyToClipboard } = useClipboard()
// 状态 // 状态
const loading = ref(false) const loading = ref(false)
@@ -244,12 +246,7 @@ const sortedModels = computed(() => {
// 复制模型 ID 到剪贴板 // 复制模型 ID 到剪贴板
async function copyModelId(modelId: string) { async function copyModelId(modelId: string) {
try { await copyToClipboard(modelId)
await navigator.clipboard.writeText(modelId)
showSuccess('已复制到剪贴板')
} catch {
showError('复制失败', '错误')
}
} }
// 加载模型 // 加载模型

View File

@@ -473,6 +473,7 @@
import { ref, watch, computed } from 'vue' import { ref, watch, computed } from 'vue'
import Button from '@/components/ui/button.vue' import Button from '@/components/ui/button.vue'
import { useEscapeKey } from '@/composables/useEscapeKey' import { useEscapeKey } from '@/composables/useEscapeKey'
import { useClipboard } from '@/composables/useClipboard'
import Card from '@/components/ui/card.vue' import Card from '@/components/ui/card.vue'
import Badge from '@/components/ui/badge.vue' import Badge from '@/components/ui/badge.vue'
import Separator from '@/components/ui/separator.vue' import Separator from '@/components/ui/separator.vue'
@@ -505,6 +506,7 @@ const copiedStates = ref<Record<string, boolean>>({})
const viewMode = ref<'compare' | 'formatted' | 'raw'>('compare') const viewMode = ref<'compare' | 'formatted' | 'raw'>('compare')
const currentExpandDepth = ref(1) const currentExpandDepth = ref(1)
const dataSource = ref<'client' | 'provider'>('client') const dataSource = ref<'client' | 'provider'>('client')
const { copyToClipboard } = useClipboard()
const historicalPricing = ref<{ const historicalPricing = ref<{
input_price: string input_price: string
output_price: string output_price: string
@@ -784,7 +786,7 @@ function copyJsonToClipboard(tabName: string) {
} }
if (data) { if (data) {
navigator.clipboard.writeText(JSON.stringify(data, null, 2)) copyToClipboard(JSON.stringify(data, null, 2), false)
copiedStates.value[tabName] = true copiedStates.value[tabName] = true
setTimeout(() => { setTimeout(() => {
copiedStates.value[tabName] = false copiedStates.value[tabName] = false

View File

@@ -86,6 +86,34 @@
</p> </p>
</div> </div>
<div
v-if="isEditMode && form.password.length > 0"
class="space-y-2"
>
<Label class="text-sm font-medium">
确认新密码 <span class="text-muted-foreground">*</span>
</Label>
<Input
:id="`pwd-confirm-${formNonce}`"
v-model="form.confirmPassword"
type="password"
autocomplete="new-password"
data-form-type="other"
data-lpignore="true"
:name="`confirm-${formNonce}`"
required
minlength="6"
placeholder="再次输入新密码"
class="h-10"
/>
<p
v-if="form.confirmPassword.length > 0 && form.password !== form.confirmPassword"
class="text-xs text-destructive"
>
两次输入的密码不一致
</p>
</div>
<div class="space-y-2"> <div class="space-y-2">
<Label <Label
for="form-email" for="form-email"
@@ -423,6 +451,7 @@ const apiFormats = ref<Array<{ value: string; label: string }>>([])
const form = ref({ const form = ref({
username: '', username: '',
password: '', password: '',
confirmPassword: '',
email: '', email: '',
quota: 10, quota: 10,
role: 'user' as 'admin' | 'user', role: 'user' as 'admin' | 'user',
@@ -443,6 +472,7 @@ function resetForm() {
form.value = { form.value = {
username: '', username: '',
password: '', password: '',
confirmPassword: '',
email: '', email: '',
quota: 10, quota: 10,
role: 'user', role: 'user',
@@ -461,6 +491,7 @@ function loadUserData() {
form.value = { form.value = {
username: props.user.username, username: props.user.username,
password: '', password: '',
confirmPassword: '',
email: props.user.email || '', email: props.user.email || '',
quota: props.user.quota_usd == null ? 10 : props.user.quota_usd, quota: props.user.quota_usd == null ? 10 : props.user.quota_usd,
role: props.user.role, role: props.user.role,
@@ -486,7 +517,9 @@ const isFormValid = computed(() => {
const hasUsername = form.value.username.trim().length > 0 const hasUsername = form.value.username.trim().length > 0
const hasEmail = form.value.email.trim().length > 0 const hasEmail = form.value.email.trim().length > 0
const hasPassword = isEditMode.value || form.value.password.length >= 6 const hasPassword = isEditMode.value || form.value.password.length >= 6
return hasUsername && hasEmail && hasPassword // 编辑模式下如果填写了密码,必须确认密码一致
const passwordConfirmed = !isEditMode.value || form.value.password.length === 0 || form.value.password === form.value.confirmPassword
return hasUsername && hasEmail && hasPassword && passwordConfirmed
}) })
// 加载访问控制选项 // 加载访问控制选项

View File

@@ -5,7 +5,7 @@
import type { User, LoginResponse } from '@/api/auth' import type { User, LoginResponse } from '@/api/auth'
import type { DashboardStatsResponse, RecentRequest, ProviderStatus, DailyStatsResponse } from '@/api/dashboard' import type { DashboardStatsResponse, RecentRequest, ProviderStatus, DailyStatsResponse } from '@/api/dashboard'
import type { User as AdminUser, ApiKey } from '@/api/users' import type { User as AdminUser } from '@/api/users'
import type { AdminApiKeysResponse } from '@/api/admin' import type { AdminApiKeysResponse } from '@/api/admin'
import type { Profile, UsageResponse } from '@/api/me' import type { Profile, UsageResponse } from '@/api/me'
import type { ProviderWithEndpointsSummary, GlobalModelResponse } from '@/api/endpoints/types' import type { ProviderWithEndpointsSummary, GlobalModelResponse } from '@/api/endpoints/types'
@@ -185,18 +185,20 @@ export const MOCK_DASHBOARD_STATS: DashboardStatsResponse = {
output: 700000, output: 700000,
cache_creation: 50000, cache_creation: 50000,
cache_read: 200000 cache_read: 200000
} },
// 普通用户专用字段
monthly_cost: 45.67
} }
export const MOCK_RECENT_REQUESTS: RecentRequest[] = [ export const MOCK_RECENT_REQUESTS: RecentRequest[] = [
{ id: 'req-001', user: 'alice', model: 'claude-sonnet-4-20250514', tokens: 15234, time: '2 分钟前' }, { id: 'req-001', user: 'alice', model: 'claude-sonnet-4-5-20250929', tokens: 15234, time: '2 分钟前' },
{ id: 'req-002', user: 'bob', model: 'gpt-4o', tokens: 8765, time: '5 分钟前' }, { id: 'req-002', user: 'bob', model: 'gpt-5.1', tokens: 8765, time: '5 分钟前' },
{ id: 'req-003', user: 'charlie', model: 'claude-opus-4-20250514', tokens: 32100, time: '8 分钟前' }, { id: 'req-003', user: 'charlie', model: 'claude-opus-4-5-20251101', tokens: 32100, time: '8 分钟前' },
{ id: 'req-004', user: 'diana', model: 'gemini-2.0-flash', tokens: 4521, time: '12 分钟前' }, { id: 'req-004', user: 'diana', model: 'gemini-3-pro-preview', tokens: 4521, time: '12 分钟前' },
{ id: 'req-005', user: 'eve', model: 'claude-sonnet-4-20250514', tokens: 9876, time: '15 分钟前' }, { id: 'req-005', user: 'eve', model: 'claude-sonnet-4-5-20250929', tokens: 9876, time: '15 分钟前' },
{ id: 'req-006', user: 'frank', model: 'gpt-4o-mini', tokens: 2345, time: '18 分钟前' }, { id: 'req-006', user: 'frank', model: 'gpt-5.1-codex-mini', tokens: 2345, time: '18 分钟前' },
{ id: 'req-007', user: 'grace', model: 'claude-haiku-3-5-20241022', tokens: 6789, time: '22 分钟前' }, { id: 'req-007', user: 'grace', model: 'claude-haiku-4-5-20251001', tokens: 6789, time: '22 分钟前' },
{ id: 'req-008', user: 'henry', model: 'gemini-2.5-pro', tokens: 12345, time: '25 分钟前' } { id: 'req-008', user: 'henry', model: 'gemini-3-pro-preview', tokens: 12345, time: '25 分钟前' }
] ]
export const MOCK_PROVIDER_STATUS: ProviderStatus[] = [ export const MOCK_PROVIDER_STATUS: ProviderStatus[] = [
@@ -231,11 +233,11 @@ function generateDailyStats(): DailyStatsResponse {
unique_models: 8 + Math.floor(Math.random() * 5), unique_models: 8 + Math.floor(Math.random() * 5),
unique_providers: 4 + Math.floor(Math.random() * 3), unique_providers: 4 + Math.floor(Math.random() * 3),
model_breakdown: [ model_breakdown: [
{ model: 'claude-sonnet-4-20250514', requests: Math.floor(baseRequests * 0.35), tokens: Math.floor(baseTokens * 0.35), cost: Number((baseCost * 0.35).toFixed(2)) }, { model: 'claude-sonnet-4-5-20250929', requests: Math.floor(baseRequests * 0.35), tokens: Math.floor(baseTokens * 0.35), cost: Number((baseCost * 0.35).toFixed(2)) },
{ model: 'gpt-4o', requests: Math.floor(baseRequests * 0.25), tokens: Math.floor(baseTokens * 0.25), cost: Number((baseCost * 0.25).toFixed(2)) }, { model: 'gpt-5.1', requests: Math.floor(baseRequests * 0.25), tokens: Math.floor(baseTokens * 0.25), cost: Number((baseCost * 0.25).toFixed(2)) },
{ model: 'claude-opus-4-20250514', requests: Math.floor(baseRequests * 0.15), tokens: Math.floor(baseTokens * 0.15), cost: Number((baseCost * 0.20).toFixed(2)) }, { model: 'claude-opus-4-5-20251101', requests: Math.floor(baseRequests * 0.15), tokens: Math.floor(baseTokens * 0.15), cost: Number((baseCost * 0.20).toFixed(2)) },
{ model: 'gemini-2.0-flash', requests: Math.floor(baseRequests * 0.15), tokens: Math.floor(baseTokens * 0.15), cost: Number((baseCost * 0.10).toFixed(2)) }, { model: 'gemini-3-pro-preview', requests: Math.floor(baseRequests * 0.15), tokens: Math.floor(baseTokens * 0.15), cost: Number((baseCost * 0.10).toFixed(2)) },
{ model: 'claude-haiku-3-5-20241022', requests: Math.floor(baseRequests * 0.10), tokens: Math.floor(baseTokens * 0.10), cost: Number((baseCost * 0.10).toFixed(2)) } { model: 'claude-haiku-4-5-20251001', requests: Math.floor(baseRequests * 0.10), tokens: Math.floor(baseTokens * 0.10), cost: Number((baseCost * 0.10).toFixed(2)) }
] ]
}) })
} }
@@ -243,11 +245,11 @@ function generateDailyStats(): DailyStatsResponse {
return { return {
daily_stats: dailyStats, daily_stats: dailyStats,
model_summary: [ model_summary: [
{ model: 'claude-sonnet-4-20250514', requests: 2456, tokens: 8500000, cost: 125.45, avg_response_time: 1.2, cost_per_request: 0.051, tokens_per_request: 3461 }, { model: 'claude-sonnet-4-5-20250929', requests: 2456, tokens: 8500000, cost: 125.45, avg_response_time: 1.2, cost_per_request: 0.051, tokens_per_request: 3461 },
{ model: 'gpt-4o', requests: 1823, tokens: 6200000, cost: 98.32, avg_response_time: 0.9, cost_per_request: 0.054, tokens_per_request: 3401 }, { model: 'gpt-5.1', requests: 1823, tokens: 6200000, cost: 98.32, avg_response_time: 0.9, cost_per_request: 0.054, tokens_per_request: 3401 },
{ model: 'claude-opus-4-20250514', requests: 987, tokens: 4100000, cost: 156.78, avg_response_time: 2.1, cost_per_request: 0.159, tokens_per_request: 4154 }, { model: 'claude-opus-4-5-20251101', requests: 987, tokens: 4100000, cost: 156.78, avg_response_time: 2.1, cost_per_request: 0.159, tokens_per_request: 4154 },
{ model: 'gemini-2.0-flash', requests: 1234, tokens: 3800000, cost: 28.56, avg_response_time: 0.6, cost_per_request: 0.023, tokens_per_request: 3079 }, { model: 'gemini-3-pro-preview', requests: 1234, tokens: 3800000, cost: 28.56, avg_response_time: 0.6, cost_per_request: 0.023, tokens_per_request: 3079 },
{ model: 'claude-haiku-3-5-20241022', requests: 2100, tokens: 5200000, cost: 32.10, avg_response_time: 0.5, cost_per_request: 0.015, tokens_per_request: 2476 } { model: 'claude-haiku-4-5-20251001', requests: 2100, tokens: 5200000, cost: 32.10, avg_response_time: 0.5, cost_per_request: 0.015, tokens_per_request: 2476 }
], ],
period: { period: {
start_date: dailyStats[0].date, start_date: dailyStats[0].date,
@@ -336,7 +338,7 @@ export const MOCK_ALL_USERS: AdminUser[] = [
// ========== API Key 数据 ========== // ========== API Key 数据 ==========
export const MOCK_USER_API_KEYS: ApiKey[] = [ export const MOCK_USER_API_KEYS = [
{ {
id: 'key-uuid-001', id: 'key-uuid-001',
key_display: 'sk-ae...x7f9', key_display: 'sk-ae...x7f9',
@@ -346,7 +348,8 @@ export const MOCK_USER_API_KEYS: ApiKey[] = [
is_active: true, is_active: true,
is_standalone: false, is_standalone: false,
total_requests: 1234, total_requests: 1234,
total_cost_usd: 45.67 total_cost_usd: 45.67,
force_capabilities: null
}, },
{ {
id: 'key-uuid-002', id: 'key-uuid-002',
@@ -357,7 +360,8 @@ export const MOCK_USER_API_KEYS: ApiKey[] = [
is_active: true, is_active: true,
is_standalone: false, is_standalone: false,
total_requests: 5678, total_requests: 5678,
total_cost_usd: 123.45 total_cost_usd: 123.45,
force_capabilities: { cache_1h: true }
}, },
{ {
id: 'key-uuid-003', id: 'key-uuid-003',
@@ -367,7 +371,8 @@ export const MOCK_USER_API_KEYS: ApiKey[] = [
is_active: false, is_active: false,
is_standalone: false, is_standalone: false,
total_requests: 100, total_requests: 100,
total_cost_usd: 2.34 total_cost_usd: 2.34,
force_capabilities: null
} }
] ]
@@ -813,16 +818,16 @@ export const MOCK_USAGE_RESPONSE: UsageResponse = {
quota_usd: 100, quota_usd: 100,
used_usd: 45.32, used_usd: 45.32,
summary_by_model: [ summary_by_model: [
{ model: 'claude-sonnet-4-20250514', requests: 456, input_tokens: 650000, output_tokens: 250000, total_tokens: 900000, total_cost_usd: 18.50, actual_total_cost_usd: 13.50 }, { model: 'claude-sonnet-4-5-20250929', requests: 456, input_tokens: 650000, output_tokens: 250000, total_tokens: 900000, total_cost_usd: 18.50, actual_total_cost_usd: 13.50 },
{ model: 'gpt-4o', requests: 312, input_tokens: 480000, output_tokens: 180000, total_tokens: 660000, total_cost_usd: 12.30, actual_total_cost_usd: 9.20 }, { model: 'gpt-5.1', requests: 312, input_tokens: 480000, output_tokens: 180000, total_tokens: 660000, total_cost_usd: 12.30, actual_total_cost_usd: 9.20 },
{ model: 'claude-haiku-3-5-20241022', requests: 289, input_tokens: 420000, output_tokens: 170000, total_tokens: 590000, total_cost_usd: 8.50, actual_total_cost_usd: 6.30 }, { model: 'claude-haiku-4-5-20251001', requests: 289, input_tokens: 420000, output_tokens: 170000, total_tokens: 590000, total_cost_usd: 8.50, actual_total_cost_usd: 6.30 },
{ model: 'gemini-2.0-flash', requests: 177, input_tokens: 250000, output_tokens: 100000, total_tokens: 350000, total_cost_usd: 6.37, actual_total_cost_usd: 4.33 } { model: 'gemini-3-pro-preview', requests: 177, input_tokens: 250000, output_tokens: 100000, total_tokens: 350000, total_cost_usd: 6.37, actual_total_cost_usd: 4.33 }
], ],
records: [ records: [
{ {
id: 'usage-001', id: 'usage-001',
provider: 'anthropic', provider: 'anthropic',
model: 'claude-sonnet-4-20250514', model: 'claude-sonnet-4-5-20250929',
input_tokens: 1500, input_tokens: 1500,
output_tokens: 800, output_tokens: 800,
total_tokens: 2300, total_tokens: 2300,
@@ -837,7 +842,7 @@ export const MOCK_USAGE_RESPONSE: UsageResponse = {
{ {
id: 'usage-002', id: 'usage-002',
provider: 'openai', provider: 'openai',
model: 'gpt-4o', model: 'gpt-5.1',
input_tokens: 2000, input_tokens: 2000,
output_tokens: 500, output_tokens: 500,
total_tokens: 2500, total_tokens: 2500,

View File

@@ -405,10 +405,10 @@ function getUsageRecords() {
// Mock 映射数据 // Mock 映射数据
const MOCK_ALIASES = [ const MOCK_ALIASES = [
{ id: 'alias-001', source_model: 'claude-4-sonnet', target_global_model_id: 'gm-001', target_global_model_name: 'claude-sonnet-4-20250514', target_global_model_display_name: 'Claude Sonnet 4', provider_id: null, provider_name: null, scope: 'global', mapping_type: 'alias', is_active: true, created_at: '2024-01-01T00:00:00Z', updated_at: '2024-01-01T00:00:00Z' }, { id: 'alias-001', source_model: 'claude-4-sonnet', target_global_model_id: 'gm-003', target_global_model_name: 'claude-sonnet-4-5-20250929', target_global_model_display_name: 'Claude Sonnet 4.5', provider_id: null, provider_name: null, scope: 'global', mapping_type: 'alias', is_active: true, created_at: '2024-01-01T00:00:00Z', updated_at: '2024-01-01T00:00:00Z' },
{ id: 'alias-002', source_model: 'claude-4-opus', target_global_model_id: 'gm-002', target_global_model_name: 'claude-opus-4-20250514', target_global_model_display_name: 'Claude Opus 4', provider_id: null, provider_name: null, scope: 'global', mapping_type: 'alias', is_active: true, created_at: '2024-01-01T00:00:00Z', updated_at: '2024-01-01T00:00:00Z' }, { id: 'alias-002', source_model: 'claude-4-opus', target_global_model_id: 'gm-002', target_global_model_name: 'claude-opus-4-5-20251101', target_global_model_display_name: 'Claude Opus 4.5', provider_id: null, provider_name: null, scope: 'global', mapping_type: 'alias', is_active: true, created_at: '2024-01-01T00:00:00Z', updated_at: '2024-01-01T00:00:00Z' },
{ id: 'alias-003', source_model: 'gpt4o', target_global_model_id: 'gm-004', target_global_model_name: 'gpt-4o', target_global_model_display_name: 'GPT-4o', provider_id: null, provider_name: null, scope: 'global', mapping_type: 'alias', is_active: true, created_at: '2024-01-01T00:00:00Z', updated_at: '2024-01-01T00:00:00Z' }, { id: 'alias-003', source_model: 'gpt5', target_global_model_id: 'gm-006', target_global_model_name: 'gpt-5.1', target_global_model_display_name: 'GPT-5.1', provider_id: null, provider_name: null, scope: 'global', mapping_type: 'alias', is_active: true, created_at: '2024-01-01T00:00:00Z', updated_at: '2024-01-01T00:00:00Z' },
{ id: 'alias-004', source_model: 'gemini-flash', target_global_model_id: 'gm-005', target_global_model_name: 'gemini-2.0-flash', target_global_model_display_name: 'Gemini 2.0 Flash', provider_id: null, provider_name: null, scope: 'global', mapping_type: 'alias', is_active: true, created_at: '2024-01-01T00:00:00Z', updated_at: '2024-01-01T00:00:00Z' } { id: 'alias-004', source_model: 'gemini-pro', target_global_model_id: 'gm-005', target_global_model_name: 'gemini-3-pro-preview', target_global_model_display_name: 'Gemini 3 Pro Preview', provider_id: null, provider_name: null, scope: 'global', mapping_type: 'alias', is_active: true, created_at: '2024-01-01T00:00:00Z', updated_at: '2024-01-01T00:00:00Z' }
] ]
// Mock Endpoint Keys // Mock Endpoint Keys
@@ -2172,10 +2172,10 @@ function generateIntervalTimelineData(
// 模型列表(用于按模型区分颜色) // 模型列表(用于按模型区分颜色)
const models = [ const models = [
'claude-sonnet-4-20250514', 'claude-sonnet-4-5-20250929',
'claude-3-5-sonnet-20241022', 'claude-haiku-4-5-20251001',
'claude-3-5-haiku-20241022', 'claude-opus-4-5-20251101',
'claude-opus-4-20250514' 'gpt-5.1'
] ]
// 生成模拟的请求间隔数据 // 生成模拟的请求间隔数据

View File

@@ -14,7 +14,7 @@ export const useUsersStore = defineStore('users', () => {
try { try {
users.value = await usersApi.getAllUsers() users.value = await usersApi.getAllUsers()
} catch (err: any) { } catch (err: any) {
error.value = err.response?.data?.detail || '获取用户列表失败' error.value = err.response?.data?.error?.message || err.response?.data?.detail || '获取用户列表失败'
} finally { } finally {
loading.value = false loading.value = false
} }
@@ -29,7 +29,7 @@ export const useUsersStore = defineStore('users', () => {
users.value.push(newUser) users.value.push(newUser)
return newUser return newUser
} catch (err: any) { } catch (err: any) {
error.value = err.response?.data?.detail || '创建用户失败' error.value = err.response?.data?.error?.message || err.response?.data?.detail || '创建用户失败'
throw err throw err
} finally { } finally {
loading.value = false loading.value = false
@@ -52,7 +52,7 @@ export const useUsersStore = defineStore('users', () => {
} }
return updatedUser return updatedUser
} catch (err: any) { } catch (err: any) {
error.value = err.response?.data?.detail || '更新用户失败' error.value = err.response?.data?.error?.message || err.response?.data?.detail || '更新用户失败'
throw err throw err
} finally { } finally {
loading.value = false loading.value = false
@@ -67,7 +67,7 @@ export const useUsersStore = defineStore('users', () => {
await usersApi.deleteUser(userId) await usersApi.deleteUser(userId)
users.value = users.value.filter(u => u.id !== userId) users.value = users.value.filter(u => u.id !== userId)
} catch (err: any) { } catch (err: any) {
error.value = err.response?.data?.detail || '删除用户失败' error.value = err.response?.data?.error?.message || err.response?.data?.detail || '删除用户失败'
throw err throw err
} finally { } finally {
loading.value = false loading.value = false
@@ -78,7 +78,7 @@ export const useUsersStore = defineStore('users', () => {
try { try {
return await usersApi.getUserApiKeys(userId) return await usersApi.getUserApiKeys(userId)
} catch (err: any) { } catch (err: any) {
error.value = err.response?.data?.detail || '获取 API Keys 失败' error.value = err.response?.data?.error?.message || err.response?.data?.detail || '获取 API Keys 失败'
throw err throw err
} }
} }
@@ -87,7 +87,7 @@ export const useUsersStore = defineStore('users', () => {
try { try {
return await usersApi.createApiKey(userId, name) return await usersApi.createApiKey(userId, name)
} catch (err: any) { } catch (err: any) {
error.value = err.response?.data?.detail || '创建 API Key 失败' error.value = err.response?.data?.error?.message || err.response?.data?.detail || '创建 API Key 失败'
throw err throw err
} }
} }
@@ -96,7 +96,7 @@ export const useUsersStore = defineStore('users', () => {
try { try {
await usersApi.deleteApiKey(userId, keyId) await usersApi.deleteApiKey(userId, keyId)
} catch (err: any) { } catch (err: any) {
error.value = err.response?.data?.detail || '删除 API Key 失败' error.value = err.response?.data?.error?.message || err.response?.data?.detail || '删除 API Key 失败'
throw err throw err
} }
} }
@@ -110,7 +110,7 @@ export const useUsersStore = defineStore('users', () => {
// 刷新用户列表以获取最新数据 // 刷新用户列表以获取最新数据
await fetchUsers() await fetchUsers()
} catch (err: any) { } catch (err: any) {
error.value = err.response?.data?.detail || '重置配额失败' error.value = err.response?.data?.error?.message || err.response?.data?.detail || '重置配额失败'
throw err throw err
} finally { } finally {
loading.value = false loading.value = false

View File

@@ -198,3 +198,49 @@ export function parseApiErrorShort(err: unknown, defaultMessage: string = '操
const lines = fullError.split('\n') const lines = fullError.split('\n')
return lines[0] || defaultMessage return lines[0] || defaultMessage
} }
/**
* 解析模型测试响应的错误信息
* @param result 测试响应结果
* @returns 格式化的错误信息
*/
export function parseTestModelError(result: {
error?: string
data?: {
response?: {
status_code?: number
error?: string | { message?: string }
}
}
}): string {
let errorMsg = result.error || '测试失败'
// 检查HTTP状态码错误
if (result.data?.response?.status_code) {
const status = result.data.response.status_code
if (status === 403) {
errorMsg = '认证失败: API密钥无效或客户端类型不被允许'
} else if (status === 401) {
errorMsg = '认证失败: API密钥无效或已过期'
} else if (status === 404) {
errorMsg = '模型不存在: 请检查模型名称是否正确'
} else if (status === 429) {
errorMsg = '请求频率过高: 请稍后重试'
} else if (status >= 500) {
errorMsg = `服务器错误: HTTP ${status}`
} else {
errorMsg = `请求失败: HTTP ${status}`
}
}
// 尝试从错误响应中提取更多信息
if (result.data?.response?.error) {
if (typeof result.data.response.error === 'string') {
errorMsg = result.data.response.error
} else if (result.data.response.error?.message) {
errorMsg = result.data.response.error.message
}
}
return errorMsg
}

View File

@@ -650,6 +650,7 @@
import { ref, computed, onMounted } from 'vue' import { ref, computed, onMounted } from 'vue'
import { useToast } from '@/composables/useToast' import { useToast } from '@/composables/useToast'
import { useConfirm } from '@/composables/useConfirm' import { useConfirm } from '@/composables/useConfirm'
import { useClipboard } from '@/composables/useClipboard'
import { adminApi, type AdminApiKey, type CreateStandaloneApiKeyRequest } from '@/api/admin' import { adminApi, type AdminApiKey, type CreateStandaloneApiKeyRequest } from '@/api/admin'
import { import {
@@ -693,6 +694,7 @@ import { log } from '@/utils/logger'
const { success, error } = useToast() const { success, error } = useToast()
const { confirmDanger } = useConfirm() const { confirmDanger } = useConfirm()
const { copyToClipboard } = useClipboard()
const apiKeys = ref<AdminApiKey[]>([]) const apiKeys = ref<AdminApiKey[]>([])
const loading = ref(false) const loading = ref(false)
@@ -927,20 +929,14 @@ function selectKey() {
} }
async function copyKey() { async function copyKey() {
try { await copyToClipboard(newKeyValue.value)
await navigator.clipboard.writeText(newKeyValue.value)
success('API Key 已复制到剪贴板')
} catch {
error('复制失败,请手动复制')
}
} }
async function copyKeyPrefix(apiKey: AdminApiKey) { async function copyKeyPrefix(apiKey: AdminApiKey) {
try { try {
// 调用后端 API 获取完整密钥 // 调用后端 API 获取完整密钥
const response = await adminApi.getFullApiKey(apiKey.id) const response = await adminApi.getFullApiKey(apiKey.id)
await navigator.clipboard.writeText(response.key) await copyToClipboard(response.key)
success('完整密钥已复制到剪贴板')
} catch (err) { } catch (err) {
log.error('复制密钥失败:', err) log.error('复制密钥失败:', err)
error('复制失败,请重试') error('复制失败,请重试')
@@ -1046,9 +1042,10 @@ async function handleKeyFormSubmit(data: StandaloneKeyFormData) {
rate_limit: data.rate_limit, rate_limit: data.rate_limit,
expire_days: data.never_expire ? null : (data.expire_days || null), expire_days: data.never_expire ? null : (data.expire_days || null),
auto_delete_on_expiry: data.auto_delete_on_expiry, auto_delete_on_expiry: data.auto_delete_on_expiry,
allowed_providers: data.allowed_providers.length > 0 ? data.allowed_providers : undefined, // 空数组表示清除限制(允许全部),后端会将空数组存为 NULL
allowed_api_formats: data.allowed_api_formats.length > 0 ? data.allowed_api_formats : undefined, allowed_providers: data.allowed_providers,
allowed_models: data.allowed_models.length > 0 ? data.allowed_models : undefined allowed_api_formats: data.allowed_api_formats,
allowed_models: data.allowed_models
} }
await adminApi.updateApiKey(data.id, updateData) await adminApi.updateApiKey(data.id, updateData)
success('API Key 更新成功') success('API Key 更新成功')
@@ -1064,9 +1061,10 @@ async function handleKeyFormSubmit(data: StandaloneKeyFormData) {
rate_limit: data.rate_limit, rate_limit: data.rate_limit,
expire_days: data.never_expire ? null : (data.expire_days || null), expire_days: data.never_expire ? null : (data.expire_days || null),
auto_delete_on_expiry: data.auto_delete_on_expiry, auto_delete_on_expiry: data.auto_delete_on_expiry,
allowed_providers: data.allowed_providers.length > 0 ? data.allowed_providers : undefined, // 空数组表示不设置限制(允许全部),后端会将空数组存为 NULL
allowed_api_formats: data.allowed_api_formats.length > 0 ? data.allowed_api_formats : undefined, allowed_providers: data.allowed_providers,
allowed_models: data.allowed_models.length > 0 ? data.allowed_models : undefined allowed_api_formats: data.allowed_api_formats,
allowed_models: data.allowed_models
} }
const response = await adminApi.createStandaloneApiKey(createData) const response = await adminApi.createStandaloneApiKey(createData)
newKeyValue.value = response.key newKeyValue.value = response.key

View File

@@ -46,6 +46,7 @@ const clearingRowAffinityKey = ref<string | null>(null)
const currentPage = ref(1) const currentPage = ref(1)
const pageSize = ref(20) const pageSize = ref(20)
const currentTime = ref(Math.floor(Date.now() / 1000)) const currentTime = ref(Math.floor(Date.now() / 1000))
const analysisHoursSelectOpen = ref(false)
// ==================== 模型映射缓存 ==================== // ==================== 模型映射缓存 ====================
@@ -1056,7 +1057,7 @@ onBeforeUnmount(() => {
<span class="text-xs text-muted-foreground hidden sm:inline">分析用户请求间隔推荐合适的缓存 TTL</span> <span class="text-xs text-muted-foreground hidden sm:inline">分析用户请求间隔推荐合适的缓存 TTL</span>
</div> </div>
<div class="flex flex-wrap items-center gap-2"> <div class="flex flex-wrap items-center gap-2">
<Select v-model="analysisHours"> <Select v-model="analysisHours" v-model:open="analysisHoursSelectOpen">
<SelectTrigger class="w-24 sm:w-28 h-8"> <SelectTrigger class="w-24 sm:w-28 h-8">
<SelectValue placeholder="时间段" /> <SelectValue placeholder="时间段" />
</SelectTrigger> </SelectTrigger>

View File

@@ -713,6 +713,7 @@ import ProviderModelFormDialog from '@/features/providers/components/ProviderMod
import type { Model } from '@/api/endpoints' import type { Model } from '@/api/endpoints'
import { useToast } from '@/composables/useToast' import { useToast } from '@/composables/useToast'
import { useConfirm } from '@/composables/useConfirm' import { useConfirm } from '@/composables/useConfirm'
import { useClipboard } from '@/composables/useClipboard'
import { useRowClick } from '@/composables/useRowClick' import { useRowClick } from '@/composables/useRowClick'
import { parseApiError } from '@/utils/errorParser' import { parseApiError } from '@/utils/errorParser'
import { import {
@@ -736,6 +737,7 @@ import {
updateGlobalModel, updateGlobalModel,
deleteGlobalModel, deleteGlobalModel,
batchAssignToProviders, batchAssignToProviders,
getGlobalModelProviders,
type GlobalModelResponse, type GlobalModelResponse,
} from '@/api/global-models' } from '@/api/global-models'
import { log } from '@/utils/logger' import { log } from '@/utils/logger'
@@ -743,6 +745,7 @@ import { getProvidersSummary } from '@/api/endpoints/providers'
import { getAllCapabilities, type CapabilityDefinition } from '@/api/endpoints' import { getAllCapabilities, type CapabilityDefinition } from '@/api/endpoints'
const { success, error: showError } = useToast() const { success, error: showError } = useToast()
const { copyToClipboard } = useClipboard()
// 状态 // 状态
const loading = ref(false) const loading = ref(false)
@@ -1066,16 +1069,6 @@ function handleRowClick(event: MouseEvent, model: GlobalModelResponse) {
selectModel(model) selectModel(model)
} }
// 复制到剪贴板
async function copyToClipboard(text: string) {
try {
await navigator.clipboard.writeText(text)
success('已复制')
} catch {
showError('复制失败')
}
}
async function selectModel(model: GlobalModelResponse) { async function selectModel(model: GlobalModelResponse) {
selectedModel.value = model selectedModel.value = model
detailTab.value = 'basic' detailTab.value = 'basic'
@@ -1088,18 +1081,11 @@ async function selectModel(model: GlobalModelResponse) {
async function loadModelProviders(_globalModelId: string) { async function loadModelProviders(_globalModelId: string) {
loadingModelProviders.value = true loadingModelProviders.value = true
try { try {
// 使用 ModelCatalog API 获取详细的关联提供商信息 // 使用新的 API 获取所有关联提供商(包括非活跃的)
const { getModelCatalog } = await import('@/api/endpoints') const response = await getGlobalModelProviders(_globalModelId)
const catalogResponse = await getModelCatalog()
// 查找当前 GlobalModel 对应的 catalog item // 转换为展示格式
const catalogItem = catalogResponse.models.find( selectedModelProviders.value = response.providers.map(p => ({
m => m.global_model_name === selectedModel.value?.name
)
if (catalogItem) {
// 转换为展示格式,包含完整的模型实现信息
selectedModelProviders.value = catalogItem.providers.map(p => ({
id: p.provider_id, id: p.provider_id,
model_id: p.model_id, model_id: p.model_id,
display_name: p.provider_display_name || p.provider_name, display_name: p.provider_display_name || p.provider_name,
@@ -1121,9 +1107,6 @@ async function loadModelProviders(_globalModelId: string) {
supports_function_calling: p.supports_function_calling, supports_function_calling: p.supports_function_calling,
supports_streaming: p.supports_streaming supports_streaming: p.supports_streaming
})) }))
} else {
selectedModelProviders.value = []
}
} catch (err: any) { } catch (err: any) {
log.error('加载关联提供商失败:', err) log.error('加载关联提供商失败:', err)
showError(parseApiError(err, '加载关联提供商失败'), '错误') showError(parseApiError(err, '加载关联提供商失败'), '错误')

View File

@@ -723,9 +723,19 @@ async function handleDeleteProvider(provider: ProviderWithEndpointsSummary) {
// 切换提供商状态 // 切换提供商状态
async function toggleProviderStatus(provider: ProviderWithEndpointsSummary) { async function toggleProviderStatus(provider: ProviderWithEndpointsSummary) {
try { try {
await updateProvider(provider.id, { is_active: !provider.is_active }) const newStatus = !provider.is_active
provider.is_active = !provider.is_active await updateProvider(provider.id, { is_active: newStatus })
showSuccess(provider.is_active ? '提供商已启用' : '提供商已停用')
// 更新抽屉内部的 provider 对象
provider.is_active = newStatus
// 同时更新主页面 providers 数组中的对象,实现无感更新
const targetProvider = providers.value.find(p => p.id === provider.id)
if (targetProvider) {
targetProvider.is_active = newStatus
}
showSuccess(newStatus ? '提供商已启用' : '提供商已停用')
} catch (err: any) { } catch (err: any) {
showError(err.response?.data?.detail || '操作失败', '错误') showError(err.response?.data?.detail || '操作失败', '错误')
} }

View File

@@ -701,6 +701,7 @@ import { ref, computed, onMounted, watch } from 'vue'
import { useUsersStore } from '@/stores/users' import { useUsersStore } from '@/stores/users'
import { useToast } from '@/composables/useToast' import { useToast } from '@/composables/useToast'
import { useConfirm } from '@/composables/useConfirm' import { useConfirm } from '@/composables/useConfirm'
import { useClipboard } from '@/composables/useClipboard'
import { usageApi, type UsageByUser } from '@/api/usage' import { usageApi, type UsageByUser } from '@/api/usage'
import { adminApi } from '@/api/admin' import { adminApi } from '@/api/admin'
@@ -748,6 +749,7 @@ import { log } from '@/utils/logger'
const { success, error } = useToast() const { success, error } = useToast()
const { confirmDanger, confirmWarning } = useConfirm() const { confirmDanger, confirmWarning } = useConfirm()
const { copyToClipboard } = useClipboard()
const usersStore = useUsersStore() const usersStore = useUsersStore()
// 用户表单对话框状态 // 用户表单对话框状态
@@ -875,7 +877,8 @@ async function toggleUserStatus(user: any) {
const action = user.is_active ? '禁用' : '启用' const action = user.is_active ? '禁用' : '启用'
const confirmed = await confirmDanger( const confirmed = await confirmDanger(
`确定要${action}用户 ${user.username} 吗?`, `确定要${action}用户 ${user.username} 吗?`,
`${action}用户` `${action}用户`,
action
) )
if (!confirmed) return if (!confirmed) return
@@ -884,7 +887,7 @@ async function toggleUserStatus(user: any) {
await usersStore.updateUser(user.id, { is_active: !user.is_active }) await usersStore.updateUser(user.id, { is_active: !user.is_active })
success(`用户已${action}`) success(`用户已${action}`)
} catch (err: any) { } catch (err: any) {
error(err.response?.data?.detail || '未知错误', `${action}用户失败`) error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', `${action}用户失败`)
} }
} }
@@ -955,7 +958,7 @@ async function handleUserFormSubmit(data: UserFormData & { password?: string })
closeUserFormDialog() closeUserFormDialog()
} catch (err: any) { } catch (err: any) {
const title = data.id ? '更新用户失败' : '创建用户失败' const title = data.id ? '更新用户失败' : '创建用户失败'
error(err.response?.data?.detail || '未知错误', title) error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', title)
} finally { } finally {
userFormDialogRef.value?.setSaving(false) userFormDialogRef.value?.setSaving(false)
} }
@@ -989,7 +992,7 @@ async function createApiKey() {
showNewApiKeyDialog.value = true showNewApiKeyDialog.value = true
await loadUserApiKeys(selectedUser.value.id) await loadUserApiKeys(selectedUser.value.id)
} catch (err: any) { } catch (err: any) {
error(err.response?.data?.detail || '未知错误', '创建 API Key 失败') error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', '创建 API Key 失败')
} finally { } finally {
creatingApiKey.value = false creatingApiKey.value = false
} }
@@ -1000,12 +1003,7 @@ function selectApiKey() {
} }
async function copyApiKey() { async function copyApiKey() {
try { await copyToClipboard(newApiKey.value)
await navigator.clipboard.writeText(newApiKey.value)
success('API Key已复制到剪贴板')
} catch {
error('复制失败,请手动复制')
}
} }
async function closeNewApiKeyDialog() { async function closeNewApiKeyDialog() {
@@ -1026,7 +1024,7 @@ async function deleteApiKey(apiKey: any) {
await loadUserApiKeys(selectedUser.value.id) await loadUserApiKeys(selectedUser.value.id)
success('API Key已删除') success('API Key已删除')
} catch (err: any) { } catch (err: any) {
error(err.response?.data?.detail || '未知错误', '删除 API Key 失败') error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', '删除 API Key 失败')
} }
} }
@@ -1034,11 +1032,10 @@ async function copyFullKey(apiKey: any) {
try { try {
// 调用后端 API 获取完整密钥 // 调用后端 API 获取完整密钥
const response = await adminApi.getFullApiKey(apiKey.id) const response = await adminApi.getFullApiKey(apiKey.id)
await navigator.clipboard.writeText(response.key) await copyToClipboard(response.key)
success('完整密钥已复制到剪贴板')
} catch (err: any) { } catch (err: any) {
log.error('复制密钥失败:', err) log.error('复制密钥失败:', err)
error(err.response?.data?.detail || '未知错误', '复制密钥失败') error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', '复制密钥失败')
} }
} }
@@ -1054,7 +1051,7 @@ async function resetQuota(user: any) {
await usersStore.resetUserQuota(user.id) await usersStore.resetUserQuota(user.id)
success('配额已重置') success('配额已重置')
} catch (err: any) { } catch (err: any) {
error(err.response?.data?.detail || '未知错误', '重置配额失败') error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', '重置配额失败')
} }
} }
@@ -1070,7 +1067,7 @@ async function deleteUser(user: any) {
await usersStore.deleteUser(user.id) await usersStore.deleteUser(user.id)
success('用户已删除') success('用户已删除')
} catch (err: any) { } catch (err: any) {
error(err.response?.data?.detail || '未知错误', '删除用户失败') error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', '删除用户失败')
} }
} }
</script> </script>

View File

@@ -102,9 +102,9 @@
<!-- Main Content --> <!-- Main Content -->
<main class="relative z-10"> <main class="relative z-10">
<!-- Fixed Logo Container --> <!-- Fixed Logo Container -->
<div class="fixed inset-0 z-20 pointer-events-none flex items-center justify-center overflow-hidden"> <div class="mt-4 fixed inset-0 z-20 pointer-events-none flex items-center justify-center overflow-hidden">
<div <div
class="transform-gpu logo-container" class="mt-16 transform-gpu logo-container"
:class="[currentSection === SECTIONS.HOME ? 'home-section' : '', `logo-transition-${scrollDirection}`]" :class="[currentSection === SECTIONS.HOME ? 'home-section' : '', `logo-transition-${scrollDirection}`]"
:style="fixedLogoStyle" :style="fixedLogoStyle"
> >
@@ -151,7 +151,7 @@
class="min-h-screen snap-start flex items-center justify-center px-16 lg:px-20 py-20" class="min-h-screen snap-start flex items-center justify-center px-16 lg:px-20 py-20"
> >
<div class="max-w-4xl mx-auto text-center"> <div class="max-w-4xl mx-auto text-center">
<div class="h-80 w-full mb-16" /> <div class="h-80 w-full mb-16 mt-8" />
<h1 <h1
class="mb-6 text-5xl md:text-7xl font-bold text-[#191919] dark:text-white leading-tight transition-all duration-700" class="mb-6 text-5xl md:text-7xl font-bold text-[#191919] dark:text-white leading-tight transition-all duration-700"
:style="getTitleStyle(SECTIONS.HOME)" :style="getTitleStyle(SECTIONS.HOME)"
@@ -166,7 +166,7 @@
整合 Claude CodeCodex CLIGemini CLI 等多个 AI 编程助手 整合 Claude CodeCodex CLIGemini CLI 等多个 AI 编程助手
</p> </p>
<button <button
class="mt-16 transition-all duration-700 cursor-pointer hover:scale-110" class="mt-8 transition-all duration-700 cursor-pointer hover:scale-110"
:style="getScrollIndicatorStyle(SECTIONS.HOME)" :style="getScrollIndicatorStyle(SECTIONS.HOME)"
@click="scrollToSection(SECTIONS.CLAUDE)" @click="scrollToSection(SECTIONS.CLAUDE)"
> >

View File

@@ -145,10 +145,10 @@
<DollarSign class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" /> <DollarSign class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
<div class="pr-6"> <div class="pr-6">
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground"> <p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
实际成 月费用
</p> </p>
<p class="mt-1.5 sm:mt-2 text-lg sm:text-xl font-semibold text-foreground"> <p class="mt-1.5 sm:mt-2 text-lg sm:text-xl font-semibold text-foreground">
{{ formatCurrency(costStats.total_actual_cost) }} {{ formatCurrency(costStats.total_cost) }}
</p> </p>
<Badge <Badge
v-if="costStats.cost_savings > 0" v-if="costStats.cost_savings > 0"
@@ -162,14 +162,14 @@
</div> </div>
</div> </div>
<!-- 普通用户缓存统计 --> <!-- 普通用户月度统计 -->
<div <div
v-else-if="!isAdmin && cacheStats && cacheStats.total_cache_tokens > 0" v-else-if="!isAdmin && (hasCacheData || (userMonthlyCost !== null && userMonthlyCost > 0))"
class="mt-6" class="mt-6"
> >
<div class="mb-3 flex items-center justify-between"> <div class="mb-3 flex items-center justify-between">
<h3 class="text-sm font-medium text-foreground"> <h3 class="text-sm font-medium text-foreground">
本月缓存使用 本月统计
</h3> </h3>
<Badge <Badge
variant="outline" variant="outline"
@@ -178,8 +178,16 @@
Monthly Monthly
</Badge> </Badge>
</div> </div>
<div class="grid grid-cols-2 gap-2 sm:gap-3 xl:grid-cols-4"> <div
<Card class="relative p-3 sm:p-4 border-book-cloth/30"> :class="[
'grid gap-2 sm:gap-3',
hasCacheData ? 'grid-cols-2 xl:grid-cols-4' : 'grid-cols-1 max-w-xs'
]"
>
<Card
v-if="cacheStats"
class="relative p-3 sm:p-4 border-book-cloth/30"
>
<Database class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" /> <Database class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
<div class="pr-6"> <div class="pr-6">
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground"> <p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
@@ -190,7 +198,10 @@
</p> </p>
</div> </div>
</Card> </Card>
<Card class="relative p-3 sm:p-4 border-kraft/30"> <Card
v-if="cacheStats"
class="relative p-3 sm:p-4 border-kraft/30"
>
<Hash class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" /> <Hash class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
<div class="pr-6"> <div class="pr-6">
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground"> <p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
@@ -201,7 +212,10 @@
</p> </p>
</div> </div>
</Card> </Card>
<Card class="relative p-3 sm:p-4 border-book-cloth/25"> <Card
v-if="cacheStats"
class="relative p-3 sm:p-4 border-book-cloth/25"
>
<Database class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" /> <Database class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
<div class="pr-6"> <div class="pr-6">
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground"> <p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
@@ -213,19 +227,16 @@
</div> </div>
</Card> </Card>
<Card <Card
v-if="tokenBreakdown" v-if="userMonthlyCost !== null"
class="relative p-3 sm:p-4 border-manilla/40" class="relative p-3 sm:p-4 border-manilla/40"
> >
<Hash class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" /> <DollarSign class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
<div class="pr-6"> <div class="pr-6">
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground"> <p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
总Token 本月费用
</p> </p>
<p class="mt-1.5 sm:mt-2 text-lg sm:text-xl font-semibold text-foreground"> <p class="mt-1.5 sm:mt-2 text-lg sm:text-xl font-semibold text-foreground">
{{ formatTokens((tokenBreakdown.input || 0) + (tokenBreakdown.output || 0)) }} {{ formatCurrency(userMonthlyCost) }}
</p>
<p class="mt-0.5 sm:mt-1 text-[9px] sm:text-[10px] text-muted-foreground">
输入 {{ formatTokens(tokenBreakdown.input || 0) }} / 输出 {{ formatTokens(tokenBreakdown.output || 0) }}
</p> </p>
</div> </div>
</Card> </Card>
@@ -831,6 +842,12 @@ const cacheStats = ref<{
total_cache_tokens: number total_cache_tokens: number
} | null>(null) } | null>(null)
const userMonthlyCost = ref<number | null>(null)
const hasCacheData = computed(() =>
cacheStats.value && cacheStats.value.total_cache_tokens > 0
)
const tokenBreakdown = ref<{ const tokenBreakdown = ref<{
input: number input: number
output: number output: number
@@ -1086,6 +1103,7 @@ async function loadDashboardData() {
} else { } else {
if (statsData.cache_stats) cacheStats.value = statsData.cache_stats if (statsData.cache_stats) cacheStats.value = statsData.cache_stats
if (statsData.token_breakdown) tokenBreakdown.value = statsData.token_breakdown if (statsData.token_breakdown) tokenBreakdown.value = statsData.token_breakdown
if (statsData.monthly_cost !== undefined) userMonthlyCost.value = statsData.monthly_cost
} }
} finally { } finally {
loading.value = false loading.value = false

View File

@@ -301,6 +301,7 @@ function stopGlobalAutoRefresh() {
function handleAutoRefreshChange(value: boolean) { function handleAutoRefreshChange(value: boolean) {
globalAutoRefresh.value = value globalAutoRefresh.value = value
if (value) { if (value) {
refreshData() // 立即刷新一次
startGlobalAutoRefresh() startGlobalAutoRefresh()
} else { } else {
stopGlobalAutoRefresh() stopGlobalAutoRefresh()

View File

@@ -342,6 +342,7 @@ import {
Plus, Plus,
} from 'lucide-vue-next' } from 'lucide-vue-next'
import { useToast } from '@/composables/useToast' import { useToast } from '@/composables/useToast'
import { useClipboard } from '@/composables/useClipboard'
import { import {
Card, Card,
Table, Table,
@@ -370,6 +371,7 @@ import { useRowClick } from '@/composables/useRowClick'
import { log } from '@/utils/logger' import { log } from '@/utils/logger'
const { success, error: showError } = useToast() const { success, error: showError } = useToast()
const { copyToClipboard } = useClipboard()
// 状态 // 状态
const loading = ref(false) const loading = ref(false)
@@ -565,16 +567,6 @@ function hasTieredPricing(model: PublicGlobalModel): boolean {
return (tiered?.tiers?.length || 0) > 1 return (tiered?.tiers?.length || 0) > 1
} }
async function copyToClipboard(text: string) {
try {
await navigator.clipboard.writeText(text)
success('已复制')
} catch (err) {
log.error('复制失败:', err)
showError('复制失败')
}
}
onMounted(() => { onMounted(() => {
refreshData() refreshData()
}) })

View File

@@ -352,6 +352,7 @@ import {
} from 'lucide-vue-next' } from 'lucide-vue-next'
import { useEscapeKey } from '@/composables/useEscapeKey' import { useEscapeKey } from '@/composables/useEscapeKey'
import { useToast } from '@/composables/useToast' import { useToast } from '@/composables/useToast'
import { useClipboard } from '@/composables/useClipboard'
import Card from '@/components/ui/card.vue' import Card from '@/components/ui/card.vue'
import Badge from '@/components/ui/badge.vue' import Badge from '@/components/ui/badge.vue'
import Button from '@/components/ui/button.vue' import Button from '@/components/ui/button.vue'
@@ -375,6 +376,7 @@ const emit = defineEmits<{
}>() }>()
const { success: showSuccess, error: showError } = useToast() const { success: showSuccess, error: showError } = useToast()
const { copyToClipboard } = useClipboard()
interface Props { interface Props {
model: PublicGlobalModel | null model: PublicGlobalModel | null
@@ -408,15 +410,6 @@ function handleClose() {
emit('update:open', false) emit('update:open', false)
} }
async function copyToClipboard(text: string) {
try {
await navigator.clipboard.writeText(text)
showSuccess('已复制')
} catch {
showError('复制失败')
}
}
function getFirstTierPrice( function getFirstTierPrice(
tieredPricing: TieredPricingConfig | undefined | null, tieredPricing: TieredPricingConfig | undefined | null,
priceKey: 'input_price_per_1m' | 'output_price_per_1m' | 'cache_creation_price_per_1m' | 'cache_read_price_per_1m' priceKey: 'input_price_per_1m' | 'output_price_per_1m' | 'cache_creation_price_per_1m' | 'cache_read_price_per_1m'

View File

@@ -13,7 +13,7 @@ authors = [
classifiers = [ classifiers = [
"Development Status :: 4 - Beta", "Development Status :: 4 - Beta",
"Intended Audience :: Developers", "Intended Audience :: Developers",
"License :: OSI Approved :: MIT License", "License :: Other/Proprietary License",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.9",

View File

@@ -80,6 +80,17 @@ async def get_keys_grouped_by_format(
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/keys/{key_id}/reveal")
async def reveal_endpoint_key(
key_id: str,
request: Request,
db: Session = Depends(get_db),
) -> dict:
"""获取完整的 API Key用于查看和复制"""
adapter = AdminRevealEndpointKeyAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/keys/{key_id}") @router.delete("/keys/{key_id}")
async def delete_endpoint_key( async def delete_endpoint_key(
key_id: str, key_id: str,
@@ -293,6 +304,30 @@ class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
return EndpointAPIKeyResponse(**response_dict) return EndpointAPIKeyResponse(**response_dict)
@dataclass
class AdminRevealEndpointKeyAdapter(AdminApiAdapter):
"""获取完整的 API Key用于查看和复制"""
key_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == self.key_id).first()
if not key:
raise NotFoundException(f"Key {self.key_id} 不存在")
try:
decrypted_key = crypto_service.decrypt(key.api_key)
except Exception as e:
logger.error(f"解密 Key 失败: ID={self.key_id}, Error={e}")
raise InvalidRequestException(
"无法解密 API Key可能是加密密钥已更改。请重新添加该密钥。"
)
logger.info(f"[REVEAL] 查看完整 Key: ID={self.key_id}, Name={key.name}")
return {"api_key": decrypted_key}
@dataclass @dataclass
class AdminDeleteEndpointKeyAdapter(AdminApiAdapter): class AdminDeleteEndpointKeyAdapter(AdminApiAdapter):
key_id: str key_id: str

View File

@@ -5,7 +5,7 @@ GlobalModel Admin API
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional from typing import Optional
from fastapi import APIRouter, Depends, Query, Request from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -19,9 +19,11 @@ from src.models.pydantic_models import (
BatchAssignToProvidersResponse, BatchAssignToProvidersResponse,
GlobalModelCreate, GlobalModelCreate,
GlobalModelListResponse, GlobalModelListResponse,
GlobalModelProvidersResponse,
GlobalModelResponse, GlobalModelResponse,
GlobalModelUpdate, GlobalModelUpdate,
GlobalModelWithStats, GlobalModelWithStats,
ModelCatalogProviderDetail,
) )
from src.services.model.global_model import GlobalModelService from src.services.model.global_model import GlobalModelService
@@ -108,6 +110,17 @@ async def batch_assign_to_providers(
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/{global_model_id}/providers", response_model=GlobalModelProvidersResponse)
async def get_global_model_providers(
request: Request,
global_model_id: str,
db: Session = Depends(get_db),
) -> GlobalModelProvidersResponse:
"""获取 GlobalModel 的所有关联提供商(包括非活跃的)"""
adapter = AdminGetGlobalModelProvidersAdapter(global_model_id=global_model_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# ========== Adapters ========== # ========== Adapters ==========
@@ -275,3 +288,61 @@ class AdminBatchAssignToProvidersAdapter(AdminApiAdapter):
logger.info(f"批量为 Provider 添加 GlobalModel: global_model_id={self.global_model_id} success={len(result['success'])} errors={len(result['errors'])}") logger.info(f"批量为 Provider 添加 GlobalModel: global_model_id={self.global_model_id} success={len(result['success'])} errors={len(result['errors'])}")
return BatchAssignToProvidersResponse(**result) return BatchAssignToProvidersResponse(**result)
@dataclass
class AdminGetGlobalModelProvidersAdapter(AdminApiAdapter):
"""获取 GlobalModel 的所有关联提供商(包括非活跃的)"""
global_model_id: str
async def handle(self, context): # type: ignore[override]
from sqlalchemy.orm import joinedload
from src.models.database import Model
global_model = GlobalModelService.get_global_model(context.db, self.global_model_id)
# 获取所有关联的 Model包括非活跃的
models = (
context.db.query(Model)
.options(joinedload(Model.provider), joinedload(Model.global_model))
.filter(Model.global_model_id == global_model.id)
.all()
)
provider_entries = []
for model in models:
provider = model.provider
if not provider:
continue
effective_tiered = model.get_effective_tiered_pricing()
tier_count = len(effective_tiered.get("tiers", [])) if effective_tiered else 1
provider_entries.append(
ModelCatalogProviderDetail(
provider_id=provider.id,
provider_name=provider.name,
provider_display_name=provider.display_name,
model_id=model.id,
target_model=model.provider_model_name,
input_price_per_1m=model.get_effective_input_price(),
output_price_per_1m=model.get_effective_output_price(),
cache_creation_price_per_1m=model.get_effective_cache_creation_price(),
cache_read_price_per_1m=model.get_effective_cache_read_price(),
cache_1h_creation_price_per_1m=model.get_effective_1h_cache_creation_price(),
price_per_request=model.get_effective_price_per_request(),
effective_tiered_pricing=effective_tiered,
tier_count=tier_count,
supports_vision=model.get_effective_supports_vision(),
supports_function_calling=model.get_effective_supports_function_calling(),
supports_streaming=model.get_effective_supports_streaming(),
is_active=bool(model.is_active),
)
)
return GlobalModelProvidersResponse(
providers=provider_entries,
total=len(provider_entries),
)

View File

@@ -32,6 +32,17 @@ class ModelsQueryRequest(BaseModel):
api_key_id: Optional[str] = None api_key_id: Optional[str] = None
class TestModelRequest(BaseModel):
"""模型测试请求"""
provider_id: str
model_name: str
api_key_id: Optional[str] = None
stream: bool = False
message: Optional[str] = "你好"
api_format: Optional[str] = None # 指定使用的API格式如果不指定则使用端点的默认格式
# ============ API Endpoints ============ # ============ API Endpoints ============
@@ -206,3 +217,228 @@ async def query_available_models(
"display_name": provider.display_name, "display_name": provider.display_name,
}, },
} }
@router.post("/test-model")
async def test_model(
request: TestModelRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
测试模型连接性
向指定提供商的指定模型发送测试请求,验证模型是否可用
Args:
request: 测试请求
Returns:
测试结果
"""
# 获取提供商及其端点
provider = (
db.query(Provider)
.options(joinedload(Provider.endpoints).joinedload(ProviderEndpoint.api_keys))
.filter(Provider.id == request.provider_id)
.first()
)
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
# 找到合适的端点和API Key
endpoint_config = None
endpoint = None
api_key = None
if request.api_key_id:
# 使用指定的API Key
for ep in provider.endpoints:
for key in ep.api_keys:
if key.id == request.api_key_id and key.is_active and ep.is_active:
endpoint = ep
api_key = key
break
if endpoint:
break
else:
# 使用第一个可用的端点和密钥
for ep in provider.endpoints:
if not ep.is_active or not ep.api_keys:
continue
for key in ep.api_keys:
if key.is_active:
endpoint = ep
api_key = key
break
if endpoint:
break
if not endpoint or not api_key:
raise HTTPException(status_code=404, detail="No active endpoint or API key found")
try:
api_key_value = crypto_service.decrypt(api_key.api_key)
except Exception as e:
logger.error(f"[test-model] Failed to decrypt API key: {e}")
raise HTTPException(status_code=500, detail="Failed to decrypt API key")
# 构建请求配置
endpoint_config = {
"api_key": api_key_value,
"api_key_id": api_key.id, # 添加API Key ID用于用量记录
"base_url": endpoint.base_url,
"api_format": endpoint.api_format,
"extra_headers": endpoint.headers,
"timeout": endpoint.timeout or 30.0,
}
try:
# 获取对应的 Adapter 类
adapter_class = _get_adapter_for_format(endpoint.api_format)
if not adapter_class:
return {
"success": False,
"error": f"Unknown API format: {endpoint.api_format}",
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
"model": request.model_name,
}
logger.debug(f"[test-model] 使用 Adapter: {adapter_class.__name__}")
logger.debug(f"[test-model] 端点 API Format: {endpoint.api_format}")
# 如果请求指定了 api_format优先使用它
target_api_format = request.api_format or endpoint.api_format
if request.api_format and request.api_format != endpoint.api_format:
logger.debug(f"[test-model] 请求指定 API Format: {request.api_format}")
# 重新获取适配器
adapter_class = _get_adapter_for_format(request.api_format)
if not adapter_class:
return {
"success": False,
"error": f"Unknown API format: {request.api_format}",
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
"model": request.model_name,
}
logger.debug(f"[test-model] 重新选择 Adapter: {adapter_class.__name__}")
# 准备测试请求数据
check_request = {
"model": request.model_name,
"messages": [
{"role": "user", "content": request.message or "Hello! This is a test message."}
],
"max_tokens": 30,
"temperature": 0.7,
}
# 发送测试请求
async with httpx.AsyncClient(timeout=endpoint_config["timeout"]) as client:
# 非流式测试
logger.debug(f"[test-model] 开始非流式测试...")
response = await adapter_class.check_endpoint(
client,
endpoint_config["base_url"],
endpoint_config["api_key"],
check_request,
endpoint_config.get("extra_headers"),
# 用量计算参数(现在强制记录)
db=db,
user=current_user,
provider_name=provider.name,
provider_id=provider.id,
api_key_id=endpoint_config.get("api_key_id"),
model_name=request.model_name,
)
# 记录提供商返回信息
logger.debug(f"[test-model] 非流式测试结果:")
logger.debug(f"[test-model] Status Code: {response.get('status_code')}")
logger.debug(f"[test-model] Response Headers: {response.get('headers', {})}")
response_data = response.get('response', {})
response_body = response_data.get('response_body', {})
logger.debug(f"[test-model] Response Data: {response_data}")
logger.debug(f"[test-model] Response Body: {response_body}")
# 尝试解析 response_body (通常是 JSON 字符串)
parsed_body = response_body
import json
if isinstance(response_body, str):
try:
parsed_body = json.loads(response_body)
except json.JSONDecodeError:
pass
if isinstance(parsed_body, dict) and 'error' in parsed_body:
error_obj = parsed_body['error']
# 兼容 error 可能是字典或字符串的情况
if isinstance(error_obj, dict):
logger.debug(f"[test-model] Error Message: {error_obj.get('message')}")
raise HTTPException(status_code=500, detail=error_obj.get('message'))
else:
logger.debug(f"[test-model] Error: {error_obj}")
raise HTTPException(status_code=500, detail=error_obj)
elif 'error' in response:
logger.debug(f"[test-model] Error: {response['error']}")
raise HTTPException(status_code=500, detail=response['error'])
else:
# 如果有选择或消息,记录内容预览
if isinstance(response_data, dict):
if 'choices' in response_data and response_data['choices']:
choice = response_data['choices'][0]
if 'message' in choice:
content = choice['message'].get('content', '')
logger.debug(f"[test-model] Content Preview: {content[:200]}...")
elif 'content' in response_data and response_data['content']:
content = str(response_data['content'])
logger.debug(f"[test-model] Content Preview: {content[:200]}...")
# 检查测试是否成功基于HTTP状态码
status_code = response.get('status_code', 0)
is_success = status_code == 200 and 'error' not in response
return {
"success": is_success,
"data": {
"stream": False,
"response": response,
},
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
"model": request.model_name,
"endpoint": {
"id": endpoint.id,
"api_format": endpoint.api_format,
"base_url": endpoint.base_url,
},
}
except Exception as e:
logger.error(f"[test-model] Error testing model {request.model_name}: {e}")
return {
"success": False,
"error": str(e),
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
"model": request.model_name,
"endpoint": {
"id": endpoint.id,
"api_format": endpoint.api_format,
"base_url": endpoint.base_url,
} if endpoint else None,
}

View File

@@ -18,7 +18,15 @@ from sqlalchemy.orm import Session, joinedload
from src.config.constants import CacheTTL from src.config.constants import CacheTTL
from src.core.cache_service import CacheService from src.core.cache_service import CacheService
from src.core.logger import logger from src.core.logger import logger
from src.models.database import GlobalModel, Model, Provider, ProviderAPIKey, ProviderEndpoint from src.models.database import (
ApiKey,
GlobalModel,
Model,
Provider,
ProviderAPIKey,
ProviderEndpoint,
User,
)
# 缓存 key 前缀 # 缓存 key 前缀
_CACHE_KEY_PREFIX = "models:list" _CACHE_KEY_PREFIX = "models:list"
@@ -82,6 +90,7 @@ class ModelInfo:
created_at: Optional[str] # ISO 格式 created_at: Optional[str] # ISO 格式
created_timestamp: int # Unix 时间戳 created_timestamp: int # Unix 时间戳
provider_name: str provider_name: str
provider_id: str = "" # Provider ID用于权限过滤
# 能力配置 # 能力配置
streaming: bool = True streaming: bool = True
vision: bool = False vision: bool = False
@@ -99,6 +108,92 @@ class ModelInfo:
output_modalities: Optional[list[str]] = None output_modalities: Optional[list[str]] = None
@dataclass
class AccessRestrictions:
"""API Key 或 User 的访问限制"""
allowed_providers: Optional[list[str]] = None # 允许的 Provider ID 列表
allowed_models: Optional[list[str]] = None # 允许的模型名称列表
allowed_api_formats: Optional[list[str]] = None # 允许的 API 格式列表
@classmethod
def from_api_key_and_user(
cls, api_key: Optional[ApiKey], user: Optional[User]
) -> "AccessRestrictions":
"""
从 API Key 和 User 合并访问限制
限制逻辑:
- API Key 的限制优先于 User 的限制
- 如果 API Key 有限制,使用 API Key 的限制
- 如果 API Key 无限制但 User 有限制,使用 User 的限制
- 两者都无限制则返回空限制
"""
allowed_providers: Optional[list[str]] = None
allowed_models: Optional[list[str]] = None
allowed_api_formats: Optional[list[str]] = None
# 优先使用 API Key 的限制
if api_key:
if api_key.allowed_providers is not None:
allowed_providers = api_key.allowed_providers
if api_key.allowed_models is not None:
allowed_models = api_key.allowed_models
if api_key.allowed_api_formats is not None:
allowed_api_formats = api_key.allowed_api_formats
# 如果 API Key 没有限制,检查 User 的限制
# 注意: User 没有 allowed_api_formats 字段
if user:
if allowed_providers is None and user.allowed_providers is not None:
allowed_providers = user.allowed_providers
if allowed_models is None and user.allowed_models is not None:
allowed_models = user.allowed_models
return cls(
allowed_providers=allowed_providers,
allowed_models=allowed_models,
allowed_api_formats=allowed_api_formats,
)
def is_api_format_allowed(self, api_format: str) -> bool:
"""
检查 API 格式是否被允许
Args:
api_format: API 格式 (如 "OPENAI", "CLAUDE", "GEMINI")
Returns:
True 如果格式被允许False 否则
"""
if self.allowed_api_formats is None:
return True
return api_format in self.allowed_api_formats
def is_model_allowed(self, model_id: str, provider_id: str) -> bool:
"""
检查模型是否被允许访问
Args:
model_id: 模型 ID
provider_id: Provider ID
Returns:
True 如果模型被允许False 否则
"""
# 检查 Provider 限制
if self.allowed_providers is not None:
if provider_id not in self.allowed_providers:
return False
# 检查模型限制
if self.allowed_models is not None:
if model_id not in self.allowed_models:
return False
return True
def get_available_provider_ids(db: Session, api_formats: list[str]) -> set[str]: def get_available_provider_ids(db: Session, api_formats: list[str]) -> set[str]:
""" """
返回有可用端点的 Provider IDs 返回有可用端点的 Provider IDs
@@ -218,6 +313,7 @@ def _extract_model_info(model: Any) -> ModelInfo:
) )
created_timestamp: int = int(model.created_at.timestamp()) if model.created_at else 0 created_timestamp: int = int(model.created_at.timestamp()) if model.created_at else 0
provider_name: str = model.provider.name if model.provider else "unknown" provider_name: str = model.provider.name if model.provider else "unknown"
provider_id: str = model.provider_id or ""
# 从 GlobalModel.config 提取配置信息 # 从 GlobalModel.config 提取配置信息
config: dict = {} config: dict = {}
@@ -233,6 +329,7 @@ def _extract_model_info(model: Any) -> ModelInfo:
created_at=created_at, created_at=created_at,
created_timestamp=created_timestamp, created_timestamp=created_timestamp,
provider_name=provider_name, provider_name=provider_name,
provider_id=provider_id,
# 能力配置 # 能力配置
streaming=config.get("streaming", True), streaming=config.get("streaming", True),
vision=config.get("vision", False), vision=config.get("vision", False),
@@ -255,6 +352,7 @@ async def list_available_models(
db: Session, db: Session,
available_provider_ids: set[str], available_provider_ids: set[str],
api_formats: Optional[list[str]] = None, api_formats: Optional[list[str]] = None,
restrictions: Optional[AccessRestrictions] = None,
) -> list[ModelInfo]: ) -> list[ModelInfo]:
""" """
获取可用模型列表(已去重,带缓存) 获取可用模型列表(已去重,带缓存)
@@ -263,6 +361,7 @@ async def list_available_models(
db: 数据库会话 db: 数据库会话
available_provider_ids: 有可用端点的 Provider ID 集合 available_provider_ids: 有可用端点的 Provider ID 集合
api_formats: API 格式列表,用于检查 Key 的 allowed_models api_formats: API 格式列表,用于检查 Key 的 allowed_models
restrictions: API Key/User 的访问限制
Returns: Returns:
去重后的 ModelInfo 列表,按创建时间倒序 去重后的 ModelInfo 列表,按创建时间倒序
@@ -270,8 +369,16 @@ async def list_available_models(
if not available_provider_ids: if not available_provider_ids:
return [] return []
# 缓存策略:只有完全无访问限制时才使用缓存
# - restrictions is None: 未传入限制对象
# - restrictions 的两个字段都为 None: 传入了限制对象但无实际限制
# 以上两种情况返回的结果相同,可以共享全局缓存
use_cache = restrictions is None or (
restrictions.allowed_providers is None and restrictions.allowed_models is None
)
# 尝试从缓存获取 # 尝试从缓存获取
if api_formats: if api_formats and use_cache:
cached = await _get_cached_models(api_formats) cached = await _get_cached_models(api_formats)
if cached is not None: if cached is not None:
return cached return cached
@@ -306,14 +413,19 @@ async def list_available_models(
if available_model_ids is not None and info.id not in available_model_ids: if available_model_ids is not None and info.id not in available_model_ids:
continue continue
# 检查 API Key/User 访问限制
if restrictions is not None:
if not restrictions.is_model_allowed(info.id, info.provider_id):
continue
if info.id in seen_model_ids: if info.id in seen_model_ids:
continue continue
seen_model_ids.add(info.id) seen_model_ids.add(info.id)
result.append(info) result.append(info)
# 写入缓存 # 只有无限制的情况才写入缓存
if api_formats: if api_formats and use_cache:
await _set_cached_models(api_formats, result) await _set_cached_models(api_formats, result)
return result return result
@@ -324,6 +436,7 @@ def find_model_by_id(
model_id: str, model_id: str,
available_provider_ids: set[str], available_provider_ids: set[str],
api_formats: Optional[list[str]] = None, api_formats: Optional[list[str]] = None,
restrictions: Optional[AccessRestrictions] = None,
) -> Optional[ModelInfo]: ) -> Optional[ModelInfo]:
""" """
按 ID 查找模型 按 ID 查找模型
@@ -338,6 +451,7 @@ def find_model_by_id(
model_id: 模型 ID model_id: 模型 ID
available_provider_ids: 有可用端点的 Provider ID 集合 available_provider_ids: 有可用端点的 Provider ID 集合
api_formats: API 格式列表,用于检查 Key 的 allowed_models api_formats: API 格式列表,用于检查 Key 的 allowed_models
restrictions: API Key/User 的访问限制
Returns: Returns:
ModelInfo 或 None ModelInfo 或 None
@@ -353,6 +467,11 @@ def find_model_by_id(
if available_model_ids is not None and model_id not in available_model_ids: if available_model_ids is not None and model_id not in available_model_ids:
return None return None
# 快速检查:如果 restrictions 明确限制了模型列表且目标模型不在其中,直接返回 None
if restrictions is not None and restrictions.allowed_models is not None:
if model_id not in restrictions.allowed_models:
return None
# 先按 GlobalModel.name 查找 # 先按 GlobalModel.name 查找
models_by_global = ( models_by_global = (
db.query(Model) db.query(Model)
@@ -368,8 +487,19 @@ def find_model_by_id(
.all() .all()
) )
def is_model_accessible(m: Model) -> bool:
"""检查模型是否可访问"""
if m.provider_id not in available_provider_ids:
return False
# 检查 API Key/User 访问限制
if restrictions is not None:
provider_id = m.provider_id or ""
if not restrictions.is_model_allowed(model_id, provider_id):
return False
return True
model = next( model = next(
(m for m in models_by_global if m.provider_id in available_provider_ids), (m for m in models_by_global if is_model_accessible(m)),
None, None,
) )
@@ -393,7 +523,7 @@ def find_model_by_id(
) )
model = next( model = next(
(m for m in models_by_provider_name if m.provider_id in available_provider_ids), (m for m in models_by_provider_name if is_model_accessible(m)),
None, None,
) )

View File

@@ -118,7 +118,9 @@ class AdminDashboardStatsAdapter(AdminApiAdapter):
# 转换为 UTC 用于与 stats_daily.date 比较(存储的是业务日期对应的 UTC 开始时间) # 转换为 UTC 用于与 stats_daily.date 比较(存储的是业务日期对应的 UTC 开始时间)
today = today_local.astimezone(timezone.utc) today = today_local.astimezone(timezone.utc)
yesterday = (today_local - timedelta(days=1)).astimezone(timezone.utc) yesterday = (today_local - timedelta(days=1)).astimezone(timezone.utc)
last_month = (today_local - timedelta(days=30)).astimezone(timezone.utc) # 本月第一天(自然月)
month_start_local = today_local.replace(day=1)
month_start = month_start_local.astimezone(timezone.utc)
# ==================== 使用预聚合数据 ==================== # ==================== 使用预聚合数据 ====================
# 从 stats_summary + 今日实时数据获取全局统计 # 从 stats_summary + 今日实时数据获取全局统计
@@ -208,7 +210,7 @@ class AdminDashboardStatsAdapter(AdminApiAdapter):
func.sum(StatsDaily.cache_read_cost).label("cache_read_cost"), func.sum(StatsDaily.cache_read_cost).label("cache_read_cost"),
func.sum(StatsDaily.fallback_count).label("fallback_count"), func.sum(StatsDaily.fallback_count).label("fallback_count"),
) )
.filter(StatsDaily.date >= last_month, StatsDaily.date < today) .filter(StatsDaily.date >= month_start, StatsDaily.date < today)
.first() .first()
) )
@@ -227,24 +229,24 @@ class AdminDashboardStatsAdapter(AdminApiAdapter):
else: else:
# 回退到实时查询(没有预聚合数据时) # 回退到实时查询(没有预聚合数据时)
total_requests = ( total_requests = (
db.query(func.count(Usage.id)).filter(Usage.created_at >= last_month).scalar() or 0 db.query(func.count(Usage.id)).filter(Usage.created_at >= month_start).scalar() or 0
) )
total_cost = ( total_cost = (
db.query(func.sum(Usage.total_cost_usd)).filter(Usage.created_at >= last_month).scalar() or 0 db.query(func.sum(Usage.total_cost_usd)).filter(Usage.created_at >= month_start).scalar() or 0
) )
total_actual_cost = ( total_actual_cost = (
db.query(func.sum(Usage.actual_total_cost_usd)) db.query(func.sum(Usage.actual_total_cost_usd))
.filter(Usage.created_at >= last_month).scalar() or 0 .filter(Usage.created_at >= month_start).scalar() or 0
) )
error_requests = ( error_requests = (
db.query(func.count(Usage.id)) db.query(func.count(Usage.id))
.filter( .filter(
Usage.created_at >= last_month, Usage.created_at >= month_start,
(Usage.status_code >= 400) | (Usage.error_message.isnot(None)), (Usage.status_code >= 400) | (Usage.error_message.isnot(None)),
).scalar() or 0 ).scalar() or 0
) )
total_tokens = ( total_tokens = (
db.query(func.sum(Usage.total_tokens)).filter(Usage.created_at >= last_month).scalar() or 0 db.query(func.sum(Usage.total_tokens)).filter(Usage.created_at >= month_start).scalar() or 0
) )
cache_stats = ( cache_stats = (
db.query( db.query(
@@ -253,7 +255,7 @@ class AdminDashboardStatsAdapter(AdminApiAdapter):
func.sum(Usage.cache_creation_cost_usd).label("cache_creation_cost"), func.sum(Usage.cache_creation_cost_usd).label("cache_creation_cost"),
func.sum(Usage.cache_read_cost_usd).label("cache_read_cost"), func.sum(Usage.cache_read_cost_usd).label("cache_read_cost"),
) )
.filter(Usage.created_at >= last_month) .filter(Usage.created_at >= month_start)
.first() .first()
) )
cache_creation_tokens = int(cache_stats.cache_creation_tokens or 0) if cache_stats else 0 cache_creation_tokens = int(cache_stats.cache_creation_tokens or 0) if cache_stats else 0
@@ -267,7 +269,7 @@ class AdminDashboardStatsAdapter(AdminApiAdapter):
RequestCandidate.request_id, func.count(RequestCandidate.id).label("executed_count") RequestCandidate.request_id, func.count(RequestCandidate.id).label("executed_count")
) )
.filter( .filter(
RequestCandidate.created_at >= last_month, RequestCandidate.created_at >= month_start,
RequestCandidate.status.in_(["success", "failed"]), RequestCandidate.status.in_(["success", "failed"]),
) )
.group_by(RequestCandidate.request_id) .group_by(RequestCandidate.request_id)
@@ -447,7 +449,9 @@ class UserDashboardStatsAdapter(DashboardAdapter):
# 转换为 UTC 用于数据库查询 # 转换为 UTC 用于数据库查询
today = today_local.astimezone(timezone.utc) today = today_local.astimezone(timezone.utc)
yesterday = (today_local - timedelta(days=1)).astimezone(timezone.utc) yesterday = (today_local - timedelta(days=1)).astimezone(timezone.utc)
last_month = (today_local - timedelta(days=30)).astimezone(timezone.utc) # 本月第一天(自然月)
month_start_local = today_local.replace(day=1)
month_start = month_start_local.astimezone(timezone.utc)
user_api_keys = db.query(func.count(ApiKey.id)).filter(ApiKey.user_id == user.id).scalar() user_api_keys = db.query(func.count(ApiKey.id)).filter(ApiKey.user_id == user.id).scalar()
active_keys = ( active_keys = (
@@ -483,12 +487,12 @@ class UserDashboardStatsAdapter(DashboardAdapter):
# 本月请求统计 # 本月请求统计
user_requests = ( user_requests = (
db.query(func.count(Usage.id)) db.query(func.count(Usage.id))
.filter(and_(Usage.user_id == user.id, Usage.created_at >= last_month)) .filter(and_(Usage.user_id == user.id, Usage.created_at >= month_start))
.scalar() .scalar()
) )
user_cost = ( user_cost = (
db.query(func.sum(Usage.total_cost_usd)) db.query(func.sum(Usage.total_cost_usd))
.filter(and_(Usage.user_id == user.id, Usage.created_at >= last_month)) .filter(and_(Usage.user_id == user.id, Usage.created_at >= month_start))
.scalar() .scalar()
or 0 or 0
) )
@@ -532,18 +536,19 @@ class UserDashboardStatsAdapter(DashboardAdapter):
func.sum(Usage.cache_read_input_tokens).label("cache_read_tokens"), func.sum(Usage.cache_read_input_tokens).label("cache_read_tokens"),
func.sum(Usage.input_tokens).label("total_input_tokens"), func.sum(Usage.input_tokens).label("total_input_tokens"),
) )
.filter(and_(Usage.user_id == user.id, Usage.created_at >= last_month)) .filter(and_(Usage.user_id == user.id, Usage.created_at >= month_start))
.first() .first()
) )
cache_creation_tokens = int(cache_stats.cache_creation_tokens or 0) if cache_stats else 0 cache_creation_tokens = int(cache_stats.cache_creation_tokens or 0) if cache_stats else 0
cache_read_tokens = int(cache_stats.cache_read_tokens or 0) if cache_stats else 0 cache_read_tokens = int(cache_stats.cache_read_tokens or 0) if cache_stats else 0
monthly_input_tokens = int(cache_stats.total_input_tokens or 0) if cache_stats else 0
# 计算缓存命中率cache_read / (input_tokens + cache_read) # 计算本月缓存命中率cache_read / (input_tokens + cache_read)
# input_tokens 是实际发送给模型的输入不含缓存读取cache_read 是从缓存读取的 # input_tokens 是实际发送给模型的输入不含缓存读取cache_read 是从缓存读取的
# 总输入 = input_tokens + cache_read缓存命中率 = cache_read / 总输入 # 总输入 = input_tokens + cache_read缓存命中率 = cache_read / 总输入
total_input_with_cache = all_time_input_tokens + all_time_cache_read total_input_with_cache = monthly_input_tokens + cache_read_tokens
cache_hit_rate = ( cache_hit_rate = (
round((all_time_cache_read / total_input_with_cache) * 100, 1) round((cache_read_tokens / total_input_with_cache) * 100, 1)
if total_input_with_cache > 0 if total_input_with_cache > 0
else 0 else 0
) )
@@ -569,15 +574,15 @@ class UserDashboardStatsAdapter(DashboardAdapter):
quota_value = "无限制" quota_value = "无限制"
quota_change = f"已用 ${user.used_usd:.2f}" quota_change = f"已用 ${user.used_usd:.2f}"
quota_high = False quota_high = False
elif user.quota_usd and user.quota_usd > 0: elif user.quota_usd > 0:
percent = min(100, int((user.used_usd / user.quota_usd) * 100)) percent = min(100, int((user.used_usd / user.quota_usd) * 100))
quota_value = "无限制" quota_value = f"${user.quota_usd:.0f}"
quota_change = f"已用 ${user.used_usd:.2f}" quota_change = f"已用 ${user.used_usd:.2f}"
quota_high = percent > 80 quota_high = percent > 80
else: else:
quota_value = "0%" quota_value = "$0"
quota_change = f"已用 ${user.used_usd:.2f}" quota_change = f"已用 ${user.used_usd:.2f}"
quota_high = False quota_high = True
return { return {
"stats": [ "stats": [
@@ -605,9 +610,15 @@ class UserDashboardStatsAdapter(DashboardAdapter):
"icon": "TrendingUp", "icon": "TrendingUp",
}, },
{ {
"name": "本月费用", "name": "总Token",
"value": f"${user_cost:.2f}", "value": format_tokens(
"icon": "DollarSign", all_time_input_tokens
+ all_time_output_tokens
+ all_time_cache_creation
+ all_time_cache_read
),
"subValue": f"输入 {format_tokens(all_time_input_tokens)} / 输出 {format_tokens(all_time_output_tokens)}",
"icon": "Hash",
}, },
], ],
"today": { "today": {
@@ -631,6 +642,8 @@ class UserDashboardStatsAdapter(DashboardAdapter):
"cache_hit_rate": cache_hit_rate, "cache_hit_rate": cache_hit_rate,
"total_cache_tokens": cache_creation_tokens + cache_read_tokens, "total_cache_tokens": cache_creation_tokens + cache_read_tokens,
}, },
# 本月费用(用于下方缓存区域显示)
"monthly_cost": float(user_cost),
} }

View File

@@ -63,6 +63,34 @@ class ChatAdapterBase(ApiAdapter):
name: str = "chat.base" name: str = "chat.base"
mode = ApiMode.STANDARD mode = ApiMode.STANDARD
# 子类可以配置的特殊方法用于check_endpoint
@classmethod
def build_endpoint_url(cls, base_url: str) -> str:
"""构建端点URL子类可以覆盖以自定义URL构建逻辑"""
# 默认实现在base_url后添加特定路径
return base_url
@classmethod
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
"""构建基础请求头,子类可以覆盖以自定义认证头"""
# 默认实现Bearer token认证
return {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
@classmethod
def get_protected_header_keys(cls) -> tuple:
"""返回不应被extra_headers覆盖的头部key子类可以覆盖"""
# 默认保护认证相关头部
return ("authorization", "content-type")
@classmethod
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""构建请求体,子类可以覆盖以自定义请求格式转换"""
# 默认实现:直接使用请求数据
return request_data.copy()
def __init__(self, allowed_api_formats: Optional[list[str]] = None): def __init__(self, allowed_api_formats: Optional[list[str]] = None):
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID] self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
@@ -654,6 +682,65 @@ class ChatAdapterBase(ApiAdapter):
# 默认实现返回空列表,子类应覆盖 # 默认实现返回空列表,子类应覆盖
return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models" return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models"
@classmethod
async def check_endpoint(
cls,
client: httpx.AsyncClient,
base_url: str,
api_key: str,
request_data: Dict[str, Any],
extra_headers: Optional[Dict[str, str]] = None,
# 用量计算参数(现在强制记录)
db: Optional[Any] = None,
user: Optional[Any] = None,
provider_name: Optional[str] = None,
provider_id: Optional[str] = None,
api_key_id: Optional[str] = None,
model_name: Optional[str] = None,
) -> Dict[str, Any]:
"""
测试模型连接性(非流式)
Args:
client: httpx 异步客户端
base_url: API 基础 URL
api_key: API 密钥(已解密)
request_data: 请求数据
extra_headers: 端点配置的额外请求头
db: 数据库会话
user: 用户对象
provider_name: 提供商名称
provider_id: 提供商ID
api_key_id: API Key ID
model_name: 模型名称
Returns:
测试响应数据
"""
from src.api.handlers.base.endpoint_checker import build_safe_headers, run_endpoint_check
# 使用子类配置方法构建请求组件
url = cls.build_endpoint_url(base_url)
base_headers = cls.build_base_headers(api_key)
protected_keys = cls.get_protected_header_keys()
headers = build_safe_headers(base_headers, extra_headers, protected_keys)
body = cls.build_request_body(request_data)
# 使用通用的endpoint checker执行请求
return await run_endpoint_check(
client=client,
url=url,
headers=headers,
json_body=body,
api_format=cls.name,
# 用量计算参数(现在强制记录)
db=db,
user=user,
provider_name=provider_name,
provider_id=provider_id,
api_key_id=api_key_id,
model_name=model_name or request_data.get("model"),
)
# ========================================================================= # =========================================================================
# Adapter 注册表 - 用于根据 API format 获取 Adapter 实例 # Adapter 注册表 - 用于根据 API format 获取 Adapter 实例

View File

@@ -614,6 +614,146 @@ class CliAdapterBase(ApiAdapter):
# 默认实现返回空列表,子类应覆盖 # 默认实现返回空列表,子类应覆盖
return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models" return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models"
@classmethod
async def check_endpoint(
cls,
client: httpx.AsyncClient,
base_url: str,
api_key: str,
request_data: Dict[str, Any],
extra_headers: Optional[Dict[str, str]] = None,
# 用量计算参数
db: Optional[Any] = None,
user: Optional[Any] = None,
provider_name: Optional[str] = None,
provider_id: Optional[str] = None,
api_key_id: Optional[str] = None,
model_name: Optional[str] = None,
) -> Dict[str, Any]:
"""
测试模型连接性(非流式)
通用的CLI endpoint测试方法使用配置方法模式
- build_endpoint_url(): 构建请求URL
- build_base_headers(): 构建基础认证头
- get_protected_header_keys(): 获取受保护的头部key
- build_request_body(): 构建请求体
- get_cli_user_agent(): 获取CLI User-Agent子类可覆盖
Args:
client: httpx 异步客户端
base_url: API 基础 URL
api_key: API 密钥(已解密)
request_data: 请求数据
extra_headers: 端点配置的额外请求头
db: 数据库会话
user: 用户对象
provider_name: 提供商名称
provider_id: 提供商ID
api_key_id: API密钥ID
model_name: 模型名称
Returns:
测试响应数据
"""
from src.api.handlers.base.endpoint_checker import build_safe_headers, run_endpoint_check
# 构建请求组件
url = cls.build_endpoint_url(base_url, request_data, model_name)
base_headers = cls.build_base_headers(api_key)
protected_keys = cls.get_protected_header_keys()
# 添加CLI User-Agent
cli_user_agent = cls.get_cli_user_agent()
if cli_user_agent:
base_headers["User-Agent"] = cli_user_agent
protected_keys = tuple(list(protected_keys) + ["user-agent"])
headers = build_safe_headers(base_headers, extra_headers, protected_keys)
body = cls.build_request_body(request_data)
# 获取有效的模型名称
effective_model_name = model_name or request_data.get("model")
return await run_endpoint_check(
client=client,
url=url,
headers=headers,
json_body=body,
api_format=cls.name,
# 用量计算参数(现在强制记录)
db=db,
user=user,
provider_name=provider_name,
provider_id=provider_id,
api_key_id=api_key_id,
model_name=effective_model_name,
)
# =========================================================================
# CLI Adapter 配置方法 - 子类应覆盖这些方法而不是整个 check_endpoint
# =========================================================================
@classmethod
def build_endpoint_url(cls, base_url: str, request_data: Dict[str, Any], model_name: Optional[str] = None) -> str:
"""
构建CLI API端点URL - 子类应覆盖
Args:
base_url: API基础URL
request_data: 请求数据
model_name: 模型名称某些API需要如Gemini
Returns:
完整的端点URL
"""
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement build_endpoint_url")
@classmethod
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
"""
构建CLI API认证头 - 子类应覆盖
Args:
api_key: API密钥
Returns:
基础认证头部字典
"""
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement build_base_headers")
@classmethod
def get_protected_header_keys(cls) -> tuple:
"""
返回CLI API的保护头部key - 子类应覆盖
Returns:
保护头部key的元组
"""
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement get_protected_header_keys")
@classmethod
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""
构建CLI API请求体 - 子类应覆盖
Args:
request_data: 请求数据
Returns:
请求体字典
"""
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement build_request_body")
@classmethod
def get_cli_user_agent(cls) -> Optional[str]:
"""
获取CLI User-Agent - 子类可覆盖
Returns:
CLI User-Agent字符串如果不需要则为None
"""
return None
# ========================================================================= # =========================================================================
# CLI Adapter 注册表 - 用于根据 API format 获取 CLI Adapter 实例 # CLI Adapter 注册表 - 用于根据 API format 获取 CLI Adapter 实例

File diff suppressed because it is too large Load Diff

View File

@@ -209,6 +209,38 @@ class ClaudeChatAdapter(ChatAdapterBase):
logger.warning(f"Failed to fetch Claude models from {models_url}: {e}") logger.warning(f"Failed to fetch Claude models from {models_url}: {e}")
return [], error_msg return [], error_msg
@classmethod
def build_endpoint_url(cls, base_url: str) -> str:
"""构建Claude API端点URL"""
base_url = base_url.rstrip("/")
if base_url.endswith("/v1"):
return f"{base_url}/messages"
else:
return f"{base_url}/v1/messages"
@classmethod
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
"""构建Claude API认证头"""
return {
"x-api-key": api_key,
"Content-Type": "application/json",
"anthropic-version": "2023-06-01",
}
@classmethod
def get_protected_header_keys(cls) -> tuple:
"""返回Claude API的保护头部key"""
return ("x-api-key", "content-type", "anthropic-version")
@classmethod
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""构建Claude API请求体"""
return {
"model": request_data.get("model"),
"max_tokens": request_data.get("max_tokens", 100),
"messages": request_data.get("messages", []),
}
def build_claude_adapter(x_app_header: Optional[str]): def build_claude_adapter(x_app_header: Optional[str]):
"""根据 x-app 头部构造 Chat 或 Claude Code 适配器。""" """根据 x-app 头部构造 Chat 或 Claude Code 适配器。"""

View File

@@ -4,7 +4,7 @@ Claude CLI Adapter - 基于通用 CLI Adapter 基类的简化实现
继承 CliAdapterBase只需配置 FORMAT_ID 和 HANDLER_CLASS。 继承 CliAdapterBase只需配置 FORMAT_ID 和 HANDLER_CLASS。
""" """
from typing import Any, Dict, Optional, Tuple, Type from typing import Any, AsyncIterator, Dict, Optional, Tuple, Type, Union
import httpx import httpx
from fastapi import Request from fastapi import Request
@@ -126,5 +126,41 @@ class ClaudeCliAdapter(CliAdapterBase):
m["api_format"] = cls.FORMAT_ID m["api_format"] = cls.FORMAT_ID
return models, error return models, error
@classmethod
def build_endpoint_url(cls, base_url: str, request_data: Dict[str, Any], model_name: Optional[str] = None) -> str:
"""构建Claude CLI API端点URL"""
base_url = base_url.rstrip("/")
if base_url.endswith("/v1"):
return f"{base_url}/messages"
else:
return f"{base_url}/v1/messages"
@classmethod
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
"""构建Claude CLI API认证头"""
return {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
@classmethod
def get_protected_header_keys(cls) -> tuple:
"""返回Claude CLI API的保护头部key"""
return ("authorization", "content-type")
@classmethod
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""构建Claude CLI API请求体"""
return {
"model": request_data.get("model"),
"max_tokens": request_data.get("max_tokens", 100),
"messages": request_data.get("messages", []),
}
@classmethod
def get_cli_user_agent(cls) -> Optional[str]:
"""获取Claude CLI User-Agent"""
return config.internal_user_agent_claude_cli
__all__ = ["ClaudeCliAdapter"] __all__ = ["ClaudeCliAdapter"]

View File

@@ -4,7 +4,7 @@ Gemini Chat Adapter
处理 Gemini API 格式的请求适配 处理 Gemini API 格式的请求适配
""" """
from typing import Any, Dict, Optional, Tuple, Type from typing import Any, AsyncIterator, Dict, Optional, Tuple, Type, Union
import httpx import httpx
from fastapi import HTTPException, Request from fastapi import HTTPException, Request
@@ -12,6 +12,7 @@ from fastapi.responses import JSONResponse
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
from src.api.handlers.base.chat_handler_base import ChatHandlerBase from src.api.handlers.base.chat_handler_base import ChatHandlerBase
from src.api.handlers.base.endpoint_checker import build_safe_headers, run_endpoint_check
from src.core.logger import logger from src.core.logger import logger
from src.models.gemini import GeminiRequest from src.models.gemini import GeminiRequest
@@ -199,6 +200,94 @@ class GeminiChatAdapter(ChatAdapterBase):
logger.warning(f"Failed to fetch Gemini models from {models_url}: {e}") logger.warning(f"Failed to fetch Gemini models from {models_url}: {e}")
return [], error_msg return [], error_msg
@classmethod
def build_endpoint_url(cls, base_url: str) -> str:
"""构建Gemini API端点URL"""
base_url = base_url.rstrip("/")
if base_url.endswith("/v1beta"):
return base_url # 子类需要处理model参数
else:
return f"{base_url}/v1beta"
@classmethod
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
"""构建Gemini API认证头"""
return {
"x-goog-api-key": api_key,
"Content-Type": "application/json",
}
@classmethod
def get_protected_header_keys(cls) -> tuple:
"""返回Gemini API的保护头部key"""
return ("x-goog-api-key", "content-type")
@classmethod
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""构建Gemini API请求体"""
return {
"contents": request_data.get("messages", []),
"generationConfig": {
"maxOutputTokens": request_data.get("max_tokens", 100),
"temperature": request_data.get("temperature", 0.7),
},
"safetySettings": [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}
],
}
@classmethod
async def check_endpoint(
cls,
client: httpx.AsyncClient,
base_url: str,
api_key: str,
request_data: Dict[str, Any],
extra_headers: Optional[Dict[str, str]] = None,
# 用量计算参数
db: Optional[Any] = None,
user: Optional[Any] = None,
provider_name: Optional[str] = None,
provider_id: Optional[str] = None,
api_key_id: Optional[str] = None,
model_name: Optional[str] = None,
) -> Dict[str, Any]:
"""测试 Gemini API 模型连接性(非流式)"""
# Gemini需要从request_data或model_name参数获取model名称
effective_model_name = model_name or request_data.get("model", "")
if not effective_model_name:
return {
"error": "Model name is required for Gemini API",
"status_code": 400,
}
# 使用基类配置方法但重写URL构建逻辑
base_url = cls.build_endpoint_url(base_url)
url = f"{base_url}/models/{effective_model_name}:generateContent"
# 构建请求组件
base_headers = cls.build_base_headers(api_key)
protected_keys = cls.get_protected_header_keys()
headers = build_safe_headers(base_headers, extra_headers, protected_keys)
body = cls.build_request_body(request_data)
# 使用基类的通用endpoint checker
from src.api.handlers.base.endpoint_checker import run_endpoint_check
return await run_endpoint_check(
client=client,
url=url,
headers=headers,
json_body=body,
api_format=cls.name,
# 用量计算参数(现在强制记录)
db=db,
user=user,
provider_name=provider_name,
provider_id=provider_id,
api_key_id=api_key_id,
model_name=effective_model_name,
)
def build_gemini_adapter(x_app_header: str = "") -> GeminiChatAdapter: def build_gemini_adapter(x_app_header: str = "") -> GeminiChatAdapter:
""" """

View File

@@ -4,7 +4,7 @@ Gemini CLI Adapter - 基于通用 CLI Adapter 基类的实现
继承 CliAdapterBase处理 Gemini CLI 格式的请求。 继承 CliAdapterBase处理 Gemini CLI 格式的请求。
""" """
from typing import Any, Dict, Optional, Tuple, Type from typing import Any, AsyncIterator, Dict, Optional, Tuple, Type, Union
import httpx import httpx
from fastapi import Request from fastapi import Request
@@ -123,6 +123,52 @@ class GeminiCliAdapter(CliAdapterBase):
m["api_format"] = cls.FORMAT_ID m["api_format"] = cls.FORMAT_ID
return models, error return models, error
@classmethod
def build_endpoint_url(cls, base_url: str, request_data: Dict[str, Any], model_name: Optional[str] = None) -> str:
"""构建Gemini CLI API端点URL"""
effective_model_name = model_name or request_data.get("model", "")
if not effective_model_name:
raise ValueError("Model name is required for Gemini API")
base_url = base_url.rstrip("/")
if base_url.endswith("/v1beta"):
prefix = base_url
else:
prefix = f"{base_url}/v1beta"
return f"{prefix}/models/{effective_model_name}:generateContent"
@classmethod
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
"""构建Gemini CLI API认证头"""
return {
"x-goog-api-key": api_key,
"Content-Type": "application/json",
}
@classmethod
def get_protected_header_keys(cls) -> tuple:
"""返回Gemini CLI API的保护头部key"""
return ("x-goog-api-key", "content-type")
@classmethod
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""构建Gemini CLI API请求体"""
return {
"contents": request_data.get("messages", []),
"generationConfig": {
"maxOutputTokens": request_data.get("max_tokens", 100),
"temperature": request_data.get("temperature", 0.7),
},
"safetySettings": [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}
],
}
@classmethod
def get_cli_user_agent(cls) -> Optional[str]:
"""获取Gemini CLI User-Agent"""
return config.internal_user_agent_gemini_cli
def build_gemini_cli_adapter(x_app_header: str = "") -> GeminiCliAdapter: def build_gemini_cli_adapter(x_app_header: str = "") -> GeminiCliAdapter:
""" """

View File

@@ -4,13 +4,14 @@ OpenAI Chat Adapter - 基于 ChatAdapterBase 的 OpenAI Chat API 适配器
处理 /v1/chat/completions 端点的 OpenAI Chat 格式请求。 处理 /v1/chat/completions 端点的 OpenAI Chat 格式请求。
""" """
from typing import Any, Dict, Optional, Tuple, Type from typing import Any, AsyncIterator, Dict, Optional, Tuple, Type, Union
import httpx import httpx
from fastapi import Request from fastapi import Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
from src.api.handlers.base.endpoint_checker import build_safe_headers, run_endpoint_check
from src.api.handlers.base.chat_handler_base import ChatHandlerBase from src.api.handlers.base.chat_handler_base import ChatHandlerBase
from src.core.logger import logger from src.core.logger import logger
from src.models.openai import OpenAIRequest from src.models.openai import OpenAIRequest
@@ -154,5 +155,32 @@ class OpenAIChatAdapter(ChatAdapterBase):
logger.warning(f"Failed to fetch models from {models_url}: {e}") logger.warning(f"Failed to fetch models from {models_url}: {e}")
return [], error_msg return [], error_msg
@classmethod
def build_endpoint_url(cls, base_url: str) -> str:
"""构建OpenAI API端点URL"""
base_url = base_url.rstrip("/")
if base_url.endswith("/v1"):
return f"{base_url}/chat/completions"
else:
return f"{base_url}/v1/chat/completions"
@classmethod
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
"""构建OpenAI API认证头"""
return {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
@classmethod
def get_protected_header_keys(cls) -> tuple:
"""返回OpenAI API的保护头部key"""
return ("authorization", "content-type")
@classmethod
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""构建OpenAI API请求体"""
return request_data.copy()
__all__ = ["OpenAIChatAdapter"] __all__ = ["OpenAIChatAdapter"]

View File

@@ -4,7 +4,7 @@ OpenAI CLI Adapter - 基于通用 CLI Adapter 基类的简化实现
继承 CliAdapterBase只需配置 FORMAT_ID 和 HANDLER_CLASS。 继承 CliAdapterBase只需配置 FORMAT_ID 和 HANDLER_CLASS。
""" """
from typing import Dict, Optional, Tuple, Type from typing import Any, AsyncIterator, Dict, Optional, Tuple, Type, Union
import httpx import httpx
from fastapi import Request from fastapi import Request
@@ -68,5 +68,37 @@ class OpenAICliAdapter(CliAdapterBase):
m["api_format"] = cls.FORMAT_ID m["api_format"] = cls.FORMAT_ID
return models, error return models, error
@classmethod
def build_endpoint_url(cls, base_url: str, request_data: Dict[str, Any], model_name: Optional[str] = None) -> str:
"""构建OpenAI CLI API端点URL"""
base_url = base_url.rstrip("/")
if base_url.endswith("/v1"):
return f"{base_url}/chat/completions"
else:
return f"{base_url}/v1/chat/completions"
@classmethod
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
"""构建OpenAI CLI API认证头"""
return {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
@classmethod
def get_protected_header_keys(cls) -> tuple:
"""返回OpenAI CLI API的保护头部key"""
return ("authorization", "content-type")
@classmethod
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""构建OpenAI CLI API请求体"""
return request_data.copy()
@classmethod
def get_cli_user_agent(cls) -> Optional[str]:
"""获取OpenAI CLI User-Agent"""
return config.internal_user_agent_openai_cli
__all__ = ["OpenAICliAdapter"] __all__ = ["OpenAICliAdapter"]

View File

@@ -14,6 +14,7 @@ from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from src.api.base.models_service import ( from src.api.base.models_service import (
AccessRestrictions,
ModelInfo, ModelInfo,
find_model_by_id, find_model_by_id,
get_available_provider_ids, get_available_provider_ids,
@@ -103,6 +104,35 @@ def _get_formats_for_api(api_format: str) -> list[str]:
return _OPENAI_FORMATS return _OPENAI_FORMATS
def _build_empty_list_response(api_format: str) -> dict:
"""根据 API 格式构建空列表响应"""
if api_format == "claude":
return {"data": [], "has_more": False, "first_id": None, "last_id": None}
elif api_format == "gemini":
return {"models": []}
else:
return {"object": "list", "data": []}
def _filter_formats_by_restrictions(
formats: list[str], restrictions: AccessRestrictions, api_format: str
) -> Tuple[list[str], Optional[dict]]:
"""
根据访问限制过滤 API 格式
Returns:
(过滤后的格式列表, 空响应或None)
如果过滤后为空,返回对应格式的空响应
"""
if restrictions.allowed_api_formats is None:
return formats, None
filtered = [f for f in formats if f in restrictions.allowed_api_formats]
if not filtered:
logger.info(f"[Models] API Key 不允许访问格式 {api_format}")
return [], _build_empty_list_response(api_format)
return filtered, None
def _authenticate(db: Session, api_key: Optional[str]) -> Tuple[Optional[User], Optional[ApiKey]]: def _authenticate(db: Session, api_key: Optional[str]) -> Tuple[Optional[User], Optional[ApiKey]]:
""" """
认证 API Key 认证 API Key
@@ -375,22 +405,24 @@ async def list_models(
logger.info(f"[Models] GET /v1/models | format={api_format}") logger.info(f"[Models] GET /v1/models | format={api_format}")
# 认证 # 认证
user, _ = _authenticate(db, api_key) user, key_record = _authenticate(db, api_key)
if not user: if not user:
return _build_auth_error_response(api_format) return _build_auth_error_response(api_format)
# 构建访问限制
restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
# 检查 API 格式限制
formats = _get_formats_for_api(api_format) formats = _get_formats_for_api(api_format)
formats, empty_response = _filter_formats_by_restrictions(formats, restrictions, api_format)
if empty_response is not None:
return empty_response
available_provider_ids = get_available_provider_ids(db, formats) available_provider_ids = get_available_provider_ids(db, formats)
if not available_provider_ids: if not available_provider_ids:
if api_format == "claude": return _build_empty_list_response(api_format)
return {"data": [], "has_more": False, "first_id": None, "last_id": None}
elif api_format == "gemini":
return {"models": []}
else:
return {"object": "list", "data": []}
models = await list_available_models(db, available_provider_ids, formats) models = await list_available_models(db, available_provider_ids, formats, restrictions)
logger.debug(f"[Models] 返回 {len(models)} 个模型") logger.debug(f"[Models] 返回 {len(models)} 个模型")
if api_format == "claude": if api_format == "claude":
@@ -419,14 +451,21 @@ async def retrieve_model(
logger.info(f"[Models] GET /v1/models/{model_id} | format={api_format}") logger.info(f"[Models] GET /v1/models/{model_id} | format={api_format}")
# 认证 # 认证
user, _ = _authenticate(db, api_key) user, key_record = _authenticate(db, api_key)
if not user: if not user:
return _build_auth_error_response(api_format) return _build_auth_error_response(api_format)
# 构建访问限制
restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
# 检查 API 格式限制
formats = _get_formats_for_api(api_format) formats = _get_formats_for_api(api_format)
formats, _ = _filter_formats_by_restrictions(formats, restrictions, api_format)
if not formats:
return _build_404_response(model_id, api_format)
available_provider_ids = get_available_provider_ids(db, formats) available_provider_ids = get_available_provider_ids(db, formats)
model_info = find_model_by_id(db, model_id, available_provider_ids, formats) model_info = find_model_by_id(db, model_id, available_provider_ids, formats, restrictions)
if not model_info: if not model_info:
return _build_404_response(model_id, api_format) return _build_404_response(model_id, api_format)
@@ -455,15 +494,25 @@ async def list_models_gemini(
api_key = _extract_api_key_from_request(request, gemini_def) api_key = _extract_api_key_from_request(request, gemini_def)
# 认证 # 认证
user, _ = _authenticate(db, api_key) user, key_record = _authenticate(db, api_key)
if not user: if not user:
return _build_auth_error_response("gemini") return _build_auth_error_response("gemini")
available_provider_ids = get_available_provider_ids(db, _GEMINI_FORMATS) # 构建访问限制
restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
# 检查 API 格式限制
formats, empty_response = _filter_formats_by_restrictions(
_GEMINI_FORMATS, restrictions, "gemini"
)
if empty_response is not None:
return empty_response
available_provider_ids = get_available_provider_ids(db, formats)
if not available_provider_ids: if not available_provider_ids:
return {"models": []} return {"models": []}
models = await list_available_models(db, available_provider_ids, _GEMINI_FORMATS) models = await list_available_models(db, available_provider_ids, formats, restrictions)
logger.debug(f"[Models] 返回 {len(models)} 个模型") logger.debug(f"[Models] 返回 {len(models)} 个模型")
response = _build_gemini_list_response(models, page_size, page_token) response = _build_gemini_list_response(models, page_size, page_token)
logger.debug(f"[Models] Gemini 响应: {response}") logger.debug(f"[Models] Gemini 响应: {response}")
@@ -486,12 +535,22 @@ async def get_model_gemini(
api_key = _extract_api_key_from_request(request, gemini_def) api_key = _extract_api_key_from_request(request, gemini_def)
# 认证 # 认证
user, _ = _authenticate(db, api_key) user, key_record = _authenticate(db, api_key)
if not user: if not user:
return _build_auth_error_response("gemini") return _build_auth_error_response("gemini")
available_provider_ids = get_available_provider_ids(db, _GEMINI_FORMATS) # 构建访问限制
model_info = find_model_by_id(db, model_id, available_provider_ids, _GEMINI_FORMATS) restrictions = AccessRestrictions.from_api_key_and_user(key_record, user)
# 检查 API 格式限制
formats, _ = _filter_formats_by_restrictions(_GEMINI_FORMATS, restrictions, "gemini")
if not formats:
return _build_404_response(model_id, "gemini")
available_provider_ids = get_available_provider_ids(db, formats)
model_info = find_model_by_id(
db, model_id, available_provider_ids, formats, restrictions
)
if not model_info: if not model_info:
return _build_404_response(model_id, "gemini") return _build_404_response(model_id, "gemini")

View File

@@ -9,6 +9,7 @@ from urllib.parse import quote, urlparse
import httpx import httpx
from src.config import config
from src.core.logger import logger from src.core.logger import logger
@@ -83,10 +84,10 @@ class HTTPClientPool:
http2=False, # 暂时禁用HTTP/2以提高兼容性 http2=False, # 暂时禁用HTTP/2以提高兼容性
verify=True, # 启用SSL验证 verify=True, # 启用SSL验证
timeout=httpx.Timeout( timeout=httpx.Timeout(
connect=10.0, # 连接超时 connect=config.http_connect_timeout,
read=300.0, # 读取超时(5分钟,适合流式响应) read=config.http_read_timeout,
write=60.0, # 写入超时(60秒,支持大请求体) write=config.http_write_timeout,
pool=5.0, # 连接池超时 pool=config.http_pool_timeout,
), ),
limits=httpx.Limits( limits=httpx.Limits(
max_connections=100, # 最大连接数 max_connections=100, # 最大连接数
@@ -111,15 +112,20 @@ class HTTPClientPool:
""" """
if name not in cls._clients: if name not in cls._clients:
# 合并默认配置和自定义配置 # 合并默认配置和自定义配置
config = { default_config = {
"http2": False, "http2": False,
"verify": True, "verify": True,
"timeout": httpx.Timeout(10.0, read=300.0), "timeout": httpx.Timeout(
connect=config.http_connect_timeout,
read=config.http_read_timeout,
write=config.http_write_timeout,
pool=config.http_pool_timeout,
),
"follow_redirects": True, "follow_redirects": True,
} }
config.update(kwargs) default_config.update(kwargs)
cls._clients[name] = httpx.AsyncClient(**config) cls._clients[name] = httpx.AsyncClient(**default_config)
logger.debug(f"创建命名HTTP客户端: {name}") logger.debug(f"创建命名HTTP客户端: {name}")
return cls._clients[name] return cls._clients[name]
@@ -151,14 +157,19 @@ class HTTPClientPool:
async with HTTPClientPool.get_temp_client() as client: async with HTTPClientPool.get_temp_client() as client:
response = await client.get('https://example.com') response = await client.get('https://example.com')
""" """
config = { default_config = {
"http2": False, "http2": False,
"verify": True, "verify": True,
"timeout": httpx.Timeout(10.0), "timeout": httpx.Timeout(
connect=config.http_connect_timeout,
read=config.http_read_timeout,
write=config.http_write_timeout,
pool=config.http_pool_timeout,
),
} }
config.update(kwargs) default_config.update(kwargs)
client = httpx.AsyncClient(**config) client = httpx.AsyncClient(**default_config)
try: try:
yield client yield client
finally: finally:
@@ -182,25 +193,30 @@ class HTTPClientPool:
Returns: Returns:
配置好的 httpx.AsyncClient 实例 配置好的 httpx.AsyncClient 实例
""" """
config: Dict[str, Any] = { client_config: Dict[str, Any] = {
"http2": False, "http2": False,
"verify": True, "verify": True,
"follow_redirects": True, "follow_redirects": True,
} }
if timeout: if timeout:
config["timeout"] = timeout client_config["timeout"] = timeout
else: else:
config["timeout"] = httpx.Timeout(10.0, read=300.0) client_config["timeout"] = httpx.Timeout(
connect=config.http_connect_timeout,
read=config.http_read_timeout,
write=config.http_write_timeout,
pool=config.http_pool_timeout,
)
# 添加代理配置 # 添加代理配置
proxy_url = build_proxy_url(proxy_config) if proxy_config else None proxy_url = build_proxy_url(proxy_config) if proxy_config else None
if proxy_url: if proxy_url:
config["proxy"] = proxy_url client_config["proxy"] = proxy_url
logger.debug(f"创建带代理的HTTP客户端: {proxy_config.get('url', 'unknown')}") logger.debug(f"创建带代理的HTTP客户端: {proxy_config.get('url', 'unknown')}")
config.update(kwargs) client_config.update(kwargs)
return httpx.AsyncClient(**config) return httpx.AsyncClient(**client_config)
# 便捷访问函数 # 便捷访问函数

View File

@@ -148,6 +148,7 @@ class Config:
# HTTP 请求超时配置(秒) # HTTP 请求超时配置(秒)
self.http_connect_timeout = float(os.getenv("HTTP_CONNECT_TIMEOUT", "10.0")) self.http_connect_timeout = float(os.getenv("HTTP_CONNECT_TIMEOUT", "10.0"))
self.http_read_timeout = float(os.getenv("HTTP_READ_TIMEOUT", "300.0"))
self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0")) self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0"))
self.http_pool_timeout = float(os.getenv("HTTP_POOL_TIMEOUT", "10.0")) self.http_pool_timeout = float(os.getenv("HTTP_POOL_TIMEOUT", "10.0"))

View File

@@ -96,13 +96,15 @@ if not DISABLE_FILE_LOG:
log_dir.mkdir(exist_ok=True) log_dir.mkdir(exist_ok=True)
# 文件日志通用配置 # 文件日志通用配置
# 注意: enqueue=False 使用同步模式,避免 multiprocessing 信号量泄漏
# 在 macOS 上,进程异常退出时 POSIX 信号量不会自动释放,导致资源耗尽
file_log_config = { file_log_config = {
"format": FILE_FORMAT, "format": FILE_FORMAT,
"filter": _log_filter, "filter": _log_filter,
"rotation": "100 MB", "rotation": "100 MB",
"retention": "30 days", "retention": "30 days",
"compression": "gz", "compression": "gz",
"enqueue": True, "enqueue": False,
"encoding": "utf-8", "encoding": "utf-8",
"catch": True, "catch": True,
} }

View File

@@ -360,6 +360,9 @@ def init_db():
注意:数据库表结构由 Alembic 管理,部署时请运行 ./migrate.sh 注意:数据库表结构由 Alembic 管理,部署时请运行 ./migrate.sh
""" """
import sys
from sqlalchemy.exc import OperationalError
logger.info("初始化数据库...") logger.info("初始化数据库...")
# 确保引擎已创建 # 确保引擎已创建
@@ -382,6 +385,38 @@ def init_db():
db.commit() db.commit()
logger.info("数据库初始化完成") logger.info("数据库初始化完成")
except OperationalError as e:
db.rollback()
# 提取数据库连接信息用于提示
db_url = config.database_url
# 隐藏密码,只显示 host:port/database
if "@" in db_url:
db_info = db_url.split("@")[-1]
else:
db_info = db_url
import os
# 直接打印到 stderr确保消息显示
print("", file=sys.stderr)
print("=" * 60, file=sys.stderr)
print("数据库连接失败", file=sys.stderr)
print("=" * 60, file=sys.stderr)
print("", file=sys.stderr)
print(f"无法连接到数据库: {db_info}", file=sys.stderr)
print("", file=sys.stderr)
print("请检查以下事项:", file=sys.stderr)
print(" 1. PostgreSQL 服务是否正在运行", file=sys.stderr)
print(" 2. 数据库连接配置是否正确 (DATABASE_URL)", file=sys.stderr)
print(" 3. 数据库用户名和密码是否正确", file=sys.stderr)
print("", file=sys.stderr)
print("如果使用 Docker请先运行:", file=sys.stderr)
print(" docker-compose up -d postgres redis", file=sys.stderr)
print("", file=sys.stderr)
print("=" * 60, file=sys.stderr)
# 使用 os._exit 直接退出,避免 uvicorn 捕获并打印堆栈
os._exit(1)
except Exception as e: except Exception as e:
logger.error(f"数据库初始化失败: {e}") logger.error(f"数据库初始化失败: {e}")
db.rollback() db.rollback()

View File

@@ -317,6 +317,7 @@ class UpdateUserRequest(BaseModel):
username: Optional[str] = Field(None, min_length=1, max_length=50) username: Optional[str] = Field(None, min_length=1, max_length=50)
email: Optional[str] = Field(None, max_length=100) email: Optional[str] = Field(None, max_length=100)
password: Optional[str] = Field(None, min_length=6, max_length=128, description="新密码(留空保持不变)")
quota_usd: Optional[float] = Field(None, ge=0) quota_usd: Optional[float] = Field(None, ge=0)
is_active: Optional[bool] = None is_active: Optional[bool] = None
role: Optional[str] = None role: Optional[str] = None

View File

@@ -274,6 +274,13 @@ class GlobalModelListResponse(BaseModel):
total: int total: int
class GlobalModelProvidersResponse(BaseModel):
"""GlobalModel 关联提供商列表响应"""
providers: List[ModelCatalogProviderDetail]
total: int
class BatchAssignToProvidersRequest(BaseModel): class BatchAssignToProvidersRequest(BaseModel):
"""批量为 Provider 添加 GlobalModel 实现""" """批量为 Provider 添加 GlobalModel 实现"""

View File

@@ -30,6 +30,8 @@
from __future__ import annotations from __future__ import annotations
import hashlib
import random
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
@@ -956,7 +958,16 @@ class CacheAwareScheduler:
# 获取活跃的 Key 并按 internal_priority + 负载均衡排序 # 获取活跃的 Key 并按 internal_priority + 负载均衡排序
active_keys = [key for key in endpoint.api_keys if key.is_active] active_keys = [key for key in endpoint.api_keys if key.is_active]
keys = self._shuffle_keys_by_internal_priority(active_keys, affinity_key) # 检查是否所有 Key 都是 TTL=0轮换模式
# 如果所有 Key 的 cache_ttl_minutes 都是 0 或 None则使用随机排序
use_random = all(
(key.cache_ttl_minutes or 0) == 0 for key in active_keys
) if active_keys else False
if use_random and len(active_keys) > 1:
logger.debug(
f" Endpoint {endpoint.id[:8]}... 启用 Key 轮换模式 (TTL=0, {len(active_keys)} keys)"
)
keys = self._shuffle_keys_by_internal_priority(active_keys, affinity_key, use_random)
for key in keys: for key in keys:
# Key 级别的能力检查(模型级别的能力检查已在上面完成) # Key 级别的能力检查(模型级别的能力检查已在上面完成)
@@ -1170,6 +1181,7 @@ class CacheAwareScheduler:
self, self,
keys: List[ProviderAPIKey], keys: List[ProviderAPIKey],
affinity_key: Optional[str] = None, affinity_key: Optional[str] = None,
use_random: bool = False,
) -> List[ProviderAPIKey]: ) -> List[ProviderAPIKey]:
""" """
对 API Key 按 internal_priority 分组,同优先级内部基于 affinity_key 进行确定性打乱 对 API Key 按 internal_priority 分组,同优先级内部基于 affinity_key 进行确定性打乱
@@ -1178,10 +1190,12 @@ class CacheAwareScheduler:
- 数字越小越优先使用 - 数字越小越优先使用
- 同优先级 Key 之间实现负载均衡 - 同优先级 Key 之间实现负载均衡
- 使用 affinity_key 哈希确保同一请求 Key 的请求稳定(避免破坏缓存亲和性) - 使用 affinity_key 哈希确保同一请求 Key 的请求稳定(避免破坏缓存亲和性)
- 当 use_random=True 时,使用随机排序实现轮换(用于 TTL=0 的场景)
Args: Args:
keys: API Key 列表 keys: API Key 列表
affinity_key: 亲和性标识符(通常为 API Key ID用于确定性打乱 affinity_key: 亲和性标识符(通常为 API Key ID用于确定性打乱
use_random: 是否使用随机排序TTL=0 时为 True
Returns: Returns:
排序后的 Key 列表 排序后的 Key 列表
@@ -1198,15 +1212,19 @@ class CacheAwareScheduler:
priority = key.internal_priority if key.internal_priority is not None else 999999 priority = key.internal_priority if key.internal_priority is not None else 999999
priority_groups[priority].append(key) priority_groups[priority].append(key)
# 对每个优先级组内的 Key 进行确定性打乱 # 对每个优先级组内的 Key 进行打乱
result = [] result = []
for priority in sorted(priority_groups.keys()): # 数字小的优先级高,排前面 for priority in sorted(priority_groups.keys()): # 数字小的优先级高,排前面
group_keys = priority_groups[priority] group_keys = priority_groups[priority]
if len(group_keys) > 1 and affinity_key: if len(group_keys) > 1:
# 改进的哈希策略:为每个 key 计算独立的哈希值 if use_random:
import hashlib # TTL=0 模式:使用随机排序实现 Key 轮换
shuffled = list(group_keys)
random.shuffle(shuffled)
result.extend(shuffled)
elif affinity_key:
# 正常模式:使用哈希确定性打乱(保持缓存亲和性)
key_scores = [] key_scores = []
for key in group_keys: for key in group_keys:
# 使用 affinity_key + key.id 的组合哈希 # 使用 affinity_key + key.id 的组合哈希
@@ -1218,8 +1236,11 @@ class CacheAwareScheduler:
sorted_group = [key for _, key in sorted(key_scores)] sorted_group = [key for _, key in sorted(key_scores)]
result.extend(sorted_group) result.extend(sorted_group)
else: else:
# 单个 Key 或没有 affinity_key 时保持原顺序 # 没有 affinity_key 时按 ID 排序保持稳定性
result.extend(sorted(group_keys, key=lambda k: k.id)) result.extend(sorted(group_keys, key=lambda k: k.id))
else:
# 单个 Key 直接添加
result.extend(group_keys)
return result return result

View File

@@ -234,8 +234,15 @@ class EndpointHealthService:
for api_format in format_key_mapping.keys() for api_format in format_key_mapping.keys()
} }
# 参数校验API 层已通过 Query(ge=1) 保证,这里做防御性检查)
if lookback_hours <= 0 or segments <= 0:
raise ValueError(
f"lookback_hours and segments must be positive, "
f"got lookback_hours={lookback_hours}, segments={segments}"
)
# 计算时间范围 # 计算时间范围
interval_minutes = (lookback_hours * 60) // segments segment_seconds = (lookback_hours * 3600) / segments
start_time = now - timedelta(hours=lookback_hours) start_time = now - timedelta(hours=lookback_hours)
# 使用 RequestCandidate 表查询所有尝试记录 # 使用 RequestCandidate 表查询所有尝试记录
@@ -243,7 +250,7 @@ class EndpointHealthService:
final_statuses = ["success", "failed", "skipped"] final_statuses = ["success", "failed", "skipped"]
segment_expr = func.floor( segment_expr = func.floor(
func.extract('epoch', RequestCandidate.created_at - start_time) / (interval_minutes * 60) func.extract('epoch', RequestCandidate.created_at - start_time) / segment_seconds
).label('segment_idx') ).label('segment_idx')
candidate_stats = ( candidate_stats = (

View File

@@ -59,14 +59,15 @@ class ApiKeyService:
if expire_days: if expire_days:
expires_at = datetime.now(timezone.utc) + timedelta(days=expire_days) expires_at = datetime.now(timezone.utc) + timedelta(days=expire_days)
# 空数组转为 None表示不限制
api_key = ApiKey( api_key = ApiKey(
user_id=user_id, user_id=user_id,
key_hash=key_hash, key_hash=key_hash,
key_encrypted=key_encrypted, key_encrypted=key_encrypted,
name=name or f"API Key {datetime.now(timezone.utc).strftime('%Y%m%d%H%M%S')}", name=name or f"API Key {datetime.now(timezone.utc).strftime('%Y%m%d%H%M%S')}",
allowed_providers=allowed_providers, allowed_providers=allowed_providers or None,
allowed_api_formats=allowed_api_formats, allowed_api_formats=allowed_api_formats or None,
allowed_models=allowed_models, allowed_models=allowed_models or None,
rate_limit=rate_limit, rate_limit=rate_limit,
concurrent_limit=concurrent_limit, concurrent_limit=concurrent_limit,
expires_at=expires_at, expires_at=expires_at,
@@ -141,8 +142,18 @@ class ApiKeyService:
"auto_delete_on_expiry", "auto_delete_on_expiry",
] ]
# 允许显式设置为空数组/None 的字段(空数组会转为 None表示"全部"
nullable_list_fields = {"allowed_providers", "allowed_api_formats", "allowed_models"}
for field, value in kwargs.items(): for field, value in kwargs.items():
if field in updatable_fields and value is not None: if field not in updatable_fields:
continue
# 对于 nullable_list_fields空数组应该转为 None表示不限制
if field in nullable_list_fields:
if value is not None:
# 空数组转为 None表示允许全部
setattr(api_key, field, value if value else None)
elif value is not None:
setattr(api_key, field, value) setattr(api_key, field, value)
api_key.updated_at = datetime.now(timezone.utc) api_key.updated_at = datetime.now(timezone.utc)

View File

@@ -1,8 +1,16 @@
"""分布式任务协调器,确保仅有一个 worker 执行特定任务""" """分布式任务协调器,确保仅有一个 worker 执行特定任务
锁清理策略:
- 单实例模式(默认):启动时使用原子操作清理旧锁并获取新锁
- 多实例模式:使用 NX 选项竞争锁,依赖 TTL 处理异常退出
使用方式:
- 默认行为:启动时清理旧锁(适用于单机部署)
- 多实例部署:设置 SINGLE_INSTANCE_MODE=false 禁用启动清理
"""
from __future__ import annotations from __future__ import annotations
import asyncio
import os import os
import pathlib import pathlib
import uuid import uuid
@@ -19,6 +27,10 @@ except ImportError: # pragma: no cover - Windows 环境
class StartupTaskCoordinator: class StartupTaskCoordinator:
"""利用 Redis 或文件锁,保证任务只在单个进程/实例中运行""" """利用 Redis 或文件锁,保证任务只在单个进程/实例中运行"""
# 类级别标记:在当前进程中是否已尝试过启动清理
# 注意:这在 fork 模式下每个 worker 都是独立的
_startup_cleanup_attempted = False
def __init__(self, redis_client=None, lock_dir: Optional[str] = None): def __init__(self, redis_client=None, lock_dir: Optional[str] = None):
self.redis = redis_client self.redis = redis_client
self._tokens: Dict[str, str] = {} self._tokens: Dict[str, str] = {}
@@ -26,6 +38,8 @@ class StartupTaskCoordinator:
self._lock_dir = pathlib.Path(lock_dir or os.getenv("TASK_LOCK_DIR", "./.locks")) self._lock_dir = pathlib.Path(lock_dir or os.getenv("TASK_LOCK_DIR", "./.locks"))
if not self._lock_dir.exists(): if not self._lock_dir.exists():
self._lock_dir.mkdir(parents=True, exist_ok=True) self._lock_dir.mkdir(parents=True, exist_ok=True)
# 单实例模式:启动时清理旧锁(适用于单机部署,避免残留锁问题)
self._single_instance_mode = os.getenv("SINGLE_INSTANCE_MODE", "true").lower() == "true"
def _redis_key(self, name: str) -> str: def _redis_key(self, name: str) -> str:
return f"task_lock:{name}" return f"task_lock:{name}"
@@ -36,7 +50,46 @@ class StartupTaskCoordinator:
if self.redis: if self.redis:
token = str(uuid.uuid4()) token = str(uuid.uuid4())
try: try:
acquired = await self.redis.set(self._redis_key(name), token, nx=True, ex=ttl) if self._single_instance_mode:
# 单实例模式:使用 Lua 脚本原子性地"清理旧锁 + 竞争获取"
# 只有当锁不存在或成功获取时才返回 1
# 这样第一个执行的 worker 会清理旧锁并获取,后续 worker 会正常竞争
script = """
local key = KEYS[1]
local token = ARGV[1]
local ttl = tonumber(ARGV[2])
local startup_key = KEYS[1] .. ':startup'
-- 检查是否已有 worker 执行过启动清理
local cleaned = redis.call('GET', startup_key)
if not cleaned then
-- 第一个 worker删除旧锁标记已清理
redis.call('DEL', key)
redis.call('SET', startup_key, '1', 'EX', 60)
end
-- 尝试获取锁NX 模式)
local result = redis.call('SET', key, token, 'NX', 'EX', ttl)
if result then
return 1
end
return 0
"""
result = await self.redis.eval(
script, 2,
self._redis_key(name), self._redis_key(name),
token, ttl
)
if result == 1:
self._tokens[name] = token
logger.info(f"任务 {name} 通过 Redis 锁独占执行")
return True
return False
else:
# 多实例模式:直接使用 NX 选项竞争锁
acquired = await self.redis.set(
self._redis_key(name), token, nx=True, ex=ttl
)
if acquired: if acquired:
self._tokens[name] = token self._tokens[name] = token
logger.info(f"任务 {name} 通过 Redis 锁独占执行") logger.info(f"任务 {name} 通过 Redis 锁独占执行")

View File

@@ -8,116 +8,86 @@ from src.api.handlers.base.utils import build_sse_headers, extract_cache_creatio
class TestExtractCacheCreationTokens: class TestExtractCacheCreationTokens:
"""测试 extract_cache_creation_tokens 函数""" """测试 extract_cache_creation_tokens 函数"""
# === 嵌套格式测试(优先级最高)=== def test_new_format_only(self) -> None:
"""测试只有新格式字段"""
def test_nested_cache_creation_format(self) -> None:
"""测试嵌套格式正常情况"""
usage = {
"cache_creation": {
"ephemeral_5m_input_tokens": 456,
"ephemeral_1h_input_tokens": 100,
}
}
assert extract_cache_creation_tokens(usage) == 556
def test_nested_cache_creation_with_old_format_fallback(self) -> None:
"""测试嵌套格式为 0 时回退到旧格式"""
usage = {
"cache_creation": {
"ephemeral_5m_input_tokens": 0,
"ephemeral_1h_input_tokens": 0,
},
"cache_creation_input_tokens": 549,
}
assert extract_cache_creation_tokens(usage) == 549
def test_nested_has_priority_over_flat(self) -> None:
"""测试嵌套格式优先于扁平格式"""
usage = {
"cache_creation": {
"ephemeral_5m_input_tokens": 100,
"ephemeral_1h_input_tokens": 200,
},
"claude_cache_creation_5_m_tokens": 999, # 应该被忽略
"claude_cache_creation_1_h_tokens": 888, # 应该被忽略
"cache_creation_input_tokens": 777, # 应该被忽略
}
assert extract_cache_creation_tokens(usage) == 300
# === 扁平格式测试(优先级第二)===
def test_flat_new_format_still_works(self) -> None:
"""测试扁平新格式兼容性"""
usage = { usage = {
"claude_cache_creation_5_m_tokens": 100, "claude_cache_creation_5_m_tokens": 100,
"claude_cache_creation_1_h_tokens": 200, "claude_cache_creation_1_h_tokens": 200,
} }
assert extract_cache_creation_tokens(usage) == 300 assert extract_cache_creation_tokens(usage) == 300
def test_flat_new_format_with_old_format_fallback(self) -> None: def test_new_format_5m_only(self) -> None:
"""测试扁平格式为 0 时回退到旧格式""" """测试只有 5 分钟缓存"""
usage = {
"claude_cache_creation_5_m_tokens": 0,
"claude_cache_creation_1_h_tokens": 0,
"cache_creation_input_tokens": 549,
}
assert extract_cache_creation_tokens(usage) == 549
def test_flat_new_format_5m_only(self) -> None:
"""测试只有 5 分钟扁平缓存"""
usage = { usage = {
"claude_cache_creation_5_m_tokens": 150, "claude_cache_creation_5_m_tokens": 150,
"claude_cache_creation_1_h_tokens": 0, "claude_cache_creation_1_h_tokens": 0,
} }
assert extract_cache_creation_tokens(usage) == 150 assert extract_cache_creation_tokens(usage) == 150
def test_flat_new_format_1h_only(self) -> None: def test_new_format_1h_only(self) -> None:
"""测试只有 1 小时扁平缓存""" """测试只有 1 小时缓存"""
usage = { usage = {
"claude_cache_creation_5_m_tokens": 0, "claude_cache_creation_5_m_tokens": 0,
"claude_cache_creation_1_h_tokens": 250, "claude_cache_creation_1_h_tokens": 250,
} }
assert extract_cache_creation_tokens(usage) == 250 assert extract_cache_creation_tokens(usage) == 250
# === 旧格式测试(优先级第三)===
def test_old_format_only(self) -> None: def test_old_format_only(self) -> None:
"""测试只有旧格式""" """测试只有旧格式字段"""
usage = { usage = {
"cache_creation_input_tokens": 549, "cache_creation_input_tokens": 500,
} }
assert extract_cache_creation_tokens(usage) == 549 assert extract_cache_creation_tokens(usage) == 500
# === 边界情况测试 === def test_both_formats_prefers_new(self) -> None:
"""测试同时存在时优先使用新格式"""
usage = {
"claude_cache_creation_5_m_tokens": 100,
"claude_cache_creation_1_h_tokens": 200,
"cache_creation_input_tokens": 999, # 应该被忽略
}
assert extract_cache_creation_tokens(usage) == 300
def test_no_cache_creation_tokens(self) -> None: def test_empty_usage(self) -> None:
"""测试没有任何缓存字段""" """测试空字典"""
usage = {} usage = {}
assert extract_cache_creation_tokens(usage) == 0 assert extract_cache_creation_tokens(usage) == 0
def test_all_formats_zero(self) -> None: def test_all_zeros(self) -> None:
"""测试所有格式都为 0""" """测试所有字段都为 0"""
usage = { usage = {
"cache_creation": {
"ephemeral_5m_input_tokens": 0,
"ephemeral_1h_input_tokens": 0,
},
"claude_cache_creation_5_m_tokens": 0, "claude_cache_creation_5_m_tokens": 0,
"claude_cache_creation_1_h_tokens": 0, "claude_cache_creation_1_h_tokens": 0,
"cache_creation_input_tokens": 0, "cache_creation_input_tokens": 0,
} }
assert extract_cache_creation_tokens(usage) == 0 assert extract_cache_creation_tokens(usage) == 0
def test_partial_new_format_with_old_format_fallback(self) -> None:
"""测试新格式字段不存在时回退到旧格式"""
usage = {
"cache_creation_input_tokens": 123,
}
assert extract_cache_creation_tokens(usage) == 123
def test_new_format_zero_should_not_fallback(self) -> None:
"""测试新格式字段存在但为 0 时,不应 fallback 到旧格式"""
usage = {
"claude_cache_creation_5_m_tokens": 0,
"claude_cache_creation_1_h_tokens": 0,
"cache_creation_input_tokens": 456,
}
# 新格式字段存在,即使值为 0 也应该使用新格式(返回 0
# 而不是 fallback 到旧格式(返回 456
assert extract_cache_creation_tokens(usage) == 0
def test_unrelated_fields_ignored(self) -> None: def test_unrelated_fields_ignored(self) -> None:
"""测试忽略无关字段""" """测试忽略无关字段"""
usage = { usage = {
"input_tokens": 1000, "input_tokens": 1000,
"output_tokens": 2000, "output_tokens": 2000,
"cache_read_input_tokens": 300, "cache_read_input_tokens": 300,
"cache_creation": { "claude_cache_creation_5_m_tokens": 50,
"ephemeral_5m_input_tokens": 50, "claude_cache_creation_1_h_tokens": 75,
"ephemeral_1h_input_tokens": 75,
},
} }
assert extract_cache_creation_tokens(usage) == 125 assert extract_cache_creation_tokens(usage) == 125