27 Commits

Author SHA1 Message Date
fawney19
41719a00e7 refactor: 改进分布式任务锁的清理策略
实现两种锁清理模式:
- 单实例模式(默认):启动时使用 Lua 脚本原子性清理旧锁,解决 worker 重启时���锁残留问题
- 多实例模式:使用 NX 选项竞争锁,依赖 TTL 处理异常退出

可通过 SINGLE_INSTANCE_MODE 环境变量控制模式选择。
2025-12-28 21:34:43 +08:00
fawney19
b5c0f85dca refactor: 统一剪贴板复制功能到 useClipboard 组合式函数
将各个组件和视图中重复的剪贴板复制逻辑提取到 useClipboard 组合式函数。
增加 showToast 参数支持静默复制,减少代码重复,提高维护性。
2025-12-28 20:41:52 +08:00
fawney19
7d6d262ed3 feat: 增加用户密码修改时的确认验证
在编辑用户时,如果填写了新密码,需要进行密码确认,确保两次输入一致。
同时更新后端请求模型以支持密码字段。
2025-12-28 20:00:25 +08:00
fawney19
e21acd73eb fix: 修复模型映射中重复关联的问题
在批量分配模型和编辑模型映射时,需要检查不仅是主模型名是否已关联,
还要检查其映射名称是否已存在,防止同一个上游模型被重复关联。
2025-12-28 19:40:07 +08:00
fawney19
702f9bc5f1 fix: 修复缓存监控页面TTL分析时间段选择器点击无响应
为 Select 组件添加 v-model:open 绑定,解决 radix-vue Select 组件
在某些情况下点击无响应的问题。

Fixes #55
2025-12-28 19:14:49 +08:00
fawney19
d0ce798881 fix: TTL=0时启用Key随机轮换模式
- 当所有Key的cache_ttl_minutes都为0时,使用随机排序代替确定性哈希
- 将hashlib和random的import移到文件顶部
- 简化单Key场景的处理逻辑

Closes #57
2025-12-28 19:07:25 +08:00
fawney19
2b1d197047 Merge remote-tracking branch 'gitcode/master' into htmambo/master 2025-12-25 22:47:08 +08:00
fawney19
71bc2e6aab fix: 增加参数校验防止除零错误 2025-12-25 22:44:17 +08:00
fawney19
afb329934a fix: 修复端点健康统计时间分段计算的除零错误 2025-12-25 19:54:16 +08:00
elky0401
1313af45a3 !4 merge htmambo/master into master
refactor: 重构模型测试错误解析逻辑并修复用量统计变量引用

Created-by: elky0401
Commit-by: fawney19;hoping
Merged-by: elky0401
Description: feat: 引入统一的端点检查器以重构适配器并改进错误处理和用量统计。
refactor: 重构模型测试错误解析逻辑并修复用量统计变量引用

See merge request: elky0401/Aether!4
2025-12-25 19:39:33 +08:00
fawney19
dddb327885 refactor: 重构模型测试错误解析逻辑并修复用量统计变量引用
- 将 ModelsTab 和 ModelAliasesTab 中重复的错误解析逻辑提取到 errorParser.ts
- 添加 parseTestModelError 函数统一处理测试响应错误
- 为 testModel API 添加 TypeScript 类型定义 (TestModelRequest/TestModelResponse)
- 修复 endpoint_checker.py 中 usage_data 变量引用错误
2025-12-25 19:36:29 +08:00
hoping
26b4a37323 feat: 引入统一的端点检查器以重构适配器并改进错误处理和用量统计。 2025-12-25 00:02:56 +08:00
fawney19
9dad194130 fix: 修复 API Key 访问限制字段无法清除的问题
- 统一前端创建和更新 API Key 时的空数组处理逻辑
- 后端创建和更新接口都支持空数组转 NULL(表示不限制)
- 开启自动刷新时立即刷新一次数据
2025-12-24 22:35:30 +08:00
fawney19
03ad16ea8a fix: 修复迁移脚本在全新安装时报错及改进统计回填逻辑
迁移脚本修复:
- 移除 AUTOCOMMIT 模式,改为在同一事务中创建索引
- 分别检查每个索引是否存在,只创建缺失的索引
- 修复全新安装时 AUTOCOMMIT 连接看不到未提交表的问题 (#46)

统计回填改进:
- 分别检查 StatsDaily 和 StatsDailyModel 的缺失日期
- 只回填实际缺失的数据而非连续区间
- 添加失败统计计数和 rollback 错误日志
2025-12-24 21:50:05 +08:00
fawney19
2fa64b98e3 fix: deploy.sh 将 Dockerfile.app.local 纳入代码变化检测 2025-12-24 18:10:42 +08:00
fawney19
75d7e89cbb perf: 添加 gunicorn --preload 参数优化内存占用
Worker 进程共享只读内存(代码、常量),可减少约 30-40% 内存占用

Closes #44
2025-12-24 18:10:42 +08:00
fawney19
d73a443484 fix: 修复初次执行 migrate.sh 时 usage 表不存在的问题 (#43)
- 在 baseline 中直接创建 usage 表复合索引
- 在后续迁移中添加表存在性检查,避免 AUTOCOMMIT 连接看不到事务中的表
2025-12-24 18:10:42 +08:00
Hwwwww-dev
15a9b88fc8 feat: enhance extract_cache_creation_tokens function to support three formats[#41] (#42)
- Updated the function to prioritize nested format, followed by flat new format, and finally old format for cache creation tokens.
- Added fallback logic for cases where the preferred formats return zero.
- Expanded unit tests to cover new format scenarios and ensure proper functionality across all formats.

Co-authored-by: heweimin <heweimin@retaileye.ai>
2025-12-24 01:31:45 +08:00
fawney19
03eb7203ec fix(api): 同步 chat_handler_base 使用 aiter_bytes 支持自动解压 2025-12-24 01:13:35 +08:00
hank9999
e38cd6819b fix(api): 优化字节流迭代器以支持自动解压 gzip (#39) 2025-12-24 01:11:35 +08:00
fawney19
d44cfaddf6 fix: rename variable to avoid shadowing in model mapping cache stats
循环内部变量 provider_model_mappings 与外部列表同名,导致外部列表被覆盖为 None 引发 AttributeError
2025-12-23 00:38:37 +08:00
fawney19
65225710a8 refactor: use ConcurrencyDefaults for CACHE_RESERVATION_RATIO constant 2025-12-23 00:34:18 +08:00
fawney19
d7f5b16359 fix: rebuild app image when migration files change
deploy.sh was only running alembic upgrade on the old container when
migration files changed, but the migration files are baked into the
Docker image. Now it rebuilds the app image when migrations change.
2025-12-23 00:23:22 +08:00
fawney19
7185818724 fix: remove index_exists check to avoid transaction conflict in migration
- Remove index_exists function that used op.get_bind() within transaction
- Use IF NOT EXISTS / IF EXISTS SQL syntax instead
- Fixes CREATE INDEX CONCURRENTLY error in Docker migration
2025-12-23 00:21:03 +08:00
fawney19
868f3349e5 fix: use AUTOCOMMIT mode for CREATE INDEX CONCURRENTLY in migration
PostgreSQL 不允许在事务块内执行 CREATE INDEX CONCURRENTLY,
通过创建独立连接并设置 AUTOCOMMIT 隔离级别来解决此问题。
2025-12-23 00:18:11 +08:00
fawney19
d7384e69d9 fix: improve code quality and add type safety for Key updates
- Replace f-string logging with lazy formatting in keys.py (lines 256, 265)
- Add EndpointAPIKeyUpdate type interface for frontend type safety
- Use typed EndpointAPIKeyUpdate instead of any in KeyFormDialog.vue
2025-12-23 00:11:10 +08:00
fawney19
1d5c378343 feat: add TTFB timeout detection and improve stream handling
- Add stream first byte timeout (TTFB) detection to trigger failover
  when provider responds too slowly (configurable via STREAM_FIRST_BYTE_TIMEOUT)
- Add rate limit fail-open/fail-close strategy configuration
- Improve exception handling in stream prefetch with proper error classification
- Refactor UsageService with shared _prepare_usage_record method
- Add batch deletion for old usage records to avoid long transaction locks
- Update CLI adapters to use proper User-Agent headers for each CLI client
- Add composite indexes migration for usage table query optimization
- Fix streaming status display in frontend to show TTFB during streaming
- Remove sensitive JWT secret logging in auth service
2025-12-22 23:44:42 +08:00
63 changed files with 3313 additions and 463 deletions

View File

@@ -105,7 +105,7 @@ RUN printf '%s\n' \
'stderr_logfile=/var/log/nginx/error.log' \ 'stderr_logfile=/var/log/nginx/error.log' \
'' \ '' \
'[program:app]' \ '[program:app]' \
'command=gunicorn src.main:app -w %(ENV_GUNICORN_WORKERS)s -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:%(ENV_PORT)s --timeout 120 --access-logfile - --error-logfile - --log-level info' \ 'command=gunicorn src.main:app --preload -w %(ENV_GUNICORN_WORKERS)s -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:%(ENV_PORT)s --timeout 120 --access-logfile - --error-logfile - --log-level info' \
'directory=/app' \ 'directory=/app' \
'autostart=true' \ 'autostart=true' \
'autorestart=true' \ 'autorestart=true' \

View File

@@ -106,7 +106,7 @@ RUN printf '%s\n' \
'stderr_logfile=/var/log/nginx/error.log' \ 'stderr_logfile=/var/log/nginx/error.log' \
'' \ '' \
'[program:app]' \ '[program:app]' \
'command=gunicorn src.main:app -w %(ENV_GUNICORN_WORKERS)s -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:%(ENV_PORT)s --timeout 120 --access-logfile - --error-logfile - --log-level info' \ 'command=gunicorn src.main:app --preload -w %(ENV_GUNICORN_WORKERS)s -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:%(ENV_PORT)s --timeout 120 --access-logfile - --error-logfile - --log-level info' \
'directory=/app' \ 'directory=/app' \
'autostart=true' \ 'autostart=true' \
'autorestart=true' \ 'autorestart=true' \

View File

@@ -394,6 +394,10 @@ def upgrade() -> None:
index=True, index=True,
), ),
) )
# usage 表复合索引(优化常见查询)
op.create_index("idx_usage_user_created", "usage", ["user_id", "created_at"])
op.create_index("idx_usage_apikey_created", "usage", ["api_key_id", "created_at"])
op.create_index("idx_usage_provider_model_created", "usage", ["provider", "model", "created_at"])
# ==================== user_quotas ==================== # ==================== user_quotas ====================
op.create_table( op.create_table(

View File

@@ -0,0 +1,65 @@
"""add usage table composite indexes for query optimization
Revision ID: b2c3d4e5f6g7
Revises: a1b2c3d4e5f6
Create Date: 2025-12-20 15:00:00.000000+00:00
"""
from alembic import op
from sqlalchemy import text
# revision identifiers, used by Alembic.
revision = 'b2c3d4e5f6g7'
down_revision = 'a1b2c3d4e5f6'
branch_labels = None
depends_on = None
def upgrade() -> None:
"""为 usage 表添加复合索引以优化常见查询
注意:这些索引已经在 baseline 迁移中创建。
此迁移仅用于从旧版本升级的场景,新安装会跳过。
"""
conn = op.get_bind()
# 检查 usage 表是否存在
result = conn.execute(text(
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'usage')"
))
if not result.scalar():
# 表不存在,跳过
return
# 定义需要创建的索引
indexes = [
("idx_usage_user_created", "ON usage (user_id, created_at)"),
("idx_usage_apikey_created", "ON usage (api_key_id, created_at)"),
("idx_usage_provider_model_created", "ON usage (provider, model, created_at)"),
]
# 分别检查并创建每个索引
for index_name, index_def in indexes:
result = conn.execute(text(
f"SELECT EXISTS (SELECT 1 FROM pg_indexes WHERE indexname = '{index_name}')"
))
if result.scalar():
continue # 索引已存在,跳过
conn.execute(text(f"CREATE INDEX {index_name} {index_def}"))
def downgrade() -> None:
"""删除复合索引"""
conn = op.get_bind()
# 使用 IF EXISTS 避免索引不存在时报错
conn.execute(text(
"DROP INDEX IF EXISTS idx_usage_provider_model_created"
))
conn.execute(text(
"DROP INDEX IF EXISTS idx_usage_apikey_created"
))
conn.execute(text(
"DROP INDEX IF EXISTS idx_usage_user_created"
))

View File

@@ -26,10 +26,13 @@ calc_deps_hash() {
cat pyproject.toml frontend/package.json frontend/package-lock.json Dockerfile.base.local 2>/dev/null | md5sum | cut -d' ' -f1 cat pyproject.toml frontend/package.json frontend/package-lock.json Dockerfile.base.local 2>/dev/null | md5sum | cut -d' ' -f1
} }
# 计算代码文件的哈希值 # 计算代码文件的哈希值(包含 Dockerfile.app.local
calc_code_hash() { calc_code_hash() {
find src -type f -name "*.py" 2>/dev/null | sort | xargs cat 2>/dev/null | md5sum | cut -d' ' -f1 {
find frontend/src -type f \( -name "*.vue" -o -name "*.ts" -o -name "*.tsx" -o -name "*.js" \) 2>/dev/null | sort | xargs cat 2>/dev/null | md5sum | cut -d' ' -f1 cat Dockerfile.app.local 2>/dev/null
find src -type f -name "*.py" 2>/dev/null | sort | xargs cat 2>/dev/null
find frontend/src -type f \( -name "*.vue" -o -name "*.ts" -o -name "*.tsx" -o -name "*.js" \) 2>/dev/null | sort | xargs cat 2>/dev/null
} | md5sum | cut -d' ' -f1
} }
# 计算迁移文件的哈希值 # 计算迁移文件的哈希值
@@ -179,7 +182,13 @@ else
echo ">>> Dependencies unchanged." echo ">>> Dependencies unchanged."
fi fi
# 检查代码是否变化,或者 base 重建了app 依赖 base # 检查代码或迁移是否变化,或者 base 重建了app 依赖 base
# 注意:迁移文件打包在镜像中,所以迁移变化也需要重建 app 镜像
MIGRATION_CHANGED=false
if check_migration_changed; then
MIGRATION_CHANGED=true
fi
if ! docker image inspect aether-app:latest >/dev/null 2>&1; then if ! docker image inspect aether-app:latest >/dev/null 2>&1; then
echo ">>> App image not found, building..." echo ">>> App image not found, building..."
build_app build_app
@@ -192,6 +201,10 @@ elif check_code_changed; then
echo ">>> Code changed, rebuilding app image..." echo ">>> Code changed, rebuilding app image..."
build_app build_app
NEED_RESTART=true NEED_RESTART=true
elif [ "$MIGRATION_CHANGED" = true ]; then
echo ">>> Migration files changed, rebuilding app image..."
build_app
NEED_RESTART=true
else else
echo ">>> Code unchanged." echo ">>> Code unchanged."
fi fi
@@ -204,9 +217,9 @@ else
echo ">>> No changes detected, skipping restart." echo ">>> No changes detected, skipping restart."
fi fi
# 检查迁移变化 # 检查迁移变化(如果前面已经检测到变化并重建了镜像,这里直接运行迁移)
if check_migration_changed; then if [ "$MIGRATION_CHANGED" = true ]; then
echo ">>> Migration files changed, running database migration..." echo ">>> Running database migration..."
sleep 3 sleep 3
run_migration run_migration
else else

View File

@@ -58,3 +58,38 @@ export async function deleteProvider(providerId: string): Promise<{ message: str
return response.data return response.data
} }
/**
* 测试模型连接性
*/
export interface TestModelRequest {
provider_id: string
model_name: string
api_key_id?: string
message?: string
api_format?: string
}
export interface TestModelResponse {
success: boolean
error?: string
data?: {
response?: {
status_code?: number
error?: string | { message?: string }
choices?: Array<{ message?: { content?: string } }>
}
content_preview?: string
}
provider?: {
id: string
name: string
display_name: string
}
model?: string
}
export async function testModel(data: TestModelRequest): Promise<TestModelResponse> {
const response = await client.post('/api/admin/provider-query/test-model', data)
return response.data
}

View File

