mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-11 20:18:30 +08:00
Compare commits
19 Commits
cddc22d2b3
...
v0.2.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
465da6f818 | ||
|
|
e5f12fddd9 | ||
|
|
4fa9a1303a | ||
|
|
43f349d415 | ||
|
|
02069954de | ||
|
|
2e15875fed | ||
|
|
b34cfb676d | ||
|
|
3064497636 | ||
|
|
dec681fea0 | ||
|
|
523e27ba9a | ||
|
|
e7db76e581 | ||
|
|
689339117a | ||
|
|
b202765be4 | ||
|
|
3bbf3073df | ||
|
|
f46aaa2182 | ||
|
|
a2f33a6c35 | ||
|
|
b6bd6357ed | ||
|
|
c3a5878b1b | ||
|
|
c02ac56da8 |
@@ -58,13 +58,13 @@ cp .env.example .env
|
|||||||
python generate_keys.py # 生成密钥, 并将生成的密钥填入 .env
|
python generate_keys.py # 生成密钥, 并将生成的密钥填入 .env
|
||||||
|
|
||||||
# 3. 部署
|
# 3. 部署
|
||||||
docker-compose up -d
|
docker compose up -d
|
||||||
|
|
||||||
# 4. 首次部署时, 初始化数据库
|
# 4. 首次部署时, 初始化数据库
|
||||||
./migrate.sh
|
./migrate.sh
|
||||||
|
|
||||||
# 5. 更新
|
# 5. 更新
|
||||||
docker-compose pull && docker-compose up -d && ./migrate.sh
|
docker compose pull && docker compose up -d && ./migrate.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
### Docker Compose(本地构建镜像)
|
### Docker Compose(本地构建镜像)
|
||||||
@@ -86,7 +86,7 @@ python generate_keys.py # 生成密钥, 并将生成的密钥填入 .env
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 启动依赖
|
# 启动依赖
|
||||||
docker-compose -f docker-compose.build.yml up -d postgres redis
|
docker compose -f docker-compose.build.yml up -d postgres redis
|
||||||
|
|
||||||
# 后端
|
# 后端
|
||||||
uv sync
|
uv sync
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from src.models.database import Base
|
|||||||
config = context.config
|
config = context.config
|
||||||
|
|
||||||
# 从环境变量获取数据库 URL
|
# 从环境变量获取数据库 URL
|
||||||
# 优先使用 DATABASE_URL,否则从 DB_PASSWORD 自动构建(与 docker-compose 保持一致)
|
# 优先使用 DATABASE_URL,否则从 DB_PASSWORD 自动构建(与 docker compose 保持一致)
|
||||||
database_url = os.getenv("DATABASE_URL")
|
database_url = os.getenv("DATABASE_URL")
|
||||||
if not database_url:
|
if not database_url:
|
||||||
db_password = os.getenv("DB_PASSWORD", "")
|
db_password = os.getenv("DB_PASSWORD", "")
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# Aether 部署配置 - 本地构建
|
# Aether 部署配置 - 本地构建
|
||||||
# 使用方法:
|
# 使用方法:
|
||||||
# 首次构建 base: docker build -f Dockerfile.base -t aether-base:latest .
|
# 首次构建 base: docker build -f Dockerfile.base -t aether-base:latest .
|
||||||
# 启动服务: docker-compose -f docker-compose.build.yml up -d --build
|
# 启动服务: docker compose -f docker-compose.build.yml up -d --build
|
||||||
|
|
||||||
services:
|
services:
|
||||||
postgres:
|
postgres:
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
# Aether 部署配置 - 使用预构建镜像
|
# Aether 部署配置 - 使用预构建镜像
|
||||||
# 使用方法: docker-compose up -d
|
# 使用方法: docker compose up -d
|
||||||
|
|
||||||
services:
|
services:
|
||||||
postgres:
|
postgres:
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ export interface UserApiKeyExport {
|
|||||||
allowed_endpoints?: string[] | null
|
allowed_endpoints?: string[] | null
|
||||||
allowed_api_formats?: string[] | null
|
allowed_api_formats?: string[] | null
|
||||||
allowed_models?: string[] | null
|
allowed_models?: string[] | null
|
||||||
rate_limit?: number
|
rate_limit?: number | null // null = 无限制
|
||||||
concurrent_limit?: number | null
|
concurrent_limit?: number | null
|
||||||
force_capabilities?: any
|
force_capabilities?: any
|
||||||
is_active: boolean
|
is_active: boolean
|
||||||
@@ -220,7 +220,7 @@ export interface AdminApiKey {
|
|||||||
total_requests?: number
|
total_requests?: number
|
||||||
total_tokens?: number
|
total_tokens?: number
|
||||||
total_cost_usd?: number
|
total_cost_usd?: number
|
||||||
rate_limit?: number
|
rate_limit?: number | null // null = 无限制
|
||||||
allowed_providers?: string[] | null // 允许的提供商列表
|
allowed_providers?: string[] | null // 允许的提供商列表
|
||||||
allowed_api_formats?: string[] | null // 允许的 API 格式列表
|
allowed_api_formats?: string[] | null // 允许的 API 格式列表
|
||||||
allowed_models?: string[] | null // 允许的模型列表
|
allowed_models?: string[] | null // 允许的模型列表
|
||||||
@@ -236,8 +236,8 @@ export interface CreateStandaloneApiKeyRequest {
|
|||||||
allowed_providers?: string[] | null
|
allowed_providers?: string[] | null
|
||||||
allowed_api_formats?: string[] | null
|
allowed_api_formats?: string[] | null
|
||||||
allowed_models?: string[] | null
|
allowed_models?: string[] | null
|
||||||
rate_limit?: number
|
rate_limit?: number | null // null = 无限制
|
||||||
expire_days?: number | null // null = 永不过期
|
expires_at?: string | null // ISO 日期字符串,如 "2025-12-31",null = 永不过期
|
||||||
initial_balance_usd: number // 初始余额,必须设置
|
initial_balance_usd: number // 初始余额,必须设置
|
||||||
auto_delete_on_expiry?: boolean // 过期后是否自动删除
|
auto_delete_on_expiry?: boolean // 过期后是否自动删除
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -75,6 +75,16 @@ export interface ModelSummary {
|
|||||||
actual_total_cost_usd?: number // 倍率消耗(仅管理员可见)
|
actual_total_cost_usd?: number // 倍率消耗(仅管理员可见)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 提供商统计接口
|
||||||
|
export interface ProviderSummary {
|
||||||
|
provider: string
|
||||||
|
requests: number
|
||||||
|
total_tokens: number
|
||||||
|
total_cost_usd: number
|
||||||
|
success_rate: number | null
|
||||||
|
avg_response_time_ms: number | null
|
||||||
|
}
|
||||||
|
|
||||||
// 使用统计响应接口
|
// 使用统计响应接口
|
||||||
export interface UsageResponse {
|
export interface UsageResponse {
|
||||||
total_requests: number
|
total_requests: number
|
||||||
@@ -87,6 +97,13 @@ export interface UsageResponse {
|
|||||||
quota_usd: number | null
|
quota_usd: number | null
|
||||||
used_usd: number
|
used_usd: number
|
||||||
summary_by_model: ModelSummary[]
|
summary_by_model: ModelSummary[]
|
||||||
|
summary_by_provider?: ProviderSummary[]
|
||||||
|
pagination?: {
|
||||||
|
total: number
|
||||||
|
limit: number
|
||||||
|
offset: number
|
||||||
|
has_more: boolean
|
||||||
|
}
|
||||||
records: UsageRecordDetail[]
|
records: UsageRecordDetail[]
|
||||||
activity_heatmap?: ActivityHeatmap | null
|
activity_heatmap?: ActivityHeatmap | null
|
||||||
}
|
}
|
||||||
@@ -175,6 +192,8 @@ export const meApi = {
|
|||||||
async getUsage(params?: {
|
async getUsage(params?: {
|
||||||
start_date?: string
|
start_date?: string
|
||||||
end_date?: string
|
end_date?: string
|
||||||
|
limit?: number
|
||||||
|
offset?: number
|
||||||
}): Promise<UsageResponse> {
|
}): Promise<UsageResponse> {
|
||||||
const response = await apiClient.get<UsageResponse>('/api/users/me/usage', { params })
|
const response = await apiClient.get<UsageResponse>('/api/users/me/usage', { params })
|
||||||
return response.data
|
return response.data
|
||||||
@@ -184,11 +203,12 @@ export const meApi = {
|
|||||||
async getActiveRequests(ids?: string): Promise<{
|
async getActiveRequests(ids?: string): Promise<{
|
||||||
requests: Array<{
|
requests: Array<{
|
||||||
id: string
|
id: string
|
||||||
status: string
|
status: 'pending' | 'streaming' | 'completed' | 'failed'
|
||||||
input_tokens: number
|
input_tokens: number
|
||||||
output_tokens: number
|
output_tokens: number
|
||||||
cost: number
|
cost: number
|
||||||
response_time_ms: number | null
|
response_time_ms: number | null
|
||||||
|
first_byte_time_ms: number | null
|
||||||
}>
|
}>
|
||||||
}> {
|
}> {
|
||||||
const params = ids ? { ids } : {}
|
const params = ids ? { ids } : {}
|
||||||
@@ -267,5 +287,14 @@ export const meApi = {
|
|||||||
}> {
|
}> {
|
||||||
const response = await apiClient.get('/api/users/me/usage/interval-timeline', { params })
|
const response = await apiClient.get('/api/users/me/usage/interval-timeline', { params })
|
||||||
return response.data
|
return response.data
|
||||||
|
},
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取活跃度热力图数据(用户)
|
||||||
|
* 后端已缓存5分钟
|
||||||
|
*/
|
||||||
|
async getActivityHeatmap(): Promise<ActivityHeatmap> {
|
||||||
|
const response = await apiClient.get<ActivityHeatmap>('/api/users/me/usage/heatmap')
|
||||||
|
return response.data
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -193,10 +193,22 @@ export const usageApi = {
|
|||||||
output_tokens: number
|
output_tokens: number
|
||||||
cost: number
|
cost: number
|
||||||
response_time_ms: number | null
|
response_time_ms: number | null
|
||||||
|
first_byte_time_ms: number | null
|
||||||
|
provider?: string | null
|
||||||
|
api_key_name?: string | null
|
||||||
}>
|
}>
|
||||||
}> {
|
}> {
|
||||||
const params = ids?.length ? { ids: ids.join(',') } : {}
|
const params = ids?.length ? { ids: ids.join(',') } : {}
|
||||||
const response = await apiClient.get('/api/admin/usage/active', { params })
|
const response = await apiClient.get('/api/admin/usage/active', { params })
|
||||||
return response.data
|
return response.data
|
||||||
|
},
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取活跃度热力图数据(管理员)
|
||||||
|
* 后端已缓存5分钟
|
||||||
|
*/
|
||||||
|
async getActivityHeatmap(): Promise<ActivityHeatmap> {
|
||||||
|
const response = await apiClient.get<ActivityHeatmap>('/api/admin/usage/heatmap')
|
||||||
|
return response.data
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
117
frontend/src/components/common/ModelMultiSelect.vue
Normal file
117
frontend/src/components/common/ModelMultiSelect.vue
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
<template>
|
||||||
|
<div class="space-y-2">
|
||||||
|
<Label class="text-sm font-medium">允许的模型</Label>
|
||||||
|
<div class="relative">
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
class="w-full h-10 px-3 border rounded-lg bg-background text-left flex items-center justify-between hover:bg-muted/50 transition-colors"
|
||||||
|
@click="isOpen = !isOpen"
|
||||||
|
>
|
||||||
|
<span :class="modelValue.length ? 'text-foreground' : 'text-muted-foreground'">
|
||||||
|
{{ modelValue.length ? `已选择 ${modelValue.length} 个` : '全部可用' }}
|
||||||
|
<span
|
||||||
|
v-if="invalidModels.length"
|
||||||
|
class="text-destructive"
|
||||||
|
>({{ invalidModels.length }} 个已失效)</span>
|
||||||
|
</span>
|
||||||
|
<ChevronDown
|
||||||
|
class="h-4 w-4 text-muted-foreground transition-transform"
|
||||||
|
:class="isOpen ? 'rotate-180' : ''"
|
||||||
|
/>
|
||||||
|
</button>
|
||||||
|
<div
|
||||||
|
v-if="isOpen"
|
||||||
|
class="fixed inset-0 z-[80]"
|
||||||
|
@click.stop="isOpen = false"
|
||||||
|
/>
|
||||||
|
<div
|
||||||
|
v-if="isOpen"
|
||||||
|
class="absolute z-[90] w-full mt-1 bg-popover border rounded-lg shadow-lg max-h-48 overflow-y-auto"
|
||||||
|
>
|
||||||
|
<!-- 失效模型(置顶显示,只能取消选择) -->
|
||||||
|
<div
|
||||||
|
v-for="modelName in invalidModels"
|
||||||
|
:key="modelName"
|
||||||
|
class="flex items-center gap-2 px-3 py-2 hover:bg-muted/50 cursor-pointer bg-destructive/5"
|
||||||
|
@click="removeModel(modelName)"
|
||||||
|
>
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
:checked="true"
|
||||||
|
class="h-4 w-4 rounded border-gray-300 cursor-pointer"
|
||||||
|
@click.stop
|
||||||
|
@change="removeModel(modelName)"
|
||||||
|
>
|
||||||
|
<span class="text-sm text-destructive">{{ modelName }}</span>
|
||||||
|
<span class="text-xs text-destructive/70">(已失效)</span>
|
||||||
|
</div>
|
||||||
|
<!-- 有效模型 -->
|
||||||
|
<div
|
||||||
|
v-for="model in models"
|
||||||
|
:key="model.name"
|
||||||
|
class="flex items-center gap-2 px-3 py-2 hover:bg-muted/50 cursor-pointer"
|
||||||
|
@click="toggleModel(model.name)"
|
||||||
|
>
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
:checked="modelValue.includes(model.name)"
|
||||||
|
class="h-4 w-4 rounded border-gray-300 cursor-pointer"
|
||||||
|
@click.stop
|
||||||
|
@change="toggleModel(model.name)"
|
||||||
|
>
|
||||||
|
<span class="text-sm">{{ model.name }}</span>
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
v-if="models.length === 0 && invalidModels.length === 0"
|
||||||
|
class="px-3 py-2 text-sm text-muted-foreground"
|
||||||
|
>
|
||||||
|
暂无可用模型
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</template>
|
||||||
|
|
||||||
|
<script setup lang="ts">
|
||||||
|
import { ref, computed } from 'vue'
|
||||||
|
import { Label } from '@/components/ui'
|
||||||
|
import { ChevronDown } from 'lucide-vue-next'
|
||||||
|
import { useInvalidModels } from '@/composables/useInvalidModels'
|
||||||
|
|
||||||
|
export interface ModelWithName {
|
||||||
|
name: string
|
||||||
|
}
|
||||||
|
|
||||||
|
const props = defineProps<{
|
||||||
|
modelValue: string[]
|
||||||
|
models: ModelWithName[]
|
||||||
|
}>()
|
||||||
|
|
||||||
|
const emit = defineEmits<{
|
||||||
|
'update:modelValue': [value: string[]]
|
||||||
|
}>()
|
||||||
|
|
||||||
|
const isOpen = ref(false)
|
||||||
|
|
||||||
|
// 检测失效模型
|
||||||
|
const { invalidModels } = useInvalidModels(
|
||||||
|
computed(() => props.modelValue),
|
||||||
|
computed(() => props.models)
|
||||||
|
)
|
||||||
|
|
||||||
|
function toggleModel(name: string) {
|
||||||
|
const newValue = [...props.modelValue]
|
||||||
|
const index = newValue.indexOf(name)
|
||||||
|
if (index === -1) {
|
||||||
|
newValue.push(name)
|
||||||
|
} else {
|
||||||
|
newValue.splice(index, 1)
|
||||||
|
}
|
||||||
|
emit('update:modelValue', newValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
function removeModel(name: string) {
|
||||||
|
const newValue = props.modelValue.filter(m => m !== name)
|
||||||
|
emit('update:modelValue', newValue)
|
||||||
|
}
|
||||||
|
</script>
|
||||||
@@ -7,3 +7,6 @@
|
|||||||
export { default as EmptyState } from './EmptyState.vue'
|
export { default as EmptyState } from './EmptyState.vue'
|
||||||
export { default as AlertDialog } from './AlertDialog.vue'
|
export { default as AlertDialog } from './AlertDialog.vue'
|
||||||
export { default as LoadingState } from './LoadingState.vue'
|
export { default as LoadingState } from './LoadingState.vue'
|
||||||
|
|
||||||
|
// 表单组件
|
||||||
|
export { default as ModelMultiSelect } from './ModelMultiSelect.vue'
|
||||||
|
|||||||
34
frontend/src/composables/useInvalidModels.ts
Normal file
34
frontend/src/composables/useInvalidModels.ts
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
import { computed, type Ref, type ComputedRef } from 'vue'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检测失效模型的 composable
|
||||||
|
*
|
||||||
|
* 用于检测 allowed_models 中已不存在于 globalModels 的模型名称,
|
||||||
|
* 这些模型可能已被删除但引用未清理。
|
||||||
|
*
|
||||||
|
* @example
|
||||||
|
* ```typescript
|
||||||
|
* const { invalidModels } = useInvalidModels(
|
||||||
|
* computed(() => form.value.allowed_models),
|
||||||
|
* globalModels
|
||||||
|
* )
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
export interface ModelWithName {
|
||||||
|
name: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useInvalidModels<T extends ModelWithName>(
|
||||||
|
allowedModels: Ref<string[]> | ComputedRef<string[]>,
|
||||||
|
globalModels: Ref<T[]>
|
||||||
|
): { invalidModels: ComputedRef<string[]> } {
|
||||||
|
const validModelNames = computed(() =>
|
||||||
|
new Set(globalModels.value.map(m => m.name))
|
||||||
|
)
|
||||||
|
|
||||||
|
const invalidModels = computed(() =>
|
||||||
|
allowedModels.value.filter(name => !validModelNames.value.has(name))
|
||||||
|
)
|
||||||
|
|
||||||
|
return { invalidModels }
|
||||||
|
}
|
||||||
@@ -79,45 +79,45 @@
|
|||||||
|
|
||||||
<div class="space-y-2">
|
<div class="space-y-2">
|
||||||
<Label
|
<Label
|
||||||
for="form-expire-days"
|
for="form-expires-at"
|
||||||
class="text-sm font-medium"
|
class="text-sm font-medium"
|
||||||
>有效期设置</Label>
|
>有效期设置</Label>
|
||||||
<div class="flex items-center gap-2">
|
<div class="flex items-center gap-2">
|
||||||
<Input
|
<div class="relative flex-1">
|
||||||
id="form-expire-days"
|
<Input
|
||||||
:model-value="form.expire_days ?? ''"
|
id="form-expires-at"
|
||||||
type="number"
|
:model-value="form.expires_at || ''"
|
||||||
min="1"
|
type="date"
|
||||||
max="3650"
|
:min="minExpiryDate"
|
||||||
placeholder="天数"
|
class="h-9 pr-8"
|
||||||
:class="form.never_expire ? 'flex-1 h-9 opacity-50' : 'flex-1 h-9'"
|
:placeholder="form.expires_at ? '' : '永不过期'"
|
||||||
:disabled="form.never_expire"
|
@update:model-value="(v) => form.expires_at = v || undefined"
|
||||||
@update:model-value="(v) => form.expire_days = parseNumberInput(v, { min: 1, max: 3650 })"
|
/>
|
||||||
/>
|
<button
|
||||||
<label class="flex items-center gap-1.5 border rounded-md px-2 py-1.5 bg-muted/50 cursor-pointer text-xs whitespace-nowrap">
|
v-if="form.expires_at"
|
||||||
<input
|
type="button"
|
||||||
v-model="form.never_expire"
|
class="absolute right-2 top-1/2 -translate-y-1/2 text-muted-foreground hover:text-foreground"
|
||||||
type="checkbox"
|
title="清空(永不过期)"
|
||||||
class="h-3.5 w-3.5 rounded border-gray-300 cursor-pointer"
|
@click="clearExpiryDate"
|
||||||
@change="onNeverExpireChange"
|
|
||||||
>
|
>
|
||||||
永不过期
|
<X class="h-4 w-4" />
|
||||||
</label>
|
</button>
|
||||||
|
</div>
|
||||||
<label
|
<label
|
||||||
class="flex items-center gap-1.5 border rounded-md px-2 py-1.5 bg-muted/50 cursor-pointer text-xs whitespace-nowrap"
|
class="flex items-center gap-1.5 border rounded-md px-2 py-1.5 bg-muted/50 cursor-pointer text-xs whitespace-nowrap"
|
||||||
:class="form.never_expire ? 'opacity-50' : ''"
|
:class="!form.expires_at ? 'opacity-50 cursor-not-allowed' : ''"
|
||||||
>
|
>
|
||||||
<input
|
<input
|
||||||
v-model="form.auto_delete_on_expiry"
|
v-model="form.auto_delete_on_expiry"
|
||||||
type="checkbox"
|
type="checkbox"
|
||||||
class="h-3.5 w-3.5 rounded border-gray-300 cursor-pointer"
|
class="h-3.5 w-3.5 rounded border-gray-300 cursor-pointer"
|
||||||
:disabled="form.never_expire"
|
:disabled="!form.expires_at"
|
||||||
>
|
>
|
||||||
到期删除
|
到期删除
|
||||||
</label>
|
</label>
|
||||||
</div>
|
</div>
|
||||||
<p class="text-xs text-muted-foreground">
|
<p class="text-xs text-muted-foreground">
|
||||||
不勾选"到期删除"则仅禁用
|
{{ form.expires_at ? '到期后' + (form.auto_delete_on_expiry ? '自动删除' : '仅禁用') + '(当天 23:59 失效)' : '留空表示永不过期' }}
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -244,55 +244,10 @@
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- 模型多选下拉框 -->
|
<!-- 模型多选下拉框 -->
|
||||||
<div class="space-y-2">
|
<ModelMultiSelect
|
||||||
<Label class="text-sm font-medium">允许的模型</Label>
|
v-model="form.allowed_models"
|
||||||
<div class="relative">
|
:models="globalModels"
|
||||||
<button
|
/>
|
||||||
type="button"
|
|
||||||
class="w-full h-10 px-3 border rounded-lg bg-background text-left flex items-center justify-between hover:bg-muted/50 transition-colors"
|
|
||||||
@click="modelDropdownOpen = !modelDropdownOpen"
|
|
||||||
>
|
|
||||||
<span :class="form.allowed_models.length ? 'text-foreground' : 'text-muted-foreground'">
|
|
||||||
{{ form.allowed_models.length ? `已选择 ${form.allowed_models.length} 个` : '全部可用' }}
|
|
||||||
</span>
|
|
||||||
<ChevronDown
|
|
||||||
class="h-4 w-4 text-muted-foreground transition-transform"
|
|
||||||
:class="modelDropdownOpen ? 'rotate-180' : ''"
|
|
||||||
/>
|
|
||||||
</button>
|
|
||||||
<div
|
|
||||||
v-if="modelDropdownOpen"
|
|
||||||
class="fixed inset-0 z-[80]"
|
|
||||||
@click.stop="modelDropdownOpen = false"
|
|
||||||
/>
|
|
||||||
<div
|
|
||||||
v-if="modelDropdownOpen"
|
|
||||||
class="absolute z-[90] w-full mt-1 bg-popover border rounded-lg shadow-lg max-h-48 overflow-y-auto"
|
|
||||||
>
|
|
||||||
<div
|
|
||||||
v-for="model in globalModels"
|
|
||||||
:key="model.name"
|
|
||||||
class="flex items-center gap-2 px-3 py-2 hover:bg-muted/50 cursor-pointer"
|
|
||||||
@click="toggleSelection('allowed_models', model.name)"
|
|
||||||
>
|
|
||||||
<input
|
|
||||||
type="checkbox"
|
|
||||||
:checked="form.allowed_models.includes(model.name)"
|
|
||||||
class="h-4 w-4 rounded border-gray-300 cursor-pointer"
|
|
||||||
@click.stop
|
|
||||||
@change="toggleSelection('allowed_models', model.name)"
|
|
||||||
>
|
|
||||||
<span class="text-sm">{{ model.name }}</span>
|
|
||||||
</div>
|
|
||||||
<div
|
|
||||||
v-if="globalModels.length === 0"
|
|
||||||
class="px-3 py-2 text-sm text-muted-foreground"
|
|
||||||
>
|
|
||||||
暂无可用模型
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</form>
|
</form>
|
||||||
@@ -325,8 +280,9 @@ import {
|
|||||||
Input,
|
Input,
|
||||||
Label,
|
Label,
|
||||||
} from '@/components/ui'
|
} from '@/components/ui'
|
||||||
import { Plus, SquarePen, Key, Shield, ChevronDown } from 'lucide-vue-next'
|
import { Plus, SquarePen, Key, Shield, ChevronDown, X } from 'lucide-vue-next'
|
||||||
import { useFormDialog } from '@/composables/useFormDialog'
|
import { useFormDialog } from '@/composables/useFormDialog'
|
||||||
|
import { ModelMultiSelect } from '@/components/common'
|
||||||
import { getProvidersSummary } from '@/api/endpoints/providers'
|
import { getProvidersSummary } from '@/api/endpoints/providers'
|
||||||
import { getGlobalModels } from '@/api/global-models'
|
import { getGlobalModels } from '@/api/global-models'
|
||||||
import { adminApi } from '@/api/admin'
|
import { adminApi } from '@/api/admin'
|
||||||
@@ -338,8 +294,7 @@ export interface StandaloneKeyFormData {
|
|||||||
id?: string
|
id?: string
|
||||||
name: string
|
name: string
|
||||||
initial_balance_usd?: number
|
initial_balance_usd?: number
|
||||||
expire_days?: number
|
expires_at?: string // ISO 日期字符串,如 "2025-12-31",undefined = 永不过期
|
||||||
never_expire: boolean
|
|
||||||
rate_limit?: number
|
rate_limit?: number
|
||||||
auto_delete_on_expiry: boolean
|
auto_delete_on_expiry: boolean
|
||||||
allowed_providers: string[]
|
allowed_providers: string[]
|
||||||
@@ -363,7 +318,6 @@ const saving = ref(false)
|
|||||||
// 下拉框状态
|
// 下拉框状态
|
||||||
const providerDropdownOpen = ref(false)
|
const providerDropdownOpen = ref(false)
|
||||||
const apiFormatDropdownOpen = ref(false)
|
const apiFormatDropdownOpen = ref(false)
|
||||||
const modelDropdownOpen = ref(false)
|
|
||||||
|
|
||||||
// 选项数据
|
// 选项数据
|
||||||
const providers = ref<ProviderWithEndpointsSummary[]>([])
|
const providers = ref<ProviderWithEndpointsSummary[]>([])
|
||||||
@@ -374,8 +328,7 @@ const allApiFormats = ref<string[]>([])
|
|||||||
const form = ref<StandaloneKeyFormData>({
|
const form = ref<StandaloneKeyFormData>({
|
||||||
name: '',
|
name: '',
|
||||||
initial_balance_usd: 10,
|
initial_balance_usd: 10,
|
||||||
expire_days: undefined,
|
expires_at: undefined,
|
||||||
never_expire: true,
|
|
||||||
rate_limit: undefined,
|
rate_limit: undefined,
|
||||||
auto_delete_on_expiry: false,
|
auto_delete_on_expiry: false,
|
||||||
allowed_providers: [],
|
allowed_providers: [],
|
||||||
@@ -383,12 +336,18 @@ const form = ref<StandaloneKeyFormData>({
|
|||||||
allowed_models: []
|
allowed_models: []
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// 计算最小可选日期(明天)
|
||||||
|
const minExpiryDate = computed(() => {
|
||||||
|
const tomorrow = new Date()
|
||||||
|
tomorrow.setDate(tomorrow.getDate() + 1)
|
||||||
|
return tomorrow.toISOString().split('T')[0]
|
||||||
|
})
|
||||||
|
|
||||||
function resetForm() {
|
function resetForm() {
|
||||||
form.value = {
|
form.value = {
|
||||||
name: '',
|
name: '',
|
||||||
initial_balance_usd: 10,
|
initial_balance_usd: 10,
|
||||||
expire_days: undefined,
|
expires_at: undefined,
|
||||||
never_expire: true,
|
|
||||||
rate_limit: undefined,
|
rate_limit: undefined,
|
||||||
auto_delete_on_expiry: false,
|
auto_delete_on_expiry: false,
|
||||||
allowed_providers: [],
|
allowed_providers: [],
|
||||||
@@ -397,7 +356,6 @@ function resetForm() {
|
|||||||
}
|
}
|
||||||
providerDropdownOpen.value = false
|
providerDropdownOpen.value = false
|
||||||
apiFormatDropdownOpen.value = false
|
apiFormatDropdownOpen.value = false
|
||||||
modelDropdownOpen.value = false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function loadKeyData() {
|
function loadKeyData() {
|
||||||
@@ -406,8 +364,7 @@ function loadKeyData() {
|
|||||||
id: props.apiKey.id,
|
id: props.apiKey.id,
|
||||||
name: props.apiKey.name || '',
|
name: props.apiKey.name || '',
|
||||||
initial_balance_usd: props.apiKey.initial_balance_usd,
|
initial_balance_usd: props.apiKey.initial_balance_usd,
|
||||||
expire_days: props.apiKey.expire_days,
|
expires_at: props.apiKey.expires_at,
|
||||||
never_expire: props.apiKey.never_expire,
|
|
||||||
rate_limit: props.apiKey.rate_limit,
|
rate_limit: props.apiKey.rate_limit,
|
||||||
auto_delete_on_expiry: props.apiKey.auto_delete_on_expiry,
|
auto_delete_on_expiry: props.apiKey.auto_delete_on_expiry,
|
||||||
allowed_providers: props.apiKey.allowed_providers || [],
|
allowed_providers: props.apiKey.allowed_providers || [],
|
||||||
@@ -452,12 +409,10 @@ function toggleSelection(field: 'allowed_providers' | 'allowed_api_formats' | 'a
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 永不过期切换
|
// 清空过期日期(同时清空到期删除选项)
|
||||||
function onNeverExpireChange() {
|
function clearExpiryDate() {
|
||||||
if (form.value.never_expire) {
|
form.value.expires_at = undefined
|
||||||
form.value.expire_days = undefined
|
form.value.auto_delete_on_expiry = false
|
||||||
form.value.auto_delete_on_expiry = false
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 提交表单
|
// 提交表单
|
||||||
|
|||||||
@@ -18,8 +18,22 @@
|
|||||||
<span class="flex-shrink-0">多</span>
|
<span class="flex-shrink-0">多</span>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
<div
|
||||||
|
v-if="isLoading"
|
||||||
|
class="h-full min-h-[160px] flex items-center justify-center text-sm text-muted-foreground"
|
||||||
|
>
|
||||||
|
<Loader2 class="h-5 w-5 animate-spin mr-2" />
|
||||||
|
加载中...
|
||||||
|
</div>
|
||||||
|
<div
|
||||||
|
v-else-if="hasError"
|
||||||
|
class="h-full min-h-[160px] flex items-center justify-center text-sm text-destructive"
|
||||||
|
>
|
||||||
|
<AlertCircle class="h-4 w-4 mr-1.5" />
|
||||||
|
加载失败
|
||||||
|
</div>
|
||||||
<ActivityHeatmap
|
<ActivityHeatmap
|
||||||
v-if="hasData"
|
v-else-if="hasData"
|
||||||
:data="data"
|
:data="data"
|
||||||
:show-header="false"
|
:show-header="false"
|
||||||
/>
|
/>
|
||||||
@@ -34,6 +48,7 @@
|
|||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { computed } from 'vue'
|
import { computed } from 'vue'
|
||||||
|
import { Loader2, AlertCircle } from 'lucide-vue-next'
|
||||||
import Card from '@/components/ui/card.vue'
|
import Card from '@/components/ui/card.vue'
|
||||||
import ActivityHeatmap from '@/components/stats/ActivityHeatmap.vue'
|
import ActivityHeatmap from '@/components/stats/ActivityHeatmap.vue'
|
||||||
import type { ActivityHeatmap as ActivityHeatmapData } from '@/types/activity'
|
import type { ActivityHeatmap as ActivityHeatmapData } from '@/types/activity'
|
||||||
@@ -41,6 +56,8 @@ import type { ActivityHeatmap as ActivityHeatmapData } from '@/types/activity'
|
|||||||
const props = defineProps<{
|
const props = defineProps<{
|
||||||
data: ActivityHeatmapData | null
|
data: ActivityHeatmapData | null
|
||||||
title: string
|
title: string
|
||||||
|
isLoading?: boolean
|
||||||
|
hasError?: boolean
|
||||||
}>()
|
}>()
|
||||||
|
|
||||||
const legendLevels = [0.08, 0.25, 0.45, 0.65, 0.85]
|
const legendLevels = [0.08, 0.25, 0.45, 0.65, 0.85]
|
||||||
|
|||||||
@@ -64,9 +64,6 @@ export function useUsageData(options: UseUsageDataOptions) {
|
|||||||
}))
|
}))
|
||||||
})
|
})
|
||||||
|
|
||||||
// 活跃度热图数据
|
|
||||||
const activityHeatmapData = computed(() => stats.value.activity_heatmap)
|
|
||||||
|
|
||||||
// 加载统计数据(不加载记录)
|
// 加载统计数据(不加载记录)
|
||||||
async function loadStats(dateRange?: DateRangeParams) {
|
async function loadStats(dateRange?: DateRangeParams) {
|
||||||
isLoadingStats.value = true
|
isLoadingStats.value = true
|
||||||
@@ -93,7 +90,7 @@ export function useUsageData(options: UseUsageDataOptions) {
|
|||||||
cache_stats: (statsData as any).cache_stats,
|
cache_stats: (statsData as any).cache_stats,
|
||||||
period_start: '',
|
period_start: '',
|
||||||
period_end: '',
|
period_end: '',
|
||||||
activity_heatmap: statsData.activity_heatmap || null
|
activity_heatmap: null
|
||||||
}
|
}
|
||||||
|
|
||||||
modelStats.value = modelData.map(item => ({
|
modelStats.value = modelData.map(item => ({
|
||||||
@@ -143,7 +140,7 @@ export function useUsageData(options: UseUsageDataOptions) {
|
|||||||
avg_response_time: userData.avg_response_time || 0,
|
avg_response_time: userData.avg_response_time || 0,
|
||||||
period_start: '',
|
period_start: '',
|
||||||
period_end: '',
|
period_end: '',
|
||||||
activity_heatmap: userData.activity_heatmap || null
|
activity_heatmap: null
|
||||||
}
|
}
|
||||||
|
|
||||||
modelStats.value = (userData.summary_by_model || []).map((item: any) => ({
|
modelStats.value = (userData.summary_by_model || []).map((item: any) => ({
|
||||||
@@ -305,7 +302,6 @@ export function useUsageData(options: UseUsageDataOptions) {
|
|||||||
|
|
||||||
// 计算属性
|
// 计算属性
|
||||||
enhancedModelStats,
|
enhancedModelStats,
|
||||||
activityHeatmapData,
|
|
||||||
|
|
||||||
// 方法
|
// 方法
|
||||||
loadStats,
|
loadStats,
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
import type { ActivityHeatmap } from '@/types/activity'
|
|
||||||
|
|
||||||
// 统计数据状态
|
// 统计数据状态
|
||||||
export interface UsageStatsState {
|
export interface UsageStatsState {
|
||||||
total_requests: number
|
total_requests: number
|
||||||
@@ -17,7 +15,6 @@ export interface UsageStatsState {
|
|||||||
}
|
}
|
||||||
period_start: string
|
period_start: string
|
||||||
period_end: string
|
period_end: string
|
||||||
activity_heatmap: ActivityHeatmap | null
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 模型统计
|
// 模型统计
|
||||||
@@ -115,7 +112,6 @@ export function createDefaultStats(): UsageStatsState {
|
|||||||
error_rate: undefined,
|
error_rate: undefined,
|
||||||
cache_stats: undefined,
|
cache_stats: undefined,
|
||||||
period_start: '',
|
period_start: '',
|
||||||
period_end: '',
|
period_end: ''
|
||||||
activity_heatmap: null
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -316,55 +316,10 @@
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- 模型多选下拉框 -->
|
<!-- 模型多选下拉框 -->
|
||||||
<div class="space-y-2">
|
<ModelMultiSelect
|
||||||
<Label class="text-sm font-medium">允许的模型</Label>
|
v-model="form.allowed_models"
|
||||||
<div class="relative">
|
:models="globalModels"
|
||||||
<button
|
/>
|
||||||
type="button"
|
|
||||||
class="w-full h-10 px-3 border rounded-lg bg-background text-left flex items-center justify-between hover:bg-muted/50 transition-colors"
|
|
||||||
@click="modelDropdownOpen = !modelDropdownOpen"
|
|
||||||
>
|
|
||||||
<span :class="form.allowed_models.length ? 'text-foreground' : 'text-muted-foreground'">
|
|
||||||
{{ form.allowed_models.length ? `已选择 ${form.allowed_models.length} 个` : '全部可用' }}
|
|
||||||
</span>
|
|
||||||
<ChevronDown
|
|
||||||
class="h-4 w-4 text-muted-foreground transition-transform"
|
|
||||||
:class="modelDropdownOpen ? 'rotate-180' : ''"
|
|
||||||
/>
|
|
||||||
</button>
|
|
||||||
<div
|
|
||||||
v-if="modelDropdownOpen"
|
|
||||||
class="fixed inset-0 z-[80]"
|
|
||||||
@click.stop="modelDropdownOpen = false"
|
|
||||||
/>
|
|
||||||
<div
|
|
||||||
v-if="modelDropdownOpen"
|
|
||||||
class="absolute z-[90] w-full mt-1 bg-popover border rounded-lg shadow-lg max-h-48 overflow-y-auto"
|
|
||||||
>
|
|
||||||
<div
|
|
||||||
v-for="model in globalModels"
|
|
||||||
:key="model.name"
|
|
||||||
class="flex items-center gap-2 px-3 py-2 hover:bg-muted/50 cursor-pointer"
|
|
||||||
@click="toggleSelection('allowed_models', model.name)"
|
|
||||||
>
|
|
||||||
<input
|
|
||||||
type="checkbox"
|
|
||||||
:checked="form.allowed_models.includes(model.name)"
|
|
||||||
class="h-4 w-4 rounded border-gray-300 cursor-pointer"
|
|
||||||
@click.stop
|
|
||||||
@change="toggleSelection('allowed_models', model.name)"
|
|
||||||
>
|
|
||||||
<span class="text-sm">{{ model.name }}</span>
|
|
||||||
</div>
|
|
||||||
<div
|
|
||||||
v-if="globalModels.length === 0"
|
|
||||||
class="px-3 py-2 text-sm text-muted-foreground"
|
|
||||||
>
|
|
||||||
暂无可用模型
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</form>
|
</form>
|
||||||
@@ -404,10 +359,12 @@ import {
|
|||||||
} from '@/components/ui'
|
} from '@/components/ui'
|
||||||
import { UserPlus, SquarePen, ChevronDown } from 'lucide-vue-next'
|
import { UserPlus, SquarePen, ChevronDown } from 'lucide-vue-next'
|
||||||
import { useFormDialog } from '@/composables/useFormDialog'
|
import { useFormDialog } from '@/composables/useFormDialog'
|
||||||
|
import { ModelMultiSelect } from '@/components/common'
|
||||||
import { getProvidersSummary } from '@/api/endpoints/providers'
|
import { getProvidersSummary } from '@/api/endpoints/providers'
|
||||||
import { getGlobalModels } from '@/api/global-models'
|
import { getGlobalModels } from '@/api/global-models'
|
||||||
import { adminApi } from '@/api/admin'
|
import { adminApi } from '@/api/admin'
|
||||||
import { log } from '@/utils/logger'
|
import { log } from '@/utils/logger'
|
||||||
|
import type { ProviderWithEndpointsSummary, GlobalModelResponse } from '@/api/endpoints/types'
|
||||||
|
|
||||||
export interface UserFormData {
|
export interface UserFormData {
|
||||||
id?: string
|
id?: string
|
||||||
@@ -440,11 +397,10 @@ const roleSelectOpen = ref(false)
|
|||||||
// 下拉框状态
|
// 下拉框状态
|
||||||
const providerDropdownOpen = ref(false)
|
const providerDropdownOpen = ref(false)
|
||||||
const endpointDropdownOpen = ref(false)
|
const endpointDropdownOpen = ref(false)
|
||||||
const modelDropdownOpen = ref(false)
|
|
||||||
|
|
||||||
// 选项数据
|
// 选项数据
|
||||||
const providers = ref<any[]>([])
|
const providers = ref<ProviderWithEndpointsSummary[]>([])
|
||||||
const globalModels = ref<any[]>([])
|
const globalModels = ref<GlobalModelResponse[]>([])
|
||||||
const apiFormats = ref<Array<{ value: string; label: string }>>([])
|
const apiFormats = ref<Array<{ value: string; label: string }>>([])
|
||||||
|
|
||||||
// 表单数据
|
// 表单数据
|
||||||
|
|||||||
@@ -850,28 +850,20 @@ async function deleteApiKey(apiKey: AdminApiKey) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function editApiKey(apiKey: AdminApiKey) {
|
function editApiKey(apiKey: AdminApiKey) {
|
||||||
// 计算过期天数
|
// 解析过期日期为 YYYY-MM-DD 格式
|
||||||
let expireDays: number | undefined = undefined
|
// 保留原始日期,不做时间过滤(避免编辑当天过期的 Key 时意外清空)
|
||||||
let neverExpire = true
|
let expiresAt: string | undefined = undefined
|
||||||
|
|
||||||
if (apiKey.expires_at) {
|
if (apiKey.expires_at) {
|
||||||
const expiresDate = new Date(apiKey.expires_at)
|
const expiresDate = new Date(apiKey.expires_at)
|
||||||
const now = new Date()
|
expiresAt = expiresDate.toISOString().split('T')[0]
|
||||||
const diffMs = expiresDate.getTime() - now.getTime()
|
|
||||||
const diffDays = Math.ceil(diffMs / (1000 * 60 * 60 * 24))
|
|
||||||
|
|
||||||
if (diffDays > 0) {
|
|
||||||
expireDays = diffDays
|
|
||||||
neverExpire = false
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
editingKeyData.value = {
|
editingKeyData.value = {
|
||||||
id: apiKey.id,
|
id: apiKey.id,
|
||||||
name: apiKey.name || '',
|
name: apiKey.name || '',
|
||||||
expire_days: expireDays,
|
expires_at: expiresAt,
|
||||||
never_expire: neverExpire,
|
rate_limit: apiKey.rate_limit ?? undefined,
|
||||||
rate_limit: apiKey.rate_limit || 100,
|
|
||||||
auto_delete_on_expiry: apiKey.auto_delete_on_expiry || false,
|
auto_delete_on_expiry: apiKey.auto_delete_on_expiry || false,
|
||||||
allowed_providers: apiKey.allowed_providers || [],
|
allowed_providers: apiKey.allowed_providers || [],
|
||||||
allowed_api_formats: apiKey.allowed_api_formats || [],
|
allowed_api_formats: apiKey.allowed_api_formats || [],
|
||||||
@@ -1033,14 +1025,25 @@ function closeKeyFormDialog() {
|
|||||||
|
|
||||||
// 统一处理表单提交
|
// 统一处理表单提交
|
||||||
async function handleKeyFormSubmit(data: StandaloneKeyFormData) {
|
async function handleKeyFormSubmit(data: StandaloneKeyFormData) {
|
||||||
|
// 验证过期日期(如果设置了,必须晚于今天)
|
||||||
|
if (data.expires_at) {
|
||||||
|
const selectedDate = new Date(data.expires_at)
|
||||||
|
const today = new Date()
|
||||||
|
today.setHours(0, 0, 0, 0)
|
||||||
|
if (selectedDate <= today) {
|
||||||
|
error('过期日期必须晚于今天')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
keyFormDialogRef.value?.setSaving(true)
|
keyFormDialogRef.value?.setSaving(true)
|
||||||
try {
|
try {
|
||||||
if (data.id) {
|
if (data.id) {
|
||||||
// 更新
|
// 更新
|
||||||
const updateData: Partial<CreateStandaloneApiKeyRequest> = {
|
const updateData: Partial<CreateStandaloneApiKeyRequest> = {
|
||||||
name: data.name || undefined,
|
name: data.name || undefined,
|
||||||
rate_limit: data.rate_limit,
|
rate_limit: data.rate_limit ?? null, // undefined = 无限制,显式传 null
|
||||||
expire_days: data.never_expire ? null : (data.expire_days || null),
|
expires_at: data.expires_at || null, // undefined/空 = 永不过期
|
||||||
auto_delete_on_expiry: data.auto_delete_on_expiry,
|
auto_delete_on_expiry: data.auto_delete_on_expiry,
|
||||||
// 空数组表示清除限制(允许全部),后端会将空数组存为 NULL
|
// 空数组表示清除限制(允许全部),后端会将空数组存为 NULL
|
||||||
allowed_providers: data.allowed_providers,
|
allowed_providers: data.allowed_providers,
|
||||||
@@ -1058,8 +1061,8 @@ async function handleKeyFormSubmit(data: StandaloneKeyFormData) {
|
|||||||
const createData: CreateStandaloneApiKeyRequest = {
|
const createData: CreateStandaloneApiKeyRequest = {
|
||||||
name: data.name || undefined,
|
name: data.name || undefined,
|
||||||
initial_balance_usd: data.initial_balance_usd,
|
initial_balance_usd: data.initial_balance_usd,
|
||||||
rate_limit: data.rate_limit,
|
rate_limit: data.rate_limit ?? null, // undefined = 无限制,显式传 null
|
||||||
expire_days: data.never_expire ? null : (data.expire_days || null),
|
expires_at: data.expires_at || null, // undefined/空 = 永不过期
|
||||||
auto_delete_on_expiry: data.auto_delete_on_expiry,
|
auto_delete_on_expiry: data.auto_delete_on_expiry,
|
||||||
// 空数组表示不设置限制(允许全部),后端会将空数组存为 NULL
|
// 空数组表示不设置限制(允许全部),后端会将空数组存为 NULL
|
||||||
allowed_providers: data.allowed_providers,
|
allowed_providers: data.allowed_providers,
|
||||||
|
|||||||
@@ -5,6 +5,8 @@
|
|||||||
<ActivityHeatmapCard
|
<ActivityHeatmapCard
|
||||||
:data="activityHeatmapData"
|
:data="activityHeatmapData"
|
||||||
:title="isAdminPage ? '总体活跃天数' : '我的活跃天数'"
|
:title="isAdminPage ? '总体活跃天数' : '我的活跃天数'"
|
||||||
|
:is-loading="isLoadingHeatmap"
|
||||||
|
:has-error="heatmapError"
|
||||||
/>
|
/>
|
||||||
<IntervalTimelineCard
|
<IntervalTimelineCard
|
||||||
:title="isAdminPage ? '请求间隔时间线' : '我的请求间隔'"
|
:title="isAdminPage ? '请求间隔时间线' : '我的请求间隔'"
|
||||||
@@ -112,8 +114,11 @@ import {
|
|||||||
import type { PeriodValue, FilterStatusValue } from '@/features/usage/types'
|
import type { PeriodValue, FilterStatusValue } from '@/features/usage/types'
|
||||||
import type { UserOption } from '@/features/usage/components/UsageRecordsTable.vue'
|
import type { UserOption } from '@/features/usage/components/UsageRecordsTable.vue'
|
||||||
import { log } from '@/utils/logger'
|
import { log } from '@/utils/logger'
|
||||||
|
import type { ActivityHeatmap } from '@/types/activity'
|
||||||
|
import { useToast } from '@/composables/useToast'
|
||||||
|
|
||||||
const route = useRoute()
|
const route = useRoute()
|
||||||
|
const { warning } = useToast()
|
||||||
const authStore = useAuthStore()
|
const authStore = useAuthStore()
|
||||||
|
|
||||||
// 判断是否是管理员页面
|
// 判断是否是管理员页面
|
||||||
@@ -144,13 +149,35 @@ const {
|
|||||||
currentRecords,
|
currentRecords,
|
||||||
totalRecords,
|
totalRecords,
|
||||||
enhancedModelStats,
|
enhancedModelStats,
|
||||||
activityHeatmapData,
|
|
||||||
availableModels,
|
availableModels,
|
||||||
availableProviders,
|
availableProviders,
|
||||||
loadStats,
|
loadStats,
|
||||||
loadRecords
|
loadRecords
|
||||||
} = useUsageData({ isAdminPage })
|
} = useUsageData({ isAdminPage })
|
||||||
|
|
||||||
|
// 热力图状态
|
||||||
|
const activityHeatmapData = ref<ActivityHeatmap | null>(null)
|
||||||
|
const isLoadingHeatmap = ref(false)
|
||||||
|
const heatmapError = ref(false)
|
||||||
|
|
||||||
|
// 加载热力图数据
|
||||||
|
async function loadHeatmapData() {
|
||||||
|
isLoadingHeatmap.value = true
|
||||||
|
heatmapError.value = false
|
||||||
|
try {
|
||||||
|
if (isAdminPage.value) {
|
||||||
|
activityHeatmapData.value = await usageApi.getActivityHeatmap()
|
||||||
|
} else {
|
||||||
|
activityHeatmapData.value = await meApi.getActivityHeatmap()
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
log.error('加载热力图数据失败:', error)
|
||||||
|
heatmapError.value = true
|
||||||
|
} finally {
|
||||||
|
isLoadingHeatmap.value = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 用户页面需要前端筛选
|
// 用户页面需要前端筛选
|
||||||
const filteredRecords = computed(() => {
|
const filteredRecords = computed(() => {
|
||||||
if (!isAdminPage.value) {
|
if (!isAdminPage.value) {
|
||||||
@@ -232,27 +259,40 @@ async function pollActiveRequests() {
|
|||||||
? await usageApi.getActiveRequests(activeRequestIds.value)
|
? await usageApi.getActiveRequests(activeRequestIds.value)
|
||||||
: await meApi.getActiveRequests(idsParam)
|
: await meApi.getActiveRequests(idsParam)
|
||||||
|
|
||||||
// 检查是否有状态变化
|
let shouldRefresh = false
|
||||||
let hasChanges = false
|
|
||||||
for (const update of requests) {
|
for (const update of requests) {
|
||||||
const record = currentRecords.value.find(r => r.id === update.id)
|
const record = currentRecords.value.find(r => r.id === update.id)
|
||||||
if (record && record.status !== update.status) {
|
if (!record) {
|
||||||
hasChanges = true
|
// 后端返回了未知的活跃请求,触发刷新以获取完整数据
|
||||||
// 如果状态变为 completed 或 failed,需要刷新获取完整数据
|
shouldRefresh = true
|
||||||
if (update.status === 'completed' || update.status === 'failed') {
|
continue
|
||||||
break
|
}
|
||||||
}
|
|
||||||
// 否则只更新状态和 token 信息
|
// 状态变化:completed/failed 需要刷新获取完整数据
|
||||||
|
if (record.status !== update.status) {
|
||||||
record.status = update.status
|
record.status = update.status
|
||||||
record.input_tokens = update.input_tokens
|
}
|
||||||
record.output_tokens = update.output_tokens
|
if (update.status === 'completed' || update.status === 'failed') {
|
||||||
record.cost = update.cost
|
shouldRefresh = true
|
||||||
record.response_time_ms = update.response_time_ms ?? undefined
|
}
|
||||||
|
|
||||||
|
// 进行中状态也需要持续更新(provider/key/TTFB 可能在 streaming 后才落库)
|
||||||
|
record.input_tokens = update.input_tokens
|
||||||
|
record.output_tokens = update.output_tokens
|
||||||
|
record.cost = update.cost
|
||||||
|
record.response_time_ms = update.response_time_ms ?? undefined
|
||||||
|
record.first_byte_time_ms = update.first_byte_time_ms ?? undefined
|
||||||
|
// 管理员接口返回额外字段
|
||||||
|
if ('provider' in update && typeof update.provider === 'string') {
|
||||||
|
record.provider = update.provider
|
||||||
|
}
|
||||||
|
if ('api_key_name' in update) {
|
||||||
|
record.api_key_name = typeof update.api_key_name === 'string' ? update.api_key_name : undefined
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果有请求完成或失败,刷新整个列表获取完整数据
|
if (shouldRefresh) {
|
||||||
if (hasChanges && requests.some(r => r.status === 'completed' || r.status === 'failed')) {
|
|
||||||
await refreshData()
|
await refreshData()
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
@@ -335,7 +375,22 @@ const selectedRequestId = ref<string | null>(null)
|
|||||||
// 初始化加载
|
// 初始化加载
|
||||||
onMounted(async () => {
|
onMounted(async () => {
|
||||||
const dateRange = getDateRangeFromPeriod(selectedPeriod.value)
|
const dateRange = getDateRangeFromPeriod(selectedPeriod.value)
|
||||||
await loadStats(dateRange)
|
|
||||||
|
// 并行加载统计数据和热力图(使用 allSettled 避免其中一个失败影响另一个)
|
||||||
|
const [statsResult, heatmapResult] = await Promise.allSettled([
|
||||||
|
loadStats(dateRange),
|
||||||
|
loadHeatmapData()
|
||||||
|
])
|
||||||
|
|
||||||
|
// 检查加载结果并通知用户
|
||||||
|
if (statsResult.status === 'rejected') {
|
||||||
|
log.error('加载统计数据失败:', statsResult.reason)
|
||||||
|
warning('统计数据加载失败,请刷新重试')
|
||||||
|
}
|
||||||
|
if (heatmapResult.status === 'rejected') {
|
||||||
|
log.error('加载热力图数据失败:', heatmapResult.reason)
|
||||||
|
// 热力图加载失败不提示,因为 UI 已显示占位符
|
||||||
|
}
|
||||||
|
|
||||||
// 管理员页面加载用户列表和第一页记录
|
// 管理员页面加载用户列表和第一页记录
|
||||||
if (isAdminPage.value) {
|
if (isAdminPage.value) {
|
||||||
|
|||||||
@@ -3,22 +3,64 @@
|
|||||||
独立余额Key:不关联用户配额,有独立余额限制,用于给非注册用户使用。
|
独立余额Key:不关联用户配额,有独立余额限制,用于给非注册用户使用。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
import os
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from src.api.base.admin_adapter import AdminApiAdapter
|
from src.api.base.admin_adapter import AdminApiAdapter
|
||||||
from src.api.base.pipeline import ApiRequestPipeline
|
from src.api.base.pipeline import ApiRequestPipeline
|
||||||
from src.core.exceptions import NotFoundException
|
from src.core.exceptions import InvalidRequestException, NotFoundException
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
from src.database import get_db
|
from src.database import get_db
|
||||||
from src.models.api import CreateApiKeyRequest
|
from src.models.api import CreateApiKeyRequest
|
||||||
from src.models.database import ApiKey, User
|
from src.models.database import ApiKey
|
||||||
from src.services.user.apikey import ApiKeyService
|
from src.services.user.apikey import ApiKeyService
|
||||||
|
|
||||||
|
|
||||||
|
# 应用时区配置,默认为 Asia/Shanghai
|
||||||
|
APP_TIMEZONE = ZoneInfo(os.getenv("APP_TIMEZONE", "Asia/Shanghai"))
|
||||||
|
|
||||||
|
|
||||||
|
def parse_expiry_date(date_str: Optional[str]) -> Optional[datetime]:
|
||||||
|
"""解析过期日期字符串为 datetime 对象。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
date_str: 日期字符串,支持 "YYYY-MM-DD" 或 ISO 格式
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
datetime 对象(当天 23:59:59.999999,应用时区),或 None 如果输入为空
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BadRequestException: 日期格式无效
|
||||||
|
"""
|
||||||
|
if not date_str or not date_str.strip():
|
||||||
|
return None
|
||||||
|
|
||||||
|
date_str = date_str.strip()
|
||||||
|
|
||||||
|
# 尝试 YYYY-MM-DD 格式
|
||||||
|
try:
|
||||||
|
parsed_date = datetime.strptime(date_str, "%Y-%m-%d")
|
||||||
|
# 设置为当天结束时间 (23:59:59.999999,应用时区)
|
||||||
|
return parsed_date.replace(
|
||||||
|
hour=23, minute=59, second=59, microsecond=999999, tzinfo=APP_TIMEZONE
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 尝试完整 ISO 格式
|
||||||
|
try:
|
||||||
|
return datetime.fromisoformat(date_str.replace("Z", "+00:00"))
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
raise InvalidRequestException(f"无效的日期格式: {date_str},请使用 YYYY-MM-DD 格式")
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/admin/api-keys", tags=["Admin - API Keys (Standalone)"])
|
router = APIRouter(prefix="/api/admin/api-keys", tags=["Admin - API Keys (Standalone)"])
|
||||||
pipeline = ApiRequestPipeline()
|
pipeline = ApiRequestPipeline()
|
||||||
|
|
||||||
@@ -215,6 +257,9 @@ class AdminCreateStandaloneKeyAdapter(AdminApiAdapter):
|
|||||||
# 独立Key需要关联到管理员用户(从context获取)
|
# 独立Key需要关联到管理员用户(从context获取)
|
||||||
admin_user_id = context.user.id
|
admin_user_id = context.user.id
|
||||||
|
|
||||||
|
# 解析过期时间(优先使用 expires_at,其次使用 expire_days)
|
||||||
|
expires_at_dt = parse_expiry_date(self.key_data.expires_at)
|
||||||
|
|
||||||
# 创建独立Key
|
# 创建独立Key
|
||||||
api_key, plain_key = ApiKeyService.create_api_key(
|
api_key, plain_key = ApiKeyService.create_api_key(
|
||||||
db=db,
|
db=db,
|
||||||
@@ -224,7 +269,8 @@ class AdminCreateStandaloneKeyAdapter(AdminApiAdapter):
|
|||||||
allowed_api_formats=self.key_data.allowed_api_formats,
|
allowed_api_formats=self.key_data.allowed_api_formats,
|
||||||
allowed_models=self.key_data.allowed_models,
|
allowed_models=self.key_data.allowed_models,
|
||||||
rate_limit=self.key_data.rate_limit, # None 表示不限制
|
rate_limit=self.key_data.rate_limit, # None 表示不限制
|
||||||
expire_days=self.key_data.expire_days,
|
expire_days=self.key_data.expire_days, # 兼容旧版
|
||||||
|
expires_at=expires_at_dt, # 优先使用
|
||||||
initial_balance_usd=self.key_data.initial_balance_usd,
|
initial_balance_usd=self.key_data.initial_balance_usd,
|
||||||
is_standalone=True, # 标记为独立Key
|
is_standalone=True, # 标记为独立Key
|
||||||
auto_delete_on_expiry=self.key_data.auto_delete_on_expiry,
|
auto_delete_on_expiry=self.key_data.auto_delete_on_expiry,
|
||||||
@@ -270,7 +316,8 @@ class AdminUpdateApiKeyAdapter(AdminApiAdapter):
|
|||||||
update_data = {}
|
update_data = {}
|
||||||
if self.key_data.name is not None:
|
if self.key_data.name is not None:
|
||||||
update_data["name"] = self.key_data.name
|
update_data["name"] = self.key_data.name
|
||||||
if self.key_data.rate_limit is not None:
|
# rate_limit: 显式传递时更新(包括 null 表示无限制)
|
||||||
|
if "rate_limit" in self.key_data.model_fields_set:
|
||||||
update_data["rate_limit"] = self.key_data.rate_limit
|
update_data["rate_limit"] = self.key_data.rate_limit
|
||||||
if (
|
if (
|
||||||
hasattr(self.key_data, "auto_delete_on_expiry")
|
hasattr(self.key_data, "auto_delete_on_expiry")
|
||||||
@@ -287,19 +334,21 @@ class AdminUpdateApiKeyAdapter(AdminApiAdapter):
|
|||||||
update_data["allowed_models"] = self.key_data.allowed_models
|
update_data["allowed_models"] = self.key_data.allowed_models
|
||||||
|
|
||||||
# 处理过期时间
|
# 处理过期时间
|
||||||
if self.key_data.expire_days is not None:
|
# 优先使用 expires_at(如果显式传递且有值)
|
||||||
if self.key_data.expire_days > 0:
|
if self.key_data.expires_at and self.key_data.expires_at.strip():
|
||||||
from datetime import timedelta
|
update_data["expires_at"] = parse_expiry_date(self.key_data.expires_at)
|
||||||
|
elif "expires_at" in self.key_data.model_fields_set:
|
||||||
|
# expires_at 明确传递为 null 或空字符串,设为永不过期
|
||||||
|
update_data["expires_at"] = None
|
||||||
|
# 兼容旧版 expire_days
|
||||||
|
elif "expire_days" in self.key_data.model_fields_set:
|
||||||
|
if self.key_data.expire_days is not None and self.key_data.expire_days > 0:
|
||||||
update_data["expires_at"] = datetime.now(timezone.utc) + timedelta(
|
update_data["expires_at"] = datetime.now(timezone.utc) + timedelta(
|
||||||
days=self.key_data.expire_days
|
days=self.key_data.expire_days
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# expire_days = 0 或负数表示永不过期
|
# expire_days = None/0/负数 表示永不过期
|
||||||
update_data["expires_at"] = None
|
update_data["expires_at"] = None
|
||||||
elif hasattr(self.key_data, "expire_days") and self.key_data.expire_days is None:
|
|
||||||
# 明确传递 None,设为永不过期
|
|
||||||
update_data["expires_at"] = None
|
|
||||||
|
|
||||||
# 使用 ApiKeyService 更新
|
# 使用 ApiKeyService 更新
|
||||||
updated_key = ApiKeyService.update_api_key(db, self.key_id, **update_data)
|
updated_key = ApiKeyService.update_api_key(db, self.key_id, **update_data)
|
||||||
|
|||||||
@@ -206,6 +206,7 @@ class AdminCreateProviderEndpointAdapter(AdminApiAdapter):
|
|||||||
provider_id=self.provider_id,
|
provider_id=self.provider_id,
|
||||||
api_format=self.endpoint_data.api_format,
|
api_format=self.endpoint_data.api_format,
|
||||||
base_url=self.endpoint_data.base_url,
|
base_url=self.endpoint_data.base_url,
|
||||||
|
custom_path=self.endpoint_data.custom_path,
|
||||||
headers=self.endpoint_data.headers,
|
headers=self.endpoint_data.headers,
|
||||||
timeout=self.endpoint_data.timeout,
|
timeout=self.endpoint_data.timeout,
|
||||||
max_retries=self.endpoint_data.max_retries,
|
max_retries=self.endpoint_data.max_retries,
|
||||||
|
|||||||
@@ -146,20 +146,25 @@ class AdminListGlobalModelsAdapter(AdminApiAdapter):
|
|||||||
search=self.search,
|
search=self.search,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 为每个 GlobalModel 添加统计数据
|
# 一次性查询所有 GlobalModel 的 provider_count(优化 N+1 问题)
|
||||||
|
model_ids = [gm.id for gm in models]
|
||||||
|
provider_counts = {}
|
||||||
|
if model_ids:
|
||||||
|
count_results = (
|
||||||
|
context.db.query(
|
||||||
|
Model.global_model_id, func.count(func.distinct(Model.provider_id))
|
||||||
|
)
|
||||||
|
.filter(Model.global_model_id.in_(model_ids))
|
||||||
|
.group_by(Model.global_model_id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
provider_counts = {gm_id: count for gm_id, count in count_results}
|
||||||
|
|
||||||
|
# 构建响应
|
||||||
model_responses = []
|
model_responses = []
|
||||||
for gm in models:
|
for gm in models:
|
||||||
# 统计关联的 Model 数量(去重 Provider)
|
|
||||||
provider_count = (
|
|
||||||
context.db.query(func.count(func.distinct(Model.provider_id)))
|
|
||||||
.filter(Model.global_model_id == gm.id)
|
|
||||||
.scalar()
|
|
||||||
or 0
|
|
||||||
)
|
|
||||||
|
|
||||||
response = GlobalModelResponse.model_validate(gm)
|
response = GlobalModelResponse.model_validate(gm)
|
||||||
response.provider_count = provider_count
|
response.provider_count = provider_counts.get(gm.id, 0)
|
||||||
# usage_count 直接从 GlobalModel 表读取,已在 model_validate 中自动映射
|
|
||||||
model_responses.append(response)
|
model_responses.append(response)
|
||||||
|
|
||||||
return GlobalModelListResponse(
|
return GlobalModelListResponse(
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
提供商策略管理 API 端点
|
提供商策略管理 API 端点
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
@@ -103,6 +103,9 @@ class AdminProviderBillingAdapter(AdminApiAdapter):
|
|||||||
|
|
||||||
if config.quota_last_reset_at:
|
if config.quota_last_reset_at:
|
||||||
new_reset_at = parser.parse(config.quota_last_reset_at)
|
new_reset_at = parser.parse(config.quota_last_reset_at)
|
||||||
|
# 确保有时区信息,如果没有则假设为 UTC
|
||||||
|
if new_reset_at.tzinfo is None:
|
||||||
|
new_reset_at = new_reset_at.replace(tzinfo=timezone.utc)
|
||||||
provider.quota_last_reset_at = new_reset_at
|
provider.quota_last_reset_at = new_reset_at
|
||||||
|
|
||||||
# 自动同步该周期内的历史使用量
|
# 自动同步该周期内的历史使用量
|
||||||
@@ -118,7 +121,11 @@ class AdminProviderBillingAdapter(AdminApiAdapter):
|
|||||||
logger.info(f"Synced usage for provider {provider.name}: ${period_usage:.4f} since {new_reset_at}")
|
logger.info(f"Synced usage for provider {provider.name}: ${period_usage:.4f} since {new_reset_at}")
|
||||||
|
|
||||||
if config.quota_expires_at:
|
if config.quota_expires_at:
|
||||||
provider.quota_expires_at = parser.parse(config.quota_expires_at)
|
expires_at = parser.parse(config.quota_expires_at)
|
||||||
|
# 确保有时区信息,如果没有则假设为 UTC
|
||||||
|
if expires_at.tzinfo is None:
|
||||||
|
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||||
|
provider.quota_expires_at = expires_at
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(provider)
|
db.refresh(provider)
|
||||||
@@ -149,7 +156,7 @@ class AdminProviderStatsAdapter(AdminApiAdapter):
|
|||||||
if not provider:
|
if not provider:
|
||||||
raise HTTPException(status_code=404, detail="Provider not found")
|
raise HTTPException(status_code=404, detail="Provider not found")
|
||||||
|
|
||||||
since = datetime.now() - timedelta(hours=self.hours)
|
since = datetime.now(timezone.utc) - timedelta(hours=self.hours)
|
||||||
stats = (
|
stats = (
|
||||||
db.query(ProviderUsageTracking)
|
db.query(ProviderUsageTracking)
|
||||||
.filter(
|
.filter(
|
||||||
|
|||||||
@@ -1133,7 +1133,7 @@ class AdminImportUsersAdapter(AdminApiAdapter):
|
|||||||
allowed_endpoints=key_data.get("allowed_endpoints"),
|
allowed_endpoints=key_data.get("allowed_endpoints"),
|
||||||
allowed_api_formats=key_data.get("allowed_api_formats"),
|
allowed_api_formats=key_data.get("allowed_api_formats"),
|
||||||
allowed_models=key_data.get("allowed_models"),
|
allowed_models=key_data.get("allowed_models"),
|
||||||
rate_limit=key_data.get("rate_limit", 100),
|
rate_limit=key_data.get("rate_limit"), # None = 无限制
|
||||||
concurrent_limit=key_data.get("concurrent_limit", 5),
|
concurrent_limit=key_data.get("concurrent_limit", 5),
|
||||||
force_capabilities=key_data.get("force_capabilities"),
|
force_capabilities=key_data.get("force_capabilities"),
|
||||||
is_active=key_data.get("is_active", True),
|
is_active=key_data.get("is_active", True),
|
||||||
|
|||||||
@@ -73,6 +73,20 @@ async def get_usage_stats(
|
|||||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/heatmap")
|
||||||
|
async def get_activity_heatmap(
|
||||||
|
request: Request,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get activity heatmap data for the past 365 days.
|
||||||
|
|
||||||
|
This endpoint is cached for 5 minutes to reduce database load.
|
||||||
|
"""
|
||||||
|
adapter = AdminActivityHeatmapAdapter()
|
||||||
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/records")
|
@router.get("/records")
|
||||||
async def get_usage_records(
|
async def get_usage_records(
|
||||||
request: Request,
|
request: Request,
|
||||||
@@ -168,12 +182,6 @@ class AdminUsageStatsAdapter(AdminApiAdapter):
|
|||||||
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
|
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
|
||||||
).count()
|
).count()
|
||||||
|
|
||||||
activity_heatmap = UsageService.get_daily_activity(
|
|
||||||
db=db,
|
|
||||||
window_days=365,
|
|
||||||
include_actual_cost=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
context.add_audit_metadata(
|
context.add_audit_metadata(
|
||||||
action="usage_stats",
|
action="usage_stats",
|
||||||
start_date=self.start_date.isoformat() if self.start_date else None,
|
start_date=self.start_date.isoformat() if self.start_date else None,
|
||||||
@@ -204,10 +212,22 @@ class AdminUsageStatsAdapter(AdminApiAdapter):
|
|||||||
),
|
),
|
||||||
"cache_read_cost": float(cache_stats.cache_read_cost or 0) if cache_stats else 0,
|
"cache_read_cost": float(cache_stats.cache_read_cost or 0) if cache_stats else 0,
|
||||||
},
|
},
|
||||||
"activity_heatmap": activity_heatmap,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AdminActivityHeatmapAdapter(AdminApiAdapter):
|
||||||
|
"""Activity heatmap adapter with Redis caching."""
|
||||||
|
|
||||||
|
async def handle(self, context): # type: ignore[override]
|
||||||
|
result = await UsageService.get_cached_heatmap(
|
||||||
|
db=context.db,
|
||||||
|
user_id=None,
|
||||||
|
include_actual_cost=True,
|
||||||
|
)
|
||||||
|
context.add_audit_metadata(action="activity_heatmap")
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class AdminUsageByModelAdapter(AdminApiAdapter):
|
class AdminUsageByModelAdapter(AdminApiAdapter):
|
||||||
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int):
|
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int):
|
||||||
self.start_date = start_date
|
self.start_date = start_date
|
||||||
@@ -670,7 +690,9 @@ class AdminActiveRequestsAdapter(AdminApiAdapter):
|
|||||||
if not id_list:
|
if not id_list:
|
||||||
return {"requests": []}
|
return {"requests": []}
|
||||||
|
|
||||||
requests = UsageService.get_active_requests_status(db=db, ids=id_list)
|
requests = UsageService.get_active_requests_status(
|
||||||
|
db=db, ids=id_list, include_admin_fields=True
|
||||||
|
)
|
||||||
return {"requests": requests}
|
return {"requests": requests}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -248,6 +248,7 @@ class AdminUpdateUserAdapter(AdminApiAdapter):
|
|||||||
raise InvalidRequestException("请求数据验证失败")
|
raise InvalidRequestException("请求数据验证失败")
|
||||||
|
|
||||||
update_data = request.model_dump(exclude_unset=True)
|
update_data = request.model_dump(exclude_unset=True)
|
||||||
|
old_role = existing_user.role
|
||||||
if "role" in update_data and update_data["role"]:
|
if "role" in update_data and update_data["role"]:
|
||||||
if hasattr(update_data["role"], "value"):
|
if hasattr(update_data["role"], "value"):
|
||||||
update_data["role"] = update_data["role"]
|
update_data["role"] = update_data["role"]
|
||||||
@@ -258,6 +259,12 @@ class AdminUpdateUserAdapter(AdminApiAdapter):
|
|||||||
if not user:
|
if not user:
|
||||||
raise NotFoundException("用户不存在", "user")
|
raise NotFoundException("用户不存在", "user")
|
||||||
|
|
||||||
|
# 角色变更时清除热力图缓存(影响 include_actual_cost 权限)
|
||||||
|
if "role" in update_data and update_data["role"] != old_role:
|
||||||
|
from src.services.usage.service import UsageService
|
||||||
|
|
||||||
|
await UsageService.clear_user_heatmap_cache(self.user_id)
|
||||||
|
|
||||||
changed_fields = list(update_data.keys())
|
changed_fields = list(update_data.keys())
|
||||||
context.add_audit_metadata(
|
context.add_audit_metadata(
|
||||||
action="update_user",
|
action="update_user",
|
||||||
@@ -424,7 +431,7 @@ class AdminCreateUserKeyAdapter(AdminApiAdapter):
|
|||||||
name=key_data.name,
|
name=key_data.name,
|
||||||
allowed_providers=key_data.allowed_providers,
|
allowed_providers=key_data.allowed_providers,
|
||||||
allowed_models=key_data.allowed_models,
|
allowed_models=key_data.allowed_models,
|
||||||
rate_limit=key_data.rate_limit or 100,
|
rate_limit=key_data.rate_limit, # None = 无限制
|
||||||
expire_days=key_data.expire_days,
|
expire_days=key_data.expire_days,
|
||||||
initial_balance_usd=None, # 普通Key不设置余额限制
|
initial_balance_usd=None, # 普通Key不设置余额限制
|
||||||
is_standalone=False, # 不是独立Key
|
is_standalone=False, # 不是独立Key
|
||||||
|
|||||||
@@ -47,7 +47,6 @@ if TYPE_CHECKING:
|
|||||||
from src.api.handlers.base.stream_context import StreamContext
|
from src.api.handlers.base.stream_context import StreamContext
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MessageTelemetry:
|
class MessageTelemetry:
|
||||||
"""
|
"""
|
||||||
负责记录 Usage/Audit,避免处理器里重复代码。
|
负责记录 Usage/Audit,避免处理器里重复代码。
|
||||||
@@ -406,7 +405,7 @@ class BaseMessageHandler:
|
|||||||
asyncio.create_task(_do_update())
|
asyncio.create_task(_do_update())
|
||||||
|
|
||||||
def _update_usage_to_streaming_with_ctx(self, ctx: "StreamContext") -> None:
|
def _update_usage_to_streaming_with_ctx(self, ctx: "StreamContext") -> None:
|
||||||
"""更新 Usage 状态为 streaming,同时更新 provider 和 target_model
|
"""更新 Usage 状态为 streaming,同时更新 provider 相关信息
|
||||||
|
|
||||||
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
|
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
|
||||||
|
|
||||||
@@ -414,7 +413,7 @@ class BaseMessageHandler:
|
|||||||
并在最终 record_success 时传递到数据库,避免重复记录导致数据不一致。
|
并在最终 record_success 时传递到数据库,避免重复记录导致数据不一致。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ctx: 流式上下文,包含 provider_name 和 mapped_model
|
ctx: 流式上下文,包含 provider 相关信息
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
from src.database.database import get_db
|
from src.database.database import get_db
|
||||||
@@ -422,6 +421,17 @@ class BaseMessageHandler:
|
|||||||
target_request_id = self.request_id
|
target_request_id = self.request_id
|
||||||
provider = ctx.provider_name
|
provider = ctx.provider_name
|
||||||
target_model = ctx.mapped_model
|
target_model = ctx.mapped_model
|
||||||
|
provider_id = ctx.provider_id
|
||||||
|
endpoint_id = ctx.endpoint_id
|
||||||
|
key_id = ctx.key_id
|
||||||
|
first_byte_time_ms = ctx.first_byte_time_ms
|
||||||
|
|
||||||
|
# 如果 provider 为空,记录警告(不应该发生,但用于调试)
|
||||||
|
if not provider:
|
||||||
|
logger.warning(
|
||||||
|
f"[{target_request_id}] 更新 streaming 状态时 provider 为空: "
|
||||||
|
f"ctx.provider_name={ctx.provider_name}, ctx.provider_id={ctx.provider_id}"
|
||||||
|
)
|
||||||
|
|
||||||
async def _do_update() -> None:
|
async def _do_update() -> None:
|
||||||
try:
|
try:
|
||||||
@@ -434,6 +444,10 @@ class BaseMessageHandler:
|
|||||||
status="streaming",
|
status="streaming",
|
||||||
provider=provider,
|
provider=provider,
|
||||||
target_model=target_model,
|
target_model=target_model,
|
||||||
|
provider_id=provider_id,
|
||||||
|
provider_endpoint_id=endpoint_id,
|
||||||
|
provider_api_key_id=key_id,
|
||||||
|
first_byte_time_ms=first_byte_time_ms,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ from src.api.handlers.base.stream_processor import StreamProcessor
|
|||||||
from src.api.handlers.base.stream_telemetry import StreamTelemetryRecorder
|
from src.api.handlers.base.stream_telemetry import StreamTelemetryRecorder
|
||||||
from src.api.handlers.base.utils import build_sse_headers
|
from src.api.handlers.base.utils import build_sse_headers
|
||||||
from src.config.settings import config
|
from src.config.settings import config
|
||||||
|
from src.core.error_utils import extract_error_message
|
||||||
from src.core.exceptions import (
|
from src.core.exceptions import (
|
||||||
EmbeddedErrorException,
|
EmbeddedErrorException,
|
||||||
ProviderAuthException,
|
ProviderAuthException,
|
||||||
@@ -500,6 +501,8 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
error_text = await self._extract_error_text(e)
|
error_text = await self._extract_error_text(e)
|
||||||
logger.error(f"Provider 返回错误: {e.response.status_code}\n Response: {error_text}")
|
logger.error(f"Provider 返回错误: {e.response.status_code}\n Response: {error_text}")
|
||||||
await http_client.aclose()
|
await http_client.aclose()
|
||||||
|
# 将上游错误信息附加到异常,以便故障转移时能够返回给客户端
|
||||||
|
e.upstream_response = error_text # type: ignore[attr-defined]
|
||||||
raise
|
raise
|
||||||
|
|
||||||
except EmbeddedErrorException:
|
except EmbeddedErrorException:
|
||||||
@@ -549,7 +552,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
model=ctx.model,
|
model=ctx.model,
|
||||||
response_time_ms=response_time_ms,
|
response_time_ms=response_time_ms,
|
||||||
status_code=status_code,
|
status_code=status_code,
|
||||||
error_message=str(error),
|
error_message=extract_error_message(error),
|
||||||
request_headers=original_headers,
|
request_headers=original_headers,
|
||||||
request_body=actual_request_body,
|
request_body=actual_request_body,
|
||||||
is_stream=True,
|
is_stream=True,
|
||||||
@@ -785,7 +788,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
model=model,
|
model=model,
|
||||||
response_time_ms=response_time_ms,
|
response_time_ms=response_time_ms,
|
||||||
status_code=status_code,
|
status_code=status_code,
|
||||||
error_message=str(e),
|
error_message=extract_error_message(e),
|
||||||
request_headers=original_headers,
|
request_headers=original_headers,
|
||||||
request_body=actual_request_body,
|
request_body=actual_request_body,
|
||||||
is_stream=False,
|
is_stream=False,
|
||||||
@@ -802,10 +805,10 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
|||||||
try:
|
try:
|
||||||
if hasattr(e.response, "is_stream_consumed") and not e.response.is_stream_consumed:
|
if hasattr(e.response, "is_stream_consumed") and not e.response.is_stream_consumed:
|
||||||
error_bytes = await e.response.aread()
|
error_bytes = await e.response.aread()
|
||||||
return error_bytes.decode("utf-8", errors="replace")[:500]
|
return error_bytes.decode("utf-8", errors="replace")
|
||||||
else:
|
else:
|
||||||
return (
|
return (
|
||||||
e.response.text[:500] if hasattr(e.response, "_content") else "Unable to read"
|
e.response.text if hasattr(e.response, "_content") else "Unable to read"
|
||||||
)
|
)
|
||||||
except Exception as decode_error:
|
except Exception as decode_error:
|
||||||
return f"Unable to read error: {decode_error}"
|
return f"Unable to read error: {decode_error}"
|
||||||
|
|||||||
@@ -34,7 +34,12 @@ from src.api.handlers.base.base_handler import (
|
|||||||
from src.api.handlers.base.parsers import get_parser_for_format
|
from src.api.handlers.base.parsers import get_parser_for_format
|
||||||
from src.api.handlers.base.request_builder import PassthroughRequestBuilder
|
from src.api.handlers.base.request_builder import PassthroughRequestBuilder
|
||||||
from src.api.handlers.base.stream_context import StreamContext
|
from src.api.handlers.base.stream_context import StreamContext
|
||||||
from src.api.handlers.base.utils import build_sse_headers
|
from src.api.handlers.base.utils import (
|
||||||
|
build_sse_headers,
|
||||||
|
check_html_response,
|
||||||
|
check_prefetched_response_error,
|
||||||
|
)
|
||||||
|
from src.core.error_utils import extract_error_message
|
||||||
|
|
||||||
# 直接从具体模块导入,避免循环依赖
|
# 直接从具体模块导入,避免循环依赖
|
||||||
from src.api.handlers.base.response_parser import (
|
from src.api.handlers.base.response_parser import (
|
||||||
@@ -57,6 +62,7 @@ from src.models.database import (
|
|||||||
ProviderEndpoint,
|
ProviderEndpoint,
|
||||||
User,
|
User,
|
||||||
)
|
)
|
||||||
|
from src.config.constants import StreamDefaults
|
||||||
from src.config.settings import config
|
from src.config.settings import config
|
||||||
from src.services.provider.transport import build_provider_url
|
from src.services.provider.transport import build_provider_url
|
||||||
from src.utils.sse_parser import SSEEventParser
|
from src.utils.sse_parser import SSEEventParser
|
||||||
@@ -328,9 +334,9 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
stream_generator,
|
stream_generator,
|
||||||
provider_name,
|
provider_name,
|
||||||
attempt_id,
|
attempt_id,
|
||||||
_provider_id,
|
provider_id,
|
||||||
_endpoint_id,
|
endpoint_id,
|
||||||
_key_id,
|
key_id,
|
||||||
) = await self.orchestrator.execute_with_fallback(
|
) = await self.orchestrator.execute_with_fallback(
|
||||||
api_format=ctx.api_format,
|
api_format=ctx.api_format,
|
||||||
model_name=ctx.model,
|
model_name=ctx.model,
|
||||||
@@ -340,7 +346,17 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
is_stream=True,
|
is_stream=True,
|
||||||
capability_requirements=capability_requirements or None,
|
capability_requirements=capability_requirements or None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 更新上下文(确保 provider 信息已设置,用于 streaming 状态更新)
|
||||||
ctx.attempt_id = attempt_id
|
ctx.attempt_id = attempt_id
|
||||||
|
if not ctx.provider_name:
|
||||||
|
ctx.provider_name = provider_name
|
||||||
|
if not ctx.provider_id:
|
||||||
|
ctx.provider_id = provider_id
|
||||||
|
if not ctx.endpoint_id:
|
||||||
|
ctx.endpoint_id = endpoint_id
|
||||||
|
if not ctx.key_id:
|
||||||
|
ctx.key_id = key_id
|
||||||
|
|
||||||
# 创建后台任务记录统计
|
# 创建后台任务记录统计
|
||||||
background_tasks = BackgroundTasks()
|
background_tasks = BackgroundTasks()
|
||||||
@@ -488,6 +504,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
error_text = await self._extract_error_text(e)
|
error_text = await self._extract_error_text(e)
|
||||||
logger.error(f"Provider 返回错误状态: {e.response.status_code}\n Response: {error_text}")
|
logger.error(f"Provider 返回错误状态: {e.response.status_code}\n Response: {error_text}")
|
||||||
await http_client.aclose()
|
await http_client.aclose()
|
||||||
|
# 将上游错误信息附加到异常,以便故障转移时能够返回给客户端
|
||||||
|
e.upstream_response = error_text # type: ignore[attr-defined]
|
||||||
raise
|
raise
|
||||||
|
|
||||||
except EmbeddedErrorException:
|
except EmbeddedErrorException:
|
||||||
@@ -523,8 +541,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
try:
|
try:
|
||||||
sse_parser = SSEEventParser()
|
sse_parser = SSEEventParser()
|
||||||
last_data_time = time.time()
|
last_data_time = time.time()
|
||||||
streaming_status_updated = False
|
|
||||||
buffer = b""
|
buffer = b""
|
||||||
|
output_state = {"first_yield": True, "streaming_updated": False}
|
||||||
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
|
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
|
||||||
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
||||||
|
|
||||||
@@ -532,11 +550,6 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
needs_conversion = self._needs_format_conversion(ctx)
|
needs_conversion = self._needs_format_conversion(ctx)
|
||||||
|
|
||||||
async for chunk in stream_response.aiter_bytes():
|
async for chunk in stream_response.aiter_bytes():
|
||||||
# 在第一次输出数据前更新状态为 streaming
|
|
||||||
if not streaming_status_updated:
|
|
||||||
self._update_usage_to_streaming_with_ctx(ctx)
|
|
||||||
streaming_status_updated = True
|
|
||||||
|
|
||||||
buffer += chunk
|
buffer += chunk
|
||||||
# 处理缓冲区中的完整行
|
# 处理缓冲区中的完整行
|
||||||
while b"\n" in buffer:
|
while b"\n" in buffer:
|
||||||
@@ -561,6 +574,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
event.get("event"),
|
event.get("event"),
|
||||||
event.get("data") or "",
|
event.get("data") or "",
|
||||||
)
|
)
|
||||||
|
self._mark_first_output(ctx, output_state)
|
||||||
yield b"\n"
|
yield b"\n"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -578,6 +592,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
|
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
self._mark_first_output(ctx, output_state)
|
||||||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
||||||
return # 结束生成器
|
return # 结束生成器
|
||||||
|
|
||||||
@@ -585,8 +600,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
if needs_conversion:
|
if needs_conversion:
|
||||||
converted_line = self._convert_sse_line(ctx, line, events)
|
converted_line = self._convert_sse_line(ctx, line, events)
|
||||||
if converted_line:
|
if converted_line:
|
||||||
|
self._mark_first_output(ctx, output_state)
|
||||||
yield (converted_line + "\n").encode("utf-8")
|
yield (converted_line + "\n").encode("utf-8")
|
||||||
else:
|
else:
|
||||||
|
self._mark_first_output(ctx, output_state)
|
||||||
yield (line + "\n").encode("utf-8")
|
yield (line + "\n").encode("utf-8")
|
||||||
|
|
||||||
for event in events:
|
for event in events:
|
||||||
@@ -637,7 +654,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
||||||
except httpx.RemoteProtocolError as e:
|
except httpx.RemoteProtocolError:
|
||||||
if ctx.data_count > 0:
|
if ctx.data_count > 0:
|
||||||
error_event = {
|
error_event = {
|
||||||
"type": "error",
|
"type": "error",
|
||||||
@@ -691,7 +708,9 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
ProviderTimeoutException: 如果首字节超时(TTFB timeout)
|
ProviderTimeoutException: 如果首字节超时(TTFB timeout)
|
||||||
"""
|
"""
|
||||||
prefetched_chunks: list = []
|
prefetched_chunks: list = []
|
||||||
max_prefetch_lines = 5 # 最多预读5行来检测错误
|
max_prefetch_lines = config.stream_prefetch_lines # 最多预读行数来检测错误
|
||||||
|
max_prefetch_bytes = StreamDefaults.MAX_PREFETCH_BYTES # 避免无换行响应导致 buffer 增长
|
||||||
|
total_prefetched_bytes = 0
|
||||||
buffer = b""
|
buffer = b""
|
||||||
line_count = 0
|
line_count = 0
|
||||||
should_stop = False
|
should_stop = False
|
||||||
@@ -718,14 +737,16 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
provider_name=str(provider.name),
|
provider_name=str(provider.name),
|
||||||
)
|
)
|
||||||
prefetched_chunks.append(first_chunk)
|
prefetched_chunks.append(first_chunk)
|
||||||
|
total_prefetched_bytes += len(first_chunk)
|
||||||
buffer += first_chunk
|
buffer += first_chunk
|
||||||
|
|
||||||
# 继续读取剩余的预读数据
|
# 继续读取剩余的预读数据
|
||||||
async for chunk in aiter:
|
async for chunk in aiter:
|
||||||
prefetched_chunks.append(chunk)
|
prefetched_chunks.append(chunk)
|
||||||
|
total_prefetched_bytes += len(chunk)
|
||||||
buffer += chunk
|
buffer += chunk
|
||||||
|
|
||||||
# 尝试按行解析缓冲区
|
# 尝试按行解析缓冲区(SSE 格式)
|
||||||
while b"\n" in buffer:
|
while b"\n" in buffer:
|
||||||
line_bytes, buffer = buffer.split(b"\n", 1)
|
line_bytes, buffer = buffer.split(b"\n", 1)
|
||||||
try:
|
try:
|
||||||
@@ -742,15 +763,15 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
normalized_line = line.rstrip("\r")
|
normalized_line = line.rstrip("\r")
|
||||||
|
|
||||||
# 检测 HTML 响应(base_url 配置错误的常见症状)
|
# 检测 HTML 响应(base_url 配置错误的常见症状)
|
||||||
lower_line = normalized_line.lower()
|
if check_html_response(normalized_line):
|
||||||
if lower_line.startswith("<!doctype") or lower_line.startswith("<html"):
|
|
||||||
logger.error(
|
logger.error(
|
||||||
f" [{self.request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
|
f" [{self.request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
|
||||||
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
|
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
|
||||||
f"base_url={endpoint.base_url}"
|
f"base_url={endpoint.base_url}"
|
||||||
)
|
)
|
||||||
raise ProviderNotAvailableException(
|
raise ProviderNotAvailableException(
|
||||||
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,请检查 endpoint 的 base_url 配置是否正确"
|
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,"
|
||||||
|
f"请检查 endpoint 的 base_url 配置是否正确"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not normalized_line or normalized_line.startswith(":"):
|
if not normalized_line or normalized_line.startswith(":"):
|
||||||
@@ -799,9 +820,30 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
should_stop = True
|
should_stop = True
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# 达到预读字节上限,停止继续预读(避免无换行响应导致内存增长)
|
||||||
|
if not should_stop and total_prefetched_bytes >= max_prefetch_bytes:
|
||||||
|
logger.debug(
|
||||||
|
f" [{self.request_id}] 预读达到字节上限,停止继续预读: "
|
||||||
|
f"Provider={provider.name}, bytes={total_prefetched_bytes}, "
|
||||||
|
f"max_bytes={max_prefetch_bytes}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
if should_stop or line_count >= max_prefetch_lines:
|
if should_stop or line_count >= max_prefetch_lines:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# 预读结束后,检查是否为非 SSE 格式的 HTML/JSON 响应
|
||||||
|
# 处理某些代理返回的纯 JSON 错误(可能无换行/多行 JSON)以及 HTML 页面(base_url 配置错误)
|
||||||
|
if not should_stop and prefetched_chunks:
|
||||||
|
check_prefetched_response_error(
|
||||||
|
prefetched_chunks=prefetched_chunks,
|
||||||
|
parser=provider_parser,
|
||||||
|
request_id=self.request_id,
|
||||||
|
provider_name=str(provider.name),
|
||||||
|
endpoint_id=endpoint.id,
|
||||||
|
base_url=endpoint.base_url,
|
||||||
|
)
|
||||||
|
|
||||||
except (EmbeddedErrorException, ProviderTimeoutException, ProviderNotAvailableException):
|
except (EmbeddedErrorException, ProviderTimeoutException, ProviderNotAvailableException):
|
||||||
# 重新抛出可重试的 Provider 异常,触发故障转移
|
# 重新抛出可重试的 Provider 异常,触发故障转移
|
||||||
raise
|
raise
|
||||||
@@ -833,17 +875,13 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
sse_parser = SSEEventParser()
|
sse_parser = SSEEventParser()
|
||||||
last_data_time = time.time()
|
last_data_time = time.time()
|
||||||
buffer = b""
|
buffer = b""
|
||||||
first_yield = True # 标记是否是第一次 yield
|
output_state = {"first_yield": True, "streaming_updated": False}
|
||||||
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
|
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
|
||||||
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
||||||
|
|
||||||
# 检查是否需要格式转换
|
# 检查是否需要格式转换
|
||||||
needs_conversion = self._needs_format_conversion(ctx)
|
needs_conversion = self._needs_format_conversion(ctx)
|
||||||
|
|
||||||
# 在第一次输出数据前更新状态为 streaming
|
|
||||||
if prefetched_chunks:
|
|
||||||
self._update_usage_to_streaming_with_ctx(ctx)
|
|
||||||
|
|
||||||
# 先处理预读的字节块
|
# 先处理预读的字节块
|
||||||
for chunk in prefetched_chunks:
|
for chunk in prefetched_chunks:
|
||||||
buffer += chunk
|
buffer += chunk
|
||||||
@@ -870,10 +908,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
event.get("event"),
|
event.get("event"),
|
||||||
event.get("data") or "",
|
event.get("data") or "",
|
||||||
)
|
)
|
||||||
# 记录首字时间 (第一次 yield)
|
self._mark_first_output(ctx, output_state)
|
||||||
if first_yield:
|
|
||||||
ctx.record_first_byte_time(self.start_time)
|
|
||||||
first_yield = False
|
|
||||||
yield b"\n"
|
yield b"\n"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -883,16 +918,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
if needs_conversion:
|
if needs_conversion:
|
||||||
converted_line = self._convert_sse_line(ctx, line, events)
|
converted_line = self._convert_sse_line(ctx, line, events)
|
||||||
if converted_line:
|
if converted_line:
|
||||||
# 记录首字时间 (第一次 yield)
|
self._mark_first_output(ctx, output_state)
|
||||||
if first_yield:
|
|
||||||
ctx.record_first_byte_time(self.start_time)
|
|
||||||
first_yield = False
|
|
||||||
yield (converted_line + "\n").encode("utf-8")
|
yield (converted_line + "\n").encode("utf-8")
|
||||||
else:
|
else:
|
||||||
# 记录首字时间 (第一次 yield)
|
self._mark_first_output(ctx, output_state)
|
||||||
if first_yield:
|
|
||||||
ctx.record_first_byte_time(self.start_time)
|
|
||||||
first_yield = False
|
|
||||||
yield (line + "\n").encode("utf-8")
|
yield (line + "\n").encode("utf-8")
|
||||||
|
|
||||||
for event in events:
|
for event in events:
|
||||||
@@ -931,10 +960,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
event.get("event"),
|
event.get("event"),
|
||||||
event.get("data") or "",
|
event.get("data") or "",
|
||||||
)
|
)
|
||||||
# 记录首字时间 (第一次 yield) - 如果预读数据为空
|
self._mark_first_output(ctx, output_state)
|
||||||
if first_yield:
|
|
||||||
ctx.record_first_byte_time(self.start_time)
|
|
||||||
first_yield = False
|
|
||||||
yield b"\n"
|
yield b"\n"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -952,6 +978,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
|
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
self._mark_first_output(ctx, output_state)
|
||||||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -959,16 +986,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
if needs_conversion:
|
if needs_conversion:
|
||||||
converted_line = self._convert_sse_line(ctx, line, events)
|
converted_line = self._convert_sse_line(ctx, line, events)
|
||||||
if converted_line:
|
if converted_line:
|
||||||
# 记录首字时间 (第一次 yield) - 如果预读数据为空
|
self._mark_first_output(ctx, output_state)
|
||||||
if first_yield:
|
|
||||||
ctx.record_first_byte_time(self.start_time)
|
|
||||||
first_yield = False
|
|
||||||
yield (converted_line + "\n").encode("utf-8")
|
yield (converted_line + "\n").encode("utf-8")
|
||||||
else:
|
else:
|
||||||
# 记录首字时间 (第一次 yield) - 如果预读数据为空
|
self._mark_first_output(ctx, output_state)
|
||||||
if first_yield:
|
|
||||||
ctx.record_first_byte_time(self.start_time)
|
|
||||||
first_yield = False
|
|
||||||
yield (line + "\n").encode("utf-8")
|
yield (line + "\n").encode("utf-8")
|
||||||
|
|
||||||
for event in events:
|
for event in events:
|
||||||
@@ -1352,7 +1373,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
model=ctx.model,
|
model=ctx.model,
|
||||||
response_time_ms=response_time_ms,
|
response_time_ms=response_time_ms,
|
||||||
status_code=status_code,
|
status_code=status_code,
|
||||||
error_message=str(error),
|
error_message=extract_error_message(error),
|
||||||
request_headers=original_headers,
|
request_headers=original_headers,
|
||||||
request_body=actual_request_body,
|
request_body=actual_request_body,
|
||||||
is_stream=True,
|
is_stream=True,
|
||||||
@@ -1620,7 +1641,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
model=model,
|
model=model,
|
||||||
response_time_ms=response_time_ms,
|
response_time_ms=response_time_ms,
|
||||||
status_code=status_code,
|
status_code=status_code,
|
||||||
error_message=str(e),
|
error_message=extract_error_message(e),
|
||||||
request_headers=original_headers,
|
request_headers=original_headers,
|
||||||
request_body=actual_request_body,
|
request_body=actual_request_body,
|
||||||
is_stream=False,
|
is_stream=False,
|
||||||
@@ -1640,14 +1661,14 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
|
|
||||||
for encoding in ["utf-8", "gbk", "latin1"]:
|
for encoding in ["utf-8", "gbk", "latin1"]:
|
||||||
try:
|
try:
|
||||||
return error_bytes.decode(encoding)[:500]
|
return error_bytes.decode(encoding)
|
||||||
except (UnicodeDecodeError, LookupError):
|
except (UnicodeDecodeError, LookupError):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return error_bytes.decode("utf-8", errors="replace")[:500]
|
return error_bytes.decode("utf-8", errors="replace")
|
||||||
else:
|
else:
|
||||||
return (
|
return (
|
||||||
e.response.text[:500]
|
e.response.text
|
||||||
if hasattr(e.response, "_content")
|
if hasattr(e.response, "_content")
|
||||||
else "Unable to read response"
|
else "Unable to read response"
|
||||||
)
|
)
|
||||||
@@ -1665,6 +1686,25 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
return False
|
return False
|
||||||
return ctx.provider_api_format.upper() != ctx.client_api_format.upper()
|
return ctx.provider_api_format.upper() != ctx.client_api_format.upper()
|
||||||
|
|
||||||
|
def _mark_first_output(self, ctx: StreamContext, state: Dict[str, bool]) -> None:
|
||||||
|
"""
|
||||||
|
标记首次输出:记录 TTFB 并更新 streaming 状态
|
||||||
|
|
||||||
|
在第一次 yield 数据前调用,确保:
|
||||||
|
1. 首字时间 (TTFB) 已记录到 ctx
|
||||||
|
2. Usage 状态已更新为 streaming(包含 provider/key/TTFB 信息)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: 流上下文
|
||||||
|
state: 包含 first_yield 和 streaming_updated 的状态字典
|
||||||
|
"""
|
||||||
|
if state["first_yield"]:
|
||||||
|
ctx.record_first_byte_time(self.start_time)
|
||||||
|
state["first_yield"] = False
|
||||||
|
if not state["streaming_updated"]:
|
||||||
|
self._update_usage_to_streaming_with_ctx(ctx)
|
||||||
|
state["streaming_updated"] = True
|
||||||
|
|
||||||
def _convert_sse_line(
|
def _convert_sse_line(
|
||||||
self,
|
self,
|
||||||
ctx: StreamContext,
|
ctx: StreamContext,
|
||||||
|
|||||||
@@ -98,6 +98,17 @@ class OpenAIResponseParser(ResponseParser):
|
|||||||
chunk.is_done = True
|
chunk.is_done = True
|
||||||
stats.has_completion = True
|
stats.has_completion = True
|
||||||
|
|
||||||
|
# 提取 usage 信息(某些 OpenAI 兼容 API 如豆包会在最后一个 chunk 中发送 usage)
|
||||||
|
# 这个 chunk 通常 choices 为空数组,但包含完整的 usage 信息
|
||||||
|
usage = parsed.get("usage")
|
||||||
|
if usage and isinstance(usage, dict):
|
||||||
|
chunk.input_tokens = usage.get("prompt_tokens", 0)
|
||||||
|
chunk.output_tokens = usage.get("completion_tokens", 0)
|
||||||
|
|
||||||
|
# 更新 stats
|
||||||
|
stats.input_tokens = chunk.input_tokens
|
||||||
|
stats.output_tokens = chunk.output_tokens
|
||||||
|
|
||||||
stats.chunk_count += 1
|
stats.chunk_count += 1
|
||||||
stats.data_count += 1
|
stats.data_count += 1
|
||||||
|
|
||||||
|
|||||||
@@ -25,8 +25,17 @@ from src.api.handlers.base.content_extractors import (
|
|||||||
from src.api.handlers.base.parsers import get_parser_for_format
|
from src.api.handlers.base.parsers import get_parser_for_format
|
||||||
from src.api.handlers.base.response_parser import ResponseParser
|
from src.api.handlers.base.response_parser import ResponseParser
|
||||||
from src.api.handlers.base.stream_context import StreamContext
|
from src.api.handlers.base.stream_context import StreamContext
|
||||||
|
from src.api.handlers.base.utils import (
|
||||||
|
check_html_response,
|
||||||
|
check_prefetched_response_error,
|
||||||
|
)
|
||||||
|
from src.config.constants import StreamDefaults
|
||||||
from src.config.settings import config
|
from src.config.settings import config
|
||||||
from src.core.exceptions import EmbeddedErrorException, ProviderTimeoutException
|
from src.core.exceptions import (
|
||||||
|
EmbeddedErrorException,
|
||||||
|
ProviderNotAvailableException,
|
||||||
|
ProviderTimeoutException,
|
||||||
|
)
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
from src.models.database import Provider, ProviderEndpoint
|
from src.models.database import Provider, ProviderEndpoint
|
||||||
from src.utils.sse_parser import SSEEventParser
|
from src.utils.sse_parser import SSEEventParser
|
||||||
@@ -165,6 +174,7 @@ class StreamProcessor:
|
|||||||
endpoint: ProviderEndpoint,
|
endpoint: ProviderEndpoint,
|
||||||
ctx: StreamContext,
|
ctx: StreamContext,
|
||||||
max_prefetch_lines: int = 5,
|
max_prefetch_lines: int = 5,
|
||||||
|
max_prefetch_bytes: int = StreamDefaults.MAX_PREFETCH_BYTES,
|
||||||
) -> list:
|
) -> list:
|
||||||
"""
|
"""
|
||||||
预读流的前几行,检测嵌套错误
|
预读流的前几行,检测嵌套错误
|
||||||
@@ -180,12 +190,14 @@ class StreamProcessor:
|
|||||||
endpoint: Endpoint 对象
|
endpoint: Endpoint 对象
|
||||||
ctx: 流式上下文
|
ctx: 流式上下文
|
||||||
max_prefetch_lines: 最多预读行数
|
max_prefetch_lines: 最多预读行数
|
||||||
|
max_prefetch_bytes: 最多预读字节数(避免无换行响应导致 buffer 增长)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
预读的字节块列表
|
预读的字节块列表
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
EmbeddedErrorException: 如果检测到嵌套错误
|
EmbeddedErrorException: 如果检测到嵌套错误
|
||||||
|
ProviderNotAvailableException: 如果检测到 HTML 响应(配置错误)
|
||||||
ProviderTimeoutException: 如果首字节超时(TTFB timeout)
|
ProviderTimeoutException: 如果首字节超时(TTFB timeout)
|
||||||
"""
|
"""
|
||||||
prefetched_chunks: list = []
|
prefetched_chunks: list = []
|
||||||
@@ -193,6 +205,7 @@ class StreamProcessor:
|
|||||||
buffer = b""
|
buffer = b""
|
||||||
line_count = 0
|
line_count = 0
|
||||||
should_stop = False
|
should_stop = False
|
||||||
|
total_prefetched_bytes = 0
|
||||||
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
|
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
|
||||||
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
||||||
|
|
||||||
@@ -206,11 +219,13 @@ class StreamProcessor:
|
|||||||
provider_name=str(provider.name),
|
provider_name=str(provider.name),
|
||||||
)
|
)
|
||||||
prefetched_chunks.append(first_chunk)
|
prefetched_chunks.append(first_chunk)
|
||||||
|
total_prefetched_bytes += len(first_chunk)
|
||||||
buffer += first_chunk
|
buffer += first_chunk
|
||||||
|
|
||||||
# 继续读取剩余的预读数据
|
# 继续读取剩余的预读数据
|
||||||
async for chunk in aiter:
|
async for chunk in aiter:
|
||||||
prefetched_chunks.append(chunk)
|
prefetched_chunks.append(chunk)
|
||||||
|
total_prefetched_bytes += len(chunk)
|
||||||
buffer += chunk
|
buffer += chunk
|
||||||
|
|
||||||
# 尝试按行解析缓冲区
|
# 尝试按行解析缓冲区
|
||||||
@@ -228,10 +243,21 @@ class StreamProcessor:
|
|||||||
|
|
||||||
line_count += 1
|
line_count += 1
|
||||||
|
|
||||||
|
# 检测 HTML 响应(base_url 配置错误的常见症状)
|
||||||
|
if check_html_response(line):
|
||||||
|
logger.error(
|
||||||
|
f" [{self.request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
|
||||||
|
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
|
||||||
|
f"base_url={endpoint.base_url}"
|
||||||
|
)
|
||||||
|
raise ProviderNotAvailableException(
|
||||||
|
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,"
|
||||||
|
f"请检查 endpoint 的 base_url 配置是否正确"
|
||||||
|
)
|
||||||
|
|
||||||
# 跳过空行和注释行
|
# 跳过空行和注释行
|
||||||
if not line or line.startswith(":"):
|
if not line or line.startswith(":"):
|
||||||
if line_count >= max_prefetch_lines:
|
if line_count >= max_prefetch_lines:
|
||||||
should_stop = True
|
|
||||||
break
|
break
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -248,7 +274,6 @@ class StreamProcessor:
|
|||||||
data = json.loads(data_str)
|
data = json.loads(data_str)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
if line_count >= max_prefetch_lines:
|
if line_count >= max_prefetch_lines:
|
||||||
should_stop = True
|
|
||||||
break
|
break
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -276,14 +301,34 @@ class StreamProcessor:
|
|||||||
should_stop = True
|
should_stop = True
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# 达到预读字节上限,停止继续预读(避免无换行响应导致内存增长)
|
||||||
|
if not should_stop and total_prefetched_bytes >= max_prefetch_bytes:
|
||||||
|
logger.debug(
|
||||||
|
f" [{self.request_id}] 预读达到字节上限,停止继续预读: "
|
||||||
|
f"Provider={provider.name}, bytes={total_prefetched_bytes}, "
|
||||||
|
f"max_bytes={max_prefetch_bytes}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
if should_stop or line_count >= max_prefetch_lines:
|
if should_stop or line_count >= max_prefetch_lines:
|
||||||
break
|
break
|
||||||
|
|
||||||
except (EmbeddedErrorException, ProviderTimeoutException):
|
# 预读结束后,检查是否为非 SSE 格式的 HTML/JSON 响应
|
||||||
|
if not should_stop and prefetched_chunks:
|
||||||
|
check_prefetched_response_error(
|
||||||
|
prefetched_chunks=prefetched_chunks,
|
||||||
|
parser=parser,
|
||||||
|
request_id=self.request_id,
|
||||||
|
provider_name=str(provider.name),
|
||||||
|
endpoint_id=endpoint.id,
|
||||||
|
base_url=endpoint.base_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
except (EmbeddedErrorException, ProviderNotAvailableException, ProviderTimeoutException):
|
||||||
# 重新抛出可重试的 Provider 异常,触发故障转移
|
# 重新抛出可重试的 Provider 异常,触发故障转移
|
||||||
raise
|
raise
|
||||||
except (OSError, IOError) as e:
|
except (OSError, IOError) as e:
|
||||||
# 网络 I/O <EFBFBD><EFBFBD><EFBFBD>常:记录警告,可能需要重试
|
# 网络 I/O 异常:记录警告,可能需要重试
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f" [{self.request_id}] 预读流时发生网络异常: {type(e).__name__}: {e}"
|
f" [{self.request_id}] 预读流时发生网络异常: {type(e).__name__}: {e}"
|
||||||
)
|
)
|
||||||
@@ -332,15 +377,15 @@ class StreamProcessor:
|
|||||||
|
|
||||||
# 处理预读数据
|
# 处理预读数据
|
||||||
if prefetched_chunks:
|
if prefetched_chunks:
|
||||||
if not streaming_started and self.on_streaming_start:
|
|
||||||
self.on_streaming_start()
|
|
||||||
streaming_started = True
|
|
||||||
|
|
||||||
for chunk in prefetched_chunks:
|
for chunk in prefetched_chunks:
|
||||||
# 记录首字时间 (TTFB) - 在 yield 之前记录
|
# 记录首字时间 (TTFB) - 在 yield 之前记录
|
||||||
if start_time is not None:
|
if start_time is not None:
|
||||||
ctx.record_first_byte_time(start_time)
|
ctx.record_first_byte_time(start_time)
|
||||||
start_time = None # 只记录一次
|
start_time = None # 只记录一次
|
||||||
|
# 首次输出前触发 streaming 回调(确保 TTFB 已写入 ctx)
|
||||||
|
if not streaming_started and self.on_streaming_start:
|
||||||
|
self.on_streaming_start()
|
||||||
|
streaming_started = True
|
||||||
|
|
||||||
# 把原始数据转发给客户端
|
# 把原始数据转发给客户端
|
||||||
yield chunk
|
yield chunk
|
||||||
@@ -363,14 +408,14 @@ class StreamProcessor:
|
|||||||
|
|
||||||
# 处理剩余的流数据
|
# 处理剩余的流数据
|
||||||
async for chunk in byte_iterator:
|
async for chunk in byte_iterator:
|
||||||
if not streaming_started and self.on_streaming_start:
|
|
||||||
self.on_streaming_start()
|
|
||||||
streaming_started = True
|
|
||||||
|
|
||||||
# 记录首字时间 (TTFB) - 在 yield 之前记录(如果预读数据为空)
|
# 记录首字时间 (TTFB) - 在 yield 之前记录(如果预读数据为空)
|
||||||
if start_time is not None:
|
if start_time is not None:
|
||||||
ctx.record_first_byte_time(start_time)
|
ctx.record_first_byte_time(start_time)
|
||||||
start_time = None # 只记录一次
|
start_time = None # 只记录一次
|
||||||
|
# 首次输出前触发 streaming 回调(确保 TTFB 已写入 ctx)
|
||||||
|
if not streaming_started and self.on_streaming_start:
|
||||||
|
self.on_streaming_start()
|
||||||
|
streaming_started = True
|
||||||
|
|
||||||
# 原始数据透传
|
# 原始数据透传
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|||||||
@@ -2,8 +2,10 @@
|
|||||||
Handler 基础工具函数
|
Handler 基础工具函数
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from src.core.exceptions import EmbeddedErrorException, ProviderNotAvailableException
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
|
|
||||||
|
|
||||||
@@ -107,3 +109,95 @@ def build_sse_headers(extra_headers: Optional[Dict[str, str]] = None) -> Dict[st
|
|||||||
if extra_headers:
|
if extra_headers:
|
||||||
headers.update(extra_headers)
|
headers.update(extra_headers)
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
|
|
||||||
|
def check_html_response(line: str) -> bool:
|
||||||
|
"""
|
||||||
|
检查行是否为 HTML 响应(base_url 配置错误的常见症状)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
line: 要检查的行内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True 如果检测到 HTML 响应
|
||||||
|
"""
|
||||||
|
lower_line = line.lstrip().lower()
|
||||||
|
return lower_line.startswith("<!doctype") or lower_line.startswith("<html")
|
||||||
|
|
||||||
|
|
||||||
|
def check_prefetched_response_error(
|
||||||
|
prefetched_chunks: list,
|
||||||
|
parser: Any,
|
||||||
|
request_id: str,
|
||||||
|
provider_name: str,
|
||||||
|
endpoint_id: Optional[str],
|
||||||
|
base_url: Optional[str],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
检查预读的响应是否为非 SSE 格式的错误响应(HTML 或纯 JSON 错误)
|
||||||
|
|
||||||
|
某些代理可能返回:
|
||||||
|
1. HTML 页面(base_url 配置错误)
|
||||||
|
2. 纯 JSON 错误(无换行或多行 JSON)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefetched_chunks: 预读的字节块列表
|
||||||
|
parser: 响应解析器(需要有 is_error_response 和 parse_response 方法)
|
||||||
|
request_id: 请求 ID(用于日志)
|
||||||
|
provider_name: Provider 名称
|
||||||
|
endpoint_id: Endpoint ID
|
||||||
|
base_url: Endpoint 的 base_url
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ProviderNotAvailableException: 如果检测到 HTML 响应
|
||||||
|
EmbeddedErrorException: 如果检测到 JSON 错误响应
|
||||||
|
"""
|
||||||
|
if not prefetched_chunks:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
prefetched_bytes = b"".join(prefetched_chunks)
|
||||||
|
stripped = prefetched_bytes.lstrip()
|
||||||
|
|
||||||
|
# 去除 BOM
|
||||||
|
if stripped.startswith(b"\xef\xbb\xbf"):
|
||||||
|
stripped = stripped[3:]
|
||||||
|
|
||||||
|
# HTML 响应(通常是 base_url 配置错误导致返回网页)
|
||||||
|
lower_prefix = stripped[:32].lower()
|
||||||
|
if lower_prefix.startswith(b"<!doctype") or lower_prefix.startswith(b"<html"):
|
||||||
|
endpoint_short = endpoint_id[:8] + "..." if endpoint_id else "N/A"
|
||||||
|
logger.error(
|
||||||
|
f" [{request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
|
||||||
|
f"Provider={provider_name}, Endpoint={endpoint_short}, "
|
||||||
|
f"base_url={base_url}"
|
||||||
|
)
|
||||||
|
raise ProviderNotAvailableException(
|
||||||
|
f"提供商 '{provider_name}' 返回了 HTML 页面而非 API 响应,"
|
||||||
|
f"请检查 endpoint 的 base_url 配置是否正确"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 纯 JSON(可能无换行/多行 JSON)
|
||||||
|
if stripped.startswith(b"{") or stripped.startswith(b"["):
|
||||||
|
payload_str = stripped.decode("utf-8", errors="replace").strip()
|
||||||
|
data = json.loads(payload_str)
|
||||||
|
if isinstance(data, dict) and parser.is_error_response(data):
|
||||||
|
parsed = parser.parse_response(data, 200)
|
||||||
|
logger.warning(
|
||||||
|
f" [{request_id}] 检测到 JSON 错误响应: "
|
||||||
|
f"Provider={provider_name}, "
|
||||||
|
f"error_type={parsed.error_type}, "
|
||||||
|
f"message={parsed.error_message}"
|
||||||
|
)
|
||||||
|
raise EmbeddedErrorException(
|
||||||
|
provider_name=provider_name,
|
||||||
|
error_code=(
|
||||||
|
int(parsed.error_type)
|
||||||
|
if parsed.error_type and parsed.error_type.isdigit()
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
error_message=parsed.error_message,
|
||||||
|
error_status=parsed.error_type,
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|||||||
@@ -104,9 +104,11 @@ async def get_my_usage(
|
|||||||
request: Request,
|
request: Request,
|
||||||
start_date: Optional[datetime] = None,
|
start_date: Optional[datetime] = None,
|
||||||
end_date: Optional[datetime] = None,
|
end_date: Optional[datetime] = None,
|
||||||
|
limit: int = Query(100, ge=1, le=200, description="每页记录数,默认100,最大200"),
|
||||||
|
offset: int = Query(0, ge=0, le=2000, description="偏移量,用于分页,最大2000"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
adapter = GetUsageAdapter(start_date=start_date, end_date=end_date)
|
adapter = GetUsageAdapter(start_date=start_date, end_date=end_date, limit=limit, offset=offset)
|
||||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||||
|
|
||||||
|
|
||||||
@@ -133,6 +135,20 @@ async def get_my_interval_timeline(
|
|||||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/usage/heatmap")
|
||||||
|
async def get_my_activity_heatmap(
|
||||||
|
request: Request,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get user's activity heatmap data for the past 365 days.
|
||||||
|
|
||||||
|
This endpoint is cached for 5 minutes to reduce database load.
|
||||||
|
"""
|
||||||
|
adapter = GetMyActivityHeatmapAdapter()
|
||||||
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/providers")
|
@router.get("/providers")
|
||||||
async def list_available_providers(request: Request, db: Session = Depends(get_db)):
|
async def list_available_providers(request: Request, db: Session = Depends(get_db)):
|
||||||
adapter = ListAvailableProvidersAdapter()
|
adapter = ListAvailableProvidersAdapter()
|
||||||
@@ -471,6 +487,8 @@ class ToggleMyApiKeyAdapter(AuthenticatedApiAdapter):
|
|||||||
class GetUsageAdapter(AuthenticatedApiAdapter):
|
class GetUsageAdapter(AuthenticatedApiAdapter):
|
||||||
start_date: Optional[datetime]
|
start_date: Optional[datetime]
|
||||||
end_date: Optional[datetime]
|
end_date: Optional[datetime]
|
||||||
|
limit: int = 100
|
||||||
|
offset: int = 0
|
||||||
|
|
||||||
async def handle(self, context): # type: ignore[override]
|
async def handle(self, context): # type: ignore[override]
|
||||||
db = context.db
|
db = context.db
|
||||||
@@ -553,7 +571,7 @@ class GetUsageAdapter(AuthenticatedApiAdapter):
|
|||||||
stats["total_cost_usd"] += item["total_cost_usd"]
|
stats["total_cost_usd"] += item["total_cost_usd"]
|
||||||
# 假设 summary 中的都是成功的请求
|
# 假设 summary 中的都是成功的请求
|
||||||
stats["success_count"] += item["requests"]
|
stats["success_count"] += item["requests"]
|
||||||
if item.get("avg_response_time_ms"):
|
if item.get("avg_response_time_ms") is not None:
|
||||||
stats["total_response_time_ms"] += item["avg_response_time_ms"] * item["requests"]
|
stats["total_response_time_ms"] += item["avg_response_time_ms"] * item["requests"]
|
||||||
stats["response_time_count"] += item["requests"]
|
stats["response_time_count"] += item["requests"]
|
||||||
|
|
||||||
@@ -582,7 +600,10 @@ class GetUsageAdapter(AuthenticatedApiAdapter):
|
|||||||
query = query.filter(Usage.created_at >= self.start_date)
|
query = query.filter(Usage.created_at >= self.start_date)
|
||||||
if self.end_date:
|
if self.end_date:
|
||||||
query = query.filter(Usage.created_at <= self.end_date)
|
query = query.filter(Usage.created_at <= self.end_date)
|
||||||
usage_records = query.order_by(Usage.created_at.desc()).limit(100).all()
|
|
||||||
|
# 计算总数用于分页
|
||||||
|
total_records = query.count()
|
||||||
|
usage_records = query.order_by(Usage.created_at.desc()).offset(self.offset).limit(self.limit).all()
|
||||||
|
|
||||||
avg_resp_query = db.query(func.avg(Usage.response_time_ms)).filter(
|
avg_resp_query = db.query(func.avg(Usage.response_time_ms)).filter(
|
||||||
Usage.user_id == user.id,
|
Usage.user_id == user.id,
|
||||||
@@ -608,6 +629,13 @@ class GetUsageAdapter(AuthenticatedApiAdapter):
|
|||||||
"used_usd": user.used_usd,
|
"used_usd": user.used_usd,
|
||||||
"summary_by_model": summary_by_model,
|
"summary_by_model": summary_by_model,
|
||||||
"summary_by_provider": summary_by_provider,
|
"summary_by_provider": summary_by_provider,
|
||||||
|
# 分页信息
|
||||||
|
"pagination": {
|
||||||
|
"total": total_records,
|
||||||
|
"limit": self.limit,
|
||||||
|
"offset": self.offset,
|
||||||
|
"has_more": self.offset + self.limit < total_records,
|
||||||
|
},
|
||||||
"records": [
|
"records": [
|
||||||
{
|
{
|
||||||
"id": r.id,
|
"id": r.id,
|
||||||
@@ -636,13 +664,6 @@ class GetUsageAdapter(AuthenticatedApiAdapter):
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
response_data["activity_heatmap"] = UsageService.get_daily_activity(
|
|
||||||
db=db,
|
|
||||||
user_id=user.id,
|
|
||||||
window_days=365,
|
|
||||||
include_actual_cost=user.role == "admin",
|
|
||||||
)
|
|
||||||
|
|
||||||
# 管理员可以看到真实成本
|
# 管理员可以看到真实成本
|
||||||
if user.role == "admin":
|
if user.role == "admin":
|
||||||
response_data["total_actual_cost"] = total_actual_cost
|
response_data["total_actual_cost"] = total_actual_cost
|
||||||
@@ -709,6 +730,20 @@ class GetMyIntervalTimelineAdapter(AuthenticatedApiAdapter):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class GetMyActivityHeatmapAdapter(AuthenticatedApiAdapter):
|
||||||
|
"""Activity heatmap adapter with Redis caching for user."""
|
||||||
|
|
||||||
|
async def handle(self, context): # type: ignore[override]
|
||||||
|
user = context.user
|
||||||
|
result = await UsageService.get_cached_heatmap(
|
||||||
|
db=context.db,
|
||||||
|
user_id=user.id,
|
||||||
|
include_actual_cost=user.role == "admin",
|
||||||
|
)
|
||||||
|
context.add_audit_metadata(action="activity_heatmap")
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class ListAvailableProvidersAdapter(AuthenticatedApiAdapter):
|
class ListAvailableProvidersAdapter(AuthenticatedApiAdapter):
|
||||||
async def handle(self, context): # type: ignore[override]
|
async def handle(self, context): # type: ignore[override]
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
|
|||||||
@@ -213,7 +213,7 @@ class RedisClientManager:
|
|||||||
f"Redis连接失败: {error_msg}\n"
|
f"Redis连接失败: {error_msg}\n"
|
||||||
"缓存亲和性功能需要Redis支持,请确保Redis服务正常运行。\n"
|
"缓存亲和性功能需要Redis支持,请确保Redis服务正常运行。\n"
|
||||||
"检查事项:\n"
|
"检查事项:\n"
|
||||||
"1. Redis服务是否已启动(docker-compose up -d redis)\n"
|
"1. Redis服务是否已启动(docker compose up -d redis)\n"
|
||||||
"2. 环境变量 REDIS_URL 或 REDIS_PASSWORD 是否配置正确\n"
|
"2. 环境变量 REDIS_URL 或 REDIS_PASSWORD 是否配置正确\n"
|
||||||
"3. Redis端口(默认6379)是否可访问"
|
"3. Redis端口(默认6379)是否可访问"
|
||||||
) from e
|
) from e
|
||||||
|
|||||||
@@ -21,6 +21,9 @@ class CacheTTL:
|
|||||||
# L1 本地缓存(用于减少 Redis 访问)
|
# L1 本地缓存(用于减少 Redis 访问)
|
||||||
L1_LOCAL = 3 # 3秒
|
L1_LOCAL = 3 # 3秒
|
||||||
|
|
||||||
|
# 活跃度热力图缓存 - 历史数据变化不频繁
|
||||||
|
ACTIVITY_HEATMAP = 300 # 5分钟
|
||||||
|
|
||||||
# 并发锁 TTL - 防止死锁
|
# 并发锁 TTL - 防止死锁
|
||||||
CONCURRENCY_LOCK = 600 # 10分钟
|
CONCURRENCY_LOCK = 600 # 10分钟
|
||||||
|
|
||||||
@@ -38,8 +41,25 @@ class CacheSize:
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class StreamDefaults:
|
||||||
|
"""流式处理默认值"""
|
||||||
|
|
||||||
|
# 预读字节上限(避免无换行响应导致内存增长)
|
||||||
|
# 64KB 基于:
|
||||||
|
# 1. SSE 单条消息通常远小于此值
|
||||||
|
# 2. 足够检测 HTML 和 JSON 错误响应
|
||||||
|
# 3. 不会占用过多内存
|
||||||
|
MAX_PREFETCH_BYTES = 64 * 1024 # 64KB
|
||||||
|
|
||||||
|
|
||||||
class ConcurrencyDefaults:
|
class ConcurrencyDefaults:
|
||||||
"""并发控制默认值"""
|
"""并发控制默认值
|
||||||
|
|
||||||
|
算法说明:边界记忆 + 渐进探测
|
||||||
|
- 触发 429 时记录边界(last_concurrent_peak),新限制 = 边界 - 1
|
||||||
|
- 扩容时不超过边界,除非是探测性扩容(长时间无 429)
|
||||||
|
- 这样可以快速收敛到真实限制附近,避免过度保守
|
||||||
|
"""
|
||||||
|
|
||||||
# 自适应并发初始限制(宽松起步,遇到 429 再降低)
|
# 自适应并发初始限制(宽松起步,遇到 429 再降低)
|
||||||
INITIAL_LIMIT = 50
|
INITIAL_LIMIT = 50
|
||||||
@@ -69,10 +89,6 @@ class ConcurrencyDefaults:
|
|||||||
# 扩容步长 - 每次扩容增加的并发数
|
# 扩容步长 - 每次扩容增加的并发数
|
||||||
INCREASE_STEP = 2
|
INCREASE_STEP = 2
|
||||||
|
|
||||||
# 缩容乘数 - 遇到 429 时基于当前并发数的缩容比例
|
|
||||||
# 0.85 表示降到触发 429 时并发数的 85%
|
|
||||||
DECREASE_MULTIPLIER = 0.85
|
|
||||||
|
|
||||||
# 最大并发限制上限
|
# 最大并发限制上限
|
||||||
MAX_CONCURRENT_LIMIT = 200
|
MAX_CONCURRENT_LIMIT = 200
|
||||||
|
|
||||||
@@ -84,6 +100,7 @@ class ConcurrencyDefaults:
|
|||||||
|
|
||||||
# === 探测性扩容参数 ===
|
# === 探测性扩容参数 ===
|
||||||
# 探测性扩容间隔(分钟)- 长时间无 429 且有流量时尝试扩容
|
# 探测性扩容间隔(分钟)- 长时间无 429 且有流量时尝试扩容
|
||||||
|
# 探测性扩容可以突破已知边界,尝试更高的并发
|
||||||
PROBE_INCREASE_INTERVAL_MINUTES = 30
|
PROBE_INCREASE_INTERVAL_MINUTES = 30
|
||||||
|
|
||||||
# 探测性扩容最小请求数 - 在探测间隔内至少需要这么多请求
|
# 探测性扩容最小请求数 - 在探测间隔内至少需要这么多请求
|
||||||
|
|||||||
28
src/core/error_utils.py
Normal file
28
src/core/error_utils.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""
|
||||||
|
错误消息处理工具函数
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
def extract_error_message(error: Exception, status_code: Optional[int] = None) -> str:
|
||||||
|
"""
|
||||||
|
从异常中提取错误消息,优先使用上游响应内容
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error: 异常对象
|
||||||
|
status_code: 可选的 HTTP 状态码,用于构建更详细的错误消息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
错误消息字符串
|
||||||
|
"""
|
||||||
|
# 优先使用 upstream_response 属性(包含上游 Provider 的原始错误)
|
||||||
|
upstream_response = getattr(error, "upstream_response", None)
|
||||||
|
if upstream_response and isinstance(upstream_response, str) and upstream_response.strip():
|
||||||
|
return str(upstream_response)
|
||||||
|
|
||||||
|
# 回退到异常的字符串表示(str 可能为空,如 httpx 超时异常)
|
||||||
|
error_str = str(error) or repr(error)
|
||||||
|
if status_code is not None:
|
||||||
|
return f"HTTP {status_code}: {error_str}"
|
||||||
|
return error_str
|
||||||
@@ -547,11 +547,19 @@ class ErrorResponse:
|
|||||||
- 所有错误都记录到日志,通过错误 ID 关联
|
- 所有错误都记录到日志,通过错误 ID 关联
|
||||||
"""
|
"""
|
||||||
if isinstance(e, ProxyException):
|
if isinstance(e, ProxyException):
|
||||||
|
details = e.details.copy() if e.details else {}
|
||||||
|
status_code = e.status_code
|
||||||
|
message = e.message
|
||||||
|
# 如果是 ProviderNotAvailableException 且有上游错误,直接透传上游信息
|
||||||
|
if isinstance(e, ProviderNotAvailableException) and e.upstream_response:
|
||||||
|
if e.upstream_status:
|
||||||
|
status_code = e.upstream_status
|
||||||
|
message = e.upstream_response
|
||||||
return ErrorResponse.create(
|
return ErrorResponse.create(
|
||||||
error_type=e.error_type,
|
error_type=e.error_type,
|
||||||
message=e.message,
|
message=message,
|
||||||
status_code=e.status_code,
|
status_code=status_code,
|
||||||
details=e.details,
|
details=details if details else None,
|
||||||
)
|
)
|
||||||
elif isinstance(e, HTTPException):
|
elif isinstance(e, HTTPException):
|
||||||
return ErrorResponse.create(
|
return ErrorResponse.create(
|
||||||
|
|||||||
@@ -411,7 +411,7 @@ def init_db():
|
|||||||
print(" 3. 数据库用户名和密码是否正确", file=sys.stderr)
|
print(" 3. 数据库用户名和密码是否正确", file=sys.stderr)
|
||||||
print("", file=sys.stderr)
|
print("", file=sys.stderr)
|
||||||
print("如果使用 Docker,请先运行:", file=sys.stderr)
|
print("如果使用 Docker,请先运行:", file=sys.stderr)
|
||||||
print(" docker-compose up -d postgres redis", file=sys.stderr)
|
print(" docker compose -f docker-compose.build.yml up -d postgres redis", file=sys.stderr)
|
||||||
print("", file=sys.stderr)
|
print("", file=sys.stderr)
|
||||||
print("=" * 60, file=sys.stderr)
|
print("=" * 60, file=sys.stderr)
|
||||||
# 使用 os._exit 直接退出,避免 uvicorn 捕获并打印堆栈
|
# 使用 os._exit 直接退出,避免 uvicorn 捕获并打印堆栈
|
||||||
|
|||||||
@@ -309,8 +309,9 @@ class CreateApiKeyRequest(BaseModel):
|
|||||||
allowed_endpoints: Optional[List[str]] = None # 允许使用的端点 ID 列表
|
allowed_endpoints: Optional[List[str]] = None # 允许使用的端点 ID 列表
|
||||||
allowed_api_formats: Optional[List[str]] = None # 允许使用的 API 格式列表
|
allowed_api_formats: Optional[List[str]] = None # 允许使用的 API 格式列表
|
||||||
allowed_models: Optional[List[str]] = None # 允许使用的模型名称列表
|
allowed_models: Optional[List[str]] = None # 允许使用的模型名称列表
|
||||||
rate_limit: Optional[int] = 100
|
rate_limit: Optional[int] = None # None = 无限制
|
||||||
expire_days: Optional[int] = None # None = 永不过期,数字 = 多少天后过期
|
expire_days: Optional[int] = None # None = 永不过期,数字 = 多少天后过期(兼容旧版)
|
||||||
|
expires_at: Optional[str] = None # ISO 日期字符串,如 "2025-12-31",优先于 expire_days
|
||||||
initial_balance_usd: Optional[float] = Field(
|
initial_balance_usd: Optional[float] = Field(
|
||||||
None, description="初始余额(USD),仅用于独立Key,None = 无限制"
|
None, description="初始余额(USD),仅用于独立Key,None = 无限制"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -150,7 +150,7 @@ class ApiKey(Base):
|
|||||||
allowed_endpoints = Column(JSON, nullable=True) # 允许使用的端点 ID 列表
|
allowed_endpoints = Column(JSON, nullable=True) # 允许使用的端点 ID 列表
|
||||||
allowed_api_formats = Column(JSON, nullable=True) # 允许使用的 API 格式列表
|
allowed_api_formats = Column(JSON, nullable=True) # 允许使用的 API 格式列表
|
||||||
allowed_models = Column(JSON, nullable=True) # 允许使用的模型名称列表
|
allowed_models = Column(JSON, nullable=True) # 允许使用的模型名称列表
|
||||||
rate_limit = Column(Integer, default=100) # 每分钟请求限制
|
rate_limit = Column(Integer, default=None, nullable=True) # 每分钟请求限制,None = 无限制
|
||||||
concurrent_limit = Column(Integer, default=5, nullable=True) # 并发请求限制
|
concurrent_limit = Column(Integer, default=5, nullable=True) # 并发请求限制
|
||||||
|
|
||||||
# Key 能力配置
|
# Key 能力配置
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ class ProviderEndpointCreate(BaseModel):
|
|||||||
provider_id: str = Field(..., description="Provider ID")
|
provider_id: str = Field(..., description="Provider ID")
|
||||||
api_format: str = Field(..., description="API 格式 (CLAUDE, OPENAI, CLAUDE_CLI, OPENAI_CLI)")
|
api_format: str = Field(..., description="API 格式 (CLAUDE, OPENAI, CLAUDE_CLI, OPENAI_CLI)")
|
||||||
base_url: str = Field(..., min_length=1, max_length=500, description="API 基础 URL")
|
base_url: str = Field(..., min_length=1, max_length=500, description="API 基础 URL")
|
||||||
|
custom_path: Optional[str] = Field(default=None, max_length=200, description="自定义请求路径")
|
||||||
|
|
||||||
# 请求配置
|
# 请求配置
|
||||||
headers: Optional[Dict[str, str]] = Field(default=None, description="自定义请求头")
|
headers: Optional[Dict[str, str]] = Field(default=None, description="自定义请求头")
|
||||||
@@ -62,6 +63,7 @@ class ProviderEndpointUpdate(BaseModel):
|
|||||||
base_url: Optional[str] = Field(
|
base_url: Optional[str] = Field(
|
||||||
default=None, min_length=1, max_length=500, description="API 基础 URL"
|
default=None, min_length=1, max_length=500, description="API 基础 URL"
|
||||||
)
|
)
|
||||||
|
custom_path: Optional[str] = Field(default=None, max_length=200, description="自定义请求路径")
|
||||||
headers: Optional[Dict[str, str]] = Field(default=None, description="自定义请求头")
|
headers: Optional[Dict[str, str]] = Field(default=None, description="自定义请求头")
|
||||||
timeout: Optional[int] = Field(default=None, ge=10, le=600, description="超时时间(秒)")
|
timeout: Optional[int] = Field(default=None, ge=10, le=600, description="超时时间(秒)")
|
||||||
max_retries: Optional[int] = Field(default=None, ge=0, le=10, description="最大重试次数")
|
max_retries: Optional[int] = Field(default=None, ge=0, le=10, description="最大重试次数")
|
||||||
@@ -94,6 +96,7 @@ class ProviderEndpointResponse(BaseModel):
|
|||||||
# API 配置
|
# API 配置
|
||||||
api_format: str
|
api_format: str
|
||||||
base_url: str
|
base_url: str
|
||||||
|
custom_path: Optional[str] = None
|
||||||
|
|
||||||
# 请求配置
|
# 请求配置
|
||||||
headers: Optional[Dict[str, str]] = None
|
headers: Optional[Dict[str, str]] = None
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ WARNING: 多进程环境注意事项
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Deque, Dict
|
from typing import Any, Deque, Dict
|
||||||
|
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
@@ -95,12 +95,12 @@ class SlidingWindow:
|
|||||||
"""获取最早的重置时间"""
|
"""获取最早的重置时间"""
|
||||||
self._cleanup()
|
self._cleanup()
|
||||||
if not self.requests:
|
if not self.requests:
|
||||||
return datetime.now()
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
# 最早的请求将在window_size秒后过期
|
# 最早的请求将在window_size秒后过期
|
||||||
oldest_request = self.requests[0]
|
oldest_request = self.requests[0]
|
||||||
reset_time = oldest_request + self.window_size
|
reset_time = oldest_request + self.window_size
|
||||||
return datetime.fromtimestamp(reset_time)
|
return datetime.fromtimestamp(reset_time, tz=timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
class SlidingWindowStrategy(RateLimitStrategy):
|
class SlidingWindowStrategy(RateLimitStrategy):
|
||||||
@@ -250,7 +250,7 @@ class SlidingWindowStrategy(RateLimitStrategy):
|
|||||||
retry_after = None
|
retry_after = None
|
||||||
if not allowed:
|
if not allowed:
|
||||||
# 计算需要等待的时间(最早请求过期的时间)
|
# 计算需要等待的时间(最早请求过期的时间)
|
||||||
retry_after = int((reset_at - datetime.now()).total_seconds()) + 1
|
retry_after = int((reset_at - datetime.now(timezone.utc)).total_seconds()) + 1
|
||||||
|
|
||||||
return RateLimitResult(
|
return RateLimitResult(
|
||||||
allowed=allowed,
|
allowed=allowed,
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
from ...clients.redis_client import get_redis_client_sync
|
from ...clients.redis_client import get_redis_client_sync
|
||||||
@@ -63,11 +63,11 @@ class TokenBucket:
|
|||||||
def get_reset_time(self) -> datetime:
|
def get_reset_time(self) -> datetime:
|
||||||
"""获取下次完全恢复的时间"""
|
"""获取下次完全恢复的时间"""
|
||||||
if self.tokens >= self.capacity:
|
if self.tokens >= self.capacity:
|
||||||
return datetime.now()
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
tokens_needed = self.capacity - self.tokens
|
tokens_needed = self.capacity - self.tokens
|
||||||
seconds_to_full = tokens_needed / self.refill_rate
|
seconds_to_full = tokens_needed / self.refill_rate
|
||||||
return datetime.now() + timedelta(seconds=seconds_to_full)
|
return datetime.now(timezone.utc) + timedelta(seconds=seconds_to_full)
|
||||||
|
|
||||||
|
|
||||||
class TokenBucketStrategy(RateLimitStrategy):
|
class TokenBucketStrategy(RateLimitStrategy):
|
||||||
@@ -370,7 +370,7 @@ class RedisTokenBucketBackend:
|
|||||||
|
|
||||||
if tokens is None or last_refill is None:
|
if tokens is None or last_refill is None:
|
||||||
remaining = capacity
|
remaining = capacity
|
||||||
reset_at = datetime.now() + timedelta(seconds=capacity / refill_rate)
|
reset_at = datetime.now(timezone.utc) + timedelta(seconds=capacity / refill_rate)
|
||||||
else:
|
else:
|
||||||
tokens_value = float(tokens)
|
tokens_value = float(tokens)
|
||||||
last_refill_value = float(last_refill)
|
last_refill_value = float(last_refill)
|
||||||
@@ -378,7 +378,7 @@ class RedisTokenBucketBackend:
|
|||||||
tokens_value = min(capacity, tokens_value + delta * refill_rate)
|
tokens_value = min(capacity, tokens_value + delta * refill_rate)
|
||||||
remaining = int(tokens_value)
|
remaining = int(tokens_value)
|
||||||
reset_after = 0 if tokens_value >= capacity else (capacity - tokens_value) / refill_rate
|
reset_after = 0 if tokens_value >= capacity else (capacity - tokens_value) / refill_rate
|
||||||
reset_at = datetime.now() + timedelta(seconds=reset_after)
|
reset_at = datetime.now(timezone.utc) + timedelta(seconds=reset_after)
|
||||||
|
|
||||||
allowed = remaining >= amount
|
allowed = remaining >= amount
|
||||||
retry_after = None
|
retry_after = None
|
||||||
|
|||||||
@@ -148,6 +148,8 @@ class GlobalModelService:
|
|||||||
删除 GlobalModel
|
删除 GlobalModel
|
||||||
|
|
||||||
默认行为: 级联删除所有关联的 Provider 模型实现
|
默认行为: 级联删除所有关联的 Provider 模型实现
|
||||||
|
注意: 不清理 API Key 和 User 的 allowed_models 引用,
|
||||||
|
保留无效引用可让用户在前端看到"已失效"的模型,便于手动清理或等待重建同名模型
|
||||||
"""
|
"""
|
||||||
global_model = GlobalModelService.get_global_model(db, global_model_id)
|
global_model = GlobalModelService.get_global_model(db, global_model_id)
|
||||||
|
|
||||||
|
|||||||
@@ -237,7 +237,7 @@ class ErrorClassifier:
|
|||||||
result["reason"] = str(data.get("reason", data.get("code", "")))
|
result["reason"] = str(data.get("reason", data.get("code", "")))
|
||||||
|
|
||||||
except (json.JSONDecodeError, TypeError, KeyError):
|
except (json.JSONDecodeError, TypeError, KeyError):
|
||||||
result["message"] = error_text[:500] if len(error_text) > 500 else error_text
|
result["message"] = error_text
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -323,8 +323,8 @@ class ErrorClassifier:
|
|||||||
if parts:
|
if parts:
|
||||||
return ": ".join(parts) if len(parts) > 1 else parts[0]
|
return ": ".join(parts) if len(parts) > 1 else parts[0]
|
||||||
|
|
||||||
# 无法解析,返回原始文本(截断)
|
# 无法解析,返回原始文本
|
||||||
return parsed["raw"][:500] if len(parsed["raw"]) > 500 else parsed["raw"]
|
return parsed["raw"]
|
||||||
|
|
||||||
def classify(
|
def classify(
|
||||||
self,
|
self,
|
||||||
@@ -484,11 +484,15 @@ class ErrorClassifier:
|
|||||||
return ProviderNotAvailableException(
|
return ProviderNotAvailableException(
|
||||||
message=detailed_message,
|
message=detailed_message,
|
||||||
provider_name=provider_name,
|
provider_name=provider_name,
|
||||||
|
upstream_status=status,
|
||||||
|
upstream_response=error_response_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ProviderNotAvailableException(
|
return ProviderNotAvailableException(
|
||||||
message=detailed_message,
|
message=detailed_message,
|
||||||
provider_name=provider_name,
|
provider_name=provider_name,
|
||||||
|
upstream_status=status,
|
||||||
|
upstream_response=error_response_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handle_http_error(
|
async def handle_http_error(
|
||||||
@@ -532,12 +536,14 @@ class ErrorClassifier:
|
|||||||
provider_name = str(provider.name)
|
provider_name = str(provider.name)
|
||||||
|
|
||||||
# 尝试读取错误响应内容
|
# 尝试读取错误响应内容
|
||||||
error_response_text = None
|
# 优先使用 handler 附加的 upstream_response 属性(流式请求中 response.text 可能为空)
|
||||||
try:
|
error_response_text = getattr(http_error, "upstream_response", None)
|
||||||
if http_error.response and hasattr(http_error.response, "text"):
|
if not error_response_text:
|
||||||
error_response_text = http_error.response.text[:1000] # 限制长度
|
try:
|
||||||
except Exception:
|
if http_error.response and hasattr(http_error.response, "text"):
|
||||||
pass
|
error_response_text = http_error.response.text
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
logger.warning(f" [{request_id}] HTTP错误 (attempt={attempt}/{max_attempts}): "
|
logger.warning(f" [{request_id}] HTTP错误 (attempt={attempt}/{max_attempts}): "
|
||||||
f"{http_error.response.status_code if http_error.response else 'unknown'}")
|
f"{http_error.response.status_code if http_error.response else 'unknown'}")
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from redis import Redis
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from src.core.enums import APIFormat
|
from src.core.enums import APIFormat
|
||||||
|
from src.core.error_utils import extract_error_message
|
||||||
from src.core.exceptions import (
|
from src.core.exceptions import (
|
||||||
ConcurrencyLimitError,
|
ConcurrencyLimitError,
|
||||||
ProviderNotAvailableException,
|
ProviderNotAvailableException,
|
||||||
@@ -401,7 +402,7 @@ class FallbackOrchestrator:
|
|||||||
db=self.db,
|
db=self.db,
|
||||||
candidate_id=candidate_record_id,
|
candidate_id=candidate_record_id,
|
||||||
error_type="HTTPStatusError",
|
error_type="HTTPStatusError",
|
||||||
error_message=f"HTTP {status_code}: {str(cause)}",
|
error_message=extract_error_message(cause, status_code),
|
||||||
status_code=status_code,
|
status_code=status_code,
|
||||||
latency_ms=elapsed_ms,
|
latency_ms=elapsed_ms,
|
||||||
concurrent_requests=captured_key_concurrent,
|
concurrent_requests=captured_key_concurrent,
|
||||||
@@ -425,31 +426,22 @@ class FallbackOrchestrator:
|
|||||||
attempt=attempt,
|
attempt=attempt,
|
||||||
max_attempts=max_attempts,
|
max_attempts=max_attempts,
|
||||||
)
|
)
|
||||||
# str(cause) 可能为空(如 httpx 超时异常),使用 repr() 作为备用
|
|
||||||
error_msg = str(cause) or repr(cause)
|
|
||||||
# 如果是 ProviderNotAvailableException,附加上游响应
|
|
||||||
if hasattr(cause, "upstream_response") and cause.upstream_response:
|
|
||||||
error_msg = f"{error_msg} | 上游响应: {cause.upstream_response[:500]}"
|
|
||||||
RequestCandidateService.mark_candidate_failed(
|
RequestCandidateService.mark_candidate_failed(
|
||||||
db=self.db,
|
db=self.db,
|
||||||
candidate_id=candidate_record_id,
|
candidate_id=candidate_record_id,
|
||||||
error_type=type(cause).__name__,
|
error_type=type(cause).__name__,
|
||||||
error_message=error_msg,
|
error_message=extract_error_message(cause),
|
||||||
latency_ms=elapsed_ms,
|
latency_ms=elapsed_ms,
|
||||||
concurrent_requests=captured_key_concurrent,
|
concurrent_requests=captured_key_concurrent,
|
||||||
)
|
)
|
||||||
return "continue" if has_retry_left else "break"
|
return "continue" if has_retry_left else "break"
|
||||||
|
|
||||||
# 未知错误:记录失败并抛出
|
# 未知错误:记录失败并抛出
|
||||||
error_msg = str(cause) or repr(cause)
|
|
||||||
# 如果是 ProviderNotAvailableException,附加上游响应
|
|
||||||
if hasattr(cause, "upstream_response") and cause.upstream_response:
|
|
||||||
error_msg = f"{error_msg} | 上游响应: {cause.upstream_response[:500]}"
|
|
||||||
RequestCandidateService.mark_candidate_failed(
|
RequestCandidateService.mark_candidate_failed(
|
||||||
db=self.db,
|
db=self.db,
|
||||||
candidate_id=candidate_record_id,
|
candidate_id=candidate_record_id,
|
||||||
error_type=type(cause).__name__,
|
error_type=type(cause).__name__,
|
||||||
error_message=error_msg,
|
error_message=extract_error_message(cause),
|
||||||
latency_ms=elapsed_ms,
|
latency_ms=elapsed_ms,
|
||||||
concurrent_requests=captured_key_concurrent,
|
concurrent_requests=captured_key_concurrent,
|
||||||
)
|
)
|
||||||
@@ -543,7 +535,9 @@ class FallbackOrchestrator:
|
|||||||
raise last_error
|
raise last_error
|
||||||
|
|
||||||
# 所有组合都已尝试完毕,全部失败
|
# 所有组合都已尝试完毕,全部失败
|
||||||
self._raise_all_failed_exception(request_id, max_attempts, last_candidate, model_name, api_format_enum)
|
self._raise_all_failed_exception(
|
||||||
|
request_id, max_attempts, last_candidate, model_name, api_format_enum, last_error
|
||||||
|
)
|
||||||
|
|
||||||
async def _try_candidate_with_retries(
|
async def _try_candidate_with_retries(
|
||||||
self,
|
self,
|
||||||
@@ -565,6 +559,7 @@ class FallbackOrchestrator:
|
|||||||
provider = candidate.provider
|
provider = candidate.provider
|
||||||
endpoint = candidate.endpoint
|
endpoint = candidate.endpoint
|
||||||
max_retries_for_candidate = int(endpoint.max_retries) if candidate.is_cached else 1
|
max_retries_for_candidate = int(endpoint.max_retries) if candidate.is_cached else 1
|
||||||
|
last_error: Optional[Exception] = None
|
||||||
|
|
||||||
for retry_index in range(max_retries_for_candidate):
|
for retry_index in range(max_retries_for_candidate):
|
||||||
attempt_counter += 1
|
attempt_counter += 1
|
||||||
@@ -599,6 +594,7 @@ class FallbackOrchestrator:
|
|||||||
return {"success": True, "response": response}
|
return {"success": True, "response": response}
|
||||||
|
|
||||||
except ExecutionError as exec_err:
|
except ExecutionError as exec_err:
|
||||||
|
last_error = exec_err.cause
|
||||||
action = await self._handle_candidate_error(
|
action = await self._handle_candidate_error(
|
||||||
exec_err=exec_err,
|
exec_err=exec_err,
|
||||||
candidate=candidate,
|
candidate=candidate,
|
||||||
@@ -630,6 +626,7 @@ class FallbackOrchestrator:
|
|||||||
"success": False,
|
"success": False,
|
||||||
"attempt_counter": attempt_counter,
|
"attempt_counter": attempt_counter,
|
||||||
"max_attempts": max_attempts,
|
"max_attempts": max_attempts,
|
||||||
|
"error": last_error,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _attach_metadata_to_error(
|
def _attach_metadata_to_error(
|
||||||
@@ -678,6 +675,7 @@ class FallbackOrchestrator:
|
|||||||
last_candidate: Optional[ProviderCandidate],
|
last_candidate: Optional[ProviderCandidate],
|
||||||
model_name: str,
|
model_name: str,
|
||||||
api_format_enum: APIFormat,
|
api_format_enum: APIFormat,
|
||||||
|
last_error: Optional[Exception] = None,
|
||||||
) -> NoReturn:
|
) -> NoReturn:
|
||||||
"""所有组合都失败时抛出异常"""
|
"""所有组合都失败时抛出异常"""
|
||||||
logger.error(f" [{request_id}] 所有 {max_attempts} 个组合均失败")
|
logger.error(f" [{request_id}] 所有 {max_attempts} 个组合均失败")
|
||||||
@@ -693,9 +691,38 @@ class FallbackOrchestrator:
|
|||||||
"api_format": api_format_enum.value,
|
"api_format": api_format_enum.value,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 提取上游错误响应
|
||||||
|
upstream_status: Optional[int] = None
|
||||||
|
upstream_response: Optional[str] = None
|
||||||
|
if last_error:
|
||||||
|
# 从 httpx.HTTPStatusError 提取
|
||||||
|
if isinstance(last_error, httpx.HTTPStatusError):
|
||||||
|
upstream_status = last_error.response.status_code
|
||||||
|
# 优先使用我们附加的 upstream_response 属性(流已读取时 response.text 可能为空)
|
||||||
|
upstream_response = getattr(last_error, "upstream_response", None)
|
||||||
|
if not upstream_response:
|
||||||
|
try:
|
||||||
|
upstream_response = last_error.response.text
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# 从其他异常属性提取(如 ProviderNotAvailableException)
|
||||||
|
else:
|
||||||
|
upstream_status = getattr(last_error, "upstream_status", None)
|
||||||
|
upstream_response = getattr(last_error, "upstream_response", None)
|
||||||
|
|
||||||
|
# 如果响应为空或无效,使用异常的字符串表示
|
||||||
|
if (
|
||||||
|
not upstream_response
|
||||||
|
or not upstream_response.strip()
|
||||||
|
or upstream_response.startswith("Unable to read")
|
||||||
|
):
|
||||||
|
upstream_response = str(last_error)
|
||||||
|
|
||||||
raise ProviderNotAvailableException(
|
raise ProviderNotAvailableException(
|
||||||
f"所有Provider均不可用,已尝试{max_attempts}个组合",
|
f"所有Provider均不可用,已尝试{max_attempts}个组合",
|
||||||
request_metadata=request_metadata,
|
request_metadata=request_metadata,
|
||||||
|
upstream_status=upstream_status,
|
||||||
|
upstream_response=upstream_response,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def execute_with_fallback(
|
async def execute_with_fallback(
|
||||||
|
|||||||
@@ -1,14 +1,16 @@
|
|||||||
"""
|
"""
|
||||||
自适应并发调整器 - 基于滑动窗口利用率的并发限制调整
|
自适应并发调整器 - 基于边界记忆的并发限制调整
|
||||||
|
|
||||||
核心改进(相对于旧版基于"持续高利用率"的方案):
|
核心算法:边界记忆 + 渐进探测
|
||||||
- 使用滑动窗口采样,容忍并发波动
|
- 触发 429 时记录边界(last_concurrent_peak),这就是真实上限
|
||||||
- 基于窗口内高利用率采样比例决策,而非要求连续高利用率
|
- 缩容策略:新限制 = 边界 - 1,而非乘性减少
|
||||||
- 增加探测性扩容机制,长时间稳定时主动尝试扩容
|
- 扩容策略:不超过已知边界,除非是探测性扩容
|
||||||
|
- 探测性扩容:长时间无 429 时尝试突破边界
|
||||||
|
|
||||||
AIMD 参数说明:
|
设计原则:
|
||||||
- 扩容:加性增加 (+INCREASE_STEP)
|
1. 快速收敛:一次 429 就能找到接近真实的限制
|
||||||
- 缩容:乘性减少 (*DECREASE_MULTIPLIER,默认 0.85)
|
2. 避免过度保守:不会因为多次 429 而无限下降
|
||||||
|
3. 安全探测:允许在稳定后尝试更高并发
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
@@ -35,21 +37,21 @@ class AdaptiveConcurrencyManager:
|
|||||||
"""
|
"""
|
||||||
自适应并发管理器
|
自适应并发管理器
|
||||||
|
|
||||||
核心算法:基于滑动窗口利用率的 AIMD
|
核心算法:边界记忆 + 渐进探测
|
||||||
- 滑动窗口记录最近 N 次请求的利用率
|
- 触发 429 时记录边界(last_concurrent_peak = 触发时的并发数)
|
||||||
- 当窗口内高利用率采样比例 >= 60% 时触发扩容
|
- 缩容:新限制 = 边界 - 1(快速收敛到真实限制附近)
|
||||||
- 遇到 429 错误时乘性减少 (*0.85)
|
- 扩容:不超过边界(即 last_concurrent_peak),允许回到边界值尝试
|
||||||
- 长时间无 429 且有流量时触发探测性扩容
|
- 探测性扩容:长时间(30分钟)无 429 时,可以尝试 +1 突破边界
|
||||||
|
|
||||||
扩容条件(满足任一即可):
|
扩容条件(满足任一即可):
|
||||||
1. 滑动窗口扩容:窗口内 >= 60% 的采样利用率 >= 70%,且不在冷却期
|
1. 利用率扩容:窗口内高利用率比例 >= 60%,且当前限制 < 边界
|
||||||
2. 探测性扩容:距上次 429 超过 30 分钟,且期间有足够请求量
|
2. 探测性扩容:距上次 429 超过 30 分钟,可以尝试突破边界
|
||||||
|
|
||||||
关键特性:
|
关键特性:
|
||||||
1. 滑动窗口容忍并发波动,不会因单次低利用率重置
|
1. 快速收敛:一次 429 就能学到接近真实的限制值
|
||||||
2. 区分并发限制和 RPM 限制
|
2. 边界保护:普通扩容不会超过已知边界
|
||||||
3. 探测性扩容避免长期卡在低限制
|
3. 安全探测:长时间稳定后允许尝试更高并发
|
||||||
4. 记录调整历史
|
4. 区分并发限制和 RPM 限制
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 默认配置 - 使用统一常量
|
# 默认配置 - 使用统一常量
|
||||||
@@ -59,7 +61,6 @@ class AdaptiveConcurrencyManager:
|
|||||||
|
|
||||||
# AIMD 参数
|
# AIMD 参数
|
||||||
INCREASE_STEP = ConcurrencyDefaults.INCREASE_STEP
|
INCREASE_STEP = ConcurrencyDefaults.INCREASE_STEP
|
||||||
DECREASE_MULTIPLIER = ConcurrencyDefaults.DECREASE_MULTIPLIER
|
|
||||||
|
|
||||||
# 滑动窗口参数
|
# 滑动窗口参数
|
||||||
UTILIZATION_WINDOW_SIZE = ConcurrencyDefaults.UTILIZATION_WINDOW_SIZE
|
UTILIZATION_WINDOW_SIZE = ConcurrencyDefaults.UTILIZATION_WINDOW_SIZE
|
||||||
@@ -115,7 +116,13 @@ class AdaptiveConcurrencyManager:
|
|||||||
# 更新429统计
|
# 更新429统计
|
||||||
key.last_429_at = datetime.now(timezone.utc) # type: ignore[assignment]
|
key.last_429_at = datetime.now(timezone.utc) # type: ignore[assignment]
|
||||||
key.last_429_type = rate_limit_info.limit_type # type: ignore[assignment]
|
key.last_429_type = rate_limit_info.limit_type # type: ignore[assignment]
|
||||||
key.last_concurrent_peak = current_concurrent # type: ignore[assignment]
|
# 仅在并发限制且拿到并发数时记录边界(RPM/UNKNOWN 不应覆盖并发边界记忆)
|
||||||
|
if (
|
||||||
|
rate_limit_info.limit_type == RateLimitType.CONCURRENT
|
||||||
|
and current_concurrent is not None
|
||||||
|
and current_concurrent > 0
|
||||||
|
):
|
||||||
|
key.last_concurrent_peak = current_concurrent # type: ignore[assignment]
|
||||||
|
|
||||||
# 遇到 429 错误,清空利用率采样窗口(重新开始收集)
|
# 遇到 429 错误,清空利用率采样窗口(重新开始收集)
|
||||||
key.utilization_samples = [] # type: ignore[assignment]
|
key.utilization_samples = [] # type: ignore[assignment]
|
||||||
@@ -207,6 +214,9 @@ class AdaptiveConcurrencyManager:
|
|||||||
|
|
||||||
current_limit = int(key.learned_max_concurrent or self.DEFAULT_INITIAL_LIMIT)
|
current_limit = int(key.learned_max_concurrent or self.DEFAULT_INITIAL_LIMIT)
|
||||||
|
|
||||||
|
# 获取已知边界(上次触发 429 时的并发数)
|
||||||
|
known_boundary = key.last_concurrent_peak
|
||||||
|
|
||||||
# 计算当前利用率
|
# 计算当前利用率
|
||||||
utilization = float(current_concurrent / current_limit) if current_limit > 0 else 0.0
|
utilization = float(current_concurrent / current_limit) if current_limit > 0 else 0.0
|
||||||
|
|
||||||
@@ -217,22 +227,29 @@ class AdaptiveConcurrencyManager:
|
|||||||
samples = self._update_utilization_window(key, now_ts, utilization)
|
samples = self._update_utilization_window(key, now_ts, utilization)
|
||||||
|
|
||||||
# 检查是否满足扩容条件
|
# 检查是否满足扩容条件
|
||||||
increase_reason = self._check_increase_conditions(key, samples, now)
|
increase_reason = self._check_increase_conditions(key, samples, now, known_boundary)
|
||||||
|
|
||||||
if increase_reason and current_limit < self.MAX_CONCURRENT_LIMIT:
|
if increase_reason and current_limit < self.MAX_CONCURRENT_LIMIT:
|
||||||
old_limit = current_limit
|
old_limit = current_limit
|
||||||
new_limit = self._increase_limit(current_limit)
|
is_probe = increase_reason == "probe_increase"
|
||||||
|
new_limit = self._increase_limit(current_limit, known_boundary, is_probe)
|
||||||
|
|
||||||
|
# 如果没有实际增长(已达边界),跳过
|
||||||
|
if new_limit <= old_limit:
|
||||||
|
return None
|
||||||
|
|
||||||
# 计算窗口统计用于日志
|
# 计算窗口统计用于日志
|
||||||
avg_util = sum(s["util"] for s in samples) / len(samples) if samples else 0
|
avg_util = sum(s["util"] for s in samples) / len(samples) if samples else 0
|
||||||
high_util_count = sum(1 for s in samples if s["util"] >= self.UTILIZATION_THRESHOLD)
|
high_util_count = sum(1 for s in samples if s["util"] >= self.UTILIZATION_THRESHOLD)
|
||||||
high_util_ratio = high_util_count / len(samples) if samples else 0
|
high_util_ratio = high_util_count / len(samples) if samples else 0
|
||||||
|
|
||||||
|
boundary_info = f"边界: {known_boundary}" if known_boundary else "无边界"
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[INCREASE] {increase_reason}: Key {key.id[:8]}... | "
|
f"[INCREASE] {increase_reason}: Key {key.id[:8]}... | "
|
||||||
f"窗口采样: {len(samples)} | "
|
f"窗口采样: {len(samples)} | "
|
||||||
f"平均利用率: {avg_util:.1%} | "
|
f"平均利用率: {avg_util:.1%} | "
|
||||||
f"高利用率比例: {high_util_ratio:.1%} | "
|
f"高利用率比例: {high_util_ratio:.1%} | "
|
||||||
|
f"{boundary_info} | "
|
||||||
f"调整: {old_limit} -> {new_limit}"
|
f"调整: {old_limit} -> {new_limit}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -246,13 +263,14 @@ class AdaptiveConcurrencyManager:
|
|||||||
high_util_ratio=round(high_util_ratio, 2),
|
high_util_ratio=round(high_util_ratio, 2),
|
||||||
sample_count=len(samples),
|
sample_count=len(samples),
|
||||||
current_concurrent=current_concurrent,
|
current_concurrent=current_concurrent,
|
||||||
|
known_boundary=known_boundary,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 更新限制
|
# 更新限制
|
||||||
key.learned_max_concurrent = new_limit # type: ignore[assignment]
|
key.learned_max_concurrent = new_limit # type: ignore[assignment]
|
||||||
|
|
||||||
# 如果是探测性扩容,更新探测时间
|
# 如果是探测性扩容,更新探测时间
|
||||||
if increase_reason == "probe_increase":
|
if is_probe:
|
||||||
key.last_probe_increase_at = now # type: ignore[assignment]
|
key.last_probe_increase_at = now # type: ignore[assignment]
|
||||||
|
|
||||||
# 扩容后清空采样窗口,重新开始收集
|
# 扩容后清空采样窗口,重新开始收集
|
||||||
@@ -303,7 +321,11 @@ class AdaptiveConcurrencyManager:
|
|||||||
return samples
|
return samples
|
||||||
|
|
||||||
def _check_increase_conditions(
|
def _check_increase_conditions(
|
||||||
self, key: ProviderAPIKey, samples: List[Dict[str, Any]], now: datetime
|
self,
|
||||||
|
key: ProviderAPIKey,
|
||||||
|
samples: List[Dict[str, Any]],
|
||||||
|
now: datetime,
|
||||||
|
known_boundary: Optional[int] = None,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
检查是否满足扩容条件
|
检查是否满足扩容条件
|
||||||
@@ -312,6 +334,7 @@ class AdaptiveConcurrencyManager:
|
|||||||
key: API Key对象
|
key: API Key对象
|
||||||
samples: 利用率采样列表
|
samples: 利用率采样列表
|
||||||
now: 当前时间
|
now: 当前时间
|
||||||
|
known_boundary: 已知边界(触发 429 时的并发数)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
扩容原因(如果满足条件),否则返回 None
|
扩容原因(如果满足条件),否则返回 None
|
||||||
@@ -320,15 +343,25 @@ class AdaptiveConcurrencyManager:
|
|||||||
if self._is_in_cooldown(key):
|
if self._is_in_cooldown(key):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 条件1:滑动窗口扩容
|
current_limit = int(key.learned_max_concurrent or self.DEFAULT_INITIAL_LIMIT)
|
||||||
|
|
||||||
|
# 条件1:滑动窗口扩容(不超过边界)
|
||||||
if len(samples) >= self.MIN_SAMPLES_FOR_DECISION:
|
if len(samples) >= self.MIN_SAMPLES_FOR_DECISION:
|
||||||
high_util_count = sum(1 for s in samples if s["util"] >= self.UTILIZATION_THRESHOLD)
|
high_util_count = sum(1 for s in samples if s["util"] >= self.UTILIZATION_THRESHOLD)
|
||||||
high_util_ratio = high_util_count / len(samples)
|
high_util_ratio = high_util_count / len(samples)
|
||||||
|
|
||||||
if high_util_ratio >= self.HIGH_UTILIZATION_RATIO:
|
if high_util_ratio >= self.HIGH_UTILIZATION_RATIO:
|
||||||
return "high_utilization"
|
# 检查是否还有扩容空间(边界保护)
|
||||||
|
if known_boundary:
|
||||||
|
# 允许扩容到边界值(而非 boundary - 1),因为缩容时已经 -1 了
|
||||||
|
if current_limit < known_boundary:
|
||||||
|
return "high_utilization"
|
||||||
|
# 已达边界,不触发普通扩容
|
||||||
|
else:
|
||||||
|
# 无边界信息,允许扩容
|
||||||
|
return "high_utilization"
|
||||||
|
|
||||||
# 条件2:探测性扩容(长时间无 429 且有流量)
|
# 条件2:探测性扩容(长时间无 429 且有流量,可以突破边界)
|
||||||
if self._should_probe_increase(key, samples, now):
|
if self._should_probe_increase(key, samples, now):
|
||||||
return "probe_increase"
|
return "probe_increase"
|
||||||
|
|
||||||
@@ -406,32 +439,65 @@ class AdaptiveConcurrencyManager:
|
|||||||
current_concurrent: Optional[int] = None,
|
current_concurrent: Optional[int] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
减少并发限制
|
减少并发限制(基于边界记忆策略)
|
||||||
|
|
||||||
策略:
|
策略:
|
||||||
- 如果知道当前并发数,设置为当前并发的70%
|
- 如果知道触发 429 时的并发数,新限制 = 并发数 - 1
|
||||||
- 否则,使用乘性减少
|
- 这样可以快速收敛到真实限制附近,而不会过度保守
|
||||||
|
- 例如:真实限制 8,触发时并发 8 -> 新限制 7(而非 8*0.85=6)
|
||||||
"""
|
"""
|
||||||
if current_concurrent:
|
if current_concurrent is not None and current_concurrent > 0:
|
||||||
# 基于当前并发数减少
|
# 边界记忆策略:新限制 = 触发边界 - 1
|
||||||
new_limit = max(
|
candidate = current_concurrent - 1
|
||||||
int(current_concurrent * self.DECREASE_MULTIPLIER), self.MIN_CONCURRENT_LIMIT
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# 乘性减少
|
# 没有并发信息时,保守减少 1
|
||||||
new_limit = max(
|
candidate = current_limit - 1
|
||||||
int(current_limit * self.DECREASE_MULTIPLIER), self.MIN_CONCURRENT_LIMIT
|
|
||||||
)
|
# 保证不会“缩容变扩容”(例如 current_concurrent > current_limit 的异常场景)
|
||||||
|
candidate = min(candidate, current_limit - 1)
|
||||||
|
|
||||||
|
new_limit = max(candidate, self.MIN_CONCURRENT_LIMIT)
|
||||||
|
|
||||||
return new_limit
|
return new_limit
|
||||||
|
|
||||||
def _increase_limit(self, current_limit: int) -> int:
|
def _increase_limit(
|
||||||
|
self,
|
||||||
|
current_limit: int,
|
||||||
|
known_boundary: Optional[int] = None,
|
||||||
|
is_probe: bool = False,
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
增加并发限制
|
增加并发限制(考虑边界保护)
|
||||||
|
|
||||||
策略:加性增加 (+1)
|
策略:
|
||||||
|
- 普通扩容:每次 +INCREASE_STEP,但不超过 known_boundary
|
||||||
|
(因为缩容时已经 -1 了,这里允许回到边界值尝试)
|
||||||
|
- 探测性扩容:每次只 +1,可以突破边界,但要谨慎
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_limit: 当前限制
|
||||||
|
known_boundary: 已知边界(last_concurrent_peak),即触发 429 时的并发数
|
||||||
|
is_probe: 是否是探测性扩容(可以突破边界)
|
||||||
"""
|
"""
|
||||||
new_limit = min(current_limit + self.INCREASE_STEP, self.MAX_CONCURRENT_LIMIT)
|
if is_probe:
|
||||||
|
# 探测模式:每次只 +1,谨慎突破边界
|
||||||
|
new_limit = current_limit + 1
|
||||||
|
else:
|
||||||
|
# 普通模式:每次 +INCREASE_STEP
|
||||||
|
new_limit = current_limit + self.INCREASE_STEP
|
||||||
|
|
||||||
|
# 边界保护:普通扩容不超过 known_boundary(允许回到边界值尝试)
|
||||||
|
if known_boundary:
|
||||||
|
if new_limit > known_boundary:
|
||||||
|
new_limit = known_boundary
|
||||||
|
|
||||||
|
# 全局上限保护
|
||||||
|
new_limit = min(new_limit, self.MAX_CONCURRENT_LIMIT)
|
||||||
|
|
||||||
|
# 确保有增长(否则返回原值表示不扩容)
|
||||||
|
if new_limit <= current_limit:
|
||||||
|
return current_limit
|
||||||
|
|
||||||
return new_limit
|
return new_limit
|
||||||
|
|
||||||
def _record_adjustment(
|
def _record_adjustment(
|
||||||
@@ -503,11 +569,16 @@ class AdaptiveConcurrencyManager:
|
|||||||
if key.last_probe_increase_at:
|
if key.last_probe_increase_at:
|
||||||
last_probe_at_str = cast(datetime, key.last_probe_increase_at).isoformat()
|
last_probe_at_str = cast(datetime, key.last_probe_increase_at).isoformat()
|
||||||
|
|
||||||
|
# 边界信息
|
||||||
|
known_boundary = key.last_concurrent_peak
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"adaptive_mode": is_adaptive,
|
"adaptive_mode": is_adaptive,
|
||||||
"max_concurrent": key.max_concurrent, # NULL=自适应,数字=固定限制
|
"max_concurrent": key.max_concurrent, # NULL=自适应,数字=固定限制
|
||||||
"effective_limit": effective_limit, # 当前有效限制
|
"effective_limit": effective_limit, # 当前有效限制
|
||||||
"learned_limit": key.learned_max_concurrent, # 学习到的限制
|
"learned_limit": key.learned_max_concurrent, # 学习到的限制
|
||||||
|
# 边界记忆相关
|
||||||
|
"known_boundary": known_boundary, # 触发 429 时的并发数(已知上限)
|
||||||
"concurrent_429_count": int(key.concurrent_429_count or 0),
|
"concurrent_429_count": int(key.concurrent_429_count or 0),
|
||||||
"rpm_429_count": int(key.rpm_429_count or 0),
|
"rpm_429_count": int(key.rpm_429_count or 0),
|
||||||
"last_429_at": last_429_at_str,
|
"last_429_at": last_429_at_str,
|
||||||
|
|||||||
@@ -289,11 +289,11 @@ class RequestResult:
|
|||||||
status_code = 500
|
status_code = 500
|
||||||
error_type = "internal_error"
|
error_type = "internal_error"
|
||||||
|
|
||||||
# 构建错误消息,包含上游响应信息
|
# 构建错误消息:优先使用上游响应作为主要错误信息
|
||||||
error_message = str(exception)
|
if isinstance(exception, ProviderNotAvailableException) and exception.upstream_response:
|
||||||
if isinstance(exception, ProviderNotAvailableException):
|
error_message = exception.upstream_response
|
||||||
if exception.upstream_response:
|
else:
|
||||||
error_message = f"{error_message} | 上游响应: {exception.upstream_response[:500]}"
|
error_message = str(exception)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
status=RequestStatus.FAILED,
|
status=RequestStatus.FAILED,
|
||||||
|
|||||||
@@ -86,6 +86,118 @@ class UsageRecordParams:
|
|||||||
class UsageService:
|
class UsageService:
|
||||||
"""用量统计服务"""
|
"""用量统计服务"""
|
||||||
|
|
||||||
|
# ==================== 缓存键常量 ====================
|
||||||
|
|
||||||
|
# 热力图缓存键前缀(依赖 TTL 自动过期,用户角色变更时主动清除)
|
||||||
|
HEATMAP_CACHE_KEY_PREFIX = "activity_heatmap"
|
||||||
|
|
||||||
|
# ==================== 热力图缓存 ====================
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_heatmap_cache_key(cls, user_id: Optional[str], include_actual_cost: bool) -> str:
|
||||||
|
"""生成热力图缓存键"""
|
||||||
|
cost_suffix = "with_cost" if include_actual_cost else "no_cost"
|
||||||
|
if user_id:
|
||||||
|
return f"{cls.HEATMAP_CACHE_KEY_PREFIX}:user:{user_id}:{cost_suffix}"
|
||||||
|
else:
|
||||||
|
return f"{cls.HEATMAP_CACHE_KEY_PREFIX}:admin:all:{cost_suffix}"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def clear_user_heatmap_cache(cls, user_id: str) -> None:
|
||||||
|
"""
|
||||||
|
清除用户的热力图缓存(用户角色变更时调用)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID
|
||||||
|
"""
|
||||||
|
from src.clients.redis_client import get_redis_client
|
||||||
|
|
||||||
|
redis_client = await get_redis_client(require_redis=False)
|
||||||
|
if not redis_client:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 清除该用户的所有热力图缓存(with_cost 和 no_cost)
|
||||||
|
keys_to_delete = [
|
||||||
|
cls._get_heatmap_cache_key(user_id, include_actual_cost=True),
|
||||||
|
cls._get_heatmap_cache_key(user_id, include_actual_cost=False),
|
||||||
|
]
|
||||||
|
|
||||||
|
for key in keys_to_delete:
|
||||||
|
try:
|
||||||
|
await redis_client.delete(key)
|
||||||
|
logger.debug(f"已清除热力图缓存: {key}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"清除热力图缓存失败: {key}, error={e}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_cached_heatmap(
|
||||||
|
cls,
|
||||||
|
db: Session,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
include_actual_cost: bool = False,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取带缓存的热力图数据
|
||||||
|
|
||||||
|
缓存策略:
|
||||||
|
- TTL: 5分钟(CacheTTL.ACTIVITY_HEATMAP)
|
||||||
|
- 仅依赖 TTL 自动过期,新使用记录最多延迟 5 分钟出现
|
||||||
|
- 用户角色变更时通过 clear_user_heatmap_cache() 主动清除
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
user_id: 用户ID,None 表示获取全局热力图(管理员)
|
||||||
|
include_actual_cost: 是否包含实际成本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
热力图数据字典
|
||||||
|
"""
|
||||||
|
from src.clients.redis_client import get_redis_client
|
||||||
|
from src.config.constants import CacheTTL
|
||||||
|
import json
|
||||||
|
|
||||||
|
cache_key = cls._get_heatmap_cache_key(user_id, include_actual_cost)
|
||||||
|
|
||||||
|
cache_ttl = CacheTTL.ACTIVITY_HEATMAP
|
||||||
|
redis_client = await get_redis_client(require_redis=False)
|
||||||
|
|
||||||
|
# 尝试从缓存获取
|
||||||
|
if redis_client:
|
||||||
|
try:
|
||||||
|
cached = await redis_client.get(cache_key)
|
||||||
|
if cached:
|
||||||
|
try:
|
||||||
|
return json.loads(cached) # type: ignore[no-any-return]
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning(f"热力图缓存解析失败,删除损坏缓存: {cache_key}, error={e}")
|
||||||
|
try:
|
||||||
|
await redis_client.delete(cache_key)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"读取热力图缓存出错: {cache_key}, error={e}")
|
||||||
|
|
||||||
|
# 从数据库查询
|
||||||
|
result = cls.get_daily_activity(
|
||||||
|
db=db,
|
||||||
|
user_id=user_id,
|
||||||
|
window_days=365,
|
||||||
|
include_actual_cost=include_actual_cost,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存到缓存(失败不影响返回结果)
|
||||||
|
if redis_client:
|
||||||
|
try:
|
||||||
|
await redis_client.setex(
|
||||||
|
cache_key,
|
||||||
|
cache_ttl,
|
||||||
|
json.dumps(result, ensure_ascii=False, default=str),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"保存热力图缓存失败: {cache_key}, error={e}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
# ==================== 内部数据类 ====================
|
# ==================== 内部数据类 ====================
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -1027,7 +1139,12 @@ class UsageService:
|
|||||||
window_days: int = 365,
|
window_days: int = 365,
|
||||||
include_actual_cost: bool = False,
|
include_actual_cost: bool = False,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""按天统计请求活跃度,用于渲染热力图。"""
|
"""按天统计请求活跃度,用于渲染热力图。
|
||||||
|
|
||||||
|
优化策略:
|
||||||
|
- 历史数据从预计算的 StatsDaily/StatsUserDaily 表读取
|
||||||
|
- 只有"今天"的数据才实时查询 Usage 表
|
||||||
|
"""
|
||||||
|
|
||||||
def ensure_timezone(value: datetime) -> datetime:
|
def ensure_timezone(value: datetime) -> datetime:
|
||||||
if value.tzinfo is None:
|
if value.tzinfo is None:
|
||||||
@@ -1041,54 +1158,109 @@ class UsageService:
|
|||||||
ensure_timezone(start_date) if start_date else end_dt - timedelta(days=window_days - 1)
|
ensure_timezone(start_date) if start_date else end_dt - timedelta(days=window_days - 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 对齐到自然日的开始/结束,避免遗漏边界数据
|
# 对齐到自然日的开始/结束
|
||||||
start_dt = start_dt.replace(hour=0, minute=0, second=0, microsecond=0)
|
start_dt = datetime.combine(start_dt.date(), datetime.min.time(), tzinfo=timezone.utc)
|
||||||
end_dt = end_dt.replace(hour=23, minute=59, second=59, microsecond=999999)
|
end_dt = datetime.combine(end_dt.date(), datetime.max.time(), tzinfo=timezone.utc)
|
||||||
|
|
||||||
from src.utils.database_helpers import date_trunc_portable
|
|
||||||
|
|
||||||
bind = db.get_bind()
|
|
||||||
dialect = bind.dialect.name if bind is not None else "sqlite"
|
|
||||||
day_bucket = date_trunc_portable(dialect, "day", Usage.created_at).label("day")
|
|
||||||
|
|
||||||
columns = [
|
|
||||||
day_bucket,
|
|
||||||
func.count(Usage.id).label("requests"),
|
|
||||||
func.sum(Usage.total_tokens).label("total_tokens"),
|
|
||||||
func.sum(Usage.total_cost_usd).label("total_cost_usd"),
|
|
||||||
]
|
|
||||||
|
|
||||||
if include_actual_cost:
|
|
||||||
columns.append(func.sum(Usage.actual_total_cost_usd).label("actual_total_cost_usd"))
|
|
||||||
|
|
||||||
query = db.query(*columns).filter(Usage.created_at >= start_dt, Usage.created_at <= end_dt)
|
|
||||||
|
|
||||||
if user_id:
|
|
||||||
query = query.filter(Usage.user_id == user_id)
|
|
||||||
|
|
||||||
query = query.group_by(day_bucket).order_by(day_bucket)
|
|
||||||
rows = query.all()
|
|
||||||
|
|
||||||
def normalize_period(value: Any) -> str:
|
|
||||||
if value is None:
|
|
||||||
return ""
|
|
||||||
if isinstance(value, str):
|
|
||||||
return value[:10]
|
|
||||||
if isinstance(value, datetime):
|
|
||||||
return value.date().isoformat()
|
|
||||||
return str(value)
|
|
||||||
|
|
||||||
|
today = now.date()
|
||||||
|
today_start_dt = datetime.combine(today, datetime.min.time(), tzinfo=timezone.utc)
|
||||||
aggregated: Dict[str, Dict[str, Any]] = {}
|
aggregated: Dict[str, Dict[str, Any]] = {}
|
||||||
for row in rows:
|
|
||||||
key = normalize_period(row.day)
|
|
||||||
aggregated[key] = {
|
|
||||||
"requests": int(row.requests or 0),
|
|
||||||
"total_tokens": int(row.total_tokens or 0),
|
|
||||||
"total_cost_usd": float(row.total_cost_usd or 0.0),
|
|
||||||
}
|
|
||||||
if include_actual_cost:
|
|
||||||
aggregated[key]["actual_total_cost_usd"] = float(row.actual_total_cost_usd or 0.0)
|
|
||||||
|
|
||||||
|
# 1. 从预计算表读取历史数据(不包括今天)
|
||||||
|
if user_id:
|
||||||
|
from src.models.database import StatsUserDaily
|
||||||
|
|
||||||
|
hist_query = db.query(StatsUserDaily).filter(
|
||||||
|
StatsUserDaily.user_id == user_id,
|
||||||
|
StatsUserDaily.date >= start_dt,
|
||||||
|
StatsUserDaily.date < today_start_dt,
|
||||||
|
)
|
||||||
|
for row in hist_query.all():
|
||||||
|
key = (
|
||||||
|
row.date.date().isoformat()
|
||||||
|
if isinstance(row.date, datetime)
|
||||||
|
else str(row.date)[:10]
|
||||||
|
)
|
||||||
|
aggregated[key] = {
|
||||||
|
"requests": row.total_requests or 0,
|
||||||
|
"total_tokens": (
|
||||||
|
(row.input_tokens or 0)
|
||||||
|
+ (row.output_tokens or 0)
|
||||||
|
+ (row.cache_creation_tokens or 0)
|
||||||
|
+ (row.cache_read_tokens or 0)
|
||||||
|
),
|
||||||
|
"total_cost_usd": float(row.total_cost or 0.0),
|
||||||
|
}
|
||||||
|
# StatsUserDaily 没有 actual_total_cost 字段,用户视图不需要倍率成本
|
||||||
|
else:
|
||||||
|
from src.models.database import StatsDaily
|
||||||
|
|
||||||
|
hist_query = db.query(StatsDaily).filter(
|
||||||
|
StatsDaily.date >= start_dt,
|
||||||
|
StatsDaily.date < today_start_dt,
|
||||||
|
)
|
||||||
|
for row in hist_query.all():
|
||||||
|
key = (
|
||||||
|
row.date.date().isoformat()
|
||||||
|
if isinstance(row.date, datetime)
|
||||||
|
else str(row.date)[:10]
|
||||||
|
)
|
||||||
|
aggregated[key] = {
|
||||||
|
"requests": row.total_requests or 0,
|
||||||
|
"total_tokens": (
|
||||||
|
(row.input_tokens or 0)
|
||||||
|
+ (row.output_tokens or 0)
|
||||||
|
+ (row.cache_creation_tokens or 0)
|
||||||
|
+ (row.cache_read_tokens or 0)
|
||||||
|
),
|
||||||
|
"total_cost_usd": float(row.total_cost or 0.0),
|
||||||
|
}
|
||||||
|
if include_actual_cost:
|
||||||
|
aggregated[key]["actual_total_cost_usd"] = float(
|
||||||
|
row.actual_total_cost or 0.0 # type: ignore[attr-defined]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. 实时查询今天的数据(如果在查询范围内)
|
||||||
|
if today >= start_dt.date() and today <= end_dt.date():
|
||||||
|
today_start = datetime.combine(today, datetime.min.time(), tzinfo=timezone.utc)
|
||||||
|
today_end = datetime.combine(today, datetime.max.time(), tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
if include_actual_cost:
|
||||||
|
today_query = db.query(
|
||||||
|
func.count(Usage.id).label("requests"),
|
||||||
|
func.sum(Usage.total_tokens).label("total_tokens"),
|
||||||
|
func.sum(Usage.total_cost_usd).label("total_cost_usd"),
|
||||||
|
func.sum(Usage.actual_total_cost_usd).label("actual_total_cost_usd"),
|
||||||
|
).filter(
|
||||||
|
Usage.created_at >= today_start,
|
||||||
|
Usage.created_at <= today_end,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
today_query = db.query(
|
||||||
|
func.count(Usage.id).label("requests"),
|
||||||
|
func.sum(Usage.total_tokens).label("total_tokens"),
|
||||||
|
func.sum(Usage.total_cost_usd).label("total_cost_usd"),
|
||||||
|
).filter(
|
||||||
|
Usage.created_at >= today_start,
|
||||||
|
Usage.created_at <= today_end,
|
||||||
|
)
|
||||||
|
|
||||||
|
if user_id:
|
||||||
|
today_query = today_query.filter(Usage.user_id == user_id)
|
||||||
|
|
||||||
|
today_row = today_query.first()
|
||||||
|
if today_row and today_row.requests:
|
||||||
|
aggregated[today.isoformat()] = {
|
||||||
|
"requests": int(today_row.requests or 0),
|
||||||
|
"total_tokens": int(today_row.total_tokens or 0),
|
||||||
|
"total_cost_usd": float(today_row.total_cost_usd or 0.0),
|
||||||
|
}
|
||||||
|
if include_actual_cost:
|
||||||
|
aggregated[today.isoformat()]["actual_total_cost_usd"] = float(
|
||||||
|
today_row.actual_total_cost_usd or 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 构建返回结果
|
||||||
days: List[Dict[str, Any]] = []
|
days: List[Dict[str, Any]] = []
|
||||||
cursor = start_dt.date()
|
cursor = start_dt.date()
|
||||||
end_date_only = end_dt.date()
|
end_date_only = end_dt.date()
|
||||||
@@ -1304,6 +1476,9 @@ class UsageService:
|
|||||||
provider: Optional[str] = None,
|
provider: Optional[str] = None,
|
||||||
target_model: Optional[str] = None,
|
target_model: Optional[str] = None,
|
||||||
first_byte_time_ms: Optional[int] = None,
|
first_byte_time_ms: Optional[int] = None,
|
||||||
|
provider_id: Optional[str] = None,
|
||||||
|
provider_endpoint_id: Optional[str] = None,
|
||||||
|
provider_api_key_id: Optional[str] = None,
|
||||||
) -> Optional[Usage]:
|
) -> Optional[Usage]:
|
||||||
"""
|
"""
|
||||||
快速更新使用记录状态
|
快速更新使用记录状态
|
||||||
@@ -1316,6 +1491,9 @@ class UsageService:
|
|||||||
provider: 提供商名称(可选,streaming 状态时更新)
|
provider: 提供商名称(可选,streaming 状态时更新)
|
||||||
target_model: 映射后的目标模型名(可选)
|
target_model: 映射后的目标模型名(可选)
|
||||||
first_byte_time_ms: 首字时间/TTFB(可选,streaming 状态时更新)
|
first_byte_time_ms: 首字时间/TTFB(可选,streaming 状态时更新)
|
||||||
|
provider_id: Provider ID(可选,streaming 状态时更新)
|
||||||
|
provider_endpoint_id: Endpoint ID(可选,streaming 状态时更新)
|
||||||
|
provider_api_key_id: Provider API Key ID(可选,streaming 状态时更新)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
更新后的 Usage 记录,如果未找到则返回 None
|
更新后的 Usage 记录,如果未找到则返回 None
|
||||||
@@ -1331,10 +1509,22 @@ class UsageService:
|
|||||||
usage.error_message = error_message
|
usage.error_message = error_message
|
||||||
if provider:
|
if provider:
|
||||||
usage.provider = provider
|
usage.provider = provider
|
||||||
|
elif status == "streaming" and usage.provider == "pending":
|
||||||
|
# 状态变为 streaming 但 provider 仍为 pending,记录警告
|
||||||
|
logger.warning(
|
||||||
|
f"状态更新为 streaming 但 provider 为空: request_id={request_id}, "
|
||||||
|
f"当前 provider={usage.provider}"
|
||||||
|
)
|
||||||
if target_model:
|
if target_model:
|
||||||
usage.target_model = target_model
|
usage.target_model = target_model
|
||||||
if first_byte_time_ms is not None:
|
if first_byte_time_ms is not None:
|
||||||
usage.first_byte_time_ms = first_byte_time_ms
|
usage.first_byte_time_ms = first_byte_time_ms
|
||||||
|
if provider_id is not None:
|
||||||
|
usage.provider_id = provider_id
|
||||||
|
if provider_endpoint_id is not None:
|
||||||
|
usage.provider_endpoint_id = provider_endpoint_id
|
||||||
|
if provider_api_key_id is not None:
|
||||||
|
usage.provider_api_key_id = provider_api_key_id
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
@@ -1446,6 +1636,8 @@ class UsageService:
|
|||||||
ids: Optional[List[str]] = None,
|
ids: Optional[List[str]] = None,
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
default_timeout_seconds: int = 300,
|
default_timeout_seconds: int = 300,
|
||||||
|
*,
|
||||||
|
include_admin_fields: bool = False,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
获取活跃请求状态(用于前端轮询),并自动清理超时的 pending/streaming 请求
|
获取活跃请求状态(用于前端轮询),并自动清理超时的 pending/streaming 请求
|
||||||
@@ -1482,6 +1674,15 @@ class UsageService:
|
|||||||
ProviderEndpoint.timeout.label("endpoint_timeout"),
|
ProviderEndpoint.timeout.label("endpoint_timeout"),
|
||||||
).outerjoin(ProviderEndpoint, Usage.provider_endpoint_id == ProviderEndpoint.id)
|
).outerjoin(ProviderEndpoint, Usage.provider_endpoint_id == ProviderEndpoint.id)
|
||||||
|
|
||||||
|
# 管理员轮询:可附带 provider 与上游 key 名称(注意:不要在普通用户接口暴露上游 key 信息)
|
||||||
|
if include_admin_fields:
|
||||||
|
from src.models.database import ProviderAPIKey
|
||||||
|
|
||||||
|
query = query.add_columns(
|
||||||
|
Usage.provider,
|
||||||
|
ProviderAPIKey.name.label("api_key_name"),
|
||||||
|
).outerjoin(ProviderAPIKey, Usage.provider_api_key_id == ProviderAPIKey.id)
|
||||||
|
|
||||||
if ids:
|
if ids:
|
||||||
query = query.filter(Usage.id.in_(ids))
|
query = query.filter(Usage.id.in_(ids))
|
||||||
if user_id:
|
if user_id:
|
||||||
@@ -1518,8 +1719,9 @@ class UsageService:
|
|||||||
)
|
)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
return [
|
result: List[Dict[str, Any]] = []
|
||||||
{
|
for r in records:
|
||||||
|
item: Dict[str, Any] = {
|
||||||
"id": r.id,
|
"id": r.id,
|
||||||
"status": "failed" if r.id in timeout_ids else r.status,
|
"status": "failed" if r.id in timeout_ids else r.status,
|
||||||
"input_tokens": r.input_tokens,
|
"input_tokens": r.input_tokens,
|
||||||
@@ -1528,8 +1730,12 @@ class UsageService:
|
|||||||
"response_time_ms": r.response_time_ms,
|
"response_time_ms": r.response_time_ms,
|
||||||
"first_byte_time_ms": r.first_byte_time_ms, # 首字时间 (TTFB)
|
"first_byte_time_ms": r.first_byte_time_ms, # 首字时间 (TTFB)
|
||||||
}
|
}
|
||||||
for r in records
|
if include_admin_fields:
|
||||||
]
|
item["provider"] = r.provider
|
||||||
|
item["api_key_name"] = r.api_key_name
|
||||||
|
result.append(item)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
# ========== 缓存亲和性分析方法 ==========
|
# ========== 缓存亲和性分析方法 ==========
|
||||||
|
|
||||||
|
|||||||
@@ -459,34 +459,38 @@ class StreamUsageTracker:
|
|||||||
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应 | 估算输入tokens:{self.input_tokens}")
|
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应 | 估算输入tokens:{self.input_tokens}")
|
||||||
|
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
first_chunk_received = False
|
first_byte_time_ms = None # 预先记录 TTFB,避免 yield 后计算不准确
|
||||||
try:
|
try:
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
# 保存原始字节流(用于错误诊断)
|
# 保存原始字节流(用于错误诊断)
|
||||||
self.raw_chunks.append(chunk)
|
self.raw_chunks.append(chunk)
|
||||||
|
|
||||||
# 第一个 chunk 收到时,更新状态为 streaming 并记录 TTFB
|
# 第一个 chunk 收到时,记录 TTFB 时间点(但先不更新数据库,避免阻塞)
|
||||||
if not first_chunk_received:
|
if chunk_count == 1:
|
||||||
first_chunk_received = True
|
# 计算 TTFB(使用请求原始开始时间或 track_stream 开始时间)
|
||||||
if self.request_id:
|
base_time = self.request_start_time or self.start_time
|
||||||
try:
|
first_byte_time_ms = int((time.time() - base_time) * 1000) if base_time else None
|
||||||
# 计算 TTFB(使用请求原始开始时间或 track_stream 开始时间)
|
|
||||||
base_time = self.request_start_time or self.start_time
|
|
||||||
first_byte_time_ms = int((time.time() - base_time) * 1000) if base_time else None
|
|
||||||
UsageService.update_usage_status(
|
|
||||||
db=self.db,
|
|
||||||
request_id=self.request_id,
|
|
||||||
status="streaming",
|
|
||||||
provider=self.provider,
|
|
||||||
first_byte_time_ms=first_byte_time_ms,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
|
|
||||||
|
|
||||||
# 返回原始块给客户端
|
# 先返回原始块给客户端,确保 TTFB 不受数据库操作影响
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
# yield 后再更新数据库状态(仅第一个 chunk 时执行)
|
||||||
|
if chunk_count == 1 and self.request_id:
|
||||||
|
try:
|
||||||
|
UsageService.update_usage_status(
|
||||||
|
db=self.db,
|
||||||
|
request_id=self.request_id,
|
||||||
|
status="streaming",
|
||||||
|
provider=self.provider,
|
||||||
|
first_byte_time_ms=first_byte_time_ms,
|
||||||
|
provider_id=self.provider_id,
|
||||||
|
provider_endpoint_id=self.provider_endpoint_id,
|
||||||
|
provider_api_key_id=self.provider_api_key_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
|
||||||
|
|
||||||
# 解析块以提取内容和使用信息(chunk是原始字节)
|
# 解析块以提取内容和使用信息(chunk是原始字节)
|
||||||
content, usage = self.parse_stream_chunk(chunk)
|
content, usage = self.parse_stream_chunk(chunk)
|
||||||
|
|
||||||
@@ -916,15 +920,38 @@ class EnhancedStreamUsageTracker(StreamUsageTracker):
|
|||||||
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应(Enhanced) | 估算输入tokens:{self.input_tokens}")
|
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应(Enhanced) | 估算输入tokens:{self.input_tokens}")
|
||||||
|
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
|
first_byte_time_ms = None # 预先记录 TTFB,避免 yield 后计算不准确
|
||||||
try:
|
try:
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
# 保存原始字节流(用于错误诊断)
|
# 保存原始字节流(用于错误诊断)
|
||||||
self.raw_chunks.append(chunk)
|
self.raw_chunks.append(chunk)
|
||||||
|
|
||||||
# 返回原始块给客户端
|
# 第一个 chunk 收到时,记录 TTFB 时间点(但先不更新数据库,避免阻塞)
|
||||||
|
if chunk_count == 1:
|
||||||
|
# 计算 TTFB(使用请求原始开始时间或 track_stream 开始时间)
|
||||||
|
base_time = self.request_start_time or self.start_time
|
||||||
|
first_byte_time_ms = int((time.time() - base_time) * 1000) if base_time else None
|
||||||
|
|
||||||
|
# 先返回原始块给客户端,确保 TTFB 不受数据库操作影响
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
# yield 后再更新数据库状态(仅第一个 chunk 时执行)
|
||||||
|
if chunk_count == 1 and self.request_id:
|
||||||
|
try:
|
||||||
|
UsageService.update_usage_status(
|
||||||
|
db=self.db,
|
||||||
|
request_id=self.request_id,
|
||||||
|
status="streaming",
|
||||||
|
provider=self.provider,
|
||||||
|
first_byte_time_ms=first_byte_time_ms,
|
||||||
|
provider_id=self.provider_id,
|
||||||
|
provider_endpoint_id=self.provider_endpoint_id,
|
||||||
|
provider_api_key_id=self.provider_api_key_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
|
||||||
|
|
||||||
# 解析块以提取内容和使用信息(chunk是原始字节)
|
# 解析块以提取内容和使用信息(chunk是原始字节)
|
||||||
content, usage = self.parse_stream_chunk(chunk)
|
content, usage = self.parse_stream_chunk(chunk)
|
||||||
|
|
||||||
|
|||||||
@@ -25,9 +25,10 @@ class ApiKeyService:
|
|||||||
allowed_providers: Optional[List[str]] = None,
|
allowed_providers: Optional[List[str]] = None,
|
||||||
allowed_api_formats: Optional[List[str]] = None,
|
allowed_api_formats: Optional[List[str]] = None,
|
||||||
allowed_models: Optional[List[str]] = None,
|
allowed_models: Optional[List[str]] = None,
|
||||||
rate_limit: int = 100,
|
rate_limit: Optional[int] = None,
|
||||||
concurrent_limit: int = 5,
|
concurrent_limit: int = 5,
|
||||||
expire_days: Optional[int] = None,
|
expire_days: Optional[int] = None,
|
||||||
|
expires_at: Optional[datetime] = None, # 直接传入过期时间,优先于 expire_days
|
||||||
initial_balance_usd: Optional[float] = None,
|
initial_balance_usd: Optional[float] = None,
|
||||||
is_standalone: bool = False,
|
is_standalone: bool = False,
|
||||||
auto_delete_on_expiry: bool = False,
|
auto_delete_on_expiry: bool = False,
|
||||||
@@ -44,6 +45,7 @@ class ApiKeyService:
|
|||||||
rate_limit: 速率限制
|
rate_limit: 速率限制
|
||||||
concurrent_limit: 并发限制
|
concurrent_limit: 并发限制
|
||||||
expire_days: 过期天数,None = 永不过期
|
expire_days: 过期天数,None = 永不过期
|
||||||
|
expires_at: 直接指定过期时间,优先于 expire_days
|
||||||
initial_balance_usd: 初始余额(USD),仅用于独立Key,None = 无限制
|
initial_balance_usd: 初始余额(USD),仅用于独立Key,None = 无限制
|
||||||
is_standalone: 是否为独立余额Key(仅管理员可创建)
|
is_standalone: 是否为独立余额Key(仅管理员可创建)
|
||||||
auto_delete_on_expiry: 过期后是否自动删除(True=物理删除,False=仅禁用)
|
auto_delete_on_expiry: 过期后是否自动删除(True=物理删除,False=仅禁用)
|
||||||
@@ -54,10 +56,10 @@ class ApiKeyService:
|
|||||||
key_hash = ApiKey.hash_key(key)
|
key_hash = ApiKey.hash_key(key)
|
||||||
key_encrypted = crypto_service.encrypt(key) # 加密存储密钥
|
key_encrypted = crypto_service.encrypt(key) # 加密存储密钥
|
||||||
|
|
||||||
# 计算过期时间
|
# 计算过期时间:优先使用 expires_at,其次使用 expire_days
|
||||||
expires_at = None
|
final_expires_at = expires_at
|
||||||
if expire_days:
|
if final_expires_at is None and expire_days:
|
||||||
expires_at = datetime.now(timezone.utc) + timedelta(days=expire_days)
|
final_expires_at = datetime.now(timezone.utc) + timedelta(days=expire_days)
|
||||||
|
|
||||||
# 空数组转为 None(表示不限制)
|
# 空数组转为 None(表示不限制)
|
||||||
api_key = ApiKey(
|
api_key = ApiKey(
|
||||||
@@ -70,7 +72,7 @@ class ApiKeyService:
|
|||||||
allowed_models=allowed_models or None,
|
allowed_models=allowed_models or None,
|
||||||
rate_limit=rate_limit,
|
rate_limit=rate_limit,
|
||||||
concurrent_limit=concurrent_limit,
|
concurrent_limit=concurrent_limit,
|
||||||
expires_at=expires_at,
|
expires_at=final_expires_at,
|
||||||
balance_used_usd=0.0,
|
balance_used_usd=0.0,
|
||||||
current_balance_usd=initial_balance_usd, # 直接使用初始余额,None = 无限制
|
current_balance_usd=initial_balance_usd, # 直接使用初始余额,None = 无限制
|
||||||
is_standalone=is_standalone,
|
is_standalone=is_standalone,
|
||||||
@@ -145,6 +147,9 @@ class ApiKeyService:
|
|||||||
# 允许显式设置为空数组/None 的字段(空数组会转为 None,表示"全部")
|
# 允许显式设置为空数组/None 的字段(空数组会转为 None,表示"全部")
|
||||||
nullable_list_fields = {"allowed_providers", "allowed_api_formats", "allowed_models"}
|
nullable_list_fields = {"allowed_providers", "allowed_api_formats", "allowed_models"}
|
||||||
|
|
||||||
|
# 允许显式设置为 None 的字段(如 expires_at=None 表示永不过期,rate_limit=None 表示无限制)
|
||||||
|
nullable_fields = {"expires_at", "rate_limit"}
|
||||||
|
|
||||||
for field, value in kwargs.items():
|
for field, value in kwargs.items():
|
||||||
if field not in updatable_fields:
|
if field not in updatable_fields:
|
||||||
continue
|
continue
|
||||||
@@ -153,6 +158,9 @@ class ApiKeyService:
|
|||||||
if value is not None:
|
if value is not None:
|
||||||
# 空数组转为 None(表示允许全部)
|
# 空数组转为 None(表示允许全部)
|
||||||
setattr(api_key, field, value if value else None)
|
setattr(api_key, field, value if value else None)
|
||||||
|
elif field in nullable_fields:
|
||||||
|
# 这些字段允许显式设置为 None
|
||||||
|
setattr(api_key, field, value)
|
||||||
elif value is not None:
|
elif value is not None:
|
||||||
setattr(api_key, field, value)
|
setattr(api_key, field, value)
|
||||||
|
|
||||||
|
|||||||
@@ -49,8 +49,16 @@ def cache_result(key_prefix: str, ttl: int = 60, user_specific: bool = True) ->
|
|||||||
# 尝试从缓存获取
|
# 尝试从缓存获取
|
||||||
cached = await redis_client.get(cache_key)
|
cached = await redis_client.get(cache_key)
|
||||||
if cached:
|
if cached:
|
||||||
logger.debug(f"缓存命中: {cache_key}")
|
try:
|
||||||
return json.loads(cached)
|
result = json.loads(cached)
|
||||||
|
logger.debug(f"缓存命中: {cache_key}")
|
||||||
|
return result
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning(f"缓存解析失败,删除损坏缓存: {cache_key}, 错误: {e}")
|
||||||
|
try:
|
||||||
|
await redis_client.delete(cache_key)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# 执行原函数
|
# 执行原函数
|
||||||
result = await func(*args, **kwargs)
|
result = await func(*args, **kwargs)
|
||||||
|
|||||||
Reference in New Issue
Block a user