mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-09 19:22:26 +08:00
Compare commits
13 Commits
v0.2.5
...
perf/optim
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ea35efe440 | ||
|
|
bf09e740e9 | ||
|
|
60c77cec56 | ||
|
|
0e4a1dddb5 | ||
|
|
1cf18b6e12 | ||
|
|
f9a8be898a | ||
|
|
1521ce5a96 | ||
|
|
f2e62dd197 | ||
|
|
d378630b38 | ||
|
|
d9e6346911 | ||
|
|
238788e0e9 | ||
|
|
68ff828505 | ||
|
|
59447fc12b |
23
.github/workflows/docker-publish.yml
vendored
23
.github/workflows/docker-publish.yml
vendored
@@ -146,10 +146,33 @@ jobs:
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=sha,prefix=
|
||||
|
||||
- name: Extract version from tag
|
||||
id: version
|
||||
run: |
|
||||
# 从 tag 提取版本号,如 v0.2.5 -> 0.2.5
|
||||
VERSION="${GITHUB_REF#refs/tags/v}"
|
||||
if [ "$VERSION" = "$GITHUB_REF" ]; then
|
||||
# 不是 tag 触发,使用 git describe
|
||||
VERSION=$(git describe --tags --always | sed 's/^v//')
|
||||
fi
|
||||
echo "version=$VERSION" >> $GITHUB_OUTPUT
|
||||
echo "Extracted version: $VERSION"
|
||||
|
||||
- name: Update Dockerfile.app to use registry base image
|
||||
run: |
|
||||
sed -i "s|FROM aether-base:latest AS builder|FROM ${{ env.REGISTRY }}/${{ env.BASE_IMAGE_NAME }}:latest AS builder|g" Dockerfile.app
|
||||
|
||||
- name: Generate version file
|
||||
run: |
|
||||
# 生成 _version.py 文件
|
||||
cat > src/_version.py << EOF
|
||||
# Auto-generated by CI
|
||||
__version__ = '${{ steps.version.outputs.version }}'
|
||||
__version_tuple__ = tuple(int(x) for x in '${{ steps.version.outputs.version }}'.split('.') if x.isdigit())
|
||||
version = __version__
|
||||
version_tuple = __version_tuple__
|
||||
EOF
|
||||
|
||||
- name: Build and push app image
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -224,3 +224,6 @@ extracted_*.ts
|
||||
.deps-hash
|
||||
.code-hash
|
||||
.migration-hash
|
||||
|
||||
# Version file (auto-generated by hatch-vcs)
|
||||
src/_version.py
|
||||
|
||||
@@ -127,14 +127,14 @@ RUN printf '%s\n' \
|
||||
'pidfile=/var/run/supervisord.pid' \
|
||||
'' \
|
||||
'[program:nginx]' \
|
||||
'command=/bin/bash -c "sed \"s/PORT_PLACEHOLDER/${PORT:-8084}/g\" /etc/nginx/sites-available/default.template > /etc/nginx/sites-available/default && /usr/sbin/nginx -g \"daemon off;\""' \
|
||||
'command=/bin/bash -c "sed \"s/PORT_PLACEHOLDER/8084/g\" /etc/nginx/sites-available/default.template > /etc/nginx/sites-available/default && /usr/sbin/nginx -g \"daemon off;\""' \
|
||||
'autostart=true' \
|
||||
'autorestart=true' \
|
||||
'stdout_logfile=/var/log/nginx/access.log' \
|
||||
'stderr_logfile=/var/log/nginx/error.log' \
|
||||
'' \
|
||||
'[program:app]' \
|
||||
'command=gunicorn src.main:app --preload -w %(ENV_GUNICORN_WORKERS)s -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:%(ENV_PORT)s --timeout 120 --access-logfile - --error-logfile - --log-level info' \
|
||||
'command=gunicorn src.main:app --preload -w %(ENV_GUNICORN_WORKERS)s -k uvicorn.workers.UvicornWorker --bind 127.0.0.1:8084 --timeout 120 --access-logfile - --error-logfile - --log-level info' \
|
||||
'directory=/app' \
|
||||
'autostart=true' \
|
||||
'autorestart=true' \
|
||||
@@ -147,6 +147,10 @@ RUN printf '%s\n' \
|
||||
# 创建目录
|
||||
RUN mkdir -p /var/log/supervisor /app/logs /app/data
|
||||
|
||||
# 入口脚本(启动前执行迁移)
|
||||
COPY entrypoint.sh /entrypoint.sh
|
||||
RUN chmod +x /entrypoint.sh
|
||||
|
||||
# 环境变量
|
||||
ENV PYTHONUNBUFFERED=1 \
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
@@ -161,4 +165,5 @@ EXPOSE 80
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD curl -f http://localhost/health || exit 1
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]
|
||||
|
||||
@@ -139,6 +139,10 @@ RUN printf '%s\n' \
|
||||
# 创建目录
|
||||
RUN mkdir -p /var/log/supervisor /app/logs /app/data
|
||||
|
||||
# 入口脚本(启动前执行迁移)
|
||||
COPY entrypoint.sh /entrypoint.sh
|
||||
RUN chmod +x /entrypoint.sh
|
||||
|
||||
# 环境变量
|
||||
ENV PYTHONUNBUFFERED=1 \
|
||||
PYTHONDONTWRITEBYTECODE=1 \
|
||||
@@ -152,4 +156,5 @@ EXPOSE 80
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD curl -f http://localhost/health || exit 1
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]
|
||||
|
||||
10
README.md
10
README.md
@@ -57,14 +57,8 @@ cd Aether
|
||||
cp .env.example .env
|
||||
python generate_keys.py # 生成密钥, 并将生成的密钥填入 .env
|
||||
|
||||
# 3. 部署
|
||||
docker compose up -d
|
||||
|
||||
# 4. 首次部署时, 初始化数据库
|
||||
./migrate.sh
|
||||
|
||||
# 5. 更新
|
||||
docker compose pull && docker compose up -d && ./migrate.sh
|
||||
# 3. 部署 / 更新(自动执行数据库迁移)
|
||||
docker compose pull && docker compose up -d
|
||||
```
|
||||
|
||||
### Docker Compose(本地构建镜像)
|
||||
|
||||
19
deploy.sh
19
deploy.sh
@@ -88,9 +88,28 @@ build_base() {
|
||||
save_deps_hash
|
||||
}
|
||||
|
||||
# 生成版本文件
|
||||
generate_version_file() {
|
||||
# 从 git 获取版本号
|
||||
local version
|
||||
version=$(git describe --tags --always 2>/dev/null | sed 's/^v//')
|
||||
if [ -z "$version" ]; then
|
||||
version="unknown"
|
||||
fi
|
||||
echo ">>> Generating version file: $version"
|
||||
cat > src/_version.py << EOF
|
||||
# Auto-generated by deploy.sh - do not edit
|
||||
__version__ = '$version'
|
||||
__version_tuple__ = tuple(int(x) for x in '$version'.split('-')[0].split('.') if x.isdigit())
|
||||
version = __version__
|
||||
version_tuple = __version_tuple__
|
||||
EOF
|
||||
}
|
||||
|
||||
# 构建应用镜像
|
||||
build_app() {
|
||||
echo ">>> Building app image (code only)..."
|
||||
generate_version_file
|
||||
docker build -f Dockerfile.app.local -t aether-app:latest .
|
||||
save_code_hash
|
||||
}
|
||||
|
||||
8
entrypoint.sh
Normal file
8
entrypoint.sh
Normal file
@@ -0,0 +1,8 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
echo "Running database migrations..."
|
||||
alembic upgrade head
|
||||
|
||||
echo "Starting application..."
|
||||
exec "$@"
|
||||
@@ -159,6 +159,15 @@ export interface EmailTemplateResetResponse {
|
||||
}
|
||||
}
|
||||
|
||||
// 检查更新响应
|
||||
export interface CheckUpdateResponse {
|
||||
current_version: string
|
||||
latest_version: string | null
|
||||
has_update: boolean
|
||||
release_url: string | null
|
||||
error: string | null
|
||||
}
|
||||
|
||||
// LDAP 配置响应
|
||||
export interface LdapConfigResponse {
|
||||
server_url: string | null
|
||||
@@ -526,6 +535,14 @@ export const adminApi = {
|
||||
return response.data
|
||||
},
|
||||
|
||||
// 检查系统更新
|
||||
async checkUpdate(): Promise<CheckUpdateResponse> {
|
||||
const response = await apiClient.get<CheckUpdateResponse>(
|
||||
'/api/admin/system/check-update'
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
// LDAP 配置相关
|
||||
// 获取 LDAP 配置
|
||||
async getLdapConfig(): Promise<LdapConfigResponse> {
|
||||
|
||||
112
frontend/src/components/common/UpdateDialog.vue
Normal file
112
frontend/src/components/common/UpdateDialog.vue
Normal file
@@ -0,0 +1,112 @@
|
||||
<template>
|
||||
<Dialog
|
||||
v-model="isOpen"
|
||||
size="md"
|
||||
title=""
|
||||
>
|
||||
<div class="flex flex-col items-center text-center py-2">
|
||||
<!-- Logo -->
|
||||
<HeaderLogo
|
||||
size="h-16 w-16"
|
||||
class-name="text-primary"
|
||||
/>
|
||||
|
||||
<!-- Title -->
|
||||
<h2 class="text-xl font-semibold text-foreground mt-4 mb-2">
|
||||
发现新版本
|
||||
</h2>
|
||||
|
||||
<!-- Version Info -->
|
||||
<div class="flex items-center gap-3 mb-4">
|
||||
<span class="px-3 py-1.5 rounded-lg bg-muted text-sm font-mono text-muted-foreground">
|
||||
v{{ currentVersion }}
|
||||
</span>
|
||||
<svg
|
||||
class="h-4 w-4 text-muted-foreground"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
viewBox="0 0 24 24"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M13 7l5 5m0 0l-5 5m5-5H6"
|
||||
/>
|
||||
</svg>
|
||||
<span class="px-3 py-1.5 rounded-lg bg-primary/10 text-sm font-mono font-medium text-primary">
|
||||
v{{ latestVersion }}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Description -->
|
||||
<p class="text-sm text-muted-foreground max-w-xs">
|
||||
新版本已发布,建议更新以获得最新功能和安全修复
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<template #footer>
|
||||
<div class="flex w-full gap-3">
|
||||
<Button
|
||||
variant="outline"
|
||||
class="flex-1"
|
||||
@click="handleLater"
|
||||
>
|
||||
稍后提醒
|
||||
</Button>
|
||||
<Button
|
||||
class="flex-1"
|
||||
@click="handleViewRelease"
|
||||
>
|
||||
查看更新
|
||||
</Button>
|
||||
</div>
|
||||
</template>
|
||||
</Dialog>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, watch } from 'vue'
|
||||
import { Dialog } from '@/components/ui'
|
||||
import Button from '@/components/ui/button.vue'
|
||||
import HeaderLogo from '@/components/HeaderLogo.vue'
|
||||
|
||||
const props = defineProps<{
|
||||
modelValue: boolean
|
||||
currentVersion: string
|
||||
latestVersion: string
|
||||
releaseUrl: string | null
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
'update:modelValue': [value: boolean]
|
||||
}>()
|
||||
|
||||
const isOpen = ref(props.modelValue)
|
||||
|
||||
watch(() => props.modelValue, (val) => {
|
||||
isOpen.value = val
|
||||
})
|
||||
|
||||
watch(isOpen, (val) => {
|
||||
emit('update:modelValue', val)
|
||||
})
|
||||
|
||||
function handleLater() {
|
||||
// 记录忽略的版本,24小时内不再提醒
|
||||
const ignoreKey = 'aether_update_ignore'
|
||||
const ignoreData = {
|
||||
version: props.latestVersion,
|
||||
until: Date.now() + 24 * 60 * 60 * 1000 // 24小时
|
||||
}
|
||||
localStorage.setItem(ignoreKey, JSON.stringify(ignoreData))
|
||||
isOpen.value = false
|
||||
}
|
||||
|
||||
function handleViewRelease() {
|
||||
if (props.releaseUrl) {
|
||||
window.open(props.releaseUrl, '_blank')
|
||||
}
|
||||
isOpen.value = false
|
||||
}
|
||||
</script>
|
||||
@@ -531,20 +531,23 @@ watch(() => props.open, async (isOpen) => {
|
||||
// 加载数据
|
||||
async function loadData() {
|
||||
await Promise.all([loadGlobalModels(), loadExistingModels()])
|
||||
// 默认折叠全局模型组
|
||||
collapsedGroups.value = new Set(['global'])
|
||||
|
||||
// 检查缓存,如果有缓存数据则直接使用
|
||||
const cachedModels = getCachedModels(props.providerId)
|
||||
if (cachedModels) {
|
||||
if (cachedModels && cachedModels.length > 0) {
|
||||
upstreamModels.value = cachedModels
|
||||
upstreamModelsLoaded.value = true
|
||||
// 折叠所有上游模型组
|
||||
// 有多个分组时全部折叠
|
||||
const allGroups = new Set(['global'])
|
||||
for (const model of cachedModels) {
|
||||
if (model.api_format) {
|
||||
collapsedGroups.value.add(model.api_format)
|
||||
allGroups.add(model.api_format)
|
||||
}
|
||||
}
|
||||
collapsedGroups.value = allGroups
|
||||
} else {
|
||||
// 只有全局模型时展开
|
||||
collapsedGroups.value = new Set()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -585,8 +588,8 @@ async function fetchUpstreamModels(forceRefresh = false) {
|
||||
} else {
|
||||
upstreamModels.value = result.models
|
||||
upstreamModelsLoaded.value = true
|
||||
// 折叠所有上游模型组
|
||||
const allGroups = new Set(collapsedGroups.value)
|
||||
// 有多个分组时全部折叠
|
||||
const allGroups = new Set(['global'])
|
||||
for (const model of result.models) {
|
||||
if (model.api_format) {
|
||||
allGroups.add(model.api_format)
|
||||
|
||||
@@ -20,44 +20,8 @@
|
||||
API 配置
|
||||
</h3>
|
||||
|
||||
<!-- API URL 和自定义路径 -->
|
||||
<div class="grid grid-cols-2 gap-4">
|
||||
<!-- API 格式 -->
|
||||
<div class="space-y-2">
|
||||
<Label for="api_format">API 格式 *</Label>
|
||||
<template v-if="isEditMode">
|
||||
<Input
|
||||
id="api_format"
|
||||
v-model="form.api_format"
|
||||
disabled
|
||||
class="bg-muted"
|
||||
/>
|
||||
<p class="text-xs text-muted-foreground">
|
||||
API 格式创建后不可修改
|
||||
</p>
|
||||
</template>
|
||||
<template v-else>
|
||||
<Select
|
||||
v-model="form.api_format"
|
||||
v-model:open="selectOpen"
|
||||
required
|
||||
>
|
||||
<SelectTrigger>
|
||||
<SelectValue placeholder="请选择 API 格式" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem
|
||||
v-for="format in apiFormats"
|
||||
:key="format.value"
|
||||
:value="format.value"
|
||||
>
|
||||
{{ format.label }}
|
||||
</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</template>
|
||||
</div>
|
||||
|
||||
<!-- API URL -->
|
||||
<div class="space-y-2">
|
||||
<Label for="base_url">API URL *</Label>
|
||||
<Input
|
||||
@@ -67,16 +31,70 @@
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div class="space-y-2">
|
||||
<Label for="custom_path">自定义请求路径(可选)</Label>
|
||||
<Input
|
||||
id="custom_path"
|
||||
v-model="form.custom_path"
|
||||
:placeholder="isEditMode ? defaultPathPlaceholder : '留空使用各格式的默认路径'"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 自定义路径 -->
|
||||
<!-- API 格式 -->
|
||||
<div class="space-y-2">
|
||||
<Label for="custom_path">自定义请求路径(可选)</Label>
|
||||
<Input
|
||||
id="custom_path"
|
||||
v-model="form.custom_path"
|
||||
:placeholder="defaultPathPlaceholder"
|
||||
/>
|
||||
<Label for="api_format">API 格式 *</Label>
|
||||
<template v-if="isEditMode">
|
||||
<Input
|
||||
id="api_format"
|
||||
v-model="form.api_format"
|
||||
disabled
|
||||
class="bg-muted"
|
||||
/>
|
||||
<p class="text-xs text-muted-foreground">
|
||||
API 格式创建后不可修改
|
||||
</p>
|
||||
</template>
|
||||
<template v-else>
|
||||
<div class="grid grid-cols-3 grid-flow-col grid-rows-2 gap-2">
|
||||
<label
|
||||
v-for="format in sortedApiFormats"
|
||||
:key="format.value"
|
||||
class="flex items-center gap-2 rounded-md border px-3 py-2 cursor-pointer transition-all text-sm"
|
||||
:class="selectedFormats.includes(format.value)
|
||||
? 'border-primary bg-primary/10 text-primary font-medium'
|
||||
: 'border-border hover:border-primary/50 hover:bg-accent'"
|
||||
>
|
||||
<input
|
||||
type="checkbox"
|
||||
:value="format.value"
|
||||
v-model="selectedFormats"
|
||||
class="sr-only"
|
||||
/>
|
||||
<span
|
||||
class="flex h-4 w-4 shrink-0 items-center justify-center rounded border transition-colors"
|
||||
:class="selectedFormats.includes(format.value)
|
||||
? 'border-primary bg-primary text-primary-foreground'
|
||||
: 'border-muted-foreground/30'"
|
||||
>
|
||||
<svg
|
||||
v-if="selectedFormats.includes(format.value)"
|
||||
class="h-3 w-3"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="3"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
>
|
||||
<polyline points="20 6 9 17 4 12" />
|
||||
</svg>
|
||||
</span>
|
||||
<span>{{ format.label }}</span>
|
||||
</label>
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -86,7 +104,7 @@
|
||||
请求配置
|
||||
</h3>
|
||||
|
||||
<div class="grid grid-cols-3 gap-4">
|
||||
<div class="grid grid-cols-4 gap-4">
|
||||
<div class="space-y-2">
|
||||
<Label for="timeout">超时(秒)</Label>
|
||||
<Input
|
||||
@@ -117,11 +135,9 @@
|
||||
@update:model-value="(v) => form.max_concurrent = parseNumberInput(v)"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="grid grid-cols-2 gap-4">
|
||||
<div class="space-y-2">
|
||||
<Label for="rate_limit">速率限制(请求/分钟)</Label>
|
||||
<Label for="rate_limit">速率限制(/分钟)</Label>
|
||||
<Input
|
||||
id="rate_limit"
|
||||
:model-value="form.rate_limit ?? ''"
|
||||
@@ -217,10 +233,10 @@
|
||||
取消
|
||||
</Button>
|
||||
<Button
|
||||
:disabled="loading || !form.base_url || (!isEditMode && !form.api_format)"
|
||||
:disabled="loading || !form.base_url || (!isEditMode && selectedFormats.length === 0)"
|
||||
@click="handleSubmit()"
|
||||
>
|
||||
{{ loading ? (isEditMode ? '保存中...' : '创建中...') : (isEditMode ? '保存修改' : '创建') }}
|
||||
{{ loading ? (isEditMode ? '保存中...' : '创建中...') : (isEditMode ? '保存修改' : `创建 ${selectedFormats.length} 个端点`) }}
|
||||
</Button>
|
||||
</template>
|
||||
</Dialog>
|
||||
@@ -245,11 +261,6 @@ import {
|
||||
Button,
|
||||
Input,
|
||||
Label,
|
||||
Select,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
Switch,
|
||||
} from '@/components/ui'
|
||||
import AlertDialog from '@/components/common/AlertDialog.vue'
|
||||
@@ -280,7 +291,6 @@ const emit = defineEmits<{
|
||||
|
||||
const { success, error: showError } = useToast()
|
||||
const loading = ref(false)
|
||||
const selectOpen = ref(false)
|
||||
const proxyEnabled = ref(false)
|
||||
const showClearCredentialsDialog = ref(false) // 确认清空凭据对话框
|
||||
|
||||
@@ -296,7 +306,7 @@ const form = ref({
|
||||
base_url: '',
|
||||
custom_path: '',
|
||||
timeout: 300,
|
||||
max_retries: 3,
|
||||
max_retries: 2,
|
||||
max_concurrent: undefined as number | undefined,
|
||||
rate_limit: undefined as number | undefined,
|
||||
is_active: true,
|
||||
@@ -306,9 +316,28 @@ const form = ref({
|
||||
proxy_password: '',
|
||||
})
|
||||
|
||||
// 选中的 API 格式(多选)
|
||||
const selectedFormats = ref<string[]>([])
|
||||
|
||||
// API 格式列表
|
||||
const apiFormats = ref<Array<{ value: string; label: string; default_path: string; aliases: string[] }>>([])
|
||||
|
||||
// 排序后的 API 格式:按列排列,每列是基础格式+CLI格式
|
||||
const sortedApiFormats = computed(() => {
|
||||
const baseFormats = apiFormats.value.filter(f => !f.value.endsWith('_cli'))
|
||||
const cliFormats = apiFormats.value.filter(f => f.value.endsWith('_cli'))
|
||||
// 交错排列:base1, cli1, base2, cli2, base3, cli3
|
||||
const result: typeof apiFormats.value = []
|
||||
for (let i = 0; i < baseFormats.length; i++) {
|
||||
result.push(baseFormats[i])
|
||||
const cliFormat = cliFormats.find(f => f.value === baseFormats[i].value + '_cli')
|
||||
if (cliFormat) {
|
||||
result.push(cliFormat)
|
||||
}
|
||||
}
|
||||
return result
|
||||
})
|
||||
|
||||
// 加载API格式列表
|
||||
const loadApiFormats = async () => {
|
||||
try {
|
||||
@@ -330,7 +359,7 @@ const defaultPath = computed(() => {
|
||||
|
||||
// 动态 placeholder
|
||||
const defaultPathPlaceholder = computed(() => {
|
||||
return `留空使用默认路径:${defaultPath.value}`
|
||||
return defaultPath.value
|
||||
})
|
||||
|
||||
// 检查是否有已保存的密码(后端返回 *** 表示有密码)
|
||||
@@ -392,7 +421,7 @@ function resetForm() {
|
||||
base_url: '',
|
||||
custom_path: '',
|
||||
timeout: 300,
|
||||
max_retries: 3,
|
||||
max_retries: 2,
|
||||
max_concurrent: undefined,
|
||||
rate_limit: undefined,
|
||||
is_active: true,
|
||||
@@ -400,6 +429,7 @@ function resetForm() {
|
||||
proxy_username: '',
|
||||
proxy_password: '',
|
||||
}
|
||||
selectedFormats.value = []
|
||||
proxyEnabled.value = false
|
||||
}
|
||||
|
||||
@@ -479,6 +509,8 @@ const handleSubmit = async (skipCredentialCheck = false) => {
|
||||
}
|
||||
|
||||
loading.value = true
|
||||
let successCount = 0
|
||||
|
||||
try {
|
||||
const proxyConfig = buildProxyConfig()
|
||||
|
||||
@@ -497,27 +529,56 @@ const handleSubmit = async (skipCredentialCheck = false) => {
|
||||
|
||||
success('端点已更新', '保存成功')
|
||||
emit('endpointUpdated')
|
||||
emit('update:modelValue', false)
|
||||
} else if (props.provider) {
|
||||
// 创建端点
|
||||
await createEndpoint(props.provider.id, {
|
||||
provider_id: props.provider.id,
|
||||
api_format: form.value.api_format,
|
||||
base_url: form.value.base_url,
|
||||
custom_path: form.value.custom_path || undefined,
|
||||
timeout: form.value.timeout,
|
||||
max_retries: form.value.max_retries,
|
||||
max_concurrent: form.value.max_concurrent,
|
||||
rate_limit: form.value.rate_limit,
|
||||
is_active: form.value.is_active,
|
||||
proxy: proxyConfig,
|
||||
// 批量创建端点 - 使用并发请求提升性能
|
||||
const results = await Promise.allSettled(
|
||||
selectedFormats.value.map(apiFormat =>
|
||||
createEndpoint(props.provider!.id, {
|
||||
provider_id: props.provider!.id,
|
||||
api_format: apiFormat,
|
||||
base_url: form.value.base_url,
|
||||
custom_path: form.value.custom_path || undefined,
|
||||
timeout: form.value.timeout,
|
||||
max_retries: form.value.max_retries,
|
||||
max_concurrent: form.value.max_concurrent,
|
||||
rate_limit: form.value.rate_limit,
|
||||
is_active: form.value.is_active,
|
||||
proxy: proxyConfig,
|
||||
})
|
||||
)
|
||||
)
|
||||
|
||||
// 统计结果
|
||||
const errors: string[] = []
|
||||
results.forEach((result, index) => {
|
||||
if (result.status === 'fulfilled') {
|
||||
successCount++
|
||||
} else {
|
||||
const apiFormat = selectedFormats.value[index]
|
||||
const formatLabel = apiFormats.value.find((f: any) => f.value === apiFormat)?.label || apiFormat
|
||||
const errorMsg = result.reason?.response?.data?.detail || '创建失败'
|
||||
errors.push(`${formatLabel}: ${errorMsg}`)
|
||||
}
|
||||
})
|
||||
|
||||
success('端点创建成功', '成功')
|
||||
emit('endpointCreated')
|
||||
resetForm()
|
||||
}
|
||||
const failCount = errors.length
|
||||
|
||||
emit('update:modelValue', false)
|
||||
// 显示结果
|
||||
if (successCount > 0 && failCount === 0) {
|
||||
success(`成功创建 ${successCount} 个端点`, '创建成功')
|
||||
} else if (successCount > 0 && failCount > 0) {
|
||||
showError(`${failCount} 个端点创建失败:\n${errors.join('\n')}`, `${successCount} 个成功,${failCount} 个失败`)
|
||||
} else {
|
||||
showError(errors.join('\n') || '创建端点失败', '创建失败')
|
||||
}
|
||||
|
||||
if (successCount > 0) {
|
||||
emit('endpointCreated')
|
||||
resetForm()
|
||||
emit('update:modelValue', false)
|
||||
}
|
||||
}
|
||||
} catch (error: any) {
|
||||
const action = isEditMode.value ? '更新' : '创建'
|
||||
showError(error.response?.data?.detail || `${action}端点失败`, '错误')
|
||||
|
||||
@@ -32,11 +32,11 @@
|
||||
v-for="modelName in selectedModels"
|
||||
:key="modelName"
|
||||
variant="secondary"
|
||||
class="text-[11px] px-2 py-0.5 bg-background border-border/60 shadow-sm"
|
||||
class="text-[11px] px-2 py-0.5 bg-background border-border/60 shadow-sm text-foreground dark:text-white"
|
||||
>
|
||||
{{ getModelLabel(modelName) }}
|
||||
<button
|
||||
class="ml-0.5 hover:text-destructive focus:outline-none"
|
||||
class="ml-0.5 hover:text-destructive focus:outline-none text-foreground dark:text-white"
|
||||
@click.stop="toggleModel(modelName, false)"
|
||||
>
|
||||
×
|
||||
|
||||
@@ -349,8 +349,8 @@ const apiKeyError = computed(() => {
|
||||
}
|
||||
|
||||
// 如果输入了值,检查长度
|
||||
if (apiKey.length < 10) {
|
||||
return 'API 密钥至少需要 10 个字符'
|
||||
if (apiKey.length < 3) {
|
||||
return 'API 密钥至少需要 3 个字符'
|
||||
}
|
||||
|
||||
return ''
|
||||
|
||||
@@ -262,17 +262,17 @@
|
||||
<div class="shrink-0 flex items-center gap-3">
|
||||
<!-- 健康度 -->
|
||||
<div
|
||||
v-if="key.success_rate !== null"
|
||||
v-if="key.health_score != null"
|
||||
class="text-xs text-right"
|
||||
>
|
||||
<div
|
||||
class="font-medium tabular-nums"
|
||||
:class="[
|
||||
key.success_rate >= 0.95 ? 'text-green-600' :
|
||||
key.success_rate >= 0.8 ? 'text-yellow-600' : 'text-red-500'
|
||||
key.health_score >= 0.95 ? 'text-green-600' :
|
||||
key.health_score >= 0.5 ? 'text-yellow-600' : 'text-red-500'
|
||||
]"
|
||||
>
|
||||
{{ (key.success_rate * 100).toFixed(0) }}%
|
||||
{{ ((key.health_score || 0) * 100).toFixed(0) }}%
|
||||
</div>
|
||||
<div class="text-[10px] text-muted-foreground opacity-70">
|
||||
{{ key.request_count }} reqs
|
||||
@@ -319,19 +319,6 @@
|
||||
<div class="flex items-center gap-2 pl-4 border-l border-border">
|
||||
<span class="text-xs text-muted-foreground">调度:</span>
|
||||
<div class="flex gap-0.5 p-0.5 bg-muted/40 rounded-md">
|
||||
<button
|
||||
type="button"
|
||||
class="px-2 py-1 text-xs font-medium rounded transition-all"
|
||||
:class="[
|
||||
schedulingMode === 'fixed_order'
|
||||
? 'bg-primary text-primary-foreground shadow-sm'
|
||||
: 'text-muted-foreground hover:text-foreground hover:bg-muted/50'
|
||||
]"
|
||||
title="严格按优先级顺序,不考虑缓存"
|
||||
@click="schedulingMode = 'fixed_order'"
|
||||
>
|
||||
固定顺序
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
class="px-2 py-1 text-xs font-medium rounded transition-all"
|
||||
@@ -345,6 +332,32 @@
|
||||
>
|
||||
缓存亲和
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
class="px-2 py-1 text-xs font-medium rounded transition-all"
|
||||
:class="[
|
||||
schedulingMode === 'load_balance'
|
||||
? 'bg-primary text-primary-foreground shadow-sm'
|
||||
: 'text-muted-foreground hover:text-foreground hover:bg-muted/50'
|
||||
]"
|
||||
title="同优先级内随机轮换,不考虑缓存"
|
||||
@click="schedulingMode = 'load_balance'"
|
||||
>
|
||||
负载均衡
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
class="px-2 py-1 text-xs font-medium rounded transition-all"
|
||||
:class="[
|
||||
schedulingMode === 'fixed_order'
|
||||
? 'bg-primary text-primary-foreground shadow-sm'
|
||||
: 'text-muted-foreground hover:text-foreground hover:bg-muted/50'
|
||||
]"
|
||||
title="严格按优先级顺序,不考虑缓存"
|
||||
@click="schedulingMode = 'fixed_order'"
|
||||
>
|
||||
固定顺序
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -400,6 +413,7 @@ interface KeyWithMeta {
|
||||
endpoint_base_url: string
|
||||
api_format: string
|
||||
capabilities: string[]
|
||||
health_score: number | null
|
||||
success_rate: number | null
|
||||
avg_response_time_ms: number | null
|
||||
request_count: number
|
||||
@@ -444,7 +458,7 @@ const saving = ref(false)
|
||||
const editingKeyPriority = ref<Record<string, string | null>>({}) // format -> keyId
|
||||
|
||||
// 调度模式状态
|
||||
const schedulingMode = ref<'fixed_order' | 'cache_affinity'>('cache_affinity')
|
||||
const schedulingMode = ref<'fixed_order' | 'load_balance' | 'cache_affinity'>('cache_affinity')
|
||||
|
||||
// 可用的 API 格式
|
||||
const availableFormats = computed(() => {
|
||||
@@ -477,7 +491,11 @@ async function loadCurrentPriorityMode() {
|
||||
activeMainTab.value = currentMode === 'global_key' ? 'key' : 'provider'
|
||||
|
||||
const currentSchedulingMode = schedulingResponse.value || 'cache_affinity'
|
||||
schedulingMode.value = currentSchedulingMode === 'fixed_order' ? 'fixed_order' : 'cache_affinity'
|
||||
if (currentSchedulingMode === 'fixed_order' || currentSchedulingMode === 'load_balance' || currentSchedulingMode === 'cache_affinity') {
|
||||
schedulingMode.value = currentSchedulingMode
|
||||
} else {
|
||||
schedulingMode.value = 'cache_affinity'
|
||||
}
|
||||
} catch {
|
||||
activeMainTab.value = 'provider'
|
||||
schedulingMode.value = 'cache_affinity'
|
||||
|
||||
@@ -178,7 +178,7 @@
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-8 w-8 text-destructive hover:text-destructive"
|
||||
class="h-8 w-8 hover:text-destructive"
|
||||
title="删除"
|
||||
@click="deleteModel(model)"
|
||||
>
|
||||
|
||||
@@ -295,6 +295,15 @@
|
||||
</template>
|
||||
|
||||
<RouterView />
|
||||
|
||||
<!-- 更新提示弹窗 -->
|
||||
<UpdateDialog
|
||||
v-if="updateInfo"
|
||||
v-model="showUpdateDialog"
|
||||
:current-version="updateInfo.current_version"
|
||||
:latest-version="updateInfo.latest_version || ''"
|
||||
:release-url="updateInfo.release_url"
|
||||
/>
|
||||
</AppShell>
|
||||
</template>
|
||||
|
||||
@@ -304,10 +313,12 @@ import { useRoute, useRouter } from 'vue-router'
|
||||
import { useAuthStore } from '@/stores/auth'
|
||||
import { useDarkMode } from '@/composables/useDarkMode'
|
||||
import { isDemoMode } from '@/config/demo'
|
||||
import { adminApi, type CheckUpdateResponse } from '@/api/admin'
|
||||
import Button from '@/components/ui/button.vue'
|
||||
import AppShell from '@/components/layout/AppShell.vue'
|
||||
import SidebarNav from '@/components/layout/SidebarNav.vue'
|
||||
import HeaderLogo from '@/components/HeaderLogo.vue'
|
||||
import UpdateDialog from '@/components/common/UpdateDialog.vue'
|
||||
import {
|
||||
Home,
|
||||
Users,
|
||||
@@ -345,17 +356,67 @@ const showAuthError = ref(false)
|
||||
const mobileMenuOpen = ref(false)
|
||||
let authCheckInterval: number | null = null
|
||||
|
||||
// 更新检查相关
|
||||
const showUpdateDialog = ref(false)
|
||||
const updateInfo = ref<CheckUpdateResponse | null>(null)
|
||||
|
||||
// 路由变化时自动关闭移动端菜单
|
||||
watch(() => route.path, () => {
|
||||
mobileMenuOpen.value = false
|
||||
})
|
||||
|
||||
// 检查是否应该显示更新提示
|
||||
function shouldShowUpdatePrompt(latestVersion: string): boolean {
|
||||
const ignoreKey = 'aether_update_ignore'
|
||||
const ignoreData = localStorage.getItem(ignoreKey)
|
||||
if (!ignoreData) return true
|
||||
|
||||
try {
|
||||
const { version, until } = JSON.parse(ignoreData)
|
||||
// 如果忽略的是同一版本且未过期,则不显示
|
||||
if (version === latestVersion && Date.now() < until) {
|
||||
return false
|
||||
}
|
||||
} catch {
|
||||
// 解析失败,显示提示
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查更新
|
||||
async function checkForUpdate() {
|
||||
// 只有管理员才检查更新
|
||||
if (authStore.user?.role !== 'admin') return
|
||||
|
||||
// 同一会话内只检查一次
|
||||
const sessionKey = 'aether_update_checked'
|
||||
if (sessionStorage.getItem(sessionKey)) return
|
||||
sessionStorage.setItem(sessionKey, '1')
|
||||
|
||||
try {
|
||||
const result = await adminApi.checkUpdate()
|
||||
if (result.has_update && result.latest_version) {
|
||||
if (shouldShowUpdatePrompt(result.latest_version)) {
|
||||
updateInfo.value = result
|
||||
showUpdateDialog.value = true
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// 静默失败,不影响用户体验
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
authCheckInterval = setInterval(() => {
|
||||
if (authStore.user && !authStore.token) {
|
||||
showAuthError.value = true
|
||||
}
|
||||
}, 5000)
|
||||
|
||||
// 延迟检查更新,避免影响页面加载
|
||||
setTimeout(() => {
|
||||
checkForUpdate()
|
||||
}, 2000)
|
||||
})
|
||||
|
||||
onUnmounted(() => {
|
||||
|
||||
@@ -425,9 +425,9 @@ const MOCK_ENDPOINT_KEYS = [
|
||||
|
||||
// Mock Endpoints
|
||||
const MOCK_ENDPOINTS = [
|
||||
{ id: 'ep-001', provider_id: 'provider-001', provider_name: 'anthropic', api_format: 'claude', base_url: 'https://api.anthropic.com', auth_type: 'bearer', timeout: 120, max_retries: 3, priority: 100, weight: 100, health_score: 98, consecutive_failures: 0, is_active: true, total_keys: 2, active_keys: 2, created_at: '2024-01-01T00:00:00Z', updated_at: new Date().toISOString() },
|
||||
{ id: 'ep-002', provider_id: 'provider-002', provider_name: 'openai', api_format: 'openai', base_url: 'https://api.openai.com', auth_type: 'bearer', timeout: 60, max_retries: 3, priority: 90, weight: 100, health_score: 97, consecutive_failures: 0, is_active: true, total_keys: 1, active_keys: 1, created_at: '2024-01-01T00:00:00Z', updated_at: new Date().toISOString() },
|
||||
{ id: 'ep-003', provider_id: 'provider-003', provider_name: 'google', api_format: 'gemini', base_url: 'https://generativelanguage.googleapis.com', auth_type: 'api_key', timeout: 60, max_retries: 3, priority: 80, weight: 100, health_score: 96, consecutive_failures: 0, is_active: true, total_keys: 1, active_keys: 1, created_at: '2024-01-15T00:00:00Z', updated_at: new Date().toISOString() }
|
||||
{ id: 'ep-001', provider_id: 'provider-001', provider_name: 'anthropic', api_format: 'claude', base_url: 'https://api.anthropic.com', auth_type: 'bearer', timeout: 120, max_retries: 2, priority: 100, weight: 100, health_score: 98, consecutive_failures: 0, is_active: true, total_keys: 2, active_keys: 2, created_at: '2024-01-01T00:00:00Z', updated_at: new Date().toISOString() },
|
||||
{ id: 'ep-002', provider_id: 'provider-002', provider_name: 'openai', api_format: 'openai', base_url: 'https://api.openai.com', auth_type: 'bearer', timeout: 60, max_retries: 2, priority: 90, weight: 100, health_score: 97, consecutive_failures: 0, is_active: true, total_keys: 1, active_keys: 1, created_at: '2024-01-01T00:00:00Z', updated_at: new Date().toISOString() },
|
||||
{ id: 'ep-003', provider_id: 'provider-003', provider_name: 'google', api_format: 'gemini', base_url: 'https://generativelanguage.googleapis.com', auth_type: 'api_key', timeout: 60, max_retries: 2, priority: 80, weight: 100, health_score: 96, consecutive_failures: 0, is_active: true, total_keys: 1, active_keys: 1, created_at: '2024-01-15T00:00:00Z', updated_at: new Date().toISOString() }
|
||||
]
|
||||
|
||||
// Mock 能力定义
|
||||
@@ -1224,7 +1224,7 @@ function generateMockEndpointsForProvider(providerId: string) {
|
||||
'https://generativelanguage.googleapis.com',
|
||||
auth_type: format.includes('GEMINI') ? 'api_key' : 'bearer',
|
||||
timeout: 120,
|
||||
max_retries: 3,
|
||||
max_retries: 2,
|
||||
priority: 100 - index * 10,
|
||||
weight: 100,
|
||||
health_score: healthDetail?.health_score ?? 1.0,
|
||||
|
||||
@@ -54,7 +54,7 @@ const fieldNameMap: Record<string, string> = {
|
||||
*/
|
||||
const errorTypeMap: Record<string, (error: ValidationError) => string> = {
|
||||
'string_too_short': (error) => {
|
||||
const minLength = error.ctx?.min_length || 10
|
||||
const minLength = error.ctx?.min_length || 3
|
||||
return `长度不能少于 ${minLength} 个字符`
|
||||
},
|
||||
'string_too_long': (error) => {
|
||||
|
||||
12
migrate.sh
12
migrate.sh
@@ -1,12 +0,0 @@
|
||||
#!/bin/bash
|
||||
# 数据库迁移脚本 - 在 Docker 容器内执行 Alembic 迁移
|
||||
|
||||
set -e
|
||||
|
||||
CONTAINER_NAME="aether-app"
|
||||
|
||||
echo "Running database migrations in container: $CONTAINER_NAME"
|
||||
|
||||
docker exec $CONTAINER_NAME alembic upgrade head
|
||||
|
||||
echo "Database migration completed successfully"
|
||||
@@ -1,34 +0,0 @@
|
||||
# file generated by setuptools-scm
|
||||
# don't change, don't track in version control
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"__version_tuple__",
|
||||
"version",
|
||||
"version_tuple",
|
||||
"__commit_id__",
|
||||
"commit_id",
|
||||
]
|
||||
|
||||
TYPE_CHECKING = False
|
||||
if TYPE_CHECKING:
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
||||
COMMIT_ID = Union[str, None]
|
||||
else:
|
||||
VERSION_TUPLE = object
|
||||
COMMIT_ID = object
|
||||
|
||||
version: str
|
||||
__version__: str
|
||||
__version_tuple__: VERSION_TUPLE
|
||||
version_tuple: VERSION_TUPLE
|
||||
commit_id: COMMIT_ID
|
||||
__commit_id__: COMMIT_ID
|
||||
|
||||
__version__ = version = '0.2.5'
|
||||
__version_tuple__ = version_tuple = (0, 2, 5)
|
||||
|
||||
__commit_id__ = commit_id = None
|
||||
@@ -18,6 +18,7 @@ from src.core.key_capabilities import get_capability
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
|
||||
from src.services.cache.provider_cache import ProviderCacheService
|
||||
from src.models.endpoint_models import (
|
||||
BatchUpdateKeyPriorityRequest,
|
||||
EndpointAPIKeyCreate,
|
||||
@@ -411,6 +412,10 @@ class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
|
||||
db.commit()
|
||||
db.refresh(key)
|
||||
|
||||
# 如果更新了 rate_multiplier,清除缓存
|
||||
if "rate_multiplier" in update_data:
|
||||
await ProviderCacheService.invalidate_provider_api_key_cache(self.key_id)
|
||||
|
||||
logger.info("[OK] 更新 Key: ID=%s, Updates=%s", self.key_id, list(update_data.keys()))
|
||||
|
||||
try:
|
||||
@@ -550,6 +555,7 @@ class AdminGetKeysGroupedByFormatAdapter(AdminApiAdapter):
|
||||
"endpoint_base_url": endpoint.base_url,
|
||||
"api_format": api_format,
|
||||
"capabilities": caps_list,
|
||||
"health_score": key.health_score,
|
||||
"success_rate": success_rate,
|
||||
"avg_response_time_ms": avg_response_time_ms,
|
||||
"request_count": key.request_count,
|
||||
|
||||
@@ -11,9 +11,11 @@ from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.enums import ProviderBillingType
|
||||
from src.core.exceptions import InvalidRequestException, NotFoundException
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.admin_requests import CreateProviderRequest, UpdateProviderRequest
|
||||
from src.models.database import Provider
|
||||
from src.services.cache.provider_cache import ProviderCacheService
|
||||
|
||||
router = APIRouter(tags=["Provider CRUD"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
@@ -296,6 +298,11 @@ class AdminUpdateProviderAdapter(AdminApiAdapter):
|
||||
db.commit()
|
||||
db.refresh(provider)
|
||||
|
||||
# 如果更新了 billing_type,清除缓存
|
||||
if "billing_type" in update_data:
|
||||
await ProviderCacheService.invalidate_provider_cache(provider.id)
|
||||
logger.debug(f"已清除 Provider 缓存: {provider.id}")
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="update_provider",
|
||||
provider_id=provider.id,
|
||||
|
||||
@@ -42,6 +42,42 @@ def _get_version_from_git() -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def _get_current_version() -> str:
|
||||
"""获取当前版本号"""
|
||||
version = _get_version_from_git()
|
||||
if version:
|
||||
return version
|
||||
try:
|
||||
from src._version import __version__
|
||||
|
||||
return __version__
|
||||
except ImportError:
|
||||
return "unknown"
|
||||
|
||||
|
||||
def _parse_version(version_str: str) -> tuple:
|
||||
"""解析版本号为可比较的元组,支持 3-4 段版本号
|
||||
|
||||
例如:
|
||||
- '0.2.5' -> (0, 2, 5, 0)
|
||||
- '0.2.5.1' -> (0, 2, 5, 1)
|
||||
- 'v0.2.5-4-g1234567' -> (0, 2, 5, 0)
|
||||
"""
|
||||
import re
|
||||
|
||||
version_str = version_str.lstrip("v")
|
||||
main_version = re.split(r"[-+]", version_str)[0]
|
||||
try:
|
||||
parts = main_version.split(".")
|
||||
# 标准化为 4 段,便于比较
|
||||
int_parts = [int(p) for p in parts]
|
||||
while len(int_parts) < 4:
|
||||
int_parts.append(0)
|
||||
return tuple(int_parts[:4])
|
||||
except ValueError:
|
||||
return (0, 0, 0, 0)
|
||||
|
||||
|
||||
@router.get("/version")
|
||||
async def get_system_version():
|
||||
"""
|
||||
@@ -52,18 +88,111 @@ async def get_system_version():
|
||||
**返回字段**:
|
||||
- `version`: 版本号字符串
|
||||
"""
|
||||
# 优先从 git 获取
|
||||
version = _get_version_from_git()
|
||||
if version:
|
||||
return {"version": version}
|
||||
return {"version": _get_current_version()}
|
||||
|
||||
|
||||
@router.get("/check-update")
|
||||
async def check_update():
|
||||
"""
|
||||
检查系统更新
|
||||
|
||||
从 GitHub Tags 获取最新版本并与当前版本对比。
|
||||
|
||||
**返回字段**:
|
||||
- `current_version`: 当前版本号
|
||||
- `latest_version`: 最新版本号
|
||||
- `has_update`: 是否有更新可用
|
||||
- `release_url`: 最新版本的 GitHub 页面链接
|
||||
"""
|
||||
import httpx
|
||||
|
||||
from src.clients.http_client import HTTPClientPool
|
||||
|
||||
current_version = _get_current_version()
|
||||
github_repo = "Aethersailor/Aether"
|
||||
github_tags_url = f"https://api.github.com/repos/{github_repo}/tags"
|
||||
|
||||
# 回退到静态版本文件
|
||||
try:
|
||||
from src._version import __version__
|
||||
async with HTTPClientPool.get_temp_client(
|
||||
timeout=httpx.Timeout(connect=5.0, read=10.0, write=5.0, pool=5.0)
|
||||
) as client:
|
||||
response = await client.get(
|
||||
github_tags_url,
|
||||
headers={
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
"User-Agent": f"Aether/{current_version}",
|
||||
},
|
||||
params={"per_page": 10},
|
||||
)
|
||||
|
||||
return {"version": __version__}
|
||||
except ImportError:
|
||||
return {"version": "unknown"}
|
||||
if response.status_code != 200:
|
||||
return {
|
||||
"current_version": current_version,
|
||||
"latest_version": None,
|
||||
"has_update": False,
|
||||
"release_url": None,
|
||||
"error": f"GitHub API 返回错误: {response.status_code}",
|
||||
}
|
||||
|
||||
tags = response.json()
|
||||
if not tags:
|
||||
return {
|
||||
"current_version": current_version,
|
||||
"latest_version": None,
|
||||
"has_update": False,
|
||||
"release_url": None,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
# 找到最新的版本 tag(按版本号排序,而非时间)
|
||||
version_tags = []
|
||||
for tag in tags:
|
||||
tag_name = tag.get("name", "")
|
||||
if tag_name.startswith("v") or tag_name[0].isdigit():
|
||||
version_tags.append((tag_name, _parse_version(tag_name)))
|
||||
|
||||
if not version_tags:
|
||||
return {
|
||||
"current_version": current_version,
|
||||
"latest_version": None,
|
||||
"has_update": False,
|
||||
"release_url": None,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
# 按版本号排序,取最大的
|
||||
version_tags.sort(key=lambda x: x[1], reverse=True)
|
||||
latest_tag = version_tags[0][0]
|
||||
latest_version = latest_tag.lstrip("v")
|
||||
|
||||
current_tuple = _parse_version(current_version)
|
||||
latest_tuple = _parse_version(latest_version)
|
||||
has_update = latest_tuple > current_tuple
|
||||
|
||||
return {
|
||||
"current_version": current_version,
|
||||
"latest_version": latest_version,
|
||||
"has_update": has_update,
|
||||
"release_url": f"https://github.com/{github_repo}/releases/tag/{latest_tag}",
|
||||
"error": None,
|
||||
}
|
||||
|
||||
except httpx.TimeoutException:
|
||||
return {
|
||||
"current_version": current_version,
|
||||
"latest_version": None,
|
||||
"has_update": False,
|
||||
"release_url": None,
|
||||
"error": "检查更新超时",
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"current_version": current_version,
|
||||
"latest_version": None,
|
||||
"has_update": False,
|
||||
"release_url": None,
|
||||
"error": f"检查更新失败: {str(e)}",
|
||||
}
|
||||
|
||||
|
||||
pipeline = ApiRequestPipeline()
|
||||
@@ -887,7 +1016,7 @@ class AdminImportConfigAdapter(AdminApiAdapter):
|
||||
)
|
||||
existing_ep.headers = ep_data.get("headers")
|
||||
existing_ep.timeout = ep_data.get("timeout", 300)
|
||||
existing_ep.max_retries = ep_data.get("max_retries", 3)
|
||||
existing_ep.max_retries = ep_data.get("max_retries", 2)
|
||||
existing_ep.max_concurrent = ep_data.get("max_concurrent")
|
||||
existing_ep.rate_limit = ep_data.get("rate_limit")
|
||||
existing_ep.is_active = ep_data.get("is_active", True)
|
||||
@@ -903,7 +1032,7 @@ class AdminImportConfigAdapter(AdminApiAdapter):
|
||||
base_url=ep_data["base_url"],
|
||||
headers=ep_data.get("headers"),
|
||||
timeout=ep_data.get("timeout", 300),
|
||||
max_retries=ep_data.get("max_retries", 3),
|
||||
max_retries=ep_data.get("max_retries", 2),
|
||||
max_concurrent=ep_data.get("max_concurrent"),
|
||||
rate_limit=ep_data.get("rate_limit"),
|
||||
is_active=ep_data.get("is_active", True),
|
||||
|
||||
@@ -691,64 +691,70 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
f" [{self.request_id}] 请求体stream字段: {provider_payload.get('stream', 'N/A')}"
|
||||
)
|
||||
|
||||
# 创建 HTTP 客户端(支持代理配置)
|
||||
# endpoint.timeout 作为整体请求超时
|
||||
# 获取复用的 HTTP 客户端(支持代理配置)
|
||||
# 注意:使用 get_proxy_client 复用连接池,不再每次创建新客户端
|
||||
from src.clients.http_client import HTTPClientPool
|
||||
|
||||
request_timeout = float(endpoint.timeout or 300)
|
||||
http_client = HTTPClientPool.create_client_with_proxy(
|
||||
http_client = await HTTPClientPool.get_proxy_client(
|
||||
proxy_config=endpoint.proxy,
|
||||
)
|
||||
|
||||
# 注意:不使用 async with,因为复用的客户端不应该被关闭
|
||||
# 超时通过 timeout 参数控制
|
||||
resp = await http_client.post(
|
||||
url,
|
||||
json=provider_payload,
|
||||
headers=provider_hdrs,
|
||||
timeout=httpx.Timeout(request_timeout),
|
||||
)
|
||||
async with http_client:
|
||||
resp = await http_client.post(url, json=provider_payload, headers=provider_hdrs)
|
||||
|
||||
status_code = resp.status_code
|
||||
response_headers = dict(resp.headers)
|
||||
status_code = resp.status_code
|
||||
response_headers = dict(resp.headers)
|
||||
|
||||
if resp.status_code == 401:
|
||||
raise ProviderAuthException(f"提供商认证失败: {provider.name}")
|
||||
elif resp.status_code == 429:
|
||||
raise ProviderRateLimitException(
|
||||
f"提供商速率限制: {provider.name}",
|
||||
provider_name=str(provider.name),
|
||||
response_headers=response_headers,
|
||||
if resp.status_code == 401:
|
||||
raise ProviderAuthException(f"提供商认证失败: {provider.name}")
|
||||
elif resp.status_code == 429:
|
||||
raise ProviderRateLimitException(
|
||||
f"提供商速率限制: {provider.name}",
|
||||
provider_name=str(provider.name),
|
||||
response_headers=response_headers,
|
||||
)
|
||||
elif resp.status_code >= 500:
|
||||
# 记录响应体以便调试
|
||||
error_body = ""
|
||||
try:
|
||||
error_body = resp.text[:1000]
|
||||
logger.error(
|
||||
f" [{self.request_id}] 上游返回5xx错误: status={resp.status_code}, body={error_body[:500]}"
|
||||
)
|
||||
elif resp.status_code >= 500:
|
||||
# 记录响应体以便调试
|
||||
error_body = ""
|
||||
try:
|
||||
error_body = resp.text[:1000]
|
||||
logger.error(
|
||||
f" [{self.request_id}] 上游返回5xx错误: status={resp.status_code}, body={error_body[:500]}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
raise ProviderNotAvailableException(
|
||||
f"提供商服务不可用: {provider.name}",
|
||||
provider_name=str(provider.name),
|
||||
upstream_status=resp.status_code,
|
||||
upstream_response=error_body,
|
||||
)
|
||||
elif resp.status_code != 200:
|
||||
# 记录非200响应以便调试
|
||||
error_body = ""
|
||||
try:
|
||||
error_body = resp.text[:1000]
|
||||
logger.warning(
|
||||
f" [{self.request_id}] 上游返回非200: status={resp.status_code}, body={error_body[:500]}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
raise ProviderNotAvailableException(
|
||||
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}",
|
||||
provider_name=str(provider.name),
|
||||
upstream_status=resp.status_code,
|
||||
upstream_response=error_body,
|
||||
except Exception:
|
||||
pass
|
||||
raise ProviderNotAvailableException(
|
||||
f"提供商服务不可用: {provider.name}",
|
||||
provider_name=str(provider.name),
|
||||
upstream_status=resp.status_code,
|
||||
upstream_response=error_body,
|
||||
)
|
||||
elif resp.status_code != 200:
|
||||
# 记录非200响应以便调试
|
||||
error_body = ""
|
||||
try:
|
||||
error_body = resp.text[:1000]
|
||||
logger.warning(
|
||||
f" [{self.request_id}] 上游返回非200: status={resp.status_code}, body={error_body[:500]}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
raise ProviderNotAvailableException(
|
||||
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}",
|
||||
provider_name=str(provider.name),
|
||||
upstream_status=resp.status_code,
|
||||
upstream_response=error_body,
|
||||
)
|
||||
|
||||
response_json = resp.json()
|
||||
return response_json if isinstance(response_json, dict) else {}
|
||||
response_json = resp.json()
|
||||
return response_json if isinstance(response_json, dict) else {}
|
||||
|
||||
try:
|
||||
# 解析能力需求
|
||||
|
||||
@@ -1534,72 +1534,78 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
f"原始模型={model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}"
|
||||
)
|
||||
|
||||
# 创建 HTTP 客户端(支持代理配置)
|
||||
# endpoint.timeout 作为整体请求超时
|
||||
# 获取复用的 HTTP 客户端(支持代理配置)
|
||||
# 注意:使用 get_proxy_client 复用连接池,不再每次创建新客户端
|
||||
from src.clients.http_client import HTTPClientPool
|
||||
|
||||
request_timeout = float(endpoint.timeout or 300)
|
||||
http_client = HTTPClientPool.create_client_with_proxy(
|
||||
http_client = await HTTPClientPool.get_proxy_client(
|
||||
proxy_config=endpoint.proxy,
|
||||
)
|
||||
|
||||
# 注意:不使用 async with,因为复用的客户端不应该被关闭
|
||||
# 超时通过 timeout 参数控制
|
||||
resp = await http_client.post(
|
||||
url,
|
||||
json=provider_payload,
|
||||
headers=provider_headers,
|
||||
timeout=httpx.Timeout(request_timeout),
|
||||
)
|
||||
async with http_client:
|
||||
resp = await http_client.post(url, json=provider_payload, headers=provider_headers)
|
||||
|
||||
status_code = resp.status_code
|
||||
response_headers = dict(resp.headers)
|
||||
status_code = resp.status_code
|
||||
response_headers = dict(resp.headers)
|
||||
|
||||
if resp.status_code == 401:
|
||||
raise ProviderAuthException(f"提供商认证失败: {provider.name}")
|
||||
elif resp.status_code == 429:
|
||||
raise ProviderRateLimitException(
|
||||
f"提供商速率限制: {provider.name}",
|
||||
provider_name=str(provider.name),
|
||||
response_headers=response_headers,
|
||||
retry_after=int(resp.headers.get("retry-after", 0)) or None,
|
||||
)
|
||||
elif resp.status_code >= 500:
|
||||
error_text = resp.text
|
||||
raise ProviderNotAvailableException(
|
||||
f"提供商服务不可用: {provider.name}, 状态: {resp.status_code}",
|
||||
provider_name=str(provider.name),
|
||||
upstream_status=resp.status_code,
|
||||
upstream_response=error_text,
|
||||
)
|
||||
elif 300 <= resp.status_code < 400:
|
||||
redirect_url = resp.headers.get("location", "unknown")
|
||||
raise ProviderNotAvailableException(
|
||||
f"提供商配置错误: {provider.name}, 返回重定向 {resp.status_code} -> {redirect_url}"
|
||||
)
|
||||
elif resp.status_code != 200:
|
||||
error_text = resp.text
|
||||
raise ProviderNotAvailableException(
|
||||
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}",
|
||||
provider_name=str(provider.name),
|
||||
upstream_status=resp.status_code,
|
||||
upstream_response=error_text,
|
||||
)
|
||||
if resp.status_code == 401:
|
||||
raise ProviderAuthException(f"提供商认证失败: {provider.name}")
|
||||
elif resp.status_code == 429:
|
||||
raise ProviderRateLimitException(
|
||||
f"提供商速率限制: {provider.name}",
|
||||
provider_name=str(provider.name),
|
||||
response_headers=response_headers,
|
||||
retry_after=int(resp.headers.get("retry-after", 0)) or None,
|
||||
)
|
||||
elif resp.status_code >= 500:
|
||||
error_text = resp.text
|
||||
raise ProviderNotAvailableException(
|
||||
f"提供商服务不可用: {provider.name}, 状态: {resp.status_code}",
|
||||
provider_name=str(provider.name),
|
||||
upstream_status=resp.status_code,
|
||||
upstream_response=error_text,
|
||||
)
|
||||
elif 300 <= resp.status_code < 400:
|
||||
redirect_url = resp.headers.get("location", "unknown")
|
||||
raise ProviderNotAvailableException(
|
||||
f"提供商配置错误: {provider.name}, 返回重定向 {resp.status_code} -> {redirect_url}"
|
||||
)
|
||||
elif resp.status_code != 200:
|
||||
error_text = resp.text
|
||||
raise ProviderNotAvailableException(
|
||||
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}",
|
||||
provider_name=str(provider.name),
|
||||
upstream_status=resp.status_code,
|
||||
upstream_response=error_text,
|
||||
)
|
||||
|
||||
# 安全解析 JSON 响应,处理可能的编码错误
|
||||
try:
|
||||
response_json = resp.json()
|
||||
except (UnicodeDecodeError, json.JSONDecodeError) as e:
|
||||
# 记录原始响应信息用于调试
|
||||
content_type = resp.headers.get("content-type", "unknown")
|
||||
content_encoding = resp.headers.get("content-encoding", "none")
|
||||
logger.error(
|
||||
f"[{self.request_id}] 无法解析响应 JSON: {e}, "
|
||||
f"Content-Type: {content_type}, Content-Encoding: {content_encoding}, "
|
||||
f"响应长度: {len(resp.content)} bytes"
|
||||
)
|
||||
raise ProviderNotAvailableException(
|
||||
f"提供商返回无效响应: {provider.name}, 无法解析 JSON: {str(e)[:100]}"
|
||||
)
|
||||
# 安全解析 JSON 响应,处理可能的编码错误
|
||||
try:
|
||||
response_json = resp.json()
|
||||
except (UnicodeDecodeError, json.JSONDecodeError) as e:
|
||||
# 记录原始响应信息用于调试
|
||||
content_type = resp.headers.get("content-type", "unknown")
|
||||
content_encoding = resp.headers.get("content-encoding", "none")
|
||||
logger.error(
|
||||
f"[{self.request_id}] 无法解析响应 JSON: {e}, "
|
||||
f"Content-Type: {content_type}, Content-Encoding: {content_encoding}, "
|
||||
f"响应长度: {len(resp.content)} bytes"
|
||||
)
|
||||
raise ProviderNotAvailableException(
|
||||
f"提供商返回无效响应: {provider.name}, 无法解析 JSON: {str(e)[:100]}"
|
||||
)
|
||||
|
||||
# 提取 Provider 响应元数据(子类可覆盖)
|
||||
response_metadata_result = self._extract_response_metadata(response_json)
|
||||
# 提取 Provider 响应元数据(子类可覆盖)
|
||||
response_metadata_result = self._extract_response_metadata(response_json)
|
||||
|
||||
return response_json if isinstance(response_json, dict) else {}
|
||||
return response_json if isinstance(response_json, dict) else {}
|
||||
|
||||
try:
|
||||
# 解析能力需求
|
||||
|
||||
@@ -1,10 +1,18 @@
|
||||
"""
|
||||
全局HTTP客户端池管理
|
||||
避免每次请求都创建新的AsyncClient,提高性能
|
||||
|
||||
性能优化说明:
|
||||
1. 默认客户端:无代理场景,全局复用单一客户端
|
||||
2. 代理客户端缓存:相同代理配置复用同一客户端,避免重复创建
|
||||
3. 连接池复用:Keep-alive 连接减少 TCP 握手开销
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from urllib.parse import quote, urlparse
|
||||
|
||||
import httpx
|
||||
@@ -12,6 +20,32 @@ import httpx
|
||||
from src.config import config
|
||||
from src.core.logger import logger
|
||||
|
||||
# 模块级锁,避免类属性延迟初始化的竞态条件
|
||||
_proxy_clients_lock = asyncio.Lock()
|
||||
_default_client_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def _compute_proxy_cache_key(proxy_config: Optional[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
计算代理配置的缓存键
|
||||
|
||||
Args:
|
||||
proxy_config: 代理配置字典
|
||||
|
||||
Returns:
|
||||
缓存键字符串,无代理时返回 "__no_proxy__"
|
||||
"""
|
||||
if not proxy_config:
|
||||
return "__no_proxy__"
|
||||
|
||||
# 构建代理 URL 作为缓存键的基础
|
||||
proxy_url = build_proxy_url(proxy_config)
|
||||
if not proxy_url:
|
||||
return "__no_proxy__"
|
||||
|
||||
# 使用 MD5 哈希来避免过长的键名
|
||||
return f"proxy:{hashlib.md5(proxy_url.encode()).hexdigest()[:16]}"
|
||||
|
||||
|
||||
def build_proxy_url(proxy_config: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
@@ -61,11 +95,20 @@ class HTTPClientPool:
|
||||
全局HTTP客户端池单例
|
||||
|
||||
管理可重用的httpx.AsyncClient实例,避免频繁创建/销毁连接
|
||||
|
||||
性能优化:
|
||||
1. 默认客户端:无代理场景复用
|
||||
2. 代理客户端缓存:相同代理配置复用同一客户端
|
||||
3. LRU 淘汰:代理客户端超过上限时淘汰最久未使用的
|
||||
"""
|
||||
|
||||
_instance: Optional["HTTPClientPool"] = None
|
||||
_default_client: Optional[httpx.AsyncClient] = None
|
||||
_clients: Dict[str, httpx.AsyncClient] = {}
|
||||
# 代理客户端缓存:{cache_key: (client, last_used_time)}
|
||||
_proxy_clients: Dict[str, Tuple[httpx.AsyncClient, float]] = {}
|
||||
# 代理客户端缓存上限(避免内存泄漏)
|
||||
_max_proxy_clients: int = 50
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
@@ -73,12 +116,50 @@ class HTTPClientPool:
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def get_default_client(cls) -> httpx.AsyncClient:
|
||||
async def get_default_client_async(cls) -> httpx.AsyncClient:
|
||||
"""
|
||||
获取默认的HTTP客户端
|
||||
获取默认的HTTP客户端(异步线程安全版本)
|
||||
|
||||
用于大多数HTTP请求,具有合理的默认配置
|
||||
"""
|
||||
if cls._default_client is not None:
|
||||
return cls._default_client
|
||||
|
||||
async with _default_client_lock:
|
||||
# 双重检查,避免重复创建
|
||||
if cls._default_client is None:
|
||||
cls._default_client = httpx.AsyncClient(
|
||||
http2=False, # 暂时禁用HTTP/2以提高兼容性
|
||||
verify=True, # 启用SSL验证
|
||||
timeout=httpx.Timeout(
|
||||
connect=config.http_connect_timeout,
|
||||
read=config.http_read_timeout,
|
||||
write=config.http_write_timeout,
|
||||
pool=config.http_pool_timeout,
|
||||
),
|
||||
limits=httpx.Limits(
|
||||
max_connections=config.http_max_connections,
|
||||
max_keepalive_connections=config.http_keepalive_connections,
|
||||
keepalive_expiry=config.http_keepalive_expiry,
|
||||
),
|
||||
follow_redirects=True, # 跟随重定向
|
||||
)
|
||||
logger.info(
|
||||
f"全局HTTP客户端池已初始化: "
|
||||
f"max_connections={config.http_max_connections}, "
|
||||
f"keepalive={config.http_keepalive_connections}, "
|
||||
f"keepalive_expiry={config.http_keepalive_expiry}s"
|
||||
)
|
||||
return cls._default_client
|
||||
|
||||
@classmethod
|
||||
def get_default_client(cls) -> httpx.AsyncClient:
|
||||
"""
|
||||
获取默认的HTTP客户端(同步版本,向后兼容)
|
||||
|
||||
⚠️ 注意:此方法在高并发首次调用时可能存在竞态条件,
|
||||
推荐使用 get_default_client_async() 异步版本。
|
||||
"""
|
||||
if cls._default_client is None:
|
||||
cls._default_client = httpx.AsyncClient(
|
||||
http2=False, # 暂时禁用HTTP/2以提高兼容性
|
||||
@@ -90,13 +171,18 @@ class HTTPClientPool:
|
||||
pool=config.http_pool_timeout,
|
||||
),
|
||||
limits=httpx.Limits(
|
||||
max_connections=100, # 最大连接数
|
||||
max_keepalive_connections=20, # 最大保活连接数
|
||||
keepalive_expiry=30.0, # 保活过期时间(秒)
|
||||
max_connections=config.http_max_connections,
|
||||
max_keepalive_connections=config.http_keepalive_connections,
|
||||
keepalive_expiry=config.http_keepalive_expiry,
|
||||
),
|
||||
follow_redirects=True, # 跟随重定向
|
||||
)
|
||||
logger.info("全局HTTP客户端池已初始化")
|
||||
logger.info(
|
||||
f"全局HTTP客户端池已初始化: "
|
||||
f"max_connections={config.http_max_connections}, "
|
||||
f"keepalive={config.http_keepalive_connections}, "
|
||||
f"keepalive_expiry={config.http_keepalive_expiry}s"
|
||||
)
|
||||
return cls._default_client
|
||||
|
||||
@classmethod
|
||||
@@ -130,6 +216,101 @@ class HTTPClientPool:
|
||||
|
||||
return cls._clients[name]
|
||||
|
||||
@classmethod
|
||||
def _get_proxy_clients_lock(cls) -> asyncio.Lock:
|
||||
"""获取代理客户端缓存锁(模块级单例,避免竞态条件)"""
|
||||
return _proxy_clients_lock
|
||||
|
||||
@classmethod
|
||||
async def _evict_lru_proxy_client(cls) -> None:
|
||||
"""淘汰最久未使用的代理客户端"""
|
||||
if len(cls._proxy_clients) < cls._max_proxy_clients:
|
||||
return
|
||||
|
||||
# 找到最久未使用的客户端
|
||||
oldest_key = min(cls._proxy_clients.keys(), key=lambda k: cls._proxy_clients[k][1])
|
||||
old_client, _ = cls._proxy_clients.pop(oldest_key)
|
||||
|
||||
# 异步关闭旧客户端
|
||||
try:
|
||||
await old_client.aclose()
|
||||
logger.debug(f"淘汰代理客户端: {oldest_key}")
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭代理客户端失败: {e}")
|
||||
|
||||
@classmethod
|
||||
async def get_proxy_client(
|
||||
cls,
|
||||
proxy_config: Optional[Dict[str, Any]] = None,
|
||||
) -> httpx.AsyncClient:
|
||||
"""
|
||||
获取代理客户端(带缓存复用)
|
||||
|
||||
相同代理配置会复用同一个客户端,大幅减少连接建立开销。
|
||||
注意:返回的客户端使用默认超时配置,如需自定义超时请在请求时传递 timeout 参数。
|
||||
|
||||
Args:
|
||||
proxy_config: 代理配置字典,包含 url, username, password
|
||||
|
||||
Returns:
|
||||
可复用的 httpx.AsyncClient 实例
|
||||
"""
|
||||
cache_key = _compute_proxy_cache_key(proxy_config)
|
||||
|
||||
# 无代理时返回默认客户端
|
||||
if cache_key == "__no_proxy__":
|
||||
return await cls.get_default_client_async()
|
||||
|
||||
lock = cls._get_proxy_clients_lock()
|
||||
async with lock:
|
||||
# 检查缓存
|
||||
if cache_key in cls._proxy_clients:
|
||||
client, _ = cls._proxy_clients[cache_key]
|
||||
# 健康检查:如果客户端已关闭,移除并重新创建
|
||||
if client.is_closed:
|
||||
del cls._proxy_clients[cache_key]
|
||||
logger.debug(f"代理客户端已关闭,将重新创建: {cache_key}")
|
||||
else:
|
||||
# 更新最后使用时间
|
||||
cls._proxy_clients[cache_key] = (client, time.time())
|
||||
return client
|
||||
|
||||
# 淘汰旧客户端(如果超过上限)
|
||||
await cls._evict_lru_proxy_client()
|
||||
|
||||
# 创建新客户端(使用默认超时,请求时可覆盖)
|
||||
client_config: Dict[str, Any] = {
|
||||
"http2": False,
|
||||
"verify": True,
|
||||
"follow_redirects": True,
|
||||
"limits": httpx.Limits(
|
||||
max_connections=config.http_max_connections,
|
||||
max_keepalive_connections=config.http_keepalive_connections,
|
||||
keepalive_expiry=config.http_keepalive_expiry,
|
||||
),
|
||||
"timeout": httpx.Timeout(
|
||||
connect=config.http_connect_timeout,
|
||||
read=config.http_read_timeout,
|
||||
write=config.http_write_timeout,
|
||||
pool=config.http_pool_timeout,
|
||||
),
|
||||
}
|
||||
|
||||
# 添加代理配置
|
||||
proxy_url = build_proxy_url(proxy_config) if proxy_config else None
|
||||
if proxy_url:
|
||||
client_config["proxy"] = proxy_url
|
||||
|
||||
client = httpx.AsyncClient(**client_config)
|
||||
cls._proxy_clients[cache_key] = (client, time.time())
|
||||
|
||||
logger.debug(
|
||||
f"创建代理客户端(缓存): {proxy_config.get('url', 'unknown') if proxy_config else 'none'}, "
|
||||
f"缓存数量: {len(cls._proxy_clients)}"
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
@classmethod
|
||||
async def close_all(cls):
|
||||
"""关闭所有HTTP客户端"""
|
||||
@@ -143,6 +324,16 @@ class HTTPClientPool:
|
||||
logger.debug(f"命名HTTP客户端已关闭: {name}")
|
||||
|
||||
cls._clients.clear()
|
||||
|
||||
# 关闭代理客户端缓存
|
||||
for cache_key, (client, _) in cls._proxy_clients.items():
|
||||
try:
|
||||
await client.aclose()
|
||||
logger.debug(f"代理客户端已关闭: {cache_key}")
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭代理客户端失败: {e}")
|
||||
|
||||
cls._proxy_clients.clear()
|
||||
logger.info("所有HTTP客户端已关闭")
|
||||
|
||||
@classmethod
|
||||
@@ -185,13 +376,15 @@ class HTTPClientPool:
|
||||
"""
|
||||
创建带代理配置的HTTP客户端
|
||||
|
||||
⚠️ 性能警告:此方法每次都创建新客户端,推荐使用 get_proxy_client() 复用连接。
|
||||
|
||||
Args:
|
||||
proxy_config: 代理配置字典,包含 url, username, password
|
||||
timeout: 超时配置
|
||||
**kwargs: 其他 httpx.AsyncClient 配置参数
|
||||
|
||||
Returns:
|
||||
配置好的 httpx.AsyncClient 实例
|
||||
配置好的 httpx.AsyncClient 实例(调用者需要负责关闭)
|
||||
"""
|
||||
client_config: Dict[str, Any] = {
|
||||
"http2": False,
|
||||
@@ -213,11 +406,21 @@ class HTTPClientPool:
|
||||
proxy_url = build_proxy_url(proxy_config) if proxy_config else None
|
||||
if proxy_url:
|
||||
client_config["proxy"] = proxy_url
|
||||
logger.debug(f"创建带代理的HTTP客户端: {proxy_config.get('url', 'unknown')}")
|
||||
logger.debug(f"创建带代理的HTTP客户端(一次性): {proxy_config.get('url', 'unknown')}")
|
||||
|
||||
client_config.update(kwargs)
|
||||
return httpx.AsyncClient(**client_config)
|
||||
|
||||
@classmethod
|
||||
def get_pool_stats(cls) -> Dict[str, Any]:
|
||||
"""获取连接池统计信息"""
|
||||
return {
|
||||
"default_client_active": cls._default_client is not None,
|
||||
"named_clients_count": len(cls._clients),
|
||||
"proxy_clients_count": len(cls._proxy_clients),
|
||||
"max_proxy_clients": cls._max_proxy_clients,
|
||||
}
|
||||
|
||||
|
||||
# 便捷访问函数
|
||||
def get_http_client() -> httpx.AsyncClient:
|
||||
|
||||
@@ -145,6 +145,24 @@ class Config:
|
||||
self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0"))
|
||||
self.http_pool_timeout = float(os.getenv("HTTP_POOL_TIMEOUT", "10.0"))
|
||||
|
||||
# HTTP 连接池配置
|
||||
# HTTP_MAX_CONNECTIONS: 最大连接数,影响并发能力
|
||||
# - 每个连接占用一个 socket,过多会耗尽系统资源
|
||||
# - 默认根据 Worker 数量自动计算:单 Worker 200,多 Worker 按比例分配
|
||||
# HTTP_KEEPALIVE_CONNECTIONS: 保活连接数,影响连接复用效率
|
||||
# - 高频请求场景应该增大此值
|
||||
# - 默认为 max_connections 的 30%(长连接场景更高效)
|
||||
# HTTP_KEEPALIVE_EXPIRY: 保活过期时间(秒)
|
||||
# - 过短会频繁重建连接,过长会占用资源
|
||||
# - 默认 30 秒,生图等长连接场景可适当增大
|
||||
self.http_max_connections = int(
|
||||
os.getenv("HTTP_MAX_CONNECTIONS") or self._auto_http_max_connections()
|
||||
)
|
||||
self.http_keepalive_connections = int(
|
||||
os.getenv("HTTP_KEEPALIVE_CONNECTIONS") or self._auto_http_keepalive_connections()
|
||||
)
|
||||
self.http_keepalive_expiry = float(os.getenv("HTTP_KEEPALIVE_EXPIRY", "30.0"))
|
||||
|
||||
# 流式处理配置
|
||||
# STREAM_PREFETCH_LINES: 预读行数,用于检测嵌套错误
|
||||
# STREAM_STATS_DELAY: 统计记录延迟(秒),等待流完全关闭
|
||||
@@ -224,6 +242,53 @@ class Config:
|
||||
"""智能计算最大溢出连接数 - 与 pool_size 相同"""
|
||||
return self.db_pool_size
|
||||
|
||||
def _auto_http_max_connections(self) -> int:
|
||||
"""
|
||||
智能计算 HTTP 最大连接数
|
||||
|
||||
计算依据:
|
||||
1. 系统 socket 资源有限(Linux 默认 ulimit -n 通常为 1024)
|
||||
2. 多 Worker 部署时每个进程独立连接池
|
||||
3. 需要为数据库连接、Redis 连接等预留资源
|
||||
|
||||
公式: base_connections / workers
|
||||
- 单 Worker: 200 连接(适合开发/低负载)
|
||||
- 多 Worker: 按比例分配,确保总数不超过系统限制
|
||||
|
||||
范围: 50 - 500
|
||||
"""
|
||||
# 基础连接数:假设系统可用 socket 约 800 个用于 HTTP
|
||||
# (预留给 DB、Redis、内部服务等)
|
||||
base_connections = 800
|
||||
workers = max(self.worker_processes, 1)
|
||||
|
||||
# 每个 Worker 分配的连接数
|
||||
per_worker = base_connections // workers
|
||||
|
||||
# 限制范围:最小 50(保证基本并发),最大 500(避免资源耗尽)
|
||||
return max(50, min(per_worker, 500))
|
||||
|
||||
def _auto_http_keepalive_connections(self) -> int:
|
||||
"""
|
||||
智能计算 HTTP 保活连接数
|
||||
|
||||
计算依据:
|
||||
1. 保活连接用于复用,减少 TCP 握手开销
|
||||
2. 对于 API 网关场景,上游请求频繁,保活比例应较高
|
||||
3. 生图等长连接场景,连接会被长时间占用
|
||||
|
||||
公式: max_connections * 0.3
|
||||
- 30% 的比例在复用效率和资源占用间取得平衡
|
||||
- 长连接场景建议手动调高到 50-70%
|
||||
|
||||
范围: 10 - max_connections
|
||||
"""
|
||||
# 保活连接数为最大连接数的 30%
|
||||
keepalive = int(self.http_max_connections * 0.3)
|
||||
|
||||
# 最小 10 个保活连接,最大不超过 max_connections
|
||||
return max(10, min(keepalive, self.http_max_connections))
|
||||
|
||||
def _parse_ttfb_timeout(self) -> float:
|
||||
"""
|
||||
解析 TTFB 超时配置,带错误处理和范围限制
|
||||
|
||||
@@ -134,7 +134,7 @@ class EndpointAPIKeyCreate(BaseModel):
|
||||
"""为 Endpoint 添加 API Key"""
|
||||
|
||||
endpoint_id: str = Field(..., description="Endpoint ID")
|
||||
api_key: str = Field(..., min_length=10, max_length=500, description="API Key(将自动加密)")
|
||||
api_key: str = Field(..., min_length=3, max_length=500, description="API Key(将自动加密)")
|
||||
name: str = Field(..., min_length=1, max_length=100, description="密钥名称(必填,用于识别)")
|
||||
|
||||
# 成本计算
|
||||
@@ -175,13 +175,9 @@ class EndpointAPIKeyCreate(BaseModel):
|
||||
@classmethod
|
||||
def validate_api_key(cls, v: str) -> str:
|
||||
"""验证 API Key 安全性"""
|
||||
# 移除首尾空白
|
||||
# 移除首尾空白(长度校验由 Field min_length 处理)
|
||||
v = v.strip()
|
||||
|
||||
# 检查最小长度
|
||||
if len(v) < 10:
|
||||
raise ValueError("API Key 长度不能少于 10 个字符")
|
||||
|
||||
# 检查危险字符(SQL 注入防护)
|
||||
dangerous_chars = ["'", '"', ";", "--", "/*", "*/", "<", ">"]
|
||||
for char in dangerous_chars:
|
||||
@@ -219,7 +215,7 @@ class EndpointAPIKeyUpdate(BaseModel):
|
||||
"""更新 Endpoint API Key"""
|
||||
|
||||
api_key: Optional[str] = Field(
|
||||
default=None, min_length=10, max_length=500, description="API Key(将自动加密)"
|
||||
default=None, min_length=3, max_length=500, description="API Key(将自动加密)"
|
||||
)
|
||||
name: Optional[str] = Field(default=None, min_length=1, max_length=100, description="密钥名称")
|
||||
rate_multiplier: Optional[float] = Field(default=None, ge=0.01, description="成本倍率")
|
||||
|
||||
@@ -8,7 +8,9 @@ import hashlib
|
||||
import secrets
|
||||
import time
|
||||
import uuid
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
import jwt
|
||||
@@ -30,6 +32,44 @@ from src.services.cache.user_cache import UserCacheService
|
||||
from src.services.user.apikey import ApiKeyService
|
||||
|
||||
|
||||
# API Key last_used_at 更新节流配置
|
||||
# 同一个 API Key 在此时间间隔内只会更新一次 last_used_at
|
||||
_LAST_USED_UPDATE_INTERVAL = 60 # 秒
|
||||
_LAST_USED_CACHE_MAX_SIZE = 10000 # LRU 缓存最大条目数
|
||||
|
||||
# 进程内缓存:记录每个 API Key 最后一次更新 last_used_at 的时间
|
||||
# 使用 OrderedDict 实现 LRU,避免内存无限增长
|
||||
_api_key_last_update_times: OrderedDict[str, float] = OrderedDict()
|
||||
_last_update_lock = Lock()
|
||||
|
||||
|
||||
def _should_update_last_used(api_key_id: str) -> bool:
|
||||
"""判断是否应该更新 API Key 的 last_used_at
|
||||
|
||||
使用节流策略,同一个 Key 在指定间隔内只更新一次。
|
||||
线程安全,使用 LRU 策略限制缓存大小。
|
||||
|
||||
Returns:
|
||||
True 表示应该更新,False 表示跳过
|
||||
"""
|
||||
now = time.time()
|
||||
|
||||
with _last_update_lock:
|
||||
last_update = _api_key_last_update_times.get(api_key_id, 0)
|
||||
|
||||
if now - last_update >= _LAST_USED_UPDATE_INTERVAL:
|
||||
_api_key_last_update_times[api_key_id] = now
|
||||
# LRU: 移到末尾(最近使用)
|
||||
_api_key_last_update_times.move_to_end(api_key_id)
|
||||
|
||||
# 超过最大容量时,移除最旧的条目
|
||||
while len(_api_key_last_update_times) > _LAST_USED_CACHE_MAX_SIZE:
|
||||
_api_key_last_update_times.popitem(last=False)
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# JWT配置从config读取
|
||||
if not config.jwt_secret_key:
|
||||
# 如果没有配置,生成一个随机密钥并警告
|
||||
@@ -367,9 +407,10 @@ class AuthService:
|
||||
logger.warning(f"API认证失败 - 用户已禁用: {user.email}")
|
||||
return None
|
||||
|
||||
# 更新最后使用时间
|
||||
key_record.last_used_at = datetime.now(timezone.utc)
|
||||
db.commit() # 立即提交事务,释放数据库锁,避免阻塞后续请求
|
||||
# 更新最后使用时间(使用节流策略,减少数据库写入)
|
||||
if _should_update_last_used(key_record.id):
|
||||
key_record.last_used_at = datetime.now(timezone.utc)
|
||||
db.commit() # 立即提交事务,释放数据库锁,避免阻塞后续请求
|
||||
|
||||
api_key_fp = hashlib.sha256(api_key.encode()).hexdigest()[:12]
|
||||
logger.debug("API认证成功: 用户 {} (api_key_fp={})", user.email, api_key_fp)
|
||||
|
||||
67
src/services/cache/aware_scheduler.py
vendored
67
src/services/cache/aware_scheduler.py
vendored
@@ -121,11 +121,13 @@ class CacheAwareScheduler:
|
||||
PRIORITY_MODE_GLOBAL_KEY,
|
||||
}
|
||||
# 调度模式常量
|
||||
SCHEDULING_MODE_FIXED_ORDER = "fixed_order" # 固定顺序模式
|
||||
SCHEDULING_MODE_CACHE_AFFINITY = "cache_affinity" # 缓存亲和模式
|
||||
SCHEDULING_MODE_FIXED_ORDER = "fixed_order" # 固定顺序模式:严格按优先级,忽略缓存
|
||||
SCHEDULING_MODE_CACHE_AFFINITY = "cache_affinity" # 缓存亲和模式:优先缓存,同优先级哈希分散
|
||||
SCHEDULING_MODE_LOAD_BALANCE = "load_balance" # 负载均衡模式:忽略缓存,同优先级随机轮换
|
||||
ALLOWED_SCHEDULING_MODES = {
|
||||
SCHEDULING_MODE_FIXED_ORDER,
|
||||
SCHEDULING_MODE_CACHE_AFFINITY,
|
||||
SCHEDULING_MODE_LOAD_BALANCE,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
@@ -680,8 +682,9 @@ class CacheAwareScheduler:
|
||||
f"(api_format={target_format.value}, model={model_name})"
|
||||
)
|
||||
|
||||
# 4. 应用缓存亲和性排序(仅在缓存亲和模式下启用)
|
||||
# 4. 根据调度模式应用不同的排序策略
|
||||
if self.scheduling_mode == self.SCHEDULING_MODE_CACHE_AFFINITY:
|
||||
# 缓存亲和模式:优先使用缓存的,同优先级内哈希分散
|
||||
if affinity_key and candidates:
|
||||
candidates = await self._apply_cache_affinity(
|
||||
candidates=candidates,
|
||||
@@ -689,8 +692,13 @@ class CacheAwareScheduler:
|
||||
api_format=target_format,
|
||||
global_model_id=global_model_id,
|
||||
)
|
||||
elif self.scheduling_mode == self.SCHEDULING_MODE_LOAD_BALANCE:
|
||||
# 负载均衡模式:忽略缓存,同优先级内随机轮换
|
||||
candidates = self._apply_load_balance(candidates)
|
||||
for candidate in candidates:
|
||||
candidate.is_cached = False
|
||||
else:
|
||||
# 固定顺序模式:标记所有候选为非缓存
|
||||
# 固定顺序模式:严格按优先级,忽略缓存
|
||||
for candidate in candidates:
|
||||
candidate.is_cached = False
|
||||
|
||||
@@ -1163,6 +1171,57 @@ class CacheAwareScheduler:
|
||||
|
||||
return result
|
||||
|
||||
def _apply_load_balance(
|
||||
self, candidates: List[ProviderCandidate]
|
||||
) -> List[ProviderCandidate]:
|
||||
"""
|
||||
负载均衡模式:同优先级内随机轮换
|
||||
|
||||
排序逻辑:
|
||||
1. 按优先级分组(provider_priority, internal_priority 或 global_priority)
|
||||
2. 同优先级组内随机打乱
|
||||
3. 不考虑缓存亲和性
|
||||
"""
|
||||
if not candidates:
|
||||
return candidates
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
# 使用 tuple 作为统一的 key 类型,兼容两种模式
|
||||
priority_groups: Dict[tuple, List[ProviderCandidate]] = defaultdict(list)
|
||||
|
||||
# 根据优先级模式选择分组方式
|
||||
if self.priority_mode == self.PRIORITY_MODE_GLOBAL_KEY:
|
||||
# 全局 Key 优先模式:按 global_priority 分组
|
||||
for candidate in candidates:
|
||||
global_priority = (
|
||||
candidate.key.global_priority
|
||||
if candidate.key and candidate.key.global_priority is not None
|
||||
else 999999
|
||||
)
|
||||
priority_groups[(global_priority,)].append(candidate)
|
||||
else:
|
||||
# 提供商优先模式:按 (provider_priority, internal_priority) 分组
|
||||
for candidate in candidates:
|
||||
key = (
|
||||
candidate.provider.provider_priority or 999999,
|
||||
candidate.key.internal_priority if candidate.key else 999999,
|
||||
)
|
||||
priority_groups[key].append(candidate)
|
||||
|
||||
result: List[ProviderCandidate] = []
|
||||
for priority in sorted(priority_groups.keys()):
|
||||
group = priority_groups[priority]
|
||||
if len(group) > 1:
|
||||
# 同优先级内随机打乱
|
||||
shuffled = list(group)
|
||||
random.shuffle(shuffled)
|
||||
result.extend(shuffled)
|
||||
else:
|
||||
result.extend(group)
|
||||
|
||||
return result
|
||||
|
||||
def _shuffle_keys_by_internal_priority(
|
||||
self,
|
||||
keys: List[ProviderAPIKey],
|
||||
|
||||
171
src/services/cache/provider_cache.py
vendored
Normal file
171
src/services/cache/provider_cache.py
vendored
Normal file
@@ -0,0 +1,171 @@
|
||||
"""
|
||||
Provider 缓存服务 - 减少 Provider 和 ProviderAPIKey 查询
|
||||
|
||||
用于缓存 Provider 的 billing_type 和 ProviderAPIKey 的 rate_multiplier,
|
||||
这些数据在 UsageService.record_usage() 中被频繁查询但变化不频繁。
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.config.constants import CacheTTL
|
||||
from src.core.cache_service import CacheService
|
||||
from src.core.enums import ProviderBillingType
|
||||
from src.core.logger import logger
|
||||
from src.models.database import Provider, ProviderAPIKey
|
||||
|
||||
|
||||
class ProviderCacheService:
|
||||
"""Provider 缓存服务
|
||||
|
||||
提供 Provider 和 ProviderAPIKey 的缓存查询功能,减少数据库访问。
|
||||
主要用于 UsageService 中获取费率倍数和计费类型。
|
||||
"""
|
||||
|
||||
CACHE_TTL = CacheTTL.PROVIDER # 5 分钟
|
||||
|
||||
@staticmethod
|
||||
async def get_provider_api_key_rate_multiplier(
|
||||
db: Session, provider_api_key_id: str
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
获取 ProviderAPIKey 的 rate_multiplier(带缓存)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
provider_api_key_id: ProviderAPIKey ID
|
||||
|
||||
Returns:
|
||||
rate_multiplier 或 None(如果找不到)
|
||||
"""
|
||||
cache_key = f"provider_api_key:rate_multiplier:{provider_api_key_id}"
|
||||
|
||||
# 1. 尝试从缓存获取
|
||||
cached_data = await CacheService.get(cache_key)
|
||||
if cached_data is not None:
|
||||
logger.debug(f"ProviderAPIKey rate_multiplier 缓存命中: {provider_api_key_id[:8]}...")
|
||||
# 缓存的 "NOT_FOUND" 表示数据库中不存在
|
||||
if cached_data == "NOT_FOUND":
|
||||
return None
|
||||
return float(cached_data)
|
||||
|
||||
# 2. 缓存未命中,查询数据库
|
||||
provider_key = (
|
||||
db.query(ProviderAPIKey.rate_multiplier)
|
||||
.filter(ProviderAPIKey.id == provider_api_key_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
# 3. 写入缓存
|
||||
if provider_key:
|
||||
rate_multiplier = provider_key.rate_multiplier or 1.0
|
||||
await CacheService.set(
|
||||
cache_key, rate_multiplier, ttl_seconds=ProviderCacheService.CACHE_TTL
|
||||
)
|
||||
logger.debug(f"ProviderAPIKey rate_multiplier 已缓存: {provider_api_key_id[:8]}...")
|
||||
return rate_multiplier
|
||||
else:
|
||||
# 缓存负结果
|
||||
await CacheService.set(
|
||||
cache_key, "NOT_FOUND", ttl_seconds=ProviderCacheService.CACHE_TTL
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def get_provider_billing_type(
|
||||
db: Session, provider_id: str
|
||||
) -> Optional[ProviderBillingType]:
|
||||
"""
|
||||
获取 Provider 的 billing_type(带缓存)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
provider_id: Provider ID
|
||||
|
||||
Returns:
|
||||
billing_type 或 None(如果找不到)
|
||||
"""
|
||||
cache_key = f"provider:billing_type:{provider_id}"
|
||||
|
||||
# 1. 尝试从缓存获取
|
||||
cached_data = await CacheService.get(cache_key)
|
||||
if cached_data is not None:
|
||||
logger.debug(f"Provider billing_type 缓存命中: {provider_id[:8]}...")
|
||||
if cached_data == "NOT_FOUND":
|
||||
return None
|
||||
try:
|
||||
return ProviderBillingType(cached_data)
|
||||
except ValueError:
|
||||
# 缓存值无效,删除并重新查询
|
||||
await CacheService.delete(cache_key)
|
||||
|
||||
# 2. 缓存未命中,查询数据库
|
||||
provider = (
|
||||
db.query(Provider.billing_type).filter(Provider.id == provider_id).first()
|
||||
)
|
||||
|
||||
# 3. 写入缓存
|
||||
if provider:
|
||||
billing_type = provider.billing_type
|
||||
await CacheService.set(
|
||||
cache_key, billing_type.value, ttl_seconds=ProviderCacheService.CACHE_TTL
|
||||
)
|
||||
logger.debug(f"Provider billing_type 已缓存: {provider_id[:8]}...")
|
||||
return billing_type
|
||||
else:
|
||||
# 缓存负结果
|
||||
await CacheService.set(
|
||||
cache_key, "NOT_FOUND", ttl_seconds=ProviderCacheService.CACHE_TTL
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def get_rate_multiplier_and_free_tier(
|
||||
db: Session,
|
||||
provider_api_key_id: Optional[str],
|
||||
provider_id: Optional[str],
|
||||
) -> Tuple[float, bool]:
|
||||
"""
|
||||
获取费率倍数和是否免费套餐(带缓存)
|
||||
|
||||
这是 UsageService._get_rate_multiplier_and_free_tier 的缓存版本。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
provider_api_key_id: ProviderAPIKey ID(可选)
|
||||
provider_id: Provider ID(可选)
|
||||
|
||||
Returns:
|
||||
(rate_multiplier, is_free_tier) 元组
|
||||
"""
|
||||
actual_rate_multiplier = 1.0
|
||||
is_free_tier = False
|
||||
|
||||
# 获取费率倍数
|
||||
if provider_api_key_id:
|
||||
rate_multiplier = await ProviderCacheService.get_provider_api_key_rate_multiplier(
|
||||
db, provider_api_key_id
|
||||
)
|
||||
if rate_multiplier is not None:
|
||||
actual_rate_multiplier = rate_multiplier
|
||||
|
||||
# 获取计费类型
|
||||
if provider_id:
|
||||
billing_type = await ProviderCacheService.get_provider_billing_type(db, provider_id)
|
||||
if billing_type == ProviderBillingType.FREE_TIER:
|
||||
is_free_tier = True
|
||||
|
||||
return actual_rate_multiplier, is_free_tier
|
||||
|
||||
@staticmethod
|
||||
async def invalidate_provider_api_key_cache(provider_api_key_id: str) -> None:
|
||||
"""清除 ProviderAPIKey 缓存"""
|
||||
await CacheService.delete(f"provider_api_key:rate_multiplier:{provider_api_key_id}")
|
||||
logger.debug(f"ProviderAPIKey 缓存已清除: {provider_api_key_id[:8]}...")
|
||||
|
||||
@staticmethod
|
||||
async def invalidate_provider_cache(provider_id: str) -> None:
|
||||
"""清除 Provider 缓存"""
|
||||
await CacheService.delete(f"provider:billing_type:{provider_id}")
|
||||
logger.debug(f"Provider 缓存已清除: {provider_id[:8]}...")
|
||||
@@ -85,6 +85,8 @@ class ConcurrencyManager:
|
||||
"""
|
||||
获取当前并发数
|
||||
|
||||
性能优化:使用 MGET 批量获取,减少 Redis 往返次数
|
||||
|
||||
Args:
|
||||
endpoint_id: Endpoint ID(可选)
|
||||
key_id: ProviderAPIKey ID(可选)
|
||||
@@ -104,15 +106,21 @@ class ConcurrencyManager:
|
||||
key_count = 0
|
||||
|
||||
try:
|
||||
# 使用 MGET 批量获取,减少 Redis 往返(2 次 GET -> 1 次 MGET)
|
||||
keys_to_fetch = []
|
||||
if endpoint_id:
|
||||
endpoint_key = self._get_endpoint_key(endpoint_id)
|
||||
result = await self._redis.get(endpoint_key)
|
||||
endpoint_count = int(result) if result else 0
|
||||
|
||||
keys_to_fetch.append(self._get_endpoint_key(endpoint_id))
|
||||
if key_id:
|
||||
key_key = self._get_key_key(key_id)
|
||||
result = await self._redis.get(key_key)
|
||||
key_count = int(result) if result else 0
|
||||
keys_to_fetch.append(self._get_key_key(key_id))
|
||||
|
||||
if keys_to_fetch:
|
||||
results = await self._redis.mget(keys_to_fetch)
|
||||
idx = 0
|
||||
if endpoint_id:
|
||||
endpoint_count = int(results[idx]) if results[idx] else 0
|
||||
idx += 1
|
||||
if key_id:
|
||||
key_count = int(results[idx]) if results[idx] else 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取并发数失败: {e}")
|
||||
|
||||
@@ -3,8 +3,9 @@
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -20,6 +21,49 @@ class LogLevel(str, Enum):
|
||||
FULL = "full" # 记录完整请求和响应(包含body,敏感信息会脱敏)
|
||||
|
||||
|
||||
# 进程内缓存 TTL(秒)- 系统配置变化不频繁,使用较长的 TTL
|
||||
_CONFIG_CACHE_TTL = 60 # 1 分钟
|
||||
|
||||
# 进程内缓存存储: {key: (value, expire_time)}
|
||||
_config_cache: Dict[str, Tuple[Any, float]] = {}
|
||||
|
||||
|
||||
def _get_cached_config(key: str) -> Tuple[bool, Any]:
|
||||
"""从进程内缓存获取配置值
|
||||
|
||||
Returns:
|
||||
(hit, value): hit=True 表示缓存命中,value 为缓存的值
|
||||
"""
|
||||
if key in _config_cache:
|
||||
value, expire_time = _config_cache[key]
|
||||
if time.time() < expire_time:
|
||||
return True, value
|
||||
# 缓存过期,安全删除(避免并发时 KeyError)
|
||||
_config_cache.pop(key, None)
|
||||
return False, None
|
||||
|
||||
|
||||
def _set_cached_config(key: str, value: Any) -> None:
|
||||
"""设置进程内缓存"""
|
||||
_config_cache[key] = (value, time.time() + _CONFIG_CACHE_TTL)
|
||||
|
||||
|
||||
def invalidate_config_cache(key: Optional[str] = None) -> None:
|
||||
"""清除配置缓存
|
||||
|
||||
Args:
|
||||
key: 配置键,如果为 None 则清除所有缓存
|
||||
"""
|
||||
global _config_cache
|
||||
if key is None:
|
||||
_config_cache = {}
|
||||
logger.debug("已清除所有系统配置缓存")
|
||||
else:
|
||||
# 使用 pop 安全删除,避免并发时 KeyError
|
||||
if _config_cache.pop(key, None) is not None:
|
||||
logger.debug(f"已清除系统配置缓存: {key}")
|
||||
|
||||
|
||||
class SystemConfigService:
|
||||
"""系统配置服务类"""
|
||||
|
||||
@@ -127,14 +171,23 @@ class SystemConfigService:
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, db: Session, key: str, default: Any = None) -> Optional[Any]:
|
||||
"""获取系统配置值"""
|
||||
"""获取系统配置值(带进程内缓存)"""
|
||||
# 1. 检查进程内缓存
|
||||
hit, cached_value = _get_cached_config(key)
|
||||
if hit:
|
||||
return cached_value
|
||||
|
||||
# 2. 查询数据库
|
||||
config = db.query(SystemConfig).filter(SystemConfig.key == key).first()
|
||||
if config:
|
||||
_set_cached_config(key, config.value)
|
||||
return config.value
|
||||
|
||||
# 如果配置不存在,检查默认值
|
||||
# 3. 如果配置不存在,使用默认值
|
||||
if key in cls.DEFAULT_CONFIGS:
|
||||
return cls.DEFAULT_CONFIGS[key]["value"]
|
||||
value = cls.DEFAULT_CONFIGS[key]["value"]
|
||||
_set_cached_config(key, value)
|
||||
return value
|
||||
|
||||
return default
|
||||
|
||||
@@ -185,6 +238,9 @@ class SystemConfigService:
|
||||
db.commit()
|
||||
db.refresh(config)
|
||||
|
||||
# 清除缓存
|
||||
invalidate_config_cache(key)
|
||||
|
||||
return config
|
||||
|
||||
@staticmethod
|
||||
@@ -243,6 +299,8 @@ class SystemConfigService:
|
||||
if config:
|
||||
db.delete(config)
|
||||
db.commit()
|
||||
# 清除缓存
|
||||
invalidate_config_cache(key)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.enums import ProviderBillingType
|
||||
from src.core.logger import logger
|
||||
from src.models.database import ApiKey, Provider, ProviderAPIKey, Usage, User, UserRole
|
||||
from src.services.model.cost import ModelCostService
|
||||
@@ -362,22 +361,12 @@ class UsageService:
|
||||
provider_api_key_id: Optional[str],
|
||||
provider_id: Optional[str],
|
||||
) -> Tuple[float, bool]:
|
||||
"""获取费率倍数和是否免费套餐"""
|
||||
actual_rate_multiplier = 1.0
|
||||
if provider_api_key_id:
|
||||
provider_key = (
|
||||
db.query(ProviderAPIKey).filter(ProviderAPIKey.id == provider_api_key_id).first()
|
||||
)
|
||||
if provider_key and provider_key.rate_multiplier:
|
||||
actual_rate_multiplier = provider_key.rate_multiplier
|
||||
"""获取费率倍数和是否免费套餐(使用缓存)"""
|
||||
from src.services.cache.provider_cache import ProviderCacheService
|
||||
|
||||
is_free_tier = False
|
||||
if provider_id:
|
||||
provider_obj = db.query(Provider).filter(Provider.id == provider_id).first()
|
||||
if provider_obj and provider_obj.billing_type == ProviderBillingType.FREE_TIER:
|
||||
is_free_tier = True
|
||||
|
||||
return actual_rate_multiplier, is_free_tier
|
||||
return await ProviderCacheService.get_rate_multiplier_and_free_tier(
|
||||
db, provider_api_key_id, provider_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def _calculate_costs(
|
||||
|
||||
Reference in New Issue
Block a user