7 Commits

Author SHA1 Message Date
fawney19
0f78d5cbf3 fix: 增强 CLI 处理器的错误信息,包含上游响应详情 2026-01-05 19:44:38 +08:00
fawney19
431c6de8d2 feat: 用户用量页面支持分页、搜索和密钥信息展示
- 用户用量API增加search参数支持密钥名、模型名搜索
- 用户用量API返回api_key信息(id、name、display)
- 用户页面记录表格增加密钥列显示
- 前端统一管理员和用户页面的分页/搜索逻辑
- 后端LIKE查询增加特殊字符转义防止SQL注入
- 添加escape_like_pattern和safe_truncate_escaped工具函数
2026-01-05 19:35:14 +08:00
fawney19
142e15bbcc Merge pull request #69 from AoaoMH/feature/Record-optimization
feat: add usage statistics and records feature with new API routes, f…
2026-01-05 19:31:59 +08:00
AAEE86
31acc5c607 feat(models): sort models by release date within each provider
Models are now sorted by release date in descending order (newest first)
within each provider group. Models without release dates are placed at the
end. When release dates are identical or missing, models fall back to
alphabetical sorting by name.
2026-01-05 18:23:04 +08:00
fawney19
bfa0a26d41 feat: 用户导出支持独立余额Key,新增系统版本接口
- 用户导出/导入支持独立余额 Key (standalone_keys)
- API Key 导出增加 expires_at 字段
- 新增 /api/admin/system/version 接口获取版本信息
- 前端系统设置页面显示当前版本
- 移除导入对话框中多余的 bg-muted 背景样式
2026-01-05 18:18:45 +08:00
AoaoMH
93ab9b6a5e feat: add usage statistics and records feature with new API routes, frontend types, services, and UI components 2026-01-05 17:03:05 +08:00
fawney19
35e29d46bd refactor: 抽取统一计费模块,支持配置驱动的多厂商计费
- 新增 src/services/billing/ 模块,包含计费计算器、模板和使用量映射
- 将 ChatAdapterBase 和 CliAdapterBase 中的计费逻辑重构为调用 billing 模块
- 为每个 adapter 添加 BILLING_TEMPLATE 类属性,指定计费模板
- 支持 Claude/OpenAI/Gemini 三种计费模板,支持阶梯计费和缓存 TTL 定价
- 新增 tests/services/billing/ 单元测试
2026-01-05 16:48:59 +08:00
30 changed files with 2199 additions and 363 deletions

View File

@@ -13,6 +13,7 @@ export interface UsersExportData {
version: string
exported_at: string
users: UserExport[]
standalone_keys?: StandaloneKeyExport[]
}
export interface UserExport {
@@ -46,11 +47,15 @@ export interface UserApiKeyExport {
concurrent_limit?: number | null
force_capabilities?: any
is_active: boolean
expires_at?: string | null
auto_delete_on_expiry?: boolean
total_requests?: number
total_cost_usd?: number
}
// 独立余额 Key 导出结构(与 UserApiKeyExport 相同,但不包含 is_standalone
export type StandaloneKeyExport = Omit<UserApiKeyExport, 'is_standalone'>
export interface GlobalModelExport {
name: string
display_name: string
@@ -189,6 +194,7 @@ export interface UsersImportResponse {
stats: {
users: { created: number; updated: number; skipped: number }
api_keys: { created: number; skipped: number }
standalone_keys?: { created: number; skipped: number }
errors: string[]
}
}
@@ -473,5 +479,13 @@ export const adminApi = {
`/api/admin/system/email/templates/${templateType}/reset`
)
return response.data
},
// 获取系统版本信息
async getSystemVersion(): Promise<{ version: string }> {
const response = await apiClient.get<{ version: string }>(
'/api/admin/system/version'
)
return response.data
}
}

View File

@@ -62,6 +62,11 @@ export interface UsageRecordDetail {
cache_creation_price_per_1m?: number
cache_read_price_per_1m?: number
price_per_request?: number // 按次计费价格
api_key?: {
id: string
name: string
display: string
}
}
// 模型统计接口
@@ -192,6 +197,7 @@ export const meApi = {
async getUsage(params?: {
start_date?: string
end_date?: string
search?: string // 通用搜索:密钥名、模型名
limit?: number
offset?: number
}): Promise<UsageResponse> {

View File

@@ -192,10 +192,17 @@ export async function getModelsDevList(officialOnly: boolean = true): Promise<Mo
}
}
// 按 provider 名称和模型名称排序
// 按 provider 名称排序provider 中的模型按 release_date 从近到远排序
items.sort((a, b) => {
const providerCompare = a.providerName.localeCompare(b.providerName)
if (providerCompare !== 0) return providerCompare
// 模型按 release_date 从近到远排序(没有日期的排到最后)
const aDate = a.releaseDate ? new Date(a.releaseDate).getTime() : 0
const bDate = b.releaseDate ? new Date(b.releaseDate).getTime() : 0
if (aDate !== bDate) return bDate - aDate // 降序:新的在前
// 日期相同或都没有日期时,按模型名称排序
return a.modelName.localeCompare(b.modelName)
})

View File

