7 Commits

Author SHA1 Message Date
fawney19
1521ce5a96 feat: 添加负载均衡调度模式
- 新增 load_balance 调度模式,同优先级内随机轮换
- 前端支持三种调度模式切换:缓存亲和、负载均衡、固定顺序
2026-01-08 03:20:04 +08:00
fawney19
f2e62dd197 feat: 添加版本更新检查功能
- 后端新增 /api/admin/system/check-update 接口,从 GitHub Tags 获取最新版本
- 前端新增 UpdateDialog 组件,管理员登录后自动检查更新并弹窗提示
- 同一会话内只检查一次,点击"稍后提醒"后 24 小时内不再提示
- CI 和 deploy.sh 自动生成 _version.py 版本文件
2026-01-08 03:01:54 +08:00
fawney19
d378630b38 perf: 添加多层缓存优化减少数据库查询
- 新增 ProviderCacheService 缓存 Provider 和 ProviderAPIKey 数据
- SystemConfigService 添加进程内缓存(TTL 60秒)
- API Key last_used_at 更新添加节流策略(60秒间隔)
- HTTP 连接池配置改为可配置,支持根据 Worker 数量自动计算
- 前端优先级管理改用 health_score 显示健康度
2026-01-08 02:34:59 +08:00
fawney19
d9e6346911 fix: 降低 API Key 最小长度限制至 3 个字符 2026-01-08 01:53:16 +08:00
fawney19
238788e0e9 fix: 统一端点默认重试次数为 2
同步前端表单、mock 数据和后端导入配置中端点的默认重试次数
2026-01-08 01:40:40 +08:00
fawney19
68ff828505 feat: 容器启动时自动执行数据库迁移
- 添加 entrypoint.sh 在容器启动前执行 alembic upgrade head
- 更新 Dockerfile.app 和 Dockerfile.app.local 使用新入口脚本
- 移除手动迁移脚本 migrate.sh
- 简化 README 部署说明
2026-01-08 01:28:36 +08:00
fawney19
59447fc12b fix: 固定容器内部端口为 8084,避免 PORT 环境变量导致端口冲突 2026-01-07 21:51:55 +08:00
28 changed files with 878 additions and 133 deletions

View File

@@ -146,10 +146,33 @@ jobs:
type=semver,pattern={{major}}.{{minor}}
type=sha,prefix=
- name: Extract version from tag
id: version
run: |
# 从 tag 提取版本号,如 v0.2.5 -> 0.2.5
VERSION="${GITHUB_REF#refs/tags/v}"
if [ "$VERSION" = "$GITHUB_REF" ]; then
# 不是 tag 触发,使用 git describe
VERSION=$(git describe --tags --always | sed 's/^v//')
fi
echo "version=$VERSION" >> $GITHUB_OUTPUT
echo "Extracted version: $VERSION"
- name: Update Dockerfile.app to use registry base image
run: |
sed -i "s|FROM aether-base:latest AS builder|FROM ${{ env.REGISTRY }}/${{ env.BASE_IMAGE_NAME }}:latest AS builder|g" Dockerfile.app
- name: Generate version file
run: |
# 生成 _version.py 文件
cat > src/_version.py << EOF
# Auto-generated by CI
__version__ = '${{ steps.version.outputs.version }}'
__version_tuple__ = tuple(int(x) for x in '${{ steps.version.outputs.version }}'.split('.') if x.isdigit())
version = __version__
version_tuple = __version_tuple__
EOF
- name: Build and push app image
uses: docker/build-push-action@v5
with:

3
.gitignore vendored
View File

@@ -224,3 +224,6 @@ extracted_*.ts
.deps-hash
.code-hash
.migration-hash
# Version file (auto-generated by hatch-vcs)
src/_version.py

View File

@@ -127,14 +127,14 @@ RUN printf '%s\n' \
'pidfile=/var/run/supervisord.pid' \
'' \
'[program:nginx]' \
'command=/bin/bash -c "sed \"s/PORT_PLACEHOLDER/${PORT:-8084}/g\" /etc/nginx/sites-available/default.template > /etc/nginx/sites-available/default && /usr/sbin/nginx -g \"daemon off;\""' \
'command=/bin/bash -c "sed \"s/PORT_PLACEHOLDER/8084/g\" /etc/nginx/sites-available/default.template > /etc/nginx/sites-available/default && /usr/sbin/nginx -g \"daemon off;\""' \
'autostart=true' \
'autorestart=true' \
'stdout_logfile=/var/log/nginx/access.log' \
'stderr_logfile=/var/log/nginx/error.log' \
'' \
'[program:app]' \
'command=gunicorn src.main:app --preload -w %(ENV_GUNICORN_WORKERS)s -k uvicorn.workers.UvicornWorker --bind 0.0.0.0:%(ENV_PORT)s --timeout 120 --access-logfile - --error-logfile - --log-level info' \
'command=gunicorn src.main:app --preload -w %(ENV_GUNICORN_WORKERS)s -k uvicorn.workers.UvicornWorker --bind 127.0.0.1:8084 --timeout 120 --access-logfile - --error-logfile - --log-level info' \
'directory=/app' \
'autostart=true' \
'autorestart=true' \
@@ -147,6 +147,10 @@ RUN printf '%s\n' \
# 创建目录
RUN mkdir -p /var/log/supervisor /app/logs /app/data
# 入口脚本(启动前执行迁移)
COPY entrypoint.sh /entrypoint.sh
RUN chmod +x /entrypoint.sh
# 环境变量
ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
@@ -161,4 +165,5 @@ EXPOSE 80
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost/health || exit 1
ENTRYPOINT ["/entrypoint.sh"]
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]

View File

@@ -139,6 +139,10 @@ RUN printf '%s\n' \
# 创建目录
RUN mkdir -p /var/log/supervisor /app/logs /app/data
# 入口脚本(启动前执行迁移)
COPY entrypoint.sh /entrypoint.sh
RUN chmod +x /entrypoint.sh
# 环境变量
ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
@@ -152,4 +156,5 @@ EXPOSE 80
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost/health || exit 1
ENTRYPOINT ["/entrypoint.sh"]
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]

View File

