13 Commits

Author SHA1 Message Date
fawney19
ea35efe440 perf: 优化 HTTP 客户端连接池复用
- 新增 get_proxy_client() 方法,相同代理配置复用同一客户端
- 添加 LRU 淘汰策略,代理客户端上限 50 个防止内存泄漏
- 新增 get_default_client_async() 异步线程安全版本
- 使用模块级锁避免类属性初始化竞态条件
- 优化 ConcurrencyManager 使用 Redis MGET 批量获取减少往返
- 添加 get_pool_stats() 连接池统计信息接口
2026-01-08 13:34:59 +08:00
fawney19
bf09e740e9 fix(ui): 优化提供商详情页的交互体验
- 模型列表删除按钮仅在 hover 时显示红色
- 批量关联模型对话框:只有全局模型时展开,有多个分组时全部折叠
2026-01-08 11:25:52 +08:00
fawney19
60c77cec56 Merge pull request #77 from AAEE86/ui
style(ui): improve text visibility in dark mode for model badges
2026-01-08 10:52:54 +08:00
fawney19
0e4a1dddb5 refactor(ui): 优化批量端点创建的 UI 和性能
- 调整布局: API URL 移至顶部, API 格式选择移至下方
- 优化 checkbox 样式: 使用自定义勾选框替代原生样式
- API 格式按列排序: 基础格式和对应 CLI 格式上下对齐
- 请求配置改为 4 列布局, 更紧凑
- 使用 Promise.allSettled 并发创建端点, 提升性能
- 改进错误提示: 失败时直接展示具体错误信息给用户
- 清理未使用的 Select 组件导入和 selectOpen 变量
2026-01-08 10:50:25 +08:00
AAEE86
1cf18b6e12 feat(ui): support batch endpoint creation with multiple API formats (#76)
Replace single API format selector with multi-select checkbox interface in endpoint creation dialog. Users can now select multiple API formats to create multiple endpoints simultaneously with shared configuration (URL, path, timeout, etc.).

- Change API format selection from dropdown to checkbox grid layout
- Add selectedFormats array to track multiple format selections
- Implement batch creation logic with individual error handling
- Update submit button to show endpoint count being created
- Adjust form layout to improve visual hierarchy
- Display appropriate success/failure messages for batch operations
- Reset selectedFormats on form reset
2026-01-08 10:42:14 +08:00
AAEE86
f9a8be898a style(ui): improve text visibility in dark mode for model badges 2026-01-08 10:26:58 +08:00
fawney19
1521ce5a96 feat: 添加负载均衡调度模式
- 新增 load_balance 调度模式,同优先级内随机轮换
- 前端支持三种调度模式切换:缓存亲和、负载均衡、固定顺序
2026-01-08 03:20:04 +08:00
fawney19
f2e62dd197 feat: 添加版本更新检查功能
- 后端新增 /api/admin/system/check-update 接口,从 GitHub Tags 获取最新版本
- 前端新增 UpdateDialog 组件,管理员登录后自动检查更新并弹窗提示
- 同一会话内只检查一次,点击"稍后提醒"后 24 小时内不再提示
- CI 和 deploy.sh 自动生成 _version.py 版本文件
2026-01-08 03:01:54 +08:00
fawney19
d378630b38 perf: 添加多层缓存优化减少数据库查询
- 新增 ProviderCacheService 缓存 Provider 和 ProviderAPIKey 数据
- SystemConfigService 添加进程内缓存(TTL 60秒)
- API Key last_used_at 更新添加节流策略(60秒间隔)
- HTTP 连接池配置改为可配置,支持根据 Worker 数量自动计算
- 前端优先级管理改用 health_score 显示健康度
2026-01-08 02:34:59 +08:00
fawney19
d9e6346911 fix: 降低 API Key 最小长度限制至 3 个字符 2026-01-08 01:53:16 +08:00
fawney19
238788e0e9 fix: 统一端点默认重试次数为 2
同步前端表单、mock 数据和后端导入配置中端点的默认重试次数
2026-01-08 01:40:40 +08:00
fawney19
68ff828505 feat: 容器启动时自动执行数据库迁移
- 添加 entrypoint.sh 在容器启动前执行 alembic upgrade head
- 更新 Dockerfile.app 和 Dockerfile.app.local 使用新入口脚本
- 移除手动迁移脚本 migrate.sh
- 简化 README 部署说明
2026-01-08 01:28:36 +08:00
fawney19
59447fc12b fix: 固定容器内部端口为 8084,避免 PORT 环境变量导致端口冲突 2026-01-07 21:51:55 +08:00
34 changed files with 1358 additions and 331 deletions

View File

@@ -146,10 +146,33 @@ jobs:
type=semver,pattern={{major}}.{{minor}}
type=sha,prefix=
- name: Extract version from tag
id: version
run: |
# 从 tag 提取版本号,如 v0.2.5 -> 0.2.5
VERSION="${GITHUB_REF#refs/tags/v}"
if [ "$VERSION" = "$GITHUB_REF" ]; then
# 不是 tag 触发,使用 git describe
VERSION=$(git describe --tags --always | sed 's/^v//')
fi
echo "version=$VERSION" >> $GITHUB_OUTPUT
echo "Extracted version: $VERSION"
- name: Update Dockerfile.app to use registry base image
run: |
sed -i "s|FROM aether-base:latest AS builder|FROM ${{ env.REGISTRY }}/${{ env.BASE_IMAGE_NAME }}:latest AS builder|g" Dockerfile.app
- name: Generate version file
run: |
# 生成 _version.py 文件
cat > src/_version.py << EOF
# Auto-generated by CI
__version__ = '${{ steps.version.outputs.version }}'
__version_tuple__ = tuple(int(x) for x in '${{ steps.version.outputs.version }}'.split('.') if x.isdigit())
version = __version__
version_tuple = __version_tuple__
EOF
- name: Build and push app image
uses: docker/build-push-action@v5
with:

3
.gitignore vendored
View File

@@ -224,3 +224,6 @@ extracted_*.ts
.deps-hash
.code-hash
.migration-hash
# Version file (auto-generated by hatch-vcs)
src/_version.py

View File

@@ -127,14 +127,14 @@ RUN printf '%s\n' \
'pidfile=/var/run/supervisord.pid' \
'' \
'[program:nginx]' \
'command=/bin/bash -c "sed \"s/PORT_PLACEHOLDER/${PORT:-8084}/g\" /etc/nginx/sites-available/default.template > /etc/nginx/sites-available/default && /usr/sbin/nginx -g \"daemon off;\""' \
'command=/bin/bash -c "sed \"s/PORT_PLACEHOLDER/8084/g\" /etc/nginx/sites-available/default.template > /etc/nginx/sites-available/default && /usr/sbin/nginx -g \"daemon off;\""' \
'autostart=true' \
'autorestart=true' \
'stdout_logfile=/var/log/nginx/access.log' \
'stderr_logfile=/var/log/nginx/error.log' \
'' \
'[program:app]' \
'command=gunicorn src.main:app --preload -w %(ENV_GUNICORN_WORKERS)s -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:%(ENV_PORT)s --timeout 120 --access-logfile - --error-logfile - --log-level info' \
'command=gunicorn src.main:app --preload -w %(ENV_GUNICORN_WORKERS)s -k uvicorn.workers.UvicornWorker --bind 127.0.0.1:8084 --timeout 120 --access-logfile - --error-logfile - --log-level info' \
'directory=/app' \
'autostart=true' \
'autorestart=true' \
@@ -147,6 +147,10 @@ RUN printf '%s\n' \
# 创建目录
RUN mkdir -p /var/log/supervisor /app/logs /app/data
# 入口脚本(启动前执行迁移)
COPY entrypoint.sh /entrypoint.sh
RUN chmod +x /entrypoint.sh
# 环境变量
ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
@@ -161,4 +165,5 @@ EXPOSE 80
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost/health || exit 1
ENTRYPOINT ["/entrypoint.sh"]
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]

View File

@@ -139,6 +139,10 @@ RUN printf '%s\n' \
# 创建目录
RUN mkdir -p /var/log/supervisor /app/logs /app/data
# 入口脚本(启动前执行迁移)
COPY entrypoint.sh /entrypoint.sh
RUN chmod +x /entrypoint.sh
# 环境变量
ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
@@ -152,4 +156,5 @@ EXPOSE 80
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost/health || exit 1
ENTRYPOINT ["/entrypoint.sh"]
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]

View File