@@ -164,6 +164,7 @@ export const usageApi = {
async getAllUsageRecords(params?: {
start_date?: string
end_date?: string
search?: string // 通用搜索:用户名、密钥名、模型名、提供商名
user_id?: string // UUID
username?: string
model?: string

View File

@@ -32,6 +32,17 @@
<!-- 分隔线 -->
<div class="hidden sm:block h-4 w-px bg-border" />
<!-- 通用搜索 -->
<div class="relative">
<Search class="absolute left-2.5 top-1/2 -translate-y-1/2 h-3.5 w-3.5 text-muted-foreground z-10 pointer-events-none" />
<Input
id="usage-records-search"
v-model="localSearch"
:placeholder="isAdmin ? '搜索用户/密钥/模型/提供商' : '搜索密钥/模型'"
class="w-32 sm:w-48 h-8 text-xs border-border/60 pl-8"
/>
</div>
<!-- 用户筛选仅管理员可见 -->
<Select
v-if="isAdmin && availableUsers.length > 0"
@@ -164,6 +175,12 @@
>
用户
</TableHead>
<TableHead
v-if="!isAdmin"
class="h-12 font-semibold w-[100px]"
>
密钥
</TableHead>
<TableHead class="h-12 font-semibold w-[140px]">
模型
</TableHead>
@@ -196,7 +213,7 @@
<TableBody>
<TableRow v-if="records.length === 0">
<TableCell
:colspan="isAdmin ? 9 : 7"
:colspan="isAdmin ? 9 : 8"
class="text-center py-12 text-muted-foreground"
>
暂无请求记录
@@ -218,7 +235,34 @@
class="py-4 w-[100px] truncate"
:title="record.username || record.user_email || (record.user_id ? `User ${record.user_id}` : '已删除用户')"
>
{{ record.username || record.user_email || (record.user_id ? `User ${record.user_id}` : '已删除用户') }}
<div class="flex flex-col text-xs gap-0.5">
<span class="truncate">
{{ record.username || record.user_email || (record.user_id ? `User ${record.user_id}` : '已删除用户') }}
</span>
<span
v-if="record.api_key?.name"
class="text-muted-foreground truncate"
:title="record.api_key.name"
>
{{ record.api_key.name }}
</span>
</div>
</TableCell>
<!-- 用户页面的密钥列 -->
<TableCell
v-if="!isAdmin"
class="py-4 w-[100px]"
:title="record.api_key?.name || '-'"
>
<div class="flex flex-col text-xs gap-0.5">
<span class="truncate">{{ record.api_key?.name || '-' }}</span>
<span
v-if="record.api_key?.display"
class="text-muted-foreground truncate"
>
{{ record.api_key.display }}
</span>
</div>
</TableCell>
<TableCell
class="font-medium py-4 w-[140px]"
@@ -438,6 +482,7 @@ import {
TableCard,
Badge,
Button,
Input,
Select,
SelectTrigger,
SelectValue,
@@ -451,7 +496,7 @@ import {
TableCell,
Pagination,
} from '@/components/ui'
import { RefreshCcw } from 'lucide-vue-next'
import { RefreshCcw, Search } from 'lucide-vue-next'
import { formatTokens, formatCurrency } from '@/utils/format'
import { formatDateTime } from '../composables'
import { useRowClick } from '@/composables/useRowClick'
@@ -471,6 +516,7 @@ const props = defineProps<{
// 时间段
selectedPeriod: string
// 筛选
filterSearch: string
filterUser: string
filterModel: string
filterProvider: string
@@ -489,6 +535,7 @@ const props = defineProps<{
const emit = defineEmits<{
'update:selectedPeriod': [value: string]
'update:filterSearch': [value: string]
'update:filterUser': [value: string]
'update:filterModel': [value: string]
'update:filterProvider': [value: string]
@@ -507,6 +554,23 @@ const filterModelSelectOpen = ref(false)
const filterProviderSelectOpen = ref(false)
const filterStatusSelectOpen = ref(false)
// 通用搜索(输入防抖)
const localSearch = ref(props.filterSearch)
let searchDebounceTimer: ReturnType<typeof setTimeout> | null = null
watch(() => props.filterSearch, (value) => {
if (value !== localSearch.value) {
localSearch.value = value
}
})
watch(localSearch, (value) => {
if (searchDebounceTimer) clearTimeout(searchDebounceTimer)
searchDebounceTimer = setTimeout(() => {
emit('update:filterSearch', value)
}, 300)
})
// 动态计时器相关
const now = ref(Date.now())
let timerInterval: ReturnType<typeof setInterval> | null = null
@@ -574,6 +638,10 @@ function handleRowClick(event: MouseEvent, id: string) {
// 组件卸载时清理
onUnmounted(() => {
stopTimer()
if (searchDebounceTimer) {
clearTimeout(searchDebounceTimer)
searchDebounceTimer = null
}
})
// 格式化 API 格式显示名称

View File

@@ -23,6 +23,7 @@ export interface PaginationParams {
}
export interface FilterParams {
search?: string
user_id?: string
model?: string
provider?: string
@@ -234,11 +235,6 @@ export function useUsageData(options: UseUsageDataOptions) {
pagination: PaginationParams,
filters?: FilterParams
): Promise<void> {
if (!isAdminPage.value) {
// 用户页面不需要分页加载,记录已在 loadStats 中获取
return
}
isLoadingRecords.value = true
try {
@@ -252,24 +248,34 @@ export function useUsageData(options: UseUsageDataOptions) {
}
// 添加筛选条件
if (filters?.user_id) {
params.user_id = filters.user_id
}
if (filters?.model) {
params.model = filters.model
}
if (filters?.provider) {
params.provider = filters.provider
}
if (filters?.status) {
params.status = filters.status
if (filters?.search?.trim()) {
params.search = filters.search.trim()
}
const response = await usageApi.getAllUsageRecords(params)
currentRecords.value = (response.records || []) as UsageRecord[]
totalRecords.value = response.total || 0
if (isAdminPage.value) {
// 管理员页面:使用管理员 API
if (filters?.user_id) {
params.user_id = filters.user_id
}
if (filters?.model) {
params.model = filters.model
}
if (filters?.provider) {
params.provider = filters.provider
}
if (filters?.status) {
params.status = filters.status
}
const response = await usageApi.getAllUsageRecords(params)
currentRecords.value = (response.records || []) as UsageRecord[]
totalRecords.value = response.total || 0
} else {
// 用户页面:使用用户 API
const userData = await meApi.getUsage(params)
currentRecords.value = (userData.records || []) as UsageRecord[]
totalRecords.value = userData.pagination?.total || currentRecords.value.length
}
} catch (error) {
log.error('加载记录失败:', error)
currentRecords.value = []

View File

@@ -61,6 +61,11 @@ export interface UsageRecord {
user_id?: string
username?: string
user_email?: string
api_key?: {
id: string | null
name: string | null
display: string | null
} | null
provider: string
api_key_name?: string
rate_multiplier?: number

View File

@@ -367,6 +367,11 @@ function generateMockUsageRecords(count: number = 100) {
user_id: user.id,
username: user.username,
user_email: user.email,
api_key: {
id: `key-${user.id}-${Math.ceil(Math.random() * 2)}`,
name: `${user.username} Key ${Math.ceil(Math.random() * 3)}`,
display: `sk-ae...${String(1000 + Math.floor(Math.random() * 9000))}`
},
provider: model.provider,
api_key_name: `${model.provider}-key-${Math.ceil(Math.random() * 3)}`,
rate_multiplier: 1.0,
@@ -835,10 +840,26 @@ const mockHandlers: Record<string, (config: AxiosRequestConfig) => Promise<Axios
'GET /api/admin/usage/records': async (config) => {
await delay()
requireAdmin()
const records = getUsageRecords()
let records = getUsageRecords()
const params = config.params || {}
const limit = parseInt(params.limit) || 20
const offset = parseInt(params.offset) || 0
// 通用搜索:用户名、密钥名、模型名、提供商名
// 支持空格分隔的组合搜索,多个关键词之间是 AND 关系
if (typeof params.search === 'string' && params.search.trim()) {
const keywords = params.search.trim().toLowerCase().split(/\s+/)
records = records.filter(r => {
// 每个关键词都要匹配至少一个字段
return keywords.every((keyword: string) =>
(r.username || '').toLowerCase().includes(keyword) ||
(r.api_key?.name || '').toLowerCase().includes(keyword) ||
(r.model || '').toLowerCase().includes(keyword) ||
(r.provider || '').toLowerCase().includes(keyword)
)
})
}
return createMockResponse({
records: records.slice(offset, offset + limit),
total: records.length,

View File

@@ -464,6 +464,30 @@
</div>
</div>
</CardSection>
<!-- 系统版本信息 -->
<CardSection
title="系统信息"
description="当前系统版本和构建信息"
>
<div class="flex items-center gap-4">
<div class="flex items-center gap-2">
<Label class="text-sm font-medium text-muted-foreground">版本:</Label>
<span
v-if="systemVersion"
class="text-sm font-mono"
>
{{ systemVersion }}
</span>
<span
v-else
class="text-sm text-muted-foreground"
>
加载中...
</span>
</div>
</div>
</CardSection>
</div>
<!-- 导入配置对话框 -->
@@ -475,7 +499,7 @@
<div class="space-y-4">
<div
v-if="importPreview"
class="p-3 bg-muted rounded-lg text-sm"
class="text-sm"
>
<p class="font-medium mb-2">
配置预览
@@ -557,7 +581,7 @@
class="space-y-4"
>
<div class="grid grid-cols-2 gap-4 text-sm">
<div class="p-3 bg-muted rounded-lg">
<div>
<p class="font-medium">
全局模型
</p>
@@ -567,7 +591,7 @@
跳过: {{ importResult.stats.global_models.skipped }}
</p>
</div>
<div class="p-3 bg-muted rounded-lg">
<div>
<p class="font-medium">
提供商
</p>
@@ -577,7 +601,7 @@
跳过: {{ importResult.stats.providers.skipped }}
</p>
</div>
<div class="p-3 bg-muted rounded-lg">
<div>
<p class="font-medium">
端点
</p>
@@ -587,7 +611,7 @@
跳过: {{ importResult.stats.endpoints.skipped }}
</p>
</div>
<div class="p-3 bg-muted rounded-lg">
<div>
<p class="font-medium">
API Keys
</p>
@@ -596,7 +620,7 @@
跳过: {{ importResult.stats.keys.skipped }}
</p>
</div>
<div class="p-3 bg-muted rounded-lg col-span-2">
<div class="col-span-2">
<p class="font-medium">
模型配置
</p>
@@ -642,7 +666,7 @@
<div class="space-y-4">
<div
v-if="importUsersPreview"
class="p-3 bg-muted rounded-lg text-sm"
class="text-sm"
>
<p class="font-medium mb-2">
数据预览
@@ -652,6 +676,9 @@
<li>
API Keys: {{ importUsersPreview.users?.reduce((sum: number, u: any) => sum + (u.api_keys?.length || 0), 0) }}
</li>
<li v-if="importUsersPreview.standalone_keys?.length">
独立余额 Keys: {{ importUsersPreview.standalone_keys.length }}
</li>
</ul>
</div>
@@ -720,7 +747,7 @@
class="space-y-4"
>
<div class="grid grid-cols-2 gap-4 text-sm">
<div class="p-3 bg-muted rounded-lg">
<div>
<p class="font-medium">
用户
</p>
@@ -730,7 +757,7 @@
跳过: {{ importUsersResult.stats.users.skipped }}
</p>
</div>
<div class="p-3 bg-muted rounded-lg">
<div>
<p class="font-medium">
API Keys
</p>
@@ -739,6 +766,18 @@
跳过: {{ importUsersResult.stats.api_keys.skipped }}
</p>
</div>
<div
v-if="importUsersResult.stats.standalone_keys"
class="col-span-2"
>
<p class="font-medium">
独立余额 Keys
</p>
<p class="text-muted-foreground">
创建: {{ importUsersResult.stats.standalone_keys.created }},
跳过: {{ importUsersResult.stats.standalone_keys.skipped }}
</p>
</div>
</div>
<div
@@ -839,6 +878,9 @@ const importUsersResult = ref<UsersImportResponse | null>(null)
const usersMergeMode = ref<'skip' | 'overwrite' | 'error'>('skip')
const usersMergeModeSelectOpen = ref(false)
// 系统版本信息
const systemVersion = ref<string>('')
const systemConfig = ref<SystemConfig>({
// 基础配置
default_user_quota_usd: 10.0,
@@ -890,9 +932,21 @@ const sensitiveHeadersStr = computed({
})
onMounted(async () => {
await loadSystemConfig()
await Promise.all([
loadSystemConfig(),
loadSystemVersion()
])
})
async function loadSystemVersion() {
try {
const data = await adminApi.getSystemVersion()
systemVersion.value = data.version
} catch (err) {
log.error('加载系统版本失败:', err)
}
}
async function loadSystemConfig() {
try {
const configs = [
@@ -1178,12 +1232,6 @@ function handleUsersFileSelect(event: Event) {
const content = e.target?.result as string
const data = JSON.parse(content) as UsersExportData
// 验证版本
if (data.version !== '1.0') {
error(`不支持的配置版本: ${data.version}`)
return
}
importUsersPreview.value = data
usersMergeMode.value = 'skip'
importUsersDialogOpen.value = true

View File

@@ -56,6 +56,7 @@
:show-actual-cost="authStore.isAdmin"
:loading="isLoadingRecords"
:selected-period="selectedPeriod"
:filter-search="filterSearch"
:filter-user="filterUser"
:filter-model="filterModel"
:filter-provider="filterProvider"
@@ -69,6 +70,7 @@
:page-size-options="pageSizeOptions"
:auto-refresh="globalAutoRefresh"
@update:selected-period="handlePeriodChange"
@update:filter-search="handleFilterSearchChange"
@update:filter-user="handleFilterUserChange"
@update:filter-model="handleFilterModelChange"
@update:filter-provider="handleFilterProviderChange"
@@ -133,6 +135,7 @@ const pageSize = ref(20)
const pageSizeOptions = [10, 20, 50, 100]
// 筛选状态
const filterSearch = ref('')
const filterUser = ref('__all__')
const filterModel = ref('__all__')
const filterProvider = ref('__all__')
@@ -392,14 +395,17 @@ onMounted(async () => {
// 热力图加载失败不提示,因为 UI 已显示占位符
}
// 管理员页面加载用户列表和第一页记录
// 加载记录和用户列表
if (isAdminPage.value) {
// 并行加载用户列表和记录
// 管理员页面:并行加载用户列表和记录
const [users] = await Promise.all([
usersApi.getAllUsers(),
loadRecords({ page: currentPage.value, pageSize: pageSize.value }, getCurrentFilters())
])
availableUsers.value = users.map(u => ({ id: u.id, username: u.username, email: u.email }))
} else {
// 用户页面:加载记录
await loadRecords({ page: currentPage.value, pageSize: pageSize.value }, getCurrentFilters())
}
})
@@ -410,34 +416,26 @@ async function handlePeriodChange(value: string) {
const dateRange = getDateRangeFromPeriod(selectedPeriod.value)
await loadStats(dateRange)
if (isAdminPage.value) {
await loadRecords({ page: 1, pageSize: pageSize.value }, getCurrentFilters())
}
await loadRecords({ page: 1, pageSize: pageSize.value }, getCurrentFilters())
}
// 处理分页变化
async function handlePageChange(page: number) {
currentPage.value = page
if (isAdminPage.value) {
await loadRecords({ page, pageSize: pageSize.value }, getCurrentFilters())
}
await loadRecords({ page, pageSize: pageSize.value }, getCurrentFilters())
}
// 处理每页大小变化
async function handlePageSizeChange(size: number) {
pageSize.value = size
currentPage.value = 1 // 重置到第一页
if (isAdminPage.value) {
await loadRecords({ page: 1, pageSize: size }, getCurrentFilters())
}
await loadRecords({ page: 1, pageSize: size }, getCurrentFilters())
}
// 获取当前筛选参数
function getCurrentFilters() {
return {
search: filterSearch.value.trim() || undefined,
user_id: filterUser.value !== '__all__' ? filterUser.value : undefined,
model: filterModel.value !== '__all__' ? filterModel.value : undefined,
provider: filterProvider.value !== '__all__' ? filterProvider.value : undefined,
@@ -446,6 +444,13 @@ function getCurrentFilters() {
}
// 处理筛选变化
async function handleFilterSearchChange(value: string) {
filterSearch.value = value
currentPage.value = 1
await loadRecords({ page: 1, pageSize: pageSize.value }, getCurrentFilters())
}
async function handleFilterUserChange(value: string) {
filterUser.value = value
currentPage.value = 1 // 重置到第一页
@@ -486,10 +491,7 @@ async function handleFilterStatusChange(value: string) {
async function refreshData() {
const dateRange = getDateRangeFromPeriod(selectedPeriod.value)
await loadStats(dateRange)
if (isAdminPage.value) {
await loadRecords({ page: currentPage.value, pageSize: pageSize.value }, getCurrentFilters())
}
await loadRecords({ page: currentPage.value, pageSize: pageSize.value }, getCurrentFilters())
}
// 显示请求详情

View File

@@ -1,5 +1,7 @@
"""系统设置API端点。"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
@@ -17,6 +19,46 @@ from src.services.email.email_template import EmailTemplate
from src.services.system.config import SystemConfigService
router = APIRouter(prefix="/api/admin/system", tags=["Admin - System"])
def _get_version_from_git() -> str | None:
"""从 git describe 获取版本号"""
import subprocess
try:
result = subprocess.run(
["git", "describe", "--tags", "--always"],
capture_output=True,
text=True,
timeout=5,
)
if result.returncode == 0:
version = result.stdout.strip()
if version.startswith("v"):
version = version[1:]
return version
except Exception:
pass
return None
@router.get("/version")
async def get_system_version():
"""获取系统版本信息"""
# 优先从 git 获取
version = _get_version_from_git()
if version:
return {"version": version}
# 回退到静态版本文件
try:
from src._version import __version__
return {"version": __version__}
except ImportError:
return {"version": "unknown"}
pipeline = ApiRequestPipeline()
@@ -950,6 +992,31 @@ class AdminExportUsersAdapter(AdminApiAdapter):
db = context.db
def _serialize_api_key(key: ApiKey, include_is_standalone: bool = False) -> dict:
"""序列化 API Key 为导出格式"""
data = {
"key_hash": key.key_hash,
"key_encrypted": key.key_encrypted,
"name": key.name,
"balance_used_usd": key.balance_used_usd,
"current_balance_usd": key.current_balance_usd,
"allowed_providers": key.allowed_providers,
"allowed_endpoints": key.allowed_endpoints,
"allowed_api_formats": key.allowed_api_formats,
"allowed_models": key.allowed_models,
"rate_limit": key.rate_limit,
"concurrent_limit": key.concurrent_limit,
"force_capabilities": key.force_capabilities,
"is_active": key.is_active,
"expires_at": key.expires_at.isoformat() if key.expires_at else None,
"auto_delete_on_expiry": key.auto_delete_on_expiry,
"total_requests": key.total_requests,
"total_cost_usd": key.total_cost_usd,
}
if include_is_standalone:
data["is_standalone"] = key.is_standalone
return data
# 导出 Users排除管理员
users = db.query(User).filter(
User.is_deleted.is_(False),
@@ -957,31 +1024,12 @@ class AdminExportUsersAdapter(AdminApiAdapter):
).all()
users_data = []
for user in users:
# 导出用户的 API Keys保留加密数据
api_keys = db.query(ApiKey).filter(ApiKey.user_id == user.id).all()
api_keys_data = []
for key in api_keys:
api_keys_data.append(
{
"key_hash": key.key_hash,
"key_encrypted": key.key_encrypted,
"name": key.name,
"is_standalone": key.is_standalone,
"balance_used_usd": key.balance_used_usd,
"current_balance_usd": key.current_balance_usd,
"allowed_providers": key.allowed_providers,
"allowed_endpoints": key.allowed_endpoints,
"allowed_api_formats": key.allowed_api_formats,
"allowed_models": key.allowed_models,
"rate_limit": key.rate_limit,
"concurrent_limit": key.concurrent_limit,
"force_capabilities": key.force_capabilities,
"is_active": key.is_active,
"auto_delete_on_expiry": key.auto_delete_on_expiry,
"total_requests": key.total_requests,
"total_cost_usd": key.total_cost_usd,
}
)
# 导出用户的 API Keys排除独立余额Key独立Key单独导出
api_keys = db.query(ApiKey).filter(
ApiKey.user_id == user.id,
ApiKey.is_standalone.is_(False)
).all()
api_keys_data = [_serialize_api_key(key, include_is_standalone=True) for key in api_keys]
users_data.append(
{
@@ -1001,10 +1049,15 @@ class AdminExportUsersAdapter(AdminApiAdapter):
}
)
# 导出独立余额 Keys管理员创建的不属于普通用户
standalone_keys = db.query(ApiKey).filter(ApiKey.is_standalone.is_(True)).all()
standalone_keys_data = [_serialize_api_key(key) for key in standalone_keys]
return {
"version": "1.0",
"version": "1.1",
"exported_at": datetime.now(timezone.utc).isoformat(),
"users": users_data,
"standalone_keys": standalone_keys_data,
}
@@ -1024,21 +1077,72 @@ class AdminImportUsersAdapter(AdminApiAdapter):
db = context.db
payload = context.ensure_json_body()
# 验证配置版本
version = payload.get("version")
if version != "1.0":
raise InvalidRequestException(f"不支持的配置版本: {version}")
# 获取导入选项
merge_mode = payload.get("merge_mode", "skip") # skip, overwrite, error
users_data = payload.get("users", [])
standalone_keys_data = payload.get("standalone_keys", [])
stats = {
"users": {"created": 0, "updated": 0, "skipped": 0},
"api_keys": {"created": 0, "skipped": 0},
"standalone_keys": {"created": 0, "skipped": 0},
"errors": [],
}
def _create_api_key_from_data(
key_data: dict,
owner_id: str,
is_standalone: bool = False,
) -> tuple[ApiKey | None, str]:
"""从导入数据创建 ApiKey 对象
Returns:
(ApiKey, "created"): 成功创建
(None, "skipped"): key 已存在,跳过
(None, "invalid"): 数据无效,跳过
"""
key_hash = key_data.get("key_hash", "").strip()
if not key_hash:
return None, "invalid"
# 检查是否已存在
existing = db.query(ApiKey).filter(ApiKey.key_hash == key_hash).first()
if existing:
return None, "skipped"
# 解析 expires_at
expires_at = None
if key_data.get("expires_at"):
try:
expires_at = datetime.fromisoformat(key_data["expires_at"])
except ValueError:
stats["errors"].append(
f"API Key '{key_data.get('name', key_hash[:8])}' 的 expires_at 格式无效"
)
return ApiKey(
id=str(uuid.uuid4()),
user_id=owner_id,
key_hash=key_hash,
key_encrypted=key_data.get("key_encrypted"),
name=key_data.get("name"),
is_standalone=is_standalone or key_data.get("is_standalone", False),
balance_used_usd=key_data.get("balance_used_usd", 0.0),
current_balance_usd=key_data.get("current_balance_usd"),
allowed_providers=key_data.get("allowed_providers"),
allowed_endpoints=key_data.get("allowed_endpoints"),
allowed_api_formats=key_data.get("allowed_api_formats"),
allowed_models=key_data.get("allowed_models"),
rate_limit=key_data.get("rate_limit"),
concurrent_limit=key_data.get("concurrent_limit", 5),
force_capabilities=key_data.get("force_capabilities"),
is_active=key_data.get("is_active", True),
expires_at=expires_at,
auto_delete_on_expiry=key_data.get("auto_delete_on_expiry", False),
total_requests=key_data.get("total_requests", 0),
total_cost_usd=key_data.get("total_cost_usd", 0.0),
), "created"
try:
for user_data in users_data:
# 跳过管理员角色的导入(不区分大小写)
@@ -1109,40 +1213,31 @@ class AdminImportUsersAdapter(AdminApiAdapter):
# 导入 API Keys
for key_data in user_data.get("api_keys", []):
# 检查是否已存在相同的 key_hash
if key_data.get("key_hash"):
existing_key = (
db.query(ApiKey)
.filter(ApiKey.key_hash == key_data["key_hash"])
.first()
)
if existing_key:
stats["api_keys"]["skipped"] += 1
continue
new_key, status = _create_api_key_from_data(key_data, user_id)
if new_key:
db.add(new_key)
stats["api_keys"]["created"] += 1
elif status == "skipped":
stats["api_keys"]["skipped"] += 1
# invalid 数据不计入统计
new_key = ApiKey(
id=str(uuid.uuid4()),
user_id=user_id,
key_hash=key_data.get("key_hash", ""),
key_encrypted=key_data.get("key_encrypted"),
name=key_data.get("name"),
is_standalone=key_data.get("is_standalone", False),
balance_used_usd=key_data.get("balance_used_usd", 0.0),
current_balance_usd=key_data.get("current_balance_usd"),
allowed_providers=key_data.get("allowed_providers"),
allowed_endpoints=key_data.get("allowed_endpoints"),
allowed_api_formats=key_data.get("allowed_api_formats"),
allowed_models=key_data.get("allowed_models"),
rate_limit=key_data.get("rate_limit"), # None = 无限制
concurrent_limit=key_data.get("concurrent_limit", 5),
force_capabilities=key_data.get("force_capabilities"),
is_active=key_data.get("is_active", True),
auto_delete_on_expiry=key_data.get("auto_delete_on_expiry", False),
total_requests=key_data.get("total_requests", 0),
total_cost_usd=key_data.get("total_cost_usd", 0.0),
)
db.add(new_key)
stats["api_keys"]["created"] += 1
# 导入独立余额 Keys需要找一个管理员用户作为 owner
if standalone_keys_data:
# 查找一个管理员用户作为独立Key的owner
admin_user = db.query(User).filter(User.role == UserRole.ADMIN).first()
if not admin_user:
stats["errors"].append("无法导入独立余额Key: 系统中没有管理员用户")
else:
for key_data in standalone_keys_data:
new_key, status = _create_api_key_from_data(
key_data, admin_user.id, is_standalone=True
)
if new_key:
db.add(new_key)
stats["standalone_keys"]["created"] += 1
elif status == "skipped":
stats["standalone_keys"]["skipped"] += 1
# invalid 数据不计入统计
db.commit()

View File

@@ -92,6 +92,7 @@ async def get_usage_records(
request: Request,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
search: Optional[str] = None, # 通用搜索:用户名、密钥名、模型名、提供商名
user_id: Optional[str] = None,
username: Optional[str] = None,
model: Optional[str] = None,
@@ -104,6 +105,7 @@ async def get_usage_records(
adapter = AdminUsageRecordsAdapter(
start_date=start_date,
end_date=end_date,
search=search,
user_id=user_id,
username=username,
model=model,
@@ -500,6 +502,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
self,
start_date: Optional[datetime],
end_date: Optional[datetime],
search: Optional[str],
user_id: Optional[str],
username: Optional[str],
model: Optional[str],
@@ -510,6 +513,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
):
self.start_date = start_date
self.end_date = end_date
self.search = search
self.user_id = user_id
self.username = username
self.model = model
@@ -519,25 +523,54 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
self.offset = offset
async def handle(self, context): # type: ignore[override]
from sqlalchemy import or_
from src.utils.database_helpers import escape_like_pattern, safe_truncate_escaped
db = context.db
query = (
db.query(Usage, User, ProviderEndpoint, ProviderAPIKey)
db.query(Usage, User, ProviderEndpoint, ProviderAPIKey, ApiKey)
.outerjoin(User, Usage.user_id == User.id)
.outerjoin(ProviderEndpoint, Usage.provider_endpoint_id == ProviderEndpoint.id)
.outerjoin(ProviderAPIKey, Usage.provider_api_key_id == ProviderAPIKey.id)
.outerjoin(ApiKey, Usage.api_key_id == ApiKey.id)
)
# 如果需要按 Provider 名称搜索/筛选,统一在这里 JOIN
if self.search or self.provider:
query = query.join(Provider, Usage.provider_id == Provider.id, isouter=True)
# 通用搜索:用户名、密钥名、模型名、提供商名
# 支持空格分隔的组合搜索,多个关键词之间是 AND 关系
# 限制:最多 10 个关键词,转义后每个关键词最长 100 字符
if self.search:
keywords = [kw for kw in self.search.strip().split() if kw][:10]
for keyword in keywords:
escaped = safe_truncate_escaped(escape_like_pattern(keyword), 100)
search_pattern = f"%{escaped}%"
query = query.filter(
or_(
User.username.ilike(search_pattern, escape="\\"),
ApiKey.name.ilike(search_pattern, escape="\\"),
Usage.model.ilike(search_pattern, escape="\\"),
Provider.name.ilike(search_pattern, escape="\\"),
)
)
if self.user_id:
query = query.filter(Usage.user_id == self.user_id)
if self.username:
# 支持用户名模糊搜索
query = query.filter(User.username.ilike(f"%{self.username}%"))
escaped = escape_like_pattern(self.username)
query = query.filter(User.username.ilike(f"%{escaped}%", escape="\\"))
if self.model:
# 支持模型名模糊搜索
query = query.filter(Usage.model.ilike(f"%{self.model}%"))
escaped = escape_like_pattern(self.model)
query = query.filter(Usage.model.ilike(f"%{escaped}%", escape="\\"))
if self.provider:
# 支持提供商名称搜索(通过 Provider 表)
query = query.join(Provider, Usage.provider_id == Provider.id, isouter=True)
query = query.filter(Provider.name.ilike(f"%{self.provider}%"))
# 支持提供商名称搜索
escaped = escape_like_pattern(self.provider)
query = query.filter(Provider.name.ilike(f"%{escaped}%", escape="\\"))
if self.status:
# 状态筛选
# 旧的筛选值(基于 is_stream 和 status_codestream, standard, error
@@ -575,7 +608,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
query.order_by(Usage.created_at.desc()).offset(self.offset).limit(self.limit).all()
)
request_ids = [usage.request_id for usage, _, _, _ in records if usage.request_id]
request_ids = [usage.request_id for usage, _, _, _, _ in records if usage.request_id]
fallback_map = {}
if request_ids:
# 只统计实际执行的候选success 或 failed不包括 skipped/pending/available
@@ -595,6 +628,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
action="usage_records",
start_date=self.start_date.isoformat() if self.start_date else None,
end_date=self.end_date.isoformat() if self.end_date else None,
search=self.search,
user_id=self.user_id,
username=self.username,
model=self.model,
@@ -606,7 +640,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
)
# 构建 provider_id -> Provider 名称的映射,避免 N+1 查询
provider_ids = [usage.provider_id for usage, _, _, _ in records if usage.provider_id]
provider_ids = [usage.provider_id for usage, _, _, _, _ in records if usage.provider_id]
provider_map = {}
if provider_ids:
providers_data = (
@@ -615,7 +649,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
provider_map = {str(p.id): p.name for p in providers_data}
data = []
for usage, user, endpoint, api_key in records:
for usage, user, endpoint, provider_api_key, user_api_key in records:
actual_cost = (
float(usage.actual_total_cost_usd)
if usage.actual_total_cost_usd is not None
@@ -636,6 +670,15 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
"user_id": user.id if user else None,
"user_email": user.email if user else "已删除用户",
"username": user.username if user else "已删除用户",
"api_key": (
{
"id": user_api_key.id,
"name": user_api_key.name,
"display": user_api_key.get_display_key(),
}
if user_api_key
else None
),
"provider": provider_name,
"model": usage.model,
"target_model": usage.target_model, # 映射后的目标模型名
@@ -661,7 +704,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
"has_fallback": fallback_map.get(usage.request_id, False),
"api_format": usage.api_format
or (endpoint.api_format if endpoint and endpoint.api_format else None),
"api_key_name": api_key.name if api_key else None,
"api_key_name": provider_api_key.name if provider_api_key else None,
"request_metadata": usage.request_metadata, # Provider 响应元数据
}
)

View File

@@ -40,6 +40,7 @@ from src.core.exceptions import (
UpstreamClientException,
)
from src.core.logger import logger
from src.services.billing import calculate_request_cost as _calculate_request_cost
from src.services.request.result import RequestResult
from src.services.usage.recorder import UsageRecorder
@@ -63,6 +64,9 @@ class ChatAdapterBase(ApiAdapter):
name: str = "chat.base"
mode = ApiMode.STANDARD
# 计费模板配置(子类可覆盖,如 "claude", "openai", "gemini"
BILLING_TEMPLATE: str = "claude"
# 子类可以配置的特殊方法用于check_endpoint
@classmethod
def build_endpoint_url(cls, base_url: str) -> str:
@@ -486,40 +490,6 @@ class ChatAdapterBase(ApiAdapter):
"""
return input_tokens + cache_read_input_tokens
def get_cache_read_price_for_ttl(
self,
tier: dict,
cache_ttl_minutes: Optional[int] = None,
) -> Optional[float]:
"""
根据缓存 TTL 获取缓存读取价格
默认实现:检查 cache_ttl_pricing 配置,按 TTL 选择价格
子类可覆盖此方法实现不同的 TTL 定价逻辑
Args:
tier: 当前阶梯配置
cache_ttl_minutes: 缓存时长(分钟)
Returns:
缓存读取价格(每 1M tokens
"""
ttl_pricing = tier.get("cache_ttl_pricing")
if ttl_pricing and cache_ttl_minutes is not None:
matched_price = None
for ttl_config in ttl_pricing:
ttl_limit = ttl_config.get("ttl_minutes", 0)
if cache_ttl_minutes <= ttl_limit:
matched_price = ttl_config.get("cache_read_price_per_1m")
break
if matched_price is not None:
return matched_price
# 超过所有配置的 TTL使用最后一个
if ttl_pricing:
return ttl_pricing[-1].get("cache_read_price_per_1m")
return tier.get("cache_read_price_per_1m")
def compute_cost(
self,
input_tokens: int,
@@ -537,8 +507,9 @@ class ChatAdapterBase(ApiAdapter):
"""
计算请求成本
默认实现:支持固定价格和阶梯计费
子类可覆盖此方法实现完全不同的计费逻辑
使用 billing 模块的配置驱动计费
子类可通过设置 BILLING_TEMPLATE 类属性来指定计费模板,
或覆盖此方法实现完全自定义的计费逻辑。
Args:
input_tokens: 输入 token 数
@@ -566,88 +537,26 @@ class ChatAdapterBase(ApiAdapter):
"tier_index": Optional[int], # 命中的阶梯索引
}
"""
tier_index = None
effective_input_price = input_price_per_1m
effective_output_price = output_price_per_1m
effective_cache_creation_price = cache_creation_price_per_1m
effective_cache_read_price = cache_read_price_per_1m
# 计算总输入上下文(使用子类可覆盖的方法)
total_input_context = self.compute_total_input_context(
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
)
# 检查阶梯计费
if tiered_pricing and tiered_pricing.get("tiers"):
total_input_context = self.compute_total_input_context(
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
)
tier = self._get_tier_for_tokens(tiered_pricing, total_input_context)
if tier:
tier_index = tiered_pricing["tiers"].index(tier)
effective_input_price = tier.get("input_price_per_1m", input_price_per_1m)
effective_output_price = tier.get("output_price_per_1m", output_price_per_1m)
effective_cache_creation_price = tier.get(
"cache_creation_price_per_1m", cache_creation_price_per_1m
)
effective_cache_read_price = self.get_cache_read_price_for_ttl(
tier, cache_ttl_minutes
)
if effective_cache_read_price is None:
effective_cache_read_price = cache_read_price_per_1m
# 计算各项成本
input_cost = (input_tokens / 1_000_000) * effective_input_price
output_cost = (output_tokens / 1_000_000) * effective_output_price
cache_creation_cost = 0.0
cache_read_cost = 0.0
if cache_creation_input_tokens > 0 and effective_cache_creation_price is not None:
cache_creation_cost = (
cache_creation_input_tokens / 1_000_000
) * effective_cache_creation_price
if cache_read_input_tokens > 0 and effective_cache_read_price is not None:
cache_read_cost = (
cache_read_input_tokens / 1_000_000
) * effective_cache_read_price
cache_cost = cache_creation_cost + cache_read_cost
request_cost = price_per_request if price_per_request else 0.0
total_cost = input_cost + output_cost + cache_cost + request_cost
return {
"input_cost": input_cost,
"output_cost": output_cost,
"cache_creation_cost": cache_creation_cost,
"cache_read_cost": cache_read_cost,
"cache_cost": cache_cost,
"request_cost": request_cost,
"total_cost": total_cost,
"tier_index": tier_index,
}
@staticmethod
def _get_tier_for_tokens(tiered_pricing: dict, total_input_tokens: int) -> Optional[dict]:
"""
根据总输入 token 数确定价格阶梯
Args:
tiered_pricing: 阶梯计费配置 {"tiers": [...]}
total_input_tokens: 总输入 token 数
Returns:
匹配的阶梯配置
"""
if not tiered_pricing or "tiers" not in tiered_pricing:
return None
tiers = tiered_pricing.get("tiers", [])
if not tiers:
return None
for tier in tiers:
up_to = tier.get("up_to")
if up_to is None or total_input_tokens <= up_to:
return tier
# 如果所有阶梯都有上限且都超过了,返回最后一个阶梯
return tiers[-1] if tiers else None
return _calculate_request_cost(
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
input_price_per_1m=input_price_per_1m,
output_price_per_1m=output_price_per_1m,
cache_creation_price_per_1m=cache_creation_price_per_1m,
cache_read_price_per_1m=cache_read_price_per_1m,
price_per_request=price_per_request,
tiered_pricing=tiered_pricing,
cache_ttl_minutes=cache_ttl_minutes,
total_input_context=total_input_context,
billing_template=self.BILLING_TEMPLATE,
)
# =========================================================================
# 模型列表查询 - 子类应覆盖此方法

View File

@@ -38,6 +38,7 @@ from src.core.exceptions import (
UpstreamClientException,
)
from src.core.logger import logger
from src.services.billing import calculate_request_cost as _calculate_request_cost
from src.services.request.result import RequestResult
from src.services.usage.recorder import UsageRecorder
@@ -61,6 +62,9 @@ class CliAdapterBase(ApiAdapter):
name: str = "cli.base"
mode = ApiMode.PROXY
# 计费模板配置(子类可覆盖,如 "claude", "openai", "gemini"
BILLING_TEMPLATE: str = "claude"
def __init__(self, allowed_api_formats: Optional[list[str]] = None):
self.allowed_api_formats = allowed_api_formats or [self.FORMAT_ID]
@@ -438,40 +442,6 @@ class CliAdapterBase(ApiAdapter):
"""
return input_tokens + cache_read_input_tokens
def get_cache_read_price_for_ttl(
self,
tier: dict,
cache_ttl_minutes: Optional[int] = None,
) -> Optional[float]:
"""
根据缓存 TTL 获取缓存读取价格
默认实现:检查 cache_ttl_pricing 配置,按 TTL 选择价格
子类可覆盖此方法实现不同的 TTL 定价逻辑
Args:
tier: 当前阶梯配置
cache_ttl_minutes: 缓存时长(分钟)
Returns:
缓存读取价格(每 1M tokens
"""
ttl_pricing = tier.get("cache_ttl_pricing")
if ttl_pricing and cache_ttl_minutes is not None:
matched_price = None
for ttl_config in ttl_pricing:
ttl_limit = ttl_config.get("ttl_minutes", 0)
if cache_ttl_minutes <= ttl_limit:
matched_price = ttl_config.get("cache_read_price_per_1m")
break
if matched_price is not None:
return matched_price
# 超过所有配置的 TTL使用最后一个
if ttl_pricing:
return ttl_pricing[-1].get("cache_read_price_per_1m")
return tier.get("cache_read_price_per_1m")
def compute_cost(
self,
input_tokens: int,
@@ -489,8 +459,9 @@ class CliAdapterBase(ApiAdapter):
"""
计算请求成本
默认实现:支持固定价格和阶梯计费
子类可覆盖此方法实现完全不同的计费逻辑
使用 billing 模块的配置驱动计费
子类可通过设置 BILLING_TEMPLATE 类属性来指定计费模板,
或覆盖此方法实现完全自定义的计费逻辑。
Args:
input_tokens: 输入 token 数
@@ -508,78 +479,26 @@ class CliAdapterBase(ApiAdapter):
Returns:
包含各项成本的字典
"""
tier_index = None
effective_input_price = input_price_per_1m
effective_output_price = output_price_per_1m
effective_cache_creation_price = cache_creation_price_per_1m
effective_cache_read_price = cache_read_price_per_1m
# 计算总输入上下文(使用子类可覆盖的方法)
total_input_context = self.compute_total_input_context(
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
)
# 检查阶梯计费
if tiered_pricing and tiered_pricing.get("tiers"):
total_input_context = self.compute_total_input_context(
input_tokens, cache_read_input_tokens, cache_creation_input_tokens
)
tier = self._get_tier_for_tokens(tiered_pricing, total_input_context)
if tier:
tier_index = tiered_pricing["tiers"].index(tier)
effective_input_price = tier.get("input_price_per_1m", input_price_per_1m)
effective_output_price = tier.get("output_price_per_1m", output_price_per_1m)
effective_cache_creation_price = tier.get(
"cache_creation_price_per_1m", cache_creation_price_per_1m
)
effective_cache_read_price = self.get_cache_read_price_for_ttl(
tier, cache_ttl_minutes
)
if effective_cache_read_price is None:
effective_cache_read_price = cache_read_price_per_1m
# 计算各项成本
input_cost = (input_tokens / 1_000_000) * effective_input_price
output_cost = (output_tokens / 1_000_000) * effective_output_price
cache_creation_cost = 0.0
cache_read_cost = 0.0
if cache_creation_input_tokens > 0 and effective_cache_creation_price is not None:
cache_creation_cost = (
cache_creation_input_tokens / 1_000_000
) * effective_cache_creation_price
if cache_read_input_tokens > 0 and effective_cache_read_price is not None:
cache_read_cost = (
cache_read_input_tokens / 1_000_000
) * effective_cache_read_price
cache_cost = cache_creation_cost + cache_read_cost
request_cost = price_per_request if price_per_request else 0.0
total_cost = input_cost + output_cost + cache_cost + request_cost
return {
"input_cost": input_cost,
"output_cost": output_cost,
"cache_creation_cost": cache_creation_cost,
"cache_read_cost": cache_read_cost,
"cache_cost": cache_cost,
"request_cost": request_cost,
"total_cost": total_cost,
"tier_index": tier_index,
}
@staticmethod
def _get_tier_for_tokens(tiered_pricing: dict, total_input_tokens: int) -> Optional[dict]:
"""根据总输入 token 数确定价格阶梯"""
if not tiered_pricing or "tiers" not in tiered_pricing:
return None
tiers = tiered_pricing.get("tiers", [])
if not tiers:
return None
for tier in tiers:
up_to = tier.get("up_to")
if up_to is None or total_input_tokens <= up_to:
return tier
return tiers[-1] if tiers else None
return _calculate_request_cost(
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
input_price_per_1m=input_price_per_1m,
output_price_per_1m=output_price_per_1m,
cache_creation_price_per_1m=cache_creation_price_per_1m,
cache_read_price_per_1m=cache_read_price_per_1m,
price_per_request=price_per_request,
tiered_pricing=tiered_pricing,
cache_ttl_minutes=cache_ttl_minutes,
total_input_context=total_input_context,
billing_template=self.BILLING_TEMPLATE,
)
# =========================================================================
# 模型列表查询 - 子类应覆盖此方法

View File

@@ -1497,8 +1497,12 @@ class CliMessageHandlerBase(BaseMessageHandler):
retry_after=int(resp.headers.get("retry-after", 0)) or None,
)
elif resp.status_code >= 500:
error_text = resp.text
raise ProviderNotAvailableException(
f"提供商服务不可用: {provider.name}, 状态: {resp.status_code}"
f"提供商服务不可用: {provider.name}, 状态: {resp.status_code}",
provider_name=str(provider.name),
upstream_status=resp.status_code,
upstream_response=error_text,
)
elif 300 <= resp.status_code < 400:
redirect_url = resp.headers.get("location", "unknown")
@@ -1508,7 +1512,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
elif resp.status_code != 200:
error_text = resp.text
raise ProviderNotAvailableException(
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}, 错误: {error_text[:200]}"
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}",
provider_name=str(provider.name),
upstream_status=resp.status_code,
upstream_response=error_text,
)
# 安全解析 JSON 响应,处理可能的编码错误

View File

@@ -63,6 +63,7 @@ class ClaudeChatAdapter(ChatAdapterBase):
"""
FORMAT_ID = "CLAUDE"
BILLING_TEMPLATE = "claude" # 使用 Claude 计费模板
name = "claude.chat"
@property

View File

@@ -24,6 +24,7 @@ class ClaudeCliAdapter(CliAdapterBase):
"""
FORMAT_ID = "CLAUDE_CLI"
BILLING_TEMPLATE = "claude" # 使用 Claude 计费模板
name = "claude.cli"
@property

View File

@@ -27,6 +27,7 @@ class GeminiChatAdapter(ChatAdapterBase):
"""
FORMAT_ID = "GEMINI"
BILLING_TEMPLATE = "gemini" # 使用 Gemini 计费模板
name = "gemini.chat"
@property

View File

@@ -24,6 +24,7 @@ class GeminiCliAdapter(CliAdapterBase):
"""
FORMAT_ID = "GEMINI_CLI"
BILLING_TEMPLATE = "gemini" # 使用 Gemini 计费模板
name = "gemini.cli"
@property

View File

@@ -26,6 +26,7 @@ class OpenAIChatAdapter(ChatAdapterBase):
"""
FORMAT_ID = "OPENAI"
BILLING_TEMPLATE = "openai" # 使用 OpenAI 计费模板
name = "openai.chat"
@property

View File

@@ -24,6 +24,7 @@ class OpenAICliAdapter(CliAdapterBase):
"""
FORMAT_ID = "OPENAI_CLI"
BILLING_TEMPLATE = "openai" # 使用 OpenAI 计费模板
name = "openai.cli"
@property

View File

@@ -104,11 +104,14 @@ async def get_my_usage(
request: Request,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
search: Optional[str] = None, # 通用搜索:密钥名、模型名
limit: int = Query(100, ge=1, le=200, description="每页记录数默认100最大200"),
offset: int = Query(0, ge=0, le=2000, description="偏移量用于分页最大2000"),
db: Session = Depends(get_db),
):
adapter = GetUsageAdapter(start_date=start_date, end_date=end_date, limit=limit, offset=offset)
adapter = GetUsageAdapter(
start_date=start_date, end_date=end_date, search=search, limit=limit, offset=offset
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@@ -487,10 +490,15 @@ class ToggleMyApiKeyAdapter(AuthenticatedApiAdapter):
class GetUsageAdapter(AuthenticatedApiAdapter):
start_date: Optional[datetime]
end_date: Optional[datetime]
search: Optional[str] = None
limit: int = 100
offset: int = 0
async def handle(self, context): # type: ignore[override]
from sqlalchemy import or_
from src.utils.database_helpers import escape_like_pattern, safe_truncate_escaped
db = context.db
user = context.user
summary_list = UsageService.get_usage_summary(
@@ -595,12 +603,30 @@ class GetUsageAdapter(AuthenticatedApiAdapter):
})
summary_by_provider = sorted(summary_by_provider, key=lambda x: x["requests"], reverse=True)
query = db.query(Usage).filter(Usage.user_id == user.id)
query = (
db.query(Usage, ApiKey)
.outerjoin(ApiKey, Usage.api_key_id == ApiKey.id)
.filter(Usage.user_id == user.id)
)
if self.start_date:
query = query.filter(Usage.created_at >= self.start_date)
if self.end_date:
query = query.filter(Usage.created_at <= self.end_date)
# 通用搜索:密钥名、模型名
# 支持空格分隔的组合搜索,多个关键词之间是 AND 关系
if self.search and self.search.strip():
keywords = [kw for kw in self.search.strip().split() if kw][:10]
for keyword in keywords:
escaped = safe_truncate_escaped(escape_like_pattern(keyword), 100)
search_pattern = f"%{escaped}%"
query = query.filter(
or_(
ApiKey.name.ilike(search_pattern, escape="\\"),
Usage.model.ilike(search_pattern, escape="\\"),
)
)
# 计算总数用于分页
total_records = query.count()
usage_records = query.order_by(Usage.created_at.desc()).offset(self.offset).limit(self.limit).all()
@@ -659,8 +685,17 @@ class GetUsageAdapter(AuthenticatedApiAdapter):
"output_price_per_1m": r.output_price_per_1m,
"cache_creation_price_per_1m": r.cache_creation_price_per_1m,
"cache_read_price_per_1m": r.cache_read_price_per_1m,
"api_key": (
{
"id": str(api_key.id),
"name": api_key.name,
"display": api_key.get_display_key(),
}
if api_key
else None
),
}
for r in usage_records
for r, api_key in usage_records
],
}
@@ -668,7 +703,7 @@ class GetUsageAdapter(AuthenticatedApiAdapter):
if user.role == "admin":
response_data["total_actual_cost"] = total_actual_cost
# 为每条记录添加真实成本和倍率信息
for i, r in enumerate(usage_records):
for i, (r, _) in enumerate(usage_records):
# 确保字段有值,避免前端显示 -
actual_cost = (
r.actual_total_cost_usd if r.actual_total_cost_usd is not None else 0.0

View File

@@ -0,0 +1,51 @@
"""
计费模块
提供配置驱动的计费计算,支持不同厂商的差异化计费模式:
- Claude: input + output + cache_creation + cache_read
- OpenAI: input + output + cache_read (无缓存创建费用)
- 豆包: input + output + cache_read + cache_storage (缓存按时计费)
- 按次计费: per_request
使用方式:
from src.services.billing import BillingCalculator, UsageMapper, StandardizedUsage
# 1. 将原始 usage 映射为标准格式
usage = UsageMapper.map(raw_usage, api_format="OPENAI")
# 2. 使用计费计算器计算费用
calculator = BillingCalculator(template="openai")
result = calculator.calculate(usage, prices)
# 3. 获取费用明细
print(result.total_cost)
print(result.costs) # {"input": 0.01, "output": 0.02, ...}
"""
from src.services.billing.calculator import BillingCalculator, calculate_request_cost
from src.services.billing.models import (
BillingDimension,
BillingUnit,
CostBreakdown,
StandardizedUsage,
)
from src.services.billing.templates import BILLING_TEMPLATE_REGISTRY, BillingTemplates
from src.services.billing.usage_mapper import UsageMapper, map_usage, map_usage_from_response
__all__ = [
# 数据模型
"BillingDimension",
"BillingUnit",
"CostBreakdown",
"StandardizedUsage",
# 模板
"BillingTemplates",
"BILLING_TEMPLATE_REGISTRY",
# 计算器
"BillingCalculator",
"calculate_request_cost",
# 映射器
"UsageMapper",
"map_usage",
"map_usage_from_response",
]

View File

@@ -0,0 +1,339 @@
"""
计费计算器
配置驱动的计费计算,支持:
- 固定价格计费
- 阶梯计费
- 多种计费模板
- 自定义计费维度
"""
from typing import Any, Dict, List, Optional, Tuple
from src.services.billing.models import (
BillingDimension,
BillingUnit,
CostBreakdown,
StandardizedUsage,
)
from src.services.billing.templates import (
BILLING_TEMPLATE_REGISTRY,
BillingTemplates,
get_template,
)
class BillingCalculator:
"""
配置驱动的计费计算器
支持多种计费模式:
- 使用预定义模板claude, openai, doubao 等)
- 自定义计费维度
- 阶梯计费
示例:
# 使用模板
calculator = BillingCalculator(template="openai")
# 自定义维度
calculator = BillingCalculator(dimensions=[
BillingDimension(name="input", usage_field="input_tokens", price_field="input_price_per_1m"),
BillingDimension(name="output", usage_field="output_tokens", price_field="output_price_per_1m"),
])
# 计算费用
usage = StandardizedUsage(input_tokens=1000, output_tokens=500)
prices = {"input_price_per_1m": 3.0, "output_price_per_1m": 15.0}
result = calculator.calculate(usage, prices)
"""
def __init__(
self,
dimensions: Optional[List[BillingDimension]] = None,
template: Optional[str] = None,
):
"""
初始化计费计算器
Args:
dimensions: 自定义计费维度列表(优先级高于模板)
template: 使用预定义模板名称 ("claude", "openai", "doubao", "per_request" 等)
"""
if dimensions:
self.dimensions = dimensions
elif template:
self.dimensions = get_template(template)
else:
# 默认使用 Claude 模板(向后兼容)
self.dimensions = BillingTemplates.CLAUDE_STANDARD
self.template_name = template
def calculate(
self,
usage: StandardizedUsage,
prices: Dict[str, float],
tiered_pricing: Optional[Dict[str, Any]] = None,
cache_ttl_minutes: Optional[int] = None,
total_input_context: Optional[int] = None,
) -> CostBreakdown:
"""
计算费用
Args:
usage: 标准化的 usage 数据
prices: 价格配置 {"input_price_per_1m": 3.0, "output_price_per_1m": 15.0, ...}
tiered_pricing: 阶梯计费配置(可选)
cache_ttl_minutes: 缓存 TTL 分钟数(用于 TTL 差异化定价)
total_input_context: 总输入上下文(用于阶梯判定,可选)
如果提供,将使用该值进行阶梯判定;否则使用默认计算逻辑
Returns:
费用明细 (CostBreakdown)
"""
result = CostBreakdown()
# 处理阶梯计费
effective_prices = prices.copy()
if tiered_pricing and tiered_pricing.get("tiers"):
tier, tier_index = self._get_tier(usage, tiered_pricing, total_input_context)
if tier:
result.tier_index = tier_index
# 阶梯价格覆盖默认价格
for key, value in tier.items():
if key not in ("up_to", "cache_ttl_pricing") and value is not None:
effective_prices[key] = value
# 处理 TTL 差异化定价
if cache_ttl_minutes is not None:
ttl_price = self._get_cache_read_price_for_ttl(tier, cache_ttl_minutes)
if ttl_price is not None:
effective_prices["cache_read_price_per_1m"] = ttl_price
# 记录使用的价格
result.effective_prices = effective_prices.copy()
# 计算各维度费用
total = 0.0
for dim in self.dimensions:
usage_value = usage.get(dim.usage_field, 0)
price = effective_prices.get(dim.price_field, dim.default_price)
if usage_value and price:
cost = dim.calculate(usage_value, price)
result.costs[dim.name] = cost
total += cost
result.total_cost = total
return result
def _get_tier(
self,
usage: StandardizedUsage,
tiered_pricing: Dict[str, Any],
total_input_context: Optional[int] = None,
) -> Tuple[Optional[Dict[str, Any]], Optional[int]]:
"""
确定价格阶梯
Args:
usage: usage 数据
tiered_pricing: 阶梯配置 {"tiers": [...]}
total_input_context: 预计算的总输入上下文(可选)
Returns:
(匹配的阶梯配置, 阶梯索引)
"""
tiers = tiered_pricing.get("tiers", [])
if not tiers:
return None, None
# 使用传入的 total_input_context或者默认计算
if total_input_context is None:
total_input_context = self._compute_total_input_context(usage)
for i, tier in enumerate(tiers):
up_to = tier.get("up_to")
# up_to 为 None 表示无上限(最后一个阶梯)
if up_to is None or total_input_context <= up_to:
return tier, i
# 如果所有阶梯都有上限且都超过了,返回最后一个阶梯
return tiers[-1], len(tiers) - 1
def _compute_total_input_context(self, usage: StandardizedUsage) -> int:
"""
计算总输入上下文(用于阶梯计费判定)
默认: input_tokens + cache_read_tokens
Args:
usage: usage 数据
Returns:
总输入 token 数
"""
return usage.input_tokens + usage.cache_read_tokens
def _get_cache_read_price_for_ttl(
self,
tier: Dict[str, Any],
cache_ttl_minutes: int,
) -> Optional[float]:
"""
根据缓存 TTL 获取缓存读取价格
某些厂商(如 Claude对不同 TTL 的缓存有不同定价。
Args:
tier: 当前阶梯配置
cache_ttl_minutes: 缓存时长(分钟)
Returns:
缓存读取价格,如果没有 TTL 差异化配置返回 None
"""
ttl_pricing = tier.get("cache_ttl_pricing")
if not ttl_pricing:
return None
# 找到匹配或最接近的 TTL 价格
for ttl_config in ttl_pricing:
ttl_limit = ttl_config.get("ttl_minutes", 0)
if cache_ttl_minutes <= ttl_limit:
price = ttl_config.get("cache_read_price_per_1m")
return float(price) if price is not None else None
# 超过所有配置的 TTL使用最后一个
if ttl_pricing:
price = ttl_pricing[-1].get("cache_read_price_per_1m")
return float(price) if price is not None else None
return None
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "BillingCalculator":
"""
从配置创建计费计算器
Config 格式:
{
"template": "claude", # 或 "openai", "doubao", "per_request"
# 或者自定义维度:
"dimensions": [
{"name": "input", "usage_field": "input_tokens", "price_field": "input_price_per_1m"},
...
]
}
Args:
config: 配置字典
Returns:
BillingCalculator 实例
"""
if "dimensions" in config:
dimensions = [BillingDimension.from_dict(d) for d in config["dimensions"]]
return cls(dimensions=dimensions)
return cls(template=config.get("template", "claude"))
def get_dimension_names(self) -> List[str]:
"""获取所有计费维度名称"""
return [dim.name for dim in self.dimensions]
def get_required_price_fields(self) -> List[str]:
"""获取所需的价格字段名称"""
return [dim.price_field for dim in self.dimensions]
def get_required_usage_fields(self) -> List[str]:
"""获取所需的 usage 字段名称"""
return [dim.usage_field for dim in self.dimensions]
def calculate_request_cost(
input_tokens: int,
output_tokens: int,
cache_creation_input_tokens: int,
cache_read_input_tokens: int,
input_price_per_1m: float,
output_price_per_1m: float,
cache_creation_price_per_1m: Optional[float],
cache_read_price_per_1m: Optional[float],
price_per_request: Optional[float],
tiered_pricing: Optional[Dict[str, Any]] = None,
cache_ttl_minutes: Optional[int] = None,
total_input_context: Optional[int] = None,
billing_template: str = "claude",
) -> Dict[str, Any]:
"""
计算请求成本的便捷函数
封装了 BillingCalculator 的调用逻辑,返回兼容旧格式的字典。
Args:
input_tokens: 输入 token 数
output_tokens: 输出 token 数
cache_creation_input_tokens: 缓存创建 token 数
cache_read_input_tokens: 缓存读取 token 数
input_price_per_1m: 输入价格(每 1M tokens
output_price_per_1m: 输出价格(每 1M tokens
cache_creation_price_per_1m: 缓存创建价格(每 1M tokens
cache_read_price_per_1m: 缓存读取价格(每 1M tokens
price_per_request: 按次计费价格
tiered_pricing: 阶梯计费配置
cache_ttl_minutes: 缓存时长(分钟)
total_input_context: 总输入上下文(用于阶梯判定)
billing_template: 计费模板名称
Returns:
包含各项成本的字典:
{
"input_cost": float,
"output_cost": float,
"cache_creation_cost": float,
"cache_read_cost": float,
"cache_cost": float,
"request_cost": float,
"total_cost": float,
"tier_index": Optional[int],
}
"""
# 构建标准化 usage
usage = StandardizedUsage(
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_tokens=cache_creation_input_tokens,
cache_read_tokens=cache_read_input_tokens,
request_count=1,
)
# 构建价格配置
prices: Dict[str, float] = {
"input_price_per_1m": input_price_per_1m,
"output_price_per_1m": output_price_per_1m,
}
if cache_creation_price_per_1m is not None:
prices["cache_creation_price_per_1m"] = cache_creation_price_per_1m
if cache_read_price_per_1m is not None:
prices["cache_read_price_per_1m"] = cache_read_price_per_1m
if price_per_request is not None:
prices["price_per_request"] = price_per_request
# 使用 BillingCalculator 计算
calculator = BillingCalculator(template=billing_template)
result = calculator.calculate(
usage, prices, tiered_pricing, cache_ttl_minutes, total_input_context
)
# 返回兼容旧格式的字典
return {
"input_cost": result.input_cost,
"output_cost": result.output_cost,
"cache_creation_cost": result.cache_creation_cost,
"cache_read_cost": result.cache_read_cost,
"cache_cost": result.cache_cost,
"request_cost": result.request_cost,
"total_cost": result.total_cost,
"tier_index": result.tier_index,
}

View File

@@ -0,0 +1,281 @@
"""
计费模块数据模型
定义计费相关的核心数据结构:
- BillingUnit: 计费单位枚举
- BillingDimension: 计费维度定义
- StandardizedUsage: 标准化的 usage 数据
- CostBreakdown: 计费明细结果
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, Optional
class BillingUnit(str, Enum):
"""计费单位"""
PER_1M_TOKENS = "per_1m_tokens" # 每百万 token
PER_1M_TOKENS_HOUR = "per_1m_tokens_hour" # 每百万 token 每小时(豆包缓存存储)
PER_REQUEST = "per_request" # 每次请求
FIXED = "fixed" # 固定费用
@dataclass
class BillingDimension:
"""
计费维度定义
每个维度描述一种计费方式,例如:
- 输入 token 计费
- 输出 token 计费
- 缓存读取计费
- 按次计费
"""
name: str # 维度名称,如 "input", "output", "cache_read"
usage_field: str # 从 usage 中取值的字段名
price_field: str # 价格配置中的字段名
unit: BillingUnit = BillingUnit.PER_1M_TOKENS # 计费单位
default_price: float = 0.0 # 默认价格(当价格配置中没有时使用)
def calculate(self, usage_value: float, price: float) -> float:
"""
计算该维度的费用
Args:
usage_value: 使用量数值
price: 单价
Returns:
计算后的费用
"""
if usage_value <= 0 or price <= 0:
return 0.0
if self.unit == BillingUnit.PER_1M_TOKENS:
return (usage_value / 1_000_000) * price
elif self.unit == BillingUnit.PER_1M_TOKENS_HOUR:
# 缓存存储按 token 数 * 小时数计费
return (usage_value / 1_000_000) * price
elif self.unit == BillingUnit.PER_REQUEST:
return usage_value * price
elif self.unit == BillingUnit.FIXED:
return price
return 0.0
def to_dict(self) -> Dict[str, Any]:
"""转换为字典(用于序列化)"""
return {
"name": self.name,
"usage_field": self.usage_field,
"price_field": self.price_field,
"unit": self.unit.value,
"default_price": self.default_price,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "BillingDimension":
"""从字典创建实例"""
return cls(
name=data["name"],
usage_field=data["usage_field"],
price_field=data["price_field"],
unit=BillingUnit(data.get("unit", "per_1m_tokens")),
default_price=data.get("default_price", 0.0),
)
@dataclass
class StandardizedUsage:
"""
标准化的 Usage 数据
将不同 API 格式的 usage 统一为标准格式,便于计费计算。
"""
# 基础 token 计数
input_tokens: int = 0
output_tokens: int = 0
# 缓存相关
cache_creation_tokens: int = 0 # Claude: 缓存创建
cache_read_tokens: int = 0 # Claude/OpenAI/豆包: 缓存读取/命中
# 特殊 token 类型
reasoning_tokens: int = 0 # o1/豆包: 推理 token通常包含在 output 中,单独记录用于分析)
# 时间相关(用于按时计费)
cache_storage_token_hours: float = 0.0 # 豆包: 缓存存储 token*小时
# 请求计数(用于按次计费)
request_count: int = 1
# 扩展字段(未来可能需要的额外维度)
extra: Dict[str, Any] = field(default_factory=dict)
def get(self, field_name: str, default: Any = 0) -> Any:
"""
通用字段获取
支持获取标准字段和扩展字段。
Args:
field_name: 字段名
default: 默认值
Returns:
字段值
"""
if hasattr(self, field_name):
value = getattr(self, field_name)
# 对于 extra 字段,不直接返回
if field_name != "extra":
return value
return self.extra.get(field_name, default)
def set(self, field_name: str, value: Any) -> None:
"""
通用字段设置
Args:
field_name: 字段名
value: 字段值
"""
if hasattr(self, field_name) and field_name != "extra":
setattr(self, field_name, value)
else:
self.extra[field_name] = value
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
result: Dict[str, Any] = {
"input_tokens": self.input_tokens,
"output_tokens": self.output_tokens,
"cache_creation_tokens": self.cache_creation_tokens,
"cache_read_tokens": self.cache_read_tokens,
"reasoning_tokens": self.reasoning_tokens,
"cache_storage_token_hours": self.cache_storage_token_hours,
"request_count": self.request_count,
}
if self.extra:
result["extra"] = self.extra
return result
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "StandardizedUsage":
"""从字典创建实例"""
extra = data.pop("extra", {}) if "extra" in data else {}
# 只取已知字段
known_fields = {
"input_tokens",
"output_tokens",
"cache_creation_tokens",
"cache_read_tokens",
"reasoning_tokens",
"cache_storage_token_hours",
"request_count",
}
filtered = {k: v for k, v in data.items() if k in known_fields}
return cls(**filtered, extra=extra)
@dataclass
class CostBreakdown:
"""
计费明细结果
包含各维度的费用和总费用。
"""
# 各维度费用 {"input": 0.01, "output": 0.02, "cache_read": 0.001, ...}
costs: Dict[str, float] = field(default_factory=dict)
# 总费用
total_cost: float = 0.0
# 命中的阶梯索引(如果使用阶梯计费)
tier_index: Optional[int] = None
# 货币单位
currency: str = "USD"
# 使用的价格(用于记录和审计)
effective_prices: Dict[str, float] = field(default_factory=dict)
# =========================================================================
# 兼容旧接口的属性(便于渐进式迁移)
# =========================================================================
@property
def input_cost(self) -> float:
"""输入费用"""
return self.costs.get("input", 0.0)
@property
def output_cost(self) -> float:
"""输出费用"""
return self.costs.get("output", 0.0)
@property
def cache_creation_cost(self) -> float:
"""缓存创建费用"""
return self.costs.get("cache_creation", 0.0)
@property
def cache_read_cost(self) -> float:
"""缓存读取费用"""
return self.costs.get("cache_read", 0.0)
@property
def cache_cost(self) -> float:
"""总缓存费用(创建 + 读取)"""
return self.cache_creation_cost + self.cache_read_cost
@property
def request_cost(self) -> float:
"""按次计费费用"""
return self.costs.get("request", 0.0)
@property
def cache_storage_cost(self) -> float:
"""缓存存储费用(豆包等)"""
return self.costs.get("cache_storage", 0.0)
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"costs": self.costs,
"total_cost": self.total_cost,
"tier_index": self.tier_index,
"currency": self.currency,
"effective_prices": self.effective_prices,
# 兼容字段
"input_cost": self.input_cost,
"output_cost": self.output_cost,
"cache_creation_cost": self.cache_creation_cost,
"cache_read_cost": self.cache_read_cost,
"cache_cost": self.cache_cost,
"request_cost": self.request_cost,
}
def to_legacy_tuple(self) -> tuple:
"""
转换为旧接口的元组格式
Returns:
(input_cost, output_cost, cache_creation_cost, cache_read_cost,
cache_cost, request_cost, total_cost, tier_index)
"""
return (
self.input_cost,
self.output_cost,
self.cache_creation_cost,
self.cache_read_cost,
self.cache_cost,
self.request_cost,
self.total_cost,
self.tier_index,
)

View File

@@ -0,0 +1,213 @@
"""
预定义计费模板
提供常见厂商的计费配置模板,避免重复配置:
- CLAUDE_STANDARD: Claude/Anthropic 标准计费
- OPENAI_STANDARD: OpenAI 标准计费
- DOUBAO_STANDARD: 豆包计费(含缓存存储)
- GEMINI_STANDARD: Gemini 标准计费
- PER_REQUEST: 按次计费
"""
from typing import Dict, List, Optional
from src.services.billing.models import BillingDimension, BillingUnit
class BillingTemplates:
"""预定义的计费模板"""
# =========================================================================
# Claude/Anthropic 标准计费
# - 输入 token
# - 输出 token
# - 缓存创建(创建时收费,约 1.25x 输入价格)
# - 缓存读取(约 0.1x 输入价格)
# =========================================================================
CLAUDE_STANDARD: List[BillingDimension] = [
BillingDimension(
name="input",
usage_field="input_tokens",
price_field="input_price_per_1m",
),
BillingDimension(
name="output",
usage_field="output_tokens",
price_field="output_price_per_1m",
),
BillingDimension(
name="cache_creation",
usage_field="cache_creation_tokens",
price_field="cache_creation_price_per_1m",
),
BillingDimension(
name="cache_read",
usage_field="cache_read_tokens",
price_field="cache_read_price_per_1m",
),
]
# =========================================================================
# OpenAI 标准计费
# - 输入 token
# - 输出 token
# - 缓存读取(部分模型支持,无缓存创建费用)
# =========================================================================
OPENAI_STANDARD: List[BillingDimension] = [
BillingDimension(
name="input",
usage_field="input_tokens",
price_field="input_price_per_1m",
),
BillingDimension(
name="output",
usage_field="output_tokens",
price_field="output_price_per_1m",
),
BillingDimension(
name="cache_read",
usage_field="cache_read_tokens",
price_field="cache_read_price_per_1m",
),
]
# =========================================================================
# 豆包计费
# - 推理输入 (input_tokens)
# - 推理输出 (output_tokens)
# - 缓存命中 (cache_read_tokens) - 类似 Claude 的缓存读取
# - 缓存存储 (cache_storage_token_hours) - 按 token 数 * 存储时长计费
#
# 注意:豆包的缓存创建是免费的,但存储需要按时付费
# =========================================================================
DOUBAO_STANDARD: List[BillingDimension] = [
BillingDimension(
name="input",
usage_field="input_tokens",
price_field="input_price_per_1m",
),
BillingDimension(
name="output",
usage_field="output_tokens",
price_field="output_price_per_1m",
),
BillingDimension(
name="cache_read",
usage_field="cache_read_tokens",
price_field="cache_read_price_per_1m",
),
BillingDimension(
name="cache_storage",
usage_field="cache_storage_token_hours",
price_field="cache_storage_price_per_1m_hour",
unit=BillingUnit.PER_1M_TOKENS_HOUR,
),
]
# =========================================================================
# Gemini 标准计费
# - 输入 token
# - 输出 token
# - 缓存读取
# =========================================================================
GEMINI_STANDARD: List[BillingDimension] = [
BillingDimension(
name="input",
usage_field="input_tokens",
price_field="input_price_per_1m",
),
BillingDimension(
name="output",
usage_field="output_tokens",
price_field="output_price_per_1m",
),
BillingDimension(
name="cache_read",
usage_field="cache_read_tokens",
price_field="cache_read_price_per_1m",
),
]
# =========================================================================
# 按次计费
# - 适用于某些图片生成模型、特殊 API 等
# - 仅按请求次数计费,不按 token 计费
# =========================================================================
PER_REQUEST: List[BillingDimension] = [
BillingDimension(
name="request",
usage_field="request_count",
price_field="price_per_request",
unit=BillingUnit.PER_REQUEST,
),
]
# =========================================================================
# 混合计费(按次 + 按 token
# - 某些模型既有固定费用又有 token 费用
# =========================================================================
HYBRID_STANDARD: List[BillingDimension] = [
BillingDimension(
name="input",
usage_field="input_tokens",
price_field="input_price_per_1m",
),
BillingDimension(
name="output",
usage_field="output_tokens",
price_field="output_price_per_1m",
),
BillingDimension(
name="request",
usage_field="request_count",
price_field="price_per_request",
unit=BillingUnit.PER_REQUEST,
),
]
# =========================================================================
# 模板注册表
# =========================================================================
BILLING_TEMPLATE_REGISTRY: Dict[str, List[BillingDimension]] = {
# 按厂商名称
"claude": BillingTemplates.CLAUDE_STANDARD,
"anthropic": BillingTemplates.CLAUDE_STANDARD,
"openai": BillingTemplates.OPENAI_STANDARD,
"doubao": BillingTemplates.DOUBAO_STANDARD,
"bytedance": BillingTemplates.DOUBAO_STANDARD,
"gemini": BillingTemplates.GEMINI_STANDARD,
"google": BillingTemplates.GEMINI_STANDARD,
# 按计费模式
"per_request": BillingTemplates.PER_REQUEST,
"hybrid": BillingTemplates.HYBRID_STANDARD,
# 默认
"default": BillingTemplates.CLAUDE_STANDARD,
}
def get_template(name: Optional[str]) -> List[BillingDimension]:
"""
获取计费模板
Args:
name: 模板名称(不区分大小写)
Returns:
计费维度列表
"""
if not name:
return BILLING_TEMPLATE_REGISTRY["default"]
template = BILLING_TEMPLATE_REGISTRY.get(name.lower())
if template is None:
available = ", ".join(sorted(BILLING_TEMPLATE_REGISTRY.keys()))
raise ValueError(f"Unknown billing template: {name!r}. Available: {available}")
return template
def list_templates() -> List[str]:
"""列出所有可用的模板名称"""
return list(BILLING_TEMPLATE_REGISTRY.keys())

View File

@@ -0,0 +1,267 @@
"""
Usage 字段映射器
将不同 API 格式的原始 usage 数据映射为标准化格式。
支持的格式:
- OPENAI / OPENAI_CLI: OpenAI Chat Completions API
- CLAUDE / CLAUDE_CLI: Anthropic Messages API
- GEMINI / GEMINI_CLI: Google Gemini API
"""
from typing import Any, Dict, Optional
from src.services.billing.models import StandardizedUsage
class UsageMapper:
"""
Usage 字段映射器
将不同 API 格式的 usage 统一映射为 StandardizedUsage。
示例:
# OpenAI 格式
raw_usage = {
"prompt_tokens": 100,
"completion_tokens": 50,
"prompt_tokens_details": {"cached_tokens": 20},
"completion_tokens_details": {"reasoning_tokens": 10}
}
usage = UsageMapper.map(raw_usage, "OPENAI")
# Claude 格式
raw_usage = {
"input_tokens": 100,
"output_tokens": 50,
"cache_creation_input_tokens": 30,
"cache_read_input_tokens": 20
}
usage = UsageMapper.map(raw_usage, "CLAUDE")
"""
# =========================================================================
# 字段映射配置
# 格式: "source_path" -> "target_field"
# source_path 支持点号分隔的嵌套路径
# =========================================================================
# OpenAI 格式字段映射
OPENAI_MAPPING: Dict[str, str] = {
"prompt_tokens": "input_tokens",
"completion_tokens": "output_tokens",
"prompt_tokens_details.cached_tokens": "cache_read_tokens",
"completion_tokens_details.reasoning_tokens": "reasoning_tokens",
}
# Claude 格式字段映射
CLAUDE_MAPPING: Dict[str, str] = {
"input_tokens": "input_tokens",
"output_tokens": "output_tokens",
"cache_creation_input_tokens": "cache_creation_tokens",
"cache_read_input_tokens": "cache_read_tokens",
}
# Gemini 格式字段映射
GEMINI_MAPPING: Dict[str, str] = {
"promptTokenCount": "input_tokens",
"candidatesTokenCount": "output_tokens",
"cachedContentTokenCount": "cache_read_tokens",
# Gemini 的 usageMetadata 格式
"usageMetadata.promptTokenCount": "input_tokens",
"usageMetadata.candidatesTokenCount": "output_tokens",
"usageMetadata.cachedContentTokenCount": "cache_read_tokens",
}
# 格式名称到映射的对应关系
FORMAT_MAPPINGS: Dict[str, Dict[str, str]] = {
"OPENAI": OPENAI_MAPPING,
"OPENAI_CLI": OPENAI_MAPPING,
"CLAUDE": CLAUDE_MAPPING,
"CLAUDE_CLI": CLAUDE_MAPPING,
"GEMINI": GEMINI_MAPPING,
"GEMINI_CLI": GEMINI_MAPPING,
}
@classmethod
def map(
cls,
raw_usage: Dict[str, Any],
api_format: str,
extra_mapping: Optional[Dict[str, str]] = None,
) -> StandardizedUsage:
"""
将原始 usage 映射为标准化格式
Args:
raw_usage: 原始 usage 字典
api_format: API 格式 ("OPENAI", "CLAUDE", "GEMINI" 等)
extra_mapping: 额外的字段映射(用于自定义扩展)
Returns:
标准化的 usage 对象
"""
if not raw_usage:
return StandardizedUsage()
# 获取对应格式的字段映射
mapping = cls._get_mapping(api_format)
# 合并额外映射
if extra_mapping:
mapping = {**mapping, **extra_mapping}
result = StandardizedUsage()
# 执行映射
for source_path, target_field in mapping.items():
value = cls._get_nested_value(raw_usage, source_path)
if value is not None:
result.set(target_field, value)
return result
@classmethod
def map_from_response(
cls,
response: Dict[str, Any],
api_format: str,
) -> StandardizedUsage:
"""
从完整响应中提取并映射 usage
不同 API 格式的 usage 位置可能不同:
- OpenAI: response["usage"]
- Claude: response["usage"] 或 message_delta 中
- Gemini: response["usageMetadata"]
Args:
response: 完整的 API 响应
api_format: API 格式
Returns:
标准化的 usage 对象
"""
format_upper = api_format.upper() if api_format else ""
# 提取 usage 部分
usage_data: Dict[str, Any] = {}
if format_upper.startswith("GEMINI"):
# Gemini: usageMetadata
usage_data = response.get("usageMetadata", {})
if not usage_data:
# 尝试从 candidates 中获取
candidates = response.get("candidates", [])
if candidates:
usage_data = candidates[0].get("usageMetadata", {})
else:
# OpenAI/Claude: usage
usage_data = response.get("usage", {})
return cls.map(usage_data, api_format)
@classmethod
def _get_mapping(cls, api_format: str) -> Dict[str, str]:
"""获取对应格式的字段映射"""
if not api_format:
return cls.CLAUDE_MAPPING
format_upper = api_format.upper()
# 精确匹配
if format_upper in cls.FORMAT_MAPPINGS:
return cls.FORMAT_MAPPINGS[format_upper]
# 前缀匹配
for key, mapping in cls.FORMAT_MAPPINGS.items():
if format_upper.startswith(key.split("_")[0]):
return mapping
# 默认使用 Claude 映射
return cls.CLAUDE_MAPPING
@classmethod
def _get_nested_value(cls, data: Dict[str, Any], path: str) -> Any:
"""
获取嵌套字段值
支持点号分隔的路径,如 "prompt_tokens_details.cached_tokens"
Args:
data: 数据字典
path: 字段路径
Returns:
字段值,不存在则返回 None
"""
if not data or not path:
return None
keys = path.split(".")
value: Any = data
for key in keys:
if isinstance(value, dict):
value = value.get(key)
if value is None:
return None
else:
return None
return value
@classmethod
def register_format(cls, format_name: str, mapping: Dict[str, str]) -> None:
"""
注册新的格式映射
Args:
format_name: 格式名称(会自动转为大写)
mapping: 字段映射
"""
cls.FORMAT_MAPPINGS[format_name.upper()] = mapping
@classmethod
def get_supported_formats(cls) -> list:
"""获取所有支持的格式"""
return list(cls.FORMAT_MAPPINGS.keys())
# =========================================================================
# 便捷函数
# =========================================================================
def map_usage(
raw_usage: Dict[str, Any],
api_format: str,
) -> StandardizedUsage:
"""
便捷函数:将原始 usage 映射为标准化格式
Args:
raw_usage: 原始 usage 字典
api_format: API 格式
Returns:
StandardizedUsage 对象
"""
return UsageMapper.map(raw_usage, api_format)
def map_usage_from_response(
response: Dict[str, Any],
api_format: str,
) -> StandardizedUsage:
"""
便捷函数:从响应中提取并映射 usage
Args:
response: API 响应
api_format: API 格式
Returns:
StandardizedUsage 对象
"""
return UsageMapper.map_from_response(response, api_format)

View File

@@ -7,6 +7,59 @@ from typing import Any
from sqlalchemy import func
def escape_like_pattern(pattern: str) -> str:
"""
转义 SQL LIKE 语句中的特殊字符(%、_、\\
Args:
pattern: 原始搜索模式
Returns:
转义后的模式,可安全用于 LIKE 查询(需配合 escape="\\\\"
Examples:
>>> escape_like_pattern("hello_world%test")
'hello\\\\_world\\\\%test'
"""
return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
def safe_truncate_escaped(escaped: str, max_len: int) -> str:
"""
安全截断已转义的字符串,避免截断在转义序列中间
转义后的字符串中,反斜杠总是成对出现(\\\\)或作为转义符(\\%, \\_
如果在某个位置截断导致末尾有奇数个反斜杠,说明截断发生在转义序列中间,
需要去掉最后一个反斜杠以保持转义完整性。
Args:
escaped: 已经过 escape_like_pattern 处理的字符串
max_len: 最大长度
Returns:
截断后的字符串,保证不会破坏转义序列
"""
if len(escaped) <= max_len:
return escaped
truncated = escaped[:max_len]
# 统计末尾连续的反斜杠数量
trailing_backslashes = 0
for i in range(len(truncated) - 1, -1, -1):
if truncated[i] == "\\":
trailing_backslashes += 1
else:
break
# 如果末尾反斜杠数量为奇数,说明截断在转义序列中间
# 需要去掉最后一个反斜杠
if trailing_backslashes % 2 == 1:
truncated = truncated[:-1]
return truncated
def date_trunc_portable(dialect_name: str, interval: str, column: Any) -> Any:
"""
跨数据库的日期截断函数

View File

View File

@@ -0,0 +1,440 @@
"""
Billing 模块测试
测试计费模块的核心功能:
- BillingCalculator 计费计算
- 计费模板
- 阶梯计费
- calculate_request_cost 便捷函数
"""
import pytest
from src.services.billing import (
BillingCalculator,
BillingDimension,
BillingTemplates,
BillingUnit,
CostBreakdown,
StandardizedUsage,
calculate_request_cost,
)
from src.services.billing.templates import get_template, list_templates
class TestBillingDimension:
"""测试计费维度"""
def test_calculate_per_1m_tokens(self) -> None:
"""测试 per_1m_tokens 计费"""
dim = BillingDimension(
name="input",
usage_field="input_tokens",
price_field="input_price_per_1m",
)
# 1000 tokens * $3 / 1M = $0.003
cost = dim.calculate(1000, 3.0)
assert abs(cost - 0.003) < 0.0001
def test_calculate_per_request(self) -> None:
"""测试按次计费"""
dim = BillingDimension(
name="request",
usage_field="request_count",
price_field="price_per_request",
unit=BillingUnit.PER_REQUEST,
)
# 按次计费cost = request_count * price
cost = dim.calculate(1, 0.05)
assert cost == 0.05
# 多次请求应按次数计费
cost = dim.calculate(3, 0.05)
assert abs(cost - 0.15) < 0.0001
def test_calculate_zero_usage(self) -> None:
"""测试零用量"""
dim = BillingDimension(
name="input",
usage_field="input_tokens",
price_field="input_price_per_1m",
)
cost = dim.calculate(0, 3.0)
assert cost == 0.0
def test_calculate_zero_price(self) -> None:
"""测试零价格"""
dim = BillingDimension(
name="input",
usage_field="input_tokens",
price_field="input_price_per_1m",
)
cost = dim.calculate(1000, 0.0)
assert cost == 0.0
def test_to_dict_and_from_dict(self) -> None:
"""测试序列化和反序列化"""
dim = BillingDimension(
name="cache_read",
usage_field="cache_read_tokens",
price_field="cache_read_price_per_1m",
unit=BillingUnit.PER_1M_TOKENS,
default_price=0.3,
)
d = dim.to_dict()
restored = BillingDimension.from_dict(d)
assert restored.name == dim.name
assert restored.usage_field == dim.usage_field
assert restored.price_field == dim.price_field
assert restored.unit == dim.unit
assert restored.default_price == dim.default_price
class TestStandardizedUsage:
"""测试标准化 Usage"""
def test_basic_usage(self) -> None:
"""测试基础 usage"""
usage = StandardizedUsage(
input_tokens=1000,
output_tokens=500,
)
assert usage.input_tokens == 1000
assert usage.output_tokens == 500
assert usage.cache_creation_tokens == 0
assert usage.cache_read_tokens == 0
def test_get_field(self) -> None:
"""测试字段获取"""
usage = StandardizedUsage(
input_tokens=1000,
output_tokens=500,
)
assert usage.get("input_tokens") == 1000
assert usage.get("nonexistent", 0) == 0
def test_extra_fields(self) -> None:
"""测试扩展字段"""
usage = StandardizedUsage(
input_tokens=1000,
output_tokens=500,
extra={"custom_field": 123},
)
assert usage.get("custom_field") == 123
def test_to_dict(self) -> None:
"""测试转换为字典"""
usage = StandardizedUsage(
input_tokens=1000,
output_tokens=500,
cache_creation_tokens=100,
)
d = usage.to_dict()
assert d["input_tokens"] == 1000
assert d["output_tokens"] == 500
assert d["cache_creation_tokens"] == 100
class TestCostBreakdown:
"""测试费用明细"""
def test_basic_breakdown(self) -> None:
"""测试基础费用明细"""
breakdown = CostBreakdown(
costs={"input": 0.003, "output": 0.0075},
total_cost=0.0105,
)
assert breakdown.input_cost == 0.003
assert breakdown.output_cost == 0.0075
assert breakdown.total_cost == 0.0105
def test_cache_cost_calculation(self) -> None:
"""测试缓存费用汇总"""
breakdown = CostBreakdown(
costs={
"input": 0.003,
"output": 0.0075,
"cache_creation": 0.001,
"cache_read": 0.0005,
},
total_cost=0.012,
)
# cache_cost = cache_creation + cache_read
assert abs(breakdown.cache_cost - 0.0015) < 0.0001
def test_to_dict(self) -> None:
"""测试转换为字典"""
breakdown = CostBreakdown(
costs={"input": 0.003, "output": 0.0075},
total_cost=0.0105,
tier_index=1,
)
d = breakdown.to_dict()
assert d["total_cost"] == 0.0105
assert d["tier_index"] == 1
assert d["input_cost"] == 0.003
class TestBillingTemplates:
"""测试计费模板"""
def test_claude_template(self) -> None:
"""测试 Claude 模板"""
template = BillingTemplates.CLAUDE_STANDARD
dim_names = [d.name for d in template]
assert "input" in dim_names
assert "output" in dim_names
assert "cache_creation" in dim_names
assert "cache_read" in dim_names
def test_openai_template(self) -> None:
"""测试 OpenAI 模板"""
template = BillingTemplates.OPENAI_STANDARD
dim_names = [d.name for d in template]
assert "input" in dim_names
assert "output" in dim_names
assert "cache_read" in dim_names
# OpenAI 没有缓存创建费用
assert "cache_creation" not in dim_names
def test_gemini_template(self) -> None:
"""测试 Gemini 模板"""
template = BillingTemplates.GEMINI_STANDARD
dim_names = [d.name for d in template]
assert "input" in dim_names
assert "output" in dim_names
assert "cache_read" in dim_names
def test_per_request_template(self) -> None:
"""测试按次计费模板"""
template = BillingTemplates.PER_REQUEST
assert len(template) == 1
assert template[0].name == "request"
assert template[0].unit == BillingUnit.PER_REQUEST
def test_get_template(self) -> None:
"""测试获取模板"""
template = get_template("claude")
assert template == BillingTemplates.CLAUDE_STANDARD
template = get_template("openai")
assert template == BillingTemplates.OPENAI_STANDARD
# 不区分大小写
template = get_template("CLAUDE")
assert template == BillingTemplates.CLAUDE_STANDARD
with pytest.raises(ValueError, match="Unknown billing template"):
get_template("unknown_template")
def test_list_templates(self) -> None:
"""测试列出模板"""
templates = list_templates()
assert "claude" in templates
assert "openai" in templates
assert "gemini" in templates
assert "per_request" in templates
class TestBillingCalculator:
"""测试计费计算器"""
def test_basic_calculation(self) -> None:
"""测试基础计费计算"""
calculator = BillingCalculator(template="claude")
usage = StandardizedUsage(input_tokens=1000, output_tokens=500)
prices = {"input_price_per_1m": 3.0, "output_price_per_1m": 15.0}
result = calculator.calculate(usage, prices)
# 1000 * 3 / 1M = 0.003
assert abs(result.input_cost - 0.003) < 0.0001
# 500 * 15 / 1M = 0.0075
assert abs(result.output_cost - 0.0075) < 0.0001
# Total = 0.0105
assert abs(result.total_cost - 0.0105) < 0.0001
def test_calculation_with_cache(self) -> None:
"""测试带缓存的计费计算"""
calculator = BillingCalculator(template="claude")
usage = StandardizedUsage(
input_tokens=1000,
output_tokens=500,
cache_creation_tokens=200,
cache_read_tokens=300,
)
prices = {
"input_price_per_1m": 3.0,
"output_price_per_1m": 15.0,
"cache_creation_price_per_1m": 3.75,
"cache_read_price_per_1m": 0.3,
}
result = calculator.calculate(usage, prices)
# cache_creation: 200 * 3.75 / 1M = 0.00075
assert abs(result.cache_creation_cost - 0.00075) < 0.0001
# cache_read: 300 * 0.3 / 1M = 0.00009
assert abs(result.cache_read_cost - 0.00009) < 0.0001
def test_tiered_pricing(self) -> None:
"""测试阶梯计费"""
calculator = BillingCalculator(template="claude")
usage = StandardizedUsage(input_tokens=250000, output_tokens=10000)
# 大于 200k 进入第二阶梯
tiered_pricing = {
"tiers": [
{"up_to": 200000, "input_price_per_1m": 3.0, "output_price_per_1m": 15.0},
{"up_to": None, "input_price_per_1m": 1.5, "output_price_per_1m": 7.5},
]
}
prices = {"input_price_per_1m": 3.0, "output_price_per_1m": 15.0}
result = calculator.calculate(usage, prices, tiered_pricing)
# 应该使用第二阶梯价格
assert result.tier_index == 1
# 250000 * 1.5 / 1M = 0.375
assert abs(result.input_cost - 0.375) < 0.0001
def test_openai_no_cache_creation(self) -> None:
"""测试 OpenAI 模板没有缓存创建费用"""
calculator = BillingCalculator(template="openai")
usage = StandardizedUsage(
input_tokens=1000,
output_tokens=500,
cache_creation_tokens=200, # 这个不应该计费
cache_read_tokens=300,
)
prices = {
"input_price_per_1m": 3.0,
"output_price_per_1m": 15.0,
"cache_creation_price_per_1m": 3.75,
"cache_read_price_per_1m": 0.3,
}
result = calculator.calculate(usage, prices)
# OpenAI 模板不包含 cache_creation 维度
assert result.cache_creation_cost == 0.0
# 但 cache_read 应该计费
assert result.cache_read_cost > 0
def test_from_config(self) -> None:
"""测试从配置创建计算器"""
config = {"template": "openai"}
calculator = BillingCalculator.from_config(config)
assert calculator.template_name == "openai"
class TestCalculateRequestCost:
"""测试便捷函数"""
def test_basic_usage(self) -> None:
"""测试基础用法"""
result = calculate_request_cost(
input_tokens=1000,
output_tokens=500,
cache_creation_input_tokens=0,
cache_read_input_tokens=0,
input_price_per_1m=3.0,
output_price_per_1m=15.0,
cache_creation_price_per_1m=None,
cache_read_price_per_1m=None,
price_per_request=None,
billing_template="claude",
)
assert "input_cost" in result
assert "output_cost" in result
assert "total_cost" in result
assert abs(result["input_cost"] - 0.003) < 0.0001
assert abs(result["output_cost"] - 0.0075) < 0.0001
def test_with_cache(self) -> None:
"""测试带缓存"""
result = calculate_request_cost(
input_tokens=1000,
output_tokens=500,
cache_creation_input_tokens=200,
cache_read_input_tokens=300,
input_price_per_1m=3.0,
output_price_per_1m=15.0,
cache_creation_price_per_1m=3.75,
cache_read_price_per_1m=0.3,
price_per_request=None,
billing_template="claude",
)
assert result["cache_creation_cost"] > 0
assert result["cache_read_cost"] > 0
assert result["cache_cost"] == result["cache_creation_cost"] + result["cache_read_cost"]
def test_different_templates(self) -> None:
"""测试不同模板"""
prices = {
"input_tokens": 1000,
"output_tokens": 500,
"cache_creation_input_tokens": 200,
"cache_read_input_tokens": 300,
"input_price_per_1m": 3.0,
"output_price_per_1m": 15.0,
"cache_creation_price_per_1m": 3.75,
"cache_read_price_per_1m": 0.3,
"price_per_request": None,
}
# Claude 模板有 cache_creation
result_claude = calculate_request_cost(**prices, billing_template="claude")
assert result_claude["cache_creation_cost"] > 0
# OpenAI 模板没有 cache_creation
result_openai = calculate_request_cost(**prices, billing_template="openai")
assert result_openai["cache_creation_cost"] == 0
def test_tiered_pricing_with_total_context(self) -> None:
"""测试使用自定义 total_input_context 的阶梯计费"""
tiered_pricing = {
"tiers": [
{"up_to": 200000, "input_price_per_1m": 3.0, "output_price_per_1m": 15.0},
{"up_to": None, "input_price_per_1m": 1.5, "output_price_per_1m": 7.5},
]
}
# 传入预计算的 total_input_context
result = calculate_request_cost(
input_tokens=1000,
output_tokens=500,
cache_creation_input_tokens=0,
cache_read_input_tokens=0,
input_price_per_1m=3.0,
output_price_per_1m=15.0,
cache_creation_price_per_1m=None,
cache_read_price_per_1m=None,
price_per_request=None,
tiered_pricing=tiered_pricing,
total_input_context=250000, # 预计算的值,超过 200k
billing_template="claude",
)
# 应该使用第二阶梯价格
assert result["tier_index"] == 1