mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-09 19:22:26 +08:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0f78d5cbf3 | ||
|
|
431c6de8d2 | ||
|
|
142e15bbcc | ||
|
|
31acc5c607 | ||
|
|
bfa0a26d41 | ||
|
|
93ab9b6a5e | ||
|
|
35e29d46bd |
@@ -13,6 +13,7 @@ export interface UsersExportData {
|
|||||||
version: string
|
version: string
|
||||||
exported_at: string
|
exported_at: string
|
||||||
users: UserExport[]
|
users: UserExport[]
|
||||||
|
standalone_keys?: StandaloneKeyExport[]
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface UserExport {
|
export interface UserExport {
|
||||||
@@ -46,11 +47,15 @@ export interface UserApiKeyExport {
|
|||||||
concurrent_limit?: number | null
|
concurrent_limit?: number | null
|
||||||
force_capabilities?: any
|
force_capabilities?: any
|
||||||
is_active: boolean
|
is_active: boolean
|
||||||
|
expires_at?: string | null
|
||||||
auto_delete_on_expiry?: boolean
|
auto_delete_on_expiry?: boolean
|
||||||
total_requests?: number
|
total_requests?: number
|
||||||
total_cost_usd?: number
|
total_cost_usd?: number
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 独立余额 Key 导出结构(与 UserApiKeyExport 相同,但不包含 is_standalone)
|
||||||
|
export type StandaloneKeyExport = Omit<UserApiKeyExport, 'is_standalone'>
|
||||||
|
|
||||||
export interface GlobalModelExport {
|
export interface GlobalModelExport {
|
||||||
name: string
|
name: string
|
||||||
display_name: string
|
display_name: string
|
||||||
@@ -189,6 +194,7 @@ export interface UsersImportResponse {
|
|||||||
stats: {
|
stats: {
|
||||||
users: { created: number; updated: number; skipped: number }
|
users: { created: number; updated: number; skipped: number }
|
||||||
api_keys: { created: number; skipped: number }
|
api_keys: { created: number; skipped: number }
|
||||||
|
standalone_keys?: { created: number; skipped: number }
|
||||||
errors: string[]
|
errors: string[]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -473,5 +479,13 @@ export const adminApi = {
|
|||||||
`/api/admin/system/email/templates/${templateType}/reset`
|
`/api/admin/system/email/templates/${templateType}/reset`
|
||||||
)
|
)
|
||||||
return response.data
|
return response.data
|
||||||
|
},
|
||||||
|
|
||||||
|
// 获取系统版本信息
|
||||||
|
async getSystemVersion(): Promise<{ version: string }> {
|
||||||
|
const response = await apiClient.get<{ version: string }>(
|
||||||
|
'/api/admin/system/version'
|
||||||
|
)
|
||||||
|
return response.data
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -62,6 +62,11 @@ export interface UsageRecordDetail {
|
|||||||
cache_creation_price_per_1m?: number
|
cache_creation_price_per_1m?: number
|
||||||
cache_read_price_per_1m?: number
|
cache_read_price_per_1m?: number
|
||||||
price_per_request?: number // 按次计费价格
|
price_per_request?: number // 按次计费价格
|
||||||
|
api_key?: {
|
||||||
|
id: string
|
||||||
|
name: string
|
||||||
|
display: string
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 模型统计接口
|
// 模型统计接口
|
||||||
@@ -192,6 +197,7 @@ export const meApi = {
|
|||||||
async getUsage(params?: {
|
async getUsage(params?: {
|
||||||
start_date?: string
|
start_date?: string
|
||||||
end_date?: string
|
end_date?: string
|
||||||
|
search?: string // 通用搜索:密钥名、模型名
|
||||||
limit?: number
|
limit?: number
|
||||||
offset?: number
|
offset?: number
|
||||||
}): Promise<UsageResponse> {
|
}): Promise<UsageResponse> {
|
||||||
|
|||||||
@@ -192,10 +192,17 @@ export async function getModelsDevList(officialOnly: boolean = true): Promise<Mo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 按 provider 名称和模型名称排序
|
// 按 provider 名称排序,provider 中的模型按 release_date 从近到远排序
|
||||||
items.sort((a, b) => {
|
items.sort((a, b) => {
|
||||||
const providerCompare = a.providerName.localeCompare(b.providerName)
|
const providerCompare = a.providerName.localeCompare(b.providerName)
|
||||||
if (providerCompare !== 0) return providerCompare
|
if (providerCompare !== 0) return providerCompare
|
||||||
|
|
||||||
|
// 模型按 release_date 从近到远排序(没有日期的排到最后)
|
||||||
|
const aDate = a.releaseDate ? new Date(a.releaseDate).getTime() : 0
|
||||||
|
const bDate = b.releaseDate ? new Date(b.releaseDate).getTime() : 0
|
||||||
|
if (aDate !== bDate) return bDate - aDate // 降序:新的在前
|
||||||
|
|
||||||
|
// 日期相同或都没有日期时,按模型名称排序
|
||||||
return a.modelName.localeCompare(b.modelName)
|
return a.modelName.localeCompare(b.modelName)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -164,6 +164,7 @@ export const usageApi = {
|
|||||||
async getAllUsageRecords(params?: {
|
async getAllUsageRecords(params?: {
|
||||||
start_date?: string
|
start_date?: string
|
||||||
end_date?: string
|
end_date?: string
|
||||||
|
search?: string // 通用搜索:用户名、密钥名、模型名、提供商名
|
||||||
user_id?: string // UUID
|
user_id?: string // UUID
|
||||||
username?: string
|
username?: string
|
||||||
model?: string
|
model?: string
|
||||||
|
|||||||
@@ -32,6 +32,17 @@
|
|||||||
<!-- 分隔线 -->
|
<!-- 分隔线 -->
|
||||||
<div class="hidden sm:block h-4 w-px bg-border" />
|
<div class="hidden sm:block h-4 w-px bg-border" />
|
||||||
|
|
||||||
|
<!-- 通用搜索 -->
|
||||||
|
<div class="relative">
|
||||||
|
<Search class="absolute left-2.5 top-1/2 -translate-y-1/2 h-3.5 w-3.5 text-muted-foreground z-10 pointer-events-none" />
|
||||||
|
<Input
|
||||||
|
id="usage-records-search"
|
||||||
|
v-model="localSearch"
|
||||||
|
:placeholder="isAdmin ? '搜索用户/密钥/模型/提供商' : '搜索密钥/模型'"
|
||||||
|
class="w-32 sm:w-48 h-8 text-xs border-border/60 pl-8"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
<!-- 用户筛选(仅管理员可见) -->
|
<!-- 用户筛选(仅管理员可见) -->
|
||||||
<Select
|
<Select
|
||||||
v-if="isAdmin && availableUsers.length > 0"
|
v-if="isAdmin && availableUsers.length > 0"
|
||||||
@@ -164,6 +175,12 @@
|
|||||||
>
|
>
|
||||||
用户
|
用户
|
||||||
</TableHead>
|
</TableHead>
|
||||||
|
<TableHead
|
||||||
|
v-if="!isAdmin"
|
||||||
|
class="h-12 font-semibold w-[100px]"
|
||||||
|
>
|
||||||
|
密钥
|
||||||
|
</TableHead>
|
||||||
<TableHead class="h-12 font-semibold w-[140px]">
|
<TableHead class="h-12 font-semibold w-[140px]">
|
||||||
模型
|
模型
|
||||||
</TableHead>
|
</TableHead>
|
||||||
@@ -196,7 +213,7 @@
|
|||||||
<TableBody>
|
<TableBody>
|
||||||
<TableRow v-if="records.length === 0">
|
<TableRow v-if="records.length === 0">
|
||||||
<TableCell
|
<TableCell
|
||||||
:colspan="isAdmin ? 9 : 7"
|
:colspan="isAdmin ? 9 : 8"
|
||||||
class="text-center py-12 text-muted-foreground"
|
class="text-center py-12 text-muted-foreground"
|
||||||
>
|
>
|
||||||
暂无请求记录
|
暂无请求记录
|
||||||
@@ -218,7 +235,34 @@
|
|||||||
class="py-4 w-[100px] truncate"
|
class="py-4 w-[100px] truncate"
|
||||||
:title="record.username || record.user_email || (record.user_id ? `User ${record.user_id}` : '已删除用户')"
|
:title="record.username || record.user_email || (record.user_id ? `User ${record.user_id}` : '已删除用户')"
|
||||||
>
|
>
|
||||||
{{ record.username || record.user_email || (record.user_id ? `User ${record.user_id}` : '已删除用户') }}
|
<div class="flex flex-col text-xs gap-0.5">
|
||||||
|
<span class="truncate">
|
||||||
|
{{ record.username || record.user_email || (record.user_id ? `User ${record.user_id}` : '已删除用户') }}
|
||||||
|
</span>
|
||||||
|
<span
|
||||||
|
v-if="record.api_key?.name"
|
||||||
|
class="text-muted-foreground truncate"
|
||||||
|
:title="record.api_key.name"
|
||||||
|
>
|
||||||
|
{{ record.api_key.name }}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</TableCell>
|
||||||
|
<!-- 用户页面的密钥列 -->
|
||||||
|
<TableCell
|
||||||
|
v-if="!isAdmin"
|
||||||
|
class="py-4 w-[100px]"
|
||||||
|
:title="record.api_key?.name || '-'"
|
||||||
|
>
|
||||||
|
<div class="flex flex-col text-xs gap-0.5">
|
||||||
|
<span class="truncate">{{ record.api_key?.name || '-' }}</span>
|
||||||
|
<span
|
||||||
|
v-if="record.api_key?.display"
|
||||||
|
class="text-muted-foreground truncate"
|
||||||
|
>
|
||||||
|
{{ record.api_key.display }}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
</TableCell>
|
</TableCell>
|
||||||
<TableCell
|
<TableCell
|
||||||
class="font-medium py-4 w-[140px]"
|
class="font-medium py-4 w-[140px]"
|
||||||
@@ -438,6 +482,7 @@ import {
|
|||||||
TableCard,
|
TableCard,
|
||||||
Badge,
|
Badge,
|
||||||
Button,
|
Button,
|
||||||
|
Input,
|
||||||
Select,
|
Select,
|
||||||
SelectTrigger,
|
SelectTrigger,
|
||||||
SelectValue,
|
SelectValue,
|
||||||
@@ -451,7 +496,7 @@ import {
|
|||||||
TableCell,
|
TableCell,
|
||||||
Pagination,
|
Pagination,
|
||||||
} from '@/components/ui'
|
} from '@/components/ui'
|
||||||
import { RefreshCcw } from 'lucide-vue-next'
|
import { RefreshCcw, Search } from 'lucide-vue-next'
|
||||||
import { formatTokens, formatCurrency } from '@/utils/format'
|
import { formatTokens, formatCurrency } from '@/utils/format'
|
||||||
import { formatDateTime } from '../composables'
|
import { formatDateTime } from '../composables'
|
||||||
import { useRowClick } from '@/composables/useRowClick'
|
import { useRowClick } from '@/composables/useRowClick'
|
||||||
@@ -471,6 +516,7 @@ const props = defineProps<{
|
|||||||
// 时间段
|
// 时间段
|
||||||
selectedPeriod: string
|
selectedPeriod: string
|
||||||
// 筛选
|
// 筛选
|
||||||
|
filterSearch: string
|
||||||
filterUser: string
|
filterUser: string
|
||||||
filterModel: string
|
filterModel: string
|
||||||
filterProvider: string
|
filterProvider: string
|
||||||
@@ -489,6 +535,7 @@ const props = defineProps<{
|
|||||||
|
|
||||||
const emit = defineEmits<{
|
const emit = defineEmits<{
|
||||||
'update:selectedPeriod': [value: string]
|
'update:selectedPeriod': [value: string]
|
||||||
|
'update:filterSearch': [value: string]
|
||||||
'update:filterUser': [value: string]
|
'update:filterUser': [value: string]
|
||||||
'update:filterModel': [value: string]
|
'update:filterModel': [value: string]
|
||||||
'update:filterProvider': [value: string]
|
'update:filterProvider': [value: string]
|
||||||
@@ -507,6 +554,23 @@ const filterModelSelectOpen = ref(false)
|
|||||||
const filterProviderSelectOpen = ref(false)
|
const filterProviderSelectOpen = ref(false)
|
||||||
const filterStatusSelectOpen = ref(false)
|
const filterStatusSelectOpen = ref(false)
|
||||||
|
|
||||||
|
// 通用搜索(输入防抖)
|
||||||
|
const localSearch = ref(props.filterSearch)
|
||||||
|
let searchDebounceTimer: ReturnType<typeof setTimeout> | null = null
|
||||||
|
|
||||||
|
watch(() => props.filterSearch, (value) => {
|
||||||
|
if (value !== localSearch.value) {
|
||||||
|
localSearch.value = value
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
watch(localSearch, (value) => {
|
||||||
|
if (searchDebounceTimer) clearTimeout(searchDebounceTimer)
|
||||||
|
searchDebounceTimer = setTimeout(() => {
|
||||||
|
emit('update:filterSearch', value)
|
||||||
|
}, 300)
|
||||||
|
})
|
||||||
|
|
||||||
// 动态计时器相关
|
// 动态计时器相关
|
||||||
const now = ref(Date.now())
|
const now = ref(Date.now())
|
||||||
let timerInterval: ReturnType<typeof setInterval> | null = null
|
let timerInterval: ReturnType<typeof setInterval> | null = null
|
||||||
@@ -574,6 +638,10 @@ function handleRowClick(event: MouseEvent, id: string) {
|
|||||||
// 组件卸载时清理
|
// 组件卸载时清理
|
||||||
onUnmounted(() => {
|
onUnmounted(() => {
|
||||||
stopTimer()
|
stopTimer()
|
||||||
|
if (searchDebounceTimer) {
|
||||||
|
clearTimeout(searchDebounceTimer)
|
||||||
|
searchDebounceTimer = null
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// 格式化 API 格式显示名称
|
// 格式化 API 格式显示名称
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ export interface PaginationParams {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export interface FilterParams {
|
export interface FilterParams {
|
||||||
|
search?: string
|
||||||
user_id?: string
|
user_id?: string
|
||||||
model?: string
|
model?: string
|
||||||
provider?: string
|
provider?: string
|
||||||
@@ -234,11 +235,6 @@ export function useUsageData(options: UseUsageDataOptions) {
|
|||||||
pagination: PaginationParams,
|
pagination: PaginationParams,
|
||||||
filters?: FilterParams
|
filters?: FilterParams
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
if (!isAdminPage.value) {
|
|
||||||
// 用户页面不需要分页加载,记录已在 loadStats 中获取
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
isLoadingRecords.value = true
|
isLoadingRecords.value = true
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@@ -252,24 +248,34 @@ export function useUsageData(options: UseUsageDataOptions) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 添加筛选条件
|
// 添加筛选条件
|
||||||
if (filters?.user_id) {
|
if (filters?.search?.trim()) {
|
||||||
params.user_id = filters.user_id
|
params.search = filters.search.trim()
|
||||||
}
|
|
||||||
if (filters?.model) {
|
|
||||||
params.model = filters.model
|
|
||||||
}
|
|
||||||
if (filters?.provider) {
|
|
||||||
params.provider = filters.provider
|
|
||||||
}
|
|
||||||
if (filters?.status) {
|
|
||||||
params.status = filters.status
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const response = await usageApi.getAllUsageRecords(params)
|
if (isAdminPage.value) {
|
||||||
|
// 管理员页面:使用管理员 API
|
||||||
currentRecords.value = (response.records || []) as UsageRecord[]
|
if (filters?.user_id) {
|
||||||
totalRecords.value = response.total || 0
|
params.user_id = filters.user_id
|
||||||
|
}
|
||||||
|
if (filters?.model) {
|
||||||
|
params.model = filters.model
|
||||||
|
}
|
||||||
|
if (filters?.provider) {
|
||||||
|
params.provider = filters.provider
|
||||||
|
}
|
||||||
|
if (filters?.status) {
|
||||||
|
params.status = filters.status
|
||||||
|
}
|
||||||
|
|
||||||
|
const response = await usageApi.getAllUsageRecords(params)
|
||||||
|
currentRecords.value = (response.records || []) as UsageRecord[]
|
||||||
|
totalRecords.value = response.total || 0
|
||||||
|
} else {
|
||||||
|
// 用户页面:使用用户 API
|
||||||
|
const userData = await meApi.getUsage(params)
|
||||||
|
currentRecords.value = (userData.records || []) as UsageRecord[]
|
||||||
|
totalRecords.value = userData.pagination?.total || currentRecords.value.length
|
||||||
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
log.error('加载记录失败:', error)
|
log.error('加载记录失败:', error)
|
||||||
currentRecords.value = []
|
currentRecords.value = []
|
||||||
|
|||||||
@@ -61,6 +61,11 @@ export interface UsageRecord {
|
|||||||
user_id?: string
|
user_id?: string
|
||||||
username?: string
|
username?: string
|
||||||
user_email?: string
|
user_email?: string
|
||||||
|
api_key?: {
|
||||||
|
id: string | null
|
||||||
|
name: string | null
|
||||||
|
display: string | null
|
||||||
|
} | null
|
||||||
provider: string
|
provider: string
|
||||||
api_key_name?: string
|
api_key_name?: string
|
||||||
rate_multiplier?: number
|
rate_multiplier?: number
|
||||||
|
|||||||
@@ -367,6 +367,11 @@ function generateMockUsageRecords(count: number = 100) {
|
|||||||
user_id: user.id,
|
user_id: user.id,
|
||||||
username: user.username,
|
username: user.username,
|
||||||
user_email: user.email,
|
user_email: user.email,
|
||||||
|
api_key: {
|
||||||
|
id: `key-${user.id}-${Math.ceil(Math.random() * 2)}`,
|
||||||
|
name: `${user.username} Key ${Math.ceil(Math.random() * 3)}`,
|
||||||
|
display: `sk-ae...${String(1000 + Math.floor(Math.random() * 9000))}`
|
||||||
|
},
|
||||||
provider: model.provider,
|
provider: model.provider,
|
||||||
api_key_name: `${model.provider}-key-${Math.ceil(Math.random() * 3)}`,
|
api_key_name: `${model.provider}-key-${Math.ceil(Math.random() * 3)}`,
|
||||||
rate_multiplier: 1.0,
|
rate_multiplier: 1.0,
|
||||||
@@ -835,10 +840,26 @@ const mockHandlers: Record<string, (config: AxiosRequestConfig) => Promise<Axios
|
|||||||
'GET /api/admin/usage/records': async (config) => {
|
'GET /api/admin/usage/records': async (config) => {
|
||||||
await delay()
|
await delay()
|
||||||
requireAdmin()
|
requireAdmin()
|
||||||
const records = getUsageRecords()
|
let records = getUsageRecords()
|
||||||
const params = config.params || {}
|
const params = config.params || {}
|
||||||
const limit = parseInt(params.limit) || 20
|
const limit = parseInt(params.limit) || 20
|
||||||
const offset = parseInt(params.offset) || 0
|
const offset = parseInt(params.offset) || 0
|
||||||
|
|
||||||
|
// 通用搜索:用户名、密钥名、模型名、提供商名
|
||||||
|
// 支持空格分隔的组合搜索,多个关键词之间是 AND 关系
|
||||||
|
if (typeof params.search === 'string' && params.search.trim()) {
|
||||||
|
const keywords = params.search.trim().toLowerCase().split(/\s+/)
|
||||||
|
records = records.filter(r => {
|
||||||
|
// 每个关键词都要匹配至少一个字段
|
||||||
|
return keywords.every((keyword: string) =>
|
||||||
|
(r.username || '').toLowerCase().includes(keyword) ||
|
||||||
|
(r.api_key?.name || '').toLowerCase().includes(keyword) ||
|
||||||
|
(r.model || '').toLowerCase().includes(keyword) ||
|
||||||
|
(r.provider || '').toLowerCase().includes(keyword)
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
return createMockResponse({
|
return createMockResponse({
|
||||||
records: records.slice(offset, offset + limit),
|
records: records.slice(offset, offset + limit),
|
||||||
total: records.length,
|
total: records.length,
|
||||||
|
|||||||
@@ -464,6 +464,30 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</CardSection>
|
</CardSection>
|
||||||
|
|
||||||
|
<!-- 系统版本信息 -->
|
||||||
|
<CardSection
|
||||||
|
title="系统信息"
|
||||||
|
description="当前系统版本和构建信息"
|
||||||
|
>
|
||||||
|
<div class="flex items-center gap-4">
|
||||||
|
<div class="flex items-center gap-2">
|
||||||
|
<Label class="text-sm font-medium text-muted-foreground">版本:</Label>
|
||||||
|
<span
|
||||||
|
v-if="systemVersion"
|
||||||
|
class="text-sm font-mono"
|
||||||
|
>
|
||||||
|
{{ systemVersion }}
|
||||||
|
</span>
|
||||||
|
<span
|
||||||
|
v-else
|
||||||
|
class="text-sm text-muted-foreground"
|
||||||
|
>
|
||||||
|
加载中...
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</CardSection>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- 导入配置对话框 -->
|
<!-- 导入配置对话框 -->
|
||||||
@@ -475,7 +499,7 @@
|
|||||||
<div class="space-y-4">
|
<div class="space-y-4">
|
||||||
<div
|
<div
|
||||||
v-if="importPreview"
|
v-if="importPreview"
|
||||||
class="p-3 bg-muted rounded-lg text-sm"
|
class="text-sm"
|
||||||
>
|
>
|
||||||
<p class="font-medium mb-2">
|
<p class="font-medium mb-2">
|
||||||
配置预览
|
配置预览
|
||||||
@@ -557,7 +581,7 @@
|
|||||||
class="space-y-4"
|
class="space-y-4"
|
||||||
>
|
>
|
||||||
<div class="grid grid-cols-2 gap-4 text-sm">
|
<div class="grid grid-cols-2 gap-4 text-sm">
|
||||||
<div class="p-3 bg-muted rounded-lg">
|
<div>
|
||||||
<p class="font-medium">
|
<p class="font-medium">
|
||||||
全局模型
|
全局模型
|
||||||
</p>
|
</p>
|
||||||
@@ -567,7 +591,7 @@
|
|||||||
跳过: {{ importResult.stats.global_models.skipped }}
|
跳过: {{ importResult.stats.global_models.skipped }}
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
<div class="p-3 bg-muted rounded-lg">
|
<div>
|
||||||
<p class="font-medium">
|
<p class="font-medium">
|
||||||
提供商
|
提供商
|
||||||
</p>
|
</p>
|
||||||
@@ -577,7 +601,7 @@
|
|||||||
跳过: {{ importResult.stats.providers.skipped }}
|
跳过: {{ importResult.stats.providers.skipped }}
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
<div class="p-3 bg-muted rounded-lg">
|
<div>
|
||||||
<p class="font-medium">
|
<p class="font-medium">
|
||||||
端点
|
端点
|
||||||
</p>
|
</p>
|
||||||
@@ -587,7 +611,7 @@
|
|||||||
跳过: {{ importResult.stats.endpoints.skipped }}
|
跳过: {{ importResult.stats.endpoints.skipped }}
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
<div class="p-3 bg-muted rounded-lg">
|
<div>
|
||||||
<p class="font-medium">
|
<p class="font-medium">
|
||||||
API Keys
|
API Keys
|
||||||
</p>
|
</p>
|
||||||
@@ -596,7 +620,7 @@
|
|||||||
跳过: {{ importResult.stats.keys.skipped }}
|
跳过: {{ importResult.stats.keys.skipped }}
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
<div class="p-3 bg-muted rounded-lg col-span-2">
|
<div class="col-span-2">
|
||||||
<p class="font-medium">
|
<p class="font-medium">
|
||||||
模型配置
|
模型配置
|
||||||
</p>
|
</p>
|
||||||
@@ -642,7 +666,7 @@
|
|||||||
<div class="space-y-4">
|
<div class="space-y-4">
|
||||||
<div
|
<div
|
||||||
v-if="importUsersPreview"
|
v-if="importUsersPreview"
|
||||||
class="p-3 bg-muted rounded-lg text-sm"
|
class="text-sm"
|
||||||
>
|
>
|
||||||
<p class="font-medium mb-2">
|
<p class="font-medium mb-2">
|
||||||
数据预览
|
数据预览
|
||||||
@@ -652,6 +676,9 @@
|
|||||||
<li>
|
<li>
|
||||||
API Keys: {{ importUsersPreview.users?.reduce((sum: number, u: any) => sum + (u.api_keys?.length || 0), 0) }} 个
|
API Keys: {{ importUsersPreview.users?.reduce((sum: number, u: any) => sum + (u.api_keys?.length || 0), 0) }} 个
|
||||||
</li>
|
</li>
|
||||||
|
<li v-if="importUsersPreview.standalone_keys?.length">
|
||||||
|
独立余额 Keys: {{ importUsersPreview.standalone_keys.length }} 个
|
||||||
|
</li>
|
||||||
</ul>
|
</ul>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -720,7 +747,7 @@
|
|||||||
class="space-y-4"
|
class="space-y-4"
|
||||||
>
|
>
|
||||||
<div class="grid grid-cols-2 gap-4 text-sm">
|
<div class="grid grid-cols-2 gap-4 text-sm">
|
||||||
<div class="p-3 bg-muted rounded-lg">
|
<div>
|
||||||
<p class="font-medium">
|
<p class="font-medium">
|
||||||
用户
|
用户
|
||||||
</p>
|
</p>
|
||||||
@@ -730,7 +757,7 @@
|
|||||||
跳过: {{ importUsersResult.stats.users.skipped }}
|
跳过: {{ importUsersResult.stats.users.skipped }}
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
<div class="p-3 bg-muted rounded-lg">
|
<div>
|
||||||
<p class="font-medium">
|
<p class="font-medium">
|
||||||
API Keys
|
API Keys
|
||||||
</p>
|
</p>
|
||||||
@@ -739,6 +766,18 @@
|
|||||||
跳过: {{ importUsersResult.stats.api_keys.skipped }}
|
跳过: {{ importUsersResult.stats.api_keys.skipped }}
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
<div
|
||||||
|
v-if="importUsersResult.stats.standalone_keys"
|
||||||
|
class="col-span-2"
|
||||||
|
>
|
||||||
|
<p class="font-medium">
|
||||||
|
独立余额 Keys
|
||||||
|
</p>
|
||||||
|
<p class="text-muted-foreground">
|
||||||
|
创建: {{ importUsersResult.stats.standalone_keys.created }},
|
||||||
|
跳过: {{ importUsersResult.stats.standalone_keys.skipped }}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div
|
<div
|
||||||
@@ -839,6 +878,9 @@ const importUsersResult = ref<UsersImportResponse | null>(null)
|
|||||||
const usersMergeMode = ref<'skip' | 'overwrite' | 'error'>('skip')
|
const usersMergeMode = ref<'skip' | 'overwrite' | 'error'>('skip')
|
||||||
const usersMergeModeSelectOpen = ref(false)
|
const usersMergeModeSelectOpen = ref(false)
|
||||||
|
|
||||||
|
// 系统版本信息
|
||||||
|
const systemVersion = ref<string>('')
|
||||||
|
|
||||||
const systemConfig = ref<SystemConfig>({
|
const systemConfig = ref<SystemConfig>({
|
||||||
// 基础配置
|
// 基础配置
|
||||||
default_user_quota_usd: 10.0,
|
default_user_quota_usd: 10.0,
|
||||||
@@ -890,9 +932,21 @@ const sensitiveHeadersStr = computed({
|
|||||||
})
|
})
|
||||||
|
|
||||||
onMounted(async () => {
|
onMounted(async () => {
|
||||||
await loadSystemConfig()
|
await Promise.all([
|
||||||
|
loadSystemConfig(),
|
||||||
|
loadSystemVersion()
|
||||||
|
])
|
||||||
})
|
})
|
||||||
|
|
||||||
|
async function loadSystemVersion() {
|
||||||
|
try {
|
||||||
|
const data = await adminApi.getSystemVersion()
|
||||||
|
systemVersion.value = data.version
|
||||||
|
} catch (err) {
|
||||||
|
log.error('加载系统版本失败:', err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async function loadSystemConfig() {
|
async function loadSystemConfig() {
|
||||||
try {
|
try {
|
||||||
const configs = [
|
const configs = [
|
||||||
@@ -1178,12 +1232,6 @@ function handleUsersFileSelect(event: Event) {
|
|||||||
const content = e.target?.result as string
|
const content = e.target?.result as string
|
||||||
const data = JSON.parse(content) as UsersExportData
|
const data = JSON.parse(content) as UsersExportData
|
||||||
|
|
||||||
// 验证版本
|
|
||||||
if (data.version !== '1.0') {
|
|
||||||
error(`不支持的配置版本: ${data.version}`)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
importUsersPreview.value = data
|
importUsersPreview.value = data
|
||||||
usersMergeMode.value = 'skip'
|
usersMergeMode.value = 'skip'
|
||||||
importUsersDialogOpen.value = true
|
importUsersDialogOpen.value = true
|
||||||
|
|||||||
@@ -56,6 +56,7 @@
|
|||||||
:show-actual-cost="authStore.isAdmin"
|
:show-actual-cost="authStore.isAdmin"
|
||||||
:loading="isLoadingRecords"
|
:loading="isLoadingRecords"
|
||||||
:selected-period="selectedPeriod"
|
:selected-period="selectedPeriod"
|
||||||
|
:filter-search="filterSearch"
|
||||||
:filter-user="filterUser"
|
:filter-user="filterUser"
|
||||||
:filter-model="filterModel"
|
:filter-model="filterModel"
|
||||||
:filter-provider="filterProvider"
|
:filter-provider="filterProvider"
|
||||||
@@ -69,6 +70,7 @@
|
|||||||
:page-size-options="pageSizeOptions"
|
:page-size-options="pageSizeOptions"
|
||||||
:auto-refresh="globalAutoRefresh"
|
:auto-refresh="globalAutoRefresh"
|
||||||
@update:selected-period="handlePeriodChange"
|
@update:selected-period="handlePeriodChange"
|
||||||
|
@update:filter-search="handleFilterSearchChange"
|
||||||
@update:filter-user="handleFilterUserChange"
|
@update:filter-user="handleFilterUserChange"
|
||||||
@update:filter-model="handleFilterModelChange"
|
@update:filter-model="handleFilterModelChange"
|
||||||
@update:filter-provider="handleFilterProviderChange"
|
@update:filter-provider="handleFilterProviderChange"
|
||||||
@@ -133,6 +135,7 @@ const pageSize = ref(20)
|
|||||||
const pageSizeOptions = [10, 20, 50, 100]
|
const pageSizeOptions = [10, 20, 50, 100]
|
||||||
|
|
||||||
// 筛选状态
|
// 筛选状态
|
||||||
|
const filterSearch = ref('')
|
||||||
const filterUser = ref('__all__')
|
const filterUser = ref('__all__')
|
||||||
const filterModel = ref('__all__')
|
const filterModel = ref('__all__')
|
||||||
const filterProvider = ref('__all__')
|
const filterProvider = ref('__all__')
|
||||||
@@ -392,14 +395,17 @@ onMounted(async () => {
|
|||||||
// 热力图加载失败不提示,因为 UI 已显示占位符
|
// 热力图加载失败不提示,因为 UI 已显示占位符
|
||||||
}
|
}
|
||||||
|
|
||||||
// 管理员页面加载用户列表和第一页记录
|
// 加载记录和用户列表
|
||||||
if (isAdminPage.value) {
|
if (isAdminPage.value) {
|
||||||
// 并行加载用户列表和记录
|
// 管理员页面:并行加载用户列表和记录
|
||||||
const [users] = await Promise.all([
|
const [users] = await Promise.all([
|
||||||
usersApi.getAllUsers(),
|
usersApi.getAllUsers(),
|
||||||
loadRecords({ page: currentPage.value, pageSize: pageSize.value }, getCurrentFilters())
|
loadRecords({ page: currentPage.value, pageSize: pageSize.value }, getCurrentFilters())
|
||||||
])
|
])
|
||||||
availableUsers.value = users.map(u => ({ id: u.id, username: u.username, email: u.email }))
|
availableUsers.value = users.map(u => ({ id: u.id, username: u.username, email: u.email }))
|
||||||
|
} else {
|
||||||
|
// 用户页面:加载记录
|
||||||
|
await loadRecords({ page: currentPage.value, pageSize: pageSize.value }, getCurrentFilters())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -410,34 +416,26 @@ async function handlePeriodChange(value: string) {
|
|||||||
|
|
||||||
const dateRange = getDateRangeFromPeriod(selectedPeriod.value)
|
const dateRange = getDateRangeFromPeriod(selectedPeriod.value)
|
||||||
await loadStats(dateRange)
|
await loadStats(dateRange)
|
||||||
|
await loadRecords({ page: 1, pageSize: pageSize.value }, getCurrentFilters())
|
||||||
if (isAdminPage.value) {
|
|
||||||
await loadRecords({ page: 1, pageSize: pageSize.value }, getCurrentFilters())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 处理分页变化
|
// 处理分页变化
|
||||||
async function handlePageChange(page: number) {
|
async function handlePageChange(page: number) {
|
||||||
currentPage.value = page
|
currentPage.value = page
|
||||||
|
await loadRecords({ page, pageSize: pageSize.value }, getCurrentFilters())
|
||||||
if (isAdminPage.value) {
|
|
||||||
await loadRecords({ page, pageSize: pageSize.value }, getCurrentFilters())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 处理每页大小变化
|
// 处理每页大小变化
|
||||||
async function handlePageSizeChange(size: number) {
|
async function handlePageSizeChange(size: number) {
|
||||||
pageSize.value = size
|
pageSize.value = size
|
||||||
currentPage.value = 1 // 重置到第一页
|
currentPage.value = 1 // 重置到第一页
|
||||||
|
await loadRecords({ page: 1, pageSize: size }, getCurrentFilters())
|
||||||
if (isAdminPage.value) {
|
|
||||||
await loadRecords({ page: 1, pageSize: size }, getCurrentFilters())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取当前筛选参数
|
// 获取当前筛选参数
|
||||||
function getCurrentFilters() {
|
function getCurrentFilters() {
|
||||||
return {
|
return {
|
||||||
|
search: filterSearch.value.trim() || undefined,
|
||||||
user_id: filterUser.value !== '__all__' ? filterUser.value : undefined,
|
user_id: filterUser.value !== '__all__' ? filterUser.value : undefined,
|
||||||
model: filterModel.value !== '__all__' ? filterModel.value : undefined,
|
model: filterModel.value !== '__all__' ? filterModel.value : undefined,
|
||||||
provider: filterProvider.value !== '__all__' ? filterProvider.value : undefined,
|
provider: filterProvider.value !== '__all__' ? filterProvider.value : undefined,
|
||||||
@@ -446,6 +444,13 @@ function getCurrentFilters() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 处理筛选变化
|
// 处理筛选变化
|
||||||
|
async function handleFilterSearchChange(value: string) {
|
||||||
|
filterSearch.value = value
|
||||||
|
currentPage.value = 1
|
||||||
|
|
||||||
|
await loadRecords({ page: 1, pageSize: pageSize.value }, getCurrentFilters())
|
||||||
|
}
|
||||||
|
|
||||||
async function handleFilterUserChange(value: string) {
|
async function handleFilterUserChange(value: string) {
|
||||||
filterUser.value = value
|
filterUser.value = value
|
||||||
currentPage.value = 1 // 重置到第一页
|
currentPage.value = 1 // 重置到第一页
|
||||||
@@ -486,10 +491,7 @@ async function handleFilterStatusChange(value: string) {
|
|||||||
async function refreshData() {
|
async function refreshData() {
|
||||||
const dateRange = getDateRangeFromPeriod(selectedPeriod.value)
|
const dateRange = getDateRangeFromPeriod(selectedPeriod.value)
|
||||||
await loadStats(dateRange)
|
await loadStats(dateRange)
|
||||||
|
await loadRecords({ page: currentPage.value, pageSize: pageSize.value }, getCurrentFilters())
|
||||||
if (isAdminPage.value) {
|
|
||||||
await loadRecords({ page: currentPage.value, pageSize: pageSize.value }, getCurrentFilters())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 显示请求详情
|
// 显示请求详情
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
"""系统设置API端点。"""
|
"""系统设置API端点。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -17,6 +19,46 @@ from src.services.email.email_template import EmailTemplate
|
|||||||
from src.services.system.config import SystemConfigService
|
from src.services.system.config import SystemConfigService
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/admin/system", tags=["Admin - System"])
|
router = APIRouter(prefix="/api/admin/system", tags=["Admin - System"])
|
||||||
|
|
||||||
|
|
||||||
|
def _get_version_from_git() -> str | None:
|
||||||
|
"""从 git describe 获取版本号"""
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
["git", "describe", "--tags", "--always"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
if result.returncode == 0:
|
||||||
|
version = result.stdout.strip()
|
||||||
|
if version.startswith("v"):
|
||||||
|
version = version[1:]
|
||||||
|
return version
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/version")
|
||||||
|
async def get_system_version():
|
||||||
|
"""获取系统版本信息"""
|
||||||
|
# 优先从 git 获取
|
||||||
|
version = _get_version_from_git()
|
||||||
|
if version:
|
||||||
|
return {"version": version}
|
||||||
|
|
||||||
|
# 回退到静态版本文件
|
||||||
|
try:
|
||||||
|
from src._version import __version__
|
||||||
|
|
||||||
|
return {"version": __version__}
|
||||||
|
except ImportError:
|
||||||
|
return {"version": "unknown"}
|
||||||
|
|
||||||
|
|
||||||
pipeline = ApiRequestPipeline()
|
pipeline = ApiRequestPipeline()
|
||||||
|
|
||||||
|
|
||||||
@@ -950,6 +992,31 @@ class AdminExportUsersAdapter(AdminApiAdapter):
|
|||||||
|
|
||||||
db = context.db
|
db = context.db
|
||||||
|
|
||||||
|
def _serialize_api_key(key: ApiKey, include_is_standalone: bool = False) -> dict:
|
||||||
|
"""序列化 API Key 为导出格式"""
|
||||||
|
data = {
|
||||||
|
"key_hash": key.key_hash,
|
||||||
|
"key_encrypted": key.key_encrypted,
|
||||||
|
"name": key.name,
|
||||||
|
"balance_used_usd": key.balance_used_usd,
|
||||||
|
"current_balance_usd": key.current_balance_usd,
|
||||||
|
"allowed_providers": key.allowed_providers,
|
||||||
|
"allowed_endpoints": key.allowed_endpoints,
|
||||||
|
"allowed_api_formats": key.allowed_api_formats,
|
||||||
|
"allowed_models": key.allowed_models,
|
||||||
|
"rate_limit": key.rate_limit,
|
||||||
|
"concurrent_limit": key.concurrent_limit,
|
||||||
|
"force_capabilities": key.force_capabilities,
|
||||||
|
"is_active": key.is_active,
|
||||||
|
"expires_at": key.expires_at.isoformat() if key.expires_at else None,
|
||||||
|
"auto_delete_on_expiry": key.auto_delete_on_expiry,
|
||||||
|
"total_requests": key.total_requests,
|
||||||
|
"total_cost_usd": key.total_cost_usd,
|
||||||
|
}
|
||||||
|
if include_is_standalone:
|
||||||
|
data["is_standalone"] = key.is_standalone
|
||||||
|
return data
|
||||||
|
|
||||||
# 导出 Users(排除管理员)
|
# 导出 Users(排除管理员)
|
||||||
users = db.query(User).filter(
|
users = db.query(User).filter(
|
||||||
User.is_deleted.is_(False),
|
User.is_deleted.is_(False),
|
||||||
@@ -957,31 +1024,12 @@ class AdminExportUsersAdapter(AdminApiAdapter):
|
|||||||
).all()
|
).all()
|
||||||
users_data = []
|
users_data = []
|
||||||
for user in users:
|
for user in users:
|
||||||
# 导出用户的 API Keys(保留加密数据)
|
# 导出用户的 API Keys(排除独立余额Key,独立Key单独导出)
|
||||||
api_keys = db.query(ApiKey).filter(ApiKey.user_id == user.id).all()
|
api_keys = db.query(ApiKey).filter(
|
||||||
api_keys_data = []
|
ApiKey.user_id == user.id,
|
||||||
for key in api_keys:
|
ApiKey.is_standalone.is_(False)
|
||||||
api_keys_data.append(
|
).all()
|
||||||
{
|
api_keys_data = [_serialize_api_key(key, include_is_standalone=True) for key in api_keys]
|
||||||
"key_hash": key.key_hash,
|
|
||||||
"key_encrypted": key.key_encrypted,
|
|
||||||
"name": key.name,
|
|
||||||
"is_standalone": key.is_standalone,
|
|
||||||
"balance_used_usd": key.balance_used_usd,
|
|
||||||
"current_balance_usd": key.current_balance_usd,
|
|
||||||
"allowed_providers": key.allowed_providers,
|
|
||||||
"allowed_endpoints": key.allowed_endpoints,
|
|
||||||
"allowed_api_formats": key.allowed_api_formats,
|
|
||||||
"allowed_models": key.allowed_models,
|
|
||||||
"rate_limit": key.rate_limit,
|
|
||||||
"concurrent_limit": key.concurrent_limit,
|
|
||||||
"force_capabilities": key.force_capabilities,
|
|
||||||
"is_active": key.is_active,
|
|
||||||
"auto_delete_on_expiry": key.auto_delete_on_expiry,
|
|
||||||
"total_requests": key.total_requests,
|
|
||||||
"total_cost_usd": key.total_cost_usd,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
users_data.append(
|
users_data.append(
|
||||||
{
|
{
|
||||||
@@ -1001,10 +1049,15 @@ class AdminExportUsersAdapter(AdminApiAdapter):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 导出独立余额 Keys(管理员创建的,不属于普通用户)
|
||||||
|
standalone_keys = db.query(ApiKey).filter(ApiKey.is_standalone.is_(True)).all()
|
||||||
|
standalone_keys_data = [_serialize_api_key(key) for key in standalone_keys]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"version": "1.0",
|
"version": "1.1",
|
||||||
"exported_at": datetime.now(timezone.utc).isoformat(),
|
"exported_at": datetime.now(timezone.utc).isoformat(),
|
||||||
"users": users_data,
|
"users": users_data,
|
||||||
|
"standalone_keys": standalone_keys_data,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -1024,21 +1077,72 @@ class AdminImportUsersAdapter(AdminApiAdapter):
|
|||||||
db = context.db
|
db = context.db
|
||||||
payload = context.ensure_json_body()
|
payload = context.ensure_json_body()
|
||||||
|
|
||||||
# 验证配置版本
|
|
||||||
version = payload.get("version")
|
|
||||||
if version != "1.0":
|
|
||||||
raise InvalidRequestException(f"不支持的配置版本: {version}")
|
|
||||||
|
|
||||||
# 获取导入选项
|
# 获取导入选项
|
||||||
merge_mode = payload.get("merge_mode", "skip") # skip, overwrite, error
|
merge_mode = payload.get("merge_mode", "skip") # skip, overwrite, error
|
||||||
users_data = payload.get("users", [])
|
users_data = payload.get("users", [])
|
||||||
|
standalone_keys_data = payload.get("standalone_keys", [])
|
||||||
|
|
||||||
stats = {
|
stats = {
|
||||||
"users": {"created": 0, "updated": 0, "skipped": 0},
|
"users": {"created": 0, "updated": 0, "skipped": 0},
|
||||||
"api_keys": {"created": 0, "skipped": 0},
|
"api_keys": {"created": 0, "skipped": 0},
|
||||||
|
"standalone_keys": {"created": 0, "skipped": 0},
|
||||||
"errors": [],
|
"errors": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _create_api_key_from_data(
|
||||||
|
key_data: dict,
|
||||||
|
owner_id: str,
|
||||||
|
is_standalone: bool = False,
|
||||||
|
) -> tuple[ApiKey | None, str]:
|
||||||
|
"""从导入数据创建 ApiKey 对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(ApiKey, "created"): 成功创建
|
||||||
|
(None, "skipped"): key 已存在,跳过
|
||||||
|
(None, "invalid"): 数据无效,跳过
|
||||||
|
"""
|
||||||
|
key_hash = key_data.get("key_hash", "").strip()
|
||||||
|
if not key_hash:
|
||||||
|
return None, "invalid"
|
||||||
|
|
||||||
|
# 检查是否已存在
|
||||||
|
existing = db.query(ApiKey).filter(ApiKey.key_hash == key_hash).first()
|
||||||
|
if existing:
|
||||||
|
return None, "skipped"
|
||||||
|
|
||||||
|
# 解析 expires_at
|
||||||
|
expires_at = None
|
||||||
|
if key_data.get("expires_at"):
|
||||||
|
try:
|
||||||
|
expires_at = datetime.fromisoformat(key_data["expires_at"])
|
||||||
|
except ValueError:
|
||||||
|
stats["errors"].append(
|
||||||
|
f"API Key '{key_data.get('name', key_hash[:8])}' 的 expires_at 格式无效"
|
||||||
|
)
|
||||||
|
|
||||||
|
return ApiKey(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=owner_id,
|
||||||
|
key_hash=key_hash,
|
||||||
|
key_encrypted=key_data.get("key_encrypted"),
|
||||||
|
name=key_data.get("name"),
|
||||||
|
is_standalone=is_standalone or key_data.get("is_standalone", False),
|
||||||
|
balance_used_usd=key_data.get("balance_used_usd", 0.0),
|
||||||
|
current_balance_usd=key_data.get("current_balance_usd"),
|
||||||
|
allowed_providers=key_data.get("allowed_providers"),
|
||||||
|
allowed_endpoints=key_data.get("allowed_endpoints"),
|
||||||
|
allowed_api_formats=key_data.get("allowed_api_formats"),
|
||||||
|
allowed_models=key_data.get("allowed_models"),
|
||||||
|
rate_limit=key_data.get("rate_limit"),
|
||||||
|
concurrent_limit=key_data.get("concurrent_limit", 5),
|
||||||
|
force_capabilities=key_data.get("force_capabilities"),
|
||||||
|
is_active=key_data.get("is_active", True),
|
||||||
|
expires_at=expires_at,
|
||||||
|
auto_delete_on_expiry=key_data.get("auto_delete_on_expiry", False),
|
||||||
|
total_requests=key_data.get("total_requests", 0),
|
||||||
|
total_cost_usd=key_data.get("total_cost_usd", 0.0),
|
||||||
|
), "created"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for user_data in users_data:
|
for user_data in users_data:
|
||||||
# 跳过管理员角色的导入(不区分大小写)
|
# 跳过管理员角色的导入(不区分大小写)
|
||||||
@@ -1109,40 +1213,31 @@ class AdminImportUsersAdapter(AdminApiAdapter):
|
|||||||
|
|
||||||
# 导入 API Keys
|
# 导入 API Keys
|
||||||
for key_data in user_data.get("api_keys", []):
|
for key_data in user_data.get("api_keys", []):
|
||||||
# 检查是否已存在相同的 key_hash
|
new_key, status = _create_api_key_from_data(key_data, user_id)
|
||||||
if key_data.get("key_hash"):
|
if new_key:
|
||||||
existing_key = (
|
db.add(new_key)
|
||||||
db.query(ApiKey)
|
stats["api_keys"]["created"] += 1
|
||||||
.filter(ApiKey.key_hash == key_data["key_hash"])
|
elif status == "skipped":
|
||||||
.first()
|
stats["api_keys"]["skipped"] += 1
|
||||||
)
|
# invalid 数据不计入统计
|
||||||
if existing_key:
|
|
||||||
stats["api_keys"]["skipped"] += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
new_key = ApiKey(
|
# 导入独立余额 Keys(需要找一个管理员用户作为 owner)
|
||||||
id=str(uuid.uuid4()),
|
if standalone_keys_data:
|
||||||
user_id=user_id,
|
# 查找一个管理员用户作为独立Key的owner
|
||||||
key_hash=key_data.get("key_hash", ""),
|
admin_user = db.query(User).filter(User.role == UserRole.ADMIN).first()
|
||||||
key_encrypted=key_data.get("key_encrypted"),
|
if not admin_user:
|
||||||
name=key_data.get("name"),
|
stats["errors"].append("无法导入独立余额Key: 系统中没有管理员用户")
|
||||||
is_standalone=key_data.get("is_standalone", False),
|
else:
|
||||||
balance_used_usd=key_data.get("balance_used_usd", 0.0),
|
for key_data in standalone_keys_data:
|
||||||
current_balance_usd=key_data.get("current_balance_usd"),
|
new_key, status = _create_api_key_from_data(
|
||||||
allowed_providers=key_data.get("allowed_providers"),
|
key_data, admin_user.id, is_standalone=True
|
||||||
allowed_endpoints=key_data.get("allowed_endpoints"),
|
)
|
||||||
allowed_api_formats=key_data.get("allowed_api_formats"),
|
if new_key:
|
||||||
allowed_models=key_data.get("allowed_models"),
|
db.add(new_key)
|
||||||
rate_limit=key_data.get("rate_limit"), # None = 无限制
|
stats["standalone_keys"]["created"] += 1
|
||||||
concurrent_limit=key_data.get("concurrent_limit", 5),
|
elif status == "skipped":
|
||||||
force_capabilities=key_data.get("force_capabilities"),
|
stats["standalone_keys"]["skipped"] += 1
|
||||||
is_active=key_data.get("is_active", True),
|
# invalid 数据不计入统计
|
||||||
auto_delete_on_expiry=key_data.get("auto_delete_on_expiry", False),
|
|
||||||
total_requests=key_data.get("total_requests", 0),
|
|
||||||
total_cost_usd=key_data.get("total_cost_usd", 0.0),
|
|
||||||
)
|
|
||||||
db.add(new_key)
|
|
||||||
stats["api_keys"]["created"] += 1
|
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
|
|||||||
@@ -92,6 +92,7 @@ async def get_usage_records(
|
|||||||
request: Request,
|
request: Request,
|
||||||
start_date: Optional[datetime] = None,
|
start_date: Optional[datetime] = None,
|
||||||
end_date: Optional[datetime] = None,
|
end_date: Optional[datetime] = None,
|
||||||
|
search: Optional[str] = None, # 通用搜索:用户名、密钥名、模型名、提供商名
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
username: Optional[str] = None,
|
username: Optional[str] = None,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
@@ -104,6 +105,7 @@ async def get_usage_records(
|
|||||||
adapter = AdminUsageRecordsAdapter(
|
adapter = AdminUsageRecordsAdapter(
|
||||||
start_date=start_date,
|
start_date=start_date,
|
||||||
end_date=end_date,
|
end_date=end_date,
|
||||||
|
search=search,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
username=username,
|
username=username,
|
||||||
model=model,
|
model=model,
|
||||||
@@ -500,6 +502,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
|
|||||||
self,
|
self,
|
||||||
start_date: Optional[datetime],
|
start_date: Optional[datetime],
|
||||||
end_date: Optional[datetime],
|
end_date: Optional[datetime],
|
||||||
|
search: Optional[str],
|
||||||
user_id: Optional[str],
|
user_id: Optional[str],
|
||||||
username: Optional[str],
|
username: Optional[str],
|
||||||
model: Optional[str],
|
model: Optional[str],
|
||||||
@@ -510,6 +513,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
|
|||||||
):
|
):
|
||||||
self.start_date = start_date
|
self.start_date = start_date
|
||||||
self.end_date = end_date
|
self.end_date = end_date
|
||||||
|
self.search = search
|
||||||
self.user_id = user_id
|
self.user_id = user_id
|
||||||
self.username = username
|
self.username = username
|
||||||
self.model = model
|
self.model = model
|
||||||
@@ -519,25 +523,54 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
|
|||||||
self.offset = offset
|
self.offset = offset
|
||||||
|
|
||||||
async def handle(self, context): # type: ignore[override]
|
async def handle(self, context): # type: ignore[override]
|
||||||
|
from sqlalchemy import or_
|
||||||
|
|
||||||
|
from src.utils.database_helpers import escape_like_pattern, safe_truncate_escaped
|
||||||
|
|
||||||
db = context.db
|
db = context.db
|
||||||
query = (
|
query = (
|
||||||
db.query(Usage, User, ProviderEndpoint, ProviderAPIKey)
|
db.query(Usage, User, ProviderEndpoint, ProviderAPIKey, ApiKey)
|
||||||
.outerjoin(User, Usage.user_id == User.id)
|
.outerjoin(User, Usage.user_id == User.id)
|
||||||
.outerjoin(ProviderEndpoint, Usage.provider_endpoint_id == ProviderEndpoint.id)
|
.outerjoin(ProviderEndpoint, Usage.provider_endpoint_id == ProviderEndpoint.id)
|
||||||
.outerjoin(ProviderAPIKey, Usage.provider_api_key_id == ProviderAPIKey.id)
|
.outerjoin(ProviderAPIKey, Usage.provider_api_key_id == ProviderAPIKey.id)
|
||||||
|
.outerjoin(ApiKey, Usage.api_key_id == ApiKey.id)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 如果需要按 Provider 名称搜索/筛选,统一在这里 JOIN
|
||||||
|
if self.search or self.provider:
|
||||||
|
query = query.join(Provider, Usage.provider_id == Provider.id, isouter=True)
|
||||||
|
|
||||||
|
# 通用搜索:用户名、密钥名、模型名、提供商名
|
||||||
|
# 支持空格分隔的组合搜索,多个关键词之间是 AND 关系
|
||||||
|
# 限制:最多 10 个关键词,转义后每个关键词最长 100 字符
|
||||||
|
if self.search:
|
||||||
|
keywords = [kw for kw in self.search.strip().split() if kw][:10]
|
||||||
|
for keyword in keywords:
|
||||||
|
escaped = safe_truncate_escaped(escape_like_pattern(keyword), 100)
|
||||||
|
search_pattern = f"%{escaped}%"
|
||||||
|
query = query.filter(
|
||||||
|
or_(
|
||||||
|
User.username.ilike(search_pattern, escape="\\"),
|
||||||
|
ApiKey.name.ilike(search_pattern, escape="\\"),
|
||||||
|
Usage.model.ilike(search_pattern, escape="\\"),
|
||||||
|
Provider.name.ilike(search_pattern, escape="\\"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if self.user_id:
|
if self.user_id:
|
||||||
query = query.filter(Usage.user_id == self.user_id)
|
query = query.filter(Usage.user_id == self.user_id)
|
||||||
if self.username:
|
if self.username:
|
||||||
# 支持用户名模糊搜索
|
# 支持用户名模糊搜索
|
||||||
query = query.filter(User.username.ilike(f"%{self.username}%"))
|
escaped = escape_like_pattern(self.username)
|
||||||
|
query = query.filter(User.username.ilike(f"%{escaped}%", escape="\\"))
|
||||||
if self.model:
|
if self.model:
|
||||||
# 支持模型名模糊搜索
|
# 支持模型名模糊搜索
|
||||||
query = query.filter(Usage.model.ilike(f"%{self.model}%"))
|
escaped = escape_like_pattern(self.model)
|
||||||
|
query = query.filter(Usage.model.ilike(f"%{escaped}%", escape="\\"))
|
||||||
if self.provider:
|
if self.provider:
|
||||||
# 支持提供商名称搜索(通过 Provider 表)
|
# 支持提供商名称搜索
|
||||||
query = query.join(Provider, Usage.provider_id == Provider.id, isouter=True)
|
escaped = escape_like_pattern(self.provider)
|
||||||
query = query.filter(Provider.name.ilike(f"%{self.provider}%"))
|
query = query.filter(Provider.name.ilike(f"%{escaped}%", escape="\\"))
|
||||||
if self.status:
|
if self.status:
|
||||||
# 状态筛选
|
# 状态筛选
|
||||||
# 旧的筛选值(基于 is_stream 和 status_code):stream, standard, error
|
# 旧的筛选值(基于 is_stream 和 status_code):stream, standard, error
|
||||||
@@ -575,7 +608,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
|
|||||||
query.order_by(Usage.created_at.desc()).offset(self.offset).limit(self.limit).all()
|
query.order_by(Usage.created_at.desc()).offset(self.offset).limit(self.limit).all()
|
||||||
)
|
)
|
||||||
|
|
||||||
request_ids = [usage.request_id for usage, _, _, _ in records if usage.request_id]
|
request_ids = [usage.request_id for usage, _, _, _, _ in records if usage.request_id]
|
||||||
fallback_map = {}
|
fallback_map = {}
|
||||||
if request_ids:
|
if request_ids:
|
||||||
# 只统计实际执行的候选(success 或 failed),不包括 skipped/pending/available
|
# 只统计实际执行的候选(success 或 failed),不包括 skipped/pending/available
|
||||||
@@ -595,6 +628,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
|
|||||||
action="usage_records",
|
action="usage_records",
|
||||||
start_date=self.start_date.isoformat() if self.start_date else None,
|
start_date=self.start_date.isoformat() if self.start_date else None,
|
||||||
end_date=self.end_date.isoformat() if self.end_date else None,
|
end_date=self.end_date.isoformat() if self.end_date else None,
|
||||||
|
search=self.search,
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
username=self.username,
|
username=self.username,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
@@ -606,7 +640,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 构建 provider_id -> Provider 名称的映射,避免 N+1 查询
|
# 构建 provider_id -> Provider 名称的映射,避免 N+1 查询
|
||||||
provider_ids = [usage.provider_id for usage, _, _, _ in records if usage.provider_id]
|
provider_ids = [usage.provider_id for usage, _, _, _, _ in records if usage.provider_id]
|
||||||
provider_map = {}
|
provider_map = {}
|
||||||
if provider_ids:
|
if provider_ids:
|
||||||
providers_data = (
|
providers_data = (
|
||||||
@@ -615,7 +649,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
|
|||||||
provider_map = {str(p.id): p.name for p in providers_data}
|
provider_map = {str(p.id): p.name for p in providers_data}
|
||||||
|
|
||||||
data = []
|
data = []
|
||||||
for usage, user, endpoint, api_key in records:
|
for usage, user, endpoint, provider_api_key, user_api_key in records:
|
||||||
actual_cost = (
|
actual_cost = (
|
||||||
float(usage.actual_total_cost_usd)
|
float(usage.actual_total_cost_usd)
|
||||||
if usage.actual_total_cost_usd is not None
|
if usage.actual_total_cost_usd is not None
|
||||||
@@ -636,6 +670,15 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
|
|||||||
"user_id": user.id if user else None,
|
"user_id": user.id if user else None,
|
||||||
"user_email": user.email if user else "已删除用户",
|
"user_email": user.email if user else "已删除用户",
|
||||||
"username": user.username if user else "已删除用户",
|
"username": user.username if user else "已删除用户",
|
||||||
|
"api_key": (
|
||||||
|
{
|
||||||
|
"id": user_api_key.id,
|
||||||
|
"name": user_api_key.name,
|
||||||
|
"display": user_api_key.get_display_key(),
|
||||||
|
}
|
||||||
|
if user_api_key
|
||||||
|
else None
|
||||||
|
),
|
||||||
"provider": provider_name,
|
"provider": provider_name,
|
||||||
"model": usage.model,
|
"model": usage.model,
|
||||||
"target_model": usage.target_model, # 映射后的目标模型名
|
"target_model": usage.target_model, # 映射后的目标模型名
|
||||||
@@ -661,7 +704,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
|
|||||||
"has_fallback": fallback_map.get(usage.request_id, False),
|
"has_fallback": fallback_map.get(usage.request_id, False),
|
||||||
"api_format": usage.api_format
|
"api_format": usage.api_format
|
||||||
or (endpoint.api_format if endpoint and endpoint.api_format else None),
|
or (endpoint.api_format if endpoint and endpoint.api_format else None),
|
||||||
"api_key_name": api_key.name if api_key else None,
|
"api_key_name": provider_api_key.name if provider_api_key else None,
|
||||||
"request_metadata": usage.request_metadata, # Provider 响应元数据
|
"request_metadata": usage.request_metadata, # Provider 响应元数据
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ from src.core.exceptions import (
|
|||||||
UpstreamClientException,
|
UpstreamClientException,
|
||||||
)
|
)
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
|
from src.services.billing import calculate_request_cost as _calculate_request_cost
|
||||||
from src.services.request.result import RequestResult
|
from src.services.request.result import RequestResult
|
||||||
from src.services.usage.recorder import UsageRecorder
|
from src.services.usage.recorder import UsageRecorder
|
||||||
|
|
||||||
@@ -63,6 +64,9 @@ class ChatAdapterBase(ApiAdapter):
|
|||||||
name: str = "chat.base"
|
name: str = "chat.base"
|
||||||
mode = ApiMode.STANDARD
|
mode = ApiMode.STANDARD
|
||||||
|
|
||||||
|
# 计费模板配置(子类可覆盖,如 "claude", "openai", "gemini")
|
||||||
|
BILLING_TEMPLATE: str = "claude"
|
||||||
|
|
||||||
# 子类可以配置的特殊方法(用于check_endpoint)
|
# 子类可以配置的特殊方法(用于check_endpoint)
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_endpoint_url(cls, base_url: str) -> str:
|
def build_endpoint_url(cls, base_url: str) -> str:
|
||||||
@@ -486,40 +490,6 @@ class ChatAdapterBase(ApiAdapter):
|
|||||||
"""
|
"""
|
||||||
return input_tokens + cache_read_input_tokens
|
return input_tokens + cache_read_input_tokens
|
||||||
|
|
||||||
def get_cache_read_price_for_ttl(
|
|
||||||
self,
|
|
||||||
tier: dict,
|
|
||||||
cache_ttl_minutes: Optional[int] = None,
|
|
||||||
) -> Optional[float]:
|
|
||||||
"""
|
|
||||||
根据缓存 TTL 获取缓存读取价格
|
|
||||||
|
|
||||||
默认实现:检查 cache_ttl_pricing 配置,按 TTL 选择价格
|
|
||||||
子类可覆盖此方法实现不同的 TTL 定价逻辑
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tier: 当前阶梯配置
|
|
||||||
cache_ttl_minutes: 缓存时长(分钟)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
缓存读取价格(每 1M tokens)
|
|
||||||
"""
|
|
||||||
ttl_pricing = tier.get("cache_ttl_pricing")
|
|
||||||
if ttl_pricing and cache_ttl_minutes is not None:
|
|
||||||
matched_price = None
|
|
||||||
for ttl_config in ttl_pricing:
|
|
||||||
ttl_limit = ttl_config.get("ttl_minutes", 0)
|
|
||||||
if cache_ttl_minutes <= ttl_limit:
|
|
||||||
matched_price = ttl_config.get("cache_read_price_per_1m")
|
|
||||||
break
|
|
||||||
if matched_price is not None:
|
|
||||||
return matched_price
|
|
||||||
# 超过所有配置的 TTL,使用最后一个
|
|
||||||
if ttl_pricing:
|
|
||||||
return ttl_pricing[-1].get("cache_read_price_per_1m")
|
|
||||||
|
|
||||||
return tier.get("cache_read_price_per_1m")
|
|
||||||
|
|
||||||
def compute_cost(
|
def compute_cost(
|
||||||
self,
|
self,
|
||||||
input_tokens: int,
|
input_tokens: int,
|
||||||
@@ -537,8 +507,9 @@ class ChatAdapterBase(ApiAdapter):
|
|||||||
"""
|
"""
|
||||||
计算请求成本
|
计算请求成本
|
||||||
|
|
||||||
默认实现:支持固定价格和阶梯计费
|
使用 billing 模块的配置驱动计费。
|
||||||
子类可覆盖此方法实现完全不同的计费逻辑
|
子类可通过设置 BILLING_TEMPLATE 类属性来指定计费模板,
|
||||||
|
或覆盖此方法实现完全自定义的计费逻辑。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_tokens: 输入 token 数
|
input_tokens: 输入 token 数
|
||||||
@@ -566,88 +537,26 @@ class ChatAdapterBase(ApiAdapter):
|
|||||||
"tier_index": Optional[int], # 命中的阶梯索引
|
"tier_index": Optional[int], # 命中的阶梯索引
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
tier_index = None
|
# 计算总输入上下文(使用子类可覆盖的方法)
|
||||||
effective_input_price = input_price_per_1m
|
total_input_context = self.compute_total_input_context(
|
||||||
effective_output_price = output_price_per_1m
|
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
|
||||||
effective_cache_creation_price = cache_creation_price_per_1m
|
)
|
||||||
effective_cache_read_price = cache_read_price_per_1m
|
|
||||||
|
|
||||||
# 检查阶梯计费
|
return _calculate_request_cost(
|
||||||
if tiered_pricing and tiered_pricing.get("tiers"):
|
input_tokens=input_tokens,
|
||||||
total_input_context = self.compute_total_input_context(
|
output_tokens=output_tokens,
|
||||||
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
|
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||||
)
|
cache_read_input_tokens=cache_read_input_tokens,
|
||||||
tier = self._get_tier_for_tokens(tiered_pricing, total_input_context)
|
input_price_per_1m=input_price_per_1m,
|
||||||
|
output_price_per_1m=output_price_per_1m,
|
||||||
if tier:
|
cache_creation_price_per_1m=cache_creation_price_per_1m,
|
||||||
tier_index = tiered_pricing["tiers"].index(tier)
|
cache_read_price_per_1m=cache_read_price_per_1m,
|
||||||
effective_input_price = tier.get("input_price_per_1m", input_price_per_1m)
|
price_per_request=price_per_request,
|
||||||
effective_output_price = tier.get("output_price_per_1m", output_price_per_1m)
|
tiered_pricing=tiered_pricing,
|
||||||
effective_cache_creation_price = tier.get(
|
cache_ttl_minutes=cache_ttl_minutes,
|
||||||
"cache_creation_price_per_1m", cache_creation_price_per_1m
|
total_input_context=total_input_context,
|
||||||
)
|
billing_template=self.BILLING_TEMPLATE,
|
||||||
effective_cache_read_price = self.get_cache_read_price_for_ttl(
|
)
|
||||||
tier, cache_ttl_minutes
|
|
||||||
)
|
|
||||||
if effective_cache_read_price is None:
|
|
||||||
effective_cache_read_price = cache_read_price_per_1m
|
|
||||||
|
|
||||||
# 计算各项成本
|
|
||||||
input_cost = (input_tokens / 1_000_000) * effective_input_price
|
|
||||||
output_cost = (output_tokens / 1_000_000) * effective_output_price
|
|
||||||
|
|
||||||
cache_creation_cost = 0.0
|
|
||||||
cache_read_cost = 0.0
|
|
||||||
if cache_creation_input_tokens > 0 and effective_cache_creation_price is not None:
|
|
||||||
cache_creation_cost = (
|
|
||||||
cache_creation_input_tokens / 1_000_000
|
|
||||||
) * effective_cache_creation_price
|
|
||||||
if cache_read_input_tokens > 0 and effective_cache_read_price is not None:
|
|
||||||
cache_read_cost = (
|
|
||||||
cache_read_input_tokens / 1_000_000
|
|
||||||
) * effective_cache_read_price
|
|
||||||
|
|
||||||
cache_cost = cache_creation_cost + cache_read_cost
|
|
||||||
request_cost = price_per_request if price_per_request else 0.0
|
|
||||||
total_cost = input_cost + output_cost + cache_cost + request_cost
|
|
||||||
|
|
||||||
return {
|
|
||||||
"input_cost": input_cost,
|
|
||||||
"output_cost": output_cost,
|
|
||||||
"cache_creation_cost": cache_creation_cost,
|
|
||||||
"cache_read_cost": cache_read_cost,
|
|
||||||
"cache_cost": cache_cost,
|
|
||||||
"request_cost": request_cost,
|
|
||||||
"total_cost": total_cost,
|
|
||||||
"tier_index": tier_index,
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_tier_for_tokens(tiered_pricing: dict, total_input_tokens: int) -> Optional[dict]:
|
|
||||||
"""
|
|
||||||
根据总输入 token 数确定价格阶梯
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tiered_pricing: 阶梯计费配置 {"tiers": [...]}
|
|
||||||
total_input_tokens: 总输入 token 数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
匹配的阶梯配置
|
|
||||||
"""
|
|
||||||
if not tiered_pricing or "tiers" not in tiered_pricing:
|
|
||||||
return None
|
|
||||||
|
|
||||||
tiers = tiered_pricing.get("tiers", [])
|
|
||||||
if not tiers:
|
|
||||||
return None
|
|
||||||
|
|
||||||
for tier in tiers:
|
|
||||||
up_to = tier.get("up_to")
|
|
||||||
if up_to is None or total_input_tokens <= up_to:
|
|
||||||
return tier
|
|
||||||
|
|
||||||
# 如果所有阶梯都有上限且都超过了,返回最后一个阶梯
|
|
||||||
return tiers[-1] if tiers else None
|
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# 模型列表查询 - 子类应覆盖此方法
|
# 模型列表查询 - 子类应覆盖此方法
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from src.core.exceptions import (
|
|||||||
UpstreamClientException,
|
UpstreamClientException,
|
||||||
)
|
)
|
||||||
from src.core.logger import logger
|
from src.core.logger import logger
|
||||||
|
from src.services.billing import calculate_request_cost as _calculate_request_cost
|
||||||
from src.services.request.result import RequestResult
|
from src.services.request.result import RequestResult
|
||||||
from src.services.usage.recorder import UsageRecorder
|
from src.services.usage.recorder import UsageRecorder
|
||||||
|
|
||||||
@@ -61,6 +62,9 @@ class CliAdapterBase(ApiAdapter):
|
|||||||
name: str = "cli.base"
|
name: str = "cli.base"
|
||||||
mode = ApiMode.PROXY
|
mode = ApiMode.PROXY
|
||||||
|
|
||||||
|
# 计费模板配置(子类可覆盖,如 "claude", "openai", "gemini")
|
||||||
|
BILLING_TEMPLATE: str = "claude"
|
||||||
|
|
||||||
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
|
||||||
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
|
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
|
||||||
|
|
||||||
@@ -438,40 +442,6 @@ class CliAdapterBase(ApiAdapter):
|
|||||||
"""
|
"""
|
||||||
return input_tokens + cache_read_input_tokens
|
return input_tokens + cache_read_input_tokens
|
||||||
|
|
||||||
def get_cache_read_price_for_ttl(
|
|
||||||
self,
|
|
||||||
tier: dict,
|
|
||||||
cache_ttl_minutes: Optional[int] = None,
|
|
||||||
) -> Optional[float]:
|
|
||||||
"""
|
|
||||||
根据缓存 TTL 获取缓存读取价格
|
|
||||||
|
|
||||||
默认实现:检查 cache_ttl_pricing 配置,按 TTL 选择价格
|
|
||||||
子类可覆盖此方法实现不同的 TTL 定价逻辑
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tier: 当前阶梯配置
|
|
||||||
cache_ttl_minutes: 缓存时长(分钟)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
缓存读取价格(每 1M tokens)
|
|
||||||
"""
|
|
||||||
ttl_pricing = tier.get("cache_ttl_pricing")
|
|
||||||
if ttl_pricing and cache_ttl_minutes is not None:
|
|
||||||
matched_price = None
|
|
||||||
for ttl_config in ttl_pricing:
|
|
||||||
ttl_limit = ttl_config.get("ttl_minutes", 0)
|
|
||||||
if cache_ttl_minutes <= ttl_limit:
|
|
||||||
matched_price = ttl_config.get("cache_read_price_per_1m")
|
|
||||||
break
|
|
||||||
if matched_price is not None:
|
|
||||||
return matched_price
|
|
||||||
# 超过所有配置的 TTL,使用最后一个
|
|
||||||
if ttl_pricing:
|
|
||||||
return ttl_pricing[-1].get("cache_read_price_per_1m")
|
|
||||||
|
|
||||||
return tier.get("cache_read_price_per_1m")
|
|
||||||
|
|
||||||
def compute_cost(
|
def compute_cost(
|
||||||
self,
|
self,
|
||||||
input_tokens: int,
|
input_tokens: int,
|
||||||
@@ -489,8 +459,9 @@ class CliAdapterBase(ApiAdapter):
|
|||||||
"""
|
"""
|
||||||
计算请求成本
|
计算请求成本
|
||||||
|
|
||||||
默认实现:支持固定价格和阶梯计费
|
使用 billing 模块的配置驱动计费。
|
||||||
子类可覆盖此方法实现完全不同的计费逻辑
|
子类可通过设置 BILLING_TEMPLATE 类属性来指定计费模板,
|
||||||
|
或覆盖此方法实现完全自定义的计费逻辑。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_tokens: 输入 token 数
|
input_tokens: 输入 token 数
|
||||||
@@ -508,78 +479,26 @@ class CliAdapterBase(ApiAdapter):
|
|||||||
Returns:
|
Returns:
|
||||||
包含各项成本的字典
|
包含各项成本的字典
|
||||||
"""
|
"""
|
||||||
tier_index = None
|
# 计算总输入上下文(使用子类可覆盖的方法)
|
||||||
effective_input_price = input_price_per_1m
|
total_input_context = self.compute_total_input_context(
|
||||||
effective_output_price = output_price_per_1m
|
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
|
||||||
effective_cache_creation_price = cache_creation_price_per_1m
|
)
|
||||||
effective_cache_read_price = cache_read_price_per_1m
|
|
||||||
|
|
||||||
# 检查阶梯计费
|
return _calculate_request_cost(
|
||||||
if tiered_pricing and tiered_pricing.get("tiers"):
|
input_tokens=input_tokens,
|
||||||
total_input_context = self.compute_total_input_context(
|
output_tokens=output_tokens,
|
||||||
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
|
cache_creation_input_tokens=cache_creation_input_tokens,
|
||||||
)
|
cache_read_input_tokens=cache_read_input_tokens,
|
||||||
tier = self._get_tier_for_tokens(tiered_pricing, total_input_context)
|
input_price_per_1m=input_price_per_1m,
|
||||||
|
output_price_per_1m=output_price_per_1m,
|
||||||
if tier:
|
cache_creation_price_per_1m=cache_creation_price_per_1m,
|
||||||
tier_index = tiered_pricing["tiers"].index(tier)
|
cache_read_price_per_1m=cache_read_price_per_1m,
|
||||||
effective_input_price = tier.get("input_price_per_1m", input_price_per_1m)
|
price_per_request=price_per_request,
|
||||||
effective_output_price = tier.get("output_price_per_1m", output_price_per_1m)
|
tiered_pricing=tiered_pricing,
|
||||||
effective_cache_creation_price = tier.get(
|
cache_ttl_minutes=cache_ttl_minutes,
|
||||||
"cache_creation_price_per_1m", cache_creation_price_per_1m
|
total_input_context=total_input_context,
|
||||||
)
|
billing_template=self.BILLING_TEMPLATE,
|
||||||
effective_cache_read_price = self.get_cache_read_price_for_ttl(
|
)
|
||||||
tier, cache_ttl_minutes
|
|
||||||
)
|
|
||||||
if effective_cache_read_price is None:
|
|
||||||
effective_cache_read_price = cache_read_price_per_1m
|
|
||||||
|
|
||||||
# 计算各项成本
|
|
||||||
input_cost = (input_tokens / 1_000_000) * effective_input_price
|
|
||||||
output_cost = (output_tokens / 1_000_000) * effective_output_price
|
|
||||||
|
|
||||||
cache_creation_cost = 0.0
|
|
||||||
cache_read_cost = 0.0
|
|
||||||
if cache_creation_input_tokens > 0 and effective_cache_creation_price is not None:
|
|
||||||
cache_creation_cost = (
|
|
||||||
cache_creation_input_tokens / 1_000_000
|
|
||||||
) * effective_cache_creation_price
|
|
||||||
if cache_read_input_tokens > 0 and effective_cache_read_price is not None:
|
|
||||||
cache_read_cost = (
|
|
||||||
cache_read_input_tokens / 1_000_000
|
|
||||||
) * effective_cache_read_price
|
|
||||||
|
|
||||||
cache_cost = cache_creation_cost + cache_read_cost
|
|
||||||
request_cost = price_per_request if price_per_request else 0.0
|
|
||||||
total_cost = input_cost + output_cost + cache_cost + request_cost
|
|
||||||
|
|
||||||
return {
|
|
||||||
"input_cost": input_cost,
|
|
||||||
"output_cost": output_cost,
|
|
||||||
"cache_creation_cost": cache_creation_cost,
|
|
||||||
"cache_read_cost": cache_read_cost,
|
|
||||||
"cache_cost": cache_cost,
|
|
||||||
"request_cost": request_cost,
|
|
||||||
"total_cost": total_cost,
|
|
||||||
"tier_index": tier_index,
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_tier_for_tokens(tiered_pricing: dict, total_input_tokens: int) -> Optional[dict]:
|
|
||||||
"""根据总输入 token 数确定价格阶梯"""
|
|
||||||
if not tiered_pricing or "tiers" not in tiered_pricing:
|
|
||||||
return None
|
|
||||||
|
|
||||||
tiers = tiered_pricing.get("tiers", [])
|
|
||||||
if not tiers:
|
|
||||||
return None
|
|
||||||
|
|
||||||
for tier in tiers:
|
|
||||||
up_to = tier.get("up_to")
|
|
||||||
if up_to is None or total_input_tokens <= up_to:
|
|
||||||
return tier
|
|
||||||
|
|
||||||
return tiers[-1] if tiers else None
|
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# 模型列表查询 - 子类应覆盖此方法
|
# 模型列表查询 - 子类应覆盖此方法
|
||||||
|
|||||||
@@ -1497,8 +1497,12 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
retry_after=int(resp.headers.get("retry-after", 0)) or None,
|
retry_after=int(resp.headers.get("retry-after", 0)) or None,
|
||||||
)
|
)
|
||||||
elif resp.status_code >= 500:
|
elif resp.status_code >= 500:
|
||||||
|
error_text = resp.text
|
||||||
raise ProviderNotAvailableException(
|
raise ProviderNotAvailableException(
|
||||||
f"提供商服务不可用: {provider.name}, 状态: {resp.status_code}"
|
f"提供商服务不可用: {provider.name}, 状态: {resp.status_code}",
|
||||||
|
provider_name=str(provider.name),
|
||||||
|
upstream_status=resp.status_code,
|
||||||
|
upstream_response=error_text,
|
||||||
)
|
)
|
||||||
elif 300 <= resp.status_code < 400:
|
elif 300 <= resp.status_code < 400:
|
||||||
redirect_url = resp.headers.get("location", "unknown")
|
redirect_url = resp.headers.get("location", "unknown")
|
||||||
@@ -1508,7 +1512,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
|||||||
elif resp.status_code != 200:
|
elif resp.status_code != 200:
|
||||||
error_text = resp.text
|
error_text = resp.text
|
||||||
raise ProviderNotAvailableException(
|
raise ProviderNotAvailableException(
|
||||||
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}, 错误: {error_text[:200]}"
|
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}",
|
||||||
|
provider_name=str(provider.name),
|
||||||
|
upstream_status=resp.status_code,
|
||||||
|
upstream_response=error_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 安全解析 JSON 响应,处理可能的编码错误
|
# 安全解析 JSON 响应,处理可能的编码错误
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ class ClaudeChatAdapter(ChatAdapterBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
FORMAT_ID = "CLAUDE"
|
FORMAT_ID = "CLAUDE"
|
||||||
|
BILLING_TEMPLATE = "claude" # 使用 Claude 计费模板
|
||||||
name = "claude.chat"
|
name = "claude.chat"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ class ClaudeCliAdapter(CliAdapterBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
FORMAT_ID = "CLAUDE_CLI"
|
FORMAT_ID = "CLAUDE_CLI"
|
||||||
|
BILLING_TEMPLATE = "claude" # 使用 Claude 计费模板
|
||||||
name = "claude.cli"
|
name = "claude.cli"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ class GeminiChatAdapter(ChatAdapterBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
FORMAT_ID = "GEMINI"
|
FORMAT_ID = "GEMINI"
|
||||||
|
BILLING_TEMPLATE = "gemini" # 使用 Gemini 计费模板
|
||||||
name = "gemini.chat"
|
name = "gemini.chat"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ class GeminiCliAdapter(CliAdapterBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
FORMAT_ID = "GEMINI_CLI"
|
FORMAT_ID = "GEMINI_CLI"
|
||||||
|
BILLING_TEMPLATE = "gemini" # 使用 Gemini 计费模板
|
||||||
name = "gemini.cli"
|
name = "gemini.cli"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ class OpenAIChatAdapter(ChatAdapterBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
FORMAT_ID = "OPENAI"
|
FORMAT_ID = "OPENAI"
|
||||||
|
BILLING_TEMPLATE = "openai" # 使用 OpenAI 计费模板
|
||||||
name = "openai.chat"
|
name = "openai.chat"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ class OpenAICliAdapter(CliAdapterBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
FORMAT_ID = "OPENAI_CLI"
|
FORMAT_ID = "OPENAI_CLI"
|
||||||
|
BILLING_TEMPLATE = "openai" # 使用 OpenAI 计费模板
|
||||||
name = "openai.cli"
|
name = "openai.cli"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -104,11 +104,14 @@ async def get_my_usage(
|
|||||||
request: Request,
|
request: Request,
|
||||||
start_date: Optional[datetime] = None,
|
start_date: Optional[datetime] = None,
|
||||||
end_date: Optional[datetime] = None,
|
end_date: Optional[datetime] = None,
|
||||||
|
search: Optional[str] = None, # 通用搜索:密钥名、模型名
|
||||||
limit: int = Query(100, ge=1, le=200, description="每页记录数,默认100,最大200"),
|
limit: int = Query(100, ge=1, le=200, description="每页记录数,默认100,最大200"),
|
||||||
offset: int = Query(0, ge=0, le=2000, description="偏移量,用于分页,最大2000"),
|
offset: int = Query(0, ge=0, le=2000, description="偏移量,用于分页,最大2000"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
adapter = GetUsageAdapter(start_date=start_date, end_date=end_date, limit=limit, offset=offset)
|
adapter = GetUsageAdapter(
|
||||||
|
start_date=start_date, end_date=end_date, search=search, limit=limit, offset=offset
|
||||||
|
)
|
||||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||||
|
|
||||||
|
|
||||||
@@ -487,10 +490,15 @@ class ToggleMyApiKeyAdapter(AuthenticatedApiAdapter):
|
|||||||
class GetUsageAdapter(AuthenticatedApiAdapter):
|
class GetUsageAdapter(AuthenticatedApiAdapter):
|
||||||
start_date: Optional[datetime]
|
start_date: Optional[datetime]
|
||||||
end_date: Optional[datetime]
|
end_date: Optional[datetime]
|
||||||
|
search: Optional[str] = None
|
||||||
limit: int = 100
|
limit: int = 100
|
||||||
offset: int = 0
|
offset: int = 0
|
||||||
|
|
||||||
async def handle(self, context): # type: ignore[override]
|
async def handle(self, context): # type: ignore[override]
|
||||||
|
from sqlalchemy import or_
|
||||||
|
|
||||||
|
from src.utils.database_helpers import escape_like_pattern, safe_truncate_escaped
|
||||||
|
|
||||||
db = context.db
|
db = context.db
|
||||||
user = context.user
|
user = context.user
|
||||||
summary_list = UsageService.get_usage_summary(
|
summary_list = UsageService.get_usage_summary(
|
||||||
@@ -595,12 +603,30 @@ class GetUsageAdapter(AuthenticatedApiAdapter):
|
|||||||
})
|
})
|
||||||
summary_by_provider = sorted(summary_by_provider, key=lambda x: x["requests"], reverse=True)
|
summary_by_provider = sorted(summary_by_provider, key=lambda x: x["requests"], reverse=True)
|
||||||
|
|
||||||
query = db.query(Usage).filter(Usage.user_id == user.id)
|
query = (
|
||||||
|
db.query(Usage, ApiKey)
|
||||||
|
.outerjoin(ApiKey, Usage.api_key_id == ApiKey.id)
|
||||||
|
.filter(Usage.user_id == user.id)
|
||||||
|
)
|
||||||
if self.start_date:
|
if self.start_date:
|
||||||
query = query.filter(Usage.created_at >= self.start_date)
|
query = query.filter(Usage.created_at >= self.start_date)
|
||||||
if self.end_date:
|
if self.end_date:
|
||||||
query = query.filter(Usage.created_at <= self.end_date)
|
query = query.filter(Usage.created_at <= self.end_date)
|
||||||
|
|
||||||
|
# 通用搜索:密钥名、模型名
|
||||||
|
# 支持空格分隔的组合搜索,多个关键词之间是 AND 关系
|
||||||
|
if self.search and self.search.strip():
|
||||||
|
keywords = [kw for kw in self.search.strip().split() if kw][:10]
|
||||||
|
for keyword in keywords:
|
||||||
|
escaped = safe_truncate_escaped(escape_like_pattern(keyword), 100)
|
||||||
|
search_pattern = f"%{escaped}%"
|
||||||
|
query = query.filter(
|
||||||
|
or_(
|
||||||
|
ApiKey.name.ilike(search_pattern, escape="\\"),
|
||||||
|
Usage.model.ilike(search_pattern, escape="\\"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# 计算总数用于分页
|
# 计算总数用于分页
|
||||||
total_records = query.count()
|
total_records = query.count()
|
||||||
usage_records = query.order_by(Usage.created_at.desc()).offset(self.offset).limit(self.limit).all()
|
usage_records = query.order_by(Usage.created_at.desc()).offset(self.offset).limit(self.limit).all()
|
||||||
@@ -659,8 +685,17 @@ class GetUsageAdapter(AuthenticatedApiAdapter):
|
|||||||
"output_price_per_1m": r.output_price_per_1m,
|
"output_price_per_1m": r.output_price_per_1m,
|
||||||
"cache_creation_price_per_1m": r.cache_creation_price_per_1m,
|
"cache_creation_price_per_1m": r.cache_creation_price_per_1m,
|
||||||
"cache_read_price_per_1m": r.cache_read_price_per_1m,
|
"cache_read_price_per_1m": r.cache_read_price_per_1m,
|
||||||
|
"api_key": (
|
||||||
|
{
|
||||||
|
"id": str(api_key.id),
|
||||||
|
"name": api_key.name,
|
||||||
|
"display": api_key.get_display_key(),
|
||||||
|
}
|
||||||
|
if api_key
|
||||||
|
else None
|
||||||
|
),
|
||||||
}
|
}
|
||||||
for r in usage_records
|
for r, api_key in usage_records
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -668,7 +703,7 @@ class GetUsageAdapter(AuthenticatedApiAdapter):
|
|||||||
if user.role == "admin":
|
if user.role == "admin":
|
||||||
response_data["total_actual_cost"] = total_actual_cost
|
response_data["total_actual_cost"] = total_actual_cost
|
||||||
# 为每条记录添加真实成本和倍率信息
|
# 为每条记录添加真实成本和倍率信息
|
||||||
for i, r in enumerate(usage_records):
|
for i, (r, _) in enumerate(usage_records):
|
||||||
# 确保字段有值,避免前端显示 -
|
# 确保字段有值,避免前端显示 -
|
||||||
actual_cost = (
|
actual_cost = (
|
||||||
r.actual_total_cost_usd if r.actual_total_cost_usd is not None else 0.0
|
r.actual_total_cost_usd if r.actual_total_cost_usd is not None else 0.0
|
||||||
|
|||||||
51
src/services/billing/__init__.py
Normal file
51
src/services/billing/__init__.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
"""
|
||||||
|
计费模块
|
||||||
|
|
||||||
|
提供配置驱动的计费计算,支持不同厂商的差异化计费模式:
|
||||||
|
- Claude: input + output + cache_creation + cache_read
|
||||||
|
- OpenAI: input + output + cache_read (无缓存创建费用)
|
||||||
|
- 豆包: input + output + cache_read + cache_storage (缓存按时计费)
|
||||||
|
- 按次计费: per_request
|
||||||
|
|
||||||
|
使用方式:
|
||||||
|
from src.services.billing import BillingCalculator, UsageMapper, StandardizedUsage
|
||||||
|
|
||||||
|
# 1. 将原始 usage 映射为标准格式
|
||||||
|
usage = UsageMapper.map(raw_usage, api_format="OPENAI")
|
||||||
|
|
||||||
|
# 2. 使用计费计算器计算费用
|
||||||
|
calculator = BillingCalculator(template="openai")
|
||||||
|
result = calculator.calculate(usage, prices)
|
||||||
|
|
||||||
|
# 3. 获取费用明细
|
||||||
|
print(result.total_cost)
|
||||||
|
print(result.costs) # {"input": 0.01, "output": 0.02, ...}
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.services.billing.calculator import BillingCalculator, calculate_request_cost
|
||||||
|
from src.services.billing.models import (
|
||||||
|
BillingDimension,
|
||||||
|
BillingUnit,
|
||||||
|
CostBreakdown,
|
||||||
|
StandardizedUsage,
|
||||||
|
)
|
||||||
|
from src.services.billing.templates import BILLING_TEMPLATE_REGISTRY, BillingTemplates
|
||||||
|
from src.services.billing.usage_mapper import UsageMapper, map_usage, map_usage_from_response
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# 数据模型
|
||||||
|
"BillingDimension",
|
||||||
|
"BillingUnit",
|
||||||
|
"CostBreakdown",
|
||||||
|
"StandardizedUsage",
|
||||||
|
# 模板
|
||||||
|
"BillingTemplates",
|
||||||
|
"BILLING_TEMPLATE_REGISTRY",
|
||||||
|
# 计算器
|
||||||
|
"BillingCalculator",
|
||||||
|
"calculate_request_cost",
|
||||||
|
# 映射器
|
||||||
|
"UsageMapper",
|
||||||
|
"map_usage",
|
||||||
|
"map_usage_from_response",
|
||||||
|
]
|
||||||
339
src/services/billing/calculator.py
Normal file
339
src/services/billing/calculator.py
Normal file
@@ -0,0 +1,339 @@
|
|||||||
|
"""
|
||||||
|
计费计算器
|
||||||
|
|
||||||
|
配置驱动的计费计算,支持:
|
||||||
|
- 固定价格计费
|
||||||
|
- 阶梯计费
|
||||||
|
- 多种计费模板
|
||||||
|
- 自定义计费维度
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from src.services.billing.models import (
|
||||||
|
BillingDimension,
|
||||||
|
BillingUnit,
|
||||||
|
CostBreakdown,
|
||||||
|
StandardizedUsage,
|
||||||
|
)
|
||||||
|
from src.services.billing.templates import (
|
||||||
|
BILLING_TEMPLATE_REGISTRY,
|
||||||
|
BillingTemplates,
|
||||||
|
get_template,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BillingCalculator:
|
||||||
|
"""
|
||||||
|
配置驱动的计费计算器
|
||||||
|
|
||||||
|
支持多种计费模式:
|
||||||
|
- 使用预定义模板(claude, openai, doubao 等)
|
||||||
|
- 自定义计费维度
|
||||||
|
- 阶梯计费
|
||||||
|
|
||||||
|
示例:
|
||||||
|
# 使用模板
|
||||||
|
calculator = BillingCalculator(template="openai")
|
||||||
|
|
||||||
|
# 自定义维度
|
||||||
|
calculator = BillingCalculator(dimensions=[
|
||||||
|
BillingDimension(name="input", usage_field="input_tokens", price_field="input_price_per_1m"),
|
||||||
|
BillingDimension(name="output", usage_field="output_tokens", price_field="output_price_per_1m"),
|
||||||
|
])
|
||||||
|
|
||||||
|
# 计算费用
|
||||||
|
usage = StandardizedUsage(input_tokens=1000, output_tokens=500)
|
||||||
|
prices = {"input_price_per_1m": 3.0, "output_price_per_1m": 15.0}
|
||||||
|
result = calculator.calculate(usage, prices)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dimensions: Optional[List[BillingDimension]] = None,
|
||||||
|
template: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化计费计算器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dimensions: 自定义计费维度列表(优先级高于模板)
|
||||||
|
template: 使用预定义模板名称 ("claude", "openai", "doubao", "per_request" 等)
|
||||||
|
"""
|
||||||
|
if dimensions:
|
||||||
|
self.dimensions = dimensions
|
||||||
|
elif template:
|
||||||
|
self.dimensions = get_template(template)
|
||||||
|
else:
|
||||||
|
# 默认使用 Claude 模板(向后兼容)
|
||||||
|
self.dimensions = BillingTemplates.CLAUDE_STANDARD
|
||||||
|
|
||||||
|
self.template_name = template
|
||||||
|
|
||||||
|
def calculate(
|
||||||
|
self,
|
||||||
|
usage: StandardizedUsage,
|
||||||
|
prices: Dict[str, float],
|
||||||
|
tiered_pricing: Optional[Dict[str, Any]] = None,
|
||||||
|
cache_ttl_minutes: Optional[int] = None,
|
||||||
|
total_input_context: Optional[int] = None,
|
||||||
|
) -> CostBreakdown:
|
||||||
|
"""
|
||||||
|
计算费用
|
||||||
|
|
||||||
|
Args:
|
||||||
|
usage: 标准化的 usage 数据
|
||||||
|
prices: 价格配置 {"input_price_per_1m": 3.0, "output_price_per_1m": 15.0, ...}
|
||||||
|
tiered_pricing: 阶梯计费配置(可选)
|
||||||
|
cache_ttl_minutes: 缓存 TTL 分钟数(用于 TTL 差异化定价)
|
||||||
|
total_input_context: 总输入上下文(用于阶梯判定,可选)
|
||||||
|
如果提供,将使用该值进行阶梯判定;否则使用默认计算逻辑
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
费用明细 (CostBreakdown)
|
||||||
|
"""
|
||||||
|
result = CostBreakdown()
|
||||||
|
|
||||||
|
# 处理阶梯计费
|
||||||
|
effective_prices = prices.copy()
|
||||||
|
if tiered_pricing and tiered_pricing.get("tiers"):
|
||||||
|
tier, tier_index = self._get_tier(usage, tiered_pricing, total_input_context)
|
||||||
|
if tier:
|
||||||
|
result.tier_index = tier_index
|
||||||
|
# 阶梯价格覆盖默认价格
|
||||||
|
for key, value in tier.items():
|
||||||
|
if key not in ("up_to", "cache_ttl_pricing") and value is not None:
|
||||||
|
effective_prices[key] = value
|
||||||
|
|
||||||
|
# 处理 TTL 差异化定价
|
||||||
|
if cache_ttl_minutes is not None:
|
||||||
|
ttl_price = self._get_cache_read_price_for_ttl(tier, cache_ttl_minutes)
|
||||||
|
if ttl_price is not None:
|
||||||
|
effective_prices["cache_read_price_per_1m"] = ttl_price
|
||||||
|
|
||||||
|
# 记录使用的价格
|
||||||
|
result.effective_prices = effective_prices.copy()
|
||||||
|
|
||||||
|
# 计算各维度费用
|
||||||
|
total = 0.0
|
||||||
|
for dim in self.dimensions:
|
||||||
|
usage_value = usage.get(dim.usage_field, 0)
|
||||||
|
price = effective_prices.get(dim.price_field, dim.default_price)
|
||||||
|
|
||||||
|
if usage_value and price:
|
||||||
|
cost = dim.calculate(usage_value, price)
|
||||||
|
result.costs[dim.name] = cost
|
||||||
|
total += cost
|
||||||
|
|
||||||
|
result.total_cost = total
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _get_tier(
|
||||||
|
self,
|
||||||
|
usage: StandardizedUsage,
|
||||||
|
tiered_pricing: Dict[str, Any],
|
||||||
|
total_input_context: Optional[int] = None,
|
||||||
|
) -> Tuple[Optional[Dict[str, Any]], Optional[int]]:
|
||||||
|
"""
|
||||||
|
确定价格阶梯
|
||||||
|
|
||||||
|
Args:
|
||||||
|
usage: usage 数据
|
||||||
|
tiered_pricing: 阶梯配置 {"tiers": [...]}
|
||||||
|
total_input_context: 预计算的总输入上下文(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(匹配的阶梯配置, 阶梯索引)
|
||||||
|
"""
|
||||||
|
tiers = tiered_pricing.get("tiers", [])
|
||||||
|
if not tiers:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# 使用传入的 total_input_context,或者默认计算
|
||||||
|
if total_input_context is None:
|
||||||
|
total_input_context = self._compute_total_input_context(usage)
|
||||||
|
|
||||||
|
for i, tier in enumerate(tiers):
|
||||||
|
up_to = tier.get("up_to")
|
||||||
|
# up_to 为 None 表示无上限(最后一个阶梯)
|
||||||
|
if up_to is None or total_input_context <= up_to:
|
||||||
|
return tier, i
|
||||||
|
|
||||||
|
# 如果所有阶梯都有上限且都超过了,返回最后一个阶梯
|
||||||
|
return tiers[-1], len(tiers) - 1
|
||||||
|
|
||||||
|
def _compute_total_input_context(self, usage: StandardizedUsage) -> int:
|
||||||
|
"""
|
||||||
|
计算总输入上下文(用于阶梯计费判定)
|
||||||
|
|
||||||
|
默认: input_tokens + cache_read_tokens
|
||||||
|
|
||||||
|
Args:
|
||||||
|
usage: usage 数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
总输入 token 数
|
||||||
|
"""
|
||||||
|
return usage.input_tokens + usage.cache_read_tokens
|
||||||
|
|
||||||
|
def _get_cache_read_price_for_ttl(
|
||||||
|
self,
|
||||||
|
tier: Dict[str, Any],
|
||||||
|
cache_ttl_minutes: int,
|
||||||
|
) -> Optional[float]:
|
||||||
|
"""
|
||||||
|
根据缓存 TTL 获取缓存读取价格
|
||||||
|
|
||||||
|
某些厂商(如 Claude)对不同 TTL 的缓存有不同定价。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tier: 当前阶梯配置
|
||||||
|
cache_ttl_minutes: 缓存时长(分钟)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
缓存读取价格,如果没有 TTL 差异化配置返回 None
|
||||||
|
"""
|
||||||
|
ttl_pricing = tier.get("cache_ttl_pricing")
|
||||||
|
if not ttl_pricing:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 找到匹配或最接近的 TTL 价格
|
||||||
|
for ttl_config in ttl_pricing:
|
||||||
|
ttl_limit = ttl_config.get("ttl_minutes", 0)
|
||||||
|
if cache_ttl_minutes <= ttl_limit:
|
||||||
|
price = ttl_config.get("cache_read_price_per_1m")
|
||||||
|
return float(price) if price is not None else None
|
||||||
|
|
||||||
|
# 超过所有配置的 TTL,使用最后一个
|
||||||
|
if ttl_pricing:
|
||||||
|
price = ttl_pricing[-1].get("cache_read_price_per_1m")
|
||||||
|
return float(price) if price is not None else None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: Dict[str, Any]) -> "BillingCalculator":
|
||||||
|
"""
|
||||||
|
从配置创建计费计算器
|
||||||
|
|
||||||
|
Config 格式:
|
||||||
|
{
|
||||||
|
"template": "claude", # 或 "openai", "doubao", "per_request"
|
||||||
|
# 或者自定义维度:
|
||||||
|
"dimensions": [
|
||||||
|
{"name": "input", "usage_field": "input_tokens", "price_field": "input_price_per_1m"},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: 配置字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BillingCalculator 实例
|
||||||
|
"""
|
||||||
|
if "dimensions" in config:
|
||||||
|
dimensions = [BillingDimension.from_dict(d) for d in config["dimensions"]]
|
||||||
|
return cls(dimensions=dimensions)
|
||||||
|
|
||||||
|
return cls(template=config.get("template", "claude"))
|
||||||
|
|
||||||
|
def get_dimension_names(self) -> List[str]:
|
||||||
|
"""获取所有计费维度名称"""
|
||||||
|
return [dim.name for dim in self.dimensions]
|
||||||
|
|
||||||
|
def get_required_price_fields(self) -> List[str]:
|
||||||
|
"""获取所需的价格字段名称"""
|
||||||
|
return [dim.price_field for dim in self.dimensions]
|
||||||
|
|
||||||
|
def get_required_usage_fields(self) -> List[str]:
|
||||||
|
"""获取所需的 usage 字段名称"""
|
||||||
|
return [dim.usage_field for dim in self.dimensions]
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_request_cost(
|
||||||
|
input_tokens: int,
|
||||||
|
output_tokens: int,
|
||||||
|
cache_creation_input_tokens: int,
|
||||||
|
cache_read_input_tokens: int,
|
||||||
|
input_price_per_1m: float,
|
||||||
|
output_price_per_1m: float,
|
||||||
|
cache_creation_price_per_1m: Optional[float],
|
||||||
|
cache_read_price_per_1m: Optional[float],
|
||||||
|
price_per_request: Optional[float],
|
||||||
|
tiered_pricing: Optional[Dict[str, Any]] = None,
|
||||||
|
cache_ttl_minutes: Optional[int] = None,
|
||||||
|
total_input_context: Optional[int] = None,
|
||||||
|
billing_template: str = "claude",
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
计算请求成本的便捷函数
|
||||||
|
|
||||||
|
封装了 BillingCalculator 的调用逻辑,返回兼容旧格式的字典。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_tokens: 输入 token 数
|
||||||
|
output_tokens: 输出 token 数
|
||||||
|
cache_creation_input_tokens: 缓存创建 token 数
|
||||||
|
cache_read_input_tokens: 缓存读取 token 数
|
||||||
|
input_price_per_1m: 输入价格(每 1M tokens)
|
||||||
|
output_price_per_1m: 输出价格(每 1M tokens)
|
||||||
|
cache_creation_price_per_1m: 缓存创建价格(每 1M tokens)
|
||||||
|
cache_read_price_per_1m: 缓存读取价格(每 1M tokens)
|
||||||
|
price_per_request: 按次计费价格
|
||||||
|
tiered_pricing: 阶梯计费配置
|
||||||
|
cache_ttl_minutes: 缓存时长(分钟)
|
||||||
|
total_input_context: 总输入上下文(用于阶梯判定)
|
||||||
|
billing_template: 计费模板名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含各项成本的字典:
|
||||||
|
{
|
||||||
|
"input_cost": float,
|
||||||
|
"output_cost": float,
|
||||||
|
"cache_creation_cost": float,
|
||||||
|
"cache_read_cost": float,
|
||||||
|
"cache_cost": float,
|
||||||
|
"request_cost": float,
|
||||||
|
"total_cost": float,
|
||||||
|
"tier_index": Optional[int],
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# 构建标准化 usage
|
||||||
|
usage = StandardizedUsage(
|
||||||
|
input_tokens=input_tokens,
|
||||||
|
output_tokens=output_tokens,
|
||||||
|
cache_creation_tokens=cache_creation_input_tokens,
|
||||||
|
cache_read_tokens=cache_read_input_tokens,
|
||||||
|
request_count=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建价格配置
|
||||||
|
prices: Dict[str, float] = {
|
||||||
|
"input_price_per_1m": input_price_per_1m,
|
||||||
|
"output_price_per_1m": output_price_per_1m,
|
||||||
|
}
|
||||||
|
if cache_creation_price_per_1m is not None:
|
||||||
|
prices["cache_creation_price_per_1m"] = cache_creation_price_per_1m
|
||||||
|
if cache_read_price_per_1m is not None:
|
||||||
|
prices["cache_read_price_per_1m"] = cache_read_price_per_1m
|
||||||
|
if price_per_request is not None:
|
||||||
|
prices["price_per_request"] = price_per_request
|
||||||
|
|
||||||
|
# 使用 BillingCalculator 计算
|
||||||
|
calculator = BillingCalculator(template=billing_template)
|
||||||
|
result = calculator.calculate(
|
||||||
|
usage, prices, tiered_pricing, cache_ttl_minutes, total_input_context
|
||||||
|
)
|
||||||
|
|
||||||
|
# 返回兼容旧格式的字典
|
||||||
|
return {
|
||||||
|
"input_cost": result.input_cost,
|
||||||
|
"output_cost": result.output_cost,
|
||||||
|
"cache_creation_cost": result.cache_creation_cost,
|
||||||
|
"cache_read_cost": result.cache_read_cost,
|
||||||
|
"cache_cost": result.cache_cost,
|
||||||
|
"request_cost": result.request_cost,
|
||||||
|
"total_cost": result.total_cost,
|
||||||
|
"tier_index": result.tier_index,
|
||||||
|
}
|
||||||
281
src/services/billing/models.py
Normal file
281
src/services/billing/models.py
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
"""
|
||||||
|
计费模块数据模型
|
||||||
|
|
||||||
|
定义计费相关的核心数据结构:
|
||||||
|
- BillingUnit: 计费单位枚举
|
||||||
|
- BillingDimension: 计费维度定义
|
||||||
|
- StandardizedUsage: 标准化的 usage 数据
|
||||||
|
- CostBreakdown: 计费明细结果
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class BillingUnit(str, Enum):
|
||||||
|
"""计费单位"""
|
||||||
|
|
||||||
|
PER_1M_TOKENS = "per_1m_tokens" # 每百万 token
|
||||||
|
PER_1M_TOKENS_HOUR = "per_1m_tokens_hour" # 每百万 token 每小时(豆包缓存存储)
|
||||||
|
PER_REQUEST = "per_request" # 每次请求
|
||||||
|
FIXED = "fixed" # 固定费用
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BillingDimension:
|
||||||
|
"""
|
||||||
|
计费维度定义
|
||||||
|
|
||||||
|
每个维度描述一种计费方式,例如:
|
||||||
|
- 输入 token 计费
|
||||||
|
- 输出 token 计费
|
||||||
|
- 缓存读取计费
|
||||||
|
- 按次计费
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str # 维度名称,如 "input", "output", "cache_read"
|
||||||
|
usage_field: str # 从 usage 中取值的字段名
|
||||||
|
price_field: str # 价格配置中的字段名
|
||||||
|
unit: BillingUnit = BillingUnit.PER_1M_TOKENS # 计费单位
|
||||||
|
default_price: float = 0.0 # 默认价格(当价格配置中没有时使用)
|
||||||
|
|
||||||
|
def calculate(self, usage_value: float, price: float) -> float:
|
||||||
|
"""
|
||||||
|
计算该维度的费用
|
||||||
|
|
||||||
|
Args:
|
||||||
|
usage_value: 使用量数值
|
||||||
|
price: 单价
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
计算后的费用
|
||||||
|
"""
|
||||||
|
if usage_value <= 0 or price <= 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
if self.unit == BillingUnit.PER_1M_TOKENS:
|
||||||
|
return (usage_value / 1_000_000) * price
|
||||||
|
elif self.unit == BillingUnit.PER_1M_TOKENS_HOUR:
|
||||||
|
# 缓存存储按 token 数 * 小时数计费
|
||||||
|
return (usage_value / 1_000_000) * price
|
||||||
|
elif self.unit == BillingUnit.PER_REQUEST:
|
||||||
|
return usage_value * price
|
||||||
|
elif self.unit == BillingUnit.FIXED:
|
||||||
|
return price
|
||||||
|
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""转换为字典(用于序列化)"""
|
||||||
|
return {
|
||||||
|
"name": self.name,
|
||||||
|
"usage_field": self.usage_field,
|
||||||
|
"price_field": self.price_field,
|
||||||
|
"unit": self.unit.value,
|
||||||
|
"default_price": self.default_price,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "BillingDimension":
|
||||||
|
"""从字典创建实例"""
|
||||||
|
return cls(
|
||||||
|
name=data["name"],
|
||||||
|
usage_field=data["usage_field"],
|
||||||
|
price_field=data["price_field"],
|
||||||
|
unit=BillingUnit(data.get("unit", "per_1m_tokens")),
|
||||||
|
default_price=data.get("default_price", 0.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StandardizedUsage:
|
||||||
|
"""
|
||||||
|
标准化的 Usage 数据
|
||||||
|
|
||||||
|
将不同 API 格式的 usage 统一为标准格式,便于计费计算。
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 基础 token 计数
|
||||||
|
input_tokens: int = 0
|
||||||
|
output_tokens: int = 0
|
||||||
|
|
||||||
|
# 缓存相关
|
||||||
|
cache_creation_tokens: int = 0 # Claude: 缓存创建
|
||||||
|
cache_read_tokens: int = 0 # Claude/OpenAI/豆包: 缓存读取/命中
|
||||||
|
|
||||||
|
# 特殊 token 类型
|
||||||
|
reasoning_tokens: int = 0 # o1/豆包: 推理 token(通常包含在 output 中,单独记录用于分析)
|
||||||
|
|
||||||
|
# 时间相关(用于按时计费)
|
||||||
|
cache_storage_token_hours: float = 0.0 # 豆包: 缓存存储 token*小时
|
||||||
|
|
||||||
|
# 请求计数(用于按次计费)
|
||||||
|
request_count: int = 1
|
||||||
|
|
||||||
|
# 扩展字段(未来可能需要的额外维度)
|
||||||
|
extra: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def get(self, field_name: str, default: Any = 0) -> Any:
|
||||||
|
"""
|
||||||
|
通用字段获取
|
||||||
|
|
||||||
|
支持获取标准字段和扩展字段。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field_name: 字段名
|
||||||
|
default: 默认值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
字段值
|
||||||
|
"""
|
||||||
|
if hasattr(self, field_name):
|
||||||
|
value = getattr(self, field_name)
|
||||||
|
# 对于 extra 字段,不直接返回
|
||||||
|
if field_name != "extra":
|
||||||
|
return value
|
||||||
|
return self.extra.get(field_name, default)
|
||||||
|
|
||||||
|
def set(self, field_name: str, value: Any) -> None:
|
||||||
|
"""
|
||||||
|
通用字段设置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field_name: 字段名
|
||||||
|
value: 字段值
|
||||||
|
"""
|
||||||
|
if hasattr(self, field_name) and field_name != "extra":
|
||||||
|
setattr(self, field_name, value)
|
||||||
|
else:
|
||||||
|
self.extra[field_name] = value
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""转换为字典"""
|
||||||
|
result: Dict[str, Any] = {
|
||||||
|
"input_tokens": self.input_tokens,
|
||||||
|
"output_tokens": self.output_tokens,
|
||||||
|
"cache_creation_tokens": self.cache_creation_tokens,
|
||||||
|
"cache_read_tokens": self.cache_read_tokens,
|
||||||
|
"reasoning_tokens": self.reasoning_tokens,
|
||||||
|
"cache_storage_token_hours": self.cache_storage_token_hours,
|
||||||
|
"request_count": self.request_count,
|
||||||
|
}
|
||||||
|
if self.extra:
|
||||||
|
result["extra"] = self.extra
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "StandardizedUsage":
|
||||||
|
"""从字典创建实例"""
|
||||||
|
extra = data.pop("extra", {}) if "extra" in data else {}
|
||||||
|
# 只取已知字段
|
||||||
|
known_fields = {
|
||||||
|
"input_tokens",
|
||||||
|
"output_tokens",
|
||||||
|
"cache_creation_tokens",
|
||||||
|
"cache_read_tokens",
|
||||||
|
"reasoning_tokens",
|
||||||
|
"cache_storage_token_hours",
|
||||||
|
"request_count",
|
||||||
|
}
|
||||||
|
filtered = {k: v for k, v in data.items() if k in known_fields}
|
||||||
|
return cls(**filtered, extra=extra)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CostBreakdown:
|
||||||
|
"""
|
||||||
|
计费明细结果
|
||||||
|
|
||||||
|
包含各维度的费用和总费用。
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 各维度费用 {"input": 0.01, "output": 0.02, "cache_read": 0.001, ...}
|
||||||
|
costs: Dict[str, float] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# 总费用
|
||||||
|
total_cost: float = 0.0
|
||||||
|
|
||||||
|
# 命中的阶梯索引(如果使用阶梯计费)
|
||||||
|
tier_index: Optional[int] = None
|
||||||
|
|
||||||
|
# 货币单位
|
||||||
|
currency: str = "USD"
|
||||||
|
|
||||||
|
# 使用的价格(用于记录和审计)
|
||||||
|
effective_prices: Dict[str, float] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# 兼容旧接口的属性(便于渐进式迁移)
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_cost(self) -> float:
|
||||||
|
"""输入费用"""
|
||||||
|
return self.costs.get("input", 0.0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_cost(self) -> float:
|
||||||
|
"""输出费用"""
|
||||||
|
return self.costs.get("output", 0.0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache_creation_cost(self) -> float:
|
||||||
|
"""缓存创建费用"""
|
||||||
|
return self.costs.get("cache_creation", 0.0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache_read_cost(self) -> float:
|
||||||
|
"""缓存读取费用"""
|
||||||
|
return self.costs.get("cache_read", 0.0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache_cost(self) -> float:
|
||||||
|
"""总缓存费用(创建 + 读取)"""
|
||||||
|
return self.cache_creation_cost + self.cache_read_cost
|
||||||
|
|
||||||
|
@property
|
||||||
|
def request_cost(self) -> float:
|
||||||
|
"""按次计费费用"""
|
||||||
|
return self.costs.get("request", 0.0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache_storage_cost(self) -> float:
|
||||||
|
"""缓存存储费用(豆包等)"""
|
||||||
|
return self.costs.get("cache_storage", 0.0)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""转换为字典"""
|
||||||
|
return {
|
||||||
|
"costs": self.costs,
|
||||||
|
"total_cost": self.total_cost,
|
||||||
|
"tier_index": self.tier_index,
|
||||||
|
"currency": self.currency,
|
||||||
|
"effective_prices": self.effective_prices,
|
||||||
|
# 兼容字段
|
||||||
|
"input_cost": self.input_cost,
|
||||||
|
"output_cost": self.output_cost,
|
||||||
|
"cache_creation_cost": self.cache_creation_cost,
|
||||||
|
"cache_read_cost": self.cache_read_cost,
|
||||||
|
"cache_cost": self.cache_cost,
|
||||||
|
"request_cost": self.request_cost,
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_legacy_tuple(self) -> tuple:
|
||||||
|
"""
|
||||||
|
转换为旧接口的元组格式
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(input_cost, output_cost, cache_creation_cost, cache_read_cost,
|
||||||
|
cache_cost, request_cost, total_cost, tier_index)
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
self.input_cost,
|
||||||
|
self.output_cost,
|
||||||
|
self.cache_creation_cost,
|
||||||
|
self.cache_read_cost,
|
||||||
|
self.cache_cost,
|
||||||
|
self.request_cost,
|
||||||
|
self.total_cost,
|
||||||
|
self.tier_index,
|
||||||
|
)
|
||||||
213
src/services/billing/templates.py
Normal file
213
src/services/billing/templates.py
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
"""
|
||||||
|
预定义计费模板
|
||||||
|
|
||||||
|
提供常见厂商的计费配置模板,避免重复配置:
|
||||||
|
- CLAUDE_STANDARD: Claude/Anthropic 标准计费
|
||||||
|
- OPENAI_STANDARD: OpenAI 标准计费
|
||||||
|
- DOUBAO_STANDARD: 豆包计费(含缓存存储)
|
||||||
|
- GEMINI_STANDARD: Gemini 标准计费
|
||||||
|
- PER_REQUEST: 按次计费
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from src.services.billing.models import BillingDimension, BillingUnit
|
||||||
|
|
||||||
|
|
||||||
|
class BillingTemplates:
|
||||||
|
"""预定义的计费模板"""
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Claude/Anthropic 标准计费
|
||||||
|
# - 输入 token
|
||||||
|
# - 输出 token
|
||||||
|
# - 缓存创建(创建时收费,约 1.25x 输入价格)
|
||||||
|
# - 缓存读取(约 0.1x 输入价格)
|
||||||
|
# =========================================================================
|
||||||
|
CLAUDE_STANDARD: List[BillingDimension] = [
|
||||||
|
BillingDimension(
|
||||||
|
name="input",
|
||||||
|
usage_field="input_tokens",
|
||||||
|
price_field="input_price_per_1m",
|
||||||
|
),
|
||||||
|
BillingDimension(
|
||||||
|
name="output",
|
||||||
|
usage_field="output_tokens",
|
||||||
|
price_field="output_price_per_1m",
|
||||||
|
),
|
||||||
|
BillingDimension(
|
||||||
|
name="cache_creation",
|
||||||
|
usage_field="cache_creation_tokens",
|
||||||
|
price_field="cache_creation_price_per_1m",
|
||||||
|
),
|
||||||
|
BillingDimension(
|
||||||
|
name="cache_read",
|
||||||
|
usage_field="cache_read_tokens",
|
||||||
|
price_field="cache_read_price_per_1m",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# OpenAI 标准计费
|
||||||
|
# - 输入 token
|
||||||
|
# - 输出 token
|
||||||
|
# - 缓存读取(部分模型支持,无缓存创建费用)
|
||||||
|
# =========================================================================
|
||||||
|
OPENAI_STANDARD: List[BillingDimension] = [
|
||||||
|
BillingDimension(
|
||||||
|
name="input",
|
||||||
|
usage_field="input_tokens",
|
||||||
|
price_field="input_price_per_1m",
|
||||||
|
),
|
||||||
|
BillingDimension(
|
||||||
|
name="output",
|
||||||
|
usage_field="output_tokens",
|
||||||
|
price_field="output_price_per_1m",
|
||||||
|
),
|
||||||
|
BillingDimension(
|
||||||
|
name="cache_read",
|
||||||
|
usage_field="cache_read_tokens",
|
||||||
|
price_field="cache_read_price_per_1m",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# 豆包计费
|
||||||
|
# - 推理输入 (input_tokens)
|
||||||
|
# - 推理输出 (output_tokens)
|
||||||
|
# - 缓存命中 (cache_read_tokens) - 类似 Claude 的缓存读取
|
||||||
|
# - 缓存存储 (cache_storage_token_hours) - 按 token 数 * 存储时长计费
|
||||||
|
#
|
||||||
|
# 注意:豆包的缓存创建是免费的,但存储需要按时付费
|
||||||
|
# =========================================================================
|
||||||
|
DOUBAO_STANDARD: List[BillingDimension] = [
|
||||||
|
BillingDimension(
|
||||||
|
name="input",
|
||||||
|
usage_field="input_tokens",
|
||||||
|
price_field="input_price_per_1m",
|
||||||
|
),
|
||||||
|
BillingDimension(
|
||||||
|
name="output",
|
||||||
|
usage_field="output_tokens",
|
||||||
|
price_field="output_price_per_1m",
|
||||||
|
),
|
||||||
|
BillingDimension(
|
||||||
|
name="cache_read",
|
||||||
|
usage_field="cache_read_tokens",
|
||||||
|
price_field="cache_read_price_per_1m",
|
||||||
|
),
|
||||||
|
BillingDimension(
|
||||||
|
name="cache_storage",
|
||||||
|
usage_field="cache_storage_token_hours",
|
||||||
|
price_field="cache_storage_price_per_1m_hour",
|
||||||
|
unit=BillingUnit.PER_1M_TOKENS_HOUR,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Gemini 标准计费
|
||||||
|
# - 输入 token
|
||||||
|
# - 输出 token
|
||||||
|
# - 缓存读取
|
||||||
|
# =========================================================================
|
||||||
|
GEMINI_STANDARD: List[BillingDimension] = [
|
||||||
|
BillingDimension(
|
||||||
|
name="input",
|
||||||
|
usage_field="input_tokens",
|
||||||
|
price_field="input_price_per_1m",
|
||||||
|
),
|
||||||
|
BillingDimension(
|
||||||
|
name="output",
|
||||||
|
usage_field="output_tokens",
|
||||||
|
price_field="output_price_per_1m",
|
||||||
|
),
|
||||||
|
BillingDimension(
|
||||||
|
name="cache_read",
|
||||||
|
usage_field="cache_read_tokens",
|
||||||
|
price_field="cache_read_price_per_1m",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# 按次计费
|
||||||
|
# - 适用于某些图片生成模型、特殊 API 等
|
||||||
|
# - 仅按请求次数计费,不按 token 计费
|
||||||
|
# =========================================================================
|
||||||
|
PER_REQUEST: List[BillingDimension] = [
|
||||||
|
BillingDimension(
|
||||||
|
name="request",
|
||||||
|
usage_field="request_count",
|
||||||
|
price_field="price_per_request",
|
||||||
|
unit=BillingUnit.PER_REQUEST,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# 混合计费(按次 + 按 token)
|
||||||
|
# - 某些模型既有固定费用又有 token 费用
|
||||||
|
# =========================================================================
|
||||||
|
HYBRID_STANDARD: List[BillingDimension] = [
|
||||||
|
BillingDimension(
|
||||||
|
name="input",
|
||||||
|
usage_field="input_tokens",
|
||||||
|
price_field="input_price_per_1m",
|
||||||
|
),
|
||||||
|
BillingDimension(
|
||||||
|
name="output",
|
||||||
|
usage_field="output_tokens",
|
||||||
|
price_field="output_price_per_1m",
|
||||||
|
),
|
||||||
|
BillingDimension(
|
||||||
|
name="request",
|
||||||
|
usage_field="request_count",
|
||||||
|
price_field="price_per_request",
|
||||||
|
unit=BillingUnit.PER_REQUEST,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# 模板注册表
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
BILLING_TEMPLATE_REGISTRY: Dict[str, List[BillingDimension]] = {
|
||||||
|
# 按厂商名称
|
||||||
|
"claude": BillingTemplates.CLAUDE_STANDARD,
|
||||||
|
"anthropic": BillingTemplates.CLAUDE_STANDARD,
|
||||||
|
"openai": BillingTemplates.OPENAI_STANDARD,
|
||||||
|
"doubao": BillingTemplates.DOUBAO_STANDARD,
|
||||||
|
"bytedance": BillingTemplates.DOUBAO_STANDARD,
|
||||||
|
"gemini": BillingTemplates.GEMINI_STANDARD,
|
||||||
|
"google": BillingTemplates.GEMINI_STANDARD,
|
||||||
|
# 按计费模式
|
||||||
|
"per_request": BillingTemplates.PER_REQUEST,
|
||||||
|
"hybrid": BillingTemplates.HYBRID_STANDARD,
|
||||||
|
# 默认
|
||||||
|
"default": BillingTemplates.CLAUDE_STANDARD,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_template(name: Optional[str]) -> List[BillingDimension]:
|
||||||
|
"""
|
||||||
|
获取计费模板
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 模板名称(不区分大小写)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
计费维度列表
|
||||||
|
"""
|
||||||
|
if not name:
|
||||||
|
return BILLING_TEMPLATE_REGISTRY["default"]
|
||||||
|
|
||||||
|
template = BILLING_TEMPLATE_REGISTRY.get(name.lower())
|
||||||
|
if template is None:
|
||||||
|
available = ", ".join(sorted(BILLING_TEMPLATE_REGISTRY.keys()))
|
||||||
|
raise ValueError(f"Unknown billing template: {name!r}. Available: {available}")
|
||||||
|
|
||||||
|
return template
|
||||||
|
|
||||||
|
|
||||||
|
def list_templates() -> List[str]:
|
||||||
|
"""列出所有可用的模板名称"""
|
||||||
|
return list(BILLING_TEMPLATE_REGISTRY.keys())
|
||||||
267
src/services/billing/usage_mapper.py
Normal file
267
src/services/billing/usage_mapper.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
"""
|
||||||
|
Usage 字段映射器
|
||||||
|
|
||||||
|
将不同 API 格式的原始 usage 数据映射为标准化格式。
|
||||||
|
|
||||||
|
支持的格式:
|
||||||
|
- OPENAI / OPENAI_CLI: OpenAI Chat Completions API
|
||||||
|
- CLAUDE / CLAUDE_CLI: Anthropic Messages API
|
||||||
|
- GEMINI / GEMINI_CLI: Google Gemini API
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from src.services.billing.models import StandardizedUsage
|
||||||
|
|
||||||
|
|
||||||
|
class UsageMapper:
|
||||||
|
"""
|
||||||
|
Usage 字段映射器
|
||||||
|
|
||||||
|
将不同 API 格式的 usage 统一映射为 StandardizedUsage。
|
||||||
|
|
||||||
|
示例:
|
||||||
|
# OpenAI 格式
|
||||||
|
raw_usage = {
|
||||||
|
"prompt_tokens": 100,
|
||||||
|
"completion_tokens": 50,
|
||||||
|
"prompt_tokens_details": {"cached_tokens": 20},
|
||||||
|
"completion_tokens_details": {"reasoning_tokens": 10}
|
||||||
|
}
|
||||||
|
usage = UsageMapper.map(raw_usage, "OPENAI")
|
||||||
|
|
||||||
|
# Claude 格式
|
||||||
|
raw_usage = {
|
||||||
|
"input_tokens": 100,
|
||||||
|
"output_tokens": 50,
|
||||||
|
"cache_creation_input_tokens": 30,
|
||||||
|
"cache_read_input_tokens": 20
|
||||||
|
}
|
||||||
|
usage = UsageMapper.map(raw_usage, "CLAUDE")
|
||||||
|
"""
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# 字段映射配置
|
||||||
|
# 格式: "source_path" -> "target_field"
|
||||||
|
# source_path 支持点号分隔的嵌套路径
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
# OpenAI 格式字段映射
|
||||||
|
OPENAI_MAPPING: Dict[str, str] = {
|
||||||
|
"prompt_tokens": "input_tokens",
|
||||||
|
"completion_tokens": "output_tokens",
|
||||||
|
"prompt_tokens_details.cached_tokens": "cache_read_tokens",
|
||||||
|
"completion_tokens_details.reasoning_tokens": "reasoning_tokens",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Claude 格式字段映射
|
||||||
|
CLAUDE_MAPPING: Dict[str, str] = {
|
||||||
|
"input_tokens": "input_tokens",
|
||||||
|
"output_tokens": "output_tokens",
|
||||||
|
"cache_creation_input_tokens": "cache_creation_tokens",
|
||||||
|
"cache_read_input_tokens": "cache_read_tokens",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Gemini 格式字段映射
|
||||||
|
GEMINI_MAPPING: Dict[str, str] = {
|
||||||
|
"promptTokenCount": "input_tokens",
|
||||||
|
"candidatesTokenCount": "output_tokens",
|
||||||
|
"cachedContentTokenCount": "cache_read_tokens",
|
||||||
|
# Gemini 的 usageMetadata 格式
|
||||||
|
"usageMetadata.promptTokenCount": "input_tokens",
|
||||||
|
"usageMetadata.candidatesTokenCount": "output_tokens",
|
||||||
|
"usageMetadata.cachedContentTokenCount": "cache_read_tokens",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 格式名称到映射的对应关系
|
||||||
|
FORMAT_MAPPINGS: Dict[str, Dict[str, str]] = {
|
||||||
|
"OPENAI": OPENAI_MAPPING,
|
||||||
|
"OPENAI_CLI": OPENAI_MAPPING,
|
||||||
|
"CLAUDE": CLAUDE_MAPPING,
|
||||||
|
"CLAUDE_CLI": CLAUDE_MAPPING,
|
||||||
|
"GEMINI": GEMINI_MAPPING,
|
||||||
|
"GEMINI_CLI": GEMINI_MAPPING,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def map(
|
||||||
|
cls,
|
||||||
|
raw_usage: Dict[str, Any],
|
||||||
|
api_format: str,
|
||||||
|
extra_mapping: Optional[Dict[str, str]] = None,
|
||||||
|
) -> StandardizedUsage:
|
||||||
|
"""
|
||||||
|
将原始 usage 映射为标准化格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_usage: 原始 usage 字典
|
||||||
|
api_format: API 格式 ("OPENAI", "CLAUDE", "GEMINI" 等)
|
||||||
|
extra_mapping: 额外的字段映射(用于自定义扩展)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
标准化的 usage 对象
|
||||||
|
"""
|
||||||
|
if not raw_usage:
|
||||||
|
return StandardizedUsage()
|
||||||
|
|
||||||
|
# 获取对应格式的字段映射
|
||||||
|
mapping = cls._get_mapping(api_format)
|
||||||
|
|
||||||
|
# 合并额外映射
|
||||||
|
if extra_mapping:
|
||||||
|
mapping = {**mapping, **extra_mapping}
|
||||||
|
|
||||||
|
result = StandardizedUsage()
|
||||||
|
|
||||||
|
# 执行映射
|
||||||
|
for source_path, target_field in mapping.items():
|
||||||
|
value = cls._get_nested_value(raw_usage, source_path)
|
||||||
|
if value is not None:
|
||||||
|
result.set(target_field, value)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def map_from_response(
|
||||||
|
cls,
|
||||||
|
response: Dict[str, Any],
|
||||||
|
api_format: str,
|
||||||
|
) -> StandardizedUsage:
|
||||||
|
"""
|
||||||
|
从完整响应中提取并映射 usage
|
||||||
|
|
||||||
|
不同 API 格式的 usage 位置可能不同:
|
||||||
|
- OpenAI: response["usage"]
|
||||||
|
- Claude: response["usage"] 或 message_delta 中
|
||||||
|
- Gemini: response["usageMetadata"]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: 完整的 API 响应
|
||||||
|
api_format: API 格式
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
标准化的 usage 对象
|
||||||
|
"""
|
||||||
|
format_upper = api_format.upper() if api_format else ""
|
||||||
|
|
||||||
|
# 提取 usage 部分
|
||||||
|
usage_data: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
if format_upper.startswith("GEMINI"):
|
||||||
|
# Gemini: usageMetadata
|
||||||
|
usage_data = response.get("usageMetadata", {})
|
||||||
|
if not usage_data:
|
||||||
|
# 尝试从 candidates 中获取
|
||||||
|
candidates = response.get("candidates", [])
|
||||||
|
if candidates:
|
||||||
|
usage_data = candidates[0].get("usageMetadata", {})
|
||||||
|
else:
|
||||||
|
# OpenAI/Claude: usage
|
||||||
|
usage_data = response.get("usage", {})
|
||||||
|
|
||||||
|
return cls.map(usage_data, api_format)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_mapping(cls, api_format: str) -> Dict[str, str]:
|
||||||
|
"""获取对应格式的字段映射"""
|
||||||
|
if not api_format:
|
||||||
|
return cls.CLAUDE_MAPPING
|
||||||
|
|
||||||
|
format_upper = api_format.upper()
|
||||||
|
|
||||||
|
# 精确匹配
|
||||||
|
if format_upper in cls.FORMAT_MAPPINGS:
|
||||||
|
return cls.FORMAT_MAPPINGS[format_upper]
|
||||||
|
|
||||||
|
# 前缀匹配
|
||||||
|
for key, mapping in cls.FORMAT_MAPPINGS.items():
|
||||||
|
if format_upper.startswith(key.split("_")[0]):
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
# 默认使用 Claude 映射
|
||||||
|
return cls.CLAUDE_MAPPING
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_nested_value(cls, data: Dict[str, Any], path: str) -> Any:
|
||||||
|
"""
|
||||||
|
获取嵌套字段值
|
||||||
|
|
||||||
|
支持点号分隔的路径,如 "prompt_tokens_details.cached_tokens"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 数据字典
|
||||||
|
path: 字段路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
字段值,不存在则返回 None
|
||||||
|
"""
|
||||||
|
if not data or not path:
|
||||||
|
return None
|
||||||
|
|
||||||
|
keys = path.split(".")
|
||||||
|
value: Any = data
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
if isinstance(value, dict):
|
||||||
|
value = value.get(key)
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_format(cls, format_name: str, mapping: Dict[str, str]) -> None:
|
||||||
|
"""
|
||||||
|
注册新的格式映射
|
||||||
|
|
||||||
|
Args:
|
||||||
|
format_name: 格式名称(会自动转为大写)
|
||||||
|
mapping: 字段映射
|
||||||
|
"""
|
||||||
|
cls.FORMAT_MAPPINGS[format_name.upper()] = mapping
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_formats(cls) -> list:
|
||||||
|
"""获取所有支持的格式"""
|
||||||
|
return list(cls.FORMAT_MAPPINGS.keys())
|
||||||
|
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# 便捷函数
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def map_usage(
|
||||||
|
raw_usage: Dict[str, Any],
|
||||||
|
api_format: str,
|
||||||
|
) -> StandardizedUsage:
|
||||||
|
"""
|
||||||
|
便捷函数:将原始 usage 映射为标准化格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_usage: 原始 usage 字典
|
||||||
|
api_format: API 格式
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StandardizedUsage 对象
|
||||||
|
"""
|
||||||
|
return UsageMapper.map(raw_usage, api_format)
|
||||||
|
|
||||||
|
|
||||||
|
def map_usage_from_response(
|
||||||
|
response: Dict[str, Any],
|
||||||
|
api_format: str,
|
||||||
|
) -> StandardizedUsage:
|
||||||
|
"""
|
||||||
|
便捷函数:从响应中提取并映射 usage
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: API 响应
|
||||||
|
api_format: API 格式
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StandardizedUsage 对象
|
||||||
|
"""
|
||||||
|
return UsageMapper.map_from_response(response, api_format)
|
||||||
@@ -7,6 +7,59 @@ from typing import Any
|
|||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
|
|
||||||
|
|
||||||
|
def escape_like_pattern(pattern: str) -> str:
|
||||||
|
"""
|
||||||
|
转义 SQL LIKE 语句中的特殊字符(%、_、\\)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: 原始搜索模式
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
转义后的模式,可安全用于 LIKE 查询(需配合 escape="\\\\")
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> escape_like_pattern("hello_world%test")
|
||||||
|
'hello\\\\_world\\\\%test'
|
||||||
|
"""
|
||||||
|
return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||||
|
|
||||||
|
|
||||||
|
def safe_truncate_escaped(escaped: str, max_len: int) -> str:
|
||||||
|
"""
|
||||||
|
安全截断已转义的字符串,避免截断在转义序列中间
|
||||||
|
|
||||||
|
转义后的字符串中,反斜杠总是成对出现(\\\\)或作为转义符(\\%, \\_)。
|
||||||
|
如果在某个位置截断导致末尾有奇数个反斜杠,说明截断发生在转义序列中间,
|
||||||
|
需要去掉最后一个反斜杠以保持转义完整性。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
escaped: 已经过 escape_like_pattern 处理的字符串
|
||||||
|
max_len: 最大长度
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
截断后的字符串,保证不会破坏转义序列
|
||||||
|
"""
|
||||||
|
if len(escaped) <= max_len:
|
||||||
|
return escaped
|
||||||
|
|
||||||
|
truncated = escaped[:max_len]
|
||||||
|
|
||||||
|
# 统计末尾连续的反斜杠数量
|
||||||
|
trailing_backslashes = 0
|
||||||
|
for i in range(len(truncated) - 1, -1, -1):
|
||||||
|
if truncated[i] == "\\":
|
||||||
|
trailing_backslashes += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 如果末尾反斜杠数量为奇数,说明截断在转义序列中间
|
||||||
|
# 需要去掉最后一个反斜杠
|
||||||
|
if trailing_backslashes % 2 == 1:
|
||||||
|
truncated = truncated[:-1]
|
||||||
|
|
||||||
|
return truncated
|
||||||
|
|
||||||
|
|
||||||
def date_trunc_portable(dialect_name: str, interval: str, column: Any) -> Any:
|
def date_trunc_portable(dialect_name: str, interval: str, column: Any) -> Any:
|
||||||
"""
|
"""
|
||||||
跨数据库的日期截断函数
|
跨数据库的日期截断函数
|
||||||
|
|||||||
0
tests/services/billing/__init__.py
Normal file
0
tests/services/billing/__init__.py
Normal file
440
tests/services/billing/test_billing.py
Normal file
440
tests/services/billing/test_billing.py
Normal file
@@ -0,0 +1,440 @@
|
|||||||
|
"""
|
||||||
|
Billing 模块测试
|
||||||
|
|
||||||
|
测试计费模块的核心功能:
|
||||||
|
- BillingCalculator 计费计算
|
||||||
|
- 计费模板
|
||||||
|
- 阶梯计费
|
||||||
|
- calculate_request_cost 便捷函数
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.services.billing import (
|
||||||
|
BillingCalculator,
|
||||||
|
BillingDimension,
|
||||||
|
BillingTemplates,
|
||||||
|
BillingUnit,
|
||||||
|
CostBreakdown,
|
||||||
|
StandardizedUsage,
|
||||||
|
calculate_request_cost,
|
||||||
|
)
|
||||||
|
from src.services.billing.templates import get_template, list_templates
|
||||||
|
|
||||||
|
|
||||||
|
class TestBillingDimension:
|
||||||
|
"""测试计费维度"""
|
||||||
|
|
||||||
|
def test_calculate_per_1m_tokens(self) -> None:
|
||||||
|
"""测试 per_1m_tokens 计费"""
|
||||||
|
dim = BillingDimension(
|
||||||
|
name="input",
|
||||||
|
usage_field="input_tokens",
|
||||||
|
price_field="input_price_per_1m",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1000 tokens * $3 / 1M = $0.003
|
||||||
|
cost = dim.calculate(1000, 3.0)
|
||||||
|
assert abs(cost - 0.003) < 0.0001
|
||||||
|
|
||||||
|
def test_calculate_per_request(self) -> None:
|
||||||
|
"""测试按次计费"""
|
||||||
|
dim = BillingDimension(
|
||||||
|
name="request",
|
||||||
|
usage_field="request_count",
|
||||||
|
price_field="price_per_request",
|
||||||
|
unit=BillingUnit.PER_REQUEST,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 按次计费:cost = request_count * price
|
||||||
|
cost = dim.calculate(1, 0.05)
|
||||||
|
assert cost == 0.05
|
||||||
|
|
||||||
|
# 多次请求应按次数计费
|
||||||
|
cost = dim.calculate(3, 0.05)
|
||||||
|
assert abs(cost - 0.15) < 0.0001
|
||||||
|
|
||||||
|
def test_calculate_zero_usage(self) -> None:
|
||||||
|
"""测试零用量"""
|
||||||
|
dim = BillingDimension(
|
||||||
|
name="input",
|
||||||
|
usage_field="input_tokens",
|
||||||
|
price_field="input_price_per_1m",
|
||||||
|
)
|
||||||
|
|
||||||
|
cost = dim.calculate(0, 3.0)
|
||||||
|
assert cost == 0.0
|
||||||
|
|
||||||
|
def test_calculate_zero_price(self) -> None:
|
||||||
|
"""测试零价格"""
|
||||||
|
dim = BillingDimension(
|
||||||
|
name="input",
|
||||||
|
usage_field="input_tokens",
|
||||||
|
price_field="input_price_per_1m",
|
||||||
|
)
|
||||||
|
|
||||||
|
cost = dim.calculate(1000, 0.0)
|
||||||
|
assert cost == 0.0
|
||||||
|
|
||||||
|
def test_to_dict_and_from_dict(self) -> None:
|
||||||
|
"""测试序列化和反序列化"""
|
||||||
|
dim = BillingDimension(
|
||||||
|
name="cache_read",
|
||||||
|
usage_field="cache_read_tokens",
|
||||||
|
price_field="cache_read_price_per_1m",
|
||||||
|
unit=BillingUnit.PER_1M_TOKENS,
|
||||||
|
default_price=0.3,
|
||||||
|
)
|
||||||
|
|
||||||
|
d = dim.to_dict()
|
||||||
|
restored = BillingDimension.from_dict(d)
|
||||||
|
|
||||||
|
assert restored.name == dim.name
|
||||||
|
assert restored.usage_field == dim.usage_field
|
||||||
|
assert restored.price_field == dim.price_field
|
||||||
|
assert restored.unit == dim.unit
|
||||||
|
assert restored.default_price == dim.default_price
|
||||||
|
|
||||||
|
|
||||||
|
class TestStandardizedUsage:
|
||||||
|
"""测试标准化 Usage"""
|
||||||
|
|
||||||
|
def test_basic_usage(self) -> None:
|
||||||
|
"""测试基础 usage"""
|
||||||
|
usage = StandardizedUsage(
|
||||||
|
input_tokens=1000,
|
||||||
|
output_tokens=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert usage.input_tokens == 1000
|
||||||
|
assert usage.output_tokens == 500
|
||||||
|
assert usage.cache_creation_tokens == 0
|
||||||
|
assert usage.cache_read_tokens == 0
|
||||||
|
|
||||||
|
def test_get_field(self) -> None:
|
||||||
|
"""测试字段获取"""
|
||||||
|
usage = StandardizedUsage(
|
||||||
|
input_tokens=1000,
|
||||||
|
output_tokens=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert usage.get("input_tokens") == 1000
|
||||||
|
assert usage.get("nonexistent", 0) == 0
|
||||||
|
|
||||||
|
def test_extra_fields(self) -> None:
|
||||||
|
"""测试扩展字段"""
|
||||||
|
usage = StandardizedUsage(
|
||||||
|
input_tokens=1000,
|
||||||
|
output_tokens=500,
|
||||||
|
extra={"custom_field": 123},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert usage.get("custom_field") == 123
|
||||||
|
|
||||||
|
def test_to_dict(self) -> None:
|
||||||
|
"""测试转换为字典"""
|
||||||
|
usage = StandardizedUsage(
|
||||||
|
input_tokens=1000,
|
||||||
|
output_tokens=500,
|
||||||
|
cache_creation_tokens=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
d = usage.to_dict()
|
||||||
|
assert d["input_tokens"] == 1000
|
||||||
|
assert d["output_tokens"] == 500
|
||||||
|
assert d["cache_creation_tokens"] == 100
|
||||||
|
|
||||||
|
|
||||||
|
class TestCostBreakdown:
|
||||||
|
"""测试费用明细"""
|
||||||
|
|
||||||
|
def test_basic_breakdown(self) -> None:
|
||||||
|
"""测试基础费用明细"""
|
||||||
|
breakdown = CostBreakdown(
|
||||||
|
costs={"input": 0.003, "output": 0.0075},
|
||||||
|
total_cost=0.0105,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert breakdown.input_cost == 0.003
|
||||||
|
assert breakdown.output_cost == 0.0075
|
||||||
|
assert breakdown.total_cost == 0.0105
|
||||||
|
|
||||||
|
def test_cache_cost_calculation(self) -> None:
|
||||||
|
"""测试缓存费用汇总"""
|
||||||
|
breakdown = CostBreakdown(
|
||||||
|
costs={
|
||||||
|
"input": 0.003,
|
||||||
|
"output": 0.0075,
|
||||||
|
"cache_creation": 0.001,
|
||||||
|
"cache_read": 0.0005,
|
||||||
|
},
|
||||||
|
total_cost=0.012,
|
||||||
|
)
|
||||||
|
|
||||||
|
# cache_cost = cache_creation + cache_read
|
||||||
|
assert abs(breakdown.cache_cost - 0.0015) < 0.0001
|
||||||
|
|
||||||
|
def test_to_dict(self) -> None:
|
||||||
|
"""测试转换为字典"""
|
||||||
|
breakdown = CostBreakdown(
|
||||||
|
costs={"input": 0.003, "output": 0.0075},
|
||||||
|
total_cost=0.0105,
|
||||||
|
tier_index=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
d = breakdown.to_dict()
|
||||||
|
assert d["total_cost"] == 0.0105
|
||||||
|
assert d["tier_index"] == 1
|
||||||
|
assert d["input_cost"] == 0.003
|
||||||
|
|
||||||
|
|
||||||
|
class TestBillingTemplates:
|
||||||
|
"""测试计费模板"""
|
||||||
|
|
||||||
|
def test_claude_template(self) -> None:
|
||||||
|
"""测试 Claude 模板"""
|
||||||
|
template = BillingTemplates.CLAUDE_STANDARD
|
||||||
|
dim_names = [d.name for d in template]
|
||||||
|
|
||||||
|
assert "input" in dim_names
|
||||||
|
assert "output" in dim_names
|
||||||
|
assert "cache_creation" in dim_names
|
||||||
|
assert "cache_read" in dim_names
|
||||||
|
|
||||||
|
def test_openai_template(self) -> None:
|
||||||
|
"""测试 OpenAI 模板"""
|
||||||
|
template = BillingTemplates.OPENAI_STANDARD
|
||||||
|
dim_names = [d.name for d in template]
|
||||||
|
|
||||||
|
assert "input" in dim_names
|
||||||
|
assert "output" in dim_names
|
||||||
|
assert "cache_read" in dim_names
|
||||||
|
# OpenAI 没有缓存创建费用
|
||||||
|
assert "cache_creation" not in dim_names
|
||||||
|
|
||||||
|
def test_gemini_template(self) -> None:
|
||||||
|
"""测试 Gemini 模板"""
|
||||||
|
template = BillingTemplates.GEMINI_STANDARD
|
||||||
|
dim_names = [d.name for d in template]
|
||||||
|
|
||||||
|
assert "input" in dim_names
|
||||||
|
assert "output" in dim_names
|
||||||
|
assert "cache_read" in dim_names
|
||||||
|
|
||||||
|
def test_per_request_template(self) -> None:
|
||||||
|
"""测试按次计费模板"""
|
||||||
|
template = BillingTemplates.PER_REQUEST
|
||||||
|
assert len(template) == 1
|
||||||
|
assert template[0].name == "request"
|
||||||
|
assert template[0].unit == BillingUnit.PER_REQUEST
|
||||||
|
|
||||||
|
def test_get_template(self) -> None:
|
||||||
|
"""测试获取模板"""
|
||||||
|
template = get_template("claude")
|
||||||
|
assert template == BillingTemplates.CLAUDE_STANDARD
|
||||||
|
|
||||||
|
template = get_template("openai")
|
||||||
|
assert template == BillingTemplates.OPENAI_STANDARD
|
||||||
|
|
||||||
|
# 不区分大小写
|
||||||
|
template = get_template("CLAUDE")
|
||||||
|
assert template == BillingTemplates.CLAUDE_STANDARD
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Unknown billing template"):
|
||||||
|
get_template("unknown_template")
|
||||||
|
|
||||||
|
def test_list_templates(self) -> None:
|
||||||
|
"""测试列出模板"""
|
||||||
|
templates = list_templates()
|
||||||
|
|
||||||
|
assert "claude" in templates
|
||||||
|
assert "openai" in templates
|
||||||
|
assert "gemini" in templates
|
||||||
|
assert "per_request" in templates
|
||||||
|
|
||||||
|
|
||||||
|
class TestBillingCalculator:
|
||||||
|
"""测试计费计算器"""
|
||||||
|
|
||||||
|
def test_basic_calculation(self) -> None:
|
||||||
|
"""测试基础计费计算"""
|
||||||
|
calculator = BillingCalculator(template="claude")
|
||||||
|
usage = StandardizedUsage(input_tokens=1000, output_tokens=500)
|
||||||
|
prices = {"input_price_per_1m": 3.0, "output_price_per_1m": 15.0}
|
||||||
|
|
||||||
|
result = calculator.calculate(usage, prices)
|
||||||
|
|
||||||
|
# 1000 * 3 / 1M = 0.003
|
||||||
|
assert abs(result.input_cost - 0.003) < 0.0001
|
||||||
|
# 500 * 15 / 1M = 0.0075
|
||||||
|
assert abs(result.output_cost - 0.0075) < 0.0001
|
||||||
|
# Total = 0.0105
|
||||||
|
assert abs(result.total_cost - 0.0105) < 0.0001
|
||||||
|
|
||||||
|
def test_calculation_with_cache(self) -> None:
|
||||||
|
"""测试带缓存的计费计算"""
|
||||||
|
calculator = BillingCalculator(template="claude")
|
||||||
|
usage = StandardizedUsage(
|
||||||
|
input_tokens=1000,
|
||||||
|
output_tokens=500,
|
||||||
|
cache_creation_tokens=200,
|
||||||
|
cache_read_tokens=300,
|
||||||
|
)
|
||||||
|
prices = {
|
||||||
|
"input_price_per_1m": 3.0,
|
||||||
|
"output_price_per_1m": 15.0,
|
||||||
|
"cache_creation_price_per_1m": 3.75,
|
||||||
|
"cache_read_price_per_1m": 0.3,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = calculator.calculate(usage, prices)
|
||||||
|
|
||||||
|
# cache_creation: 200 * 3.75 / 1M = 0.00075
|
||||||
|
assert abs(result.cache_creation_cost - 0.00075) < 0.0001
|
||||||
|
# cache_read: 300 * 0.3 / 1M = 0.00009
|
||||||
|
assert abs(result.cache_read_cost - 0.00009) < 0.0001
|
||||||
|
|
||||||
|
def test_tiered_pricing(self) -> None:
|
||||||
|
"""测试阶梯计费"""
|
||||||
|
calculator = BillingCalculator(template="claude")
|
||||||
|
usage = StandardizedUsage(input_tokens=250000, output_tokens=10000)
|
||||||
|
|
||||||
|
# 大于 200k 进入第二阶梯
|
||||||
|
tiered_pricing = {
|
||||||
|
"tiers": [
|
||||||
|
{"up_to": 200000, "input_price_per_1m": 3.0, "output_price_per_1m": 15.0},
|
||||||
|
{"up_to": None, "input_price_per_1m": 1.5, "output_price_per_1m": 7.5},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
prices = {"input_price_per_1m": 3.0, "output_price_per_1m": 15.0}
|
||||||
|
|
||||||
|
result = calculator.calculate(usage, prices, tiered_pricing)
|
||||||
|
|
||||||
|
# 应该使用第二阶梯价格
|
||||||
|
assert result.tier_index == 1
|
||||||
|
# 250000 * 1.5 / 1M = 0.375
|
||||||
|
assert abs(result.input_cost - 0.375) < 0.0001
|
||||||
|
|
||||||
|
def test_openai_no_cache_creation(self) -> None:
|
||||||
|
"""测试 OpenAI 模板没有缓存创建费用"""
|
||||||
|
calculator = BillingCalculator(template="openai")
|
||||||
|
usage = StandardizedUsage(
|
||||||
|
input_tokens=1000,
|
||||||
|
output_tokens=500,
|
||||||
|
cache_creation_tokens=200, # 这个不应该计费
|
||||||
|
cache_read_tokens=300,
|
||||||
|
)
|
||||||
|
prices = {
|
||||||
|
"input_price_per_1m": 3.0,
|
||||||
|
"output_price_per_1m": 15.0,
|
||||||
|
"cache_creation_price_per_1m": 3.75,
|
||||||
|
"cache_read_price_per_1m": 0.3,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = calculator.calculate(usage, prices)
|
||||||
|
|
||||||
|
# OpenAI 模板不包含 cache_creation 维度
|
||||||
|
assert result.cache_creation_cost == 0.0
|
||||||
|
# 但 cache_read 应该计费
|
||||||
|
assert result.cache_read_cost > 0
|
||||||
|
|
||||||
|
def test_from_config(self) -> None:
|
||||||
|
"""测试从配置创建计算器"""
|
||||||
|
config = {"template": "openai"}
|
||||||
|
calculator = BillingCalculator.from_config(config)
|
||||||
|
|
||||||
|
assert calculator.template_name == "openai"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCalculateRequestCost:
|
||||||
|
"""测试便捷函数"""
|
||||||
|
|
||||||
|
def test_basic_usage(self) -> None:
|
||||||
|
"""测试基础用法"""
|
||||||
|
result = calculate_request_cost(
|
||||||
|
input_tokens=1000,
|
||||||
|
output_tokens=500,
|
||||||
|
cache_creation_input_tokens=0,
|
||||||
|
cache_read_input_tokens=0,
|
||||||
|
input_price_per_1m=3.0,
|
||||||
|
output_price_per_1m=15.0,
|
||||||
|
cache_creation_price_per_1m=None,
|
||||||
|
cache_read_price_per_1m=None,
|
||||||
|
price_per_request=None,
|
||||||
|
billing_template="claude",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "input_cost" in result
|
||||||
|
assert "output_cost" in result
|
||||||
|
assert "total_cost" in result
|
||||||
|
assert abs(result["input_cost"] - 0.003) < 0.0001
|
||||||
|
assert abs(result["output_cost"] - 0.0075) < 0.0001
|
||||||
|
|
||||||
|
def test_with_cache(self) -> None:
|
||||||
|
"""测试带缓存"""
|
||||||
|
result = calculate_request_cost(
|
||||||
|
input_tokens=1000,
|
||||||
|
output_tokens=500,
|
||||||
|
cache_creation_input_tokens=200,
|
||||||
|
cache_read_input_tokens=300,
|
||||||
|
input_price_per_1m=3.0,
|
||||||
|
output_price_per_1m=15.0,
|
||||||
|
cache_creation_price_per_1m=3.75,
|
||||||
|
cache_read_price_per_1m=0.3,
|
||||||
|
price_per_request=None,
|
||||||
|
billing_template="claude",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["cache_creation_cost"] > 0
|
||||||
|
assert result["cache_read_cost"] > 0
|
||||||
|
assert result["cache_cost"] == result["cache_creation_cost"] + result["cache_read_cost"]
|
||||||
|
|
||||||
|
def test_different_templates(self) -> None:
|
||||||
|
"""测试不同模板"""
|
||||||
|
prices = {
|
||||||
|
"input_tokens": 1000,
|
||||||
|
"output_tokens": 500,
|
||||||
|
"cache_creation_input_tokens": 200,
|
||||||
|
"cache_read_input_tokens": 300,
|
||||||
|
"input_price_per_1m": 3.0,
|
||||||
|
"output_price_per_1m": 15.0,
|
||||||
|
"cache_creation_price_per_1m": 3.75,
|
||||||
|
"cache_read_price_per_1m": 0.3,
|
||||||
|
"price_per_request": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Claude 模板有 cache_creation
|
||||||
|
result_claude = calculate_request_cost(**prices, billing_template="claude")
|
||||||
|
assert result_claude["cache_creation_cost"] > 0
|
||||||
|
|
||||||
|
# OpenAI 模板没有 cache_creation
|
||||||
|
result_openai = calculate_request_cost(**prices, billing_template="openai")
|
||||||
|
assert result_openai["cache_creation_cost"] == 0
|
||||||
|
|
||||||
|
def test_tiered_pricing_with_total_context(self) -> None:
|
||||||
|
"""测试使用自定义 total_input_context 的阶梯计费"""
|
||||||
|
tiered_pricing = {
|
||||||
|
"tiers": [
|
||||||
|
{"up_to": 200000, "input_price_per_1m": 3.0, "output_price_per_1m": 15.0},
|
||||||
|
{"up_to": None, "input_price_per_1m": 1.5, "output_price_per_1m": 7.5},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
# 传入预计算的 total_input_context
|
||||||
|
result = calculate_request_cost(
|
||||||
|
input_tokens=1000,
|
||||||
|
output_tokens=500,
|
||||||
|
cache_creation_input_tokens=0,
|
||||||
|
cache_read_input_tokens=0,
|
||||||
|
input_price_per_1m=3.0,
|
||||||
|
output_price_per_1m=15.0,
|
||||||
|
cache_creation_price_per_1m=None,
|
||||||
|
cache_read_price_per_1m=None,
|
||||||
|
price_per_request=None,
|
||||||
|
tiered_pricing=tiered_pricing,
|
||||||
|
total_input_context=250000, # 预计算的值,超过 200k
|
||||||
|
billing_template="claude",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 应该使用第二阶梯价格
|
||||||
|
assert result["tier_index"] == 1
|
||||||
Reference in New Issue
Block a user