19 Commits

Author SHA1 Message Date
fawney19
465da6f818 feat: OpenAI 流式响应解析器支持提取 usage 信息
部分 OpenAI 兼容 API(如豆包)会在最后一个 chunk 中发送 usage 信息,
现在可以正确提取 prompt_tokens 和 completion_tokens。
2026-01-05 12:50:05 +08:00
fawney19
e5f12fddd9 feat: 流式预读增强与自适应并发算法优化
流式预读增强:
- 新增预读字节上限(64KB),防止无换行响应导致内存增长
- 预读结束后检测非 SSE 格式的错误响应(HTML 页面、纯 JSON 错误)
- 抽取 check_html_response 和 check_prefetched_response_error 到 utils.py

自适应并发算法优化(边界记忆 + 渐进探测):
- 缩容策略:从乘性减少改为边界 -1,一次 429 即可收敛到真实限制附近
- 扩容策略:普通扩容不超过已知边界,探测性扩容可谨慎突破(每次 +1)
- 仅在并发限制 429 时记录边界,避免 RPM/UNKNOWN 类型覆盖
2026-01-05 12:17:45 +08:00
fawney19
4fa9a1303a feat: 优化首字时间和 streaming 状态的记录时序
改进 streaming 状态更新机制:
- 统一在首次输出时记录 TTFB 并更新 streaming 状态
- 重构 CliMessageHandlerBase 中的状态更新逻辑,消除重复
- 确保 provider/key 信息在 streaming 状态更新时已可用

前端改进:
- 添加 first_byte_time_ms 字段支持
- 管理员接口支持返回 provider/api_key_name 字段
- 优化活跃请求轮询逻辑,更准确地判断是否需要刷新完整数据

数据库与 API:
- UsageService.get_active_requests_status 添加 include_admin_fields 参数
- 管理员接口调用时启用该参数以获取额外信息
2026-01-05 10:31:34 +08:00
fawney19
43f349d415 fix: 确保 CLI handler 的 streaming 状态更新时 provider 信息已设置
在 execute_with_fallback 返回后,显式设置 ctx 的 provider 信息,
与 chat_handler_base.py 的行为保持一致,避免 streaming 状态更新
时 provider 为空的问题。
2026-01-05 09:36:35 +08:00
fawney19
02069954de fix: streaming 状态更新时传递 first_byte_time_ms 2026-01-05 09:29:38 +08:00
fawney19
2e15875fed feat: 端点 API 支持 custom_path 字段
- ProviderEndpointCreate 添加 custom_path 参数
- ProviderEndpointUpdate 添加 custom_path 参数
- ProviderEndpointResponse 返回 custom_path 字段
- 创建端点时传递 custom_path 到数据库模型
2026-01-05 09:22:20 +08:00
fawney19
b34cfb676d fix: streaming 状态更新时传递 provider 相关 ID 信息
在 update_usage_status 方法中增加 provider_id、provider_endpoint_id
和 provider_api_key_id 参数,确保流式请求进入 streaming 状态时
能正确记录这些字段。
2026-01-05 09:12:03 +08:00
fawney19
3064497636 refactor: 改进上游错误消息的提取和传递
- 新增 extract_error_message 工具函数,统一错误消息提取逻辑
- 在 HTTPStatusError 异常上附加 upstream_response 属性,保留原始错误
- 优先使用上游响应内容作为错误消息,而非异常字符串表示
- 移除错误消息的长度限制(500/1000 字符)
- 修复边界条件检查,使用 startswith 匹配 "Unable to read" 前缀
- 简化 result.py 中的条件判断逻辑
2026-01-05 03:18:55 +08:00
fawney19
dec681fea0 fix: 统一时区处理,确保所有 datetime 带时区信息
- token_bucket.py: get_reset_time 和 Redis 后端使用 timezone.utc
- sliding_window.py: get_reset_time 和 retry_after 计算使用 timezone.utc
- provider_strategy.py: dateutil.parser 解析后确保有时区信息
2026-01-05 02:23:24 +08:00
fawney19
523e27ba9a fix: API Key 过期时间使用应用时区而非 UTC
- 后端:parse_expiry_date 使用 APP_TIMEZONE(默认 Asia/Shanghai)
- 前端:移除提示文案中的 "UTC"
2026-01-05 02:18:16 +08:00
fawney19
e7db76e581 refactor: API Key 过期时间改用日期选择器,rate_limit 支持无限制
- 前端:将过期时间设置从"天数输入"改为"日期选择器",更直观
- 后端:新增 expires_at 字段(ISO 日期格式),兼容旧版 expire_days
- rate_limit 字段现在支持 null 表示无限制,移除默认值 100
- 解析逻辑:过期时间设为当天 UTC 23:59:59.999999
2026-01-05 02:16:16 +08:00
fawney19
689339117a refactor: 提取 ModelMultiSelect 组件并支持失效模型检测
- 新增 ModelMultiSelect 组件,支持显示和移除已失效的模型
- 新增 useInvalidModels composable 检测 allowed_models 中的无效引用
- 重构 StandaloneKeyFormDialog 和 UserFormDialog 使用新组件
- 补充 GlobalModel 删除逻辑的设计说明注释
2026-01-05 01:20:58 +08:00
fawney19
b202765be4 perf: 优化流式响应 TTFB,将数据库状态更新移至 yield 后执行
- StreamUsageTracker: 先 yield 首个 chunk 再更新 streaming 状态
- EnhancedStreamUsageTracker: 同步添加 TTFB 记录和状态更新逻辑
- 确保客户端首字节响应不受数据库操作延迟影响
2026-01-05 00:13:23 +08:00
fawney19
3bbf3073df feat: 所有 Provider 失败时透传上游错误信息
- FallbackOrchestrator 在所有候选组合失败后保留最后的错误信息
- 从 httpx.HTTPStatusError 提取上游状态码和响应内容
- ProviderNotAvailableException 携带上游错误信息
- ErrorResponse 在返回错误时透传上游状态码和响应
2026-01-04 23:50:15 +08:00
fawney19
f46aaa2182 debug: 添加 streaming 状态更新时 provider 为空的调试日志
- base_handler: 更新 streaming 状态时检测并记录 provider 为空的情况
- cli_handler_base: 修复预读数据为空时未更新 streaming 状态的问题
- usage service: 检测状态变为 streaming 但 provider 仍为 pending 的异常
2026-01-04 23:16:01 +08:00
fawney19
a2f33a6c35 perf: 拆分热力图为独立接口并添加 Redis 缓存
- 新增独立热力图 API 端点 (/api/admin/usage/heatmap, /api/users/me/usage/heatmap)
- 添加 Redis 缓存层 (5分钟 TTL),减少数据库查询
- 用户角色变更时清除热力图缓存
- 前端并行加载统计数据和热力图,添加加载/错误状态显示
- 修复 cache_decorator 缺少 JSON 解析错误处理的问题
- 更新 docker-compose 启动命令提示
2026-01-04 22:42:58 +08:00
fawney19
b6bd6357ed perf: 优化 GlobalModel 列表查询的 N+1 问题 2026-01-04 20:05:23 +08:00
fawney19
c3a5878b1b feat: 优化用量查询分页和热力图性能
- 用量查询接口添加 limit/offset 分页参数支持
- 热力图统计从实时查询 Usage 表改为读取预计算的 StatsDaily/StatsUserDaily 表
- 修复 avg_response_time_ms 为 0 时被错误跳过的问题
2026-01-04 18:02:47 +08:00
fawney19
c02ac56da8 chore: 更新 docker-compose 命令为 docker compose
统一使用 Docker Compose V2 的现代写法
2026-01-03 01:39:45 +08:00
50 changed files with 1405 additions and 487 deletions

View File

@@ -58,13 +58,13 @@ cp .env.example .env
python generate_keys.py # 生成密钥, 并将生成的密钥填入 .env
# 3. 部署
docker-compose up -d
docker compose up -d
# 4. 首次部署时, 初始化数据库
./migrate.sh
# 5. 更新
docker-compose pull && docker-compose up -d && ./migrate.sh
docker compose pull && docker compose up -d && ./migrate.sh
```
### Docker Compose本地构建镜像
@@ -86,7 +86,7 @@ python generate_keys.py # 生成密钥, 并将生成的密钥填入 .env
```bash
# 启动依赖
docker-compose -f docker-compose.build.yml up -d postgres redis
docker compose -f docker-compose.build.yml up -d postgres redis
# 后端
uv sync

View File

@@ -30,7 +30,7 @@ from src.models.database import Base
config = context.config
# 从环境变量获取数据库 URL
# 优先使用 DATABASE_URL否则从 DB_PASSWORD 自动构建(与 docker-compose 保持一致)
# 优先使用 DATABASE_URL否则从 DB_PASSWORD 自动构建(与 docker compose 保持一致)
database_url = os.getenv("DATABASE_URL")
if not database_url:
db_password = os.getenv("DB_PASSWORD", "")

View File

@@ -1,7 +1,7 @@
# Aether 部署配置 - 本地构建
# 使用方法:
# 首次构建 base: docker build -f Dockerfile.base -t aether-base:latest .
# 启动服务: docker-compose -f docker-compose.build.yml up -d --build
# 启动服务: docker compose -f docker-compose.build.yml up -d --build
services:
postgres:

View File

@@ -1,5 +1,5 @@
# Aether 部署配置 - 使用预构建镜像
# 使用方法: docker-compose up -d
# 使用方法: docker compose up -d
services:
postgres:

View File

