feat: 实现 GlobalModel 别名匹配系统

主要更改:
- GlobalModel 支持 model_aliases 配置,允许使用正则表达式定义别名规则
- Provider Key 的 allowed_models 现在可以通过别名规则匹配 GlobalModel
- 新增 ModelAliasesTab 组件用于管理模型别名配置
- Provider 详情页新增别名映射预览功能,展示 Key 白名单与 GlobalModel 别名的匹配关系
- 路由预览 API 返回 Key 的 allowed_models 信息

安全特性:
- 使用 regex 库的原生超时保护(100ms)防止 ReDoS 攻击
- 别名规则数量限制(50 条/模型)和长度限制(200 字符)
- 别名映射预览 API 添加超时保护和结果截断

其他改进:
- GlobalModel 更新/删除时使用行级锁防止并发竞态
- 缓存失效逻辑优化,支持异步清理和正则缓存清空
- 路由 Tab 布局重构,使用 flexbox 替代绝对定位
This commit is contained in:
fawney19
2026-01-13 16:04:15 +08:00
parent 9fea71a70c
commit 85decd7487
21 changed files with 3845 additions and 2308 deletions

View File

@@ -95,3 +95,50 @@ export async function testModel(data: TestModelRequest): Promise<TestModelRespon
const response = await client.post('/api/admin/provider-query/test-model', data)
return response.data
}
/**
* 别名映射预览相关类型
*/
export interface AliasMatchedModel {
allowed_model: string
alias_pattern: string
}
export interface AliasMatchingGlobalModel {
global_model_id: string
global_model_name: string
display_name: string
is_active: boolean
matched_models: AliasMatchedModel[]
}
export interface AliasMatchingKey {
key_id: string
key_name: string
masked_key: string
is_active: boolean
allowed_models: string[]
matching_global_models: AliasMatchingGlobalModel[]
}
export interface ProviderAliasMappingPreviewResponse {
provider_id: string
provider_name: string
keys: AliasMatchingKey[]
total_keys: number
total_matches: number
// 截断提示
truncated: boolean
truncated_keys: number
truncated_models: number
}
/**
* 获取 Provider 别名映射预览
*/
export async function getProviderAliasMappingPreview(
providerId: string
): Promise<ProviderAliasMappingPreviewResponse> {
const response = await client.get(`/api/admin/providers/${providerId}/alias-mapping-preview`)
return response.data
}

View File

@@ -641,6 +641,7 @@ export interface RoutingKeyInfo {
health_score: number
is_active: boolean
api_formats: string[]
allowed_models?: string[] | null // 允许的模型列表null 表示不限制
circuit_breaker_open: boolean
circuit_breaker_formats: string[]
}

View File

@@ -0,0 +1,523 @@
<template>
<Card class="overflow-hidden">
<!-- 表头 -->
<div class="px-4 py-3 border-b border-border/60">
<div class="flex items-center justify-between">
<div class="flex items-baseline gap-2">
<h4 class="text-sm font-semibold">别名规则</h4>
<span class="text-xs text-muted-foreground">
支持正则表达式 ({{ localAliases.length }}/{{ MAX_ALIASES_PER_MODEL }})
</span>
</div>
<Button
variant="ghost"
size="icon"
class="h-7 w-7"
title="添加规则"
:disabled="localAliases.length >= MAX_ALIASES_PER_MODEL"
@click="addAlias"
>
<Plus class="w-4 h-4" />
</Button>
</div>
</div>
<!-- 规则列表 -->
<div v-if="localAliases.length > 0" class="divide-y">
<div
v-for="(alias, index) in localAliases"
:key="index"
>
<!-- 规则行 -->
<div
class="px-4 py-3 flex items-center gap-3 cursor-pointer hover:bg-muted/30 transition-colors"
@click="toggleExpand(index)"
>
<ChevronRight
class="w-4 h-4 text-muted-foreground transition-transform flex-shrink-0"
:class="{ 'rotate-90': expandedIndex === index }"
/>
<div class="flex-1 min-w-0">
<Input
v-model="localAliases[index]"
placeholder="例如: claude-haiku-.*"
:class="`font-mono text-sm ${alias.trim() && !getAliasValidation(alias).valid ? 'border-destructive' : ''}`"
@click.stop
@input="markDirty"
/>
<!-- 验证错误提示 -->
<div
v-if="alias.trim() && !getAliasValidation(alias).valid"
class="flex items-center gap-1 mt-1 text-xs text-destructive"
>
<AlertCircle class="w-3 h-3" />
<span>{{ getAliasValidation(alias).error }}</span>
</div>
</div>
<!-- 匹配统计 -->
<Badge
v-if="getAliasValidation(alias).valid && getMatchCount(alias) > 0"
variant="secondary"
class="text-xs flex-shrink-0 h-6 leading-none"
>
{{ getMatchCount(alias) }} 匹配
</Badge>
<Badge
v-else-if="alias.trim() && getAliasValidation(alias).valid"
variant="outline"
class="text-xs text-muted-foreground flex-shrink-0 h-6 leading-none"
>
无匹配
</Badge>
<!-- 操作按钮 -->
<div class="flex items-center gap-1 flex-shrink-0">
<Button
v-if="isDirty"
variant="ghost"
size="icon"
class="h-7 w-7 text-muted-foreground hover:text-primary"
title="保存"
:disabled="saving || hasValidationErrors"
@click.stop="saveAliases"
>
<Save v-if="!saving" class="w-4 h-4" />
<RefreshCw v-else class="w-4 h-4 animate-spin" />
</Button>
<Button
variant="ghost"
size="icon"
class="h-7 w-7 text-muted-foreground hover:text-destructive"
title="删除"
:disabled="saving"
@click.stop="removeAlias(index)"
>
<Trash2 class="w-4 h-4" />
</Button>
</div>
</div>
<!-- 展开内容匹配的 Key 列表 -->
<div
v-if="expandedIndex === index"
class="border-t bg-muted/10 px-4 py-3"
>
<div v-if="loadingPreview" class="flex items-center justify-center py-4">
<RefreshCw class="w-4 h-4 animate-spin text-muted-foreground" />
</div>
<div v-else-if="getMatchedKeysForAlias(alias).length === 0" class="text-center py-4">
<p class="text-sm text-muted-foreground">
{{ alias.trim() ? '此规则暂无匹配的 Key 白名单' : '请输入别名规则' }}
</p>
</div>
<div v-else class="space-y-2">
<div
v-for="item in getMatchedKeysForAlias(alias)"
:key="item.keyId"
class="bg-background rounded-md border p-3"
>
<div class="flex items-center gap-2 text-sm mb-2">
<span class="text-muted-foreground">{{ item.providerName }}</span>
<span class="text-muted-foreground">/</span>
<span class="font-medium">{{ item.keyName }}</span>
<span class="text-xs text-muted-foreground font-mono ml-auto">
{{ item.maskedKey }}
</span>
</div>
<div class="flex flex-wrap gap-1">
<Badge
v-for="model in item.matchedModels"
:key="model"
variant="secondary"
class="text-xs font-mono"
>
{{ model }}
</Badge>
</div>
</div>
</div>
</div>
</div>
</div>
<!-- 空状态 -->
<div
v-else
class="text-center py-8"
>
<GitMerge class="w-10 h-10 mx-auto text-muted-foreground/30 mb-3" />
<p class="text-sm text-muted-foreground">
暂无别名规则
</p>
<p class="text-xs text-muted-foreground mt-1">
添加别名可匹配 Provider Key 白名单中的模型
</p>
</div>
</Card>
</template>
<script setup lang="ts">
import { ref, watch, onMounted, onUnmounted, computed } from 'vue'
import { Card, Button, Input, Badge } from '@/components/ui'
import { Plus, Trash2, GitMerge, RefreshCw, ChevronRight, Save, AlertCircle } from 'lucide-vue-next'
import { updateGlobalModel, getGlobalModel, getGlobalModelRoutingPreview } from '@/api/global-models'
import type { ModelRoutingPreviewResponse } from '@/api/endpoints/types'
import { log } from '@/utils/logger'
import { useToast } from '@/composables/useToast'
const props = defineProps<{
globalModelId: string
modelName: string
aliases: string[]
}>()
const emit = defineEmits<{
update: [aliases: string[]]
}>()
// 安全限制常量(与后端保持一致)
const MAX_ALIASES_PER_MODEL = 50
const MAX_ALIAS_LENGTH = 200
// 危险的正则模式(可能导致 ReDoS与后端 model_permissions.py 保持一致)
// 注意:这些是用于检测用户输入字符串中的危险正则构造
const DANGEROUS_REGEX_PATTERNS = [
/\([^)]*[+*]\)[+*]/, // (x+)+, (x*)*, (x+)*, (x*)+
/\([^)]*\)\{[0-9]+,\}/, // (x){n,} 无上限
/\(\.\*\)\{[0-9]+,\}/, // (.*){n,} 贪婪量词 + 高重复
/\(\.\+\)\{[0-9]+,\}/, // (.+){n,} 贪婪量词 + 高重复
/\([^)]*\|[^)]*\)[+*]/, // (a|b)+ 选择分支 + 量词
/\(\.\*\)\+/, // (.*)+
/\(\.\+\)\+/, // (.+)+
/\([^)]*\*\)[+*]/, // 嵌套量词: (a*)+
/\(\\w\+\)\+/, // (\w+)+ - 检测字面量 \w
/\(\.\*\)\*/, // (.*)*
/\(.*\+.*\)\+/, // (a+b)+ 更通用的嵌套量词检测
/\[.*\]\{[0-9]+,\}\{/, // [x]{n,}{m,} 嵌套量词
/\.{2,}\*/, // ..* 连续通配
/\([^)]*\|[^)]*\)\*/, // (a|a)* 选择分支 + 星号
/\{[0-9]{2,},\}/, // {10,} 高重复次数无上限
/\(\[.*\]\+\)\+/, // ([x]+)+ 字符类嵌套量词
// 补充的危险模式(与后端保持一致)
/\([^)]*[+*]\)\{[0-9]+,/, // (a+){n,} 量词后跟大括号量词
/\(\([^)]*[+*]\)[+*]\)/, // ((a+)+) 三层嵌套量词
/\(\?:[^)]*[+*]\)[+*]/, // (?:a+)+ 非捕获组嵌套量词
]
// 正则匹配安全限制(与后端保持一致)
const REGEX_MATCH_MAX_INPUT_LENGTH = 200
const { success: toastSuccess, error: toastError } = useToast()
// 本地状态
const localAliases = ref<string[]>([...props.aliases])
const originalAliases = ref<string[]>([...props.aliases]) // 用于保存失败时恢复
const isDirty = ref(false)
const saving = ref(false)
const expandedIndex = ref<number | null>(null)
// 匹配预览状态
const loadingPreview = ref(false)
const routingData = ref<ModelRoutingPreviewResponse | null>(null)
// 正则编译缓存(简单的 LRU 实现)
const REGEX_CACHE_MAX_SIZE = 100
class LRURegexCache {
private cache = new Map<string, RegExp | null>()
private maxSize: number
constructor(maxSize: number) {
this.maxSize = maxSize
}
get(key: string): RegExp | null | undefined {
if (!this.cache.has(key)) return undefined
// 移到最后LRU
const value = this.cache.get(key)!
this.cache.delete(key)
this.cache.set(key, value)
return value
}
set(key: string, value: RegExp | null): void {
// 如果已存在,先删除(会重新添加到最后)
if (this.cache.has(key)) {
this.cache.delete(key)
} else if (this.cache.size >= this.maxSize) {
// 缓存已满,删除最早的条目
const firstKey = this.cache.keys().next().value as string | undefined
if (firstKey !== undefined) {
this.cache.delete(firstKey)
}
}
this.cache.set(key, value)
}
clear(): void {
this.cache.clear()
}
get size(): number {
return this.cache.size
}
}
const regexCache = new LRURegexCache(REGEX_CACHE_MAX_SIZE)
interface MatchedKeyForAlias {
keyId: string
keyName: string
maskedKey: string
providerName: string
matchedModels: string[]
}
interface ValidationResult {
valid: boolean
error?: string
}
/**
* 验证别名规则是否安全
*/
function validateAliasPattern(pattern: string): ValidationResult {
if (!pattern || !pattern.trim()) {
return { valid: false, error: '规则不能为空' }
}
if (pattern.length > MAX_ALIAS_LENGTH) {
return { valid: false, error: `规则过长 (最大 ${MAX_ALIAS_LENGTH} 字符)` }
}
// 检查危险模式
for (const dangerous of DANGEROUS_REGEX_PATTERNS) {
if (dangerous.test(pattern)) {
return { valid: false, error: '规则包含潜在危险的正则构造' }
}
}
// 尝试编译验证语法
try {
new RegExp(`^${pattern}$`, 'i')
} catch {
return { valid: false, error: `正则表达式语法错误` }
}
return { valid: true }
}
/**
* 获取别名的验证状态
*/
function getAliasValidation(alias: string): ValidationResult {
if (!alias.trim()) {
return { valid: true } // 空值暂不报错,保存时过滤
}
return validateAliasPattern(alias)
}
/**
* 检查是否有验证错误
*/
const hasValidationErrors = computed(() => {
return localAliases.value.some(alias => {
if (!alias.trim()) return false
return !validateAliasPattern(alias).valid
})
})
/**
* 安全的正则匹配(带缓存和保护)
*/
function matchPattern(pattern: string, text: string): boolean {
// 快速路径:精确匹配
if (pattern.toLowerCase() === text.toLowerCase()) {
return true
}
// 长度检查
if (pattern.length > MAX_ALIAS_LENGTH) {
return false
}
// 危险模式检查
for (const dangerous of DANGEROUS_REGEX_PATTERNS) {
if (dangerous.test(pattern)) {
return false
}
}
// 使用 LRU 缓存
let regex = regexCache.get(pattern)
if (regex === undefined) {
try {
regex = new RegExp(`^${pattern}$`, 'i')
regexCache.set(pattern, regex)
} catch {
regexCache.set(pattern, null)
return false
}
}
if (regex === null) {
return false
}
try {
// 额外保护:限制正则匹配的输入长度(与后端保持一致)
const matchInput = text.slice(0, REGEX_MATCH_MAX_INPUT_LENGTH)
return regex.test(matchInput)
} catch {
return false
}
}
// 获取指定别名匹配的 Key 列表
function getMatchedKeysForAlias(alias: string): MatchedKeyForAlias[] {
if (!routingData.value || !alias.trim()) return []
// 使用 Map 按 keyId 去重并合并匹配结果
const keyMap = new Map<string, MatchedKeyForAlias>()
for (const provider of routingData.value.providers) {
for (const endpoint of provider.endpoints) {
for (const key of endpoint.keys) {
if (!key.allowed_models || key.allowed_models.length === 0) continue
const matchedModels: string[] = []
for (const allowedModel of key.allowed_models) {
if (matchPattern(alias, allowedModel)) {
matchedModels.push(allowedModel)
}
}
if (matchedModels.length > 0) {
const existing = keyMap.get(key.id)
if (existing) {
// 合并匹配结果(去重)
const mergedModels = new Set([...existing.matchedModels, ...matchedModels])
existing.matchedModels = Array.from(mergedModels)
} else {
keyMap.set(key.id, {
keyId: key.id,
keyName: key.name,
maskedKey: key.masked_key,
providerName: provider.name,
matchedModels,
})
}
}
}
}
}
return Array.from(keyMap.values())
}
// 获取指定别名的匹配数量
function getMatchCount(alias: string): number {
return getMatchedKeysForAlias(alias).reduce((sum, item) => sum + item.matchedModels.length, 0)
}
function toggleExpand(index: number) {
expandedIndex.value = expandedIndex.value === index ? null : index
}
watch(() => props.aliases, (newAliases) => {
localAliases.value = [...newAliases]
originalAliases.value = [...newAliases]
isDirty.value = false
}, { deep: true })
// globalModelId 变化时清空缓存并重新加载预览
watch(() => props.globalModelId, () => {
regexCache.clear()
loadMatchPreview()
})
function markDirty() {
isDirty.value = true
}
function addAlias() {
if (localAliases.value.length >= MAX_ALIASES_PER_MODEL) {
toastError(`最多支持 ${MAX_ALIASES_PER_MODEL} 条别名规则`)
return
}
localAliases.value.push('')
isDirty.value = true
expandedIndex.value = localAliases.value.length - 1
}
function removeAlias(index: number) {
localAliases.value.splice(index, 1)
isDirty.value = true
if (expandedIndex.value === index) {
expandedIndex.value = null
} else if (expandedIndex.value !== null && expandedIndex.value > index) {
expandedIndex.value--
}
}
async function saveAliases() {
const cleanedAliases = localAliases.value
.map(a => a.trim())
.filter(a => a.length > 0)
saving.value = true
try {
const currentModel = await getGlobalModel(props.globalModelId)
const currentConfig = currentModel.config || {}
const updatedConfig = {
...currentConfig,
model_aliases: cleanedAliases.length > 0 ? cleanedAliases : undefined,
}
if (!updatedConfig.model_aliases || updatedConfig.model_aliases.length === 0) {
delete updatedConfig.model_aliases
}
await updateGlobalModel(props.globalModelId, {
config: updatedConfig,
})
localAliases.value = cleanedAliases
originalAliases.value = [...cleanedAliases] // 更新原始值
isDirty.value = false
toastSuccess('别名规则已保存')
emit('update', cleanedAliases)
} catch (err) {
log.error('保存别名规则失败:', err)
toastError('保存失败,请重试')
// 保存失败时恢复到原始值
localAliases.value = [...originalAliases.value]
isDirty.value = false
} finally {
saving.value = false
}
}
async function loadMatchPreview() {
// 清空正则缓存,确保使用最新数据
regexCache.clear()
loadingPreview.value = true
try {
routingData.value = await getGlobalModelRoutingPreview(props.globalModelId)
} catch (err) {
log.error('加载匹配预览失败:', err)
} finally {
loadingPreview.value = false
}
}
onMounted(() => {
loadMatchPreview()
})
// 组件卸载时清理缓存,防止内存泄漏
onUnmounted(() => {
regexCache.clear()
})
</script>

