mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 00:02:28 +08:00
Compare commits
27 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
41719a00e7 | ||
|
|
b5c0f85dca | ||
|
|
7d6d262ed3 | ||
|
|
e21acd73eb | ||
|
|
702f9bc5f1 | ||
|
|
d0ce798881 | ||
|
|
2b1d197047 | ||
|
|
71bc2e6aab | ||
|
|
afb329934a | ||
|
|
1313af45a3 | ||
|
|
dddb327885 | ||
|
|
26b4a37323 | ||
|
|
9dad194130 | ||
|
|
03ad16ea8a | ||
|
|
2fa64b98e3 | ||
|
|
75d7e89cbb | ||
|
|
d73a443484 | ||
|
|
15a9b88fc8 | ||
|
|
03eb7203ec | ||
|
|
e38cd6819b | ||
|
|
d44cfaddf6 | ||
|
|
65225710a8 | ||
|
|
d7f5b16359 | ||
|
|
7185818724 | ||
|
|
868f3349e5 | ||
|
|
d7384e69d9 | ||
|
|
1d5c378343 |
@@ -105,7 +105,7 @@ RUN printf '%s\n' \
|
|||||||
'stderr_logfile=/var/log/nginx/error.log' \
|
'stderr_logfile=/var/log/nginx/error.log' \
|
||||||
'' \
|
'' \
|
||||||
'[program:app]' \
|
'[program:app]' \
|
||||||
'command=gunicorn src.main:app -w %(ENV_GUNICORN_WORKERS)s -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:%(ENV_PORT)s --timeout 120 --access-logfile - --error-logfile - --log-level info' \
|
'command=gunicorn src.main:app --preload -w %(ENV_GUNICORN_WORKERS)s -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:%(ENV_PORT)s --timeout 120 --access-logfile - --error-logfile - --log-level info' \
|
||||||
'directory=/app' \
|
'directory=/app' \
|
||||||
'autostart=true' \
|
'autostart=true' \
|
||||||
'autorestart=true' \
|
'autorestart=true' \
|
||||||
|
|||||||
@@ -106,7 +106,7 @@ RUN printf '%s\n' \
|
|||||||
'stderr_logfile=/var/log/nginx/error.log' \
|
'stderr_logfile=/var/log/nginx/error.log' \
|
||||||
'' \
|
'' \
|
||||||
'[program:app]' \
|
'[program:app]' \
|
||||||
'command=gunicorn src.main:app -w %(ENV_GUNICORN_WORKERS)s -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:%(ENV_PORT)s --timeout 120 --access-logfile - --error-logfile - --log-level info' \
|
'command=gunicorn src.main:app --preload -w %(ENV_GUNICORN_WORKERS)s -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:%(ENV_PORT)s --timeout 120 --access-logfile - --error-logfile - --log-level info' \
|
||||||
'directory=/app' \
|
'directory=/app' \
|
||||||
'autostart=true' \
|
'autostart=true' \
|
||||||
'autorestart=true' \
|
'autorestart=true' \
|
||||||
|
|||||||
@@ -394,6 +394,10 @@ def upgrade() -> None:
|
|||||||
index=True,
|
index=True,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
# usage 表复合索引(优化常见查询)
|
||||||
|
op.create_index("idx_usage_user_created", "usage", ["user_id", "created_at"])
|
||||||
|
op.create_index("idx_usage_apikey_created", "usage", ["api_key_id", "created_at"])
|
||||||
|
op.create_index("idx_usage_provider_model_created", "usage", ["provider", "model", "created_at"])
|
||||||
|
|
||||||
# ==================== user_quotas ====================
|
# ==================== user_quotas ====================
|
||||||
op.create_table(
|
op.create_table(
|
||||||
|
|||||||
@@ -0,0 +1,65 @@
|
|||||||
|
"""add usage table composite indexes for query optimization
|
||||||
|
|
||||||
|
Revision ID: b2c3d4e5f6g7
|
||||||
|
Revises: a1b2c3d4e5f6
|
||||||
|
Create Date: 2025-12-20 15:00:00.000000+00:00
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'b2c3d4e5f6g7'
|
||||||
|
down_revision = 'a1b2c3d4e5f6'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""为 usage 表添加复合索引以优化常见查询
|
||||||
|
|
||||||
|
注意:这些索引已经在 baseline 迁移中创建。
|
||||||
|
此迁移仅用于从旧版本升级的场景,新安装会跳过。
|
||||||
|
"""
|
||||||
|
conn = op.get_bind()
|
||||||
|
|
||||||
|
# 检查 usage 表是否存在
|
||||||
|
result = conn.execute(text(
|
||||||
|
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'usage')"
|
||||||
|
))
|
||||||
|
if not result.scalar():
|
||||||
|
# 表不存在,跳过
|
||||||
|
return
|
||||||
|
|
||||||
|
# 定义需要创建的索引
|
||||||
|
indexes = [
|
||||||
|
("idx_usage_user_created", "ON usage (user_id, created_at)"),
|
||||||
|
("idx_usage_apikey_created", "ON usage (api_key_id, created_at)"),
|
||||||
|
("idx_usage_provider_model_created", "ON usage (provider, model, created_at)"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# 分别检查并创建每个索引
|
||||||
|
for index_name, index_def in indexes:
|
||||||
|
result = conn.execute(text(
|
||||||
|
f"SELECT EXISTS (SELECT 1 FROM pg_indexes WHERE indexname = '{index_name}')"
|
||||||
|
))
|
||||||
|
if result.scalar():
|
||||||
|
continue # 索引已存在,跳过
|
||||||
|
|
||||||
|
conn.execute(text(f"CREATE INDEX {index_name} {index_def}"))
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""删除复合索引"""
|
||||||
|
conn = op.get_bind()
|
||||||
|
|
||||||
|
# 使用 IF EXISTS 避免索引不存在时报错
|
||||||
|
conn.execute(text(
|
||||||
|
"DROP INDEX IF EXISTS idx_usage_provider_model_created"
|
||||||
|
))
|
||||||
|
conn.execute(text(
|
||||||
|
"DROP INDEX IF EXISTS idx_usage_apikey_created"
|
||||||
|
))
|
||||||
|
conn.execute(text(
|
||||||
|
"DROP INDEX IF EXISTS idx_usage_user_created"
|
||||||
|
))
|
||||||
27
deploy.sh
27
deploy.sh
@@ -26,10 +26,13 @@ calc_deps_hash() {
|
|||||||
cat pyproject.toml frontend/package.json frontend/package-lock.json Dockerfile.base.local 2>/dev/null | md5sum | cut -d' ' -f1
|
cat pyproject.toml frontend/package.json frontend/package-lock.json Dockerfile.base.local 2>/dev/null | md5sum | cut -d' ' -f1
|
||||||
}
|
}
|
||||||
|
|
||||||
# 计算代码文件的哈希值
|
# 计算代码文件的哈希值(包含 Dockerfile.app.local)
|
||||||
calc_code_hash() {
|
calc_code_hash() {
|
||||||
find src -type f -name "*.py" 2>/dev/null | sort | xargs cat 2>/dev/null | md5sum | cut -d' ' -f1
|
{
|
||||||
find frontend/src -type f \( -name "*.vue" -o -name "*.ts" -o -name "*.tsx" -o -name "*.js" \) 2>/dev/null | sort | xargs cat 2>/dev/null | md5sum | cut -d' ' -f1
|
cat Dockerfile.app.local 2>/dev/null
|
||||||
|
find src -type f -name "*.py" 2>/dev/null | sort | xargs cat 2>/dev/null
|
||||||
|
find frontend/src -type f \( -name "*.vue" -o -name "*.ts" -o -name "*.tsx" -o -name "*.js" \) 2>/dev/null | sort | xargs cat 2>/dev/null
|
||||||
|
} | md5sum | cut -d' ' -f1
|
||||||
}
|
}
|
||||||
|
|
||||||
# 计算迁移文件的哈希值
|
# 计算迁移文件的哈希值
|
||||||
@@ -179,7 +182,13 @@ else
|
|||||||
echo ">>> Dependencies unchanged."
|
echo ">>> Dependencies unchanged."
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# 检查代码是否变化,或者 base 重建了(app 依赖 base)
|
# 检查代码或迁移是否变化,或者 base 重建了(app 依赖 base)
|
||||||
|
# 注意:迁移文件打包在镜像中,所以迁移变化也需要重建 app 镜像
|
||||||
|
MIGRATION_CHANGED=false
|
||||||
|
if check_migration_changed; then
|
||||||
|
MIGRATION_CHANGED=true
|
||||||
|
fi
|
||||||
|
|
||||||
if ! docker image inspect aether-app:latest >/dev/null 2>&1; then
|
if ! docker image inspect aether-app:latest >/dev/null 2>&1; then
|
||||||
echo ">>> App image not found, building..."
|
echo ">>> App image not found, building..."
|
||||||
build_app
|
build_app
|
||||||
@@ -192,6 +201,10 @@ elif check_code_changed; then
|
|||||||
echo ">>> Code changed, rebuilding app image..."
|
echo ">>> Code changed, rebuilding app image..."
|
||||||
build_app
|
build_app
|
||||||
NEED_RESTART=true
|
NEED_RESTART=true
|
||||||
|
elif [ "$MIGRATION_CHANGED" = true ]; then
|
||||||
|
echo ">>> Migration files changed, rebuilding app image..."
|
||||||
|
build_app
|
||||||
|
NEED_RESTART=true
|
||||||
else
|
else
|
||||||
echo ">>> Code unchanged."
|
echo ">>> Code unchanged."
|
||||||
fi
|
fi
|
||||||
@@ -204,9 +217,9 @@ else
|
|||||||
echo ">>> No changes detected, skipping restart."
|
echo ">>> No changes detected, skipping restart."
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# 检查迁移变化
|
# 检查迁移变化(如果前面已经检测到变化并重建了镜像,这里直接运行迁移)
|
||||||
if check_migration_changed; then
|
if [ "$MIGRATION_CHANGED" = true ]; then
|
||||||
echo ">>> Migration files changed, running database migration..."
|
echo ">>> Running database migration..."
|
||||||
sleep 3
|
sleep 3
|
||||||
run_migration
|
run_migration
|
||||||
else
|
else
|
||||||
|
|||||||
@@ -58,3 +58,38 @@ export async function deleteProvider(providerId: string): Promise<{ message: str
|
|||||||
return response.data
|
return response.data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 测试模型连接性
|
||||||
|
*/
|
||||||
|
export interface TestModelRequest {
|
||||||
|
provider_id: string
|
||||||
|
model_name: string
|
||||||
|
api_key_id?: string
|
||||||
|
message?: string
|
||||||
|
api_format?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface TestModelResponse {
|
||||||
|
success: boolean
|
||||||
|
error?: string
|
||||||
|
data?: {
|
||||||
|
response?: {
|
||||||
|
status_code?: number
|
||||||
|
error?: string | { message?: string }
|
||||||
|
choices?: Array<{ message?: { content?: string } }>
|
||||||
|
}
|
||||||
|
content_preview?: string
|
||||||
|
}
|
||||||
|
provider?: {
|
||||||
|
id: string
|
||||||
|
name: string
|
||||||
|
display_name: string
|
||||||
|
}
|
||||||
|
model?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function testModel(data: TestModelRequest): Promise<TestModelResponse> {
|
||||||
|
const response = await client.post('/api/admin/provider-query/test-model', data)
|
||||||
|
return response.data
|
||||||
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -110,6 +110,24 @@ export interface EndpointAPIKey {
|
|||||||
request_results_window?: Array<{ ts: number; ok: boolean }> // 请求结果滑动窗口
|
request_results_window?: Array<{ ts: number; ok: boolean }> // 请求结果滑动窗口
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface EndpointAPIKeyUpdate {
|
||||||
|
name?: string
|
||||||
|
api_key?: string // 仅在需要更新时提供
|
||||||
|
rate_multiplier?: number
|
||||||
|
internal_priority?: number
|
||||||
|
global_priority?: number | null
|
||||||
|
max_concurrent?: number | null // null 表示切换为自适应模式
|
||||||
|
rate_limit?: number
|
||||||
|
daily_limit?: number
|
||||||
|
monthly_limit?: number
|
||||||
|
allowed_models?: string[] | null
|
||||||
|
capabilities?: Record<string, boolean> | null
|
||||||
|
cache_ttl_minutes?: number
|
||||||
|
max_probe_interval_minutes?: number
|
||||||
|
note?: string
|
||||||
|
is_active?: boolean
|
||||||
|
}
|
||||||
|
|
||||||
export interface EndpointHealthDetail {
|
export interface EndpointHealthDetail {
|
||||||
api_format: string
|
api_format: string
|
||||||
health_score: number
|
health_score: number
|
||||||
|
|||||||
@@ -163,7 +163,9 @@ const contentZIndex = computed(() => (props.zIndex || 60) + 10)
|
|||||||
useEscapeKey(() => {
|
useEscapeKey(() => {
|
||||||
if (isOpen.value) {
|
if (isOpen.value) {
|
||||||
handleClose()
|
handleClose()
|
||||||
|
return true // 阻止其他监听器(如父级抽屉的 ESC 监听器)
|
||||||
}
|
}
|
||||||
|
return false
|
||||||
}, {
|
}, {
|
||||||
disableOnInput: true,
|
disableOnInput: true,
|
||||||
once: false
|
once: false
|
||||||
|
|||||||
@@ -4,11 +4,11 @@ import { log } from '@/utils/logger'
|
|||||||
export function useClipboard() {
|
export function useClipboard() {
|
||||||
const { success, error: showError } = useToast()
|
const { success, error: showError } = useToast()
|
||||||
|
|
||||||
async function copyToClipboard(text: string): Promise<boolean> {
|
async function copyToClipboard(text: string, showToast = true): Promise<boolean> {
|
||||||
try {
|
try {
|
||||||
if (navigator.clipboard && window.isSecureContext) {
|
if (navigator.clipboard && window.isSecureContext) {
|
||||||
await navigator.clipboard.writeText(text)
|
await navigator.clipboard.writeText(text)
|
||||||
success('已复制到剪贴板')
|
if (showToast) success('已复制到剪贴板')
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -25,17 +25,17 @@ export function useClipboard() {
|
|||||||
try {
|
try {
|
||||||
const successful = document.execCommand('copy')
|
const successful = document.execCommand('copy')
|
||||||
if (successful) {
|
if (successful) {
|
||||||
success('已复制到剪贴板')
|
if (showToast) success('已复制到剪贴板')
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
showError('复制失败,请手动复制')
|
if (showToast) showError('复制失败,请手动复制')
|
||||||
return false
|
return false
|
||||||
} finally {
|
} finally {
|
||||||
document.body.removeChild(textArea)
|
document.body.removeChild(textArea)
|
||||||
}
|
}
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
log.error('复制失败:', err)
|
log.error('复制失败:', err)
|
||||||
showError('复制失败,请手动选择文本进行复制')
|
if (showToast) showError('复制失败,请手动选择文本进行复制')
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -47,11 +47,11 @@ export function useConfirm() {
|
|||||||
/**
|
/**
|
||||||
* 便捷方法:危险操作确认(红色主题)
|
* 便捷方法:危险操作确认(红色主题)
|
||||||
*/
|
*/
|
||||||
const confirmDanger = (message: string, title?: string): Promise<boolean> => {
|
const confirmDanger = (message: string, title?: string, confirmText?: string): Promise<boolean> => {
|
||||||
return confirm({
|
return confirm({
|
||||||
message,
|
message,
|
||||||
title: title || '危险操作',
|
title: title || '危险操作',
|
||||||
confirmText: '删除',
|
confirmText: confirmText || '删除',
|
||||||
variant: 'danger'
|
variant: 'danger'
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,11 +4,11 @@ import { onMounted, onUnmounted, ref } from 'vue'
|
|||||||
* ESC 键监听 Composable(简化版本,直接使用独立监听器)
|
* ESC 键监听 Composable(简化版本,直接使用独立监听器)
|
||||||
* 用于按 ESC 键关闭弹窗或其他可关闭的组件
|
* 用于按 ESC 键关闭弹窗或其他可关闭的组件
|
||||||
*
|
*
|
||||||
* @param callback - 按 ESC 键时执行的回调函数
|
* @param callback - 按 ESC 键时执行的回调函数,返回 true 表示已处理事件,阻止其他监听器执行
|
||||||
* @param options - 配置选项
|
* @param options - 配置选项
|
||||||
*/
|
*/
|
||||||
export function useEscapeKey(
|
export function useEscapeKey(
|
||||||
callback: () => void,
|
callback: () => void | boolean,
|
||||||
options: {
|
options: {
|
||||||
/** 是否在输入框获得焦点时禁用 ESC 键,默认 true */
|
/** 是否在输入框获得焦点时禁用 ESC 键,默认 true */
|
||||||
disableOnInput?: boolean
|
disableOnInput?: boolean
|
||||||
@@ -42,8 +42,11 @@ export function useEscapeKey(
|
|||||||
if (isInputElement) return
|
if (isInputElement) return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 执行回调
|
// 执行回调,如果返回 true 则阻止其他监听器
|
||||||
callback()
|
const handled = callback()
|
||||||
|
if (handled === true) {
|
||||||
|
event.stopImmediatePropagation()
|
||||||
|
}
|
||||||
|
|
||||||
// 移除当前元素的焦点,避免残留样式
|
// 移除当前元素的焦点,避免残留样式
|
||||||
if (document.activeElement instanceof HTMLElement) {
|
if (document.activeElement instanceof HTMLElement) {
|
||||||
|
|||||||
@@ -700,6 +700,7 @@ import {
|
|||||||
} from 'lucide-vue-next'
|
} from 'lucide-vue-next'
|
||||||
import { useEscapeKey } from '@/composables/useEscapeKey'
|
import { useEscapeKey } from '@/composables/useEscapeKey'
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
|
import { useClipboard } from '@/composables/useClipboard'
|
||||||
import Card from '@/components/ui/card.vue'
|
import Card from '@/components/ui/card.vue'
|
||||||
import Badge from '@/components/ui/badge.vue'
|
import Badge from '@/components/ui/badge.vue'
|
||||||
import Button from '@/components/ui/button.vue'
|
import Button from '@/components/ui/button.vue'
|
||||||
@@ -731,6 +732,7 @@ const emit = defineEmits<{
|
|||||||
'refreshProviders': []
|
'refreshProviders': []
|
||||||
}>()
|
}>()
|
||||||
const { success: showSuccess, error: showError } = useToast()
|
const { success: showSuccess, error: showError } = useToast()
|
||||||
|
const { copyToClipboard } = useClipboard()
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
model: GlobalModelResponse | null
|
model: GlobalModelResponse | null
|
||||||
@@ -763,16 +765,6 @@ function handleClose() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 复制到剪贴板
|
|
||||||
async function copyToClipboard(text: string) {
|
|
||||||
try {
|
|
||||||
await navigator.clipboard.writeText(text)
|
|
||||||
showSuccess('已复制')
|
|
||||||
} catch {
|
|
||||||
showError('复制失败')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 格式化日期
|
// 格式化日期
|
||||||
function formatDate(dateStr: string): string {
|
function formatDate(dateStr: string): string {
|
||||||
if (!dateStr) return '-'
|
if (!dateStr) return '-'
|
||||||
|
|||||||
@@ -433,11 +433,17 @@ const availableGlobalModels = computed(() => {
|
|||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
// 计算可添加的上游模型(排除已关联的)
|
// 计算可添加的上游模型(排除已关联的,包括主模型名和映射名称)
|
||||||
const availableUpstreamModelsBase = computed(() => {
|
const availableUpstreamModelsBase = computed(() => {
|
||||||
const existingModelNames = new Set(
|
const existingModelNames = new Set<string>()
|
||||||
existingModels.value.map(m => m.provider_model_name)
|
for (const m of existingModels.value) {
|
||||||
)
|
// 主模型名
|
||||||
|
existingModelNames.add(m.provider_model_name)
|
||||||
|
// 映射名称
|
||||||
|
for (const mapping of m.provider_model_mappings ?? []) {
|
||||||
|
if (mapping.name) existingModelNames.add(mapping.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
return upstreamModels.value.filter(m => !existingModelNames.has(m.id))
|
return upstreamModels.value.filter(m => !existingModelNames.has(m.id))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -260,6 +260,7 @@ import {
|
|||||||
updateEndpointKey,
|
updateEndpointKey,
|
||||||
getAllCapabilities,
|
getAllCapabilities,
|
||||||
type EndpointAPIKey,
|
type EndpointAPIKey,
|
||||||
|
type EndpointAPIKeyUpdate,
|
||||||
type ProviderEndpoint,
|
type ProviderEndpoint,
|
||||||
type CapabilityDefinition
|
type CapabilityDefinition
|
||||||
} from '@/api/endpoints'
|
} from '@/api/endpoints'
|
||||||
@@ -386,10 +387,11 @@ function loadKeyData() {
|
|||||||
api_key: '',
|
api_key: '',
|
||||||
rate_multiplier: props.editingKey.rate_multiplier || 1.0,
|
rate_multiplier: props.editingKey.rate_multiplier || 1.0,
|
||||||
internal_priority: props.editingKey.internal_priority ?? 50,
|
internal_priority: props.editingKey.internal_priority ?? 50,
|
||||||
max_concurrent: props.editingKey.max_concurrent || undefined,
|
// 保留原始的 null/undefined 状态,null 表示自适应模式
|
||||||
rate_limit: props.editingKey.rate_limit || undefined,
|
max_concurrent: props.editingKey.max_concurrent ?? undefined,
|
||||||
daily_limit: props.editingKey.daily_limit || undefined,
|
rate_limit: props.editingKey.rate_limit ?? undefined,
|
||||||
monthly_limit: props.editingKey.monthly_limit || undefined,
|
daily_limit: props.editingKey.daily_limit ?? undefined,
|
||||||
|
monthly_limit: props.editingKey.monthly_limit ?? undefined,
|
||||||
cache_ttl_minutes: props.editingKey.cache_ttl_minutes ?? 5,
|
cache_ttl_minutes: props.editingKey.cache_ttl_minutes ?? 5,
|
||||||
max_probe_interval_minutes: props.editingKey.max_probe_interval_minutes ?? 32,
|
max_probe_interval_minutes: props.editingKey.max_probe_interval_minutes ?? 32,
|
||||||
note: props.editingKey.note || '',
|
note: props.editingKey.note || '',
|
||||||
@@ -439,12 +441,17 @@ async function handleSave() {
|
|||||||
saving.value = true
|
saving.value = true
|
||||||
try {
|
try {
|
||||||
if (props.editingKey) {
|
if (props.editingKey) {
|
||||||
// 更新
|
// 更新模式
|
||||||
const updateData: any = {
|
// 注意:max_concurrent 需要显式发送 null 来切换到自适应模式
|
||||||
|
// undefined 会在 JSON 中被忽略,所以用 null 表示"清空/自适应"
|
||||||
|
const updateData: EndpointAPIKeyUpdate = {
|
||||||
name: form.value.name,
|
name: form.value.name,
|
||||||
rate_multiplier: form.value.rate_multiplier,
|
rate_multiplier: form.value.rate_multiplier,
|
||||||
internal_priority: form.value.internal_priority,
|
internal_priority: form.value.internal_priority,
|
||||||
max_concurrent: form.value.max_concurrent,
|
// 显式使用 null 表示自适应模式,这样后端能区分"未提供"和"设置为 null"
|
||||||
|
// 注意:只有 max_concurrent 需要这种处理,因为它有"自适应模式"的概念
|
||||||
|
// 其他限制字段(rate_limit 等)不支持"清空"操作,undefined 会被 JSON 忽略即不更新
|
||||||
|
max_concurrent: form.value.max_concurrent === undefined ? null : form.value.max_concurrent,
|
||||||
rate_limit: form.value.rate_limit,
|
rate_limit: form.value.rate_limit,
|
||||||
daily_limit: form.value.daily_limit,
|
daily_limit: form.value.daily_limit,
|
||||||
monthly_limit: form.value.monthly_limit,
|
monthly_limit: form.value.monthly_limit,
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
v-model:open="modelSelectOpen"
|
v-model:open="modelSelectOpen"
|
||||||
:model-value="formData.modelId"
|
:model-value="formData.modelId"
|
||||||
:disabled="!!editingGroup"
|
:disabled="!!editingGroup"
|
||||||
@update:model-value="formData.modelId = $event"
|
@update:model-value="handleModelChange"
|
||||||
>
|
>
|
||||||
<SelectTrigger class="h-9">
|
<SelectTrigger class="h-9">
|
||||||
<SelectValue placeholder="请选择模型" />
|
<SelectValue placeholder="请选择模型" />
|
||||||
@@ -449,7 +449,17 @@ interface UpstreamModelGroup {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const groupedAvailableUpstreamModels = computed<UpstreamModelGroup[]>(() => {
|
const groupedAvailableUpstreamModels = computed<UpstreamModelGroup[]>(() => {
|
||||||
|
// 收集当前表单已添加的名称
|
||||||
const addedNames = new Set(formData.value.aliases.map(a => a.name.trim()))
|
const addedNames = new Set(formData.value.aliases.map(a => a.name.trim()))
|
||||||
|
|
||||||
|
// 收集所有已存在的映射名称(包括主模型名和映射名称)
|
||||||
|
for (const m of props.models) {
|
||||||
|
addedNames.add(m.provider_model_name)
|
||||||
|
for (const mapping of m.provider_model_mappings ?? []) {
|
||||||
|
if (mapping.name) addedNames.add(mapping.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const availableModels = filteredUpstreamModels.value.filter(m => !addedNames.has(m.id))
|
const availableModels = filteredUpstreamModels.value.filter(m => !addedNames.has(m.id))
|
||||||
|
|
||||||
const groups = new Map<string, UpstreamModelGroup>()
|
const groups = new Map<string, UpstreamModelGroup>()
|
||||||
@@ -519,6 +529,15 @@ function initForm() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 处理模型选择变更
|
||||||
|
function handleModelChange(value: string) {
|
||||||
|
formData.value.modelId = value
|
||||||
|
const selectedModel = props.models.find(m => m.id === value)
|
||||||
|
if (selectedModel) {
|
||||||
|
upstreamModelSearch.value = selectedModel.provider_model_name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 切换 API 格式
|
// 切换 API 格式
|
||||||
function toggleApiFormat(format: string) {
|
function toggleApiFormat(format: string) {
|
||||||
const index = formData.value.apiFormats.indexOf(format)
|
const index = formData.value.apiFormats.indexOf(format)
|
||||||
|
|||||||
@@ -483,9 +483,9 @@
|
|||||||
<span
|
<span
|
||||||
v-if="key.max_concurrent || key.is_adaptive"
|
v-if="key.max_concurrent || key.is_adaptive"
|
||||||
class="text-muted-foreground"
|
class="text-muted-foreground"
|
||||||
:title="key.is_adaptive ? `自适应并发限制(学习值: ${key.learned_max_concurrent ?? '未学习'})` : '固定并发限制'"
|
:title="key.is_adaptive ? `自适应并发限制(学习值: ${key.learned_max_concurrent ?? '未学习'})` : `固定并发限制: ${key.max_concurrent}`"
|
||||||
>
|
>
|
||||||
{{ key.is_adaptive ? '自适应' : '固定' }}并发: {{ key.learned_max_concurrent || key.max_concurrent || 3 }}
|
{{ key.is_adaptive ? '自适应' : '固定' }}并发: {{ key.is_adaptive ? (key.learned_max_concurrent ?? '学习中') : key.max_concurrent }}
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -531,6 +531,7 @@
|
|||||||
<!-- 模型名称映射 -->
|
<!-- 模型名称映射 -->
|
||||||
<ModelAliasesTab
|
<ModelAliasesTab
|
||||||
v-if="provider"
|
v-if="provider"
|
||||||
|
ref="modelAliasesTabRef"
|
||||||
:key="`aliases-${provider.id}`"
|
:key="`aliases-${provider.id}`"
|
||||||
:provider="provider"
|
:provider="provider"
|
||||||
@refresh="handleRelatedDataRefresh"
|
@refresh="handleRelatedDataRefresh"
|
||||||
@@ -660,6 +661,7 @@ import Button from '@/components/ui/button.vue'
|
|||||||
import Badge from '@/components/ui/badge.vue'
|
import Badge from '@/components/ui/badge.vue'
|
||||||
import Card from '@/components/ui/card.vue'
|
import Card from '@/components/ui/card.vue'
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
|
import { useClipboard } from '@/composables/useClipboard'
|
||||||
import { getProvider, getProviderEndpoints } from '@/api/endpoints'
|
import { getProvider, getProviderEndpoints } from '@/api/endpoints'
|
||||||
import {
|
import {
|
||||||
KeyFormDialog,
|
KeyFormDialog,
|
||||||
@@ -705,6 +707,7 @@ const emit = defineEmits<{
|
|||||||
}>()
|
}>()
|
||||||
|
|
||||||
const { error: showError, success: showSuccess } = useToast()
|
const { error: showError, success: showSuccess } = useToast()
|
||||||
|
const { copyToClipboard } = useClipboard()
|
||||||
|
|
||||||
const loading = ref(false)
|
const loading = ref(false)
|
||||||
const provider = ref<any>(null)
|
const provider = ref<any>(null)
|
||||||
@@ -735,6 +738,9 @@ const deleteModelConfirmOpen = ref(false)
|
|||||||
const modelToDelete = ref<Model | null>(null)
|
const modelToDelete = ref<Model | null>(null)
|
||||||
const batchAssignDialogOpen = ref(false)
|
const batchAssignDialogOpen = ref(false)
|
||||||
|
|
||||||
|
// ModelAliasesTab 组件引用
|
||||||
|
const modelAliasesTabRef = ref<InstanceType<typeof ModelAliasesTab> | null>(null)
|
||||||
|
|
||||||
// 拖动排序相关状态
|
// 拖动排序相关状态
|
||||||
const dragState = ref({
|
const dragState = ref({
|
||||||
isDragging: false,
|
isDragging: false,
|
||||||
@@ -756,7 +762,9 @@ const hasBlockingDialogOpen = computed(() =>
|
|||||||
deleteKeyConfirmOpen.value ||
|
deleteKeyConfirmOpen.value ||
|
||||||
modelFormDialogOpen.value ||
|
modelFormDialogOpen.value ||
|
||||||
deleteModelConfirmOpen.value ||
|
deleteModelConfirmOpen.value ||
|
||||||
batchAssignDialogOpen.value
|
batchAssignDialogOpen.value ||
|
||||||
|
// 检测 ModelAliasesTab 子组件的 Dialog 是否打开
|
||||||
|
modelAliasesTabRef.value?.dialogOpen
|
||||||
)
|
)
|
||||||
|
|
||||||
// 监听 providerId 变化
|
// 监听 providerId 变化
|
||||||
@@ -1244,16 +1252,6 @@ function getHealthScoreBarColor(score: number): string {
|
|||||||
return 'bg-red-500 dark:bg-red-400'
|
return 'bg-red-500 dark:bg-red-400'
|
||||||
}
|
}
|
||||||
|
|
||||||
// 复制到剪贴板
|
|
||||||
async function copyToClipboard(text: string) {
|
|
||||||
try {
|
|
||||||
await navigator.clipboard.writeText(text)
|
|
||||||
showSuccess('已复制到剪贴板')
|
|
||||||
} catch {
|
|
||||||
showError('复制失败', '错误')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 加载 Provider 信息
|
// 加载 Provider 信息
|
||||||
async function loadProvider() {
|
async function loadProvider() {
|
||||||
if (!props.providerId) return
|
if (!props.providerId) return
|
||||||
|
|||||||
@@ -110,16 +110,30 @@
|
|||||||
<div
|
<div
|
||||||
v-for="mapping in group.aliases"
|
v-for="mapping in group.aliases"
|
||||||
:key="mapping.name"
|
:key="mapping.name"
|
||||||
class="flex items-center gap-2 py-1"
|
class="flex items-center justify-between gap-2 py-1"
|
||||||
>
|
>
|
||||||
<!-- 优先级标签 -->
|
<div class="flex items-center gap-2 flex-1 min-w-0">
|
||||||
<span class="inline-flex items-center justify-center w-5 h-5 rounded bg-background border text-xs font-medium shrink-0">
|
<!-- 优先级标签 -->
|
||||||
{{ mapping.priority }}
|
<span class="inline-flex items-center justify-center w-5 h-5 rounded bg-background border text-xs font-medium shrink-0">
|
||||||
</span>
|
{{ mapping.priority }}
|
||||||
<!-- 映射名称 -->
|
</span>
|
||||||
<span class="font-mono text-sm truncate">
|
<!-- 映射名称 -->
|
||||||
{{ mapping.name }}
|
<span class="font-mono text-sm truncate">
|
||||||
</span>
|
{{ mapping.name }}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<!-- 测试按钮 -->
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
size="icon"
|
||||||
|
class="h-7 w-7 shrink-0"
|
||||||
|
title="测试映射"
|
||||||
|
:disabled="testingMapping === `${group.model.id}-${group.apiFormatsKey}-${mapping.name}`"
|
||||||
|
@click="testMapping(group, mapping)"
|
||||||
|
>
|
||||||
|
<Loader2 v-if="testingMapping === `${group.model.id}-${group.apiFormatsKey}-${mapping.name}`" class="w-3 h-3 animate-spin" />
|
||||||
|
<Play v-else class="w-3 h-3" />
|
||||||
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -166,18 +180,20 @@
|
|||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, computed, onMounted, watch } from 'vue'
|
import { ref, computed, onMounted, watch } from 'vue'
|
||||||
import { Tag, Plus, Edit, Trash2, ChevronRight } from 'lucide-vue-next'
|
import { Tag, Plus, Edit, Trash2, ChevronRight, Loader2, Play } from 'lucide-vue-next'
|
||||||
import { Card, Button, Badge } from '@/components/ui'
|
import { Card, Button, Badge } from '@/components/ui'
|
||||||
import AlertDialog from '@/components/common/AlertDialog.vue'
|
import AlertDialog from '@/components/common/AlertDialog.vue'
|
||||||
import ModelMappingDialog, { type AliasGroup } from '../ModelMappingDialog.vue'
|
import ModelMappingDialog, { type AliasGroup } from '../ModelMappingDialog.vue'
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
import {
|
import {
|
||||||
getProviderModels,
|
getProviderModels,
|
||||||
|
testModel,
|
||||||
API_FORMAT_LABELS,
|
API_FORMAT_LABELS,
|
||||||
type Model,
|
type Model,
|
||||||
type ProviderModelAlias
|
type ProviderModelAlias
|
||||||
} from '@/api/endpoints'
|
} from '@/api/endpoints'
|
||||||
import { updateModel } from '@/api/endpoints/models'
|
import { updateModel } from '@/api/endpoints/models'
|
||||||
|
import { parseTestModelError } from '@/utils/errorParser'
|
||||||
|
|
||||||
const props = defineProps<{
|
const props = defineProps<{
|
||||||
provider: any
|
provider: any
|
||||||
@@ -196,6 +212,7 @@ const dialogOpen = ref(false)
|
|||||||
const deleteConfirmOpen = ref(false)
|
const deleteConfirmOpen = ref(false)
|
||||||
const editingGroup = ref<AliasGroup | null>(null)
|
const editingGroup = ref<AliasGroup | null>(null)
|
||||||
const deletingGroup = ref<AliasGroup | null>(null)
|
const deletingGroup = ref<AliasGroup | null>(null)
|
||||||
|
const testingMapping = ref<string | null>(null)
|
||||||
|
|
||||||
// 列表展开状态
|
// 列表展开状态
|
||||||
const expandedAliasGroups = ref<Set<string>>(new Set())
|
const expandedAliasGroups = ref<Set<string>>(new Set())
|
||||||
@@ -337,6 +354,49 @@ async function onDialogSaved() {
|
|||||||
emit('refresh')
|
emit('refresh')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 测试模型映射
|
||||||
|
async function testMapping(group: any, mapping: any) {
|
||||||
|
const testingKey = `${group.model.id}-${group.apiFormatsKey}-${mapping.name}`
|
||||||
|
testingMapping.value = testingKey
|
||||||
|
|
||||||
|
try {
|
||||||
|
// 根据分组的 API 格式来确定应该使用的格式
|
||||||
|
let apiFormat = null
|
||||||
|
if (group.apiFormats.length === 1) {
|
||||||
|
apiFormat = group.apiFormats[0]
|
||||||
|
} else if (group.apiFormats.length === 0) {
|
||||||
|
// 如果没有指定格式,但分组显示为"全部",则使用模型的默认格式
|
||||||
|
apiFormat = group.model.effective_api_format || group.model.api_format
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = await testModel({
|
||||||
|
provider_id: props.provider.id,
|
||||||
|
model_name: mapping.name, // 使用映射名称进行测试
|
||||||
|
message: "hello",
|
||||||
|
api_format: apiFormat
|
||||||
|
})
|
||||||
|
|
||||||
|
if (result.success) {
|
||||||
|
showSuccess(`映射 "${mapping.name}" 测试成功`)
|
||||||
|
|
||||||
|
// 如果有响应内容,可以显示更多信息
|
||||||
|
if (result.data?.response?.choices?.[0]?.message?.content) {
|
||||||
|
const content = result.data.response.choices[0].message.content
|
||||||
|
showSuccess(`测试成功,响应: ${content.substring(0, 100)}${content.length > 100 ? '...' : ''}`)
|
||||||
|
} else if (result.data?.content_preview) {
|
||||||
|
showSuccess(`流式测试成功,预览: ${result.data.content_preview}`)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
showError(`映射测试失败: ${parseTestModelError(result)}`)
|
||||||
|
}
|
||||||
|
} catch (err: any) {
|
||||||
|
const errorMsg = err.response?.data?.detail || err.message || '测试请求失败'
|
||||||
|
showError(`映射测试失败: ${errorMsg}`)
|
||||||
|
} finally {
|
||||||
|
testingMapping.value = null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 监听 provider 变化
|
// 监听 provider 变化
|
||||||
watch(() => props.provider?.id, (newId) => {
|
watch(() => props.provider?.id, (newId) => {
|
||||||
if (newId) {
|
if (newId) {
|
||||||
@@ -349,4 +409,9 @@ onMounted(() => {
|
|||||||
loadModels()
|
loadModels()
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// 暴露给父组件,用于检测是否有弹窗打开
|
||||||
|
defineExpose({
|
||||||
|
dialogOpen: computed(() => dialogOpen.value || deleteConfirmOpen.value)
|
||||||
|
})
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
@@ -156,6 +156,17 @@
|
|||||||
</td>
|
</td>
|
||||||
<td class="align-top px-4 py-3">
|
<td class="align-top px-4 py-3">
|
||||||
<div class="flex justify-center gap-1.5">
|
<div class="flex justify-center gap-1.5">
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
size="icon"
|
||||||
|
class="h-8 w-8"
|
||||||
|
title="测试模型"
|
||||||
|
:disabled="testingModelId === model.id"
|
||||||
|
@click="testModelConnection(model)"
|
||||||
|
>
|
||||||
|
<Loader2 v-if="testingModelId === model.id" class="w-3.5 h-3.5 animate-spin" />
|
||||||
|
<Play v-else class="w-3.5 h-3.5" />
|
||||||
|
</Button>
|
||||||
<Button
|
<Button
|
||||||
variant="ghost"
|
variant="ghost"
|
||||||
size="icon"
|
size="icon"
|
||||||
@@ -209,12 +220,14 @@
|
|||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, computed, onMounted } from 'vue'
|
import { ref, computed, onMounted } from 'vue'
|
||||||
import { Box, Edit, Trash2, Layers, Eye, Wrench, Zap, Brain, Power, Copy, Image } from 'lucide-vue-next'
|
import { Box, Edit, Trash2, Layers, Eye, Wrench, Zap, Brain, Power, Copy, Image, Loader2, Play } from 'lucide-vue-next'
|
||||||
import Card from '@/components/ui/card.vue'
|
import Card from '@/components/ui/card.vue'
|
||||||
import Button from '@/components/ui/button.vue'
|
import Button from '@/components/ui/button.vue'
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
import { getProviderModels, type Model } from '@/api/endpoints'
|
import { useClipboard } from '@/composables/useClipboard'
|
||||||
|
import { getProviderModels, testModel, type Model } from '@/api/endpoints'
|
||||||
import { updateModel } from '@/api/endpoints/models'
|
import { updateModel } from '@/api/endpoints/models'
|
||||||
|
import { parseTestModelError } from '@/utils/errorParser'
|
||||||
|
|
||||||
const props = defineProps<{
|
const props = defineProps<{
|
||||||
provider: any
|
provider: any
|
||||||
@@ -227,11 +240,13 @@ const emit = defineEmits<{
|
|||||||
}>()
|
}>()
|
||||||
|
|
||||||
const { error: showError, success: showSuccess } = useToast()
|
const { error: showError, success: showSuccess } = useToast()
|
||||||
|
const { copyToClipboard } = useClipboard()
|
||||||
|
|
||||||
// 状态
|
// 状态
|
||||||
const loading = ref(false)
|
const loading = ref(false)
|
||||||
const models = ref<Model[]>([])
|
const models = ref<Model[]>([])
|
||||||
const togglingModelId = ref<string | null>(null)
|
const togglingModelId = ref<string | null>(null)
|
||||||
|
const testingModelId = ref<string | null>(null)
|
||||||
|
|
||||||
// 按名称排序的模型列表
|
// 按名称排序的模型列表
|
||||||
const sortedModels = computed(() => {
|
const sortedModels = computed(() => {
|
||||||
@@ -244,12 +259,7 @@ const sortedModels = computed(() => {
|
|||||||
|
|
||||||
// 复制模型 ID 到剪贴板
|
// 复制模型 ID 到剪贴板
|
||||||
async function copyModelId(modelId: string) {
|
async function copyModelId(modelId: string) {
|
||||||
try {
|
await copyToClipboard(modelId)
|
||||||
await navigator.clipboard.writeText(modelId)
|
|
||||||
showSuccess('已复制到剪贴板')
|
|
||||||
} catch {
|
|
||||||
showError('复制失败', '错误')
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 加载模型
|
// 加载模型
|
||||||
@@ -380,6 +390,39 @@ async function toggleModelActive(model: Model) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 测试模型连接性
|
||||||
|
async function testModelConnection(model: Model) {
|
||||||
|
if (testingModelId.value) return
|
||||||
|
|
||||||
|
testingModelId.value = model.id
|
||||||
|
try {
|
||||||
|
const result = await testModel({
|
||||||
|
provider_id: props.provider.id,
|
||||||
|
model_name: model.provider_model_name,
|
||||||
|
message: "hello"
|
||||||
|
})
|
||||||
|
|
||||||
|
if (result.success) {
|
||||||
|
showSuccess(`模型 "${model.provider_model_name}" 测试成功`)
|
||||||
|
|
||||||
|
// 如果有响应内容,可以显示更多信息
|
||||||
|
if (result.data?.response?.choices?.[0]?.message?.content) {
|
||||||
|
const content = result.data.response.choices[0].message.content
|
||||||
|
showSuccess(`测试成功,响应: ${content.substring(0, 100)}${content.length > 100 ? '...' : ''}`)
|
||||||
|
} else if (result.data?.content_preview) {
|
||||||
|
showSuccess(`流式测试成功,预览: ${result.data.content_preview}`)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
showError(`模型测试失败: ${parseTestModelError(result)}`)
|
||||||
|
}
|
||||||
|
} catch (err: any) {
|
||||||
|
const errorMsg = err.response?.data?.detail || err.message || '测试请求失败'
|
||||||
|
showError(`模型测试失败: ${errorMsg}`)
|
||||||
|
} finally {
|
||||||
|
testingModelId.value = null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
onMounted(() => {
|
onMounted(() => {
|
||||||
loadModels()
|
loadModels()
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -473,6 +473,7 @@
|
|||||||
import { ref, watch, computed } from 'vue'
|
import { ref, watch, computed } from 'vue'
|
||||||
import Button from '@/components/ui/button.vue'
|
import Button from '@/components/ui/button.vue'
|
||||||
import { useEscapeKey } from '@/composables/useEscapeKey'
|
import { useEscapeKey } from '@/composables/useEscapeKey'
|
||||||
|
import { useClipboard } from '@/composables/useClipboard'
|
||||||
import Card from '@/components/ui/card.vue'
|
import Card from '@/components/ui/card.vue'
|
||||||
import Badge from '@/components/ui/badge.vue'
|
import Badge from '@/components/ui/badge.vue'
|
||||||
import Separator from '@/components/ui/separator.vue'
|
import Separator from '@/components/ui/separator.vue'
|
||||||
@@ -505,6 +506,7 @@ const copiedStates = ref<Record<string, boolean>>({})
|
|||||||
const viewMode = ref<'compare' | 'formatted' | 'raw'>('compare')
|
const viewMode = ref<'compare' | 'formatted' | 'raw'>('compare')
|
||||||
const currentExpandDepth = ref(1)
|
const currentExpandDepth = ref(1)
|
||||||
const dataSource = ref<'client' | 'provider'>('client')
|
const dataSource = ref<'client' | 'provider'>('client')
|
||||||
|
const { copyToClipboard } = useClipboard()
|
||||||
const historicalPricing = ref<{
|
const historicalPricing = ref<{
|
||||||
input_price: string
|
input_price: string
|
||||||
output_price: string
|
output_price: string
|
||||||
@@ -784,7 +786,7 @@ function copyJsonToClipboard(tabName: string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (data) {
|
if (data) {
|
||||||
navigator.clipboard.writeText(JSON.stringify(data, null, 2))
|
copyToClipboard(JSON.stringify(data, null, 2), false)
|
||||||
copiedStates.value[tabName] = true
|
copiedStates.value[tabName] = true
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
copiedStates.value[tabName] = false
|
copiedStates.value[tabName] = false
|
||||||
|
|||||||
@@ -366,14 +366,34 @@
|
|||||||
</div>
|
</div>
|
||||||
</TableCell>
|
</TableCell>
|
||||||
<TableCell class="text-right py-4 w-[70px]">
|
<TableCell class="text-right py-4 w-[70px]">
|
||||||
|
<!-- pending 状态:只显示增长的总时间 -->
|
||||||
<div
|
<div
|
||||||
v-if="record.status === 'pending' || record.status === 'streaming'"
|
v-if="record.status === 'pending'"
|
||||||
class="flex flex-col items-end text-xs gap-0.5"
|
class="flex flex-col items-end text-xs gap-0.5"
|
||||||
>
|
>
|
||||||
|
<span class="text-muted-foreground">-</span>
|
||||||
<span class="text-primary tabular-nums">
|
<span class="text-primary tabular-nums">
|
||||||
{{ getElapsedTime(record) }}
|
{{ getElapsedTime(record) }}
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
|
<!-- streaming 状态:首字固定 + 总时间增长 -->
|
||||||
|
<div
|
||||||
|
v-else-if="record.status === 'streaming'"
|
||||||
|
class="flex flex-col items-end text-xs gap-0.5"
|
||||||
|
>
|
||||||
|
<span
|
||||||
|
v-if="record.first_byte_time_ms != null"
|
||||||
|
class="tabular-nums"
|
||||||
|
>{{ (record.first_byte_time_ms / 1000).toFixed(2) }}s</span>
|
||||||
|
<span
|
||||||
|
v-else
|
||||||
|
class="text-muted-foreground"
|
||||||
|
>-</span>
|
||||||
|
<span class="text-primary tabular-nums">
|
||||||
|
{{ getElapsedTime(record) }}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<!-- 已完成状态:首字 + 总耗时 -->
|
||||||
<div
|
<div
|
||||||
v-else-if="record.response_time_ms != null"
|
v-else-if="record.response_time_ms != null"
|
||||||
class="flex flex-col items-end text-xs gap-0.5"
|
class="flex flex-col items-end text-xs gap-0.5"
|
||||||
|
|||||||
@@ -86,6 +86,34 @@
|
|||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div
|
||||||
|
v-if="isEditMode && form.password.length > 0"
|
||||||
|
class="space-y-2"
|
||||||
|
>
|
||||||
|
<Label class="text-sm font-medium">
|
||||||
|
确认新密码 <span class="text-muted-foreground">*</span>
|
||||||
|
</Label>
|
||||||
|
<Input
|
||||||
|
:id="`pwd-confirm-${formNonce}`"
|
||||||
|
v-model="form.confirmPassword"
|
||||||
|
type="password"
|
||||||
|
autocomplete="new-password"
|
||||||
|
data-form-type="other"
|
||||||
|
data-lpignore="true"
|
||||||
|
:name="`confirm-${formNonce}`"
|
||||||
|
required
|
||||||
|
minlength="6"
|
||||||
|
placeholder="再次输入新密码"
|
||||||
|
class="h-10"
|
||||||
|
/>
|
||||||
|
<p
|
||||||
|
v-if="form.confirmPassword.length > 0 && form.password !== form.confirmPassword"
|
||||||
|
class="text-xs text-destructive"
|
||||||
|
>
|
||||||
|
两次输入的密码不一致
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
<div class="space-y-2">
|
<div class="space-y-2">
|
||||||
<Label
|
<Label
|
||||||
for="form-email"
|
for="form-email"
|
||||||
@@ -423,6 +451,7 @@ const apiFormats = ref<Array<{ value: string; label: string }>>([])
|
|||||||
const form = ref({
|
const form = ref({
|
||||||
username: '',
|
username: '',
|
||||||
password: '',
|
password: '',
|
||||||
|
confirmPassword: '',
|
||||||
email: '',
|
email: '',
|
||||||
quota: 10,
|
quota: 10,
|
||||||
role: 'user' as 'admin' | 'user',
|
role: 'user' as 'admin' | 'user',
|
||||||
@@ -443,6 +472,7 @@ function resetForm() {
|
|||||||
form.value = {
|
form.value = {
|
||||||
username: '',
|
username: '',
|
||||||
password: '',
|
password: '',
|
||||||
|
confirmPassword: '',
|
||||||
email: '',
|
email: '',
|
||||||
quota: 10,
|
quota: 10,
|
||||||
role: 'user',
|
role: 'user',
|
||||||
@@ -461,6 +491,7 @@ function loadUserData() {
|
|||||||
form.value = {
|
form.value = {
|
||||||
username: props.user.username,
|
username: props.user.username,
|
||||||
password: '',
|
password: '',
|
||||||
|
confirmPassword: '',
|
||||||
email: props.user.email || '',
|
email: props.user.email || '',
|
||||||
quota: props.user.quota_usd == null ? 10 : props.user.quota_usd,
|
quota: props.user.quota_usd == null ? 10 : props.user.quota_usd,
|
||||||
role: props.user.role,
|
role: props.user.role,
|
||||||
@@ -486,7 +517,9 @@ const isFormValid = computed(() => {
|
|||||||
const hasUsername = form.value.username.trim().length > 0
|
const hasUsername = form.value.username.trim().length > 0
|
||||||
const hasEmail = form.value.email.trim().length > 0
|
const hasEmail = form.value.email.trim().length > 0
|
||||||
const hasPassword = isEditMode.value || form.value.password.length >= 6
|
const hasPassword = isEditMode.value || form.value.password.length >= 6
|
||||||
return hasUsername && hasEmail && hasPassword
|
// 编辑模式下如果填写了密码,必须确认密码一致
|
||||||
|
const passwordConfirmed = !isEditMode.value || form.value.password.length === 0 || form.value.password === form.value.confirmPassword
|
||||||
|
return hasUsername && hasEmail && hasPassword && passwordConfirmed
|
||||||
})
|
})
|
||||||
|
|
||||||
// 加载访问控制选项
|
// 加载访问控制选项
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ export const useUsersStore = defineStore('users', () => {
|
|||||||
try {
|
try {
|
||||||
users.value = await usersApi.getAllUsers()
|
users.value = await usersApi.getAllUsers()
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error.value = err.response?.data?.detail || '获取用户列表失败'
|
error.value = err.response?.data?.error?.message || err.response?.data?.detail || '获取用户列表失败'
|
||||||
} finally {
|
} finally {
|
||||||
loading.value = false
|
loading.value = false
|
||||||
}
|
}
|
||||||
@@ -29,7 +29,7 @@ export const useUsersStore = defineStore('users', () => {
|
|||||||
users.value.push(newUser)
|
users.value.push(newUser)
|
||||||
return newUser
|
return newUser
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error.value = err.response?.data?.detail || '创建用户失败'
|
error.value = err.response?.data?.error?.message || err.response?.data?.detail || '创建用户失败'
|
||||||
throw err
|
throw err
|
||||||
} finally {
|
} finally {
|
||||||
loading.value = false
|
loading.value = false
|
||||||
@@ -52,7 +52,7 @@ export const useUsersStore = defineStore('users', () => {
|
|||||||
}
|
}
|
||||||
return updatedUser
|
return updatedUser
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error.value = err.response?.data?.detail || '更新用户失败'
|
error.value = err.response?.data?.error?.message || err.response?.data?.detail || '更新用户失败'
|
||||||
throw err
|
throw err
|
||||||
} finally {
|
} finally {
|
||||||
loading.value = false
|
loading.value = false
|
||||||
@@ -67,7 +67,7 @@ export const useUsersStore = defineStore('users', () => {
|
|||||||
await usersApi.deleteUser(userId)
|
await usersApi.deleteUser(userId)
|
||||||
users.value = users.value.filter(u => u.id !== userId)
|
users.value = users.value.filter(u => u.id !== userId)
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error.value = err.response?.data?.detail || '删除用户失败'
|
error.value = err.response?.data?.error?.message || err.response?.data?.detail || '删除用户失败'
|
||||||
throw err
|
throw err
|
||||||
} finally {
|
} finally {
|
||||||
loading.value = false
|
loading.value = false
|
||||||
@@ -78,7 +78,7 @@ export const useUsersStore = defineStore('users', () => {
|
|||||||
try {
|
try {
|
||||||
return await usersApi.getUserApiKeys(userId)
|
return await usersApi.getUserApiKeys(userId)
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error.value = err.response?.data?.detail || '获取 API Keys 失败'
|
error.value = err.response?.data?.error?.message || err.response?.data?.detail || '获取 API Keys 失败'
|
||||||
throw err
|
throw err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -87,7 +87,7 @@ export const useUsersStore = defineStore('users', () => {
|
|||||||
try {
|
try {
|
||||||
return await usersApi.createApiKey(userId, name)
|
return await usersApi.createApiKey(userId, name)
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error.value = err.response?.data?.detail || '创建 API Key 失败'
|
error.value = err.response?.data?.error?.message || err.response?.data?.detail || '创建 API Key 失败'
|
||||||
throw err
|
throw err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -96,7 +96,7 @@ export const useUsersStore = defineStore('users', () => {
|
|||||||
try {
|
try {
|
||||||
await usersApi.deleteApiKey(userId, keyId)
|
await usersApi.deleteApiKey(userId, keyId)
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error.value = err.response?.data?.detail || '删除 API Key 失败'
|
error.value = err.response?.data?.error?.message || err.response?.data?.detail || '删除 API Key 失败'
|
||||||
throw err
|
throw err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -110,7 +110,7 @@ export const useUsersStore = defineStore('users', () => {
|
|||||||
// 刷新用户列表以获取最新数据
|
// 刷新用户列表以获取最新数据
|
||||||
await fetchUsers()
|
await fetchUsers()
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error.value = err.response?.data?.detail || '重置配额失败'
|
error.value = err.response?.data?.error?.message || err.response?.data?.detail || '重置配额失败'
|
||||||
throw err
|
throw err
|
||||||
} finally {
|
} finally {
|
||||||
loading.value = false
|
loading.value = false
|
||||||
|
|||||||
@@ -198,3 +198,49 @@ export function parseApiErrorShort(err: unknown, defaultMessage: string = '操
|
|||||||
const lines = fullError.split('\n')
|
const lines = fullError.split('\n')
|
||||||
return lines[0] || defaultMessage
|
return lines[0] || defaultMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 解析模型测试响应的错误信息
|
||||||
|
* @param result 测试响应结果
|
||||||
|
* @returns 格式化的错误信息
|
||||||
|
*/
|
||||||
|
export function parseTestModelError(result: {
|
||||||
|
error?: string
|
||||||
|
data?: {
|
||||||
|
response?: {
|
||||||
|
status_code?: number
|
||||||
|
error?: string | { message?: string }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}): string {
|
||||||
|
let errorMsg = result.error || '测试失败'
|
||||||
|
|
||||||
|
// 检查HTTP状态码错误
|
||||||
|
if (result.data?.response?.status_code) {
|
||||||
|
const status = result.data.response.status_code
|
||||||
|
if (status === 403) {
|
||||||
|
errorMsg = '认证失败: API密钥无效或客户端类型不被允许'
|
||||||
|
} else if (status === 401) {
|
||||||
|
errorMsg = '认证失败: API密钥无效或已过期'
|
||||||
|
} else if (status === 404) {
|
||||||
|
errorMsg = '模型不存在: 请检查模型名称是否正确'
|
||||||
|
} else if (status === 429) {
|
||||||
|
errorMsg = '请求频率过高: 请稍后重试'
|
||||||
|
} else if (status >= 500) {
|
||||||
|
errorMsg = `服务器错误: HTTP ${status}`
|
||||||
|
} else {
|
||||||
|
errorMsg = `请求失败: HTTP ${status}`
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试从错误响应中提取更多信息
|
||||||
|
if (result.data?.response?.error) {
|
||||||
|
if (typeof result.data.response.error === 'string') {
|
||||||
|
errorMsg = result.data.response.error
|
||||||
|
} else if (result.data.response.error?.message) {
|
||||||
|
errorMsg = result.data.response.error.message
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return errorMsg
|
||||||
|
}
|
||||||
|
|||||||
@@ -650,6 +650,7 @@
|
|||||||
import { ref, computed, onMounted } from 'vue'
|
import { ref, computed, onMounted } from 'vue'
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
import { useConfirm } from '@/composables/useConfirm'
|
import { useConfirm } from '@/composables/useConfirm'
|
||||||
|
import { useClipboard } from '@/composables/useClipboard'
|
||||||
import { adminApi, type AdminApiKey, type CreateStandaloneApiKeyRequest } from '@/api/admin'
|
import { adminApi, type AdminApiKey, type CreateStandaloneApiKeyRequest } from '@/api/admin'
|
||||||
|
|
||||||
import {
|
import {
|
||||||
@@ -693,6 +694,7 @@ import { log } from '@/utils/logger'
|
|||||||
|
|
||||||
const { success, error } = useToast()
|
const { success, error } = useToast()
|
||||||
const { confirmDanger } = useConfirm()
|
const { confirmDanger } = useConfirm()
|
||||||
|
const { copyToClipboard } = useClipboard()
|
||||||
|
|
||||||
const apiKeys = ref<AdminApiKey[]>([])
|
const apiKeys = ref<AdminApiKey[]>([])
|
||||||
const loading = ref(false)
|
const loading = ref(false)
|
||||||
@@ -927,20 +929,14 @@ function selectKey() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async function copyKey() {
|
async function copyKey() {
|
||||||
try {
|
await copyToClipboard(newKeyValue.value)
|
||||||
await navigator.clipboard.writeText(newKeyValue.value)
|
|
||||||
success('API Key 已复制到剪贴板')
|
|
||||||
} catch {
|
|
||||||
error('复制失败,请手动复制')
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async function copyKeyPrefix(apiKey: AdminApiKey) {
|
async function copyKeyPrefix(apiKey: AdminApiKey) {
|
||||||
try {
|
try {
|
||||||
// 调用后端 API 获取完整密钥
|
// 调用后端 API 获取完整密钥
|
||||||
const response = await adminApi.getFullApiKey(apiKey.id)
|
const response = await adminApi.getFullApiKey(apiKey.id)
|
||||||
await navigator.clipboard.writeText(response.key)
|
await copyToClipboard(response.key)
|
||||||
success('完整密钥已复制到剪贴板')
|
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
log.error('复制密钥失败:', err)
|
log.error('复制密钥失败:', err)
|
||||||
error('复制失败,请重试')
|
error('复制失败,请重试')
|
||||||
@@ -1046,9 +1042,10 @@ async function handleKeyFormSubmit(data: StandaloneKeyFormData) {
|
|||||||
rate_limit: data.rate_limit,
|
rate_limit: data.rate_limit,
|
||||||
expire_days: data.never_expire ? null : (data.expire_days || null),
|
expire_days: data.never_expire ? null : (data.expire_days || null),
|
||||||
auto_delete_on_expiry: data.auto_delete_on_expiry,
|
auto_delete_on_expiry: data.auto_delete_on_expiry,
|
||||||
allowed_providers: data.allowed_providers.length > 0 ? data.allowed_providers : undefined,
|
// 空数组表示清除限制(允许全部),后端会将空数组存为 NULL
|
||||||
allowed_api_formats: data.allowed_api_formats.length > 0 ? data.allowed_api_formats : undefined,
|
allowed_providers: data.allowed_providers,
|
||||||
allowed_models: data.allowed_models.length > 0 ? data.allowed_models : undefined
|
allowed_api_formats: data.allowed_api_formats,
|
||||||
|
allowed_models: data.allowed_models
|
||||||
}
|
}
|
||||||
await adminApi.updateApiKey(data.id, updateData)
|
await adminApi.updateApiKey(data.id, updateData)
|
||||||
success('API Key 更新成功')
|
success('API Key 更新成功')
|
||||||
@@ -1064,9 +1061,10 @@ async function handleKeyFormSubmit(data: StandaloneKeyFormData) {
|
|||||||
rate_limit: data.rate_limit,
|
rate_limit: data.rate_limit,
|
||||||
expire_days: data.never_expire ? null : (data.expire_days || null),
|
expire_days: data.never_expire ? null : (data.expire_days || null),
|
||||||
auto_delete_on_expiry: data.auto_delete_on_expiry,
|
auto_delete_on_expiry: data.auto_delete_on_expiry,
|
||||||
allowed_providers: data.allowed_providers.length > 0 ? data.allowed_providers : undefined,
|
// 空数组表示不设置限制(允许全部),后端会将空数组存为 NULL
|
||||||
allowed_api_formats: data.allowed_api_formats.length > 0 ? data.allowed_api_formats : undefined,
|
allowed_providers: data.allowed_providers,
|
||||||
allowed_models: data.allowed_models.length > 0 ? data.allowed_models : undefined
|
allowed_api_formats: data.allowed_api_formats,
|
||||||
|
allowed_models: data.allowed_models
|
||||||
}
|
}
|
||||||
const response = await adminApi.createStandaloneApiKey(createData)
|
const response = await adminApi.createStandaloneApiKey(createData)
|
||||||
newKeyValue.value = response.key
|
newKeyValue.value = response.key
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ const clearingRowAffinityKey = ref<string | null>(null)
|
|||||||
const currentPage = ref(1)
|
const currentPage = ref(1)
|
||||||
const pageSize = ref(20)
|
const pageSize = ref(20)
|
||||||
const currentTime = ref(Math.floor(Date.now() / 1000))
|
const currentTime = ref(Math.floor(Date.now() / 1000))
|
||||||
|
const analysisHoursSelectOpen = ref(false)
|
||||||
|
|
||||||
// ==================== 模型映射缓存 ====================
|
// ==================== 模型映射缓存 ====================
|
||||||
|
|
||||||
@@ -1056,7 +1057,7 @@ onBeforeUnmount(() => {
|
|||||||
<span class="text-xs text-muted-foreground hidden sm:inline">分析用户请求间隔,推荐合适的缓存 TTL</span>
|
<span class="text-xs text-muted-foreground hidden sm:inline">分析用户请求间隔,推荐合适的缓存 TTL</span>
|
||||||
</div>
|
</div>
|
||||||
<div class="flex flex-wrap items-center gap-2">
|
<div class="flex flex-wrap items-center gap-2">
|
||||||
<Select v-model="analysisHours">
|
<Select v-model="analysisHours" v-model:open="analysisHoursSelectOpen">
|
||||||
<SelectTrigger class="w-24 sm:w-28 h-8">
|
<SelectTrigger class="w-24 sm:w-28 h-8">
|
||||||
<SelectValue placeholder="时间段" />
|
<SelectValue placeholder="时间段" />
|
||||||
</SelectTrigger>
|
</SelectTrigger>
|
||||||
|
|||||||
@@ -713,6 +713,7 @@ import ProviderModelFormDialog from '@/features/providers/components/ProviderMod
|
|||||||
import type { Model } from '@/api/endpoints'
|
import type { Model } from '@/api/endpoints'
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
import { useConfirm } from '@/composables/useConfirm'
|
import { useConfirm } from '@/composables/useConfirm'
|
||||||
|
import { useClipboard } from '@/composables/useClipboard'
|
||||||
import { useRowClick } from '@/composables/useRowClick'
|
import { useRowClick } from '@/composables/useRowClick'
|
||||||
import { parseApiError } from '@/utils/errorParser'
|
import { parseApiError } from '@/utils/errorParser'
|
||||||
import {
|
import {
|
||||||
@@ -743,6 +744,7 @@ import { getProvidersSummary } from '@/api/endpoints/providers'
|
|||||||
import { getAllCapabilities, type CapabilityDefinition } from '@/api/endpoints'
|
import { getAllCapabilities, type CapabilityDefinition } from '@/api/endpoints'
|
||||||
|
|
||||||
const { success, error: showError } = useToast()
|
const { success, error: showError } = useToast()
|
||||||
|
const { copyToClipboard } = useClipboard()
|
||||||
|
|
||||||
// 状态
|
// 状态
|
||||||
const loading = ref(false)
|
const loading = ref(false)
|
||||||
@@ -1066,16 +1068,6 @@ function handleRowClick(event: MouseEvent, model: GlobalModelResponse) {
|
|||||||
selectModel(model)
|
selectModel(model)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 复制到剪贴板
|
|
||||||
async function copyToClipboard(text: string) {
|
|
||||||
try {
|
|
||||||
await navigator.clipboard.writeText(text)
|
|
||||||
success('已复制')
|
|
||||||
} catch {
|
|
||||||
showError('复制失败')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async function selectModel(model: GlobalModelResponse) {
|
async function selectModel(model: GlobalModelResponse) {
|
||||||
selectedModel.value = model
|
selectedModel.value = model
|
||||||
detailTab.value = 'basic'
|
detailTab.value = 'basic'
|
||||||
|
|||||||
@@ -723,9 +723,19 @@ async function handleDeleteProvider(provider: ProviderWithEndpointsSummary) {
|
|||||||
// 切换提供商状态
|
// 切换提供商状态
|
||||||
async function toggleProviderStatus(provider: ProviderWithEndpointsSummary) {
|
async function toggleProviderStatus(provider: ProviderWithEndpointsSummary) {
|
||||||
try {
|
try {
|
||||||
await updateProvider(provider.id, { is_active: !provider.is_active })
|
const newStatus = !provider.is_active
|
||||||
provider.is_active = !provider.is_active
|
await updateProvider(provider.id, { is_active: newStatus })
|
||||||
showSuccess(provider.is_active ? '提供商已启用' : '提供商已停用')
|
|
||||||
|
// 更新抽屉内部的 provider 对象
|
||||||
|
provider.is_active = newStatus
|
||||||
|
|
||||||
|
// 同时更新主页面 providers 数组中的对象,实现无感更新
|
||||||
|
const targetProvider = providers.value.find(p => p.id === provider.id)
|
||||||
|
if (targetProvider) {
|
||||||
|
targetProvider.is_active = newStatus
|
||||||
|
}
|
||||||
|
|
||||||
|
showSuccess(newStatus ? '提供商已启用' : '提供商已停用')
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
showError(err.response?.data?.detail || '操作失败', '错误')
|
showError(err.response?.data?.detail || '操作失败', '错误')
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -701,6 +701,7 @@ import { ref, computed, onMounted, watch } from 'vue'
|
|||||||
import { useUsersStore } from '@/stores/users'
|
import { useUsersStore } from '@/stores/users'
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
import { useConfirm } from '@/composables/useConfirm'
|
import { useConfirm } from '@/composables/useConfirm'
|
||||||
|
import { useClipboard } from '@/composables/useClipboard'
|
||||||
import { usageApi, type UsageByUser } from '@/api/usage'
|
import { usageApi, type UsageByUser } from '@/api/usage'
|
||||||
import { adminApi } from '@/api/admin'
|
import { adminApi } from '@/api/admin'
|
||||||
|
|
||||||
@@ -748,6 +749,7 @@ import { log } from '@/utils/logger'
|
|||||||
|
|
||||||
const { success, error } = useToast()
|
const { success, error } = useToast()
|
||||||
const { confirmDanger, confirmWarning } = useConfirm()
|
const { confirmDanger, confirmWarning } = useConfirm()
|
||||||
|
const { copyToClipboard } = useClipboard()
|
||||||
const usersStore = useUsersStore()
|
const usersStore = useUsersStore()
|
||||||
|
|
||||||
// 用户表单对话框状态
|
// 用户表单对话框状态
|
||||||
@@ -875,7 +877,8 @@ async function toggleUserStatus(user: any) {
|
|||||||
const action = user.is_active ? '禁用' : '启用'
|
const action = user.is_active ? '禁用' : '启用'
|
||||||
const confirmed = await confirmDanger(
|
const confirmed = await confirmDanger(
|
||||||
`确定要${action}用户 ${user.username} 吗?`,
|
`确定要${action}用户 ${user.username} 吗?`,
|
||||||
`${action}用户`
|
`${action}用户`,
|
||||||
|
action
|
||||||
)
|
)
|
||||||
|
|
||||||
if (!confirmed) return
|
if (!confirmed) return
|
||||||
@@ -884,7 +887,7 @@ async function toggleUserStatus(user: any) {
|
|||||||
await usersStore.updateUser(user.id, { is_active: !user.is_active })
|
await usersStore.updateUser(user.id, { is_active: !user.is_active })
|
||||||
success(`用户已${action}`)
|
success(`用户已${action}`)
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error(err.response?.data?.detail || '未知错误', `${action}用户失败`)
|
error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', `${action}用户失败`)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -955,7 +958,7 @@ async function handleUserFormSubmit(data: UserFormData & { password?: string })
|
|||||||
closeUserFormDialog()
|
closeUserFormDialog()
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
const title = data.id ? '更新用户失败' : '创建用户失败'
|
const title = data.id ? '更新用户失败' : '创建用户失败'
|
||||||
error(err.response?.data?.detail || '未知错误', title)
|
error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', title)
|
||||||
} finally {
|
} finally {
|
||||||
userFormDialogRef.value?.setSaving(false)
|
userFormDialogRef.value?.setSaving(false)
|
||||||
}
|
}
|
||||||
@@ -989,7 +992,7 @@ async function createApiKey() {
|
|||||||
showNewApiKeyDialog.value = true
|
showNewApiKeyDialog.value = true
|
||||||
await loadUserApiKeys(selectedUser.value.id)
|
await loadUserApiKeys(selectedUser.value.id)
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error(err.response?.data?.detail || '未知错误', '创建 API Key 失败')
|
error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', '创建 API Key 失败')
|
||||||
} finally {
|
} finally {
|
||||||
creatingApiKey.value = false
|
creatingApiKey.value = false
|
||||||
}
|
}
|
||||||
@@ -1000,12 +1003,7 @@ function selectApiKey() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async function copyApiKey() {
|
async function copyApiKey() {
|
||||||
try {
|
await copyToClipboard(newApiKey.value)
|
||||||
await navigator.clipboard.writeText(newApiKey.value)
|
|
||||||
success('API Key已复制到剪贴板')
|
|
||||||
} catch {
|
|
||||||
error('复制失败,请手动复制')
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async function closeNewApiKeyDialog() {
|
async function closeNewApiKeyDialog() {
|
||||||
@@ -1026,7 +1024,7 @@ async function deleteApiKey(apiKey: any) {
|
|||||||
await loadUserApiKeys(selectedUser.value.id)
|
await loadUserApiKeys(selectedUser.value.id)
|
||||||
success('API Key已删除')
|
success('API Key已删除')
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error(err.response?.data?.detail || '未知错误', '删除 API Key 失败')
|
error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', '删除 API Key 失败')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1034,11 +1032,10 @@ async function copyFullKey(apiKey: any) {
|
|||||||
try {
|
try {
|
||||||
// 调用后端 API 获取完整密钥
|
// 调用后端 API 获取完整密钥
|
||||||
const response = await adminApi.getFullApiKey(apiKey.id)
|
const response = await adminApi.getFullApiKey(apiKey.id)
|
||||||
await navigator.clipboard.writeText(response.key)
|
await copyToClipboard(response.key)
|
||||||
success('完整密钥已复制到剪贴板')
|
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
log.error('复制密钥失败:', err)
|
log.error('复制密钥失败:', err)
|
||||||
error(err.response?.data?.detail || '未知错误', '复制密钥失败')
|
error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', '复制密钥失败')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1054,7 +1051,7 @@ async function resetQuota(user: any) {
|
|||||||
await usersStore.resetUserQuota(user.id)
|
await usersStore.resetUserQuota(user.id)
|
||||||
success('配额已重置')
|
success('配额已重置')
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error(err.response?.data?.detail || '未知错误', '重置配额失败')
|
error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', '重置配额失败')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1070,7 +1067,7 @@ async function deleteUser(user: any) {
|
|||||||
await usersStore.deleteUser(user.id)
|
await usersStore.deleteUser(user.id)
|
||||||
success('用户已删除')
|
success('用户已删除')
|
||||||
} catch (err: any) {
|
} catch (err: any) {
|
||||||
error(err.response?.data?.detail || '未知错误', '删除用户失败')
|
error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', '删除用户失败')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
</script>
|
</script>
|
||||||
|
|||||||
@@ -102,9 +102,9 @@
|
|||||||
<!-- Main Content -->
|
<!-- Main Content -->
|
||||||
<main class="relative z-10">
|
<main class="relative z-10">
|
||||||
<!-- Fixed Logo Container -->
|
<!-- Fixed Logo Container -->
|
||||||
<div class="fixed inset-0 z-20 pointer-events-none flex items-center justify-center overflow-hidden">
|
<div class="mt-4 fixed inset-0 z-20 pointer-events-none flex items-center justify-center overflow-hidden">
|
||||||
<div
|
<div
|
||||||
class="transform-gpu logo-container"
|
class="mt-16 transform-gpu logo-container"
|
||||||
:class="[currentSection === SECTIONS.HOME ? 'home-section' : '', `logo-transition-${scrollDirection}`]"
|
:class="[currentSection === SECTIONS.HOME ? 'home-section' : '', `logo-transition-${scrollDirection}`]"
|
||||||
:style="fixedLogoStyle"
|
:style="fixedLogoStyle"
|
||||||
>
|
>
|
||||||
@@ -151,7 +151,7 @@
|
|||||||
class="min-h-screen snap-start flex items-center justify-center px-16 lg:px-20 py-20"
|
class="min-h-screen snap-start flex items-center justify-center px-16 lg:px-20 py-20"
|
||||||
>
|
>
|
||||||
<div class="max-w-4xl mx-auto text-center">
|
<div class="max-w-4xl mx-auto text-center">
|
||||||
<div class="h-80 w-full mb-16" />
|
<div class="h-80 w-full mb-16 mt-8" />
|
||||||
<h1
|
<h1
|
||||||
class="mb-6 text-5xl md:text-7xl font-bold text-[#191919] dark:text-white leading-tight transition-all duration-700"
|
class="mb-6 text-5xl md:text-7xl font-bold text-[#191919] dark:text-white leading-tight transition-all duration-700"
|
||||||
:style="getTitleStyle(SECTIONS.HOME)"
|
:style="getTitleStyle(SECTIONS.HOME)"
|
||||||
@@ -166,7 +166,7 @@
|
|||||||
整合 Claude Code、Codex CLI、Gemini CLI 等多个 AI 编程助手
|
整合 Claude Code、Codex CLI、Gemini CLI 等多个 AI 编程助手
|
||||||
</p>
|
</p>
|
||||||
<button
|
<button
|
||||||
class="mt-16 transition-all duration-700 cursor-pointer hover:scale-110"
|
class="mt-8 transition-all duration-700 cursor-pointer hover:scale-110"
|
||||||
:style="getScrollIndicatorStyle(SECTIONS.HOME)"
|
:style="getScrollIndicatorStyle(SECTIONS.HOME)"
|
||||||
@click="scrollToSection(SECTIONS.CLAUDE)"
|
@click="scrollToSection(SECTIONS.CLAUDE)"
|
||||||
>
|
>
|
||||||
|
|||||||
@@ -301,6 +301,7 @@ function stopGlobalAutoRefresh() {
|
|||||||
function handleAutoRefreshChange(value: boolean) {
|
function handleAutoRefreshChange(value: boolean) {
|
||||||
globalAutoRefresh.value = value
|
globalAutoRefresh.value = value
|
||||||
if (value) {
|
if (value) {
|
||||||
|
refreshData() // 立即刷新一次
|
||||||
startGlobalAutoRefresh()
|
startGlobalAutoRefresh()
|
||||||
} else {
|
} else {
|
||||||
stopGlobalAutoRefresh()
|
stopGlobalAutoRefresh()
|
||||||
|
|||||||
@@ -342,6 +342,7 @@ import {
|
|||||||
Plus,
|
Plus,
|
||||||
} from 'lucide-vue-next'
|
} from 'lucide-vue-next'
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
|
import { useClipboard } from '@/composables/useClipboard'
|
||||||
import {
|
import {
|
||||||
Card,
|
Card,
|
||||||
Table,
|
Table,
|
||||||
@@ -370,6 +371,7 @@ import { useRowClick } from '@/composables/useRowClick'
|
|||||||
import { log } from '@/utils/logger'
|
import { log } from '@/utils/logger'
|
||||||
|
|
||||||
const { success, error: showError } = useToast()
|
const { success, error: showError } = useToast()
|
||||||
|
const { copyToClipboard } = useClipboard()
|
||||||
|
|
||||||
// 状态
|
// 状态
|
||||||
const loading = ref(false)
|
const loading = ref(false)
|
||||||
@@ -565,16 +567,6 @@ function hasTieredPricing(model: PublicGlobalModel): boolean {
|
|||||||
return (tiered?.tiers?.length || 0) > 1
|
return (tiered?.tiers?.length || 0) > 1
|
||||||
}
|
}
|
||||||
|
|
||||||
async function copyToClipboard(text: string) {
|
|
||||||
try {
|
|
||||||
await navigator.clipboard.writeText(text)
|
|
||||||
success('已复制')
|
|
||||||
} catch (err) {
|
|
||||||
log.error('复制失败:', err)
|
|
||||||
showError('复制失败')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
onMounted(() => {
|
onMounted(() => {
|
||||||
refreshData()
|
refreshData()
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -352,6 +352,7 @@ import {
|
|||||||
} from 'lucide-vue-next'
|
} from 'lucide-vue-next'
|
||||||
import { useEscapeKey } from '@/composables/useEscapeKey'
|
import { useEscapeKey } from '@/composables/useEscapeKey'
|
||||||
import { useToast } from '@/composables/useToast'
|
import { useToast } from '@/composables/useToast'
|
||||||
|
import { useClipboard } from '@/composables/useClipboard'
|
||||||
import Card from '@/components/ui/card.vue'
|
import Card from '@/components/ui/card.vue'
|
||||||
import Badge from '@/components/ui/badge.vue'
|
import Badge from '@/components/ui/badge.vue'
|
||||||
import Button from '@/components/ui/button.vue'
|
import Button from '@/components/ui/button.vue'
|
||||||
@@ -375,6 +376,7 @@ const emit = defineEmits<{
|
|||||||
}>()
|
}>()
|
||||||
|
|
||||||
const { success: showSuccess, error: showError } = useToast()
|
const { success: showSuccess, error: showError } = useToast()
|
||||||
|
const { copyToClipboard } = useClipboard()
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
model: PublicGlobalModel | null
|
model: PublicGlobalModel | null
|
||||||
@@ -408,15 +410,6 @@ function handleClose() {
|
|||||||
emit('update:open', false)
|
emit('update:open', false)
|
||||||
}
|
}
|
||||||
|
|
||||||
async function copyToClipboard(text: string) {
|
|
||||||
try {
|
|
||||||
await navigator.clipboard.writeText(text)
|
|
||||||
showSuccess('已复制')
|
|
||||||
} catch {
|
|
||||||
showError('复制失败')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function getFirstTierPrice(
|
function getFirstTierPrice(
|
||||||
tieredPricing: TieredPricingConfig | undefined | null,
|
tieredPricing: TieredPricingConfig | undefined | null,
|
||||||
priceKey: 'input_price_per_1m' | 'output_price_per_1m' | 'cache_creation_price_per_1m' | 'cache_read_price_per_1m'
|
priceKey: 'input_price_per_1m' | 'output_price_per_1m' | 'cache_creation_price_per_1m' | 'cache_read_price_per_1m'
|
||||||
|
|||||||
@@ -246,6 +246,15 @@ class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
|
|||||||
if "api_key" in update_data:
|
if "api_key" in update_data:
|
||||||
update_data["api_key"] = crypto_service.encrypt(update_data["api_key"])
|
update_data["api_key"] = crypto_service.encrypt(update_data["api_key"])
|
||||||
|
|
||||||
|
# 特殊处理 max_concurrent:需要区分"未提供"和"显式设置为 null"
|
||||||
|
# 当 max_concurrent 被显式设置时(在 model_fields_set 中),即使值为 None 也应该更新
|
||||||
|
if "max_concurrent" in self.key_data.model_fields_set:
|
||||||
|
update_data["max_concurrent"] = self.key_data.max_concurrent
|
||||||
|
# 切换到自适应模式时,清空学习到的并发限制,让系统重新学习
|
||||||
|
if self.key_data.max_concurrent is None:
|
||||||
|
update_data["learned_max_concurrent"] = None
|
||||||
|
logger.info("Key %s 切换为自适应并发模式", self.key_id)
|
||||||
|
|
||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
setattr(key, field, value)
|
setattr(key, field, value)
|
||||||
key.updated_at = datetime.now(timezone.utc)
|
key.updated_at = datetime.now(timezone.utc)
|
||||||
@@ -253,7 +262,7 @@ class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
|
|||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(key)
|
db.refresh(key)
|
||||||
|
|
||||||
logger.info(f"[OK] 更新 Key: ID={self.key_id}, Updates={list(update_data.keys())}")
|
logger.info("[OK] 更新 Key: ID=%s, Updates=%s", self.key_id, list(update_data.keys()))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
decrypted_key = crypto_service.decrypt(key.api_key)
|
decrypted_key = crypto_service.decrypt(key.api_key)
|
||||||
|
|||||||
@@ -947,7 +947,7 @@ class AdminClearProviderCacheAdapter(AdminApiAdapter):
|
|||||||
class AdminCacheConfigAdapter(AdminApiAdapter):
|
class AdminCacheConfigAdapter(AdminApiAdapter):
|
||||||
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
|
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
|
||||||
from src.services.cache.affinity_manager import CacheAffinityManager
|
from src.services.cache.affinity_manager import CacheAffinityManager
|
||||||
from src.services.cache.aware_scheduler import CacheAwareScheduler
|
from src.config.constants import ConcurrencyDefaults
|
||||||
from src.services.rate_limit.adaptive_reservation import get_adaptive_reservation_manager
|
from src.services.rate_limit.adaptive_reservation import get_adaptive_reservation_manager
|
||||||
|
|
||||||
# 获取动态预留管理器的配置
|
# 获取动态预留管理器的配置
|
||||||
@@ -958,7 +958,7 @@ class AdminCacheConfigAdapter(AdminApiAdapter):
|
|||||||
"status": "ok",
|
"status": "ok",
|
||||||
"data": {
|
"data": {
|
||||||
"cache_ttl_seconds": CacheAffinityManager.DEFAULT_CACHE_TTL,
|
"cache_ttl_seconds": CacheAffinityManager.DEFAULT_CACHE_TTL,
|
||||||
"cache_reservation_ratio": CacheAwareScheduler.CACHE_RESERVATION_RATIO,
|
"cache_reservation_ratio": ConcurrencyDefaults.CACHE_RESERVATION_RATIO,
|
||||||
"dynamic_reservation": {
|
"dynamic_reservation": {
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
"config": reservation_stats["config"],
|
"config": reservation_stats["config"],
|
||||||
@@ -981,7 +981,7 @@ class AdminCacheConfigAdapter(AdminApiAdapter):
|
|||||||
context.add_audit_metadata(
|
context.add_audit_metadata(
|
||||||
action="cache_config",
|
action="cache_config",
|
||||||
cache_ttl_seconds=CacheAffinityManager.DEFAULT_CACHE_TTL,
|
cache_ttl_seconds=CacheAffinityManager.DEFAULT_CACHE_TTL,
|
||||||
cache_reservation_ratio=CacheAwareScheduler.CACHE_RESERVATION_RATIO,
|
cache_reservation_ratio=ConcurrencyDefaults.CACHE_RESERVATION_RATIO,
|
||||||
dynamic_reservation_enabled=True,
|
dynamic_reservation_enabled=True,
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
@@ -1236,7 +1236,7 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
|
|||||||
try:
|
try:
|
||||||
cached_data = json.loads(cached_str)
|
cached_data = json.loads(cached_str)
|
||||||
provider_model_name = cached_data.get("provider_model_name")
|
provider_model_name = cached_data.get("provider_model_name")
|
||||||
provider_model_mappings = cached_data.get("provider_model_mappings", [])
|
cached_model_mappings = cached_data.get("provider_model_mappings", [])
|
||||||
|
|
||||||
# 获取 Provider 和 GlobalModel 信息
|
# 获取 Provider 和 GlobalModel 信息
|
||||||
provider = provider_map.get(provider_id)
|
provider = provider_map.get(provider_id)
|
||||||
@@ -1245,8 +1245,8 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
|
|||||||
if provider and global_model:
|
if provider and global_model:
|
||||||
# 提取映射名称
|
# 提取映射名称
|
||||||
mapping_names = []
|
mapping_names = []
|
||||||
if provider_model_mappings:
|
if cached_model_mappings:
|
||||||
for mapping_entry in provider_model_mappings:
|
for mapping_entry in cached_model_mappings:
|
||||||
if isinstance(mapping_entry, dict) and mapping_entry.get("name"):
|
if isinstance(mapping_entry, dict) and mapping_entry.get("name"):
|
||||||
mapping_names.append(mapping_entry["name"])
|
mapping_names.append(mapping_entry["name"])
|
||||||
|
|
||||||
|
|||||||
@@ -32,6 +32,17 @@ class ModelsQueryRequest(BaseModel):
|
|||||||
api_key_id: Optional[str] = None
|
api_key_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelRequest(BaseModel):
|
||||||
|
"""模型测试请求"""
|
||||||
|
|
||||||
|
provider_id: str
|
||||||
|
model_name: str
|
||||||
|
api_key_id: Optional[str] = None
|
||||||
|
stream: bool = False
|
||||||
|
message: Optional[str] = "你好"
|
||||||
|
api_format: Optional[str] = None # 指定使用的API格式,如果不指定则使用端点的默认格式
|
||||||
|
|
||||||
|
|
||||||
# ============ API Endpoints ============
|
# ============ API Endpoints ============
|
||||||
|
|
||||||
|
|
||||||
@@ -206,3 +217,228 @@ async def query_available_models(
|
|||||||
"display_name": provider.display_name,
|
"display_name": provider.display_name,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/test-model")
|
||||||
|
async def test_model(
|
||||||
|
request: TestModelRequest,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
测试模型连接性
|
||||||
|
|
||||||
|
向指定提供商的指定模型发送测试请求,验证模型是否可用
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: 测试请求
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
测试结果
|
||||||
|
"""
|
||||||
|
# 获取提供商及其端点
|
||||||
|
provider = (
|
||||||
|
db.query(Provider)
|
||||||
|
.options(joinedload(Provider.endpoints).joinedload(ProviderEndpoint.api_keys))
|
||||||
|
.filter(Provider.id == request.provider_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not provider:
|
||||||
|
raise HTTPException(status_code=404, detail="Provider not found")
|
||||||
|
|
||||||
|
# 找到合适的端点和API Key
|
||||||
|
endpoint_config = None
|
||||||
|
endpoint = None
|
||||||
|
api_key = None
|
||||||
|
|
||||||
|
if request.api_key_id:
|
||||||
|
# 使用指定的API Key
|
||||||
|
for ep in provider.endpoints:
|
||||||
|
for key in ep.api_keys:
|
||||||
|
if key.id == request.api_key_id and key.is_active and ep.is_active:
|
||||||
|
endpoint = ep
|
||||||
|
api_key = key
|
||||||
|
break
|
||||||
|
if endpoint:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# 使用第一个可用的端点和密钥
|
||||||
|
for ep in provider.endpoints:
|
||||||
|
if not ep.is_active or not ep.api_keys:
|
||||||
|
continue
|
||||||
|
for key in ep.api_keys:
|
||||||
|
if key.is_active:
|
||||||
|
endpoint = ep
|
||||||
|
api_key = key
|
||||||
|
break
|
||||||
|
if endpoint:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not endpoint or not api_key:
|
||||||
|
raise HTTPException(status_code=404, detail="No active endpoint or API key found")
|
||||||
|
|
||||||
|
try:
|
||||||
|
api_key_value = crypto_service.decrypt(api_key.api_key)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[test-model] Failed to decrypt API key: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to decrypt API key")
|
||||||
|
|
||||||
|
# 构建请求配置
|
||||||
|
endpoint_config = {
|
||||||
|
"api_key": api_key_value,
|
||||||
|
"api_key_id": api_key.id, # 添加API Key ID用于用量记录
|
||||||
|
"base_url": endpoint.base_url,
|
||||||
|
"api_format": endpoint.api_format,
|
||||||
|
"extra_headers": endpoint.headers,
|
||||||
|
"timeout": endpoint.timeout or 30.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 获取对应的 Adapter 类
|
||||||
|
adapter_class = _get_adapter_for_format(endpoint.api_format)
|
||||||
|
if not adapter_class:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"Unknown API format: {endpoint.api_format}",
|
||||||
|
"provider": {
|
||||||
|
"id": provider.id,
|
||||||
|
"name": provider.name,
|
||||||
|
"display_name": provider.display_name,
|
||||||
|
},
|
||||||
|
"model": request.model_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(f"[test-model] 使用 Adapter: {adapter_class.__name__}")
|
||||||
|
logger.debug(f"[test-model] 端点 API Format: {endpoint.api_format}")
|
||||||
|
|
||||||
|
# 如果请求指定了 api_format,优先使用它
|
||||||
|
target_api_format = request.api_format or endpoint.api_format
|
||||||
|
if request.api_format and request.api_format != endpoint.api_format:
|
||||||
|
logger.debug(f"[test-model] 请求指定 API Format: {request.api_format}")
|
||||||
|
# 重新获取适配器
|
||||||
|
adapter_class = _get_adapter_for_format(request.api_format)
|
||||||
|
if not adapter_class:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"Unknown API format: {request.api_format}",
|
||||||
|
"provider": {
|
||||||
|
"id": provider.id,
|
||||||
|
"name": provider.name,
|
||||||
|
"display_name": provider.display_name,
|
||||||
|
},
|
||||||
|
"model": request.model_name,
|
||||||
|
}
|
||||||
|
logger.debug(f"[test-model] 重新选择 Adapter: {adapter_class.__name__}")
|
||||||
|
|
||||||
|
# 准备测试请求数据
|
||||||
|
check_request = {
|
||||||
|
"model": request.model_name,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": request.message or "Hello! This is a test message."}
|
||||||
|
],
|
||||||
|
"max_tokens": 30,
|
||||||
|
"temperature": 0.7,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 发送测试请求
|
||||||
|
async with httpx.AsyncClient(timeout=endpoint_config["timeout"]) as client:
|
||||||
|
# 非流式测试
|
||||||
|
logger.debug(f"[test-model] 开始非流式测试...")
|
||||||
|
|
||||||
|
response = await adapter_class.check_endpoint(
|
||||||
|
client,
|
||||||
|
endpoint_config["base_url"],
|
||||||
|
endpoint_config["api_key"],
|
||||||
|
check_request,
|
||||||
|
endpoint_config.get("extra_headers"),
|
||||||
|
# 用量计算参数(现在强制记录)
|
||||||
|
db=db,
|
||||||
|
user=current_user,
|
||||||
|
provider_name=provider.name,
|
||||||
|
provider_id=provider.id,
|
||||||
|
api_key_id=endpoint_config.get("api_key_id"),
|
||||||
|
model_name=request.model_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 记录提供商返回信息
|
||||||
|
logger.debug(f"[test-model] 非流式测试结果:")
|
||||||
|
logger.debug(f"[test-model] Status Code: {response.get('status_code')}")
|
||||||
|
logger.debug(f"[test-model] Response Headers: {response.get('headers', {})}")
|
||||||
|
response_data = response.get('response', {})
|
||||||
|
response_body = response_data.get('response_body', {})
|
||||||
|
logger.debug(f"[test-model] Response Data: {response_data}")
|
||||||
|
logger.debug(f"[test-model] Response Body: {response_body}")
|
||||||
|
# 尝试解析 response_body (通常是 JSON 字符串)
|
||||||
|
parsed_body = response_body
|
||||||
|
import json
|
||||||
|
if isinstance(response_body, str):
|
||||||
|
try:
|
||||||
|
parsed_body = json.loads(response_body)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if isinstance(parsed_body, dict) and 'error' in parsed_body:
|
||||||
|
error_obj = parsed_body['error']
|
||||||
|
# 兼容 error 可能是字典或字符串的情况
|
||||||
|
if isinstance(error_obj, dict):
|
||||||
|
logger.debug(f"[test-model] Error Message: {error_obj.get('message')}")
|
||||||
|
raise HTTPException(status_code=500, detail=error_obj.get('message'))
|
||||||
|
else:
|
||||||
|
logger.debug(f"[test-model] Error: {error_obj}")
|
||||||
|
raise HTTPException(status_code=500, detail=error_obj)
|
||||||
|
elif 'error' in response:
|
||||||
|
logger.debug(f"[test-model] Error: {response['error']}")
|
||||||
|
raise HTTPException(status_code=500, detail=response['error'])
|
||||||
|
else:
|
||||||
|
# 如果有选择或消息,记录内容预览
|
||||||
|
if isinstance(response_data, dict):
|
||||||
|
if 'choices' in response_data and response_data['choices']:
|
||||||
|
choice = response_data['choices'][0]
|
||||||
|
if 'message' in choice:
|
||||||
|
content = choice['message'].get('content', '')
|
||||||
|
logger.debug(f"[test-model] Content Preview: {content[:200]}...")
|
||||||
|
elif 'content' in response_data and response_data['content']:
|
||||||
|
content = str(response_data['content'])
|
||||||
|
logger.debug(f"[test-model] Content Preview: {content[:200]}...")
|
||||||
|
|
||||||
|
# 检查测试是否成功(基于HTTP状态码)
|
||||||
|
status_code = response.get('status_code', 0)
|
||||||
|
is_success = status_code == 200 and 'error' not in response
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": is_success,
|
||||||
|
"data": {
|
||||||
|
"stream": False,
|
||||||
|
"response": response,
|
||||||
|
},
|
||||||
|
"provider": {
|
||||||
|
"id": provider.id,
|
||||||
|
"name": provider.name,
|
||||||
|
"display_name": provider.display_name,
|
||||||
|
},
|
||||||
|
"model": request.model_name,
|
||||||
|
"endpoint": {
|
||||||
|
"id": endpoint.id,
|
||||||
|
"api_format": endpoint.api_format,
|
||||||
|
"base_url": endpoint.base_url,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[test-model] Error testing model {request.model_name}: {e}")
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": str(e),
|
||||||
|
"provider": {
|
||||||
|
"id": provider.id,
|
||||||
|
"name": provider.name,
|
||||||
|
"display_name": provider.display_name,
|
||||||
|
},
|
||||||
|
"model": request.model_name,
|
||||||
|
"endpoint": {
|
||||||
|
"id": endpoint.id,
|
||||||
|
"api_format": endpoint.api_format,
|
||||||
|
"base_url": endpoint.base_url,
|
||||||
|
} if endpoint else None,
|
||||||
|
}
|
||||||
|
|||||||
@@ -376,6 +376,9 @@ class BaseMessageHandler:
|
|||||||
|
|
||||||
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
|
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
|
||||||
|
|
||||||
|
注意:TTFB(首字节时间)由 StreamContext.record_first_byte_time() 记录,
|
||||||
|
并在最终 record_success 时传递到数据库,避免重复记录导致数据不一致。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request_id: 请求 ID,如果不传则使用 self.request_id
|
request_id: 请求 ID,如果不传则使用 self.request_id
|
||||||
"""
|
"""
|
||||||
@@ -407,6 +410,9 @@ class BaseMessageHandler:
|
|||||||
|
|
||||||
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
|
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
|
||||||
|
|
||||||
|
注意:TTFB(首字节时间)由 StreamContext.record_first_byte_time() 记录,
|
||||||
|
并在最终 record_success 时传递到数据库,避免重复记录导致数据不一致。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ctx: 流式上下文,包含 provider_name 和 mapped_model
|
ctx: 流式上下文,包含 provider_name 和 mapped_model
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -63,6 +63,34 @@ class ChatAdapterBase(ApiAdapter):
|
|||||||
name: str = "chat.base"
|
name: str = "chat.base"
|
||||||
mode = ApiMode.STANDARD
|
mode = ApiMode.STANDARD
|
||||||
|
|
||||||
|
# 子类可以配置的特殊方法(用于check_endpoint)
|
||||||
|
@classmethod
|
||||||
|
def build_endpoint_url(cls, base_url: str) -> str:
|
||||||
|
"""构建端点URL,子类可以覆盖以自定义URL构建逻辑"""
|
||||||
|
# 默认实现:在base_url后添加特定路径
|
||||||
|
return base_url
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
|
||||||
|
"""构建基础请求头,子类可以覆盖以自定义认证头"""
|
||||||
|
# 默认实现:Bearer token认证
|
||||||
|
return {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_protected_header_keys(cls) -> tuple:
|
||||||
|
"""返回不应被extra_headers覆盖的头部key,子类可以覆盖"""
|
||||||
|
# 默认保护认证相关头部
|
||||||
|
return ("authorization", "content-type")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""构建请求体,子类可以覆盖以自定义请求格式转换"""
|
||||||
|
# 默认实现:直接使用请求数据
|
||||||
|
return request_data.copy()
|
||||||
|
|
||||||
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
||||||
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
|
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
|
||||||
|
|
||||||
@@ -654,6 +682,65 @@ class ChatAdapterBase(ApiAdapter):
|
|||||||
# 默认实现返回空列表,子类应覆盖
|
# 默认实现返回空列表,子类应覆盖
|
||||||
return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models"
|
return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def check_endpoint(
|
||||||
|
cls,
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
base_url: str,
|
||||||
|
api_key: str,
|
||||||
|
request_data: Dict[str, Any],
|
||||||
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
|
# 用量计算参数(现在强制记录)
|
||||||
|
db: Optional[Any] = None,
|
||||||
|
user: Optional[Any] = None,
|
||||||
|
provider_name: Optional[str] = None,
|
||||||
|
provider_id: Optional[str] = None,
|
||||||
|
api_key_id: Optional[str] = None,
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
测试模型连接性(非流式)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: httpx 异步客户端
|
||||||
|
base_url: API 基础 URL
|
||||||
|
api_key: API 密钥(已解密)
|
||||||
|
request_data: 请求数据
|
||||||
|
extra_headers: 端点配置的额外请求头
|
||||||
|
db: 数据库会话
|
||||||
|
user: 用户对象
|
||||||
|
provider_name: 提供商名称
|
||||||
|
provider_id: 提供商ID
|
||||||
|
api_key_id: API Key ID
|
||||||
|
model_name: 模型名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
测试响应数据
|
||||||
|
"""
|
||||||
|
from src.api.handlers.base.endpoint_checker import build_safe_headers, run_endpoint_check
|
||||||
|
|
||||||
|
# 使用子类配置方法构建请求组件
|
||||||
|
url = cls.build_endpoint_url(base_url)
|
||||||
|
base_headers = cls.build_base_headers(api_key)
|
||||||
|
protected_keys = cls.get_protected_header_keys()
|
||||||
|
headers = build_safe_headers(base_headers, extra_headers, protected_keys)
|
||||||
|
body = cls.build_request_body(request_data)
|
||||||
|
|
||||||
|
# 使用通用的endpoint checker执行请求
|
||||||
|
return await run_endpoint_check(
|
||||||
|
client=client,
|
||||||
|
url=url,
|
||||||
|
headers=headers,
|
||||||
|
json_body=body,
|
||||||
|
api_format=cls.name,
|
||||||
|
# 用量计算参数(现在强制记录)
|
||||||
|
db=db,
|
||||||
|
user=user,
|
||||||
|
provider_name=provider_name,
|
||||||
|
provider_id=provider_id,
|
||||||
|
api_key_id=api_key_id,
|
||||||
|
model_name=model_name or request_data.get("model"),
|
||||||
|
)
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# Adapter 注册表 - 用于根据 API format 获取 Adapter 实例
|
# Adapter 注册表 - 用于根据 API format 获取 Adapter 实例
|
||||||
|
|||||||
@@ -484,9 +484,8 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
|
|
||||||
stream_response.raise_for_status()
|
stream_response.raise_for_status()
|
||||||
|
|
||||||
# 使用字节流迭代器(避免 aiter_lines 的性能问题)
|
# 使用字节流迭代器(避免 aiter_lines 的性能问题, aiter_bytes 会自动解压 gzip/deflate)
|
||||||
# aiter_raw() 返回原始数据块,无缓冲,实现真正的流式传输
|
byte_iterator = stream_response.aiter_bytes()
|
||||||
byte_iterator = stream_response.aiter_raw()
|
|
||||||
|
|
||||||
# 预读检测嵌套错误
|
# 预读检测嵌套错误
|
||||||
prefetched_chunks = await stream_processor.prefetch_and_check_error(
|
prefetched_chunks = await stream_processor.prefetch_and_check_error(
|
||||||
|
|||||||
@@ -614,6 +614,146 @@ class CliAdapterBase(ApiAdapter):
|
|||||||
# 默认实现返回空列表,子类应覆盖
|
# 默认实现返回空列表,子类应覆盖
|
||||||
return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models"
|
return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def check_endpoint(
|
||||||
|
cls,
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
base_url: str,
|
||||||
|
api_key: str,
|
||||||
|
request_data: Dict[str, Any],
|
||||||
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
|
# 用量计算参数
|
||||||
|
db: Optional[Any] = None,
|
||||||
|
user: Optional[Any] = None,
|
||||||
|
provider_name: Optional[str] = None,
|
||||||
|
provider_id: Optional[str] = None,
|
||||||
|
api_key_id: Optional[str] = None,
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
测试模型连接性(非流式)
|
||||||
|
|
||||||
|
通用的CLI endpoint测试方法,使用配置方法模式:
|
||||||
|
- build_endpoint_url(): 构建请求URL
|
||||||
|
- build_base_headers(): 构建基础认证头
|
||||||
|
- get_protected_header_keys(): 获取受保护的头部key
|
||||||
|
- build_request_body(): 构建请求体
|
||||||
|
- get_cli_user_agent(): 获取CLI User-Agent(子类可覆盖)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: httpx 异步客户端
|
||||||
|
base_url: API 基础 URL
|
||||||
|
api_key: API 密钥(已解密)
|
||||||
|
request_data: 请求数据
|
||||||
|
extra_headers: 端点配置的额外请求头
|
||||||
|
db: 数据库会话
|
||||||
|
user: 用户对象
|
||||||
|
provider_name: 提供商名称
|
||||||
|
provider_id: 提供商ID
|
||||||
|
api_key_id: API密钥ID
|
||||||
|
model_name: 模型名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
测试响应数据
|
||||||
|
"""
|
||||||
|
from src.api.handlers.base.endpoint_checker import build_safe_headers, run_endpoint_check
|
||||||
|
|
||||||
|
# 构建请求组件
|
||||||
|
url = cls.build_endpoint_url(base_url, request_data, model_name)
|
||||||
|
base_headers = cls.build_base_headers(api_key)
|
||||||
|
protected_keys = cls.get_protected_header_keys()
|
||||||
|
|
||||||
|
# 添加CLI User-Agent
|
||||||
|
cli_user_agent = cls.get_cli_user_agent()
|
||||||
|
if cli_user_agent:
|
||||||
|
base_headers["User-Agent"] = cli_user_agent
|
||||||
|
protected_keys = tuple(list(protected_keys) + ["user-agent"])
|
||||||
|
|
||||||
|
headers = build_safe_headers(base_headers, extra_headers, protected_keys)
|
||||||
|
body = cls.build_request_body(request_data)
|
||||||
|
|
||||||
|
# 获取有效的模型名称
|
||||||
|
effective_model_name = model_name or request_data.get("model")
|
||||||
|
|
||||||
|
return await run_endpoint_check(
|
||||||
|
client=client,
|
||||||
|
url=url,
|
||||||
|
headers=headers,
|
||||||
|
json_body=body,
|
||||||
|
api_format=cls.name,
|
||||||
|
# 用量计算参数(现在强制记录)
|
||||||
|
db=db,
|
||||||
|
user=user,
|
||||||
|
provider_name=provider_name,
|
||||||
|
provider_id=provider_id,
|
||||||
|
api_key_id=api_key_id,
|
||||||
|
model_name=effective_model_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# CLI Adapter 配置方法 - 子类应覆盖这些方法而不是整个 check_endpoint
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_endpoint_url(cls, base_url: str, request_data: Dict[str, Any], model_name: Optional[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
构建CLI API端点URL - 子类应覆盖
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_url: API基础URL
|
||||||
|
request_data: 请求数据
|
||||||
|
model_name: 模型名称(某些API需要,如Gemini)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
完整的端点URL
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement build_endpoint_url")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
构建CLI API认证头 - 子类应覆盖
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API密钥
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
基础认证头部字典
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement build_base_headers")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_protected_header_keys(cls) -> tuple:
|
||||||
|
"""
|
||||||
|
返回CLI API的保护头部key - 子类应覆盖
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
保护头部key的元组
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement get_protected_header_keys")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
构建CLI API请求体 - 子类应覆盖
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_data: 请求数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
请求体字典
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement build_request_body")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_cli_user_agent(cls) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
获取CLI User-Agent - 子类可覆盖
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CLI User-Agent字符串,如果不需要则为None
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# CLI Adapter 注册表 - 用于根据 API format 获取 CLI Adapter 实例
|
# CLI Adapter 注册表 - 用于根据 API format 获取 CLI Adapter 实例
|
||||||
|
|||||||
@@ -57,8 +57,10 @@ from src.models.database import (
|
|||||||
ProviderEndpoint,
|
ProviderEndpoint,
|
||||||
User,
|
User,
|
||||||
)
|
)
|
||||||
|
from src.config.settings import config
|
||||||
from src.services.provider.transport import build_provider_url
|
from src.services.provider.transport import build_provider_url
|
||||||
from src.utils.sse_parser import SSEEventParser
|
from src.utils.sse_parser import SSEEventParser
|
||||||
|
from src.utils.timeout import read_first_chunk_with_ttfb_timeout
|
||||||
|
|
||||||
|
|
||||||
class CliMessageHandlerBase(BaseMessageHandler):
|
class CliMessageHandlerBase(BaseMessageHandler):
|
||||||
@@ -474,8 +476,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
|
|
||||||
stream_response.raise_for_status()
|
stream_response.raise_for_status()
|
||||||
|
|
||||||
# 使用字节流迭代器(避免 aiter_lines 的性能问题)
|
# 使用字节流迭代器(避免 aiter_lines 的性能问题, aiter_bytes 会自动解压 gzip/deflate)
|
||||||
byte_iterator = stream_response.aiter_raw()
|
byte_iterator = stream_response.aiter_bytes()
|
||||||
|
|
||||||
# 预读第一个数据块,检测嵌套错误(HTTP 200 但响应体包含错误)
|
# 预读第一个数据块,检测嵌套错误(HTTP 200 但响应体包含错误)
|
||||||
prefetched_chunks = await self._prefetch_and_check_embedded_error(
|
prefetched_chunks = await self._prefetch_and_check_embedded_error(
|
||||||
@@ -529,7 +531,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
# 检查是否需要格式转换
|
# 检查是否需要格式转换
|
||||||
needs_conversion = self._needs_format_conversion(ctx)
|
needs_conversion = self._needs_format_conversion(ctx)
|
||||||
|
|
||||||
async for chunk in stream_response.aiter_raw():
|
async for chunk in stream_response.aiter_bytes():
|
||||||
# 在第一次输出数据前更新状态为 streaming
|
# 在第一次输出数据前更新状态为 streaming
|
||||||
if not streaming_status_updated:
|
if not streaming_status_updated:
|
||||||
self._update_usage_to_streaming_with_ctx(ctx)
|
self._update_usage_to_streaming_with_ctx(ctx)
|
||||||
@@ -672,6 +674,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
|
|
||||||
同时检测 HTML 响应(通常是 base_url 配置错误导致返回网页)。
|
同时检测 HTML 响应(通常是 base_url 配置错误导致返回网页)。
|
||||||
|
|
||||||
|
首次读取时会应用 TTFB(首字节超时)检测,超时则触发故障转移。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
byte_iterator: 字节流迭代器
|
byte_iterator: 字节流迭代器
|
||||||
provider: Provider 对象
|
provider: Provider 对象
|
||||||
@@ -684,6 +688,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
Raises:
|
Raises:
|
||||||
EmbeddedErrorException: 如果检测到嵌套错误
|
EmbeddedErrorException: 如果检测到嵌套错误
|
||||||
ProviderNotAvailableException: 如果检测到 HTML 响应(配置错误)
|
ProviderNotAvailableException: 如果检测到 HTML 响应(配置错误)
|
||||||
|
ProviderTimeoutException: 如果首字节超时(TTFB timeout)
|
||||||
"""
|
"""
|
||||||
prefetched_chunks: list = []
|
prefetched_chunks: list = []
|
||||||
max_prefetch_lines = 5 # 最多预读5行来检测错误
|
max_prefetch_lines = 5 # 最多预读5行来检测错误
|
||||||
@@ -704,7 +709,19 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
else:
|
else:
|
||||||
provider_parser = self.parser
|
provider_parser = self.parser
|
||||||
|
|
||||||
async for chunk in byte_iterator:
|
# 使用共享的 TTFB 超时函数读取首字节
|
||||||
|
ttfb_timeout = config.stream_first_byte_timeout
|
||||||
|
first_chunk, aiter = await read_first_chunk_with_ttfb_timeout(
|
||||||
|
byte_iterator,
|
||||||
|
timeout=ttfb_timeout,
|
||||||
|
request_id=self.request_id,
|
||||||
|
provider_name=str(provider.name),
|
||||||
|
)
|
||||||
|
prefetched_chunks.append(first_chunk)
|
||||||
|
buffer += first_chunk
|
||||||
|
|
||||||
|
# 继续读取剩余的预读数据
|
||||||
|
async for chunk in aiter:
|
||||||
prefetched_chunks.append(chunk)
|
prefetched_chunks.append(chunk)
|
||||||
buffer += chunk
|
buffer += chunk
|
||||||
|
|
||||||
@@ -785,12 +802,21 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
if should_stop or line_count >= max_prefetch_lines:
|
if should_stop or line_count >= max_prefetch_lines:
|
||||||
break
|
break
|
||||||
|
|
||||||
except EmbeddedErrorException:
|
except (EmbeddedErrorException, ProviderTimeoutException, ProviderNotAvailableException):
|
||||||
# 重新抛出嵌套错误
|
# 重新抛出可重试的 Provider 异常,触发故障转移
|
||||||
raise
|
raise
|
||||||
|
except (OSError, IOError) as e:
|
||||||
|
# 网络 I/O 异常:记录警告,可能需要重试
|
||||||
|
logger.warning(
|
||||||
|
f" [{self.request_id}] 预读流时发生网络异常: {type(e).__name__}: {e}"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 其他异常(如网络错误)在预读阶段发生,记录日志但不中断
|
# 未预期的严重异常:记录错误并重新抛出,避免掩盖问题
|
||||||
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
|
logger.error(
|
||||||
|
f" [{self.request_id}] 预读流时发生严重异常: {type(e).__name__}: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
return prefetched_chunks
|
return prefetched_chunks
|
||||||
|
|
||||||
|
|||||||
1252
src/api/handlers/base/endpoint_checker.py
Normal file
1252
src/api/handlers/base/endpoint_checker.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -25,10 +25,12 @@ from src.api.handlers.base.content_extractors import (
|
|||||||
from src.api.handlers.base.parsers import get_parser_for_format
|
from src.api.handlers.base.parsers import get_parser_for_format
|
||||||
from src.api.handlers.base.response_parser import ResponseParser
|
from src.api.handlers.base.response_parser import ResponseParser
|
||||||
from src.api.handlers.base.stream_context import StreamContext
|
from src.api.handlers.base.stream_context import StreamContext
|
||||||
from src.core.exceptions import EmbeddedErrorException
|
from src.config.settings import config
|
||||||
|
from src.core.exceptions import EmbeddedErrorException, ProviderTimeoutException
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
from src.models.database import Provider, ProviderEndpoint
|
from src.models.database import Provider, ProviderEndpoint
|
||||||
from src.utils.sse_parser import SSEEventParser
|
from src.utils.sse_parser import SSEEventParser
|
||||||
|
from src.utils.timeout import read_first_chunk_with_ttfb_timeout
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -170,6 +172,8 @@ class StreamProcessor:
|
|||||||
某些 Provider(如 Gemini)可能返回 HTTP 200,但在响应体中包含错误信息。
|
某些 Provider(如 Gemini)可能返回 HTTP 200,但在响应体中包含错误信息。
|
||||||
这种情况需要在流开始输出之前检测,以便触发重试逻辑。
|
这种情况需要在流开始输出之前检测,以便触发重试逻辑。
|
||||||
|
|
||||||
|
首次读取时会应用 TTFB(首字节超时)检测,超时则触发故障转移。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
byte_iterator: 字节流迭代器
|
byte_iterator: 字节流迭代器
|
||||||
provider: Provider 对象
|
provider: Provider 对象
|
||||||
@@ -182,6 +186,7 @@ class StreamProcessor:
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
EmbeddedErrorException: 如果检测到嵌套错误
|
EmbeddedErrorException: 如果检测到嵌套错误
|
||||||
|
ProviderTimeoutException: 如果首字节超时(TTFB timeout)
|
||||||
"""
|
"""
|
||||||
prefetched_chunks: list = []
|
prefetched_chunks: list = []
|
||||||
parser = self.get_parser_for_provider(ctx)
|
parser = self.get_parser_for_provider(ctx)
|
||||||
@@ -192,7 +197,19 @@ class StreamProcessor:
|
|||||||
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for chunk in byte_iterator:
|
# 使用共享的 TTFB 超时函数读取首字节
|
||||||
|
ttfb_timeout = config.stream_first_byte_timeout
|
||||||
|
first_chunk, aiter = await read_first_chunk_with_ttfb_timeout(
|
||||||
|
byte_iterator,
|
||||||
|
timeout=ttfb_timeout,
|
||||||
|
request_id=self.request_id,
|
||||||
|
provider_name=str(provider.name),
|
||||||
|
)
|
||||||
|
prefetched_chunks.append(first_chunk)
|
||||||
|
buffer += first_chunk
|
||||||
|
|
||||||
|
# 继续读取剩余的预读数据
|
||||||
|
async for chunk in aiter:
|
||||||
prefetched_chunks.append(chunk)
|
prefetched_chunks.append(chunk)
|
||||||
buffer += chunk
|
buffer += chunk
|
||||||
|
|
||||||
@@ -262,10 +279,21 @@ class StreamProcessor:
|
|||||||
if should_stop or line_count >= max_prefetch_lines:
|
if should_stop or line_count >= max_prefetch_lines:
|
||||||
break
|
break
|
||||||
|
|
||||||
except EmbeddedErrorException:
|
except (EmbeddedErrorException, ProviderTimeoutException):
|
||||||
|
# 重新抛出可重试的 Provider 异常,触发故障转移
|
||||||
raise
|
raise
|
||||||
|
except (OSError, IOError) as e:
|
||||||
|
# 网络 I/O <20><><EFBFBD>常:记录警告,可能需要重试
|
||||||
|
logger.warning(
|
||||||
|
f" [{self.request_id}] 预读流时发生网络异常: {type(e).__name__}: {e}"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
|
# 未预期的严重异常:记录错误并重新抛出,避免掩盖问题
|
||||||
|
logger.error(
|
||||||
|
f" [{self.request_id}] 预读流时发生严重异常: {type(e).__name__}: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
return prefetched_chunks
|
return prefetched_chunks
|
||||||
|
|
||||||
|
|||||||
@@ -4,17 +4,28 @@ Handler 基础工具函数
|
|||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from src.core.logger import logger
|
||||||
|
|
||||||
|
|
||||||
def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int:
|
def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int:
|
||||||
"""
|
"""
|
||||||
提取缓存创建 tokens(兼容新旧格式)
|
提取缓存创建 tokens(兼容三种格式)
|
||||||
|
|
||||||
Claude API 在不同版本中使用了不同的字段名来表示缓存创建 tokens:
|
根据 Anthropic API 文档,支持三种格式(按优先级):
|
||||||
- 新格式(2024年后):使用 claude_cache_creation_5_m_tokens 和
|
|
||||||
claude_cache_creation_1_h_tokens 分别表示 5 分钟和 1 小时缓存
|
|
||||||
- 旧格式:使用 cache_creation_input_tokens 表示总的缓存创建 tokens
|
|
||||||
|
|
||||||
此函数自动检测并适配两种格式,优先使用新格式。
|
1. **嵌套格式(优先级最高)**:
|
||||||
|
usage.cache_creation.ephemeral_5m_input_tokens
|
||||||
|
usage.cache_creation.ephemeral_1h_input_tokens
|
||||||
|
|
||||||
|
2. **扁平新格式(优先级第二)**:
|
||||||
|
usage.claude_cache_creation_5_m_tokens
|
||||||
|
usage.claude_cache_creation_1_h_tokens
|
||||||
|
|
||||||
|
3. **旧格式(优先级第三)**:
|
||||||
|
usage.cache_creation_input_tokens
|
||||||
|
|
||||||
|
优先使用嵌套格式,如果嵌套格式字段存在但值为 0,则智能 fallback 到旧格式。
|
||||||
|
扁平格式和嵌套格式互斥,按顺序检查。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
usage: API 响应中的 usage 字典
|
usage: API 响应中的 usage 字典
|
||||||
@@ -22,20 +33,63 @@ def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int:
|
|||||||
Returns:
|
Returns:
|
||||||
缓存创建 tokens 总数
|
缓存创建 tokens 总数
|
||||||
"""
|
"""
|
||||||
# 检查新格式字段是否存在(而非值是否为 0)
|
# 1. 检查嵌套格式(最新格式)
|
||||||
# 如果字段存在,即使值为 0 也是合法的,不应 fallback 到旧格式
|
cache_creation = usage.get("cache_creation")
|
||||||
has_new_format = (
|
if isinstance(cache_creation, dict):
|
||||||
|
cache_5m = int(cache_creation.get("ephemeral_5m_input_tokens", 0))
|
||||||
|
cache_1h = int(cache_creation.get("ephemeral_1h_input_tokens", 0))
|
||||||
|
total = cache_5m + cache_1h
|
||||||
|
|
||||||
|
if total > 0:
|
||||||
|
logger.debug(
|
||||||
|
f"Using nested cache_creation: 5m={cache_5m}, 1h={cache_1h}, total={total}"
|
||||||
|
)
|
||||||
|
return total
|
||||||
|
|
||||||
|
# 嵌套格式存在但为 0,fallback 到旧格式
|
||||||
|
old_format = int(usage.get("cache_creation_input_tokens", 0))
|
||||||
|
if old_format > 0:
|
||||||
|
logger.debug(
|
||||||
|
f"Nested cache_creation is 0, using old format: {old_format}"
|
||||||
|
)
|
||||||
|
return old_format
|
||||||
|
|
||||||
|
# 都是 0,返回 0
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# 2. 检查扁平新格式
|
||||||
|
has_flat_format = (
|
||||||
"claude_cache_creation_5_m_tokens" in usage
|
"claude_cache_creation_5_m_tokens" in usage
|
||||||
or "claude_cache_creation_1_h_tokens" in usage
|
or "claude_cache_creation_1_h_tokens" in usage
|
||||||
)
|
)
|
||||||
|
|
||||||
if has_new_format:
|
if has_flat_format:
|
||||||
cache_5m = usage.get("claude_cache_creation_5_m_tokens", 0)
|
cache_5m = int(usage.get("claude_cache_creation_5_m_tokens", 0))
|
||||||
cache_1h = usage.get("claude_cache_creation_1_h_tokens", 0)
|
cache_1h = int(usage.get("claude_cache_creation_1_h_tokens", 0))
|
||||||
return int(cache_5m) + int(cache_1h)
|
total = cache_5m + cache_1h
|
||||||
|
|
||||||
# 回退到旧格式
|
if total > 0:
|
||||||
return int(usage.get("cache_creation_input_tokens", 0))
|
logger.debug(
|
||||||
|
f"Using flat new format: 5m={cache_5m}, 1h={cache_1h}, total={total}"
|
||||||
|
)
|
||||||
|
return total
|
||||||
|
|
||||||
|
# 扁平格式存在但为 0,fallback 到旧格式
|
||||||
|
old_format = int(usage.get("cache_creation_input_tokens", 0))
|
||||||
|
if old_format > 0:
|
||||||
|
logger.debug(
|
||||||
|
f"Flat cache_creation is 0, using old format: {old_format}"
|
||||||
|
)
|
||||||
|
return old_format
|
||||||
|
|
||||||
|
# 都是 0,返回 0
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# 3. 回退到旧格式
|
||||||
|
old_format = int(usage.get("cache_creation_input_tokens", 0))
|
||||||
|
if old_format > 0:
|
||||||
|
logger.debug(f"Using old format: cache_creation_input_tokens={old_format}")
|
||||||
|
return old_format
|
||||||
|
|
||||||
|
|
||||||
def build_sse_headers(extra_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
|
def build_sse_headers(extra_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
|
||||||
|
|||||||
@@ -209,6 +209,38 @@ class ClaudeChatAdapter(ChatAdapterBase):
|
|||||||
logger.warning(f"Failed to fetch Claude models from {models_url}: {e}")
|
logger.warning(f"Failed to fetch Claude models from {models_url}: {e}")
|
||||||
return [], error_msg
|
return [], error_msg
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_endpoint_url(cls, base_url: str) -> str:
|
||||||
|
"""构建Claude API端点URL"""
|
||||||
|
base_url = base_url.rstrip("/")
|
||||||
|
if base_url.endswith("/v1"):
|
||||||
|
return f"{base_url}/messages"
|
||||||
|
else:
|
||||||
|
return f"{base_url}/v1/messages"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
|
||||||
|
"""构建Claude API认证头"""
|
||||||
|
return {
|
||||||
|
"x-api-key": api_key,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"anthropic-version": "2023-06-01",
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_protected_header_keys(cls) -> tuple:
|
||||||
|
"""返回Claude API的保护头部key"""
|
||||||
|
return ("x-api-key", "content-type", "anthropic-version")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""构建Claude API请求体"""
|
||||||
|
return {
|
||||||
|
"model": request_data.get("model"),
|
||||||
|
"max_tokens": request_data.get("max_tokens", 100),
|
||||||
|
"messages": request_data.get("messages", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_claude_adapter(x_app_header: Optional[str]):
|
def build_claude_adapter(x_app_header: Optional[str]):
|
||||||
"""根据 x-app 头部构造 Chat 或 Claude Code 适配器。"""
|
"""根据 x-app 头部构造 Chat 或 Claude Code 适配器。"""
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ Claude CLI Adapter - 基于通用 CLI Adapter 基类的简化实现
|
|||||||
继承 CliAdapterBase,只需配置 FORMAT_ID 和 HANDLER_CLASS。
|
继承 CliAdapterBase,只需配置 FORMAT_ID 和 HANDLER_CLASS。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Tuple, Type
|
from typing import Any, AsyncIterator, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
@@ -115,7 +115,7 @@ class ClaudeCliAdapter(CliAdapterBase):
|
|||||||
) -> Tuple[list, Optional[str]]:
|
) -> Tuple[list, Optional[str]]:
|
||||||
"""查询 Claude API 支持的模型列表(带 CLI User-Agent)"""
|
"""查询 Claude API 支持的模型列表(带 CLI User-Agent)"""
|
||||||
# 复用 ClaudeChatAdapter 的实现,添加 CLI User-Agent
|
# 复用 ClaudeChatAdapter 的实现,添加 CLI User-Agent
|
||||||
cli_headers = {"User-Agent": config.internal_user_agent_claude}
|
cli_headers = {"User-Agent": config.internal_user_agent_claude_cli}
|
||||||
if extra_headers:
|
if extra_headers:
|
||||||
cli_headers.update(extra_headers)
|
cli_headers.update(extra_headers)
|
||||||
models, error = await ClaudeChatAdapter.fetch_models(
|
models, error = await ClaudeChatAdapter.fetch_models(
|
||||||
@@ -126,5 +126,41 @@ class ClaudeCliAdapter(CliAdapterBase):
|
|||||||
m["api_format"] = cls.FORMAT_ID
|
m["api_format"] = cls.FORMAT_ID
|
||||||
return models, error
|
return models, error
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_endpoint_url(cls, base_url: str, request_data: Dict[str, Any], model_name: Optional[str] = None) -> str:
|
||||||
|
"""构建Claude CLI API端点URL"""
|
||||||
|
base_url = base_url.rstrip("/")
|
||||||
|
if base_url.endswith("/v1"):
|
||||||
|
return f"{base_url}/messages"
|
||||||
|
else:
|
||||||
|
return f"{base_url}/v1/messages"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
|
||||||
|
"""构建Claude CLI API认证头"""
|
||||||
|
return {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_protected_header_keys(cls) -> tuple:
|
||||||
|
"""返回Claude CLI API的保护头部key"""
|
||||||
|
return ("authorization", "content-type")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""构建Claude CLI API请求体"""
|
||||||
|
return {
|
||||||
|
"model": request_data.get("model"),
|
||||||
|
"max_tokens": request_data.get("max_tokens", 100),
|
||||||
|
"messages": request_data.get("messages", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_cli_user_agent(cls) -> Optional[str]:
|
||||||
|
"""获取Claude CLI User-Agent"""
|
||||||
|
return config.internal_user_agent_claude_cli
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["ClaudeCliAdapter"]
|
__all__ = ["ClaudeCliAdapter"]
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ Gemini Chat Adapter
|
|||||||
处理 Gemini API 格式的请求适配
|
处理 Gemini API 格式的请求适配
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Tuple, Type
|
from typing import Any, AsyncIterator, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Request
|
||||||
@@ -12,6 +12,7 @@ from fastapi.responses import JSONResponse
|
|||||||
|
|
||||||
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
|
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
|
||||||
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
||||||
|
from src.api.handlers.base.endpoint_checker import build_safe_headers, run_endpoint_check
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
from src.models.gemini import GeminiRequest
|
from src.models.gemini import GeminiRequest
|
||||||
|
|
||||||
@@ -199,6 +200,94 @@ class GeminiChatAdapter(ChatAdapterBase):
|
|||||||
logger.warning(f"Failed to fetch Gemini models from {models_url}: {e}")
|
logger.warning(f"Failed to fetch Gemini models from {models_url}: {e}")
|
||||||
return [], error_msg
|
return [], error_msg
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_endpoint_url(cls, base_url: str) -> str:
|
||||||
|
"""构建Gemini API端点URL"""
|
||||||
|
base_url = base_url.rstrip("/")
|
||||||
|
if base_url.endswith("/v1beta"):
|
||||||
|
return base_url # 子类需要处理model参数
|
||||||
|
else:
|
||||||
|
return f"{base_url}/v1beta"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
|
||||||
|
"""构建Gemini API认证头"""
|
||||||
|
return {
|
||||||
|
"x-goog-api-key": api_key,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_protected_header_keys(cls) -> tuple:
|
||||||
|
"""返回Gemini API的保护头部key"""
|
||||||
|
return ("x-goog-api-key", "content-type")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""构建Gemini API请求体"""
|
||||||
|
return {
|
||||||
|
"contents": request_data.get("messages", []),
|
||||||
|
"generationConfig": {
|
||||||
|
"maxOutputTokens": request_data.get("max_tokens", 100),
|
||||||
|
"temperature": request_data.get("temperature", 0.7),
|
||||||
|
},
|
||||||
|
"safetySettings": [
|
||||||
|
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def check_endpoint(
|
||||||
|
cls,
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
base_url: str,
|
||||||
|
api_key: str,
|
||||||
|
request_data: Dict[str, Any],
|
||||||
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
|
# 用量计算参数
|
||||||
|
db: Optional[Any] = None,
|
||||||
|
user: Optional[Any] = None,
|
||||||
|
provider_name: Optional[str] = None,
|
||||||
|
provider_id: Optional[str] = None,
|
||||||
|
api_key_id: Optional[str] = None,
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""测试 Gemini API 模型连接性(非流式)"""
|
||||||
|
# Gemini需要从request_data或model_name参数获取model名称
|
||||||
|
effective_model_name = model_name or request_data.get("model", "")
|
||||||
|
if not effective_model_name:
|
||||||
|
return {
|
||||||
|
"error": "Model name is required for Gemini API",
|
||||||
|
"status_code": 400,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 使用基类配置方法,但重写URL构建逻辑
|
||||||
|
base_url = cls.build_endpoint_url(base_url)
|
||||||
|
url = f"{base_url}/models/{effective_model_name}:generateContent"
|
||||||
|
|
||||||
|
# 构建请求组件
|
||||||
|
base_headers = cls.build_base_headers(api_key)
|
||||||
|
protected_keys = cls.get_protected_header_keys()
|
||||||
|
headers = build_safe_headers(base_headers, extra_headers, protected_keys)
|
||||||
|
body = cls.build_request_body(request_data)
|
||||||
|
|
||||||
|
# 使用基类的通用endpoint checker
|
||||||
|
from src.api.handlers.base.endpoint_checker import run_endpoint_check
|
||||||
|
return await run_endpoint_check(
|
||||||
|
client=client,
|
||||||
|
url=url,
|
||||||
|
headers=headers,
|
||||||
|
json_body=body,
|
||||||
|
api_format=cls.name,
|
||||||
|
# 用量计算参数(现在强制记录)
|
||||||
|
db=db,
|
||||||
|
user=user,
|
||||||
|
provider_name=provider_name,
|
||||||
|
provider_id=provider_id,
|
||||||
|
api_key_id=api_key_id,
|
||||||
|
model_name=effective_model_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_gemini_adapter(x_app_header: str = "") -> GeminiChatAdapter:
|
def build_gemini_adapter(x_app_header: str = "") -> GeminiChatAdapter:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ Gemini CLI Adapter - 基于通用 CLI Adapter 基类的实现
|
|||||||
继承 CliAdapterBase,处理 Gemini CLI 格式的请求。
|
继承 CliAdapterBase,处理 Gemini CLI 格式的请求。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Tuple, Type
|
from typing import Any, AsyncIterator, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
@@ -112,7 +112,7 @@ class GeminiCliAdapter(CliAdapterBase):
|
|||||||
) -> Tuple[list, Optional[str]]:
|
) -> Tuple[list, Optional[str]]:
|
||||||
"""查询 Gemini API 支持的模型列表(带 CLI User-Agent)"""
|
"""查询 Gemini API 支持的模型列表(带 CLI User-Agent)"""
|
||||||
# 复用 GeminiChatAdapter 的实现,添加 CLI User-Agent
|
# 复用 GeminiChatAdapter 的实现,添加 CLI User-Agent
|
||||||
cli_headers = {"User-Agent": config.internal_user_agent_gemini}
|
cli_headers = {"User-Agent": config.internal_user_agent_gemini_cli}
|
||||||
if extra_headers:
|
if extra_headers:
|
||||||
cli_headers.update(extra_headers)
|
cli_headers.update(extra_headers)
|
||||||
models, error = await GeminiChatAdapter.fetch_models(
|
models, error = await GeminiChatAdapter.fetch_models(
|
||||||
@@ -123,6 +123,52 @@ class GeminiCliAdapter(CliAdapterBase):
|
|||||||
m["api_format"] = cls.FORMAT_ID
|
m["api_format"] = cls.FORMAT_ID
|
||||||
return models, error
|
return models, error
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_endpoint_url(cls, base_url: str, request_data: Dict[str, Any], model_name: Optional[str] = None) -> str:
|
||||||
|
"""构建Gemini CLI API端点URL"""
|
||||||
|
effective_model_name = model_name or request_data.get("model", "")
|
||||||
|
if not effective_model_name:
|
||||||
|
raise ValueError("Model name is required for Gemini API")
|
||||||
|
|
||||||
|
base_url = base_url.rstrip("/")
|
||||||
|
if base_url.endswith("/v1beta"):
|
||||||
|
prefix = base_url
|
||||||
|
else:
|
||||||
|
prefix = f"{base_url}/v1beta"
|
||||||
|
return f"{prefix}/models/{effective_model_name}:generateContent"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
|
||||||
|
"""构建Gemini CLI API认证头"""
|
||||||
|
return {
|
||||||
|
"x-goog-api-key": api_key,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_protected_header_keys(cls) -> tuple:
|
||||||
|
"""返回Gemini CLI API的保护头部key"""
|
||||||
|
return ("x-goog-api-key", "content-type")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""构建Gemini CLI API请求体"""
|
||||||
|
return {
|
||||||
|
"contents": request_data.get("messages", []),
|
||||||
|
"generationConfig": {
|
||||||
|
"maxOutputTokens": request_data.get("max_tokens", 100),
|
||||||
|
"temperature": request_data.get("temperature", 0.7),
|
||||||
|
},
|
||||||
|
"safetySettings": [
|
||||||
|
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_cli_user_agent(cls) -> Optional[str]:
|
||||||
|
"""获取Gemini CLI User-Agent"""
|
||||||
|
return config.internal_user_agent_gemini_cli
|
||||||
|
|
||||||
|
|
||||||
def build_gemini_cli_adapter(x_app_header: str = "") -> GeminiCliAdapter:
|
def build_gemini_cli_adapter(x_app_header: str = "") -> GeminiCliAdapter:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -4,13 +4,14 @@ OpenAI Chat Adapter - 基于 ChatAdapterBase 的 OpenAI Chat API 适配器
|
|||||||
处理 /v1/chat/completions 端点的 OpenAI Chat 格式请求。
|
处理 /v1/chat/completions 端点的 OpenAI Chat 格式请求。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Tuple, Type
|
from typing import Any, AsyncIterator, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
|
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
|
||||||
|
from src.api.handlers.base.endpoint_checker import build_safe_headers, run_endpoint_check
|
||||||
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
from src.models.openai import OpenAIRequest
|
from src.models.openai import OpenAIRequest
|
||||||
@@ -154,5 +155,32 @@ class OpenAIChatAdapter(ChatAdapterBase):
|
|||||||
logger.warning(f"Failed to fetch models from {models_url}: {e}")
|
logger.warning(f"Failed to fetch models from {models_url}: {e}")
|
||||||
return [], error_msg
|
return [], error_msg
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_endpoint_url(cls, base_url: str) -> str:
|
||||||
|
"""构建OpenAI API端点URL"""
|
||||||
|
base_url = base_url.rstrip("/")
|
||||||
|
if base_url.endswith("/v1"):
|
||||||
|
return f"{base_url}/chat/completions"
|
||||||
|
else:
|
||||||
|
return f"{base_url}/v1/chat/completions"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
|
||||||
|
"""构建OpenAI API认证头"""
|
||||||
|
return {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_protected_header_keys(cls) -> tuple:
|
||||||
|
"""返回OpenAI API的保护头部key"""
|
||||||
|
return ("authorization", "content-type")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""构建OpenAI API请求体"""
|
||||||
|
return request_data.copy()
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["OpenAIChatAdapter"]
|
__all__ = ["OpenAIChatAdapter"]
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ OpenAI CLI Adapter - 基于通用 CLI Adapter 基类的简化实现
|
|||||||
继承 CliAdapterBase,只需配置 FORMAT_ID 和 HANDLER_CLASS。
|
继承 CliAdapterBase,只需配置 FORMAT_ID 和 HANDLER_CLASS。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict, Optional, Tuple, Type
|
from typing import Any, AsyncIterator, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
@@ -57,7 +57,7 @@ class OpenAICliAdapter(CliAdapterBase):
|
|||||||
) -> Tuple[list, Optional[str]]:
|
) -> Tuple[list, Optional[str]]:
|
||||||
"""查询 OpenAI 兼容 API 支持的模型列表(带 CLI User-Agent)"""
|
"""查询 OpenAI 兼容 API 支持的模型列表(带 CLI User-Agent)"""
|
||||||
# 复用 OpenAIChatAdapter 的实现,添加 CLI User-Agent
|
# 复用 OpenAIChatAdapter 的实现,添加 CLI User-Agent
|
||||||
cli_headers = {"User-Agent": config.internal_user_agent_openai}
|
cli_headers = {"User-Agent": config.internal_user_agent_openai_cli}
|
||||||
if extra_headers:
|
if extra_headers:
|
||||||
cli_headers.update(extra_headers)
|
cli_headers.update(extra_headers)
|
||||||
models, error = await OpenAIChatAdapter.fetch_models(
|
models, error = await OpenAIChatAdapter.fetch_models(
|
||||||
@@ -68,5 +68,37 @@ class OpenAICliAdapter(CliAdapterBase):
|
|||||||
m["api_format"] = cls.FORMAT_ID
|
m["api_format"] = cls.FORMAT_ID
|
||||||
return models, error
|
return models, error
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_endpoint_url(cls, base_url: str, request_data: Dict[str, Any], model_name: Optional[str] = None) -> str:
|
||||||
|
"""构建OpenAI CLI API端点URL"""
|
||||||
|
base_url = base_url.rstrip("/")
|
||||||
|
if base_url.endswith("/v1"):
|
||||||
|
return f"{base_url}/chat/completions"
|
||||||
|
else:
|
||||||
|
return f"{base_url}/v1/chat/completions"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
|
||||||
|
"""构建OpenAI CLI API认证头"""
|
||||||
|
return {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_protected_header_keys(cls) -> tuple:
|
||||||
|
"""返回OpenAI CLI API的保护头部key"""
|
||||||
|
return ("authorization", "content-type")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""构建OpenAI CLI API请求体"""
|
||||||
|
return request_data.copy()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_cli_user_agent(cls) -> Optional[str]:
|
||||||
|
"""获取OpenAI CLI User-Agent"""
|
||||||
|
return config.internal_user_agent_openai_cli
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["OpenAICliAdapter"]
|
__all__ = ["OpenAICliAdapter"]
|
||||||
|
|||||||
@@ -77,7 +77,10 @@ class ConcurrencyDefaults:
|
|||||||
MAX_CONCURRENT_LIMIT = 200
|
MAX_CONCURRENT_LIMIT = 200
|
||||||
|
|
||||||
# 最小并发限制下限
|
# 最小并发限制下限
|
||||||
MIN_CONCURRENT_LIMIT = 1
|
# 设置为 3 而不是 1,因为预留机制(10%预留给缓存用户)会导致
|
||||||
|
# 当 learned_max_concurrent=1 时新用户实际可用槽位为 0,永远无法命中
|
||||||
|
# 注意:当 limit < 10 时,预留机制实际不生效(预留槽位 = 0),这是可接受的
|
||||||
|
MIN_CONCURRENT_LIMIT = 3
|
||||||
|
|
||||||
# === 探测性扩容参数 ===
|
# === 探测性扩容参数 ===
|
||||||
# 探测性扩容间隔(分钟)- 长时间无 429 且有流量时尝试扩容
|
# 探测性扩容间隔(分钟)- 长时间无 429 且有流量时尝试扩容
|
||||||
|
|||||||
@@ -56,10 +56,11 @@ class Config:
|
|||||||
|
|
||||||
# Redis 依赖策略(生产默认必需,开发默认可选,可通过 REDIS_REQUIRED 覆盖)
|
# Redis 依赖策略(生产默认必需,开发默认可选,可通过 REDIS_REQUIRED 覆盖)
|
||||||
redis_required_env = os.getenv("REDIS_REQUIRED")
|
redis_required_env = os.getenv("REDIS_REQUIRED")
|
||||||
if redis_required_env is None:
|
if redis_required_env is not None:
|
||||||
self.require_redis = self.environment not in {"development", "test", "testing"}
|
|
||||||
else:
|
|
||||||
self.require_redis = redis_required_env.lower() == "true"
|
self.require_redis = redis_required_env.lower() == "true"
|
||||||
|
else:
|
||||||
|
# 保持向后兼容:开发环境可选,生产环境必需
|
||||||
|
self.require_redis = self.environment not in {"development", "test", "testing"}
|
||||||
|
|
||||||
# CORS配置 - 使用环境变量配置允许的源
|
# CORS配置 - 使用环境变量配置允许的源
|
||||||
# 格式: 逗号分隔的域名列表,如 "http://localhost:3000,https://example.com"
|
# 格式: 逗号分隔的域名列表,如 "http://localhost:3000,https://example.com"
|
||||||
@@ -133,6 +134,18 @@ class Config:
|
|||||||
self.concurrency_slot_ttl = int(os.getenv("CONCURRENCY_SLOT_TTL", "600"))
|
self.concurrency_slot_ttl = int(os.getenv("CONCURRENCY_SLOT_TTL", "600"))
|
||||||
self.cache_reservation_ratio = float(os.getenv("CACHE_RESERVATION_RATIO", "0.1"))
|
self.cache_reservation_ratio = float(os.getenv("CACHE_RESERVATION_RATIO", "0.1"))
|
||||||
|
|
||||||
|
# 限流降级策略配置
|
||||||
|
# RATE_LIMIT_FAIL_OPEN: 当限流服务(Redis)异常时的行为
|
||||||
|
#
|
||||||
|
# True (默认): fail-open - 放行请求(优先可用性)
|
||||||
|
# 风险:Redis 故障期间无法限流,可能被滥用
|
||||||
|
# 适用:API 网关作为关键基础设施,必须保持高可用
|
||||||
|
#
|
||||||
|
# False: fail-close - 拒绝所有请求(优先安全性)
|
||||||
|
# 风险:Redis 故障会导致 API 网关不可用
|
||||||
|
# 适用:有严格速率限制要求的安全敏感场景
|
||||||
|
self.rate_limit_fail_open = os.getenv("RATE_LIMIT_FAIL_OPEN", "true").lower() == "true"
|
||||||
|
|
||||||
# HTTP 请求超时配置(秒)
|
# HTTP 请求超时配置(秒)
|
||||||
self.http_connect_timeout = float(os.getenv("HTTP_CONNECT_TIMEOUT", "10.0"))
|
self.http_connect_timeout = float(os.getenv("HTTP_CONNECT_TIMEOUT", "10.0"))
|
||||||
self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0"))
|
self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0"))
|
||||||
@@ -141,19 +154,22 @@ class Config:
|
|||||||
# 流式处理配置
|
# 流式处理配置
|
||||||
# STREAM_PREFETCH_LINES: 预读行数,用于检测嵌套错误
|
# STREAM_PREFETCH_LINES: 预读行数,用于检测嵌套错误
|
||||||
# STREAM_STATS_DELAY: 统计记录延迟(秒),等待流完全关闭
|
# STREAM_STATS_DELAY: 统计记录延迟(秒),等待流完全关闭
|
||||||
|
# STREAM_FIRST_BYTE_TIMEOUT: 首字节超时(秒),等待首字节超过此时间触发故障转移
|
||||||
|
# 范围: 10-120 秒,默认 30 秒(必须小于 http_write_timeout 避免竞态)
|
||||||
self.stream_prefetch_lines = int(os.getenv("STREAM_PREFETCH_LINES", "5"))
|
self.stream_prefetch_lines = int(os.getenv("STREAM_PREFETCH_LINES", "5"))
|
||||||
self.stream_stats_delay = float(os.getenv("STREAM_STATS_DELAY", "0.1"))
|
self.stream_stats_delay = float(os.getenv("STREAM_STATS_DELAY", "0.1"))
|
||||||
|
self.stream_first_byte_timeout = self._parse_ttfb_timeout()
|
||||||
|
|
||||||
# 内部请求 User-Agent 配置(用于查询上游模型列表等)
|
# 内部请求 User-Agent 配置(用于查询上游模型列表等)
|
||||||
# 可通过环境变量覆盖默认值
|
# 可通过环境变量覆盖默认值,模拟对应 CLI 客户端
|
||||||
self.internal_user_agent_claude = os.getenv(
|
self.internal_user_agent_claude_cli = os.getenv(
|
||||||
"CLAUDE_USER_AGENT", "claude-cli/1.0"
|
"CLAUDE_CLI_USER_AGENT", "claude-code/1.0.1"
|
||||||
)
|
)
|
||||||
self.internal_user_agent_openai = os.getenv(
|
self.internal_user_agent_openai_cli = os.getenv(
|
||||||
"OPENAI_USER_AGENT", "openai-cli/1.0"
|
"OPENAI_CLI_USER_AGENT", "openai-codex/1.0"
|
||||||
)
|
)
|
||||||
self.internal_user_agent_gemini = os.getenv(
|
self.internal_user_agent_gemini_cli = os.getenv(
|
||||||
"GEMINI_USER_AGENT", "gemini-cli/1.0"
|
"GEMINI_CLI_USER_AGENT", "gemini-cli/0.1.0"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 验证连接池配置
|
# 验证连接池配置
|
||||||
@@ -177,6 +193,39 @@ class Config:
|
|||||||
"""智能计算最大溢出连接数 - 与 pool_size 相同"""
|
"""智能计算最大溢出连接数 - 与 pool_size 相同"""
|
||||||
return self.db_pool_size
|
return self.db_pool_size
|
||||||
|
|
||||||
|
def _parse_ttfb_timeout(self) -> float:
|
||||||
|
"""
|
||||||
|
解析 TTFB 超时配置,带错误处理和范围限制
|
||||||
|
|
||||||
|
TTFB (Time To First Byte) 用于检测慢响应的 Provider,超时触发故障转移。
|
||||||
|
此值必须小于 http_write_timeout,避免竞态条件。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
超时时间(秒),范围 10-120,默认 30
|
||||||
|
"""
|
||||||
|
default_timeout = 30.0
|
||||||
|
min_timeout = 10.0
|
||||||
|
max_timeout = 120.0 # 必须小于 http_write_timeout (默认 60s) 的 2 倍
|
||||||
|
|
||||||
|
raw_value = os.getenv("STREAM_FIRST_BYTE_TIMEOUT", str(default_timeout))
|
||||||
|
try:
|
||||||
|
timeout = float(raw_value)
|
||||||
|
except ValueError:
|
||||||
|
# 延迟导入,避免循环依赖(Config 初始化时 logger 可能未就绪)
|
||||||
|
self._ttfb_config_warning = (
|
||||||
|
f"无效的 STREAM_FIRST_BYTE_TIMEOUT 配置 '{raw_value}',使用默认值 {default_timeout}秒"
|
||||||
|
)
|
||||||
|
return default_timeout
|
||||||
|
|
||||||
|
# 范围限制
|
||||||
|
clamped = max(min_timeout, min(max_timeout, timeout))
|
||||||
|
if clamped != timeout:
|
||||||
|
self._ttfb_config_warning = (
|
||||||
|
f"STREAM_FIRST_BYTE_TIMEOUT={timeout}秒超出范围 [{min_timeout}-{max_timeout}],"
|
||||||
|
f"已调整为 {clamped}秒"
|
||||||
|
)
|
||||||
|
return clamped
|
||||||
|
|
||||||
def _validate_pool_config(self) -> None:
|
def _validate_pool_config(self) -> None:
|
||||||
"""验证连接池配置是否安全"""
|
"""验证连接池配置是否安全"""
|
||||||
total_per_worker = self.db_pool_size + self.db_max_overflow
|
total_per_worker = self.db_pool_size + self.db_max_overflow
|
||||||
@@ -224,6 +273,10 @@ class Config:
|
|||||||
if hasattr(self, "_pool_config_warning") and self._pool_config_warning:
|
if hasattr(self, "_pool_config_warning") and self._pool_config_warning:
|
||||||
logger.warning(self._pool_config_warning)
|
logger.warning(self._pool_config_warning)
|
||||||
|
|
||||||
|
# TTFB 超时配置警告
|
||||||
|
if hasattr(self, "_ttfb_config_warning") and self._ttfb_config_warning:
|
||||||
|
logger.warning(self._ttfb_config_warning)
|
||||||
|
|
||||||
# 管理员密码检查(必须在环境变量中设置)
|
# 管理员密码检查(必须在环境变量中设置)
|
||||||
if hasattr(self, "_missing_admin_password") and self._missing_admin_password:
|
if hasattr(self, "_missing_admin_password") and self._missing_admin_password:
|
||||||
logger.error("必须设置 ADMIN_PASSWORD 环境变量!")
|
logger.error("必须设置 ADMIN_PASSWORD 环境变量!")
|
||||||
|
|||||||
@@ -336,10 +336,44 @@ class PluginMiddleware:
|
|||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
return None
|
return None
|
||||||
|
except ConnectionError as e:
|
||||||
|
# Redis 连接错误:根据配置决定
|
||||||
|
logger.warning(f"Rate limit connection error: {e}")
|
||||||
|
if config.rate_limit_fail_open:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return RateLimitResult(
|
||||||
|
allowed=False,
|
||||||
|
remaining=0,
|
||||||
|
retry_after=30,
|
||||||
|
message="Rate limit service unavailable"
|
||||||
|
)
|
||||||
|
except TimeoutError as e:
|
||||||
|
# 超时错误:可能是负载过高,根据配置决定
|
||||||
|
logger.warning(f"Rate limit timeout: {e}")
|
||||||
|
if config.rate_limit_fail_open:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return RateLimitResult(
|
||||||
|
allowed=False,
|
||||||
|
remaining=0,
|
||||||
|
retry_after=30,
|
||||||
|
message="Rate limit service timeout"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Rate limit error: {e}")
|
logger.error(f"Rate limit error: {type(e).__name__}: {e}")
|
||||||
# 发生错误时允许请求通过
|
# 其他异常:根据配置决定
|
||||||
return None
|
if config.rate_limit_fail_open:
|
||||||
|
# fail-open: 异常时放行请求(优先可用性)
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
# fail-close: 异常时拒绝请求(优先安全性)
|
||||||
|
return RateLimitResult(
|
||||||
|
allowed=False,
|
||||||
|
remaining=0,
|
||||||
|
retry_after=60,
|
||||||
|
message="Rate limit service error"
|
||||||
|
)
|
||||||
|
|
||||||
async def _call_pre_request_plugins(self, request: Request) -> None:
|
async def _call_pre_request_plugins(self, request: Request) -> None:
|
||||||
"""调用请求前的插件(当前保留扩展点)"""
|
"""调用请求前的插件(当前保留扩展点)"""
|
||||||
|
|||||||
@@ -317,6 +317,7 @@ class UpdateUserRequest(BaseModel):
|
|||||||
|
|
||||||
username: Optional[str] = Field(None, min_length=1, max_length=50)
|
username: Optional[str] = Field(None, min_length=1, max_length=50)
|
||||||
email: Optional[str] = Field(None, max_length=100)
|
email: Optional[str] = Field(None, max_length=100)
|
||||||
|
password: Optional[str] = Field(None, min_length=6, max_length=128, description="新密码(留空保持不变)")
|
||||||
quota_usd: Optional[float] = Field(None, ge=0)
|
quota_usd: Optional[float] = Field(None, ge=0)
|
||||||
is_active: Optional[bool] = None
|
is_active: Optional[bool] = None
|
||||||
role: Optional[str] = None
|
role: Optional[str] = None
|
||||||
|
|||||||
@@ -226,8 +226,11 @@ class EndpointAPIKeyUpdate(BaseModel):
|
|||||||
global_priority: Optional[int] = Field(
|
global_priority: Optional[int] = Field(
|
||||||
default=None, description="全局 Key 优先级(全局 Key 优先模式,数字越小越优先)"
|
default=None, description="全局 Key 优先级(全局 Key 优先模式,数字越小越优先)"
|
||||||
)
|
)
|
||||||
# 注意:max_concurrent=None 表示不更新,要切换为自适应模式请使用专用 API
|
# max_concurrent: 使用特殊标记区分"未提供"和"设置为 null(自适应模式)"
|
||||||
max_concurrent: Optional[int] = Field(default=None, ge=1, description="最大并发数")
|
# - 不提供字段:不更新
|
||||||
|
# - 提供 null:切换为自适应模式
|
||||||
|
# - 提供数字:设置固定并发限制
|
||||||
|
max_concurrent: Optional[int] = Field(default=None, ge=1, description="最大并发数(null=自适应模式)")
|
||||||
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制")
|
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制")
|
||||||
daily_limit: Optional[int] = Field(default=None, ge=1, description="每日限制")
|
daily_limit: Optional[int] = Field(default=None, ge=1, description="每日限制")
|
||||||
monthly_limit: Optional[int] = Field(default=None, ge=1, description="每月限制")
|
monthly_limit: Optional[int] = Field(default=None, ge=1, description="每月限制")
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ if not config.jwt_secret_key:
|
|||||||
if config.environment == "production":
|
if config.environment == "production":
|
||||||
raise ValueError("JWT_SECRET_KEY must be set in production environment!")
|
raise ValueError("JWT_SECRET_KEY must be set in production environment!")
|
||||||
config.jwt_secret_key = secrets.token_urlsafe(32)
|
config.jwt_secret_key = secrets.token_urlsafe(32)
|
||||||
logger.warning(f"JWT_SECRET_KEY未在环境变量中找到,已生成随机密钥用于开发: {config.jwt_secret_key[:10]}...")
|
logger.warning("JWT_SECRET_KEY未在环境变量中找到,已生成随机密钥用于开发")
|
||||||
logger.warning("生产环境请设置JWT_SECRET_KEY环境变量!")
|
logger.warning("生产环境请设置JWT_SECRET_KEY环境变量!")
|
||||||
|
|
||||||
JWT_SECRET_KEY = config.jwt_secret_key
|
JWT_SECRET_KEY = config.jwt_secret_key
|
||||||
|
|||||||
55
src/services/cache/aware_scheduler.py
vendored
55
src/services/cache/aware_scheduler.py
vendored
@@ -30,6 +30,8 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import random
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||||
@@ -956,7 +958,16 @@ class CacheAwareScheduler:
|
|||||||
|
|
||||||
# 获取活跃的 Key 并按 internal_priority + 负载均衡排序
|
# 获取活跃的 Key 并按 internal_priority + 负载均衡排序
|
||||||
active_keys = [key for key in endpoint.api_keys if key.is_active]
|
active_keys = [key for key in endpoint.api_keys if key.is_active]
|
||||||
keys = self._shuffle_keys_by_internal_priority(active_keys, affinity_key)
|
# 检查是否所有 Key 都是 TTL=0(轮换模式)
|
||||||
|
# 如果所有 Key 的 cache_ttl_minutes 都是 0 或 None,则使用随机排序
|
||||||
|
use_random = all(
|
||||||
|
(key.cache_ttl_minutes or 0) == 0 for key in active_keys
|
||||||
|
) if active_keys else False
|
||||||
|
if use_random and len(active_keys) > 1:
|
||||||
|
logger.debug(
|
||||||
|
f" Endpoint {endpoint.id[:8]}... 启用 Key 轮换模式 (TTL=0, {len(active_keys)} keys)"
|
||||||
|
)
|
||||||
|
keys = self._shuffle_keys_by_internal_priority(active_keys, affinity_key, use_random)
|
||||||
|
|
||||||
for key in keys:
|
for key in keys:
|
||||||
# Key 级别的能力检查(模型级别的能力检查已在上面完成)
|
# Key 级别的能力检查(模型级别的能力检查已在上面完成)
|
||||||
@@ -1170,6 +1181,7 @@ class CacheAwareScheduler:
|
|||||||
self,
|
self,
|
||||||
keys: List[ProviderAPIKey],
|
keys: List[ProviderAPIKey],
|
||||||
affinity_key: Optional[str] = None,
|
affinity_key: Optional[str] = None,
|
||||||
|
use_random: bool = False,
|
||||||
) -> List[ProviderAPIKey]:
|
) -> List[ProviderAPIKey]:
|
||||||
"""
|
"""
|
||||||
对 API Key 按 internal_priority 分组,同优先级内部基于 affinity_key 进行确定性打乱
|
对 API Key 按 internal_priority 分组,同优先级内部基于 affinity_key 进行确定性打乱
|
||||||
@@ -1178,10 +1190,12 @@ class CacheAwareScheduler:
|
|||||||
- 数字越小越优先使用
|
- 数字越小越优先使用
|
||||||
- 同优先级 Key 之间实现负载均衡
|
- 同优先级 Key 之间实现负载均衡
|
||||||
- 使用 affinity_key 哈希确保同一请求 Key 的请求稳定(避免破坏缓存亲和性)
|
- 使用 affinity_key 哈希确保同一请求 Key 的请求稳定(避免破坏缓存亲和性)
|
||||||
|
- 当 use_random=True 时,使用随机排序实现轮换(用于 TTL=0 的场景)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
keys: API Key 列表
|
keys: API Key 列表
|
||||||
affinity_key: 亲和性标识符(通常为 API Key ID,用于确定性打乱)
|
affinity_key: 亲和性标识符(通常为 API Key ID,用于确定性打乱)
|
||||||
|
use_random: 是否使用随机排序(TTL=0 时为 True)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
排序后的 Key 列表
|
排序后的 Key 列表
|
||||||
@@ -1198,28 +1212,35 @@ class CacheAwareScheduler:
|
|||||||
priority = key.internal_priority if key.internal_priority is not None else 999999
|
priority = key.internal_priority if key.internal_priority is not None else 999999
|
||||||
priority_groups[priority].append(key)
|
priority_groups[priority].append(key)
|
||||||
|
|
||||||
# 对每个优先级组内的 Key 进行确定性打乱
|
# 对每个优先级组内的 Key 进行打乱
|
||||||
result = []
|
result = []
|
||||||
for priority in sorted(priority_groups.keys()): # 数字小的优先级高,排前面
|
for priority in sorted(priority_groups.keys()): # 数字小的优先级高,排前面
|
||||||
group_keys = priority_groups[priority]
|
group_keys = priority_groups[priority]
|
||||||
|
|
||||||
if len(group_keys) > 1 and affinity_key:
|
if len(group_keys) > 1:
|
||||||
# 改进的哈希策略:为每个 key 计算独立的哈希值
|
if use_random:
|
||||||
import hashlib
|
# TTL=0 模式:使用随机排序实现 Key 轮换
|
||||||
|
shuffled = list(group_keys)
|
||||||
|
random.shuffle(shuffled)
|
||||||
|
result.extend(shuffled)
|
||||||
|
elif affinity_key:
|
||||||
|
# 正常模式:使用哈希确定性打乱(保持缓存亲和性)
|
||||||
|
key_scores = []
|
||||||
|
for key in group_keys:
|
||||||
|
# 使用 affinity_key + key.id 的组合哈希
|
||||||
|
hash_input = f"{affinity_key}:{key.id}"
|
||||||
|
hash_value = int(hashlib.sha256(hash_input.encode()).hexdigest()[:16], 16)
|
||||||
|
key_scores.append((hash_value, key))
|
||||||
|
|
||||||
key_scores = []
|
# 按哈希值排序
|
||||||
for key in group_keys:
|
sorted_group = [key for _, key in sorted(key_scores)]
|
||||||
# 使用 affinity_key + key.id 的组合哈希
|
result.extend(sorted_group)
|
||||||
hash_input = f"{affinity_key}:{key.id}"
|
else:
|
||||||
hash_value = int(hashlib.sha256(hash_input.encode()).hexdigest()[:16], 16)
|
# 没有 affinity_key 时按 ID 排序保持稳定性
|
||||||
key_scores.append((hash_value, key))
|
result.extend(sorted(group_keys, key=lambda k: k.id))
|
||||||
|
|
||||||
# 按哈希值排序
|
|
||||||
sorted_group = [key for _, key in sorted(key_scores)]
|
|
||||||
result.extend(sorted_group)
|
|
||||||
else:
|
else:
|
||||||
# 单个 Key 或没有 affinity_key 时保持原顺序
|
# 单个 Key 直接添加
|
||||||
result.extend(sorted(group_keys, key=lambda k: k.id))
|
result.extend(group_keys)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
@@ -234,8 +234,15 @@ class EndpointHealthService:
|
|||||||
for api_format in format_key_mapping.keys()
|
for api_format in format_key_mapping.keys()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 参数校验(API 层已通过 Query(ge=1) 保证,这里做防御性检查)
|
||||||
|
if lookback_hours <= 0 or segments <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"lookback_hours and segments must be positive, "
|
||||||
|
f"got lookback_hours={lookback_hours}, segments={segments}"
|
||||||
|
)
|
||||||
|
|
||||||
# 计算时间范围
|
# 计算时间范围
|
||||||
interval_minutes = (lookback_hours * 60) // segments
|
segment_seconds = (lookback_hours * 3600) / segments
|
||||||
start_time = now - timedelta(hours=lookback_hours)
|
start_time = now - timedelta(hours=lookback_hours)
|
||||||
|
|
||||||
# 使用 RequestCandidate 表查询所有尝试记录
|
# 使用 RequestCandidate 表查询所有尝试记录
|
||||||
@@ -243,7 +250,7 @@ class EndpointHealthService:
|
|||||||
final_statuses = ["success", "failed", "skipped"]
|
final_statuses = ["success", "failed", "skipped"]
|
||||||
|
|
||||||
segment_expr = func.floor(
|
segment_expr = func.floor(
|
||||||
func.extract('epoch', RequestCandidate.created_at - start_time) / (interval_minutes * 60)
|
func.extract('epoch', RequestCandidate.created_at - start_time) / segment_seconds
|
||||||
).label('segment_idx')
|
).label('segment_idx')
|
||||||
|
|
||||||
candidate_stats = (
|
candidate_stats = (
|
||||||
|
|||||||
@@ -208,86 +208,120 @@ class CleanupScheduler:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# 非首次运行,检查最近是否有缺失的日期需要回填
|
# 非首次运行,检查最近是否有缺失的日期需要回填
|
||||||
latest_stat = db.query(StatsDaily).order_by(StatsDaily.date.desc()).first()
|
from src.models.database import StatsDailyModel
|
||||||
|
|
||||||
if latest_stat:
|
yesterday_business_date = today_local.date() - timedelta(days=1)
|
||||||
latest_date_utc = latest_stat.date
|
max_backfill_days: int = SystemConfigService.get_config(
|
||||||
if latest_date_utc.tzinfo is None:
|
db, "max_stats_backfill_days", 30
|
||||||
latest_date_utc = latest_date_utc.replace(tzinfo=timezone.utc)
|
) or 30
|
||||||
else:
|
|
||||||
latest_date_utc = latest_date_utc.astimezone(timezone.utc)
|
|
||||||
|
|
||||||
# 使用业务日期计算缺失区间(避免用 UTC 年月日导致日期偏移,且对 DST 更安全)
|
# 计算回填检查的起始日期
|
||||||
latest_business_date = latest_date_utc.astimezone(app_tz).date()
|
check_start_date = yesterday_business_date - timedelta(
|
||||||
yesterday_business_date = today_local.date() - timedelta(days=1)
|
days=max_backfill_days - 1
|
||||||
missing_start_date = latest_business_date + timedelta(days=1)
|
)
|
||||||
|
|
||||||
if missing_start_date <= yesterday_business_date:
|
# 获取 StatsDaily 和 StatsDailyModel 中已有数据的日期集合
|
||||||
missing_days = (
|
existing_daily_dates = set()
|
||||||
yesterday_business_date - missing_start_date
|
existing_model_dates = set()
|
||||||
).days + 1
|
|
||||||
|
|
||||||
# 限制最大回填天数,防止停机很久后一次性回填太多
|
daily_stats = (
|
||||||
max_backfill_days: int = SystemConfigService.get_config(
|
db.query(StatsDaily.date)
|
||||||
db, "max_stats_backfill_days", 30
|
.filter(StatsDaily.date >= check_start_date.isoformat())
|
||||||
) or 30
|
.all()
|
||||||
if missing_days > max_backfill_days:
|
)
|
||||||
logger.warning(
|
for (stat_date,) in daily_stats:
|
||||||
f"缺失 {missing_days} 天数据超过最大回填限制 "
|
if stat_date.tzinfo is None:
|
||||||
f"{max_backfill_days} 天,只回填最近 {max_backfill_days} 天"
|
stat_date = stat_date.replace(tzinfo=timezone.utc)
|
||||||
|
existing_daily_dates.add(stat_date.astimezone(app_tz).date())
|
||||||
|
|
||||||
|
model_stats = (
|
||||||
|
db.query(StatsDailyModel.date)
|
||||||
|
.filter(StatsDailyModel.date >= check_start_date.isoformat())
|
||||||
|
.distinct()
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
for (stat_date,) in model_stats:
|
||||||
|
if stat_date.tzinfo is None:
|
||||||
|
stat_date = stat_date.replace(tzinfo=timezone.utc)
|
||||||
|
existing_model_dates.add(stat_date.astimezone(app_tz).date())
|
||||||
|
|
||||||
|
# 找出需要回填的日期
|
||||||
|
all_dates = set()
|
||||||
|
current = check_start_date
|
||||||
|
while current <= yesterday_business_date:
|
||||||
|
all_dates.add(current)
|
||||||
|
current += timedelta(days=1)
|
||||||
|
|
||||||
|
# 需要回填 StatsDaily 的日期
|
||||||
|
missing_daily_dates = all_dates - existing_daily_dates
|
||||||
|
# 需要回填 StatsDailyModel 的日期
|
||||||
|
missing_model_dates = all_dates - existing_model_dates
|
||||||
|
# 合并所有需要处理的日期
|
||||||
|
dates_to_process = missing_daily_dates | missing_model_dates
|
||||||
|
|
||||||
|
if dates_to_process:
|
||||||
|
sorted_dates = sorted(dates_to_process)
|
||||||
|
logger.info(
|
||||||
|
f"检测到 {len(dates_to_process)} 天的统计数据需要回填 "
|
||||||
|
f"(StatsDaily 缺失 {len(missing_daily_dates)} 天, "
|
||||||
|
f"StatsDailyModel 缺失 {len(missing_model_dates)} 天)"
|
||||||
|
)
|
||||||
|
|
||||||
|
users = (
|
||||||
|
db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
|
||||||
|
)
|
||||||
|
|
||||||
|
failed_dates = 0
|
||||||
|
failed_users = 0
|
||||||
|
|
||||||
|
for current_date in sorted_dates:
|
||||||
|
try:
|
||||||
|
current_date_local = datetime.combine(
|
||||||
|
current_date, datetime.min.time(), tzinfo=app_tz
|
||||||
)
|
)
|
||||||
missing_start_date = yesterday_business_date - timedelta(
|
# 只在缺失时才聚合对应的表
|
||||||
days=max_backfill_days - 1
|
if current_date in missing_daily_dates:
|
||||||
)
|
|
||||||
missing_days = max_backfill_days
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"检测到缺失 {missing_days} 天的统计数据 "
|
|
||||||
f"({missing_start_date} ~ {yesterday_business_date}),开始回填..."
|
|
||||||
)
|
|
||||||
|
|
||||||
current_date = missing_start_date
|
|
||||||
users = (
|
|
||||||
db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
|
|
||||||
)
|
|
||||||
|
|
||||||
while current_date <= yesterday_business_date:
|
|
||||||
try:
|
|
||||||
current_date_local = datetime.combine(
|
|
||||||
current_date, datetime.min.time(), tzinfo=app_tz
|
|
||||||
)
|
|
||||||
StatsAggregatorService.aggregate_daily_stats(
|
StatsAggregatorService.aggregate_daily_stats(
|
||||||
db, current_date_local
|
db, current_date_local
|
||||||
)
|
)
|
||||||
|
if current_date in missing_model_dates:
|
||||||
StatsAggregatorService.aggregate_daily_model_stats(
|
StatsAggregatorService.aggregate_daily_model_stats(
|
||||||
db, current_date_local
|
db, current_date_local
|
||||||
)
|
)
|
||||||
for (user_id,) in users:
|
# 用户统计在任一缺失时都回填
|
||||||
try:
|
for (user_id,) in users:
|
||||||
StatsAggregatorService.aggregate_user_daily_stats(
|
|
||||||
db, user_id, current_date_local
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(
|
|
||||||
f"回填用户 {user_id} 日期 {current_date} 失败: {e}"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
db.rollback()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"回填日期 {current_date} 失败: {e}")
|
|
||||||
try:
|
try:
|
||||||
db.rollback()
|
StatsAggregatorService.aggregate_user_daily_stats(
|
||||||
except Exception:
|
db, user_id, current_date_local
|
||||||
pass
|
)
|
||||||
|
except Exception as e:
|
||||||
|
failed_users += 1
|
||||||
|
logger.warning(
|
||||||
|
f"回填用户 {user_id} 日期 {current_date} 失败: {e}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
db.rollback()
|
||||||
|
except Exception as rollback_err:
|
||||||
|
logger.error(f"回滚失败: {rollback_err}")
|
||||||
|
except Exception as e:
|
||||||
|
failed_dates += 1
|
||||||
|
logger.warning(f"回填日期 {current_date} 失败: {e}")
|
||||||
|
try:
|
||||||
|
db.rollback()
|
||||||
|
except Exception as rollback_err:
|
||||||
|
logger.error(f"回滚失败: {rollback_err}")
|
||||||
|
|
||||||
current_date += timedelta(days=1)
|
StatsAggregatorService.update_summary(db)
|
||||||
|
|
||||||
StatsAggregatorService.update_summary(db)
|
if failed_dates > 0 or failed_users > 0:
|
||||||
logger.info(f"缺失数据回填完成,共 {missing_days} 天")
|
logger.warning(
|
||||||
|
f"回填完成,共处理 {len(dates_to_process)} 天,"
|
||||||
|
f"失败: {failed_dates} 天, {failed_users} 个用户记录"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("统计数据已是最新,无需回填")
|
logger.info(f"缺失数据回填完成,共处理 {len(dates_to_process)} 天")
|
||||||
|
else:
|
||||||
|
logger.info("统计数据已是最新,无需回填")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 定时任务:聚合昨天的数据
|
# 定时任务:聚合昨天的数据
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
@@ -16,6 +17,71 @@ from src.services.model.cost import ModelCostService
|
|||||||
from src.services.system.config import SystemConfigService
|
from src.services.system.config import SystemConfigService
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UsageRecordParams:
|
||||||
|
"""用量记录参数数据类,用于在内部方法间传递数据"""
|
||||||
|
db: Session
|
||||||
|
user: Optional[User]
|
||||||
|
api_key: Optional[ApiKey]
|
||||||
|
provider: str
|
||||||
|
model: str
|
||||||
|
input_tokens: int
|
||||||
|
output_tokens: int
|
||||||
|
cache_creation_input_tokens: int
|
||||||
|
cache_read_input_tokens: int
|
||||||
|
request_type: str
|
||||||
|
api_format: Optional[str]
|
||||||
|
is_stream: bool
|
||||||
|
response_time_ms: Optional[int]
|
||||||
|
first_byte_time_ms: Optional[int]
|
||||||
|
status_code: int
|
||||||
|
error_message: Optional[str]
|
||||||
|
metadata: Optional[Dict[str, Any]]
|
||||||
|
request_headers: Optional[Dict[str, Any]]
|
||||||
|
request_body: Optional[Any]
|
||||||
|
provider_request_headers: Optional[Dict[str, Any]]
|
||||||
|
response_headers: Optional[Dict[str, Any]]
|
||||||
|
response_body: Optional[Any]
|
||||||
|
request_id: str
|
||||||
|
provider_id: Optional[str]
|
||||||
|
provider_endpoint_id: Optional[str]
|
||||||
|
provider_api_key_id: Optional[str]
|
||||||
|
status: str
|
||||||
|
cache_ttl_minutes: Optional[int]
|
||||||
|
use_tiered_pricing: bool
|
||||||
|
target_model: Optional[str]
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""验证关键字段,确保数据完整性"""
|
||||||
|
# Token 数量不能为负数
|
||||||
|
if self.input_tokens < 0:
|
||||||
|
raise ValueError(f"input_tokens 不能为负数: {self.input_tokens}")
|
||||||
|
if self.output_tokens < 0:
|
||||||
|
raise ValueError(f"output_tokens 不能为负数: {self.output_tokens}")
|
||||||
|
if self.cache_creation_input_tokens < 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"cache_creation_input_tokens 不能为负数: {self.cache_creation_input_tokens}"
|
||||||
|
)
|
||||||
|
if self.cache_read_input_tokens < 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"cache_read_input_tokens 不能为负数: {self.cache_read_input_tokens}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 响应时间不能为负数
|
||||||
|
if self.response_time_ms is not None and self.response_time_ms < 0:
|
||||||
|
raise ValueError(f"response_time_ms 不能为负数: {self.response_time_ms}")
|
||||||
|
if self.first_byte_time_ms is not None and self.first_byte_time_ms < 0:
|
||||||
|
raise ValueError(f"first_byte_time_ms 不能为负数: {self.first_byte_time_ms}")
|
||||||
|
|
||||||
|
# HTTP 状态码范围校验
|
||||||
|
if not (100 <= self.status_code <= 599):
|
||||||
|
raise ValueError(f"无效的 HTTP 状态码: {self.status_code}")
|
||||||
|
|
||||||
|
# 状态值校验
|
||||||
|
valid_statuses = {"pending", "streaming", "completed", "failed"}
|
||||||
|
if self.status not in valid_statuses:
|
||||||
|
raise ValueError(f"无效的状态值: {self.status},有效值: {valid_statuses}")
|
||||||
|
|
||||||
|
|
||||||
class UsageService:
|
class UsageService:
|
||||||
"""用量统计服务"""
|
"""用量统计服务"""
|
||||||
@@ -471,6 +537,97 @@ class UsageService:
|
|||||||
cache_ttl_minutes=cache_ttl_minutes,
|
cache_ttl_minutes=cache_ttl_minutes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _prepare_usage_record(
|
||||||
|
cls,
|
||||||
|
params: UsageRecordParams,
|
||||||
|
) -> Tuple[Dict[str, Any], float]:
|
||||||
|
"""准备用量记录的共享逻辑
|
||||||
|
|
||||||
|
此方法提取了 record_usage 和 record_usage_async 的公共处理逻辑:
|
||||||
|
- 获取费率倍数
|
||||||
|
- 计算成本
|
||||||
|
- 构建 Usage 参数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: 用量记录参数数据类
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(usage_params 字典, total_cost 总成本)
|
||||||
|
"""
|
||||||
|
# 获取费率倍数和是否免费套餐
|
||||||
|
actual_rate_multiplier, is_free_tier = await cls._get_rate_multiplier_and_free_tier(
|
||||||
|
params.db, params.provider_api_key_id, params.provider_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# 计算成本
|
||||||
|
is_failed_request = params.status_code >= 400 or params.error_message is not None
|
||||||
|
(
|
||||||
|
input_price, output_price, cache_creation_price, cache_read_price, request_price,
|
||||||
|
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost,
|
||||||
|
request_cost, total_cost, _tier_index
|
||||||
|
) = await cls._calculate_costs(
|
||||||
|
db=params.db,
|
||||||
|
provider=params.provider,
|
||||||
|
model=params.model,
|
||||||
|
input_tokens=params.input_tokens,
|
||||||
|
output_tokens=params.output_tokens,
|
||||||
|
cache_creation_input_tokens=params.cache_creation_input_tokens,
|
||||||
|
cache_read_input_tokens=params.cache_read_input_tokens,
|
||||||
|
api_format=params.api_format,
|
||||||
|
cache_ttl_minutes=params.cache_ttl_minutes,
|
||||||
|
use_tiered_pricing=params.use_tiered_pricing,
|
||||||
|
is_failed_request=is_failed_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建 Usage 参数
|
||||||
|
usage_params = cls._build_usage_params(
|
||||||
|
db=params.db,
|
||||||
|
user=params.user,
|
||||||
|
api_key=params.api_key,
|
||||||
|
provider=params.provider,
|
||||||
|
model=params.model,
|
||||||
|
input_tokens=params.input_tokens,
|
||||||
|
output_tokens=params.output_tokens,
|
||||||
|
cache_creation_input_tokens=params.cache_creation_input_tokens,
|
||||||
|
cache_read_input_tokens=params.cache_read_input_tokens,
|
||||||
|
request_type=params.request_type,
|
||||||
|
api_format=params.api_format,
|
||||||
|
is_stream=params.is_stream,
|
||||||
|
response_time_ms=params.response_time_ms,
|
||||||
|
first_byte_time_ms=params.first_byte_time_ms,
|
||||||
|
status_code=params.status_code,
|
||||||
|
error_message=params.error_message,
|
||||||
|
metadata=params.metadata,
|
||||||
|
request_headers=params.request_headers,
|
||||||
|
request_body=params.request_body,
|
||||||
|
provider_request_headers=params.provider_request_headers,
|
||||||
|
response_headers=params.response_headers,
|
||||||
|
response_body=params.response_body,
|
||||||
|
request_id=params.request_id,
|
||||||
|
provider_id=params.provider_id,
|
||||||
|
provider_endpoint_id=params.provider_endpoint_id,
|
||||||
|
provider_api_key_id=params.provider_api_key_id,
|
||||||
|
status=params.status,
|
||||||
|
target_model=params.target_model,
|
||||||
|
input_cost=input_cost,
|
||||||
|
output_cost=output_cost,
|
||||||
|
cache_creation_cost=cache_creation_cost,
|
||||||
|
cache_read_cost=cache_read_cost,
|
||||||
|
cache_cost=cache_cost,
|
||||||
|
request_cost=request_cost,
|
||||||
|
total_cost=total_cost,
|
||||||
|
input_price=input_price,
|
||||||
|
output_price=output_price,
|
||||||
|
cache_creation_price=cache_creation_price,
|
||||||
|
cache_read_price=cache_read_price,
|
||||||
|
request_price=request_price,
|
||||||
|
actual_rate_multiplier=actual_rate_multiplier,
|
||||||
|
is_free_tier=is_free_tier,
|
||||||
|
)
|
||||||
|
|
||||||
|
return usage_params, total_cost
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def record_usage_async(
|
async def record_usage_async(
|
||||||
cls,
|
cls,
|
||||||
@@ -516,76 +673,25 @@ class UsageService:
|
|||||||
if request_id is None:
|
if request_id is None:
|
||||||
request_id = str(uuid.uuid4())[:8]
|
request_id = str(uuid.uuid4())[:8]
|
||||||
|
|
||||||
# 获取费率倍数和是否免费套餐
|
# 使用共享逻辑准备记录参数
|
||||||
actual_rate_multiplier, is_free_tier = await cls._get_rate_multiplier_and_free_tier(
|
params = UsageRecordParams(
|
||||||
db, provider_api_key_id, provider_id
|
db=db, user=user, api_key=api_key, provider=provider, model=model,
|
||||||
)
|
input_tokens=input_tokens, output_tokens=output_tokens,
|
||||||
|
|
||||||
# 计算成本
|
|
||||||
is_failed_request = status_code >= 400 or error_message is not None
|
|
||||||
(
|
|
||||||
input_price, output_price, cache_creation_price, cache_read_price, request_price,
|
|
||||||
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost,
|
|
||||||
request_cost, total_cost, tier_index
|
|
||||||
) = await cls._calculate_costs(
|
|
||||||
db=db,
|
|
||||||
provider=provider,
|
|
||||||
model=model,
|
|
||||||
input_tokens=input_tokens,
|
|
||||||
output_tokens=output_tokens,
|
|
||||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||||
cache_read_input_tokens=cache_read_input_tokens,
|
cache_read_input_tokens=cache_read_input_tokens,
|
||||||
api_format=api_format,
|
request_type=request_type, api_format=api_format, is_stream=is_stream,
|
||||||
cache_ttl_minutes=cache_ttl_minutes,
|
response_time_ms=response_time_ms, first_byte_time_ms=first_byte_time_ms,
|
||||||
use_tiered_pricing=use_tiered_pricing,
|
status_code=status_code, error_message=error_message, metadata=metadata,
|
||||||
is_failed_request=is_failed_request,
|
request_headers=request_headers, request_body=request_body,
|
||||||
)
|
|
||||||
|
|
||||||
# 构建 Usage 参数
|
|
||||||
usage_params = cls._build_usage_params(
|
|
||||||
db=db,
|
|
||||||
user=user,
|
|
||||||
api_key=api_key,
|
|
||||||
provider=provider,
|
|
||||||
model=model,
|
|
||||||
input_tokens=input_tokens,
|
|
||||||
output_tokens=output_tokens,
|
|
||||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
|
||||||
cache_read_input_tokens=cache_read_input_tokens,
|
|
||||||
request_type=request_type,
|
|
||||||
api_format=api_format,
|
|
||||||
is_stream=is_stream,
|
|
||||||
response_time_ms=response_time_ms,
|
|
||||||
first_byte_time_ms=first_byte_time_ms,
|
|
||||||
status_code=status_code,
|
|
||||||
error_message=error_message,
|
|
||||||
metadata=metadata,
|
|
||||||
request_headers=request_headers,
|
|
||||||
request_body=request_body,
|
|
||||||
provider_request_headers=provider_request_headers,
|
provider_request_headers=provider_request_headers,
|
||||||
response_headers=response_headers,
|
response_headers=response_headers, response_body=response_body,
|
||||||
response_body=response_body,
|
request_id=request_id, provider_id=provider_id,
|
||||||
request_id=request_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
provider_endpoint_id=provider_endpoint_id,
|
provider_endpoint_id=provider_endpoint_id,
|
||||||
provider_api_key_id=provider_api_key_id,
|
provider_api_key_id=provider_api_key_id, status=status,
|
||||||
status=status,
|
cache_ttl_minutes=cache_ttl_minutes, use_tiered_pricing=use_tiered_pricing,
|
||||||
target_model=target_model,
|
target_model=target_model,
|
||||||
input_cost=input_cost,
|
|
||||||
output_cost=output_cost,
|
|
||||||
cache_creation_cost=cache_creation_cost,
|
|
||||||
cache_read_cost=cache_read_cost,
|
|
||||||
cache_cost=cache_cost,
|
|
||||||
request_cost=request_cost,
|
|
||||||
total_cost=total_cost,
|
|
||||||
input_price=input_price,
|
|
||||||
output_price=output_price,
|
|
||||||
cache_creation_price=cache_creation_price,
|
|
||||||
cache_read_price=cache_read_price,
|
|
||||||
request_price=request_price,
|
|
||||||
actual_rate_multiplier=actual_rate_multiplier,
|
|
||||||
is_free_tier=is_free_tier,
|
|
||||||
)
|
)
|
||||||
|
usage_params, _ = await cls._prepare_usage_record(params)
|
||||||
|
|
||||||
# 创建 Usage 记录
|
# 创建 Usage 记录
|
||||||
usage = Usage(**usage_params)
|
usage = Usage(**usage_params)
|
||||||
@@ -660,76 +766,25 @@ class UsageService:
|
|||||||
if request_id is None:
|
if request_id is None:
|
||||||
request_id = str(uuid.uuid4())[:8]
|
request_id = str(uuid.uuid4())[:8]
|
||||||
|
|
||||||
# 获取费率倍数和是否免费套餐
|
# 使用共享逻辑准备记录参数
|
||||||
actual_rate_multiplier, is_free_tier = await cls._get_rate_multiplier_and_free_tier(
|
params = UsageRecordParams(
|
||||||
db, provider_api_key_id, provider_id
|
db=db, user=user, api_key=api_key, provider=provider, model=model,
|
||||||
)
|
input_tokens=input_tokens, output_tokens=output_tokens,
|
||||||
|
|
||||||
# 计算成本
|
|
||||||
is_failed_request = status_code >= 400 or error_message is not None
|
|
||||||
(
|
|
||||||
input_price, output_price, cache_creation_price, cache_read_price, request_price,
|
|
||||||
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost,
|
|
||||||
request_cost, total_cost, _tier_index
|
|
||||||
) = await cls._calculate_costs(
|
|
||||||
db=db,
|
|
||||||
provider=provider,
|
|
||||||
model=model,
|
|
||||||
input_tokens=input_tokens,
|
|
||||||
output_tokens=output_tokens,
|
|
||||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||||
cache_read_input_tokens=cache_read_input_tokens,
|
cache_read_input_tokens=cache_read_input_tokens,
|
||||||
api_format=api_format,
|
request_type=request_type, api_format=api_format, is_stream=is_stream,
|
||||||
cache_ttl_minutes=cache_ttl_minutes,
|
response_time_ms=response_time_ms, first_byte_time_ms=first_byte_time_ms,
|
||||||
use_tiered_pricing=use_tiered_pricing,
|
status_code=status_code, error_message=error_message, metadata=metadata,
|
||||||
is_failed_request=is_failed_request,
|
request_headers=request_headers, request_body=request_body,
|
||||||
)
|
|
||||||
|
|
||||||
# 构建 Usage 参数
|
|
||||||
usage_params = cls._build_usage_params(
|
|
||||||
db=db,
|
|
||||||
user=user,
|
|
||||||
api_key=api_key,
|
|
||||||
provider=provider,
|
|
||||||
model=model,
|
|
||||||
input_tokens=input_tokens,
|
|
||||||
output_tokens=output_tokens,
|
|
||||||
cache_creation_input_tokens=cache_creation_input_tokens,
|
|
||||||
cache_read_input_tokens=cache_read_input_tokens,
|
|
||||||
request_type=request_type,
|
|
||||||
api_format=api_format,
|
|
||||||
is_stream=is_stream,
|
|
||||||
response_time_ms=response_time_ms,
|
|
||||||
first_byte_time_ms=first_byte_time_ms,
|
|
||||||
status_code=status_code,
|
|
||||||
error_message=error_message,
|
|
||||||
metadata=metadata,
|
|
||||||
request_headers=request_headers,
|
|
||||||
request_body=request_body,
|
|
||||||
provider_request_headers=provider_request_headers,
|
provider_request_headers=provider_request_headers,
|
||||||
response_headers=response_headers,
|
response_headers=response_headers, response_body=response_body,
|
||||||
response_body=response_body,
|
request_id=request_id, provider_id=provider_id,
|
||||||
request_id=request_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
provider_endpoint_id=provider_endpoint_id,
|
provider_endpoint_id=provider_endpoint_id,
|
||||||
provider_api_key_id=provider_api_key_id,
|
provider_api_key_id=provider_api_key_id, status=status,
|
||||||
status=status,
|
cache_ttl_minutes=cache_ttl_minutes, use_tiered_pricing=use_tiered_pricing,
|
||||||
target_model=target_model,
|
target_model=target_model,
|
||||||
input_cost=input_cost,
|
|
||||||
output_cost=output_cost,
|
|
||||||
cache_creation_cost=cache_creation_cost,
|
|
||||||
cache_read_cost=cache_read_cost,
|
|
||||||
cache_cost=cache_cost,
|
|
||||||
request_cost=request_cost,
|
|
||||||
total_cost=total_cost,
|
|
||||||
input_price=input_price,
|
|
||||||
output_price=output_price,
|
|
||||||
cache_creation_price=cache_creation_price,
|
|
||||||
cache_read_price=cache_read_price,
|
|
||||||
request_price=request_price,
|
|
||||||
actual_rate_multiplier=actual_rate_multiplier,
|
|
||||||
is_free_tier=is_free_tier,
|
|
||||||
)
|
)
|
||||||
|
usage_params, total_cost = await cls._prepare_usage_record(params)
|
||||||
|
|
||||||
# 检查是否已存在相同 request_id 的记录
|
# 检查是否已存在相同 request_id 的记录
|
||||||
existing_usage = db.query(Usage).filter(Usage.request_id == request_id).first()
|
existing_usage = db.query(Usage).filter(Usage.request_id == request_id).first()
|
||||||
@@ -751,7 +806,7 @@ class UsageService:
|
|||||||
api_key = db.merge(api_key)
|
api_key = db.merge(api_key)
|
||||||
|
|
||||||
# 使用原子更新避免并发竞态条件
|
# 使用原子更新避免并发竞态条件
|
||||||
from sqlalchemy import func, update
|
from sqlalchemy import func as sql_func, update
|
||||||
from src.models.database import ApiKey as ApiKeyModel, User as UserModel, GlobalModel
|
from src.models.database import ApiKey as ApiKeyModel, User as UserModel, GlobalModel
|
||||||
|
|
||||||
# 更新用户使用量(独立 Key 不计入创建者的使用记录)
|
# 更新用户使用量(独立 Key 不计入创建者的使用记录)
|
||||||
@@ -762,7 +817,7 @@ class UsageService:
|
|||||||
.values(
|
.values(
|
||||||
used_usd=UserModel.used_usd + total_cost,
|
used_usd=UserModel.used_usd + total_cost,
|
||||||
total_usd=UserModel.total_usd + total_cost,
|
total_usd=UserModel.total_usd + total_cost,
|
||||||
updated_at=func.now(),
|
updated_at=sql_func.now(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -776,8 +831,8 @@ class UsageService:
|
|||||||
total_requests=ApiKeyModel.total_requests + 1,
|
total_requests=ApiKeyModel.total_requests + 1,
|
||||||
total_cost_usd=ApiKeyModel.total_cost_usd + total_cost,
|
total_cost_usd=ApiKeyModel.total_cost_usd + total_cost,
|
||||||
balance_used_usd=ApiKeyModel.balance_used_usd + total_cost,
|
balance_used_usd=ApiKeyModel.balance_used_usd + total_cost,
|
||||||
last_used_at=func.now(),
|
last_used_at=sql_func.now(),
|
||||||
updated_at=func.now(),
|
updated_at=sql_func.now(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -787,8 +842,8 @@ class UsageService:
|
|||||||
.values(
|
.values(
|
||||||
total_requests=ApiKeyModel.total_requests + 1,
|
total_requests=ApiKeyModel.total_requests + 1,
|
||||||
total_cost_usd=ApiKeyModel.total_cost_usd + total_cost,
|
total_cost_usd=ApiKeyModel.total_cost_usd + total_cost,
|
||||||
last_used_at=func.now(),
|
last_used_at=sql_func.now(),
|
||||||
updated_at=func.now(),
|
updated_at=sql_func.now(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1121,19 +1176,48 @@ class UsageService:
|
|||||||
]
|
]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def cleanup_old_usage_records(db: Session, days_to_keep: int = 90) -> int:
|
def cleanup_old_usage_records(
|
||||||
"""清理旧的使用记录"""
|
db: Session, days_to_keep: int = 90, batch_size: int = 1000
|
||||||
|
) -> int:
|
||||||
|
"""清理旧的使用记录(分批删除避免长事务锁定)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
days_to_keep: 保留天数,默认 90 天
|
||||||
|
batch_size: 每批删除数量,默认 1000 条
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
删除的总记录数
|
||||||
|
"""
|
||||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
|
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
|
||||||
|
total_deleted = 0
|
||||||
|
|
||||||
# 删除旧记录
|
while True:
|
||||||
deleted = db.query(Usage).filter(Usage.created_at < cutoff_date).delete()
|
# 查询待删除的 ID(使用新索引 idx_usage_user_created)
|
||||||
|
batch_ids = (
|
||||||
|
db.query(Usage.id)
|
||||||
|
.filter(Usage.created_at < cutoff_date)
|
||||||
|
.limit(batch_size)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
db.commit()
|
if not batch_ids:
|
||||||
|
break
|
||||||
|
|
||||||
logger.info(f"清理使用记录: 删除 {deleted} 条超过 {days_to_keep} 天的记录")
|
# 批量删除
|
||||||
|
deleted_count = (
|
||||||
|
db.query(Usage)
|
||||||
|
.filter(Usage.id.in_([row.id for row in batch_ids]))
|
||||||
|
.delete(synchronize_session=False)
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
total_deleted += deleted_count
|
||||||
|
|
||||||
return deleted
|
logger.debug(f"清理使用记录: 本批删除 {deleted_count} 条")
|
||||||
|
|
||||||
|
logger.info(f"清理使用记录: 共删除 {total_deleted} 条超过 {days_to_keep} 天的记录")
|
||||||
|
|
||||||
|
return total_deleted
|
||||||
|
|
||||||
# ========== 请求状态追踪方法 ==========
|
# ========== 请求状态追踪方法 ==========
|
||||||
|
|
||||||
@@ -1219,6 +1303,7 @@ class UsageService:
|
|||||||
error_message: Optional[str] = None,
|
error_message: Optional[str] = None,
|
||||||
provider: Optional[str] = None,
|
provider: Optional[str] = None,
|
||||||
target_model: Optional[str] = None,
|
target_model: Optional[str] = None,
|
||||||
|
first_byte_time_ms: Optional[int] = None,
|
||||||
) -> Optional[Usage]:
|
) -> Optional[Usage]:
|
||||||
"""
|
"""
|
||||||
快速更新使用记录状态
|
快速更新使用记录状态
|
||||||
@@ -1230,6 +1315,7 @@ class UsageService:
|
|||||||
error_message: 错误消息(仅在 failed 状态时使用)
|
error_message: 错误消息(仅在 failed 状态时使用)
|
||||||
provider: 提供商名称(可选,streaming 状态时更新)
|
provider: 提供商名称(可选,streaming 状态时更新)
|
||||||
target_model: 映射后的目标模型名(可选)
|
target_model: 映射后的目标模型名(可选)
|
||||||
|
first_byte_time_ms: 首字时间/TTFB(可选,streaming 状态时更新)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
更新后的 Usage 记录,如果未找到则返回 None
|
更新后的 Usage 记录,如果未找到则返回 None
|
||||||
@@ -1247,6 +1333,8 @@ class UsageService:
|
|||||||
usage.provider = provider
|
usage.provider = provider
|
||||||
if target_model:
|
if target_model:
|
||||||
usage.target_model = target_model
|
usage.target_model = target_model
|
||||||
|
if first_byte_time_ms is not None:
|
||||||
|
usage.first_byte_time_ms = first_byte_time_ms
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
from typing import Any, AsyncIterator, Dict, Optional, Tuple
|
from typing import Any, AsyncIterator, Dict, Optional, Tuple
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@@ -457,26 +458,32 @@ class StreamUsageTracker:
|
|||||||
|
|
||||||
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应 | 估算输入tokens:{self.input_tokens}")
|
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应 | 估算输入tokens:{self.input_tokens}")
|
||||||
|
|
||||||
# 更新状态为 streaming,同时更新 provider
|
|
||||||
if self.request_id:
|
|
||||||
try:
|
|
||||||
from src.services.usage.service import UsageService
|
|
||||||
UsageService.update_usage_status(
|
|
||||||
db=self.db,
|
|
||||||
request_id=self.request_id,
|
|
||||||
status="streaming",
|
|
||||||
provider=self.provider,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
|
|
||||||
|
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
|
first_chunk_received = False
|
||||||
try:
|
try:
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
# 保存原始字节流(用于错误诊断)
|
# 保存原始字节流(用于错误诊断)
|
||||||
self.raw_chunks.append(chunk)
|
self.raw_chunks.append(chunk)
|
||||||
|
|
||||||
|
# 第一个 chunk 收到时,更新状态为 streaming 并记录 TTFB
|
||||||
|
if not first_chunk_received:
|
||||||
|
first_chunk_received = True
|
||||||
|
if self.request_id:
|
||||||
|
try:
|
||||||
|
# 计算 TTFB(使用请求原始开始时间或 track_stream 开始时间)
|
||||||
|
base_time = self.request_start_time or self.start_time
|
||||||
|
first_byte_time_ms = int((time.time() - base_time) * 1000) if base_time else None
|
||||||
|
UsageService.update_usage_status(
|
||||||
|
db=self.db,
|
||||||
|
request_id=self.request_id,
|
||||||
|
status="streaming",
|
||||||
|
provider=self.provider,
|
||||||
|
first_byte_time_ms=first_byte_time_ms,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
|
||||||
|
|
||||||
# 返回原始块给客户端
|
# 返回原始块给客户端
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|||||||
@@ -59,14 +59,15 @@ class ApiKeyService:
|
|||||||
if expire_days:
|
if expire_days:
|
||||||
expires_at = datetime.now(timezone.utc) + timedelta(days=expire_days)
|
expires_at = datetime.now(timezone.utc) + timedelta(days=expire_days)
|
||||||
|
|
||||||
|
# 空数组转为 None(表示不限制)
|
||||||
api_key = ApiKey(
|
api_key = ApiKey(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
key_hash=key_hash,
|
key_hash=key_hash,
|
||||||
key_encrypted=key_encrypted,
|
key_encrypted=key_encrypted,
|
||||||
name=name or f"API Key {datetime.now(timezone.utc).strftime('%Y%m%d%H%M%S')}",
|
name=name or f"API Key {datetime.now(timezone.utc).strftime('%Y%m%d%H%M%S')}",
|
||||||
allowed_providers=allowed_providers,
|
allowed_providers=allowed_providers or None,
|
||||||
allowed_api_formats=allowed_api_formats,
|
allowed_api_formats=allowed_api_formats or None,
|
||||||
allowed_models=allowed_models,
|
allowed_models=allowed_models or None,
|
||||||
rate_limit=rate_limit,
|
rate_limit=rate_limit,
|
||||||
concurrent_limit=concurrent_limit,
|
concurrent_limit=concurrent_limit,
|
||||||
expires_at=expires_at,
|
expires_at=expires_at,
|
||||||
@@ -141,8 +142,18 @@ class ApiKeyService:
|
|||||||
"auto_delete_on_expiry",
|
"auto_delete_on_expiry",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# 允许显式设置为空数组/None 的字段(空数组会转为 None,表示"全部")
|
||||||
|
nullable_list_fields = {"allowed_providers", "allowed_api_formats", "allowed_models"}
|
||||||
|
|
||||||
for field, value in kwargs.items():
|
for field, value in kwargs.items():
|
||||||
if field in updatable_fields and value is not None:
|
if field not in updatable_fields:
|
||||||
|
continue
|
||||||
|
# 对于 nullable_list_fields,空数组应该转为 None(表示不限制)
|
||||||
|
if field in nullable_list_fields:
|
||||||
|
if value is not None:
|
||||||
|
# 空数组转为 None(表示允许全部)
|
||||||
|
setattr(api_key, field, value if value else None)
|
||||||
|
elif value is not None:
|
||||||
setattr(api_key, field, value)
|
setattr(api_key, field, value)
|
||||||
|
|
||||||
api_key.updated_at = datetime.now(timezone.utc)
|
api_key.updated_at = datetime.now(timezone.utc)
|
||||||
|
|||||||
@@ -1,8 +1,16 @@
|
|||||||
"""分布式任务协调器,确保仅有一个 worker 执行特定任务"""
|
"""分布式任务协调器,确保仅有一个 worker 执行特定任务
|
||||||
|
|
||||||
|
锁清理策略:
|
||||||
|
- 单实例模式(默认):启动时使用原子操作清理旧锁并获取新锁
|
||||||
|
- 多实例模式:使用 NX 选项竞争锁,依赖 TTL 处理异常退出
|
||||||
|
|
||||||
|
使用方式:
|
||||||
|
- 默认行为:启动时清理旧锁(适用于单机部署)
|
||||||
|
- 多实例部署:设置 SINGLE_INSTANCE_MODE=false 禁用启动清理
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import uuid
|
import uuid
|
||||||
@@ -19,6 +27,10 @@ except ImportError: # pragma: no cover - Windows 环境
|
|||||||
class StartupTaskCoordinator:
|
class StartupTaskCoordinator:
|
||||||
"""利用 Redis 或文件锁,保证任务只在单个进程/实例中运行"""
|
"""利用 Redis 或文件锁,保证任务只在单个进程/实例中运行"""
|
||||||
|
|
||||||
|
# 类级别标记:在当前进程中是否已尝试过启动清理
|
||||||
|
# 注意:这在 fork 模式下每个 worker 都是独立的
|
||||||
|
_startup_cleanup_attempted = False
|
||||||
|
|
||||||
def __init__(self, redis_client=None, lock_dir: Optional[str] = None):
|
def __init__(self, redis_client=None, lock_dir: Optional[str] = None):
|
||||||
self.redis = redis_client
|
self.redis = redis_client
|
||||||
self._tokens: Dict[str, str] = {}
|
self._tokens: Dict[str, str] = {}
|
||||||
@@ -26,6 +38,8 @@ class StartupTaskCoordinator:
|
|||||||
self._lock_dir = pathlib.Path(lock_dir or os.getenv("TASK_LOCK_DIR", "./.locks"))
|
self._lock_dir = pathlib.Path(lock_dir or os.getenv("TASK_LOCK_DIR", "./.locks"))
|
||||||
if not self._lock_dir.exists():
|
if not self._lock_dir.exists():
|
||||||
self._lock_dir.mkdir(parents=True, exist_ok=True)
|
self._lock_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
# 单实例模式:启动时清理旧锁(适用于单机部署,避免残留锁问题)
|
||||||
|
self._single_instance_mode = os.getenv("SINGLE_INSTANCE_MODE", "true").lower() == "true"
|
||||||
|
|
||||||
def _redis_key(self, name: str) -> str:
|
def _redis_key(self, name: str) -> str:
|
||||||
return f"task_lock:{name}"
|
return f"task_lock:{name}"
|
||||||
@@ -36,12 +50,51 @@ class StartupTaskCoordinator:
|
|||||||
if self.redis:
|
if self.redis:
|
||||||
token = str(uuid.uuid4())
|
token = str(uuid.uuid4())
|
||||||
try:
|
try:
|
||||||
acquired = await self.redis.set(self._redis_key(name), token, nx=True, ex=ttl)
|
if self._single_instance_mode:
|
||||||
if acquired:
|
# 单实例模式:使用 Lua 脚本原子性地"清理旧锁 + 竞争获取"
|
||||||
self._tokens[name] = token
|
# 只有当锁不存在或成功获取时才返回 1
|
||||||
logger.info(f"任务 {name} 通过 Redis 锁独占执行")
|
# 这样第一个执行的 worker 会清理旧锁并获取,后续 worker 会正常竞争
|
||||||
return True
|
script = """
|
||||||
return False
|
local key = KEYS[1]
|
||||||
|
local token = ARGV[1]
|
||||||
|
local ttl = tonumber(ARGV[2])
|
||||||
|
local startup_key = KEYS[1] .. ':startup'
|
||||||
|
|
||||||
|
-- 检查是否已有 worker 执行过启动清理
|
||||||
|
local cleaned = redis.call('GET', startup_key)
|
||||||
|
if not cleaned then
|
||||||
|
-- 第一个 worker:删除旧锁,标记已清理
|
||||||
|
redis.call('DEL', key)
|
||||||
|
redis.call('SET', startup_key, '1', 'EX', 60)
|
||||||
|
end
|
||||||
|
|
||||||
|
-- 尝试获取锁(NX 模式)
|
||||||
|
local result = redis.call('SET', key, token, 'NX', 'EX', ttl)
|
||||||
|
if result then
|
||||||
|
return 1
|
||||||
|
end
|
||||||
|
return 0
|
||||||
|
"""
|
||||||
|
result = await self.redis.eval(
|
||||||
|
script, 2,
|
||||||
|
self._redis_key(name), self._redis_key(name),
|
||||||
|
token, ttl
|
||||||
|
)
|
||||||
|
if result == 1:
|
||||||
|
self._tokens[name] = token
|
||||||
|
logger.info(f"任务 {name} 通过 Redis 锁独占执行")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
# 多实例模式:直接使用 NX 选项竞争锁
|
||||||
|
acquired = await self.redis.set(
|
||||||
|
self._redis_key(name), token, nx=True, ex=ttl
|
||||||
|
)
|
||||||
|
if acquired:
|
||||||
|
self._tokens[name] = token
|
||||||
|
logger.info(f"任务 {name} 通过 Redis 锁独占执行")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
except Exception as exc: # pragma: no cover - Redis 异常回退
|
except Exception as exc: # pragma: no cover - Redis 异常回退
|
||||||
logger.warning(f"Redis 锁获取失败,回退到文件锁: {exc}")
|
logger.warning(f"Redis 锁获取失败,回退到文件锁: {exc}")
|
||||||
|
|
||||||
|
|||||||
@@ -139,3 +139,83 @@ async def with_timeout_context(timeout: float, operation_name: str = "operation"
|
|||||||
# Python 3.10 及以下版本的兼容实现
|
# Python 3.10 及以下版本的兼容实现
|
||||||
# 注意:这个简单实现不支持嵌套取消
|
# 注意:这个简单实现不支持嵌套取消
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def read_first_chunk_with_ttfb_timeout(
|
||||||
|
byte_iterator: Any,
|
||||||
|
timeout: float,
|
||||||
|
request_id: str,
|
||||||
|
provider_name: str,
|
||||||
|
) -> tuple[bytes, Any]:
|
||||||
|
"""
|
||||||
|
读取流的首字节并应用 TTFB 超时检测
|
||||||
|
|
||||||
|
首字节超时(Time To First Byte)用于检测慢响应的 Provider,
|
||||||
|
超时时触发故障转移到其他可用的 Provider。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
byte_iterator: 异步字节流迭代器
|
||||||
|
timeout: TTFB 超时时间(秒)
|
||||||
|
request_id: 请求 ID(用于日志)
|
||||||
|
provider_name: Provider 名称(用于日志和异常)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(first_chunk, aiter): 首个字节块和异步迭代器
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ProviderTimeoutException: 如果首字节超时
|
||||||
|
"""
|
||||||
|
from src.core.exceptions import ProviderTimeoutException
|
||||||
|
|
||||||
|
aiter = byte_iterator.__aiter__()
|
||||||
|
|
||||||
|
try:
|
||||||
|
first_chunk = await asyncio.wait_for(aiter.__anext__(), timeout=timeout)
|
||||||
|
return first_chunk, aiter
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# 完整的资源清理:先关闭迭代器,再关闭底层响应
|
||||||
|
await _cleanup_iterator_resources(aiter, request_id)
|
||||||
|
logger.warning(
|
||||||
|
f" [{request_id}] 流首字节超时 (TTFB): "
|
||||||
|
f"Provider={provider_name}, timeout={timeout}s"
|
||||||
|
)
|
||||||
|
raise ProviderTimeoutException(
|
||||||
|
provider_name=provider_name,
|
||||||
|
timeout=int(timeout),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _cleanup_iterator_resources(aiter: Any, request_id: str) -> None:
|
||||||
|
"""
|
||||||
|
清理异步迭代器及其底层资源
|
||||||
|
|
||||||
|
确保在 TTFB 超时后正确释放 HTTP 连接,避免连接泄漏。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
aiter: 异步迭代器
|
||||||
|
request_id: 请求 ID(用于日志)
|
||||||
|
"""
|
||||||
|
# 1. 关闭迭代器本身
|
||||||
|
if hasattr(aiter, "aclose"):
|
||||||
|
try:
|
||||||
|
await aiter.aclose()
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f" [{request_id}] 关闭迭代器失败: {e}")
|
||||||
|
|
||||||
|
# 2. 关闭底层响应对象(httpx.Response)
|
||||||
|
# 迭代器可能持有 _response 属性指向底层响应
|
||||||
|
response = getattr(aiter, "_response", None)
|
||||||
|
if response is not None and hasattr(response, "aclose"):
|
||||||
|
try:
|
||||||
|
await response.aclose()
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f" [{request_id}] 关闭底层响应失败: {e}")
|
||||||
|
|
||||||
|
# 3. 尝试关闭 httpx 流(如果迭代器是 httpx 的 aiter_bytes)
|
||||||
|
# httpx 的 Response.aiter_bytes() 返回的生成器可能有 _stream 属性
|
||||||
|
stream = getattr(aiter, "_stream", None)
|
||||||
|
if stream is not None and hasattr(stream, "aclose"):
|
||||||
|
try:
|
||||||
|
await stream.aclose()
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f" [{request_id}] 关闭流对象失败: {e}")
|
||||||
|
|||||||
Reference in New Issue
Block a user