@@ -42,7 +42,7 @@ export interface UserApiKeyExport {
allowed_endpoints?: string[] | null
allowed_api_formats?: string[] | null
allowed_models?: string[] | null
rate_limit?: number
rate_limit?: number | null // null = 无限制
concurrent_limit?: number | null
force_capabilities?: any
is_active: boolean
@@ -220,7 +220,7 @@ export interface AdminApiKey {
total_requests?: number
total_tokens?: number
total_cost_usd?: number
rate_limit?: number
rate_limit?: number | null // null = 无限制
allowed_providers?: string[] | null // 允许的提供商列表
allowed_api_formats?: string[] | null // 允许的 API 格式列表
allowed_models?: string[] | null // 允许的模型列表
@@ -236,8 +236,8 @@ export interface CreateStandaloneApiKeyRequest {
allowed_providers?: string[] | null
allowed_api_formats?: string[] | null
allowed_models?: string[] | null
rate_limit?: number
expire_days?: number | null // null = 永不过期
rate_limit?: number | null // null = 无限制
expires_at?: string | null // ISO 日期字符串,如 "2025-12-31"null = 永不过期
initial_balance_usd: number // 初始余额,必须设置
auto_delete_on_expiry?: boolean // 过期后是否自动删除
}

View File

@@ -75,6 +75,16 @@ export interface ModelSummary {
actual_total_cost_usd?: number // 倍率消耗(仅管理员可见)
}
// 提供商统计接口
export interface ProviderSummary {
provider: string
requests: number
total_tokens: number
total_cost_usd: number
success_rate: number | null
avg_response_time_ms: number | null
}
// 使用统计响应接口
export interface UsageResponse {
total_requests: number
@@ -87,6 +97,13 @@ export interface UsageResponse {
quota_usd: number | null
used_usd: number
summary_by_model: ModelSummary[]
summary_by_provider?: ProviderSummary[]
pagination?: {
total: number
limit: number
offset: number
has_more: boolean
}
records: UsageRecordDetail[]
activity_heatmap?: ActivityHeatmap | null
}
@@ -175,6 +192,8 @@ export const meApi = {
async getUsage(params?: {
start_date?: string
end_date?: string
limit?: number
offset?: number
}): Promise<UsageResponse> {
const response = await apiClient.get<UsageResponse>('/api/users/me/usage', { params })
return response.data
@@ -184,11 +203,12 @@ export const meApi = {
async getActiveRequests(ids?: string): Promise<{
requests: Array<{
id: string
status: string
status: 'pending' | 'streaming' | 'completed' | 'failed'
input_tokens: number
output_tokens: number
cost: number
response_time_ms: number | null
first_byte_time_ms: number | null
}>
}> {
const params = ids ? { ids } : {}
@@ -267,5 +287,14 @@ export const meApi = {
}> {
const response = await apiClient.get('/api/users/me/usage/interval-timeline', { params })
return response.data
},
/**
* 获取活跃度热力图数据(用户)
* 后端已缓存5分钟
*/
async getActivityHeatmap(): Promise<ActivityHeatmap> {
const response = await apiClient.get<ActivityHeatmap>('/api/users/me/usage/heatmap')
return response.data
}
}

View File

@@ -193,10 +193,22 @@ export const usageApi = {
output_tokens: number
cost: number
response_time_ms: number | null
first_byte_time_ms: number | null
provider?: string | null
api_key_name?: string | null
}>
}> {
const params = ids?.length ? { ids: ids.join(',') } : {}
const response = await apiClient.get('/api/admin/usage/active', { params })
return response.data
},
/**
* 获取活跃度热力图数据(管理员)
* 后端已缓存5分钟
*/
async getActivityHeatmap(): Promise<ActivityHeatmap> {
const response = await apiClient.get<ActivityHeatmap>('/api/admin/usage/heatmap')
return response.data
}
}

View File

@@ -0,0 +1,117 @@
<template>
<div class="space-y-2">
<Label class="text-sm font-medium">允许的模型</Label>
<div class="relative">
<button
type="button"
class="w-full h-10 px-3 border rounded-lg bg-background text-left flex items-center justify-between hover:bg-muted/50 transition-colors"
@click="isOpen = !isOpen"
>
<span :class="modelValue.length ? 'text-foreground' : 'text-muted-foreground'">
{{ modelValue.length ? `已选择 ${modelValue.length}` : '全部可用' }}
<span
v-if="invalidModels.length"
class="text-destructive"
>({{ invalidModels.length }} 个已失效)</span>
</span>
<ChevronDown
class="h-4 w-4 text-muted-foreground transition-transform"
:class="isOpen ? 'rotate-180' : ''"
/>
</button>
<div
v-if="isOpen"
class="fixed inset-0 z-[80]"
@click.stop="isOpen = false"
/>
<div
v-if="isOpen"
class="absolute z-[90] w-full mt-1 bg-popover border rounded-lg shadow-lg max-h-48 overflow-y-auto"
>
<!-- 失效模型置顶显示只能取消选择 -->
<div
v-for="modelName in invalidModels"
:key="modelName"
class="flex items-center gap-2 px-3 py-2 hover:bg-muted/50 cursor-pointer bg-destructive/5"
@click="removeModel(modelName)"
>
<input
type="checkbox"
:checked="true"
class="h-4 w-4 rounded border-gray-300 cursor-pointer"
@click.stop
@change="removeModel(modelName)"
>
<span class="text-sm text-destructive">{{ modelName }}</span>
<span class="text-xs text-destructive/70">(已失效)</span>
</div>
<!-- 有效模型 -->
<div
v-for="model in models"
:key="model.name"
class="flex items-center gap-2 px-3 py-2 hover:bg-muted/50 cursor-pointer"
@click="toggleModel(model.name)"
>
<input
type="checkbox"
:checked="modelValue.includes(model.name)"
class="h-4 w-4 rounded border-gray-300 cursor-pointer"
@click.stop
@change="toggleModel(model.name)"
>
<span class="text-sm">{{ model.name }}</span>
</div>
<div
v-if="models.length === 0 && invalidModels.length === 0"
class="px-3 py-2 text-sm text-muted-foreground"
>
暂无可用模型
</div>
</div>
</div>
</div>
</template>
<script setup lang="ts">
import { ref, computed } from 'vue'
import { Label } from '@/components/ui'
import { ChevronDown } from 'lucide-vue-next'
import { useInvalidModels } from '@/composables/useInvalidModels'
export interface ModelWithName {
name: string
}
const props = defineProps<{
modelValue: string[]
models: ModelWithName[]
}>()
const emit = defineEmits<{
'update:modelValue': [value: string[]]
}>()
const isOpen = ref(false)
// 检测失效模型
const { invalidModels } = useInvalidModels(
computed(() => props.modelValue),
computed(() => props.models)
)
function toggleModel(name: string) {
const newValue = [...props.modelValue]
const index = newValue.indexOf(name)
if (index === -1) {
newValue.push(name)
} else {
newValue.splice(index, 1)
}
emit('update:modelValue', newValue)
}
function removeModel(name: string) {
const newValue = props.modelValue.filter(m => m !== name)
emit('update:modelValue', newValue)
}
</script>

View File

@@ -7,3 +7,6 @@
export { default as EmptyState } from './EmptyState.vue'
export { default as AlertDialog } from './AlertDialog.vue'
export { default as LoadingState } from './LoadingState.vue'
// 表单组件
export { default as ModelMultiSelect } from './ModelMultiSelect.vue'

View File

@@ -0,0 +1,34 @@
import { computed, type Ref, type ComputedRef } from 'vue'
/**
* 检测失效模型的 composable
*
* 用于检测 allowed_models 中已不存在于 globalModels 的模型名称,
* 这些模型可能已被删除但引用未清理。
*
* @example
* ```typescript
* const { invalidModels } = useInvalidModels(
* computed(() => form.value.allowed_models),
* globalModels
* )
* ```
*/
export interface ModelWithName {
name: string
}
export function useInvalidModels<T extends ModelWithName>(
allowedModels: Ref<string[]> | ComputedRef<string[]>,
globalModels: Ref<T[]>
): { invalidModels: ComputedRef<string[]> } {
const validModelNames = computed(() =>
new Set(globalModels.value.map(m => m.name))
)
const invalidModels = computed(() =>
allowedModels.value.filter(name => !validModelNames.value.has(name))
)
return { invalidModels }
}

View File

@@ -79,45 +79,45 @@
<div class="space-y-2">
<Label
for="form-expire-days"
for="form-expires-at"
class="text-sm font-medium"
>有效期设置</Label>
<div class="flex items-center gap-2">
<Input
id="form-expire-days"
:model-value="form.expire_days ?? ''"
type="number"
min="1"
max="3650"
placeholder="天数"
:class="form.never_expire ? 'flex-1 h-9 opacity-50' : 'flex-1 h-9'"
:disabled="form.never_expire"
@update:model-value="(v) => form.expire_days = parseNumberInput(v, { min: 1, max: 3650 })"
/>
<label class="flex items-center gap-1.5 border rounded-md px-2 py-1.5 bg-muted/50 cursor-pointer text-xs whitespace-nowrap">
<input
v-model="form.never_expire"
type="checkbox"
class="h-3.5 w-3.5 rounded border-gray-300 cursor-pointer"
@change="onNeverExpireChange"
<div class="relative flex-1">
<Input
id="form-expires-at"
:model-value="form.expires_at || ''"
type="date"
:min="minExpiryDate"
class="h-9 pr-8"
:placeholder="form.expires_at ? '' : '永不过期'"
@update:model-value="(v) => form.expires_at = v || undefined"
/>
<button
v-if="form.expires_at"
type="button"
class="absolute right-2 top-1/2 -translate-y-1/2 text-muted-foreground hover:text-foreground"
title="清空永不过期"
@click="clearExpiryDate"
>
永不过期
</label>
<X class="h-4 w-4" />
</button>
</div>
<label
class="flex items-center gap-1.5 border rounded-md px-2 py-1.5 bg-muted/50 cursor-pointer text-xs whitespace-nowrap"
:class="form.never_expire ? 'opacity-50' : ''"
:class="!form.expires_at ? 'opacity-50 cursor-not-allowed' : ''"
>
<input
v-model="form.auto_delete_on_expiry"
type="checkbox"
class="h-3.5 w-3.5 rounded border-gray-300 cursor-pointer"
:disabled="form.never_expire"
:disabled="!form.expires_at"
>
到期删除
</label>
</div>
<p class="text-xs text-muted-foreground">
不勾选"到期删除"则仅禁用
{{ form.expires_at ? '到期后' + (form.auto_delete_on_expiry ? '自动删除' : '仅禁用') + '(当天 23:59 失效)' : '留空表示永不过期' }}
</p>
</div>
@@ -244,55 +244,10 @@
</div>
<!-- 模型多选下拉框 -->
<div class="space-y-2">
<Label class="text-sm font-medium">允许的模型</Label>
<div class="relative">
<button
type="button"
class="w-full h-10 px-3 border rounded-lg bg-background text-left flex items-center justify-between hover:bg-muted/50 transition-colors"
@click="modelDropdownOpen = !modelDropdownOpen"
>
<span :class="form.allowed_models.length ? 'text-foreground' : 'text-muted-foreground'">
{{ form.allowed_models.length ? `已选择 ${form.allowed_models.length} 个` : '全部可用' }}
</span>
<ChevronDown
class="h-4 w-4 text-muted-foreground transition-transform"
:class="modelDropdownOpen ? 'rotate-180' : ''"
/>
</button>
<div
v-if="modelDropdownOpen"
class="fixed inset-0 z-[80]"
@click.stop="modelDropdownOpen = false"
/>
<div
v-if="modelDropdownOpen"
class="absolute z-[90] w-full mt-1 bg-popover border rounded-lg shadow-lg max-h-48 overflow-y-auto"
>
<div
v-for="model in globalModels"
:key="model.name"
class="flex items-center gap-2 px-3 py-2 hover:bg-muted/50 cursor-pointer"
@click="toggleSelection('allowed_models', model.name)"
>
<input
type="checkbox"
:checked="form.allowed_models.includes(model.name)"
class="h-4 w-4 rounded border-gray-300 cursor-pointer"
@click.stop
@change="toggleSelection('allowed_models', model.name)"
>
<span class="text-sm">{{ model.name }}</span>
</div>
<div
v-if="globalModels.length === 0"
class="px-3 py-2 text-sm text-muted-foreground"
>
暂无可用模型
</div>
</div>
</div>
</div>
<ModelMultiSelect
v-model="form.allowed_models"
:models="globalModels"
/>
</div>
</div>
</form>
@@ -325,8 +280,9 @@ import {
Input,
Label,
} from '@/components/ui'
import { Plus, SquarePen, Key, Shield, ChevronDown } from 'lucide-vue-next'
import { Plus, SquarePen, Key, Shield, ChevronDown, X } from 'lucide-vue-next'
import { useFormDialog } from '@/composables/useFormDialog'
import { ModelMultiSelect } from '@/components/common'
import { getProvidersSummary } from '@/api/endpoints/providers'
import { getGlobalModels } from '@/api/global-models'
import { adminApi } from '@/api/admin'
@@ -338,8 +294,7 @@ export interface StandaloneKeyFormData {
id?: string
name: string
initial_balance_usd?: number
expire_days?: number
never_expire: boolean
expires_at?: string // ISO 日期字符串,如 "2025-12-31"undefined = 永不过期
rate_limit?: number
auto_delete_on_expiry: boolean
allowed_providers: string[]
@@ -363,7 +318,6 @@ const saving = ref(false)
// 下拉框状态
const providerDropdownOpen = ref(false)
const apiFormatDropdownOpen = ref(false)
const modelDropdownOpen = ref(false)
// 选项数据
const providers = ref<ProviderWithEndpointsSummary[]>([])
@@ -374,8 +328,7 @@ const allApiFormats = ref<string[]>([])
const form = ref<StandaloneKeyFormData>({
name: '',
initial_balance_usd: 10,
expire_days: undefined,
never_expire: true,
expires_at: undefined,
rate_limit: undefined,
auto_delete_on_expiry: false,
allowed_providers: [],
@@ -383,12 +336,18 @@ const form = ref<StandaloneKeyFormData>({
allowed_models: []
})
// 计算最小可选日期(明天)
const minExpiryDate = computed(() => {
const tomorrow = new Date()
tomorrow.setDate(tomorrow.getDate() + 1)
return tomorrow.toISOString().split('T')[0]
})
function resetForm() {
form.value = {
name: '',
initial_balance_usd: 10,
expire_days: undefined,
never_expire: true,
expires_at: undefined,
rate_limit: undefined,
auto_delete_on_expiry: false,
allowed_providers: [],
@@ -397,7 +356,6 @@ function resetForm() {
}
providerDropdownOpen.value = false
apiFormatDropdownOpen.value = false
modelDropdownOpen.value = false
}
function loadKeyData() {
@@ -406,8 +364,7 @@ function loadKeyData() {
id: props.apiKey.id,
name: props.apiKey.name || '',
initial_balance_usd: props.apiKey.initial_balance_usd,
expire_days: props.apiKey.expire_days,
never_expire: props.apiKey.never_expire,
expires_at: props.apiKey.expires_at,
rate_limit: props.apiKey.rate_limit,
auto_delete_on_expiry: props.apiKey.auto_delete_on_expiry,
allowed_providers: props.apiKey.allowed_providers || [],
@@ -452,12 +409,10 @@ function toggleSelection(field: 'allowed_providers' | 'allowed_api_formats' | 'a
}
}
// 永不过期切换
function onNeverExpireChange() {
if (form.value.never_expire) {
form.value.expire_days = undefined
form.value.auto_delete_on_expiry = false
}
// 清空过期日期(同时清空到期删除选项)
function clearExpiryDate() {
form.value.expires_at = undefined
form.value.auto_delete_on_expiry = false
}
// 提交表单

View File

@@ -18,8 +18,22 @@
<span class="flex-shrink-0"></span>
</div>
</div>
<div
v-if="isLoading"
class="h-full min-h-[160px] flex items-center justify-center text-sm text-muted-foreground"
>
<Loader2 class="h-5 w-5 animate-spin mr-2" />
加载中...
</div>
<div
v-else-if="hasError"
class="h-full min-h-[160px] flex items-center justify-center text-sm text-destructive"
>
<AlertCircle class="h-4 w-4 mr-1.5" />
加载失败
</div>
<ActivityHeatmap
v-if="hasData"
v-else-if="hasData"
:data="data"
:show-header="false"
/>
@@ -34,6 +48,7 @@
<script setup lang="ts">
import { computed } from 'vue'
import { Loader2, AlertCircle } from 'lucide-vue-next'
import Card from '@/components/ui/card.vue'
import ActivityHeatmap from '@/components/stats/ActivityHeatmap.vue'
import type { ActivityHeatmap as ActivityHeatmapData } from '@/types/activity'
@@ -41,6 +56,8 @@ import type { ActivityHeatmap as ActivityHeatmapData } from '@/types/activity'
const props = defineProps<{
data: ActivityHeatmapData | null
title: string
isLoading?: boolean
hasError?: boolean
}>()
const legendLevels = [0.08, 0.25, 0.45, 0.65, 0.85]

View File

@@ -64,9 +64,6 @@ export function useUsageData(options: UseUsageDataOptions) {
}))
})
// 活跃度热图数据
const activityHeatmapData = computed(() => stats.value.activity_heatmap)
// 加载统计数据(不加载记录)
async function loadStats(dateRange?: DateRangeParams) {
isLoadingStats.value = true
@@ -93,7 +90,7 @@ export function useUsageData(options: UseUsageDataOptions) {
cache_stats: (statsData as any).cache_stats,
period_start: '',
period_end: '',
activity_heatmap: statsData.activity_heatmap || null
activity_heatmap: null
}
modelStats.value = modelData.map(item => ({
@@ -143,7 +140,7 @@ export function useUsageData(options: UseUsageDataOptions) {
avg_response_time: userData.avg_response_time || 0,
period_start: '',
period_end: '',
activity_heatmap: userData.activity_heatmap || null
activity_heatmap: null
}
modelStats.value = (userData.summary_by_model || []).map((item: any) => ({
@@ -305,7 +302,6 @@ export function useUsageData(options: UseUsageDataOptions) {
// 计算属性
enhancedModelStats,
activityHeatmapData,
// 方法
loadStats,

View File

@@ -1,5 +1,3 @@
import type { ActivityHeatmap } from '@/types/activity'
// 统计数据状态
export interface UsageStatsState {
total_requests: number
@@ -17,7 +15,6 @@ export interface UsageStatsState {
}
period_start: string
period_end: string
activity_heatmap: ActivityHeatmap | null
}
// 模型统计
@@ -115,7 +112,6 @@ export function createDefaultStats(): UsageStatsState {
error_rate: undefined,
cache_stats: undefined,
period_start: '',
period_end: '',
activity_heatmap: null
period_end: ''
}
}

View File

@@ -316,55 +316,10 @@
</div>
<!-- 模型多选下拉框 -->
<div class="space-y-2">
<Label class="text-sm font-medium">允许的模型</Label>
<div class="relative">
<button
type="button"
class="w-full h-10 px-3 border rounded-lg bg-background text-left flex items-center justify-between hover:bg-muted/50 transition-colors"
@click="modelDropdownOpen = !modelDropdownOpen"
>
<span :class="form.allowed_models.length ? 'text-foreground' : 'text-muted-foreground'">
{{ form.allowed_models.length ? `已选择 ${form.allowed_models.length}` : '全部可用' }}
</span>
<ChevronDown
class="h-4 w-4 text-muted-foreground transition-transform"
:class="modelDropdownOpen ? 'rotate-180' : ''"
/>
</button>
<div
v-if="modelDropdownOpen"
class="fixed inset-0 z-[80]"
@click.stop="modelDropdownOpen = false"
/>
<div
v-if="modelDropdownOpen"
class="absolute z-[90] w-full mt-1 bg-popover border rounded-lg shadow-lg max-h-48 overflow-y-auto"
>
<div
v-for="model in globalModels"
:key="model.name"
class="flex items-center gap-2 px-3 py-2 hover:bg-muted/50 cursor-pointer"
@click="toggleSelection('allowed_models', model.name)"
>
<input
type="checkbox"
:checked="form.allowed_models.includes(model.name)"
class="h-4 w-4 rounded border-gray-300 cursor-pointer"
@click.stop
@change="toggleSelection('allowed_models', model.name)"
>
<span class="text-sm">{{ model.name }}</span>
</div>
<div
v-if="globalModels.length === 0"
class="px-3 py-2 text-sm text-muted-foreground"
>
暂无可用模型
</div>
</div>
</div>
</div>
<ModelMultiSelect
v-model="form.allowed_models"
:models="globalModels"
/>
</div>
</div>
</form>
@@ -404,10 +359,12 @@ import {
} from '@/components/ui'
import { UserPlus, SquarePen, ChevronDown } from 'lucide-vue-next'
import { useFormDialog } from '@/composables/useFormDialog'
import { ModelMultiSelect } from '@/components/common'
import { getProvidersSummary } from '@/api/endpoints/providers'
import { getGlobalModels } from '@/api/global-models'
import { adminApi } from '@/api/admin'
import { log } from '@/utils/logger'
import type { ProviderWithEndpointsSummary, GlobalModelResponse } from '@/api/endpoints/types'
export interface UserFormData {
id?: string
@@ -440,11 +397,10 @@ const roleSelectOpen = ref(false)
// 下拉框状态
const providerDropdownOpen = ref(false)
const endpointDropdownOpen = ref(false)
const modelDropdownOpen = ref(false)
// 选项数据
const providers = ref<any[]>([])
const globalModels = ref<any[]>([])
const providers = ref<ProviderWithEndpointsSummary[]>([])
const globalModels = ref<GlobalModelResponse[]>([])
const apiFormats = ref<Array<{ value: string; label: string }>>([])
// 表单数据

View File

@@ -850,28 +850,20 @@ async function deleteApiKey(apiKey: AdminApiKey) {
}
function editApiKey(apiKey: AdminApiKey) {
// 计算过期天数
let expireDays: number | undefined = undefined
let neverExpire = true
// 解析过期日期为 YYYY-MM-DD 格式
// 保留原始日期,不做时间过滤(避免编辑当天过期的 Key 时意外清空)
let expiresAt: string | undefined = undefined
if (apiKey.expires_at) {
const expiresDate = new Date(apiKey.expires_at)
const now = new Date()
const diffMs = expiresDate.getTime() - now.getTime()
const diffDays = Math.ceil(diffMs / (1000 * 60 * 60 * 24))
if (diffDays > 0) {
expireDays = diffDays
neverExpire = false
}
expiresAt = expiresDate.toISOString().split('T')[0]
}
editingKeyData.value = {
id: apiKey.id,
name: apiKey.name || '',
expire_days: expireDays,
never_expire: neverExpire,
rate_limit: apiKey.rate_limit || 100,
expires_at: expiresAt,
rate_limit: apiKey.rate_limit ?? undefined,
auto_delete_on_expiry: apiKey.auto_delete_on_expiry || false,
allowed_providers: apiKey.allowed_providers || [],
allowed_api_formats: apiKey.allowed_api_formats || [],
@@ -1033,14 +1025,25 @@ function closeKeyFormDialog() {
// 统一处理表单提交
async function handleKeyFormSubmit(data: StandaloneKeyFormData) {
// 验证过期日期(如果设置了,必须晚于今天)
if (data.expires_at) {
const selectedDate = new Date(data.expires_at)
const today = new Date()
today.setHours(0, 0, 0, 0)
if (selectedDate <= today) {
error('过期日期必须晚于今天')
return
}
}
keyFormDialogRef.value?.setSaving(true)
try {
if (data.id) {
// 更新
const updateData: Partial<CreateStandaloneApiKeyRequest> = {
name: data.name || undefined,
rate_limit: data.rate_limit,
expire_days: data.never_expire ? null : (data.expire_days || null),
rate_limit: data.rate_limit ?? null, // undefined = 无限制,显式传 null
expires_at: data.expires_at || null, // undefined/空 = 永不过期
auto_delete_on_expiry: data.auto_delete_on_expiry,
// 空数组表示清除限制(允许全部),后端会将空数组存为 NULL
allowed_providers: data.allowed_providers,
@@ -1058,8 +1061,8 @@ async function handleKeyFormSubmit(data: StandaloneKeyFormData) {
const createData: CreateStandaloneApiKeyRequest = {
name: data.name || undefined,
initial_balance_usd: data.initial_balance_usd,
rate_limit: data.rate_limit,
expire_days: data.never_expire ? null : (data.expire_days || null),
rate_limit: data.rate_limit ?? null, // undefined = 无限制,显式传 null
expires_at: data.expires_at || null, // undefined/空 = 永不过期
auto_delete_on_expiry: data.auto_delete_on_expiry,
// 空数组表示不设置限制(允许全部),后端会将空数组存为 NULL
allowed_providers: data.allowed_providers,

View File

@@ -5,6 +5,8 @@
<ActivityHeatmapCard
:data="activityHeatmapData"
:title="isAdminPage ? '总体活跃天数' : '我的活跃天数'"
:is-loading="isLoadingHeatmap"
:has-error="heatmapError"
/>
<IntervalTimelineCard
:title="isAdminPage ? '请求间隔时间线' : '我的请求间隔'"
@@ -112,8 +114,11 @@ import {
import type { PeriodValue, FilterStatusValue } from '@/features/usage/types'
import type { UserOption } from '@/features/usage/components/UsageRecordsTable.vue'
import { log } from '@/utils/logger'
import type { ActivityHeatmap } from '@/types/activity'
import { useToast } from '@/composables/useToast'
const route = useRoute()
const { warning } = useToast()
const authStore = useAuthStore()
// 判断是否是管理员页面
@@ -144,13 +149,35 @@ const {
currentRecords,
totalRecords,
enhancedModelStats,
activityHeatmapData,
availableModels,
availableProviders,
loadStats,
loadRecords
} = useUsageData({ isAdminPage })
// 热力图状态
const activityHeatmapData = ref<ActivityHeatmap | null>(null)
const isLoadingHeatmap = ref(false)
const heatmapError = ref(false)
// 加载热力图数据
async function loadHeatmapData() {
isLoadingHeatmap.value = true
heatmapError.value = false
try {
if (isAdminPage.value) {
activityHeatmapData.value = await usageApi.getActivityHeatmap()
} else {
activityHeatmapData.value = await meApi.getActivityHeatmap()
}
} catch (error) {
log.error('加载热力图数据失败:', error)
heatmapError.value = true
} finally {
isLoadingHeatmap.value = false
}
}
// 用户页面需要前端筛选
const filteredRecords = computed(() => {
if (!isAdminPage.value) {
@@ -232,27 +259,40 @@ async function pollActiveRequests() {
? await usageApi.getActiveRequests(activeRequestIds.value)
: await meApi.getActiveRequests(idsParam)
// 检查是否有状态变化
let hasChanges = false
let shouldRefresh = false
for (const update of requests) {
const record = currentRecords.value.find(r => r.id === update.id)
if (record && record.status !== update.status) {
hasChanges = true
// 如果状态变为 completed 或 failed需要刷新获取完整数据
if (update.status === 'completed' || update.status === 'failed') {
break
}
// 否则只更新状态和 token 信息
if (!record) {
// 后端返回了未知的活跃请求,触发刷新以获取完整数据
shouldRefresh = true
continue
}
// 状态变化completed/failed 需要刷新获取完整数据
if (record.status !== update.status) {
record.status = update.status
record.input_tokens = update.input_tokens
record.output_tokens = update.output_tokens
record.cost = update.cost
record.response_time_ms = update.response_time_ms ?? undefined
}
if (update.status === 'completed' || update.status === 'failed') {
shouldRefresh = true
}
// 进行中状态也需要持续更新provider/key/TTFB 可能在 streaming 后才落库)
record.input_tokens = update.input_tokens
record.output_tokens = update.output_tokens
record.cost = update.cost
record.response_time_ms = update.response_time_ms ?? undefined
record.first_byte_time_ms = update.first_byte_time_ms ?? undefined
// 管理员接口返回额外字段
if ('provider' in update && typeof update.provider === 'string') {
record.provider = update.provider
}
if ('api_key_name' in update) {
record.api_key_name = typeof update.api_key_name === 'string' ? update.api_key_name : undefined
}
}
// 如果有请求完成或失败,刷新整个列表获取完整数据
if (hasChanges && requests.some(r => r.status === 'completed' || r.status === 'failed')) {
if (shouldRefresh) {
await refreshData()
}
} catch (error) {
@@ -335,7 +375,22 @@ const selectedRequestId = ref<string | null>(null)
// 初始化加载
onMounted(async () => {
const dateRange = getDateRangeFromPeriod(selectedPeriod.value)
await loadStats(dateRange)
// 并行加载统计数据和热力图(使用 allSettled 避免其中一个失败影响另一个)
const [statsResult, heatmapResult] = await Promise.allSettled([
loadStats(dateRange),
loadHeatmapData()
])
// 检查加载结果并通知用户
if (statsResult.status === 'rejected') {
log.error('加载统计数据失败:', statsResult.reason)
warning('统计数据加载失败,请刷新重试')
}
if (heatmapResult.status === 'rejected') {
log.error('加载热力图数据失败:', heatmapResult.reason)
// 热力图加载失败不提示,因为 UI 已显示占位符
}
// 管理员页面加载用户列表和第一页记录
if (isAdminPage.value) {

View File

@@ -3,22 +3,64 @@
独立余额Key不关联用户配额有独立余额限制用于给非注册用户使用。
"""
from datetime import datetime, timezone
import os
from datetime import datetime, timedelta, timezone
from typing import Optional
from zoneinfo import ZoneInfo
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import NotFoundException
from src.core.exceptions import InvalidRequestException, NotFoundException
from src.core.logger import logger
from src.database import get_db
from src.models.api import CreateApiKeyRequest
from src.models.database import ApiKey, User
from src.models.database import ApiKey
from src.services.user.apikey import ApiKeyService
# 应用时区配置,默认为 Asia/Shanghai
APP_TIMEZONE = ZoneInfo(os.getenv("APP_TIMEZONE", "Asia/Shanghai"))
def parse_expiry_date(date_str: Optional[str]) -> Optional[datetime]:
"""解析过期日期字符串为 datetime 对象。
Args:
date_str: 日期字符串,支持 "YYYY-MM-DD" 或 ISO 格式
Returns:
datetime 对象(当天 23:59:59.999999,应用时区),或 None 如果输入为空
Raises:
BadRequestException: 日期格式无效
"""
if not date_str or not date_str.strip():
return None
date_str = date_str.strip()
# 尝试 YYYY-MM-DD 格式
try:
parsed_date = datetime.strptime(date_str, "%Y-%m-%d")
# 设置为当天结束时间 (23:59:59.999999,应用时区)
return parsed_date.replace(
hour=23, minute=59, second=59, microsecond=999999, tzinfo=APP_TIMEZONE
)
except ValueError:
pass
# 尝试完整 ISO 格式
try:
return datetime.fromisoformat(date_str.replace("Z", "+00:00"))
except ValueError:
pass
raise InvalidRequestException(f"无效的日期格式: {date_str},请使用 YYYY-MM-DD 格式")
router = APIRouter(prefix="/api/admin/api-keys", tags=["Admin - API Keys (Standalone)"])
pipeline = ApiRequestPipeline()
@@ -215,6 +257,9 @@ class AdminCreateStandaloneKeyAdapter(AdminApiAdapter):
# 独立Key需要关联到管理员用户从context获取
admin_user_id = context.user.id
# 解析过期时间(优先使用 expires_at其次使用 expire_days
expires_at_dt = parse_expiry_date(self.key_data.expires_at)
# 创建独立Key
api_key, plain_key = ApiKeyService.create_api_key(
db=db,
@@ -224,7 +269,8 @@ class AdminCreateStandaloneKeyAdapter(AdminApiAdapter):
allowed_api_formats=self.key_data.allowed_api_formats,
allowed_models=self.key_data.allowed_models,
rate_limit=self.key_data.rate_limit, # None 表示不限制
expire_days=self.key_data.expire_days,
expire_days=self.key_data.expire_days, # 兼容旧版
expires_at=expires_at_dt, # 优先使用
initial_balance_usd=self.key_data.initial_balance_usd,
is_standalone=True, # 标记为独立Key
auto_delete_on_expiry=self.key_data.auto_delete_on_expiry,
@@ -270,7 +316,8 @@ class AdminUpdateApiKeyAdapter(AdminApiAdapter):
update_data = {}
if self.key_data.name is not None:
update_data["name"] = self.key_data.name
if self.key_data.rate_limit is not None:
# rate_limit: 显式传递时更新(包括 null 表示无限制)
if "rate_limit" in self.key_data.model_fields_set:
update_data["rate_limit"] = self.key_data.rate_limit
if (
hasattr(self.key_data, "auto_delete_on_expiry")
@@ -287,19 +334,21 @@ class AdminUpdateApiKeyAdapter(AdminApiAdapter):
update_data["allowed_models"] = self.key_data.allowed_models
# 处理过期时间
if self.key_data.expire_days is not None:
if self.key_data.expire_days > 0:
from datetime import timedelta
# 优先使用 expires_at如果显式传递且有值
if self.key_data.expires_at and self.key_data.expires_at.strip():
update_data["expires_at"] = parse_expiry_date(self.key_data.expires_at)
elif "expires_at" in self.key_data.model_fields_set:
# expires_at 明确传递为 null 或空字符串,设为永不过期
update_data["expires_at"] = None
# 兼容旧版 expire_days
elif "expire_days" in self.key_data.model_fields_set:
if self.key_data.expire_days is not None and self.key_data.expire_days > 0:
update_data["expires_at"] = datetime.now(timezone.utc) + timedelta(
days=self.key_data.expire_days
)
else:
# expire_days = 0 或负数表示永不过期
# expire_days = None/0/负数 表示永不过期
update_data["expires_at"] = None
elif hasattr(self.key_data, "expire_days") and self.key_data.expire_days is None:
# 明确传递 None设为永不过期
update_data["expires_at"] = None
# 使用 ApiKeyService 更新
updated_key = ApiKeyService.update_api_key(db, self.key_id, **update_data)

View File

@@ -206,6 +206,7 @@ class AdminCreateProviderEndpointAdapter(AdminApiAdapter):
provider_id=self.provider_id,
api_format=self.endpoint_data.api_format,
base_url=self.endpoint_data.base_url,
custom_path=self.endpoint_data.custom_path,
headers=self.endpoint_data.headers,
timeout=self.endpoint_data.timeout,
max_retries=self.endpoint_data.max_retries,

View File

@@ -146,20 +146,25 @@ class AdminListGlobalModelsAdapter(AdminApiAdapter):
search=self.search,
)
# 为每个 GlobalModel 添加统计数据
# 一次性查询所有 GlobalModel 的 provider_count优化 N+1 问题)
model_ids = [gm.id for gm in models]
provider_counts = {}
if model_ids:
count_results = (
context.db.query(
Model.global_model_id, func.count(func.distinct(Model.provider_id))
)
.filter(Model.global_model_id.in_(model_ids))
.group_by(Model.global_model_id)
.all()
)
provider_counts = {gm_id: count for gm_id, count in count_results}
# 构建响应
model_responses = []
for gm in models:
# 统计关联的 Model 数量(去重 Provider
provider_count = (
context.db.query(func.count(func.distinct(Model.provider_id)))
.filter(Model.global_model_id == gm.id)
.scalar()
or 0
)
response = GlobalModelResponse.model_validate(gm)
response.provider_count = provider_count
# usage_count 直接从 GlobalModel 表读取,已在 model_validate 中自动映射
response.provider_count = provider_counts.get(gm.id, 0)
model_responses.append(response)
return GlobalModelListResponse(

View File

@@ -2,7 +2,7 @@
提供商策略管理 API 端点
"""
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request
@@ -103,6 +103,9 @@ class AdminProviderBillingAdapter(AdminApiAdapter):
if config.quota_last_reset_at:
new_reset_at = parser.parse(config.quota_last_reset_at)
# 确保有时区信息,如果没有则假设为 UTC
if new_reset_at.tzinfo is None:
new_reset_at = new_reset_at.replace(tzinfo=timezone.utc)
provider.quota_last_reset_at = new_reset_at
# 自动同步该周期内的历史使用量
@@ -118,7 +121,11 @@ class AdminProviderBillingAdapter(AdminApiAdapter):
logger.info(f"Synced usage for provider {provider.name}: ${period_usage:.4f} since {new_reset_at}")
if config.quota_expires_at:
provider.quota_expires_at = parser.parse(config.quota_expires_at)
expires_at = parser.parse(config.quota_expires_at)
# 确保有时区信息,如果没有则假设为 UTC
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
provider.quota_expires_at = expires_at
db.commit()
db.refresh(provider)
@@ -149,7 +156,7 @@ class AdminProviderStatsAdapter(AdminApiAdapter):
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
since = datetime.now() - timedelta(hours=self.hours)
since = datetime.now(timezone.utc) - timedelta(hours=self.hours)
stats = (
db.query(ProviderUsageTracking)
.filter(

View File

@@ -1133,7 +1133,7 @@ class AdminImportUsersAdapter(AdminApiAdapter):
allowed_endpoints=key_data.get("allowed_endpoints"),
allowed_api_formats=key_data.get("allowed_api_formats"),
allowed_models=key_data.get("allowed_models"),
rate_limit=key_data.get("rate_limit", 100),
rate_limit=key_data.get("rate_limit"), # None = 无限制
concurrent_limit=key_data.get("concurrent_limit", 5),
force_capabilities=key_data.get("force_capabilities"),
is_active=key_data.get("is_active", True),

View File

@@ -73,6 +73,20 @@ async def get_usage_stats(
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/heatmap")
async def get_activity_heatmap(
request: Request,
db: Session = Depends(get_db),
):
"""
Get activity heatmap data for the past 365 days.
This endpoint is cached for 5 minutes to reduce database load.
"""
adapter = AdminActivityHeatmapAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/records")
async def get_usage_records(
request: Request,
@@ -168,12 +182,6 @@ class AdminUsageStatsAdapter(AdminApiAdapter):
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
).count()
activity_heatmap = UsageService.get_daily_activity(
db=db,
window_days=365,
include_actual_cost=True,
)
context.add_audit_metadata(
action="usage_stats",
start_date=self.start_date.isoformat() if self.start_date else None,
@@ -204,10 +212,22 @@ class AdminUsageStatsAdapter(AdminApiAdapter):
),
"cache_read_cost": float(cache_stats.cache_read_cost or 0) if cache_stats else 0,
},
"activity_heatmap": activity_heatmap,
}
class AdminActivityHeatmapAdapter(AdminApiAdapter):
"""Activity heatmap adapter with Redis caching."""
async def handle(self, context): # type: ignore[override]
result = await UsageService.get_cached_heatmap(
db=context.db,
user_id=None,
include_actual_cost=True,
)
context.add_audit_metadata(action="activity_heatmap")
return result
class AdminUsageByModelAdapter(AdminApiAdapter):
def __init__(self, start_date: Optional[datetime], end_date: Optional[datetime], limit: int):
self.start_date = start_date
@@ -670,7 +690,9 @@ class AdminActiveRequestsAdapter(AdminApiAdapter):
if not id_list:
return {"requests": []}
requests = UsageService.get_active_requests_status(db=db, ids=id_list)
requests = UsageService.get_active_requests_status(
db=db, ids=id_list, include_admin_fields=True
)
return {"requests": requests}

View File

@@ -248,6 +248,7 @@ class AdminUpdateUserAdapter(AdminApiAdapter):
raise InvalidRequestException("请求数据验证失败")
update_data = request.model_dump(exclude_unset=True)
old_role = existing_user.role
if "role" in update_data and update_data["role"]:
if hasattr(update_data["role"], "value"):
update_data["role"] = update_data["role"]
@@ -258,6 +259,12 @@ class AdminUpdateUserAdapter(AdminApiAdapter):
if not user:
raise NotFoundException("用户不存在", "user")
# 角色变更时清除热力图缓存(影响 include_actual_cost 权限)
if "role" in update_data and update_data["role"] != old_role:
from src.services.usage.service import UsageService
await UsageService.clear_user_heatmap_cache(self.user_id)
changed_fields = list(update_data.keys())
context.add_audit_metadata(
action="update_user",
@@ -424,7 +431,7 @@ class AdminCreateUserKeyAdapter(AdminApiAdapter):
name=key_data.name,
allowed_providers=key_data.allowed_providers,
allowed_models=key_data.allowed_models,
rate_limit=key_data.rate_limit or 100,
rate_limit=key_data.rate_limit, # None = 无限制
expire_days=key_data.expire_days,
initial_balance_usd=None, # 普通Key不设置余额限制
is_standalone=False, # 不是独立Key

View File

@@ -47,7 +47,6 @@ if TYPE_CHECKING:
from src.api.handlers.base.stream_context import StreamContext
class MessageTelemetry:
"""
负责记录 Usage/Audit避免处理器里重复代码。
@@ -406,7 +405,7 @@ class BaseMessageHandler:
asyncio.create_task(_do_update())
def _update_usage_to_streaming_with_ctx(self, ctx: "StreamContext") -> None:
"""更新 Usage 状态为 streaming同时更新 provider 和 target_model
"""更新 Usage 状态为 streaming同时更新 provider 相关信息
使用 asyncio 后台任务执行数据库更新,避免阻塞流式传输
@@ -414,7 +413,7 @@ class BaseMessageHandler:
并在最终 record_success 时传递到数据库,避免重复记录导致数据不一致。
Args:
ctx: 流式上下文,包含 provider_name 和 mapped_model
ctx: 流式上下文,包含 provider 相关信息
"""
import asyncio
from src.database.database import get_db
@@ -422,6 +421,17 @@ class BaseMessageHandler:
target_request_id = self.request_id
provider = ctx.provider_name
target_model = ctx.mapped_model
provider_id = ctx.provider_id
endpoint_id = ctx.endpoint_id
key_id = ctx.key_id
first_byte_time_ms = ctx.first_byte_time_ms
# 如果 provider 为空,记录警告(不应该发生,但用于调试)
if not provider:
logger.warning(
f"[{target_request_id}] 更新 streaming 状态时 provider 为空: "
f"ctx.provider_name={ctx.provider_name}, ctx.provider_id={ctx.provider_id}"
)
async def _do_update() -> None:
try:
@@ -434,6 +444,10 @@ class BaseMessageHandler:
status="streaming",
provider=provider,
target_model=target_model,
provider_id=provider_id,
provider_endpoint_id=endpoint_id,
provider_api_key_id=key_id,
first_byte_time_ms=first_byte_time_ms,
)
finally:
db.close()

View File

@@ -36,6 +36,7 @@ from src.api.handlers.base.stream_processor import StreamProcessor
from src.api.handlers.base.stream_telemetry import StreamTelemetryRecorder
from src.api.handlers.base.utils import build_sse_headers
from src.config.settings import config
from src.core.error_utils import extract_error_message
from src.core.exceptions import (
EmbeddedErrorException,
ProviderAuthException,
@@ -500,6 +501,8 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
error_text = await self._extract_error_text(e)
logger.error(f"Provider 返回错误: {e.response.status_code}\n Response: {error_text}")
await http_client.aclose()
# 将上游错误信息附加到异常,以便故障转移时能够返回给客户端
e.upstream_response = error_text # type: ignore[attr-defined]
raise
except EmbeddedErrorException:
@@ -549,7 +552,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
model=ctx.model,
response_time_ms=response_time_ms,
status_code=status_code,
error_message=str(error),
error_message=extract_error_message(error),
request_headers=original_headers,
request_body=actual_request_body,
is_stream=True,
@@ -785,7 +788,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
model=model,
response_time_ms=response_time_ms,
status_code=status_code,
error_message=str(e),
error_message=extract_error_message(e),
request_headers=original_headers,
request_body=actual_request_body,
is_stream=False,
@@ -802,10 +805,10 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
try:
if hasattr(e.response, "is_stream_consumed") and not e.response.is_stream_consumed:
error_bytes = await e.response.aread()
return error_bytes.decode("utf-8", errors="replace")[:500]
return error_bytes.decode("utf-8", errors="replace")
else:
return (
e.response.text[:500] if hasattr(e.response, "_content") else "Unable to read"
e.response.text if hasattr(e.response, "_content") else "Unable to read"
)
except Exception as decode_error:
return f"Unable to read error: {decode_error}"

View File

@@ -34,7 +34,12 @@ from src.api.handlers.base.base_handler import (
from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.request_builder import PassthroughRequestBuilder
from src.api.handlers.base.stream_context import StreamContext
from src.api.handlers.base.utils import build_sse_headers
from src.api.handlers.base.utils import (
build_sse_headers,
check_html_response,
check_prefetched_response_error,
)
from src.core.error_utils import extract_error_message
# 直接从具体模块导入,避免循环依赖
from src.api.handlers.base.response_parser import (
@@ -57,6 +62,7 @@ from src.models.database import (
ProviderEndpoint,
User,
)
from src.config.constants import StreamDefaults
from src.config.settings import config
from src.services.provider.transport import build_provider_url
from src.utils.sse_parser import SSEEventParser
@@ -328,9 +334,9 @@ class CliMessageHandlerBase(BaseMessageHandler):
stream_generator,
provider_name,
attempt_id,
_provider_id,
_endpoint_id,
_key_id,
provider_id,
endpoint_id,
key_id,
) = await self.orchestrator.execute_with_fallback(
api_format=ctx.api_format,
model_name=ctx.model,
@@ -340,7 +346,17 @@ class CliMessageHandlerBase(BaseMessageHandler):
is_stream=True,
capability_requirements=capability_requirements or None,
)
# 更新上下文(确保 provider 信息已设置,用于 streaming 状态更新)
ctx.attempt_id = attempt_id
if not ctx.provider_name:
ctx.provider_name = provider_name
if not ctx.provider_id:
ctx.provider_id = provider_id
if not ctx.endpoint_id:
ctx.endpoint_id = endpoint_id
if not ctx.key_id:
ctx.key_id = key_id
# 创建后台任务记录统计
background_tasks = BackgroundTasks()
@@ -488,6 +504,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
error_text = await self._extract_error_text(e)
logger.error(f"Provider 返回错误状态: {e.response.status_code}\n Response: {error_text}")
await http_client.aclose()
# 将上游错误信息附加到异常,以便故障转移时能够返回给客户端
e.upstream_response = error_text # type: ignore[attr-defined]
raise
except EmbeddedErrorException:
@@ -523,8 +541,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
try:
sse_parser = SSEEventParser()
last_data_time = time.time()
streaming_status_updated = False
buffer = b""
output_state = {"first_yield": True, "streaming_updated": False}
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
@@ -532,11 +550,6 @@ class CliMessageHandlerBase(BaseMessageHandler):
needs_conversion = self._needs_format_conversion(ctx)
async for chunk in stream_response.aiter_bytes():
# 在第一次输出数据前更新状态为 streaming
if not streaming_status_updated:
self._update_usage_to_streaming_with_ctx(ctx)
streaming_status_updated = True
buffer += chunk
# 处理缓冲区中的完整行
while b"\n" in buffer:
@@ -561,6 +574,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
event.get("event"),
event.get("data") or "",
)
self._mark_first_output(ctx, output_state)
yield b"\n"
continue
@@ -578,6 +592,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
},
}
self._mark_first_output(ctx, output_state)
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
return # 结束生成器
@@ -585,8 +600,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
self._mark_first_output(ctx, output_state)
yield (converted_line + "\n").encode("utf-8")
else:
self._mark_first_output(ctx, output_state)
yield (line + "\n").encode("utf-8")
for event in events:
@@ -637,7 +654,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
except httpx.RemoteProtocolError as e:
except httpx.RemoteProtocolError:
if ctx.data_count > 0:
error_event = {
"type": "error",
@@ -691,7 +708,9 @@ class CliMessageHandlerBase(BaseMessageHandler):
ProviderTimeoutException: 如果首字节超时TTFB timeout
"""
prefetched_chunks: list = []
max_prefetch_lines = 5 # 最多预读5行来检测错误
max_prefetch_lines = config.stream_prefetch_lines # 最多预读行来检测错误
max_prefetch_bytes = StreamDefaults.MAX_PREFETCH_BYTES # 避免无换行响应导致 buffer 增长
total_prefetched_bytes = 0
buffer = b""
line_count = 0
should_stop = False
@@ -718,14 +737,16 @@ class CliMessageHandlerBase(BaseMessageHandler):
provider_name=str(provider.name),
)
prefetched_chunks.append(first_chunk)
total_prefetched_bytes += len(first_chunk)
buffer += first_chunk
# 继续读取剩余的预读数据
async for chunk in aiter:
prefetched_chunks.append(chunk)
total_prefetched_bytes += len(chunk)
buffer += chunk
# 尝试按行解析缓冲区
# 尝试按行解析缓冲区SSE 格式)
while b"\n" in buffer:
line_bytes, buffer = buffer.split(b"\n", 1)
try:
@@ -742,15 +763,15 @@ class CliMessageHandlerBase(BaseMessageHandler):
normalized_line = line.rstrip("\r")
# 检测 HTML 响应base_url 配置错误的常见症状)
lower_line = normalized_line.lower()
if lower_line.startswith("<!doctype") or lower_line.startswith("<html"):
if check_html_response(normalized_line):
logger.error(
f" [{self.request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
f"base_url={endpoint.base_url}"
)
raise ProviderNotAvailableException(
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,请检查 endpoint 的 base_url 配置是否正确"
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,"
f"请检查 endpoint 的 base_url 配置是否正确"
)
if not normalized_line or normalized_line.startswith(":"):
@@ -799,9 +820,30 @@ class CliMessageHandlerBase(BaseMessageHandler):
should_stop = True
break
# 达到预读字节上限,停止继续预读(避免无换行响应导致内存增长)
if not should_stop and total_prefetched_bytes >= max_prefetch_bytes:
logger.debug(
f" [{self.request_id}] 预读达到字节上限,停止继续预读: "
f"Provider={provider.name}, bytes={total_prefetched_bytes}, "
f"max_bytes={max_prefetch_bytes}"
)
break
if should_stop or line_count >= max_prefetch_lines:
break
# 预读结束后,检查是否为非 SSE 格式的 HTML/JSON 响应
# 处理某些代理返回的纯 JSON 错误(可能无换行/多行 JSON以及 HTML 页面base_url 配置错误)
if not should_stop and prefetched_chunks:
check_prefetched_response_error(
prefetched_chunks=prefetched_chunks,
parser=provider_parser,
request_id=self.request_id,
provider_name=str(provider.name),
endpoint_id=endpoint.id,
base_url=endpoint.base_url,
)
except (EmbeddedErrorException, ProviderTimeoutException, ProviderNotAvailableException):
# 重新抛出可重试的 Provider 异常,触发故障转移
raise
@@ -833,17 +875,13 @@ class CliMessageHandlerBase(BaseMessageHandler):
sse_parser = SSEEventParser()
last_data_time = time.time()
buffer = b""
first_yield = True # 标记是否是第一次 yield
output_state = {"first_yield": True, "streaming_updated": False}
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
# 检查是否需要格式转换
needs_conversion = self._needs_format_conversion(ctx)
# 在第一次输出数据前更新状态为 streaming
if prefetched_chunks:
self._update_usage_to_streaming_with_ctx(ctx)
# 先处理预读的字节块
for chunk in prefetched_chunks:
buffer += chunk
@@ -870,10 +908,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
event.get("event"),
event.get("data") or "",
)
# 记录首字时间 (第一次 yield)
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
self._mark_first_output(ctx, output_state)
yield b"\n"
continue
@@ -883,16 +918,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
# 记录首字时间 (第一次 yield)
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
self._mark_first_output(ctx, output_state)
yield (converted_line + "\n").encode("utf-8")
else:
# 记录首字时间 (第一次 yield)
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
self._mark_first_output(ctx, output_state)
yield (line + "\n").encode("utf-8")
for event in events:
@@ -931,10 +960,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
event.get("event"),
event.get("data") or "",
)
# 记录首字时间 (第一次 yield) - 如果预读数据为空
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
self._mark_first_output(ctx, output_state)
yield b"\n"
continue
@@ -952,6 +978,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
},
}
self._mark_first_output(ctx, output_state)
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
return
@@ -959,16 +986,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
if needs_conversion:
converted_line = self._convert_sse_line(ctx, line, events)
if converted_line:
# 记录首字时间 (第一次 yield) - 如果预读数据为空
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
self._mark_first_output(ctx, output_state)
yield (converted_line + "\n").encode("utf-8")
else:
# 记录首字时间 (第一次 yield) - 如果预读数据为空
if first_yield:
ctx.record_first_byte_time(self.start_time)
first_yield = False
self._mark_first_output(ctx, output_state)
yield (line + "\n").encode("utf-8")
for event in events:
@@ -1352,7 +1373,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
model=ctx.model,
response_time_ms=response_time_ms,
status_code=status_code,
error_message=str(error),
error_message=extract_error_message(error),
request_headers=original_headers,
request_body=actual_request_body,
is_stream=True,
@@ -1620,7 +1641,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
model=model,
response_time_ms=response_time_ms,
status_code=status_code,
error_message=str(e),
error_message=extract_error_message(e),
request_headers=original_headers,
request_body=actual_request_body,
is_stream=False,
@@ -1640,14 +1661,14 @@ class CliMessageHandlerBase(BaseMessageHandler):
for encoding in ["utf-8", "gbk", "latin1"]:
try:
return error_bytes.decode(encoding)[:500]
return error_bytes.decode(encoding)
except (UnicodeDecodeError, LookupError):
continue
return error_bytes.decode("utf-8", errors="replace")[:500]
return error_bytes.decode("utf-8", errors="replace")
else:
return (
e.response.text[:500]
e.response.text
if hasattr(e.response, "_content")
else "Unable to read response"
)
@@ -1665,6 +1686,25 @@ class CliMessageHandlerBase(BaseMessageHandler):
return False
return ctx.provider_api_format.upper() != ctx.client_api_format.upper()
def _mark_first_output(self, ctx: StreamContext, state: Dict[str, bool]) -> None:
"""
标记首次输出:记录 TTFB 并更新 streaming 状态
在第一次 yield 数据前调用,确保:
1. 首字时间 (TTFB) 已记录到 ctx
2. Usage 状态已更新为 streaming包含 provider/key/TTFB 信息)
Args:
ctx: 流上下文
state: 包含 first_yield 和 streaming_updated 的状态字典
"""
if state["first_yield"]:
ctx.record_first_byte_time(self.start_time)
state["first_yield"] = False
if not state["streaming_updated"]:
self._update_usage_to_streaming_with_ctx(ctx)
state["streaming_updated"] = True
def _convert_sse_line(
self,
ctx: StreamContext,

View File

@@ -98,6 +98,17 @@ class OpenAIResponseParser(ResponseParser):
chunk.is_done = True
stats.has_completion = True
# 提取 usage 信息(某些 OpenAI 兼容 API 如豆包会在最后一个 chunk 中发送 usage
# 这个 chunk 通常 choices 为空数组,但包含完整的 usage 信息
usage = parsed.get("usage")
if usage and isinstance(usage, dict):
chunk.input_tokens = usage.get("prompt_tokens", 0)
chunk.output_tokens = usage.get("completion_tokens", 0)
# 更新 stats
stats.input_tokens = chunk.input_tokens
stats.output_tokens = chunk.output_tokens
stats.chunk_count += 1
stats.data_count += 1

View File

@@ -25,8 +25,17 @@ from src.api.handlers.base.content_extractors import (
from src.api.handlers.base.parsers import get_parser_for_format
from src.api.handlers.base.response_parser import ResponseParser
from src.api.handlers.base.stream_context import StreamContext
from src.api.handlers.base.utils import (
check_html_response,
check_prefetched_response_error,
)
from src.config.constants import StreamDefaults
from src.config.settings import config
from src.core.exceptions import EmbeddedErrorException, ProviderTimeoutException
from src.core.exceptions import (
EmbeddedErrorException,
ProviderNotAvailableException,
ProviderTimeoutException,
)
from src.core.logger import logger
from src.models.database import Provider, ProviderEndpoint
from src.utils.sse_parser import SSEEventParser
@@ -165,6 +174,7 @@ class StreamProcessor:
endpoint: ProviderEndpoint,
ctx: StreamContext,
max_prefetch_lines: int = 5,
max_prefetch_bytes: int = StreamDefaults.MAX_PREFETCH_BYTES,
) -> list:
"""
预读流的前几行,检测嵌套错误
@@ -180,12 +190,14 @@ class StreamProcessor:
endpoint: Endpoint 对象
ctx: 流式上下文
max_prefetch_lines: 最多预读行数
max_prefetch_bytes: 最多预读字节数(避免无换行响应导致 buffer 增长)
Returns:
预读的字节块列表
Raises:
EmbeddedErrorException: 如果检测到嵌套错误
ProviderNotAvailableException: 如果检测到 HTML 响应(配置错误)
ProviderTimeoutException: 如果首字节超时TTFB timeout
"""
prefetched_chunks: list = []
@@ -193,6 +205,7 @@ class StreamProcessor:
buffer = b""
line_count = 0
should_stop = False
total_prefetched_bytes = 0
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
@@ -206,11 +219,13 @@ class StreamProcessor:
provider_name=str(provider.name),
)
prefetched_chunks.append(first_chunk)
total_prefetched_bytes += len(first_chunk)
buffer += first_chunk
# 继续读取剩余的预读数据
async for chunk in aiter:
prefetched_chunks.append(chunk)
total_prefetched_bytes += len(chunk)
buffer += chunk
# 尝试按行解析缓冲区
@@ -228,10 +243,21 @@ class StreamProcessor:
line_count += 1
# 检测 HTML 响应base_url 配置错误的常见症状)
if check_html_response(line):
logger.error(
f" [{self.request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
f"base_url={endpoint.base_url}"
)
raise ProviderNotAvailableException(
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,"
f"请检查 endpoint 的 base_url 配置是否正确"
)
# 跳过空行和注释行
if not line or line.startswith(":"):
if line_count >= max_prefetch_lines:
should_stop = True
break
continue
@@ -248,7 +274,6 @@ class StreamProcessor:
data = json.loads(data_str)
except json.JSONDecodeError:
if line_count >= max_prefetch_lines:
should_stop = True
break
continue
@@ -276,14 +301,34 @@ class StreamProcessor:
should_stop = True
break
# 达到预读字节上限,停止继续预读(避免无换行响应导致内存增长)
if not should_stop and total_prefetched_bytes >= max_prefetch_bytes:
logger.debug(
f" [{self.request_id}] 预读达到字节上限,停止继续预读: "
f"Provider={provider.name}, bytes={total_prefetched_bytes}, "
f"max_bytes={max_prefetch_bytes}"
)
break
if should_stop or line_count >= max_prefetch_lines:
break
except (EmbeddedErrorException, ProviderTimeoutException):
# 预读结束后,检查是否为非 SSE 格式的 HTML/JSON 响应
if not should_stop and prefetched_chunks:
check_prefetched_response_error(
prefetched_chunks=prefetched_chunks,
parser=parser,
request_id=self.request_id,
provider_name=str(provider.name),
endpoint_id=endpoint.id,
base_url=endpoint.base_url,
)
except (EmbeddedErrorException, ProviderNotAvailableException, ProviderTimeoutException):
# 重新抛出可重试的 Provider 异常,触发故障转移
raise
except (OSError, IOError) as e:
# 网络 I/O <EFBFBD><EFBFBD><EFBFBD>常:记录警告,可能需要重试
# 网络 I/O 常:记录警告,可能需要重试
logger.warning(
f" [{self.request_id}] 预读流时发生网络异常: {type(e).__name__}: {e}"
)
@@ -332,15 +377,15 @@ class StreamProcessor:
# 处理预读数据
if prefetched_chunks:
if not streaming_started and self.on_streaming_start:
self.on_streaming_start()
streaming_started = True
for chunk in prefetched_chunks:
# 记录首字时间 (TTFB) - 在 yield 之前记录
if start_time is not None:
ctx.record_first_byte_time(start_time)
start_time = None # 只记录一次
# 首次输出前触发 streaming 回调(确保 TTFB 已写入 ctx
if not streaming_started and self.on_streaming_start:
self.on_streaming_start()
streaming_started = True
# 把原始数据转发给客户端
yield chunk
@@ -363,14 +408,14 @@ class StreamProcessor:
# 处理剩余的流数据
async for chunk in byte_iterator:
if not streaming_started and self.on_streaming_start:
self.on_streaming_start()
streaming_started = True
# 记录首字时间 (TTFB) - 在 yield 之前记录(如果预读数据为空)
if start_time is not None:
ctx.record_first_byte_time(start_time)
start_time = None # 只记录一次
# 首次输出前触发 streaming 回调(确保 TTFB 已写入 ctx
if not streaming_started and self.on_streaming_start:
self.on_streaming_start()
streaming_started = True
# 原始数据透传
yield chunk

View File

@@ -2,8 +2,10 @@
Handler 基础工具函数
"""
import json
from typing import Any, Dict, Optional
from src.core.exceptions import EmbeddedErrorException, ProviderNotAvailableException
from src.core.logger import logger
@@ -107,3 +109,95 @@ def build_sse_headers(extra_headers: Optional[Dict[str, str]] = None) -> Dict[st
if extra_headers:
headers.update(extra_headers)
return headers
def check_html_response(line: str) -> bool:
"""
检查行是否为 HTML 响应base_url 配置错误的常见症状)
Args:
line: 要检查的行内容
Returns:
True 如果检测到 HTML 响应
"""
lower_line = line.lstrip().lower()
return lower_line.startswith("<!doctype") or lower_line.startswith("<html")
def check_prefetched_response_error(
prefetched_chunks: list,
parser: Any,
request_id: str,
provider_name: str,
endpoint_id: Optional[str],
base_url: Optional[str],
) -> None:
"""
检查预读的响应是否为非 SSE 格式的错误响应HTML 或纯 JSON 错误)
某些代理可能返回:
1. HTML 页面base_url 配置错误)
2. 纯 JSON 错误(无换行或多行 JSON
Args:
prefetched_chunks: 预读的字节块列表
parser: 响应解析器(需要有 is_error_response 和 parse_response 方法)
request_id: 请求 ID用于日志
provider_name: Provider 名称
endpoint_id: Endpoint ID
base_url: Endpoint 的 base_url
Raises:
ProviderNotAvailableException: 如果检测到 HTML 响应
EmbeddedErrorException: 如果检测到 JSON 错误响应
"""
if not prefetched_chunks:
return
try:
prefetched_bytes = b"".join(prefetched_chunks)
stripped = prefetched_bytes.lstrip()
# 去除 BOM
if stripped.startswith(b"\xef\xbb\xbf"):
stripped = stripped[3:]
# HTML 响应(通常是 base_url 配置错误导致返回网页)
lower_prefix = stripped[:32].lower()
if lower_prefix.startswith(b"<!doctype") or lower_prefix.startswith(b"<html"):
endpoint_short = endpoint_id[:8] + "..." if endpoint_id else "N/A"
logger.error(
f" [{request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
f"Provider={provider_name}, Endpoint={endpoint_short}, "
f"base_url={base_url}"
)
raise ProviderNotAvailableException(
f"提供商 '{provider_name}' 返回了 HTML 页面而非 API 响应,"
f"请检查 endpoint 的 base_url 配置是否正确"
)
# 纯 JSON可能无换行/多行 JSON
if stripped.startswith(b"{") or stripped.startswith(b"["):
payload_str = stripped.decode("utf-8", errors="replace").strip()
data = json.loads(payload_str)
if isinstance(data, dict) and parser.is_error_response(data):
parsed = parser.parse_response(data, 200)
logger.warning(
f" [{request_id}] 检测到 JSON 错误响应: "
f"Provider={provider_name}, "
f"error_type={parsed.error_type}, "
f"message={parsed.error_message}"
)
raise EmbeddedErrorException(
provider_name=provider_name,
error_code=(
int(parsed.error_type)
if parsed.error_type and parsed.error_type.isdigit()
else None
),
error_message=parsed.error_message,
error_status=parsed.error_type,
)
except json.JSONDecodeError:
pass

View File

@@ -104,9 +104,11 @@ async def get_my_usage(
request: Request,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: int = Query(100, ge=1, le=200, description="每页记录数默认100最大200"),
offset: int = Query(0, ge=0, le=2000, description="偏移量用于分页最大2000"),
db: Session = Depends(get_db),
):
adapter = GetUsageAdapter(start_date=start_date, end_date=end_date)
adapter = GetUsageAdapter(start_date=start_date, end_date=end_date, limit=limit, offset=offset)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@@ -133,6 +135,20 @@ async def get_my_interval_timeline(
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/usage/heatmap")
async def get_my_activity_heatmap(
request: Request,
db: Session = Depends(get_db),
):
"""
Get user's activity heatmap data for the past 365 days.
This endpoint is cached for 5 minutes to reduce database load.
"""
adapter = GetMyActivityHeatmapAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/providers")
async def list_available_providers(request: Request, db: Session = Depends(get_db)):
adapter = ListAvailableProvidersAdapter()
@@ -471,6 +487,8 @@ class ToggleMyApiKeyAdapter(AuthenticatedApiAdapter):
class GetUsageAdapter(AuthenticatedApiAdapter):
start_date: Optional[datetime]
end_date: Optional[datetime]
limit: int = 100
offset: int = 0
async def handle(self, context): # type: ignore[override]
db = context.db
@@ -553,7 +571,7 @@ class GetUsageAdapter(AuthenticatedApiAdapter):
stats["total_cost_usd"] += item["total_cost_usd"]
# 假设 summary 中的都是成功的请求
stats["success_count"] += item["requests"]
if item.get("avg_response_time_ms"):
if item.get("avg_response_time_ms") is not None:
stats["total_response_time_ms"] += item["avg_response_time_ms"] * item["requests"]
stats["response_time_count"] += item["requests"]
@@ -582,7 +600,10 @@ class GetUsageAdapter(AuthenticatedApiAdapter):
query = query.filter(Usage.created_at >= self.start_date)
if self.end_date:
query = query.filter(Usage.created_at <= self.end_date)
usage_records = query.order_by(Usage.created_at.desc()).limit(100).all()
# 计算总数用于分页
total_records = query.count()
usage_records = query.order_by(Usage.created_at.desc()).offset(self.offset).limit(self.limit).all()
avg_resp_query = db.query(func.avg(Usage.response_time_ms)).filter(
Usage.user_id == user.id,
@@ -608,6 +629,13 @@ class GetUsageAdapter(AuthenticatedApiAdapter):
"used_usd": user.used_usd,
"summary_by_model": summary_by_model,
"summary_by_provider": summary_by_provider,
# 分页信息
"pagination": {
"total": total_records,
"limit": self.limit,
"offset": self.offset,
"has_more": self.offset + self.limit < total_records,
},
"records": [
{
"id": r.id,
@@ -636,13 +664,6 @@ class GetUsageAdapter(AuthenticatedApiAdapter):
],
}
response_data["activity_heatmap"] = UsageService.get_daily_activity(
db=db,
user_id=user.id,
window_days=365,
include_actual_cost=user.role == "admin",
)
# 管理员可以看到真实成本
if user.role == "admin":
response_data["total_actual_cost"] = total_actual_cost
@@ -709,6 +730,20 @@ class GetMyIntervalTimelineAdapter(AuthenticatedApiAdapter):
return result
class GetMyActivityHeatmapAdapter(AuthenticatedApiAdapter):
"""Activity heatmap adapter with Redis caching for user."""
async def handle(self, context): # type: ignore[override]
user = context.user
result = await UsageService.get_cached_heatmap(
db=context.db,
user_id=user.id,
include_actual_cost=user.role == "admin",
)
context.add_audit_metadata(action="activity_heatmap")
return result
class ListAvailableProvidersAdapter(AuthenticatedApiAdapter):
async def handle(self, context): # type: ignore[override]
from sqlalchemy.orm import selectinload

View File

@@ -213,7 +213,7 @@ class RedisClientManager:
f"Redis连接失败: {error_msg}\n"
"缓存亲和性功能需要Redis支持请确保Redis服务正常运行。\n"
"检查事项:\n"
"1. Redis服务是否已启动docker-compose up -d redis\n"
"1. Redis服务是否已启动docker compose up -d redis\n"
"2. 环境变量 REDIS_URL 或 REDIS_PASSWORD 是否配置正确\n"
"3. Redis端口默认6379是否可访问"
) from e

View File

@@ -21,6 +21,9 @@ class CacheTTL:
# L1 本地缓存(用于减少 Redis 访问)
L1_LOCAL = 3 # 3秒
# 活跃度热力图缓存 - 历史数据变化不频繁
ACTIVITY_HEATMAP = 300 # 5分钟
# 并发锁 TTL - 防止死锁
CONCURRENCY_LOCK = 600 # 10分钟
@@ -38,8 +41,25 @@ class CacheSize:
# ==============================================================================
class StreamDefaults:
"""流式处理默认值"""
# 预读字节上限(避免无换行响应导致内存增长)
# 64KB 基于:
# 1. SSE 单条消息通常远小于此值
# 2. 足够检测 HTML 和 JSON 错误响应
# 3. 不会占用过多内存
MAX_PREFETCH_BYTES = 64 * 1024 # 64KB
class ConcurrencyDefaults:
"""并发控制默认值"""
"""并发控制默认值
算法说明:边界记忆 + 渐进探测
- 触发 429 时记录边界last_concurrent_peak新限制 = 边界 - 1
- 扩容时不超过边界,除非是探测性扩容(长时间无 429
- 这样可以快速收敛到真实限制附近,避免过度保守
"""
# 自适应并发初始限制(宽松起步,遇到 429 再降低)
INITIAL_LIMIT = 50
@@ -69,10 +89,6 @@ class ConcurrencyDefaults:
# 扩容步长 - 每次扩容增加的并发数
INCREASE_STEP = 2
# 缩容乘数 - 遇到 429 时基于当前并发数的缩容比例
# 0.85 表示降到触发 429 时并发数的 85%
DECREASE_MULTIPLIER = 0.85
# 最大并发限制上限
MAX_CONCURRENT_LIMIT = 200
@@ -84,6 +100,7 @@ class ConcurrencyDefaults:
# === 探测性扩容参数 ===
# 探测性扩容间隔(分钟)- 长时间无 429 且有流量时尝试扩容
# 探测性扩容可以突破已知边界,尝试更高的并发
PROBE_INCREASE_INTERVAL_MINUTES = 30
# 探测性扩容最小请求数 - 在探测间隔内至少需要这么多请求

28
src/core/error_utils.py Normal file
View File

@@ -0,0 +1,28 @@
"""
错误消息处理工具函数
"""
from typing import Optional
def extract_error_message(error: Exception, status_code: Optional[int] = None) -> str:
"""
从异常中提取错误消息,优先使用上游响应内容
Args:
error: 异常对象
status_code: 可选的 HTTP 状态码,用于构建更详细的错误消息
Returns:
错误消息字符串
"""
# 优先使用 upstream_response 属性(包含上游 Provider 的原始错误)
upstream_response = getattr(error, "upstream_response", None)
if upstream_response and isinstance(upstream_response, str) and upstream_response.strip():
return str(upstream_response)
# 回退到异常的字符串表示str 可能为空,如 httpx 超时异常)
error_str = str(error) or repr(error)
if status_code is not None:
return f"HTTP {status_code}: {error_str}"
return error_str

View File

@@ -547,11 +547,19 @@ class ErrorResponse:
- 所有错误都记录到日志,通过错误 ID 关联
"""
if isinstance(e, ProxyException):
details = e.details.copy() if e.details else {}
status_code = e.status_code
message = e.message
# 如果是 ProviderNotAvailableException 且有上游错误,直接透传上游信息
if isinstance(e, ProviderNotAvailableException) and e.upstream_response:
if e.upstream_status:
status_code = e.upstream_status
message = e.upstream_response
return ErrorResponse.create(
error_type=e.error_type,
message=e.message,
status_code=e.status_code,
details=e.details,
message=message,
status_code=status_code,
details=details if details else None,
)
elif isinstance(e, HTTPException):
return ErrorResponse.create(

View File

@@ -411,7 +411,7 @@ def init_db():
print(" 3. 数据库用户名和密码是否正确", file=sys.stderr)
print("", file=sys.stderr)
print("如果使用 Docker请先运行:", file=sys.stderr)
print(" docker-compose up -d postgres redis", file=sys.stderr)
print(" docker compose -f docker-compose.build.yml up -d postgres redis", file=sys.stderr)
print("", file=sys.stderr)
print("=" * 60, file=sys.stderr)
# 使用 os._exit 直接退出,避免 uvicorn 捕获并打印堆栈

View File

@@ -309,8 +309,9 @@ class CreateApiKeyRequest(BaseModel):
allowed_endpoints: Optional[List[str]] = None # 允许使用的端点 ID 列表
allowed_api_formats: Optional[List[str]] = None # 允许使用的 API 格式列表
allowed_models: Optional[List[str]] = None # 允许使用的模型名称列表
rate_limit: Optional[int] = 100
expire_days: Optional[int] = None # None = 永不过期,数字 = 多少天后过期
rate_limit: Optional[int] = None # None = 无限制
expire_days: Optional[int] = None # None = 永不过期,数字 = 多少天后过期(兼容旧版)
expires_at: Optional[str] = None # ISO 日期字符串,如 "2025-12-31",优先于 expire_days
initial_balance_usd: Optional[float] = Field(
None, description="初始余额USD仅用于独立KeyNone = 无限制"
)

View File

@@ -150,7 +150,7 @@ class ApiKey(Base):
allowed_endpoints = Column(JSON, nullable=True) # 允许使用的端点 ID 列表
allowed_api_formats = Column(JSON, nullable=True) # 允许使用的 API 格式列表
allowed_models = Column(JSON, nullable=True) # 允许使用的模型名称列表
rate_limit = Column(Integer, default=100) # 每分钟请求限制
rate_limit = Column(Integer, default=None, nullable=True) # 每分钟请求限制None = 无限制
concurrent_limit = Column(Integer, default=5, nullable=True) # 并发请求限制
# Key 能力配置

View File

@@ -19,6 +19,7 @@ class ProviderEndpointCreate(BaseModel):
provider_id: str = Field(..., description="Provider ID")
api_format: str = Field(..., description="API 格式 (CLAUDE, OPENAI, CLAUDE_CLI, OPENAI_CLI)")
base_url: str = Field(..., min_length=1, max_length=500, description="API 基础 URL")
custom_path: Optional[str] = Field(default=None, max_length=200, description="自定义请求路径")
# 请求配置
headers: Optional[Dict[str, str]] = Field(default=None, description="自定义请求头")
@@ -62,6 +63,7 @@ class ProviderEndpointUpdate(BaseModel):
base_url: Optional[str] = Field(
default=None, min_length=1, max_length=500, description="API 基础 URL"
)
custom_path: Optional[str] = Field(default=None, max_length=200, description="自定义请求路径")
headers: Optional[Dict[str, str]] = Field(default=None, description="自定义请求头")
timeout: Optional[int] = Field(default=None, ge=10, le=600, description="超时时间(秒)")
max_retries: Optional[int] = Field(default=None, ge=0, le=10, description="最大重试次数")
@@ -94,6 +96,7 @@ class ProviderEndpointResponse(BaseModel):
# API 配置
api_format: str
base_url: str
custom_path: Optional[str] = None
# 请求配置
headers: Optional[Dict[str, str]] = None

View File

@@ -21,7 +21,7 @@ WARNING: 多进程环境注意事项
import asyncio
import time
from collections import deque
from datetime import datetime
from datetime import datetime, timezone
from typing import Any, Deque, Dict
from src.core.logger import logger
@@ -95,12 +95,12 @@ class SlidingWindow:
"""获取最早的重置时间"""
self._cleanup()
if not self.requests:
return datetime.now()
return datetime.now(timezone.utc)
# 最早的请求将在window_size秒后过期
oldest_request = self.requests[0]
reset_time = oldest_request + self.window_size
return datetime.fromtimestamp(reset_time)
return datetime.fromtimestamp(reset_time, tz=timezone.utc)
class SlidingWindowStrategy(RateLimitStrategy):
@@ -250,7 +250,7 @@ class SlidingWindowStrategy(RateLimitStrategy):
retry_after = None
if not allowed:
# 计算需要等待的时间(最早请求过期的时间)
retry_after = int((reset_at - datetime.now()).total_seconds()) + 1
retry_after = int((reset_at - datetime.now(timezone.utc)).total_seconds()) + 1
return RateLimitResult(
allowed=allowed,

View File

@@ -3,7 +3,7 @@
import asyncio
import os
import time
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional, Tuple
from ...clients.redis_client import get_redis_client_sync
@@ -63,11 +63,11 @@ class TokenBucket:
def get_reset_time(self) -> datetime:
"""获取下次完全恢复的时间"""
if self.tokens >= self.capacity:
return datetime.now()
return datetime.now(timezone.utc)
tokens_needed = self.capacity - self.tokens
seconds_to_full = tokens_needed / self.refill_rate
return datetime.now() + timedelta(seconds=seconds_to_full)
return datetime.now(timezone.utc) + timedelta(seconds=seconds_to_full)
class TokenBucketStrategy(RateLimitStrategy):
@@ -370,7 +370,7 @@ class RedisTokenBucketBackend:
if tokens is None or last_refill is None:
remaining = capacity
reset_at = datetime.now() + timedelta(seconds=capacity / refill_rate)
reset_at = datetime.now(timezone.utc) + timedelta(seconds=capacity / refill_rate)
else:
tokens_value = float(tokens)
last_refill_value = float(last_refill)
@@ -378,7 +378,7 @@ class RedisTokenBucketBackend:
tokens_value = min(capacity, tokens_value + delta * refill_rate)
remaining = int(tokens_value)
reset_after = 0 if tokens_value >= capacity else (capacity - tokens_value) / refill_rate
reset_at = datetime.now() + timedelta(seconds=reset_after)
reset_at = datetime.now(timezone.utc) + timedelta(seconds=reset_after)
allowed = remaining >= amount
retry_after = None

View File

@@ -148,6 +148,8 @@ class GlobalModelService:
删除 GlobalModel
默认行为: 级联删除所有关联的 Provider 模型实现
注意: 不清理 API Key 和 User 的 allowed_models 引用,
保留无效引用可让用户在前端看到"已失效"的模型,便于手动清理或等待重建同名模型
"""
global_model = GlobalModelService.get_global_model(db, global_model_id)

View File

@@ -237,7 +237,7 @@ class ErrorClassifier:
result["reason"] = str(data.get("reason", data.get("code", "")))
except (json.JSONDecodeError, TypeError, KeyError):
result["message"] = error_text[:500] if len(error_text) > 500 else error_text
result["message"] = error_text
return result
@@ -323,8 +323,8 @@ class ErrorClassifier:
if parts:
return ": ".join(parts) if len(parts) > 1 else parts[0]
# 无法解析,返回原始文本(截断)
return parsed["raw"][:500] if len(parsed["raw"]) > 500 else parsed["raw"]
# 无法解析,返回原始文本
return parsed["raw"]
def classify(
self,
@@ -484,11 +484,15 @@ class ErrorClassifier:
return ProviderNotAvailableException(
message=detailed_message,
provider_name=provider_name,
upstream_status=status,
upstream_response=error_response_text,
)
return ProviderNotAvailableException(
message=detailed_message,
provider_name=provider_name,
upstream_status=status,
upstream_response=error_response_text,
)
async def handle_http_error(
@@ -532,12 +536,14 @@ class ErrorClassifier:
provider_name = str(provider.name)
# 尝试读取错误响应内容
error_response_text = None
try:
if http_error.response and hasattr(http_error.response, "text"):
error_response_text = http_error.response.text[:1000] # 限制长度
except Exception:
pass
# 优先使用 handler 附加的 upstream_response 属性(流式请求中 response.text 可能为空)
error_response_text = getattr(http_error, "upstream_response", None)
if not error_response_text:
try:
if http_error.response and hasattr(http_error.response, "text"):
error_response_text = http_error.response.text
except Exception:
pass
logger.warning(f" [{request_id}] HTTP错误 (attempt={attempt}/{max_attempts}): "
f"{http_error.response.status_code if http_error.response else 'unknown'}")

View File

@@ -30,6 +30,7 @@ from redis import Redis
from sqlalchemy.orm import Session
from src.core.enums import APIFormat
from src.core.error_utils import extract_error_message
from src.core.exceptions import (
ConcurrencyLimitError,
ProviderNotAvailableException,
@@ -401,7 +402,7 @@ class FallbackOrchestrator:
db=self.db,
candidate_id=candidate_record_id,
error_type="HTTPStatusError",
error_message=f"HTTP {status_code}: {str(cause)}",
error_message=extract_error_message(cause, status_code),
status_code=status_code,
latency_ms=elapsed_ms,
concurrent_requests=captured_key_concurrent,
@@ -425,31 +426,22 @@ class FallbackOrchestrator:
attempt=attempt,
max_attempts=max_attempts,
)
# str(cause) 可能为空(如 httpx 超时异常),使用 repr() 作为备用
error_msg = str(cause) or repr(cause)
# 如果是 ProviderNotAvailableException附加上游响应
if hasattr(cause, "upstream_response") and cause.upstream_response:
error_msg = f"{error_msg} | 上游响应: {cause.upstream_response[:500]}"
RequestCandidateService.mark_candidate_failed(
db=self.db,
candidate_id=candidate_record_id,
error_type=type(cause).__name__,
error_message=error_msg,
error_message=extract_error_message(cause),
latency_ms=elapsed_ms,
concurrent_requests=captured_key_concurrent,
)
return "continue" if has_retry_left else "break"
# 未知错误:记录失败并抛出
error_msg = str(cause) or repr(cause)
# 如果是 ProviderNotAvailableException附加上游响应
if hasattr(cause, "upstream_response") and cause.upstream_response:
error_msg = f"{error_msg} | 上游响应: {cause.upstream_response[:500]}"
RequestCandidateService.mark_candidate_failed(
db=self.db,
candidate_id=candidate_record_id,
error_type=type(cause).__name__,
error_message=error_msg,
error_message=extract_error_message(cause),
latency_ms=elapsed_ms,
concurrent_requests=captured_key_concurrent,
)
@@ -543,7 +535,9 @@ class FallbackOrchestrator:
raise last_error
# 所有组合都已尝试完毕,全部失败
self._raise_all_failed_exception(request_id, max_attempts, last_candidate, model_name, api_format_enum)
self._raise_all_failed_exception(
request_id, max_attempts, last_candidate, model_name, api_format_enum, last_error
)
async def _try_candidate_with_retries(
self,
@@ -565,6 +559,7 @@ class FallbackOrchestrator:
provider = candidate.provider
endpoint = candidate.endpoint
max_retries_for_candidate = int(endpoint.max_retries) if candidate.is_cached else 1
last_error: Optional[Exception] = None
for retry_index in range(max_retries_for_candidate):
attempt_counter += 1
@@ -599,6 +594,7 @@ class FallbackOrchestrator:
return {"success": True, "response": response}
except ExecutionError as exec_err:
last_error = exec_err.cause
action = await self._handle_candidate_error(
exec_err=exec_err,
candidate=candidate,
@@ -630,6 +626,7 @@ class FallbackOrchestrator:
"success": False,
"attempt_counter": attempt_counter,
"max_attempts": max_attempts,
"error": last_error,
}
def _attach_metadata_to_error(
@@ -678,6 +675,7 @@ class FallbackOrchestrator:
last_candidate: Optional[ProviderCandidate],
model_name: str,
api_format_enum: APIFormat,
last_error: Optional[Exception] = None,
) -> NoReturn:
"""所有组合都失败时抛出异常"""
logger.error(f" [{request_id}] 所有 {max_attempts} 个组合均失败")
@@ -693,9 +691,38 @@ class FallbackOrchestrator:
"api_format": api_format_enum.value,
}
# 提取上游错误响应
upstream_status: Optional[int] = None
upstream_response: Optional[str] = None
if last_error:
# 从 httpx.HTTPStatusError 提取
if isinstance(last_error, httpx.HTTPStatusError):
upstream_status = last_error.response.status_code
# 优先使用我们附加的 upstream_response 属性(流已读取时 response.text 可能为空)
upstream_response = getattr(last_error, "upstream_response", None)
if not upstream_response:
try:
upstream_response = last_error.response.text
except Exception:
pass
# 从其他异常属性提取(如 ProviderNotAvailableException
else:
upstream_status = getattr(last_error, "upstream_status", None)
upstream_response = getattr(last_error, "upstream_response", None)
# 如果响应为空或无效,使用异常的字符串表示
if (
not upstream_response
or not upstream_response.strip()
or upstream_response.startswith("Unable to read")
):
upstream_response = str(last_error)
raise ProviderNotAvailableException(
f"所有Provider均不可用已尝试{max_attempts}个组合",
request_metadata=request_metadata,
upstream_status=upstream_status,
upstream_response=upstream_response,
)
async def execute_with_fallback(

View File

@@ -1,14 +1,16 @@
"""
自适应并发调整器 - 基于滑动窗口利用率的并发限制调整
自适应并发调整器 - 基于边界记忆的并发限制调整
核心改进(相对于旧版基于"持续高利用率"的方案):
- 使用滑动窗口采样,容忍并发波动
- 基于窗口内高利用率采样比例决策,而非要求连续高利用率
- 增加探测性扩容机制,长时间稳定时主动尝试扩容
核心算法:边界记忆 + 渐进探测
- 触发 429 时记录边界last_concurrent_peak这就是真实上限
- 缩容策略:新限制 = 边界 - 1而非乘性减少
- 扩容策略:不超过已知边界,除非是探测性扩容
- 探测性扩容:长时间无 429 时尝试突破边界
AIMD 参数说明
- 扩容:加性增加 (+INCREASE_STEP)
- 缩容:乘性减少 (*DECREASE_MULTIPLIER默认 0.85)
设计原则
1. 快速收敛:一次 429 就能找到接近真实的限制
2. 避免过度保守:不会因为多次 429 而无限下降
3. 安全探测:允许在稳定后尝试更高并发
"""
from datetime import datetime, timezone
@@ -35,21 +37,21 @@ class AdaptiveConcurrencyManager:
"""
自适应并发管理器
核心算法:基于滑动窗口利用率的 AIMD
- 滑动窗口记录最近 N 次请求的利用率
- 当窗口内高利用率采样比例 >= 60% 时触发扩容
- 遇到 429 错误时乘性减少 (*0.85)
- 长时间无 429 且有流量时触发探测性扩容
核心算法:边界记忆 + 渐进探测
- 触发 429 时记录边界last_concurrent_peak = 触发时的并发数)
- 缩容:新限制 = 边界 - 1快速收敛到真实限制附近
- 扩容:不超过边界(即 last_concurrent_peak允许回到边界值尝试
- 探测性扩容长时间30分钟无 429 时,可以尝试 +1 突破边界
扩容条件(满足任一即可):
1. 滑动窗口扩容:窗口内 >= 60% 的采样利用率 >= 70%,且不在冷却期
2. 探测性扩容:距上次 429 超过 30 分钟,且期间有足够请求量
1. 利用率扩容:窗口内利用率比例 >= 60%,且当前限制 < 边界
2. 探测性扩容:距上次 429 超过 30 分钟,可以尝试突破边界
关键特性:
1. 滑动窗口容忍并发波动,不会因单次低利用率重置
2. 区分并发限制和 RPM 限制
3. 探测性扩容避免长期卡在低限制
4. 记录调整历史
1. 快速收敛:一次 429 就能学到接近真实的限制值
2. 边界保护:普通扩容不会超过已知边界
3. 安全探测:长时间稳定后允许尝试更高并发
4. 区分并发限制和 RPM 限制
"""
# 默认配置 - 使用统一常量
@@ -59,7 +61,6 @@ class AdaptiveConcurrencyManager:
# AIMD 参数
INCREASE_STEP = ConcurrencyDefaults.INCREASE_STEP
DECREASE_MULTIPLIER = ConcurrencyDefaults.DECREASE_MULTIPLIER
# 滑动窗口参数
UTILIZATION_WINDOW_SIZE = ConcurrencyDefaults.UTILIZATION_WINDOW_SIZE
@@ -115,7 +116,13 @@ class AdaptiveConcurrencyManager:
# 更新429统计
key.last_429_at = datetime.now(timezone.utc) # type: ignore[assignment]
key.last_429_type = rate_limit_info.limit_type # type: ignore[assignment]
key.last_concurrent_peak = current_concurrent # type: ignore[assignment]
# 仅在并发限制且拿到并发数时记录边界RPM/UNKNOWN 不应覆盖并发边界记忆)
if (
rate_limit_info.limit_type == RateLimitType.CONCURRENT
and current_concurrent is not None
and current_concurrent > 0
):
key.last_concurrent_peak = current_concurrent # type: ignore[assignment]
# 遇到 429 错误,清空利用率采样窗口(重新开始收集)
key.utilization_samples = [] # type: ignore[assignment]
@@ -207,6 +214,9 @@ class AdaptiveConcurrencyManager:
current_limit = int(key.learned_max_concurrent or self.DEFAULT_INITIAL_LIMIT)
# 获取已知边界(上次触发 429 时的并发数)
known_boundary = key.last_concurrent_peak
# 计算当前利用率
utilization = float(current_concurrent / current_limit) if current_limit > 0 else 0.0
@@ -217,22 +227,29 @@ class AdaptiveConcurrencyManager:
samples = self._update_utilization_window(key, now_ts, utilization)
# 检查是否满足扩容条件
increase_reason = self._check_increase_conditions(key, samples, now)
increase_reason = self._check_increase_conditions(key, samples, now, known_boundary)
if increase_reason and current_limit < self.MAX_CONCURRENT_LIMIT:
old_limit = current_limit
new_limit = self._increase_limit(current_limit)
is_probe = increase_reason == "probe_increase"
new_limit = self._increase_limit(current_limit, known_boundary, is_probe)
# 如果没有实际增长(已达边界),跳过
if new_limit <= old_limit:
return None
# 计算窗口统计用于日志
avg_util = sum(s["util"] for s in samples) / len(samples) if samples else 0
high_util_count = sum(1 for s in samples if s["util"] >= self.UTILIZATION_THRESHOLD)
high_util_ratio = high_util_count / len(samples) if samples else 0
boundary_info = f"边界: {known_boundary}" if known_boundary else "无边界"
logger.info(
f"[INCREASE] {increase_reason}: Key {key.id[:8]}... | "
f"窗口采样: {len(samples)} | "
f"平均利用率: {avg_util:.1%} | "
f"高利用率比例: {high_util_ratio:.1%} | "
f"{boundary_info} | "
f"调整: {old_limit} -> {new_limit}"
)
@@ -246,13 +263,14 @@ class AdaptiveConcurrencyManager:
high_util_ratio=round(high_util_ratio, 2),
sample_count=len(samples),
current_concurrent=current_concurrent,
known_boundary=known_boundary,
)
# 更新限制
key.learned_max_concurrent = new_limit # type: ignore[assignment]
# 如果是探测性扩容,更新探测时间
if increase_reason == "probe_increase":
if is_probe:
key.last_probe_increase_at = now # type: ignore[assignment]
# 扩容后清空采样窗口,重新开始收集
@@ -303,7 +321,11 @@ class AdaptiveConcurrencyManager:
return samples
def _check_increase_conditions(
self, key: ProviderAPIKey, samples: List[Dict[str, Any]], now: datetime
self,
key: ProviderAPIKey,
samples: List[Dict[str, Any]],
now: datetime,
known_boundary: Optional[int] = None,
) -> Optional[str]:
"""
检查是否满足扩容条件
@@ -312,6 +334,7 @@ class AdaptiveConcurrencyManager:
key: API Key对象
samples: 利用率采样列表
now: 当前时间
known_boundary: 已知边界(触发 429 时的并发数)
Returns:
扩容原因(如果满足条件),否则返回 None
@@ -320,15 +343,25 @@ class AdaptiveConcurrencyManager:
if self._is_in_cooldown(key):
return None
# 条件1滑动窗口扩容
current_limit = int(key.learned_max_concurrent or self.DEFAULT_INITIAL_LIMIT)
# 条件1滑动窗口扩容不超过边界
if len(samples) >= self.MIN_SAMPLES_FOR_DECISION:
high_util_count = sum(1 for s in samples if s["util"] >= self.UTILIZATION_THRESHOLD)
high_util_ratio = high_util_count / len(samples)
if high_util_ratio >= self.HIGH_UTILIZATION_RATIO:
return "high_utilization"
# 检查是否还有扩容空间(边界保护)
if known_boundary:
# 允许扩容到边界值(而非 boundary - 1因为缩容时已经 -1 了
if current_limit < known_boundary:
return "high_utilization"
# 已达边界,不触发普通扩容
else:
# 无边界信息,允许扩容
return "high_utilization"
# 条件2探测性扩容长时间无 429 且有流量)
# 条件2探测性扩容长时间无 429 且有流量,可以突破边界
if self._should_probe_increase(key, samples, now):
return "probe_increase"
@@ -406,32 +439,65 @@ class AdaptiveConcurrencyManager:
current_concurrent: Optional[int] = None,
) -> int:
"""
减少并发限制
减少并发限制(基于边界记忆策略)
策略:
- 如果知道当前并发数设置为当前并发的70%
- 否则,使用乘性减少
- 如果知道触发 429 时的并发数,新限制 = 并发数 - 1
- 这样可以快速收敛到真实限制附近,而不会过度保守
- 例如:真实限制 8触发时并发 8 -> 新限制 7而非 8*0.85=6
"""
if current_concurrent:
# 基于当前并发数减少
new_limit = max(
int(current_concurrent * self.DECREASE_MULTIPLIER), self.MIN_CONCURRENT_LIMIT
)
if current_concurrent is not None and current_concurrent > 0:
# 边界记忆策略:新限制 = 触发边界 - 1
candidate = current_concurrent - 1
else:
# 乘性减少
new_limit = max(
int(current_limit * self.DECREASE_MULTIPLIER), self.MIN_CONCURRENT_LIMIT
)
# 没有并发信息时,保守减少 1
candidate = current_limit - 1
# 保证不会“缩容变扩容”(例如 current_concurrent > current_limit 的异常场景)
candidate = min(candidate, current_limit - 1)
new_limit = max(candidate, self.MIN_CONCURRENT_LIMIT)
return new_limit
def _increase_limit(self, current_limit: int) -> int:
def _increase_limit(
self,
current_limit: int,
known_boundary: Optional[int] = None,
is_probe: bool = False,
) -> int:
"""
增加并发限制
增加并发限制(考虑边界保护)
策略:加性增加 (+1)
策略:
- 普通扩容:每次 +INCREASE_STEP但不超过 known_boundary
(因为缩容时已经 -1 了,这里允许回到边界值尝试)
- 探测性扩容:每次只 +1可以突破边界但要谨慎
Args:
current_limit: 当前限制
known_boundary: 已知边界last_concurrent_peak即触发 429 时的并发数
is_probe: 是否是探测性扩容(可以突破边界)
"""
new_limit = min(current_limit + self.INCREASE_STEP, self.MAX_CONCURRENT_LIMIT)
if is_probe:
# 探测模式:每次只 +1谨慎突破边界
new_limit = current_limit + 1
else:
# 普通模式:每次 +INCREASE_STEP
new_limit = current_limit + self.INCREASE_STEP
# 边界保护:普通扩容不超过 known_boundary允许回到边界值尝试
if known_boundary:
if new_limit > known_boundary:
new_limit = known_boundary
# 全局上限保护
new_limit = min(new_limit, self.MAX_CONCURRENT_LIMIT)
# 确保有增长(否则返回原值表示不扩容)
if new_limit <= current_limit:
return current_limit
return new_limit
def _record_adjustment(
@@ -503,11 +569,16 @@ class AdaptiveConcurrencyManager:
if key.last_probe_increase_at:
last_probe_at_str = cast(datetime, key.last_probe_increase_at).isoformat()
# 边界信息
known_boundary = key.last_concurrent_peak
return {
"adaptive_mode": is_adaptive,
"max_concurrent": key.max_concurrent, # NULL=自适应,数字=固定限制
"effective_limit": effective_limit, # 当前有效限制
"learned_limit": key.learned_max_concurrent, # 学习到的限制
# 边界记忆相关
"known_boundary": known_boundary, # 触发 429 时的并发数(已知上限)
"concurrent_429_count": int(key.concurrent_429_count or 0),
"rpm_429_count": int(key.rpm_429_count or 0),
"last_429_at": last_429_at_str,

View File

@@ -289,11 +289,11 @@ class RequestResult:
status_code = 500
error_type = "internal_error"
# 构建错误消息,包含上游响应信息
error_message = str(exception)
if isinstance(exception, ProviderNotAvailableException):
if exception.upstream_response:
error_message = f"{error_message} | 上游响应: {exception.upstream_response[:500]}"
# 构建错误消息:优先使用上游响应作为主要错误信息
if isinstance(exception, ProviderNotAvailableException) and exception.upstream_response:
error_message = exception.upstream_response
else:
error_message = str(exception)
return cls(
status=RequestStatus.FAILED,

View File

@@ -86,6 +86,118 @@ class UsageRecordParams:
class UsageService:
"""用量统计服务"""
# ==================== 缓存键常量 ====================
# 热力图缓存键前缀(依赖 TTL 自动过期,用户角色变更时主动清除)
HEATMAP_CACHE_KEY_PREFIX = "activity_heatmap"
# ==================== 热力图缓存 ====================
@classmethod
def _get_heatmap_cache_key(cls, user_id: Optional[str], include_actual_cost: bool) -> str:
"""生成热力图缓存键"""
cost_suffix = "with_cost" if include_actual_cost else "no_cost"
if user_id:
return f"{cls.HEATMAP_CACHE_KEY_PREFIX}:user:{user_id}:{cost_suffix}"
else:
return f"{cls.HEATMAP_CACHE_KEY_PREFIX}:admin:all:{cost_suffix}"
@classmethod
async def clear_user_heatmap_cache(cls, user_id: str) -> None:
"""
清除用户的热力图缓存(用户角色变更时调用)
Args:
user_id: 用户ID
"""
from src.clients.redis_client import get_redis_client
redis_client = await get_redis_client(require_redis=False)
if not redis_client:
return
# 清除该用户的所有热力图缓存with_cost 和 no_cost
keys_to_delete = [
cls._get_heatmap_cache_key(user_id, include_actual_cost=True),
cls._get_heatmap_cache_key(user_id, include_actual_cost=False),
]
for key in keys_to_delete:
try:
await redis_client.delete(key)
logger.debug(f"已清除热力图缓存: {key}")
except Exception as e:
logger.warning(f"清除热力图缓存失败: {key}, error={e}")
@classmethod
async def get_cached_heatmap(
cls,
db: Session,
user_id: Optional[str] = None,
include_actual_cost: bool = False,
) -> Dict[str, Any]:
"""
获取带缓存的热力图数据
缓存策略:
- TTL: 5分钟CacheTTL.ACTIVITY_HEATMAP
- 仅依赖 TTL 自动过期,新使用记录最多延迟 5 分钟出现
- 用户角色变更时通过 clear_user_heatmap_cache() 主动清除
Args:
db: 数据库会话
user_id: 用户IDNone 表示获取全局热力图(管理员)
include_actual_cost: 是否包含实际成本
Returns:
热力图数据字典
"""
from src.clients.redis_client import get_redis_client
from src.config.constants import CacheTTL
import json
cache_key = cls._get_heatmap_cache_key(user_id, include_actual_cost)
cache_ttl = CacheTTL.ACTIVITY_HEATMAP
redis_client = await get_redis_client(require_redis=False)
# 尝试从缓存获取
if redis_client:
try:
cached = await redis_client.get(cache_key)
if cached:
try:
return json.loads(cached) # type: ignore[no-any-return]
except json.JSONDecodeError as e:
logger.warning(f"热力图缓存解析失败,删除损坏缓存: {cache_key}, error={e}")
try:
await redis_client.delete(cache_key)
except Exception:
pass
except Exception as e:
logger.error(f"读取热力图缓存出错: {cache_key}, error={e}")
# 从数据库查询
result = cls.get_daily_activity(
db=db,
user_id=user_id,
window_days=365,
include_actual_cost=include_actual_cost,
)
# 保存到缓存(失败不影响返回结果)
if redis_client:
try:
await redis_client.setex(
cache_key,
cache_ttl,
json.dumps(result, ensure_ascii=False, default=str),
)
except Exception as e:
logger.warning(f"保存热力图缓存失败: {cache_key}, error={e}")
return result
# ==================== 内部数据类 ====================
@staticmethod
@@ -1027,7 +1139,12 @@ class UsageService:
window_days: int = 365,
include_actual_cost: bool = False,
) -> Dict[str, Any]:
"""按天统计请求活跃度,用于渲染热力图。"""
"""按天统计请求活跃度,用于渲染热力图。
优化策略:
- 历史数据从预计算的 StatsDaily/StatsUserDaily 表读取
- 只有"今天"的数据才实时查询 Usage 表
"""
def ensure_timezone(value: datetime) -> datetime:
if value.tzinfo is None:
@@ -1041,54 +1158,109 @@ class UsageService:
ensure_timezone(start_date) if start_date else end_dt - timedelta(days=window_days - 1)
)
# 对齐到自然日的开始/结束,避免遗漏边界数据
start_dt = start_dt.replace(hour=0, minute=0, second=0, microsecond=0)
end_dt = end_dt.replace(hour=23, minute=59, second=59, microsecond=999999)
from src.utils.database_helpers import date_trunc_portable
bind = db.get_bind()
dialect = bind.dialect.name if bind is not None else "sqlite"
day_bucket = date_trunc_portable(dialect, "day", Usage.created_at).label("day")
columns = [
day_bucket,
func.count(Usage.id).label("requests"),
func.sum(Usage.total_tokens).label("total_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost_usd"),
]
if include_actual_cost:
columns.append(func.sum(Usage.actual_total_cost_usd).label("actual_total_cost_usd"))
query = db.query(*columns).filter(Usage.created_at >= start_dt, Usage.created_at <= end_dt)
if user_id:
query = query.filter(Usage.user_id == user_id)
query = query.group_by(day_bucket).order_by(day_bucket)
rows = query.all()
def normalize_period(value: Any) -> str:
if value is None:
return ""
if isinstance(value, str):
return value[:10]
if isinstance(value, datetime):
return value.date().isoformat()
return str(value)
# 对齐到自然日的开始/结束
start_dt = datetime.combine(start_dt.date(), datetime.min.time(), tzinfo=timezone.utc)
end_dt = datetime.combine(end_dt.date(), datetime.max.time(), tzinfo=timezone.utc)
today = now.date()
today_start_dt = datetime.combine(today, datetime.min.time(), tzinfo=timezone.utc)
aggregated: Dict[str, Dict[str, Any]] = {}
for row in rows:
key = normalize_period(row.day)
aggregated[key] = {
"requests": int(row.requests or 0),
"total_tokens": int(row.total_tokens or 0),
"total_cost_usd": float(row.total_cost_usd or 0.0),
}
if include_actual_cost:
aggregated[key]["actual_total_cost_usd"] = float(row.actual_total_cost_usd or 0.0)
# 1. 从预计算表读取历史数据(不包括今天)
if user_id:
from src.models.database import StatsUserDaily
hist_query = db.query(StatsUserDaily).filter(
StatsUserDaily.user_id == user_id,
StatsUserDaily.date >= start_dt,
StatsUserDaily.date < today_start_dt,
)
for row in hist_query.all():
key = (
row.date.date().isoformat()
if isinstance(row.date, datetime)
else str(row.date)[:10]
)
aggregated[key] = {
"requests": row.total_requests or 0,
"total_tokens": (
(row.input_tokens or 0)
+ (row.output_tokens or 0)
+ (row.cache_creation_tokens or 0)
+ (row.cache_read_tokens or 0)
),
"total_cost_usd": float(row.total_cost or 0.0),
}
# StatsUserDaily 没有 actual_total_cost 字段,用户视图不需要倍率成本
else:
from src.models.database import StatsDaily
hist_query = db.query(StatsDaily).filter(
StatsDaily.date >= start_dt,
StatsDaily.date < today_start_dt,
)
for row in hist_query.all():
key = (
row.date.date().isoformat()
if isinstance(row.date, datetime)
else str(row.date)[:10]
)
aggregated[key] = {
"requests": row.total_requests or 0,
"total_tokens": (
(row.input_tokens or 0)
+ (row.output_tokens or 0)
+ (row.cache_creation_tokens or 0)
+ (row.cache_read_tokens or 0)
),
"total_cost_usd": float(row.total_cost or 0.0),
}
if include_actual_cost:
aggregated[key]["actual_total_cost_usd"] = float(
row.actual_total_cost or 0.0 # type: ignore[attr-defined]
)
# 2. 实时查询今天的数据(如果在查询范围内)
if today >= start_dt.date() and today <= end_dt.date():
today_start = datetime.combine(today, datetime.min.time(), tzinfo=timezone.utc)
today_end = datetime.combine(today, datetime.max.time(), tzinfo=timezone.utc)
if include_actual_cost:
today_query = db.query(
func.count(Usage.id).label("requests"),
func.sum(Usage.total_tokens).label("total_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost_usd"),
func.sum(Usage.actual_total_cost_usd).label("actual_total_cost_usd"),
).filter(
Usage.created_at >= today_start,
Usage.created_at <= today_end,
)
else:
today_query = db.query(
func.count(Usage.id).label("requests"),
func.sum(Usage.total_tokens).label("total_tokens"),
func.sum(Usage.total_cost_usd).label("total_cost_usd"),
).filter(
Usage.created_at >= today_start,
Usage.created_at <= today_end,
)
if user_id:
today_query = today_query.filter(Usage.user_id == user_id)
today_row = today_query.first()
if today_row and today_row.requests:
aggregated[today.isoformat()] = {
"requests": int(today_row.requests or 0),
"total_tokens": int(today_row.total_tokens or 0),
"total_cost_usd": float(today_row.total_cost_usd or 0.0),
}
if include_actual_cost:
aggregated[today.isoformat()]["actual_total_cost_usd"] = float(
today_row.actual_total_cost_usd or 0.0
)
# 3. 构建返回结果
days: List[Dict[str, Any]] = []
cursor = start_dt.date()
end_date_only = end_dt.date()
@@ -1304,6 +1476,9 @@ class UsageService:
provider: Optional[str] = None,
target_model: Optional[str] = None,
first_byte_time_ms: Optional[int] = None,
provider_id: Optional[str] = None,
provider_endpoint_id: Optional[str] = None,
provider_api_key_id: Optional[str] = None,
) -> Optional[Usage]:
"""
快速更新使用记录状态
@@ -1316,6 +1491,9 @@ class UsageService:
provider: 提供商名称可选streaming 状态时更新)
target_model: 映射后的目标模型名(可选)
first_byte_time_ms: 首字时间/TTFB可选streaming 状态时更新)
provider_id: Provider ID可选streaming 状态时更新)
provider_endpoint_id: Endpoint ID可选streaming 状态时更新)
provider_api_key_id: Provider API Key ID可选streaming 状态时更新)
Returns:
更新后的 Usage 记录,如果未找到则返回 None
@@ -1331,10 +1509,22 @@ class UsageService:
usage.error_message = error_message
if provider:
usage.provider = provider
elif status == "streaming" and usage.provider == "pending":
# 状态变为 streaming 但 provider 仍为 pending记录警告
logger.warning(
f"状态更新为 streaming 但 provider 为空: request_id={request_id}, "
f"当前 provider={usage.provider}"
)
if target_model:
usage.target_model = target_model
if first_byte_time_ms is not None:
usage.first_byte_time_ms = first_byte_time_ms
if provider_id is not None:
usage.provider_id = provider_id
if provider_endpoint_id is not None:
usage.provider_endpoint_id = provider_endpoint_id
if provider_api_key_id is not None:
usage.provider_api_key_id = provider_api_key_id
db.commit()
@@ -1446,6 +1636,8 @@ class UsageService:
ids: Optional[List[str]] = None,
user_id: Optional[str] = None,
default_timeout_seconds: int = 300,
*,
include_admin_fields: bool = False,
) -> List[Dict[str, Any]]:
"""
获取活跃请求状态(用于前端轮询),并自动清理超时的 pending/streaming 请求
@@ -1482,6 +1674,15 @@ class UsageService:
ProviderEndpoint.timeout.label("endpoint_timeout"),
).outerjoin(ProviderEndpoint, Usage.provider_endpoint_id == ProviderEndpoint.id)
# 管理员轮询:可附带 provider 与上游 key 名称(注意:不要在普通用户接口暴露上游 key 信息)
if include_admin_fields:
from src.models.database import ProviderAPIKey
query = query.add_columns(
Usage.provider,
ProviderAPIKey.name.label("api_key_name"),
).outerjoin(ProviderAPIKey, Usage.provider_api_key_id == ProviderAPIKey.id)
if ids:
query = query.filter(Usage.id.in_(ids))
if user_id:
@@ -1518,8 +1719,9 @@ class UsageService:
)
db.commit()
return [
{
result: List[Dict[str, Any]] = []
for r in records:
item: Dict[str, Any] = {
"id": r.id,
"status": "failed" if r.id in timeout_ids else r.status,
"input_tokens": r.input_tokens,
@@ -1528,8 +1730,12 @@ class UsageService:
"response_time_ms": r.response_time_ms,
"first_byte_time_ms": r.first_byte_time_ms, # 首字时间 (TTFB)
}
for r in records
]
if include_admin_fields:
item["provider"] = r.provider
item["api_key_name"] = r.api_key_name
result.append(item)
return result
# ========== 缓存亲和性分析方法 ==========

View File

@@ -459,34 +459,38 @@ class StreamUsageTracker:
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应 | 估算输入tokens:{self.input_tokens}")
chunk_count = 0
first_chunk_received = False
first_byte_time_ms = None # 预先记录 TTFB避免 yield 后计算不准确
try:
async for chunk in stream:
chunk_count += 1
# 保存原始字节流(用于错误诊断)
self.raw_chunks.append(chunk)
# 第一个 chunk 收到时,更新状态为 streaming 并记录 TTFB
if not first_chunk_received:
first_chunk_received = True
if self.request_id:
try:
# 计算 TTFB使用请求原始开始时间或 track_stream 开始时间)
base_time = self.request_start_time or self.start_time
first_byte_time_ms = int((time.time() - base_time) * 1000) if base_time else None
UsageService.update_usage_status(
db=self.db,
request_id=self.request_id,
status="streaming",
provider=self.provider,
first_byte_time_ms=first_byte_time_ms,
)
except Exception as e:
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
# 第一个 chunk 收到时,记录 TTFB 时间点(但先不更新数据库,避免阻塞)
if chunk_count == 1:
# 计算 TTFB使用请求原始开始时间或 track_stream 开始时间)
base_time = self.request_start_time or self.start_time
first_byte_time_ms = int((time.time() - base_time) * 1000) if base_time else None
# 返回原始块给客户端
# 返回原始块给客户端,确保 TTFB 不受数据库操作影响
yield chunk
# yield 后再更新数据库状态(仅第一个 chunk 时执行)
if chunk_count == 1 and self.request_id:
try:
UsageService.update_usage_status(
db=self.db,
request_id=self.request_id,
status="streaming",
provider=self.provider,
first_byte_time_ms=first_byte_time_ms,
provider_id=self.provider_id,
provider_endpoint_id=self.provider_endpoint_id,
provider_api_key_id=self.provider_api_key_id,
)
except Exception as e:
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
# 解析块以提取内容和使用信息chunk是原始字节
content, usage = self.parse_stream_chunk(chunk)
@@ -916,15 +920,38 @@ class EnhancedStreamUsageTracker(StreamUsageTracker):
logger.debug(f"ID:{self.request_id} | 开始跟踪流式响应(Enhanced) | 估算输入tokens:{self.input_tokens}")
chunk_count = 0
first_byte_time_ms = None # 预先记录 TTFB避免 yield 后计算不准确
try:
async for chunk in stream:
chunk_count += 1
# 保存原始字节流(用于错误诊断)
self.raw_chunks.append(chunk)
# 返回原始块给客户端
# 第一个 chunk 收到时,记录 TTFB 时间点(但先不更新数据库,避免阻塞)
if chunk_count == 1:
# 计算 TTFB使用请求原始开始时间或 track_stream 开始时间)
base_time = self.request_start_time or self.start_time
first_byte_time_ms = int((time.time() - base_time) * 1000) if base_time else None
# 先返回原始块给客户端,确保 TTFB 不受数据库操作影响
yield chunk
# yield 后再更新数据库状态(仅第一个 chunk 时执行)
if chunk_count == 1 and self.request_id:
try:
UsageService.update_usage_status(
db=self.db,
request_id=self.request_id,
status="streaming",
provider=self.provider,
first_byte_time_ms=first_byte_time_ms,
provider_id=self.provider_id,
provider_endpoint_id=self.provider_endpoint_id,
provider_api_key_id=self.provider_api_key_id,
)
except Exception as e:
logger.warning(f"更新使用记录状态为 streaming 失败: {e}")
# 解析块以提取内容和使用信息chunk是原始字节
content, usage = self.parse_stream_chunk(chunk)

View File

@@ -25,9 +25,10 @@ class ApiKeyService:
allowed_providers: Optional[List[str]] = None,
allowed_api_formats: Optional[List[str]] = None,
allowed_models: Optional[List[str]] = None,
rate_limit: int = 100,
rate_limit: Optional[int] = None,
concurrent_limit: int = 5,
expire_days: Optional[int] = None,
expires_at: Optional[datetime] = None, # 直接传入过期时间,优先于 expire_days
initial_balance_usd: Optional[float] = None,
is_standalone: bool = False,
auto_delete_on_expiry: bool = False,
@@ -44,6 +45,7 @@ class ApiKeyService:
rate_limit: 速率限制
concurrent_limit: 并发限制
expire_days: 过期天数None = 永不过期
expires_at: 直接指定过期时间,优先于 expire_days
initial_balance_usd: 初始余额USD仅用于独立KeyNone = 无限制
is_standalone: 是否为独立余额Key仅管理员可创建
auto_delete_on_expiry: 过期后是否自动删除True=物理删除False=仅禁用)
@@ -54,10 +56,10 @@ class ApiKeyService:
key_hash = ApiKey.hash_key(key)
key_encrypted = crypto_service.encrypt(key) # 加密存储密钥
# 计算过期时间
expires_at = None
if expire_days:
expires_at = datetime.now(timezone.utc) + timedelta(days=expire_days)
# 计算过期时间:优先使用 expires_at其次使用 expire_days
final_expires_at = expires_at
if final_expires_at is None and expire_days:
final_expires_at = datetime.now(timezone.utc) + timedelta(days=expire_days)
# 空数组转为 None表示不限制
api_key = ApiKey(
@@ -70,7 +72,7 @@ class ApiKeyService:
allowed_models=allowed_models or None,
rate_limit=rate_limit,
concurrent_limit=concurrent_limit,
expires_at=expires_at,
expires_at=final_expires_at,
balance_used_usd=0.0,
current_balance_usd=initial_balance_usd, # 直接使用初始余额None = 无限制
is_standalone=is_standalone,
@@ -145,6 +147,9 @@ class ApiKeyService:
# 允许显式设置为空数组/None 的字段(空数组会转为 None表示"全部"
nullable_list_fields = {"allowed_providers", "allowed_api_formats", "allowed_models"}
# 允许显式设置为 None 的字段(如 expires_at=None 表示永不过期rate_limit=None 表示无限制)
nullable_fields = {"expires_at", "rate_limit"}
for field, value in kwargs.items():
if field not in updatable_fields:
continue
@@ -153,6 +158,9 @@ class ApiKeyService:
if value is not None:
# 空数组转为 None表示允许全部
setattr(api_key, field, value if value else None)
elif field in nullable_fields:
# 这些字段允许显式设置为 None
setattr(api_key, field, value)
elif value is not None:
setattr(api_key, field, value)

View File

@@ -49,8 +49,16 @@ def cache_result(key_prefix: str, ttl: int = 60, user_specific: bool = True) ->
# 尝试从缓存获取
cached = await redis_client.get(cache_key)
if cached:
logger.debug(f"缓存命中: {cache_key}")
return json.loads(cached)
try:
result = json.loads(cached)
logger.debug(f"缓存命中: {cache_key}")
return result
except json.JSONDecodeError as e:
logger.warning(f"缓存解析失败,删除损坏缓存: {cache_key}, 错误: {e}")
try:
await redis_client.delete(cache_key)
except Exception:
pass
# 执行原函数
result = await func(*args, **kwargs)