@@ -57,14 +57,8 @@ cd Aether
cp .env.example .env
python generate_keys.py # 生成密钥, 并将生成的密钥填入 .env
# 3. 部署
docker compose up -d
# 4. 首次部署时, 初始化数据库
./migrate.sh
# 5. 更新
docker compose pull && docker compose up -d && ./migrate.sh
# 3. 部署 / 更新(自动执行数据库迁移)
docker compose pull && docker compose up -d
```
### Docker Compose本地构建镜像

View File

@@ -88,9 +88,28 @@ build_base() {
save_deps_hash
}
# 生成版本文件
generate_version_file() {
# 从 git 获取版本号
local version
version=$(git describe --tags --always 2>/dev/null | sed 's/^v//')
if [ -z "$version" ]; then
version="unknown"
fi
echo ">>> Generating version file: $version"
cat > src/_version.py << EOF
# Auto-generated by deploy.sh - do not edit
__version__ = '$version'
__version_tuple__ = tuple(int(x) for x in '$version'.split('-')[0].split('.') if x.isdigit())
version = __version__
version_tuple = __version_tuple__
EOF
}
# 构建应用镜像
build_app() {
echo ">>> Building app image (code only)..."
generate_version_file
docker build -f Dockerfile.app.local -t aether-app:latest .
save_code_hash
}

8
entrypoint.sh Normal file
View File

@@ -0,0 +1,8 @@
#!/bin/bash
set -e
echo "Running database migrations..."
alembic upgrade head
echo "Starting application..."
exec "$@"

View File

@@ -159,6 +159,15 @@ export interface EmailTemplateResetResponse {
}
}
// 检查更新响应
export interface CheckUpdateResponse {
current_version: string
latest_version: string | null
has_update: boolean
release_url: string | null
error: string | null
}
// LDAP 配置响应
export interface LdapConfigResponse {
server_url: string | null
@@ -526,6 +535,14 @@ export const adminApi = {
return response.data
},
// 检查系统更新
async checkUpdate(): Promise<CheckUpdateResponse> {
const response = await apiClient.get<CheckUpdateResponse>(
'/api/admin/system/check-update'
)
return response.data
},
// LDAP 配置相关
// 获取 LDAP 配置
async getLdapConfig(): Promise<LdapConfigResponse> {

View File

@@ -0,0 +1,112 @@
<template>
<Dialog
v-model="isOpen"
size="md"
title=""
>
<div class="flex flex-col items-center text-center py-2">
<!-- Logo -->
<HeaderLogo
size="h-16 w-16"
class-name="text-primary"
/>
<!-- Title -->
<h2 class="text-xl font-semibold text-foreground mt-4 mb-2">
发现新版本
</h2>
<!-- Version Info -->
<div class="flex items-center gap-3 mb-4">
<span class="px-3 py-1.5 rounded-lg bg-muted text-sm font-mono text-muted-foreground">
v{{ currentVersion }}
</span>
<svg
class="h-4 w-4 text-muted-foreground"
fill="none"
stroke="currentColor"
viewBox="0 0 24 24"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M13 7l5 5m0 0l-5 5m5-5H6"
/>
</svg>
<span class="px-3 py-1.5 rounded-lg bg-primary/10 text-sm font-mono font-medium text-primary">
v{{ latestVersion }}
</span>
</div>
<!-- Description -->
<p class="text-sm text-muted-foreground max-w-xs">
新版本已发布建议更新以获得最新功能和安全修复
</p>
</div>
<template #footer>
<div class="flex w-full gap-3">
<Button
variant="outline"
class="flex-1"
@click="handleLater"
>
稍后提醒
</Button>
<Button
class="flex-1"
@click="handleViewRelease"
>
查看更新
</Button>
</div>
</template>
</Dialog>
</template>
<script setup lang="ts">
import { ref, watch } from 'vue'
import { Dialog } from '@/components/ui'
import Button from '@/components/ui/button.vue'
import HeaderLogo from '@/components/HeaderLogo.vue'
const props = defineProps<{
modelValue: boolean
currentVersion: string
latestVersion: string
releaseUrl: string | null
}>()
const emit = defineEmits<{
'update:modelValue': [value: boolean]
}>()
const isOpen = ref(props.modelValue)
watch(() => props.modelValue, (val) => {
isOpen.value = val
})
watch(isOpen, (val) => {
emit('update:modelValue', val)
})
function handleLater() {
// 记录忽略的版本24小时内不再提醒
const ignoreKey = 'aether_update_ignore'
const ignoreData = {
version: props.latestVersion,
until: Date.now() + 24 * 60 * 60 * 1000 // 24小时
}
localStorage.setItem(ignoreKey, JSON.stringify(ignoreData))
isOpen.value = false
}
function handleViewRelease() {
if (props.releaseUrl) {
window.open(props.releaseUrl, '_blank')
}
isOpen.value = false
}
</script>

View File

@@ -531,20 +531,23 @@ watch(() => props.open, async (isOpen) => {
// 加载数据
async function loadData() {
await Promise.all([loadGlobalModels(), loadExistingModels()])
// 默认折叠全局模型组
collapsedGroups.value = new Set(['global'])
// 检查缓存,如果有缓存数据则直接使用
const cachedModels = getCachedModels(props.providerId)
if (cachedModels) {
if (cachedModels && cachedModels.length > 0) {
upstreamModels.value = cachedModels
upstreamModelsLoaded.value = true
// 折叠所有上游模型组
// 有多个分组时全部折叠
const allGroups = new Set(['global'])
for (const model of cachedModels) {
if (model.api_format) {
collapsedGroups.value.add(model.api_format)
allGroups.add(model.api_format)
}
}
collapsedGroups.value = allGroups
} else {
// 只有全局模型时展开
collapsedGroups.value = new Set()
}
}
@@ -585,8 +588,8 @@ async function fetchUpstreamModels(forceRefresh = false) {
} else {
upstreamModels.value = result.models
upstreamModelsLoaded.value = true
// 折叠所有上游模型组
const allGroups = new Set(collapsedGroups.value)
// 有多个分组时全部折叠
const allGroups = new Set(['global'])
for (const model of result.models) {
if (model.api_format) {
allGroups.add(model.api_format)

View File

@@ -20,44 +20,8 @@
API 配置
</h3>
<!-- API URL 和自定义路径 -->
<div class="grid grid-cols-2 gap-4">
<!-- API 格式 -->
<div class="space-y-2">
<Label for="api_format">API 格式 *</Label>
<template v-if="isEditMode">
<Input
id="api_format"
v-model="form.api_format"
disabled
class="bg-muted"
/>
<p class="text-xs text-muted-foreground">
API 格式创建后不可修改
</p>
</template>
<template v-else>
<Select
v-model="form.api_format"
v-model:open="selectOpen"
required
>
<SelectTrigger>
<SelectValue placeholder="请选择 API 格式" />
</SelectTrigger>
<SelectContent>
<SelectItem
v-for="format in apiFormats"
:key="format.value"
:value="format.value"
>
{{ format.label }}
</SelectItem>
</SelectContent>
</Select>
</template>
</div>
<!-- API URL -->
<div class="space-y-2">
<Label for="base_url">API URL *</Label>
<Input
@@ -67,16 +31,70 @@
required
/>
</div>
<div class="space-y-2">
<Label for="custom_path">自定义请求路径可选</Label>
<Input
id="custom_path"
v-model="form.custom_path"
:placeholder="isEditMode ? defaultPathPlaceholder : '留空使用各格式的默认路径'"
/>
</div>
</div>
<!-- 自定义路径 -->
<!-- API 格式 -->
<div class="space-y-2">
<Label for="custom_path">自定义请求路径可选</Label>
<Input
id="custom_path"
v-model="form.custom_path"
:placeholder="defaultPathPlaceholder"
/>
<Label for="api_format">API 格式 *</Label>
<template v-if="isEditMode">
<Input
id="api_format"
v-model="form.api_format"
disabled
class="bg-muted"
/>
<p class="text-xs text-muted-foreground">
API 格式创建后不可修改
</p>
</template>
<template v-else>
<div class="grid grid-cols-3 grid-flow-col grid-rows-2 gap-2">
<label
v-for="format in sortedApiFormats"
:key="format.value"
class="flex items-center gap-2 rounded-md border px-3 py-2 cursor-pointer transition-all text-sm"
:class="selectedFormats.includes(format.value)
? 'border-primary bg-primary/10 text-primary font-medium'
: 'border-border hover:border-primary/50 hover:bg-accent'"
>
<input
type="checkbox"
:value="format.value"
v-model="selectedFormats"
class="sr-only"
/>
<span
class="flex h-4 w-4 shrink-0 items-center justify-center rounded border transition-colors"
:class="selectedFormats.includes(format.value)
? 'border-primary bg-primary text-primary-foreground'
: 'border-muted-foreground/30'"
>
<svg
v-if="selectedFormats.includes(format.value)"
class="h-3 w-3"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="3"
stroke-linecap="round"
stroke-linejoin="round"
>
<polyline points="20 6 9 17 4 12" />
</svg>
</span>
<span>{{ format.label }}</span>
</label>
</div>
</template>
</div>
</div>
@@ -86,7 +104,7 @@
请求配置
</h3>
<div class="grid grid-cols-3 gap-4">
<div class="grid grid-cols-4 gap-4">
<div class="space-y-2">
<Label for="timeout">超时</Label>
<Input
@@ -117,11 +135,9 @@
@update:model-value="(v) => form.max_concurrent = parseNumberInput(v)"
/>
</div>
</div>
<div class="grid grid-cols-2 gap-4">
<div class="space-y-2">
<Label for="rate_limit">速率限制(请求/分钟)</Label>
<Label for="rate_limit">速率限制(/分钟)</Label>
<Input
id="rate_limit"
:model-value="form.rate_limit ?? ''"
@@ -217,10 +233,10 @@
取消
</Button>
<Button
:disabled="loading || !form.base_url || (!isEditMode && !form.api_format)"
:disabled="loading || !form.base_url || (!isEditMode && selectedFormats.length === 0)"
@click="handleSubmit()"
>
{{ loading ? (isEditMode ? '保存中...' : '创建中...') : (isEditMode ? '保存修改' : '创建') }}
{{ loading ? (isEditMode ? '保存中...' : '创建中...') : (isEditMode ? '保存修改' : `创建 ${selectedFormats.length} 个端点`) }}
</Button>
</template>
</Dialog>
@@ -245,11 +261,6 @@ import {
Button,
Input,
Label,
Select,
SelectTrigger,
SelectValue,
SelectContent,
SelectItem,
Switch,
} from '@/components/ui'
import AlertDialog from '@/components/common/AlertDialog.vue'
@@ -280,7 +291,6 @@ const emit = defineEmits<{
const { success, error: showError } = useToast()
const loading = ref(false)
const selectOpen = ref(false)
const proxyEnabled = ref(false)
const showClearCredentialsDialog = ref(false) // 确认清空凭据对话框
@@ -296,7 +306,7 @@ const form = ref({
base_url: '',
custom_path: '',
timeout: 300,
max_retries: 3,
max_retries: 2,
max_concurrent: undefined as number | undefined,
rate_limit: undefined as number | undefined,
is_active: true,
@@ -306,9 +316,28 @@ const form = ref({
proxy_password: '',
})
// 选中的 API 格式(多选)
const selectedFormats = ref<string[]>([])
// API 格式列表
const apiFormats = ref<Array<{ value: string; label: string; default_path: string; aliases: string[] }>>([])
// 排序后的 API 格式:按列排列,每列是基础格式+CLI格式
const sortedApiFormats = computed(() => {
const baseFormats = apiFormats.value.filter(f => !f.value.endsWith('_cli'))
const cliFormats = apiFormats.value.filter(f => f.value.endsWith('_cli'))
// 交错排列base1, cli1, base2, cli2, base3, cli3
const result: typeof apiFormats.value = []
for (let i = 0; i < baseFormats.length; i++) {
result.push(baseFormats[i])
const cliFormat = cliFormats.find(f => f.value === baseFormats[i].value + '_cli')
if (cliFormat) {
result.push(cliFormat)
}
}
return result
})
// 加载API格式列表
const loadApiFormats = async () => {
try {
@@ -330,7 +359,7 @@ const defaultPath = computed(() => {
// 动态 placeholder
const defaultPathPlaceholder = computed(() => {
return `留空使用默认路径:${defaultPath.value}`
return defaultPath.value
})
// 检查是否有已保存的密码(后端返回 *** 表示有密码)
@@ -392,7 +421,7 @@ function resetForm() {
base_url: '',
custom_path: '',
timeout: 300,
max_retries: 3,
max_retries: 2,
max_concurrent: undefined,
rate_limit: undefined,
is_active: true,
@@ -400,6 +429,7 @@ function resetForm() {
proxy_username: '',
proxy_password: '',
}
selectedFormats.value = []
proxyEnabled.value = false
}
@@ -479,6 +509,8 @@ const handleSubmit = async (skipCredentialCheck = false) => {
}
loading.value = true
let successCount = 0
try {
const proxyConfig = buildProxyConfig()
@@ -497,27 +529,56 @@ const handleSubmit = async (skipCredentialCheck = false) => {
success('端点已更新', '保存成功')
emit('endpointUpdated')
emit('update:modelValue', false)
} else if (props.provider) {
// 创建端点
await createEndpoint(props.provider.id, {
provider_id: props.provider.id,
api_format: form.value.api_format,
base_url: form.value.base_url,
custom_path: form.value.custom_path || undefined,
timeout: form.value.timeout,
max_retries: form.value.max_retries,
max_concurrent: form.value.max_concurrent,
rate_limit: form.value.rate_limit,
is_active: form.value.is_active,
proxy: proxyConfig,
// 批量创建端点 - 使用并发请求提升性能
const results = await Promise.allSettled(
selectedFormats.value.map(apiFormat =>
createEndpoint(props.provider!.id, {
provider_id: props.provider!.id,
api_format: apiFormat,
base_url: form.value.base_url,
custom_path: form.value.custom_path || undefined,
timeout: form.value.timeout,
max_retries: form.value.max_retries,
max_concurrent: form.value.max_concurrent,
rate_limit: form.value.rate_limit,
is_active: form.value.is_active,
proxy: proxyConfig,
})
)
)
// 统计结果
const errors: string[] = []
results.forEach((result, index) => {
if (result.status === 'fulfilled') {
successCount++
} else {
const apiFormat = selectedFormats.value[index]
const formatLabel = apiFormats.value.find((f: any) => f.value === apiFormat)?.label || apiFormat
const errorMsg = result.reason?.response?.data?.detail || '创建失败'
errors.push(`${formatLabel}: ${errorMsg}`)
}
})
success('端点创建成功', '成功')
emit('endpointCreated')
resetForm()
}
const failCount = errors.length
emit('update:modelValue', false)
// 显示结果
if (successCount > 0 && failCount === 0) {
success(`成功创建 ${successCount} 个端点`, '创建成功')
} else if (successCount > 0 && failCount > 0) {
showError(`${failCount} 个端点创建失败:\n${errors.join('\n')}`, `${successCount} 个成功,${failCount} 个失败`)
} else {
showError(errors.join('\n') || '创建端点失败', '创建失败')
}
if (successCount > 0) {
emit('endpointCreated')
resetForm()
emit('update:modelValue', false)
}
}
} catch (error: any) {
const action = isEditMode.value ? '更新' : '创建'
showError(error.response?.data?.detail || `${action}端点失败`, '错误')

View File

@@ -32,11 +32,11 @@
v-for="modelName in selectedModels"
:key="modelName"
variant="secondary"
class="text-[11px] px-2 py-0.5 bg-background border-border/60 shadow-sm"
class="text-[11px] px-2 py-0.5 bg-background border-border/60 shadow-sm text-foreground dark:text-white"
>
{{ getModelLabel(modelName) }}
<button
class="ml-0.5 hover:text-destructive focus:outline-none"
class="ml-0.5 hover:text-destructive focus:outline-none text-foreground dark:text-white"
@click.stop="toggleModel(modelName, false)"
>
&times;

View File

@@ -349,8 +349,8 @@ const apiKeyError = computed(() => {
}
// 如果输入了值,检查长度
if (apiKey.length < 10) {
return 'API 密钥至少需要 10 个字符'
if (apiKey.length < 3) {
return 'API 密钥至少需要 3 个字符'
}
return ''

View File

@@ -262,17 +262,17 @@
<div class="shrink-0 flex items-center gap-3">
<!-- 健康度 -->
<div
v-if="key.success_rate !== null"
v-if="key.health_score != null"
class="text-xs text-right"
>
<div
class="font-medium tabular-nums"
:class="[
key.success_rate >= 0.95 ? 'text-green-600' :
key.success_rate >= 0.8 ? 'text-yellow-600' : 'text-red-500'
key.health_score >= 0.95 ? 'text-green-600' :
key.health_score >= 0.5 ? 'text-yellow-600' : 'text-red-500'
]"
>
{{ (key.success_rate * 100).toFixed(0) }}%
{{ ((key.health_score || 0) * 100).toFixed(0) }}%
</div>
<div class="text-[10px] text-muted-foreground opacity-70">
{{ key.request_count }} reqs
@@ -319,19 +319,6 @@
<div class="flex items-center gap-2 pl-4 border-l border-border">
<span class="text-xs text-muted-foreground">调度:</span>
<div class="flex gap-0.5 p-0.5 bg-muted/40 rounded-md">
<button
type="button"
class="px-2 py-1 text-xs font-medium rounded transition-all"
:class="[
schedulingMode === 'fixed_order'
? 'bg-primary text-primary-foreground shadow-sm'
: 'text-muted-foreground hover:text-foreground hover:bg-muted/50'
]"
title="严格按优先级顺序,不考虑缓存"
@click="schedulingMode = 'fixed_order'"
>
固定顺序
</button>
<button
type="button"
class="px-2 py-1 text-xs font-medium rounded transition-all"
@@ -345,6 +332,32 @@
>
缓存亲和
</button>
<button
type="button"
class="px-2 py-1 text-xs font-medium rounded transition-all"
:class="[
schedulingMode === 'load_balance'
? 'bg-primary text-primary-foreground shadow-sm'
: 'text-muted-foreground hover:text-foreground hover:bg-muted/50'
]"
title="同优先级内随机轮换,不考虑缓存"
@click="schedulingMode = 'load_balance'"
>
负载均衡
</button>
<button
type="button"
class="px-2 py-1 text-xs font-medium rounded transition-all"
:class="[
schedulingMode === 'fixed_order'
? 'bg-primary text-primary-foreground shadow-sm'
: 'text-muted-foreground hover:text-foreground hover:bg-muted/50'
]"
title="严格按优先级顺序,不考虑缓存"
@click="schedulingMode = 'fixed_order'"
>
固定顺序
</button>
</div>
</div>
</div>
@@ -400,6 +413,7 @@ interface KeyWithMeta {
endpoint_base_url: string
api_format: string
capabilities: string[]
health_score: number | null
success_rate: number | null
avg_response_time_ms: number | null
request_count: number
@@ -444,7 +458,7 @@ const saving = ref(false)
const editingKeyPriority = ref<Record<string, string | null>>({}) // format -> keyId
// 调度模式状态
const schedulingMode = ref<'fixed_order' | 'cache_affinity'>('cache_affinity')
const schedulingMode = ref<'fixed_order' | 'load_balance' | 'cache_affinity'>('cache_affinity')
// 可用的 API 格式
const availableFormats = computed(() => {
@@ -477,7 +491,11 @@ async function loadCurrentPriorityMode() {
activeMainTab.value = currentMode === 'global_key' ? 'key' : 'provider'
const currentSchedulingMode = schedulingResponse.value || 'cache_affinity'
schedulingMode.value = currentSchedulingMode === 'fixed_order' ? 'fixed_order' : 'cache_affinity'
if (currentSchedulingMode === 'fixed_order' || currentSchedulingMode === 'load_balance' || currentSchedulingMode === 'cache_affinity') {
schedulingMode.value = currentSchedulingMode
} else {
schedulingMode.value = 'cache_affinity'
}
} catch {
activeMainTab.value = 'provider'
schedulingMode.value = 'cache_affinity'

View File

@@ -178,7 +178,7 @@
<Button
variant="ghost"
size="icon"
class="h-8 w-8 text-destructive hover:text-destructive"
class="h-8 w-8 hover:text-destructive"
title="删除"
@click="deleteModel(model)"
>

View File

@@ -295,6 +295,15 @@
</template>
<RouterView />
<!-- 更新提示弹窗 -->
<UpdateDialog
v-if="updateInfo"
v-model="showUpdateDialog"
:current-version="updateInfo.current_version"
:latest-version="updateInfo.latest_version || ''"
:release-url="updateInfo.release_url"
/>
</AppShell>
</template>
@@ -304,10 +313,12 @@ import { useRoute, useRouter } from 'vue-router'
import { useAuthStore } from '@/stores/auth'
import { useDarkMode } from '@/composables/useDarkMode'
import { isDemoMode } from '@/config/demo'
import { adminApi, type CheckUpdateResponse } from '@/api/admin'
import Button from '@/components/ui/button.vue'
import AppShell from '@/components/layout/AppShell.vue'
import SidebarNav from '@/components/layout/SidebarNav.vue'
import HeaderLogo from '@/components/HeaderLogo.vue'
import UpdateDialog from '@/components/common/UpdateDialog.vue'
import {
Home,
Users,
@@ -345,17 +356,67 @@ const showAuthError = ref(false)
const mobileMenuOpen = ref(false)
let authCheckInterval: number | null = null
// 更新检查相关
const showUpdateDialog = ref(false)
const updateInfo = ref<CheckUpdateResponse | null>(null)
// 路由变化时自动关闭移动端菜单
watch(() => route.path, () => {
mobileMenuOpen.value = false
})
// 检查是否应该显示更新提示
function shouldShowUpdatePrompt(latestVersion: string): boolean {
const ignoreKey = 'aether_update_ignore'
const ignoreData = localStorage.getItem(ignoreKey)
if (!ignoreData) return true
try {
const { version, until } = JSON.parse(ignoreData)
// 如果忽略的是同一版本且未过期,则不显示
if (version === latestVersion && Date.now() < until) {
return false
}
} catch {
// 解析失败,显示提示
}
return true
}
// 检查更新
async function checkForUpdate() {
// 只有管理员才检查更新
if (authStore.user?.role !== 'admin') return
// 同一会话内只检查一次
const sessionKey = 'aether_update_checked'
if (sessionStorage.getItem(sessionKey)) return
sessionStorage.setItem(sessionKey, '1')
try {
const result = await adminApi.checkUpdate()
if (result.has_update && result.latest_version) {
if (shouldShowUpdatePrompt(result.latest_version)) {
updateInfo.value = result
showUpdateDialog.value = true
}
}
} catch {
// 静默失败,不影响用户体验
}
}
onMounted(() => {
authCheckInterval = setInterval(() => {
if (authStore.user && !authStore.token) {
showAuthError.value = true
}
}, 5000)
// 延迟检查更新,避免影响页面加载
setTimeout(() => {
checkForUpdate()
}, 2000)
})
onUnmounted(() => {

View File

@@ -425,9 +425,9 @@ const MOCK_ENDPOINT_KEYS = [
// Mock Endpoints
const MOCK_ENDPOINTS = [
{ id: 'ep-001', provider_id: 'provider-001', provider_name: 'anthropic', api_format: 'claude', base_url: 'https://api.anthropic.com', auth_type: 'bearer', timeout: 120, max_retries: 3, priority: 100, weight: 100, health_score: 98, consecutive_failures: 0, is_active: true, total_keys: 2, active_keys: 2, created_at: '2024-01-01T00:00:00Z', updated_at: new Date().toISOString() },
{ id: 'ep-002', provider_id: 'provider-002', provider_name: 'openai', api_format: 'openai', base_url: 'https://api.openai.com', auth_type: 'bearer', timeout: 60, max_retries: 3, priority: 90, weight: 100, health_score: 97, consecutive_failures: 0, is_active: true, total_keys: 1, active_keys: 1, created_at: '2024-01-01T00:00:00Z', updated_at: new Date().toISOString() },
{ id: 'ep-003', provider_id: 'provider-003', provider_name: 'google', api_format: 'gemini', base_url: 'https://generativelanguage.googleapis.com', auth_type: 'api_key', timeout: 60, max_retries: 3, priority: 80, weight: 100, health_score: 96, consecutive_failures: 0, is_active: true, total_keys: 1, active_keys: 1, created_at: '2024-01-15T00:00:00Z', updated_at: new Date().toISOString() }
{ id: 'ep-001', provider_id: 'provider-001', provider_name: 'anthropic', api_format: 'claude', base_url: 'https://api.anthropic.com', auth_type: 'bearer', timeout: 120, max_retries: 2, priority: 100, weight: 100, health_score: 98, consecutive_failures: 0, is_active: true, total_keys: 2, active_keys: 2, created_at: '2024-01-01T00:00:00Z', updated_at: new Date().toISOString() },
{ id: 'ep-002', provider_id: 'provider-002', provider_name: 'openai', api_format: 'openai', base_url: 'https://api.openai.com', auth_type: 'bearer', timeout: 60, max_retries: 2, priority: 90, weight: 100, health_score: 97, consecutive_failures: 0, is_active: true, total_keys: 1, active_keys: 1, created_at: '2024-01-01T00:00:00Z', updated_at: new Date().toISOString() },
{ id: 'ep-003', provider_id: 'provider-003', provider_name: 'google', api_format: 'gemini', base_url: 'https://generativelanguage.googleapis.com', auth_type: 'api_key', timeout: 60, max_retries: 2, priority: 80, weight: 100, health_score: 96, consecutive_failures: 0, is_active: true, total_keys: 1, active_keys: 1, created_at: '2024-01-15T00:00:00Z', updated_at: new Date().toISOString() }
]
// Mock 能力定义
@@ -1224,7 +1224,7 @@ function generateMockEndpointsForProvider(providerId: string) {
'https://generativelanguage.googleapis.com',
auth_type: format.includes('GEMINI') ? 'api_key' : 'bearer',
timeout: 120,
max_retries: 3,
max_retries: 2,
priority: 100 - index * 10,
weight: 100,
health_score: healthDetail?.health_score ?? 1.0,

View File

@@ -54,7 +54,7 @@ const fieldNameMap: Record<string, string> = {
*/
const errorTypeMap: Record<string, (error: ValidationError) => string> = {
'string_too_short': (error) => {
const minLength = error.ctx?.min_length || 10
const minLength = error.ctx?.min_length || 3
return `长度不能少于 ${minLength} 个字符`
},
'string_too_long': (error) => {

View File

@@ -1,12 +0,0 @@
#!/bin/bash
# 数据库迁移脚本 - 在 Docker 容器内执行 Alembic 迁移
set -e
CONTAINER_NAME="aether-app"
echo "Running database migrations in container: $CONTAINER_NAME"
docker exec $CONTAINER_NAME alembic upgrade head
echo "Database migration completed successfully"

View File

@@ -1,34 +0,0 @@
# file generated by setuptools-scm
# don't change, don't track in version control
__all__ = [
"__version__",
"__version_tuple__",
"version",
"version_tuple",
"__commit_id__",
"commit_id",
]
TYPE_CHECKING = False
if TYPE_CHECKING:
from typing import Tuple
from typing import Union
VERSION_TUPLE = Tuple[Union[int, str], ...]
COMMIT_ID = Union[str, None]
else:
VERSION_TUPLE = object
COMMIT_ID = object
version: str
__version__: str
__version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE
commit_id: COMMIT_ID
__commit_id__: COMMIT_ID
__version__ = version = '0.2.5'
__version_tuple__ = version_tuple = (0, 2, 5)
__commit_id__ = commit_id = None

View File

@@ -18,6 +18,7 @@ from src.core.key_capabilities import get_capability
from src.core.logger import logger
from src.database import get_db
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
from src.services.cache.provider_cache import ProviderCacheService
from src.models.endpoint_models import (
BatchUpdateKeyPriorityRequest,
EndpointAPIKeyCreate,
@@ -411,6 +412,10 @@ class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
db.commit()
db.refresh(key)
# 如果更新了 rate_multiplier清除缓存
if "rate_multiplier" in update_data:
await ProviderCacheService.invalidate_provider_api_key_cache(self.key_id)
logger.info("[OK] 更新 Key: ID=%s, Updates=%s", self.key_id, list(update_data.keys()))
try:
@@ -550,6 +555,7 @@ class AdminGetKeysGroupedByFormatAdapter(AdminApiAdapter):
"endpoint_base_url": endpoint.base_url,
"api_format": api_format,
"capabilities": caps_list,
"health_score": key.health_score,
"success_rate": success_rate,
"avg_response_time_ms": avg_response_time_ms,
"request_count": key.request_count,

View File

@@ -11,9 +11,11 @@ from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.enums import ProviderBillingType
from src.core.exceptions import InvalidRequestException, NotFoundException
from src.core.logger import logger
from src.database import get_db
from src.models.admin_requests import CreateProviderRequest, UpdateProviderRequest
from src.models.database import Provider
from src.services.cache.provider_cache import ProviderCacheService
router = APIRouter(tags=["Provider CRUD"])
pipeline = ApiRequestPipeline()
@@ -296,6 +298,11 @@ class AdminUpdateProviderAdapter(AdminApiAdapter):
db.commit()
db.refresh(provider)
# 如果更新了 billing_type清除缓存
if "billing_type" in update_data:
await ProviderCacheService.invalidate_provider_cache(provider.id)
logger.debug(f"已清除 Provider 缓存: {provider.id}")
context.add_audit_metadata(
action="update_provider",
provider_id=provider.id,

View File

@@ -42,6 +42,42 @@ def _get_version_from_git() -> str | None:
return None
def _get_current_version() -> str:
"""获取当前版本号"""
version = _get_version_from_git()
if version:
return version
try:
from src._version import __version__
return __version__
except ImportError:
return "unknown"
def _parse_version(version_str: str) -> tuple:
"""解析版本号为可比较的元组,支持 3-4 段版本号
例如:
- '0.2.5' -> (0, 2, 5, 0)
- '0.2.5.1' -> (0, 2, 5, 1)
- 'v0.2.5-4-g1234567' -> (0, 2, 5, 0)
"""
import re
version_str = version_str.lstrip("v")
main_version = re.split(r"[-+]", version_str)[0]
try:
parts = main_version.split(".")
# 标准化为 4 段,便于比较
int_parts = [int(p) for p in parts]
while len(int_parts) < 4:
int_parts.append(0)
return tuple(int_parts[:4])
except ValueError:
return (0, 0, 0, 0)
@router.get("/version")
async def get_system_version():
"""
@@ -52,18 +88,111 @@ async def get_system_version():
**返回字段**:
- `version`: 版本号字符串
"""
# 优先从 git 获取
version = _get_version_from_git()
if version:
return {"version": version}
return {"version": _get_current_version()}
@router.get("/check-update")
async def check_update():
"""
检查系统更新
从 GitHub Tags 获取最新版本并与当前版本对比。
**返回字段**:
- `current_version`: 当前版本号
- `latest_version`: 最新版本号
- `has_update`: 是否有更新可用
- `release_url`: 最新版本的 GitHub 页面链接
"""
import httpx
from src.clients.http_client import HTTPClientPool
current_version = _get_current_version()
github_repo = "Aethersailor/Aether"
github_tags_url = f"https://api.github.com/repos/{github_repo}/tags"
# 回退到静态版本文件
try:
from src._version import __version__
async with HTTPClientPool.get_temp_client(
timeout=httpx.Timeout(connect=5.0, read=10.0, write=5.0, pool=5.0)
) as client:
response = await client.get(
github_tags_url,
headers={
"Accept": "application/vnd.github.v3+json",
"User-Agent": f"Aether/{current_version}",
},
params={"per_page": 10},
)
return {"version": __version__}
except ImportError:
return {"version": "unknown"}
if response.status_code != 200:
return {
"current_version": current_version,
"latest_version": None,
"has_update": False,
"release_url": None,
"error": f"GitHub API 返回错误: {response.status_code}",
}
tags = response.json()
if not tags:
return {
"current_version": current_version,
"latest_version": None,
"has_update": False,
"release_url": None,
"error": None,
}
# 找到最新的版本 tag按版本号排序而非时间
version_tags = []
for tag in tags:
tag_name = tag.get("name", "")
if tag_name.startswith("v") or tag_name[0].isdigit():
version_tags.append((tag_name, _parse_version(tag_name)))
if not version_tags:
return {
"current_version": current_version,
"latest_version": None,
"has_update": False,
"release_url": None,
"error": None,
}
# 按版本号排序,取最大的
version_tags.sort(key=lambda x: x[1], reverse=True)
latest_tag = version_tags[0][0]
latest_version = latest_tag.lstrip("v")
current_tuple = _parse_version(current_version)
latest_tuple = _parse_version(latest_version)
has_update = latest_tuple > current_tuple
return {
"current_version": current_version,
"latest_version": latest_version,
"has_update": has_update,
"release_url": f"https://github.com/{github_repo}/releases/tag/{latest_tag}",
"error": None,
}
except httpx.TimeoutException:
return {
"current_version": current_version,
"latest_version": None,
"has_update": False,
"release_url": None,
"error": "检查更新超时",
}
except Exception as e:
return {
"current_version": current_version,
"latest_version": None,
"has_update": False,
"release_url": None,
"error": f"检查更新失败: {str(e)}",
}
pipeline = ApiRequestPipeline()
@@ -887,7 +1016,7 @@ class AdminImportConfigAdapter(AdminApiAdapter):
)
existing_ep.headers = ep_data.get("headers")
existing_ep.timeout = ep_data.get("timeout", 300)
existing_ep.max_retries = ep_data.get("max_retries", 3)
existing_ep.max_retries = ep_data.get("max_retries", 2)
existing_ep.max_concurrent = ep_data.get("max_concurrent")
existing_ep.rate_limit = ep_data.get("rate_limit")
existing_ep.is_active = ep_data.get("is_active", True)
@@ -903,7 +1032,7 @@ class AdminImportConfigAdapter(AdminApiAdapter):
base_url=ep_data["base_url"],
headers=ep_data.get("headers"),
timeout=ep_data.get("timeout", 300),
max_retries=ep_data.get("max_retries", 3),
max_retries=ep_data.get("max_retries", 2),
max_concurrent=ep_data.get("max_concurrent"),
rate_limit=ep_data.get("rate_limit"),
is_active=ep_data.get("is_active", True),

View File

@@ -691,64 +691,70 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
f" [{self.request_id}] 请求体stream字段: {provider_payload.get('stream', 'N/A')}"
)
# 创建 HTTP 客户端(支持代理配置)
# endpoint.timeout 作为整体请求超时
# 获取复用的 HTTP 客户端(支持代理配置)
# 注意:使用 get_proxy_client 复用连接池,不再每次创建新客户端
from src.clients.http_client import HTTPClientPool
request_timeout = float(endpoint.timeout or 300)
http_client = HTTPClientPool.create_client_with_proxy(
http_client = await HTTPClientPool.get_proxy_client(
proxy_config=endpoint.proxy,
)
# 注意:不使用 async with因为复用的客户端不应该被关闭
# 超时通过 timeout 参数控制
resp = await http_client.post(
url,
json=provider_payload,
headers=provider_hdrs,
timeout=httpx.Timeout(request_timeout),
)
async with http_client:
resp = await http_client.post(url, json=provider_payload, headers=provider_hdrs)
status_code = resp.status_code
response_headers = dict(resp.headers)
status_code = resp.status_code
response_headers = dict(resp.headers)
if resp.status_code == 401:
raise ProviderAuthException(f"提供商认证失败: {provider.name}")
elif resp.status_code == 429:
raise ProviderRateLimitException(
f"提供商速率限制: {provider.name}",
provider_name=str(provider.name),
response_headers=response_headers,
if resp.status_code == 401:
raise ProviderAuthException(f"提供商认证失败: {provider.name}")
elif resp.status_code == 429:
raise ProviderRateLimitException(
f"提供商速率限制: {provider.name}",
provider_name=str(provider.name),
response_headers=response_headers,
)
elif resp.status_code >= 500:
# 记录响应体以便调试
error_body = ""
try:
error_body = resp.text[:1000]
logger.error(
f" [{self.request_id}] 上游返回5xx错误: status={resp.status_code}, body={error_body[:500]}"
)
elif resp.status_code >= 500:
# 记录响应体以便调试
error_body = ""
try:
error_body = resp.text[:1000]
logger.error(
f" [{self.request_id}] 上游返回5xx错误: status={resp.status_code}, body={error_body[:500]}"
)
except Exception:
pass
raise ProviderNotAvailableException(
f"提供商服务不可用: {provider.name}",
provider_name=str(provider.name),
upstream_status=resp.status_code,
upstream_response=error_body,
)
elif resp.status_code != 200:
# 记录非200响应以便调试
error_body = ""
try:
error_body = resp.text[:1000]
logger.warning(
f" [{self.request_id}] 上游返回非200: status={resp.status_code}, body={error_body[:500]}"
)
except Exception:
pass
raise ProviderNotAvailableException(
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}",
provider_name=str(provider.name),
upstream_status=resp.status_code,
upstream_response=error_body,
except Exception:
pass
raise ProviderNotAvailableException(
f"提供商服务不可用: {provider.name}",
provider_name=str(provider.name),
upstream_status=resp.status_code,
upstream_response=error_body,
)
elif resp.status_code != 200:
# 记录非200响应以便调试
error_body = ""
try:
error_body = resp.text[:1000]
logger.warning(
f" [{self.request_id}] 上游返回非200: status={resp.status_code}, body={error_body[:500]}"
)
except Exception:
pass
raise ProviderNotAvailableException(
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}",
provider_name=str(provider.name),
upstream_status=resp.status_code,
upstream_response=error_body,
)
response_json = resp.json()
return response_json if isinstance(response_json, dict) else {}
response_json = resp.json()
return response_json if isinstance(response_json, dict) else {}
try:
# 解析能力需求

View File

@@ -1534,72 +1534,78 @@ class CliMessageHandlerBase(BaseMessageHandler):
f"原始模型={model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}"
)
# 创建 HTTP 客户端(支持代理配置)
# endpoint.timeout 作为整体请求超时
# 获取复用的 HTTP 客户端(支持代理配置)
# 注意:使用 get_proxy_client 复用连接池,不再每次创建新客户端
from src.clients.http_client import HTTPClientPool
request_timeout = float(endpoint.timeout or 300)
http_client = HTTPClientPool.create_client_with_proxy(
http_client = await HTTPClientPool.get_proxy_client(
proxy_config=endpoint.proxy,
)
# 注意:不使用 async with因为复用的客户端不应该被关闭
# 超时通过 timeout 参数控制
resp = await http_client.post(
url,
json=provider_payload,
headers=provider_headers,
timeout=httpx.Timeout(request_timeout),
)
async with http_client:
resp = await http_client.post(url, json=provider_payload, headers=provider_headers)
status_code = resp.status_code
response_headers = dict(resp.headers)
status_code = resp.status_code
response_headers = dict(resp.headers)
if resp.status_code == 401:
raise ProviderAuthException(f"提供商认证失败: {provider.name}")
elif resp.status_code == 429:
raise ProviderRateLimitException(
f"提供商速率限制: {provider.name}",
provider_name=str(provider.name),
response_headers=response_headers,
retry_after=int(resp.headers.get("retry-after", 0)) or None,
)
elif resp.status_code >= 500:
error_text = resp.text
raise ProviderNotAvailableException(
f"提供商服务不可用: {provider.name}, 状态: {resp.status_code}",
provider_name=str(provider.name),
upstream_status=resp.status_code,
upstream_response=error_text,
)
elif 300 <= resp.status_code < 400:
redirect_url = resp.headers.get("location", "unknown")
raise ProviderNotAvailableException(
f"提供商配置错误: {provider.name}, 返回重定向 {resp.status_code} -> {redirect_url}"
)
elif resp.status_code != 200:
error_text = resp.text
raise ProviderNotAvailableException(
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}",
provider_name=str(provider.name),
upstream_status=resp.status_code,
upstream_response=error_text,
)
if resp.status_code == 401:
raise ProviderAuthException(f"提供商认证失败: {provider.name}")
elif resp.status_code == 429:
raise ProviderRateLimitException(
f"提供商速率限制: {provider.name}",
provider_name=str(provider.name),
response_headers=response_headers,
retry_after=int(resp.headers.get("retry-after", 0)) or None,
)
elif resp.status_code >= 500:
error_text = resp.text
raise ProviderNotAvailableException(
f"提供商服务不可用: {provider.name}, 状态: {resp.status_code}",
provider_name=str(provider.name),
upstream_status=resp.status_code,
upstream_response=error_text,
)
elif 300 <= resp.status_code < 400:
redirect_url = resp.headers.get("location", "unknown")
raise ProviderNotAvailableException(
f"提供商配置错误: {provider.name}, 返回重定向 {resp.status_code} -> {redirect_url}"
)
elif resp.status_code != 200:
error_text = resp.text
raise ProviderNotAvailableException(
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}",
provider_name=str(provider.name),
upstream_status=resp.status_code,
upstream_response=error_text,
)
# 安全解析 JSON 响应,处理可能的编码错误
try:
response_json = resp.json()
except (UnicodeDecodeError, json.JSONDecodeError) as e:
# 记录原始响应信息用于调试
content_type = resp.headers.get("content-type", "unknown")
content_encoding = resp.headers.get("content-encoding", "none")
logger.error(
f"[{self.request_id}] 无法解析响应 JSON: {e}, "
f"Content-Type: {content_type}, Content-Encoding: {content_encoding}, "
f"响应长度: {len(resp.content)} bytes"
)
raise ProviderNotAvailableException(
f"提供商返回无效响应: {provider.name}, 无法解析 JSON: {str(e)[:100]}"
)
# 安全解析 JSON 响应,处理可能的编码错误
try:
response_json = resp.json()
except (UnicodeDecodeError, json.JSONDecodeError) as e:
# 记录原始响应信息用于调试
content_type = resp.headers.get("content-type", "unknown")
content_encoding = resp.headers.get("content-encoding", "none")
logger.error(
f"[{self.request_id}] 无法解析响应 JSON: {e}, "
f"Content-Type: {content_type}, Content-Encoding: {content_encoding}, "
f"响应长度: {len(resp.content)} bytes"
)
raise ProviderNotAvailableException(
f"提供商返回无效响应: {provider.name}, 无法解析 JSON: {str(e)[:100]}"
)
# 提取 Provider 响应元数据(子类可覆盖)
response_metadata_result = self._extract_response_metadata(response_json)
# 提取 Provider 响应元数据(子类可覆盖)
response_metadata_result = self._extract_response_metadata(response_json)
return response_json if isinstance(response_json, dict) else {}
return response_json if isinstance(response_json, dict) else {}
try:
# 解析能力需求

View File

@@ -1,10 +1,18 @@
"""
全局HTTP客户端池管理
避免每次请求都创建新的AsyncClient,提高性能
性能优化说明:
1. 默认客户端:无代理场景,全局复用单一客户端
2. 代理客户端缓存:相同代理配置复用同一客户端,避免重复创建
3. 连接池复用Keep-alive 连接减少 TCP 握手开销
"""
import asyncio
import hashlib
import time
from contextlib import asynccontextmanager
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Tuple
from urllib.parse import quote, urlparse
import httpx
@@ -12,6 +20,32 @@ import httpx
from src.config import config
from src.core.logger import logger
# 模块级锁,避免类属性延迟初始化的竞态条件
_proxy_clients_lock = asyncio.Lock()
_default_client_lock = asyncio.Lock()
def _compute_proxy_cache_key(proxy_config: Optional[Dict[str, Any]]) -> str:
"""
计算代理配置的缓存键
Args:
proxy_config: 代理配置字典
Returns:
缓存键字符串,无代理时返回 "__no_proxy__"
"""
if not proxy_config:
return "__no_proxy__"
# 构建代理 URL 作为缓存键的基础
proxy_url = build_proxy_url(proxy_config)
if not proxy_url:
return "__no_proxy__"
# 使用 MD5 哈希来避免过长的键名
return f"proxy:{hashlib.md5(proxy_url.encode()).hexdigest()[:16]}"
def build_proxy_url(proxy_config: Dict[str, Any]) -> Optional[str]:
"""
@@ -61,11 +95,20 @@ class HTTPClientPool:
全局HTTP客户端池单例
管理可重用的httpx.AsyncClient实例,避免频繁创建/销毁连接
性能优化:
1. 默认客户端:无代理场景复用
2. 代理客户端缓存:相同代理配置复用同一客户端
3. LRU 淘汰:代理客户端超过上限时淘汰最久未使用的
"""
_instance: Optional["HTTPClientPool"] = None
_default_client: Optional[httpx.AsyncClient] = None
_clients: Dict[str, httpx.AsyncClient] = {}
# 代理客户端缓存:{cache_key: (client, last_used_time)}
_proxy_clients: Dict[str, Tuple[httpx.AsyncClient, float]] = {}
# 代理客户端缓存上限(避免内存泄漏)
_max_proxy_clients: int = 50
def __new__(cls):
if cls._instance is None:
@@ -73,12 +116,50 @@ class HTTPClientPool:
return cls._instance
@classmethod
def get_default_client(cls) -> httpx.AsyncClient:
async def get_default_client_async(cls) -> httpx.AsyncClient:
"""
获取默认的HTTP客户端
获取默认的HTTP客户端(异步线程安全版本)
用于大多数HTTP请求,具有合理的默认配置
"""
if cls._default_client is not None:
return cls._default_client
async with _default_client_lock:
# 双重检查,避免重复创建
if cls._default_client is None:
cls._default_client = httpx.AsyncClient(
http2=False, # 暂时禁用HTTP/2以提高兼容性
verify=True, # 启用SSL验证
timeout=httpx.Timeout(
connect=config.http_connect_timeout,
read=config.http_read_timeout,
write=config.http_write_timeout,
pool=config.http_pool_timeout,
),
limits=httpx.Limits(
max_connections=config.http_max_connections,
max_keepalive_connections=config.http_keepalive_connections,
keepalive_expiry=config.http_keepalive_expiry,
),
follow_redirects=True, # 跟随重定向
)
logger.info(
f"全局HTTP客户端池已初始化: "
f"max_connections={config.http_max_connections}, "
f"keepalive={config.http_keepalive_connections}, "
f"keepalive_expiry={config.http_keepalive_expiry}s"
)
return cls._default_client
@classmethod
def get_default_client(cls) -> httpx.AsyncClient:
"""
获取默认的HTTP客户端同步版本向后兼容
⚠️ 注意:此方法在高并发首次调用时可能存在竞态条件,
推荐使用 get_default_client_async() 异步版本。
"""
if cls._default_client is None:
cls._default_client = httpx.AsyncClient(
http2=False, # 暂时禁用HTTP/2以提高兼容性
@@ -90,13 +171,18 @@ class HTTPClientPool:
pool=config.http_pool_timeout,
),
limits=httpx.Limits(
max_connections=100, # 最大连接数
max_keepalive_connections=20, # 最大保活连接数
keepalive_expiry=30.0, # 保活过期时间(秒)
max_connections=config.http_max_connections,
max_keepalive_connections=config.http_keepalive_connections,
keepalive_expiry=config.http_keepalive_expiry,
),
follow_redirects=True, # 跟随重定向
)
logger.info("全局HTTP客户端池已初始化")
logger.info(
f"全局HTTP客户端池已初始化: "
f"max_connections={config.http_max_connections}, "
f"keepalive={config.http_keepalive_connections}, "
f"keepalive_expiry={config.http_keepalive_expiry}s"
)
return cls._default_client
@classmethod
@@ -130,6 +216,101 @@ class HTTPClientPool:
return cls._clients[name]
@classmethod
def _get_proxy_clients_lock(cls) -> asyncio.Lock:
"""获取代理客户端缓存锁(模块级单例,避免竞态条件)"""
return _proxy_clients_lock
@classmethod
async def _evict_lru_proxy_client(cls) -> None:
"""淘汰最久未使用的代理客户端"""
if len(cls._proxy_clients) < cls._max_proxy_clients:
return
# 找到最久未使用的客户端
oldest_key = min(cls._proxy_clients.keys(), key=lambda k: cls._proxy_clients[k][1])
old_client, _ = cls._proxy_clients.pop(oldest_key)
# 异步关闭旧客户端
try:
await old_client.aclose()
logger.debug(f"淘汰代理客户端: {oldest_key}")
except Exception as e:
logger.warning(f"关闭代理客户端失败: {e}")
@classmethod
async def get_proxy_client(
cls,
proxy_config: Optional[Dict[str, Any]] = None,
) -> httpx.AsyncClient:
"""
获取代理客户端(带缓存复用)
相同代理配置会复用同一个客户端,大幅减少连接建立开销。
注意:返回的客户端使用默认超时配置,如需自定义超时请在请求时传递 timeout 参数。
Args:
proxy_config: 代理配置字典,包含 url, username, password
Returns:
可复用的 httpx.AsyncClient 实例
"""
cache_key = _compute_proxy_cache_key(proxy_config)
# 无代理时返回默认客户端
if cache_key == "__no_proxy__":
return await cls.get_default_client_async()
lock = cls._get_proxy_clients_lock()
async with lock:
# 检查缓存
if cache_key in cls._proxy_clients:
client, _ = cls._proxy_clients[cache_key]
# 健康检查:如果客户端已关闭,移除并重新创建
if client.is_closed:
del cls._proxy_clients[cache_key]
logger.debug(f"代理客户端已关闭,将重新创建: {cache_key}")
else:
# 更新最后使用时间
cls._proxy_clients[cache_key] = (client, time.time())
return client
# 淘汰旧客户端(如果超过上限)
await cls._evict_lru_proxy_client()
# 创建新客户端(使用默认超时,请求时可覆盖)
client_config: Dict[str, Any] = {
"http2": False,
"verify": True,
"follow_redirects": True,
"limits": httpx.Limits(
max_connections=config.http_max_connections,
max_keepalive_connections=config.http_keepalive_connections,
keepalive_expiry=config.http_keepalive_expiry,
),
"timeout": httpx.Timeout(
connect=config.http_connect_timeout,
read=config.http_read_timeout,
write=config.http_write_timeout,
pool=config.http_pool_timeout,
),
}
# 添加代理配置
proxy_url = build_proxy_url(proxy_config) if proxy_config else None
if proxy_url:
client_config["proxy"] = proxy_url
client = httpx.AsyncClient(**client_config)
cls._proxy_clients[cache_key] = (client, time.time())
logger.debug(
f"创建代理客户端(缓存): {proxy_config.get('url', 'unknown') if proxy_config else 'none'}, "
f"缓存数量: {len(cls._proxy_clients)}"
)
return client
@classmethod
async def close_all(cls):
"""关闭所有HTTP客户端"""
@@ -143,6 +324,16 @@ class HTTPClientPool:
logger.debug(f"命名HTTP客户端已关闭: {name}")
cls._clients.clear()
# 关闭代理客户端缓存
for cache_key, (client, _) in cls._proxy_clients.items():
try:
await client.aclose()
logger.debug(f"代理客户端已关闭: {cache_key}")
except Exception as e:
logger.warning(f"关闭代理客户端失败: {e}")
cls._proxy_clients.clear()
logger.info("所有HTTP客户端已关闭")
@classmethod
@@ -185,13 +376,15 @@ class HTTPClientPool:
"""
创建带代理配置的HTTP客户端
⚠️ 性能警告:此方法每次都创建新客户端,推荐使用 get_proxy_client() 复用连接。
Args:
proxy_config: 代理配置字典,包含 url, username, password
timeout: 超时配置
**kwargs: 其他 httpx.AsyncClient 配置参数
Returns:
配置好的 httpx.AsyncClient 实例
配置好的 httpx.AsyncClient 实例(调用者需要负责关闭)
"""
client_config: Dict[str, Any] = {
"http2": False,
@@ -213,11 +406,21 @@ class HTTPClientPool:
proxy_url = build_proxy_url(proxy_config) if proxy_config else None
if proxy_url:
client_config["proxy"] = proxy_url
logger.debug(f"创建带代理的HTTP客户端: {proxy_config.get('url', 'unknown')}")
logger.debug(f"创建带代理的HTTP客户端(一次性): {proxy_config.get('url', 'unknown')}")
client_config.update(kwargs)
return httpx.AsyncClient(**client_config)
@classmethod
def get_pool_stats(cls) -> Dict[str, Any]:
"""获取连接池统计信息"""
return {
"default_client_active": cls._default_client is not None,
"named_clients_count": len(cls._clients),
"proxy_clients_count": len(cls._proxy_clients),
"max_proxy_clients": cls._max_proxy_clients,
}
# 便捷访问函数
def get_http_client() -> httpx.AsyncClient:

View File

@@ -145,6 +145,24 @@ class Config:
self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0"))
self.http_pool_timeout = float(os.getenv("HTTP_POOL_TIMEOUT", "10.0"))
# HTTP 连接池配置
# HTTP_MAX_CONNECTIONS: 最大连接数,影响并发能力
# - 每个连接占用一个 socket过多会耗尽系统资源
# - 默认根据 Worker 数量自动计算:单 Worker 200多 Worker 按比例分配
# HTTP_KEEPALIVE_CONNECTIONS: 保活连接数,影响连接复用效率
# - 高频请求场景应该增大此值
# - 默认为 max_connections 的 30%(长连接场景更高效)
# HTTP_KEEPALIVE_EXPIRY: 保活过期时间(秒)
# - 过短会频繁重建连接,过长会占用资源
# - 默认 30 秒,生图等长连接场景可适当增大
self.http_max_connections = int(
os.getenv("HTTP_MAX_CONNECTIONS") or self._auto_http_max_connections()
)
self.http_keepalive_connections = int(
os.getenv("HTTP_KEEPALIVE_CONNECTIONS") or self._auto_http_keepalive_connections()
)
self.http_keepalive_expiry = float(os.getenv("HTTP_KEEPALIVE_EXPIRY", "30.0"))
# 流式处理配置
# STREAM_PREFETCH_LINES: 预读行数,用于检测嵌套错误
# STREAM_STATS_DELAY: 统计记录延迟(秒),等待流完全关闭
@@ -224,6 +242,53 @@ class Config:
"""智能计算最大溢出连接数 - 与 pool_size 相同"""
return self.db_pool_size
def _auto_http_max_connections(self) -> int:
"""
智能计算 HTTP 最大连接数
计算依据:
1. 系统 socket 资源有限Linux 默认 ulimit -n 通常为 1024
2. 多 Worker 部署时每个进程独立连接池
3. 需要为数据库连接、Redis 连接等预留资源
公式: base_connections / workers
- 单 Worker: 200 连接(适合开发/低负载)
- 多 Worker: 按比例分配,确保总数不超过系统限制
范围: 50 - 500
"""
# 基础连接数:假设系统可用 socket 约 800 个用于 HTTP
# (预留给 DB、Redis、内部服务等
base_connections = 800
workers = max(self.worker_processes, 1)
# 每个 Worker 分配的连接数
per_worker = base_connections // workers
# 限制范围:最小 50保证基本并发最大 500避免资源耗尽
return max(50, min(per_worker, 500))
def _auto_http_keepalive_connections(self) -> int:
"""
智能计算 HTTP 保活连接数
计算依据:
1. 保活连接用于复用,减少 TCP 握手开销
2. 对于 API 网关场景,上游请求频繁,保活比例应较高
3. 生图等长连接场景,连接会被长时间占用
公式: max_connections * 0.3
- 30% 的比例在复用效率和资源占用间取得平衡
- 长连接场景建议手动调高到 50-70%
范围: 10 - max_connections
"""
# 保活连接数为最大连接数的 30%
keepalive = int(self.http_max_connections * 0.3)
# 最小 10 个保活连接,最大不超过 max_connections
return max(10, min(keepalive, self.http_max_connections))
def _parse_ttfb_timeout(self) -> float:
"""
解析 TTFB 超时配置,带错误处理和范围限制

View File

@@ -134,7 +134,7 @@ class EndpointAPIKeyCreate(BaseModel):
"""为 Endpoint 添加 API Key"""
endpoint_id: str = Field(..., description="Endpoint ID")
api_key: str = Field(..., min_length=10, max_length=500, description="API Key将自动加密")
api_key: str = Field(..., min_length=3, max_length=500, description="API Key将自动加密")
name: str = Field(..., min_length=1, max_length=100, description="密钥名称(必填,用于识别)")
# 成本计算
@@ -175,13 +175,9 @@ class EndpointAPIKeyCreate(BaseModel):
@classmethod
def validate_api_key(cls, v: str) -> str:
"""验证 API Key 安全性"""
# 移除首尾空白
# 移除首尾空白(长度校验由 Field min_length 处理)
v = v.strip()
# 检查最小长度
if len(v) < 10:
raise ValueError("API Key 长度不能少于 10 个字符")
# 检查危险字符SQL 注入防护)
dangerous_chars = ["'", '"', ";", "--", "/*", "*/", "<", ">"]
for char in dangerous_chars:
@@ -219,7 +215,7 @@ class EndpointAPIKeyUpdate(BaseModel):
"""更新 Endpoint API Key"""
api_key: Optional[str] = Field(
default=None, min_length=10, max_length=500, description="API Key将自动加密"
default=None, min_length=3, max_length=500, description="API Key将自动加密"
)
name: Optional[str] = Field(default=None, min_length=1, max_length=100, description="密钥名称")
rate_multiplier: Optional[float] = Field(default=None, ge=0.01, description="成本倍率")

View File

@@ -8,7 +8,9 @@ import hashlib
import secrets
import time
import uuid
from collections import OrderedDict
from datetime import datetime, timedelta, timezone
from threading import Lock
from typing import TYPE_CHECKING, Any, Dict, Optional
import jwt
@@ -30,6 +32,44 @@ from src.services.cache.user_cache import UserCacheService
from src.services.user.apikey import ApiKeyService
# API Key last_used_at 更新节流配置
# 同一个 API Key 在此时间间隔内只会更新一次 last_used_at
_LAST_USED_UPDATE_INTERVAL = 60 # 秒
_LAST_USED_CACHE_MAX_SIZE = 10000 # LRU 缓存最大条目数
# 进程内缓存:记录每个 API Key 最后一次更新 last_used_at 的时间
# 使用 OrderedDict 实现 LRU避免内存无限增长
_api_key_last_update_times: OrderedDict[str, float] = OrderedDict()
_last_update_lock = Lock()
def _should_update_last_used(api_key_id: str) -> bool:
"""判断是否应该更新 API Key 的 last_used_at
使用节流策略,同一个 Key 在指定间隔内只更新一次。
线程安全,使用 LRU 策略限制缓存大小。
Returns:
True 表示应该更新False 表示跳过
"""
now = time.time()
with _last_update_lock:
last_update = _api_key_last_update_times.get(api_key_id, 0)
if now - last_update >= _LAST_USED_UPDATE_INTERVAL:
_api_key_last_update_times[api_key_id] = now
# LRU: 移到末尾(最近使用)
_api_key_last_update_times.move_to_end(api_key_id)
# 超过最大容量时,移除最旧的条目
while len(_api_key_last_update_times) > _LAST_USED_CACHE_MAX_SIZE:
_api_key_last_update_times.popitem(last=False)
return True
return False
# JWT配置从config读取
if not config.jwt_secret_key:
# 如果没有配置,生成一个随机密钥并警告
@@ -367,9 +407,10 @@ class AuthService:
logger.warning(f"API认证失败 - 用户已禁用: {user.email}")
return None
# 更新最后使用时间
key_record.last_used_at = datetime.now(timezone.utc)
db.commit() # 立即提交事务,释放数据库锁,避免阻塞后续请求
# 更新最后使用时间(使用节流策略,减少数据库写入)
if _should_update_last_used(key_record.id):
key_record.last_used_at = datetime.now(timezone.utc)
db.commit() # 立即提交事务,释放数据库锁,避免阻塞后续请求
api_key_fp = hashlib.sha256(api_key.encode()).hexdigest()[:12]
logger.debug("API认证成功: 用户 {} (api_key_fp={})", user.email, api_key_fp)

View File

@@ -121,11 +121,13 @@ class CacheAwareScheduler:
PRIORITY_MODE_GLOBAL_KEY,
}
# 调度模式常量
SCHEDULING_MODE_FIXED_ORDER = "fixed_order" # 固定顺序模式
SCHEDULING_MODE_CACHE_AFFINITY = "cache_affinity" # 缓存亲和模式
SCHEDULING_MODE_FIXED_ORDER = "fixed_order" # 固定顺序模式:严格按优先级,忽略缓存
SCHEDULING_MODE_CACHE_AFFINITY = "cache_affinity" # 缓存亲和模式:优先缓存,同优先级哈希分散
SCHEDULING_MODE_LOAD_BALANCE = "load_balance" # 负载均衡模式:忽略缓存,同优先级随机轮换
ALLOWED_SCHEDULING_MODES = {
SCHEDULING_MODE_FIXED_ORDER,
SCHEDULING_MODE_CACHE_AFFINITY,
SCHEDULING_MODE_LOAD_BALANCE,
}
def __init__(
@@ -680,8 +682,9 @@ class CacheAwareScheduler:
f"(api_format={target_format.value}, model={model_name})"
)
# 4. 应用缓存亲和性排序(仅在缓存亲和模式下启用)
# 4. 根据调度模式应用不同的排序策略
if self.scheduling_mode == self.SCHEDULING_MODE_CACHE_AFFINITY:
# 缓存亲和模式:优先使用缓存的,同优先级内哈希分散
if affinity_key and candidates:
candidates = await self._apply_cache_affinity(
candidates=candidates,
@@ -689,8 +692,13 @@ class CacheAwareScheduler:
api_format=target_format,
global_model_id=global_model_id,
)
elif self.scheduling_mode == self.SCHEDULING_MODE_LOAD_BALANCE:
# 负载均衡模式:忽略缓存,同优先级内随机轮换
candidates = self._apply_load_balance(candidates)
for candidate in candidates:
candidate.is_cached = False
else:
# 固定顺序模式:标记所有候选为非缓存
# 固定顺序模式:严格按优先级,忽略缓存
for candidate in candidates:
candidate.is_cached = False
@@ -1163,6 +1171,57 @@ class CacheAwareScheduler:
return result
def _apply_load_balance(
self, candidates: List[ProviderCandidate]
) -> List[ProviderCandidate]:
"""
负载均衡模式:同优先级内随机轮换
排序逻辑:
1. 按优先级分组provider_priority, internal_priority 或 global_priority
2. 同优先级组内随机打乱
3. 不考虑缓存亲和性
"""
if not candidates:
return candidates
from collections import defaultdict
# 使用 tuple 作为统一的 key 类型,兼容两种模式
priority_groups: Dict[tuple, List[ProviderCandidate]] = defaultdict(list)
# 根据优先级模式选择分组方式
if self.priority_mode == self.PRIORITY_MODE_GLOBAL_KEY:
# 全局 Key 优先模式:按 global_priority 分组
for candidate in candidates:
global_priority = (
candidate.key.global_priority
if candidate.key and candidate.key.global_priority is not None
else 999999
)
priority_groups[(global_priority,)].append(candidate)
else:
# 提供商优先模式:按 (provider_priority, internal_priority) 分组
for candidate in candidates:
key = (
candidate.provider.provider_priority or 999999,
candidate.key.internal_priority if candidate.key else 999999,
)
priority_groups[key].append(candidate)
result: List[ProviderCandidate] = []
for priority in sorted(priority_groups.keys()):
group = priority_groups[priority]
if len(group) > 1:
# 同优先级内随机打乱
shuffled = list(group)
random.shuffle(shuffled)
result.extend(shuffled)
else:
result.extend(group)
return result
def _shuffle_keys_by_internal_priority(
self,
keys: List[ProviderAPIKey],

171
src/services/cache/provider_cache.py vendored Normal file
View File

@@ -0,0 +1,171 @@
"""
Provider 缓存服务 - 减少 Provider 和 ProviderAPIKey 查询
用于缓存 Provider 的 billing_type 和 ProviderAPIKey 的 rate_multiplier
这些数据在 UsageService.record_usage() 中被频繁查询但变化不频繁。
"""
from typing import Optional, Tuple
from sqlalchemy.orm import Session
from src.config.constants import CacheTTL
from src.core.cache_service import CacheService
from src.core.enums import ProviderBillingType
from src.core.logger import logger
from src.models.database import Provider, ProviderAPIKey
class ProviderCacheService:
"""Provider 缓存服务
提供 Provider 和 ProviderAPIKey 的缓存查询功能,减少数据库访问。
主要用于 UsageService 中获取费率倍数和计费类型。
"""
CACHE_TTL = CacheTTL.PROVIDER # 5 分钟
@staticmethod
async def get_provider_api_key_rate_multiplier(
db: Session, provider_api_key_id: str
) -> Optional[float]:
"""
获取 ProviderAPIKey 的 rate_multiplier带缓存
Args:
db: 数据库会话
provider_api_key_id: ProviderAPIKey ID
Returns:
rate_multiplier 或 None如果找不到
"""
cache_key = f"provider_api_key:rate_multiplier:{provider_api_key_id}"
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data is not None:
logger.debug(f"ProviderAPIKey rate_multiplier 缓存命中: {provider_api_key_id[:8]}...")
# 缓存的 "NOT_FOUND" 表示数据库中不存在
if cached_data == "NOT_FOUND":
return None
return float(cached_data)
# 2. 缓存未命中,查询数据库
provider_key = (
db.query(ProviderAPIKey.rate_multiplier)
.filter(ProviderAPIKey.id == provider_api_key_id)
.first()
)
# 3. 写入缓存
if provider_key:
rate_multiplier = provider_key.rate_multiplier or 1.0
await CacheService.set(
cache_key, rate_multiplier, ttl_seconds=ProviderCacheService.CACHE_TTL
)
logger.debug(f"ProviderAPIKey rate_multiplier 已缓存: {provider_api_key_id[:8]}...")
return rate_multiplier
else:
# 缓存负结果
await CacheService.set(
cache_key, "NOT_FOUND", ttl_seconds=ProviderCacheService.CACHE_TTL
)
return None
@staticmethod
async def get_provider_billing_type(
db: Session, provider_id: str
) -> Optional[ProviderBillingType]:
"""
获取 Provider 的 billing_type带缓存
Args:
db: 数据库会话
provider_id: Provider ID
Returns:
billing_type 或 None如果找不到
"""
cache_key = f"provider:billing_type:{provider_id}"
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data is not None:
logger.debug(f"Provider billing_type 缓存命中: {provider_id[:8]}...")
if cached_data == "NOT_FOUND":
return None
try:
return ProviderBillingType(cached_data)
except ValueError:
# 缓存值无效,删除并重新查询
await CacheService.delete(cache_key)
# 2. 缓存未命中,查询数据库
provider = (
db.query(Provider.billing_type).filter(Provider.id == provider_id).first()
)
# 3. 写入缓存
if provider:
billing_type = provider.billing_type
await CacheService.set(
cache_key, billing_type.value, ttl_seconds=ProviderCacheService.CACHE_TTL
)
logger.debug(f"Provider billing_type 已缓存: {provider_id[:8]}...")
return billing_type
else:
# 缓存负结果
await CacheService.set(
cache_key, "NOT_FOUND", ttl_seconds=ProviderCacheService.CACHE_TTL
)
return None
@staticmethod
async def get_rate_multiplier_and_free_tier(
db: Session,
provider_api_key_id: Optional[str],
provider_id: Optional[str],
) -> Tuple[float, bool]:
"""
获取费率倍数和是否免费套餐(带缓存)
这是 UsageService._get_rate_multiplier_and_free_tier 的缓存版本。
Args:
db: 数据库会话
provider_api_key_id: ProviderAPIKey ID可选
provider_id: Provider ID可选
Returns:
(rate_multiplier, is_free_tier) 元组
"""
actual_rate_multiplier = 1.0
is_free_tier = False
# 获取费率倍数
if provider_api_key_id:
rate_multiplier = await ProviderCacheService.get_provider_api_key_rate_multiplier(
db, provider_api_key_id
)
if rate_multiplier is not None:
actual_rate_multiplier = rate_multiplier
# 获取计费类型
if provider_id:
billing_type = await ProviderCacheService.get_provider_billing_type(db, provider_id)
if billing_type == ProviderBillingType.FREE_TIER:
is_free_tier = True
return actual_rate_multiplier, is_free_tier
@staticmethod
async def invalidate_provider_api_key_cache(provider_api_key_id: str) -> None:
"""清除 ProviderAPIKey 缓存"""
await CacheService.delete(f"provider_api_key:rate_multiplier:{provider_api_key_id}")
logger.debug(f"ProviderAPIKey 缓存已清除: {provider_api_key_id[:8]}...")
@staticmethod
async def invalidate_provider_cache(provider_id: str) -> None:
"""清除 Provider 缓存"""
await CacheService.delete(f"provider:billing_type:{provider_id}")
logger.debug(f"Provider 缓存已清除: {provider_id[:8]}...")

View File

@@ -85,6 +85,8 @@ class ConcurrencyManager:
"""
获取当前并发数
性能优化:使用 MGET 批量获取,减少 Redis 往返次数
Args:
endpoint_id: Endpoint ID可选
key_id: ProviderAPIKey ID可选
@@ -104,15 +106,21 @@ class ConcurrencyManager:
key_count = 0
try:
# 使用 MGET 批量获取,减少 Redis 往返2 次 GET -> 1 次 MGET
keys_to_fetch = []
if endpoint_id:
endpoint_key = self._get_endpoint_key(endpoint_id)
result = await self._redis.get(endpoint_key)
endpoint_count = int(result) if result else 0
keys_to_fetch.append(self._get_endpoint_key(endpoint_id))
if key_id:
key_key = self._get_key_key(key_id)
result = await self._redis.get(key_key)
key_count = int(result) if result else 0
keys_to_fetch.append(self._get_key_key(key_id))
if keys_to_fetch:
results = await self._redis.mget(keys_to_fetch)
idx = 0
if endpoint_id:
endpoint_count = int(results[idx]) if results[idx] else 0
idx += 1
if key_id:
key_count = int(results[idx]) if results[idx] else 0
except Exception as e:
logger.error(f"获取并发数失败: {e}")

View File

@@ -3,8 +3,9 @@
"""
import json
import time
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple
from sqlalchemy.orm import Session
@@ -20,6 +21,49 @@ class LogLevel(str, Enum):
FULL = "full" # 记录完整请求和响应包含body敏感信息会脱敏
# 进程内缓存 TTL- 系统配置变化不频繁,使用较长的 TTL
_CONFIG_CACHE_TTL = 60 # 1 分钟
# 进程内缓存存储: {key: (value, expire_time)}
_config_cache: Dict[str, Tuple[Any, float]] = {}
def _get_cached_config(key: str) -> Tuple[bool, Any]:
"""从进程内缓存获取配置值
Returns:
(hit, value): hit=True 表示缓存命中value 为缓存的值
"""
if key in _config_cache:
value, expire_time = _config_cache[key]
if time.time() < expire_time:
return True, value
# 缓存过期,安全删除(避免并发时 KeyError
_config_cache.pop(key, None)
return False, None
def _set_cached_config(key: str, value: Any) -> None:
"""设置进程内缓存"""
_config_cache[key] = (value, time.time() + _CONFIG_CACHE_TTL)
def invalidate_config_cache(key: Optional[str] = None) -> None:
"""清除配置缓存
Args:
key: 配置键,如果为 None 则清除所有缓存
"""
global _config_cache
if key is None:
_config_cache = {}
logger.debug("已清除所有系统配置缓存")
else:
# 使用 pop 安全删除,避免并发时 KeyError
if _config_cache.pop(key, None) is not None:
logger.debug(f"已清除系统配置缓存: {key}")
class SystemConfigService:
"""系统配置服务类"""
@@ -127,14 +171,23 @@ class SystemConfigService:
@classmethod
def get_config(cls, db: Session, key: str, default: Any = None) -> Optional[Any]:
"""获取系统配置值"""
"""获取系统配置值(带进程内缓存)"""
# 1. 检查进程内缓存
hit, cached_value = _get_cached_config(key)
if hit:
return cached_value
# 2. 查询数据库
config = db.query(SystemConfig).filter(SystemConfig.key == key).first()
if config:
_set_cached_config(key, config.value)
return config.value
# 如果配置不存在,检查默认值
# 3. 如果配置不存在,使用默认值
if key in cls.DEFAULT_CONFIGS:
return cls.DEFAULT_CONFIGS[key]["value"]
value = cls.DEFAULT_CONFIGS[key]["value"]
_set_cached_config(key, value)
return value
return default
@@ -185,6 +238,9 @@ class SystemConfigService:
db.commit()
db.refresh(config)
# 清除缓存
invalidate_config_cache(key)
return config
@staticmethod
@@ -243,6 +299,8 @@ class SystemConfigService:
if config:
db.delete(config)
db.commit()
# 清除缓存
invalidate_config_cache(key)
return True
return False

View File

@@ -10,7 +10,6 @@ from typing import Any, Dict, List, Optional, Tuple
from sqlalchemy import func
from sqlalchemy.orm import Session
from src.core.enums import ProviderBillingType
from src.core.logger import logger
from src.models.database import ApiKey, Provider, ProviderAPIKey, Usage, User, UserRole
from src.services.model.cost import ModelCostService
@@ -362,22 +361,12 @@ class UsageService:
provider_api_key_id: Optional[str],
provider_id: Optional[str],
) -> Tuple[float, bool]:
"""获取费率倍数和是否免费套餐"""
actual_rate_multiplier = 1.0
if provider_api_key_id:
provider_key = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.id == provider_api_key_id).first()
)
if provider_key and provider_key.rate_multiplier:
actual_rate_multiplier = provider_key.rate_multiplier
"""获取费率倍数和是否免费套餐(使用缓存)"""
from src.services.cache.provider_cache import ProviderCacheService
is_free_tier = False
if provider_id:
provider_obj = db.query(Provider).filter(Provider.id == provider_id).first()
if provider_obj and provider_obj.billing_type == ProviderBillingType.FREE_TIER:
is_free_tier = True
return actual_rate_multiplier, is_free_tier
return await ProviderCacheService.get_rate_multiplier_and_free_tier(
db, provider_api_key_id, provider_id
)
@classmethod
async def _calculate_costs(