@@ -110,6 +110,24 @@ export interface EndpointAPIKey {
request_results_window?: Array<{ ts: number; ok: boolean }> // 请求结果滑动窗口 request_results_window?: Array<{ ts: number; ok: boolean }> // 请求结果滑动窗口
} }
export interface EndpointAPIKeyUpdate {
name?: string
api_key?: string // 仅在需要更新时提供
rate_multiplier?: number
internal_priority?: number
global_priority?: number | null
max_concurrent?: number | null // null 表示切换为自适应模式
rate_limit?: number
daily_limit?: number
monthly_limit?: number
allowed_models?: string[] | null
capabilities?: Record<string, boolean> | null
cache_ttl_minutes?: number
max_probe_interval_minutes?: number
note?: string
is_active?: boolean
}
export interface EndpointHealthDetail { export interface EndpointHealthDetail {
api_format: string api_format: string
health_score: number health_score: number

View File

@@ -163,7 +163,9 @@ const contentZIndex = computed(() => (props.zIndex || 60) + 10)
useEscapeKey(() => { useEscapeKey(() => {
if (isOpen.value) { if (isOpen.value) {
handleClose() handleClose()
return true // 阻止其他监听器(如父级抽屉的 ESC 监听器)
} }
return false
}, { }, {
disableOnInput: true, disableOnInput: true,
once: false once: false

View File

@@ -4,11 +4,11 @@ import { log } from '@/utils/logger'
export function useClipboard() { export function useClipboard() {
const { success, error: showError } = useToast() const { success, error: showError } = useToast()
async function copyToClipboard(text: string): Promise<boolean> { async function copyToClipboard(text: string, showToast = true): Promise<boolean> {
try { try {
if (navigator.clipboard && window.isSecureContext) { if (navigator.clipboard && window.isSecureContext) {
await navigator.clipboard.writeText(text) await navigator.clipboard.writeText(text)
success('已复制到剪贴板') if (showToast) success('已复制到剪贴板')
return true return true
} }
@@ -25,17 +25,17 @@ export function useClipboard() {
try { try {
const successful = document.execCommand('copy') const successful = document.execCommand('copy')
if (successful) { if (successful) {
success('已复制到剪贴板') if (showToast) success('已复制到剪贴板')
return true return true
} }
showError('复制失败,请手动复制') if (showToast) showError('复制失败,请手动复制')
return false return false
} finally { } finally {
document.body.removeChild(textArea) document.body.removeChild(textArea)
} }
} catch (err) { } catch (err) {
log.error('复制失败:', err) log.error('复制失败:', err)
showError('复制失败,请手动选择文本进行复制') if (showToast) showError('复制失败,请手动选择文本进行复制')
return false return false
} }
} }

View File

@@ -47,11 +47,11 @@ export function useConfirm() {
/** /**
* 便捷方法:危险操作确认(红色主题) * 便捷方法:危险操作确认(红色主题)
*/ */
const confirmDanger = (message: string, title?: string): Promise<boolean> => { const confirmDanger = (message: string, title?: string, confirmText?: string): Promise<boolean> => {
return confirm({ return confirm({
message, message,
title: title || '危险操作', title: title || '危险操作',
confirmText: '删除', confirmText: confirmText || '删除',
variant: 'danger' variant: 'danger'
}) })
} }

View File

@@ -4,11 +4,11 @@ import { onMounted, onUnmounted, ref } from 'vue'
* ESC 键监听 Composable简化版本直接使用独立监听器 * ESC 键监听 Composable简化版本直接使用独立监听器
* 用于按 ESC 键关闭弹窗或其他可关闭的组件 * 用于按 ESC 键关闭弹窗或其他可关闭的组件
* *
* @param callback - 按 ESC 键时执行的回调函数 * @param callback - 按 ESC 键时执行的回调函数,返回 true 表示已处理事件,阻止其他监听器执行
* @param options - 配置选项 * @param options - 配置选项
*/ */
export function useEscapeKey( export function useEscapeKey(
callback: () => void, callback: () => void | boolean,
options: { options: {
/** 是否在输入框获得焦点时禁用 ESC 键,默认 true */ /** 是否在输入框获得焦点时禁用 ESC 键,默认 true */
disableOnInput?: boolean disableOnInput?: boolean
@@ -42,8 +42,11 @@ export function useEscapeKey(
if (isInputElement) return if (isInputElement) return
} }
// 执行回调 // 执行回调,如果返回 true 则阻止其他监听器
callback() const handled = callback()
if (handled === true) {
event.stopImmediatePropagation()
}
// 移除当前元素的焦点,避免残留样式 // 移除当前元素的焦点,避免残留样式
if (document.activeElement instanceof HTMLElement) { if (document.activeElement instanceof HTMLElement) {

View File

@@ -700,6 +700,7 @@ import {
} from 'lucide-vue-next' } from 'lucide-vue-next'
import { useEscapeKey } from '@/composables/useEscapeKey' import { useEscapeKey } from '@/composables/useEscapeKey'
import { useToast } from '@/composables/useToast' import { useToast } from '@/composables/useToast'
import { useClipboard } from '@/composables/useClipboard'
import Card from '@/components/ui/card.vue' import Card from '@/components/ui/card.vue'
import Badge from '@/components/ui/badge.vue' import Badge from '@/components/ui/badge.vue'
import Button from '@/components/ui/button.vue' import Button from '@/components/ui/button.vue'
@@ -731,6 +732,7 @@ const emit = defineEmits<{
'refreshProviders': [] 'refreshProviders': []
}>() }>()
const { success: showSuccess, error: showError } = useToast() const { success: showSuccess, error: showError } = useToast()
const { copyToClipboard } = useClipboard()
interface Props { interface Props {
model: GlobalModelResponse | null model: GlobalModelResponse | null
@@ -763,16 +765,6 @@ function handleClose() {
} }
} }
// 复制到剪贴板
async function copyToClipboard(text: string) {
try {
await navigator.clipboard.writeText(text)
showSuccess('已复制')
} catch {
showError('复制失败')
}
}
// 格式化日期 // 格式化日期
function formatDate(dateStr: string): string { function formatDate(dateStr: string): string {
if (!dateStr) return '-' if (!dateStr) return '-'

View File

@@ -433,11 +433,17 @@ const availableGlobalModels = computed(() => {
) )
}) })
// 计算可添加的上游模型(排除已关联的) // 计算可添加的上游模型(排除已关联的,包括主模型名和映射名称
const availableUpstreamModelsBase = computed(() => { const availableUpstreamModelsBase = computed(() => {
const existingModelNames = new Set( const existingModelNames = new Set<string>()
existingModels.value.map(m => m.provider_model_name) for (const m of existingModels.value) {
) // 主模型名
existingModelNames.add(m.provider_model_name)
// 映射名称
for (const mapping of m.provider_model_mappings ?? []) {
if (mapping.name) existingModelNames.add(mapping.name)
}
}
return upstreamModels.value.filter(m => !existingModelNames.has(m.id)) return upstreamModels.value.filter(m => !existingModelNames.has(m.id))
}) })

View File

@@ -260,6 +260,7 @@ import {
updateEndpointKey, updateEndpointKey,
getAllCapabilities, getAllCapabilities,
type EndpointAPIKey, type EndpointAPIKey,
type EndpointAPIKeyUpdate,
type ProviderEndpoint, type ProviderEndpoint,
type CapabilityDefinition type CapabilityDefinition
} from '@/api/endpoints' } from '@/api/endpoints'
@@ -386,10 +387,11 @@ function loadKeyData() {
api_key: '', api_key: '',
rate_multiplier: props.editingKey.rate_multiplier || 1.0, rate_multiplier: props.editingKey.rate_multiplier || 1.0,
internal_priority: props.editingKey.internal_priority ?? 50, internal_priority: props.editingKey.internal_priority ?? 50,
max_concurrent: props.editingKey.max_concurrent || undefined, // 保留原始的 null/undefined 状态null 表示自适应模式
rate_limit: props.editingKey.rate_limit || undefined, max_concurrent: props.editingKey.max_concurrent ?? undefined,
daily_limit: props.editingKey.daily_limit || undefined, rate_limit: props.editingKey.rate_limit ?? undefined,
monthly_limit: props.editingKey.monthly_limit || undefined, daily_limit: props.editingKey.daily_limit ?? undefined,
monthly_limit: props.editingKey.monthly_limit ?? undefined,
cache_ttl_minutes: props.editingKey.cache_ttl_minutes ?? 5, cache_ttl_minutes: props.editingKey.cache_ttl_minutes ?? 5,
max_probe_interval_minutes: props.editingKey.max_probe_interval_minutes ?? 32, max_probe_interval_minutes: props.editingKey.max_probe_interval_minutes ?? 32,
note: props.editingKey.note || '', note: props.editingKey.note || '',
@@ -439,12 +441,17 @@ async function handleSave() {
saving.value = true saving.value = true
try { try {
if (props.editingKey) { if (props.editingKey) {
// 更新 // 更新模式
const updateData: any = { // 注意max_concurrent 需要显式发送 null 来切换到自适应模式
// undefined 会在 JSON 中被忽略,所以用 null 表示"清空/自适应"
const updateData: EndpointAPIKeyUpdate = {
name: form.value.name, name: form.value.name,
rate_multiplier: form.value.rate_multiplier, rate_multiplier: form.value.rate_multiplier,
internal_priority: form.value.internal_priority, internal_priority: form.value.internal_priority,
max_concurrent: form.value.max_concurrent, // 显式使用 null 表示自适应模式,这样后端能区分"未提供"和"设置为 null"
// 注意:只有 max_concurrent 需要这种处理,因为它有"自适应模式"的概念
// 其他限制字段rate_limit 等)不支持"清空"操作undefined 会被 JSON 忽略即不更新
max_concurrent: form.value.max_concurrent === undefined ? null : form.value.max_concurrent,
rate_limit: form.value.rate_limit, rate_limit: form.value.rate_limit,
daily_limit: form.value.daily_limit, daily_limit: form.value.daily_limit,
monthly_limit: form.value.monthly_limit, monthly_limit: form.value.monthly_limit,

View File

@@ -17,7 +17,7 @@
v-model:open="modelSelectOpen" v-model:open="modelSelectOpen"
:model-value="formData.modelId" :model-value="formData.modelId"
:disabled="!!editingGroup" :disabled="!!editingGroup"
@update:model-value="formData.modelId = $event" @update:model-value="handleModelChange"
> >
<SelectTrigger class="h-9"> <SelectTrigger class="h-9">
<SelectValue placeholder="请选择模型" /> <SelectValue placeholder="请选择模型" />
@@ -449,7 +449,17 @@ interface UpstreamModelGroup {
} }
const groupedAvailableUpstreamModels = computed<UpstreamModelGroup[]>(() => { const groupedAvailableUpstreamModels = computed<UpstreamModelGroup[]>(() => {
// 收集当前表单已添加的名称
const addedNames = new Set(formData.value.aliases.map(a => a.name.trim())) const addedNames = new Set(formData.value.aliases.map(a => a.name.trim()))
// 收集所有已存在的映射名称(包括主模型名和映射名称)
for (const m of props.models) {
addedNames.add(m.provider_model_name)
for (const mapping of m.provider_model_mappings ?? []) {
if (mapping.name) addedNames.add(mapping.name)
}
}
const availableModels = filteredUpstreamModels.value.filter(m => !addedNames.has(m.id)) const availableModels = filteredUpstreamModels.value.filter(m => !addedNames.has(m.id))
const groups = new Map<string, UpstreamModelGroup>() const groups = new Map<string, UpstreamModelGroup>()
@@ -519,6 +529,15 @@ function initForm() {
} }
} }
// 处理模型选择变更
function handleModelChange(value: string) {
formData.value.modelId = value
const selectedModel = props.models.find(m => m.id === value)
if (selectedModel) {
upstreamModelSearch.value = selectedModel.provider_model_name
}
}
// 切换 API 格式 // 切换 API 格式
function toggleApiFormat(format: string) { function toggleApiFormat(format: string) {
const index = formData.value.apiFormats.indexOf(format) const index = formData.value.apiFormats.indexOf(format)

View File

@@ -483,9 +483,9 @@
<span <span
v-if="key.max_concurrent || key.is_adaptive" v-if="key.max_concurrent || key.is_adaptive"
class="text-muted-foreground" class="text-muted-foreground"
:title="key.is_adaptive ? `自适应并发限制(学习值: ${key.learned_max_concurrent ?? '未学习'}` : '固定并发限制'" :title="key.is_adaptive ? `自适应并发限制(学习值: ${key.learned_max_concurrent ?? '未学习'}` : `固定并发限制: ${key.max_concurrent}`"
> >
{{ key.is_adaptive ? '自适应' : '固定' }}并发: {{ key.learned_max_concurrent || key.max_concurrent || 3 }} {{ key.is_adaptive ? '自适应' : '固定' }}并发: {{ key.is_adaptive ? (key.learned_max_concurrent ?? '学习中') : key.max_concurrent }}
</span> </span>
</div> </div>
</div> </div>
@@ -531,6 +531,7 @@
<!-- 模型名称映射 --> <!-- 模型名称映射 -->
<ModelAliasesTab <ModelAliasesTab
v-if="provider" v-if="provider"
ref="modelAliasesTabRef"
:key="`aliases-${provider.id}`" :key="`aliases-${provider.id}`"
:provider="provider" :provider="provider"
@refresh="handleRelatedDataRefresh" @refresh="handleRelatedDataRefresh"
@@ -660,6 +661,7 @@ import Button from '@/components/ui/button.vue'
import Badge from '@/components/ui/badge.vue' import Badge from '@/components/ui/badge.vue'
import Card from '@/components/ui/card.vue' import Card from '@/components/ui/card.vue'
import { useToast } from '@/composables/useToast' import { useToast } from '@/composables/useToast'
import { useClipboard } from '@/composables/useClipboard'
import { getProvider, getProviderEndpoints } from '@/api/endpoints' import { getProvider, getProviderEndpoints } from '@/api/endpoints'
import { import {
KeyFormDialog, KeyFormDialog,
@@ -705,6 +707,7 @@ const emit = defineEmits<{
}>() }>()
const { error: showError, success: showSuccess } = useToast() const { error: showError, success: showSuccess } = useToast()
const { copyToClipboard } = useClipboard()
const loading = ref(false) const loading = ref(false)
const provider = ref<any>(null) const provider = ref<any>(null)
@@ -735,6 +738,9 @@ const deleteModelConfirmOpen = ref(false)
const modelToDelete = ref<Model | null>(null) const modelToDelete = ref<Model | null>(null)
const batchAssignDialogOpen = ref(false) const batchAssignDialogOpen = ref(false)
// ModelAliasesTab 组件引用
const modelAliasesTabRef = ref<InstanceType<typeof ModelAliasesTab> | null>(null)
// 拖动排序相关状态 // 拖动排序相关状态
const dragState = ref({ const dragState = ref({
isDragging: false, isDragging: false,
@@ -756,7 +762,9 @@ const hasBlockingDialogOpen = computed(() =>
deleteKeyConfirmOpen.value || deleteKeyConfirmOpen.value ||
modelFormDialogOpen.value || modelFormDialogOpen.value ||
deleteModelConfirmOpen.value || deleteModelConfirmOpen.value ||
batchAssignDialogOpen.value batchAssignDialogOpen.value ||
// 检测 ModelAliasesTab 子组件的 Dialog 是否打开
modelAliasesTabRef.value?.dialogOpen
) )
// 监听 providerId 变化 // 监听 providerId 变化
@@ -1244,16 +1252,6 @@ function getHealthScoreBarColor(score: number): string {
return 'bg-red-500 dark:bg-red-400' return 'bg-red-500 dark:bg-red-400'
} }
// 复制到剪贴板
async function copyToClipboard(text: string) {
try {
await navigator.clipboard.writeText(text)
showSuccess('已复制到剪贴板')
} catch {
showError('复制失败', '错误')
}
}
// 加载 Provider 信息 // 加载 Provider 信息
async function loadProvider() { async function loadProvider() {
if (!props.providerId) return if (!props.providerId) return

View File

@@ -110,16 +110,30 @@
<div <div
v-for="mapping in group.aliases" v-for="mapping in group.aliases"
:key="mapping.name" :key="mapping.name"
class="flex items-center gap-2 py-1" class="flex items-center justify-between gap-2 py-1"
> >
<!-- 优先级标签 --> <div class="flex items-center gap-2 flex-1 min-w-0">
<span class="inline-flex items-center justify-center w-5 h-5 rounded bg-background border text-xs font-medium shrink-0"> <!-- 优先级标签 -->
{{ mapping.priority }} <span class="inline-flex items-center justify-center w-5 h-5 rounded bg-background border text-xs font-medium shrink-0">
</span> {{ mapping.priority }}
<!-- 映射名称 --> </span>
<span class="font-mono text-sm truncate"> <!-- 映射名称 -->
{{ mapping.name }} <span class="font-mono text-sm truncate">
</span> {{ mapping.name }}
</span>
</div>
<!-- 测试按钮 -->
<Button
variant="ghost"
size="icon"
class="h-7 w-7 shrink-0"
title="测试映射"
:disabled="testingMapping === `${group.model.id}-${group.apiFormatsKey}-${mapping.name}`"
@click="testMapping(group, mapping)"
>
<Loader2 v-if="testingMapping === `${group.model.id}-${group.apiFormatsKey}-${mapping.name}`" class="w-3 h-3 animate-spin" />
<Play v-else class="w-3 h-3" />
</Button>
</div> </div>
</div> </div>
</div> </div>
@@ -166,18 +180,20 @@
<script setup lang="ts"> <script setup lang="ts">
import { ref, computed, onMounted, watch } from 'vue' import { ref, computed, onMounted, watch } from 'vue'
import { Tag, Plus, Edit, Trash2, ChevronRight } from 'lucide-vue-next' import { Tag, Plus, Edit, Trash2, ChevronRight, Loader2, Play } from 'lucide-vue-next'
import { Card, Button, Badge } from '@/components/ui' import { Card, Button, Badge } from '@/components/ui'
import AlertDialog from '@/components/common/AlertDialog.vue' import AlertDialog from '@/components/common/AlertDialog.vue'
import ModelMappingDialog, { type AliasGroup } from '../ModelMappingDialog.vue' import ModelMappingDialog, { type AliasGroup } from '../ModelMappingDialog.vue'
import { useToast } from '@/composables/useToast' import { useToast } from '@/composables/useToast'
import { import {
getProviderModels, getProviderModels,
testModel,
API_FORMAT_LABELS, API_FORMAT_LABELS,
type Model, type Model,
type ProviderModelAlias type ProviderModelAlias
} from '@/api/endpoints' } from '@/api/endpoints'
import { updateModel } from '@/api/endpoints/models' import { updateModel } from '@/api/endpoints/models'
import { parseTestModelError } from '@/utils/errorParser'
const props = defineProps<{ const props = defineProps<{
provider: any provider: any
@@ -196,6 +212,7 @@ const dialogOpen = ref(false)
const deleteConfirmOpen = ref(false) const deleteConfirmOpen = ref(false)
const editingGroup = ref<AliasGroup | null>(null) const editingGroup = ref<AliasGroup | null>(null)
const deletingGroup = ref<AliasGroup | null>(null) const deletingGroup = ref<AliasGroup | null>(null)
const testingMapping = ref<string | null>(null)
// 列表展开状态 // 列表展开状态
const expandedAliasGroups = ref<Set<string>>(new Set()) const expandedAliasGroups = ref<Set<string>>(new Set())
@@ -337,6 +354,49 @@ async function onDialogSaved() {
emit('refresh') emit('refresh')
} }
// 测试模型映射
async function testMapping(group: any, mapping: any) {
const testingKey = `${group.model.id}-${group.apiFormatsKey}-${mapping.name}`
testingMapping.value = testingKey
try {
// 根据分组的 API 格式来确定应该使用的格式
let apiFormat = null
if (group.apiFormats.length === 1) {
apiFormat = group.apiFormats[0]
} else if (group.apiFormats.length === 0) {
// 如果没有指定格式,但分组显示为"全部",则使用模型的默认格式
apiFormat = group.model.effective_api_format || group.model.api_format
}
const result = await testModel({
provider_id: props.provider.id,
model_name: mapping.name, // 使用映射名称进行测试
message: "hello",
api_format: apiFormat
})
if (result.success) {
showSuccess(`映射 "${mapping.name}" 测试成功`)
// 如果有响应内容,可以显示更多信息
if (result.data?.response?.choices?.[0]?.message?.content) {
const content = result.data.response.choices[0].message.content
showSuccess(`测试成功,响应: ${content.substring(0, 100)}${content.length > 100 ? '...' : ''}`)
} else if (result.data?.content_preview) {
showSuccess(`流式测试成功,预览: ${result.data.content_preview}`)
}
} else {
showError(`映射测试失败: ${parseTestModelError(result)}`)
}
} catch (err: any) {
const errorMsg = err.response?.data?.detail || err.message || '测试请求失败'
showError(`映射测试失败: ${errorMsg}`)
} finally {
testingMapping.value = null
}
}
// 监听 provider 变化 // 监听 provider 变化
watch(() => props.provider?.id, (newId) => { watch(() => props.provider?.id, (newId) => {
if (newId) { if (newId) {
@@ -349,4 +409,9 @@ onMounted(() => {
loadModels() loadModels()
} }
}) })
// 暴露给父组件,用于检测是否有弹窗打开
defineExpose({
dialogOpen: computed(() => dialogOpen.value || deleteConfirmOpen.value)
})
</script> </script>

View File

@@ -156,6 +156,17 @@
</td> </td>
<td class="align-top px-4 py-3"> <td class="align-top px-4 py-3">
<div class="flex justify-center gap-1.5"> <div class="flex justify-center gap-1.5">
<Button
variant="ghost"
size="icon"
class="h-8 w-8"
title="测试模型"
:disabled="testingModelId === model.id"
@click="testModelConnection(model)"
>
<Loader2 v-if="testingModelId === model.id" class="w-3.5 h-3.5 animate-spin" />
<Play v-else class="w-3.5 h-3.5" />
</Button>
<Button <Button
variant="ghost" variant="ghost"
size="icon" size="icon"
@@ -209,12 +220,14 @@
<script setup lang="ts"> <script setup lang="ts">
import { ref, computed, onMounted } from 'vue' import { ref, computed, onMounted } from 'vue'
import { Box, Edit, Trash2, Layers, Eye, Wrench, Zap, Brain, Power, Copy, Image } from 'lucide-vue-next' import { Box, Edit, Trash2, Layers, Eye, Wrench, Zap, Brain, Power, Copy, Image, Loader2, Play } from 'lucide-vue-next'
import Card from '@/components/ui/card.vue' import Card from '@/components/ui/card.vue'
import Button from '@/components/ui/button.vue' import Button from '@/components/ui/button.vue'
import { useToast } from '@/composables/useToast' import { useToast } from '@/composables/useToast'
import { getProviderModels, type Model } from '@/api/endpoints' import { useClipboard } from '@/composables/useClipboard'
import { getProviderModels, testModel, type Model } from '@/api/endpoints'
import { updateModel } from '@/api/endpoints/models' import { updateModel } from '@/api/endpoints/models'
import { parseTestModelError } from '@/utils/errorParser'
const props = defineProps<{ const props = defineProps<{
provider: any provider: any
@@ -227,11 +240,13 @@ const emit = defineEmits<{
}>() }>()
const { error: showError, success: showSuccess } = useToast() const { error: showError, success: showSuccess } = useToast()
const { copyToClipboard } = useClipboard()
// 状态 // 状态
const loading = ref(false) const loading = ref(false)
const models = ref<Model[]>([]) const models = ref<Model[]>([])
const togglingModelId = ref<string | null>(null) const togglingModelId = ref<string | null>(null)
const testingModelId = ref<string | null>(null)
// 按名称排序的模型列表 // 按名称排序的模型列表
const sortedModels = computed(() => { const sortedModels = computed(() => {
@@ -244,12 +259,7 @@ const sortedModels = computed(() => {
// 复制模型 ID 到剪贴板 // 复制模型 ID 到剪贴板
async function copyModelId(modelId: string) { async function copyModelId(modelId: string) {
try { await copyToClipboard(modelId)
await navigator.clipboard.writeText(modelId)
showSuccess('已复制到剪贴板')
} catch {
showError('复制失败', '错误')
}
} }
// 加载模型 // 加载模型
@@ -380,6 +390,39 @@ async function toggleModelActive(model: Model) {
} }
} }
// 测试模型连接性
async function testModelConnection(model: Model) {
if (testingModelId.value) return
testingModelId.value = model.id
try {
const result = await testModel({
provider_id: props.provider.id,
model_name: model.provider_model_name,
message: "hello"
})
if (result.success) {
showSuccess(`模型 "${model.provider_model_name}" 测试成功`)
// 如果有响应内容,可以显示更多信息
if (result.data?.response?.choices?.[0]?.message?.content) {
const content = result.data.response.choices[0].message.content
showSuccess(`测试成功,响应: ${content.substring(0, 100)}${content.length > 100 ? '...' : ''}`)
} else if (result.data?.content_preview) {
showSuccess(`流式测试成功,预览: ${result.data.content_preview}`)
}
} else {
showError(`模型测试失败: ${parseTestModelError(result)}`)
}
} catch (err: any) {
const errorMsg = err.response?.data?.detail || err.message || '测试请求失败'
showError(`模型测试失败: ${errorMsg}`)
} finally {
testingModelId.value = null
}
}
onMounted(() => { onMounted(() => {
loadModels() loadModels()
}) })

View File

@@ -473,6 +473,7 @@
import { ref, watch, computed } from 'vue' import { ref, watch, computed } from 'vue'
import Button from '@/components/ui/button.vue' import Button from '@/components/ui/button.vue'
import { useEscapeKey } from '@/composables/useEscapeKey' import { useEscapeKey } from '@/composables/useEscapeKey'
import { useClipboard } from '@/composables/useClipboard'
import Card from '@/components/ui/card.vue' import Card from '@/components/ui/card.vue'
import Badge from '@/components/ui/badge.vue' import Badge from '@/components/ui/badge.vue'
import Separator from '@/components/ui/separator.vue' import Separator from '@/components/ui/separator.vue'
@@ -505,6 +506,7 @@ const copiedStates = ref<Record<string, boolean>>({})
const viewMode = ref<'compare' | 'formatted' | 'raw'>('compare') const viewMode = ref<'compare' | 'formatted' | 'raw'>('compare')
const currentExpandDepth = ref(1) const currentExpandDepth = ref(1)
const dataSource = ref<'client' | 'provider'>('client') const dataSource = ref<'client' | 'provider'>('client')
const { copyToClipboard } = useClipboard()
const historicalPricing = ref<{ const historicalPricing = ref<{
input_price: string input_price: string
output_price: string output_price: string
@@ -784,7 +786,7 @@ function copyJsonToClipboard(tabName: string) {
} }
if (data) { if (data) {
navigator.clipboard.writeText(JSON.stringify(data, null, 2)) copyToClipboard(JSON.stringify(data, null, 2), false)
copiedStates.value[tabName] = true copiedStates.value[tabName] = true
setTimeout(() => { setTimeout(() => {
copiedStates.value[tabName] = false copiedStates.value[tabName] = false

View File

@@ -366,14 +366,34 @@
</div> </div>
</TableCell> </TableCell>
<TableCell class="text-right py-4 w-[70px]"> <TableCell class="text-right py-4 w-[70px]">
<!-- pending 状态只显示增长的总时间 -->
<div <div
v-if="record.status === 'pending' || record.status === 'streaming'" v-if="record.status === 'pending'"
class="flex flex-col items-end text-xs gap-0.5" class="flex flex-col items-end text-xs gap-0.5"
> >
<span class="text-muted-foreground">-</span>
<span class="text-primary tabular-nums"> <span class="text-primary tabular-nums">
{{ getElapsedTime(record) }} {{ getElapsedTime(record) }}
</span> </span>
</div> </div>
<!-- streaming 状态首字固定 + 总时间增长 -->
<div
v-else-if="record.status === 'streaming'"
class="flex flex-col items-end text-xs gap-0.5"
>
<span
v-if="record.first_byte_time_ms != null"
class="tabular-nums"
>{{ (record.first_byte_time_ms / 1000).toFixed(2) }}s</span>
<span
v-else
class="text-muted-foreground"
>-</span>
<span class="text-primary tabular-nums">
{{ getElapsedTime(record) }}
</span>
</div>
<!-- 已完成状态首字 + 总耗时 -->
<div <div
v-else-if="record.response_time_ms != null" v-else-if="record.response_time_ms != null"
class="flex flex-col items-end text-xs gap-0.5" class="flex flex-col items-end text-xs gap-0.5"

View File

@@ -86,6 +86,34 @@
</p> </p>
</div> </div>
<div
v-if="isEditMode && form.password.length > 0"
class="space-y-2"
>
<Label class="text-sm font-medium">
确认新密码 <span class="text-muted-foreground">*</span>
</Label>
<Input
:id="`pwd-confirm-${formNonce}`"
v-model="form.confirmPassword"
type="password"
autocomplete="new-password"
data-form-type="other"
data-lpignore="true"
:name="`confirm-${formNonce}`"
required
minlength="6"
placeholder="再次输入新密码"
class="h-10"
/>
<p
v-if="form.confirmPassword.length > 0 && form.password !== form.confirmPassword"
class="text-xs text-destructive"
>
两次输入的密码不一致
</p>
</div>
<div class="space-y-2"> <div class="space-y-2">
<Label <Label
for="form-email" for="form-email"
@@ -423,6 +451,7 @@ const apiFormats = ref<Array<{ value: string; label: string }>>([])
const form = ref({ const form = ref({
username: '', username: '',
password: '', password: '',
confirmPassword: '',
email: '', email: '',
quota: 10, quota: 10,
role: 'user' as 'admin' | 'user', role: 'user' as 'admin' | 'user',
@@ -443,6 +472,7 @@ function resetForm() {
form.value = { form.value = {
username: '', username: '',
password: '', password: '',
confirmPassword: '',
email: '', email: '',
quota: 10, quota: 10,
role: 'user', role: 'user',
@@ -461,6 +491,7 @@ function loadUserData() {
form.value = { form.value = {
username: props.user.username, username: props.user.username,
password: '', password: '',
confirmPassword: '',
email: props.user.email || '', email: props.user.email || '',
quota: props.user.quota_usd == null ? 10 : props.user.quota_usd, quota: props.user.quota_usd == null ? 10 : props.user.quota_usd,
role: props.user.role, role: props.user.role,
@@ -486,7 +517,9 @@ const isFormValid = computed(() => {
const hasUsername = form.value.username.trim().length > 0 const hasUsername = form.value.username.trim().length > 0
const hasEmail = form.value.email.trim().length > 0 const hasEmail = form.value.email.trim().length > 0
const hasPassword = isEditMode.value || form.value.password.length >= 6 const hasPassword = isEditMode.value || form.value.password.length >= 6
return hasUsername && hasEmail && hasPassword // 编辑模式下如果填写了密码,必须确认密码一致
const passwordConfirmed = !isEditMode.value || form.value.password.length === 0 || form.value.password === form.value.confirmPassword
return hasUsername && hasEmail && hasPassword && passwordConfirmed
}) })
// 加载访问控制选项 // 加载访问控制选项

View File

@@ -14,7 +14,7 @@ export const useUsersStore = defineStore('users', () => {
try { try {
users.value = await usersApi.getAllUsers() users.value = await usersApi.getAllUsers()
} catch (err: any) { } catch (err: any) {
error.value = err.response?.data?.detail || '获取用户列表失败' error.value = err.response?.data?.error?.message || err.response?.data?.detail || '获取用户列表失败'
} finally { } finally {
loading.value = false loading.value = false
} }
@@ -29,7 +29,7 @@ export const useUsersStore = defineStore('users', () => {
users.value.push(newUser) users.value.push(newUser)
return newUser return newUser
} catch (err: any) { } catch (err: any) {
error.value = err.response?.data?.detail || '创建用户失败' error.value = err.response?.data?.error?.message || err.response?.data?.detail || '创建用户失败'
throw err throw err
} finally { } finally {
loading.value = false loading.value = false
@@ -52,7 +52,7 @@ export const useUsersStore = defineStore('users', () => {
} }
return updatedUser return updatedUser
} catch (err: any) { } catch (err: any) {
error.value = err.response?.data?.detail || '更新用户失败' error.value = err.response?.data?.error?.message || err.response?.data?.detail || '更新用户失败'
throw err throw err
} finally { } finally {
loading.value = false loading.value = false
@@ -67,7 +67,7 @@ export const useUsersStore = defineStore('users', () => {
await usersApi.deleteUser(userId) await usersApi.deleteUser(userId)
users.value = users.value.filter(u => u.id !== userId) users.value = users.value.filter(u => u.id !== userId)
} catch (err: any) { } catch (err: any) {
error.value = err.response?.data?.detail || '删除用户失败' error.value = err.response?.data?.error?.message || err.response?.data?.detail || '删除用户失败'
throw err throw err
} finally { } finally {
loading.value = false loading.value = false
@@ -78,7 +78,7 @@ export const useUsersStore = defineStore('users', () => {
try { try {
return await usersApi.getUserApiKeys(userId) return await usersApi.getUserApiKeys(userId)
} catch (err: any) { } catch (err: any) {
error.value = err.response?.data?.detail || '获取 API Keys 失败' error.value = err.response?.data?.error?.message || err.response?.data?.detail || '获取 API Keys 失败'
throw err throw err
} }
} }
@@ -87,7 +87,7 @@ export const useUsersStore = defineStore('users', () => {
try { try {
return await usersApi.createApiKey(userId, name) return await usersApi.createApiKey(userId, name)
} catch (err: any) { } catch (err: any) {
error.value = err.response?.data?.detail || '创建 API Key 失败' error.value = err.response?.data?.error?.message || err.response?.data?.detail || '创建 API Key 失败'
throw err throw err
} }
} }
@@ -96,7 +96,7 @@ export const useUsersStore = defineStore('users', () => {
try { try {
await usersApi.deleteApiKey(userId, keyId) await usersApi.deleteApiKey(userId, keyId)
} catch (err: any) { } catch (err: any) {
error.value = err.response?.data?.detail || '删除 API Key 失败' error.value = err.response?.data?.error?.message || err.response?.data?.detail || '删除 API Key 失败'
throw err throw err
} }
} }
@@ -110,7 +110,7 @@ export const useUsersStore = defineStore('users', () => {
// 刷新用户列表以获取最新数据 // 刷新用户列表以获取最新数据
await fetchUsers() await fetchUsers()
} catch (err: any) { } catch (err: any) {
error.value = err.response?.data?.detail || '重置配额失败' error.value = err.response?.data?.error?.message || err.response?.data?.detail || '重置配额失败'
throw err throw err
} finally { } finally {
loading.value = false loading.value = false

View File

@@ -198,3 +198,49 @@ export function parseApiErrorShort(err: unknown, defaultMessage: string = '操
const lines = fullError.split('\n') const lines = fullError.split('\n')
return lines[0] || defaultMessage return lines[0] || defaultMessage
} }
/**
* 解析模型测试响应的错误信息
* @param result 测试响应结果
* @returns 格式化的错误信息
*/
export function parseTestModelError(result: {
error?: string
data?: {
response?: {
status_code?: number
error?: string | { message?: string }
}
}
}): string {
let errorMsg = result.error || '测试失败'
// 检查HTTP状态码错误
if (result.data?.response?.status_code) {
const status = result.data.response.status_code
if (status === 403) {
errorMsg = '认证失败: API密钥无效或客户端类型不被允许'
} else if (status === 401) {
errorMsg = '认证失败: API密钥无效或已过期'
} else if (status === 404) {
errorMsg = '模型不存在: 请检查模型名称是否正确'
} else if (status === 429) {
errorMsg = '请求频率过高: 请稍后重试'
} else if (status >= 500) {
errorMsg = `服务器错误: HTTP ${status}`
} else {
errorMsg = `请求失败: HTTP ${status}`
}
}
// 尝试从错误响应中提取更多信息
if (result.data?.response?.error) {
if (typeof result.data.response.error === 'string') {
errorMsg = result.data.response.error
} else if (result.data.response.error?.message) {
errorMsg = result.data.response.error.message
}
}
return errorMsg
}

View File

@@ -650,6 +650,7 @@
import { ref, computed, onMounted } from 'vue' import { ref, computed, onMounted } from 'vue'
import { useToast } from '@/composables/useToast' import { useToast } from '@/composables/useToast'
import { useConfirm } from '@/composables/useConfirm' import { useConfirm } from '@/composables/useConfirm'
import { useClipboard } from '@/composables/useClipboard'
import { adminApi, type AdminApiKey, type CreateStandaloneApiKeyRequest } from '@/api/admin' import { adminApi, type AdminApiKey, type CreateStandaloneApiKeyRequest } from '@/api/admin'
import { import {
@@ -693,6 +694,7 @@ import { log } from '@/utils/logger'
const { success, error } = useToast() const { success, error } = useToast()
const { confirmDanger } = useConfirm() const { confirmDanger } = useConfirm()
const { copyToClipboard } = useClipboard()
const apiKeys = ref<AdminApiKey[]>([]) const apiKeys = ref<AdminApiKey[]>([])
const loading = ref(false) const loading = ref(false)
@@ -927,20 +929,14 @@ function selectKey() {
} }
async function copyKey() { async function copyKey() {
try { await copyToClipboard(newKeyValue.value)
await navigator.clipboard.writeText(newKeyValue.value)
success('API Key 已复制到剪贴板')
} catch {
error('复制失败,请手动复制')
}
} }
async function copyKeyPrefix(apiKey: AdminApiKey) { async function copyKeyPrefix(apiKey: AdminApiKey) {
try { try {
// 调用后端 API 获取完整密钥 // 调用后端 API 获取完整密钥
const response = await adminApi.getFullApiKey(apiKey.id) const response = await adminApi.getFullApiKey(apiKey.id)
await navigator.clipboard.writeText(response.key) await copyToClipboard(response.key)
success('完整密钥已复制到剪贴板')
} catch (err) { } catch (err) {
log.error('复制密钥失败:', err) log.error('复制密钥失败:', err)
error('复制失败,请重试') error('复制失败,请重试')
@@ -1046,9 +1042,10 @@ async function handleKeyFormSubmit(data: StandaloneKeyFormData) {
rate_limit: data.rate_limit, rate_limit: data.rate_limit,
expire_days: data.never_expire ? null : (data.expire_days || null), expire_days: data.never_expire ? null : (data.expire_days || null),
auto_delete_on_expiry: data.auto_delete_on_expiry, auto_delete_on_expiry: data.auto_delete_on_expiry,
allowed_providers: data.allowed_providers.length > 0 ? data.allowed_providers : undefined, // 空数组表示清除限制(允许全部),后端会将空数组存为 NULL
allowed_api_formats: data.allowed_api_formats.length > 0 ? data.allowed_api_formats : undefined, allowed_providers: data.allowed_providers,
allowed_models: data.allowed_models.length > 0 ? data.allowed_models : undefined allowed_api_formats: data.allowed_api_formats,
allowed_models: data.allowed_models
} }
await adminApi.updateApiKey(data.id, updateData) await adminApi.updateApiKey(data.id, updateData)
success('API Key 更新成功') success('API Key 更新成功')
@@ -1064,9 +1061,10 @@ async function handleKeyFormSubmit(data: StandaloneKeyFormData) {
rate_limit: data.rate_limit, rate_limit: data.rate_limit,
expire_days: data.never_expire ? null : (data.expire_days || null), expire_days: data.never_expire ? null : (data.expire_days || null),
auto_delete_on_expiry: data.auto_delete_on_expiry, auto_delete_on_expiry: data.auto_delete_on_expiry,
allowed_providers: data.allowed_providers.length > 0 ? data.allowed_providers : undefined, // 空数组表示不设置限制(允许全部),后端会将空数组存为 NULL
allowed_api_formats: data.allowed_api_formats.length > 0 ? data.allowed_api_formats : undefined, allowed_providers: data.allowed_providers,
allowed_models: data.allowed_models.length > 0 ? data.allowed_models : undefined allowed_api_formats: data.allowed_api_formats,
allowed_models: data.allowed_models
} }
const response = await adminApi.createStandaloneApiKey(createData) const response = await adminApi.createStandaloneApiKey(createData)
newKeyValue.value = response.key newKeyValue.value = response.key

View File

@@ -46,6 +46,7 @@ const clearingRowAffinityKey = ref<string | null>(null)
const currentPage = ref(1) const currentPage = ref(1)
const pageSize = ref(20) const pageSize = ref(20)
const currentTime = ref(Math.floor(Date.now() / 1000)) const currentTime = ref(Math.floor(Date.now() / 1000))
const analysisHoursSelectOpen = ref(false)
// ==================== 模型映射缓存 ==================== // ==================== 模型映射缓存 ====================
@@ -1056,7 +1057,7 @@ onBeforeUnmount(() => {
<span class="text-xs text-muted-foreground hidden sm:inline">分析用户请求间隔推荐合适的缓存 TTL</span> <span class="text-xs text-muted-foreground hidden sm:inline">分析用户请求间隔推荐合适的缓存 TTL</span>
</div> </div>
<div class="flex flex-wrap items-center gap-2"> <div class="flex flex-wrap items-center gap-2">
<Select v-model="analysisHours"> <Select v-model="analysisHours" v-model:open="analysisHoursSelectOpen">
<SelectTrigger class="w-24 sm:w-28 h-8"> <SelectTrigger class="w-24 sm:w-28 h-8">
<SelectValue placeholder="时间段" /> <SelectValue placeholder="时间段" />
</SelectTrigger> </SelectTrigger>

View File

@@ -713,6 +713,7 @@ import ProviderModelFormDialog from '@/features/providers/components/ProviderMod
import type { Model } from '@/api/endpoints' import type { Model } from '@/api/endpoints'
import { useToast } from '@/composables/useToast' import { useToast } from '@/composables/useToast'
import { useConfirm } from '@/composables/useConfirm' import { useConfirm } from '@/composables/useConfirm'
import { useClipboard } from '@/composables/useClipboard'
import { useRowClick } from '@/composables/useRowClick' import { useRowClick } from '@/composables/useRowClick'
import { parseApiError } from '@/utils/errorParser' import { parseApiError } from '@/utils/errorParser'
import { import {
@@ -743,6 +744,7 @@ import { getProvidersSummary } from '@/api/endpoints/providers'
import { getAllCapabilities, type CapabilityDefinition } from '@/api/endpoints' import { getAllCapabilities, type CapabilityDefinition } from '@/api/endpoints'
const { success, error: showError } = useToast() const { success, error: showError } = useToast()
const { copyToClipboard } = useClipboard()
// 状态 // 状态
const loading = ref(false) const loading = ref(false)
@@ -1066,16 +1068,6 @@ function handleRowClick(event: MouseEvent, model: GlobalModelResponse) {
selectModel(model) selectModel(model)
} }
// 复制到剪贴板
async function copyToClipboard(text: string) {
try {
await navigator.clipboard.writeText(text)
success('已复制')
} catch {
showError('复制失败')
}
}
async function selectModel(model: GlobalModelResponse) { async function selectModel(model: GlobalModelResponse) {
selectedModel.value = model selectedModel.value = model
detailTab.value = 'basic' detailTab.value = 'basic'

View File

@@ -723,9 +723,19 @@ async function handleDeleteProvider(provider: ProviderWithEndpointsSummary) {
// 切换提供商状态 // 切换提供商状态
async function toggleProviderStatus(provider: ProviderWithEndpointsSummary) { async function toggleProviderStatus(provider: ProviderWithEndpointsSummary) {
try { try {
await updateProvider(provider.id, { is_active: !provider.is_active }) const newStatus = !provider.is_active
provider.is_active = !provider.is_active await updateProvider(provider.id, { is_active: newStatus })
showSuccess(provider.is_active ? '提供商已启用' : '提供商已停用')
// 更新抽屉内部的 provider 对象
provider.is_active = newStatus
// 同时更新主页面 providers 数组中的对象,实现无感更新
const targetProvider = providers.value.find(p => p.id === provider.id)
if (targetProvider) {
targetProvider.is_active = newStatus
}
showSuccess(newStatus ? '提供商已启用' : '提供商已停用')
} catch (err: any) { } catch (err: any) {
showError(err.response?.data?.detail || '操作失败', '错误') showError(err.response?.data?.detail || '操作失败', '错误')
} }

View File

@@ -701,6 +701,7 @@ import { ref, computed, onMounted, watch } from 'vue'
import { useUsersStore } from '@/stores/users' import { useUsersStore } from '@/stores/users'
import { useToast } from '@/composables/useToast' import { useToast } from '@/composables/useToast'
import { useConfirm } from '@/composables/useConfirm' import { useConfirm } from '@/composables/useConfirm'
import { useClipboard } from '@/composables/useClipboard'
import { usageApi, type UsageByUser } from '@/api/usage' import { usageApi, type UsageByUser } from '@/api/usage'
import { adminApi } from '@/api/admin' import { adminApi } from '@/api/admin'
@@ -748,6 +749,7 @@ import { log } from '@/utils/logger'
const { success, error } = useToast() const { success, error } = useToast()
const { confirmDanger, confirmWarning } = useConfirm() const { confirmDanger, confirmWarning } = useConfirm()
const { copyToClipboard } = useClipboard()
const usersStore = useUsersStore() const usersStore = useUsersStore()
// 用户表单对话框状态 // 用户表单对话框状态
@@ -875,7 +877,8 @@ async function toggleUserStatus(user: any) {
const action = user.is_active ? '禁用' : '启用' const action = user.is_active ? '禁用' : '启用'
const confirmed = await confirmDanger( const confirmed = await confirmDanger(
`确定要${action}用户 ${user.username} 吗?`, `确定要${action}用户 ${user.username} 吗?`,
`${action}用户` `${action}用户`,
action
) )
if (!confirmed) return if (!confirmed) return
@@ -884,7 +887,7 @@ async function toggleUserStatus(user: any) {
await usersStore.updateUser(user.id, { is_active: !user.is_active }) await usersStore.updateUser(user.id, { is_active: !user.is_active })
success(`用户已${action}`) success(`用户已${action}`)
} catch (err: any) { } catch (err: any) {
error(err.response?.data?.detail || '未知错误', `${action}用户失败`) error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', `${action}用户失败`)
} }
} }
@@ -955,7 +958,7 @@ async function handleUserFormSubmit(data: UserFormData & { password?: string })
closeUserFormDialog() closeUserFormDialog()
} catch (err: any) { } catch (err: any) {
const title = data.id ? '更新用户失败' : '创建用户失败' const title = data.id ? '更新用户失败' : '创建用户失败'
error(err.response?.data?.detail || '未知错误', title) error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', title)
} finally { } finally {
userFormDialogRef.value?.setSaving(false) userFormDialogRef.value?.setSaving(false)
} }
@@ -989,7 +992,7 @@ async function createApiKey() {
showNewApiKeyDialog.value = true showNewApiKeyDialog.value = true
await loadUserApiKeys(selectedUser.value.id) await loadUserApiKeys(selectedUser.value.id)
} catch (err: any) { } catch (err: any) {
error(err.response?.data?.detail || '未知错误', '创建 API Key 失败') error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', '创建 API Key 失败')
} finally { } finally {
creatingApiKey.value = false creatingApiKey.value = false
} }
@@ -1000,12 +1003,7 @@ function selectApiKey() {
} }
async function copyApiKey() { async function copyApiKey() {
try { await copyToClipboard(newApiKey.value)
await navigator.clipboard.writeText(newApiKey.value)
success('API Key已复制到剪贴板')
} catch {
error('复制失败,请手动复制')
}
} }
async function closeNewApiKeyDialog() { async function closeNewApiKeyDialog() {
@@ -1026,7 +1024,7 @@ async function deleteApiKey(apiKey: any) {
await loadUserApiKeys(selectedUser.value.id) await loadUserApiKeys(selectedUser.value.id)
success('API Key已删除') success('API Key已删除')
} catch (err: any) { } catch (err: any) {
error(err.response?.data?.detail || '未知错误', '删除 API Key 失败') error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', '删除 API Key 失败')
} }
} }
@@ -1034,11 +1032,10 @@ async function copyFullKey(apiKey: any) {
try { try {
// 调用后端 API 获取完整密钥 // 调用后端 API 获取完整密钥
const response = await adminApi.getFullApiKey(apiKey.id) const response = await adminApi.getFullApiKey(apiKey.id)
await navigator.clipboard.writeText(response.key) await copyToClipboard(response.key)
success('完整密钥已复制到剪贴板')
} catch (err: any) { } catch (err: any) {
log.error('复制密钥失败:', err) log.error('复制密钥失败:', err)
error(err.response?.data?.detail || '未知错误', '复制密钥失败') error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', '复制密钥失败')
} }
} }
@@ -1054,7 +1051,7 @@ async function resetQuota(user: any) {
await usersStore.resetUserQuota(user.id) await usersStore.resetUserQuota(user.id)
success('配额已重置') success('配额已重置')
} catch (err: any) { } catch (err: any) {
error(err.response?.data?.detail || '未知错误', '重置配额失败') error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', '重置配额失败')
} }
} }
@@ -1070,7 +1067,7 @@ async function deleteUser(user: any) {
await usersStore.deleteUser(user.id) await usersStore.deleteUser(user.id)
success('用户已删除') success('用户已删除')
} catch (err: any) { } catch (err: any) {
error(err.response?.data?.detail || '未知错误', '删除用户失败') error(err.response?.data?.error?.message || err.response?.data?.detail || '未知错误', '删除用户失败')
} }
} }
</script> </script>

View File

@@ -102,9 +102,9 @@
<!-- Main Content --> <!-- Main Content -->
<main class="relative z-10"> <main class="relative z-10">
<!-- Fixed Logo Container --> <!-- Fixed Logo Container -->
<div class="fixed inset-0 z-20 pointer-events-none flex items-center justify-center overflow-hidden"> <div class="mt-4 fixed inset-0 z-20 pointer-events-none flex items-center justify-center overflow-hidden">
<div <div
class="transform-gpu logo-container" class="mt-16 transform-gpu logo-container"
:class="[currentSection === SECTIONS.HOME ? 'home-section' : '', `logo-transition-${scrollDirection}`]" :class="[currentSection === SECTIONS.HOME ? 'home-section' : '', `logo-transition-${scrollDirection}`]"
:style="fixedLogoStyle" :style="fixedLogoStyle"
> >
@@ -151,7 +151,7 @@
class="min-h-screen snap-start flex items-center justify-center px-16 lg:px-20 py-20" class="min-h-screen snap-start flex items-center justify-center px-16 lg:px-20 py-20"
> >
<div class="max-w-4xl mx-auto text-center"> <div class="max-w-4xl mx-auto text-center">
<div class="h-80 w-full mb-16" /> <div class="h-80 w-full mb-16 mt-8" />
<h1 <h1
class="mb-6 text-5xl md:text-7xl font-bold text-[#191919] dark:text-white leading-tight transition-all duration-700" class="mb-6 text-5xl md:text-7xl font-bold text-[#191919] dark:text-white leading-tight transition-all duration-700"
:style="getTitleStyle(SECTIONS.HOME)" :style="getTitleStyle(SECTIONS.HOME)"
@@ -166,7 +166,7 @@
整合 Claude CodeCodex CLIGemini CLI 等多个 AI 编程助手 整合 Claude CodeCodex CLIGemini CLI 等多个 AI 编程助手
</p> </p>
<button <button
class="mt-16 transition-all duration-700 cursor-pointer hover:scale-110" class="mt-8 transition-all duration-700 cursor-pointer hover:scale-110"
:style="getScrollIndicatorStyle(SECTIONS.HOME)" :style="getScrollIndicatorStyle(SECTIONS.HOME)"
@click="scrollToSection(SECTIONS.CLAUDE)" @click="scrollToSection(SECTIONS.CLAUDE)"
> >

View File

@@ -301,6 +301,7 @@ function stopGlobalAutoRefresh() {
function handleAutoRefreshChange(value: boolean) { function handleAutoRefreshChange(value: boolean) {
globalAutoRefresh.value = value globalAutoRefresh.value = value
if (value) { if (value) {
refreshData() // 立即刷新一次
startGlobalAutoRefresh() startGlobalAutoRefresh()
} else { } else {
stopGlobalAutoRefresh() stopGlobalAutoRefresh()

View File

@@ -342,6 +342,7 @@ import {
Plus, Plus,
} from 'lucide-vue-next' } from 'lucide-vue-next'
import { useToast } from '@/composables/useToast' import { useToast } from '@/composables/useToast'
import { useClipboard } from '@/composables/useClipboard'
import { import {
Card, Card,
Table, Table,
@@ -370,6 +371,7 @@ import { useRowClick } from '@/composables/useRowClick'
import { log } from '@/utils/logger' import { log } from '@/utils/logger'
const { success, error: showError } = useToast() const { success, error: showError } = useToast()
const { copyToClipboard } = useClipboard()
// 状态 // 状态
const loading = ref(false) const loading = ref(false)
@@ -565,16 +567,6 @@ function hasTieredPricing(model: PublicGlobalModel): boolean {
return (tiered?.tiers?.length || 0) > 1 return (tiered?.tiers?.length || 0) > 1
} }
async function copyToClipboard(text: string) {
try {
await navigator.clipboard.writeText(text)
success('已复制')
} catch (err) {
log.error('复制失败:', err)
showError('复制失败')
}
}
onMounted(() => { onMounted(() => {
refreshData() refreshData()
}) })

View File

@@ -352,6 +352,7 @@ import {
} from 'lucide-vue-next' } from 'lucide-vue-next'
import { useEscapeKey } from '@/composables/useEscapeKey' import { useEscapeKey } from '@/composables/useEscapeKey'
import { useToast } from '@/composables/useToast' import { useToast } from '@/composables/useToast'
import { useClipboard } from '@/composables/useClipboard'
import Card from '@/components/ui/card.vue' import Card from '@/components/ui/card.vue'
import Badge from '@/components/ui/badge.vue' import Badge from '@/components/ui/badge.vue'
import Button from '@/components/ui/button.vue' import Button from '@/components/ui/button.vue'
@@ -375,6 +376,7 @@ const emit = defineEmits<{
}>() }>()
const { success: showSuccess, error: showError } = useToast() const { success: showSuccess, error: showError } = useToast()
const { copyToClipboard } = useClipboard()
interface Props { interface Props {
model: PublicGlobalModel | null model: PublicGlobalModel | null
@@ -408,15 +410,6 @@ function handleClose() {
emit('update:open', false) emit('update:open', false)
} }
async function copyToClipboard(text: string) {
try {
await navigator.clipboard.writeText(text)
showSuccess('已复制')
} catch {
showError('复制失败')
}
}
function getFirstTierPrice( function getFirstTierPrice(
tieredPricing: TieredPricingConfig | undefined | null, tieredPricing: TieredPricingConfig | undefined | null,
priceKey: 'input_price_per_1m' | 'output_price_per_1m' | 'cache_creation_price_per_1m' | 'cache_read_price_per_1m' priceKey: 'input_price_per_1m' | 'output_price_per_1m' | 'cache_creation_price_per_1m' | 'cache_read_price_per_1m'

View File

@@ -246,6 +246,15 @@ class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
if "api_key" in update_data: if "api_key" in update_data:
update_data["api_key"] = crypto_service.encrypt(update_data["api_key"]) update_data["api_key"] = crypto_service.encrypt(update_data["api_key"])
# 特殊处理 max_concurrent需要区分"未提供"和"显式设置为 null"
# 当 max_concurrent 被显式设置时(在 model_fields_set 中),即使值为 None 也应该更新
if "max_concurrent" in self.key_data.model_fields_set:
update_data["max_concurrent"] = self.key_data.max_concurrent
# 切换到自适应模式时,清空学习到的并发限制,让系统重新学习
if self.key_data.max_concurrent is None:
update_data["learned_max_concurrent"] = None
logger.info("Key %s 切换为自适应并发模式", self.key_id)
for field, value in update_data.items(): for field, value in update_data.items():
setattr(key, field, value) setattr(key, field, value)
key.updated_at = datetime.now(timezone.utc) key.updated_at = datetime.now(timezone.utc)
@@ -253,7 +262,7 @@ class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
db.commit() db.commit()
db.refresh(key) db.refresh(key)
logger.info(f"[OK] 更新 Key: ID={self.key_id}, Updates={list(update_data.keys())}") logger.info("[OK] 更新 Key: ID=%s, Updates=%s", self.key_id, list(update_data.keys()))
try: try:
decrypted_key = crypto_service.decrypt(key.api_key) decrypted_key = crypto_service.decrypt(key.api_key)

View File

@@ -947,7 +947,7 @@ class AdminClearProviderCacheAdapter(AdminApiAdapter):
class AdminCacheConfigAdapter(AdminApiAdapter): class AdminCacheConfigAdapter(AdminApiAdapter):
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override] async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
from src.services.cache.affinity_manager import CacheAffinityManager from src.services.cache.affinity_manager import CacheAffinityManager
from src.services.cache.aware_scheduler import CacheAwareScheduler from src.config.constants import ConcurrencyDefaults
from src.services.rate_limit.adaptive_reservation import get_adaptive_reservation_manager from src.services.rate_limit.adaptive_reservation import get_adaptive_reservation_manager
# 获取动态预留管理器的配置 # 获取动态预留管理器的配置
@@ -958,7 +958,7 @@ class AdminCacheConfigAdapter(AdminApiAdapter):
"status": "ok", "status": "ok",
"data": { "data": {
"cache_ttl_seconds": CacheAffinityManager.DEFAULT_CACHE_TTL, "cache_ttl_seconds": CacheAffinityManager.DEFAULT_CACHE_TTL,
"cache_reservation_ratio": CacheAwareScheduler.CACHE_RESERVATION_RATIO, "cache_reservation_ratio": ConcurrencyDefaults.CACHE_RESERVATION_RATIO,
"dynamic_reservation": { "dynamic_reservation": {
"enabled": True, "enabled": True,
"config": reservation_stats["config"], "config": reservation_stats["config"],
@@ -981,7 +981,7 @@ class AdminCacheConfigAdapter(AdminApiAdapter):
context.add_audit_metadata( context.add_audit_metadata(
action="cache_config", action="cache_config",
cache_ttl_seconds=CacheAffinityManager.DEFAULT_CACHE_TTL, cache_ttl_seconds=CacheAffinityManager.DEFAULT_CACHE_TTL,
cache_reservation_ratio=CacheAwareScheduler.CACHE_RESERVATION_RATIO, cache_reservation_ratio=ConcurrencyDefaults.CACHE_RESERVATION_RATIO,
dynamic_reservation_enabled=True, dynamic_reservation_enabled=True,
) )
return response return response
@@ -1236,7 +1236,7 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
try: try:
cached_data = json.loads(cached_str) cached_data = json.loads(cached_str)
provider_model_name = cached_data.get("provider_model_name") provider_model_name = cached_data.get("provider_model_name")
provider_model_mappings = cached_data.get("provider_model_mappings", []) cached_model_mappings = cached_data.get("provider_model_mappings", [])
# 获取 Provider 和 GlobalModel 信息 # 获取 Provider 和 GlobalModel 信息
provider = provider_map.get(provider_id) provider = provider_map.get(provider_id)
@@ -1245,8 +1245,8 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
if provider and global_model: if provider and global_model:
# 提取映射名称 # 提取映射名称
mapping_names = [] mapping_names = []
if provider_model_mappings: if cached_model_mappings:
for mapping_entry in provider_model_mappings: for mapping_entry in cached_model_mappings:
if isinstance(mapping_entry, dict) and mapping_entry.get("name"): if isinstance(mapping_entry, dict) and mapping_entry.get("name"):
mapping_names.append(mapping_entry["name"]) mapping_names.append(mapping_entry["name"])

View File

@@ -32,6 +32,17 @@ class ModelsQueryRequest(BaseModel):
api_key_id: Optional[str] = None api_key_id: Optional[str] = None
class TestModelRequest(BaseModel):
"""模型测试请求"""
provider_id: str
model_name: str
api_key_id: Optional[str] = None
stream: bool = False
message: Optional[str] = "你好"
api_format: Optional[str] = None # 指定使用的API格式如果不指定则使用端点的默认格式
# ============ API Endpoints ============ # ============ API Endpoints ============
@@ -206,3 +217,228 @@ async def query_available_models(
"display_name": provider.display_name, "display_name": provider.display_name,
}, },
} }
@router.post("/test-model")
async def test_model(
request: TestModelRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
测试模型连接性
向指定提供商的指定模型发送测试请求,验证模型是否可用
Args:
request: 测试请求
Returns:
测试结果
"""
# 获取提供商及其端点
provider = (
db.query(Provider)
.options(joinedload(Provider.endpoints).joinedload(ProviderEndpoint.api_keys))
.filter(Provider.id == request.provider_id)
.first()
)
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
# 找到合适的端点和API Key
endpoint_config = None
endpoint = None
api_key = None
if request.api_key_id:
# 使用指定的API Key
for ep in provider.endpoints:
for key in ep.api_keys:
if key.id == request.api_key_id and key.is_active and ep.is_active:
endpoint = ep
api_key = key
break
if endpoint:
break
else:
# 使用第一个可用的端点和密钥
for ep in provider.endpoints:
if not ep.is_active or not ep.api_keys:
continue
for key in ep.api_keys:
if key.is_active:
endpoint = ep
api_key = key
break
if endpoint:
break
if not endpoint or not api_key:
raise HTTPException(status_code=404, detail="No active endpoint or API key found")
try:
api_key_value = crypto_service.decrypt(api_key.api_key)
except Exception as e:
logger.error(f"[test-model] Failed to decrypt API key: {e}")
raise HTTPException(status_code=500, detail="Failed to decrypt API key")
# 构建请求配置
endpoint_config = {
"api_key": api_key_value,
"api_key_id": api_key.id, # 添加API Key ID用于用量记录
"base_url": endpoint.base_url,
"api_format": endpoint.api_format,
"extra_headers": endpoint.headers,
"timeout": endpoint.timeout or 30.0,
}
try:
# 获取对应的 Adapter 类
adapter_class = _get_adapter_for_format(endpoint.api_format)
if not adapter_class:
return {
"success": False,
"error": f"Unknown API format: {endpoint.api_format}",
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
"model": request.model_name,
}
logger.debug(f"[test-model] 使用 Adapter: {adapter_class.__name__}")
logger.debug(f"[test-model] 端点 API Format: {endpoint.api_format}")
# 如果请求指定了 api_format优先使用它
target_api_format = request.api_format or endpoint.api_format
if request.api_format and request.api_format != endpoint.api_format:
logger.debug(f"[test-model] 请求指定 API Format: {request.api_format}")
# 重新获取适配器
adapter_class = _get_adapter_for_format(request.api_format)
if not adapter_class:
return {
"success": False,
"error": f"Unknown API format: {request.api_format}",
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
"model": request.model_name,
}
logger.debug(f"[test-model] 重新选择 Adapter: {adapter_class.__name__}")
# 准备测试请求数据
check_request = {
"model": request.model_name,
"messages": [
{"role": "user", "content": request.message or "Hello! This is a test message."}
],
"max_tokens": 30,
"temperature": 0.7,
}
# 发送测试请求
async with httpx.AsyncClient(timeout=endpoint_config["timeout"]) as client:
# 非流式测试
logger.debug(f"[test-model] 开始非流式测试...")
response = await adapter_class.check_endpoint(
client,
endpoint_config["base_url"],
endpoint_config["api_key"],
check_request,
endpoint_config.get("extra_headers"),
# 用量计算参数(现在强制记录)
db=db,
user=current_user,
provider_name=provider.name,
provider_id=provider.id,
api_key_id=endpoint_config.get("api_key_id"),
model_name=request.model_name,
)
# 记录提供商返回信息
logger.debug(f"[test-model] 非流式测试结果:")
logger.debug(f"[test-model] Status Code: {response.get('status_code')}")
logger.debug(f"[test-model] Response Headers: {response.get('headers', {})}")
response_data = response.get('response', {})
response_body = response_data.get('response_body', {})
logger.debug(f"[test-model] Response Data: {response_data}")
logger.debug(f"[test-model] Response Body: {response_body}")
# 尝试解析 response_body (通常是 JSON 字符串)
parsed_body = response_body
import json
if isinstance(response_body, str):
try:
parsed_body = json.loads(response_body)
except json.JSONDecodeError:
pass
if isinstance(parsed_body, dict) and 'error' in parsed_body:
error_obj = parsed_body['error']
# 兼容 error 可能是字典或字符串的情况
if isinstance(error_obj, dict):
logger.debug(f"[test-model] Error Message: {error_obj.get('message')}")
raise HTTPException(status_code=500, detail=error_obj.get('message'))
else:
logger.debug(f"[test-model] Error: {error_obj}")
raise HTTPException(status_code=500, detail=error_obj)
elif 'error' in response:
logger.debug(f"[test-model] Error: {response['error']}")
raise HTTPException(status_code=500, detail=response['error'])
else:
# 如果有选择或消息,记录内容预览
if isinstance(response_data, dict):
if 'choices' in response_data and response_data['choices']:
choice = response_data['choices'][0]
if 'message' in choice:
content = choice['message'].get('content', '')
logger.debug(f"[test-model] Content Preview: {content[:200]}...")
elif 'content' in response_data and response_data['content']:
content = str(response_data['content'])
logger.debug(f"[test-model] Content Preview: {content[:200]}...")
# 检查测试是否成功基于HTTP状态码
status_code = response.get('status_code', 0)
is_success = status_code == 200 and 'error' not in response
return {
"success": is_success,
"data": {
"stream": False,
"response": response,
},
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
"model": request.model_name,
"endpoint": {
"id": endpoint.id,
"api_format": endpoint.api_format,
"base_url": endpoint.base_url,
},
}
except Exception as e:
logger.error(f"[test-model] Error testing model {request.model_name}: {e}")
return {
"success": False,
"error": str(e),
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
"model": request.model_name,
"endpoint": {
"id": endpoint.id,
"api_format": endpoint.api_format,
"base_url": endpoint.base_url,
} if endpoint else None,
}

View File

@@ -376,6 +376,9 @@ class BaseMessageHandler:
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输 使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
注意TTFB首字节时间由 StreamContext.record_first_byte_time() 记录,
并在最终 record_success 时传递到数据库,避免重复记录导致数据不一致。
Args: Args:
request_id: 请求 ID如果不传则使用 self.request_id request_id: 请求 ID如果不传则使用 self.request_id
""" """
@@ -407,6 +410,9 @@ class BaseMessageHandler:
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输 使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
注意TTFB首字节时间由 StreamContext.record_first_byte_time() 记录,
并在最终 record_success 时传递到数据库,避免重复记录导致数据不一致。
Args: Args:
ctx: 流式上下文,包含 provider_name 和 mapped_model ctx: 流式上下文,包含 provider_name 和 mapped_model
""" """

View File

@@ -63,6 +63,34 @@ class ChatAdapterBase(ApiAdapter):
name: str = "chat.base" name: str = "chat.base"
mode = ApiMode.STANDARD mode = ApiMode.STANDARD
# 子类可以配置的特殊方法用于check_endpoint
@classmethod
def build_endpoint_url(cls, base_url: str) -> str:
"""构建端点URL子类可以覆盖以自定义URL构建逻辑"""
# 默认实现在base_url后添加特定路径
return base_url
@classmethod
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
"""构建基础请求头,子类可以覆盖以自定义认证头"""
# 默认实现Bearer token认证
return {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
@classmethod
def get_protected_header_keys(cls) -> tuple:
"""返回不应被extra_headers覆盖的头部key子类可以覆盖"""
# 默认保护认证相关头部
return ("authorization", "content-type")
@classmethod
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""构建请求体,子类可以覆盖以自定义请求格式转换"""
# 默认实现:直接使用请求数据
return request_data.copy()
def __init__(self, allowed_api_formats: Optional[list[str]] = None): def __init__(self, allowed_api_formats: Optional[list[str]] = None):
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID] self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
@@ -654,6 +682,65 @@ class ChatAdapterBase(ApiAdapter):
# 默认实现返回空列表,子类应覆盖 # 默认实现返回空列表,子类应覆盖
return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models" return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models"
@classmethod
async def check_endpoint(
cls,
client: httpx.AsyncClient,
base_url: str,
api_key: str,
request_data: Dict[str, Any],
extra_headers: Optional[Dict[str, str]] = None,
# 用量计算参数(现在强制记录)
db: Optional[Any] = None,
user: Optional[Any] = None,
provider_name: Optional[str] = None,
provider_id: Optional[str] = None,
api_key_id: Optional[str] = None,
model_name: Optional[str] = None,
) -> Dict[str, Any]:
"""
测试模型连接性(非流式)
Args:
client: httpx 异步客户端
base_url: API 基础 URL
api_key: API 密钥(已解密)
request_data: 请求数据
extra_headers: 端点配置的额外请求头
db: 数据库会话
user: 用户对象
provider_name: 提供商名称
provider_id: 提供商ID
api_key_id: API Key ID
model_name: 模型名称
Returns:
测试响应数据
"""
from src.api.handlers.base.endpoint_checker import build_safe_headers, run_endpoint_check
# 使用子类配置方法构建请求组件
url = cls.build_endpoint_url(base_url)
base_headers = cls.build_base_headers(api_key)
protected_keys = cls.get_protected_header_keys()
headers = build_safe_headers(base_headers, extra_headers, protected_keys)
body = cls.build_request_body(request_data)
# 使用通用的endpoint checker执行请求
return await run_endpoint_check(
client=client,
url=url,
headers=headers,
json_body=body,
api_format=cls.name,
# 用量计算参数(现在强制记录)
db=db,
user=user,
provider_name=provider_name,
provider_id=provider_id,
api_key_id=api_key_id,
model_name=model_name or request_data.get("model"),
)
# ========================================================================= # =========================================================================
# Adapter 注册表 - 用于根据 API format 获取 Adapter 实例 # Adapter 注册表 - 用于根据 API format 获取 Adapter 实例

View File

@@ -484,9 +484,8 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
stream_response.raise_for_status() stream_response.raise_for_status()
# 使用字节流迭代器(避免 aiter_lines 的性能问题) # 使用字节流迭代器(避免 aiter_lines 的性能问题, aiter_bytes 会自动解压 gzip/deflate
# aiter_raw() 返回原始数据块,无缓冲,实现真正的流式传输 byte_iterator = stream_response.aiter_bytes()
byte_iterator = stream_response.aiter_raw()
# 预读检测嵌套错误 # 预读检测嵌套错误
prefetched_chunks = await stream_processor.prefetch_and_check_error( prefetched_chunks = await stream_processor.prefetch_and_check_error(

View File

@@ -614,6 +614,146 @@ class CliAdapterBase(ApiAdapter):
# 默认实现返回空列表,子类应覆盖 # 默认实现返回空列表,子类应覆盖
return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models" return [], f"{cls.FORMAT_ID} adapter does not implement fetch_models"
@classmethod
async def check_endpoint(
cls,
client: httpx.AsyncClient,
base_url: str,
api_key: str,
request_data: Dict[str, Any],
extra_headers: Optional[Dict[str, str]] = None,
# 用量计算参数
db: Optional[Any] = None,
user: Optional[Any] = None,
provider_name: Optional[str] = None,
provider_id: Optional[str] = None,
api_key_id: Optional[str] = None,
model_name: Optional[str] = None,
) -> Dict[str, Any]:
"""
测试模型连接性(非流式)
通用的CLI endpoint测试方法使用配置方法模式
- build_endpoint_url(): 构建请求URL
- build_base_headers(): 构建基础认证头
- get_protected_header_keys(): 获取受保护的头部key
- build_request_body(): 构建请求体
- get_cli_user_agent(): 获取CLI User-Agent子类可覆盖
Args:
client: httpx 异步客户端
base_url: API 基础 URL
api_key: API 密钥(已解密)
request_data: 请求数据
extra_headers: 端点配置的额外请求头
db: 数据库会话
user: 用户对象
provider_name: 提供商名称
provider_id: 提供商ID
api_key_id: API密钥ID
model_name: 模型名称
Returns:
测试响应数据
"""
from src.api.handlers.base.endpoint_checker import build_safe_headers, run_endpoint_check
# 构建请求组件
url = cls.build_endpoint_url(base_url, request_data, model_name)
base_headers = cls.build_base_headers(api_key)
protected_keys = cls.get_protected_header_keys()
# 添加CLI User-Agent
cli_user_agent = cls.get_cli_user_agent()
if cli_user_agent:
base_headers["User-Agent"] = cli_user_agent
protected_keys = tuple(list(protected_keys) + ["user-agent"])
headers = build_safe_headers(base_headers, extra_headers, protected_keys)
body = cls.build_request_body(request_data)
# 获取有效的模型名称
effective_model_name = model_name or request_data.get("model")
return await run_endpoint_check(
client=client,
url=url,
headers=headers,
json_body=body,
api_format=cls.name,
# 用量计算参数(现在强制记录)
db=db,
user=user,
provider_name=provider_name,
provider_id=provider_id,
api_key_id=api_key_id,
model_name=effective_model_name,
)
# =========================================================================
# CLI Adapter 配置方法 - 子类应覆盖这些方法而不是整个 check_endpoint
# =========================================================================
@classmethod
def build_endpoint_url(cls, base_url: str, request_data: Dict[str, Any], model_name: Optional[str] = None) -> str:
"""
构建CLI API端点URL - 子类应覆盖
Args:
base_url: API基础URL
request_data: 请求数据
model_name: 模型名称某些API需要如Gemini
Returns:
完整的端点URL
"""
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement build_endpoint_url")
@classmethod
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
"""
构建CLI API认证头 - 子类应覆盖
Args:
api_key: API密钥
Returns:
基础认证头部字典
"""
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement build_base_headers")
@classmethod
def get_protected_header_keys(cls) -> tuple:
"""
返回CLI API的保护头部key - 子类应覆盖
Returns:
保护头部key的元组
"""
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement get_protected_header_keys")
@classmethod
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""
构建CLI API请求体 - 子类应覆盖
Args:
request_data: 请求数据
Returns:
请求体字典
"""
raise NotImplementedError(f"{cls.FORMAT_ID} adapter must implement build_request_body")
@classmethod
def get_cli_user_agent(cls) -> Optional[str]:
"""
获取CLI User-Agent - 子类可覆盖
Returns:
CLI User-Agent字符串如果不需要则为None
"""
return None
# ========================================================================= # =========================================================================
# CLI Adapter 注册表 - 用于根据 API format 获取 CLI Adapter 实例 # CLI Adapter 注册表 - 用于根据 API format 获取 CLI Adapter 实例

View File

@@ -57,8 +57,10 @@ from src.models.database import (
ProviderEndpoint, ProviderEndpoint,
User, User,
) )
from src.config.settings import config
from src.services.provider.transport import build_provider_url from src.services.provider.transport import build_provider_url
from src.utils.sse_parser import SSEEventParser from src.utils.sse_parser import SSEEventParser
from src.utils.timeout import read_first_chunk_with_ttfb_timeout
class CliMessageHandlerBase(BaseMessageHandler): class CliMessageHandlerBase(BaseMessageHandler):
@@ -474,8 +476,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
stream_response.raise_for_status() stream_response.raise_for_status()
# 使用字节流迭代器(避免 aiter_lines 的性能问题) # 使用字节流迭代器(避免 aiter_lines 的性能问题, aiter_bytes 会自动解压 gzip/deflate
byte_iterator = stream_response.aiter_raw() byte_iterator = stream_response.aiter_bytes()
# 预读第一个数据块检测嵌套错误HTTP 200 但响应体包含错误) # 预读第一个数据块检测嵌套错误HTTP 200 但响应体包含错误)
prefetched_chunks = await self._prefetch_and_check_embedded_error( prefetched_chunks = await self._prefetch_and_check_embedded_error(
@@ -529,7 +531,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 检查是否需要格式转换 # 检查是否需要格式转换
needs_conversion = self._needs_format_conversion(ctx) needs_conversion = self._needs_format_conversion(ctx)
async for chunk in stream_response.aiter_raw(): async for chunk in stream_response.aiter_bytes():
# 在第一次输出数据前更新状态为 streaming # 在第一次输出数据前更新状态为 streaming
if not streaming_status_updated: if not streaming_status_updated:
self._update_usage_to_streaming_with_ctx(ctx) self._update_usage_to_streaming_with_ctx(ctx)
@@ -672,6 +674,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
同时检测 HTML 响应(通常是 base_url 配置错误导致返回网页)。 同时检测 HTML 响应(通常是 base_url 配置错误导致返回网页)。
首次读取时会应用 TTFB首字节超时检测超时则触发故障转移。
Args: Args:
byte_iterator: 字节流迭代器 byte_iterator: 字节流迭代器
provider: Provider 对象 provider: Provider 对象
@@ -684,6 +688,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
Raises: Raises:
EmbeddedErrorException: 如果检测到嵌套错误 EmbeddedErrorException: 如果检测到嵌套错误
ProviderNotAvailableException: 如果检测到 HTML 响应(配置错误) ProviderNotAvailableException: 如果检测到 HTML 响应(配置错误)
ProviderTimeoutException: 如果首字节超时TTFB timeout
""" """
prefetched_chunks: list = [] prefetched_chunks: list = []
max_prefetch_lines = 5 # 最多预读5行来检测错误 max_prefetch_lines = 5 # 最多预读5行来检测错误
@@ -704,7 +709,19 @@ class CliMessageHandlerBase(BaseMessageHandler):
else: else:
provider_parser = self.parser provider_parser = self.parser
async for chunk in byte_iterator: # 使用共享的 TTFB 超时函数读取首字节
ttfb_timeout = config.stream_first_byte_timeout
first_chunk, aiter = await read_first_chunk_with_ttfb_timeout(
byte_iterator,
timeout=ttfb_timeout,
request_id=self.request_id,
provider_name=str(provider.name),
)
prefetched_chunks.append(first_chunk)
buffer += first_chunk
# 继续读取剩余的预读数据
async for chunk in aiter:
prefetched_chunks.append(chunk) prefetched_chunks.append(chunk)
buffer += chunk buffer += chunk
@@ -785,12 +802,21 @@ class CliMessageHandlerBase(BaseMessageHandler):
if should_stop or line_count >= max_prefetch_lines: if should_stop or line_count >= max_prefetch_lines:
break break
except EmbeddedErrorException: except (EmbeddedErrorException, ProviderTimeoutException, ProviderNotAvailableException):
# 重新抛出嵌套错误 # 重新抛出可重试的 Provider 异常,触发故障转移
raise raise
except (OSError, IOError) as e:
# 网络 I/O 异常:记录警告,可能需要重试
logger.warning(
f" [{self.request_id}] 预读流时发生网络异常: {type(e).__name__}: {e}"
)
except Exception as e: except Exception as e:
# 其他异常(如网络错误)在预读阶段发生,记录日志但不中断 # 未预期的严重异常:记录错误并重新抛出,避免掩盖问题
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}") logger.error(
f" [{self.request_id}] 预读流时发生严重异常: {type(e).__name__}: {e}",
exc_info=True
)
raise
return prefetched_chunks return prefetched_chunks

File diff suppressed because it is too large Load Diff

View File

@@ -25,10 +25,12 @@ from src.api.handlers.base.content_extractors import (
from src.api.handlers.base.parsers import get_parser_for_format from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.response_parser import ResponseParser from src.api.handlers.base.response_parser import ResponseParser
from src.api.handlers.base.stream_context import StreamContext from src.api.handlers.base.stream_context import StreamContext
from src.core.exceptions import EmbeddedErrorException from src.config.settings import config
from src.core.exceptions import EmbeddedErrorException, ProviderTimeoutException
from src.core.logger import logger from src.core.logger import logger
from src.models.database import Provider, ProviderEndpoint from src.models.database import Provider, ProviderEndpoint
from src.utils.sse_parser import SSEEventParser from src.utils.sse_parser import SSEEventParser
from src.utils.timeout import read_first_chunk_with_ttfb_timeout
@dataclass @dataclass
@@ -170,6 +172,8 @@ class StreamProcessor:
某些 Provider如 Gemini可能返回 HTTP 200但在响应体中包含错误信息。 某些 Provider如 Gemini可能返回 HTTP 200但在响应体中包含错误信息。
这种情况需要在流开始输出之前检测,以便触发重试逻辑。 这种情况需要在流开始输出之前检测,以便触发重试逻辑。
首次读取时会应用 TTFB首字节超时检测超时则触发故障转移。
Args: Args:
byte_iterator: 字节流迭代器 byte_iterator: 字节流迭代器
provider: Provider 对象 provider: Provider 对象
@@ -182,6 +186,7 @@ class StreamProcessor:
Raises: Raises:
EmbeddedErrorException: 如果检测到嵌套错误 EmbeddedErrorException: 如果检测到嵌套错误
ProviderTimeoutException: 如果首字节超时TTFB timeout
""" """
prefetched_chunks: list = [] prefetched_chunks: list = []
parser = self.get_parser_for_provider(ctx) parser = self.get_parser_for_provider(ctx)
@@ -192,7 +197,19 @@ class StreamProcessor:
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace") decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
try: try:
async for chunk in byte_iterator: # 使用共享的 TTFB 超时函数读取首字节
ttfb_timeout = config.stream_first_byte_timeout
first_chunk, aiter = await read_first_chunk_with_ttfb_timeout(
byte_iterator,
timeout=ttfb_timeout,
request_id=self.request_id,
provider_name=str(provider.name),
)
prefetched_chunks.append(first_chunk)
buffer += first_chunk
# 继续读取剩余的预读数据
async for chunk in aiter:
prefetched_chunks.append(chunk) prefetched_chunks.append(chunk)
buffer += chunk buffer += chunk
@@ -262,10 +279,21 @@ class StreamProcessor:
if should_stop or line_count >= max_prefetch_lines: if should_stop or line_count >= max_prefetch_lines:
break break
except EmbeddedErrorException: except (EmbeddedErrorException, ProviderTimeoutException):
# 重新抛出可重试的 Provider 异常,触发故障转移
raise raise
except (OSError, IOError) as e:
# 网络 I/O <20><><EFBFBD>记录警告可能需要重试
logger.warning(
f" [{self.request_id}] 预读流时发生网络异常: {type(e).__name__}: {e}"
)
except Exception as e: except Exception as e:
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}") # 未预期的严重异常:记录错误并重新抛出,避免掩盖问题
logger.error(
f" [{self.request_id}] 预读流时发生严重异常: {type(e).__name__}: {e}",
exc_info=True
)
raise
return prefetched_chunks return prefetched_chunks

View File

@@ -4,17 +4,28 @@ Handler 基础工具函数
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from src.core.logger import logger
def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int: def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int:
""" """
提取缓存创建 tokens兼容新旧格式) 提取缓存创建 tokens兼容三种格式)
Claude API 在不同版本中使用了不同的字段名来表示缓存创建 tokens 根据 Anthropic API 文档,支持三种格式(按优先级)
- 新格式2024年后使用 claude_cache_creation_5_m_tokens 和
claude_cache_creation_1_h_tokens 分别表示 5 分钟和 1 小时缓存
- 旧格式:使用 cache_creation_input_tokens 表示总的缓存创建 tokens
此函数自动检测并适配两种格式优先使用新格式。 1. **嵌套格式优先级最高)**
usage.cache_creation.ephemeral_5m_input_tokens
usage.cache_creation.ephemeral_1h_input_tokens
2. **扁平新格式(优先级第二)**
usage.claude_cache_creation_5_m_tokens
usage.claude_cache_creation_1_h_tokens
3. **旧格式(优先级第三)**
usage.cache_creation_input_tokens
优先使用嵌套格式,如果嵌套格式字段存在但值为 0则智能 fallback 到旧格式。
扁平格式和嵌套格式互斥,按顺序检查。
Args: Args:
usage: API 响应中的 usage 字典 usage: API 响应中的 usage 字典
@@ -22,20 +33,63 @@ def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int:
Returns: Returns:
缓存创建 tokens 总数 缓存创建 tokens 总数
""" """
# 检查新格式字段是否存在(而非值是否为 0 # 1. 检查嵌套格式(最新格式
# 如果字段存在,即使值为 0 也是合法的,不应 fallback 到旧格式 cache_creation = usage.get("cache_creation")
has_new_format = ( if isinstance(cache_creation, dict):
cache_5m = int(cache_creation.get("ephemeral_5m_input_tokens", 0))
cache_1h = int(cache_creation.get("ephemeral_1h_input_tokens", 0))
total = cache_5m + cache_1h
if total > 0:
logger.debug(
f"Using nested cache_creation: 5m={cache_5m}, 1h={cache_1h}, total={total}"
)
return total
# 嵌套格式存在但为 0fallback 到旧格式
old_format = int(usage.get("cache_creation_input_tokens", 0))
if old_format > 0:
logger.debug(
f"Nested cache_creation is 0, using old format: {old_format}"
)
return old_format
# 都是 0返回 0
return 0
# 2. 检查扁平新格式
has_flat_format = (
"claude_cache_creation_5_m_tokens" in usage "claude_cache_creation_5_m_tokens" in usage
or "claude_cache_creation_1_h_tokens" in usage or "claude_cache_creation_1_h_tokens" in usage
) )
if has_new_format: if has_flat_format:
cache_5m = usage.get("claude_cache_creation_5_m_tokens", 0) cache_5m = int(usage.get("claude_cache_creation_5_m_tokens", 0))
cache_1h = usage.get("claude_cache_creation_1_h_tokens", 0) cache_1h = int(usage.get("claude_cache_creation_1_h_tokens", 0))
return int(cache_5m) + int(cache_1h) total = cache_5m + cache_1h
# 回退到旧格式 if total > 0:
return int(usage.get("cache_creation_input_tokens", 0)) logger.debug(
f"Using flat new format: 5m={cache_5m}, 1h={cache_1h}, total={total}"
)
return total
# 扁平格式存在但为 0fallback 到旧格式
old_format = int(usage.get("cache_creation_input_tokens", 0))
if old_format > 0:
logger.debug(
f"Flat cache_creation is 0, using old format: {old_format}"
)
return old_format
# 都是 0返回 0
return 0
# 3. 回退到旧格式
old_format = int(usage.get("cache_creation_input_tokens", 0))
if old_format > 0:
logger.debug(f"Using old format: cache_creation_input_tokens={old_format}")
return old_format
def build_sse_headers(extra_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]: def build_sse_headers(extra_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:

View File

@@ -209,6 +209,38 @@ class ClaudeChatAdapter(ChatAdapterBase):
logger.warning(f"Failed to fetch Claude models from {models_url}: {e}") logger.warning(f"Failed to fetch Claude models from {models_url}: {e}")
return [], error_msg return [], error_msg
@classmethod
def build_endpoint_url(cls, base_url: str) -> str:
"""构建Claude API端点URL"""
base_url = base_url.rstrip("/")
if base_url.endswith("/v1"):
return f"{base_url}/messages"
else:
return f"{base_url}/v1/messages"
@classmethod
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
"""构建Claude API认证头"""
return {
"x-api-key": api_key,
"Content-Type": "application/json",
"anthropic-version": "2023-06-01",
}
@classmethod
def get_protected_header_keys(cls) -> tuple:
"""返回Claude API的保护头部key"""
return ("x-api-key", "content-type", "anthropic-version")
@classmethod
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""构建Claude API请求体"""
return {
"model": request_data.get("model"),
"max_tokens": request_data.get("max_tokens", 100),
"messages": request_data.get("messages", []),
}
def build_claude_adapter(x_app_header: Optional[str]): def build_claude_adapter(x_app_header: Optional[str]):
"""根据 x-app 头部构造 Chat 或 Claude Code 适配器。""" """根据 x-app 头部构造 Chat 或 Claude Code 适配器。"""

View File

@@ -4,7 +4,7 @@ Claude CLI Adapter - 基于通用 CLI Adapter 基类的简化实现
继承 CliAdapterBase只需配置 FORMAT_ID 和 HANDLER_CLASS。 继承 CliAdapterBase只需配置 FORMAT_ID 和 HANDLER_CLASS。
""" """
from typing import Any, Dict, Optional, Tuple, Type from typing import Any, AsyncIterator, Dict, Optional, Tuple, Type, Union
import httpx import httpx
from fastapi import Request from fastapi import Request
@@ -115,7 +115,7 @@ class ClaudeCliAdapter(CliAdapterBase):
) -> Tuple[list, Optional[str]]: ) -> Tuple[list, Optional[str]]:
"""查询 Claude API 支持的模型列表(带 CLI User-Agent""" """查询 Claude API 支持的模型列表(带 CLI User-Agent"""
# 复用 ClaudeChatAdapter 的实现,添加 CLI User-Agent # 复用 ClaudeChatAdapter 的实现,添加 CLI User-Agent
cli_headers = {"User-Agent": config.internal_user_agent_claude} cli_headers = {"User-Agent": config.internal_user_agent_claude_cli}
if extra_headers: if extra_headers:
cli_headers.update(extra_headers) cli_headers.update(extra_headers)
models, error = await ClaudeChatAdapter.fetch_models( models, error = await ClaudeChatAdapter.fetch_models(
@@ -126,5 +126,41 @@ class ClaudeCliAdapter(CliAdapterBase):
m["api_format"] = cls.FORMAT_ID m["api_format"] = cls.FORMAT_ID
return models, error return models, error
@classmethod
def build_endpoint_url(cls, base_url: str, request_data: Dict[str, Any], model_name: Optional[str] = None) -> str:
"""构建Claude CLI API端点URL"""
base_url = base_url.rstrip("/")
if base_url.endswith("/v1"):
return f"{base_url}/messages"
else:
return f"{base_url}/v1/messages"
@classmethod
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
"""构建Claude CLI API认证头"""
return {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
@classmethod
def get_protected_header_keys(cls) -> tuple:
"""返回Claude CLI API的保护头部key"""
return ("authorization", "content-type")
@classmethod
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""构建Claude CLI API请求体"""
return {
"model": request_data.get("model"),
"max_tokens": request_data.get("max_tokens", 100),
"messages": request_data.get("messages", []),
}
@classmethod
def get_cli_user_agent(cls) -> Optional[str]:
"""获取Claude CLI User-Agent"""
return config.internal_user_agent_claude_cli
__all__ = ["ClaudeCliAdapter"] __all__ = ["ClaudeCliAdapter"]

View File

@@ -4,7 +4,7 @@ Gemini Chat Adapter
处理 Gemini API 格式的请求适配 处理 Gemini API 格式的请求适配
""" """
from typing import Any, Dict, Optional, Tuple, Type from typing import Any, AsyncIterator, Dict, Optional, Tuple, Type, Union
import httpx import httpx
from fastapi import HTTPException, Request from fastapi import HTTPException, Request
@@ -12,6 +12,7 @@ from fastapi.responses import JSONResponse
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
from src.api.handlers.base.chat_handler_base import ChatHandlerBase from src.api.handlers.base.chat_handler_base import ChatHandlerBase
from src.api.handlers.base.endpoint_checker import build_safe_headers, run_endpoint_check
from src.core.logger import logger from src.core.logger import logger
from src.models.gemini import GeminiRequest from src.models.gemini import GeminiRequest
@@ -199,6 +200,94 @@ class GeminiChatAdapter(ChatAdapterBase):
logger.warning(f"Failed to fetch Gemini models from {models_url}: {e}") logger.warning(f"Failed to fetch Gemini models from {models_url}: {e}")
return [], error_msg return [], error_msg
@classmethod
def build_endpoint_url(cls, base_url: str) -> str:
"""构建Gemini API端点URL"""
base_url = base_url.rstrip("/")
if base_url.endswith("/v1beta"):
return base_url # 子类需要处理model参数
else:
return f"{base_url}/v1beta"
@classmethod
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
"""构建Gemini API认证头"""
return {
"x-goog-api-key": api_key,
"Content-Type": "application/json",
}
@classmethod
def get_protected_header_keys(cls) -> tuple:
"""返回Gemini API的保护头部key"""
return ("x-goog-api-key", "content-type")
@classmethod
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""构建Gemini API请求体"""
return {
"contents": request_data.get("messages", []),
"generationConfig": {
"maxOutputTokens": request_data.get("max_tokens", 100),
"temperature": request_data.get("temperature", 0.7),
},
"safetySettings": [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}
],
}
@classmethod
async def check_endpoint(
cls,
client: httpx.AsyncClient,
base_url: str,
api_key: str,
request_data: Dict[str, Any],
extra_headers: Optional[Dict[str, str]] = None,
# 用量计算参数
db: Optional[Any] = None,
user: Optional[Any] = None,
provider_name: Optional[str] = None,
provider_id: Optional[str] = None,
api_key_id: Optional[str] = None,
model_name: Optional[str] = None,
) -> Dict[str, Any]:
"""测试 Gemini API 模型连接性(非流式)"""
# Gemini需要从request_data或model_name参数获取model名称
effective_model_name = model_name or request_data.get("model", "")
if not effective_model_name:
return {
"error": "Model name is required for Gemini API",
"status_code": 400,
}
# 使用基类配置方法但重写URL构建逻辑
base_url = cls.build_endpoint_url(base_url)
url = f"{base_url}/models/{effective_model_name}:generateContent"
# 构建请求组件
base_headers = cls.build_base_headers(api_key)
protected_keys = cls.get_protected_header_keys()
headers = build_safe_headers(base_headers, extra_headers, protected_keys)
body = cls.build_request_body(request_data)
# 使用基类的通用endpoint checker
from src.api.handlers.base.endpoint_checker import run_endpoint_check
return await run_endpoint_check(
client=client,
url=url,
headers=headers,
json_body=body,
api_format=cls.name,
# 用量计算参数(现在强制记录)
db=db,
user=user,
provider_name=provider_name,
provider_id=provider_id,
api_key_id=api_key_id,
model_name=effective_model_name,
)
def build_gemini_adapter(x_app_header: str = "") -> GeminiChatAdapter: def build_gemini_adapter(x_app_header: str = "") -> GeminiChatAdapter:
""" """

View File

@@ -4,7 +4,7 @@ Gemini CLI Adapter - 基于通用 CLI Adapter 基类的实现
继承 CliAdapterBase处理 Gemini CLI 格式的请求。 继承 CliAdapterBase处理 Gemini CLI 格式的请求。
""" """
from typing import Any, Dict, Optional, Tuple, Type from typing import Any, AsyncIterator, Dict, Optional, Tuple, Type, Union
import httpx import httpx
from fastapi import Request from fastapi import Request
@@ -112,7 +112,7 @@ class GeminiCliAdapter(CliAdapterBase):
) -> Tuple[list, Optional[str]]: ) -> Tuple[list, Optional[str]]:
"""查询 Gemini API 支持的模型列表(带 CLI User-Agent""" """查询 Gemini API 支持的模型列表(带 CLI User-Agent"""
# 复用 GeminiChatAdapter 的实现,添加 CLI User-Agent # 复用 GeminiChatAdapter 的实现,添加 CLI User-Agent
cli_headers = {"User-Agent": config.internal_user_agent_gemini} cli_headers = {"User-Agent": config.internal_user_agent_gemini_cli}
if extra_headers: if extra_headers:
cli_headers.update(extra_headers) cli_headers.update(extra_headers)
models, error = await GeminiChatAdapter.fetch_models( models, error = await GeminiChatAdapter.fetch_models(
@@ -123,6 +123,52 @@ class GeminiCliAdapter(CliAdapterBase):
m["api_format"] = cls.FORMAT_ID m["api_format"] = cls.FORMAT_ID
return models, error return models, error
@classmethod
def build_endpoint_url(cls, base_url: str, request_data: Dict[str, Any], model_name: Optional[str] = None) -> str:
"""构建Gemini CLI API端点URL"""
effective_model_name = model_name or request_data.get("model", "")
if not effective_model_name:
raise ValueError("Model name is required for Gemini API")
base_url = base_url.rstrip("/")
if base_url.endswith("/v1beta"):
prefix = base_url
else:
prefix = f"{base_url}/v1beta"
return f"{prefix}/models/{effective_model_name}:generateContent"
@classmethod
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
"""构建Gemini CLI API认证头"""
return {
"x-goog-api-key": api_key,
"Content-Type": "application/json",
}
@classmethod
def get_protected_header_keys(cls) -> tuple:
"""返回Gemini CLI API的保护头部key"""
return ("x-goog-api-key", "content-type")
@classmethod
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""构建Gemini CLI API请求体"""
return {
"contents": request_data.get("messages", []),
"generationConfig": {
"maxOutputTokens": request_data.get("max_tokens", 100),
"temperature": request_data.get("temperature", 0.7),
},
"safetySettings": [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}
],
}
@classmethod
def get_cli_user_agent(cls) -> Optional[str]:
"""获取Gemini CLI User-Agent"""
return config.internal_user_agent_gemini_cli
def build_gemini_cli_adapter(x_app_header: str = "") -> GeminiCliAdapter: def build_gemini_cli_adapter(x_app_header: str = "") -> GeminiCliAdapter:
""" """

View File

@@ -4,13 +4,14 @@ OpenAI Chat Adapter - 基于 ChatAdapterBase 的 OpenAI Chat API 适配器
处理 /v1/chat/completions 端点的 OpenAI Chat 格式请求。 处理 /v1/chat/completions 端点的 OpenAI Chat 格式请求。
""" """
from typing import Any, Dict, Optional, Tuple, Type from typing import Any, AsyncIterator, Dict, Optional, Tuple, Type, Union
import httpx import httpx
from fastapi import Request from fastapi import Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter from src.api.handlers.base.chat_adapter_base import ChatAdapterBase, register_adapter
from src.api.handlers.base.endpoint_checker import build_safe_headers, run_endpoint_check
from src.api.handlers.base.chat_handler_base import ChatHandlerBase from src.api.handlers.base.chat_handler_base import ChatHandlerBase
from src.core.logger import logger from src.core.logger import logger
from src.models.openai import OpenAIRequest from src.models.openai import OpenAIRequest
@@ -154,5 +155,32 @@ class OpenAIChatAdapter(ChatAdapterBase):
logger.warning(f"Failed to fetch models from {models_url}: {e}") logger.warning(f"Failed to fetch models from {models_url}: {e}")
return [], error_msg return [], error_msg
@classmethod
def build_endpoint_url(cls, base_url: str) -> str:
"""构建OpenAI API端点URL"""
base_url = base_url.rstrip("/")
if base_url.endswith("/v1"):
return f"{base_url}/chat/completions"
else:
return f"{base_url}/v1/chat/completions"
@classmethod
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
"""构建OpenAI API认证头"""
return {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
@classmethod
def get_protected_header_keys(cls) -> tuple:
"""返回OpenAI API的保护头部key"""
return ("authorization", "content-type")
@classmethod
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""构建OpenAI API请求体"""
return request_data.copy()
__all__ = ["OpenAIChatAdapter"] __all__ = ["OpenAIChatAdapter"]

View File

@@ -4,7 +4,7 @@ OpenAI CLI Adapter - 基于通用 CLI Adapter 基类的简化实现
继承 CliAdapterBase只需配置 FORMAT_ID 和 HANDLER_CLASS。 继承 CliAdapterBase只需配置 FORMAT_ID 和 HANDLER_CLASS。
""" """
from typing import Dict, Optional, Tuple, Type from typing import Any, AsyncIterator, Dict, Optional, Tuple, Type, Union
import httpx import httpx
from fastapi import Request from fastapi import Request
@@ -57,7 +57,7 @@ class OpenAICliAdapter(CliAdapterBase):
) -> Tuple[list, Optional[str]]: ) -> Tuple[list, Optional[str]]:
"""查询 OpenAI 兼容 API 支持的模型列表(带 CLI User-Agent""" """查询 OpenAI 兼容 API 支持的模型列表(带 CLI User-Agent"""
# 复用 OpenAIChatAdapter 的实现,添加 CLI User-Agent # 复用 OpenAIChatAdapter 的实现,添加 CLI User-Agent
cli_headers = {"User-Agent": config.internal_user_agent_openai} cli_headers = {"User-Agent": config.internal_user_agent_openai_cli}
if extra_headers: if extra_headers:
cli_headers.update(extra_headers) cli_headers.update(extra_headers)
models, error = await OpenAIChatAdapter.fetch_models( models, error = await OpenAIChatAdapter.fetch_models(
@@ -68,5 +68,37 @@ class OpenAICliAdapter(CliAdapterBase):
m["api_format"] = cls.FORMAT_ID m["api_format"] = cls.FORMAT_ID
return models, error return models, error
@classmethod
def build_endpoint_url(cls, base_url: str, request_data: Dict[str, Any], model_name: Optional[str] = None) -> str:
"""构建OpenAI CLI API端点URL"""
base_url = base_url.rstrip("/")
if base_url.endswith("/v1"):
return f"{base_url}/chat/completions"
else:
return f"{base_url}/v1/chat/completions"
@classmethod
def build_base_headers(cls, api_key: str) -> Dict[str, str]:
"""构建OpenAI CLI API认证头"""
return {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
@classmethod
def get_protected_header_keys(cls) -> tuple:
"""返回OpenAI CLI API的保护头部key"""
return ("authorization", "content-type")
@classmethod
def build_request_body(cls, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""构建OpenAI CLI API请求体"""
return request_data.copy()
@classmethod
def get_cli_user_agent(cls) -> Optional[str]:
"""获取OpenAI CLI User-Agent"""
return config.internal_user_agent_openai_cli
__all__ = ["OpenAICliAdapter"] __all__ = ["OpenAICliAdapter"]

View File

@@ -77,7 +77,10 @@ class ConcurrencyDefaults:
MAX_CONCURRENT_LIMIT = 200 MAX_CONCURRENT_LIMIT = 200
# 最小并发限制下限 # 最小并发限制下限
MIN_CONCURRENT_LIMIT = 1 # 设置为 3 而不是 1因为预留机制10%预留给缓存用户)会导致
# 当 learned_max_concurrent=1 时新用户实际可用槽位为 0永远无法命中
# 注意:当 limit < 10 时,预留机制实际不生效(预留槽位 = 0这是可接受的
MIN_CONCURRENT_LIMIT = 3
# === 探测性扩容参数 === # === 探测性扩容参数 ===
# 探测性扩容间隔(分钟)- 长时间无 429 且有流量时尝试扩容 # 探测性扩容间隔(分钟)- 长时间无 429 且有流量时尝试扩容

View File

@@ -56,10 +56,11 @@ class Config:
# Redis 依赖策略(生产默认必需,开发默认可选,可通过 REDIS_REQUIRED 覆盖) # Redis 依赖策略(生产默认必需,开发默认可选,可通过 REDIS_REQUIRED 覆盖)
redis_required_env = os.getenv("REDIS_REQUIRED") redis_required_env = os.getenv("REDIS_REQUIRED")
if redis_required_env is None: if redis_required_env is not None:
self.require_redis = self.environment not in {"development", "test", "testing"}
else:
self.require_redis = redis_required_env.lower() == "true" self.require_redis = redis_required_env.lower() == "true"
else:
# 保持向后兼容:开发环境可选,生产环境必需
self.require_redis = self.environment not in {"development", "test", "testing"}
# CORS配置 - 使用环境变量配置允许的源 # CORS配置 - 使用环境变量配置允许的源
# 格式: 逗号分隔的域名列表,如 "http://localhost:3000,https://example.com" # 格式: 逗号分隔的域名列表,如 "http://localhost:3000,https://example.com"
@@ -133,6 +134,18 @@ class Config:
self.concurrency_slot_ttl = int(os.getenv("CONCURRENCY_SLOT_TTL", "600")) self.concurrency_slot_ttl = int(os.getenv("CONCURRENCY_SLOT_TTL", "600"))
self.cache_reservation_ratio = float(os.getenv("CACHE_RESERVATION_RATIO", "0.1")) self.cache_reservation_ratio = float(os.getenv("CACHE_RESERVATION_RATIO", "0.1"))
# 限流降级策略配置
# RATE_LIMIT_FAIL_OPEN: 当限流服务Redis异常时的行为
#
# True (默认): fail-open - 放行请求(优先可用性)
# 风险Redis 故障期间无法限流,可能被滥用
# 适用API 网关作为关键基础设施,必须保持高可用
#
# False: fail-close - 拒绝所有请求(优先安全性)
# 风险Redis 故障会导致 API 网关不可用
# 适用:有严格速率限制要求的安全敏感场景
self.rate_limit_fail_open = os.getenv("RATE_LIMIT_FAIL_OPEN", "true").lower() == "true"
# HTTP 请求超时配置(秒) # HTTP 请求超时配置(秒)
self.http_connect_timeout = float(os.getenv("HTTP_CONNECT_TIMEOUT", "10.0")) self.http_connect_timeout = float(os.getenv("HTTP_CONNECT_TIMEOUT", "10.0"))
self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0")) self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0"))
@@ -141,19 +154,22 @@ class Config:
# 流式处理配置 # 流式处理配置
# STREAM_PREFETCH_LINES: 预读行数,用于检测嵌套错误 # STREAM_PREFETCH_LINES: 预读行数,用于检测嵌套错误
# STREAM_STATS_DELAY: 统计记录延迟(秒),等待流完全关闭 # STREAM_STATS_DELAY: 统计记录延迟(秒),等待流完全关闭
# STREAM_FIRST_BYTE_TIMEOUT: 首字节超时(秒),等待首字节超过此时间触发故障转移
# 范围: 10-120 秒,默认 30 秒(必须小于 http_write_timeout 避免竞态)
self.stream_prefetch_lines = int(os.getenv("STREAM_PREFETCH_LINES", "5")) self.stream_prefetch_lines = int(os.getenv("STREAM_PREFETCH_LINES", "5"))
self.stream_stats_delay = float(os.getenv("STREAM_STATS_DELAY", "0.1")) self.stream_stats_delay = float(os.getenv("STREAM_STATS_DELAY", "0.1"))
self.stream_first_byte_timeout = self._parse_ttfb_timeout()
# 内部请求 User-Agent 配置(用于查询上游模型列表等) # 内部请求 User-Agent 配置(用于查询上游模型列表等)
# 可通过环境变量覆盖默认值 # 可通过环境变量覆盖默认值,模拟对应 CLI 客户端
self.internal_user_agent_claude = os.getenv( self.internal_user_agent_claude_cli = os.getenv(
"CLAUDE_USER_AGENT", "claude-cli/1.0" "CLAUDE_CLI_USER_AGENT", "claude-code/1.0.1"
) )
self.internal_user_agent_openai = os.getenv( self.internal_user_agent_openai_cli = os.getenv(
"OPENAI_USER_AGENT", "openai-cli/1.0" "OPENAI_CLI_USER_AGENT", "openai-codex/1.0"
) )
self.internal_user_agent_gemini = os.getenv( self.internal_user_agent_gemini_cli = os.getenv(
"GEMINI_USER_AGENT", "gemini-cli/1.0" "GEMINI_CLI_USER_AGENT", "gemini-cli/0.1.0"
) )
# 验证连接池配置 # 验证连接池配置
@@ -177,6 +193,39 @@ class Config:
"""智能计算最大溢出连接数 - 与 pool_size 相同""" """智能计算最大溢出连接数 - 与 pool_size 相同"""
return self.db_pool_size return self.db_pool_size
def _parse_ttfb_timeout(self) -> float:
"""
解析 TTFB 超时配置,带错误处理和范围限制
TTFB (Time To First Byte) 用于检测慢响应的 Provider超时触发故障转移。
此值必须小于 http_write_timeout避免竞态条件。
Returns:
超时时间(秒),范围 10-120默认 30
"""
default_timeout = 30.0
min_timeout = 10.0
max_timeout = 120.0 # 必须小于 http_write_timeout (默认 60s) 的 2 倍
raw_value = os.getenv("STREAM_FIRST_BYTE_TIMEOUT", str(default_timeout))
try:
timeout = float(raw_value)
except ValueError:
# 延迟导入避免循环依赖Config 初始化时 logger 可能未就绪)
self._ttfb_config_warning = (
f"无效的 STREAM_FIRST_BYTE_TIMEOUT 配置 '{raw_value}',使用默认值 {default_timeout}"
)
return default_timeout
# 范围限制
clamped = max(min_timeout, min(max_timeout, timeout))
if clamped != timeout:
self._ttfb_config_warning = (
f"STREAM_FIRST_BYTE_TIMEOUT={timeout}秒超出范围 [{min_timeout}-{max_timeout}]"
f"已调整为 {clamped}"
)
return clamped
def _validate_pool_config(self) -> None: def _validate_pool_config(self) -> None:
"""验证连接池配置是否安全""" """验证连接池配置是否安全"""
total_per_worker = self.db_pool_size + self.db_max_overflow total_per_worker = self.db_pool_size + self.db_max_overflow
@@ -224,6 +273,10 @@ class Config:
if hasattr(self, "_pool_config_warning") and self._pool_config_warning: if hasattr(self, "_pool_config_warning") and self._pool_config_warning:
logger.warning(self._pool_config_warning) logger.warning(self._pool_config_warning)
# TTFB 超时配置警告
if hasattr(self, "_ttfb_config_warning") and self._ttfb_config_warning:
logger.warning(self._ttfb_config_warning)
# 管理员密码检查(必须在环境变量中设置) # 管理员密码检查(必须在环境变量中设置)
if hasattr(self, "_missing_admin_password") and self._missing_admin_password: if hasattr(self, "_missing_admin_password") and self._missing_admin_password:
logger.error("必须设置 ADMIN_PASSWORD 环境变量!") logger.error("必须设置 ADMIN_PASSWORD 环境变量!")

View File

@@ -336,10 +336,44 @@ class PluginMiddleware:
) )
return result return result
return None return None
except ConnectionError as e:
# Redis 连接错误:根据配置决定
logger.warning(f"Rate limit connection error: {e}")
if config.rate_limit_fail_open:
return None
else:
return RateLimitResult(
allowed=False,
remaining=0,
retry_after=30,
message="Rate limit service unavailable"
)
except TimeoutError as e:
# 超时错误:可能是负载过高,根据配置决定
logger.warning(f"Rate limit timeout: {e}")
if config.rate_limit_fail_open:
return None
else:
return RateLimitResult(
allowed=False,
remaining=0,
retry_after=30,
message="Rate limit service timeout"
)
except Exception as e: except Exception as e:
logger.error(f"Rate limit error: {e}") logger.error(f"Rate limit error: {type(e).__name__}: {e}")
# 发生错误时允许请求通过 # 其他异常:根据配置决定
return None if config.rate_limit_fail_open:
# fail-open: 异常时放行请求(优先可用性)
return None
else:
# fail-close: 异常时拒绝请求(优先安全性)
return RateLimitResult(
allowed=False,
remaining=0,
retry_after=60,
message="Rate limit service error"
)
async def _call_pre_request_plugins(self, request: Request) -> None: async def _call_pre_request_plugins(self, request: Request) -> None:
"""调用请求前的插件(当前保留扩展点)""" """调用请求前的插件(当前保留扩展点)"""

View File

@@ -317,6 +317,7 @@ class UpdateUserRequest(BaseModel):
username: Optional[str] = Field(None, min_length=1, max_length=50) username: Optional[str] = Field(None, min_length=1, max_length=50)
email: Optional[str] = Field(None, max_length=100) email: Optional[str] = Field(None, max_length=100)
password: Optional[str] = Field(None, min_length=6, max_length=128, description="新密码(留空保持不变)")
quota_usd: Optional[float] = Field(None, ge=0) quota_usd: Optional[float] = Field(None, ge=0)
is_active: Optional[bool] = None is_active: Optional[bool] = None
role: Optional[str] = None role: Optional[str] = None

View File

@@ -226,8 +226,11 @@ class EndpointAPIKeyUpdate(BaseModel):
global_priority: Optional[int] = Field( global_priority: Optional[int] = Field(
default=None, description="全局 Key 优先级(全局 Key 优先模式,数字越小越优先)" default=None, description="全局 Key 优先级(全局 Key 优先模式,数字越小越优先)"
) )
# 注意:max_concurrent=None 表示不更新,要切换为自适应模式请使用专用 API # max_concurrent: 使用特殊标记区分"未提供"和"设置为 null自适应模式"
max_concurrent: Optional[int] = Field(default=None, ge=1, description="最大并发数") # - 不提供字段:不更新
# - 提供 null切换为自适应模式
# - 提供数字:设置固定并发限制
max_concurrent: Optional[int] = Field(default=None, ge=1, description="最大并发数null=自适应模式)")
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制") rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制")
daily_limit: Optional[int] = Field(default=None, ge=1, description="每日限制") daily_limit: Optional[int] = Field(default=None, ge=1, description="每日限制")
monthly_limit: Optional[int] = Field(default=None, ge=1, description="每月限制") monthly_limit: Optional[int] = Field(default=None, ge=1, description="每月限制")

View File

@@ -27,7 +27,7 @@ if not config.jwt_secret_key:
if config.environment == "production": if config.environment == "production":
raise ValueError("JWT_SECRET_KEY must be set in production environment!") raise ValueError("JWT_SECRET_KEY must be set in production environment!")
config.jwt_secret_key = secrets.token_urlsafe(32) config.jwt_secret_key = secrets.token_urlsafe(32)
logger.warning(f"JWT_SECRET_KEY未在环境变量中找到已生成随机密钥用于开发: {config.jwt_secret_key[:10]}...") logger.warning("JWT_SECRET_KEY未在环境变量中找到已生成随机密钥用于开发")
logger.warning("生产环境请设置JWT_SECRET_KEY环境变量!") logger.warning("生产环境请设置JWT_SECRET_KEY环境变量!")
JWT_SECRET_KEY = config.jwt_secret_key JWT_SECRET_KEY = config.jwt_secret_key

View File

@@ -30,6 +30,8 @@
from __future__ import annotations from __future__ import annotations
import hashlib
import random
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
@@ -956,7 +958,16 @@ class CacheAwareScheduler:
# 获取活跃的 Key 并按 internal_priority + 负载均衡排序 # 获取活跃的 Key 并按 internal_priority + 负载均衡排序
active_keys = [key for key in endpoint.api_keys if key.is_active] active_keys = [key for key in endpoint.api_keys if key.is_active]
keys = self._shuffle_keys_by_internal_priority(active_keys, affinity_key) # 检查是否所有 Key 都是 TTL=0轮换模式
# 如果所有 Key 的 cache_ttl_minutes 都是 0 或 None则使用随机排序
use_random = all(
(key.cache_ttl_minutes or 0) == 0 for key in active_keys
) if active_keys else False
if use_random and len(active_keys) > 1:
logger.debug(
f" Endpoint {endpoint.id[:8]}... 启用 Key 轮换模式 (TTL=0, {len(active_keys)} keys)"
)
keys = self._shuffle_keys_by_internal_priority(active_keys, affinity_key, use_random)
for key in keys: for key in keys:
# Key 级别的能力检查(模型级别的能力检查已在上面完成) # Key 级别的能力检查(模型级别的能力检查已在上面完成)
@@ -1170,6 +1181,7 @@ class CacheAwareScheduler:
self, self,
keys: List[ProviderAPIKey], keys: List[ProviderAPIKey],
affinity_key: Optional[str] = None, affinity_key: Optional[str] = None,
use_random: bool = False,
) -> List[ProviderAPIKey]: ) -> List[ProviderAPIKey]:
""" """
对 API Key 按 internal_priority 分组,同优先级内部基于 affinity_key 进行确定性打乱 对 API Key 按 internal_priority 分组,同优先级内部基于 affinity_key 进行确定性打乱
@@ -1178,10 +1190,12 @@ class CacheAwareScheduler:
- 数字越小越优先使用 - 数字越小越优先使用
- 同优先级 Key 之间实现负载均衡 - 同优先级 Key 之间实现负载均衡
- 使用 affinity_key 哈希确保同一请求 Key 的请求稳定(避免破坏缓存亲和性) - 使用 affinity_key 哈希确保同一请求 Key 的请求稳定(避免破坏缓存亲和性)
- 当 use_random=True 时,使用随机排序实现轮换(用于 TTL=0 的场景)
Args: Args:
keys: API Key 列表 keys: API Key 列表
affinity_key: 亲和性标识符(通常为 API Key ID用于确定性打乱 affinity_key: 亲和性标识符(通常为 API Key ID用于确定性打乱
use_random: 是否使用随机排序TTL=0 时为 True
Returns: Returns:
排序后的 Key 列表 排序后的 Key 列表
@@ -1198,28 +1212,35 @@ class CacheAwareScheduler:
priority = key.internal_priority if key.internal_priority is not None else 999999 priority = key.internal_priority if key.internal_priority is not None else 999999
priority_groups[priority].append(key) priority_groups[priority].append(key)
# 对每个优先级组内的 Key 进行确定性打乱 # 对每个优先级组内的 Key 进行打乱
result = [] result = []
for priority in sorted(priority_groups.keys()): # 数字小的优先级高,排前面 for priority in sorted(priority_groups.keys()): # 数字小的优先级高,排前面
group_keys = priority_groups[priority] group_keys = priority_groups[priority]
if len(group_keys) > 1 and affinity_key: if len(group_keys) > 1:
# 改进的哈希策略:为每个 key 计算独立的哈希值 if use_random:
import hashlib # TTL=0 模式:使用随机排序实现 Key 轮换
shuffled = list(group_keys)
random.shuffle(shuffled)
result.extend(shuffled)
elif affinity_key:
# 正常模式:使用哈希确定性打乱(保持缓存亲和性)
key_scores = []
for key in group_keys:
# 使用 affinity_key + key.id 的组合哈希
hash_input = f"{affinity_key}:{key.id}"
hash_value = int(hashlib.sha256(hash_input.encode()).hexdigest()[:16], 16)
key_scores.append((hash_value, key))
key_scores = [] # 按哈希值排序
for key in group_keys: sorted_group = [key for _, key in sorted(key_scores)]
# 使用 affinity_key + key.id 的组合哈希 result.extend(sorted_group)
hash_input = f"{affinity_key}:{key.id}" else:
hash_value = int(hashlib.sha256(hash_input.encode()).hexdigest()[:16], 16) # 没有 affinity_key 时按 ID 排序保持稳定性
key_scores.append((hash_value, key)) result.extend(sorted(group_keys, key=lambda k: k.id))
# 按哈希值排序
sorted_group = [key for _, key in sorted(key_scores)]
result.extend(sorted_group)
else: else:
# 单个 Key 或没有 affinity_key 时保持原顺序 # 单个 Key 直接添加
result.extend(sorted(group_keys, key=lambda k: k.id)) result.extend(group_keys)
return result return result

View File

@@ -234,8 +234,15 @@ class EndpointHealthService:
for api_format in format_key_mapping.keys() for api_format in format_key_mapping.keys()
} }
# 参数校验API 层已通过 Query(ge=1) 保证,这里做防御性检查)
if lookback_hours <= 0 or segments <= 0:
raise ValueError(
f"lookback_hours and segments must be positive, "
f"got lookback_hours={lookback_hours}, segments={segments}"
)
# 计算时间范围 # 计算时间范围
interval_minutes = (lookback_hours * 60) // segments segment_seconds = (lookback_hours * 3600) / segments
start_time = now - timedelta(hours=lookback_hours) start_time = now - timedelta(hours=lookback_hours)
# 使用 RequestCandidate 表查询所有尝试记录 # 使用 RequestCandidate 表查询所有尝试记录
@@ -243,7 +250,7 @@ class EndpointHealthService:
final_statuses = ["success", "failed", "skipped"] final_statuses = ["success", "failed", "skipped"]
segment_expr = func.floor( segment_expr = func.floor(
func.extract('epoch', RequestCandidate.created_at - start_time) / (interval_minutes * 60) func.extract('epoch', RequestCandidate.created_at - start_time) / segment_seconds
).label('segment_idx') ).label('segment_idx')
candidate_stats = ( candidate_stats = (

View File

@@ -208,86 +208,120 @@ class CleanupScheduler:
return return
# 非首次运行,检查最近是否有缺失的日期需要回填 # 非首次运行,检查最近是否有缺失的日期需要回填
latest_stat = db.query(StatsDaily).order_by(StatsDaily.date.desc()).first() from src.models.database import StatsDailyModel
if latest_stat: yesterday_business_date = today_local.date() - timedelta(days=1)
latest_date_utc = latest_stat.date max_backfill_days: int = SystemConfigService.get_config(
if latest_date_utc.tzinfo is None: db, "max_stats_backfill_days", 30
latest_date_utc = latest_date_utc.replace(tzinfo=timezone.utc) ) or 30
else:
latest_date_utc = latest_date_utc.astimezone(timezone.utc)
# 使用业务日期计算缺失区间(避免用 UTC 年月日导致日期偏移,且对 DST 更安全) # 计算回填检查的起始日期
latest_business_date = latest_date_utc.astimezone(app_tz).date() check_start_date = yesterday_business_date - timedelta(
yesterday_business_date = today_local.date() - timedelta(days=1) days=max_backfill_days - 1
missing_start_date = latest_business_date + timedelta(days=1) )
if missing_start_date <= yesterday_business_date: # 获取 StatsDaily 和 StatsDailyModel 中已有数据的日期集合
missing_days = ( existing_daily_dates = set()
yesterday_business_date - missing_start_date existing_model_dates = set()
).days + 1
# 限制最大回填天数,防止停机很久后一次性回填太多 daily_stats = (
max_backfill_days: int = SystemConfigService.get_config( db.query(StatsDaily.date)
db, "max_stats_backfill_days", 30 .filter(StatsDaily.date >= check_start_date.isoformat())
) or 30 .all()
if missing_days > max_backfill_days: )
logger.warning( for (stat_date,) in daily_stats:
f"缺失 {missing_days} 天数据超过最大回填限制 " if stat_date.tzinfo is None:
f"{max_backfill_days} 天,只回填最近 {max_backfill_days}" stat_date = stat_date.replace(tzinfo=timezone.utc)
existing_daily_dates.add(stat_date.astimezone(app_tz).date())
model_stats = (
db.query(StatsDailyModel.date)
.filter(StatsDailyModel.date >= check_start_date.isoformat())
.distinct()
.all()
)
for (stat_date,) in model_stats:
if stat_date.tzinfo is None:
stat_date = stat_date.replace(tzinfo=timezone.utc)
existing_model_dates.add(stat_date.astimezone(app_tz).date())
# 找出需要回填的日期
all_dates = set()
current = check_start_date
while current <= yesterday_business_date:
all_dates.add(current)
current += timedelta(days=1)
# 需要回填 StatsDaily 的日期
missing_daily_dates = all_dates - existing_daily_dates
# 需要回填 StatsDailyModel 的日期
missing_model_dates = all_dates - existing_model_dates
# 合并所有需要处理的日期
dates_to_process = missing_daily_dates | missing_model_dates
if dates_to_process:
sorted_dates = sorted(dates_to_process)
logger.info(
f"检测到 {len(dates_to_process)} 天的统计数据需要回填 "
f"(StatsDaily 缺失 {len(missing_daily_dates)} 天, "
f"StatsDailyModel 缺失 {len(missing_model_dates)} 天)"
)
users = (
db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
)
failed_dates = 0
failed_users = 0
for current_date in sorted_dates:
try:
current_date_local = datetime.combine(
current_date, datetime.min.time(), tzinfo=app_tz
) )
missing_start_date = yesterday_business_date - timedelta( # 只在缺失时才聚合对应的表
days=max_backfill_days - 1 if current_date in missing_daily_dates:
)
missing_days = max_backfill_days
logger.info(
f"检测到缺失 {missing_days} 天的统计数据 "
f"({missing_start_date} ~ {yesterday_business_date}),开始回填..."
)
current_date = missing_start_date
users = (
db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
)
while current_date <= yesterday_business_date:
try:
current_date_local = datetime.combine(
current_date, datetime.min.time(), tzinfo=app_tz
)
StatsAggregatorService.aggregate_daily_stats( StatsAggregatorService.aggregate_daily_stats(
db, current_date_local db, current_date_local
) )
if current_date in missing_model_dates:
StatsAggregatorService.aggregate_daily_model_stats( StatsAggregatorService.aggregate_daily_model_stats(
db, current_date_local db, current_date_local
) )
for (user_id,) in users: # 用户统计在任一缺失时都回填
try: for (user_id,) in users:
StatsAggregatorService.aggregate_user_daily_stats(
db, user_id, current_date_local
)
except Exception as e:
logger.warning(
f"回填用户 {user_id} 日期 {current_date} 失败: {e}"
)
try:
db.rollback()
except Exception:
pass
except Exception as e:
logger.warning(f"回填日期 {current_date} 失败: {e}")
try: try:
db.rollback() StatsAggregatorService.aggregate_user_daily_stats(
except Exception: db, user_id, current_date_local
pass )
except Exception as e:
failed_users += 1
logger.warning(
f"回填用户 {user_id} 日期 {current_date} 失败: {e}"
)
try:
db.rollback()
except Exception as rollback_err:
logger.error(f"回滚失败: {rollback_err}")
except Exception as e:
failed_dates += 1
logger.warning(f"回填日期 {current_date} 失败: {e}")
try:
db.rollback()
except Exception as rollback_err:
logger.error(f"回滚失败: {rollback_err}")
current_date += timedelta(days=1) StatsAggregatorService.update_summary(db)
StatsAggregatorService.update_summary(db) if failed_dates > 0 or failed_users > 0:
logger.info(f"缺失数据回填完成,共 {missing_days}") logger.warning(
f"回填完成,共处理 {len(dates_to_process)} 天,"
f"失败: {failed_dates} 天, {failed_users} 个用户记录"
)
else: else:
logger.info("统计数据已是最新,无需回填") logger.info(f"缺失数据回填完成,共处理 {len(dates_to_process)}")
else:
logger.info("统计数据已是最新,无需回填")
return return
# 定时任务:聚合昨天的数据 # 定时任务:聚合昨天的数据

View File

@@ -3,6 +3,7 @@
""" """
import uuid import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
@@ -16,6 +17,71 @@ from src.services.model.cost import ModelCostService
from src.services.system.config import SystemConfigService from src.services.system.config import SystemConfigService
@dataclass
class UsageRecordParams:
"""用量记录参数数据类,用于在内部方法间传递数据"""
db: Session
user: Optional[User]
api_key: Optional[ApiKey]
provider: str
model: str
input_tokens: int
output_tokens: int
cache_creation_input_tokens: int
cache_read_input_tokens: int
request_type: str
api_format: Optional[str]
is_stream: bool
response_time_ms: Optional[int]
first_byte_time_ms: Optional[int]
status_code: int
error_message: Optional[str]
metadata: Optional[Dict[str, Any]]
request_headers: Optional[Dict[str, Any]]
request_body: Optional[Any]
provider_request_headers: Optional[Dict[str, Any]]
response_headers: Optional[Dict[str, Any]]
response_body: Optional[Any]
request_id: str
provider_id: Optional[str]
provider_endpoint_id: Optional[str]
provider_api_key_id: Optional[str]
status: str
cache_ttl_minutes: Optional[int]
use_tiered_pricing: bool
target_model: Optional[str]
def __post_init__(self) -> None:
"""验证关键字段,确保数据完整性"""
# Token 数量不能为负数
if self.input_tokens < 0:
raise ValueError(f"input_tokens 不能为负数: {self.input_tokens}")
if self.output_tokens < 0:
raise ValueError(f"output_tokens 不能为负数: {self.output_tokens}")
if self.cache_creation_input_tokens < 0:
raise ValueError(
f"cache_creation_input_tokens 不能为负数: {self.cache_creation_input_tokens}"
)
if self.cache_read_input_tokens < 0:
raise ValueError(
f"cache_read_input_tokens 不能为负数: {self.cache_read_input_tokens}"
)
# 响应时间不能为负数
if self.response_time_ms is not None and self.response_time_ms < 0:
raise ValueError(f"response_time_ms 不能为负数: {self.response_time_ms}")
if self.first_byte_time_ms is not None and self.first_byte_time_ms < 0:
raise ValueError(f"first_byte_time_ms 不能为负数: {self.first_byte_time_ms}")
# HTTP 状态码范围校验
if not (100 <= self.status_code <= 599):
raise ValueError(f"无效的 HTTP 状态码: {self.status_code}")
# 状态值校验
valid_statuses = {"pending", "streaming", "completed", "failed"}
if self.status not in valid_statuses:
raise ValueError(f"无效的状态值: {self.status},有效值: {valid_statuses}")
class UsageService: class UsageService:
"""用量统计服务""" """用量统计服务"""
@@ -471,6 +537,97 @@ class UsageService:
cache_ttl_minutes=cache_ttl_minutes, cache_ttl_minutes=cache_ttl_minutes,
) )
@classmethod
async def _prepare_usage_record(
cls,
params: UsageRecordParams,
) -> Tuple[Dict[str, Any], float]:
"""准备用量记录的共享逻辑
此方法提取了 record_usage 和 record_usage_async 的公共处理逻辑:
- 获取费率倍数
- 计算成本
- 构建 Usage 参数
Args:
params: 用量记录参数数据类
Returns:
(usage_params 字典, total_cost 总成本)
"""
# 获取费率倍数和是否免费套餐
actual_rate_multiplier, is_free_tier = await cls._get_rate_multiplier_and_free_tier(
params.db, params.provider_api_key_id, params.provider_id
)
# 计算成本
is_failed_request = params.status_code >= 400 or params.error_message is not None
(
input_price, output_price, cache_creation_price, cache_read_price, request_price,
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost,
request_cost, total_cost, _tier_index
) = await cls._calculate_costs(
db=params.db,
provider=params.provider,
model=params.model,
input_tokens=params.input_tokens,
output_tokens=params.output_tokens,
cache_creation_input_tokens=params.cache_creation_input_tokens,
cache_read_input_tokens=params.cache_read_input_tokens,
api_format=params.api_format,
cache_ttl_minutes=params.cache_ttl_minutes,
use_tiered_pricing=params.use_tiered_pricing,
is_failed_request=is_failed_request,
)
# 构建 Usage 参数
usage_params = cls._build_usage_params(
db=params.db,
user=params.user,
api_key=params.api_key,
provider=params.provider,
model=params.model,
input_tokens=params.input_tokens,
output_tokens=params.output_tokens,
cache_creation_input_tokens=params.cache_creation_input_tokens,
cache_read_input_tokens=params.cache_read_input_tokens,
request_type=params.request_type,
api_format=params.api_format,
is_stream=params.is_stream,
response_time_ms=params.response_time_ms,
first_byte_time_ms=params.first_byte_time_ms,
status_code=params.status_code,
error_message=params.error_message,
metadata=params.metadata,
request_headers=params.request_headers,
request_body=params.request_body,
provider_request_headers=params.provider_request_headers,
response_headers=params.response_headers,
response_body=params.response_body,
request_id=params.request_id,
provider_id=params.provider_id,
provider_endpoint_id=params.provider_endpoint_id,
provider_api_key_id=params.provider_api_key_id,
status=params.status,
target_model=params.target_model,
input_cost=input_cost,
output_cost=output_cost,
cache_creation_cost=cache_creation_cost,
cache_read_cost=cache_read_cost,
cache_cost=cache_cost,
request_cost=request_cost,
total_cost=total_cost,
input_price=input_price,
output_price=output_price,
cache_creation_price=cache_creation_price,
cache_read_price=cache_read_price,
request_price=request_price,
actual_rate_multiplier=actual_rate_multiplier,
is_free_tier=is_free_tier,
)
return usage_params, total_cost
@classmethod @classmethod
async def record_usage_async( async def record_usage_async(
cls, cls,
@@ -516,76 +673,25 @@ class UsageService:
if request_id is None: if request_id is None:
request_id = str(uuid.uuid4())[:8] request_id = str(uuid.uuid4())[:8]
# 获取费率倍数和是否免费套餐 # 使用共享逻辑准备记录参数
actual_rate_multiplier, is_free_tier = await cls._get_rate_multiplier_and_free_tier( params = UsageRecordParams(
db, provider_api_key_id, provider_id db=db, user=user, api_key=api_key, provider=provider, model=model,
) input_tokens=input_tokens, output_tokens=output_tokens,
# 计算成本
is_failed_request = status_code >= 400 or error_message is not None
(
input_price, output_price, cache_creation_price, cache_read_price, request_price,
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost,
request_cost, total_cost, tier_index
) = await cls._calculate_costs(
db=db,
provider=provider,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_input_tokens, cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens, cache_read_input_tokens=cache_read_input_tokens,
api_format=api_format, request_type=request_type, api_format=api_format, is_stream=is_stream,
cache_ttl_minutes=cache_ttl_minutes, response_time_ms=response_time_ms, first_byte_time_ms=first_byte_time_ms,
use_tiered_pricing=use_tiered_pricing, status_code=status_code, error_message=error_message, metadata=metadata,
is_failed_request=is_failed_request, request_headers=request_headers, request_body=request_body,
)
# 构建 Usage 参数
usage_params = cls._build_usage_params(
db=db,
user=user,
api_key=api_key,
provider=provider,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
request_type=request_type,
api_format=api_format,
is_stream=is_stream,
response_time_ms=response_time_ms,
first_byte_time_ms=first_byte_time_ms,
status_code=status_code,
error_message=error_message,
metadata=metadata,
request_headers=request_headers,
request_body=request_body,
provider_request_headers=provider_request_headers, provider_request_headers=provider_request_headers,
response_headers=response_headers, response_headers=response_headers, response_body=response_body,
response_body=response_body, request_id=request_id, provider_id=provider_id,
request_id=request_id,
provider_id=provider_id,
provider_endpoint_id=provider_endpoint_id, provider_endpoint_id=provider_endpoint_id,
provider_api_key_id=provider_api_key_id, provider_api_key_id=provider_api_key_id, status=status,
status=status, cache_ttl_minutes=cache_ttl_minutes, use_tiered_pricing=use_tiered_pricing,
target_model=target_model, target_model=target_model,
input_cost=input_cost,
output_cost=output_cost,
cache_creation_cost=cache_creation_cost,
cache_read_cost=cache_read_cost,
cache_cost=cache_cost,
request_cost=request_cost,
total_cost=total_cost,
input_price=input_price,
output_price=output_price,
cache_creation_price=cache_creation_price,
cache_read_price=cache_read_price,
request_price=request_price,
actual_rate_multiplier=actual_rate_multiplier,
is_free_tier=is_free_tier,
) )
usage_params, _ = await cls._prepare_usage_record(params)
# 创建 Usage 记录 # 创建 Usage 记录
usage = Usage(**usage_params) usage = Usage(**usage_params)
@@ -660,76 +766,25 @@ class UsageService:
if request_id is None: if request_id is None:
request_id = str(uuid.uuid4())[:8] request_id = str(uuid.uuid4())[:8]
# 获取费率倍数和是否免费套餐 # 使用共享逻辑准备记录参数
actual_rate_multiplier, is_free_tier = await cls._get_rate_multiplier_and_free_tier( params = UsageRecordParams(
db, provider_api_key_id, provider_id db=db, user=user, api_key=api_key, provider=provider, model=model,
) input_tokens=input_tokens, output_tokens=output_tokens,
# 计算成本
is_failed_request = status_code >= 400 or error_message is not None
(
input_price, output_price, cache_creation_price, cache_read_price, request_price,
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost,
request_cost, total_cost, _tier_index
) = await cls._calculate_costs(
db=db,
provider=provider,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_input_tokens, cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens, cache_read_input_tokens=cache_read_input_tokens,
api_format=api_format, request_type=request_type, api_format=api_format, is_stream=is_stream,
cache_ttl_minutes=cache_ttl_minutes, response_time_ms=response_time_ms, first_byte_time_ms=first_byte_time_ms,
use_tiered_pricing=use_tiered_pricing, status_code=status_code, error_message=error_message, metadata=metadata,
is_failed_request=is_failed_request, request_headers=request_headers, request_body=request_body,
)
# 构建 Usage 参数
usage_params = cls._build_usage_params(
db=db,
user=user,
api_key=api_key,
provider=provider,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
request_type=request_type,
api_format=api_format,
is_stream=is_stream,
response_time_ms=response_time_ms,
first_byte_time_ms=first_byte_time_ms,
status_code=status_code,
error_message=error_message,
metadata=metadata,
request_headers=request_headers,
request_body=request_body,
provider_request_headers=provider_request_headers, provider_request_headers=provider_request_headers,
response_headers=response_headers, response_headers=response_headers, response_body=response_body,
response_body=response_body, request_id=request_id, provider_id=provider_id,
request_id=request_id,
provider_id=provider_id,
provider_endpoint_id=provider_endpoint_id, provider_endpoint_id=provider_endpoint_id,
provider_api_key_id=provider_api_key_id, provider_api_key_id=provider_api_key_id, status=status,
status=status, cache_ttl_minutes=cache_ttl_minutes, use_tiered_pricing=use_tiered_pricing,
target_model=target_model, target_model=target_model,
input_cost=input_cost,
output_cost=output_cost,
cache_creation_cost=cache_creation_cost,
cache_read_cost=cache_read_cost,
cache_cost=cache_cost,
request_cost=request_cost,
total_cost=total_cost,
input_price=input_price,
output_price=output_price,
cache_creation_price=cache_creation_price,
cache_read_price=cache_read_price,
request_price=request_price,
actual_rate_multiplier=actual_rate_multiplier,
is_free_tier=is_free_tier,
) )
usage_params, total_cost = await cls._prepare_usage_record(params)
# 检查是否已存在相同 request_id 的记录 # 检查是否已存在相同 request_id 的记录
existing_usage = db.query(Usage).filter(Usage.request_id == request_id).first() existing_usage = db.query(Usage).filter(Usage.request_id == request_id).first()
@@ -751,7 +806,7 @@ class UsageService:
api_key = db.merge(api_key) api_key = db.merge(api_key)
# 使用原子更新避免并发竞态条件 # 使用原子更新避免并发竞态条件
from sqlalchemy import func, update from sqlalchemy import func as sql_func, update
from src.models.database import ApiKey as ApiKeyModel, User as UserModel, GlobalModel from src.models.database import ApiKey as ApiKeyModel, User as UserModel, GlobalModel
# 更新用户使用量(独立 Key 不计入创建者的使用记录) # 更新用户使用量(独立 Key 不计入创建者的使用记录)
@@ -762,7 +817,7 @@ class UsageService:
.values( .values(
used_usd=UserModel.used_usd + total_cost, used_usd=UserModel.used_usd + total_cost,
total_usd=UserModel.total_usd + total_cost, total_usd=UserModel.total_usd + total_cost,
updated_at=func.now(), updated_at=sql_func.now(),
) )
) )
@@ -776,8 +831,8 @@ class UsageService:
total_requests=ApiKeyModel.total_requests + 1, total_requests=ApiKeyModel.total_requests + 1,
total_cost_usd=ApiKeyModel.total_cost_usd + total_cost, total_cost_usd=ApiKeyModel.total_cost_usd + total_cost,
balance_used_usd=ApiKeyModel.balance_used_usd + total_cost, balance_used_usd=ApiKeyModel.balance_used_usd + total_cost,
last_used_at=func.now(), last_used_at=sql_func.now(),
updated_at=func.now(), updated_at=sql_func.now(),
) )
) )
else: else:
@@ -787,8 +842,8 @@ class UsageService:
.values( .values(
total_requests=ApiKeyModel.total_requests + 1, total_requests=ApiKeyModel.total_requests + 1,
total_cost_usd=ApiKeyModel.total_cost_usd + total_cost, total_cost_usd=ApiKeyModel.total_cost_usd + total_cost,
last_used_at=func.now(), last_used_at=sql_func.now(),
updated_at=func.now(), updated_at=sql_func.now(),
) )
) )
@@ -1121,19 +1176,48 @@ class UsageService:
] ]
@staticmethod @staticmethod
def cleanup_old_usage_records(db: Session, days_to_keep: int = 90) -> int: def cleanup_old_usage_records(
"""清理旧的使用记录""" db: Session, days_to_keep: int = 90, batch_size: int = 1000
) -> int:
"""清理旧的使用记录(分批删除避免长事务锁定)
Args:
db: 数据库会话
days_to_keep: 保留天数,默认 90 天
batch_size: 每批删除数量,默认 1000 条
Returns:
删除的总记录数
"""
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep) cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
total_deleted = 0
# 删除旧记录 while True:
deleted = db.query(Usage).filter(Usage.created_at < cutoff_date).delete() # 查询待删除的 ID使用新索引 idx_usage_user_created
batch_ids = (
db.query(Usage.id)
.filter(Usage.created_at < cutoff_date)
.limit(batch_size)
.all()
)
db.commit() if not batch_ids:
break
logger.info(f"清理使用记录: 删除 {deleted} 条超过 {days_to_keep} 天的记录") # 批量删除
deleted_count = (
db.query(Usage)
.filter(Usage.id.in_([row.id for row in batch_ids]))
.delete(synchronize_session=False)
)
db.commit()
total_deleted += deleted_count
return deleted logger.debug(f"清理使用记录: 本批删除 {deleted_count}")
logger.info(f"清理使用记录: 共删除 {total_deleted} 条超过 {days_to_keep} 天的记录")
return total_deleted
# ========== 请求状态追踪方法 ========== # ========== 请求状态追踪方法 ==========
@@ -1219,6 +1303,7 @@ class UsageService:
error_message: Optional[str] = None, error_message: Optional[str] = None,
provider: Optional[str] = None, provider: Optional[str] = None,
target_model: Optional[str] = None, target_model: Optional[str] = None,
first_byte_time_ms: Optional[int] = None,
) -> Optional[Usage]: ) -> Optional[Usage]:
""" """
快速更新使用记录状态 快速更新使用记录状态
@@ -1230,6 +1315,7 @@ class UsageService:
error_message: 错误消息(仅在 failed 状态时使用) error_message: 错误消息(仅在 failed 状态时使用)
provider: 提供商名称可选streaming 状态时更新) provider: 提供商名称可选streaming 状态时更新)
target_model: 映射后的目标模型名(可选) target_model: 映射后的目标模型名(可选)
first_byte_time_ms: 首字时间/TTFB可选streaming 状态时更新)
Returns: Returns:
更新后的 Usage 记录,如果未找到则返回 None 更新后的 Usage 记录,如果未找到则返回 None
@@ -1247,6 +1333,8 @@ class UsageService:
usage.provider = provider usage.provider = provider
if target_model: if target_model:
usage.target_model = target_model usage.target_model = target_model
if first_byte_time_ms is not None:
usage.first_byte_time_ms = first_byte_time_ms
db.commit() db.commit()

View File

@@ -5,6 +5,7 @@
import json import json
import re import re
import time
from typing import Any, AsyncIterator, Dict, Optional, Tuple from typing import Any, AsyncIterator, Dict, Optional, Tuple
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -457,26 +458,32 @@ class StreamUsageTracker:
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应 | 估算输入tokens:{self.input_tokens}") logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应 | 估算输入tokens:{self.input_tokens}")
# 更新状态为 streaming同时更新 provider
if self.request_id:
try:
from src.services.usage.service import UsageService
UsageService.update_usage_status(
db=self.db,
request_id=self.request_id,
status="streaming",
provider=self.provider,
)
except Exception as e:
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
chunk_count = 0 chunk_count = 0
first_chunk_received = False
try: try:
async for chunk in stream: async for chunk in stream:
chunk_count += 1 chunk_count += 1
# 保存原始字节流(用于错误诊断) # 保存原始字节流(用于错误诊断)
self.raw_chunks.append(chunk) self.raw_chunks.append(chunk)
# 第一个 chunk 收到时,更新状态为 streaming 并记录 TTFB
if not first_chunk_received:
first_chunk_received = True
if self.request_id:
try:
# 计算 TTFB使用请求原始开始时间或 track_stream 开始时间)
base_time = self.request_start_time or self.start_time
first_byte_time_ms = int((time.time() - base_time) * 1000) if base_time else None
UsageService.update_usage_status(
db=self.db,
request_id=self.request_id,
status="streaming",
provider=self.provider,
first_byte_time_ms=first_byte_time_ms,
)
except Exception as e:
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
# 返回原始块给客户端 # 返回原始块给客户端
yield chunk yield chunk

View File

@@ -59,14 +59,15 @@ class ApiKeyService:
if expire_days: if expire_days:
expires_at = datetime.now(timezone.utc) + timedelta(days=expire_days) expires_at = datetime.now(timezone.utc) + timedelta(days=expire_days)
# 空数组转为 None表示不限制
api_key = ApiKey( api_key = ApiKey(
user_id=user_id, user_id=user_id,
key_hash=key_hash, key_hash=key_hash,
key_encrypted=key_encrypted, key_encrypted=key_encrypted,
name=name or f"API Key {datetime.now(timezone.utc).strftime('%Y%m%d%H%M%S')}", name=name or f"API Key {datetime.now(timezone.utc).strftime('%Y%m%d%H%M%S')}",
allowed_providers=allowed_providers, allowed_providers=allowed_providers or None,
allowed_api_formats=allowed_api_formats, allowed_api_formats=allowed_api_formats or None,
allowed_models=allowed_models, allowed_models=allowed_models or None,
rate_limit=rate_limit, rate_limit=rate_limit,
concurrent_limit=concurrent_limit, concurrent_limit=concurrent_limit,
expires_at=expires_at, expires_at=expires_at,
@@ -141,8 +142,18 @@ class ApiKeyService:
"auto_delete_on_expiry", "auto_delete_on_expiry",
] ]
# 允许显式设置为空数组/None 的字段(空数组会转为 None表示"全部"
nullable_list_fields = {"allowed_providers", "allowed_api_formats", "allowed_models"}
for field, value in kwargs.items(): for field, value in kwargs.items():
if field in updatable_fields and value is not None: if field not in updatable_fields:
continue
# 对于 nullable_list_fields空数组应该转为 None表示不限制
if field in nullable_list_fields:
if value is not None:
# 空数组转为 None表示允许全部
setattr(api_key, field, value if value else None)
elif value is not None:
setattr(api_key, field, value) setattr(api_key, field, value)
api_key.updated_at = datetime.now(timezone.utc) api_key.updated_at = datetime.now(timezone.utc)

View File

@@ -1,8 +1,16 @@
"""分布式任务协调器,确保仅有一个 worker 执行特定任务""" """分布式任务协调器,确保仅有一个 worker 执行特定任务
锁清理策略:
- 单实例模式(默认):启动时使用原子操作清理旧锁并获取新锁
- 多实例模式:使用 NX 选项竞争锁,依赖 TTL 处理异常退出
使用方式:
- 默认行为:启动时清理旧锁(适用于单机部署)
- 多实例部署:设置 SINGLE_INSTANCE_MODE=false 禁用启动清理
"""
from __future__ import annotations from __future__ import annotations
import asyncio
import os import os
import pathlib import pathlib
import uuid import uuid
@@ -19,6 +27,10 @@ except ImportError: # pragma: no cover - Windows 环境
class StartupTaskCoordinator: class StartupTaskCoordinator:
"""利用 Redis 或文件锁,保证任务只在单个进程/实例中运行""" """利用 Redis 或文件锁,保证任务只在单个进程/实例中运行"""
# 类级别标记:在当前进程中是否已尝试过启动清理
# 注意:这在 fork 模式下每个 worker 都是独立的
_startup_cleanup_attempted = False
def __init__(self, redis_client=None, lock_dir: Optional[str] = None): def __init__(self, redis_client=None, lock_dir: Optional[str] = None):
self.redis = redis_client self.redis = redis_client
self._tokens: Dict[str, str] = {} self._tokens: Dict[str, str] = {}
@@ -26,6 +38,8 @@ class StartupTaskCoordinator:
self._lock_dir = pathlib.Path(lock_dir or os.getenv("TASK_LOCK_DIR", "./.locks")) self._lock_dir = pathlib.Path(lock_dir or os.getenv("TASK_LOCK_DIR", "./.locks"))
if not self._lock_dir.exists(): if not self._lock_dir.exists():
self._lock_dir.mkdir(parents=True, exist_ok=True) self._lock_dir.mkdir(parents=True, exist_ok=True)
# 单实例模式:启动时清理旧锁(适用于单机部署,避免残留锁问题)
self._single_instance_mode = os.getenv("SINGLE_INSTANCE_MODE", "true").lower() == "true"
def _redis_key(self, name: str) -> str: def _redis_key(self, name: str) -> str:
return f"task_lock:{name}" return f"task_lock:{name}"
@@ -36,12 +50,51 @@ class StartupTaskCoordinator:
if self.redis: if self.redis:
token = str(uuid.uuid4()) token = str(uuid.uuid4())
try: try:
acquired = await self.redis.set(self._redis_key(name), token, nx=True, ex=ttl) if self._single_instance_mode:
if acquired: # 单实例模式:使用 Lua 脚本原子性地"清理旧锁 + 竞争获取"
self._tokens[name] = token # 只有当锁不存在或成功获取时才返回 1
logger.info(f"任务 {name} 通过 Redis 锁独占执行") # 这样第一个执行的 worker 会清理旧锁并获取,后续 worker 会正常竞争
return True script = """
return False local key = KEYS[1]
local token = ARGV[1]
local ttl = tonumber(ARGV[2])
local startup_key = KEYS[1] .. ':startup'
-- 检查是否已有 worker 执行过启动清理
local cleaned = redis.call('GET', startup_key)
if not cleaned then
-- 第一个 worker删除旧锁标记已清理
redis.call('DEL', key)
redis.call('SET', startup_key, '1', 'EX', 60)
end
-- 尝试获取锁NX 模式)
local result = redis.call('SET', key, token, 'NX', 'EX', ttl)
if result then
return 1
end
return 0
"""
result = await self.redis.eval(
script, 2,
self._redis_key(name), self._redis_key(name),
token, ttl
)
if result == 1:
self._tokens[name] = token
logger.info(f"任务 {name} 通过 Redis 锁独占执行")
return True
return False
else:
# 多实例模式:直接使用 NX 选项竞争锁
acquired = await self.redis.set(
self._redis_key(name), token, nx=True, ex=ttl
)
if acquired:
self._tokens[name] = token
logger.info(f"任务 {name} 通过 Redis 锁独占执行")
return True
return False
except Exception as exc: # pragma: no cover - Redis 异常回退 except Exception as exc: # pragma: no cover - Redis 异常回退
logger.warning(f"Redis 锁获取失败,回退到文件锁: {exc}") logger.warning(f"Redis 锁获取失败,回退到文件锁: {exc}")

View File

@@ -139,3 +139,83 @@ async def with_timeout_context(timeout: float, operation_name: str = "operation"
# Python 3.10 及以下版本的兼容实现 # Python 3.10 及以下版本的兼容实现
# 注意:这个简单实现不支持嵌套取消 # 注意:这个简单实现不支持嵌套取消
pass pass
async def read_first_chunk_with_ttfb_timeout(
byte_iterator: Any,
timeout: float,
request_id: str,
provider_name: str,
) -> tuple[bytes, Any]:
"""
读取流的首字节并应用 TTFB 超时检测
首字节超时Time To First Byte用于检测慢响应的 Provider
超时时触发故障转移到其他可用的 Provider。
Args:
byte_iterator: 异步字节流迭代器
timeout: TTFB 超时时间(秒)
request_id: 请求 ID用于日志
provider_name: Provider 名称(用于日志和异常)
Returns:
(first_chunk, aiter): 首个字节块和异步迭代器
Raises:
ProviderTimeoutException: 如果首字节超时
"""
from src.core.exceptions import ProviderTimeoutException
aiter = byte_iterator.__aiter__()
try:
first_chunk = await asyncio.wait_for(aiter.__anext__(), timeout=timeout)
return first_chunk, aiter
except asyncio.TimeoutError:
# 完整的资源清理:先关闭迭代器,再关闭底层响应
await _cleanup_iterator_resources(aiter, request_id)
logger.warning(
f" [{request_id}] 流首字节超时 (TTFB): "
f"Provider={provider_name}, timeout={timeout}s"
)
raise ProviderTimeoutException(
provider_name=provider_name,
timeout=int(timeout),
)
async def _cleanup_iterator_resources(aiter: Any, request_id: str) -> None:
"""
清理异步迭代器及其底层资源
确保在 TTFB 超时后正确释放 HTTP 连接,避免连接泄漏。
Args:
aiter: 异步迭代器
request_id: 请求 ID用于日志
"""
# 1. 关闭迭代器本身
if hasattr(aiter, "aclose"):
try:
await aiter.aclose()
except Exception as e:
logger.debug(f" [{request_id}] 关闭迭代器失败: {e}")
# 2. 关闭底层响应对象httpx.Response
# 迭代器可能持有 _response 属性指向底层响应
response = getattr(aiter, "_response", None)
if response is not None and hasattr(response, "aclose"):
try:
await response.aclose()
except Exception as e:
logger.debug(f" [{request_id}] 关闭底层响应失败: {e}")
# 3. 尝试关闭 httpx 流(如果迭代器是 httpx 的 aiter_bytes
# httpx 的 Response.aiter_bytes() 返回的生成器可能有 _stream 属性
stream = getattr(aiter, "_stream", None)
if stream is not None and hasattr(stream, "aclose"):
try:
await stream.aclose()
except Exception as e:
logger.debug(f" [{request_id}] 关闭流对象失败: {e}")