View File

@@ -104,6 +104,19 @@
<span class="hidden sm:inline">链路控制</span>
<span class="sm:hidden">链路</span>
</button>
<button
type="button"
class="flex-1 px-2 sm:px-4 py-2 text-xs sm:text-sm font-medium rounded-md transition-all duration-200"
:class="[
detailTab === 'aliases'
? 'bg-primary text-primary-foreground shadow-sm'
: 'text-muted-foreground hover:text-foreground hover:bg-background/50'
]"
@click="detailTab = 'aliases'"
>
<span class="hidden sm:inline">模型映射</span>
<span class="sm:hidden">映射</span>
</button>
</div>
<!-- Tab 内容 -->
@@ -419,6 +432,17 @@
@delete-provider="handleDeleteProviderFromRouting"
/>
</div>
<!-- Tab 3: 模型映射 -->
<div v-show="detailTab === 'aliases'">
<ModelAliasesTab
v-if="model"
:global-model-id="model.id"
:model-name="model.name"
:aliases="model.config?.model_aliases || []"
@update="handleAliasesUpdate"
/>
</div>
</div>
</Card>
</div>
@@ -456,6 +480,7 @@ import TableRow from '@/components/ui/table-row.vue'
import TableHead from '@/components/ui/table-head.vue'
import TableCell from '@/components/ui/table-cell.vue'
import RoutingTab from './RoutingTab.vue'
import ModelAliasesTab from './ModelAliasesTab.vue'
// 使用外部类型定义
import type { GlobalModelResponse } from '@/api/global-models'
@@ -518,6 +543,13 @@ function refreshRoutingData() {
routingTabRef.value?.loadRoutingData?.()
}
// 处理模型别名更新
function handleAliasesUpdate(_aliases: string[]) {
// 别名已在 ModelAliasesTab 内部保存到服务器
// 刷新路由数据以反映可能的候选变化
refreshRoutingData()
}
// 暴露刷新方法给父组件
defineExpose({
refreshRoutingData

View File

@@ -76,21 +76,19 @@
>
<!-- 格式标题栏 -->
<div
class="px-3 py-2 bg-muted/30 border-b border-border/40 flex items-center justify-between cursor-pointer"
class="px-4 py-3 bg-muted/30 flex items-center justify-between cursor-pointer hover:bg-muted/50 transition-colors"
@click="toggleFormat(formatGroup.api_format)"
>
<div class="flex items-center gap-2">
<div class="flex items-center gap-3">
<Badge
variant="secondary"
class="text-xs font-medium"
class="text-xs font-semibold px-2.5 py-1"
>
{{ formatGroup.api_format }}
</Badge>
<span class="text-xs text-muted-foreground">
<span class="text-sm text-muted-foreground">
{{ formatGroup.active_keys }}/{{ formatGroup.total_keys }} Keys
</span>
<span class="text-xs text-muted-foreground">·</span>
<span class="text-xs text-muted-foreground">
<span class="mx-1.5">·</span>
{{ formatGroup.active_providers }}/{{ formatGroup.total_providers }} 提供商
</span>
</div>
@@ -105,203 +103,220 @@
<div v-if="isFormatExpanded(formatGroup.api_format)">
<!-- ========== 全局 Key 优先模式 ========== -->
<template v-if="isGlobalKeyMode">
<div class="relative">
<!-- 垂直主线 -->
<div
v-if="formatGroup.keyGroups.length > 0"
class="absolute left-5 top-0 bottom-0 w-0.5 bg-border"
/>
<div class="py-2">
<template
v-for="(keyGroup, groupIndex) in formatGroup.keyGroups"
:key="groupIndex"
<div class="py-2 pl-3">
<template
v-for="(keyGroup, groupIndex) in formatGroup.keyGroups"
:key="groupIndex"
>
<!-- 第一组且有多个 key 时显示负载均衡标签 -->
<div
v-if="groupIndex === 0 && keyGroup.keys.length > 1"
class="ml-6 mr-3 mb-1 flex items-center gap-1 text-[10px] text-muted-foreground/60"
>
<!-- 第一组且有多个 key 时显示负载均衡标签 -->
<div
v-if="groupIndex === 0 && keyGroup.keys.length > 1"
class="ml-10 mr-3 mb-1 flex items-center gap-1 text-[10px] text-muted-foreground/60"
>
<span>负载均衡</span>
</div>
<span>负载均衡</span>
</div>
<!-- 该优先级组内的 Keys -->
<div
v-for="(keyEntry, keyIndex) in keyGroup.keys"
:key="keyEntry.key.id"
class="relative"
>
<!-- 该优先级组内的 Keys -->
<div
v-for="(keyEntry, keyIndex) in keyGroup.keys"
:key="keyEntry.key.id"
class="flex py-1"
>
<!-- 左侧节点 + 连线 -->
<div class="w-6 flex flex-col items-center shrink-0">
<!-- 上半段连线 -->
<div
class="w-0.5 flex-1"
:class="groupIndex === 0 && keyIndex === 0 ? 'bg-transparent' : 'bg-border'"
/>
<!-- 节点圆点 -->
<div
class="absolute left-[14px] top-4 w-3 h-3 rounded-full border-2 z-10"
class="w-3 h-3 rounded-full border-2 shrink-0"
:class="getGlobalKeyNodeClass(keyEntry, groupIndex, keyIndex)"
/>
<!-- Key 卡片无展开直接显示所有信息 -->
<!-- 下半段连线 -->
<div
class="ml-10 mr-3 mb-2"
:class="!keyEntry.key.is_active ? 'opacity-50' : ''"
class="w-0.5 flex-1"
:class="isLastKeyInFormat(formatGroup, groupIndex, keyIndex) ? 'bg-transparent' : 'bg-border'"
/>
</div>
<!-- Key 卡片 -->
<div
class="flex-1 mr-3"
:class="!keyEntry.key.is_active ? 'opacity-50' : ''"
>
<div
class="group rounded-lg transition-all p-2.5"
:class="getGlobalKeyCardClass(keyEntry, groupIndex, keyIndex)"
>
<div
class="group rounded-lg transition-all p-2.5"
:class="getGlobalKeyCardClass(keyEntry, groupIndex, keyIndex)"
>
<div class="flex items-center gap-2">
<!-- 第一列优先级标签 -->
<div
v-if="keyEntry.key.is_active"
class="px-1.5 py-0.5 rounded-full text-[10px] font-medium shrink-0"
:class="groupIndex === 0 && keyIndex === 0
? 'bg-primary text-primary-foreground'
: 'bg-muted-foreground/20 text-muted-foreground'"
>
<span v-if="groupIndex === 0 && keyIndex === 0">首选</span>
<span v-else>P{{ keyGroup.priority ?? '?' }}</span>
<div class="flex items-center gap-2">
<!-- 第一列优先级标签 -->
<div
v-if="keyEntry.key.is_active"
class="px-1.5 py-0.5 rounded-full text-[10px] font-medium shrink-0"
:class="groupIndex === 0 && keyIndex === 0
? 'bg-primary text-primary-foreground'
: 'bg-muted-foreground/20 text-muted-foreground'"
>
<span v-if="groupIndex === 0 && keyIndex === 0">首选</span>
<span v-else>P{{ keyGroup.priority ?? '?' }}</span>
</div>
<!-- 第二列状态指示灯 -->
<span
class="w-1.5 h-1.5 rounded-full shrink-0"
:class="getKeyStatusClass(keyEntry.key)"
/>
<!-- 第三列Key 名称 + Provider 信息 -->
<div class="min-w-0 flex-1">
<div class="flex items-center gap-1">
<span
class="text-sm font-medium truncate"
:class="keyEntry.key.circuit_breaker_open ? 'text-destructive' : ''"
>
{{ keyEntry.key.name }}
</span>
<code class="font-mono text-[10px] text-muted-foreground/60 shrink-0">
{{ keyEntry.key.masked_key }}
</code>
<Zap
v-if="keyEntry.key.circuit_breaker_open"
class="w-3 h-3 text-destructive shrink-0"
/>
</div>
<!-- Provider Endpoint 信息 -->
<div class="text-[10px] text-muted-foreground truncate">
{{ keyEntry.provider.name }}
<span v-if="hasModelMapping(keyEntry.provider)">
({{ keyEntry.provider.provider_model_name }})
</span>
<span v-if="keyEntry.provider.billing_type">
· {{ getBillingLabel(keyEntry.provider) }}
</span>
<span v-if="keyEntry.endpoint">
· {{ keyEntry.endpoint.base_url }}
</span>
</div>
</div>
<!-- 状态指示灯 -->
<span
class="w-1.5 h-1.5 rounded-full shrink-0"
:class="getKeyStatusClass(keyEntry.key)"
/>
<!-- 第三列Key 名称 + Provider 信息 -->
<div class="min-w-0 flex-1">
<div class="flex items-center gap-1">
<span
class="text-sm font-medium truncate"
:class="keyEntry.key.circuit_breaker_open ? 'text-destructive' : ''"
>
{{ keyEntry.key.name }}
</span>
<code class="font-mono text-[10px] text-muted-foreground/60 shrink-0">
{{ keyEntry.key.masked_key }}
</code>
<Zap
v-if="keyEntry.key.circuit_breaker_open"
class="w-3 h-3 text-destructive shrink-0"
<!-- 健康度 + RPM + 操作按钮 -->
<div class="flex items-center gap-1.5 shrink-0">
<!-- 健康度 -->
<div class="flex items-center gap-1">
<div class="w-8 h-1 bg-muted/80 rounded-full overflow-hidden">
<div
class="h-full transition-all duration-300"
:class="getHealthScoreBarColor(keyEntry.key.health_score)"
:style="{ width: `${keyEntry.key.health_score}%` }"
/>
</div>
<!-- Provider Endpoint 信息 -->
<div class="text-[10px] text-muted-foreground truncate">
{{ keyEntry.provider.name }}
<span v-if="hasModelMapping(keyEntry.provider)">
({{ keyEntry.provider.provider_model_name }})
</span>
<span v-if="keyEntry.provider.billing_type">
· {{ getBillingLabel(keyEntry.provider) }}
</span>
<span v-if="keyEntry.endpoint">
· {{ keyEntry.endpoint.base_url }}
</span>
</div>
</div>
<!-- 第四列健康度 + RPM + 操作按钮 -->
<div class="flex items-center gap-1.5 shrink-0">
<!-- 健康度 -->
<div class="flex items-center gap-1">
<div class="w-8 h-1 bg-muted/80 rounded-full overflow-hidden">
<div
class="h-full transition-all duration-300"
:class="getHealthScoreBarColor(keyEntry.key.health_score)"
:style="{ width: `${keyEntry.key.health_score}%` }"
/>
</div>
<span
class="text-[10px] font-medium tabular-nums"
:class="getHealthScoreTextColor(keyEntry.key.health_score)"
>
{{ Math.round(keyEntry.key.health_score) }}%
</span>
</div>
<!-- RPM -->
<span
v-if="keyEntry.key.effective_rpm"
class="text-[10px] text-muted-foreground/60"
class="text-[10px] font-medium tabular-nums"
:class="getHealthScoreTextColor(keyEntry.key.health_score)"
>
{{ keyEntry.key.is_adaptive ? '~' : '' }}{{ keyEntry.key.effective_rpm }}
{{ Math.round(keyEntry.key.health_score) }}%
</span>
<!-- 操作按钮 -->
<Button
variant="ghost"
size="icon"
class="h-6 w-6"
title="编辑此关联"
@click.stop="$emit('editProvider', keyEntry.provider)"
>
<Edit class="w-3 h-3" />
</Button>
<Button
variant="ghost"
size="icon"
class="h-6 w-6"
:title="keyEntry.provider.model_is_active ? '停用此关联' : '启用此关联'"
@click.stop="$emit('toggleProviderStatus', keyEntry.provider)"
>
<Power class="w-3 h-3" />
</Button>
<Button
variant="ghost"
size="icon"
class="h-6 w-6"
title="删除此关联"
@click.stop="$emit('deleteProvider', keyEntry.provider)"
>
<Trash2 class="w-3 h-3" />
</Button>
</div>
<!-- RPM -->
<span
v-if="keyEntry.key.effective_rpm"
class="text-[10px] text-muted-foreground/60"
>
{{ keyEntry.key.is_adaptive ? '~' : '' }}{{ keyEntry.key.effective_rpm }}
</span>
<!-- 操作按钮 -->
<Button
variant="ghost"
size="icon"
class="h-6 w-6"
title="编辑此关联"
@click.stop="$emit('editProvider', keyEntry.provider)"
>
<Edit class="w-3 h-3" />
</Button>
<Button
variant="ghost"
size="icon"
class="h-6 w-6"
:title="keyEntry.provider.model_is_active ? '停用此关联' : '启用此关联'"
@click.stop="$emit('toggleProviderStatus', keyEntry.provider)"
>
<Power class="w-3 h-3" />
</Button>
<Button
variant="ghost"
size="icon"
class="h-6 w-6"
title="删除此关联"
@click.stop="$emit('deleteProvider', keyEntry.provider)"
>
<Trash2 class="w-3 h-3" />
</Button>
</div>
<!-- 熔断详情如果有 -->
<div
v-if="keyEntry.key.circuit_breaker_open"
class="text-[10px] text-destructive mt-1.5 ml-6"
>
熔断中: {{ keyEntry.key.circuit_breaker_formats.join(', ') }}
</div>
</div>
<!-- 熔断详情如果有 -->
<div
v-if="keyEntry.key.circuit_breaker_open"
class="text-[10px] text-destructive mt-1.5 ml-6"
>
熔断中: {{ keyEntry.key.circuit_breaker_formats.join(', ') }}
</div>
</div>
</div>
</div>
<!-- 降级标记如果下一组有多个 key显示"降级 · 负载均衡" -->
<div
v-if="groupIndex < formatGroup.keyGroups.length - 1"
class="ml-10 -mt-1 mb-1 flex items-center gap-1"
>
<ArrowDown class="w-3 h-3 text-muted-foreground/50" />
<span class="text-[10px] text-muted-foreground/50">
<!-- 降级标记如果下一组有多个 key显示"降级 · 负载均衡" -->
<div
v-if="groupIndex < formatGroup.keyGroups.length - 1"
class="flex py-0.5"
>
<div class="w-6 flex justify-center shrink-0">
<div class="w-0.5 h-full bg-border" />
</div>
<div class="flex items-center gap-1 text-[10px] text-muted-foreground/50">
<ArrowDown class="w-3 h-3" />
<span>
{{ formatGroup.keyGroups[groupIndex + 1].keys.length > 1 ? '降级 · 负载均衡' : '降级' }}
</span>
</div>
</template>
</div>
</div>
</template>
</div>
</template>
<!-- ========== 提供商优先模式 ========== -->
<template v-else>
<div class="relative">
<!-- 垂直主线 -->
<div class="py-2 pl-3">
<div
v-if="formatGroup.providers.length > 0"
class="absolute left-5 top-0 bottom-0 w-0.5 bg-border"
/>
<div class="py-2">
<div
v-for="(providerEntry, providerIndex) in formatGroup.providers"
:key="`${providerEntry.provider.id}-${providerEntry.endpoint?.id || providerIndex}`"
class="relative"
>
<!-- 节点圆点 -->
<div
class="absolute left-[14px] top-4 w-3 h-3 rounded-full border-2 z-10"
:class="getFormatProviderNodeClass(providerEntry, providerIndex)"
/>
v-for="(providerEntry, providerIndex) in formatGroup.providers"
:key="`${providerEntry.provider.id}-${providerEntry.endpoint?.id || providerIndex}`"
>
<!-- 提供商行 -->
<div class="flex py-1">
<!-- 左侧节点 + 连线 -->
<div class="w-6 flex flex-col items-center shrink-0">
<!-- 上半段连线 -->
<div
class="w-0.5 flex-1"
:class="providerIndex === 0 ? 'bg-transparent' : 'bg-border'"
/>
<!-- 节点圆点 -->
<div
class="w-3 h-3 rounded-full border-2 shrink-0"
:class="getFormatProviderNodeClass(providerEntry, providerIndex)"
/>
<!-- 下半段连线 -->
<div
class="w-0.5 flex-1"
:class="providerIndex === formatGroup.providers.length - 1 ? 'bg-transparent' : 'bg-border'"
/>
</div>
<!-- 提供商卡片 -->
<div
class="ml-10 mr-3 mb-2"
class="flex-1 mr-3"
:class="!providerEntry.provider.is_active || !providerEntry.provider.model_is_active ? 'opacity-50' : ''"
>
<div
@@ -536,14 +551,19 @@
</Transition>
</div>
</div>
</div>
<!-- 降级标记 -->
<div
v-if="providerIndex < formatGroup.providers.length - 1"
class="ml-10 -mt-1 mb-1 flex items-center gap-1"
>
<ArrowDown class="w-3 h-3 text-muted-foreground/50" />
<span class="text-[10px] text-muted-foreground/50">降级</span>
<!-- 降级标记 -->
<div
v-if="providerIndex < formatGroup.providers.length - 1"
class="flex py-0.5"
>
<div class="w-6 flex justify-center shrink-0">
<div class="w-0.5 h-full bg-border" />
</div>
<div class="flex items-center gap-1 text-[10px] text-muted-foreground/50">
<ArrowDown class="w-3 h-3" />
<span>降级</span>
</div>
</div>
</div>
@@ -905,6 +925,13 @@ function getGlobalKeyNodeClass(entry: GlobalKeyEntry, groupIndex: number, keyInd
return 'bg-background border-border'
}
// 判断是否为格式组中的最后一个 Key
function isLastKeyInFormat(formatGroup: ApiFormatGroup, groupIndex: number, keyIndex: number): boolean {
const isLastGroup = groupIndex === formatGroup.keyGroups.length - 1
const isLastKeyInGroup = keyIndex === formatGroup.keyGroups[groupIndex].keys.length - 1
return isLastGroup && isLastKeyInGroup
}
// 获取全局 Key 卡片样式(全局 Key 优先模式)
function getGlobalKeyCardClass(entry: GlobalKeyEntry, groupIndex: number, keyIndex: number): string {
if (!entry.key.is_active || !entry.provider.is_active || !entry.provider.model_is_active) {

View File

@@ -1,3 +1,4 @@
export { default as GlobalModelFormDialog } from './GlobalModelFormDialog.vue'
export { default as ModelDetailDrawer } from './ModelDetailDrawer.vue'
export { default as TieredPricingEditor } from './TieredPricingEditor.vue'
export { default as ModelAliasesTab } from './ModelAliasesTab.vue'

View File

@@ -367,17 +367,106 @@
@edit-model="handleEditModel"
@delete-model="handleDeleteModel"
@batch-assign="handleBatchAssign"
@add-mapping="handleAddMapping"
/>
<!-- 模型名称映射 -->
<ModelAliasesTab
v-if="provider"
ref="modelAliasesTabRef"
:key="`aliases-${provider.id}`"
:provider="provider"
@refresh="handleRelatedDataRefresh"
/>
<!-- 别名映射预览 -->
<Card
v-if="aliasMappingLoading || (aliasMappingPreview && aliasMappingPreview.total_matches > 0)"
class="overflow-hidden"
>
<div class="px-4 py-3 border-b border-border/60">
<h3 class="text-sm font-semibold">
别名映射预览
</h3>
</div>
<!-- 加载状态 -->
<div v-if="aliasMappingLoading" class="flex items-center justify-center py-8">
<RefreshCw class="w-5 h-5 animate-spin text-muted-foreground" />
</div>
<!-- GlobalModel 列表 -->
<div v-else class="divide-y divide-border/40">
<div
v-for="(gmInfo, gmIndex) in computedAliasMappingByModel"
:key="gmInfo.global_model_id"
>
<!-- GlobalModel -->
<div
class="px-4 py-3 flex items-center gap-3 cursor-pointer hover:bg-muted/30 transition-colors"
@click="toggleAliasExpand(gmIndex)"
>
<ChevronRight
class="w-4 h-4 text-muted-foreground transition-transform flex-shrink-0"
:class="{ 'rotate-90': aliasExpandedIndex === gmIndex }"
/>
<div class="flex-1 min-w-0">
<div class="flex items-center gap-2">
<span class="text-sm font-medium truncate">{{ gmInfo.display_name }}</span>
<Badge
v-if="!gmInfo.is_active"
variant="outline"
class="text-[10px] px-1.5 py-0 text-muted-foreground flex-shrink-0"
>
停用
</Badge>
</div>
<div class="flex items-center gap-2 text-xs text-muted-foreground">
<span class="font-mono">{{ gmInfo.global_model_name }}</span>
<span class="text-muted-foreground/50">|</span>
<span class="font-mono text-primary/80">{{ gmInfo.alias_patterns.join(' / ') }}</span>
</div>
</div>
<Badge
variant="secondary"
class="text-xs flex-shrink-0"
>
{{ gmInfo.matched_keys.length }} Key · {{ gmInfo.total_models }} 模型
</Badge>
</div>
<!-- 展开内容匹配的 Key 列表 -->
<div
v-if="aliasExpandedIndex === gmIndex"
class="border-t bg-muted/10 px-4 py-3"
>
<div class="space-y-2">
<div
v-for="keyItem in gmInfo.matched_keys"
:key="keyItem.key_id"
class="bg-background rounded-md border p-3"
>
<div class="flex items-center gap-2 text-sm mb-2">
<Key class="w-3.5 h-3.5 text-muted-foreground flex-shrink-0" />
<span class="font-medium truncate">{{ keyItem.key_name || '未命名密钥' }}</span>
<span class="text-xs text-muted-foreground font-mono ml-auto flex-shrink-0">
{{ keyItem.masked_key }}
</span>
<Badge
v-if="!keyItem.is_active"
variant="secondary"
class="text-[10px] px-1.5 py-0 flex-shrink-0"
>
禁用
</Badge>
</div>
<div class="flex flex-wrap gap-1.5">
<Badge
v-for="match in keyItem.matches"
:key="match.allowed_model"
variant="secondary"
class="text-xs font-mono"
:title="`匹配规则: ${match.alias_pattern}`"
>
{{ match.allowed_model }}
</Badge>
</div>
</div>
</div>
</div>
</div>
</div>
</Card>
</div>
</template>
</Card>
@@ -485,7 +574,6 @@
<script setup lang="ts">
import { ref, watch, computed, nextTick } from 'vue'
import {
Server,
Plus,
Key,
ChevronRight,
@@ -493,13 +581,9 @@ import {
Trash2,
RefreshCw,
X,
Loader2,
Power,
GripVertical,
Copy,
Eye,
EyeOff,
ExternalLink,
Shield
} from 'lucide-vue-next'
import { useEscapeKey } from '@/composables/useEscapeKey'
@@ -508,12 +592,11 @@ import Badge from '@/components/ui/badge.vue'
import Card from '@/components/ui/card.vue'
import { useToast } from '@/composables/useToast'
import { useClipboard } from '@/composables/useClipboard'
import { getProvider, getProviderEndpoints } from '@/api/endpoints'
import { getProvider, getProviderEndpoints, getProviderAliasMappingPreview, type ProviderAliasMappingPreviewResponse } from '@/api/endpoints'
import {
KeyFormDialog,
KeyAllowedModelsEditDialog,
ModelsTab,
ModelAliasesTab,
BatchAssignModelsDialog
} from '@/features/providers/components'
import EndpointFormDialog from '@/features/providers/components/EndpointFormDialog.vue'
@@ -592,8 +675,83 @@ const deleteModelConfirmOpen = ref(false)
const modelToDelete = ref<Model | null>(null)
const batchAssignDialogOpen = ref(false)
// ModelAliasesTab 组件引用
const modelAliasesTabRef = ref<InstanceType<typeof ModelAliasesTab> | null>(null)
// 别名映射预览状态
const aliasMappingPreview = ref<ProviderAliasMappingPreviewResponse | null>(null)
const aliasMappingLoading = ref(false)
const aliasExpandedIndex = ref<number | null>(null)
// 切换别名展开
function toggleAliasExpand(index: number) {
aliasExpandedIndex.value = aliasExpandedIndex.value === index ? null : index
}
// 按 GlobalModel 分组的别名映射数据
interface MatchedKeyItem {
key_id: string
key_name: string
masked_key: string
is_active: boolean
matches: { allowed_model: string; alias_pattern: string }[]
}
interface GlobalModelAliasInfo {
global_model_id: string
global_model_name: string
display_name: string
is_active: boolean
alias_patterns: string[]
matched_keys: MatchedKeyItem[]
total_models: number
}
const computedAliasMappingByModel = computed<GlobalModelAliasInfo[]>(() => {
if (!aliasMappingPreview.value) return []
// 按 GlobalModel 分组
const modelMap = new Map<string, GlobalModelAliasInfo>()
for (const keyInfo of aliasMappingPreview.value.keys) {
for (const gm of keyInfo.matching_global_models) {
if (!modelMap.has(gm.global_model_id)) {
// 收集所有匹配用到的别名规则(去重)
const patterns = new Set<string>()
for (const match of gm.matched_models) {
patterns.add(match.alias_pattern)
}
modelMap.set(gm.global_model_id, {
global_model_id: gm.global_model_id,
global_model_name: gm.global_model_name,
display_name: gm.display_name,
is_active: gm.is_active,
alias_patterns: Array.from(patterns),
matched_keys: [],
total_models: 0,
})
}
const modelInfo = modelMap.get(gm.global_model_id)!
// 更新别名规则集合(可能来自不同 Key 的匹配)
for (const match of gm.matched_models) {
if (!modelInfo.alias_patterns.includes(match.alias_pattern)) {
modelInfo.alias_patterns.push(match.alias_pattern)
}
}
modelInfo.matched_keys.push({
key_id: keyInfo.key_id,
key_name: keyInfo.key_name,
masked_key: keyInfo.masked_key,
is_active: keyInfo.is_active,
matches: gm.matched_models,
})
modelInfo.total_models += gm.matched_models.length
}
}
return Array.from(modelMap.values())
})
// 拖动排序相关状态(旧的端点级别拖拽,保留以兼容)
const dragState = ref({
@@ -625,9 +783,7 @@ const hasBlockingDialogOpen = computed(() =>
deleteKeyConfirmOpen.value ||
modelFormDialogOpen.value ||
deleteModelConfirmOpen.value ||
batchAssignDialogOpen.value ||
// 检测 ModelAliasesTab 子组件的 Dialog 是否打开
modelAliasesTabRef.value?.dialogOpen
batchAssignDialogOpen.value
)
// 所有密钥的扁平列表(带端点信息)
@@ -665,6 +821,7 @@ watch(() => props.providerId, (newId) => {
if (newId && props.open) {
loadProvider()
loadEndpoints()
loadAliasMappingPreview()
}
}, { immediate: true })
@@ -673,6 +830,7 @@ watch(() => props.open, (newOpen) => {
if (newOpen && props.providerId) {
loadProvider()
loadEndpoints()
loadAliasMappingPreview()
} else if (!newOpen) {
// 重置所有状态
provider.value = null
@@ -696,6 +854,10 @@ watch(() => props.open, (newOpen) => {
// 清除已显示的密钥(安全考虑)
revealedKeys.value.clear()
// 重置别名映射预览
aliasMappingPreview.value = null
aliasExpandedIndex.value = null
}
})
@@ -722,11 +884,6 @@ function toggleEndpoint(endpointId: string) {
}
}
async function handleRelatedDataRefresh() {
await loadProvider()
emit('refresh')
}
// 显示端点管理对话框
function showAddEndpointDialog() {
endpointDialogOpen.value = true
@@ -962,11 +1119,6 @@ function handleBatchAssign() {
batchAssignDialogOpen.value = true
}
// 处理添加映射(从 ModelsTab 触发)
function handleAddMapping(model: Model) {
modelAliasesTabRef.value?.openAddDialogForModel(model.id)
}
// 处理批量关联完成
async function handleBatchAssignChanged() {
await loadProvider()
@@ -1375,6 +1527,25 @@ async function loadEndpoints() {
}
}
// 加载别名映射预览
async function loadAliasMappingPreview() {
if (!props.providerId) return
aliasMappingLoading.value = true
try {
aliasMappingPreview.value = await getProviderAliasMappingPreview(props.providerId)
} catch (err: any) {
// 404 静默处理Provider 不存在或无别名配置)
if (err.response?.status !== 404) {
console.warn('加载别名映射预览失败:', err)
showError('加载别名映射预览失败')
}
aliasMappingPreview.value = null
} finally {
aliasMappingLoading.value = false
}
}
// 添加 ESC 键监听
useEscapeKey(() => {
if (props.open) {

View File

@@ -8,7 +8,5 @@ export { default as ProviderModelFormDialog } from './ProviderModelFormDialog.vu
export { default as ProviderDetailDrawer } from './ProviderDetailDrawer.vue'
export { default as EndpointHealthTimeline } from './EndpointHealthTimeline.vue'
export { default as BatchAssignModelsDialog } from './BatchAssignModelsDialog.vue'
export { default as ModelAliasDialog } from './ModelAliasDialog.vue'
export { default as ModelsTab } from './provider-tabs/ModelsTab.vue'
export { default as ModelAliasesTab } from './provider-tabs/ModelAliasesTab.vue'

View File

@@ -117,15 +117,6 @@
</td>
<td class="align-top px-4 py-3">
<div class="flex justify-center gap-1.5">
<Button
variant="ghost"
size="icon"
class="h-8 w-8"
title="添加映射"
@click="addMapping(model)"
>
<Link class="w-3.5 h-3.5" />
</Button>
<Button
variant="ghost"
size="icon"
@@ -179,7 +170,7 @@
<script setup lang="ts">
import { ref, computed, onMounted } from 'vue'
import { Box, Edit, Trash2, Layers, Power, Copy, Link } from 'lucide-vue-next'
import { Box, Edit, Trash2, Layers, Power, Copy } from 'lucide-vue-next'
import Card from '@/components/ui/card.vue'
import Button from '@/components/ui/button.vue'
import { useToast } from '@/composables/useToast'
@@ -195,7 +186,6 @@ const emit = defineEmits<{
'editModel': [model: Model]
'deleteModel': [model: Model]
'batchAssign': []
'addMapping': [model: Model]
}>()
const { error: showError, success: showSuccess } = useToast()
@@ -315,11 +305,6 @@ function deleteModel(model: Model) {
emit('deleteModel', model)
}
// 添加映射
function addMapping(model: Model) {
emit('addMapping', model)
}
// 打开批量关联对话框
function openBatchAssignDialog() {
emit('batchAssign')

View File

@@ -41,6 +41,7 @@ dependencies = [
"aiosqlite>=0.20.0",
"loguru>=0.7.3",
"tiktoken>=0.5.0",
"regex>=2024.0.0", # 支持超时的正则库,用于 ReDoS 防护
"aiofiles>=24.1.0",
"aiohttp>=3.12.15",
"aiosmtplib>=4.0.2",

View File

@@ -321,6 +321,14 @@ class AdminCreateGlobalModelAdapter(AdminApiAdapter):
payload: GlobalModelCreate
async def handle(self, context): # type: ignore[override]
from src.core.exceptions import InvalidRequestException
from src.core.model_permissions import validate_and_extract_model_aliases
# 验证 model_aliases如果有
is_valid, error, _ = validate_and_extract_model_aliases(self.payload.config)
if not is_valid:
raise InvalidRequestException(f"别名规则验证失败: {error}", "model_aliases")
# 将 TieredPricingConfig 转换为 dict
tiered_pricing_dict = self.payload.default_tiered_pricing.model_dump()
@@ -352,6 +360,40 @@ class AdminUpdateGlobalModelAdapter(AdminApiAdapter):
payload: GlobalModelUpdate
async def handle(self, context): # type: ignore[override]
from src.core.exceptions import InvalidRequestException
from src.core.model_permissions import validate_and_extract_model_aliases
# 验证 model_aliases如果有
is_valid, error, _ = validate_and_extract_model_aliases(self.payload.config)
if not is_valid:
raise InvalidRequestException(f"别名规则验证失败: {error}", "model_aliases")
# 使用行级锁获取旧的 GlobalModel 信息,防止并发更新导致的竞态条件
# 设置 2 秒锁超时,允许短暂等待而非立即失败,提升并发操作的成功率
from sqlalchemy import text
from sqlalchemy.exc import OperationalError
from src.models.database import GlobalModel
try:
# 设置会话级别的锁超时(仅影响当前事务)
context.db.execute(text("SET LOCAL lock_timeout = '2s'"))
old_global_model = (
context.db.query(GlobalModel)
.filter(GlobalModel.id == self.global_model_id)
.with_for_update()
.first()
)
except OperationalError as e:
# 锁超时或锁冲突时返回友好的错误提示
error_msg = str(e).lower()
if "lock" in error_msg or "timeout" in error_msg:
raise InvalidRequestException("该模型正在被其他操作更新,请稍后重试")
raise
old_model_name = old_global_model.name if old_global_model else None
new_model_name = self.payload.name if self.payload.name else old_model_name
# 执行更新(此时仍持有行锁)
global_model = GlobalModelService.update_global_model(
db=context.db,
global_model_id=self.global_model_id,
@@ -360,11 +402,18 @@ class AdminUpdateGlobalModelAdapter(AdminApiAdapter):
logger.info(f"GlobalModel 已更新: id={global_model.id} name={global_model.name}")
# 失效相关缓存
# 更新成功后才失效缓存(避免回滚时缓存已被清除的竞态问题)
# 注意:此时事务已提交(由 pipeline 管理),数据已持久化
from src.services.cache.invalidation import get_cache_invalidation_service
cache_service = get_cache_invalidation_service()
cache_service.on_global_model_changed(global_model.name)
# 同步清理新旧两个名称的缓存(防止名称变更时的竞态)
if old_model_name:
cache_service.on_global_model_changed(old_model_name, self.global_model_id)
if new_model_name and new_model_name != old_model_name:
cache_service.on_global_model_changed(new_model_name, self.global_model_id)
# 异步失效更多缓存
await cache_service.on_global_model_changed_async(global_model.name, global_model.id)
return GlobalModelResponse.model_validate(global_model)
@@ -376,24 +425,44 @@ class AdminDeleteGlobalModelAdapter(AdminApiAdapter):
global_model_id: str
async def handle(self, context): # type: ignore[override]
# 获取 GlobalModel 信息(用于失效缓存)
# 使用行级锁获取 GlobalModel 信息,防止并发操作导致的竞态条件
# 设置 2 秒锁超时,允许短暂等待而非立即失败
from sqlalchemy import text
from sqlalchemy.exc import OperationalError
from src.core.exceptions import InvalidRequestException
from src.models.database import GlobalModel
global_model = (
context.db.query(GlobalModel).filter(GlobalModel.id == self.global_model_id).first()
)
try:
# 设置会话级别的锁超时(仅影响当前事务)
context.db.execute(text("SET LOCAL lock_timeout = '2s'"))
global_model = (
context.db.query(GlobalModel)
.filter(GlobalModel.id == self.global_model_id)
.with_for_update()
.first()
)
except OperationalError as e:
# 锁超时或锁冲突时返回友好的错误提示
error_msg = str(e).lower()
if "lock" in error_msg or "timeout" in error_msg:
raise InvalidRequestException("该模型正在被其他操作处理,请稍后重试")
raise
model_name = global_model.name if global_model else None
model_id = global_model.id if global_model else self.global_model_id
# 执行删除(此时仍持有行锁)
GlobalModelService.delete_global_model(context.db, self.global_model_id)
logger.info(f"GlobalModel 已删除: id={self.global_model_id}")
# 失效相关缓存
if model_name:
from src.services.cache.invalidation import get_cache_invalidation_service
# 删除成功后才失效缓存(避免回滚时缓存已被清除的竞态问题)
from src.services.cache.invalidation import get_cache_invalidation_service
cache_service = get_cache_invalidation_service()
cache_service.on_global_model_changed(model_name)
cache_service = get_cache_invalidation_service()
if model_name:
cache_service.on_global_model_changed(model_name, model_id)
await cache_service.on_global_model_changed_async(model_name, model_id)
return None
@@ -413,7 +482,9 @@ class AdminBatchAssignToProvidersAdapter(AdminApiAdapter):
create_models=self.payload.create_models,
)
logger.info(f"批量为 Provider 添加 GlobalModel: global_model_id={self.global_model_id} success={len(result['success'])} errors={len(result['errors'])}")
logger.info(
f"批量为 Provider 添加 GlobalModel: global_model_id={self.global_model_id} success={len(result['success'])} errors={len(result['errors'])}"
)
return BatchAssignToProvidersResponse(**result)

View File

@@ -17,6 +17,7 @@ from sqlalchemy.orm import Session, selectinload
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.model_permissions import parse_allowed_models_to_list
from src.database import get_db
from src.models.database import (
GlobalModel,
@@ -49,6 +50,8 @@ class RoutingKeyInfo(BaseModel):
health_score: float = Field(100.0, description="健康度分数")
is_active: bool
api_formats: List[str] = Field(default_factory=list, description="支持的 API 格式")
# 模型白名单
allowed_models: Optional[List[str]] = Field(None, description="允许的模型列表null 表示不限制")
# 熔断状态
circuit_breaker_open: bool = Field(False, description="熔断器是否打开")
circuit_breaker_formats: List[str] = Field(default_factory=list, description="熔断的 API 格式列表")
@@ -299,6 +302,21 @@ class AdminGetModelRoutingPreviewAdapter(AdminApiAdapter):
circuit_breaker_open = True
circuit_breaker_formats.append(fmt)
# 解析 allowed_models
# 语义说明:
# - None: 不限制(允许所有模型)
# - {}: 空字典 = 不限制normalize_allowed_models 返回 None
# - []: 空列表 = 拒绝所有模型
# - {"CLAUDE": []}: 指定格式空列表 = 该格式拒绝所有
raw_allowed_models = key.allowed_models
if raw_allowed_models is None:
allowed_models_list = None
elif isinstance(raw_allowed_models, dict) and not raw_allowed_models:
# 空 dict {} 在语义上等价于不限制
allowed_models_list = None
else:
allowed_models_list = parse_allowed_models_to_list(raw_allowed_models)
key_infos.append(
RoutingKeyInfo(
id=key.id or "",
@@ -313,6 +331,7 @@ class AdminGetModelRoutingPreviewAdapter(AdminApiAdapter):
health_score=health_score,
is_active=bool(key.is_active),
api_formats=key.api_formats or [],
allowed_models=allowed_models_list,
circuit_breaker_open=circuit_breaker_open,
circuit_breaker_formats=circuit_breaker_formats,
)

View File

@@ -1,10 +1,11 @@
"""管理员 Provider 管理路由。"""
import asyncio
from datetime import datetime, timezone
from typing import Optional
from typing import Dict, List, Optional
from fastapi import APIRouter, Depends, Query, Request
from pydantic import ValidationError
from pydantic import BaseModel, ConfigDict, Field, ValidationError
from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
@@ -12,15 +13,81 @@ from src.api.base.pipeline import ApiRequestPipeline
from src.core.enums import ProviderBillingType
from src.core.exceptions import InvalidRequestException, NotFoundException
from src.core.logger import logger
from src.core.model_permissions import match_model_with_pattern, parse_allowed_models_to_list
from src.database import get_db
from src.models.admin_requests import CreateProviderRequest, UpdateProviderRequest
from src.models.database import Provider
from src.models.database import GlobalModel, Provider, ProviderAPIKey
from src.services.cache.provider_cache import ProviderCacheService
router = APIRouter(tags=["Provider CRUD"])
pipeline = ApiRequestPipeline()
# 别名映射预览配置(管理后台功能,限制宽松)
ALIAS_PREVIEW_MAX_KEYS = 200
ALIAS_PREVIEW_MAX_MODELS = 500
ALIAS_PREVIEW_TIMEOUT_SECONDS = 10.0
# ========== Response Models ==========
class AliasMatchedModel(BaseModel):
"""匹配到的模型名称"""
allowed_model: str = Field(..., description="Key 白名单中匹配到的模型名")
alias_pattern: str = Field(..., description="匹配的别名规则")
class AliasMatchingGlobalModel(BaseModel):
"""有别名匹配的 GlobalModel"""
global_model_id: str
global_model_name: str
display_name: str
is_active: bool
matched_models: List[AliasMatchedModel] = Field(
default_factory=list, description="匹配到的模型列表"
)
model_config = ConfigDict(from_attributes=True)
class AliasMatchingKey(BaseModel):
"""有别名匹配的 Key"""
key_id: str
key_name: str
masked_key: str
is_active: bool
allowed_models: List[str] = Field(default_factory=list, description="Key 的模型白名单")
matching_global_models: List[AliasMatchingGlobalModel] = Field(
default_factory=list, description="匹配到的 GlobalModel 列表"
)
model_config = ConfigDict(from_attributes=True)
class ProviderAliasMappingPreviewResponse(BaseModel):
"""Provider 别名映射预览响应"""
provider_id: str
provider_name: str
keys: List[AliasMatchingKey] = Field(
default_factory=list, description="有白名单配置且匹配到别名的 Key 列表"
)
total_keys: int = Field(0, description="有匹配结果的 Key 数量")
total_matches: int = Field(
0, description="匹配到的 GlobalModel 数量(同一 GlobalModel 被多个 Key 匹配会重复计数)"
)
# 截断提示字段
truncated: bool = Field(False, description="是否因限制而截断结果")
truncated_keys: int = Field(0, description="被截断的 Key 数量")
truncated_models: int = Field(0, description="被截断的 GlobalModel 数量")
model_config = ConfigDict(from_attributes=True)
@router.get("/")
async def list_providers(
request: Request,
@@ -292,7 +359,9 @@ class AdminUpdateProviderAdapter(AdminApiAdapter):
setattr(provider, field, ProviderBillingType(value))
elif field == "proxy" and value is not None:
# proxy 需要转换为 dict如果是 Pydantic 模型)
setattr(provider, field, value if isinstance(value, dict) else value.model_dump())
setattr(
provider, field, value if isinstance(value, dict) else value.model_dump()
)
else:
setattr(provider, field, value)
@@ -345,3 +414,232 @@ class AdminDeleteProviderAdapter(AdminApiAdapter):
db.delete(provider)
db.commit()
return {"message": "提供商已删除"}
@router.get(
"/{provider_id}/alias-mapping-preview",
response_model=ProviderAliasMappingPreviewResponse,
)
async def get_provider_alias_mapping_preview(
request: Request,
provider_id: str,
db: Session = Depends(get_db),
) -> ProviderAliasMappingPreviewResponse:
"""
获取 Provider 别名映射预览
查看该 Provider 的 Key 白名单能够被哪些 GlobalModel 的别名规则匹配。
**路径参数**:
- `provider_id`: Provider ID
**返回字段**:
- `provider_id`: Provider ID
- `provider_name`: Provider 名称
- `keys`: 有白名单配置的 Key 列表,每个包含:
- `key_id`: Key ID
- `key_name`: Key 名称
- `masked_key`: 脱敏的 Key
- `allowed_models`: Key 的白名单模型列表
- `matching_global_models`: 匹配到的 GlobalModel 列表
- `total_keys`: 有白名单配置的 Key 总数
- `total_matches`: 匹配到的 GlobalModel 总数
"""
adapter = AdminGetProviderAliasMappingPreviewAdapter(provider_id=provider_id)
# 添加超时保护,防止复杂匹配导致的 DoS
try:
return await asyncio.wait_for(
pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode),
timeout=ALIAS_PREVIEW_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
logger.warning(f"别名映射预览超时: provider_id={provider_id}")
raise InvalidRequestException("别名映射预览超时,请简化配置或稍后重试")
class AdminGetProviderAliasMappingPreviewAdapter(AdminApiAdapter):
"""获取 Provider 别名映射预览"""
def __init__(self, provider_id: str):
self.provider_id = provider_id
async def handle(self, context) -> ProviderAliasMappingPreviewResponse: # type: ignore[override]
db = context.db
# 获取 Provider
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException("提供商不存在", "provider")
# 统计截断情况
truncated_keys = 0
truncated_models = 0
# 获取该 Provider 有白名单配置的 Key 总数(用于截断统计)
from sqlalchemy import func
total_keys_with_allowed_models = (
db.query(func.count(ProviderAPIKey.id))
.filter(
ProviderAPIKey.provider_id == self.provider_id,
ProviderAPIKey.allowed_models.isnot(None),
)
.scalar()
or 0
)
# 获取该 Provider 有白名单配置的 Key只查询需要的字段
keys = (
db.query(
ProviderAPIKey.id,
ProviderAPIKey.name,
ProviderAPIKey.api_key,
ProviderAPIKey.is_active,
ProviderAPIKey.allowed_models,
)
.filter(
ProviderAPIKey.provider_id == self.provider_id,
ProviderAPIKey.allowed_models.isnot(None),
)
.limit(ALIAS_PREVIEW_MAX_KEYS)
.all()
)
# 计算被截断的 Key 数量
if total_keys_with_allowed_models > ALIAS_PREVIEW_MAX_KEYS:
truncated_keys = total_keys_with_allowed_models - ALIAS_PREVIEW_MAX_KEYS
# 获取有 model_aliases 配置的 GlobalModel 总数(用于截断统计)
total_models_with_aliases = (
db.query(func.count(GlobalModel.id))
.filter(
GlobalModel.config.isnot(None),
GlobalModel.config["model_aliases"].isnot(None),
func.jsonb_array_length(GlobalModel.config["model_aliases"]) > 0,
)
.scalar()
or 0
)
# 只查询有 model_aliases 配置的 GlobalModel使用 SQLAlchemy JSONB 操作符)
global_models = (
db.query(
GlobalModel.id,
GlobalModel.name,
GlobalModel.display_name,
GlobalModel.is_active,
GlobalModel.config,
)
.filter(
GlobalModel.config.isnot(None),
GlobalModel.config["model_aliases"].isnot(None),
func.jsonb_array_length(GlobalModel.config["model_aliases"]) > 0,
)
.limit(ALIAS_PREVIEW_MAX_MODELS)
.all()
)
# 计算被截断的 GlobalModel 数量
if total_models_with_aliases > ALIAS_PREVIEW_MAX_MODELS:
truncated_models = total_models_with_aliases - ALIAS_PREVIEW_MAX_MODELS
# 构建有别名配置的 GlobalModel 映射
models_with_aliases: Dict[str, tuple] = {} # id -> (model_info, aliases)
for gm in global_models:
config = gm.config or {}
aliases = config.get("model_aliases", [])
if aliases:
models_with_aliases[gm.id] = (gm, aliases)
# 如果没有任何带别名的 GlobalModel直接返回空结果
if not models_with_aliases:
return ProviderAliasMappingPreviewResponse(
provider_id=provider.id,
provider_name=provider.name,
keys=[],
total_keys=0,
total_matches=0,
truncated=False,
truncated_keys=0,
truncated_models=0,
)
key_infos: List[AliasMatchingKey] = []
total_matches = 0
# 创建 CryptoService 实例
from src.core.crypto import CryptoService
crypto = CryptoService()
for key in keys:
allowed_models_list = parse_allowed_models_to_list(key.allowed_models)
if not allowed_models_list:
continue
# 生成脱敏 Key
masked_key = "***"
if key.api_key:
try:
decrypted_key = crypto.decrypt(key.api_key, silent=True)
if len(decrypted_key) > 8:
masked_key = f"{decrypted_key[:4]}***{decrypted_key[-4:]}"
else:
masked_key = f"{decrypted_key[:2]}***"
except Exception:
pass
# 查找匹配的 GlobalModel
matching_global_models: List[AliasMatchingGlobalModel] = []
for gm_id, (gm, aliases) in models_with_aliases.items():
matched_models: List[AliasMatchedModel] = []
for allowed_model in allowed_models_list:
for alias_pattern in aliases:
if match_model_with_pattern(alias_pattern, allowed_model):
matched_models.append(
AliasMatchedModel(
allowed_model=allowed_model,
alias_pattern=alias_pattern,
)
)
break # 一个 allowed_model 只需匹配一个别名
if matched_models:
matching_global_models.append(
AliasMatchingGlobalModel(
global_model_id=gm.id,
global_model_name=gm.name,
display_name=gm.display_name,
is_active=bool(gm.is_active),
matched_models=matched_models,
)
)
total_matches += 1
if matching_global_models:
key_infos.append(
AliasMatchingKey(
key_id=key.id or "",
key_name=key.name or "",
masked_key=masked_key,
is_active=bool(key.is_active),
allowed_models=allowed_models_list,
matching_global_models=matching_global_models,
)
)
is_truncated = truncated_keys > 0 or truncated_models > 0
return ProviderAliasMappingPreviewResponse(
provider_id=provider.id,
provider_name=provider.name,
keys=key_infos,
total_keys=len(key_infos),
total_matches=total_matches,
truncated=is_truncated,
truncated_keys=truncated_keys,
truncated_models=truncated_models,
)

View File

@@ -53,6 +53,7 @@ from src.models.database import (
ProviderEndpoint,
User,
)
from src.services.cache.aware_scheduler import ProviderCandidate
from src.services.provider.transport import build_provider_url
@@ -312,6 +313,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
provider: Provider,
endpoint: ProviderEndpoint,
key: ProviderAPIKey,
candidate: ProviderCandidate,
) -> AsyncGenerator[bytes, None]:
return await self._execute_stream_request(
ctx,
@@ -322,6 +324,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
original_request_body,
original_headers,
query_params,
candidate,
)
try:
@@ -411,6 +414,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
original_request_body: Dict[str, Any],
original_headers: Dict[str, str],
query_params: Optional[Dict[str, str]] = None,
candidate: Optional[ProviderCandidate] = None,
) -> AsyncGenerator[bytes, None]:
"""执行流式请求并返回流生成器"""
# 重置上下文状态(重试时清除之前的数据)
@@ -425,11 +429,13 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
provider_api_format=str(endpoint.api_format) if endpoint.api_format else None,
)
# 获取模型映射
mapped_model = await self._get_mapped_model(
source_model=ctx.model,
provider_id=str(provider.id),
)
# 获取模型映射(优先使用别名匹配到的模型,其次是 Provider 级别的映射)
mapped_model = candidate.alias_matched_model if candidate else None
if not mapped_model:
mapped_model = await self._get_mapped_model(
source_model=ctx.model,
provider_id=str(provider.id),
)
# 应用模型映射到请求体
if mapped_model:
@@ -650,17 +656,20 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
provider: Provider,
endpoint: ProviderEndpoint,
key: ProviderAPIKey,
candidate: ProviderCandidate,
) -> Dict[str, Any]:
nonlocal provider_name, response_json, status_code, response_headers
nonlocal provider_request_headers, provider_request_body, mapped_model_result
provider_name = str(provider.name)
# 获取模型映射
mapped_model = await self._get_mapped_model(
source_model=model,
provider_id=str(provider.id),
)
# 获取模型映射(优先使用别名匹配到的模型,其次是 Provider 级别的映射)
mapped_model = candidate.alias_matched_model if candidate else None
if not mapped_model:
mapped_model = await self._get_mapped_model(
source_model=model,
provider_id=str(provider.id),
)
# 应用模型映射
if mapped_model:

View File

@@ -64,6 +64,7 @@ from src.models.database import (
ProviderEndpoint,
User,
)
from src.services.cache.aware_scheduler import ProviderCandidate
from src.services.provider.transport import build_provider_url
from src.utils.sse_parser import SSEEventParser
from src.utils.timeout import read_first_chunk_with_ttfb_timeout
@@ -317,6 +318,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
provider: Provider,
endpoint: ProviderEndpoint,
key: ProviderAPIKey,
candidate: ProviderCandidate,
) -> AsyncGenerator[bytes, None]:
return await self._execute_stream_request(
ctx,
@@ -326,6 +328,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
original_request_body,
original_headers,
query_params,
candidate,
)
try:
@@ -405,6 +408,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
original_request_body: Dict[str, Any],
original_headers: Dict[str, str],
query_params: Optional[Dict[str, str]] = None,
candidate: Optional[ProviderCandidate] = None,
) -> AsyncGenerator[bytes, None]:
"""执行流式请求并返回流生成器"""
# 重置上下文状态(重试时清除之前的数据,避免累积)
@@ -432,11 +436,13 @@ class CliMessageHandlerBase(BaseMessageHandler):
ctx.provider_api_format = str(endpoint.api_format) if endpoint.api_format else ""
ctx.client_api_format = ctx.api_format # 已在 process_stream 中设置
# 获取模型映射(映射名称 → 实际模型名
mapped_model = await self._get_mapped_model(
source_model=ctx.model,
provider_id=str(provider.id),
)
# 获取模型映射(优先使用别名匹配到的模型,其次是 Provider 级别的映射)
mapped_model = candidate.alias_matched_model if candidate else None
if not mapped_model:
mapped_model = await self._get_mapped_model(
source_model=ctx.model,
provider_id=str(provider.id),
)
# 应用模型映射到请求体(子类可覆盖此方法处理不同格式)
if mapped_model:
@@ -1247,14 +1253,29 @@ class CliMessageHandlerBase(BaseMessageHandler):
stream_generator: AsyncGenerator[bytes, None],
) -> AsyncGenerator[bytes, None]:
"""创建带监控的流生成器"""
import time as time_module
last_chunk_time = time_module.time()
chunk_count = 0
try:
async for chunk in stream_generator:
last_chunk_time = time_module.time()
chunk_count += 1
yield chunk
except asyncio.CancelledError:
# 计算距离上次收到 chunk 的时间
time_since_last_chunk = time_module.time() - last_chunk_time
# 如果响应已完成,不标记为失败
if not ctx.has_completion:
ctx.status_code = 499
ctx.error_message = "Client disconnected"
logger.warning(
f"ID:{ctx.request_id} | Stream cancelled: "
f"chunks={chunk_count}, "
f"has_completion={ctx.has_completion}, "
f"time_since_last_chunk={time_since_last_chunk:.2f}s, "
f"output_tokens={ctx.output_tokens}"
)
raise
except httpx.TimeoutException as e:
ctx.status_code = 504
@@ -1536,16 +1557,19 @@ class CliMessageHandlerBase(BaseMessageHandler):
provider: Provider,
endpoint: ProviderEndpoint,
key: ProviderAPIKey,
candidate: ProviderCandidate,
) -> Dict[str, Any]:
nonlocal provider_name, response_json, status_code, response_headers, provider_api_format, provider_request_headers, provider_request_body, mapped_model_result, response_metadata_result
provider_name = str(provider.name)
provider_api_format = str(endpoint.api_format) if endpoint.api_format else ""
# 获取模型映射(映射名称 → 实际模型名
mapped_model = await self._get_mapped_model(
source_model=model,
provider_id=str(provider.id),
)
# 获取模型映射(优先使用别名匹配到的模型,其次是 Provider 级别的映射)
mapped_model = candidate.alias_matched_model if candidate else None
if not mapped_model:
mapped_model = await self._get_mapped_model(
source_model=model,
provider_id=str(provider.id),
)
# 应用模型映射到请求体(子类可覆盖此方法处理不同格式)
if mapped_model:

View File

@@ -268,7 +268,7 @@ async def test_connection(
orchestrator = FallbackOrchestrator(db, redis_client)
# 定义请求函数
async def test_request_func(_prov, endpoint, key):
async def test_request_func(_prov, endpoint, key, _candidate):
request_builder = PassthroughRequestBuilder()
provider_payload, provider_headers = request_builder.build(
payload, {}, endpoint, key, is_stream=False

View File

@@ -6,9 +6,27 @@
2. 按格式模式(字典): {"OPENAI": ["gpt-4o"], "CLAUDE": ["claude-sonnet-4"]}
使用 None/null 表示不限制(允许所有模型)
支持模型别名匹配:
- GlobalModel.config.model_aliases 定义别名模式
- 别名模式支持正则表达式语法
- 例如claude-haiku-.* 可匹配 claude-haiku-4.5, claude-haiku-last
- 使用 regex 库的原生超时保护100ms防止 ReDoS
"""
from typing import Any, Dict, List, Optional, Set, Union
import re
from functools import lru_cache
from typing import Dict, List, Optional, Set, Tuple, Union
import regex
from src.core.logger import logger
# 别名规则限制
MAX_ALIASES_PER_MODEL = 50
MAX_ALIAS_LENGTH = 200
MAX_MODEL_NAME_LENGTH = 200 # 与 MAX_ALIAS_LENGTH 保持一致
REGEX_MATCH_TIMEOUT_MS = 100 # 正则匹配超时(毫秒)
# 类型别名
AllowedModels = Optional[Union[List[str], Dict[str, List[str]]]]
@@ -284,3 +302,288 @@ def convert_to_simple_mode(allowed_models: AllowedModels) -> Optional[List[str]]
return sorted(all_models) if all_models else None
return None
def parse_allowed_models_to_list(allowed_models: AllowedModels) -> List[str]:
"""
解析 allowed_models支持 list 和 dict 格式)为统一的列表
与 convert_to_simple_mode 的区别:
- 本函数返回空列表而非 None用于 UI 展示)
- convert_to_simple_mode 返回 None 表示不限制
Args:
allowed_models: 允许的模型配置(列表或字典)
Returns:
模型名称列表(可能为空)
"""
if allowed_models is None:
return []
if isinstance(allowed_models, list):
return allowed_models
if isinstance(allowed_models, dict):
all_models: Set[str] = set()
for models in allowed_models.values():
if isinstance(models, list):
all_models.update(models)
return sorted(all_models)
return []
def validate_alias_pattern(pattern: str) -> Tuple[bool, Optional[str]]:
"""
验证别名模式是否安全
Args:
pattern: 待验证的正则模式
Returns:
(is_valid, error_message)
"""
if not pattern or not pattern.strip():
return False, "别名规则不能为空"
if len(pattern) > MAX_ALIAS_LENGTH:
return False, f"别名规则过长 (最大 {MAX_ALIAS_LENGTH} 字符)"
# 尝试编译验证语法
try:
re.compile(f"^{pattern}$", re.IGNORECASE)
except re.error as e:
return False, f"正则表达式语法错误: {e}"
return True, None
def validate_model_aliases(aliases: Optional[List[str]]) -> Tuple[bool, Optional[str]]:
"""
验证别名列表是否合法
Args:
aliases: 别名列表
Returns:
(is_valid, error_message)
"""
if not aliases:
return True, None
if len(aliases) > MAX_ALIASES_PER_MODEL:
return False, f"别名规则数量超限 (最大 {MAX_ALIASES_PER_MODEL} 条)"
for i, alias in enumerate(aliases):
is_valid, error = validate_alias_pattern(alias)
if not is_valid:
return False, f"{i + 1} 条规则无效: {error}"
return True, None
def validate_and_extract_model_aliases(
config: Optional[dict],
) -> Tuple[bool, Optional[str], Optional[List[str]]]:
"""
从 config 中验证并提取 model_aliases
用于 GlobalModel 创建/更新时的统一验证
Args:
config: GlobalModel 的 config 字典
Returns:
(is_valid, error_message, aliases):
- is_valid: 验证是否通过
- error_message: 错误信息(验证失败时)
- aliases: 提取的别名列表(验证成功时)
"""
if not config or "model_aliases" not in config:
return True, None, None
aliases = config.get("model_aliases")
# 允许显式设置为 None表示清除别名
if aliases is None:
return True, None, None
# 类型验证:必须是列表
if not isinstance(aliases, list):
return False, "model_aliases 必须是数组类型", None
# 元素类型验证:必须是字符串
if not all(isinstance(a, str) for a in aliases):
return False, "model_aliases 数组元素必须是字符串", None
# 业务规则验证
is_valid, error = validate_model_aliases(aliases)
if not is_valid:
return False, error, None
return True, None, aliases
@lru_cache(maxsize=2000)
def _compile_pattern_cached(pattern: str) -> Optional[regex.Pattern]:
"""
编译正则模式(带 LRU 缓存)
Args:
pattern: 正则模式字符串
Returns:
编译后的正则对象,如果无效则返回 None
"""
try:
return regex.compile(f"^{pattern}$", regex.IGNORECASE)
except regex.error as e:
logger.debug(f"正则编译失败: pattern={pattern}, error={e}")
return None
def clear_regex_cache() -> None:
"""
清空正则缓存
在 GlobalModel 别名更新时调用此函数以确保缓存一致性
"""
_compile_pattern_cached.cache_clear()
logger.debug("[RegexCache] 缓存已清空")
def _match_with_timeout(
compiled_regex: regex.Pattern, text: str, timeout_ms: int = REGEX_MATCH_TIMEOUT_MS
) -> Optional[bool]:
"""
带超时的正则匹配(使用 regex 库的原生超时支持)
相比 ThreadPoolExecutor 方案的优势:
- C 层面中断匹配,不会留下僵尸线程
- 更低的性能开销
- 更精确的超时控制
Args:
compiled_regex: 编译后的 regex.Pattern 对象
text: 待匹配的文本
timeout_ms: 超时时间(毫秒)
Returns:
True: 匹配成功
False: 匹配失败
None: 超时或异常
"""
try:
# regex 库的 timeout 参数单位是秒
result = compiled_regex.match(text, timeout=timeout_ms / 1000.0)
return result is not None
except TimeoutError:
logger.warning(
f"正则匹配超时 ({timeout_ms}ms): pattern={compiled_regex.pattern[:50]}..., text={text[:50]}..."
)
return None
except Exception as e:
logger.warning(f"正则匹配异常: {e}")
return None
def match_model_with_pattern(pattern: str, model_name: str) -> bool:
"""
检查模型名是否匹配别名模式(支持正则表达式)
安全特性:
- 长度限制检查
- 正则编译缓存
- 正则匹配超时保护100ms使用 regex 库原生超时)
Args:
pattern: 别名模式,支持正则表达式语法
model_name: 被检查的模型名(来自 Key 的 allowed_models
Returns:
True 如果匹配
示例:
match_model_with_pattern("claude-haiku-.*", "claude-haiku-4.5") -> True
match_model_with_pattern("gpt-4o", "gpt-4o") -> True
match_model_with_pattern("gpt-4o", "gpt-4") -> False
"""
# 快速路径:精确匹配
if pattern.lower() == model_name.lower():
return True
# 长度检查
if len(pattern) > MAX_ALIAS_LENGTH or len(model_name) > MAX_MODEL_NAME_LENGTH:
return False
# 使用缓存的编译结果
compiled = _compile_pattern_cached(pattern)
if compiled is None:
return False
# 使用带超时的匹配regex 库原生支持)
result = _match_with_timeout(compiled, model_name)
return result is True
def check_model_allowed_with_aliases(
model_name: str,
allowed_models: AllowedModels,
api_format: Optional[str] = None,
resolved_model_name: Optional[str] = None,
model_aliases: Optional[List[str]] = None,
) -> tuple[bool, Optional[str]]:
"""
检查模型是否被允许(支持别名通配符匹配)
匹配优先级:
1. 精确匹配 model_name用户请求的模型名
2. 精确匹配 resolved_model_nameGlobalModel.name
3. 遍历 model_aliases检查每个别名是否匹配 allowed_models 中的任一项
别名匹配顺序说明:
- 按 allowed_models 集合的迭代顺序遍历(通常为字母顺序,因为内部使用 set
- 对于每个 allowed_model按 model_aliases 数组顺序依次尝试匹配
- 返回第一个成功匹配的 allowed_model
- 如需确定性行为,请确保 model_aliases 中的规则从最具体到最通用排序
Args:
model_name: 请求的模型名称
allowed_models: 允许的模型配置(来自 Provider Key
api_format: 当前请求的 API 格式
resolved_model_name: 解析后的 GlobalModel.name
model_aliases: GlobalModel 的别名列表(来自 config.model_aliases
Returns:
(is_allowed, matched_model_name):
- is_allowed: 是否允许使用该模型
- matched_model_name: 通过别名匹配到的模型名(仅别名匹配时有值,精确匹配时为 None
"""
# 先尝试精确匹配(使用原有逻辑)
if check_model_allowed(model_name, allowed_models, api_format, resolved_model_name):
return True, None
# 如果精确匹配失败且有别名配置,尝试别名匹配
if not model_aliases:
return False, None
# 获取 allowed_models 的集合
allowed_set = normalize_allowed_models(allowed_models, api_format)
if allowed_set is None:
# 不限制,已在 check_model_allowed 中返回 True
return True, None
if len(allowed_set) == 0:
# 空集合 = 拒绝所有
return False, None
# 遍历 allowed_models 中的每个模型名,检查是否有别名能匹配
# 注意:返回第一个匹配的模型名,匹配顺序由 allowed_set 迭代顺序和 model_aliases 数组顺序决定
for allowed_model in allowed_set:
for alias_pattern in model_aliases:
if match_model_with_pattern(alias_pattern, allowed_model):
# 返回匹配到的模型名,用于实际请求
return True, allowed_model
return False, None

View File

@@ -32,6 +32,7 @@ from __future__ import annotations
import hashlib
import random
import re
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
@@ -76,6 +77,7 @@ class ProviderCandidate:
is_cached: bool = False
is_skipped: bool = False # 是否被跳过
skip_reason: Optional[str] = None # 跳过原因
alias_matched_model: Optional[str] = None # 通过别名匹配到的模型名(用于实际请求)
@dataclass
@@ -590,6 +592,9 @@ class CacheAwareScheduler:
requested_model_name = model_name
resolved_model_name = str(global_model.name)
# 提取模型别名(用于 Provider Key 的 allowed_models 匹配)
model_aliases: List[str] = (global_model.config or {}).get("model_aliases", [])
# 获取合并后的访问限制ApiKey + User
restrictions = self._get_effective_restrictions(user_api_key)
allowed_api_formats = restrictions["allowed_api_formats"]
@@ -657,6 +662,7 @@ class CacheAwareScheduler:
target_format=target_format,
model_name=requested_model_name,
resolved_model_name=resolved_model_name,
model_aliases=model_aliases,
affinity_key=affinity_key,
max_candidates=max_candidates,
is_stream=is_stream,
@@ -852,7 +858,8 @@ class CacheAwareScheduler:
model_name: str,
capability_requirements: Optional[Dict[str, bool]] = None,
resolved_model_name: Optional[str] = None,
) -> Tuple[bool, Optional[str]]:
model_aliases: Optional[List[str]] = None,
) -> Tuple[bool, Optional[str], Optional[str]]:
"""
检查 API Key 的可用性
@@ -864,28 +871,53 @@ class CacheAwareScheduler:
model_name: 模型名称
capability_requirements: 能力需求(可选)
resolved_model_name: 解析后的 GlobalModel.name可选
model_aliases: GlobalModel 的别名列表(用于通配符匹配)
Returns:
(is_available, skip_reason)
(is_available, skip_reason, alias_matched_model)
- is_available: Key 是否可用
- skip_reason: 不可用时的原因
- alias_matched_model: 通过别名匹配到的模型名(用于实际请求)
"""
# 检查熔断器状态(使用详细状态方法获取更丰富的跳过原因,按 API 格式)
is_available, circuit_reason = health_monitor.get_circuit_breaker_status(
key, api_format=api_format
)
if not is_available:
return False, circuit_reason or "熔断器已打开"
return False, circuit_reason or "熔断器已打开", None
# 模型权限检查:使用 allowed_models 白名单(支持简单列表和按格式字典两种模式)
# None = 允许所有模型,[] = 拒绝所有模型,["a","b"] = 只允许指定模型
from src.core.model_permissions import check_model_allowed, get_allowed_models_preview
# 支持通配符别名匹配(通过 model_aliases
from src.core.model_permissions import (
check_model_allowed_with_aliases,
get_allowed_models_preview,
)
if not check_model_allowed(
model_name=model_name,
allowed_models=key.allowed_models,
api_format=api_format,
resolved_model_name=resolved_model_name,
):
return False, f"模型权限不匹配(允许: {get_allowed_models_preview(key.allowed_models)})"
try:
is_allowed, alias_matched_model = check_model_allowed_with_aliases(
model_name=model_name,
allowed_models=key.allowed_models,
api_format=api_format,
resolved_model_name=resolved_model_name,
model_aliases=model_aliases,
)
except TimeoutError:
# 正则匹配超时(可能是 ReDoS 攻击或复杂模式)
logger.warning(f"别名匹配超时: key_id={key.id}, model={model_name}")
return False, "别名匹配超时,请简化配置", None
except re.error as e:
# 正则语法错误(配置问题)
logger.warning(f"别名规则无效: key_id={key.id}, model={model_name}, error={e}")
return False, f"别名规则无效: {str(e)}", None
except Exception as e:
# 其他未知异常
logger.error(f"别名匹配异常: key_id={key.id}, model={model_name}, error={e}", exc_info=True)
# 异常时保守处理:不允许使用该 Key
return False, "别名匹配失败", None
if not is_allowed:
return False, f"模型权限不匹配(允许: {get_allowed_models_preview(key.allowed_models)})", None
# Key 级别的能力匹配检查
# 注意:模型级别的能力检查已在 _check_model_support 中完成
@@ -896,9 +928,9 @@ class CacheAwareScheduler:
key_caps: Dict[str, bool] = dict(key.capabilities or {})
is_match, skip_reason = check_capability_match(key_caps, capability_requirements)
if not is_match:
return False, skip_reason
return False, skip_reason, None
return True, None
return True, None, alias_matched_model
async def _build_candidates(
self,
@@ -908,6 +940,7 @@ class CacheAwareScheduler:
model_name: str,
affinity_key: Optional[str],
resolved_model_name: Optional[str] = None,
model_aliases: Optional[List[str]] = None,
max_candidates: Optional[int] = None,
is_stream: bool = False,
capability_requirements: Optional[Dict[str, bool]] = None,
@@ -924,6 +957,7 @@ class CacheAwareScheduler:
model_name: 模型名称(用户请求的名称,可能是映射名)
affinity_key: 亲和性标识符通常为API Key ID
resolved_model_name: 解析后的 GlobalModel.name用于 Key.allowed_models 校验)
model_aliases: GlobalModel 的别名列表(用于 Key.allowed_models 通配符匹配)
max_candidates: 最大候选数
is_stream: 是否是流式请求,如果为 True 则过滤不支持流式的 Provider
capability_requirements: 能力需求(可选)
@@ -981,12 +1015,13 @@ class CacheAwareScheduler:
for key in keys:
# Key 级别的能力检查
is_available, skip_reason = self._check_key_availability(
is_available, skip_reason, alias_matched_model = self._check_key_availability(
key,
target_format_str,
model_name,
capability_requirements,
resolved_model_name=resolved_model_name,
model_aliases=model_aliases,
)
candidate = ProviderCandidate(
@@ -995,6 +1030,7 @@ class CacheAwareScheduler:
key=key,
is_skipped=not is_available,
skip_reason=skip_reason,
alias_matched_model=alias_matched_model,
)
candidates.append(candidate)

View File

@@ -1,10 +1,7 @@
"""
缓存失效服务
统一管理各种缓存的失效逻辑,支持:
1. GlobalModel 变更时失效相关缓存
2. Model 变更时失效模型映射缓存
3. 支持同步和异步缓存后端
统一管理各种缓存的失效逻辑
"""
from typing import Optional
@@ -13,56 +10,54 @@ from src.core.logger import logger
class CacheInvalidationService:
"""
缓存失效服务
提供统一的缓存失效接口,当数据库模型变更时自动清理相关缓存
"""
"""缓存失效服务"""
def __init__(self):
"""初始化缓存失效服务"""
self._model_mappers = [] # 可能有多个 ModelMapperMiddleware 实例
self._model_mappers = []
def register_model_mapper(self, model_mapper):
"""注册 ModelMapper 实例"""
if model_mapper not in self._model_mappers:
self._model_mappers.append(model_mapper)
logger.debug(f"[CacheInvalidation] ModelMapper 已注册 (实例: {id(model_mapper)},总数: {len(self._model_mappers)})")
def on_global_model_changed(self, model_name: str):
async def on_global_model_changed(
self, model_name: str, global_model_id: Optional[str] = None
) -> None:
"""
GlobalModel 变更时的缓存失效
Args:
model_name: 变更的 GlobalModel.name
global_model_id: GlobalModel ID可选
"""
logger.info(f"[CacheInvalidation] GlobalModel 变更: {model_name}")
# 失效所有 ModelMapper 中与此模型相关的缓存
# 1. 清空正则缓存
from src.core.model_permissions import clear_regex_cache
clear_regex_cache()
# 2. 清空 ModelMapper 缓存
for mapper in self._model_mappers:
# 清空所有缓存(因为不知道哪些 provider 使用了这个模型)
mapper.clear_cache()
logger.debug(f"[CacheInvalidation] 已清空 ModelMapper 缓存")
# 3. 清空 ModelCacheService 缓存
from src.services.cache.model_cache import ModelCacheService
try:
await ModelCacheService.invalidate_global_model_cache(
global_model_id=global_model_id or "", name=model_name
)
except Exception as e:
logger.error(f"[CacheInvalidation] 失效 ModelCacheService 缓存失败: {e}")
def on_model_changed(self, provider_id: str, global_model_id: str):
"""
Model 变更时的缓存失效
Args:
provider_id: Provider ID
global_model_id: GlobalModel ID
"""
logger.info(f"[CacheInvalidation] Model 变更: provider={provider_id[:8]}..., "
f"global_model={global_model_id[:8]}...")
# 失效 ModelMapper 中特定 Provider 的缓存
"""Model 变更时的缓存失效"""
for mapper in self._model_mappers:
mapper.refresh_cache(provider_id)
def clear_all_caches(self):
"""清空所有缓存"""
logger.info("[CacheInvalidation] 清空所有缓存")
for mapper in self._model_mappers:
mapper.clear_cache()
@@ -72,16 +67,10 @@ _cache_invalidation_service: Optional[CacheInvalidationService] = None
def get_cache_invalidation_service() -> CacheInvalidationService:
"""
获取全局缓存失效服务实例
Returns:
CacheInvalidationService 实例
"""
"""获取全局缓存失效服务实例"""
global _cache_invalidation_service
if _cache_invalidation_service is None:
_cache_invalidation_service = CacheInvalidationService()
logger.debug("[CacheInvalidation] 初始化缓存失效服务")
return _cache_invalidation_service

View File

@@ -138,7 +138,7 @@ class RequestExecutor:
context.concurrent_requests = key_rpm_count # 用于记录,实际是 RPM 计数
context.start_time = time.time()
response = await request_func(provider, endpoint, key)
response = await request_func(provider, endpoint, key, candidate)
context.elapsed_ms = int((time.time() - context.start_time) * 1000)

3986
uv.lock generated

File diff suppressed because it is too large Load Diff