mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-02 15:52:26 +08:00
refactor: optimize provider query and stats aggregation logic
This commit is contained in:
@@ -20,10 +20,10 @@ depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create ENUM types
|
||||
op.execute("CREATE TYPE userrole AS ENUM ('admin', 'user')")
|
||||
# Create ENUM types (with IF NOT EXISTS for idempotency)
|
||||
op.execute("DO $$ BEGIN CREATE TYPE userrole AS ENUM ('admin', 'user'); EXCEPTION WHEN duplicate_object THEN NULL; END $$")
|
||||
op.execute(
|
||||
"CREATE TYPE providerbillingtype AS ENUM ('monthly_quota', 'pay_as_you_go', 'free_tier')"
|
||||
"DO $$ BEGIN CREATE TYPE providerbillingtype AS ENUM ('monthly_quota', 'pay_as_you_go', 'free_tier'); EXCEPTION WHEN duplicate_object THEN NULL; END $$"
|
||||
)
|
||||
|
||||
# ==================== users ====================
|
||||
@@ -35,7 +35,7 @@ def upgrade() -> None:
|
||||
sa.Column("password_hash", sa.String(255), nullable=False),
|
||||
sa.Column(
|
||||
"role",
|
||||
sa.Enum("admin", "user", name="userrole", create_type=False),
|
||||
postgresql.ENUM("admin", "user", name="userrole", create_type=False),
|
||||
nullable=False,
|
||||
server_default="user",
|
||||
),
|
||||
@@ -67,7 +67,7 @@ def upgrade() -> None:
|
||||
sa.Column("website", sa.String(500), nullable=True),
|
||||
sa.Column(
|
||||
"billing_type",
|
||||
sa.Enum(
|
||||
postgresql.ENUM(
|
||||
"monthly_quota", "pay_as_you_go", "free_tier", name="providerbillingtype", create_type=False
|
||||
),
|
||||
nullable=False,
|
||||
|
||||
@@ -124,6 +124,27 @@ export interface ModelExport {
|
||||
config?: any
|
||||
}
|
||||
|
||||
// Provider 模型查询响应
|
||||
export interface ProviderModelsQueryResponse {
|
||||
success: boolean
|
||||
data: {
|
||||
models: Array<{
|
||||
id: string
|
||||
object?: string
|
||||
created?: number
|
||||
owned_by?: string
|
||||
display_name?: string
|
||||
api_format?: string
|
||||
}>
|
||||
error?: string
|
||||
}
|
||||
provider: {
|
||||
id: string
|
||||
name: string
|
||||
display_name: string
|
||||
}
|
||||
}
|
||||
|
||||
export interface ConfigImportRequest extends ConfigExportData {
|
||||
merge_mode: 'skip' | 'overwrite' | 'error'
|
||||
}
|
||||
@@ -356,5 +377,14 @@ export const adminApi = {
|
||||
data
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
// 查询 Provider 可用模型(从上游 API 获取)
|
||||
async queryProviderModels(providerId: string, apiKeyId?: string): Promise<ProviderModelsQueryResponse> {
|
||||
const response = await apiClient.post<ProviderModelsQueryResponse>(
|
||||
'/api/admin/provider-query/models',
|
||||
{ provider_id: providerId, api_key_id: apiKeyId }
|
||||
)
|
||||
return response.data
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,25 @@
|
||||
// API 格式常量
|
||||
export const API_FORMATS = {
|
||||
CLAUDE: 'CLAUDE',
|
||||
CLAUDE_CLI: 'CLAUDE_CLI',
|
||||
OPENAI: 'OPENAI',
|
||||
OPENAI_CLI: 'OPENAI_CLI',
|
||||
GEMINI: 'GEMINI',
|
||||
GEMINI_CLI: 'GEMINI_CLI',
|
||||
} as const
|
||||
|
||||
export type APIFormat = typeof API_FORMATS[keyof typeof API_FORMATS]
|
||||
|
||||
// API 格式显示名称映射(按品牌分组:API 在前,CLI 在后)
|
||||
export const API_FORMAT_LABELS: Record<string, string> = {
|
||||
[API_FORMATS.CLAUDE]: 'Claude',
|
||||
[API_FORMATS.CLAUDE_CLI]: 'Claude CLI',
|
||||
[API_FORMATS.OPENAI]: 'OpenAI',
|
||||
[API_FORMATS.OPENAI_CLI]: 'OpenAI CLI',
|
||||
[API_FORMATS.GEMINI]: 'Gemini',
|
||||
[API_FORMATS.GEMINI_CLI]: 'Gemini CLI',
|
||||
}
|
||||
|
||||
export interface ProviderEndpoint {
|
||||
id: string
|
||||
provider_id: string
|
||||
@@ -214,6 +236,7 @@ export interface ConcurrencyStatus {
|
||||
export interface ProviderModelAlias {
|
||||
name: string
|
||||
priority: number // 优先级(数字越小优先级越高)
|
||||
api_formats?: string[] // 作用域(适用的 API 格式),为空表示对所有格式生效
|
||||
}
|
||||
|
||||
export interface Model {
|
||||
|
||||
@@ -396,15 +396,13 @@ interface ProviderGroup {
|
||||
|
||||
const groupedModels = computed(() => {
|
||||
let models = allModels.value.filter(m => !m.deprecated)
|
||||
// 搜索(支持空格分隔的多关键词 AND 搜索)
|
||||
if (searchQuery.value) {
|
||||
const query = searchQuery.value.toLowerCase()
|
||||
models = models.filter(model =>
|
||||
model.providerId.toLowerCase().includes(query) ||
|
||||
model.providerName.toLowerCase().includes(query) ||
|
||||
model.modelId.toLowerCase().includes(query) ||
|
||||
model.modelName.toLowerCase().includes(query) ||
|
||||
model.family?.toLowerCase().includes(query)
|
||||
)
|
||||
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||
models = models.filter(model => {
|
||||
const searchableText = `${model.providerId} ${model.providerName} ${model.modelId} ${model.modelName} ${model.family || ''}`.toLowerCase()
|
||||
return keywords.every(keyword => searchableText.includes(keyword))
|
||||
})
|
||||
}
|
||||
|
||||
// 按提供商分组
|
||||
@@ -425,10 +423,12 @@ const groupedModels = computed(() => {
|
||||
|
||||
// 如果有搜索词,把提供商名称/ID匹配的排在前面
|
||||
if (searchQuery.value) {
|
||||
const query = searchQuery.value.toLowerCase()
|
||||
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||
result.sort((a, b) => {
|
||||
const aProviderMatch = a.providerId.toLowerCase().includes(query) || a.providerName.toLowerCase().includes(query)
|
||||
const bProviderMatch = b.providerId.toLowerCase().includes(query) || b.providerName.toLowerCase().includes(query)
|
||||
const aText = `${a.providerId} ${a.providerName}`.toLowerCase()
|
||||
const bText = `${b.providerId} ${b.providerName}`.toLowerCase()
|
||||
const aProviderMatch = keywords.some(k => aText.includes(k))
|
||||
const bProviderMatch = keywords.some(k => bText.includes(k))
|
||||
if (aProviderMatch && !bProviderMatch) return -1
|
||||
if (!aProviderMatch && bProviderMatch) return 1
|
||||
return a.providerName.localeCompare(b.providerName)
|
||||
|
||||
@@ -526,7 +526,14 @@
|
||||
@edit-model="handleEditModel"
|
||||
@delete-model="handleDeleteModel"
|
||||
@batch-assign="handleBatchAssign"
|
||||
@manage-alias="handleManageAlias"
|
||||
/>
|
||||
|
||||
<!-- 模型名称映射 -->
|
||||
<ModelAliasesTab
|
||||
v-if="provider"
|
||||
:key="`aliases-${provider.id}`"
|
||||
:provider="provider"
|
||||
@refresh="handleRelatedDataRefresh"
|
||||
/>
|
||||
</div>
|
||||
</template>
|
||||
@@ -629,16 +636,6 @@
|
||||
@update:open="batchAssignDialogOpen = $event"
|
||||
@changed="handleBatchAssignChanged"
|
||||
/>
|
||||
|
||||
<!-- 模型别名管理对话框 -->
|
||||
<ModelAliasDialog
|
||||
v-if="open && provider"
|
||||
:open="aliasDialogOpen"
|
||||
:provider-id="provider.id"
|
||||
:model="aliasEditingModel"
|
||||
@update:open="aliasDialogOpen = $event"
|
||||
@saved="handleAliasSaved"
|
||||
/>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
@@ -667,8 +664,8 @@ import {
|
||||
KeyFormDialog,
|
||||
KeyAllowedModelsDialog,
|
||||
ModelsTab,
|
||||
BatchAssignModelsDialog,
|
||||
ModelAliasDialog
|
||||
ModelAliasesTab,
|
||||
BatchAssignModelsDialog
|
||||
} from '@/features/providers/components'
|
||||
import EndpointFormDialog from '@/features/providers/components/EndpointFormDialog.vue'
|
||||
import ProviderModelFormDialog from '@/features/providers/components/ProviderModelFormDialog.vue'
|
||||
@@ -737,10 +734,6 @@ const deleteModelConfirmOpen = ref(false)
|
||||
const modelToDelete = ref<Model | null>(null)
|
||||
const batchAssignDialogOpen = ref(false)
|
||||
|
||||
// 别名管理相关状态
|
||||
const aliasDialogOpen = ref(false)
|
||||
const aliasEditingModel = ref<Model | null>(null)
|
||||
|
||||
// 拖动排序相关状态
|
||||
const dragState = ref({
|
||||
isDragging: false,
|
||||
@@ -762,8 +755,7 @@ const hasBlockingDialogOpen = computed(() =>
|
||||
deleteKeyConfirmOpen.value ||
|
||||
modelFormDialogOpen.value ||
|
||||
deleteModelConfirmOpen.value ||
|
||||
batchAssignDialogOpen.value ||
|
||||
aliasDialogOpen.value
|
||||
batchAssignDialogOpen.value
|
||||
)
|
||||
|
||||
// 监听 providerId 变化
|
||||
@@ -792,7 +784,6 @@ watch(() => props.open, (newOpen) => {
|
||||
keyAllowedModelsDialogOpen.value = false
|
||||
deleteKeyConfirmOpen.value = false
|
||||
batchAssignDialogOpen.value = false
|
||||
aliasDialogOpen.value = false
|
||||
|
||||
// 重置临时数据
|
||||
endpointToEdit.value = null
|
||||
@@ -1030,19 +1021,6 @@ async function handleBatchAssignChanged() {
|
||||
emit('refresh')
|
||||
}
|
||||
|
||||
// 处理管理映射 - 打开别名对话框
|
||||
function handleManageAlias(model: Model) {
|
||||
aliasEditingModel.value = model
|
||||
aliasDialogOpen.value = true
|
||||
}
|
||||
|
||||
// 处理别名保存完成
|
||||
async function handleAliasSaved() {
|
||||
aliasEditingModel.value = null
|
||||
await loadProvider()
|
||||
emit('refresh')
|
||||
}
|
||||
|
||||
// 处理模型保存完成
|
||||
async function handleModelSaved() {
|
||||
editingModel.value = null
|
||||
|
||||
@@ -10,3 +10,4 @@ export { default as BatchAssignModelsDialog } from './BatchAssignModelsDialog.vu
|
||||
export { default as ModelAliasDialog } from './ModelAliasDialog.vue'
|
||||
|
||||
export { default as ModelsTab } from './provider-tabs/ModelsTab.vue'
|
||||
export { default as ModelAliasesTab } from './provider-tabs/ModelAliasesTab.vue'
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -165,15 +165,6 @@
|
||||
>
|
||||
<Edit class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-8 w-8"
|
||||
title="管理映射"
|
||||
@click="openAliasDialog(model)"
|
||||
>
|
||||
<Tag class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
@@ -218,7 +209,7 @@
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
import { Box, Edit, Trash2, Layers, Eye, Wrench, Zap, Brain, Power, Copy, Image, Tag } from 'lucide-vue-next'
|
||||
import { Box, Edit, Trash2, Layers, Eye, Wrench, Zap, Brain, Power, Copy, Image } from 'lucide-vue-next'
|
||||
import Card from '@/components/ui/card.vue'
|
||||
import Button from '@/components/ui/button.vue'
|
||||
import { useToast } from '@/composables/useToast'
|
||||
@@ -233,7 +224,6 @@ const emit = defineEmits<{
|
||||
'editModel': [model: Model]
|
||||
'deleteModel': [model: Model]
|
||||
'batchAssign': []
|
||||
'manageAlias': [model: Model]
|
||||
}>()
|
||||
|
||||
const { error: showError, success: showSuccess } = useToast()
|
||||
@@ -373,11 +363,6 @@ function openBatchAssignDialog() {
|
||||
emit('batchAssign')
|
||||
}
|
||||
|
||||
// 打开别名管理对话框
|
||||
function openAliasDialog(model: Model) {
|
||||
emit('manageAlias', model)
|
||||
}
|
||||
|
||||
// 切换模型启用状态
|
||||
async function toggleModelActive(model: Model) {
|
||||
if (togglingModelId.value) return
|
||||
|
||||
@@ -751,15 +751,13 @@ const expiringSoonCount = computed(() => apiKeys.value.filter(key => isExpiringS
|
||||
const filteredApiKeys = computed(() => {
|
||||
let result = apiKeys.value
|
||||
|
||||
// 搜索筛选
|
||||
// 搜索筛选(支持空格分隔的多关键词 AND 搜索)
|
||||
if (searchQuery.value) {
|
||||
const query = searchQuery.value.toLowerCase()
|
||||
result = result.filter(key =>
|
||||
(key.name && key.name.toLowerCase().includes(query)) ||
|
||||
(key.key_display && key.key_display.toLowerCase().includes(query)) ||
|
||||
(key.username && key.username.toLowerCase().includes(query)) ||
|
||||
(key.user_email && key.user_email.toLowerCase().includes(query))
|
||||
)
|
||||
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||
result = result.filter(key => {
|
||||
const searchableText = `${key.name || ''} ${key.key_display || ''} ${key.username || ''} ${key.user_email || ''}`.toLowerCase()
|
||||
return keywords.every(keyword => searchableText.includes(keyword))
|
||||
})
|
||||
}
|
||||
|
||||
// 状态筛选
|
||||
|
||||
@@ -1002,13 +1002,13 @@ async function batchRemoveSelectedProviders() {
|
||||
const filteredGlobalModels = computed(() => {
|
||||
let result = globalModels.value
|
||||
|
||||
// 搜索
|
||||
// 搜索(支持空格分隔的多关键词 AND 搜索)
|
||||
if (searchQuery.value) {
|
||||
const query = searchQuery.value.toLowerCase()
|
||||
result = result.filter(m =>
|
||||
m.name.toLowerCase().includes(query) ||
|
||||
m.display_name?.toLowerCase().includes(query)
|
||||
)
|
||||
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||
result = result.filter(m => {
|
||||
const searchableText = `${m.name} ${m.display_name || ''}`.toLowerCase()
|
||||
return keywords.every(keyword => searchableText.includes(keyword))
|
||||
})
|
||||
}
|
||||
|
||||
// 能力筛选
|
||||
|
||||
@@ -505,13 +505,13 @@ const priorityModeConfig = computed(() => {
|
||||
const filteredProviders = computed(() => {
|
||||
let result = [...providers.value]
|
||||
|
||||
// 搜索筛选
|
||||
// 搜索筛选(支持空格分隔的多关键词 AND 搜索)
|
||||
if (searchQuery.value.trim()) {
|
||||
const query = searchQuery.value.toLowerCase()
|
||||
result = result.filter(p =>
|
||||
p.display_name.toLowerCase().includes(query) ||
|
||||
p.name.toLowerCase().includes(query)
|
||||
)
|
||||
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||
result = result.filter(p => {
|
||||
const searchableText = `${p.display_name} ${p.name}`.toLowerCase()
|
||||
return keywords.every(keyword => searchableText.includes(keyword))
|
||||
})
|
||||
}
|
||||
|
||||
// 排序
|
||||
|
||||
@@ -791,11 +791,13 @@ const filteredUsers = computed(() => {
|
||||
return new Date(b.created_at).getTime() - new Date(a.created_at).getTime()
|
||||
})
|
||||
|
||||
// 搜索(支持空格分隔的多关键词 AND 搜索)
|
||||
if (searchQuery.value) {
|
||||
const query = searchQuery.value.toLowerCase()
|
||||
filtered = filtered.filter(
|
||||
u => u.username.toLowerCase().includes(query) || u.email?.toLowerCase().includes(query)
|
||||
)
|
||||
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||
filtered = filtered.filter(u => {
|
||||
const searchableText = `${u.username} ${u.email || ''}`.toLowerCase()
|
||||
return keywords.every(keyword => searchableText.includes(keyword))
|
||||
})
|
||||
}
|
||||
|
||||
if (filterRole.value !== 'all') {
|
||||
|
||||
@@ -474,13 +474,13 @@ async function toggleCapability(modelName: string, capName: string) {
|
||||
const filteredModels = computed(() => {
|
||||
let result = models.value
|
||||
|
||||
// 搜索
|
||||
// 搜索(支持空格分隔的多关键词 AND 搜索)
|
||||
if (searchQuery.value) {
|
||||
const query = searchQuery.value.toLowerCase()
|
||||
result = result.filter(m =>
|
||||
m.name.toLowerCase().includes(query) ||
|
||||
m.display_name?.toLowerCase().includes(query)
|
||||
)
|
||||
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||
result = result.filter(m => {
|
||||
const searchableText = `${m.name} ${m.display_name || ''}`.toLowerCase()
|
||||
return keywords.every(keyword => searchableText.includes(keyword))
|
||||
})
|
||||
}
|
||||
|
||||
// 能力筛选
|
||||
|
||||
@@ -7,6 +7,7 @@ from .api_keys import router as api_keys_router
|
||||
from .endpoints import router as endpoints_router
|
||||
from .models import router as models_router
|
||||
from .monitoring import router as monitoring_router
|
||||
from .provider_query import router as provider_query_router
|
||||
from .provider_strategy import router as provider_strategy_router
|
||||
from .providers import router as providers_router
|
||||
from .security import router as security_router
|
||||
@@ -26,5 +27,6 @@ router.include_router(provider_strategy_router)
|
||||
router.include_router(adaptive_router)
|
||||
router.include_router(models_router)
|
||||
router.include_router(security_router)
|
||||
router.include_router(provider_query_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
|
||||
@@ -1,46 +1,28 @@
|
||||
"""
|
||||
Provider Query API 端点
|
||||
用于查询提供商的余额、使用记录等信息
|
||||
用于查询提供商的模型列表等信息
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.core.crypto import crypto_service
|
||||
from src.core.logger import logger
|
||||
from src.database.database import get_db
|
||||
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint, User
|
||||
|
||||
# 初始化适配器注册
|
||||
from src.plugins.provider_query import init # noqa
|
||||
from src.plugins.provider_query import get_query_registry
|
||||
from src.plugins.provider_query.base import QueryCapability
|
||||
from src.models.database import Provider, ProviderEndpoint, User
|
||||
from src.utils.auth_utils import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/provider-query", tags=["Provider Query"])
|
||||
router = APIRouter(prefix="/api/admin/provider-query", tags=["Provider Query"])
|
||||
|
||||
|
||||
# ============ Request/Response Models ============
|
||||
|
||||
|
||||
class BalanceQueryRequest(BaseModel):
|
||||
"""余额查询请求"""
|
||||
|
||||
provider_id: str
|
||||
api_key_id: Optional[str] = None # 如果不指定,使用提供商的第一个可用 API Key
|
||||
|
||||
|
||||
class UsageSummaryQueryRequest(BaseModel):
|
||||
"""使用汇总查询请求"""
|
||||
|
||||
provider_id: str
|
||||
api_key_id: Optional[str] = None
|
||||
period: str = "month" # day, week, month, year
|
||||
|
||||
|
||||
class ModelsQueryRequest(BaseModel):
|
||||
"""模型列表查询请求"""
|
||||
|
||||
@@ -51,360 +33,281 @@ class ModelsQueryRequest(BaseModel):
|
||||
# ============ API Endpoints ============
|
||||
|
||||
|
||||
@router.get("/adapters")
|
||||
async def list_adapters(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取所有可用的查询适配器
|
||||
async def _fetch_openai_models(
|
||||
client: httpx.AsyncClient,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
api_format: str,
|
||||
extra_headers: Optional[dict] = None,
|
||||
) -> tuple[list, Optional[str]]:
|
||||
"""获取 OpenAI 格式的模型列表
|
||||
|
||||
Returns:
|
||||
适配器列表
|
||||
tuple[list, Optional[str]]: (模型列表, 错误信息)
|
||||
"""
|
||||
registry = get_query_registry()
|
||||
adapters = registry.list_adapters()
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
if extra_headers:
|
||||
# 防止 extra_headers 覆盖 Authorization
|
||||
safe_headers = {k: v for k, v in extra_headers.items() if k.lower() != "authorization"}
|
||||
headers.update(safe_headers)
|
||||
|
||||
return {"success": True, "data": adapters}
|
||||
# 构建 /v1/models URL
|
||||
if base_url.endswith("/v1"):
|
||||
models_url = f"{base_url}/models"
|
||||
else:
|
||||
models_url = f"{base_url}/v1/models"
|
||||
|
||||
try:
|
||||
response = await client.get(models_url, headers=headers)
|
||||
logger.debug(f"OpenAI models request to {models_url}: status={response.status_code}")
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
models = []
|
||||
if "data" in data:
|
||||
models = data["data"]
|
||||
elif isinstance(data, list):
|
||||
models = data
|
||||
# 为每个模型添加 api_format 字段
|
||||
for m in models:
|
||||
m["api_format"] = api_format
|
||||
return models, None
|
||||
else:
|
||||
# 记录详细的错误信息
|
||||
error_body = response.text[:500] if response.text else "(empty)"
|
||||
error_msg = f"HTTP {response.status_code}: {error_body}"
|
||||
logger.warning(f"OpenAI models request to {models_url} failed: {error_msg}")
|
||||
return [], error_msg
|
||||
except Exception as e:
|
||||
error_msg = f"Request error: {str(e)}"
|
||||
logger.warning(f"Failed to fetch models from {models_url}: {e}")
|
||||
return [], error_msg
|
||||
|
||||
|
||||
@router.get("/capabilities/{provider_id}")
|
||||
async def get_provider_capabilities(
|
||||
provider_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取提供商支持的查询能力
|
||||
|
||||
Args:
|
||||
provider_id: 提供商 ID
|
||||
async def _fetch_claude_models(
|
||||
client: httpx.AsyncClient, base_url: str, api_key: str, api_format: str
|
||||
) -> tuple[list, Optional[str]]:
|
||||
"""获取 Claude 格式的模型列表
|
||||
|
||||
Returns:
|
||||
支持的查询能力列表
|
||||
tuple[list, Optional[str]]: (模型列表, 错误信息)
|
||||
"""
|
||||
# 获取提供商
|
||||
from sqlalchemy import select
|
||||
|
||||
result = await db.execute(select(Provider).where(Provider.id == provider_id))
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
registry = get_query_registry()
|
||||
capabilities = registry.get_capabilities_for_provider(provider.name)
|
||||
|
||||
if capabilities is None:
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"provider_id": provider_id,
|
||||
"provider_name": provider.name,
|
||||
"capabilities": [],
|
||||
"has_adapter": False,
|
||||
"message": "No query adapter available for this provider",
|
||||
},
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"provider_id": provider_id,
|
||||
"provider_name": provider.name,
|
||||
"capabilities": [c.name for c in capabilities],
|
||||
"has_adapter": True,
|
||||
},
|
||||
headers = {
|
||||
"x-api-key": api_key,
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"anthropic-version": "2023-06-01",
|
||||
}
|
||||
|
||||
# 构建 /v1/models URL
|
||||
if base_url.endswith("/v1"):
|
||||
models_url = f"{base_url}/models"
|
||||
else:
|
||||
models_url = f"{base_url}/v1/models"
|
||||
|
||||
@router.post("/balance")
|
||||
async def query_balance(
|
||||
request: BalanceQueryRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
查询提供商余额
|
||||
try:
|
||||
response = await client.get(models_url, headers=headers)
|
||||
logger.debug(f"Claude models request to {models_url}: status={response.status_code}")
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
models = []
|
||||
if "data" in data:
|
||||
models = data["data"]
|
||||
elif isinstance(data, list):
|
||||
models = data
|
||||
# 为每个模型添加 api_format 字段
|
||||
for m in models:
|
||||
m["api_format"] = api_format
|
||||
return models, None
|
||||
else:
|
||||
error_body = response.text[:500] if response.text else "(empty)"
|
||||
error_msg = f"HTTP {response.status_code}: {error_body}"
|
||||
logger.warning(f"Claude models request to {models_url} failed: {error_msg}")
|
||||
return [], error_msg
|
||||
except Exception as e:
|
||||
error_msg = f"Request error: {str(e)}"
|
||||
logger.warning(f"Failed to fetch Claude models from {models_url}: {e}")
|
||||
return [], error_msg
|
||||
|
||||
Args:
|
||||
request: 查询请求
|
||||
|
||||
async def _fetch_gemini_models(
|
||||
client: httpx.AsyncClient, base_url: str, api_key: str, api_format: str
|
||||
) -> tuple[list, Optional[str]]:
|
||||
"""获取 Gemini 格式的模型列表
|
||||
|
||||
Returns:
|
||||
余额信息
|
||||
tuple[list, Optional[str]]: (模型列表, 错误信息)
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
# 兼容 base_url 已包含 /v1beta 的情况
|
||||
base_url_clean = base_url.rstrip("/")
|
||||
if base_url_clean.endswith("/v1beta"):
|
||||
models_url = f"{base_url_clean}/models?key={api_key}"
|
||||
else:
|
||||
models_url = f"{base_url_clean}/v1beta/models?key={api_key}"
|
||||
|
||||
# 获取提供商及其端点
|
||||
result = await db.execute(
|
||||
select(Provider)
|
||||
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
|
||||
.where(Provider.id == request.provider_id)
|
||||
)
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
# 获取 API Key
|
||||
api_key_value = None
|
||||
endpoint_config = None
|
||||
|
||||
if request.api_key_id:
|
||||
# 查找指定的 API Key
|
||||
for endpoint in provider.endpoints:
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.id == request.api_key_id:
|
||||
api_key_value = api_key.api_key
|
||||
endpoint_config = {
|
||||
"base_url": endpoint.base_url,
|
||||
"api_format": endpoint.api_format if endpoint.api_format else None,
|
||||
try:
|
||||
response = await client.get(models_url)
|
||||
logger.debug(f"Gemini models request to {models_url}: status={response.status_code}")
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if "models" in data:
|
||||
# 转换为统一格式
|
||||
return [
|
||||
{
|
||||
"id": m.get("name", "").replace("models/", ""),
|
||||
"owned_by": "google",
|
||||
"display_name": m.get("displayName", ""),
|
||||
"api_format": api_format,
|
||||
}
|
||||
break
|
||||
if api_key_value:
|
||||
break
|
||||
|
||||
if not api_key_value:
|
||||
raise HTTPException(status_code=404, detail="API Key not found")
|
||||
else:
|
||||
# 使用第一个可用的 API Key
|
||||
for endpoint in provider.endpoints:
|
||||
if endpoint.is_active and endpoint.api_keys:
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.is_active:
|
||||
api_key_value = api_key.api_key
|
||||
endpoint_config = {
|
||||
"base_url": endpoint.base_url,
|
||||
"api_format": endpoint.api_format if endpoint.api_format else None,
|
||||
}
|
||||
break
|
||||
if api_key_value:
|
||||
break
|
||||
|
||||
if not api_key_value:
|
||||
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
|
||||
|
||||
# 查询余额
|
||||
registry = get_query_registry()
|
||||
query_result = await registry.query_provider_balance(
|
||||
provider_type=provider.name, api_key=api_key_value, endpoint_config=endpoint_config
|
||||
)
|
||||
|
||||
if not query_result.success:
|
||||
logger.warning(f"Balance query failed for provider {provider.name}: {query_result.error}")
|
||||
|
||||
return {
|
||||
"success": query_result.success,
|
||||
"data": query_result.to_dict(),
|
||||
"provider": {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.post("/usage-summary")
|
||||
async def query_usage_summary(
|
||||
request: UsageSummaryQueryRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
查询提供商使用汇总
|
||||
|
||||
Args:
|
||||
request: 查询请求
|
||||
|
||||
Returns:
|
||||
使用汇总信息
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
# 获取提供商及其端点
|
||||
result = await db.execute(
|
||||
select(Provider)
|
||||
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
|
||||
.where(Provider.id == request.provider_id)
|
||||
)
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
# 获取 API Key(逻辑同上)
|
||||
api_key_value = None
|
||||
endpoint_config = None
|
||||
|
||||
if request.api_key_id:
|
||||
for endpoint in provider.endpoints:
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.id == request.api_key_id:
|
||||
api_key_value = api_key.api_key
|
||||
endpoint_config = {"base_url": endpoint.base_url}
|
||||
break
|
||||
if api_key_value:
|
||||
break
|
||||
|
||||
if not api_key_value:
|
||||
raise HTTPException(status_code=404, detail="API Key not found")
|
||||
else:
|
||||
for endpoint in provider.endpoints:
|
||||
if endpoint.is_active and endpoint.api_keys:
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.is_active:
|
||||
api_key_value = api_key.api_key
|
||||
endpoint_config = {"base_url": endpoint.base_url}
|
||||
break
|
||||
if api_key_value:
|
||||
break
|
||||
|
||||
if not api_key_value:
|
||||
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
|
||||
|
||||
# 查询使用汇总
|
||||
registry = get_query_registry()
|
||||
query_result = await registry.query_provider_usage(
|
||||
provider_type=provider.name,
|
||||
api_key=api_key_value,
|
||||
period=request.period,
|
||||
endpoint_config=endpoint_config,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": query_result.success,
|
||||
"data": query_result.to_dict(),
|
||||
"provider": {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
},
|
||||
}
|
||||
for m in data["models"]
|
||||
], None
|
||||
return [], None
|
||||
else:
|
||||
error_body = response.text[:500] if response.text else "(empty)"
|
||||
error_msg = f"HTTP {response.status_code}: {error_body}"
|
||||
logger.warning(f"Gemini models request to {models_url} failed: {error_msg}")
|
||||
return [], error_msg
|
||||
except Exception as e:
|
||||
error_msg = f"Request error: {str(e)}"
|
||||
logger.warning(f"Failed to fetch Gemini models from {models_url}: {e}")
|
||||
return [], error_msg
|
||||
|
||||
|
||||
@router.post("/models")
|
||||
async def query_available_models(
|
||||
request: ModelsQueryRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
查询提供商可用模型
|
||||
|
||||
遍历所有活跃端点,根据端点的 API 格式选择正确的请求方式:
|
||||
- OPENAI/OPENAI_CLI: /v1/models (Bearer token)
|
||||
- CLAUDE/CLAUDE_CLI: /v1/models (x-api-key)
|
||||
- GEMINI/GEMINI_CLI: /v1beta/models (URL key parameter)
|
||||
|
||||
Args:
|
||||
request: 查询请求
|
||||
|
||||
Returns:
|
||||
模型列表
|
||||
所有端点的模型列表(合并)
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
# 获取提供商及其端点
|
||||
result = await db.execute(
|
||||
select(Provider)
|
||||
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
|
||||
.where(Provider.id == request.provider_id)
|
||||
provider = (
|
||||
db.query(Provider)
|
||||
.options(joinedload(Provider.endpoints).joinedload(ProviderEndpoint.api_keys))
|
||||
.filter(Provider.id == request.provider_id)
|
||||
.first()
|
||||
)
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
# 获取 API Key
|
||||
api_key_value = None
|
||||
endpoint_config = None
|
||||
# 收集所有活跃端点的配置
|
||||
endpoint_configs: list[dict] = []
|
||||
|
||||
if request.api_key_id:
|
||||
# 指定了特定的 API Key,只使用该 Key 对应的端点
|
||||
for endpoint in provider.endpoints:
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.id == request.api_key_id:
|
||||
api_key_value = api_key.api_key
|
||||
endpoint_config = {"base_url": endpoint.base_url}
|
||||
try:
|
||||
api_key_value = crypto_service.decrypt(api_key.api_key)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decrypt API key: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to decrypt API key")
|
||||
endpoint_configs.append({
|
||||
"api_key": api_key_value,
|
||||
"base_url": endpoint.base_url,
|
||||
"api_format": endpoint.api_format,
|
||||
"extra_headers": endpoint.headers,
|
||||
})
|
||||
break
|
||||
if api_key_value:
|
||||
if endpoint_configs:
|
||||
break
|
||||
|
||||
if not api_key_value:
|
||||
if not endpoint_configs:
|
||||
raise HTTPException(status_code=404, detail="API Key not found")
|
||||
else:
|
||||
# 遍历所有活跃端点,每个端点取第一个可用的 Key
|
||||
for endpoint in provider.endpoints:
|
||||
if endpoint.is_active and endpoint.api_keys:
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.is_active:
|
||||
api_key_value = api_key.api_key
|
||||
endpoint_config = {"base_url": endpoint.base_url}
|
||||
break
|
||||
if api_key_value:
|
||||
break
|
||||
if not endpoint.is_active or not endpoint.api_keys:
|
||||
continue
|
||||
|
||||
if not api_key_value:
|
||||
# 找第一个可用的 Key
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.is_active:
|
||||
try:
|
||||
api_key_value = crypto_service.decrypt(api_key.api_key)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decrypt API key: {e}")
|
||||
continue # 尝试下一个 Key
|
||||
endpoint_configs.append({
|
||||
"api_key": api_key_value,
|
||||
"base_url": endpoint.base_url,
|
||||
"api_format": endpoint.api_format,
|
||||
"extra_headers": endpoint.headers,
|
||||
})
|
||||
break # 只取第一个可用的 Key
|
||||
|
||||
if not endpoint_configs:
|
||||
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
|
||||
|
||||
# 查询模型
|
||||
registry = get_query_registry()
|
||||
adapter = registry.get_adapter_for_provider(provider.name)
|
||||
# 并发请求所有端点的模型列表
|
||||
all_models: list = []
|
||||
errors: list[str] = []
|
||||
|
||||
if not adapter:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"No query adapter available for provider: {provider.name}"
|
||||
async def fetch_endpoint_models(
|
||||
client: httpx.AsyncClient, config: dict
|
||||
) -> tuple[list, Optional[str]]:
|
||||
base_url = config["base_url"]
|
||||
if not base_url:
|
||||
return [], None
|
||||
base_url = base_url.rstrip("/")
|
||||
api_format = config["api_format"]
|
||||
api_key_value = config["api_key"]
|
||||
extra_headers = config["extra_headers"]
|
||||
|
||||
try:
|
||||
if api_format in ["CLAUDE", "CLAUDE_CLI"]:
|
||||
return await _fetch_claude_models(client, base_url, api_key_value, api_format)
|
||||
elif api_format in ["GEMINI", "GEMINI_CLI"]:
|
||||
return await _fetch_gemini_models(client, base_url, api_key_value, api_format)
|
||||
else:
|
||||
return await _fetch_openai_models(
|
||||
client, base_url, api_key_value, api_format, extra_headers
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching models from {api_format} endpoint: {e}")
|
||||
return [], f"{api_format}: {str(e)}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
results = await asyncio.gather(
|
||||
*[fetch_endpoint_models(client, c) for c in endpoint_configs]
|
||||
)
|
||||
for models, error in results:
|
||||
all_models.extend(models)
|
||||
if error:
|
||||
errors.append(error)
|
||||
|
||||
query_result = await adapter.query_available_models(
|
||||
api_key=api_key_value, endpoint_config=endpoint_config
|
||||
)
|
||||
# 按 model id 去重(保留第一个)
|
||||
seen_ids: set[str] = set()
|
||||
unique_models: list = []
|
||||
for model in all_models:
|
||||
model_id = model.get("id")
|
||||
if model_id and model_id not in seen_ids:
|
||||
seen_ids.add(model_id)
|
||||
unique_models.append(model)
|
||||
|
||||
error = "; ".join(errors) if errors else None
|
||||
if not unique_models and not error:
|
||||
error = "No models returned from any endpoint"
|
||||
|
||||
return {
|
||||
"success": query_result.success,
|
||||
"data": query_result.to_dict(),
|
||||
"success": len(unique_models) > 0,
|
||||
"data": {"models": unique_models, "error": error},
|
||||
"provider": {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/cache/{provider_id}")
|
||||
async def clear_query_cache(
|
||||
provider_id: str,
|
||||
api_key_id: Optional[str] = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
清除查询缓存
|
||||
|
||||
Args:
|
||||
provider_id: 提供商 ID
|
||||
api_key_id: 可选,指定清除某个 API Key 的缓存
|
||||
|
||||
Returns:
|
||||
清除结果
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
|
||||
# 获取提供商
|
||||
result = await db.execute(select(Provider).where(Provider.id == provider_id))
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
registry = get_query_registry()
|
||||
adapter = registry.get_adapter_for_provider(provider.name)
|
||||
|
||||
if adapter:
|
||||
if api_key_id:
|
||||
# 获取 API Key 值来清除缓存
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
result = await db.execute(select(ProviderAPIKey).where(ProviderAPIKey.id == api_key_id))
|
||||
api_key = result.scalar_one_or_none()
|
||||
if api_key:
|
||||
adapter.clear_cache(api_key.api_key)
|
||||
else:
|
||||
adapter.clear_cache()
|
||||
|
||||
return {"success": True, "message": "Cache cleared successfully"}
|
||||
|
||||
@@ -731,8 +731,15 @@ class DashboardDailyStatsAdapter(DashboardAdapter):
|
||||
)
|
||||
# stats_daily.date 存储的是业务日期对应的 UTC 开始时间
|
||||
# 需要转回业务时区再取日期,才能与日期序列匹配
|
||||
def _to_business_date_str(value: datetime) -> str:
|
||||
if value.tzinfo is None:
|
||||
value_utc = value.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
value_utc = value.astimezone(timezone.utc)
|
||||
return value_utc.astimezone(app_tz).date().isoformat()
|
||||
|
||||
stats_map = {
|
||||
stat.date.replace(tzinfo=timezone.utc).astimezone(app_tz).date().isoformat(): {
|
||||
_to_business_date_str(stat.date): {
|
||||
"requests": stat.total_requests,
|
||||
"tokens": stat.input_tokens + stat.output_tokens + stat.cache_creation_tokens + stat.cache_read_tokens,
|
||||
"cost": stat.total_cost,
|
||||
@@ -790,6 +797,38 @@ class DashboardDailyStatsAdapter(DashboardAdapter):
|
||||
"unique_providers": today_unique_providers,
|
||||
"fallback_count": today_fallback_count,
|
||||
}
|
||||
|
||||
# 历史预聚合缺失时兜底:按业务日范围实时计算(仅补最近少量缺失,避免全表扫描)
|
||||
yesterday_date = today_local.date() - timedelta(days=1)
|
||||
historical_end = min(end_date_local.date(), yesterday_date)
|
||||
missing_dates: list[str] = []
|
||||
cursor = start_date_local.date()
|
||||
while cursor <= historical_end:
|
||||
date_str = cursor.isoformat()
|
||||
if date_str not in stats_map:
|
||||
missing_dates.append(date_str)
|
||||
cursor += timedelta(days=1)
|
||||
|
||||
if missing_dates:
|
||||
for date_str in missing_dates[-7:]:
|
||||
target_local = datetime.fromisoformat(date_str).replace(tzinfo=app_tz)
|
||||
computed = StatsAggregatorService.compute_daily_stats(db, target_local)
|
||||
stats_map[date_str] = {
|
||||
"requests": computed["total_requests"],
|
||||
"tokens": (
|
||||
computed["input_tokens"]
|
||||
+ computed["output_tokens"]
|
||||
+ computed["cache_creation_tokens"]
|
||||
+ computed["cache_read_tokens"]
|
||||
),
|
||||
"cost": computed["total_cost"],
|
||||
"avg_response_time": computed["avg_response_time_ms"] / 1000.0
|
||||
if computed["avg_response_time_ms"]
|
||||
else 0,
|
||||
"unique_models": computed["unique_models"],
|
||||
"unique_providers": computed["unique_providers"],
|
||||
"fallback_count": computed["fallback_count"],
|
||||
}
|
||||
else:
|
||||
# 普通用户:仍需实时查询(用户级预聚合可选)
|
||||
query = db.query(Usage).filter(
|
||||
|
||||
@@ -266,8 +266,11 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
if mapping and mapping.model:
|
||||
# 使用 select_provider_model_name 支持别名功能
|
||||
# 传入 api_key.id 作为 affinity_key,实现相同用户稳定选择同一别名
|
||||
# 传入 api_format 用于过滤适用的别名作用域
|
||||
affinity_key = self.api_key.id if self.api_key else None
|
||||
mapped_name = mapping.model.select_provider_model_name(affinity_key)
|
||||
mapped_name = mapping.model.select_provider_model_name(
|
||||
affinity_key, api_format=self.FORMAT_ID
|
||||
)
|
||||
logger.debug(f"[Chat] 模型映射: {source_model} -> {mapped_name}")
|
||||
return mapped_name
|
||||
|
||||
|
||||
@@ -155,8 +155,11 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
if mapping and mapping.model:
|
||||
# 使用 select_provider_model_name 支持别名功能
|
||||
# 传入 api_key.id 作为 affinity_key,实现相同用户稳定选择同一别名
|
||||
# 传入 api_format 用于过滤适用的别名作用域
|
||||
affinity_key = self.api_key.id if self.api_key else None
|
||||
mapped_name = mapping.model.select_provider_model_name(affinity_key)
|
||||
mapped_name = mapping.model.select_provider_model_name(
|
||||
affinity_key, api_format=self.FORMAT_ID
|
||||
)
|
||||
logger.debug(f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)")
|
||||
return mapped_name
|
||||
|
||||
|
||||
@@ -813,7 +813,9 @@ class Model(Base):
|
||||
def get_effective_supports_image_generation(self) -> bool:
|
||||
return self._get_effective_capability("supports_image_generation", False)
|
||||
|
||||
def select_provider_model_name(self, affinity_key: Optional[str] = None) -> str:
|
||||
def select_provider_model_name(
|
||||
self, affinity_key: Optional[str] = None, api_format: Optional[str] = None
|
||||
) -> str:
|
||||
"""按优先级选择要使用的 Provider 模型名称
|
||||
|
||||
如果配置了 provider_model_aliases,按优先级选择(数字越小越优先);
|
||||
@@ -822,6 +824,7 @@ class Model(Base):
|
||||
|
||||
Args:
|
||||
affinity_key: 用于哈希分散的亲和键(如用户 API Key 哈希),确保同一用户稳定选择同一别名
|
||||
api_format: 当前请求的 API 格式(如 CLAUDE、OPENAI 等),用于过滤适用的别名
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
@@ -840,6 +843,13 @@ class Model(Base):
|
||||
if not isinstance(name, str) or not name.strip():
|
||||
continue
|
||||
|
||||
# 检查 api_formats 作用域(如果配置了且当前有 api_format)
|
||||
alias_api_formats = raw.get("api_formats")
|
||||
if api_format and alias_api_formats:
|
||||
# 如果配置了作用域,只有匹配时才生效
|
||||
if isinstance(alias_api_formats, list) and api_format not in alias_api_formats:
|
||||
continue
|
||||
|
||||
raw_priority = raw.get("priority", 1)
|
||||
try:
|
||||
priority = int(raw_priority)
|
||||
|
||||
@@ -35,6 +35,7 @@ class CleanupScheduler:
|
||||
def __init__(self):
|
||||
self.running = False
|
||||
self._interval_tasks = []
|
||||
self._stats_aggregation_lock = asyncio.Lock()
|
||||
|
||||
async def start(self):
|
||||
"""启动调度器"""
|
||||
@@ -56,6 +57,14 @@ class CleanupScheduler:
|
||||
job_id="stats_aggregation",
|
||||
name="统计数据聚合",
|
||||
)
|
||||
# 统计聚合补偿任务 - 每 30 分钟检查缺失并回填
|
||||
scheduler.add_interval_job(
|
||||
self._scheduled_stats_aggregation,
|
||||
minutes=30,
|
||||
job_id="stats_aggregation_backfill",
|
||||
name="统计数据聚合补偿",
|
||||
backfill=True,
|
||||
)
|
||||
|
||||
# 清理任务 - 凌晨 3 点执行
|
||||
scheduler.add_cron_job(
|
||||
@@ -115,9 +124,9 @@ class CleanupScheduler:
|
||||
|
||||
# ========== 任务函数(APScheduler 直接调用异步函数) ==========
|
||||
|
||||
async def _scheduled_stats_aggregation(self):
|
||||
async def _scheduled_stats_aggregation(self, backfill: bool = False):
|
||||
"""统计聚合任务(定时调用)"""
|
||||
await self._perform_stats_aggregation()
|
||||
await self._perform_stats_aggregation(backfill=backfill)
|
||||
|
||||
async def _scheduled_cleanup(self):
|
||||
"""清理任务(定时调用)"""
|
||||
@@ -144,136 +153,157 @@ class CleanupScheduler:
|
||||
Args:
|
||||
backfill: 是否回填历史数据(启动时检查缺失的日期)
|
||||
"""
|
||||
db = create_session()
|
||||
try:
|
||||
# 检查是否启用统计聚合
|
||||
if not SystemConfigService.get_config(db, "enable_stats_aggregation", True):
|
||||
logger.info("统计聚合已禁用,跳过聚合任务")
|
||||
return
|
||||
if self._stats_aggregation_lock.locked():
|
||||
logger.info("统计聚合任务正在运行,跳过本次触发")
|
||||
return
|
||||
|
||||
logger.info("开始执行统计数据聚合...")
|
||||
|
||||
from src.models.database import StatsDaily, User as DBUser
|
||||
from src.services.system.scheduler import APP_TIMEZONE
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
# 使用业务时区计算日期,确保与定时任务触发时间一致
|
||||
# 定时任务在 Asia/Shanghai 凌晨 1 点触发,此时应聚合 Asia/Shanghai 的"昨天"
|
||||
app_tz = ZoneInfo(APP_TIMEZONE)
|
||||
now_local = datetime.now(app_tz)
|
||||
today_local = now_local.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
if backfill:
|
||||
# 启动时检查并回填缺失的日期
|
||||
from src.models.database import StatsSummary
|
||||
|
||||
summary = db.query(StatsSummary).first()
|
||||
if not summary:
|
||||
# 首次运行,回填所有历史数据
|
||||
logger.info("检测到首次运行,开始回填历史统计数据...")
|
||||
days_to_backfill = SystemConfigService.get_config(
|
||||
db, "stats_backfill_days", 365
|
||||
)
|
||||
count = StatsAggregatorService.backfill_historical_data(
|
||||
db, days=days_to_backfill
|
||||
)
|
||||
logger.info(f"历史数据回填完成,共 {count} 天")
|
||||
async with self._stats_aggregation_lock:
|
||||
db = create_session()
|
||||
try:
|
||||
# 检查是否启用统计聚合
|
||||
if not SystemConfigService.get_config(db, "enable_stats_aggregation", True):
|
||||
logger.info("统计聚合已禁用,跳过聚合任务")
|
||||
return
|
||||
|
||||
# 非首次运行,检查最近是否有缺失的日期需要回填
|
||||
latest_stat = (
|
||||
db.query(StatsDaily)
|
||||
.order_by(StatsDaily.date.desc())
|
||||
.first()
|
||||
)
|
||||
logger.info("开始执行统计数据聚合...")
|
||||
|
||||
if latest_stat:
|
||||
latest_date_utc = latest_stat.date
|
||||
if latest_date_utc.tzinfo is None:
|
||||
latest_date_utc = latest_date_utc.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
latest_date_utc = latest_date_utc.astimezone(timezone.utc)
|
||||
from src.models.database import StatsDaily, User as DBUser
|
||||
from src.services.system.scheduler import APP_TIMEZONE
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
# 使用业务日期计算缺失区间(避免用 UTC 年月日导致日期偏移,且对 DST 更安全)
|
||||
latest_business_date = latest_date_utc.astimezone(app_tz).date()
|
||||
yesterday_business_date = (today_local.date() - timedelta(days=1))
|
||||
missing_start_date = latest_business_date + timedelta(days=1)
|
||||
# 使用业务时区计算日期,确保与定时任务触发时间一致
|
||||
# 定时任务在 Asia/Shanghai 凌晨 1 点触发,此时应聚合 Asia/Shanghai 的"昨天"
|
||||
app_tz = ZoneInfo(APP_TIMEZONE)
|
||||
now_local = datetime.now(app_tz)
|
||||
today_local = now_local.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
if missing_start_date <= yesterday_business_date:
|
||||
missing_days = (yesterday_business_date - missing_start_date).days + 1
|
||||
logger.info(
|
||||
f"检测到缺失 {missing_days} 天的统计数据 "
|
||||
f"({missing_start_date} ~ {yesterday_business_date}),开始回填..."
|
||||
if backfill:
|
||||
# 启动时检查并回填缺失的日期
|
||||
from src.models.database import StatsSummary
|
||||
|
||||
summary = db.query(StatsSummary).first()
|
||||
if not summary:
|
||||
# 首次运行,回填所有历史数据
|
||||
logger.info("检测到首次运行,开始回填历史统计数据...")
|
||||
days_to_backfill = SystemConfigService.get_config(
|
||||
db, "stats_backfill_days", 365
|
||||
)
|
||||
count = StatsAggregatorService.backfill_historical_data(
|
||||
db, days=days_to_backfill
|
||||
)
|
||||
logger.info(f"历史数据回填完成,共 {count} 天")
|
||||
return
|
||||
|
||||
current_date = missing_start_date
|
||||
users = db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
|
||||
# 非首次运行,检查最近是否有缺失的日期需要回填
|
||||
latest_stat = db.query(StatsDaily).order_by(StatsDaily.date.desc()).first()
|
||||
|
||||
while current_date <= yesterday_business_date:
|
||||
try:
|
||||
current_date_local = datetime.combine(
|
||||
current_date, datetime.min.time(), tzinfo=app_tz
|
||||
if latest_stat:
|
||||
latest_date_utc = latest_stat.date
|
||||
if latest_date_utc.tzinfo is None:
|
||||
latest_date_utc = latest_date_utc.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
latest_date_utc = latest_date_utc.astimezone(timezone.utc)
|
||||
|
||||
# 使用业务日期计算缺失区间(避免用 UTC 年月日导致日期偏移,且对 DST 更安全)
|
||||
latest_business_date = latest_date_utc.astimezone(app_tz).date()
|
||||
yesterday_business_date = today_local.date() - timedelta(days=1)
|
||||
missing_start_date = latest_business_date + timedelta(days=1)
|
||||
|
||||
if missing_start_date <= yesterday_business_date:
|
||||
missing_days = (
|
||||
yesterday_business_date - missing_start_date
|
||||
).days + 1
|
||||
|
||||
# 限制最大回填天数,防止停机很久后一次性回填太多
|
||||
max_backfill_days: int = SystemConfigService.get_config(
|
||||
db, "max_stats_backfill_days", 30
|
||||
) or 30
|
||||
if missing_days > max_backfill_days:
|
||||
logger.warning(
|
||||
f"缺失 {missing_days} 天数据超过最大回填限制 "
|
||||
f"{max_backfill_days} 天,只回填最近 {max_backfill_days} 天"
|
||||
)
|
||||
StatsAggregatorService.aggregate_daily_stats(db, current_date_local)
|
||||
# 聚合用户数据
|
||||
for (user_id,) in users:
|
||||
try:
|
||||
StatsAggregatorService.aggregate_user_daily_stats(
|
||||
db, user_id, current_date_local
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"回填用户 {user_id} 日期 {current_date} 失败: {e}"
|
||||
)
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"回填日期 {current_date} 失败: {e}")
|
||||
missing_start_date = yesterday_business_date - timedelta(
|
||||
days=max_backfill_days - 1
|
||||
)
|
||||
missing_days = max_backfill_days
|
||||
|
||||
logger.info(
|
||||
f"检测到缺失 {missing_days} 天的统计数据 "
|
||||
f"({missing_start_date} ~ {yesterday_business_date}),开始回填..."
|
||||
)
|
||||
|
||||
current_date = missing_start_date
|
||||
users = (
|
||||
db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
|
||||
)
|
||||
|
||||
while current_date <= yesterday_business_date:
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
current_date_local = datetime.combine(
|
||||
current_date, datetime.min.time(), tzinfo=app_tz
|
||||
)
|
||||
StatsAggregatorService.aggregate_daily_stats(
|
||||
db, current_date_local
|
||||
)
|
||||
for (user_id,) in users:
|
||||
try:
|
||||
StatsAggregatorService.aggregate_user_daily_stats(
|
||||
db, user_id, current_date_local
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"回填用户 {user_id} 日期 {current_date} 失败: {e}"
|
||||
)
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"回填日期 {current_date} 失败: {e}")
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
current_date += timedelta(days=1)
|
||||
current_date += timedelta(days=1)
|
||||
|
||||
# 更新全局汇总
|
||||
StatsAggregatorService.update_summary(db)
|
||||
logger.info(f"缺失数据回填完成,共 {missing_days} 天")
|
||||
else:
|
||||
logger.info("统计数据已是最新,无需回填")
|
||||
return
|
||||
StatsAggregatorService.update_summary(db)
|
||||
logger.info(f"缺失数据回填完成,共 {missing_days} 天")
|
||||
else:
|
||||
logger.info("统计数据已是最新,无需回填")
|
||||
return
|
||||
|
||||
# 定时任务:聚合昨天的数据
|
||||
# 注意:aggregate_daily_stats 期望业务时区的日期,不是 UTC
|
||||
yesterday_local = today_local - timedelta(days=1)
|
||||
# 定时任务:聚合昨天的数据
|
||||
yesterday_local = today_local - timedelta(days=1)
|
||||
|
||||
StatsAggregatorService.aggregate_daily_stats(db, yesterday_local)
|
||||
StatsAggregatorService.aggregate_daily_stats(db, yesterday_local)
|
||||
|
||||
# 聚合所有用户的昨日数据
|
||||
users = db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
|
||||
for (user_id,) in users:
|
||||
try:
|
||||
StatsAggregatorService.aggregate_user_daily_stats(db, user_id, yesterday_local)
|
||||
except Exception as e:
|
||||
logger.warning(f"聚合用户 {user_id} 统计数据失败: {e}")
|
||||
# 回滚当前用户的失败操作,继续处理其他用户
|
||||
users = db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
|
||||
for (user_id,) in users:
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
StatsAggregatorService.aggregate_user_daily_stats(
|
||||
db, user_id, yesterday_local
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"聚合用户 {user_id} 统计数据失败: {e}")
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 更新全局汇总
|
||||
StatsAggregatorService.update_summary(db)
|
||||
StatsAggregatorService.update_summary(db)
|
||||
|
||||
logger.info("统计数据聚合完成")
|
||||
logger.info("统计数据聚合完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"统计聚合任务执行失败: {e}")
|
||||
db.rollback()
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.exception(f"统计聚合任务执行失败: {e}")
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def _perform_pending_cleanup(self):
|
||||
"""执行 pending 状态清理"""
|
||||
|
||||
@@ -56,65 +56,44 @@ class StatsAggregatorService:
|
||||
"""统计数据聚合服务"""
|
||||
|
||||
@staticmethod
|
||||
def aggregate_daily_stats(db: Session, date: datetime) -> StatsDaily:
|
||||
"""聚合指定日期的统计数据
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
date: 要聚合的业务日期(使用 APP_TIMEZONE 时区)
|
||||
|
||||
Returns:
|
||||
StatsDaily 记录
|
||||
"""
|
||||
# 将业务日期转换为 UTC 时间范围
|
||||
def compute_daily_stats(db: Session, date: datetime) -> dict:
|
||||
"""计算指定业务日期的统计数据(不写入数据库)"""
|
||||
day_start, day_end = _get_business_day_range(date)
|
||||
|
||||
# stats_daily.date 存储的是业务日期对应的 UTC 开始时间
|
||||
# 检查是否已存在该日期的记录
|
||||
existing = db.query(StatsDaily).filter(StatsDaily.date == day_start).first()
|
||||
if existing:
|
||||
stats = existing
|
||||
else:
|
||||
stats = StatsDaily(id=str(uuid.uuid4()), date=day_start)
|
||||
|
||||
# 基础请求统计
|
||||
base_query = db.query(Usage).filter(
|
||||
and_(Usage.created_at >= day_start, Usage.created_at < day_end)
|
||||
)
|
||||
|
||||
total_requests = base_query.count()
|
||||
|
||||
# 如果没有请求,直接返回空记录
|
||||
if total_requests == 0:
|
||||
stats.total_requests = 0
|
||||
stats.success_requests = 0
|
||||
stats.error_requests = 0
|
||||
stats.input_tokens = 0
|
||||
stats.output_tokens = 0
|
||||
stats.cache_creation_tokens = 0
|
||||
stats.cache_read_tokens = 0
|
||||
stats.total_cost = 0.0
|
||||
stats.actual_total_cost = 0.0
|
||||
stats.input_cost = 0.0
|
||||
stats.output_cost = 0.0
|
||||
stats.cache_creation_cost = 0.0
|
||||
stats.cache_read_cost = 0.0
|
||||
stats.avg_response_time_ms = 0.0
|
||||
stats.fallback_count = 0
|
||||
return {
|
||||
"day_start": day_start,
|
||||
"total_requests": 0,
|
||||
"success_requests": 0,
|
||||
"error_requests": 0,
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"cache_creation_tokens": 0,
|
||||
"cache_read_tokens": 0,
|
||||
"total_cost": 0.0,
|
||||
"actual_total_cost": 0.0,
|
||||
"input_cost": 0.0,
|
||||
"output_cost": 0.0,
|
||||
"cache_creation_cost": 0.0,
|
||||
"cache_read_cost": 0.0,
|
||||
"avg_response_time_ms": 0.0,
|
||||
"fallback_count": 0,
|
||||
"unique_models": 0,
|
||||
"unique_providers": 0,
|
||||
}
|
||||
|
||||
if not existing:
|
||||
db.add(stats)
|
||||
db.commit()
|
||||
return stats
|
||||
|
||||
# 错误请求数
|
||||
error_requests = (
|
||||
base_query.filter(
|
||||
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
|
||||
).count()
|
||||
)
|
||||
|
||||
# Token 和成本聚合
|
||||
aggregated = (
|
||||
db.query(
|
||||
func.sum(Usage.input_tokens).label("input_tokens"),
|
||||
@@ -157,7 +136,6 @@ class StatsAggregatorService:
|
||||
or 0
|
||||
)
|
||||
|
||||
# 使用维度统计
|
||||
unique_models = (
|
||||
db.query(func.count(func.distinct(Usage.model)))
|
||||
.filter(and_(Usage.created_at >= day_start, Usage.created_at < day_end))
|
||||
@@ -171,31 +149,74 @@ class StatsAggregatorService:
|
||||
or 0
|
||||
)
|
||||
|
||||
return {
|
||||
"day_start": day_start,
|
||||
"total_requests": total_requests,
|
||||
"success_requests": total_requests - error_requests,
|
||||
"error_requests": error_requests,
|
||||
"input_tokens": int(aggregated.input_tokens or 0) if aggregated else 0,
|
||||
"output_tokens": int(aggregated.output_tokens or 0) if aggregated else 0,
|
||||
"cache_creation_tokens": int(aggregated.cache_creation_tokens or 0) if aggregated else 0,
|
||||
"cache_read_tokens": int(aggregated.cache_read_tokens or 0) if aggregated else 0,
|
||||
"total_cost": float(aggregated.total_cost or 0) if aggregated else 0.0,
|
||||
"actual_total_cost": float(aggregated.actual_total_cost or 0) if aggregated else 0.0,
|
||||
"input_cost": float(aggregated.input_cost or 0) if aggregated else 0.0,
|
||||
"output_cost": float(aggregated.output_cost or 0) if aggregated else 0.0,
|
||||
"cache_creation_cost": float(aggregated.cache_creation_cost or 0) if aggregated else 0.0,
|
||||
"cache_read_cost": float(aggregated.cache_read_cost or 0) if aggregated else 0.0,
|
||||
"avg_response_time_ms": float(aggregated.avg_response_time or 0) if aggregated else 0.0,
|
||||
"fallback_count": fallback_count,
|
||||
"unique_models": unique_models,
|
||||
"unique_providers": unique_providers,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def aggregate_daily_stats(db: Session, date: datetime) -> StatsDaily:
|
||||
"""聚合指定日期的统计数据
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
date: 要聚合的业务日期(使用 APP_TIMEZONE 时区)
|
||||
|
||||
Returns:
|
||||
StatsDaily 记录
|
||||
"""
|
||||
computed = StatsAggregatorService.compute_daily_stats(db, date)
|
||||
day_start = computed["day_start"]
|
||||
|
||||
# stats_daily.date 存储的是业务日期对应的 UTC 开始时间
|
||||
# 检查是否已存在该日期的记录
|
||||
existing = db.query(StatsDaily).filter(StatsDaily.date == day_start).first()
|
||||
if existing:
|
||||
stats = existing
|
||||
else:
|
||||
stats = StatsDaily(id=str(uuid.uuid4()), date=day_start)
|
||||
|
||||
# 更新统计记录
|
||||
stats.total_requests = total_requests
|
||||
stats.success_requests = total_requests - error_requests
|
||||
stats.error_requests = error_requests
|
||||
stats.input_tokens = int(aggregated.input_tokens or 0)
|
||||
stats.output_tokens = int(aggregated.output_tokens or 0)
|
||||
stats.cache_creation_tokens = int(aggregated.cache_creation_tokens or 0)
|
||||
stats.cache_read_tokens = int(aggregated.cache_read_tokens or 0)
|
||||
stats.total_cost = float(aggregated.total_cost or 0)
|
||||
stats.actual_total_cost = float(aggregated.actual_total_cost or 0)
|
||||
stats.input_cost = float(aggregated.input_cost or 0)
|
||||
stats.output_cost = float(aggregated.output_cost or 0)
|
||||
stats.cache_creation_cost = float(aggregated.cache_creation_cost or 0)
|
||||
stats.cache_read_cost = float(aggregated.cache_read_cost or 0)
|
||||
stats.avg_response_time_ms = float(aggregated.avg_response_time or 0)
|
||||
stats.fallback_count = fallback_count
|
||||
stats.unique_models = unique_models
|
||||
stats.unique_providers = unique_providers
|
||||
stats.total_requests = computed["total_requests"]
|
||||
stats.success_requests = computed["success_requests"]
|
||||
stats.error_requests = computed["error_requests"]
|
||||
stats.input_tokens = computed["input_tokens"]
|
||||
stats.output_tokens = computed["output_tokens"]
|
||||
stats.cache_creation_tokens = computed["cache_creation_tokens"]
|
||||
stats.cache_read_tokens = computed["cache_read_tokens"]
|
||||
stats.total_cost = computed["total_cost"]
|
||||
stats.actual_total_cost = computed["actual_total_cost"]
|
||||
stats.input_cost = computed["input_cost"]
|
||||
stats.output_cost = computed["output_cost"]
|
||||
stats.cache_creation_cost = computed["cache_creation_cost"]
|
||||
stats.cache_read_cost = computed["cache_read_cost"]
|
||||
stats.avg_response_time_ms = computed["avg_response_time_ms"]
|
||||
stats.fallback_count = computed["fallback_count"]
|
||||
stats.unique_models = computed["unique_models"]
|
||||
stats.unique_providers = computed["unique_providers"]
|
||||
|
||||
if not existing:
|
||||
db.add(stats)
|
||||
db.commit()
|
||||
|
||||
# 日志使用业务日期(输入参数),而不是 UTC 日期
|
||||
logger.info(f"[StatsAggregator] 聚合日期 {date.date()} 完成: {total_requests} 请求")
|
||||
logger.info(f"[StatsAggregator] 聚合日期 {date.date()} 完成: {computed['total_requests']} 请求")
|
||||
return stats
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user