@@ -57,14 +57,8 @@ cd Aether
cp .env.example .env
python generate_keys.py # 生成密钥, 并将生成的密钥填入 .env
# 3. 部署
docker compose up -d
# 4. 首次部署时, 初始化数据库
./migrate.sh
# 5. 更新
docker compose pull && docker compose up -d && ./migrate.sh
# 3. 部署 / 更新(自动执行数据库迁移)
docker compose pull && docker compose up -d
```
### Docker Compose本地构建镜像

View File

@@ -88,9 +88,28 @@ build_base() {
save_deps_hash
}
# 生成版本文件
generate_version_file() {
# 从 git 获取版本号
local version
version=$(git describe --tags --always 2>/dev/null | sed 's/^v//')
if [ -z "$version" ]; then
version="unknown"
fi
echo ">>> Generating version file: $version"
cat > src/_version.py << EOF
# Auto-generated by deploy.sh - do not edit
__version__ = '$version'
__version_tuple__ = tuple(int(x) for x in '$version'.split('-')[0].split('.') if x.isdigit())
version = __version__
version_tuple = __version_tuple__
EOF
}
# 构建应用镜像
build_app() {
echo ">>> Building app image (code only)..."
generate_version_file
docker build -f Dockerfile.app.local -t aether-app:latest .
save_code_hash
}

8
entrypoint.sh Normal file
View File

@@ -0,0 +1,8 @@
#!/bin/bash
set -e
echo "Running database migrations..."
alembic upgrade head
echo "Starting application..."
exec "$@"

View File

@@ -159,6 +159,15 @@ export interface EmailTemplateResetResponse {
}
}
// 检查更新响应
export interface CheckUpdateResponse {
current_version: string
latest_version: string | null
has_update: boolean
release_url: string | null
error: string | null
}
// LDAP 配置响应
export interface LdapConfigResponse {
server_url: string | null
@@ -526,6 +535,14 @@ export const adminApi = {
return response.data
},
// 检查系统更新
async checkUpdate(): Promise<CheckUpdateResponse> {
const response = await apiClient.get<CheckUpdateResponse>(
'/api/admin/system/check-update'
)
return response.data
},
// LDAP 配置相关
// 获取 LDAP 配置
async getLdapConfig(): Promise<LdapConfigResponse> {

View File

@@ -0,0 +1,112 @@
<template>
<Dialog
v-model="isOpen"
size="md"
title=""
>
<div class="flex flex-col items-center text-center py-2">
<!-- Logo -->
<HeaderLogo
size="h-16 w-16"
class-name="text-primary"
/>
<!-- Title -->
<h2 class="text-xl font-semibold text-foreground mt-4 mb-2">
发现新版本
</h2>
<!-- Version Info -->
<div class="flex items-center gap-3 mb-4">
<span class="px-3 py-1.5 rounded-lg bg-muted text-sm font-mono text-muted-foreground">
v{{ currentVersion }}
</span>
<svg
class="h-4 w-4 text-muted-foreground"
fill="none"
stroke="currentColor"
viewBox="0 0 24 24"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M13 7l5 5m0 0l-5 5m5-5H6"
/>
</svg>
<span class="px-3 py-1.5 rounded-lg bg-primary/10 text-sm font-mono font-medium text-primary">
v{{ latestVersion }}
</span>
</div>
<!-- Description -->
<p class="text-sm text-muted-foreground max-w-xs">
新版本已发布建议更新以获得最新功能和安全修复
</p>
</div>
<template #footer>
<div class="flex w-full gap-3">
<Button
variant="outline"
class="flex-1"
@click="handleLater"
>
稍后提醒
</Button>
<Button
class="flex-1"
@click="handleViewRelease"
>
查看更新
</Button>
</div>
</template>
</Dialog>
</template>
<script setup lang="ts">
import { ref, watch } from 'vue'
import { Dialog } from '@/components/ui'
import Button from '@/components/ui/button.vue'
import HeaderLogo from '@/components/HeaderLogo.vue'
const props = defineProps<{
modelValue: boolean
currentVersion: string
latestVersion: string
releaseUrl: string | null
}>()
const emit = defineEmits<{
'update:modelValue': [value: boolean]
}>()
const isOpen = ref(props.modelValue)
watch(() => props.modelValue, (val) => {
isOpen.value = val
})
watch(isOpen, (val) => {
emit('update:modelValue', val)
})
function handleLater() {
// 记录忽略的版本24小时内不再提醒
const ignoreKey = 'aether_update_ignore'
const ignoreData = {
version: props.latestVersion,
until: Date.now() + 24 * 60 * 60 * 1000 // 24小时
}
localStorage.setItem(ignoreKey, JSON.stringify(ignoreData))
isOpen.value = false
}
function handleViewRelease() {
if (props.releaseUrl) {
window.open(props.releaseUrl, '_blank')
}
isOpen.value = false
}
</script>

View File

@@ -296,7 +296,7 @@ const form = ref({
base_url: '',
custom_path: '',
timeout: 300,
max_retries: 3,
max_retries: 2,
max_concurrent: undefined as number | undefined,
rate_limit: undefined as number | undefined,
is_active: true,
@@ -392,7 +392,7 @@ function resetForm() {
base_url: '',
custom_path: '',
timeout: 300,
max_retries: 3,
max_retries: 2,
max_concurrent: undefined,
rate_limit: undefined,
is_active: true,

View File

@@ -349,8 +349,8 @@ const apiKeyError = computed(() => {
}
// 如果输入了值,检查长度
if (apiKey.length < 10) {
return 'API 密钥至少需要 10 个字符'
if (apiKey.length < 3) {
return 'API 密钥至少需要 3 个字符'
}
return ''

View File

@@ -262,17 +262,17 @@
<div class="shrink-0 flex items-center gap-3">
<!-- 健康度 -->
<div
v-if="key.success_rate !== null"
v-if="key.health_score != null"
class="text-xs text-right"
>
<div
class="font-medium tabular-nums"
:class="[
key.success_rate >= 0.95 ? 'text-green-600' :
key.success_rate >= 0.8 ? 'text-yellow-600' : 'text-red-500'
key.health_score >= 0.95 ? 'text-green-600' :
key.health_score >= 0.5 ? 'text-yellow-600' : 'text-red-500'
]"
>
{{ (key.success_rate * 100).toFixed(0) }}%
{{ ((key.health_score || 0) * 100).toFixed(0) }}%
</div>
<div class="text-[10px] text-muted-foreground opacity-70">
{{ key.request_count }} reqs
@@ -319,19 +319,6 @@
<div class="flex items-center gap-2 pl-4 border-l border-border">
<span class="text-xs text-muted-foreground">调度:</span>
<div class="flex gap-0.5 p-0.5 bg-muted/40 rounded-md">
<button
type="button"
class="px-2 py-1 text-xs font-medium rounded transition-all"
:class="[
schedulingMode === 'fixed_order'
? 'bg-primary text-primary-foreground shadow-sm'
: 'text-muted-foreground hover:text-foreground hover:bg-muted/50'
]"
title="严格按优先级顺序,不考虑缓存"
@click="schedulingMode = 'fixed_order'"
>
固定顺序
</button>
<button
type="button"
class="px-2 py-1 text-xs font-medium rounded transition-all"
@@ -345,6 +332,32 @@
>
缓存亲和
</button>
<button
type="button"
class="px-2 py-1 text-xs font-medium rounded transition-all"
:class="[
schedulingMode === 'load_balance'
? 'bg-primary text-primary-foreground shadow-sm'
: 'text-muted-foreground hover:text-foreground hover:bg-muted/50'
]"
title="同优先级内随机轮换,不考虑缓存"
@click="schedulingMode = 'load_balance'"
>
负载均衡
</button>
<button
type="button"
class="px-2 py-1 text-xs font-medium rounded transition-all"
:class="[
schedulingMode === 'fixed_order'
? 'bg-primary text-primary-foreground shadow-sm'
: 'text-muted-foreground hover:text-foreground hover:bg-muted/50'
]"
title="严格按优先级顺序,不考虑缓存"
@click="schedulingMode = 'fixed_order'"
>
固定顺序
</button>
</div>
</div>
</div>
@@ -400,6 +413,7 @@ interface KeyWithMeta {
endpoint_base_url: string
api_format: string
capabilities: string[]
health_score: number | null
success_rate: number | null
avg_response_time_ms: number | null
request_count: number
@@ -444,7 +458,7 @@ const saving = ref(false)
const editingKeyPriority = ref<Record<string, string | null>>({}) // format -> keyId
// 调度模式状态
const schedulingMode = ref<'fixed_order' | 'cache_affinity'>('cache_affinity')
const schedulingMode = ref<'fixed_order' | 'load_balance' | 'cache_affinity'>('cache_affinity')
// 可用的 API 格式
const availableFormats = computed(() => {
@@ -477,7 +491,11 @@ async function loadCurrentPriorityMode() {
activeMainTab.value = currentMode === 'global_key' ? 'key' : 'provider'
const currentSchedulingMode = schedulingResponse.value || 'cache_affinity'
schedulingMode.value = currentSchedulingMode === 'fixed_order' ? 'fixed_order' : 'cache_affinity'
if (currentSchedulingMode === 'fixed_order' || currentSchedulingMode === 'load_balance' || currentSchedulingMode === 'cache_affinity') {
schedulingMode.value = currentSchedulingMode
} else {
schedulingMode.value = 'cache_affinity'
}
} catch {
activeMainTab.value = 'provider'
schedulingMode.value = 'cache_affinity'

View File

@@ -295,6 +295,15 @@
</template>
<RouterView />
<!-- 更新提示弹窗 -->
<UpdateDialog
v-if="updateInfo"
v-model="showUpdateDialog"
:current-version="updateInfo.current_version"
:latest-version="updateInfo.latest_version || ''"
:release-url="updateInfo.release_url"
/>
</AppShell>
</template>
@@ -304,10 +313,12 @@ import { useRoute, useRouter } from 'vue-router'
import { useAuthStore } from '@/stores/auth'
import { useDarkMode } from '@/composables/useDarkMode'
import { isDemoMode } from '@/config/demo'
import { adminApi, type CheckUpdateResponse } from '@/api/admin'
import Button from '@/components/ui/button.vue'
import AppShell from '@/components/layout/AppShell.vue'
import SidebarNav from '@/components/layout/SidebarNav.vue'
import HeaderLogo from '@/components/HeaderLogo.vue'
import UpdateDialog from '@/components/common/UpdateDialog.vue'
import {
Home,
Users,
@@ -345,17 +356,67 @@ const showAuthError = ref(false)
const mobileMenuOpen = ref(false)
let authCheckInterval: number | null = null
// 更新检查相关
const showUpdateDialog = ref(false)
const updateInfo = ref<CheckUpdateResponse | null>(null)
// 路由变化时自动关闭移动端菜单
watch(() => route.path, () => {
mobileMenuOpen.value = false
})
// 检查是否应该显示更新提示
function shouldShowUpdatePrompt(latestVersion: string): boolean {
const ignoreKey = 'aether_update_ignore'
const ignoreData = localStorage.getItem(ignoreKey)
if (!ignoreData) return true
try {
const { version, until } = JSON.parse(ignoreData)
// 如果忽略的是同一版本且未过期,则不显示
if (version === latestVersion && Date.now() < until) {
return false
}
} catch {
// 解析失败,显示提示
}
return true
}
// 检查更新
async function checkForUpdate() {
// 只有管理员才检查更新
if (authStore.user?.role !== 'admin') return
// 同一会话内只检查一次
const sessionKey = 'aether_update_checked'
if (sessionStorage.getItem(sessionKey)) return
sessionStorage.setItem(sessionKey, '1')
try {
const result = await adminApi.checkUpdate()
if (result.has_update && result.latest_version) {
if (shouldShowUpdatePrompt(result.latest_version)) {
updateInfo.value = result
showUpdateDialog.value = true
}
}
} catch {
// 静默失败,不影响用户体验
}
}
onMounted(() => {
authCheckInterval = setInterval(() => {
if (authStore.user && !authStore.token) {
showAuthError.value = true
}
}, 5000)
// 延迟检查更新,避免影响页面加载
setTimeout(() => {
checkForUpdate()
}, 2000)
})
onUnmounted(() => {

View File

@@ -425,9 +425,9 @@ const MOCK_ENDPOINT_KEYS = [
// Mock Endpoints
const MOCK_ENDPOINTS = [
{ id: 'ep-001', provider_id: 'provider-001', provider_name: 'anthropic', api_format: 'claude', base_url: 'https://api.anthropic.com', auth_type: 'bearer', timeout: 120, max_retries: 3, priority: 100, weight: 100, health_score: 98, consecutive_failures: 0, is_active: true, total_keys: 2, active_keys: 2, created_at: '2024-01-01T00:00:00Z', updated_at: new Date().toISOString() },
{ id: 'ep-002', provider_id: 'provider-002', provider_name: 'openai', api_format: 'openai', base_url: 'https://api.openai.com', auth_type: 'bearer', timeout: 60, max_retries: 3, priority: 90, weight: 100, health_score: 97, consecutive_failures: 0, is_active: true, total_keys: 1, active_keys: 1, created_at: '2024-01-01T00:00:00Z', updated_at: new Date().toISOString() },
{ id: 'ep-003', provider_id: 'provider-003', provider_name: 'google', api_format: 'gemini', base_url: 'https://generativelanguage.googleapis.com', auth_type: 'api_key', timeout: 60, max_retries: 3, priority: 80, weight: 100, health_score: 96, consecutive_failures: 0, is_active: true, total_keys: 1, active_keys: 1, created_at: '2024-01-15T00:00:00Z', updated_at: new Date().toISOString() }
{ id: 'ep-001', provider_id: 'provider-001', provider_name: 'anthropic', api_format: 'claude', base_url: 'https://api.anthropic.com', auth_type: 'bearer', timeout: 120, max_retries: 2, priority: 100, weight: 100, health_score: 98, consecutive_failures: 0, is_active: true, total_keys: 2, active_keys: 2, created_at: '2024-01-01T00:00:00Z', updated_at: new Date().toISOString() },
{ id: 'ep-002', provider_id: 'provider-002', provider_name: 'openai', api_format: 'openai', base_url: 'https://api.openai.com', auth_type: 'bearer', timeout: 60, max_retries: 2, priority: 90, weight: 100, health_score: 97, consecutive_failures: 0, is_active: true, total_keys: 1, active_keys: 1, created_at: '2024-01-01T00:00:00Z', updated_at: new Date().toISOString() },
{ id: 'ep-003', provider_id: 'provider-003', provider_name: 'google', api_format: 'gemini', base_url: 'https://generativelanguage.googleapis.com', auth_type: 'api_key', timeout: 60, max_retries: 2, priority: 80, weight: 100, health_score: 96, consecutive_failures: 0, is_active: true, total_keys: 1, active_keys: 1, created_at: '2024-01-15T00:00:00Z', updated_at: new Date().toISOString() }
]
// Mock 能力定义
@@ -1224,7 +1224,7 @@ function generateMockEndpointsForProvider(providerId: string) {
'https://generativelanguage.googleapis.com',
auth_type: format.includes('GEMINI') ? 'api_key' : 'bearer',
timeout: 120,
max_retries: 3,
max_retries: 2,
priority: 100 - index * 10,
weight: 100,
health_score: healthDetail?.health_score ?? 1.0,

View File

@@ -54,7 +54,7 @@ const fieldNameMap: Record<string, string> = {
*/
const errorTypeMap: Record<string, (error: ValidationError) => string> = {
'string_too_short': (error) => {
const minLength = error.ctx?.min_length || 10
const minLength = error.ctx?.min_length || 3
return `长度不能少于 ${minLength} 个字符`
},
'string_too_long': (error) => {

View File

@@ -1,12 +0,0 @@
#!/bin/bash
# 数据库迁移脚本 - 在 Docker 容器内执行 Alembic 迁移
set -e
CONTAINER_NAME="aether-app"
echo "Running database migrations in container: $CONTAINER_NAME"
docker exec $CONTAINER_NAME alembic upgrade head
echo "Database migration completed successfully"

View File

@@ -1,34 +0,0 @@
# file generated by setuptools-scm
# don't change, don't track in version control
__all__ = [
"__version__",
"__version_tuple__",
"version",
"version_tuple",
"__commit_id__",
"commit_id",
]
TYPE_CHECKING = False
if TYPE_CHECKING:
from typing import Tuple
from typing import Union
VERSION_TUPLE = Tuple[Union[int, str], ...]
COMMIT_ID = Union[str, None]
else:
VERSION_TUPLE = object
COMMIT_ID = object
version: str
__version__: str
__version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE
commit_id: COMMIT_ID
__commit_id__: COMMIT_ID
__version__ = version = '0.2.5'
__version_tuple__ = version_tuple = (0, 2, 5)
__commit_id__ = commit_id = None

View File

@@ -18,6 +18,7 @@ from src.core.key_capabilities import get_capability
from src.core.logger import logger
from src.database import get_db
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
from src.services.cache.provider_cache import ProviderCacheService
from src.models.endpoint_models import (
BatchUpdateKeyPriorityRequest,
EndpointAPIKeyCreate,
@@ -411,6 +412,10 @@ class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
db.commit()
db.refresh(key)
# 如果更新了 rate_multiplier清除缓存
if "rate_multiplier" in update_data:
await ProviderCacheService.invalidate_provider_api_key_cache(self.key_id)
logger.info("[OK] 更新 Key: ID=%s, Updates=%s", self.key_id, list(update_data.keys()))
try:
@@ -550,6 +555,7 @@ class AdminGetKeysGroupedByFormatAdapter(AdminApiAdapter):
"endpoint_base_url": endpoint.base_url,
"api_format": api_format,
"capabilities": caps_list,
"health_score": key.health_score,
"success_rate": success_rate,
"avg_response_time_ms": avg_response_time_ms,
"request_count": key.request_count,

View File

@@ -11,9 +11,11 @@ from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.enums import ProviderBillingType
from src.core.exceptions import InvalidRequestException, NotFoundException
from src.core.logger import logger
from src.database import get_db
from src.models.admin_requests import CreateProviderRequest, UpdateProviderRequest
from src.models.database import Provider
from src.services.cache.provider_cache import ProviderCacheService
router = APIRouter(tags=["Provider CRUD"])
pipeline = ApiRequestPipeline()
@@ -296,6 +298,11 @@ class AdminUpdateProviderAdapter(AdminApiAdapter):
db.commit()
db.refresh(provider)
# 如果更新了 billing_type清除缓存
if "billing_type" in update_data:
await ProviderCacheService.invalidate_provider_cache(provider.id)
logger.debug(f"已清除 Provider 缓存: {provider.id}")
context.add_audit_metadata(
action="update_provider",
provider_id=provider.id,

View File

@@ -42,6 +42,42 @@ def _get_version_from_git() -> str | None:
return None
def _get_current_version() -> str:
"""获取当前版本号"""
version = _get_version_from_git()
if version:
return version
try:
from src._version import __version__
return __version__
except ImportError:
return "unknown"
def _parse_version(version_str: str) -> tuple:
"""解析版本号为可比较的元组,支持 3-4 段版本号
例如:
- '0.2.5' -> (0, 2, 5, 0)
- '0.2.5.1' -> (0, 2, 5, 1)
- 'v0.2.5-4-g1234567' -> (0, 2, 5, 0)
"""
import re
version_str = version_str.lstrip("v")
main_version = re.split(r"[-+]", version_str)[0]
try:
parts = main_version.split(".")
# 标准化为 4 段,便于比较
int_parts = [int(p) for p in parts]
while len(int_parts) < 4:
int_parts.append(0)
return tuple(int_parts[:4])
except ValueError:
return (0, 0, 0, 0)
@router.get("/version")
async def get_system_version():
"""
@@ -52,18 +88,111 @@ async def get_system_version():
**返回字段**:
- `version`: 版本号字符串
"""
# 优先从 git 获取
version = _get_version_from_git()
if version:
return {"version": version}
return {"version": _get_current_version()}
@router.get("/check-update")
async def check_update():
"""
检查系统更新
从 GitHub Tags 获取最新版本并与当前版本对比。
**返回字段**:
- `current_version`: 当前版本号
- `latest_version`: 最新版本号
- `has_update`: 是否有更新可用
- `release_url`: 最新版本的 GitHub 页面链接
"""
import httpx
from src.clients.http_client import HTTPClientPool
current_version = _get_current_version()
github_repo = "Aethersailor/Aether"
github_tags_url = f"https://api.github.com/repos/{github_repo}/tags"
# 回退到静态版本文件
try:
from src._version import __version__
async with HTTPClientPool.get_temp_client(
timeout=httpx.Timeout(connect=5.0, read=10.0, write=5.0, pool=5.0)
) as client:
response = await client.get(
github_tags_url,
headers={
"Accept": "application/vnd.github.v3+json",
"User-Agent": f"Aether/{current_version}",
},
params={"per_page": 10},
)
return {"version": __version__}
except ImportError:
return {"version": "unknown"}
if response.status_code != 200:
return {
"current_version": current_version,
"latest_version": None,
"has_update": False,
"release_url": None,
"error": f"GitHub API 返回错误: {response.status_code}",
}
tags = response.json()
if not tags:
return {
"current_version": current_version,
"latest_version": None,
"has_update": False,
"release_url": None,
"error": None,
}
# 找到最新的版本 tag按版本号排序而非时间
version_tags = []
for tag in tags:
tag_name = tag.get("name", "")
if tag_name.startswith("v") or tag_name[0].isdigit():
version_tags.append((tag_name, _parse_version(tag_name)))
if not version_tags:
return {
"current_version": current_version,
"latest_version": None,
"has_update": False,
"release_url": None,
"error": None,
}
# 按版本号排序,取最大的
version_tags.sort(key=lambda x: x[1], reverse=True)
latest_tag = version_tags[0][0]
latest_version = latest_tag.lstrip("v")
current_tuple = _parse_version(current_version)
latest_tuple = _parse_version(latest_version)
has_update = latest_tuple > current_tuple
return {
"current_version": current_version,
"latest_version": latest_version,
"has_update": has_update,
"release_url": f"https://github.com/{github_repo}/releases/tag/{latest_tag}",
"error": None,
}
except httpx.TimeoutException:
return {
"current_version": current_version,
"latest_version": None,
"has_update": False,
"release_url": None,
"error": "检查更新超时",
}
except Exception as e:
return {
"current_version": current_version,
"latest_version": None,
"has_update": False,
"release_url": None,
"error": f"检查更新失败: {str(e)}",
}
pipeline = ApiRequestPipeline()
@@ -887,7 +1016,7 @@ class AdminImportConfigAdapter(AdminApiAdapter):
)
existing_ep.headers = ep_data.get("headers")
existing_ep.timeout = ep_data.get("timeout", 300)
existing_ep.max_retries = ep_data.get("max_retries", 3)
existing_ep.max_retries = ep_data.get("max_retries", 2)
existing_ep.max_concurrent = ep_data.get("max_concurrent")
existing_ep.rate_limit = ep_data.get("rate_limit")
existing_ep.is_active = ep_data.get("is_active", True)
@@ -903,7 +1032,7 @@ class AdminImportConfigAdapter(AdminApiAdapter):
base_url=ep_data["base_url"],
headers=ep_data.get("headers"),
timeout=ep_data.get("timeout", 300),
max_retries=ep_data.get("max_retries", 3),
max_retries=ep_data.get("max_retries", 2),
max_concurrent=ep_data.get("max_concurrent"),
rate_limit=ep_data.get("rate_limit"),
is_active=ep_data.get("is_active", True),

View File

@@ -90,13 +90,18 @@ class HTTPClientPool:
pool=config.http_pool_timeout,
),
limits=httpx.Limits(
max_connections=100, # 最大连接数
max_keepalive_connections=20, # 最大保活连接数
keepalive_expiry=30.0, # 保活过期时间(秒)
max_connections=config.http_max_connections,
max_keepalive_connections=config.http_keepalive_connections,
keepalive_expiry=config.http_keepalive_expiry,
),
follow_redirects=True, # 跟随重定向
)
logger.info("全局HTTP客户端池已初始化")
logger.info(
f"全局HTTP客户端池已初始化: "
f"max_connections={config.http_max_connections}, "
f"keepalive={config.http_keepalive_connections}, "
f"keepalive_expiry={config.http_keepalive_expiry}s"
)
return cls._default_client
@classmethod

View File

@@ -145,6 +145,24 @@ class Config:
self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0"))
self.http_pool_timeout = float(os.getenv("HTTP_POOL_TIMEOUT", "10.0"))
# HTTP 连接池配置
# HTTP_MAX_CONNECTIONS: 最大连接数,影响并发能力
# - 每个连接占用一个 socket过多会耗尽系统资源
# - 默认根据 Worker 数量自动计算:单 Worker 200多 Worker 按比例分配
# HTTP_KEEPALIVE_CONNECTIONS: 保活连接数,影响连接复用效率
# - 高频请求场景应该增大此值
# - 默认为 max_connections 的 30%(长连接场景更高效)
# HTTP_KEEPALIVE_EXPIRY: 保活过期时间(秒)
# - 过短会频繁重建连接,过长会占用资源
# - 默认 30 秒,生图等长连接场景可适当增大
self.http_max_connections = int(
os.getenv("HTTP_MAX_CONNECTIONS") or self._auto_http_max_connections()
)
self.http_keepalive_connections = int(
os.getenv("HTTP_KEEPALIVE_CONNECTIONS") or self._auto_http_keepalive_connections()
)
self.http_keepalive_expiry = float(os.getenv("HTTP_KEEPALIVE_EXPIRY", "30.0"))
# 流式处理配置
# STREAM_PREFETCH_LINES: 预读行数,用于检测嵌套错误
# STREAM_STATS_DELAY: 统计记录延迟(秒),等待流完全关闭
@@ -224,6 +242,53 @@ class Config:
"""智能计算最大溢出连接数 - 与 pool_size 相同"""
return self.db_pool_size
def _auto_http_max_connections(self) -> int:
"""
智能计算 HTTP 最大连接数
计算依据:
1. 系统 socket 资源有限Linux 默认 ulimit -n 通常为 1024
2. 多 Worker 部署时每个进程独立连接池
3. 需要为数据库连接、Redis 连接等预留资源
公式: base_connections / workers
- 单 Worker: 200 连接(适合开发/低负载)
- 多 Worker: 按比例分配,确保总数不超过系统限制
范围: 50 - 500
"""
# 基础连接数:假设系统可用 socket 约 800 个用于 HTTP
# (预留给 DB、Redis、内部服务等
base_connections = 800
workers = max(self.worker_processes, 1)
# 每个 Worker 分配的连接数
per_worker = base_connections // workers
# 限制范围:最小 50保证基本并发最大 500避免资源耗尽
return max(50, min(per_worker, 500))
def _auto_http_keepalive_connections(self) -> int:
"""
智能计算 HTTP 保活连接数
计算依据:
1. 保活连接用于复用,减少 TCP 握手开销
2. 对于 API 网关场景,上游请求频繁,保活比例应较高
3. 生图等长连接场景,连接会被长时间占用
公式: max_connections * 0.3
- 30% 的比例在复用效率和资源占用间取得平衡
- 长连接场景建议手动调高到 50-70%
范围: 10 - max_connections
"""
# 保活连接数为最大连接数的 30%
keepalive = int(self.http_max_connections * 0.3)
# 最小 10 个保活连接,最大不超过 max_connections
return max(10, min(keepalive, self.http_max_connections))
def _parse_ttfb_timeout(self) -> float:
"""
解析 TTFB 超时配置,带错误处理和范围限制

View File

@@ -134,7 +134,7 @@ class EndpointAPIKeyCreate(BaseModel):
"""为 Endpoint 添加 API Key"""
endpoint_id: str = Field(..., description="Endpoint ID")
api_key: str = Field(..., min_length=10, max_length=500, description="API Key将自动加密")
api_key: str = Field(..., min_length=3, max_length=500, description="API Key将自动加密")
name: str = Field(..., min_length=1, max_length=100, description="密钥名称(必填,用于识别)")
# 成本计算
@@ -175,13 +175,9 @@ class EndpointAPIKeyCreate(BaseModel):
@classmethod
def validate_api_key(cls, v: str) -> str:
"""验证 API Key 安全性"""
# 移除首尾空白
# 移除首尾空白(长度校验由 Field min_length 处理)
v = v.strip()
# 检查最小长度
if len(v) < 10:
raise ValueError("API Key 长度不能少于 10 个字符")
# 检查危险字符SQL 注入防护)
dangerous_chars = ["'", '"', ";", "--", "/*", "*/", "<", ">"]
for char in dangerous_chars:
@@ -219,7 +215,7 @@ class EndpointAPIKeyUpdate(BaseModel):
"""更新 Endpoint API Key"""
api_key: Optional[str] = Field(
default=None, min_length=10, max_length=500, description="API Key将自动加密"
default=None, min_length=3, max_length=500, description="API Key将自动加密"
)
name: Optional[str] = Field(default=None, min_length=1, max_length=100, description="密钥名称")
rate_multiplier: Optional[float] = Field(default=None, ge=0.01, description="成本倍率")

View File

@@ -8,7 +8,9 @@ import hashlib
import secrets
import time
import uuid
from collections import OrderedDict
from datetime import datetime, timedelta, timezone
from threading import Lock
from typing import TYPE_CHECKING, Any, Dict, Optional
import jwt
@@ -30,6 +32,44 @@ from src.services.cache.user_cache import UserCacheService
from src.services.user.apikey import ApiKeyService
# API Key last_used_at 更新节流配置
# 同一个 API Key 在此时间间隔内只会更新一次 last_used_at
_LAST_USED_UPDATE_INTERVAL = 60 # 秒
_LAST_USED_CACHE_MAX_SIZE = 10000 # LRU 缓存最大条目数
# 进程内缓存:记录每个 API Key 最后一次更新 last_used_at 的时间
# 使用 OrderedDict 实现 LRU避免内存无限增长
_api_key_last_update_times: OrderedDict[str, float] = OrderedDict()
_last_update_lock = Lock()
def _should_update_last_used(api_key_id: str) -> bool:
"""判断是否应该更新 API Key 的 last_used_at
使用节流策略,同一个 Key 在指定间隔内只更新一次。
线程安全,使用 LRU 策略限制缓存大小。
Returns:
True 表示应该更新False 表示跳过
"""
now = time.time()
with _last_update_lock:
last_update = _api_key_last_update_times.get(api_key_id, 0)
if now - last_update >= _LAST_USED_UPDATE_INTERVAL:
_api_key_last_update_times[api_key_id] = now
# LRU: 移到末尾(最近使用)
_api_key_last_update_times.move_to_end(api_key_id)
# 超过最大容量时,移除最旧的条目
while len(_api_key_last_update_times) > _LAST_USED_CACHE_MAX_SIZE:
_api_key_last_update_times.popitem(last=False)
return True
return False
# JWT配置从config读取
if not config.jwt_secret_key:
# 如果没有配置,生成一个随机密钥并警告
@@ -367,9 +407,10 @@ class AuthService:
logger.warning(f"API认证失败 - 用户已禁用: {user.email}")
return None
# 更新最后使用时间
key_record.last_used_at = datetime.now(timezone.utc)
db.commit() # 立即提交事务,释放数据库锁,避免阻塞后续请求
# 更新最后使用时间(使用节流策略,减少数据库写入)
if _should_update_last_used(key_record.id):
key_record.last_used_at = datetime.now(timezone.utc)
db.commit() # 立即提交事务,释放数据库锁,避免阻塞后续请求
api_key_fp = hashlib.sha256(api_key.encode()).hexdigest()[:12]
logger.debug("API认证成功: 用户 {} (api_key_fp={})", user.email, api_key_fp)

View File

@@ -121,11 +121,13 @@ class CacheAwareScheduler:
PRIORITY_MODE_GLOBAL_KEY,
}
# 调度模式常量
SCHEDULING_MODE_FIXED_ORDER = "fixed_order" # 固定顺序模式
SCHEDULING_MODE_CACHE_AFFINITY = "cache_affinity" # 缓存亲和模式
SCHEDULING_MODE_FIXED_ORDER = "fixed_order" # 固定顺序模式:严格按优先级,忽略缓存
SCHEDULING_MODE_CACHE_AFFINITY = "cache_affinity" # 缓存亲和模式:优先缓存,同优先级哈希分散
SCHEDULING_MODE_LOAD_BALANCE = "load_balance" # 负载均衡模式:忽略缓存,同优先级随机轮换
ALLOWED_SCHEDULING_MODES = {
SCHEDULING_MODE_FIXED_ORDER,
SCHEDULING_MODE_CACHE_AFFINITY,
SCHEDULING_MODE_LOAD_BALANCE,
}
def __init__(
@@ -680,8 +682,9 @@ class CacheAwareScheduler:
f"(api_format={target_format.value}, model={model_name})"
)
# 4. 应用缓存亲和性排序(仅在缓存亲和模式下启用)
# 4. 根据调度模式应用不同的排序策略
if self.scheduling_mode == self.SCHEDULING_MODE_CACHE_AFFINITY:
# 缓存亲和模式:优先使用缓存的,同优先级内哈希分散
if affinity_key and candidates:
candidates = await self._apply_cache_affinity(
candidates=candidates,
@@ -689,8 +692,13 @@ class CacheAwareScheduler:
api_format=target_format,
global_model_id=global_model_id,
)
elif self.scheduling_mode == self.SCHEDULING_MODE_LOAD_BALANCE:
# 负载均衡模式:忽略缓存,同优先级内随机轮换
candidates = self._apply_load_balance(candidates)
for candidate in candidates:
candidate.is_cached = False
else:
# 固定顺序模式:标记所有候选为非缓存
# 固定顺序模式:严格按优先级,忽略缓存
for candidate in candidates:
candidate.is_cached = False
@@ -1163,6 +1171,57 @@ class CacheAwareScheduler:
return result
def _apply_load_balance(
self, candidates: List[ProviderCandidate]
) -> List[ProviderCandidate]:
"""
负载均衡模式:同优先级内随机轮换
排序逻辑:
1. 按优先级分组provider_priority, internal_priority 或 global_priority
2. 同优先级组内随机打乱
3. 不考虑缓存亲和性
"""
if not candidates:
return candidates
from collections import defaultdict
# 使用 tuple 作为统一的 key 类型,兼容两种模式
priority_groups: Dict[tuple, List[ProviderCandidate]] = defaultdict(list)
# 根据优先级模式选择分组方式
if self.priority_mode == self.PRIORITY_MODE_GLOBAL_KEY:
# 全局 Key 优先模式:按 global_priority 分组
for candidate in candidates:
global_priority = (
candidate.key.global_priority
if candidate.key and candidate.key.global_priority is not None
else 999999
)
priority_groups[(global_priority,)].append(candidate)
else:
# 提供商优先模式:按 (provider_priority, internal_priority) 分组
for candidate in candidates:
key = (
candidate.provider.provider_priority or 999999,
candidate.key.internal_priority if candidate.key else 999999,
)
priority_groups[key].append(candidate)
result: List[ProviderCandidate] = []
for priority in sorted(priority_groups.keys()):
group = priority_groups[priority]
if len(group) > 1:
# 同优先级内随机打乱
shuffled = list(group)
random.shuffle(shuffled)
result.extend(shuffled)
else:
result.extend(group)
return result
def _shuffle_keys_by_internal_priority(
self,
keys: List[ProviderAPIKey],

171
src/services/cache/provider_cache.py vendored Normal file
View File

@@ -0,0 +1,171 @@
"""
Provider 缓存服务 - 减少 Provider 和 ProviderAPIKey 查询
用于缓存 Provider 的 billing_type 和 ProviderAPIKey 的 rate_multiplier
这些数据在 UsageService.record_usage() 中被频繁查询但变化不频繁。
"""
from typing import Optional, Tuple
from sqlalchemy.orm import Session
from src.config.constants import CacheTTL
from src.core.cache_service import CacheService
from src.core.enums import ProviderBillingType
from src.core.logger import logger
from src.models.database import Provider, ProviderAPIKey
class ProviderCacheService:
"""Provider 缓存服务
提供 Provider 和 ProviderAPIKey 的缓存查询功能,减少数据库访问。
主要用于 UsageService 中获取费率倍数和计费类型。
"""
CACHE_TTL = CacheTTL.PROVIDER # 5 分钟
@staticmethod
async def get_provider_api_key_rate_multiplier(
db: Session, provider_api_key_id: str
) -> Optional[float]:
"""
获取 ProviderAPIKey 的 rate_multiplier带缓存
Args:
db: 数据库会话
provider_api_key_id: ProviderAPIKey ID
Returns:
rate_multiplier 或 None如果找不到
"""
cache_key = f"provider_api_key:rate_multiplier:{provider_api_key_id}"
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data is not None:
logger.debug(f"ProviderAPIKey rate_multiplier 缓存命中: {provider_api_key_id[:8]}...")
# 缓存的 "NOT_FOUND" 表示数据库中不存在
if cached_data == "NOT_FOUND":
return None
return float(cached_data)
# 2. 缓存未命中,查询数据库
provider_key = (
db.query(ProviderAPIKey.rate_multiplier)
.filter(ProviderAPIKey.id == provider_api_key_id)
.first()
)
# 3. 写入缓存
if provider_key:
rate_multiplier = provider_key.rate_multiplier or 1.0
await CacheService.set(
cache_key, rate_multiplier, ttl_seconds=ProviderCacheService.CACHE_TTL
)
logger.debug(f"ProviderAPIKey rate_multiplier 已缓存: {provider_api_key_id[:8]}...")
return rate_multiplier
else:
# 缓存负结果
await CacheService.set(
cache_key, "NOT_FOUND", ttl_seconds=ProviderCacheService.CACHE_TTL
)
return None
@staticmethod
async def get_provider_billing_type(
db: Session, provider_id: str
) -> Optional[ProviderBillingType]:
"""
获取 Provider 的 billing_type带缓存
Args:
db: 数据库会话
provider_id: Provider ID
Returns:
billing_type 或 None如果找不到
"""
cache_key = f"provider:billing_type:{provider_id}"
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data is not None:
logger.debug(f"Provider billing_type 缓存命中: {provider_id[:8]}...")
if cached_data == "NOT_FOUND":
return None
try:
return ProviderBillingType(cached_data)
except ValueError:
# 缓存值无效,删除并重新查询
await CacheService.delete(cache_key)
# 2. 缓存未命中,查询数据库
provider = (
db.query(Provider.billing_type).filter(Provider.id == provider_id).first()
)
# 3. 写入缓存
if provider:
billing_type = provider.billing_type
await CacheService.set(
cache_key, billing_type.value, ttl_seconds=ProviderCacheService.CACHE_TTL
)
logger.debug(f"Provider billing_type 已缓存: {provider_id[:8]}...")
return billing_type
else:
# 缓存负结果
await CacheService.set(
cache_key, "NOT_FOUND", ttl_seconds=ProviderCacheService.CACHE_TTL
)
return None
@staticmethod
async def get_rate_multiplier_and_free_tier(
db: Session,
provider_api_key_id: Optional[str],
provider_id: Optional[str],
) -> Tuple[float, bool]:
"""
获取费率倍数和是否免费套餐(带缓存)
这是 UsageService._get_rate_multiplier_and_free_tier 的缓存版本。
Args:
db: 数据库会话
provider_api_key_id: ProviderAPIKey ID可选
provider_id: Provider ID可选
Returns:
(rate_multiplier, is_free_tier) 元组
"""
actual_rate_multiplier = 1.0
is_free_tier = False
# 获取费率倍数
if provider_api_key_id:
rate_multiplier = await ProviderCacheService.get_provider_api_key_rate_multiplier(
db, provider_api_key_id
)
if rate_multiplier is not None:
actual_rate_multiplier = rate_multiplier
# 获取计费类型
if provider_id:
billing_type = await ProviderCacheService.get_provider_billing_type(db, provider_id)
if billing_type == ProviderBillingType.FREE_TIER:
is_free_tier = True
return actual_rate_multiplier, is_free_tier
@staticmethod
async def invalidate_provider_api_key_cache(provider_api_key_id: str) -> None:
"""清除 ProviderAPIKey 缓存"""
await CacheService.delete(f"provider_api_key:rate_multiplier:{provider_api_key_id}")
logger.debug(f"ProviderAPIKey 缓存已清除: {provider_api_key_id[:8]}...")
@staticmethod
async def invalidate_provider_cache(provider_id: str) -> None:
"""清除 Provider 缓存"""
await CacheService.delete(f"provider:billing_type:{provider_id}")
logger.debug(f"Provider 缓存已清除: {provider_id[:8]}...")

View File

@@ -3,8 +3,9 @@
"""
import json
import time
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple
from sqlalchemy.orm import Session
@@ -20,6 +21,49 @@ class LogLevel(str, Enum):
FULL = "full" # 记录完整请求和响应包含body敏感信息会脱敏
# 进程内缓存 TTL- 系统配置变化不频繁,使用较长的 TTL
_CONFIG_CACHE_TTL = 60 # 1 分钟
# 进程内缓存存储: {key: (value, expire_time)}
_config_cache: Dict[str, Tuple[Any, float]] = {}
def _get_cached_config(key: str) -> Tuple[bool, Any]:
"""从进程内缓存获取配置值
Returns:
(hit, value): hit=True 表示缓存命中value 为缓存的值
"""
if key in _config_cache:
value, expire_time = _config_cache[key]
if time.time() < expire_time:
return True, value
# 缓存过期,安全删除(避免并发时 KeyError
_config_cache.pop(key, None)
return False, None
def _set_cached_config(key: str, value: Any) -> None:
"""设置进程内缓存"""
_config_cache[key] = (value, time.time() + _CONFIG_CACHE_TTL)
def invalidate_config_cache(key: Optional[str] = None) -> None:
"""清除配置缓存
Args:
key: 配置键,如果为 None 则清除所有缓存
"""
global _config_cache
if key is None:
_config_cache = {}
logger.debug("已清除所有系统配置缓存")
else:
# 使用 pop 安全删除,避免并发时 KeyError
if _config_cache.pop(key, None) is not None:
logger.debug(f"已清除系统配置缓存: {key}")
class SystemConfigService:
"""系统配置服务类"""
@@ -127,14 +171,23 @@ class SystemConfigService:
@classmethod
def get_config(cls, db: Session, key: str, default: Any = None) -> Optional[Any]:
"""获取系统配置值"""
"""获取系统配置值(带进程内缓存)"""
# 1. 检查进程内缓存
hit, cached_value = _get_cached_config(key)
if hit:
return cached_value
# 2. 查询数据库
config = db.query(SystemConfig).filter(SystemConfig.key == key).first()
if config:
_set_cached_config(key, config.value)
return config.value
# 如果配置不存在,检查默认值
# 3. 如果配置不存在,使用默认值
if key in cls.DEFAULT_CONFIGS:
return cls.DEFAULT_CONFIGS[key]["value"]
value = cls.DEFAULT_CONFIGS[key]["value"]
_set_cached_config(key, value)
return value
return default
@@ -185,6 +238,9 @@ class SystemConfigService:
db.commit()
db.refresh(config)
# 清除缓存
invalidate_config_cache(key)
return config
@staticmethod
@@ -243,6 +299,8 @@ class SystemConfigService:
if config:
db.delete(config)
db.commit()
# 清除缓存
invalidate_config_cache(key)
return True
return False

View File

@@ -10,7 +10,6 @@ from typing import Any, Dict, List, Optional, Tuple
from sqlalchemy import func
from sqlalchemy.orm import Session
from src.core.enums import ProviderBillingType
from src.core.logger import logger
from src.models.database import ApiKey, Provider, ProviderAPIKey, Usage, User, UserRole
from src.services.model.cost import ModelCostService
@@ -362,22 +361,12 @@ class UsageService:
provider_api_key_id: Optional[str],
provider_id: Optional[str],
) -> Tuple[float, bool]:
"""获取费率倍数和是否免费套餐"""
actual_rate_multiplier = 1.0
if provider_api_key_id:
provider_key = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.id == provider_api_key_id).first()
)
if provider_key and provider_key.rate_multiplier:
actual_rate_multiplier = provider_key.rate_multiplier
"""获取费率倍数和是否免费套餐(使用缓存)"""
from src.services.cache.provider_cache import ProviderCacheService
is_free_tier = False
if provider_id:
provider_obj = db.query(Provider).filter(Provider.id == provider_id).first()
if provider_obj and provider_obj.billing_type == ProviderBillingType.FREE_TIER:
is_free_tier = True
return actual_rate_multiplier, is_free_tier
return await ProviderCacheService.get_rate_multiplier_and_free_tier(
db, provider_api_key_id, provider_id
)
@classmethod
async def _calculate_costs(