mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-04 08:42:27 +08:00
Compare commits
55 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c6b9582978 | ||
|
|
3d583b0a8d | ||
|
|
f849a54027 | ||
|
|
f2cd96c34c | ||
|
|
34d480910a | ||
|
|
f16fb28405 | ||
|
|
a0ffc2c406 | ||
|
|
a7bfab1475 | ||
|
|
84d4db0f8d | ||
|
|
903b182fdf | ||
|
|
d9bd0790fe | ||
|
|
c6fcc7982d | ||
|
|
aaa6a8f60d | ||
|
|
11774c69b6 | ||
|
|
8f0a0cbdb1 | ||
|
|
51b85915d2 | ||
|
|
b0d295c6c9 | ||
|
|
c94f011462 | ||
|
|
3296d026e3 | ||
|
|
2e01c7cf5a | ||
|
|
88e37594cf | ||
|
|
03ee6c16d9 | ||
|
|
743f23e640 | ||
|
|
7068aa9130 | ||
|
|
56fb6bf36c | ||
|
|
728f9bb126 | ||
|
|
5319c06f0e | ||
|
|
0ef6e04593 | ||
|
|
beae7a2616 | ||
|
|
21eedbe331 | ||
|
|
393d4d13ff | ||
|
|
77613795ed | ||
|
|
f54127cba5 | ||
|
|
54370cb3f9 | ||
|
|
07b81351d9 | ||
|
|
5d829a100a | ||
|
|
006cd2c3e5 | ||
|
|
90ca5065ee | ||
|
|
66307f8f49 | ||
|
|
fc0ca3944e | ||
|
|
25a049d607 | ||
|
|
15b4f665d1 | ||
|
|
36a84e19b4 | ||
|
|
1f7db361ad | ||
|
|
766a3280d6 | ||
|
|
4cbe0c38f7 | ||
|
|
18ce6637b6 | ||
|
|
0e1de65eb3 | ||
|
|
b64d507c6e | ||
|
|
6e63116cc9 | ||
|
|
9ac56662da | ||
|
|
624d81f453 | ||
|
|
2423edec98 | ||
|
|
d8516e42ea | ||
|
|
a73e0d51db |
@@ -23,7 +23,7 @@ RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
# Python 依赖(安装到系统,不用 -e 模式)
|
||||
COPY pyproject.toml README.md ./
|
||||
RUN mkdir -p src && touch src/__init__.py && \
|
||||
pip install --no-cache-dir .
|
||||
SETUPTOOLS_SCM_PRETEND_VERSION=0.1.0 pip install --no-cache-dir .
|
||||
|
||||
# 前端依赖
|
||||
COPY frontend/package*.json /tmp/frontend/
|
||||
|
||||
@@ -0,0 +1,251 @@
|
||||
"""remove_model_mappings_add_aliases
|
||||
|
||||
合并迁移:
|
||||
1. 添加 provider_model_aliases 字段到 models 表
|
||||
2. 迁移 model_mappings 数据到 provider_model_aliases
|
||||
3. 删除 model_mappings 表
|
||||
4. 添加索引优化别名解析性能
|
||||
|
||||
Revision ID: e9b3d63f0cbf
|
||||
Revises: 20251210_baseline
|
||||
Create Date: 2025-12-14 13:00:22.828183+00:00
|
||||
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'e9b3d63f0cbf'
|
||||
down_revision = '20251210_baseline'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""添加 provider_model_aliases 字段,迁移数据,删除 model_mappings 表"""
|
||||
# 1. 添加 provider_model_aliases 字段
|
||||
op.add_column(
|
||||
'models',
|
||||
sa.Column('provider_model_aliases', sa.JSON(), nullable=True)
|
||||
)
|
||||
|
||||
# 2. 迁移 model_mappings 数据
|
||||
bind = op.get_bind()
|
||||
session = Session(bind=bind)
|
||||
|
||||
model_mappings_table = sa.table(
|
||||
"model_mappings",
|
||||
sa.column("source_model", sa.String),
|
||||
sa.column("target_global_model_id", sa.String),
|
||||
sa.column("provider_id", sa.String),
|
||||
sa.column("mapping_type", sa.String),
|
||||
sa.column("is_active", sa.Boolean),
|
||||
)
|
||||
|
||||
models_table = sa.table(
|
||||
"models",
|
||||
sa.column("id", sa.String),
|
||||
sa.column("provider_id", sa.String),
|
||||
sa.column("global_model_id", sa.String),
|
||||
sa.column("provider_model_aliases", sa.JSON),
|
||||
sa.column("updated_at", sa.DateTime(timezone=True)),
|
||||
)
|
||||
|
||||
def normalize_alias_list(value) -> list[dict]:
|
||||
"""将 DB 返回的 JSON 值规范化为 list[{'name': str, 'priority': int}]"""
|
||||
if value is None:
|
||||
return []
|
||||
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
value = json.loads(value) if value else []
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
if not isinstance(value, list):
|
||||
return []
|
||||
|
||||
normalized: list[dict] = []
|
||||
for item in value:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
raw_name = item.get("name")
|
||||
if not isinstance(raw_name, str):
|
||||
continue
|
||||
name = raw_name.strip()
|
||||
if not name:
|
||||
continue
|
||||
|
||||
raw_priority = item.get("priority", 1)
|
||||
try:
|
||||
priority = int(raw_priority)
|
||||
except Exception:
|
||||
priority = 1
|
||||
if priority < 1:
|
||||
priority = 1
|
||||
|
||||
normalized.append({"name": name, "priority": priority})
|
||||
|
||||
return normalized
|
||||
|
||||
# 查询所有活跃的 provider 级别 alias(只迁移 is_active=True 且 mapping_type='alias' 的)
|
||||
# 全局别名/映射不迁移(新架构不再支持 source_model -> GlobalModel.name 的解析)
|
||||
mappings = session.execute(
|
||||
sa.select(
|
||||
model_mappings_table.c.source_model,
|
||||
model_mappings_table.c.target_global_model_id,
|
||||
model_mappings_table.c.provider_id,
|
||||
)
|
||||
.where(
|
||||
model_mappings_table.c.is_active.is_(True),
|
||||
model_mappings_table.c.provider_id.isnot(None),
|
||||
model_mappings_table.c.mapping_type == "alias",
|
||||
)
|
||||
.order_by(model_mappings_table.c.provider_id, model_mappings_table.c.source_model)
|
||||
).all()
|
||||
|
||||
# 按 (provider_id, target_global_model_id) 分组,收集别名
|
||||
alias_groups: dict = {}
|
||||
for source_model, target_global_model_id, provider_id in mappings:
|
||||
if not isinstance(source_model, str):
|
||||
continue
|
||||
source_model = source_model.strip()
|
||||
if not source_model:
|
||||
continue
|
||||
if not isinstance(provider_id, str) or not provider_id:
|
||||
continue
|
||||
if not isinstance(target_global_model_id, str) or not target_global_model_id:
|
||||
continue
|
||||
|
||||
key = (provider_id, target_global_model_id)
|
||||
if key not in alias_groups:
|
||||
alias_groups[key] = []
|
||||
priority = len(alias_groups[key]) + 1
|
||||
alias_groups[key].append({"name": source_model, "priority": priority})
|
||||
|
||||
# 更新对应的 models 记录
|
||||
for (provider_id, global_model_id), aliases in alias_groups.items():
|
||||
model_row = session.execute(
|
||||
sa.select(models_table.c.id, models_table.c.provider_model_aliases)
|
||||
.where(
|
||||
models_table.c.provider_id == provider_id,
|
||||
models_table.c.global_model_id == global_model_id,
|
||||
)
|
||||
.limit(1)
|
||||
).first()
|
||||
|
||||
if model_row:
|
||||
model_id = model_row[0]
|
||||
existing_aliases = normalize_alias_list(model_row[1])
|
||||
|
||||
existing_names = {a["name"] for a in existing_aliases}
|
||||
merged_aliases = list(existing_aliases)
|
||||
for alias in aliases:
|
||||
name = alias.get("name")
|
||||
if not isinstance(name, str):
|
||||
continue
|
||||
name = name.strip()
|
||||
if not name or name in existing_names:
|
||||
continue
|
||||
|
||||
merged_aliases.append(
|
||||
{
|
||||
"name": name,
|
||||
"priority": len(merged_aliases) + 1,
|
||||
}
|
||||
)
|
||||
existing_names.add(name)
|
||||
|
||||
session.execute(
|
||||
models_table.update()
|
||||
.where(models_table.c.id == model_id)
|
||||
.values(
|
||||
provider_model_aliases=merged_aliases if merged_aliases else None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
)
|
||||
|
||||
session.commit()
|
||||
|
||||
# 3. 删除 model_mappings 表
|
||||
op.drop_table('model_mappings')
|
||||
|
||||
# 4. 添加索引优化别名解析性能
|
||||
# provider_model_name 索引(支持精确匹配)
|
||||
op.create_index(
|
||||
"idx_model_provider_model_name",
|
||||
"models",
|
||||
["provider_model_name"],
|
||||
unique=False,
|
||||
postgresql_where=sa.text("is_active = true"),
|
||||
)
|
||||
|
||||
# provider_model_aliases GIN 索引(支持 JSONB 查询,仅 PostgreSQL)
|
||||
if bind.dialect.name == "postgresql":
|
||||
# 将 json 列转为 jsonb(jsonb 性能更好且支持 GIN 索引)
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE models
|
||||
ALTER COLUMN provider_model_aliases TYPE jsonb
|
||||
USING provider_model_aliases::jsonb
|
||||
"""
|
||||
)
|
||||
# 创建 GIN 索引
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_model_provider_model_aliases_gin
|
||||
ON models USING gin(provider_model_aliases jsonb_path_ops)
|
||||
WHERE is_active = true
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""恢复 model_mappings 表,移除 provider_model_aliases 字段和索引"""
|
||||
bind = op.get_bind()
|
||||
|
||||
# 1. 删除索引
|
||||
op.drop_index("idx_model_provider_model_name", table_name="models")
|
||||
|
||||
if bind.dialect.name == "postgresql":
|
||||
op.execute("DROP INDEX IF EXISTS idx_model_provider_model_aliases_gin")
|
||||
# 将 jsonb 列还原为 json
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE models
|
||||
ALTER COLUMN provider_model_aliases TYPE json
|
||||
USING provider_model_aliases::json
|
||||
"""
|
||||
)
|
||||
|
||||
# 2. 恢复 model_mappings 表
|
||||
op.create_table(
|
||||
'model_mappings',
|
||||
sa.Column('id', sa.String(36), primary_key=True),
|
||||
sa.Column('source_model', sa.String(200), nullable=False),
|
||||
sa.Column(
|
||||
'target_global_model_id',
|
||||
sa.String(36),
|
||||
sa.ForeignKey('global_models.id', ondelete='CASCADE'),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column('provider_id', sa.String(36), sa.ForeignKey('providers.id'), nullable=True),
|
||||
sa.Column('mapping_type', sa.String(20), nullable=False, server_default='alias'),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.UniqueConstraint('source_model', 'provider_id', name='uq_model_mapping_source_provider'),
|
||||
)
|
||||
op.create_index('ix_model_mappings_source_model', 'model_mappings', ['source_model'])
|
||||
op.create_index('ix_model_mappings_target_global_model_id', 'model_mappings', ['target_global_model_id'])
|
||||
op.create_index('ix_model_mappings_provider_id', 'model_mappings', ['provider_id'])
|
||||
op.create_index('ix_model_mappings_mapping_type', 'model_mappings', ['mapping_type'])
|
||||
|
||||
# 3. 移除 provider_model_aliases 字段
|
||||
op.drop_column('models', 'provider_model_aliases')
|
||||
@@ -271,3 +271,71 @@ export const cacheAnalysisApi = {
|
||||
return response.data
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 模型映射缓存管理 API ====================
|
||||
|
||||
// 映射条目
|
||||
export interface ModelMappingItem {
|
||||
mapping_name: string
|
||||
global_model_name: string | null
|
||||
global_model_display_name: string | null
|
||||
providers: string[]
|
||||
ttl: number | null
|
||||
}
|
||||
|
||||
// 未映射的条目(NOT_FOUND、invalid、error)
|
||||
export interface UnmappedEntry {
|
||||
mapping_name: string
|
||||
status: 'not_found' | 'invalid' | 'error'
|
||||
ttl: number | null
|
||||
}
|
||||
|
||||
export interface ModelMappingCacheStats {
|
||||
available: boolean
|
||||
message?: string
|
||||
ttl_seconds?: number
|
||||
total_keys?: number
|
||||
breakdown?: {
|
||||
model_by_id: number
|
||||
model_by_provider_global: number
|
||||
global_model_by_id: number
|
||||
global_model_by_name: number
|
||||
global_model_resolve: number
|
||||
}
|
||||
mappings?: ModelMappingItem[]
|
||||
unmapped?: UnmappedEntry[] | null
|
||||
}
|
||||
|
||||
export interface ClearModelMappingCacheResponse {
|
||||
status: string
|
||||
message: string
|
||||
deleted_count?: number
|
||||
model_name?: string
|
||||
deleted_keys?: string[]
|
||||
}
|
||||
|
||||
export const modelMappingCacheApi = {
|
||||
/**
|
||||
* 获取模型映射缓存统计
|
||||
*/
|
||||
async getStats(): Promise<ModelMappingCacheStats> {
|
||||
const response = await api.get('/api/admin/monitoring/cache/model-mapping/stats')
|
||||
return response.data.data
|
||||
},
|
||||
|
||||
/**
|
||||
* 清除所有模型映射缓存
|
||||
*/
|
||||
async clearAll(): Promise<ClearModelMappingCacheResponse> {
|
||||
const response = await api.delete('/api/admin/monitoring/cache/model-mapping')
|
||||
return response.data
|
||||
},
|
||||
|
||||
/**
|
||||
* 清除指定模型名称的映射缓存
|
||||
*/
|
||||
async clearByName(modelName: string): Promise<ClearModelMappingCacheResponse> {
|
||||
const response = await api.delete(`/api/admin/monitoring/cache/model-mapping/${encodeURIComponent(modelName)}`)
|
||||
return response.data
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,121 +0,0 @@
|
||||
/**
|
||||
* 模型别名管理 API
|
||||
*/
|
||||
|
||||
import client from '../client'
|
||||
import type { ModelMapping, ModelMappingCreate, ModelMappingUpdate } from './types'
|
||||
|
||||
export interface ModelAlias {
|
||||
id: string
|
||||
alias: string
|
||||
global_model_id: string
|
||||
global_model_name: string | null
|
||||
global_model_display_name: string | null
|
||||
provider_id: string | null
|
||||
provider_name: string | null
|
||||
scope: 'global' | 'provider'
|
||||
mapping_type: 'alias' | 'mapping'
|
||||
is_active: boolean
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
export interface CreateModelAliasRequest {
|
||||
alias: string
|
||||
global_model_id: string
|
||||
provider_id?: string | null
|
||||
mapping_type?: 'alias' | 'mapping'
|
||||
is_active?: boolean
|
||||
}
|
||||
|
||||
export interface UpdateModelAliasRequest {
|
||||
alias?: string
|
||||
global_model_id?: string
|
||||
provider_id?: string | null
|
||||
mapping_type?: 'alias' | 'mapping'
|
||||
is_active?: boolean
|
||||
}
|
||||
|
||||
function transformMapping(mapping: ModelMapping): ModelAlias {
|
||||
return {
|
||||
id: mapping.id,
|
||||
alias: mapping.source_model,
|
||||
global_model_id: mapping.target_global_model_id,
|
||||
global_model_name: mapping.target_global_model_name,
|
||||
global_model_display_name: mapping.target_global_model_display_name,
|
||||
provider_id: mapping.provider_id ?? null,
|
||||
provider_name: mapping.provider_name ?? null,
|
||||
scope: mapping.scope,
|
||||
mapping_type: mapping.mapping_type || 'alias',
|
||||
is_active: mapping.is_active,
|
||||
created_at: mapping.created_at,
|
||||
updated_at: mapping.updated_at
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取别名列表
|
||||
*/
|
||||
export async function getAliases(params?: {
|
||||
provider_id?: string
|
||||
global_model_id?: string
|
||||
is_active?: boolean
|
||||
skip?: number
|
||||
limit?: number
|
||||
}): Promise<ModelAlias[]> {
|
||||
const response = await client.get('/api/admin/models/mappings', {
|
||||
params: {
|
||||
provider_id: params?.provider_id,
|
||||
target_global_model_id: params?.global_model_id,
|
||||
is_active: params?.is_active,
|
||||
skip: params?.skip,
|
||||
limit: params?.limit
|
||||
}
|
||||
})
|
||||
return (response.data as ModelMapping[]).map(transformMapping)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取单个别名
|
||||
*/
|
||||
export async function getAlias(id: string): Promise<ModelAlias> {
|
||||
const response = await client.get(`/api/admin/models/mappings/${id}`)
|
||||
return transformMapping(response.data)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建别名
|
||||
*/
|
||||
export async function createAlias(data: CreateModelAliasRequest): Promise<ModelAlias> {
|
||||
const payload: ModelMappingCreate = {
|
||||
source_model: data.alias,
|
||||
target_global_model_id: data.global_model_id,
|
||||
provider_id: data.provider_id ?? null,
|
||||
mapping_type: data.mapping_type ?? 'alias',
|
||||
is_active: data.is_active ?? true
|
||||
}
|
||||
const response = await client.post('/api/admin/models/mappings', payload)
|
||||
return transformMapping(response.data)
|
||||
}
|
||||
|
||||
/**
|
||||
* 更新别名
|
||||
*/
|
||||
export async function updateAlias(id: string, data: UpdateModelAliasRequest): Promise<ModelAlias> {
|
||||
const payload: ModelMappingUpdate = {
|
||||
source_model: data.alias,
|
||||
target_global_model_id: data.global_model_id,
|
||||
provider_id: data.provider_id ?? null,
|
||||
mapping_type: data.mapping_type,
|
||||
is_active: data.is_active
|
||||
}
|
||||
const response = await client.patch(`/api/admin/models/mappings/${id}`, payload)
|
||||
return transformMapping(response.data)
|
||||
}
|
||||
|
||||
/**
|
||||
* 删除别名
|
||||
*/
|
||||
export async function deleteAlias(id: string): Promise<void> {
|
||||
await client.delete(`/api/admin/models/mappings/${id}`)
|
||||
}
|
||||
@@ -4,6 +4,5 @@ export * from './endpoints'
|
||||
export * from './keys'
|
||||
export * from './health'
|
||||
export * from './models'
|
||||
export * from './aliases'
|
||||
export * from './adaptive'
|
||||
export * from './global-models'
|
||||
|
||||
@@ -5,9 +5,6 @@ import type {
|
||||
ModelUpdate,
|
||||
ModelCatalogResponse,
|
||||
ProviderAvailableSourceModelsResponse,
|
||||
UpdateModelMappingRequest,
|
||||
UpdateModelMappingResponse,
|
||||
DeleteModelMappingResponse
|
||||
} from './types'
|
||||
|
||||
/**
|
||||
@@ -99,27 +96,6 @@ export async function getProviderAvailableSourceModels(
|
||||
return response.data
|
||||
}
|
||||
|
||||
/**
|
||||
* 更新目录中的模型映射
|
||||
*/
|
||||
export async function updateCatalogMapping(
|
||||
mappingId: string,
|
||||
data: UpdateModelMappingRequest
|
||||
): Promise<UpdateModelMappingResponse> {
|
||||
const response = await client.put(`/api/admin/models/catalog/mappings/${mappingId}`, data)
|
||||
return response.data
|
||||
}
|
||||
|
||||
/**
|
||||
* 删除目录中的模型映射
|
||||
*/
|
||||
export async function deleteCatalogMapping(
|
||||
mappingId: string
|
||||
): Promise<DeleteModelMappingResponse> {
|
||||
const response = await client.delete(`/api/admin/models/catalog/mappings/${mappingId}`)
|
||||
return response.data
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量为 Provider 关联 GlobalModels
|
||||
*/
|
||||
|
||||
@@ -211,11 +211,17 @@ export interface ConcurrencyStatus {
|
||||
key_max_concurrent?: number
|
||||
}
|
||||
|
||||
export interface ProviderModelAlias {
|
||||
name: string
|
||||
priority: number // 优先级(数字越小优先级越高)
|
||||
}
|
||||
|
||||
export interface Model {
|
||||
id: string
|
||||
provider_id: string
|
||||
global_model_id?: string // 关联的 GlobalModel ID
|
||||
provider_model_name: string // Provider 侧的模型名称(原 name)
|
||||
provider_model_name: string // Provider 侧的主模型名称
|
||||
provider_model_aliases?: ProviderModelAlias[] | null // 模型名称别名列表(带优先级)
|
||||
// 原始配置值(可能为空,为空时使用 GlobalModel 默认值)
|
||||
price_per_request?: number | null // 按次计费价格
|
||||
tiered_pricing?: TieredPricingConfig | null // 阶梯计费配置
|
||||
@@ -244,7 +250,8 @@ export interface Model {
|
||||
}
|
||||
|
||||
export interface ModelCreate {
|
||||
provider_model_name: string // Provider 侧的模型名称(原 name)
|
||||
provider_model_name: string // Provider 侧的主模型名称
|
||||
provider_model_aliases?: ProviderModelAlias[] // 模型名称别名列表(带优先级)
|
||||
global_model_id: string // 关联的 GlobalModel ID(必填)
|
||||
// 计费配置(可选,为空时使用 GlobalModel 默认值)
|
||||
price_per_request?: number // 按次计费价格
|
||||
@@ -261,6 +268,7 @@ export interface ModelCreate {
|
||||
|
||||
export interface ModelUpdate {
|
||||
provider_model_name?: string
|
||||
provider_model_aliases?: ProviderModelAlias[] | null // 模型名称别名列表(带优先级)
|
||||
global_model_id?: string
|
||||
price_per_request?: number | null // 按次计费价格(null 表示清空/使用默认值)
|
||||
tiered_pricing?: TieredPricingConfig | null // 阶梯计费配置
|
||||
@@ -273,21 +281,6 @@ export interface ModelUpdate {
|
||||
is_available?: boolean
|
||||
}
|
||||
|
||||
export interface ModelMapping {
|
||||
id: string
|
||||
source_model: string // 别名/源模型名
|
||||
target_global_model_id: string // 目标 GlobalModel ID
|
||||
target_global_model_name: string | null
|
||||
target_global_model_display_name: string | null
|
||||
provider_id: string | null
|
||||
provider_name: string | null
|
||||
scope: 'global' | 'provider'
|
||||
mapping_type: 'alias' | 'mapping'
|
||||
is_active: boolean
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
export interface ModelCapabilities {
|
||||
supports_vision: boolean
|
||||
supports_function_calling: boolean
|
||||
@@ -335,7 +328,6 @@ export interface ModelCatalogItem {
|
||||
global_model_name: string // GlobalModel.name(原 source_model)
|
||||
display_name: string // GlobalModel.display_name
|
||||
description?: string | null // GlobalModel.description
|
||||
aliases: string[] // 所有指向该 GlobalModel 的别名列表
|
||||
providers: ModelCatalogProviderDetail[] // 支持该模型的 Provider 列表
|
||||
price_range: ModelPriceRange // 价格区间
|
||||
total_providers: number
|
||||
@@ -351,8 +343,6 @@ export interface ProviderAvailableSourceModel {
|
||||
global_model_name: string // GlobalModel.name(原 source_model)
|
||||
display_name: string // GlobalModel.display_name
|
||||
provider_model_name: string // Model.provider_model_name(Provider 侧的模型名)
|
||||
has_alias: boolean // 是否有别名指向该 GlobalModel
|
||||
aliases: string[] // 别名列表
|
||||
model_id?: string | null // Model.id
|
||||
price: ProviderModelPriceInfo
|
||||
capabilities: ModelCapabilities
|
||||
@@ -371,65 +361,6 @@ export interface BatchAssignProviderConfig {
|
||||
model_id?: string
|
||||
}
|
||||
|
||||
export interface BatchAssignModelMappingRequest {
|
||||
global_model_id: string // 要分配的 GlobalModel ID(原 source_model)
|
||||
providers: BatchAssignProviderConfig[]
|
||||
}
|
||||
|
||||
export interface BatchAssignProviderResult {
|
||||
provider_id: string
|
||||
mapping_id?: string | null
|
||||
created_model: boolean
|
||||
model_id?: string | null
|
||||
updated: boolean
|
||||
}
|
||||
|
||||
export interface BatchAssignError {
|
||||
provider_id: string
|
||||
error: string
|
||||
}
|
||||
|
||||
export interface BatchAssignModelMappingResponse {
|
||||
success: boolean
|
||||
created_mappings: BatchAssignProviderResult[]
|
||||
errors: BatchAssignError[]
|
||||
}
|
||||
|
||||
export interface ModelMappingCreate {
|
||||
source_model: string // 源模型名或别名
|
||||
target_global_model_id: string // 目标 GlobalModel ID
|
||||
provider_id?: string | null
|
||||
mapping_type?: 'alias' | 'mapping'
|
||||
is_active?: boolean
|
||||
}
|
||||
|
||||
export interface ModelMappingUpdate {
|
||||
source_model?: string // 源模型名或别名
|
||||
target_global_model_id?: string // 目标 GlobalModel ID
|
||||
provider_id?: string | null
|
||||
mapping_type?: 'alias' | 'mapping'
|
||||
is_active?: boolean
|
||||
}
|
||||
|
||||
export interface UpdateModelMappingRequest {
|
||||
source_model?: string
|
||||
target_global_model_id?: string
|
||||
provider_id?: string | null
|
||||
mapping_type?: 'alias' | 'mapping'
|
||||
is_active?: boolean
|
||||
}
|
||||
|
||||
export interface UpdateModelMappingResponse {
|
||||
success: boolean
|
||||
mapping_id: string
|
||||
message?: string
|
||||
}
|
||||
|
||||
export interface DeleteModelMappingResponse {
|
||||
success: boolean
|
||||
message?: string
|
||||
}
|
||||
|
||||
export interface AdaptiveStatsResponse {
|
||||
adaptive_mode: boolean
|
||||
current_limit: number | null
|
||||
|
||||
@@ -1,384 +0,0 @@
|
||||
<template>
|
||||
<Dialog
|
||||
:model-value="open"
|
||||
:title="dialogTitle"
|
||||
:description="dialogDescription"
|
||||
:icon="dialogIcon"
|
||||
size="md"
|
||||
@update:model-value="handleDialogUpdate"
|
||||
>
|
||||
<form
|
||||
class="space-y-4"
|
||||
@submit.prevent="handleSubmit"
|
||||
>
|
||||
<!-- 模式选择(仅创建时显示) -->
|
||||
<div
|
||||
v-if="!isEditMode"
|
||||
class="space-y-2"
|
||||
>
|
||||
<Label>创建类型 *</Label>
|
||||
<div class="grid grid-cols-2 gap-3">
|
||||
<button
|
||||
type="button"
|
||||
class="p-3 rounded-lg border-2 text-left transition-all"
|
||||
:class="[
|
||||
form.mapping_type === 'alias'
|
||||
? 'border-primary bg-primary/5'
|
||||
: 'border-border hover:border-primary/50'
|
||||
]"
|
||||
@click="form.mapping_type = 'alias'"
|
||||
>
|
||||
<div class="font-medium text-sm">
|
||||
别名
|
||||
</div>
|
||||
<div class="text-xs text-muted-foreground mt-1">
|
||||
名称简写,按目标模型计费
|
||||
</div>
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
class="p-3 rounded-lg border-2 text-left transition-all"
|
||||
:class="[
|
||||
form.mapping_type === 'mapping'
|
||||
? 'border-primary bg-primary/5'
|
||||
: 'border-border hover:border-primary/50'
|
||||
]"
|
||||
@click="form.mapping_type = 'mapping'"
|
||||
>
|
||||
<div class="font-medium text-sm">
|
||||
映射
|
||||
</div>
|
||||
<div class="text-xs text-muted-foreground mt-1">
|
||||
模型降级,按源模型计费
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 模式说明 -->
|
||||
<div class="rounded-lg border border-border bg-muted/50 p-3 text-sm">
|
||||
<p class="text-foreground font-medium mb-1">
|
||||
{{ form.mapping_type === 'alias' ? '别名模式' : '映射模式' }}
|
||||
</p>
|
||||
<p class="text-muted-foreground text-xs">
|
||||
{{ form.mapping_type === 'alias'
|
||||
? '用户请求此别名时,会路由到目标模型,并按目标模型价格计费。'
|
||||
: '将源模型的请求转发到目标模型处理,按源模型价格计费。' }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Provider 选择/作用范围 -->
|
||||
<div
|
||||
v-if="showProviderSelect"
|
||||
class="space-y-2"
|
||||
>
|
||||
<Label>作用范围</Label>
|
||||
<!-- 固定 Provider 时显示只读 -->
|
||||
<div
|
||||
v-if="fixedProvider"
|
||||
class="px-3 py-2 border rounded-md bg-muted/50 text-sm"
|
||||
>
|
||||
仅 {{ fixedProvider.display_name || fixedProvider.name }}
|
||||
</div>
|
||||
<!-- 否则显示可选择的下拉 -->
|
||||
<Select
|
||||
v-else
|
||||
v-model:open="providerSelectOpen"
|
||||
:model-value="form.provider_id || 'global'"
|
||||
@update:model-value="handleProviderChange"
|
||||
>
|
||||
<SelectTrigger class="w-full">
|
||||
<SelectValue placeholder="选择作用范围" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="global">
|
||||
全局(所有 Provider)
|
||||
</SelectItem>
|
||||
<SelectItem
|
||||
v-for="p in providers"
|
||||
:key="p.id"
|
||||
:value="p.id"
|
||||
>
|
||||
仅 {{ p.display_name || p.name }}
|
||||
</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
<!-- 别名模式:别名名称 -->
|
||||
<div
|
||||
v-if="form.mapping_type === 'alias'"
|
||||
class="space-y-2"
|
||||
>
|
||||
<Label for="alias-name">别名名称 *</Label>
|
||||
<Input
|
||||
id="alias-name"
|
||||
v-model="form.alias"
|
||||
placeholder="如:sonnet, opus"
|
||||
:disabled="isEditMode"
|
||||
required
|
||||
/>
|
||||
<p class="text-xs text-muted-foreground">
|
||||
{{ isEditMode ? '创建后不可修改' : '用户将使用此名称请求模型' }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- 映射模式:选择源模型 -->
|
||||
<div
|
||||
v-else
|
||||
class="space-y-2"
|
||||
>
|
||||
<Label>源模型 (用户请求的模型) *</Label>
|
||||
<Select
|
||||
v-model:open="sourceModelSelectOpen"
|
||||
:model-value="form.alias"
|
||||
:disabled="isEditMode"
|
||||
@update:model-value="form.alias = $event"
|
||||
>
|
||||
<SelectTrigger
|
||||
class="w-full"
|
||||
:class="{ 'opacity-50': isEditMode }"
|
||||
>
|
||||
<SelectValue placeholder="请选择源模型" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem
|
||||
v-for="model in availableSourceModels"
|
||||
:key="model.id"
|
||||
:value="model.name"
|
||||
>
|
||||
{{ model.display_name }} ({{ model.name }})
|
||||
</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<p class="text-xs text-muted-foreground">
|
||||
{{ isEditMode ? '创建后不可修改' : '选择要被映射的源模型,计费将按此模型价格' }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- 目标模型选择 -->
|
||||
<div class="space-y-2">
|
||||
<Label>
|
||||
{{ form.mapping_type === 'alias' ? '目标模型 *' : '目标模型 (实际处理请求) *' }}
|
||||
</Label>
|
||||
<!-- 固定目标模型时显示只读信息 -->
|
||||
<div
|
||||
v-if="fixedTargetModel"
|
||||
class="px-3 py-2 border rounded-md bg-muted/50"
|
||||
>
|
||||
<span class="font-medium">{{ fixedTargetModel.display_name }}</span>
|
||||
<span class="text-muted-foreground ml-1">({{ fixedTargetModel.name }})</span>
|
||||
</div>
|
||||
<!-- 否则显示下拉选择 -->
|
||||
<Select
|
||||
v-else
|
||||
v-model:open="targetModelSelectOpen"
|
||||
:model-value="form.global_model_id"
|
||||
@update:model-value="form.global_model_id = $event"
|
||||
>
|
||||
<SelectTrigger class="w-full">
|
||||
<SelectValue placeholder="请选择模型" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem
|
||||
v-for="model in availableTargetModels"
|
||||
:key="model.id"
|
||||
:value="model.id"
|
||||
>
|
||||
{{ model.display_name }} ({{ model.name }})
|
||||
</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
</form>
|
||||
|
||||
<template #footer>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
@click="handleCancel"
|
||||
>
|
||||
取消
|
||||
</Button>
|
||||
<Button
|
||||
:disabled="submitting"
|
||||
@click="handleSubmit"
|
||||
>
|
||||
<Loader2
|
||||
v-if="submitting"
|
||||
class="w-4 h-4 mr-2 animate-spin"
|
||||
/>
|
||||
{{ isEditMode ? '保存' : '创建' }}
|
||||
</Button>
|
||||
</template>
|
||||
</Dialog>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed } from 'vue'
|
||||
import { Loader2, Tag, SquarePen } from 'lucide-vue-next'
|
||||
import { Dialog, Select, SelectTrigger, SelectValue, SelectContent, SelectItem } from '@/components/ui'
|
||||
import Button from '@/components/ui/button.vue'
|
||||
import Input from '@/components/ui/input.vue'
|
||||
import Label from '@/components/ui/label.vue'
|
||||
import { useToast } from '@/composables/useToast'
|
||||
import { useFormDialog } from '@/composables/useFormDialog'
|
||||
import type { ModelAlias, CreateModelAliasRequest, UpdateModelAliasRequest } from '@/api/endpoints/aliases'
|
||||
import type { GlobalModelResponse } from '@/api/global-models'
|
||||
|
||||
export interface ProviderOption {
|
||||
id: string
|
||||
name: string
|
||||
display_name?: string
|
||||
}
|
||||
|
||||
interface AliasFormData {
|
||||
alias: string
|
||||
global_model_id: string
|
||||
provider_id: string | null
|
||||
mapping_type: 'alias' | 'mapping'
|
||||
is_active: boolean
|
||||
}
|
||||
|
||||
const props = withDefaults(defineProps<{
|
||||
open: boolean
|
||||
editingAlias?: ModelAlias | null
|
||||
globalModels: GlobalModelResponse[]
|
||||
providers?: ProviderOption[]
|
||||
fixedTargetModel?: GlobalModelResponse | null // 用于从模型详情抽屉打开时固定目标模型
|
||||
fixedProvider?: ProviderOption | null // 用于 Provider 特定别名固定 Provider
|
||||
showProviderSelect?: boolean // 是否显示 Provider 选择(默认 true)
|
||||
}>(), {
|
||||
editingAlias: null,
|
||||
providers: () => [],
|
||||
fixedTargetModel: null,
|
||||
fixedProvider: null,
|
||||
showProviderSelect: true
|
||||
})
|
||||
|
||||
const emit = defineEmits<{
|
||||
'update:open': [value: boolean]
|
||||
'submit': [data: CreateModelAliasRequest | UpdateModelAliasRequest, isEdit: boolean]
|
||||
}>()
|
||||
|
||||
const { error: showError } = useToast()
|
||||
|
||||
// 状态
|
||||
const submitting = ref(false)
|
||||
const providerSelectOpen = ref(false)
|
||||
const sourceModelSelectOpen = ref(false)
|
||||
const targetModelSelectOpen = ref(false)
|
||||
const form = ref<AliasFormData>({
|
||||
alias: '',
|
||||
global_model_id: '',
|
||||
provider_id: null,
|
||||
mapping_type: 'alias',
|
||||
is_active: true,
|
||||
})
|
||||
|
||||
// 处理 Provider 选择变化
|
||||
function handleProviderChange(value: string) {
|
||||
form.value.provider_id = value === 'global' ? null : value
|
||||
}
|
||||
|
||||
// 重置表单
|
||||
function resetForm() {
|
||||
form.value = {
|
||||
alias: '',
|
||||
global_model_id: props.fixedTargetModel?.id || '',
|
||||
provider_id: props.fixedProvider?.id || null,
|
||||
mapping_type: 'alias',
|
||||
is_active: true,
|
||||
}
|
||||
}
|
||||
|
||||
// 加载别名数据(编辑模式)
|
||||
function loadAliasData() {
|
||||
if (!props.editingAlias) return
|
||||
form.value = {
|
||||
alias: props.editingAlias.alias,
|
||||
global_model_id: props.editingAlias.global_model_id,
|
||||
provider_id: props.editingAlias.provider_id,
|
||||
mapping_type: props.editingAlias.mapping_type || 'alias',
|
||||
is_active: props.editingAlias.is_active,
|
||||
}
|
||||
}
|
||||
|
||||
// 使用 useFormDialog 统一处理对话框逻辑
|
||||
const { isEditMode, handleDialogUpdate, handleCancel } = useFormDialog({
|
||||
isOpen: () => props.open,
|
||||
entity: () => props.editingAlias,
|
||||
isLoading: submitting,
|
||||
onClose: () => emit('update:open', false),
|
||||
loadData: loadAliasData,
|
||||
resetForm,
|
||||
})
|
||||
|
||||
// 对话框标题
|
||||
const dialogTitle = computed(() => {
|
||||
if (isEditMode.value) {
|
||||
return form.value.mapping_type === 'mapping' ? '编辑映射' : '编辑别名'
|
||||
}
|
||||
if (props.fixedProvider) {
|
||||
return `创建 ${props.fixedProvider.display_name || props.fixedProvider.name} 特定别名/映射`
|
||||
}
|
||||
return '创建别名/映射'
|
||||
})
|
||||
|
||||
// 对话框描述
|
||||
const dialogDescription = computed(() => {
|
||||
if (isEditMode.value) {
|
||||
return form.value.mapping_type === 'mapping' ? '修改模型映射配置' : '修改别名设置'
|
||||
}
|
||||
return '为模型创建别名或映射规则'
|
||||
})
|
||||
|
||||
// 对话框图标
|
||||
const dialogIcon = computed(() => isEditMode.value ? SquarePen : Tag)
|
||||
|
||||
// 映射模式下可选的源模型(排除已选择的目标模型)
|
||||
const availableSourceModels = computed(() => {
|
||||
return props.globalModels.filter(m => m.id !== form.value.global_model_id)
|
||||
})
|
||||
|
||||
// 可选的目标模型(映射模式下排除已选择的源模型)
|
||||
const availableTargetModels = computed(() => {
|
||||
if (form.value.mapping_type === 'mapping' && form.value.alias) {
|
||||
// 找到源模型对应的 GlobalModel
|
||||
const sourceModel = props.globalModels.find(m => m.name === form.value.alias)
|
||||
if (sourceModel) {
|
||||
return props.globalModels.filter(m => m.id !== sourceModel.id)
|
||||
}
|
||||
}
|
||||
return props.globalModels
|
||||
})
|
||||
|
||||
// 提交表单
|
||||
async function handleSubmit() {
|
||||
if (!form.value.alias) {
|
||||
showError(form.value.mapping_type === 'alias' ? '请输入别名名称' : '请选择源模型', '错误')
|
||||
return
|
||||
}
|
||||
|
||||
const targetModelId = props.fixedTargetModel?.id || form.value.global_model_id
|
||||
if (!targetModelId) {
|
||||
showError('请选择目标模型', '错误')
|
||||
return
|
||||
}
|
||||
|
||||
submitting.value = true
|
||||
try {
|
||||
const data: CreateModelAliasRequest | UpdateModelAliasRequest = {
|
||||
alias: form.value.alias,
|
||||
global_model_id: targetModelId,
|
||||
provider_id: props.fixedProvider?.id || form.value.provider_id,
|
||||
mapping_type: form.value.mapping_type,
|
||||
is_active: form.value.is_active,
|
||||
}
|
||||
|
||||
emit('submit', data, !!props.editingAlias)
|
||||
} finally {
|
||||
submitting.value = false
|
||||
}
|
||||
}
|
||||
</script>
|
||||
@@ -104,19 +104,6 @@
|
||||
<span class="hidden sm:inline">关联提供商</span>
|
||||
<span class="sm:hidden">提供商</span>
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
class="flex-1 px-2 sm:px-4 py-2 text-xs sm:text-sm font-medium rounded-md transition-all duration-200"
|
||||
:class="[
|
||||
detailTab === 'aliases'
|
||||
? 'bg-primary text-primary-foreground shadow-sm'
|
||||
: 'text-muted-foreground hover:text-foreground hover:bg-background/50'
|
||||
]"
|
||||
@click="detailTab = 'aliases'"
|
||||
>
|
||||
<span class="hidden sm:inline">别名/映射</span>
|
||||
<span class="sm:hidden">别名</span>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Tab 内容 -->
|
||||
@@ -684,236 +671,6 @@
|
||||
</div>
|
||||
</Card>
|
||||
</div>
|
||||
|
||||
<!-- Tab 3: 别名 -->
|
||||
<div v-show="detailTab === 'aliases'">
|
||||
<Card class="overflow-hidden">
|
||||
<!-- 标题栏 -->
|
||||
<div class="px-4 py-3 border-b border-border/60">
|
||||
<div class="flex items-center justify-between gap-4">
|
||||
<div>
|
||||
<h4 class="text-sm font-semibold">
|
||||
别名与映射
|
||||
</h4>
|
||||
</div>
|
||||
<div class="flex items-center gap-2">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-8 w-8"
|
||||
title="添加别名/映射"
|
||||
@click="$emit('addAlias')"
|
||||
>
|
||||
<Plus class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-8 w-8"
|
||||
title="刷新"
|
||||
@click="$emit('refreshAliases')"
|
||||
>
|
||||
<RefreshCw
|
||||
class="w-3.5 h-3.5"
|
||||
:class="loadingAliases ? 'animate-spin' : ''"
|
||||
/>
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 表格内容 -->
|
||||
<div
|
||||
v-if="loadingAliases"
|
||||
class="flex items-center justify-center py-12"
|
||||
>
|
||||
<Loader2 class="w-6 h-6 animate-spin text-primary" />
|
||||
</div>
|
||||
|
||||
<template v-else-if="aliases.length > 0">
|
||||
<!-- 桌面端表格 -->
|
||||
<Table class="hidden sm:table">
|
||||
<TableHeader>
|
||||
<TableRow class="border-b border-border/60 hover:bg-transparent">
|
||||
<TableHead class="h-10 font-semibold">
|
||||
别名
|
||||
</TableHead>
|
||||
<TableHead class="w-[80px] h-10 font-semibold">
|
||||
类型
|
||||
</TableHead>
|
||||
<TableHead class="w-[100px] h-10 font-semibold">
|
||||
作用域
|
||||
</TableHead>
|
||||
<TableHead class="w-[100px] h-10 font-semibold text-center">
|
||||
操作
|
||||
</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
<TableRow
|
||||
v-for="alias in aliases"
|
||||
:key="alias.id"
|
||||
class="border-b border-border/40 hover:bg-muted/30 transition-colors"
|
||||
>
|
||||
<TableCell class="py-3">
|
||||
<div class="flex items-center gap-2">
|
||||
<span
|
||||
class="w-2 h-2 rounded-full shrink-0"
|
||||
:class="alias.is_active ? 'bg-green-500' : 'bg-gray-300'"
|
||||
:title="alias.is_active ? '活跃' : '停用'"
|
||||
/>
|
||||
<code class="text-sm font-medium bg-muted px-1.5 py-0.5 rounded">{{ alias.alias }}</code>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell class="py-3">
|
||||
<Badge
|
||||
variant="secondary"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ alias.mapping_type === 'mapping' ? '映射' : '别名' }}
|
||||
</Badge>
|
||||
</TableCell>
|
||||
<TableCell class="py-3">
|
||||
<Badge
|
||||
v-if="alias.provider_id"
|
||||
variant="outline"
|
||||
class="text-xs truncate max-w-[90px]"
|
||||
:title="alias.provider_name || 'Provider'"
|
||||
>
|
||||
{{ alias.provider_name || 'Provider' }}
|
||||
</Badge>
|
||||
<Badge
|
||||
v-else
|
||||
variant="default"
|
||||
class="text-xs"
|
||||
>
|
||||
全局
|
||||
</Badge>
|
||||
</TableCell>
|
||||
<TableCell class="py-3 text-center">
|
||||
<div class="flex items-center justify-center gap-0.5">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
title="编辑"
|
||||
@click="$emit('editAlias', alias)"
|
||||
>
|
||||
<Edit class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
:title="alias.is_active ? '停用' : '启用'"
|
||||
@click="$emit('toggleAliasStatus', alias)"
|
||||
>
|
||||
<Power class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
title="删除"
|
||||
@click="$emit('deleteAlias', alias)"
|
||||
>
|
||||
<Trash2 class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
</TableBody>
|
||||
</Table>
|
||||
|
||||
<!-- 移动端卡片列表 -->
|
||||
<div class="sm:hidden divide-y divide-border/40">
|
||||
<div
|
||||
v-for="alias in aliases"
|
||||
:key="alias.id"
|
||||
class="p-4 space-y-2"
|
||||
>
|
||||
<div class="flex items-start justify-between gap-3">
|
||||
<div class="flex items-center gap-2 min-w-0 flex-1">
|
||||
<span
|
||||
class="w-2 h-2 rounded-full shrink-0"
|
||||
:class="alias.is_active ? 'bg-green-500' : 'bg-gray-300'"
|
||||
/>
|
||||
<code class="text-sm font-medium bg-muted px-1.5 py-0.5 rounded truncate">{{ alias.alias }}</code>
|
||||
</div>
|
||||
<div class="flex items-center gap-1 shrink-0">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
@click="$emit('editAlias', alias)"
|
||||
>
|
||||
<Edit class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
@click="$emit('toggleAliasStatus', alias)"
|
||||
>
|
||||
<Power class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
@click="$emit('deleteAlias', alias)"
|
||||
>
|
||||
<Trash2 class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex items-center gap-2">
|
||||
<Badge
|
||||
variant="secondary"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ alias.mapping_type === 'mapping' ? '映射' : '别名' }}
|
||||
</Badge>
|
||||
<Badge
|
||||
v-if="alias.provider_id"
|
||||
variant="outline"
|
||||
class="text-xs truncate max-w-[120px]"
|
||||
>
|
||||
{{ alias.provider_name || 'Provider' }}
|
||||
</Badge>
|
||||
<Badge
|
||||
v-else
|
||||
variant="default"
|
||||
class="text-xs"
|
||||
>
|
||||
全局
|
||||
</Badge>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<div
|
||||
v-else
|
||||
class="text-center py-12"
|
||||
>
|
||||
<!-- 空状态 -->
|
||||
<Tag class="w-12 h-12 mx-auto text-muted-foreground/30 mb-3" />
|
||||
<p class="text-sm text-muted-foreground">
|
||||
暂无别名或映射
|
||||
</p>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
class="mt-4"
|
||||
@click="$emit('addAlias')"
|
||||
>
|
||||
<Plus class="w-4 h-4 mr-1" />
|
||||
添加别名/映射
|
||||
</Button>
|
||||
</div>
|
||||
</Card>
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
</div>
|
||||
@@ -931,7 +688,6 @@ import {
|
||||
Zap,
|
||||
Image,
|
||||
Building2,
|
||||
Tag,
|
||||
Plus,
|
||||
Edit,
|
||||
Trash2,
|
||||
@@ -955,13 +711,11 @@ import TableCell from '@/components/ui/table-cell.vue'
|
||||
|
||||
// 使用外部类型定义
|
||||
import type { GlobalModelResponse } from '@/api/global-models'
|
||||
import type { ModelAlias } from '@/api/endpoints/aliases'
|
||||
import type { TieredPricingConfig, PricingTier } from '@/api/endpoints/types'
|
||||
import type { CapabilityDefinition } from '@/api/endpoints'
|
||||
|
||||
const props = withDefaults(defineProps<Props>(), {
|
||||
loadingProviders: false,
|
||||
loadingAliases: false,
|
||||
hasBlockingDialogOpen: false,
|
||||
})
|
||||
const emit = defineEmits<{
|
||||
@@ -973,11 +727,6 @@ const emit = defineEmits<{
|
||||
'deleteProvider': [provider: any]
|
||||
'toggleProviderStatus': [provider: any]
|
||||
'refreshProviders': []
|
||||
'addAlias': []
|
||||
'editAlias': [alias: ModelAlias]
|
||||
'toggleAliasStatus': [alias: ModelAlias]
|
||||
'deleteAlias': [alias: ModelAlias]
|
||||
'refreshAliases': []
|
||||
}>()
|
||||
const { success: showSuccess, error: showError } = useToast()
|
||||
|
||||
@@ -985,9 +734,7 @@ interface Props {
|
||||
model: GlobalModelResponse | null
|
||||
open: boolean
|
||||
providers: any[]
|
||||
aliases: ModelAlias[]
|
||||
loadingProviders?: boolean
|
||||
loadingAliases?: boolean
|
||||
hasBlockingDialogOpen?: boolean
|
||||
capabilities?: CapabilityDefinition[]
|
||||
}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
export { default as GlobalModelFormDialog } from './GlobalModelFormDialog.vue'
|
||||
export { default as AliasDialog } from './AliasDialog.vue'
|
||||
export { default as ModelDetailDrawer } from './ModelDetailDrawer.vue'
|
||||
export { default as TieredPricingEditor } from './TieredPricingEditor.vue'
|
||||
|
||||
333
frontend/src/features/providers/components/ModelAliasDialog.vue
Normal file
333
frontend/src/features/providers/components/ModelAliasDialog.vue
Normal file
@@ -0,0 +1,333 @@
|
||||
<template>
|
||||
<Dialog
|
||||
:model-value="open"
|
||||
title="管理模型名称映射"
|
||||
description="配置 Provider 对此模型使用的名称变体,系统会按优先级顺序选择"
|
||||
:icon="Tag"
|
||||
size="lg"
|
||||
@update:model-value="handleClose"
|
||||
>
|
||||
<div class="space-y-4">
|
||||
<!-- 模型信息 -->
|
||||
<div class="rounded-lg border bg-muted/30 p-3">
|
||||
<p class="font-medium">
|
||||
{{ model?.global_model_display_name || model?.provider_model_name }}
|
||||
</p>
|
||||
<p class="text-sm text-muted-foreground font-mono">
|
||||
主名称: {{ model?.provider_model_name }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- 别名列表 -->
|
||||
<div class="space-y-3">
|
||||
<div class="flex items-center justify-between">
|
||||
<Label class="text-sm font-medium">名称映射</Label>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
@click="addAlias"
|
||||
>
|
||||
<Plus class="w-4 h-4 mr-1" />
|
||||
添加
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<!-- 提示信息 -->
|
||||
<div
|
||||
v-if="aliases.length > 0"
|
||||
class="flex items-center gap-2 px-3 py-2 text-xs text-muted-foreground bg-muted/30 rounded-md"
|
||||
>
|
||||
<Info class="w-3.5 h-3.5 shrink-0" />
|
||||
<span>拖拽调整顺序,点击序号可编辑(相同数字为同级,负载均衡)</span>
|
||||
</div>
|
||||
|
||||
<div
|
||||
v-if="aliases.length > 0"
|
||||
class="space-y-2"
|
||||
>
|
||||
<div
|
||||
v-for="(alias, index) in aliases"
|
||||
:key="index"
|
||||
class="group flex items-center gap-3 px-3 py-2.5 rounded-lg border transition-all duration-200"
|
||||
:class="[
|
||||
draggedIndex === index
|
||||
? 'border-primary/50 bg-primary/5 shadow-md scale-[1.01]'
|
||||
: dragOverIndex === index
|
||||
? 'border-primary/30 bg-primary/5'
|
||||
: 'border-border/50 bg-background hover:border-border hover:bg-muted/30'
|
||||
]"
|
||||
draggable="true"
|
||||
@dragstart="handleDragStart(index, $event)"
|
||||
@dragend="handleDragEnd"
|
||||
@dragover.prevent="handleDragOver(index)"
|
||||
@dragleave="handleDragLeave"
|
||||
@drop="handleDrop(index)"
|
||||
>
|
||||
<!-- 拖拽手柄 -->
|
||||
<div class="cursor-grab active:cursor-grabbing p-1 rounded hover:bg-muted text-muted-foreground/40 group-hover:text-muted-foreground transition-colors shrink-0">
|
||||
<GripVertical class="w-4 h-4" />
|
||||
</div>
|
||||
|
||||
<!-- 可编辑优先级 -->
|
||||
<div class="shrink-0">
|
||||
<input
|
||||
v-if="editingPriorityIndex === index"
|
||||
type="number"
|
||||
min="1"
|
||||
:value="alias.priority"
|
||||
class="w-8 h-6 rounded-md bg-background border border-primary text-xs font-medium text-center focus:outline-none [appearance:textfield] [&::-webkit-outer-spin-button]:appearance-none [&::-webkit-inner-spin-button]:appearance-none"
|
||||
autofocus
|
||||
@blur="finishEditPriority(index, $event)"
|
||||
@keydown.enter="($event.target as HTMLInputElement).blur()"
|
||||
@keydown.escape="cancelEditPriority"
|
||||
>
|
||||
<div
|
||||
v-else
|
||||
class="w-6 h-6 rounded-md bg-muted/50 flex items-center justify-center text-xs font-medium text-muted-foreground cursor-pointer hover:bg-primary/10 hover:text-primary transition-colors"
|
||||
title="点击编辑优先级,相同数字为同级(负载均衡)"
|
||||
@click.stop="startEditPriority(index)"
|
||||
>
|
||||
{{ alias.priority }}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 别名输入框 -->
|
||||
<Input
|
||||
v-model="alias.name"
|
||||
placeholder="映射名称,如 Claude-Sonnet-4.5"
|
||||
class="flex-1"
|
||||
/>
|
||||
|
||||
<!-- 删除按钮 -->
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="shrink-0 text-destructive hover:text-destructive h-8 w-8"
|
||||
@click="removeAlias(index)"
|
||||
>
|
||||
<X class="w-4 h-4" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div
|
||||
v-else
|
||||
class="text-center py-6 text-muted-foreground border rounded-lg border-dashed"
|
||||
>
|
||||
<Tag class="w-8 h-8 mx-auto mb-2 opacity-50" />
|
||||
<p class="text-sm">未配置映射</p>
|
||||
<p class="text-xs mt-1">将只使用主模型名称</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<template #footer>
|
||||
<Button
|
||||
variant="outline"
|
||||
@click="handleClose(false)"
|
||||
>
|
||||
取消
|
||||
</Button>
|
||||
<Button
|
||||
:disabled="submitting"
|
||||
@click="handleSubmit"
|
||||
>
|
||||
<Loader2
|
||||
v-if="submitting"
|
||||
class="w-4 h-4 mr-2 animate-spin"
|
||||
/>
|
||||
保存
|
||||
</Button>
|
||||
</template>
|
||||
</Dialog>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, watch } from 'vue'
|
||||
import { Tag, Plus, X, Loader2, GripVertical, Info } from 'lucide-vue-next'
|
||||
import { Dialog, Button, Input, Label } from '@/components/ui'
|
||||
import { useToast } from '@/composables/useToast'
|
||||
import { updateModel } from '@/api/endpoints/models'
|
||||
import type { Model, ProviderModelAlias } from '@/api/endpoints'
|
||||
|
||||
interface Props {
|
||||
open: boolean
|
||||
providerId: string
|
||||
model: Model | null
|
||||
}
|
||||
|
||||
const props = defineProps<Props>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
'update:open': [value: boolean]
|
||||
'saved': []
|
||||
}>()
|
||||
|
||||
const { error: showError, success: showSuccess } = useToast()
|
||||
|
||||
const submitting = ref(false)
|
||||
const aliases = ref<ProviderModelAlias[]>([])
|
||||
|
||||
// 拖拽状态
|
||||
const draggedIndex = ref<number | null>(null)
|
||||
const dragOverIndex = ref<number | null>(null)
|
||||
|
||||
// 优先级编辑状态
|
||||
const editingPriorityIndex = ref<number | null>(null)
|
||||
|
||||
// 监听 open 变化
|
||||
watch(() => props.open, (newOpen) => {
|
||||
if (newOpen && props.model) {
|
||||
// 加载现有别名配置
|
||||
if (props.model.provider_model_aliases && Array.isArray(props.model.provider_model_aliases)) {
|
||||
aliases.value = JSON.parse(JSON.stringify(props.model.provider_model_aliases))
|
||||
} else {
|
||||
aliases.value = []
|
||||
}
|
||||
// 重置状态
|
||||
editingPriorityIndex.value = null
|
||||
draggedIndex.value = null
|
||||
dragOverIndex.value = null
|
||||
}
|
||||
})
|
||||
|
||||
// 添加别名
|
||||
function addAlias() {
|
||||
// 新别名优先级为当前最大优先级 + 1,或者默认为 1
|
||||
const maxPriority = aliases.value.length > 0
|
||||
? Math.max(...aliases.value.map(a => a.priority))
|
||||
: 0
|
||||
aliases.value.push({ name: '', priority: maxPriority + 1 })
|
||||
}
|
||||
|
||||
// 移除别名
|
||||
function removeAlias(index: number) {
|
||||
aliases.value.splice(index, 1)
|
||||
}
|
||||
|
||||
// ===== 拖拽排序 =====
|
||||
function handleDragStart(index: number, event: DragEvent) {
|
||||
draggedIndex.value = index
|
||||
if (event.dataTransfer) {
|
||||
event.dataTransfer.effectAllowed = 'move'
|
||||
}
|
||||
}
|
||||
|
||||
function handleDragEnd() {
|
||||
draggedIndex.value = null
|
||||
dragOverIndex.value = null
|
||||
}
|
||||
|
||||
function handleDragOver(index: number) {
|
||||
if (draggedIndex.value !== null && draggedIndex.value !== index) {
|
||||
dragOverIndex.value = index
|
||||
}
|
||||
}
|
||||
|
||||
function handleDragLeave() {
|
||||
dragOverIndex.value = null
|
||||
}
|
||||
|
||||
function handleDrop(targetIndex: number) {
|
||||
const dragIndex = draggedIndex.value
|
||||
if (dragIndex === null || dragIndex === targetIndex) {
|
||||
dragOverIndex.value = null
|
||||
return
|
||||
}
|
||||
|
||||
const items = [...aliases.value]
|
||||
const draggedItem = items[dragIndex]
|
||||
|
||||
// 记录每个别名的原始优先级(在修改前)
|
||||
const originalPriorityMap = new Map<number, number>()
|
||||
items.forEach((alias, idx) => {
|
||||
originalPriorityMap.set(idx, alias.priority)
|
||||
})
|
||||
|
||||
// 重排数组
|
||||
items.splice(dragIndex, 1)
|
||||
items.splice(targetIndex, 0, draggedItem)
|
||||
|
||||
// 按新顺序为每个组分配新的优先级
|
||||
// 同组的别名保持相同的优先级(被拖动的别名单独成组)
|
||||
const groupNewPriority = new Map<number, number>() // 原优先级 -> 新优先级
|
||||
let currentPriority = 1
|
||||
|
||||
// 找到被拖动项在原数组中的索引对应的原始优先级
|
||||
const draggedOriginalPriority = originalPriorityMap.get(dragIndex)!
|
||||
|
||||
items.forEach((alias, newIdx) => {
|
||||
// 找到这个别名在原数组中的索引
|
||||
const originalIdx = aliases.value.findIndex(a => a === alias)
|
||||
const originalPriority = originalIdx >= 0 ? originalPriorityMap.get(originalIdx)! : alias.priority
|
||||
|
||||
if (alias === draggedItem) {
|
||||
// 被拖动的别名是独立的新组,获得当前优先级
|
||||
alias.priority = currentPriority
|
||||
currentPriority++
|
||||
} else {
|
||||
if (groupNewPriority.has(originalPriority)) {
|
||||
// 这个组已经分配过优先级,使用相同的值
|
||||
alias.priority = groupNewPriority.get(originalPriority)!
|
||||
} else {
|
||||
// 这个组第一次出现,分配新优先级
|
||||
groupNewPriority.set(originalPriority, currentPriority)
|
||||
alias.priority = currentPriority
|
||||
currentPriority++
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
aliases.value = items
|
||||
draggedIndex.value = null
|
||||
dragOverIndex.value = null
|
||||
}
|
||||
|
||||
// ===== 优先级编辑 =====
|
||||
function startEditPriority(index: number) {
|
||||
editingPriorityIndex.value = index
|
||||
}
|
||||
|
||||
function finishEditPriority(index: number, event: FocusEvent) {
|
||||
const input = event.target as HTMLInputElement
|
||||
const newPriority = parseInt(input.value) || 1
|
||||
aliases.value[index].priority = Math.max(1, newPriority)
|
||||
editingPriorityIndex.value = null
|
||||
}
|
||||
|
||||
function cancelEditPriority() {
|
||||
editingPriorityIndex.value = null
|
||||
}
|
||||
|
||||
// 关闭对话框
|
||||
function handleClose(value: boolean) {
|
||||
if (!submitting.value) {
|
||||
emit('update:open', value)
|
||||
}
|
||||
}
|
||||
|
||||
// 提交保存
|
||||
async function handleSubmit() {
|
||||
if (submitting.value || !props.model) return
|
||||
|
||||
submitting.value = true
|
||||
try {
|
||||
// 过滤掉空的别名
|
||||
const validAliases = aliases.value.filter(a => a.name.trim())
|
||||
|
||||
await updateModel(props.providerId, props.model.id, {
|
||||
provider_model_aliases: validAliases.length > 0 ? validAliases : null
|
||||
})
|
||||
|
||||
showSuccess('映射配置已保存')
|
||||
emit('update:open', false)
|
||||
emit('saved')
|
||||
} catch (err: any) {
|
||||
showError(err.response?.data?.detail || '保存失败', '错误')
|
||||
} finally {
|
||||
submitting.value = false
|
||||
}
|
||||
}
|
||||
</script>
|
||||
@@ -526,14 +526,7 @@
|
||||
@edit-model="handleEditModel"
|
||||
@delete-model="handleDeleteModel"
|
||||
@batch-assign="handleBatchAssign"
|
||||
/>
|
||||
|
||||
<!-- 模型映射 -->
|
||||
<MappingsTab
|
||||
v-if="provider"
|
||||
:key="`mappings-${provider.id}`"
|
||||
:provider="provider"
|
||||
@refresh="handleRelatedDataRefresh"
|
||||
@manage-alias="handleManageAlias"
|
||||
/>
|
||||
</div>
|
||||
</template>
|
||||
@@ -636,6 +629,16 @@
|
||||
@update:open="batchAssignDialogOpen = $event"
|
||||
@changed="handleBatchAssignChanged"
|
||||
/>
|
||||
|
||||
<!-- 模型别名管理对话框 -->
|
||||
<ModelAliasDialog
|
||||
v-if="open && provider"
|
||||
:open="aliasDialogOpen"
|
||||
:provider-id="provider.id"
|
||||
:model="aliasEditingModel"
|
||||
@update:open="aliasDialogOpen = $event"
|
||||
@saved="handleAliasSaved"
|
||||
/>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
@@ -663,9 +666,9 @@ import { getProvider, getProviderEndpoints } from '@/api/endpoints'
|
||||
import {
|
||||
KeyFormDialog,
|
||||
KeyAllowedModelsDialog,
|
||||
MappingsTab,
|
||||
ModelsTab,
|
||||
BatchAssignModelsDialog
|
||||
BatchAssignModelsDialog,
|
||||
ModelAliasDialog
|
||||
} from '@/features/providers/components'
|
||||
import EndpointFormDialog from '@/features/providers/components/EndpointFormDialog.vue'
|
||||
import ProviderModelFormDialog from '@/features/providers/components/ProviderModelFormDialog.vue'
|
||||
@@ -734,6 +737,10 @@ const deleteModelConfirmOpen = ref(false)
|
||||
const modelToDelete = ref<Model | null>(null)
|
||||
const batchAssignDialogOpen = ref(false)
|
||||
|
||||
// 别名管理相关状态
|
||||
const aliasDialogOpen = ref(false)
|
||||
const aliasEditingModel = ref<Model | null>(null)
|
||||
|
||||
// 拖动排序相关状态
|
||||
const dragState = ref({
|
||||
isDragging: false,
|
||||
@@ -755,7 +762,8 @@ const hasBlockingDialogOpen = computed(() =>
|
||||
deleteKeyConfirmOpen.value ||
|
||||
modelFormDialogOpen.value ||
|
||||
deleteModelConfirmOpen.value ||
|
||||
batchAssignDialogOpen.value
|
||||
batchAssignDialogOpen.value ||
|
||||
aliasDialogOpen.value
|
||||
)
|
||||
|
||||
// 监听 providerId 变化
|
||||
@@ -784,6 +792,7 @@ watch(() => props.open, (newOpen) => {
|
||||
keyAllowedModelsDialogOpen.value = false
|
||||
deleteKeyConfirmOpen.value = false
|
||||
batchAssignDialogOpen.value = false
|
||||
aliasDialogOpen.value = false
|
||||
|
||||
// 重置临时数据
|
||||
endpointToEdit.value = null
|
||||
@@ -1021,6 +1030,19 @@ async function handleBatchAssignChanged() {
|
||||
emit('refresh')
|
||||
}
|
||||
|
||||
// 处理管理映射 - 打开别名对话框
|
||||
function handleManageAlias(model: Model) {
|
||||
aliasEditingModel.value = model
|
||||
aliasDialogOpen.value = true
|
||||
}
|
||||
|
||||
// 处理别名保存完成
|
||||
async function handleAliasSaved() {
|
||||
aliasEditingModel.value = null
|
||||
await loadProvider()
|
||||
emit('refresh')
|
||||
}
|
||||
|
||||
// 处理模型保存完成
|
||||
async function handleModelSaved() {
|
||||
editingModel.value = null
|
||||
|
||||
@@ -7,6 +7,6 @@ export { default as ProviderModelFormDialog } from './ProviderModelFormDialog.vu
|
||||
export { default as ProviderDetailDrawer } from './ProviderDetailDrawer.vue'
|
||||
export { default as EndpointHealthTimeline } from './EndpointHealthTimeline.vue'
|
||||
export { default as BatchAssignModelsDialog } from './BatchAssignModelsDialog.vue'
|
||||
export { default as ModelAliasDialog } from './ModelAliasDialog.vue'
|
||||
|
||||
export { default as MappingsTab } from './provider-tabs/MappingsTab.vue'
|
||||
export { default as ModelsTab } from './provider-tabs/ModelsTab.vue'
|
||||
|
||||
@@ -1,310 +0,0 @@
|
||||
<template>
|
||||
<Card class="overflow-hidden">
|
||||
<!-- 标题头部 -->
|
||||
<div class="p-4 border-b border-border/60">
|
||||
<div class="flex items-center justify-between">
|
||||
<div class="flex items-center gap-2">
|
||||
<h3 class="text-sm font-semibold leading-none">
|
||||
别名与映射管理
|
||||
</h3>
|
||||
</div>
|
||||
<Button
|
||||
v-if="!hideAddButton"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
class="h-8"
|
||||
@click="openCreateDialog"
|
||||
>
|
||||
<Plus class="w-3.5 h-3.5 mr-1.5" />
|
||||
创建别名/映射
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 加载状态 -->
|
||||
<div
|
||||
v-if="loading"
|
||||
class="flex items-center justify-center py-12"
|
||||
>
|
||||
<div class="animate-spin rounded-full h-8 w-8 border-b-2 border-primary" />
|
||||
</div>
|
||||
|
||||
<!-- 别名列表 -->
|
||||
<div
|
||||
v-else-if="mappings.length > 0"
|
||||
class="overflow-x-auto"
|
||||
>
|
||||
<table class="w-full text-sm">
|
||||
<thead class="bg-muted/50 text-xs uppercase tracking-wide text-muted-foreground">
|
||||
<tr>
|
||||
<th class="text-left px-4 py-3 font-semibold">
|
||||
名称
|
||||
</th>
|
||||
<th class="text-left px-4 py-3 font-semibold w-24">
|
||||
类型
|
||||
</th>
|
||||
<th class="text-left px-4 py-3 font-semibold">
|
||||
指向模型
|
||||
</th>
|
||||
<th
|
||||
v-if="!hideAddButton"
|
||||
class="px-4 py-3 font-semibold w-28 text-center"
|
||||
>
|
||||
操作
|
||||
</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr
|
||||
v-for="mapping in mappings"
|
||||
:key="mapping.id"
|
||||
class="border-b border-border/40 last:border-b-0 hover:bg-muted/30 transition-colors"
|
||||
>
|
||||
<td class="px-4 py-3">
|
||||
<div class="flex items-center gap-2">
|
||||
<!-- 状态指示灯 -->
|
||||
<span
|
||||
class="w-2 h-2 rounded-full shrink-0"
|
||||
:class="mapping.is_active ? 'bg-green-500' : 'bg-gray-300'"
|
||||
:title="mapping.is_active ? '活跃' : '停用'"
|
||||
/>
|
||||
<span class="font-mono">{{ mapping.alias }}</span>
|
||||
</div>
|
||||
</td>
|
||||
<td class="px-4 py-3">
|
||||
<Badge
|
||||
variant="secondary"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ mapping.mapping_type === 'mapping' ? '映射' : '别名' }}
|
||||
</Badge>
|
||||
</td>
|
||||
<td class="px-4 py-3">
|
||||
{{ mapping.global_model_display_name || mapping.global_model_name }}
|
||||
</td>
|
||||
<td
|
||||
v-if="!hideAddButton"
|
||||
class="px-4 py-3"
|
||||
>
|
||||
<div class="flex justify-center gap-1.5">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-8 w-8"
|
||||
title="编辑"
|
||||
@click="openEditDialog(mapping)"
|
||||
>
|
||||
<Edit class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-8 w-8"
|
||||
:disabled="togglingId === mapping.id"
|
||||
:title="mapping.is_active ? '点击停用' : '点击启用'"
|
||||
@click="toggleActive(mapping)"
|
||||
>
|
||||
<Power class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-8 w-8 text-destructive hover:text-destructive"
|
||||
title="删除"
|
||||
@click="confirmDelete(mapping)"
|
||||
>
|
||||
<Trash2 class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<!-- 空状态 -->
|
||||
<div
|
||||
v-else
|
||||
class="p-8 text-center text-muted-foreground"
|
||||
>
|
||||
<ArrowLeftRight class="w-12 h-12 mx-auto mb-3 opacity-50" />
|
||||
<p class="text-sm">
|
||||
暂无特定别名/映射
|
||||
</p>
|
||||
<p class="text-xs mt-1">
|
||||
点击上方按钮添加
|
||||
</p>
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
<!-- 使用共享的 AliasDialog 组件 -->
|
||||
<AliasDialog
|
||||
:open="dialogOpen"
|
||||
:editing-alias="editingAlias"
|
||||
:global-models="availableModels"
|
||||
:fixed-provider="fixedProviderOption"
|
||||
:show-provider-select="true"
|
||||
@update:open="handleDialogVisibility"
|
||||
@submit="handleAliasSubmit"
|
||||
/>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
import { ArrowLeftRight, Plus, Edit, Trash2, Power } from 'lucide-vue-next'
|
||||
import Card from '@/components/ui/card.vue'
|
||||
import Badge from '@/components/ui/badge.vue'
|
||||
import Button from '@/components/ui/button.vue'
|
||||
import AliasDialog from '@/features/models/components/AliasDialog.vue'
|
||||
import { useToast } from '@/composables/useToast'
|
||||
import {
|
||||
getAliases,
|
||||
createAlias,
|
||||
updateAlias,
|
||||
deleteAlias,
|
||||
type ModelAlias,
|
||||
type CreateModelAliasRequest,
|
||||
type UpdateModelAliasRequest,
|
||||
} from '@/api/endpoints/aliases'
|
||||
import { listGlobalModels, type GlobalModelResponse } from '@/api/global-models'
|
||||
|
||||
const props = withDefaults(defineProps<{
|
||||
provider: any
|
||||
hideAddButton?: boolean
|
||||
}>(), {
|
||||
hideAddButton: false
|
||||
})
|
||||
|
||||
const emit = defineEmits<{
|
||||
refresh: []
|
||||
}>()
|
||||
|
||||
const { success, error: showError } = useToast()
|
||||
|
||||
// 状态
|
||||
const loading = ref(false)
|
||||
const submitting = ref(false)
|
||||
const togglingId = ref<string | null>(null)
|
||||
const mappings = ref<ModelAlias[]>([])
|
||||
const availableModels = ref<GlobalModelResponse[]>([])
|
||||
const dialogOpen = ref(false)
|
||||
const editingAlias = ref<ModelAlias | null>(null)
|
||||
|
||||
// 固定的 Provider 选项(传递给 AliasDialog)
|
||||
const fixedProviderOption = computed(() => ({
|
||||
id: props.provider.id,
|
||||
name: props.provider.name,
|
||||
display_name: props.provider.display_name
|
||||
}))
|
||||
|
||||
// 加载映射 (实际返回的是该 Provider 的别名列表)
|
||||
async function loadMappings() {
|
||||
try {
|
||||
loading.value = true
|
||||
mappings.value = await getAliases({ provider_id: props.provider.id })
|
||||
} catch (err: any) {
|
||||
showError(err.response?.data?.detail || '加载失败', '错误')
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 加载可用的 GlobalModel 列表
|
||||
async function loadAvailableModels() {
|
||||
try {
|
||||
const response = await listGlobalModels({ limit: 1000, is_active: true })
|
||||
availableModels.value = response.models || []
|
||||
} catch (err: any) {
|
||||
showError(err.response?.data?.detail || '加载模型列表失败', '错误')
|
||||
}
|
||||
}
|
||||
|
||||
// 打开创建对话框
|
||||
function openCreateDialog() {
|
||||
editingAlias.value = null
|
||||
dialogOpen.value = true
|
||||
}
|
||||
|
||||
// 打开编辑对话框
|
||||
function openEditDialog(alias: ModelAlias) {
|
||||
editingAlias.value = alias
|
||||
dialogOpen.value = true
|
||||
}
|
||||
|
||||
// 处理对话框可见性变化
|
||||
function handleDialogVisibility(value: boolean) {
|
||||
dialogOpen.value = value
|
||||
if (!value) {
|
||||
editingAlias.value = null
|
||||
}
|
||||
}
|
||||
|
||||
// 处理别名提交(来自 AliasDialog 组件)
|
||||
async function handleAliasSubmit(data: CreateModelAliasRequest | UpdateModelAliasRequest, isEdit: boolean) {
|
||||
submitting.value = true
|
||||
try {
|
||||
if (isEdit && editingAlias.value) {
|
||||
// 更新
|
||||
await updateAlias(editingAlias.value.id, data as UpdateModelAliasRequest)
|
||||
success(data.mapping_type === 'mapping' ? '映射已更新' : '别名已更新')
|
||||
} else {
|
||||
// 创建 - 确保 provider_id 设置为当前 Provider
|
||||
const createData = data as CreateModelAliasRequest
|
||||
createData.provider_id = props.provider.id
|
||||
await createAlias(createData)
|
||||
success(data.mapping_type === 'mapping' ? '映射已创建' : '别名已创建')
|
||||
}
|
||||
dialogOpen.value = false
|
||||
editingAlias.value = null
|
||||
await loadMappings()
|
||||
emit('refresh')
|
||||
} catch (err: any) {
|
||||
const detail = err.response?.data?.detail || err.message
|
||||
let errorMessage = detail
|
||||
if (detail === '映射已存在') {
|
||||
errorMessage = '该名称已存在,请使用其他名称'
|
||||
}
|
||||
showError(errorMessage, isEdit ? '更新失败' : '创建失败')
|
||||
} finally {
|
||||
submitting.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 切换启用状态
|
||||
async function toggleActive(alias: ModelAlias) {
|
||||
if (togglingId.value) return
|
||||
|
||||
togglingId.value = alias.id
|
||||
try {
|
||||
const newStatus = !alias.is_active
|
||||
await updateAlias(alias.id, { is_active: newStatus })
|
||||
alias.is_active = newStatus
|
||||
} catch (err: any) {
|
||||
showError(err.response?.data?.detail || '操作失败', '错误')
|
||||
} finally {
|
||||
togglingId.value = null
|
||||
}
|
||||
}
|
||||
|
||||
// 确认删除
|
||||
async function confirmDelete(alias: ModelAlias) {
|
||||
const typeName = alias.mapping_type === 'mapping' ? '映射' : '别名'
|
||||
if (!confirm(`确定要删除${typeName} "${alias.alias}" 吗?`)) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
await deleteAlias(alias.id)
|
||||
success(`${typeName}已删除`)
|
||||
await loadMappings()
|
||||
emit('refresh')
|
||||
} catch (err: any) {
|
||||
showError(err.response?.data?.detail || err.message, '删除失败')
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
loadMappings()
|
||||
loadAvailableModels()
|
||||
})
|
||||
</script>
|
||||
@@ -165,6 +165,15 @@
|
||||
>
|
||||
<Edit class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-8 w-8"
|
||||
title="管理映射"
|
||||
@click="openAliasDialog(model)"
|
||||
>
|
||||
<Tag class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
@@ -209,7 +218,7 @@
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
import { Box, Edit, Trash2, Layers, Eye, Wrench, Zap, Brain, Power, Copy, Image } from 'lucide-vue-next'
|
||||
import { Box, Edit, Trash2, Layers, Eye, Wrench, Zap, Brain, Power, Copy, Image, Tag } from 'lucide-vue-next'
|
||||
import Card from '@/components/ui/card.vue'
|
||||
import Button from '@/components/ui/button.vue'
|
||||
import { useToast } from '@/composables/useToast'
|
||||
@@ -224,6 +233,7 @@ const emit = defineEmits<{
|
||||
'editModel': [model: Model]
|
||||
'deleteModel': [model: Model]
|
||||
'batchAssign': []
|
||||
'manageAlias': [model: Model]
|
||||
}>()
|
||||
|
||||
const { error: showError, success: showSuccess } = useToast()
|
||||
@@ -363,6 +373,11 @@ function openBatchAssignDialog() {
|
||||
emit('batchAssign')
|
||||
}
|
||||
|
||||
// 打开别名管理对话框
|
||||
function openAliasDialog(model: Model) {
|
||||
emit('manageAlias', model)
|
||||
}
|
||||
|
||||
// 切换模型启用状态
|
||||
async function toggleModelActive(model: Model) {
|
||||
if (togglingModelId.value) return
|
||||
|
||||
@@ -543,13 +543,14 @@ function formatApiFormat(format: string): string {
|
||||
}
|
||||
|
||||
// 获取实际使用的模型(优先 target_model,其次 model_version)
|
||||
// 只有当实际模型与请求模型不同时才返回,用于显示映射箭头
|
||||
function getActualModel(record: UsageRecord): string | null {
|
||||
// 优先显示模型映射
|
||||
if (record.target_model) {
|
||||
if (record.target_model && record.target_model !== record.model) {
|
||||
return record.target_model
|
||||
}
|
||||
// 其次显示 Provider 返回的实际版本(如 Gemini 的 modelVersion)
|
||||
if (record.request_metadata?.model_version) {
|
||||
if (record.request_metadata?.model_version && record.request_metadata.model_version !== record.model) {
|
||||
return record.request_metadata.model_version
|
||||
}
|
||||
return null
|
||||
|
||||
@@ -313,7 +313,6 @@ import {
|
||||
Gauge,
|
||||
Layers,
|
||||
FolderTree,
|
||||
Tag,
|
||||
Box,
|
||||
LogOut,
|
||||
SunMoon,
|
||||
@@ -411,7 +410,6 @@ const navigation = computed(() => {
|
||||
{ name: '用户管理', href: '/admin/users', icon: Users },
|
||||
{ name: '提供商', href: '/admin/providers', icon: FolderTree },
|
||||
{ name: '模型管理', href: '/admin/models', icon: Layers },
|
||||
{ name: '别名映射', href: '/admin/aliases', icon: Tag },
|
||||
{ name: '独立密钥', href: '/admin/keys', icon: Key },
|
||||
{ name: '使用记录', href: '/admin/usage', icon: BarChart3 },
|
||||
]
|
||||
|
||||
@@ -91,11 +91,6 @@ const routes: RouteRecordRaw[] = [
|
||||
name: 'ModelManagement',
|
||||
component: () => importWithRetry(() => import('@/views/admin/ModelManagement.vue'))
|
||||
},
|
||||
{
|
||||
path: 'aliases',
|
||||
name: 'AliasManagement',
|
||||
component: () => importWithRetry(() => import('@/views/admin/AliasManagement.vue'))
|
||||
},
|
||||
{
|
||||
path: 'health-monitor',
|
||||
name: 'HealthMonitor',
|
||||
|
||||
@@ -1,500 +0,0 @@
|
||||
<template>
|
||||
<div class="flex flex-col">
|
||||
<Card class="overflow-hidden">
|
||||
<!-- 搜索和过滤区域 -->
|
||||
<div class="px-4 sm:px-6 py-3 sm:py-3.5 border-b border-border/60">
|
||||
<div class="flex flex-col sm:flex-row sm:items-center sm:justify-between gap-3 sm:gap-4">
|
||||
<h3 class="text-sm sm:text-base font-semibold shrink-0">
|
||||
别名管理
|
||||
</h3>
|
||||
<div class="flex flex-wrap items-center gap-2">
|
||||
<!-- 搜索框 -->
|
||||
<div class="relative">
|
||||
<Search class="absolute left-2.5 top-1/2 -translate-y-1/2 h-3.5 w-3.5 text-muted-foreground z-10 pointer-events-none" />
|
||||
<Input
|
||||
id="alias-search"
|
||||
v-model="aliasesSearch"
|
||||
placeholder="搜索别名或关联模型"
|
||||
class="w-32 sm:w-44 pl-8 pr-3 h-8 text-sm border-border/60 focus-visible:ring-1"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div class="hidden sm:block h-4 w-px bg-border" />
|
||||
|
||||
<!-- 提供商过滤器 -->
|
||||
<Select
|
||||
v-model:open="aliasProviderSelectOpen"
|
||||
:model-value="aliasProviderFilter"
|
||||
@update:model-value="aliasProviderFilter = $event"
|
||||
>
|
||||
<SelectTrigger class="w-28 sm:w-40 h-8 text-xs border-border/60">
|
||||
<SelectValue />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="all">
|
||||
全部别名
|
||||
</SelectItem>
|
||||
<SelectItem value="global">
|
||||
仅全局别名
|
||||
</SelectItem>
|
||||
<SelectItem
|
||||
v-for="provider in providers"
|
||||
:key="provider.id"
|
||||
:value="provider.id"
|
||||
>
|
||||
{{ provider.display_name }}
|
||||
</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
|
||||
<div class="hidden sm:block h-4 w-px bg-border" />
|
||||
|
||||
<!-- 操作按钮 -->
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-8 w-8"
|
||||
title="新建别名"
|
||||
@click="openCreateAliasDialog"
|
||||
>
|
||||
<Plus class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<RefreshButton
|
||||
:loading="loadingAliases"
|
||||
@click="loadAliases"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
v-if="loadingAliases"
|
||||
class="flex items-center justify-center py-12"
|
||||
>
|
||||
<Loader2 class="w-10 h-10 animate-spin text-primary" />
|
||||
</div>
|
||||
<div v-else>
|
||||
<Table class="hidden xl:table text-sm">
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead class="w-[200px]">
|
||||
别名
|
||||
</TableHead>
|
||||
<TableHead class="w-[280px]">
|
||||
关联模型
|
||||
</TableHead>
|
||||
<TableHead class="w-[70px] text-center">
|
||||
类型
|
||||
</TableHead>
|
||||
<TableHead class="w-[100px] text-center">
|
||||
作用域
|
||||
</TableHead>
|
||||
<TableHead class="w-[70px] text-center">
|
||||
状态
|
||||
</TableHead>
|
||||
<TableHead class="w-[100px] text-center">
|
||||
操作
|
||||
</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
<TableRow v-if="filteredAliases.length === 0">
|
||||
<TableCell
|
||||
colspan="6"
|
||||
class="text-center py-8 text-muted-foreground"
|
||||
>
|
||||
{{ aliasProviderFilter === 'global' ? '暂无全局别名' : '暂无别名' }}
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
<TableRow
|
||||
v-for="alias in paginatedAliases"
|
||||
:key="alias.id"
|
||||
>
|
||||
<TableCell>
|
||||
<span class="font-mono font-medium">{{ alias.alias }}</span>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div class="flex flex-col gap-0.5">
|
||||
<span class="font-medium">{{ alias.global_model_display_name || alias.global_model_name }}</span>
|
||||
<span class="text-xs text-muted-foreground font-mono">{{ alias.global_model_name }}</span>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell class="text-center">
|
||||
<Badge
|
||||
variant="secondary"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ alias.mapping_type === 'mapping' ? '映射' : '别名' }}
|
||||
</Badge>
|
||||
</TableCell>
|
||||
<TableCell class="text-center">
|
||||
<Badge
|
||||
v-if="alias.provider_id"
|
||||
variant="outline"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ alias.provider_name || 'Provider 特定' }}
|
||||
</Badge>
|
||||
<Badge
|
||||
v-else
|
||||
variant="default"
|
||||
class="text-xs"
|
||||
>
|
||||
全局
|
||||
</Badge>
|
||||
</TableCell>
|
||||
<TableCell class="text-center">
|
||||
<Badge
|
||||
:variant="alias.is_active ? 'default' : 'secondary'"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ alias.is_active ? '活跃' : '停用' }}
|
||||
</Badge>
|
||||
</TableCell>
|
||||
<TableCell class="text-center">
|
||||
<div class="flex items-center justify-center gap-1">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
title="编辑别名"
|
||||
@click="openEditAliasDialog(alias)"
|
||||
>
|
||||
<Edit class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
:title="alias.is_active ? '停用别名' : '启用别名'"
|
||||
@click="toggleAliasStatus(alias)"
|
||||
>
|
||||
<Power class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
title="删除别名"
|
||||
@click="confirmDeleteAlias(alias)"
|
||||
>
|
||||
<Trash2 class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
</TableBody>
|
||||
</Table>
|
||||
|
||||
<!-- 移动端卡片列表 -->
|
||||
<div
|
||||
v-if="filteredAliases.length > 0"
|
||||
class="xl:hidden divide-y divide-border/40"
|
||||
>
|
||||
<div
|
||||
v-for="alias in paginatedAliases"
|
||||
:key="alias.id"
|
||||
class="p-4 space-y-2"
|
||||
>
|
||||
<div class="flex items-start justify-between gap-3">
|
||||
<div class="flex-1 min-w-0">
|
||||
<div class="flex items-center gap-2">
|
||||
<span class="font-mono font-medium truncate">{{ alias.alias }}</span>
|
||||
<Badge
|
||||
:variant="alias.is_active ? 'default' : 'secondary'"
|
||||
class="text-xs shrink-0"
|
||||
>
|
||||
{{ alias.is_active ? '活跃' : '停用' }}
|
||||
</Badge>
|
||||
</div>
|
||||
<div class="text-xs text-muted-foreground mt-1">
|
||||
<span class="font-medium">{{ alias.global_model_display_name || alias.global_model_name }}</span>
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex items-center gap-0.5 shrink-0">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
@click="openEditAliasDialog(alias)"
|
||||
>
|
||||
<Edit class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
@click="toggleAliasStatus(alias)"
|
||||
>
|
||||
<Power class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
@click="confirmDeleteAlias(alias)"
|
||||
>
|
||||
<Trash2 class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex items-center gap-2">
|
||||
<Badge
|
||||
variant="secondary"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ alias.mapping_type === 'mapping' ? '映射' : '别名' }}
|
||||
</Badge>
|
||||
<Badge
|
||||
v-if="alias.provider_id"
|
||||
variant="outline"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ alias.provider_name || 'Provider 特定' }}
|
||||
</Badge>
|
||||
<Badge
|
||||
v-else
|
||||
variant="default"
|
||||
class="text-xs"
|
||||
>
|
||||
全局
|
||||
</Badge>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 分页 -->
|
||||
<Pagination
|
||||
v-if="!loadingAliases && filteredAliases.length > 0"
|
||||
:current="aliasesCurrentPage"
|
||||
:total="filteredAliases.length"
|
||||
:page-size="aliasesPageSize"
|
||||
@update:current="aliasesCurrentPage = $event"
|
||||
@update:page-size="aliasesPageSize = $event"
|
||||
/>
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
<!-- 创建/编辑别名对话框 -->
|
||||
<AliasDialog
|
||||
:open="createAliasDialogOpen"
|
||||
:editing-alias="editingAlias"
|
||||
:global-models="globalModels"
|
||||
:providers="providers"
|
||||
@update:open="handleAliasDialogUpdate"
|
||||
@submit="handleAliasSubmit"
|
||||
/>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted, watch } from 'vue'
|
||||
import {
|
||||
Edit,
|
||||
Loader2,
|
||||
Plus,
|
||||
Power,
|
||||
Search,
|
||||
Trash2
|
||||
} from 'lucide-vue-next'
|
||||
import {
|
||||
Card,
|
||||
Button,
|
||||
Input,
|
||||
Badge,
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
RefreshButton,
|
||||
Pagination
|
||||
} from '@/components/ui'
|
||||
import AliasDialog from '@/features/models/components/AliasDialog.vue'
|
||||
import { useToast } from '@/composables/useToast'
|
||||
import { useConfirm } from '@/composables/useConfirm'
|
||||
import {
|
||||
getAliases,
|
||||
createAlias,
|
||||
updateAlias,
|
||||
deleteAlias,
|
||||
type ModelAlias,
|
||||
type CreateModelAliasRequest,
|
||||
type UpdateModelAliasRequest
|
||||
} from '@/api/endpoints/aliases'
|
||||
import { listGlobalModels, type GlobalModelResponse } from '@/api/global-models'
|
||||
import { getProvidersSummary } from '@/api/endpoints/providers'
|
||||
import { log } from '@/utils/logger'
|
||||
|
||||
const { success, error: showError } = useToast()
|
||||
const { confirmDanger } = useConfirm()
|
||||
|
||||
// 状态
|
||||
const loadingAliases = ref(false)
|
||||
const submitting = ref(false)
|
||||
const aliasesSearch = ref('')
|
||||
const aliasProviderFilter = ref<string>('all')
|
||||
const aliasProviderSelectOpen = ref(false)
|
||||
const createAliasDialogOpen = ref(false)
|
||||
const editingAliasId = ref<string | null>(null)
|
||||
|
||||
// 数据
|
||||
const allAliases = ref<ModelAlias[]>([])
|
||||
const globalModels = ref<GlobalModelResponse[]>([])
|
||||
const providers = ref<any[]>([])
|
||||
|
||||
// 分页
|
||||
const aliasesCurrentPage = ref(1)
|
||||
const aliasesPageSize = ref(20)
|
||||
|
||||
// 编辑中的别名对象
|
||||
const editingAlias = computed(() => {
|
||||
if (!editingAliasId.value) return null
|
||||
return allAliases.value.find(a => a.id === editingAliasId.value) || null
|
||||
})
|
||||
|
||||
// 筛选后的别名列表
|
||||
const filteredAliases = computed(() => {
|
||||
let result = allAliases.value
|
||||
|
||||
// 按 Provider 筛选
|
||||
if (aliasProviderFilter.value === 'global') {
|
||||
result = result.filter(alias => !alias.provider_id)
|
||||
} else if (aliasProviderFilter.value !== 'all') {
|
||||
result = result.filter(alias => alias.provider_id === aliasProviderFilter.value)
|
||||
}
|
||||
|
||||
// 按搜索关键词筛选
|
||||
const keyword = aliasesSearch.value.trim().toLowerCase()
|
||||
if (keyword) {
|
||||
result = result.filter(alias =>
|
||||
alias.alias.toLowerCase().includes(keyword) ||
|
||||
alias.global_model_name?.toLowerCase().includes(keyword) ||
|
||||
alias.global_model_display_name?.toLowerCase().includes(keyword)
|
||||
)
|
||||
}
|
||||
|
||||
return result
|
||||
})
|
||||
|
||||
// 分页计算
|
||||
const paginatedAliases = computed(() => {
|
||||
const start = (aliasesCurrentPage.value - 1) * aliasesPageSize.value
|
||||
const end = start + aliasesPageSize.value
|
||||
return filteredAliases.value.slice(start, end)
|
||||
})
|
||||
|
||||
// 搜索或筛选变化时重置到第一页
|
||||
watch([aliasesSearch, aliasProviderFilter], () => {
|
||||
aliasesCurrentPage.value = 1
|
||||
})
|
||||
|
||||
async function loadAliases() {
|
||||
loadingAliases.value = true
|
||||
try {
|
||||
allAliases.value = await getAliases({ limit: 1000 })
|
||||
} catch (err: any) {
|
||||
showError(err.response?.data?.detail || err.message, '加载别名失败')
|
||||
} finally {
|
||||
loadingAliases.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function loadGlobalModelsList() {
|
||||
try {
|
||||
const response = await listGlobalModels()
|
||||
globalModels.value = response.models || []
|
||||
} catch (err: any) {
|
||||
log.error('加载模型失败:', err)
|
||||
}
|
||||
}
|
||||
|
||||
async function loadProviders() {
|
||||
try {
|
||||
providers.value = await getProvidersSummary()
|
||||
} catch (err: any) {
|
||||
showError(err.response?.data?.detail || err.message, '加载 Provider 列表失败')
|
||||
}
|
||||
}
|
||||
|
||||
function openCreateAliasDialog() {
|
||||
editingAliasId.value = null
|
||||
createAliasDialogOpen.value = true
|
||||
}
|
||||
|
||||
function openEditAliasDialog(alias: ModelAlias) {
|
||||
editingAliasId.value = alias.id
|
||||
createAliasDialogOpen.value = true
|
||||
}
|
||||
|
||||
function handleAliasDialogUpdate(value: boolean) {
|
||||
createAliasDialogOpen.value = value
|
||||
if (!value) {
|
||||
editingAliasId.value = null
|
||||
}
|
||||
}
|
||||
|
||||
async function handleAliasSubmit(data: CreateModelAliasRequest | UpdateModelAliasRequest, isEdit: boolean) {
|
||||
submitting.value = true
|
||||
try {
|
||||
if (isEdit && editingAliasId.value) {
|
||||
await updateAlias(editingAliasId.value, data as UpdateModelAliasRequest)
|
||||
success(data.mapping_type === 'mapping' ? '映射已更新' : '别名已更新')
|
||||
} else {
|
||||
await createAlias(data as CreateModelAliasRequest)
|
||||
success(data.mapping_type === 'mapping' ? '映射已创建' : '别名已创建')
|
||||
}
|
||||
createAliasDialogOpen.value = false
|
||||
editingAliasId.value = null
|
||||
await loadAliases()
|
||||
} catch (err: any) {
|
||||
const detail = err.response?.data?.detail || err.message
|
||||
let errorMessage = detail
|
||||
if (detail === '映射已存在') {
|
||||
errorMessage = '目标作用域已存在同名别名,请先删除冲突的映射或选择其他作用域'
|
||||
}
|
||||
showError(errorMessage, isEdit ? '更新失败' : '创建失败')
|
||||
} finally {
|
||||
submitting.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function confirmDeleteAlias(alias: ModelAlias) {
|
||||
const confirmed = await confirmDanger(
|
||||
`确定要删除别名 "${alias.alias}" 吗?`,
|
||||
'删除别名'
|
||||
)
|
||||
if (!confirmed) return
|
||||
|
||||
try {
|
||||
await deleteAlias(alias.id)
|
||||
success('别名已删除')
|
||||
await loadAliases()
|
||||
} catch (err: any) {
|
||||
showError(err.response?.data?.detail || err.message, '删除失败')
|
||||
}
|
||||
}
|
||||
|
||||
async function toggleAliasStatus(alias: ModelAlias) {
|
||||
try {
|
||||
await updateAlias(alias.id, { is_active: !alias.is_active })
|
||||
alias.is_active = !alias.is_active
|
||||
success(alias.is_active ? '别名已启用' : '别名已停用')
|
||||
} catch (err: any) {
|
||||
showError(err.response?.data?.detail || err.message, '操作失败')
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(async () => {
|
||||
await Promise.all([
|
||||
loadAliases(),
|
||||
loadGlobalModelsList(),
|
||||
loadProviders()
|
||||
])
|
||||
})
|
||||
</script>
|
||||
@@ -18,10 +18,10 @@ import SelectContent from '@/components/ui/select-content.vue'
|
||||
import SelectItem from '@/components/ui/select-item.vue'
|
||||
import SelectValue from '@/components/ui/select-value.vue'
|
||||
import ScatterChart from '@/components/charts/ScatterChart.vue'
|
||||
import { Trash2, Eraser, Search, X, BarChart3, ChevronDown, ChevronRight } from 'lucide-vue-next'
|
||||
import { Trash2, Eraser, Search, X, BarChart3, ChevronDown, ChevronRight, Database, ArrowRight } from 'lucide-vue-next'
|
||||
import { useToast } from '@/composables/useToast'
|
||||
import { useConfirm } from '@/composables/useConfirm'
|
||||
import { cacheApi, type CacheStats, type CacheConfig, type UserAffinity } from '@/api/cache'
|
||||
import { cacheApi, modelMappingCacheApi, type CacheStats, type CacheConfig, type UserAffinity, type ModelMappingCacheStats } from '@/api/cache'
|
||||
import type { TTLAnalysisUser } from '@/api/cache'
|
||||
import { formatNumber, formatTokens, formatCost, formatRemainingTime } from '@/utils/format'
|
||||
import {
|
||||
@@ -47,6 +47,13 @@ const currentPage = ref(1)
|
||||
const pageSize = ref(20)
|
||||
const currentTime = ref(Math.floor(Date.now() / 1000))
|
||||
|
||||
// ==================== 模型映射缓存 ====================
|
||||
|
||||
const modelMappingStats = ref<ModelMappingCacheStats | null>(null)
|
||||
const modelMappingLoading = ref(false)
|
||||
const clearingModelMapping = ref(false)
|
||||
const clearingModelName = ref<string | null>(null)
|
||||
|
||||
const { success: showSuccess, error: showError, info: showInfo } = useToast()
|
||||
const { confirm: showConfirm } = useConfirm()
|
||||
|
||||
@@ -241,13 +248,87 @@ function stopCountdown() {
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 模型映射缓存方法 ====================
|
||||
|
||||
async function fetchModelMappingStats() {
|
||||
modelMappingLoading.value = true
|
||||
try {
|
||||
modelMappingStats.value = await modelMappingCacheApi.getStats()
|
||||
} catch (error) {
|
||||
showError('获取模型映射缓存统计失败')
|
||||
log.error('获取模型映射缓存统计失败', error)
|
||||
} finally {
|
||||
modelMappingLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function clearAllModelMappingCache() {
|
||||
const confirmed = await showConfirm({
|
||||
title: '确认清除',
|
||||
message: '确定要清除所有模型映射缓存吗?这会影响所有模型的名称解析。',
|
||||
confirmText: '确认清除',
|
||||
variant: 'destructive'
|
||||
})
|
||||
|
||||
if (!confirmed) return
|
||||
|
||||
clearingModelMapping.value = true
|
||||
try {
|
||||
const result = await modelMappingCacheApi.clearAll()
|
||||
showSuccess(`已清除 ${result.deleted_count} 个缓存键`)
|
||||
await fetchModelMappingStats()
|
||||
} catch (error) {
|
||||
showError('清除模型映射缓存失败')
|
||||
log.error('清除模型映射缓存失败', error)
|
||||
} finally {
|
||||
clearingModelMapping.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function clearModelMappingByName(modelName: string) {
|
||||
clearingModelName.value = modelName
|
||||
try {
|
||||
await modelMappingCacheApi.clearByName(modelName)
|
||||
showSuccess(`已清除 ${modelName} 的映射缓存`)
|
||||
await fetchModelMappingStats()
|
||||
} catch (error) {
|
||||
showError('清除缓存失败')
|
||||
log.error('清除模型映射缓存失败', error)
|
||||
} finally {
|
||||
clearingModelName.value = null
|
||||
}
|
||||
}
|
||||
|
||||
function formatTTL(ttl: number | null): string {
|
||||
if (ttl === null || ttl < 0) return '-'
|
||||
if (ttl < 60) return `${ttl}s`
|
||||
const minutes = Math.floor(ttl / 60)
|
||||
const seconds = ttl % 60
|
||||
if (seconds === 0) return `${minutes}m`
|
||||
return `${minutes}m${seconds}s`
|
||||
}
|
||||
|
||||
function getUnmappedStatusBadge(status: string): { variant: 'default' | 'secondary' | 'destructive' | 'outline', text: string } {
|
||||
switch (status) {
|
||||
case 'not_found':
|
||||
return { variant: 'secondary', text: '未找到' }
|
||||
case 'invalid':
|
||||
return { variant: 'destructive', text: '无效' }
|
||||
case 'error':
|
||||
return { variant: 'destructive', text: '错误' }
|
||||
default:
|
||||
return { variant: 'outline', text: status }
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 刷新所有数据 ====================
|
||||
|
||||
async function refreshData() {
|
||||
await Promise.all([
|
||||
fetchCacheStats(),
|
||||
fetchCacheConfig(),
|
||||
fetchAffinityList()
|
||||
fetchAffinityList(),
|
||||
fetchModelMappingStats()
|
||||
])
|
||||
}
|
||||
|
||||
@@ -272,6 +353,7 @@ onMounted(() => {
|
||||
fetchCacheStats()
|
||||
fetchCacheConfig()
|
||||
fetchAffinityList()
|
||||
fetchModelMappingStats()
|
||||
startCountdown()
|
||||
refreshAnalysis()
|
||||
})
|
||||
@@ -599,6 +681,222 @@ onBeforeUnmount(() => {
|
||||
/>
|
||||
</Card>
|
||||
|
||||
<!-- 模型映射缓存管理 -->
|
||||
<Card class="overflow-hidden">
|
||||
<div class="px-4 sm:px-6 py-3 sm:py-3.5 border-b border-border/60">
|
||||
<div class="flex flex-col sm:flex-row sm:items-center sm:justify-between gap-3 sm:gap-4">
|
||||
<div class="flex items-center gap-3 shrink-0">
|
||||
<Database class="h-5 w-5 text-muted-foreground hidden sm:block" />
|
||||
<h3 class="text-sm sm:text-base font-semibold">
|
||||
模型映射缓存
|
||||
</h3>
|
||||
</div>
|
||||
<div class="flex flex-wrap items-center gap-2">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-8 w-8 text-muted-foreground/70 hover:text-destructive"
|
||||
title="清除全部映射缓存"
|
||||
:disabled="clearingModelMapping || !modelMappingStats?.available"
|
||||
@click="clearAllModelMappingCache"
|
||||
>
|
||||
<Eraser class="h-4 w-4" />
|
||||
</Button>
|
||||
<RefreshButton
|
||||
:loading="modelMappingLoading"
|
||||
@click="fetchModelMappingStats"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 映射缓存表格 -->
|
||||
<Table
|
||||
v-if="modelMappingStats?.available && modelMappingStats.mappings && modelMappingStats.mappings.length > 0"
|
||||
class="hidden md:table"
|
||||
>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead class="w-[25%]">
|
||||
全局模型
|
||||
</TableHead>
|
||||
<TableHead class="w-8 text-center" />
|
||||
<TableHead class="w-[30%]">
|
||||
映射模型
|
||||
</TableHead>
|
||||
<TableHead class="w-[25%]">
|
||||
提供商
|
||||
</TableHead>
|
||||
<TableHead class="w-[10%] text-center">
|
||||
剩余
|
||||
</TableHead>
|
||||
<TableHead class="w-[5%] text-right">
|
||||
操作
|
||||
</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
<TableRow
|
||||
v-for="mapping in modelMappingStats.mappings"
|
||||
:key="mapping.mapping_name"
|
||||
>
|
||||
<TableCell>
|
||||
<div v-if="mapping.global_model_name">
|
||||
<div class="text-sm font-medium">
|
||||
{{ mapping.global_model_display_name || mapping.global_model_name }}
|
||||
</div>
|
||||
<div
|
||||
v-if="mapping.global_model_display_name && mapping.global_model_display_name !== mapping.global_model_name"
|
||||
class="text-xs text-muted-foreground font-mono"
|
||||
>
|
||||
{{ mapping.global_model_name }}
|
||||
</div>
|
||||
</div>
|
||||
<span
|
||||
v-else
|
||||
class="text-sm text-muted-foreground"
|
||||
>-</span>
|
||||
</TableCell>
|
||||
<TableCell class="text-center">
|
||||
<ArrowRight class="h-4 w-4 text-muted-foreground" />
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<span class="text-sm font-mono">{{ mapping.mapping_name }}</span>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div
|
||||
v-if="mapping.providers && mapping.providers.length > 0"
|
||||
class="flex flex-wrap gap-1"
|
||||
>
|
||||
<Badge
|
||||
v-for="provider in mapping.providers.slice(0, 3)"
|
||||
:key="provider"
|
||||
variant="outline"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ provider }}
|
||||
</Badge>
|
||||
<Badge
|
||||
v-if="mapping.providers.length > 3"
|
||||
variant="outline"
|
||||
class="text-xs"
|
||||
>
|
||||
+{{ mapping.providers.length - 3 }}
|
||||
</Badge>
|
||||
</div>
|
||||
<span
|
||||
v-else
|
||||
class="text-sm text-muted-foreground"
|
||||
>-</span>
|
||||
</TableCell>
|
||||
<TableCell class="text-center">
|
||||
<span class="text-xs text-muted-foreground">{{ formatTTL(mapping.ttl) }}</span>
|
||||
</TableCell>
|
||||
<TableCell class="text-right">
|
||||
<Button
|
||||
size="icon"
|
||||
variant="ghost"
|
||||
class="h-6 w-6 text-muted-foreground/50 hover:text-destructive"
|
||||
:disabled="clearingModelName === mapping.mapping_name"
|
||||
title="清除缓存"
|
||||
@click="clearModelMappingByName(mapping.mapping_name)"
|
||||
>
|
||||
<X class="h-3.5 w-3.5" />
|
||||
</Button>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
</TableBody>
|
||||
</Table>
|
||||
|
||||
<!-- 移动端卡片列表 -->
|
||||
<div
|
||||
v-if="modelMappingStats?.available && modelMappingStats.mappings && modelMappingStats.mappings.length > 0"
|
||||
class="md:hidden divide-y divide-border/40"
|
||||
>
|
||||
<div
|
||||
v-for="mapping in modelMappingStats.mappings"
|
||||
:key="`m-${mapping.mapping_name}`"
|
||||
class="p-4 space-y-2"
|
||||
>
|
||||
<div class="flex items-center justify-between gap-2">
|
||||
<span class="text-sm font-medium truncate">{{ mapping.global_model_display_name || mapping.global_model_name || '-' }}</span>
|
||||
<Button
|
||||
size="icon"
|
||||
variant="ghost"
|
||||
class="h-6 w-6 text-muted-foreground/50 hover:text-destructive shrink-0"
|
||||
:disabled="clearingModelName === mapping.mapping_name"
|
||||
@click="clearModelMappingByName(mapping.mapping_name)"
|
||||
>
|
||||
<X class="h-3.5 w-3.5" />
|
||||
</Button>
|
||||
</div>
|
||||
<div class="flex items-center gap-2 text-xs text-muted-foreground">
|
||||
<ArrowRight class="h-3.5 w-3.5 shrink-0" />
|
||||
<span class="font-mono">{{ mapping.mapping_name }}</span>
|
||||
</div>
|
||||
<div
|
||||
v-if="mapping.providers && mapping.providers.length > 0"
|
||||
class="flex flex-wrap gap-1"
|
||||
>
|
||||
<Badge
|
||||
v-for="provider in mapping.providers"
|
||||
:key="provider"
|
||||
variant="outline"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ provider }}
|
||||
</Badge>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 未映射条目(NOT_FOUND 等) -->
|
||||
<div
|
||||
v-if="modelMappingStats?.available && modelMappingStats.unmapped && modelMappingStats.unmapped.length > 0"
|
||||
class="px-6 py-4 border-t border-border/40"
|
||||
>
|
||||
<div class="text-xs text-muted-foreground mb-2">
|
||||
未映射的缓存条目
|
||||
</div>
|
||||
<div class="flex flex-wrap gap-1.5">
|
||||
<Badge
|
||||
v-for="entry in modelMappingStats.unmapped"
|
||||
:key="entry.mapping_name"
|
||||
:variant="getUnmappedStatusBadge(entry.status).variant"
|
||||
class="text-xs font-mono cursor-pointer"
|
||||
:title="`${getUnmappedStatusBadge(entry.status).text} - 点击清除`"
|
||||
@click="clearModelMappingByName(entry.mapping_name)"
|
||||
>
|
||||
{{ entry.mapping_name }}
|
||||
</Badge>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 无缓存状态 -->
|
||||
<div
|
||||
v-else-if="modelMappingStats?.available && (!modelMappingStats.mappings || modelMappingStats.mappings.length === 0) && (!modelMappingStats.unmapped || modelMappingStats.unmapped.length === 0)"
|
||||
class="px-6 py-8 text-center text-sm text-muted-foreground"
|
||||
>
|
||||
暂无模型解析缓存
|
||||
</div>
|
||||
|
||||
<!-- Redis 未启用 -->
|
||||
<div
|
||||
v-else-if="modelMappingStats && !modelMappingStats.available"
|
||||
class="px-6 py-8 text-center text-sm text-muted-foreground"
|
||||
>
|
||||
{{ modelMappingStats.message || 'Redis 未启用' }}
|
||||
</div>
|
||||
|
||||
<!-- 加载中 -->
|
||||
<div
|
||||
v-else-if="modelMappingLoading"
|
||||
class="px-6 py-8 text-center text-sm text-muted-foreground"
|
||||
>
|
||||
加载中...
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
<!-- TTL 分析区域 -->
|
||||
<Card class="overflow-hidden">
|
||||
<div class="px-4 sm:px-6 py-3 sm:py-3.5 border-b border-border/60">
|
||||
|
||||
@@ -425,25 +425,12 @@
|
||||
@success="handleModelFormSuccess"
|
||||
/>
|
||||
|
||||
<!-- 创建/编辑别名/映射对话框 -->
|
||||
<AliasDialog
|
||||
:open="createAliasDialogOpen"
|
||||
:editing-alias="editingAlias"
|
||||
:global-models="globalModels"
|
||||
:providers="providers"
|
||||
:fixed-target-model="isTargetModelFixed ? selectedModel : null"
|
||||
@update:open="handleAliasDialogUpdate"
|
||||
@submit="handleAliasSubmit"
|
||||
/>
|
||||
|
||||
<!-- 模型详情抽屉 -->
|
||||
<ModelDetailDrawer
|
||||
:model="selectedModel"
|
||||
:open="!!selectedModel"
|
||||
:providers="selectedModelProviders"
|
||||
:aliases="selectedModelAliases"
|
||||
:loading-providers="loadingModelProviders"
|
||||
:loading-aliases="loadingModelAliases"
|
||||
:has-blocking-dialog-open="hasBlockingDialogOpen"
|
||||
:capabilities="capabilities"
|
||||
@update:open="handleDrawerOpenChange"
|
||||
@@ -454,11 +441,6 @@
|
||||
@delete-provider="confirmDeleteProviderImplementation"
|
||||
@toggle-provider-status="toggleProviderStatus"
|
||||
@refresh-providers="refreshSelectedModelProviders"
|
||||
@add-alias="openAddAliasDialog"
|
||||
@edit-alias="openEditAliasDialog"
|
||||
@toggle-alias-status="toggleAliasStatusFromDrawer"
|
||||
@delete-alias="confirmDeleteAliasFromDrawer"
|
||||
@refresh-aliases="refreshSelectedModelAliases"
|
||||
/>
|
||||
|
||||
<!-- 批量添加关联提供商对话框 -->
|
||||
@@ -736,9 +718,7 @@ import {
|
||||
} from 'lucide-vue-next'
|
||||
import ModelDetailDrawer from '@/features/models/components/ModelDetailDrawer.vue'
|
||||
import GlobalModelFormDialog from '@/features/models/components/GlobalModelFormDialog.vue'
|
||||
import AliasDialog from '@/features/models/components/AliasDialog.vue'
|
||||
import ProviderModelFormDialog from '@/features/providers/components/ProviderModelFormDialog.vue'
|
||||
import type { CreateModelAliasRequest, UpdateModelAliasRequest } from '@/api/endpoints/aliases'
|
||||
import type { Model } from '@/api/endpoints'
|
||||
import { useToast } from '@/composables/useToast'
|
||||
import { useConfirm } from '@/composables/useConfirm'
|
||||
@@ -768,13 +748,6 @@ import {
|
||||
type GlobalModelResponse,
|
||||
} from '@/api/global-models'
|
||||
import { log } from '@/utils/logger'
|
||||
import {
|
||||
getAliases,
|
||||
createAlias,
|
||||
updateAlias,
|
||||
deleteAlias,
|
||||
type ModelAlias,
|
||||
} from '@/api/endpoints/aliases'
|
||||
import { getProvidersSummary } from '@/api/endpoints/providers'
|
||||
import { getAllCapabilities, type CapabilityDefinition } from '@/api/endpoints'
|
||||
|
||||
@@ -788,13 +761,9 @@ const searchQuery = ref('')
|
||||
const selectedModel = ref<GlobalModelResponse | null>(null)
|
||||
const createModelDialogOpen = ref(false)
|
||||
const editingModel = ref<GlobalModelResponse | null>(null)
|
||||
const createAliasDialogOpen = ref(false)
|
||||
const editingAliasId = ref<string | null>(null)
|
||||
const isTargetModelFixed = ref(false) // 目标模型是否固定(从模型详情抽屉打开时为 true)
|
||||
|
||||
// 数据
|
||||
const globalModels = ref<GlobalModelResponse[]>([])
|
||||
const allAliases = ref<ModelAlias[]>([])
|
||||
const providers = ref<any[]>([])
|
||||
const capabilities = ref<CapabilityDefinition[]>([])
|
||||
|
||||
@@ -804,9 +773,7 @@ const catalogPageSize = ref(20)
|
||||
|
||||
// 选中模型的详细数据
|
||||
const selectedModelProviders = ref<any[]>([])
|
||||
const selectedModelAliases = ref<ModelAlias[]>([])
|
||||
const loadingModelProviders = ref(false)
|
||||
const loadingModelAliases = ref(false)
|
||||
|
||||
// 批量添加关联提供商
|
||||
const batchAddProvidersDialogOpen = ref(false)
|
||||
@@ -876,19 +843,10 @@ function hasTieredPricing(model: GlobalModelResponse): boolean {
|
||||
// 检测是否有对话框打开(防止误关闭抽屉)
|
||||
const hasBlockingDialogOpen = computed(() =>
|
||||
createModelDialogOpen.value ||
|
||||
createAliasDialogOpen.value ||
|
||||
batchAddProvidersDialogOpen.value ||
|
||||
editProviderDialogOpen.value
|
||||
)
|
||||
|
||||
// 编辑中的别名对象(用于传递给 AliasDialog)
|
||||
const editingAlias = computed(() => {
|
||||
if (!editingAliasId.value) return null
|
||||
return allAliases.value.find(a => a.id === editingAliasId.value) ||
|
||||
selectedModelAliases.value.find(a => a.id === editingAliasId.value) ||
|
||||
null
|
||||
})
|
||||
|
||||
// 能力筛选
|
||||
const capabilityFilters = ref({
|
||||
streaming: false,
|
||||
@@ -1131,11 +1089,8 @@ async function selectModel(model: GlobalModelResponse) {
|
||||
selectedModel.value = model
|
||||
detailTab.value = 'basic'
|
||||
|
||||
// 加载该模型的关联提供商和别名
|
||||
await Promise.all([
|
||||
loadModelProviders(model.id),
|
||||
loadModelAliases(model.id)
|
||||
])
|
||||
// 加载该模型的关联提供商
|
||||
await loadModelProviders(model.id)
|
||||
}
|
||||
|
||||
// 加载指定模型的关联提供商
|
||||
@@ -1187,27 +1142,6 @@ async function loadModelProviders(_globalModelId: string) {
|
||||
}
|
||||
}
|
||||
|
||||
// 加载指定模型的别名
|
||||
async function loadModelAliases(globalModelId: string) {
|
||||
loadingModelAliases.value = true
|
||||
try {
|
||||
const aliases = await getAliases({ limit: 1000 })
|
||||
selectedModelAliases.value = aliases.filter(a => a.global_model_id === globalModelId)
|
||||
} catch (err: any) {
|
||||
log.error('加载别名失败:', err)
|
||||
selectedModelAliases.value = []
|
||||
} finally {
|
||||
loadingModelAliases.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 刷新当前选中模型的别名
|
||||
async function refreshSelectedModelAliases() {
|
||||
if (selectedModel.value) {
|
||||
await loadModelAliases(selectedModel.value.id)
|
||||
}
|
||||
}
|
||||
|
||||
// 刷新当前选中模型的关联提供商
|
||||
async function refreshSelectedModelProviders() {
|
||||
if (selectedModel.value) {
|
||||
@@ -1329,14 +1263,6 @@ async function confirmDeleteProviderImplementation(provider: any) {
|
||||
}
|
||||
}
|
||||
|
||||
// 打开添加别名对话框(从模型详情抽屉)
|
||||
function openAddAliasDialog() {
|
||||
if (!selectedModel.value) return
|
||||
editingAliasId.value = null
|
||||
isTargetModelFixed.value = true // 目标模型固定为当前选中模型
|
||||
createAliasDialogOpen.value = true
|
||||
}
|
||||
|
||||
function openCreateModelDialog() {
|
||||
editingModel.value = null
|
||||
createModelDialogOpen.value = true
|
||||
@@ -1391,106 +1317,6 @@ async function toggleModelStatus(model: GlobalModelResponse) {
|
||||
}
|
||||
}
|
||||
|
||||
async function loadAliases() {
|
||||
try {
|
||||
allAliases.value = await getAliases({ limit: 1000 })
|
||||
} catch (err: any) {
|
||||
showError(err.response?.data?.detail || err.message, '加载别名失败')
|
||||
}
|
||||
}
|
||||
|
||||
function openEditAliasDialog(alias: ModelAlias) {
|
||||
editingAliasId.value = alias.id
|
||||
isTargetModelFixed.value = false
|
||||
createAliasDialogOpen.value = true
|
||||
}
|
||||
|
||||
// 处理别名对话框关闭事件
|
||||
function handleAliasDialogUpdate(value: boolean) {
|
||||
createAliasDialogOpen.value = value
|
||||
if (!value) {
|
||||
editingAliasId.value = null
|
||||
isTargetModelFixed.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 处理别名提交(来自 AliasDialog 组件)
|
||||
async function handleAliasSubmit(data: CreateModelAliasRequest | UpdateModelAliasRequest, isEdit: boolean) {
|
||||
submitting.value = true
|
||||
try {
|
||||
if (isEdit && editingAliasId.value) {
|
||||
// 更新
|
||||
await updateAlias(editingAliasId.value, data as UpdateModelAliasRequest)
|
||||
success(data.mapping_type === 'mapping' ? '映射已更新' : '别名已更新')
|
||||
} else {
|
||||
// 创建
|
||||
await createAlias(data as CreateModelAliasRequest)
|
||||
success(data.mapping_type === 'mapping' ? '映射已创建' : '别名已创建')
|
||||
}
|
||||
createAliasDialogOpen.value = false
|
||||
editingAliasId.value = null
|
||||
isTargetModelFixed.value = false
|
||||
|
||||
// 刷新数据
|
||||
await loadAliases()
|
||||
if (selectedModel.value) {
|
||||
await loadModelAliases(selectedModel.value.id)
|
||||
}
|
||||
// 刷新外层模型列表以更新 alias_count
|
||||
await loadGlobalModels()
|
||||
} catch (err: any) {
|
||||
const detail = err.response?.data?.detail || err.message
|
||||
// 优化错误提示文案
|
||||
let errorMessage = detail
|
||||
if (detail === '映射已存在') {
|
||||
errorMessage = '目标作用域已存在同名别名,请先删除冲突的映射或选择其他作用域'
|
||||
}
|
||||
showError(errorMessage, isEdit ? '更新失败' : '创建失败')
|
||||
} finally {
|
||||
submitting.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function confirmDeleteAlias(alias: ModelAlias) {
|
||||
const confirmed = await confirmDanger(
|
||||
`确定要删除别名 "${alias.alias}" 吗?`,
|
||||
'删除别名'
|
||||
)
|
||||
if (!confirmed) return
|
||||
|
||||
try {
|
||||
await deleteAlias(alias.id)
|
||||
success('别名已删除')
|
||||
await loadAliases()
|
||||
// 刷新外层模型列表以更新 alias_count
|
||||
await loadGlobalModels()
|
||||
} catch (err: any) {
|
||||
showError(err.response?.data?.detail || err.message, '删除失败')
|
||||
}
|
||||
}
|
||||
|
||||
async function toggleAliasStatus(alias: ModelAlias) {
|
||||
try {
|
||||
await updateAlias(alias.id, { is_active: !alias.is_active })
|
||||
alias.is_active = !alias.is_active
|
||||
success(alias.is_active ? '别名已启用' : '别名已停用')
|
||||
} catch (err: any) {
|
||||
showError(err.response?.data?.detail || err.message, '操作失败')
|
||||
}
|
||||
}
|
||||
|
||||
// 从抽屉中切换别名状态
|
||||
async function toggleAliasStatusFromDrawer(alias: ModelAlias) {
|
||||
await toggleAliasStatus(alias)
|
||||
await refreshSelectedModelAliases()
|
||||
}
|
||||
|
||||
// 从抽屉中删除别名
|
||||
async function confirmDeleteAliasFromDrawer(alias: ModelAlias) {
|
||||
await confirmDeleteAlias(alias)
|
||||
await refreshSelectedModelAliases()
|
||||
}
|
||||
|
||||
async function refreshData() {
|
||||
await loadGlobalModels()
|
||||
}
|
||||
|
||||
@@ -6,11 +6,9 @@ from fastapi import APIRouter
|
||||
|
||||
from .catalog import router as catalog_router
|
||||
from .global_models import router as global_models_router
|
||||
from .mappings import router as mappings_router
|
||||
|
||||
router = APIRouter(prefix="/api/admin/models", tags=["Admin - Model Management"])
|
||||
|
||||
# 挂载子路由
|
||||
router.include_router(catalog_router)
|
||||
router.include_router(global_models_router)
|
||||
router.include_router(mappings_router)
|
||||
|
||||
@@ -1,38 +1,26 @@
|
||||
"""
|
||||
统一模型目录 Admin API
|
||||
|
||||
阶段一:基于 ModelMapping 和 Model 的聚合视图
|
||||
基于 GlobalModel 的聚合视图
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Set
|
||||
from typing import Dict, List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy import func, or_
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import GlobalModel, Model, ModelMapping, Provider
|
||||
from src.models.database import GlobalModel, Model
|
||||
from src.models.pydantic_models import (
|
||||
BatchAssignError,
|
||||
BatchAssignModelMappingRequest,
|
||||
BatchAssignModelMappingResponse,
|
||||
BatchAssignProviderResult,
|
||||
DeleteModelMappingResponse,
|
||||
ModelCapabilities,
|
||||
ModelCatalogItem,
|
||||
ModelCatalogProviderDetail,
|
||||
ModelCatalogResponse,
|
||||
ModelPriceRange,
|
||||
OrphanedModel,
|
||||
UpdateModelMappingRequest,
|
||||
UpdateModelMappingResponse,
|
||||
)
|
||||
from src.services.cache.invalidation import get_cache_invalidation_service
|
||||
from src.services.model.service import ModelService
|
||||
|
||||
router = APIRouter(prefix="/catalog", tags=["Admin - Model Catalog"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
@@ -47,24 +35,13 @@ async def get_model_catalog(
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/batch-assign", response_model=BatchAssignModelMappingResponse)
|
||||
async def batch_assign_model_mappings(
|
||||
request: Request,
|
||||
payload: BatchAssignModelMappingRequest,
|
||||
db: Session = Depends(get_db),
|
||||
) -> BatchAssignModelMappingResponse:
|
||||
adapter = AdminBatchAssignModelMappingsAdapter(payload=payload)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminGetModelCatalogAdapter(AdminApiAdapter):
|
||||
"""管理员查询统一模型目录
|
||||
|
||||
新架构说明:
|
||||
架构说明:
|
||||
1. 以 GlobalModel 为中心聚合数据
|
||||
2. ModelMapping 表提供别名信息(provider_id=NULL 表示全局)
|
||||
3. Model 表提供关联提供商和价格
|
||||
2. Model 表提供关联提供商和价格
|
||||
"""
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
@@ -75,29 +52,7 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
|
||||
db.query(GlobalModel).filter(GlobalModel.is_active == True).all()
|
||||
)
|
||||
|
||||
# 2. 获取所有活跃的别名(含全局和 Provider 特定)
|
||||
aliases_rows: List[ModelMapping] = (
|
||||
db.query(ModelMapping)
|
||||
.options(joinedload(ModelMapping.target_global_model))
|
||||
.filter(
|
||||
ModelMapping.is_active == True,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 按 GlobalModel ID 组织别名
|
||||
aliases_by_global_model: Dict[str, List[str]] = {}
|
||||
for alias_row in aliases_rows:
|
||||
if not alias_row.target_global_model_id:
|
||||
continue
|
||||
gm_id = alias_row.target_global_model_id
|
||||
if gm_id not in aliases_by_global_model:
|
||||
aliases_by_global_model[gm_id] = []
|
||||
if alias_row.source_model not in aliases_by_global_model[gm_id]:
|
||||
aliases_by_global_model[gm_id].append(alias_row.source_model)
|
||||
|
||||
# 3. 获取所有活跃的 Model 实现(包含 global_model 以便计算有效价格)
|
||||
# 2. 获取所有活跃的 Model 实现(包含 global_model 以便计算有效价格)
|
||||
models: List[Model] = (
|
||||
db.query(Model)
|
||||
.options(joinedload(Model.provider), joinedload(Model.global_model))
|
||||
@@ -111,7 +66,7 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
|
||||
if model.global_model_id:
|
||||
models_by_global_model.setdefault(model.global_model_id, []).append(model)
|
||||
|
||||
# 4. 为每个 GlobalModel 构建 catalog item
|
||||
# 3. 为每个 GlobalModel 构建 catalog item
|
||||
catalog_items: List[ModelCatalogItem] = []
|
||||
|
||||
for gm in global_models:
|
||||
@@ -168,7 +123,6 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
|
||||
supports_function_calling=model.get_effective_supports_function_calling(),
|
||||
supports_streaming=model.get_effective_supports_streaming(),
|
||||
is_active=bool(model.is_active),
|
||||
mapping_id=None, # 新架构中不再有 mapping_id
|
||||
)
|
||||
)
|
||||
|
||||
@@ -187,7 +141,6 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
|
||||
global_model_name=gm.name,
|
||||
display_name=gm.display_name,
|
||||
description=gm.description,
|
||||
aliases=aliases_by_global_model.get(gm_id, []),
|
||||
providers=provider_entries,
|
||||
price_range=price_range,
|
||||
total_providers=len(provider_entries),
|
||||
@@ -195,238 +148,7 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
|
||||
)
|
||||
)
|
||||
|
||||
# 5. 查找孤立的别名(别名指向的 GlobalModel 不存在或不活跃)
|
||||
orphaned_rows = (
|
||||
db.query(ModelMapping.source_model, GlobalModel.name, func.count(ModelMapping.id))
|
||||
.outerjoin(GlobalModel, ModelMapping.target_global_model_id == GlobalModel.id)
|
||||
.filter(
|
||||
ModelMapping.is_active == True,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
or_(GlobalModel.id == None, GlobalModel.is_active == False),
|
||||
)
|
||||
.group_by(ModelMapping.source_model, GlobalModel.name)
|
||||
.all()
|
||||
)
|
||||
orphaned_models = [
|
||||
OrphanedModel(alias=row[0], global_model_name=row[1], mapping_count=row[2])
|
||||
for row in orphaned_rows
|
||||
if row[0]
|
||||
]
|
||||
|
||||
return ModelCatalogResponse(
|
||||
models=catalog_items,
|
||||
total=len(catalog_items),
|
||||
orphaned_models=orphaned_models,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminBatchAssignModelMappingsAdapter(AdminApiAdapter):
|
||||
payload: BatchAssignModelMappingRequest
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db: Session = context.db
|
||||
created: List[BatchAssignProviderResult] = []
|
||||
errors: List[BatchAssignError] = []
|
||||
|
||||
for provider_config in self.payload.providers:
|
||||
provider_id = provider_config.provider_id
|
||||
try:
|
||||
provider: Provider = db.query(Provider).filter(Provider.id == provider_id).first()
|
||||
if not provider:
|
||||
errors.append(
|
||||
BatchAssignError(provider_id=provider_id, error="Provider 不存在")
|
||||
)
|
||||
continue
|
||||
|
||||
model_id: Optional[str] = None
|
||||
created_model = False
|
||||
|
||||
if provider_config.create_model:
|
||||
model_data = provider_config.model_data
|
||||
if not model_data:
|
||||
errors.append(
|
||||
BatchAssignError(provider_id=provider_id, error="缺少 model_data 配置")
|
||||
)
|
||||
continue
|
||||
|
||||
existing_model = ModelService.get_model_by_name(
|
||||
db, provider_id, model_data.provider_model_name
|
||||
)
|
||||
if existing_model:
|
||||
model_id = existing_model.id
|
||||
logger.info("模型 %s 已存在于 Provider %s,复用现有模型",
|
||||
model_data.provider_model_name,
|
||||
provider.name,
|
||||
)
|
||||
else:
|
||||
model = ModelService.create_model(db, provider_id, model_data)
|
||||
model_id = model.id
|
||||
created_model = True
|
||||
else:
|
||||
model_id = provider_config.model_id
|
||||
if not model_id:
|
||||
errors.append(
|
||||
BatchAssignError(provider_id=provider_id, error="缺少 model_id")
|
||||
)
|
||||
continue
|
||||
model = (
|
||||
db.query(Model)
|
||||
.filter(Model.id == model_id, Model.provider_id == provider_id)
|
||||
.first()
|
||||
)
|
||||
if not model:
|
||||
errors.append(
|
||||
BatchAssignError(
|
||||
provider_id=provider_id, error="模型不存在或不属于当前 Provider")
|
||||
)
|
||||
continue
|
||||
|
||||
# 批量分配功能需要适配 GlobalModel 架构
|
||||
# 参见 docs/optimization-backlog.md 中的待办项
|
||||
errors.append(
|
||||
BatchAssignError(
|
||||
provider_id=provider_id,
|
||||
error="批量分配功能暂时不可用,需要适配新的 GlobalModel 架构",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
except Exception as exc:
|
||||
db.rollback()
|
||||
logger.error("批量添加模型映射失败(需要适配新架构)")
|
||||
errors.append(BatchAssignError(provider_id=provider_id, error=str(exc)))
|
||||
|
||||
return BatchAssignModelMappingResponse(
|
||||
success=len(created) > 0,
|
||||
created_mappings=created,
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/mappings/{mapping_id}", response_model=UpdateModelMappingResponse)
|
||||
async def update_model_mapping(
|
||||
request: Request,
|
||||
mapping_id: str,
|
||||
payload: UpdateModelMappingRequest,
|
||||
db: Session = Depends(get_db),
|
||||
) -> UpdateModelMappingResponse:
|
||||
"""更新模型映射"""
|
||||
adapter = AdminUpdateModelMappingAdapter(mapping_id=mapping_id, payload=payload)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/mappings/{mapping_id}", response_model=DeleteModelMappingResponse)
|
||||
async def delete_model_mapping(
|
||||
request: Request,
|
||||
mapping_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
) -> DeleteModelMappingResponse:
|
||||
"""删除模型映射"""
|
||||
adapter = AdminDeleteModelMappingAdapter(mapping_id=mapping_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminUpdateModelMappingAdapter(AdminApiAdapter):
|
||||
"""更新模型映射"""
|
||||
|
||||
mapping_id: str
|
||||
payload: UpdateModelMappingRequest
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db: Session = context.db
|
||||
|
||||
mapping: Optional[ModelMapping] = (
|
||||
db.query(ModelMapping).filter(ModelMapping.id == self.mapping_id).first()
|
||||
)
|
||||
|
||||
if not mapping:
|
||||
raise HTTPException(status_code=404, detail="映射不存在")
|
||||
|
||||
update_data = self.payload.model_dump(exclude_unset=True)
|
||||
|
||||
if "provider_id" in update_data:
|
||||
new_provider_id = update_data["provider_id"]
|
||||
if new_provider_id:
|
||||
provider = db.query(Provider).filter(Provider.id == new_provider_id).first()
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider 不存在")
|
||||
mapping.provider_id = new_provider_id
|
||||
|
||||
if "target_global_model_id" in update_data:
|
||||
target_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.id == update_data["target_global_model_id"],
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not target_model:
|
||||
raise HTTPException(status_code=404, detail="目标 GlobalModel 不存在或未激活")
|
||||
mapping.target_global_model_id = update_data["target_global_model_id"]
|
||||
|
||||
if "source_model" in update_data:
|
||||
new_source = update_data["source_model"].strip()
|
||||
if not new_source:
|
||||
raise HTTPException(status_code=400, detail="source_model 不能为空")
|
||||
mapping.source_model = new_source
|
||||
|
||||
if "is_active" in update_data:
|
||||
mapping.is_active = update_data["is_active"]
|
||||
|
||||
duplicate = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == mapping.source_model,
|
||||
ModelMapping.provider_id == mapping.provider_id,
|
||||
ModelMapping.id != mapping.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if duplicate:
|
||||
raise HTTPException(status_code=400, detail="映射已存在")
|
||||
|
||||
db.commit()
|
||||
db.refresh(mapping)
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_mapping_changed(mapping.source_model, mapping.provider_id)
|
||||
|
||||
return UpdateModelMappingResponse(
|
||||
success=True,
|
||||
mapping_id=mapping.id,
|
||||
message="映射更新成功",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminDeleteModelMappingAdapter(AdminApiAdapter):
|
||||
"""删除模型映射"""
|
||||
|
||||
mapping_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db: Session = context.db
|
||||
|
||||
mapping: Optional[ModelMapping] = (
|
||||
db.query(ModelMapping).filter(ModelMapping.id == self.mapping_id).first()
|
||||
)
|
||||
|
||||
if not mapping:
|
||||
raise HTTPException(status_code=404, detail="映射不存在")
|
||||
|
||||
source_model = mapping.source_model
|
||||
provider_id = mapping.provider_id
|
||||
|
||||
db.delete(mapping)
|
||||
db.commit()
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_mapping_changed(source_model, provider_id)
|
||||
|
||||
return DeleteModelMappingResponse(
|
||||
success=True,
|
||||
message=f"映射 {self.mapping_id} 已删除",
|
||||
)
|
||||
|
||||
@@ -123,7 +123,7 @@ class AdminListGlobalModelsAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from sqlalchemy import func
|
||||
|
||||
from src.models.database import Model, ModelMapping
|
||||
from src.models.database import Model
|
||||
|
||||
models = GlobalModelService.list_global_models(
|
||||
db=context.db,
|
||||
@@ -144,17 +144,8 @@ class AdminListGlobalModelsAdapter(AdminApiAdapter):
|
||||
or 0
|
||||
)
|
||||
|
||||
# 统计别名数量
|
||||
alias_count = (
|
||||
context.db.query(func.count(ModelMapping.id))
|
||||
.filter(ModelMapping.target_global_model_id == gm.id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
response = GlobalModelResponse.model_validate(gm)
|
||||
response.provider_count = provider_count
|
||||
response.alias_count = alias_count
|
||||
# usage_count 直接从 GlobalModel 表读取,已在 model_validate 中自动映射
|
||||
model_responses.append(response)
|
||||
|
||||
|
||||
@@ -1,303 +0,0 @@
|
||||
"""模型映射管理 API
|
||||
|
||||
提供模型映射的 CRUD 操作。
|
||||
|
||||
模型映射(Mapping)用于将源模型映射到目标模型,例如:
|
||||
- 请求 gpt-5.1 → Provider A 映射到 gpt-4
|
||||
- 用于处理 Provider 不支持请求模型的情况
|
||||
|
||||
映射必须关联到特定的 Provider。
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.api import (
|
||||
ModelMappingCreate,
|
||||
ModelMappingResponse,
|
||||
ModelMappingUpdate,
|
||||
)
|
||||
from src.models.database import GlobalModel, ModelMapping, Provider, User
|
||||
from src.services.cache.invalidation import get_cache_invalidation_service
|
||||
|
||||
|
||||
router = APIRouter(prefix="/mappings", tags=["Model Mappings"])
|
||||
|
||||
|
||||
def _serialize_mapping(mapping: ModelMapping) -> ModelMappingResponse:
|
||||
target = mapping.target_global_model
|
||||
provider = mapping.provider
|
||||
scope = "provider" if mapping.provider_id else "global"
|
||||
return ModelMappingResponse(
|
||||
id=mapping.id,
|
||||
source_model=mapping.source_model,
|
||||
target_global_model_id=mapping.target_global_model_id,
|
||||
target_global_model_name=target.name if target else None,
|
||||
target_global_model_display_name=target.display_name if target else None,
|
||||
provider_id=mapping.provider_id,
|
||||
provider_name=provider.name if provider else None,
|
||||
scope=scope,
|
||||
mapping_type=mapping.mapping_type,
|
||||
is_active=mapping.is_active,
|
||||
created_at=mapping.created_at,
|
||||
updated_at=mapping.updated_at,
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=List[ModelMappingResponse])
|
||||
async def list_mappings(
|
||||
provider_id: Optional[str] = Query(None, description="按 Provider 筛选"),
|
||||
source_model: Optional[str] = Query(None, description="按源模型名筛选"),
|
||||
target_global_model_id: Optional[str] = Query(None, description="按目标模型筛选"),
|
||||
scope: Optional[str] = Query(None, description="global 或 provider"),
|
||||
mapping_type: Optional[str] = Query(None, description="映射类型: alias 或 mapping"),
|
||||
is_active: Optional[bool] = Query(None, description="按状态筛选"),
|
||||
skip: int = Query(0, ge=0, description="跳过记录数"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回记录数"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取模型映射列表"""
|
||||
query = db.query(ModelMapping).options(
|
||||
joinedload(ModelMapping.target_global_model),
|
||||
joinedload(ModelMapping.provider),
|
||||
)
|
||||
|
||||
if provider_id is not None:
|
||||
query = query.filter(ModelMapping.provider_id == provider_id)
|
||||
if scope == "global":
|
||||
query = query.filter(ModelMapping.provider_id.is_(None))
|
||||
elif scope == "provider":
|
||||
query = query.filter(ModelMapping.provider_id.isnot(None))
|
||||
if mapping_type is not None:
|
||||
query = query.filter(ModelMapping.mapping_type == mapping_type)
|
||||
if source_model:
|
||||
query = query.filter(ModelMapping.source_model.ilike(f"%{source_model}%"))
|
||||
if target_global_model_id is not None:
|
||||
query = query.filter(ModelMapping.target_global_model_id == target_global_model_id)
|
||||
if is_active is not None:
|
||||
query = query.filter(ModelMapping.is_active == is_active)
|
||||
|
||||
mappings = query.offset(skip).limit(limit).all()
|
||||
return [_serialize_mapping(mapping) for mapping in mappings]
|
||||
|
||||
|
||||
@router.get("/{mapping_id}", response_model=ModelMappingResponse)
|
||||
async def get_mapping(
|
||||
mapping_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取单个模型映射"""
|
||||
mapping = (
|
||||
db.query(ModelMapping)
|
||||
.options(
|
||||
joinedload(ModelMapping.target_global_model),
|
||||
joinedload(ModelMapping.provider),
|
||||
)
|
||||
.filter(ModelMapping.id == mapping_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not mapping:
|
||||
raise HTTPException(status_code=404, detail=f"映射 {mapping_id} 不存在")
|
||||
|
||||
return _serialize_mapping(mapping)
|
||||
|
||||
|
||||
@router.post("", response_model=ModelMappingResponse, status_code=201)
|
||||
async def create_mapping(
|
||||
data: ModelMappingCreate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""创建模型映射"""
|
||||
source_model = data.source_model.strip()
|
||||
if not source_model:
|
||||
raise HTTPException(status_code=400, detail="source_model 不能为空")
|
||||
|
||||
# 验证 mapping_type
|
||||
if data.mapping_type not in ("alias", "mapping"):
|
||||
raise HTTPException(status_code=400, detail="mapping_type 必须是 'alias' 或 'mapping'")
|
||||
|
||||
# 验证目标 GlobalModel 存在
|
||||
target_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(GlobalModel.id == data.target_global_model_id, GlobalModel.is_active == True)
|
||||
.first()
|
||||
)
|
||||
if not target_model:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"目标模型 {data.target_global_model_id} 不存在或未激活"
|
||||
)
|
||||
|
||||
# 验证 Provider 存在
|
||||
provider = None
|
||||
provider_id = data.provider_id
|
||||
if provider_id:
|
||||
provider = db.query(Provider).filter(Provider.id == provider_id).first()
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail=f"Provider {provider_id} 不存在")
|
||||
|
||||
# 检查映射是否已存在(全局或同一 Provider 下不可重复)
|
||||
existing = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == source_model,
|
||||
ModelMapping.provider_id == provider_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="映射已存在")
|
||||
|
||||
# 创建映射
|
||||
mapping = ModelMapping(
|
||||
id=str(uuid.uuid4()),
|
||||
source_model=source_model,
|
||||
target_global_model_id=data.target_global_model_id,
|
||||
provider_id=provider_id,
|
||||
mapping_type=data.mapping_type,
|
||||
is_active=data.is_active,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
db.add(mapping)
|
||||
db.commit()
|
||||
mapping = (
|
||||
db.query(ModelMapping)
|
||||
.options(
|
||||
joinedload(ModelMapping.target_global_model),
|
||||
joinedload(ModelMapping.provider),
|
||||
)
|
||||
.filter(ModelMapping.id == mapping.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
logger.info(f"创建模型映射: {source_model} -> {target_model.name} "
|
||||
f"(Provider: {provider.name if provider else 'global'}, ID: {mapping.id})")
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_mapping_changed(source_model, provider_id)
|
||||
|
||||
return _serialize_mapping(mapping)
|
||||
|
||||
|
||||
@router.patch("/{mapping_id}", response_model=ModelMappingResponse)
|
||||
async def update_mapping(
|
||||
mapping_id: str,
|
||||
data: ModelMappingUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""更新模型映射"""
|
||||
mapping = db.query(ModelMapping).filter(ModelMapping.id == mapping_id).first()
|
||||
if not mapping:
|
||||
raise HTTPException(status_code=404, detail=f"映射 {mapping_id} 不存在")
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
|
||||
# 更新 Provider
|
||||
if "provider_id" in update_data:
|
||||
new_provider_id = update_data["provider_id"]
|
||||
if new_provider_id:
|
||||
provider = db.query(Provider).filter(Provider.id == new_provider_id).first()
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail=f"Provider {new_provider_id} 不存在")
|
||||
mapping.provider_id = new_provider_id
|
||||
|
||||
# 更新目标模型
|
||||
if "target_global_model_id" in update_data:
|
||||
target_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.id == update_data["target_global_model_id"],
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not target_model:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"目标模型 {update_data['target_global_model_id']} 不存在或未激活",
|
||||
)
|
||||
mapping.target_global_model_id = update_data["target_global_model_id"]
|
||||
|
||||
# 更新源模型名
|
||||
if "source_model" in update_data:
|
||||
new_source = update_data["source_model"].strip()
|
||||
if not new_source:
|
||||
raise HTTPException(status_code=400, detail="source_model 不能为空")
|
||||
mapping.source_model = new_source
|
||||
|
||||
# 检查唯一约束
|
||||
duplicate = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == mapping.source_model,
|
||||
ModelMapping.provider_id == mapping.provider_id,
|
||||
ModelMapping.id != mapping_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if duplicate:
|
||||
raise HTTPException(status_code=400, detail="映射已存在")
|
||||
|
||||
# 更新映射类型
|
||||
if "mapping_type" in update_data:
|
||||
if update_data["mapping_type"] not in ("alias", "mapping"):
|
||||
raise HTTPException(status_code=400, detail="mapping_type 必须是 'alias' 或 'mapping'")
|
||||
mapping.mapping_type = update_data["mapping_type"]
|
||||
|
||||
# 更新状态
|
||||
if "is_active" in update_data:
|
||||
mapping.is_active = update_data["is_active"]
|
||||
|
||||
mapping.updated_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
db.refresh(mapping)
|
||||
|
||||
logger.info(f"更新模型映射 (ID: {mapping.id})")
|
||||
|
||||
mapping = (
|
||||
db.query(ModelMapping)
|
||||
.options(
|
||||
joinedload(ModelMapping.target_global_model),
|
||||
joinedload(ModelMapping.provider),
|
||||
)
|
||||
.filter(ModelMapping.id == mapping.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_mapping_changed(mapping.source_model, mapping.provider_id)
|
||||
|
||||
return _serialize_mapping(mapping)
|
||||
|
||||
|
||||
@router.delete("/{mapping_id}", status_code=204)
|
||||
async def delete_mapping(
|
||||
mapping_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除模型映射"""
|
||||
mapping = db.query(ModelMapping).filter(ModelMapping.id == mapping_id).first()
|
||||
|
||||
if not mapping:
|
||||
raise HTTPException(status_code=404, detail=f"映射 {mapping_id} 不存在")
|
||||
|
||||
source_model = mapping.source_model
|
||||
provider_id = mapping.provider_id
|
||||
|
||||
logger.info(f"删除模型映射: {source_model} -> {mapping.target_global_model_id} (ID: {mapping.id})")
|
||||
|
||||
db.delete(mapping)
|
||||
db.commit()
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_mapping_changed(source_model, provider_id)
|
||||
|
||||
return None
|
||||
@@ -869,3 +869,310 @@ class AdminCacheConfigAdapter(AdminApiAdapter):
|
||||
dynamic_reservation_enabled=True,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
# ==================== 模型映射缓存管理 ====================
|
||||
|
||||
|
||||
@router.get("/model-mapping/stats")
|
||||
async def get_model_mapping_cache_stats(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取模型映射缓存统计信息
|
||||
|
||||
返回:
|
||||
- 缓存键数量
|
||||
- 缓存 TTL 配置
|
||||
- 各类型缓存数量
|
||||
"""
|
||||
adapter = AdminModelMappingCacheStatsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/model-mapping")
|
||||
async def clear_all_model_mapping_cache(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
清除所有模型映射缓存
|
||||
|
||||
警告: 这会影响所有模型解析,请谨慎使用
|
||||
"""
|
||||
adapter = AdminClearAllModelMappingCacheAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/model-mapping/{model_name}")
|
||||
async def clear_model_mapping_cache_by_name(
|
||||
model_name: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
清除指定模型名称的映射缓存
|
||||
|
||||
参数:
|
||||
- model_name: 模型名称(可以是 GlobalModel.name 或映射名称)
|
||||
"""
|
||||
adapter = AdminClearModelMappingCacheByNameAdapter(model_name=model_name)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
import json
|
||||
|
||||
from src.clients.redis_client import get_redis_client
|
||||
from src.config.constants import CacheTTL
|
||||
from src.models.database import GlobalModel, Model, Provider
|
||||
|
||||
db = context.db
|
||||
|
||||
try:
|
||||
redis = await get_redis_client(require_redis=False)
|
||||
if not redis:
|
||||
return {
|
||||
"status": "ok",
|
||||
"data": {
|
||||
"available": False,
|
||||
"message": "Redis 未启用,模型映射缓存不可用",
|
||||
},
|
||||
}
|
||||
|
||||
# 统计各类型缓存键数量
|
||||
model_id_keys = []
|
||||
global_model_id_keys = []
|
||||
global_model_name_keys = []
|
||||
global_model_resolve_keys = []
|
||||
provider_global_keys = []
|
||||
|
||||
# 扫描所有模型相关的缓存键
|
||||
async for key in redis.scan_iter(match="model:*", count=100):
|
||||
key_str = key.decode() if isinstance(key, bytes) else key
|
||||
if key_str.startswith("model:id:"):
|
||||
model_id_keys.append(key_str)
|
||||
elif key_str.startswith("model:provider_global:"):
|
||||
provider_global_keys.append(key_str)
|
||||
|
||||
async for key in redis.scan_iter(match="global_model:*", count=100):
|
||||
key_str = key.decode() if isinstance(key, bytes) else key
|
||||
if key_str.startswith("global_model:id:"):
|
||||
global_model_id_keys.append(key_str)
|
||||
elif key_str.startswith("global_model:name:"):
|
||||
global_model_name_keys.append(key_str)
|
||||
elif key_str.startswith("global_model:resolve:"):
|
||||
global_model_resolve_keys.append(key_str)
|
||||
|
||||
total_keys = (
|
||||
len(model_id_keys)
|
||||
+ len(global_model_id_keys)
|
||||
+ len(global_model_name_keys)
|
||||
+ len(global_model_resolve_keys)
|
||||
+ len(provider_global_keys)
|
||||
)
|
||||
|
||||
# 解析缓存内容,构建映射列表
|
||||
mappings = []
|
||||
unmapped_entries = []
|
||||
|
||||
for key in global_model_resolve_keys[:100]: # 最多处理 100 个
|
||||
mapping_name = key.replace("global_model:resolve:", "")
|
||||
try:
|
||||
cached_value = await redis.get(key)
|
||||
ttl = await redis.ttl(key)
|
||||
|
||||
if cached_value:
|
||||
cached_str = (
|
||||
cached_value.decode()
|
||||
if isinstance(cached_value, bytes)
|
||||
else cached_value
|
||||
)
|
||||
|
||||
if cached_str == "NOT_FOUND":
|
||||
unmapped_entries.append({
|
||||
"mapping_name": mapping_name,
|
||||
"status": "not_found",
|
||||
"ttl": ttl if ttl > 0 else None,
|
||||
})
|
||||
else:
|
||||
try:
|
||||
cached_data = json.loads(cached_str)
|
||||
global_model_id = cached_data.get("id")
|
||||
global_model_name = cached_data.get("name")
|
||||
global_model_display_name = cached_data.get("display_name")
|
||||
|
||||
# 跳过 mapping_name == global_model_name 的情况(直接匹配,不是映射)
|
||||
if mapping_name == global_model_name:
|
||||
continue
|
||||
|
||||
# 查询哪些 Provider 配置了这个映射名称
|
||||
provider_names = []
|
||||
if global_model_id:
|
||||
models = (
|
||||
db.query(Model, Provider)
|
||||
.join(Provider, Model.provider_id == Provider.id)
|
||||
.filter(
|
||||
Model.global_model_id == global_model_id,
|
||||
Model.is_active,
|
||||
Provider.is_active,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
# 只显示配置了该映射名称的 Provider
|
||||
for model, provider in models:
|
||||
# 检查是否是主模型名称
|
||||
if model.provider_model_name == mapping_name:
|
||||
provider_names.append(
|
||||
provider.display_name or provider.name
|
||||
)
|
||||
continue
|
||||
# 检查是否在别名列表中
|
||||
if model.provider_model_aliases:
|
||||
alias_names = [
|
||||
a.get("name")
|
||||
for a in model.provider_model_aliases
|
||||
if isinstance(a, dict)
|
||||
]
|
||||
if mapping_name in alias_names:
|
||||
provider_names.append(
|
||||
provider.display_name or provider.name
|
||||
)
|
||||
provider_names = sorted(list(set(provider_names)))
|
||||
|
||||
mappings.append({
|
||||
"mapping_name": mapping_name,
|
||||
"global_model_name": global_model_name,
|
||||
"global_model_display_name": global_model_display_name,
|
||||
"providers": provider_names,
|
||||
"ttl": ttl if ttl > 0 else None,
|
||||
})
|
||||
|
||||
except json.JSONDecodeError:
|
||||
unmapped_entries.append({
|
||||
"mapping_name": mapping_name,
|
||||
"status": "invalid",
|
||||
"ttl": ttl if ttl > 0 else None,
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"解析缓存键 {key} 失败: {e}")
|
||||
unmapped_entries.append({
|
||||
"mapping_name": mapping_name,
|
||||
"status": "error",
|
||||
"ttl": None,
|
||||
})
|
||||
|
||||
# 按 mapping_name 排序
|
||||
mappings.sort(key=lambda x: x["mapping_name"])
|
||||
|
||||
response_data = {
|
||||
"available": True,
|
||||
"ttl_seconds": CacheTTL.MODEL,
|
||||
"total_keys": total_keys,
|
||||
"breakdown": {
|
||||
"model_by_id": len(model_id_keys),
|
||||
"model_by_provider_global": len(provider_global_keys),
|
||||
"global_model_by_id": len(global_model_id_keys),
|
||||
"global_model_by_name": len(global_model_name_keys),
|
||||
"global_model_resolve": len(global_model_resolve_keys),
|
||||
},
|
||||
"mappings": mappings,
|
||||
"unmapped": unmapped_entries if unmapped_entries else None,
|
||||
}
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="model_mapping_cache_stats",
|
||||
total_keys=total_keys,
|
||||
)
|
||||
return {"status": "ok", "data": response_data}
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception(f"获取模型映射缓存统计失败: {exc}")
|
||||
raise HTTPException(status_code=500, detail=f"获取统计失败: {exc}")
|
||||
|
||||
|
||||
class AdminClearAllModelMappingCacheAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from src.clients.redis_client import get_redis_client
|
||||
|
||||
try:
|
||||
redis = await get_redis_client(require_redis=False)
|
||||
if not redis:
|
||||
raise HTTPException(status_code=503, detail="Redis 未启用")
|
||||
|
||||
deleted_count = 0
|
||||
|
||||
# 删除所有模型相关的缓存键
|
||||
keys_to_delete = []
|
||||
async for key in redis.scan_iter(match="model:*", count=100):
|
||||
keys_to_delete.append(key)
|
||||
async for key in redis.scan_iter(match="global_model:*", count=100):
|
||||
keys_to_delete.append(key)
|
||||
|
||||
if keys_to_delete:
|
||||
deleted_count = await redis.delete(*keys_to_delete)
|
||||
|
||||
logger.warning(f"已清除所有模型映射缓存(管理员操作): {deleted_count} 个键")
|
||||
context.add_audit_metadata(
|
||||
action="model_mapping_cache_clear_all",
|
||||
deleted_count=deleted_count,
|
||||
)
|
||||
return {
|
||||
"status": "ok",
|
||||
"message": f"已清除所有模型映射缓存",
|
||||
"deleted_count": deleted_count,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception(f"清除模型映射缓存失败: {exc}")
|
||||
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminClearModelMappingCacheByNameAdapter(AdminApiAdapter):
|
||||
model_name: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from src.clients.redis_client import get_redis_client
|
||||
|
||||
try:
|
||||
redis = await get_redis_client(require_redis=False)
|
||||
if not redis:
|
||||
raise HTTPException(status_code=503, detail="Redis 未启用")
|
||||
|
||||
deleted_keys = []
|
||||
|
||||
# 清除 resolve 缓存
|
||||
resolve_key = f"global_model:resolve:{self.model_name}"
|
||||
if await redis.exists(resolve_key):
|
||||
await redis.delete(resolve_key)
|
||||
deleted_keys.append(resolve_key)
|
||||
|
||||
# 清除 name 缓存
|
||||
name_key = f"global_model:name:{self.model_name}"
|
||||
if await redis.exists(name_key):
|
||||
await redis.delete(name_key)
|
||||
deleted_keys.append(name_key)
|
||||
|
||||
logger.info(f"已清除模型映射缓存: model_name={self.model_name}, 删除键={deleted_keys}")
|
||||
context.add_audit_metadata(
|
||||
action="model_mapping_cache_clear_by_name",
|
||||
model_name=self.model_name,
|
||||
deleted_keys=deleted_keys,
|
||||
)
|
||||
return {
|
||||
"status": "ok",
|
||||
"message": f"已清除模型 {self.model_name} 的映射缓存",
|
||||
"model_name": self.model_name,
|
||||
"deleted_keys": deleted_keys,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception(f"清除模型映射缓存失败: {exc}")
|
||||
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
|
||||
|
||||
@@ -6,7 +6,6 @@ from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
@@ -26,7 +25,6 @@ from src.models.pydantic_models import (
|
||||
from src.models.database import (
|
||||
GlobalModel,
|
||||
Model,
|
||||
ModelMapping,
|
||||
Provider,
|
||||
)
|
||||
from src.models.pydantic_models import (
|
||||
@@ -136,8 +134,7 @@ async def get_provider_available_source_models(
|
||||
获取该 Provider 支持的所有统一模型名(source_model)
|
||||
|
||||
包括:
|
||||
1. 通过 ModelMapping 映射的模型
|
||||
2. 直连模型(Model.provider_model_name 直接作为统一模型名)
|
||||
1. 直连模型(Model.provider_model_name 直接作为统一模型名)
|
||||
"""
|
||||
adapter = AdminGetProviderAvailableSourceModelsAdapter(provider_id=provider_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
@@ -294,10 +291,9 @@ class AdminGetProviderAvailableSourceModelsAdapter(AdminApiAdapter):
|
||||
"""
|
||||
返回 Provider 支持的所有 GlobalModel
|
||||
|
||||
方案 A 逻辑:
|
||||
逻辑:
|
||||
1. 查询该 Provider 的所有 Model
|
||||
2. 通过 Model.global_model_id 获取 GlobalModel
|
||||
3. 查询所有指向该 GlobalModel 的别名(ModelMapping.alias)
|
||||
"""
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
@@ -324,27 +320,10 @@ class AdminGetProviderAvailableSourceModelsAdapter(AdminApiAdapter):
|
||||
|
||||
# 如果该 GlobalModel 还未处理,初始化
|
||||
if global_model_name not in global_models_dict:
|
||||
# 查询指向该 GlobalModel 的所有别名/映射
|
||||
alias_rows = (
|
||||
db.query(ModelMapping.source_model)
|
||||
.filter(
|
||||
ModelMapping.target_global_model_id == global_model.id,
|
||||
ModelMapping.is_active == True,
|
||||
or_(
|
||||
ModelMapping.provider_id == self.provider_id,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
alias_list = [alias[0] for alias in alias_rows]
|
||||
|
||||
global_models_dict[global_model_name] = {
|
||||
"global_model_name": global_model_name,
|
||||
"display_name": global_model.display_name,
|
||||
"provider_model_name": model.provider_model_name,
|
||||
"has_alias": len(alias_list) > 0,
|
||||
"aliases": alias_list,
|
||||
"model_id": model.id,
|
||||
"price": {
|
||||
"input_price_per_1m": model.get_effective_input_price(),
|
||||
|
||||
350
src/api/base/models_service.py
Normal file
350
src/api/base/models_service.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""
|
||||
公共模型查询服务
|
||||
|
||||
为 Claude/OpenAI/Gemini 的 /models 端点提供统一的查询逻辑
|
||||
|
||||
查询逻辑:
|
||||
1. 找到指定 api_format 的活跃端点
|
||||
2. 端点下有活跃的 Key
|
||||
3. Provider 关联了该模型(Model 表)
|
||||
4. Key 的 allowed_models 允许该模型(null = 允许所有)
|
||||
"""
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.config.constants import CacheTTL
|
||||
from src.core.cache_service import CacheService
|
||||
from src.core.logger import logger
|
||||
from src.models.database import GlobalModel, Model, Provider, ProviderAPIKey, ProviderEndpoint
|
||||
|
||||
# 缓存 key 前缀
|
||||
_CACHE_KEY_PREFIX = "models:list"
|
||||
_CACHE_TTL = CacheTTL.MODEL # 300 秒
|
||||
|
||||
|
||||
def _get_cache_key(api_formats: list[str]) -> str:
|
||||
"""生成缓存 key"""
|
||||
formats_str = ",".join(sorted(api_formats))
|
||||
return f"{_CACHE_KEY_PREFIX}:{formats_str}"
|
||||
|
||||
|
||||
async def _get_cached_models(api_formats: list[str]) -> Optional[list["ModelInfo"]]:
|
||||
"""从缓存获取模型列表"""
|
||||
cache_key = _get_cache_key(api_formats)
|
||||
try:
|
||||
cached = await CacheService.get(cache_key)
|
||||
if cached:
|
||||
logger.debug(f"[ModelsService] 缓存命中: {cache_key}, {len(cached)} 个模型")
|
||||
return [ModelInfo(**item) for item in cached]
|
||||
except Exception as e:
|
||||
logger.warning(f"[ModelsService] 缓存读取失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def _set_cached_models(api_formats: list[str], models: list["ModelInfo"]) -> None:
|
||||
"""将模型列表写入缓存"""
|
||||
cache_key = _get_cache_key(api_formats)
|
||||
try:
|
||||
data = [asdict(m) for m in models]
|
||||
await CacheService.set(cache_key, data, ttl_seconds=_CACHE_TTL)
|
||||
logger.debug(f"[ModelsService] 已缓存: {cache_key}, {len(models)} 个模型, TTL={_CACHE_TTL}s")
|
||||
except Exception as e:
|
||||
logger.warning(f"[ModelsService] 缓存写入失败: {e}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""统一的模型信息结构"""
|
||||
|
||||
id: str # 模型 ID (GlobalModel.name 或 provider_model_name)
|
||||
display_name: str
|
||||
description: Optional[str]
|
||||
created_at: Optional[str] # ISO 格式
|
||||
created_timestamp: int # Unix 时间戳
|
||||
provider_name: str
|
||||
|
||||
|
||||
def get_available_provider_ids(db: Session, api_formats: list[str]) -> set[str]:
|
||||
"""
|
||||
返回有可用端点的 Provider IDs
|
||||
|
||||
条件:
|
||||
- 端点 api_format 匹配
|
||||
- 端点是活跃的
|
||||
- 端点下有活跃的 Key
|
||||
"""
|
||||
rows = (
|
||||
db.query(ProviderEndpoint.provider_id)
|
||||
.join(ProviderAPIKey, ProviderAPIKey.endpoint_id == ProviderEndpoint.id)
|
||||
.filter(
|
||||
ProviderEndpoint.api_format.in_(api_formats),
|
||||
ProviderEndpoint.is_active.is_(True),
|
||||
ProviderAPIKey.is_active.is_(True),
|
||||
)
|
||||
.distinct()
|
||||
.all()
|
||||
)
|
||||
return {row[0] for row in rows}
|
||||
|
||||
|
||||
def _get_available_model_ids_for_format(db: Session, api_formats: list[str]) -> set[str]:
|
||||
"""
|
||||
获取指定格式下真正可用的模型 ID 集合
|
||||
|
||||
一个模型可用需满足:
|
||||
1. 端点 api_format 匹配且活跃
|
||||
2. 端点下有活跃的 Key
|
||||
3. **该端点的 Provider 关联了该模型**
|
||||
4. Key 的 allowed_models 允许该模型(null = 允许该 Provider 关联的所有模型)
|
||||
"""
|
||||
# 查询所有匹配格式的活跃端点及其活跃 Key,同时获取 endpoint_id
|
||||
endpoint_keys = (
|
||||
db.query(
|
||||
ProviderEndpoint.id.label("endpoint_id"),
|
||||
ProviderEndpoint.provider_id,
|
||||
ProviderAPIKey.allowed_models,
|
||||
)
|
||||
.join(ProviderAPIKey, ProviderAPIKey.endpoint_id == ProviderEndpoint.id)
|
||||
.filter(
|
||||
ProviderEndpoint.api_format.in_(api_formats),
|
||||
ProviderEndpoint.is_active.is_(True),
|
||||
ProviderAPIKey.is_active.is_(True),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not endpoint_keys:
|
||||
return set()
|
||||
|
||||
# 收集每个 (provider_id, endpoint_id) 对应的 allowed_models
|
||||
# 使用 provider_id 作为 key,因为模型是关联到 Provider 的
|
||||
provider_allowed_models: dict[str, list[Optional[list[str]]]] = {}
|
||||
provider_ids_with_format: set[str] = set()
|
||||
|
||||
for endpoint_id, provider_id, allowed_models in endpoint_keys:
|
||||
provider_ids_with_format.add(provider_id)
|
||||
if provider_id not in provider_allowed_models:
|
||||
provider_allowed_models[provider_id] = []
|
||||
provider_allowed_models[provider_id].append(allowed_models)
|
||||
|
||||
# 只查询那些有匹配格式端点的 Provider 下的模型
|
||||
models = (
|
||||
db.query(Model)
|
||||
.options(joinedload(Model.global_model))
|
||||
.join(Provider)
|
||||
.filter(
|
||||
Model.provider_id.in_(provider_ids_with_format),
|
||||
Model.is_active.is_(True),
|
||||
Provider.is_active.is_(True),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
available_model_ids: set[str] = set()
|
||||
|
||||
for model in models:
|
||||
model_provider_id = model.provider_id
|
||||
global_model = model.global_model
|
||||
model_id = global_model.name if global_model else model.provider_model_name # type: ignore
|
||||
|
||||
if not model_provider_id or not model_id:
|
||||
continue
|
||||
|
||||
# 该模型的 Provider 必须有匹配格式的端点
|
||||
if model_provider_id not in provider_ids_with_format:
|
||||
continue
|
||||
|
||||
# 检查该 provider 下是否有 Key 允许这个模型
|
||||
allowed_lists = provider_allowed_models.get(model_provider_id, [])
|
||||
for allowed_models in allowed_lists:
|
||||
if allowed_models is None:
|
||||
# null = 允许该 Provider 关联的所有模型(已通过上面的查询限制)
|
||||
available_model_ids.add(model_id)
|
||||
break
|
||||
elif model_id in allowed_models:
|
||||
# 明确在允许列表中
|
||||
available_model_ids.add(model_id)
|
||||
break
|
||||
elif global_model and model.provider_model_name in allowed_models:
|
||||
# 也检查 provider_model_name
|
||||
available_model_ids.add(model_id)
|
||||
break
|
||||
|
||||
return available_model_ids
|
||||
|
||||
|
||||
def _extract_model_info(model: Any) -> ModelInfo:
|
||||
"""从 Model 对象提取 ModelInfo"""
|
||||
global_model = model.global_model
|
||||
model_id: str = global_model.name if global_model else model.provider_model_name
|
||||
display_name: str = global_model.display_name if global_model else model.provider_model_name
|
||||
description: Optional[str] = global_model.description if global_model else None
|
||||
created_at: Optional[str] = (
|
||||
model.created_at.strftime("%Y-%m-%dT%H:%M:%SZ") if model.created_at else None
|
||||
)
|
||||
created_timestamp: int = int(model.created_at.timestamp()) if model.created_at else 0
|
||||
provider_name: str = model.provider.name if model.provider else "unknown"
|
||||
|
||||
return ModelInfo(
|
||||
id=model_id,
|
||||
display_name=display_name,
|
||||
description=description,
|
||||
created_at=created_at,
|
||||
created_timestamp=created_timestamp,
|
||||
provider_name=provider_name,
|
||||
)
|
||||
|
||||
|
||||
async def list_available_models(
|
||||
db: Session,
|
||||
available_provider_ids: set[str],
|
||||
api_formats: Optional[list[str]] = None,
|
||||
) -> list[ModelInfo]:
|
||||
"""
|
||||
获取可用模型列表(已去重,带缓存)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
available_provider_ids: 有可用端点的 Provider ID 集合
|
||||
api_formats: API 格式列表,用于检查 Key 的 allowed_models
|
||||
|
||||
Returns:
|
||||
去重后的 ModelInfo 列表,按创建时间倒序
|
||||
"""
|
||||
if not available_provider_ids:
|
||||
return []
|
||||
|
||||
# 尝试从缓存获取
|
||||
if api_formats:
|
||||
cached = await _get_cached_models(api_formats)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
# 如果提供了 api_formats,获取真正可用的模型 ID
|
||||
available_model_ids: Optional[set[str]] = None
|
||||
if api_formats:
|
||||
available_model_ids = _get_available_model_ids_for_format(db, api_formats)
|
||||
if not available_model_ids:
|
||||
return []
|
||||
|
||||
query = (
|
||||
db.query(Model)
|
||||
.options(joinedload(Model.global_model), joinedload(Model.provider))
|
||||
.join(Provider)
|
||||
.filter(
|
||||
Model.is_active.is_(True),
|
||||
Provider.is_active.is_(True),
|
||||
Model.provider_id.in_(available_provider_ids),
|
||||
)
|
||||
.order_by(Model.created_at.desc())
|
||||
)
|
||||
all_models = query.all()
|
||||
|
||||
result: list[ModelInfo] = []
|
||||
seen_model_ids: set[str] = set()
|
||||
|
||||
for model in all_models:
|
||||
info = _extract_model_info(model)
|
||||
|
||||
# 如果有 available_model_ids 限制,检查是否在其中
|
||||
if available_model_ids is not None and info.id not in available_model_ids:
|
||||
continue
|
||||
|
||||
if info.id in seen_model_ids:
|
||||
continue
|
||||
seen_model_ids.add(info.id)
|
||||
|
||||
result.append(info)
|
||||
|
||||
# 写入缓存
|
||||
if api_formats:
|
||||
await _set_cached_models(api_formats, result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def find_model_by_id(
|
||||
db: Session,
|
||||
model_id: str,
|
||||
available_provider_ids: set[str],
|
||||
api_formats: Optional[list[str]] = None,
|
||||
) -> Optional[ModelInfo]:
|
||||
"""
|
||||
按 ID 查找模型
|
||||
|
||||
查找顺序:
|
||||
1. 先按 GlobalModel.name 查找
|
||||
2. 如果没找到任何候选,再按 provider_model_name 查找
|
||||
3. 如果有候选但都不可用,返回 None(不回退)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
model_id: 模型 ID
|
||||
available_provider_ids: 有可用端点的 Provider ID 集合
|
||||
api_formats: API 格式列表,用于检查 Key 的 allowed_models
|
||||
|
||||
Returns:
|
||||
ModelInfo 或 None
|
||||
"""
|
||||
if not available_provider_ids:
|
||||
return None
|
||||
|
||||
# 如果提供了 api_formats,获取真正可用的模型 ID
|
||||
available_model_ids: Optional[set[str]] = None
|
||||
if api_formats:
|
||||
available_model_ids = _get_available_model_ids_for_format(db, api_formats)
|
||||
# 快速检查:如果目标模型不在可用列表中,直接返回 None
|
||||
if available_model_ids is not None and model_id not in available_model_ids:
|
||||
return None
|
||||
|
||||
# 先按 GlobalModel.name 查找
|
||||
models_by_global = (
|
||||
db.query(Model)
|
||||
.options(joinedload(Model.global_model), joinedload(Model.provider))
|
||||
.join(Provider)
|
||||
.join(GlobalModel, Model.global_model_id == GlobalModel.id)
|
||||
.filter(
|
||||
GlobalModel.name == model_id,
|
||||
Model.is_active.is_(True),
|
||||
Provider.is_active.is_(True),
|
||||
)
|
||||
.order_by(Model.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
model = next(
|
||||
(m for m in models_by_global if m.provider_id in available_provider_ids),
|
||||
None,
|
||||
)
|
||||
|
||||
# 如果有候选但都不可用,直接返回 None(不回退 provider_model_name)
|
||||
if not model and models_by_global:
|
||||
return None
|
||||
|
||||
# 如果找不到任何候选,按 provider_model_name 查找
|
||||
if not model:
|
||||
models_by_provider_name = (
|
||||
db.query(Model)
|
||||
.options(joinedload(Model.global_model), joinedload(Model.provider))
|
||||
.join(Provider)
|
||||
.filter(
|
||||
Model.provider_model_name == model_id,
|
||||
Model.is_active.is_(True),
|
||||
Provider.is_active.is_(True),
|
||||
)
|
||||
.order_by(Model.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
model = next(
|
||||
(m for m in models_by_provider_name if m.provider_id in available_provider_ids),
|
||||
None,
|
||||
)
|
||||
|
||||
if not model:
|
||||
return None
|
||||
|
||||
return _extract_model_info(model)
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import List, Sequence, Tuple, TypeVar
|
||||
from typing import Any, List, Sequence, Tuple, TypeVar
|
||||
|
||||
from sqlalchemy.orm import Query
|
||||
|
||||
@@ -40,10 +40,10 @@ def paginate_sequence(
|
||||
return sliced, meta
|
||||
|
||||
|
||||
def build_pagination_payload(items: List[dict], meta: PaginationMeta, **extra) -> dict:
|
||||
def build_pagination_payload(items: List[dict], meta: PaginationMeta, **extra: Any) -> dict:
|
||||
"""
|
||||
构建标准分页响应 payload。
|
||||
"""
|
||||
payload = {"items": items, "meta": meta.to_dict()}
|
||||
payload: dict = {"items": items, "meta": meta.to_dict()}
|
||||
payload.update(extra)
|
||||
return payload
|
||||
|
||||
@@ -263,7 +263,10 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
mapping = await mapper.get_mapping(source_model, provider_id)
|
||||
|
||||
if mapping and mapping.model:
|
||||
mapped_name = str(mapping.model.provider_model_name)
|
||||
# 使用 select_provider_model_name 支持别名功能
|
||||
# 传入 api_key.id 作为 affinity_key,实现相同用户稳定选择同一别名
|
||||
affinity_key = self.api_key.id if self.api_key else None
|
||||
mapped_name = mapping.model.select_provider_model_name(affinity_key)
|
||||
logger.debug(f"[Chat] 模型映射: {source_model} -> {mapped_name}")
|
||||
return mapped_name
|
||||
|
||||
|
||||
@@ -190,14 +190,17 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
"""
|
||||
获取模型映射后的实际模型名
|
||||
|
||||
按优先级查找:映射 → 别名 → 直接匹配 GlobalModel
|
||||
查找逻辑:
|
||||
1. 直接通过 GlobalModel.name 匹配
|
||||
2. 查找该 Provider 的 Model 实现
|
||||
3. 使用 provider_model_name / provider_model_aliases 选择最终名称
|
||||
|
||||
Args:
|
||||
source_model: 用户请求的模型名(可能是别名)
|
||||
source_model: 用户请求的模型名(必须是 GlobalModel.name)
|
||||
provider_id: Provider ID
|
||||
|
||||
Returns:
|
||||
映射后的 provider_model_name,如果没有找到映射则返回 None
|
||||
映射后的 Provider 模型名,如果没有找到映射则返回 None
|
||||
"""
|
||||
from src.services.model.mapper import ModelMapperMiddleware
|
||||
|
||||
@@ -207,7 +210,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
logger.debug(f"[CLI] _get_mapped_model: source={source_model}, provider={provider_id[:8]}..., mapping={mapping}")
|
||||
|
||||
if mapping and mapping.model:
|
||||
mapped_name = str(mapping.model.provider_model_name)
|
||||
# 使用 select_provider_model_name 支持别名功能
|
||||
# 传入 api_key.id 作为 affinity_key,实现相同用户稳定选择同一别名
|
||||
affinity_key = self.api_key.id if self.api_key else None
|
||||
mapped_name = mapping.model.select_provider_model_name(affinity_key)
|
||||
logger.debug(f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)")
|
||||
return mapped_name
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
不再经过 Protocol 抽象层。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple, Type
|
||||
|
||||
from src.api.handlers.base.response_parser import (
|
||||
ParsedChunk,
|
||||
@@ -60,7 +60,7 @@ def _check_nested_error(response: Dict[str, Any]) -> Tuple[bool, Optional[Dict[s
|
||||
class OpenAIResponseParser(ResponseParser):
|
||||
"""OpenAI 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
from src.api.handlers.openai.stream_parser import OpenAIStreamParser
|
||||
|
||||
self._parser = OpenAIStreamParser()
|
||||
@@ -146,7 +146,7 @@ class OpenAIResponseParser(ResponseParser):
|
||||
if choices:
|
||||
message = choices[0].get("message", {})
|
||||
content = message.get("content")
|
||||
if content:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
return ""
|
||||
|
||||
@@ -158,7 +158,7 @@ class OpenAIResponseParser(ResponseParser):
|
||||
class OpenAICliResponseParser(OpenAIResponseParser):
|
||||
"""OpenAI CLI 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.name = "OPENAI_CLI"
|
||||
self.api_format = "OPENAI_CLI"
|
||||
@@ -167,7 +167,7 @@ class OpenAICliResponseParser(OpenAIResponseParser):
|
||||
class ClaudeResponseParser(ResponseParser):
|
||||
"""Claude 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
from src.api.handlers.claude.stream_parser import ClaudeStreamParser
|
||||
|
||||
self._parser = ClaudeStreamParser()
|
||||
@@ -291,7 +291,7 @@ class ClaudeResponseParser(ResponseParser):
|
||||
class ClaudeCliResponseParser(ClaudeResponseParser):
|
||||
"""Claude CLI 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.name = "CLAUDE_CLI"
|
||||
self.api_format = "CLAUDE_CLI"
|
||||
@@ -300,7 +300,7 @@ class ClaudeCliResponseParser(ClaudeResponseParser):
|
||||
class GeminiResponseParser(ResponseParser):
|
||||
"""Gemini 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
from src.api.handlers.gemini.stream_parser import GeminiStreamParser
|
||||
|
||||
self._parser = GeminiStreamParser()
|
||||
@@ -443,20 +443,20 @@ class GeminiResponseParser(ResponseParser):
|
||||
|
||||
使用增强的错误检测逻辑,支持嵌套在 chunks 中的错误
|
||||
"""
|
||||
return self._parser.is_error_event(response)
|
||||
return bool(self._parser.is_error_event(response))
|
||||
|
||||
|
||||
class GeminiCliResponseParser(GeminiResponseParser):
|
||||
"""Gemini CLI 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.name = "GEMINI_CLI"
|
||||
self.api_format = "GEMINI_CLI"
|
||||
|
||||
|
||||
# 解析器注册表
|
||||
_PARSERS = {
|
||||
_PARSERS: Dict[str, Type[ResponseParser]] = {
|
||||
"CLAUDE": ClaudeResponseParser,
|
||||
"CLAUDE_CLI": ClaudeCliResponseParser,
|
||||
"OPENAI": OpenAIResponseParser,
|
||||
@@ -498,6 +498,5 @@ __all__ = [
|
||||
"GeminiResponseParser",
|
||||
"GeminiCliResponseParser",
|
||||
"get_parser_for_format",
|
||||
"get_parser_from_protocol",
|
||||
"is_cli_format",
|
||||
]
|
||||
|
||||
@@ -108,7 +108,10 @@ class ClaudeStreamParser:
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads(line)
|
||||
result = json.loads(line)
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
return None
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
@@ -147,7 +150,8 @@ class ClaudeStreamParser:
|
||||
Returns:
|
||||
事件类型字符串
|
||||
"""
|
||||
return event.get("type")
|
||||
event_type = event.get("type")
|
||||
return str(event_type) if event_type is not None else None
|
||||
|
||||
def extract_text_delta(self, event: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
@@ -164,7 +168,8 @@ class ClaudeStreamParser:
|
||||
|
||||
delta = event.get("delta", {})
|
||||
if delta.get("type") == self.DELTA_TEXT:
|
||||
return delta.get("text")
|
||||
text = delta.get("text")
|
||||
return str(text) if text is not None else None
|
||||
|
||||
return None
|
||||
|
||||
@@ -219,7 +224,8 @@ class ClaudeStreamParser:
|
||||
return None
|
||||
|
||||
message = event.get("message", {})
|
||||
return message.get("id")
|
||||
msg_id = message.get("id")
|
||||
return str(msg_id) if msg_id is not None else None
|
||||
|
||||
def extract_stop_reason(self, event: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
@@ -235,7 +241,8 @@ class ClaudeStreamParser:
|
||||
return None
|
||||
|
||||
delta = event.get("delta", {})
|
||||
return delta.get("stop_reason")
|
||||
reason = delta.get("stop_reason")
|
||||
return str(reason) if reason is not None else None
|
||||
|
||||
|
||||
__all__ = ["ClaudeStreamParser"]
|
||||
|
||||
@@ -70,7 +70,7 @@ class ClaudeToGeminiConverter:
|
||||
return [{"text": content}]
|
||||
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
parts: List[Dict[str, Any]] = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
parts.append({"text": block})
|
||||
@@ -249,6 +249,8 @@ class GeminiToClaudeConverter:
|
||||
"RECITATION": "content_filtered",
|
||||
"OTHER": "stop_sequence",
|
||||
}
|
||||
if finish_reason is None:
|
||||
return "end_turn"
|
||||
return mapping.get(finish_reason, "end_turn")
|
||||
|
||||
def _create_empty_response(self) -> Dict[str, Any]:
|
||||
@@ -365,7 +367,7 @@ class OpenAIToGeminiConverter:
|
||||
return [{"text": content}]
|
||||
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
parts: List[Dict[str, Any]] = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
parts.append({"text": item})
|
||||
@@ -524,7 +526,7 @@ class GeminiToOpenAIConverter:
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
}
|
||||
|
||||
def _convert_finish_reason(self, finish_reason: Optional[str]) -> Optional[str]:
|
||||
def _convert_finish_reason(self, finish_reason: Optional[str]) -> str:
|
||||
"""转换停止原因"""
|
||||
mapping = {
|
||||
"STOP": "stop",
|
||||
@@ -533,6 +535,8 @@ class GeminiToOpenAIConverter:
|
||||
"RECITATION": "content_filter",
|
||||
"OTHER": "stop",
|
||||
}
|
||||
if finish_reason is None:
|
||||
return "stop"
|
||||
return mapping.get(finish_reason, "stop")
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ Gemini API 的流式响应格式与 Claude/OpenAI 不同:
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
|
||||
class GeminiStreamParser:
|
||||
@@ -32,18 +32,18 @@ class GeminiStreamParser:
|
||||
FINISH_REASON_RECITATION = "RECITATION"
|
||||
FINISH_REASON_OTHER = "OTHER"
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._buffer = ""
|
||||
self._in_array = False
|
||||
self._brace_depth = 0
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
"""重置解析器状态"""
|
||||
self._buffer = ""
|
||||
self._in_array = False
|
||||
self._brace_depth = 0
|
||||
|
||||
def parse_chunk(self, chunk: bytes | str) -> List[Dict[str, Any]]:
|
||||
def parse_chunk(self, chunk: Union[bytes, str]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
解析流式数据块
|
||||
|
||||
@@ -111,7 +111,10 @@ class GeminiStreamParser:
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads(line.strip().rstrip(","))
|
||||
result = json.loads(line.strip().rstrip(","))
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
return None
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
@@ -216,7 +219,8 @@ class GeminiStreamParser:
|
||||
"""
|
||||
candidates = event.get("candidates", [])
|
||||
if candidates:
|
||||
return candidates[0].get("finishReason")
|
||||
reason = candidates[0].get("finishReason")
|
||||
return str(reason) if reason is not None else None
|
||||
return None
|
||||
|
||||
def extract_text_delta(self, event: Dict[str, Any]) -> Optional[str]:
|
||||
@@ -285,7 +289,8 @@ class GeminiStreamParser:
|
||||
Returns:
|
||||
模型版本,如果没有返回 None
|
||||
"""
|
||||
return event.get("modelVersion")
|
||||
version = event.get("modelVersion")
|
||||
return str(version) if version is not None else None
|
||||
|
||||
def extract_safety_ratings(self, event: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
@@ -301,7 +306,10 @@ class GeminiStreamParser:
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
return candidates[0].get("safetyRatings")
|
||||
ratings = candidates[0].get("safetyRatings")
|
||||
if isinstance(ratings, list):
|
||||
return ratings
|
||||
return None
|
||||
|
||||
|
||||
__all__ = ["GeminiStreamParser"]
|
||||
|
||||
@@ -78,7 +78,10 @@ class OpenAIStreamParser:
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads(line)
|
||||
result = json.loads(line)
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
return None
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
@@ -116,7 +119,8 @@ class OpenAIStreamParser:
|
||||
"""
|
||||
choices = chunk.get("choices", [])
|
||||
if choices:
|
||||
return choices[0].get("finish_reason")
|
||||
reason = choices[0].get("finish_reason")
|
||||
return str(reason) if reason is not None else None
|
||||
return None
|
||||
|
||||
def extract_text_delta(self, chunk: Dict[str, Any]) -> Optional[str]:
|
||||
@@ -156,7 +160,10 @@ class OpenAIStreamParser:
|
||||
return None
|
||||
|
||||
delta = choices[0].get("delta", {})
|
||||
return delta.get("tool_calls")
|
||||
tool_calls = delta.get("tool_calls")
|
||||
if isinstance(tool_calls, list):
|
||||
return tool_calls
|
||||
return None
|
||||
|
||||
def extract_role(self, chunk: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
@@ -175,7 +182,8 @@ class OpenAIStreamParser:
|
||||
return None
|
||||
|
||||
delta = choices[0].get("delta", {})
|
||||
return delta.get("role")
|
||||
role = delta.get("role")
|
||||
return str(role) if role is not None else None
|
||||
|
||||
|
||||
__all__ = ["OpenAIStreamParser"]
|
||||
|
||||
@@ -6,10 +6,13 @@ from .capabilities import router as capabilities_router
|
||||
from .catalog import router as catalog_router
|
||||
from .claude import router as claude_router
|
||||
from .gemini import router as gemini_router
|
||||
from .models import router as models_router
|
||||
from .openai import router as openai_router
|
||||
from .system_catalog import router as system_catalog_router
|
||||
|
||||
router = APIRouter()
|
||||
# Models API 需要在最前面注册,避免被其他路由的 path 参数捕获
|
||||
router.include_router(models_router)
|
||||
router.include_router(claude_router, tags=["Claude API"])
|
||||
router.include_router(openai_router)
|
||||
router.include_router(gemini_router, tags=["Gemini API"])
|
||||
|
||||
@@ -61,15 +61,18 @@ async def get_model_supported_capabilities(
|
||||
获取指定模型支持的能力列表
|
||||
|
||||
Args:
|
||||
model_name: 模型名称(如 claude-sonnet-4-20250514)
|
||||
model_name: 模型名称(如 claude-sonnet-4-20250514,必须是 GlobalModel.name)
|
||||
|
||||
Returns:
|
||||
模型支持的能力列表,以及每个能力的详细定义
|
||||
"""
|
||||
from src.services.model.mapping_resolver import get_model_mapping_resolver
|
||||
from src.models.database import GlobalModel
|
||||
|
||||
mapping_resolver = get_model_mapping_resolver()
|
||||
global_model = await mapping_resolver.get_global_model_by_request(db, model_name, None)
|
||||
global_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(GlobalModel.name == model_name, GlobalModel.is_active == True)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not global_model:
|
||||
return {
|
||||
|
||||
@@ -20,14 +20,12 @@ from src.models.api import (
|
||||
ProviderStatsResponse,
|
||||
PublicGlobalModelListResponse,
|
||||
PublicGlobalModelResponse,
|
||||
PublicModelMappingResponse,
|
||||
PublicModelResponse,
|
||||
PublicProviderResponse,
|
||||
)
|
||||
from src.models.database import (
|
||||
GlobalModel,
|
||||
Model,
|
||||
ModelMapping,
|
||||
Provider,
|
||||
ProviderEndpoint,
|
||||
RequestCandidate,
|
||||
@@ -72,24 +70,6 @@ async def get_public_models(
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
|
||||
|
||||
|
||||
@router.get("/model-mappings", response_model=List[PublicModelMappingResponse])
|
||||
async def get_public_model_mappings(
|
||||
request: Request,
|
||||
provider_id: Optional[str] = Query(None, description="提供商ID过滤"),
|
||||
alias: Optional[str] = Query(None, description="别名过滤(原source_model)"),
|
||||
skip: int = Query(0, description="跳过记录数"),
|
||||
limit: int = Query(100, description="返回记录数限制"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = PublicModelMappingsAdapter(
|
||||
provider_id=provider_id,
|
||||
alias=alias,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
|
||||
|
||||
|
||||
@router.get("/stats", response_model=ProviderStatsResponse)
|
||||
async def get_public_stats(request: Request, db: Session = Depends(get_db)):
|
||||
adapter = PublicStatsAdapter()
|
||||
@@ -176,13 +156,6 @@ class PublicProvidersAdapter(PublicApiAdapter):
|
||||
.filter(and_(Model.provider_id == provider.id, Model.is_active.is_(True)))
|
||||
.count()
|
||||
)
|
||||
mappings_count = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
and_(ModelMapping.provider_id == provider.id, ModelMapping.is_active.is_(True))
|
||||
)
|
||||
.count()
|
||||
)
|
||||
endpoints_count = len(provider.endpoints) if provider.endpoints else 0
|
||||
active_endpoints_count = (
|
||||
sum(1 for ep in provider.endpoints if ep.is_active) if provider.endpoints else 0
|
||||
@@ -196,7 +169,6 @@ class PublicProvidersAdapter(PublicApiAdapter):
|
||||
provider_priority=provider.provider_priority,
|
||||
models_count=models_count,
|
||||
active_models_count=active_models_count,
|
||||
mappings_count=mappings_count,
|
||||
endpoints_count=endpoints_count,
|
||||
active_endpoints_count=active_endpoints_count,
|
||||
)
|
||||
@@ -256,78 +228,6 @@ class PublicModelsAdapter(PublicApiAdapter):
|
||||
return response
|
||||
|
||||
|
||||
@dataclass
|
||||
class PublicModelMappingsAdapter(PublicApiAdapter):
|
||||
provider_id: Optional[str]
|
||||
alias: Optional[str] # 原 source_model,改为 alias
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
logger.debug("公共API请求模型映射列表")
|
||||
|
||||
query = (
|
||||
db.query(ModelMapping, GlobalModel, Provider)
|
||||
.join(GlobalModel, ModelMapping.target_global_model_id == GlobalModel.id)
|
||||
.outerjoin(Provider, ModelMapping.provider_id == Provider.id)
|
||||
.filter(
|
||||
and_(
|
||||
ModelMapping.is_active.is_(True),
|
||||
GlobalModel.is_active.is_(True),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if self.provider_id is not None:
|
||||
provider_global_model_ids = (
|
||||
db.query(Model.global_model_id)
|
||||
.join(Provider, Model.provider_id == Provider.id)
|
||||
.filter(
|
||||
Provider.id == self.provider_id,
|
||||
Model.is_active.is_(True),
|
||||
Provider.is_active.is_(True),
|
||||
Model.global_model_id.isnot(None),
|
||||
)
|
||||
.distinct()
|
||||
)
|
||||
query = query.filter(
|
||||
or_(
|
||||
ModelMapping.provider_id == self.provider_id,
|
||||
and_(
|
||||
ModelMapping.provider_id.is_(None),
|
||||
ModelMapping.target_global_model_id.in_(provider_global_model_ids),
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
query = query.filter(ModelMapping.provider_id.is_(None))
|
||||
|
||||
if self.alias is not None:
|
||||
query = query.filter(ModelMapping.source_model.ilike(f"%{self.alias}%"))
|
||||
|
||||
results = query.offset(self.skip).limit(self.limit).all()
|
||||
response = []
|
||||
for mapping, global_model, provider in results:
|
||||
scope = "provider" if mapping.provider_id else "global"
|
||||
mapping_data = PublicModelMappingResponse(
|
||||
id=mapping.id,
|
||||
source_model=mapping.source_model,
|
||||
target_global_model_id=mapping.target_global_model_id,
|
||||
target_global_model_name=global_model.name if global_model else None,
|
||||
target_global_model_display_name=(
|
||||
global_model.display_name if global_model else None
|
||||
),
|
||||
provider_id=mapping.provider_id,
|
||||
scope=scope,
|
||||
is_active=mapping.is_active,
|
||||
)
|
||||
response.append(mapping_data.model_dump())
|
||||
|
||||
logger.debug(f"返回 {len(response)} 个模型映射")
|
||||
return response
|
||||
|
||||
|
||||
class PublicStatsAdapter(PublicApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
@@ -339,9 +239,6 @@ class PublicStatsAdapter(PublicApiAdapter):
|
||||
.filter(and_(Model.is_active.is_(True), Provider.is_active.is_(True)))
|
||||
.count()
|
||||
)
|
||||
from ...models.database import ModelMapping
|
||||
|
||||
active_mappings = db.query(ModelMapping).filter(ModelMapping.is_active.is_(True)).count()
|
||||
formats = (
|
||||
db.query(Provider.api_format).filter(Provider.is_active.is_(True)).distinct().all()
|
||||
)
|
||||
@@ -351,7 +248,6 @@ class PublicStatsAdapter(PublicApiAdapter):
|
||||
active_providers=active_providers,
|
||||
total_models=active_models,
|
||||
active_models=active_models,
|
||||
total_mappings=active_mappings,
|
||||
supported_formats=supported_formats,
|
||||
)
|
||||
logger.debug("返回系统统计信息")
|
||||
|
||||
@@ -3,6 +3,8 @@ Claude API 端点
|
||||
|
||||
- /v1/messages - Claude Messages API
|
||||
- /v1/messages/count_tokens - Token Count API
|
||||
|
||||
注意: /v1/models 端点由 models.py 统一处理,根据请求头返回对应格式
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
@@ -5,11 +5,9 @@ Gemini API 专属端点
|
||||
- /v1beta/models/{model}:generateContent
|
||||
- /v1beta/models/{model}:streamGenerateContent
|
||||
|
||||
注意: Gemini API 的 model 在 URL 路径中,而不是请求体中
|
||||
|
||||
路径配置来源: src.core.api_format_metadata.APIFormat.GEMINI
|
||||
- path_prefix: 本站路径前缀(如 /gemini),通过 router prefix 配置
|
||||
- default_path: 标准 API 路径模板
|
||||
注意:
|
||||
- Gemini API 的 model 在 URL 路径中,而不是请求体中
|
||||
- /v1beta/models (列表) 和 /v1beta/models/{model} (详情) 由 models.py 统一处理
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
@@ -109,7 +107,7 @@ async def stream_generate_content(
|
||||
)
|
||||
|
||||
|
||||
# 兼容 v1 路径(部分 SDK 可能使用)
|
||||
# 兼容 v1 路径(部分 SDK 可能使用 generateContent)
|
||||
@router.post("/v1/models/{model}:generateContent")
|
||||
async def generate_content_v1(
|
||||
model: str,
|
||||
|
||||
499
src/api/public/models.py
Normal file
499
src/api/public/models.py
Normal file
@@ -0,0 +1,499 @@
|
||||
"""
|
||||
统一的 Models API 端点
|
||||
|
||||
根据请求头认证方式自动返回对应格式:
|
||||
- x-api-key + anthropic-version -> Claude 格式
|
||||
- x-goog-api-key (header) 或 ?key= 参数 -> Gemini 格式
|
||||
- Authorization: Bearer (bearer) -> OpenAI 格式
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.models_service import (
|
||||
ModelInfo,
|
||||
find_model_by_id,
|
||||
get_available_provider_ids,
|
||||
list_available_models,
|
||||
)
|
||||
from src.core.api_format_metadata import API_FORMAT_DEFINITIONS, ApiFormatDefinition
|
||||
from src.core.enums import APIFormat
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import ApiKey, User
|
||||
from src.services.auth.service import AuthService
|
||||
|
||||
router = APIRouter(tags=["Models API"])
|
||||
|
||||
# 各格式对应的 API 格式列表
|
||||
# 注意: CLI 格式是透传格式,Models API 只返回非 CLI 格式的端点支持的模型
|
||||
_CLAUDE_FORMATS = [APIFormat.CLAUDE.value]
|
||||
_OPENAI_FORMATS = [APIFormat.OPENAI.value]
|
||||
_GEMINI_FORMATS = [APIFormat.GEMINI.value]
|
||||
|
||||
|
||||
def _extract_api_key_from_request(
|
||||
request: Request, definition: ApiFormatDefinition
|
||||
) -> Optional[str]:
|
||||
"""根据格式定义从请求中提取 API Key"""
|
||||
auth_header = definition.auth_header.lower()
|
||||
auth_type = definition.auth_type
|
||||
|
||||
header_value = request.headers.get(auth_header)
|
||||
if not header_value:
|
||||
# Gemini 还支持 ?key= 参数
|
||||
if definition.api_format in (APIFormat.GEMINI, APIFormat.GEMINI_CLI):
|
||||
return request.query_params.get("key")
|
||||
return None
|
||||
|
||||
if auth_type == "bearer":
|
||||
# Bearer token: "Bearer xxx"
|
||||
if header_value.lower().startswith("bearer "):
|
||||
return header_value[7:].strip()
|
||||
return None
|
||||
else:
|
||||
# header 类型: 直接使用值
|
||||
return header_value
|
||||
|
||||
|
||||
def _detect_api_format_and_key(request: Request) -> Tuple[str, Optional[str]]:
|
||||
"""
|
||||
根据请求头检测 API 格式并提取 API Key
|
||||
|
||||
检测顺序:
|
||||
1. x-api-key + anthropic-version -> Claude
|
||||
2. x-goog-api-key (header) 或 ?key= -> Gemini
|
||||
3. Authorization: Bearer -> OpenAI (默认)
|
||||
|
||||
Returns:
|
||||
(api_format, api_key) 元组
|
||||
"""
|
||||
# Claude: x-api-key + anthropic-version (必须同时存在)
|
||||
claude_def = API_FORMAT_DEFINITIONS[APIFormat.CLAUDE]
|
||||
claude_key = _extract_api_key_from_request(request, claude_def)
|
||||
if claude_key and request.headers.get("anthropic-version"):
|
||||
return "claude", claude_key
|
||||
|
||||
# Gemini: x-goog-api-key (header 类型) 或 ?key=
|
||||
gemini_def = API_FORMAT_DEFINITIONS[APIFormat.GEMINI]
|
||||
gemini_key = _extract_api_key_from_request(request, gemini_def)
|
||||
if gemini_key:
|
||||
return "gemini", gemini_key
|
||||
|
||||
# OpenAI: Authorization: Bearer (默认)
|
||||
# 注意: 如果只有 x-api-key 但没有 anthropic-version,也走 OpenAI 格式
|
||||
openai_def = API_FORMAT_DEFINITIONS[APIFormat.OPENAI]
|
||||
openai_key = _extract_api_key_from_request(request, openai_def)
|
||||
# 如果 OpenAI 格式没有 key,但有 x-api-key,也用它(兼容)
|
||||
if not openai_key and claude_key:
|
||||
openai_key = claude_key
|
||||
return "openai", openai_key
|
||||
|
||||
|
||||
def _get_formats_for_api(api_format: str) -> list[str]:
|
||||
"""获取对应 API 格式的端点格式列表"""
|
||||
if api_format == "claude":
|
||||
return _CLAUDE_FORMATS
|
||||
elif api_format == "gemini":
|
||||
return _GEMINI_FORMATS
|
||||
else:
|
||||
return _OPENAI_FORMATS
|
||||
|
||||
|
||||
def _authenticate(db: Session, api_key: Optional[str]) -> Tuple[Optional[User], Optional[ApiKey]]:
|
||||
"""
|
||||
认证 API Key
|
||||
|
||||
Returns:
|
||||
(user, api_key_record) 元组,认证失败返回 (None, None)
|
||||
"""
|
||||
if not api_key:
|
||||
logger.debug("[Models] 认证失败: 未提供 API Key")
|
||||
return None, None
|
||||
|
||||
result = AuthService.authenticate_api_key(db, api_key)
|
||||
if not result:
|
||||
logger.debug("[Models] 认证失败: API Key 无效")
|
||||
return None, None
|
||||
|
||||
user, key_record = result
|
||||
logger.debug(f"[Models] 认证成功: {user.email} (Key: {key_record.name})")
|
||||
return result
|
||||
|
||||
|
||||
def _build_auth_error_response(api_format: str) -> JSONResponse:
|
||||
"""根据 API 格式构建认证错误响应"""
|
||||
if api_format == "claude":
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"type": "error",
|
||||
"error": {
|
||||
"type": "authentication_error",
|
||||
"message": "Invalid API key provided",
|
||||
},
|
||||
},
|
||||
)
|
||||
elif api_format == "gemini":
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"error": {
|
||||
"code": 401,
|
||||
"message": "API key not valid. Please pass a valid API key.",
|
||||
"status": "UNAUTHENTICATED",
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"error": {
|
||||
"message": "Incorrect API key provided. You can find your API key at https://platform.openai.com/account/api-keys.",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "invalid_api_key",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 响应构建函数
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _build_claude_list_response(
|
||||
models: list[ModelInfo],
|
||||
before_id: Optional[str],
|
||||
after_id: Optional[str],
|
||||
limit: int,
|
||||
) -> dict:
|
||||
"""构建 Claude 格式的列表响应"""
|
||||
model_data_list = [
|
||||
{
|
||||
"id": m.id,
|
||||
"type": "model",
|
||||
"display_name": m.display_name,
|
||||
"created_at": m.created_at,
|
||||
}
|
||||
for m in models
|
||||
]
|
||||
|
||||
# 处理分页
|
||||
start_idx = 0
|
||||
if after_id:
|
||||
for i, m in enumerate(model_data_list):
|
||||
if m["id"] == after_id:
|
||||
start_idx = i + 1
|
||||
break
|
||||
|
||||
end_idx = len(model_data_list)
|
||||
if before_id:
|
||||
for i, m in enumerate(model_data_list):
|
||||
if m["id"] == before_id:
|
||||
end_idx = i
|
||||
break
|
||||
|
||||
paginated = model_data_list[start_idx:end_idx][:limit]
|
||||
|
||||
first_id = paginated[0]["id"] if paginated else None
|
||||
last_id = paginated[-1]["id"] if paginated else None
|
||||
has_more = len(model_data_list[start_idx:end_idx]) > limit
|
||||
|
||||
return {
|
||||
"data": paginated,
|
||||
"has_more": has_more,
|
||||
"first_id": first_id,
|
||||
"last_id": last_id,
|
||||
}
|
||||
|
||||
|
||||
def _build_openai_list_response(models: list[ModelInfo]) -> dict:
|
||||
"""构建 OpenAI 格式的列表响应"""
|
||||
data = [
|
||||
{
|
||||
"id": m.id,
|
||||
"object": "model",
|
||||
"created": m.created_timestamp,
|
||||
"owned_by": m.provider_name,
|
||||
}
|
||||
for m in models
|
||||
]
|
||||
return {"object": "list", "data": data}
|
||||
|
||||
|
||||
def _build_gemini_list_response(
|
||||
models: list[ModelInfo],
|
||||
page_size: int,
|
||||
page_token: Optional[str],
|
||||
) -> dict:
|
||||
"""构建 Gemini 格式的列表响应"""
|
||||
# 处理分页
|
||||
start_idx = 0
|
||||
if page_token:
|
||||
try:
|
||||
start_idx = int(page_token)
|
||||
except ValueError:
|
||||
start_idx = 0
|
||||
|
||||
end_idx = start_idx + page_size
|
||||
paginated_models = models[start_idx:end_idx]
|
||||
|
||||
models_data = [
|
||||
{
|
||||
"name": f"models/{m.id}",
|
||||
"baseModelId": m.id,
|
||||
"version": "001",
|
||||
"displayName": m.display_name,
|
||||
"description": m.description or f"Model {m.id}",
|
||||
"inputTokenLimit": 128000,
|
||||
"outputTokenLimit": 8192,
|
||||
"supportedGenerationMethods": ["generateContent", "countTokens"],
|
||||
"temperature": 1.0,
|
||||
"maxTemperature": 2.0,
|
||||
"topP": 0.95,
|
||||
"topK": 64,
|
||||
}
|
||||
for m in paginated_models
|
||||
]
|
||||
|
||||
response: dict = {"models": models_data}
|
||||
if end_idx < len(models):
|
||||
response["nextPageToken"] = str(end_idx)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def _build_claude_model_response(model_info: ModelInfo) -> dict:
|
||||
"""构建 Claude 格式的模型详情响应"""
|
||||
return {
|
||||
"id": model_info.id,
|
||||
"type": "model",
|
||||
"display_name": model_info.display_name,
|
||||
"created_at": model_info.created_at,
|
||||
}
|
||||
|
||||
|
||||
def _build_openai_model_response(model_info: ModelInfo) -> dict:
|
||||
"""构建 OpenAI 格式的模型详情响应"""
|
||||
return {
|
||||
"id": model_info.id,
|
||||
"object": "model",
|
||||
"created": model_info.created_timestamp,
|
||||
"owned_by": model_info.provider_name,
|
||||
}
|
||||
|
||||
|
||||
def _build_gemini_model_response(model_info: ModelInfo) -> dict:
|
||||
"""构建 Gemini 格式的模型详情响应"""
|
||||
return {
|
||||
"name": f"models/{model_info.id}",
|
||||
"baseModelId": model_info.id,
|
||||
"version": "001",
|
||||
"displayName": model_info.display_name,
|
||||
"description": model_info.description or f"Model {model_info.id}",
|
||||
"inputTokenLimit": 128000,
|
||||
"outputTokenLimit": 8192,
|
||||
"supportedGenerationMethods": ["generateContent", "countTokens"],
|
||||
"temperature": 1.0,
|
||||
"maxTemperature": 2.0,
|
||||
"topP": 0.95,
|
||||
"topK": 64,
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 404 响应
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _build_404_response(model_id: str, api_format: str) -> JSONResponse:
|
||||
"""根据 API 格式构建 404 响应"""
|
||||
if api_format == "claude":
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"type": "error",
|
||||
"error": {"type": "not_found_error", "message": f"Model '{model_id}' not found"},
|
||||
},
|
||||
)
|
||||
elif api_format == "gemini":
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"error": {
|
||||
"code": 404,
|
||||
"message": f"models/{model_id} is not found",
|
||||
"status": "NOT_FOUND",
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"error": {
|
||||
"message": f"The model '{model_id}' does not exist",
|
||||
"type": "invalid_request_error",
|
||||
"param": "model",
|
||||
"code": "model_not_found",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 路由端点
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get("/v1/models", response_model=None)
|
||||
async def list_models(
|
||||
request: Request,
|
||||
# Claude 分页参数
|
||||
before_id: Optional[str] = Query(None, description="返回此 ID 之前的结果 (Claude)"),
|
||||
after_id: Optional[str] = Query(None, description="返回此 ID 之后的结果 (Claude)"),
|
||||
limit: int = Query(20, ge=1, le=1000, description="返回数量限制 (Claude)"),
|
||||
# Gemini 分页参数
|
||||
page_size: int = Query(50, alias="pageSize", ge=1, le=1000, description="每页数量 (Gemini)"),
|
||||
page_token: Optional[str] = Query(None, alias="pageToken", description="分页 token (Gemini)"),
|
||||
db: Session = Depends(get_db),
|
||||
) -> Union[dict, JSONResponse]:
|
||||
"""
|
||||
List models - 根据请求头认证方式返回对应格式
|
||||
|
||||
- x-api-key -> Claude 格式
|
||||
- x-goog-api-key 或 ?key= -> Gemini 格式
|
||||
- Authorization: Bearer -> OpenAI 格式
|
||||
"""
|
||||
api_format, api_key = _detect_api_format_and_key(request)
|
||||
logger.info(f"[Models] GET /v1/models | format={api_format}")
|
||||
|
||||
# 认证
|
||||
user, _ = _authenticate(db, api_key)
|
||||
if not user:
|
||||
return _build_auth_error_response(api_format)
|
||||
|
||||
formats = _get_formats_for_api(api_format)
|
||||
|
||||
available_provider_ids = get_available_provider_ids(db, formats)
|
||||
if not available_provider_ids:
|
||||
if api_format == "claude":
|
||||
return {"data": [], "has_more": False, "first_id": None, "last_id": None}
|
||||
elif api_format == "gemini":
|
||||
return {"models": []}
|
||||
else:
|
||||
return {"object": "list", "data": []}
|
||||
|
||||
models = await list_available_models(db, available_provider_ids, formats)
|
||||
logger.debug(f"[Models] 返回 {len(models)} 个模型")
|
||||
|
||||
if api_format == "claude":
|
||||
return _build_claude_list_response(models, before_id, after_id, limit)
|
||||
elif api_format == "gemini":
|
||||
return _build_gemini_list_response(models, page_size, page_token)
|
||||
else:
|
||||
return _build_openai_list_response(models)
|
||||
|
||||
|
||||
@router.get("/v1/models/{model_id:path}", response_model=None)
|
||||
async def retrieve_model(
|
||||
model_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> Union[dict, JSONResponse]:
|
||||
"""
|
||||
Retrieve model - 根据请求头认证方式返回对应格式
|
||||
"""
|
||||
api_format, api_key = _detect_api_format_and_key(request)
|
||||
|
||||
# Gemini 格式的 name 带 "models/" 前缀,需要移除
|
||||
if api_format == "gemini" and model_id.startswith("models/"):
|
||||
model_id = model_id[7:]
|
||||
|
||||
logger.info(f"[Models] GET /v1/models/{model_id} | format={api_format}")
|
||||
|
||||
# 认证
|
||||
user, _ = _authenticate(db, api_key)
|
||||
if not user:
|
||||
return _build_auth_error_response(api_format)
|
||||
|
||||
formats = _get_formats_for_api(api_format)
|
||||
|
||||
available_provider_ids = get_available_provider_ids(db, formats)
|
||||
model_info = find_model_by_id(db, model_id, available_provider_ids, formats)
|
||||
|
||||
if not model_info:
|
||||
return _build_404_response(model_id, api_format)
|
||||
|
||||
if api_format == "claude":
|
||||
return _build_claude_model_response(model_info)
|
||||
elif api_format == "gemini":
|
||||
return _build_gemini_model_response(model_info)
|
||||
else:
|
||||
return _build_openai_model_response(model_info)
|
||||
|
||||
|
||||
# Gemini 专用路径 /v1beta/models
|
||||
@router.get("/v1beta/models", response_model=None)
|
||||
async def list_models_gemini(
|
||||
request: Request,
|
||||
page_size: int = Query(50, alias="pageSize", ge=1, le=1000),
|
||||
page_token: Optional[str] = Query(None, alias="pageToken"),
|
||||
db: Session = Depends(get_db),
|
||||
) -> Union[dict, JSONResponse]:
|
||||
"""List models (Gemini v1beta 端点)"""
|
||||
logger.info("[Models] GET /v1beta/models | format=gemini")
|
||||
|
||||
# 从 x-goog-api-key 或 ?key= 提取 API Key
|
||||
gemini_def = API_FORMAT_DEFINITIONS[APIFormat.GEMINI]
|
||||
api_key = _extract_api_key_from_request(request, gemini_def)
|
||||
|
||||
# 认证
|
||||
user, _ = _authenticate(db, api_key)
|
||||
if not user:
|
||||
return _build_auth_error_response("gemini")
|
||||
|
||||
available_provider_ids = get_available_provider_ids(db, _GEMINI_FORMATS)
|
||||
if not available_provider_ids:
|
||||
return {"models": []}
|
||||
|
||||
models = await list_available_models(db, available_provider_ids, _GEMINI_FORMATS)
|
||||
logger.debug(f"[Models] 返回 {len(models)} 个模型")
|
||||
response = _build_gemini_list_response(models, page_size, page_token)
|
||||
logger.debug(f"[Models] Gemini 响应: {response}")
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/v1beta/models/{model_name:path}", response_model=None)
|
||||
async def get_model_gemini(
|
||||
request: Request,
|
||||
model_name: str,
|
||||
db: Session = Depends(get_db),
|
||||
) -> Union[dict, JSONResponse]:
|
||||
"""Get model (Gemini v1beta 端点)"""
|
||||
# 移除 "models/" 前缀(如果有)
|
||||
model_id = model_name[7:] if model_name.startswith("models/") else model_name
|
||||
logger.info(f"[Models] GET /v1beta/models/{model_id} | format=gemini")
|
||||
|
||||
# 从 x-goog-api-key 或 ?key= 提取 API Key
|
||||
gemini_def = API_FORMAT_DEFINITIONS[APIFormat.GEMINI]
|
||||
api_key = _extract_api_key_from_request(request, gemini_def)
|
||||
|
||||
# 认证
|
||||
user, _ = _authenticate(db, api_key)
|
||||
if not user:
|
||||
return _build_auth_error_response("gemini")
|
||||
|
||||
available_provider_ids = get_available_provider_ids(db, _GEMINI_FORMATS)
|
||||
model_info = find_model_by_id(db, model_id, available_provider_ids, _GEMINI_FORMATS)
|
||||
|
||||
if not model_info:
|
||||
return _build_404_response(model_id, "gemini")
|
||||
|
||||
return _build_gemini_model_response(model_info)
|
||||
@@ -3,6 +3,8 @@ OpenAI API 端点
|
||||
|
||||
- /v1/chat/completions - OpenAI Chat API
|
||||
- /v1/responses - OpenAI Responses API (CLI)
|
||||
|
||||
注意: /v1/models 端点由 models.py 统一处理,根据请求头返回对应格式
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import and_, func
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.adapter import ApiAdapter, ApiMode
|
||||
from src.api.base.authenticated_adapter import AuthenticatedApiAdapter
|
||||
@@ -713,7 +713,7 @@ class ListAvailableProvidersAdapter(AuthenticatedApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from src.models.database import Model, ModelMapping, ProviderEndpoint
|
||||
from src.models.database import Model, ProviderEndpoint
|
||||
|
||||
db = context.db
|
||||
|
||||
@@ -765,53 +765,6 @@ class ListAvailableProvidersAdapter(AuthenticatedApiAdapter):
|
||||
}
|
||||
)
|
||||
|
||||
# 查询该 Provider 所有 Model 对应的 GlobalModel 的别名/映射
|
||||
provider_model_global_ids = {
|
||||
m.global_model_id for m in provider.models if m.global_model_id
|
||||
}
|
||||
if provider_model_global_ids:
|
||||
# 查询全局别名 + Provider 特定映射
|
||||
alias_mappings = (
|
||||
db.query(ModelMapping)
|
||||
.options(joinedload(ModelMapping.target_global_model))
|
||||
.filter(
|
||||
ModelMapping.target_global_model_id.in_(provider_model_global_ids),
|
||||
ModelMapping.is_active == True,
|
||||
(ModelMapping.provider_id == provider.id)
|
||||
| (ModelMapping.provider_id == None),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
for alias_obj in alias_mappings:
|
||||
# 为这个别名找到该 Provider 的 Model 实现
|
||||
model = next(
|
||||
(
|
||||
m
|
||||
for m in provider.models
|
||||
if m.global_model_id == alias_obj.target_global_model_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
if model:
|
||||
models_data.append(
|
||||
{
|
||||
"id": alias_obj.id,
|
||||
"name": alias_obj.source_model,
|
||||
"display_name": (
|
||||
alias_obj.target_global_model.display_name
|
||||
if alias_obj.target_global_model
|
||||
else alias_obj.source_model
|
||||
),
|
||||
"input_price_per_1m": model.input_price_per_1m,
|
||||
"output_price_per_1m": model.output_price_per_1m,
|
||||
"cache_creation_price_per_1m": model.cache_creation_price_per_1m,
|
||||
"cache_read_price_per_1m": model.cache_read_price_per_1m,
|
||||
"supports_vision": model.supports_vision,
|
||||
"supports_function_calling": model.supports_function_calling,
|
||||
"supports_streaming": model.supports_streaming,
|
||||
}
|
||||
)
|
||||
|
||||
result.append(
|
||||
{
|
||||
"id": provider.id,
|
||||
|
||||
@@ -14,7 +14,6 @@ class CacheTTL:
|
||||
# Provider/Model 缓存 - 配置变更不频繁
|
||||
PROVIDER = 300 # 5分钟
|
||||
MODEL = 300 # 5分钟
|
||||
MODEL_MAPPING = 300 # 5分钟
|
||||
|
||||
# 缓存亲和性 - 对应 provider_api_key.cache_ttl_minutes 默认值
|
||||
CACHE_AFFINITY = 300 # 5分钟
|
||||
@@ -33,9 +32,6 @@ class CacheSize:
|
||||
# 默认 LRU 缓存大小
|
||||
DEFAULT = 1000
|
||||
|
||||
# ModelMapping 缓存(可能有较多别名)
|
||||
MODEL_MAPPING = 2000
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# 并发和限流常量
|
||||
|
||||
@@ -59,7 +59,7 @@ _DEFINITIONS: Dict[APIFormat, ApiFormatDefinition] = {
|
||||
api_format=APIFormat.CLAUDE,
|
||||
aliases=("claude", "anthropic", "claude_compatible"),
|
||||
default_path="/v1/messages",
|
||||
path_prefix="", # 本站路径前缀,可配置如 "/claude"
|
||||
path_prefix="", # 通过请求头区分格式,不使用路径前缀
|
||||
auth_header="x-api-key",
|
||||
auth_type="header",
|
||||
),
|
||||
@@ -85,7 +85,7 @@ _DEFINITIONS: Dict[APIFormat, ApiFormatDefinition] = {
|
||||
"openai_compatible",
|
||||
),
|
||||
default_path="/v1/chat/completions",
|
||||
path_prefix="", # 本站路径前缀,可配置如 "/openai"
|
||||
path_prefix="", # 默认格式
|
||||
auth_header="Authorization",
|
||||
auth_type="bearer",
|
||||
),
|
||||
@@ -93,7 +93,7 @@ _DEFINITIONS: Dict[APIFormat, ApiFormatDefinition] = {
|
||||
api_format=APIFormat.OPENAI_CLI,
|
||||
aliases=("openai_cli", "responses"),
|
||||
default_path="/responses",
|
||||
path_prefix="",
|
||||
path_prefix="", # 与 OPENAI 共享入口
|
||||
auth_header="Authorization",
|
||||
auth_type="bearer",
|
||||
),
|
||||
@@ -101,7 +101,7 @@ _DEFINITIONS: Dict[APIFormat, ApiFormatDefinition] = {
|
||||
api_format=APIFormat.GEMINI,
|
||||
aliases=("gemini", "google", "vertex"),
|
||||
default_path="/v1beta/models/{model}:{action}",
|
||||
path_prefix="", # 本站路径前缀,可配置如 "/gemini"
|
||||
path_prefix="", # 通过请求头区分格式
|
||||
auth_header="x-goog-api-key",
|
||||
auth_type="header",
|
||||
),
|
||||
|
||||
@@ -115,7 +115,7 @@ class SyncLRUCache:
|
||||
"""删除缓存值(通过索引)"""
|
||||
self.delete(key)
|
||||
|
||||
def keys(self):
|
||||
def keys(self) -> list:
|
||||
"""返回所有未过期的 key"""
|
||||
with self._lock:
|
||||
now = time.time()
|
||||
|
||||
@@ -67,7 +67,7 @@ FILE_FORMAT = "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:
|
||||
logger.remove()
|
||||
|
||||
|
||||
def _log_filter(record):
|
||||
def _log_filter(record: dict) -> bool: # type: ignore[type-arg]
|
||||
return "watchfiles" not in record["name"]
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ if IS_DOCKER:
|
||||
sys.stdout,
|
||||
format=CONSOLE_FORMAT_PROD,
|
||||
level=LOG_LEVEL,
|
||||
filter=_log_filter,
|
||||
filter=_log_filter, # type: ignore[arg-type]
|
||||
colorize=False,
|
||||
)
|
||||
else:
|
||||
@@ -84,7 +84,7 @@ else:
|
||||
sys.stdout,
|
||||
format=CONSOLE_FORMAT_DEV,
|
||||
level=LOG_LEVEL,
|
||||
filter=_log_filter,
|
||||
filter=_log_filter, # type: ignore[arg-type]
|
||||
colorize=True,
|
||||
)
|
||||
|
||||
@@ -97,7 +97,7 @@ if not DISABLE_FILE_LOG:
|
||||
log_dir / "app.log",
|
||||
format=FILE_FORMAT,
|
||||
level="DEBUG",
|
||||
filter=_log_filter,
|
||||
filter=_log_filter, # type: ignore[arg-type]
|
||||
rotation="00:00",
|
||||
retention="30 days",
|
||||
compression="gz",
|
||||
@@ -110,7 +110,7 @@ if not DISABLE_FILE_LOG:
|
||||
log_dir / "error.log",
|
||||
format=FILE_FORMAT,
|
||||
level="ERROR",
|
||||
filter=_log_filter,
|
||||
filter=_log_filter, # type: ignore[arg-type]
|
||||
rotation="00:00",
|
||||
retention="30 days",
|
||||
compression="gz",
|
||||
|
||||
@@ -44,3 +44,24 @@ health_open_circuits = Gauge(
|
||||
"health_open_circuits",
|
||||
"Number of provider keys currently in circuit breaker open state",
|
||||
)
|
||||
|
||||
# 模型映射解析相关
|
||||
model_mapping_resolution_total = Counter(
|
||||
"model_mapping_resolution_total",
|
||||
"Total number of model mapping resolutions",
|
||||
["method", "cache_hit"],
|
||||
# method: direct_match, provider_model_name, alias, not_found
|
||||
# cache_hit: true, false
|
||||
)
|
||||
|
||||
model_mapping_resolution_duration_seconds = Histogram(
|
||||
"model_mapping_resolution_duration_seconds",
|
||||
"Duration of model mapping resolution in seconds",
|
||||
["method"],
|
||||
buckets=[0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0], # 1ms 到 1s
|
||||
)
|
||||
|
||||
model_mapping_conflict_total = Counter(
|
||||
"model_mapping_conflict_total",
|
||||
"Total number of mapping conflicts detected (same name maps to multiple GlobalModels)",
|
||||
)
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class ProviderHealthTracker:
|
||||
@@ -32,7 +32,7 @@ class ProviderHealthTracker:
|
||||
# 存储优先级调整
|
||||
self.priority_adjustments: Dict[str, int] = {}
|
||||
|
||||
def record_success(self, provider_name: str):
|
||||
def record_success(self, provider_name: str) -> None:
|
||||
"""记录成功的请求"""
|
||||
current_time = time.time()
|
||||
|
||||
@@ -47,7 +47,7 @@ class ProviderHealthTracker:
|
||||
if self.priority_adjustments.get(provider_name, 0) < 0:
|
||||
self.priority_adjustments[provider_name] += 1
|
||||
|
||||
def record_failure(self, provider_name: str):
|
||||
def record_failure(self, provider_name: str) -> None:
|
||||
"""记录失败的请求"""
|
||||
current_time = time.time()
|
||||
|
||||
@@ -93,7 +93,7 @@ class ProviderHealthTracker:
|
||||
"status": self._get_status_label(failure_rate, recent_failures),
|
||||
}
|
||||
|
||||
def _cleanup_old_records(self, provider_name: str, current_time: float):
|
||||
def _cleanup_old_records(self, provider_name: str, current_time: float) -> None:
|
||||
"""清理超出时间窗口的记录"""
|
||||
# 清理失败记录
|
||||
self.failures[provider_name] = [
|
||||
@@ -130,7 +130,7 @@ class ProviderHealthTracker:
|
||||
adjustment = self.get_priority_adjustment(provider_name)
|
||||
return adjustment > -3
|
||||
|
||||
def reset_provider_health(self, provider_name: str):
|
||||
def reset_provider_health(self, provider_name: str) -> None:
|
||||
"""重置提供商的健康状态(管理员手动操作)"""
|
||||
self.failures[provider_name] = []
|
||||
self.successes[provider_name] = []
|
||||
@@ -146,7 +146,7 @@ class SimpleProviderSelector:
|
||||
def __init__(self, health_tracker: ProviderHealthTracker):
|
||||
self.health_tracker = health_tracker
|
||||
|
||||
def select_provider(self, providers: list, specified_provider: Optional[str] = None):
|
||||
def select_provider(self, providers: list, specified_provider: Optional[str] = None) -> Any:
|
||||
"""
|
||||
选择提供商
|
||||
|
||||
|
||||
@@ -5,12 +5,12 @@
|
||||
|
||||
import asyncio
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, TypeVar
|
||||
from typing import Any, Callable, Coroutine, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
async def run_in_executor(func: Callable[..., T], *args, **kwargs) -> T:
|
||||
async def run_in_executor(func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
|
||||
"""
|
||||
在线程池中运行同步函数,避免阻塞事件循环
|
||||
|
||||
@@ -21,7 +21,7 @@ async def run_in_executor(func: Callable[..., T], *args, **kwargs) -> T:
|
||||
return await loop.run_in_executor(None, lambda: func(*args, **kwargs))
|
||||
|
||||
|
||||
def async_wrap_sync_db(func: Callable[..., T]) -> Callable[..., Any]:
|
||||
def async_wrap_sync_db(func: Callable[..., T]) -> Callable[..., Coroutine[Any, Any, T]]:
|
||||
"""
|
||||
装饰器:包装同步数据库函数为异步函数
|
||||
|
||||
@@ -35,7 +35,7 @@ def async_wrap_sync_db(func: Callable[..., T]) -> Callable[..., Any]:
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> T:
|
||||
return await run_in_executor(func, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -347,7 +347,7 @@ def init_default_models(db: Session):
|
||||
"""初始化默认模型配置"""
|
||||
|
||||
# 注意:作为中转代理服务,不再预设模型配置
|
||||
# 模型配置应该通过 Model 和 ModelMapping 表动态管理
|
||||
# 模型配置应该通过 GlobalModel 和 Model 表动态管理
|
||||
# 这个函数保留用于未来可能的默认模型初始化
|
||||
pass
|
||||
|
||||
|
||||
@@ -2,4 +2,4 @@
|
||||
中间件模块
|
||||
"""
|
||||
|
||||
__all__ = []
|
||||
__all__: list[str] = []
|
||||
|
||||
@@ -334,7 +334,6 @@ class ProviderResponse(BaseModel):
|
||||
updated_at: datetime
|
||||
models_count: int = 0
|
||||
active_models_count: int = 0
|
||||
model_mappings_count: int = 0
|
||||
api_keys_count: int = 0
|
||||
|
||||
class Config:
|
||||
@@ -346,7 +345,11 @@ class ModelCreate(BaseModel):
|
||||
"""创建模型请求 - 价格和能力字段可选,为空时使用 GlobalModel 默认值"""
|
||||
|
||||
provider_model_name: str = Field(
|
||||
..., min_length=1, max_length=200, description="Provider 侧的模型名称"
|
||||
..., min_length=1, max_length=200, description="Provider 侧的主模型名称"
|
||||
)
|
||||
provider_model_aliases: Optional[List[dict]] = Field(
|
||||
None,
|
||||
description="模型名称别名列表,格式: [{'name': 'alias1', 'priority': 1}, ...]",
|
||||
)
|
||||
global_model_id: str = Field(..., description="关联的 GlobalModel ID(必填)")
|
||||
# 按次计费配置 - 可选,为空时使用 GlobalModel 默认值
|
||||
@@ -374,6 +377,10 @@ class ModelUpdate(BaseModel):
|
||||
"""更新模型请求"""
|
||||
|
||||
provider_model_name: Optional[str] = Field(None, min_length=1, max_length=200)
|
||||
provider_model_aliases: Optional[List[dict]] = Field(
|
||||
None,
|
||||
description="模型名称别名列表,格式: [{'name': 'alias1', 'priority': 1}, ...]",
|
||||
)
|
||||
global_model_id: Optional[str] = None
|
||||
# 按次计费配置
|
||||
price_per_request: Optional[float] = Field(None, ge=0, description="每次请求固定费用")
|
||||
@@ -398,6 +405,7 @@ class ModelResponse(BaseModel):
|
||||
provider_id: str
|
||||
global_model_id: Optional[str]
|
||||
provider_model_name: str
|
||||
provider_model_aliases: Optional[List[dict]] = None
|
||||
|
||||
# 按次计费配置
|
||||
price_per_request: Optional[float] = None
|
||||
@@ -465,54 +473,6 @@ class ModelDetailResponse(BaseModel):
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# ========== 模型映射 ==========
|
||||
class ModelMappingCreate(BaseModel):
|
||||
"""创建模型映射请求(源模型到目标模型的映射)"""
|
||||
|
||||
source_model: str = Field(..., min_length=1, max_length=200, description="源模型名或别名")
|
||||
target_global_model_id: str = Field(..., description="目标 GlobalModel ID")
|
||||
provider_id: Optional[str] = Field(None, description="Provider ID(为空时表示全局别名)")
|
||||
mapping_type: str = Field(
|
||||
"alias",
|
||||
description="映射类型:alias=按目标模型计费(别名),mapping=按源模型计费(降级映射)",
|
||||
)
|
||||
is_active: bool = Field(True, description="是否启用")
|
||||
|
||||
|
||||
class ModelMappingUpdate(BaseModel):
|
||||
"""更新模型映射请求"""
|
||||
|
||||
source_model: Optional[str] = Field(
|
||||
None, min_length=1, max_length=200, description="源模型名或别名"
|
||||
)
|
||||
target_global_model_id: Optional[str] = Field(None, description="目标 GlobalModel ID")
|
||||
provider_id: Optional[str] = Field(None, description="Provider ID(为空时表示全局别名)")
|
||||
mapping_type: Optional[str] = Field(
|
||||
None, description="映射类型:alias=按目标模型计费(别名),mapping=按源模型计费(降级映射)"
|
||||
)
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
|
||||
class ModelMappingResponse(BaseModel):
|
||||
"""模型映射响应"""
|
||||
|
||||
id: str
|
||||
source_model: str
|
||||
target_global_model_id: str
|
||||
target_global_model_name: Optional[str]
|
||||
target_global_model_display_name: Optional[str]
|
||||
provider_id: Optional[str]
|
||||
provider_name: Optional[str]
|
||||
scope: str = Field(..., description="global 或 provider")
|
||||
mapping_type: str = Field(..., description="映射类型:alias 或 mapping")
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# ========== 系统设置 ==========
|
||||
class SystemSettingsRequest(BaseModel):
|
||||
"""系统设置请求"""
|
||||
@@ -558,7 +518,6 @@ class PublicProviderResponse(BaseModel):
|
||||
# 统计信息
|
||||
models_count: int
|
||||
active_models_count: int
|
||||
mappings_count: int
|
||||
endpoints_count: int # 端点总数
|
||||
active_endpoints_count: int # 活跃端点数
|
||||
|
||||
@@ -587,19 +546,6 @@ class PublicModelResponse(BaseModel):
|
||||
is_active: bool = True
|
||||
|
||||
|
||||
class PublicModelMappingResponse(BaseModel):
|
||||
"""公开的模型映射信息响应"""
|
||||
|
||||
id: str
|
||||
source_model: str
|
||||
target_global_model_id: str
|
||||
target_global_model_name: Optional[str]
|
||||
target_global_model_display_name: Optional[str]
|
||||
provider_id: Optional[str] = None
|
||||
scope: str = Field(..., description="global 或 provider")
|
||||
is_active: bool
|
||||
|
||||
|
||||
class ProviderStatsResponse(BaseModel):
|
||||
"""提供商统计信息响应"""
|
||||
|
||||
@@ -607,7 +553,6 @@ class ProviderStatsResponse(BaseModel):
|
||||
active_providers: int
|
||||
total_models: int
|
||||
active_models: int
|
||||
total_mappings: int
|
||||
supported_formats: List[str]
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import secrets
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum as PyEnum
|
||||
from typing import Optional
|
||||
|
||||
import bcrypt
|
||||
from sqlalchemy import (
|
||||
@@ -491,9 +492,6 @@ class Provider(Base):
|
||||
|
||||
# 关系
|
||||
models = relationship("Model", back_populates="provider", cascade="all, delete-orphan")
|
||||
model_mappings = relationship(
|
||||
"ModelMapping", back_populates="provider", cascade="all, delete-orphan"
|
||||
)
|
||||
endpoints = relationship(
|
||||
"ProviderEndpoint", back_populates="provider", cascade="all, delete-orphan"
|
||||
)
|
||||
@@ -656,7 +654,11 @@ class Model(Base):
|
||||
global_model_id = Column(String(36), ForeignKey("global_models.id"), nullable=False, index=True)
|
||||
|
||||
# Provider 映射配置
|
||||
provider_model_name = Column(String(200), nullable=False) # Provider 侧的模型名称
|
||||
provider_model_name = Column(String(200), nullable=False) # Provider 侧的主模型名称
|
||||
# 模型名称别名列表(带优先级),用于同一模型在 Provider 侧有多个名称变体的场景
|
||||
# 格式: [{"name": "Claude-Sonnet-4.5", "priority": 1}, {"name": "Claude-Sonnet-4-5", "priority": 2}]
|
||||
# 为空时只使用 provider_model_name
|
||||
provider_model_aliases = Column(JSON, nullable=True, default=None)
|
||||
|
||||
# 按次计费配置(每次请求的固定费用,美元)- 可为空,为空时使用 GlobalModel 的默认值
|
||||
price_per_request = Column(Float, nullable=True) # 每次请求固定费用
|
||||
@@ -786,60 +788,83 @@ class Model(Base):
|
||||
def get_effective_supports_image_generation(self) -> bool:
|
||||
return self._get_effective_capability("supports_image_generation", False)
|
||||
|
||||
def select_provider_model_name(self, affinity_key: Optional[str] = None) -> str:
|
||||
"""按优先级选择要使用的 Provider 模型名称
|
||||
|
||||
class ModelMapping(Base):
|
||||
"""模型映射表 - 统一处理别名与降级策略
|
||||
如果配置了 provider_model_aliases,按优先级选择(数字越小越优先);
|
||||
相同优先级的别名通过哈希分散实现负载均衡(与 Key 调度策略一致);
|
||||
否则返回 provider_model_name。
|
||||
|
||||
设计原则:
|
||||
- source_model 接收用户请求的原始模型名/别名
|
||||
- target_global_model_id 指向真实的 GlobalModel
|
||||
- provider_id 为空表示全局别名,非空表示 Provider 特定映射/降级
|
||||
- 一个 (source_model, provider_id) 组合唯一
|
||||
Args:
|
||||
affinity_key: 用于哈希分散的亲和键(如用户 API Key 哈希),确保同一用户稳定选择同一别名
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
映射类型 (mapping_type):
|
||||
- alias: 别名模式,按目标模型计费(只是名称简写)
|
||||
- mapping: 映射模式,按源模型计费(模型降级/替代)
|
||||
"""
|
||||
if not self.provider_model_aliases:
|
||||
return self.provider_model_name
|
||||
|
||||
__tablename__ = "model_mappings"
|
||||
raw_aliases = self.provider_model_aliases
|
||||
if not isinstance(raw_aliases, list) or len(raw_aliases) == 0:
|
||||
return self.provider_model_name
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True)
|
||||
aliases: list[dict] = []
|
||||
for raw in raw_aliases:
|
||||
if not isinstance(raw, dict):
|
||||
continue
|
||||
name = raw.get("name")
|
||||
if not isinstance(name, str) or not name.strip():
|
||||
continue
|
||||
|
||||
# 源模型名称(可能是别名或真实 GlobalModel.name)
|
||||
source_model = Column(String(200), nullable=False, index=True)
|
||||
raw_priority = raw.get("priority", 1)
|
||||
try:
|
||||
priority = int(raw_priority)
|
||||
except Exception:
|
||||
priority = 1
|
||||
if priority < 1:
|
||||
priority = 1
|
||||
|
||||
# 目标 GlobalModel
|
||||
target_global_model_id = Column(
|
||||
String(36), ForeignKey("global_models.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
aliases.append({"name": name.strip(), "priority": priority})
|
||||
|
||||
# Provider 关联:NULL 代表全局别名
|
||||
provider_id = Column(String(36), ForeignKey("providers.id"), nullable=True, index=True)
|
||||
if not aliases:
|
||||
return self.provider_model_name
|
||||
|
||||
# 映射类型:alias=按目标模型计费,mapping=按源模型计费
|
||||
mapping_type = Column(String(20), nullable=False, default="alias", index=True)
|
||||
# 按优先级排序(数字越小越优先)
|
||||
sorted_aliases = sorted(aliases, key=lambda x: x["priority"])
|
||||
|
||||
# 状态
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
# 获取最高优先级(最小数字)
|
||||
highest_priority = sorted_aliases[0]["priority"]
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
nullable=False,
|
||||
)
|
||||
# 获取所有最高优先级的别名
|
||||
top_priority_aliases = [
|
||||
alias for alias in sorted_aliases
|
||||
if alias["priority"] == highest_priority
|
||||
]
|
||||
|
||||
# 关系
|
||||
target_global_model = relationship("GlobalModel", foreign_keys=[target_global_model_id])
|
||||
provider = relationship("Provider", back_populates="model_mappings")
|
||||
# 如果有多个相同优先级的别名,通过哈希分散选择
|
||||
if len(top_priority_aliases) > 1 and affinity_key:
|
||||
# 为每个别名计算哈希得分,选择得分最小的
|
||||
def hash_score(alias: dict) -> int:
|
||||
combined = f"{affinity_key}:{alias['name']}"
|
||||
return int(hashlib.md5(combined.encode()).hexdigest(), 16)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("source_model", "provider_id", name="uq_model_mapping_source_provider"),
|
||||
)
|
||||
selected = min(top_priority_aliases, key=hash_score)
|
||||
elif len(top_priority_aliases) > 1:
|
||||
# 没有 affinity_key 时,使用确定性选择(按名称排序后取第一个)
|
||||
# 避免随机选择导致同一请求重试时选择不同的模型名称
|
||||
selected = min(top_priority_aliases, key=lambda x: x["name"])
|
||||
else:
|
||||
selected = top_priority_aliases[0]
|
||||
|
||||
return selected["name"]
|
||||
|
||||
def get_all_provider_model_names(self) -> list[str]:
|
||||
"""获取所有可用的 Provider 模型名称(主名称 + 别名)"""
|
||||
names = [self.provider_model_name]
|
||||
if self.provider_model_aliases:
|
||||
for alias in self.provider_model_aliases:
|
||||
if isinstance(alias, dict) and alias.get("name"):
|
||||
names.append(alias["name"])
|
||||
return names
|
||||
|
||||
|
||||
class ProviderAPIKey(Base):
|
||||
|
||||
@@ -7,8 +7,6 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from .api import ModelCreate
|
||||
|
||||
|
||||
# ========== 阶梯计费相关模型 ==========
|
||||
|
||||
@@ -131,24 +129,14 @@ class ModelCatalogProviderDetail(BaseModel):
|
||||
supports_function_calling: Optional[bool] = None
|
||||
supports_streaming: Optional[bool] = None
|
||||
is_active: bool
|
||||
mapping_id: Optional[str]
|
||||
|
||||
|
||||
class OrphanedModel(BaseModel):
|
||||
"""孤立的统一模型(Mapping 存在但 GlobalModel 缺失)"""
|
||||
|
||||
alias: str # 别名
|
||||
global_model_name: Optional[str] # 关联的 GlobalModel 名称(如果有)
|
||||
mapping_count: int
|
||||
|
||||
|
||||
class ModelCatalogItem(BaseModel):
|
||||
"""统一模型目录条目(方案 A:基于 GlobalModel)"""
|
||||
"""统一模型目录条目(基于 GlobalModel)"""
|
||||
|
||||
global_model_name: str # GlobalModel.name
|
||||
display_name: str # GlobalModel.display_name
|
||||
description: Optional[str] # GlobalModel.description
|
||||
aliases: List[str] # 所有指向该 GlobalModel 的别名列表
|
||||
providers: List[ModelCatalogProviderDetail] # 支持该模型的 Provider 列表
|
||||
price_range: ModelPriceRange # 价格区间(从所有 Provider 的 Model 中聚合)
|
||||
total_providers: int
|
||||
@@ -160,7 +148,6 @@ class ModelCatalogResponse(BaseModel):
|
||||
|
||||
models: List[ModelCatalogItem]
|
||||
total: int
|
||||
orphaned_models: List[OrphanedModel]
|
||||
|
||||
|
||||
class ProviderModelPriceInfo(BaseModel):
|
||||
@@ -174,13 +161,11 @@ class ProviderModelPriceInfo(BaseModel):
|
||||
|
||||
|
||||
class ProviderAvailableSourceModel(BaseModel):
|
||||
"""Provider 支持的统一模型条目(方案 A)"""
|
||||
"""Provider 支持的统一模型条目"""
|
||||
|
||||
global_model_name: str # GlobalModel.name
|
||||
display_name: str # GlobalModel.display_name
|
||||
provider_model_name: str # Model.provider_model_name (Provider 侧的模型名)
|
||||
has_alias: bool # 是否有别名指向该 GlobalModel
|
||||
aliases: List[str] # 别名列表
|
||||
model_id: Optional[str] # Model.id
|
||||
price: ProviderModelPriceInfo
|
||||
capabilities: ModelCapabilities
|
||||
@@ -194,50 +179,7 @@ class ProviderAvailableSourceModelsResponse(BaseModel):
|
||||
total: int
|
||||
|
||||
|
||||
class BatchAssignProviderConfig(BaseModel):
|
||||
"""批量添加映射的 Provider 配置"""
|
||||
|
||||
provider_id: str
|
||||
create_model: bool = Field(False, description="是否需要创建新的 Model")
|
||||
model_data: Optional[ModelCreate] = Field(
|
||||
None, description="create_model=true 时需要提供的模型配置", alias="model_config"
|
||||
)
|
||||
model_id: Optional[str] = Field(None, description="create_model=false 时需要提供的现有模型 ID")
|
||||
|
||||
|
||||
class BatchAssignModelMappingRequest(BaseModel):
|
||||
"""批量添加模型映射请求(方案 A:暂不支持,需要重构)"""
|
||||
|
||||
global_model_id: str # 要分配的 GlobalModel ID
|
||||
providers: List[BatchAssignProviderConfig]
|
||||
|
||||
|
||||
class BatchAssignProviderResult(BaseModel):
|
||||
"""批量映射结果条目"""
|
||||
|
||||
provider_id: str
|
||||
mapping_id: Optional[str]
|
||||
created_model: bool
|
||||
model_id: Optional[str]
|
||||
updated: bool = False
|
||||
|
||||
|
||||
class BatchAssignError(BaseModel):
|
||||
"""批量映射错误信息"""
|
||||
|
||||
provider_id: str
|
||||
error: str
|
||||
|
||||
|
||||
class BatchAssignModelMappingResponse(BaseModel):
|
||||
"""批量映射响应"""
|
||||
|
||||
success: bool
|
||||
created_mappings: List[BatchAssignProviderResult]
|
||||
errors: List[BatchAssignError]
|
||||
|
||||
|
||||
# ========== 阶段二:GlobalModel 相关模型 ==========
|
||||
# ========== GlobalModel 相关模型 ==========
|
||||
|
||||
|
||||
class GlobalModelCreate(BaseModel):
|
||||
@@ -328,7 +270,6 @@ class GlobalModelResponse(BaseModel):
|
||||
)
|
||||
# 统计数据(可选)
|
||||
provider_count: Optional[int] = Field(default=0, description="支持的 Provider 数量")
|
||||
alias_count: Optional[int] = Field(default=0, description="别名数量")
|
||||
usage_count: Optional[int] = Field(default=0, description="调用次数")
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime]
|
||||
@@ -355,7 +296,7 @@ class GlobalModelListResponse(BaseModel):
|
||||
class BatchAssignToProvidersRequest(BaseModel):
|
||||
"""批量为 Provider 添加 GlobalModel 实现"""
|
||||
|
||||
provider_ids: List[str] = Field(..., min_items=1, description="Provider ID 列表")
|
||||
provider_ids: List[str] = Field(..., min_length=1, description="Provider ID 列表")
|
||||
create_models: bool = Field(default=False, description="是否自动创建 Model 记录")
|
||||
|
||||
|
||||
@@ -379,43 +320,11 @@ class BatchAssignModelsToProviderResponse(BaseModel):
|
||||
errors: List[dict]
|
||||
|
||||
|
||||
class UpdateModelMappingRequest(BaseModel):
|
||||
"""更新模型映射请求"""
|
||||
|
||||
source_model: Optional[str] = Field(
|
||||
None, min_length=1, max_length=200, description="源模型名或别名"
|
||||
)
|
||||
target_global_model_id: Optional[str] = Field(None, description="目标 GlobalModel ID")
|
||||
provider_id: Optional[str] = Field(None, description="Provider ID(为空时为全局别名)")
|
||||
is_active: Optional[bool] = Field(None, description="是否启用")
|
||||
|
||||
|
||||
class UpdateModelMappingResponse(BaseModel):
|
||||
"""更新模型映射响应"""
|
||||
|
||||
success: bool
|
||||
mapping_id: str
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
class DeleteModelMappingResponse(BaseModel):
|
||||
"""删除模型映射响应"""
|
||||
|
||||
success: bool
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BatchAssignError",
|
||||
"BatchAssignModelMappingRequest",
|
||||
"BatchAssignModelMappingResponse",
|
||||
"BatchAssignModelsToProviderRequest",
|
||||
"BatchAssignModelsToProviderResponse",
|
||||
"BatchAssignProviderConfig",
|
||||
"BatchAssignProviderResult",
|
||||
"BatchAssignToProvidersRequest",
|
||||
"BatchAssignToProvidersResponse",
|
||||
"DeleteModelMappingResponse",
|
||||
"GlobalModelCreate",
|
||||
"GlobalModelListResponse",
|
||||
"GlobalModelResponse",
|
||||
@@ -426,10 +335,7 @@ __all__ = [
|
||||
"ModelCatalogProviderDetail",
|
||||
"ModelCatalogResponse",
|
||||
"ModelPriceRange",
|
||||
"OrphanedModel",
|
||||
"ProviderAvailableSourceModel",
|
||||
"ProviderAvailableSourceModelsResponse",
|
||||
"ProviderModelPriceInfo",
|
||||
"UpdateModelMappingRequest",
|
||||
"UpdateModelMappingResponse",
|
||||
]
|
||||
|
||||
184
src/services/cache/aware_scheduler.py
vendored
184
src/services/cache/aware_scheduler.py
vendored
@@ -28,15 +28,17 @@
|
||||
- 失效缓存亲和性,避免重复选择故障资源
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from src.core.logger import logger
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from src.core.enums import APIFormat
|
||||
from src.core.exceptions import ProviderNotAvailableException
|
||||
from src.core.exceptions import ModelNotSupportedException, ProviderNotAvailableException
|
||||
from src.core.logger import logger
|
||||
from src.models.database import (
|
||||
ApiKey,
|
||||
Model,
|
||||
@@ -44,10 +46,15 @@ from src.models.database import (
|
||||
ProviderAPIKey,
|
||||
ProviderEndpoint,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.models.database import GlobalModel
|
||||
|
||||
from src.services.cache.affinity_manager import (
|
||||
CacheAffinityManager,
|
||||
get_affinity_manager,
|
||||
)
|
||||
from src.services.cache.model_cache import ModelCacheService
|
||||
from src.services.health.monitor import health_monitor
|
||||
from src.services.provider.format import normalize_api_format
|
||||
from src.services.rate_limit.adaptive_reservation import (
|
||||
@@ -227,19 +234,6 @@ class CacheAwareScheduler:
|
||||
if provider_offset == 0:
|
||||
# 没有找到任何候选,提供友好的错误提示
|
||||
error_msg = f"模型 '{model_name}' 不可用"
|
||||
|
||||
# 查找相似模型
|
||||
from src.services.model.mapping_resolver import get_model_mapping_resolver
|
||||
|
||||
resolver = get_model_mapping_resolver()
|
||||
similar_models = resolver.find_similar_models(db, model_name, limit=3)
|
||||
|
||||
if similar_models:
|
||||
suggestions = [
|
||||
f"{name} (相似度: {score:.0%})" for name, score in similar_models
|
||||
]
|
||||
error_msg += f"\n\n您可能想使用以下模型:\n - " + "\n - ".join(suggestions)
|
||||
|
||||
raise ProviderNotAvailableException(error_msg)
|
||||
break
|
||||
|
||||
@@ -272,9 +266,11 @@ class CacheAwareScheduler:
|
||||
self._metrics["concurrency_denied"] += 1
|
||||
continue
|
||||
|
||||
logger.debug(f" └─ 选择 Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
|
||||
logger.debug(
|
||||
f" └─ 选择 Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
|
||||
f"Key=***{key.api_key[-4:]}, 缓存命中={is_cached_user}, "
|
||||
f"并发状态[{snapshot.describe()}]")
|
||||
f"并发状态[{snapshot.describe()}]"
|
||||
)
|
||||
|
||||
if key.cache_ttl_minutes > 0 and global_model_id:
|
||||
ttl = key.cache_ttl_minutes * 60 if key.cache_ttl_minutes > 0 else None
|
||||
@@ -362,7 +358,9 @@ class CacheAwareScheduler:
|
||||
logger.debug(f" -> 无并发管理器,直接通过")
|
||||
snapshot = ConcurrencySnapshot(
|
||||
endpoint_current=0,
|
||||
endpoint_limit=int(endpoint.max_concurrent) if endpoint.max_concurrent is not None else None,
|
||||
endpoint_limit=(
|
||||
int(endpoint.max_concurrent) if endpoint.max_concurrent is not None else None
|
||||
),
|
||||
key_current=0,
|
||||
key_limit=effective_key_limit,
|
||||
is_cached_user=is_cached_user,
|
||||
@@ -497,10 +495,12 @@ class CacheAwareScheduler:
|
||||
user = None
|
||||
|
||||
# 调试日志
|
||||
logger.debug(f"[_get_effective_restrictions] ApiKey={user_api_key.id[:8]}..., "
|
||||
logger.debug(
|
||||
f"[_get_effective_restrictions] ApiKey={user_api_key.id[:8]}..., "
|
||||
f"User={user.id[:8] if user else 'None'}..., "
|
||||
f"ApiKey.allowed_models={user_api_key.allowed_models}, "
|
||||
f"User.allowed_models={user.allowed_models if user else 'N/A'}")
|
||||
f"User.allowed_models={user.allowed_models if user else 'N/A'}"
|
||||
)
|
||||
|
||||
def merge_restrictions(key_restriction, user_restriction):
|
||||
"""合并两个限制列表,返回有效的限制集合"""
|
||||
@@ -579,13 +579,17 @@ class CacheAwareScheduler:
|
||||
|
||||
target_format = normalize_api_format(api_format)
|
||||
|
||||
# 0. 解析 model_name 到 GlobalModel(用于缓存亲和性的规范化标识)
|
||||
from src.services.model.mapping_resolver import get_model_mapping_resolver
|
||||
mapping_resolver = get_model_mapping_resolver()
|
||||
global_model = await mapping_resolver.get_global_model_by_request(db, model_name, None)
|
||||
# 0. 解析 model_name 到 GlobalModel(支持直接匹配和别名匹配,使用 ModelCacheService)
|
||||
global_model = await ModelCacheService.resolve_global_model_by_name_or_alias(db, model_name)
|
||||
|
||||
if not global_model:
|
||||
logger.warning(f"GlobalModel not found: {model_name}")
|
||||
raise ModelNotSupportedException(model=model_name)
|
||||
|
||||
# 使用 GlobalModel.id 作为缓存亲和性的模型标识,确保别名和规范名都能命中同一个缓存
|
||||
global_model_id: str = str(global_model.id) if global_model else model_name
|
||||
global_model_id: str = str(global_model.id)
|
||||
requested_model_name = model_name
|
||||
resolved_model_name = str(global_model.name)
|
||||
|
||||
# 获取合并后的访问限制(ApiKey + User)
|
||||
restrictions = self._get_effective_restrictions(user_api_key)
|
||||
@@ -595,16 +599,29 @@ class CacheAwareScheduler:
|
||||
allowed_models = restrictions["allowed_models"]
|
||||
|
||||
# 0.1 检查 API 格式是否被允许
|
||||
if allowed_api_formats:
|
||||
if allowed_api_formats is not None:
|
||||
if target_format.value not in allowed_api_formats:
|
||||
logger.debug(f"API Key {user_api_key.id[:8] if user_api_key else 'N/A'}... 不允许使用 API 格式 {target_format.value}, "
|
||||
f"允许的格式: {allowed_api_formats}")
|
||||
logger.debug(
|
||||
f"API Key {user_api_key.id[:8] if user_api_key else 'N/A'}... 不允许使用 API 格式 {target_format.value}, "
|
||||
f"允许的格式: {allowed_api_formats}"
|
||||
)
|
||||
return [], global_model_id
|
||||
|
||||
# 0.2 检查模型是否被允许
|
||||
if allowed_models:
|
||||
if model_name not in allowed_models:
|
||||
logger.debug(f"用户/API Key 不允许使用模型 {model_name}, " f"允许的模型: {allowed_models}")
|
||||
if allowed_models is not None:
|
||||
if (
|
||||
requested_model_name not in allowed_models
|
||||
and resolved_model_name not in allowed_models
|
||||
):
|
||||
resolved_note = (
|
||||
f" (解析为 {resolved_model_name})"
|
||||
if resolved_model_name != requested_model_name
|
||||
else ""
|
||||
)
|
||||
logger.debug(
|
||||
f"用户/API Key 不允许使用模型 {requested_model_name}{resolved_note}, "
|
||||
f"允许的模型: {allowed_models}"
|
||||
)
|
||||
return [], global_model_id
|
||||
|
||||
# 1. 查询 Providers
|
||||
@@ -618,7 +635,7 @@ class CacheAwareScheduler:
|
||||
return [], global_model_id
|
||||
|
||||
# 1.5 根据 allowed_providers 过滤(合并 ApiKey 和 User 的限制)
|
||||
if allowed_providers:
|
||||
if allowed_providers is not None:
|
||||
original_count = len(providers)
|
||||
# 同时支持 provider id 和 name 匹配
|
||||
providers = [
|
||||
@@ -635,7 +652,8 @@ class CacheAwareScheduler:
|
||||
db=db,
|
||||
providers=providers,
|
||||
target_format=target_format,
|
||||
model_name=model_name,
|
||||
model_name=requested_model_name,
|
||||
resolved_model_name=resolved_model_name,
|
||||
affinity_key=affinity_key,
|
||||
max_candidates=max_candidates,
|
||||
allowed_endpoints=allowed_endpoints,
|
||||
@@ -650,8 +668,10 @@ class CacheAwareScheduler:
|
||||
self._metrics["total_candidates"] += len(candidates)
|
||||
self._metrics["last_candidate_count"] = len(candidates)
|
||||
|
||||
logger.debug(f"预先获取到 {len(candidates)} 个可用组合 "
|
||||
f"(api_format={target_format.value}, model={model_name})")
|
||||
logger.debug(
|
||||
f"预先获取到 {len(candidates)} 个可用组合 "
|
||||
f"(api_format={target_format.value}, model={model_name})"
|
||||
)
|
||||
|
||||
# 4. 应用缓存亲和性排序(使用 global_model_id 作为模型标识)
|
||||
if affinity_key and candidates:
|
||||
@@ -685,7 +705,6 @@ class CacheAwareScheduler:
|
||||
db.query(Provider)
|
||||
.options(
|
||||
selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys),
|
||||
selectinload(Provider.model_mappings),
|
||||
# 同时加载 models 和 global_model 关系,以便 get_effective_* 方法能正确继承默认值
|
||||
selectinload(Provider.models).selectinload(Model.global_model),
|
||||
)
|
||||
@@ -715,63 +734,31 @@ class CacheAwareScheduler:
|
||||
- 模型支持的能力是全局的,与具体的 Key 无关
|
||||
- 如果模型不支持某能力,整个 Provider 的所有 Key 都应该被跳过
|
||||
|
||||
映射回退逻辑:
|
||||
- 如果存在模型映射(mapping),先尝试映射后的模型
|
||||
- 如果映射后的模型因能力不满足而失败,回退尝试原模型
|
||||
- 其他失败原因(如模型不存在、Provider 未实现等)不触发回退
|
||||
支持两种匹配方式:
|
||||
1. 直接匹配 GlobalModel.name
|
||||
2. 通过 ModelCacheService 匹配别名(全局查找)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
provider: Provider 对象
|
||||
model_name: 模型名称
|
||||
model_name: 模型名称(可以是 GlobalModel.name 或别名)
|
||||
is_stream: 是否是流式请求,如果为 True 则同时检查流式支持
|
||||
capability_requirements: 能力需求(可选),用于检查模型是否支持所需能力
|
||||
|
||||
Returns:
|
||||
(is_supported, skip_reason, supported_capabilities) - 是否支持、跳过原因、模型支持的能力列表
|
||||
"""
|
||||
from src.services.model.mapping_resolver import get_model_mapping_resolver
|
||||
|
||||
mapping_resolver = get_model_mapping_resolver()
|
||||
|
||||
# 获取映射后的模型,同时检查是否发生了映射
|
||||
global_model, is_mapped = await mapping_resolver.get_global_model_with_mapping_info(
|
||||
db, model_name, str(provider.id)
|
||||
)
|
||||
# 使用 ModelCacheService 解析模型名称(支持别名)
|
||||
global_model = await ModelCacheService.resolve_global_model_by_name_or_alias(db, model_name)
|
||||
|
||||
if not global_model:
|
||||
return False, "模型不存在", None
|
||||
# 完全未找到匹配
|
||||
return False, "模型不存在或 Provider 未配置此模型", None
|
||||
|
||||
# 尝试检查映射后的模型
|
||||
# 找到 GlobalModel 后,检查当前 Provider 是否支持
|
||||
is_supported, skip_reason, caps = await self._check_model_support_for_global_model(
|
||||
db, provider, global_model, model_name, is_stream, capability_requirements
|
||||
)
|
||||
|
||||
# 如果映射后的模型因能力不满足而失败,且存在映射,则回退尝试原模型
|
||||
if not is_supported and is_mapped and skip_reason and "不支持能力" in skip_reason:
|
||||
logger.debug(
|
||||
f"Provider {provider.name} 映射模型 {global_model.name} 能力不满足,"
|
||||
f"回退尝试原模型 {model_name}"
|
||||
)
|
||||
|
||||
# 获取原模型(不应用映射)
|
||||
original_global_model = await mapping_resolver.get_global_model_direct(db, model_name)
|
||||
|
||||
if original_global_model and original_global_model.id != global_model.id:
|
||||
# 尝试原模型
|
||||
is_supported, skip_reason, caps = await self._check_model_support_for_global_model(
|
||||
db,
|
||||
provider,
|
||||
original_global_model,
|
||||
model_name,
|
||||
is_stream,
|
||||
capability_requirements,
|
||||
)
|
||||
if is_supported:
|
||||
logger.debug(
|
||||
f"Provider {provider.name} 原模型 {original_global_model.name} 支持所需能力"
|
||||
)
|
||||
|
||||
return is_supported, skip_reason, caps
|
||||
|
||||
async def _check_model_support_for_global_model(
|
||||
@@ -798,7 +785,16 @@ class CacheAwareScheduler:
|
||||
(is_supported, skip_reason, supported_capabilities)
|
||||
"""
|
||||
# 确保 global_model 附加到当前 Session
|
||||
global_model = db.merge(global_model, load=False)
|
||||
# 注意:从缓存重建的对象是 transient 状态,不能使用 load=False
|
||||
# 使用 load=True(默认)允许 SQLAlchemy 正确处理 transient 对象
|
||||
from sqlalchemy import inspect
|
||||
insp = inspect(global_model)
|
||||
if insp.transient or insp.detached:
|
||||
# transient/detached 对象:使用默认 merge(会查询 DB 检查是否存在)
|
||||
global_model = db.merge(global_model)
|
||||
else:
|
||||
# persistent 对象:已经附加到 session,无需 merge
|
||||
pass
|
||||
|
||||
# 获取模型支持的能力列表
|
||||
model_supported_capabilities: List[str] = list(global_model.supported_capabilities or [])
|
||||
@@ -806,6 +802,11 @@ class CacheAwareScheduler:
|
||||
# 查询该 Provider 是否有实现这个 GlobalModel
|
||||
for model in provider.models:
|
||||
if model.global_model_id == global_model.id and model.is_active:
|
||||
logger.debug(
|
||||
f"[_check_model_support_for_global_model] Provider={provider.name}, "
|
||||
f"GlobalModel={global_model.name}, "
|
||||
f"provider_model_name={model.provider_model_name}"
|
||||
)
|
||||
# 检查流式支持
|
||||
if is_stream:
|
||||
supports_streaming = model.get_effective_supports_streaming()
|
||||
@@ -833,6 +834,7 @@ class CacheAwareScheduler:
|
||||
key: ProviderAPIKey,
|
||||
model_name: str,
|
||||
capability_requirements: Optional[Dict[str, bool]] = None,
|
||||
resolved_model_name: Optional[str] = None,
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
检查 API Key 的可用性
|
||||
@@ -844,6 +846,7 @@ class CacheAwareScheduler:
|
||||
key: API Key 对象
|
||||
model_name: 模型名称
|
||||
capability_requirements: 能力需求(可选)
|
||||
resolved_model_name: 解析后的 GlobalModel.name(可选)
|
||||
|
||||
Returns:
|
||||
(is_available, skip_reason)
|
||||
@@ -855,7 +858,10 @@ class CacheAwareScheduler:
|
||||
|
||||
# 模型权限检查:使用 allowed_models 白名单
|
||||
# None = 允许所有模型,[] = 拒绝所有模型,["a","b"] = 只允许指定模型
|
||||
if key.allowed_models is not None and model_name not in key.allowed_models:
|
||||
if key.allowed_models is not None and (
|
||||
model_name not in key.allowed_models
|
||||
and (not resolved_model_name or resolved_model_name not in key.allowed_models)
|
||||
):
|
||||
allowed_preview = ", ".join(key.allowed_models[:3]) if key.allowed_models else "(无)"
|
||||
suffix = "..." if len(key.allowed_models) > 3 else ""
|
||||
return False, f"模型权限不匹配(允许: {allowed_preview}{suffix})"
|
||||
@@ -880,6 +886,7 @@ class CacheAwareScheduler:
|
||||
target_format: APIFormat,
|
||||
model_name: str,
|
||||
affinity_key: Optional[str],
|
||||
resolved_model_name: Optional[str] = None,
|
||||
max_candidates: Optional[int] = None,
|
||||
allowed_endpoints: Optional[set] = None,
|
||||
is_stream: bool = False,
|
||||
@@ -892,8 +899,9 @@ class CacheAwareScheduler:
|
||||
db: 数据库会话
|
||||
providers: Provider 列表
|
||||
target_format: 目标 API 格式
|
||||
model_name: 模型名称
|
||||
model_name: 模型名称(用户请求的名称,可能是别名)
|
||||
affinity_key: 亲和性标识符(通常为API Key ID)
|
||||
resolved_model_name: 解析后的 GlobalModel.name(用于 Key.allowed_models 校验)
|
||||
max_candidates: 最大候选数
|
||||
allowed_endpoints: 允许的 Endpoint ID 集合(None 表示不限制)
|
||||
is_stream: 是否是流式请求,如果为 True 则过滤不支持流式的 Provider
|
||||
@@ -925,7 +933,7 @@ class CacheAwareScheduler:
|
||||
continue
|
||||
|
||||
# 检查 Endpoint 是否在允许列表中
|
||||
if allowed_endpoints and endpoint.id not in allowed_endpoints:
|
||||
if allowed_endpoints is not None and endpoint.id not in allowed_endpoints:
|
||||
logger.debug(
|
||||
f"Endpoint {endpoint.id[:8]}... 不在用户/API Key 的允许列表中,跳过"
|
||||
)
|
||||
@@ -938,7 +946,10 @@ class CacheAwareScheduler:
|
||||
for key in keys:
|
||||
# Key 级别的能力检查(模型级别的能力检查已在上面完成)
|
||||
is_available, skip_reason = self._check_key_availability(
|
||||
key, model_name, capability_requirements
|
||||
key,
|
||||
model_name,
|
||||
capability_requirements,
|
||||
resolved_model_name=resolved_model_name,
|
||||
)
|
||||
|
||||
candidate = ProviderCandidate(
|
||||
@@ -1007,11 +1018,13 @@ class CacheAwareScheduler:
|
||||
candidate.is_cached = True
|
||||
cached_candidates.append(candidate)
|
||||
matched = True
|
||||
logger.debug(f"检测到缓存亲和性: affinity_key={affinity_key[:8]}..., "
|
||||
logger.debug(
|
||||
f"检测到缓存亲和性: affinity_key={affinity_key[:8]}..., "
|
||||
f"api_format={api_format_str}, global_model_id={global_model_id[:8]}..., "
|
||||
f"provider={provider.name}, endpoint={endpoint.id[:8]}..., "
|
||||
f"provider_key=***{key.api_key[-4:]}, "
|
||||
f"使用次数={affinity.request_count}")
|
||||
f"使用次数={affinity.request_count}"
|
||||
)
|
||||
else:
|
||||
candidate.is_cached = False
|
||||
other_candidates.append(candidate)
|
||||
@@ -1117,6 +1130,7 @@ class CacheAwareScheduler:
|
||||
c.key.internal_priority if c.key else 999999,
|
||||
c.key.id if c.key else "",
|
||||
)
|
||||
|
||||
result.extend(sorted(group, key=secondary_sort))
|
||||
|
||||
return result
|
||||
|
||||
3
src/services/cache/backend.py
vendored
3
src/services/cache/backend.py
vendored
@@ -6,8 +6,7 @@
|
||||
2. RedisCache: Redis 缓存(分布式)
|
||||
|
||||
使用场景:
|
||||
- ModelMappingResolver: 模型映射与别名解析缓存
|
||||
- ModelMapper: 模型映射缓存
|
||||
- ModelCacheService: 模型解析缓存
|
||||
- 其他需要缓存的服务
|
||||
"""
|
||||
|
||||
|
||||
42
src/services/cache/invalidation.py
vendored
42
src/services/cache/invalidation.py
vendored
@@ -3,18 +3,14 @@
|
||||
|
||||
统一管理各种缓存的失效逻辑,支持:
|
||||
1. GlobalModel 变更时失效相关缓存
|
||||
2. ModelMapping 变更时失效别名/降级缓存
|
||||
3. Model 变更时失效模型映射缓存
|
||||
4. 支持同步和异步缓存后端
|
||||
2. Model 变更时失效模型映射缓存
|
||||
3. 支持同步和异步缓存后端
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from src.core.logger import logger
|
||||
|
||||
from src.core.logger import logger
|
||||
|
||||
|
||||
class CacheInvalidationService:
|
||||
"""
|
||||
@@ -25,14 +21,8 @@ class CacheInvalidationService:
|
||||
|
||||
def __init__(self):
|
||||
"""初始化缓存失效服务"""
|
||||
self._mapping_resolver = None
|
||||
self._model_mappers = [] # 可能有多个 ModelMapperMiddleware 实例
|
||||
|
||||
def set_mapping_resolver(self, mapping_resolver):
|
||||
"""设置模型映射解析器实例"""
|
||||
self._mapping_resolver = mapping_resolver
|
||||
logger.debug(f"[CacheInvalidation] 模型映射解析器已注册 (实例: {id(mapping_resolver)})")
|
||||
|
||||
def register_model_mapper(self, model_mapper):
|
||||
"""注册 ModelMapper 实例"""
|
||||
if model_mapper not in self._model_mappers:
|
||||
@@ -48,37 +38,12 @@ class CacheInvalidationService:
|
||||
"""
|
||||
logger.info(f"[CacheInvalidation] GlobalModel 变更: {model_name}")
|
||||
|
||||
# 异步失效模型解析器中的缓存
|
||||
if self._mapping_resolver:
|
||||
asyncio.create_task(self._mapping_resolver.invalidate_global_model_cache())
|
||||
|
||||
# 失效所有 ModelMapper 中与此模型相关的缓存
|
||||
for mapper in self._model_mappers:
|
||||
# 清空所有缓存(因为不知道哪些 provider 使用了这个模型)
|
||||
mapper.clear_cache()
|
||||
logger.debug(f"[CacheInvalidation] 已清空 ModelMapper 缓存")
|
||||
|
||||
def on_model_mapping_changed(self, source_model: str, provider_id: Optional[str] = None):
|
||||
"""
|
||||
ModelMapping 变更时的缓存失效
|
||||
|
||||
Args:
|
||||
source_model: 变更的源模型名
|
||||
provider_id: 相关 Provider(None 表示全局)
|
||||
"""
|
||||
logger.info(f"[CacheInvalidation] ModelMapping 变更: {source_model} (provider={provider_id})")
|
||||
|
||||
if self._mapping_resolver:
|
||||
asyncio.create_task(
|
||||
self._mapping_resolver.invalidate_mapping_cache(source_model, provider_id)
|
||||
)
|
||||
|
||||
for mapper in self._model_mappers:
|
||||
if provider_id:
|
||||
mapper.refresh_cache(provider_id)
|
||||
else:
|
||||
mapper.clear_cache()
|
||||
|
||||
def on_model_changed(self, provider_id: str, global_model_id: str):
|
||||
"""
|
||||
Model 变更时的缓存失效
|
||||
@@ -98,9 +63,6 @@ class CacheInvalidationService:
|
||||
"""清空所有缓存"""
|
||||
logger.info("[CacheInvalidation] 清空所有缓存")
|
||||
|
||||
if self._mapping_resolver:
|
||||
asyncio.create_task(self._mapping_resolver.clear_cache())
|
||||
|
||||
for mapper in self._model_mappers:
|
||||
mapper.clear_cache()
|
||||
|
||||
|
||||
348
src/services/cache/model_cache.py
vendored
348
src/services/cache/model_cache.py
vendored
@@ -1,16 +1,23 @@
|
||||
"""
|
||||
Model 映射缓存服务 - 减少模型映射和别名查询
|
||||
Model 映射缓存服务 - 减少模型查询
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.exc import OperationalError, ProgrammingError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.config.constants import CacheTTL
|
||||
from src.core.cache_service import CacheService
|
||||
from src.core.logger import logger
|
||||
from src.models.database import GlobalModel, Model, ModelMapping
|
||||
|
||||
from src.core.metrics import (
|
||||
model_mapping_conflict_total,
|
||||
model_mapping_resolution_duration_seconds,
|
||||
model_mapping_resolution_total,
|
||||
)
|
||||
from src.models.database import GlobalModel, Model
|
||||
|
||||
|
||||
class ModelCacheService:
|
||||
@@ -103,7 +110,9 @@ class ModelCacheService:
|
||||
# 1. 尝试从缓存获取
|
||||
cached_data = await CacheService.get(cache_key)
|
||||
if cached_data:
|
||||
logger.debug(f"Model 缓存命中(provider+global): {provider_id[:8]}...+{global_model_id[:8]}...")
|
||||
logger.debug(
|
||||
f"Model 缓存命中(provider+global): {provider_id[:8]}...+{global_model_id[:8]}..."
|
||||
)
|
||||
return ModelCacheService._dict_to_model(cached_data)
|
||||
|
||||
# 2. 缓存未命中,查询数据库
|
||||
@@ -121,7 +130,9 @@ class ModelCacheService:
|
||||
if model:
|
||||
model_dict = ModelCacheService._model_to_dict(model)
|
||||
await CacheService.set(cache_key, model_dict, ttl_seconds=ModelCacheService.CACHE_TTL)
|
||||
logger.debug(f"Model 已缓存(provider+global): {provider_id[:8]}...+{global_model_id[:8]}...")
|
||||
logger.debug(
|
||||
f"Model 已缓存(provider+global): {provider_id[:8]}...+{global_model_id[:8]}..."
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
@@ -158,66 +169,22 @@ class ModelCacheService:
|
||||
|
||||
return global_model
|
||||
|
||||
@staticmethod
|
||||
async def resolve_alias(
|
||||
db: Session, source_model: str, provider_id: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
解析模型别名(带缓存)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
source_model: 源模型名称或别名
|
||||
provider_id: Provider ID(可选,用于 Provider 特定别名)
|
||||
|
||||
Returns:
|
||||
目标 GlobalModel ID 或 None
|
||||
"""
|
||||
# 构造缓存键
|
||||
if provider_id:
|
||||
cache_key = f"alias:provider:{provider_id}:{source_model}"
|
||||
else:
|
||||
cache_key = f"alias:global:{source_model}"
|
||||
|
||||
# 1. 尝试从缓存获取
|
||||
cached_result = await CacheService.get(cache_key)
|
||||
if cached_result:
|
||||
logger.debug(f"别名缓存命中: {source_model} (provider: {provider_id or 'global'})")
|
||||
return cached_result
|
||||
|
||||
# 2. 缓存未命中,查询数据库
|
||||
query = db.query(ModelMapping).filter(ModelMapping.source_model == source_model)
|
||||
|
||||
if provider_id:
|
||||
# Provider 特定别名优先
|
||||
query = query.filter(ModelMapping.provider_id == provider_id)
|
||||
else:
|
||||
# 全局别名
|
||||
query = query.filter(ModelMapping.provider_id.is_(None))
|
||||
|
||||
mapping = query.first()
|
||||
|
||||
# 3. 写入缓存
|
||||
target_global_model_id = mapping.target_global_model_id if mapping else None
|
||||
await CacheService.set(
|
||||
cache_key, target_global_model_id, ttl_seconds=ModelCacheService.CACHE_TTL
|
||||
)
|
||||
|
||||
if mapping:
|
||||
logger.debug(f"别名已缓存: {source_model} → {target_global_model_id}")
|
||||
|
||||
return target_global_model_id
|
||||
|
||||
@staticmethod
|
||||
async def invalidate_model_cache(
|
||||
model_id: str, provider_id: Optional[str] = None, global_model_id: Optional[str] = None
|
||||
):
|
||||
model_id: str,
|
||||
provider_id: Optional[str] = None,
|
||||
global_model_id: Optional[str] = None,
|
||||
provider_model_name: Optional[str] = None,
|
||||
provider_model_aliases: Optional[list] = None,
|
||||
) -> None:
|
||||
"""清除 Model 缓存
|
||||
|
||||
Args:
|
||||
model_id: Model ID
|
||||
provider_id: Provider ID(用于清除 provider_global 缓存)
|
||||
global_model_id: GlobalModel ID(用于清除 provider_global 缓存)
|
||||
provider_model_name: provider_model_name(用于清除 resolve 缓存)
|
||||
provider_model_aliases: 映射名称列表(用于清除 resolve 缓存)
|
||||
"""
|
||||
# 清除 model:id 缓存
|
||||
await CacheService.delete(f"model:id:{model_id}")
|
||||
@@ -225,28 +192,241 @@ class ModelCacheService:
|
||||
# 清除 provider_global 缓存(如果提供了必要参数)
|
||||
if provider_id and global_model_id:
|
||||
await CacheService.delete(f"model:provider_global:{provider_id}:{global_model_id}")
|
||||
logger.debug(f"Model 缓存已清除: {model_id}, provider_global:{provider_id[:8]}...:{global_model_id[:8]}...")
|
||||
logger.debug(
|
||||
f"Model 缓存已清除: {model_id}, provider_global:{provider_id[:8]}...:{global_model_id[:8]}..."
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Model 缓存已清除: {model_id}")
|
||||
|
||||
# 清除 resolve 缓存(provider_model_name 和 aliases 可能都被用作解析 key)
|
||||
resolve_keys_to_clear = []
|
||||
if provider_model_name:
|
||||
resolve_keys_to_clear.append(provider_model_name)
|
||||
if provider_model_aliases:
|
||||
for alias_entry in provider_model_aliases:
|
||||
if isinstance(alias_entry, dict):
|
||||
alias_name = alias_entry.get("name", "").strip()
|
||||
if alias_name:
|
||||
resolve_keys_to_clear.append(alias_name)
|
||||
|
||||
for key in resolve_keys_to_clear:
|
||||
await CacheService.delete(f"global_model:resolve:{key}")
|
||||
|
||||
if resolve_keys_to_clear:
|
||||
logger.debug(f"Model resolve 缓存已清除: {resolve_keys_to_clear}")
|
||||
|
||||
@staticmethod
|
||||
async def invalidate_global_model_cache(global_model_id: str, name: Optional[str] = None):
|
||||
async def invalidate_global_model_cache(global_model_id: str, name: Optional[str] = None) -> None:
|
||||
"""清除 GlobalModel 缓存"""
|
||||
await CacheService.delete(f"global_model:id:{global_model_id}")
|
||||
if name:
|
||||
await CacheService.delete(f"global_model:name:{name}")
|
||||
# 同时清除 resolve 缓存,因为 GlobalModel.name 也是一个 resolve key
|
||||
await CacheService.delete(f"global_model:resolve:{name}")
|
||||
logger.debug(f"GlobalModel 缓存已清除: {global_model_id}")
|
||||
|
||||
@staticmethod
|
||||
async def invalidate_alias_cache(source_model: str, provider_id: Optional[str] = None):
|
||||
"""清除别名缓存"""
|
||||
if provider_id:
|
||||
cache_key = f"alias:provider:{provider_id}:{source_model}"
|
||||
else:
|
||||
cache_key = f"alias:global:{source_model}"
|
||||
async def resolve_global_model_by_name_or_alias(
|
||||
db: Session, model_name: str
|
||||
) -> Optional[GlobalModel]:
|
||||
"""
|
||||
通过名称或映射解析 GlobalModel(带缓存,支持映射匹配)
|
||||
|
||||
await CacheService.delete(cache_key)
|
||||
logger.debug(f"别名缓存已清除: {source_model}")
|
||||
查找顺序:
|
||||
1. 检查缓存
|
||||
2. 通过映射匹配(查询 Model 表的 provider_model_name 和 provider_model_aliases)
|
||||
3. 直接匹配 GlobalModel.name(兜底)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
model_name: 模型名称(可以是 GlobalModel.name 或映射名称)
|
||||
|
||||
Returns:
|
||||
GlobalModel 对象或 None
|
||||
"""
|
||||
start_time = time.time()
|
||||
resolution_method = "not_found"
|
||||
cache_hit = False
|
||||
|
||||
normalized_name = model_name.strip()
|
||||
if not normalized_name:
|
||||
return None
|
||||
|
||||
cache_key = f"global_model:resolve:{normalized_name}"
|
||||
|
||||
try:
|
||||
# 1. 尝试从缓存获取
|
||||
cached_data = await CacheService.get(cache_key)
|
||||
if cached_data:
|
||||
if cached_data == "NOT_FOUND":
|
||||
# 缓存的负结果
|
||||
cache_hit = True
|
||||
resolution_method = "not_found"
|
||||
logger.debug(f"GlobalModel 缓存命中(映射解析-未找到): {normalized_name}")
|
||||
return None
|
||||
if isinstance(cached_data, dict) and "supported_capabilities" not in cached_data:
|
||||
# 兼容旧缓存:字段不全时视为未命中,走 DB 刷新
|
||||
logger.debug(f"GlobalModel 缓存命中但 schema 过旧,刷新: {normalized_name}")
|
||||
else:
|
||||
cache_hit = True
|
||||
resolution_method = "direct_match" # 缓存命中时无法区分原始解析方式
|
||||
logger.debug(f"GlobalModel 缓存命中(映射解析): {normalized_name}")
|
||||
return ModelCacheService._dict_to_global_model(cached_data)
|
||||
|
||||
# 2. 优先通过 provider_model_name 和映射名称匹配(Provider 配置优先级最高)
|
||||
from sqlalchemy import or_
|
||||
|
||||
from src.models.database import Provider
|
||||
|
||||
# 构建精确的映射匹配条件
|
||||
# 注意:provider_model_aliases 是 JSONB 数组,需要使用 PostgreSQL 的 JSONB 操作符
|
||||
# 对于 SQLite,会在 Python 层面进行过滤
|
||||
try:
|
||||
# 尝试使用 PostgreSQL 的 JSONB 查询(更高效)
|
||||
# 使用 json.dumps 确保正确转义特殊字符,避免 SQL 注入
|
||||
jsonb_pattern = json.dumps([{"name": normalized_name}])
|
||||
models_with_global = (
|
||||
db.query(Model, GlobalModel)
|
||||
.join(Provider, Model.provider_id == Provider.id)
|
||||
.join(GlobalModel, Model.global_model_id == GlobalModel.id)
|
||||
.filter(
|
||||
Provider.is_active == True,
|
||||
Model.is_active == True,
|
||||
GlobalModel.is_active == True,
|
||||
or_(
|
||||
Model.provider_model_name == normalized_name,
|
||||
# PostgreSQL JSONB 查询:检查数组中是否有包含 {"name": "xxx"} 的元素
|
||||
Model.provider_model_aliases.op("@>")(jsonb_pattern),
|
||||
),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
except (OperationalError, ProgrammingError) as e:
|
||||
# JSONB 操作符不支持(如 SQLite),回退到加载匹配 provider_model_name 的 Model
|
||||
# 并在 Python 层过滤 aliases
|
||||
logger.debug(
|
||||
f"JSONB 查询失败,回退到 Python 过滤: {e}",
|
||||
)
|
||||
# 优化:先用 provider_model_name 缩小范围,再加载其他可能匹配的记录
|
||||
models_with_global = (
|
||||
db.query(Model, GlobalModel)
|
||||
.join(Provider, Model.provider_id == Provider.id)
|
||||
.join(GlobalModel, Model.global_model_id == GlobalModel.id)
|
||||
.filter(
|
||||
Provider.is_active == True,
|
||||
Model.is_active == True,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 用于存储匹配结果:{(model_id, global_model_id): (GlobalModel, match_type, priority)}
|
||||
# 使用字典去重,同一个 Model 只保留优先级最高的匹配
|
||||
matched_models_dict = {}
|
||||
|
||||
# 遍历查询结果进行匹配
|
||||
for model, gm in models_with_global:
|
||||
key = (model.id, gm.id)
|
||||
|
||||
# 检查 provider_model_aliases 是否匹配(优先级更高)
|
||||
if model.provider_model_aliases:
|
||||
for alias_entry in model.provider_model_aliases:
|
||||
if isinstance(alias_entry, dict):
|
||||
alias_name = alias_entry.get("name", "").strip()
|
||||
if alias_name == normalized_name:
|
||||
# alias 优先级为 0(最高),覆盖任何已存在的匹配
|
||||
matched_models_dict[key] = (gm, "alias", 0)
|
||||
logger.debug(
|
||||
f"模型名称 '{normalized_name}' 通过映射名称匹配到 "
|
||||
f"GlobalModel: {gm.name} (Model: {model.id[:8]}...)"
|
||||
)
|
||||
break
|
||||
|
||||
# 如果还没有匹配(或只有 provider_model_name 匹配),检查 provider_model_name
|
||||
if key not in matched_models_dict or matched_models_dict[key][1] != "alias":
|
||||
if model.provider_model_name == normalized_name:
|
||||
# provider_model_name 优先级为 1(兜底),只在没有 alias 匹配时使用
|
||||
if key not in matched_models_dict:
|
||||
matched_models_dict[key] = (gm, "provider_model_name", 1)
|
||||
logger.debug(
|
||||
f"模型名称 '{normalized_name}' 通过 provider_model_name 匹配到 "
|
||||
f"GlobalModel: {gm.name} (Model: {model.id[:8]}...)"
|
||||
)
|
||||
|
||||
# 如果通过 provider_model_name/alias 找到了,直接返回
|
||||
if matched_models_dict:
|
||||
# 转换为列表并排序:按 priority(alias=0 优先)、然后按 GlobalModel.name
|
||||
matched_global_models = [
|
||||
(gm, match_type) for gm, match_type, priority in matched_models_dict.values()
|
||||
]
|
||||
matched_global_models.sort(
|
||||
key=lambda item: (
|
||||
0 if item[1] == "alias" else 1, # alias 优先
|
||||
item[0].name # 同优先级按名称排序(确定性)
|
||||
)
|
||||
)
|
||||
|
||||
# 记录解析方式
|
||||
resolution_method = matched_global_models[0][1]
|
||||
|
||||
if len(matched_global_models) > 1:
|
||||
# 检测到冲突
|
||||
unique_models = {gm.id: gm for gm, _ in matched_global_models}
|
||||
if len(unique_models) > 1:
|
||||
model_names = [gm.name for gm in unique_models.values()]
|
||||
logger.warning(
|
||||
f"模型映射冲突: 名称 '{normalized_name}' 匹配到多个不同的 GlobalModel: "
|
||||
f"{', '.join(model_names)},使用第一个匹配结果"
|
||||
)
|
||||
# 记录冲突指标
|
||||
model_mapping_conflict_total.inc()
|
||||
|
||||
# 返回第一个匹配的 GlobalModel
|
||||
result_global_model: GlobalModel = matched_global_models[0][0]
|
||||
global_model_dict = ModelCacheService._global_model_to_dict(result_global_model)
|
||||
await CacheService.set(
|
||||
cache_key, global_model_dict, ttl_seconds=ModelCacheService.CACHE_TTL
|
||||
)
|
||||
logger.debug(
|
||||
f"GlobalModel 已缓存(映射解析-{resolution_method}): {normalized_name} -> {result_global_model.name}"
|
||||
)
|
||||
return result_global_model
|
||||
|
||||
# 3. 如果通过 provider 映射没找到,最后尝试直接通过 GlobalModel.name 查找
|
||||
global_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(GlobalModel.name == normalized_name, GlobalModel.is_active == True)
|
||||
.first()
|
||||
)
|
||||
|
||||
if global_model:
|
||||
resolution_method = "direct_match"
|
||||
# 缓存结果
|
||||
global_model_dict = ModelCacheService._global_model_to_dict(global_model)
|
||||
await CacheService.set(
|
||||
cache_key, global_model_dict, ttl_seconds=ModelCacheService.CACHE_TTL
|
||||
)
|
||||
logger.debug(f"GlobalModel 已缓存(映射解析-直接匹配): {normalized_name}")
|
||||
return global_model
|
||||
|
||||
# 4. 完全未找到
|
||||
resolution_method = "not_found"
|
||||
# 未找到匹配,缓存负结果
|
||||
await CacheService.set(
|
||||
cache_key, "NOT_FOUND", ttl_seconds=ModelCacheService.CACHE_TTL
|
||||
)
|
||||
logger.debug(f"GlobalModel 未找到(映射解析): {normalized_name}")
|
||||
return None
|
||||
|
||||
finally:
|
||||
# 记录监控指标
|
||||
duration = time.time() - start_time
|
||||
model_mapping_resolution_total.labels(
|
||||
method=resolution_method, cache_hit=str(cache_hit).lower()
|
||||
).inc()
|
||||
model_mapping_resolution_duration_seconds.labels(method=resolution_method).observe(
|
||||
duration
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _model_to_dict(model: Model) -> dict:
|
||||
@@ -256,6 +436,7 @@ class ModelCacheService:
|
||||
"provider_id": model.provider_id,
|
||||
"global_model_id": model.global_model_id,
|
||||
"provider_model_name": model.provider_model_name,
|
||||
"provider_model_aliases": getattr(model, "provider_model_aliases", None),
|
||||
"is_active": model.is_active,
|
||||
"is_available": model.is_available if hasattr(model, "is_available") else True,
|
||||
"price_per_request": (
|
||||
@@ -266,6 +447,7 @@ class ModelCacheService:
|
||||
"supports_function_calling": model.supports_function_calling,
|
||||
"supports_streaming": model.supports_streaming,
|
||||
"supports_extended_thinking": model.supports_extended_thinking,
|
||||
"supports_image_generation": getattr(model, "supports_image_generation", None),
|
||||
"config": model.config,
|
||||
}
|
||||
|
||||
@@ -277,6 +459,7 @@ class ModelCacheService:
|
||||
provider_id=model_dict["provider_id"],
|
||||
global_model_id=model_dict["global_model_id"],
|
||||
provider_model_name=model_dict["provider_model_name"],
|
||||
provider_model_aliases=model_dict.get("provider_model_aliases"),
|
||||
is_active=model_dict["is_active"],
|
||||
is_available=model_dict.get("is_available", True),
|
||||
price_per_request=model_dict.get("price_per_request"),
|
||||
@@ -285,6 +468,7 @@ class ModelCacheService:
|
||||
supports_function_calling=model_dict.get("supports_function_calling"),
|
||||
supports_streaming=model_dict.get("supports_streaming"),
|
||||
supports_extended_thinking=model_dict.get("supports_extended_thinking"),
|
||||
supports_image_generation=model_dict.get("supports_image_generation"),
|
||||
config=model_dict.get("config"),
|
||||
)
|
||||
return model
|
||||
@@ -296,12 +480,12 @@ class ModelCacheService:
|
||||
"id": global_model.id,
|
||||
"name": global_model.name,
|
||||
"display_name": global_model.display_name,
|
||||
"family": global_model.family,
|
||||
"group_id": global_model.group_id,
|
||||
"supports_vision": global_model.supports_vision,
|
||||
"supports_thinking": global_model.supports_thinking,
|
||||
"context_window": global_model.context_window,
|
||||
"max_output_tokens": global_model.max_output_tokens,
|
||||
"default_supports_vision": global_model.default_supports_vision,
|
||||
"default_supports_function_calling": global_model.default_supports_function_calling,
|
||||
"default_supports_streaming": global_model.default_supports_streaming,
|
||||
"default_supports_extended_thinking": global_model.default_supports_extended_thinking,
|
||||
"default_supports_image_generation": global_model.default_supports_image_generation,
|
||||
"supported_capabilities": global_model.supported_capabilities,
|
||||
"is_active": global_model.is_active,
|
||||
"description": global_model.description,
|
||||
}
|
||||
@@ -313,12 +497,18 @@ class ModelCacheService:
|
||||
id=global_model_dict["id"],
|
||||
name=global_model_dict["name"],
|
||||
display_name=global_model_dict.get("display_name"),
|
||||
family=global_model_dict.get("family"),
|
||||
group_id=global_model_dict.get("group_id"),
|
||||
supports_vision=global_model_dict.get("supports_vision", False),
|
||||
supports_thinking=global_model_dict.get("supports_thinking", False),
|
||||
context_window=global_model_dict.get("context_window"),
|
||||
max_output_tokens=global_model_dict.get("max_output_tokens"),
|
||||
default_supports_vision=global_model_dict.get("default_supports_vision", False),
|
||||
default_supports_function_calling=global_model_dict.get(
|
||||
"default_supports_function_calling", False
|
||||
),
|
||||
default_supports_streaming=global_model_dict.get("default_supports_streaming", True),
|
||||
default_supports_extended_thinking=global_model_dict.get(
|
||||
"default_supports_extended_thinking", False
|
||||
),
|
||||
default_supports_image_generation=global_model_dict.get(
|
||||
"default_supports_image_generation", False
|
||||
),
|
||||
supported_capabilities=global_model_dict.get("supported_capabilities") or [],
|
||||
is_active=global_model_dict.get("is_active", True),
|
||||
description=global_model_dict.get("description"),
|
||||
)
|
||||
|
||||
14
src/services/cache/sync.py
vendored
14
src/services/cache/sync.py
vendored
@@ -6,7 +6,7 @@
|
||||
|
||||
使用场景:
|
||||
1. 多实例部署时,确保所有实例的缓存一致性
|
||||
2. GlobalModel/ModelMapping 变更时,同步失效所有实例的缓存
|
||||
2. GlobalModel/Model 变更时,同步失效所有实例的缓存
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -29,7 +29,6 @@ class CacheSyncService:
|
||||
|
||||
# Redis 频道名称
|
||||
CHANNEL_GLOBAL_MODEL = "cache:invalidate:global_model"
|
||||
CHANNEL_MODEL_MAPPING = "cache:invalidate:model_mapping"
|
||||
CHANNEL_MODEL = "cache:invalidate:model"
|
||||
CHANNEL_CLEAR_ALL = "cache:invalidate:clear_all"
|
||||
|
||||
@@ -58,7 +57,6 @@ class CacheSyncService:
|
||||
# 订阅所有缓存失效频道
|
||||
await self._pubsub.subscribe(
|
||||
self.CHANNEL_GLOBAL_MODEL,
|
||||
self.CHANNEL_MODEL_MAPPING,
|
||||
self.CHANNEL_MODEL,
|
||||
self.CHANNEL_CLEAR_ALL,
|
||||
)
|
||||
@@ -68,7 +66,7 @@ class CacheSyncService:
|
||||
self._running = True
|
||||
|
||||
logger.info("[CacheSync] 缓存同步服务已启动,订阅频道: "
|
||||
f"{self.CHANNEL_GLOBAL_MODEL}, {self.CHANNEL_MODEL_MAPPING}, "
|
||||
f"{self.CHANNEL_GLOBAL_MODEL}, "
|
||||
f"{self.CHANNEL_MODEL}, {self.CHANNEL_CLEAR_ALL}")
|
||||
except Exception as e:
|
||||
logger.error(f"[CacheSync] 启动失败: {e}")
|
||||
@@ -141,14 +139,6 @@ class CacheSyncService:
|
||||
"""发布 GlobalModel 变更通知"""
|
||||
await self._publish(self.CHANNEL_GLOBAL_MODEL, {"model_name": model_name})
|
||||
|
||||
async def publish_model_mapping_changed(
|
||||
self, source_model: str, provider_id: Optional[str] = None
|
||||
):
|
||||
"""发布 ModelMapping 变更通知"""
|
||||
await self._publish(
|
||||
self.CHANNEL_MODEL_MAPPING, {"source_model": source_model, "provider_id": provider_id}
|
||||
)
|
||||
|
||||
async def publish_model_changed(self, provider_id: str, global_model_id: str):
|
||||
"""发布 Model 变更通知"""
|
||||
await self._publish(
|
||||
|
||||
@@ -1,19 +1,15 @@
|
||||
"""
|
||||
模型服务模块
|
||||
|
||||
包含模型管理、模型映射、成本计算等功能。
|
||||
包含模型管理、成本计算等功能。
|
||||
"""
|
||||
|
||||
from src.services.model.cost import ModelCostService
|
||||
from src.services.model.global_model import GlobalModelService
|
||||
from src.services.model.mapper import ModelMapperMiddleware
|
||||
from src.services.model.mapping_resolver import ModelMappingResolver
|
||||
from src.services.model.service import ModelService
|
||||
|
||||
__all__ = [
|
||||
"ModelService",
|
||||
"GlobalModelService",
|
||||
"ModelMapperMiddleware",
|
||||
"ModelMappingResolver",
|
||||
"ModelCostService",
|
||||
]
|
||||
|
||||
@@ -14,7 +14,7 @@ from typing import Dict, Optional, Tuple, Union
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.models.database import GlobalModel, Model, ModelMapping, Provider
|
||||
from src.models.database import GlobalModel, Model, Provider
|
||||
|
||||
|
||||
ProviderRef = Union[str, Provider, None]
|
||||
@@ -161,16 +161,11 @@ class ModelCostService:
|
||||
result = None
|
||||
|
||||
if provider_obj:
|
||||
from src.services.model.mapping_resolver import resolve_model_to_global_name
|
||||
|
||||
global_model_name = await resolve_model_to_global_name(
|
||||
self.db, model, provider_obj.id
|
||||
)
|
||||
|
||||
# 直接通过 GlobalModel.name 查找
|
||||
global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.name == global_model_name,
|
||||
GlobalModel.name == model,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
@@ -226,17 +221,14 @@ class ModelCostService:
|
||||
注意:如果模型配置了阶梯计费,此方法返回第一个阶梯的价格作为默认值。
|
||||
实际计费时应使用 compute_cost_with_tiered_pricing 方法。
|
||||
|
||||
计费逻辑(基于 mapping_type):
|
||||
1. 查找 ModelMapping(如果存在)
|
||||
2. 如果 mapping_type='alias':使用目标 GlobalModel 的价格
|
||||
3. 如果 mapping_type='mapping':尝试使用 source_model 对应的 GlobalModel 价格
|
||||
- 如果 source_model 对应的 GlobalModel 存在且有 Model 实现,使用那个价格
|
||||
- 否则回退到目标 GlobalModel 的价格
|
||||
4. 如果没有找到任何 ModelMapping,尝试直接匹配 GlobalModel.name
|
||||
计费逻辑:
|
||||
1. 直接通过 GlobalModel.name 匹配
|
||||
2. 查找该 Provider 的 Model 实现
|
||||
3. 获取价格配置
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
|
||||
model: 用户请求的模型名(必须是 GlobalModel.name)
|
||||
|
||||
Returns:
|
||||
(input_price, output_price) 元组
|
||||
@@ -253,136 +245,37 @@ class ModelCostService:
|
||||
output_price = None
|
||||
|
||||
if provider_obj:
|
||||
# 步骤 1: 查找 ModelMapping 以确定 mapping_type
|
||||
from src.models.database import ModelMapping
|
||||
|
||||
mapping = None
|
||||
# 先查 Provider 特定映射
|
||||
mapping = (
|
||||
self.db.query(ModelMapping)
|
||||
# 直接通过 GlobalModel.name 查找(用户必须使用标准模型名称)
|
||||
global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
ModelMapping.source_model == model,
|
||||
ModelMapping.provider_id == provider_obj.id,
|
||||
ModelMapping.is_active == True,
|
||||
GlobalModel.name == model,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
# 再查全局映射
|
||||
if not mapping:
|
||||
mapping = (
|
||||
self.db.query(ModelMapping)
|
||||
if global_model:
|
||||
model_obj = (
|
||||
self.db.query(Model)
|
||||
.filter(
|
||||
ModelMapping.source_model == model,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
ModelMapping.is_active == True,
|
||||
Model.provider_id == provider_obj.id,
|
||||
Model.global_model_id == global_model.id,
|
||||
Model.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if mapping:
|
||||
# 有映射,根据 mapping_type 决定计费模型
|
||||
if mapping.mapping_type == "mapping":
|
||||
# mapping 模式:尝试使用 source_model 对应的 GlobalModel 价格
|
||||
source_global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.name == model,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if source_global_model:
|
||||
source_model_obj = (
|
||||
self.db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == provider_obj.id,
|
||||
Model.global_model_id == source_global_model.id,
|
||||
Model.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if source_model_obj:
|
||||
# 检查是否有阶梯计费
|
||||
tiered = source_model_obj.get_effective_tiered_pricing()
|
||||
if tiered and tiered.get("tiers"):
|
||||
first_tier = tiered["tiers"][0]
|
||||
input_price = first_tier.get("input_price_per_1m", 0)
|
||||
output_price = first_tier.get("output_price_per_1m", 0)
|
||||
else:
|
||||
input_price = source_model_obj.get_effective_input_price()
|
||||
output_price = source_model_obj.get_effective_output_price()
|
||||
logger.debug(f"[mapping模式] 使用源模型价格: {model} "
|
||||
f"(输入: ${input_price}/M, 输出: ${output_price}/M)")
|
||||
|
||||
# alias 模式或 mapping 模式未找到源模型价格:使用目标 GlobalModel 价格
|
||||
if input_price is None:
|
||||
target_global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.id == mapping.target_global_model_id,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if target_global_model:
|
||||
target_model_obj = (
|
||||
self.db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == provider_obj.id,
|
||||
Model.global_model_id == target_global_model.id,
|
||||
Model.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if target_model_obj:
|
||||
# 检查是否有阶梯计费
|
||||
tiered = target_model_obj.get_effective_tiered_pricing()
|
||||
if tiered and tiered.get("tiers"):
|
||||
first_tier = tiered["tiers"][0]
|
||||
input_price = first_tier.get("input_price_per_1m", 0)
|
||||
output_price = first_tier.get("output_price_per_1m", 0)
|
||||
else:
|
||||
input_price = target_model_obj.get_effective_input_price()
|
||||
output_price = target_model_obj.get_effective_output_price()
|
||||
mode_label = (
|
||||
"alias模式"
|
||||
if mapping.mapping_type == "alias"
|
||||
else "mapping模式(回退)"
|
||||
)
|
||||
logger.debug(f"[{mode_label}] 使用目标模型价格: {model} -> {target_global_model.name} "
|
||||
f"(输入: ${input_price}/M, 输出: ${output_price}/M)")
|
||||
else:
|
||||
# 没有映射,尝试直接匹配 GlobalModel.name
|
||||
global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.name == model,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if global_model:
|
||||
model_obj = (
|
||||
self.db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == provider_obj.id,
|
||||
Model.global_model_id == global_model.id,
|
||||
Model.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if model_obj:
|
||||
# 检查是否有阶梯计费
|
||||
tiered = model_obj.get_effective_tiered_pricing()
|
||||
if tiered and tiered.get("tiers"):
|
||||
first_tier = tiered["tiers"][0]
|
||||
input_price = first_tier.get("input_price_per_1m", 0)
|
||||
output_price = first_tier.get("output_price_per_1m", 0)
|
||||
else:
|
||||
input_price = model_obj.get_effective_input_price()
|
||||
output_price = model_obj.get_effective_output_price()
|
||||
logger.debug(f"找到模型价格配置: {provider_name}/{model} "
|
||||
f"(输入: ${input_price}/M, 输出: ${output_price}/M)")
|
||||
if model_obj:
|
||||
# 检查是否有阶梯计费
|
||||
tiered = model_obj.get_effective_tiered_pricing()
|
||||
if tiered and tiered.get("tiers"):
|
||||
first_tier = tiered["tiers"][0]
|
||||
input_price = first_tier.get("input_price_per_1m", 0)
|
||||
output_price = first_tier.get("output_price_per_1m", 0)
|
||||
else:
|
||||
input_price = model_obj.get_effective_input_price()
|
||||
output_price = model_obj.get_effective_output_price()
|
||||
logger.debug(f"找到模型价格配置: {provider_name}/{model} "
|
||||
f"(输入: ${input_price}/M, 输出: ${output_price}/M)")
|
||||
|
||||
# 如果没有找到价格配置,使用 0.0 并记录警告
|
||||
if input_price is None:
|
||||
@@ -404,15 +297,14 @@ class ModelCostService:
|
||||
"""
|
||||
返回给定 provider/model 的 (input_price, output_price)。
|
||||
|
||||
新架构逻辑:
|
||||
1. 使用 ModelMappingResolver 解析别名(如果是)
|
||||
2. 解析为 GlobalModel.name
|
||||
3. 查找该 Provider 的 Model 实现
|
||||
4. 获取价格配置
|
||||
逻辑:
|
||||
1. 直接通过 GlobalModel.name 匹配
|
||||
2. 查找该 Provider 的 Model 实现
|
||||
3. 获取价格配置
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
|
||||
model: 用户请求的模型名(必须是 GlobalModel.name)
|
||||
|
||||
Returns:
|
||||
(input_price, output_price) 元组
|
||||
@@ -434,15 +326,9 @@ class ModelCostService:
|
||||
"""
|
||||
异步版本: 返回缓存创建/读取价格(每 1M tokens)。
|
||||
|
||||
新架构逻辑:
|
||||
1. 使用 ModelMappingResolver 解析别名(如果是)
|
||||
2. 解析为 GlobalModel.name
|
||||
3. 查找该 Provider 的 Model 实现
|
||||
4. 获取缓存价格配置
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
|
||||
model: 用户请求的模型名(必须是 GlobalModel.name)
|
||||
input_price: 基础输入价格(用于 Claude 模型的默认估算)
|
||||
|
||||
Returns:
|
||||
@@ -460,22 +346,17 @@ class ModelCostService:
|
||||
cache_read_price = None
|
||||
|
||||
if provider_obj:
|
||||
# 步骤 1: 检查是否是别名
|
||||
from src.services.model.mapping_resolver import resolve_model_to_global_name
|
||||
|
||||
global_model_name = await resolve_model_to_global_name(self.db, model, provider_obj.id)
|
||||
|
||||
# 步骤 2: 查找 GlobalModel
|
||||
# 直接通过 GlobalModel.name 查找(用户必须使用标准模型名称)
|
||||
global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.name == global_model_name,
|
||||
GlobalModel.name == model,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# 步骤 3: 查找该 Provider 的 Model 实现
|
||||
# 查找该 Provider 的 Model 实现
|
||||
if global_model:
|
||||
model_obj = (
|
||||
self.db.query(Model)
|
||||
@@ -517,15 +398,9 @@ class ModelCostService:
|
||||
"""
|
||||
异步版本: 返回按次计费价格(每次请求的固定费用)。
|
||||
|
||||
新架构逻辑:
|
||||
1. 使用 ModelMappingResolver 解析别名(如果是)
|
||||
2. 解析为 GlobalModel.name
|
||||
3. 查找该 Provider 的 Model 实现
|
||||
4. 获取按次计费价格配置
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
|
||||
model: 用户请求的模型名(必须是 GlobalModel.name)
|
||||
|
||||
Returns:
|
||||
按次计费价格,如果没有配置则返回 None
|
||||
@@ -534,22 +409,17 @@ class ModelCostService:
|
||||
price_per_request = None
|
||||
|
||||
if provider_obj:
|
||||
# 步骤 1: 检查是否是别名
|
||||
from src.services.model.mapping_resolver import resolve_model_to_global_name
|
||||
|
||||
global_model_name = await resolve_model_to_global_name(self.db, model, provider_obj.id)
|
||||
|
||||
# 步骤 2: 查找 GlobalModel
|
||||
# 直接通过 GlobalModel.name 查找(用户必须使用标准模型名称)
|
||||
global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.name == global_model_name,
|
||||
GlobalModel.name == model,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# 步骤 3: 查找该 Provider 的 Model 实现
|
||||
# 查找该 Provider 的 Model 实现
|
||||
if global_model:
|
||||
model_obj = (
|
||||
self.db.query(Model)
|
||||
@@ -595,15 +465,14 @@ class ModelCostService:
|
||||
"""
|
||||
返回缓存创建/读取价格(每 1M tokens)。
|
||||
|
||||
新架构逻辑:
|
||||
1. 使用 ModelMappingResolver 解析别名(如果是)
|
||||
2. 解析为 GlobalModel.name
|
||||
3. 查找该 Provider 的 Model 实现
|
||||
4. 获取缓存价格配置
|
||||
逻辑:
|
||||
1. 直接通过 GlobalModel.name 匹配
|
||||
2. 查找该 Provider 的 Model 实现
|
||||
3. 获取缓存价格配置
|
||||
|
||||
Args:
|
||||
provider: Provider 对象或提供商名称
|
||||
model: 用户请求的模型名(可能是 GlobalModel.name 或别名)
|
||||
model: 用户请求的模型名(必须是 GlobalModel.name)
|
||||
input_price: 基础输入价格(用于 Claude 模型的默认估算)
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -11,7 +11,7 @@ from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.core.exceptions import InvalidRequestException, NotFoundException
|
||||
from src.core.logger import logger
|
||||
from src.models.database import GlobalModel, Model, ModelMapping
|
||||
from src.models.database import GlobalModel, Model
|
||||
from src.models.pydantic_models import GlobalModelUpdate
|
||||
|
||||
|
||||
|
||||
@@ -5,18 +5,13 @@
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.core.cache_utils import SyncLRUCache
|
||||
from src.core.logger import logger
|
||||
from src.models.claude import ClaudeMessagesRequest
|
||||
from src.models.database import GlobalModel, Model, ModelMapping, Provider, ProviderEndpoint
|
||||
from src.models.database import GlobalModel, Model, Provider, ProviderEndpoint
|
||||
from src.services.cache.model_cache import ModelCacheService
|
||||
from src.services.model.mapping_resolver import (
|
||||
get_model_mapping_resolver,
|
||||
resolve_model_to_global_name,
|
||||
)
|
||||
|
||||
|
||||
|
||||
class ModelMapperMiddleware:
|
||||
@@ -71,10 +66,10 @@ class ModelMapperMiddleware:
|
||||
if mapping:
|
||||
# 应用映射
|
||||
original_model = request.model
|
||||
request.model = mapping.model.provider_model_name
|
||||
request.model = mapping.model.select_provider_model_name()
|
||||
|
||||
logger.debug(f"Applied model mapping for provider {provider.name}: "
|
||||
f"{original_model} -> {mapping.model.provider_model_name}")
|
||||
f"{original_model} -> {request.model}")
|
||||
else:
|
||||
# 没有找到映射,使用原始模型名
|
||||
logger.debug(f"No model mapping found for {source_model} with provider {provider.name}, "
|
||||
@@ -84,17 +79,16 @@ class ModelMapperMiddleware:
|
||||
|
||||
async def get_mapping(
|
||||
self, source_model: str, provider_id: str
|
||||
) -> Optional[ModelMapping]: # UUID
|
||||
) -> Optional[object]:
|
||||
"""
|
||||
获取模型映射
|
||||
|
||||
优化后逻辑:
|
||||
1. 使用统一的 ModelMappingResolver 解析别名(带缓存)
|
||||
2. 通过 GlobalModel 找到该 Provider 的 Model 实现
|
||||
3. 使用独立的映射缓存
|
||||
简化后的逻辑:
|
||||
1. 通过 GlobalModel.name 或别名解析 GlobalModel
|
||||
2. 找到 GlobalModel 后,查找该 Provider 的 Model 实现
|
||||
|
||||
Args:
|
||||
source_model: 用户请求的模型名或别名
|
||||
source_model: 用户请求的模型名(可以是 GlobalModel.name 或别名)
|
||||
provider_id: 提供商ID (UUID)
|
||||
|
||||
Returns:
|
||||
@@ -107,62 +101,57 @@ class ModelMapperMiddleware:
|
||||
|
||||
mapping = None
|
||||
|
||||
# 步骤 1 & 2: 通过统一的模型映射解析服务
|
||||
mapping_resolver = get_model_mapping_resolver()
|
||||
global_model = await mapping_resolver.get_global_model_by_request(
|
||||
self.db, source_model, provider_id
|
||||
# 步骤 1: 解析 GlobalModel(支持别名)
|
||||
global_model = await ModelCacheService.resolve_global_model_by_name_or_alias(
|
||||
self.db, source_model
|
||||
)
|
||||
|
||||
if not global_model:
|
||||
logger.debug(f"GlobalModel not found: {source_model} (provider={provider_id[:8]}...)")
|
||||
logger.debug(f"GlobalModel not found: {source_model}")
|
||||
self._cache[cache_key] = None
|
||||
return None
|
||||
|
||||
# 步骤 3: 查找该 Provider 是否有实现这个 GlobalModel 的 Model(使用缓存)
|
||||
# 步骤 2: 查找该 Provider 是否有实现这个 GlobalModel 的 Model(使用缓存)
|
||||
model = await ModelCacheService.get_model_by_provider_and_global_model(
|
||||
self.db, provider_id, global_model.id
|
||||
)
|
||||
|
||||
if model:
|
||||
# 只有当模型名发生变化时才返回映射
|
||||
if model.provider_model_name != source_model:
|
||||
mapping = type(
|
||||
"obj",
|
||||
(object,),
|
||||
{
|
||||
"source_model": source_model,
|
||||
"model": model,
|
||||
"is_active": True,
|
||||
"provider_id": provider_id,
|
||||
},
|
||||
)()
|
||||
# 创建映射对象
|
||||
mapping = type(
|
||||
"obj",
|
||||
(object,),
|
||||
{
|
||||
"source_model": source_model,
|
||||
"model": model,
|
||||
"is_active": True,
|
||||
"provider_id": provider_id,
|
||||
},
|
||||
)()
|
||||
|
||||
logger.debug(f"Found model mapping: {source_model} -> {model.provider_model_name} "
|
||||
f"(provider={provider_id[:8]}...)")
|
||||
else:
|
||||
logger.debug(f"Model found but no name change: {source_model} (provider={provider_id[:8]}...)")
|
||||
logger.debug(f"Found model mapping: {source_model} -> {model.provider_model_name} "
|
||||
f"(provider={provider_id[:8]}...)")
|
||||
|
||||
# 缓存结果
|
||||
self._cache[cache_key] = mapping
|
||||
|
||||
return mapping
|
||||
|
||||
def get_all_mappings(self, provider_id: str) -> List[ModelMapping]: # UUID
|
||||
def get_all_mappings(self, provider_id: str) -> List[object]:
|
||||
"""
|
||||
获取提供商的所有可用模型(通过 GlobalModel)
|
||||
|
||||
方案 A: 返回该 Provider 所有可用的 GlobalModel
|
||||
|
||||
Args:
|
||||
provider_id: 提供商ID (UUID)
|
||||
|
||||
Returns:
|
||||
模型映射列表(模拟的 ModelMapping 对象列表)
|
||||
模型映射列表
|
||||
"""
|
||||
# 查询该 Provider 的所有活跃 Model
|
||||
# 查询该 Provider 的所有活跃 Model(使用 joinedload 避免 N+1)
|
||||
models = (
|
||||
self.db.query(Model)
|
||||
.join(GlobalModel)
|
||||
.options(joinedload(Model.global_model))
|
||||
.filter(
|
||||
Model.provider_id == provider_id,
|
||||
Model.is_active == True,
|
||||
@@ -171,7 +160,7 @@ class ModelMapperMiddleware:
|
||||
.all()
|
||||
)
|
||||
|
||||
# 构造兼容的 ModelMapping 对象列表
|
||||
# 构造兼容的映射对象列表
|
||||
mappings = []
|
||||
for model in models:
|
||||
mapping = type(
|
||||
@@ -188,7 +177,7 @@ class ModelMapperMiddleware:
|
||||
|
||||
return mappings
|
||||
|
||||
def get_supported_models(self, provider_id: str) -> List[str]: # UUID
|
||||
def get_supported_models(self, provider_id: str) -> List[str]:
|
||||
"""
|
||||
获取提供商支持的所有源模型名
|
||||
|
||||
@@ -223,15 +212,6 @@ class ModelMapperMiddleware:
|
||||
if not mapping.is_active:
|
||||
return False, f"Model mapping for {request.model} is disabled"
|
||||
|
||||
# 不限制max_tokens,作为中转服务不应该限制用户的请求
|
||||
# if request.max_tokens and request.max_tokens > mapping.max_output_tokens:
|
||||
# return False, (
|
||||
# f"Requested max_tokens {request.max_tokens} exceeds limit "
|
||||
# f"{mapping.max_output_tokens} for model {request.model}"
|
||||
# )
|
||||
|
||||
# 可以添加更多验证逻辑,比如检查输入长度等
|
||||
|
||||
return True, None
|
||||
|
||||
def clear_cache(self):
|
||||
@@ -239,7 +219,7 @@ class ModelMapperMiddleware:
|
||||
self._cache.clear()
|
||||
logger.debug("Model mapping cache cleared")
|
||||
|
||||
def refresh_cache(self, provider_id: Optional[str] = None): # UUID
|
||||
def refresh_cache(self, provider_id: Optional[str] = None):
|
||||
"""
|
||||
刷新缓存
|
||||
|
||||
@@ -285,16 +265,10 @@ class ModelRoutingMiddleware:
|
||||
"""
|
||||
根据模型名选择提供商
|
||||
|
||||
逻辑:
|
||||
1. 如果指定了提供商,使用指定的提供商
|
||||
2. 如果没指定,使用默认提供商
|
||||
3. 选定提供商后,会检查该提供商的模型映射(在apply_mapping中处理)
|
||||
4. 如果指定了allowed_api_formats,只选择符合格式的提供商
|
||||
|
||||
Args:
|
||||
model_name: 请求的模型名
|
||||
preferred_provider: 首选提供商名称
|
||||
allowed_api_formats: 允许的API格式列表(如 ['CLAUDE', 'CLAUDE_CLI'])
|
||||
allowed_api_formats: 允许的API格式列表
|
||||
request_id: 请求ID(用于日志关联)
|
||||
|
||||
Returns:
|
||||
@@ -313,14 +287,12 @@ class ModelRoutingMiddleware:
|
||||
if provider:
|
||||
# 检查API格式 - 从 endpoints 中检查
|
||||
if allowed_api_formats:
|
||||
# 检查是否有符合要求的活跃端点
|
||||
has_matching_endpoint = any(
|
||||
ep.is_active and ep.api_format and ep.api_format in allowed_api_formats
|
||||
for ep in provider.endpoints
|
||||
)
|
||||
if not has_matching_endpoint:
|
||||
logger.warning(f"Specified provider {provider.name} has no active endpoints with allowed API formats ({allowed_api_formats})")
|
||||
# 不返回该提供商,继续查找
|
||||
else:
|
||||
logger.debug(f" └─ {request_prefix}使用指定提供商: {provider.name} | 模型:{model_name}")
|
||||
return provider
|
||||
@@ -330,10 +302,9 @@ class ModelRoutingMiddleware:
|
||||
else:
|
||||
logger.warning(f"Specified provider {preferred_provider} not found or inactive")
|
||||
|
||||
# 2. 查找优先级最高的活动提供商(provider_priority 最小)
|
||||
# 2. 查找优先级最高的活动提供商
|
||||
query = self.db.query(Provider).filter(Provider.is_active == True)
|
||||
|
||||
# 如果指定了API格式过滤,添加过滤条件 - 检查是否有符合要求的 endpoint
|
||||
if allowed_api_formats:
|
||||
query = (
|
||||
query.join(ProviderEndpoint)
|
||||
@@ -344,32 +315,27 @@ class ModelRoutingMiddleware:
|
||||
.distinct()
|
||||
)
|
||||
|
||||
# 按 provider_priority 排序,优先级最高(数字最小)的在前
|
||||
best_provider = query.order_by(Provider.provider_priority.asc(), Provider.id.asc()).first()
|
||||
|
||||
if best_provider:
|
||||
logger.debug(f" └─ {request_prefix}使用优先级最高提供商: {best_provider.name} (priority:{best_provider.provider_priority}) | 模型:{model_name}")
|
||||
return best_provider
|
||||
|
||||
# 3. 没有任何活动提供商
|
||||
if allowed_api_formats:
|
||||
logger.error(f"No active providers found with allowed API formats {allowed_api_formats}. Please configure at least one provider.")
|
||||
logger.error(f"No active providers found with allowed API formats {allowed_api_formats}.")
|
||||
else:
|
||||
logger.error("No active providers found. Please configure at least one provider.")
|
||||
logger.error("No active providers found.")
|
||||
return None
|
||||
|
||||
def get_available_models(self) -> Dict[str, List[str]]:
|
||||
"""
|
||||
获取所有可用的模型及其提供商
|
||||
|
||||
方案 A: 基于 GlobalModel 查询
|
||||
|
||||
Returns:
|
||||
字典,键为 GlobalModel.name,值为支持该模型的提供商名列表
|
||||
"""
|
||||
result = {}
|
||||
|
||||
# 查询所有活跃的 GlobalModel 及其 Provider
|
||||
models = (
|
||||
self.db.query(GlobalModel.name, Provider.name)
|
||||
.join(Model, GlobalModel.id == Model.global_model_id)
|
||||
@@ -392,28 +358,23 @@ class ModelRoutingMiddleware:
|
||||
"""
|
||||
获取某个模型最便宜的提供商
|
||||
|
||||
方案 A: 通过 GlobalModel 查找
|
||||
|
||||
Args:
|
||||
model_name: GlobalModel 名称或别名
|
||||
model_name: GlobalModel 名称
|
||||
|
||||
Returns:
|
||||
最便宜的提供商
|
||||
"""
|
||||
# 步骤 1: 解析模型名
|
||||
global_model_name = await resolve_model_to_global_name(self.db, model_name)
|
||||
|
||||
# 步骤 2: 查找 GlobalModel
|
||||
# 直接查找 GlobalModel
|
||||
global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(GlobalModel.name == global_model_name, GlobalModel.is_active == True)
|
||||
.filter(GlobalModel.name == model_name, GlobalModel.is_active == True)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not global_model:
|
||||
return None
|
||||
|
||||
# 步骤 3: 查询所有支持该模型的 Provider 及其价格
|
||||
# 查询所有支持该模型的 Provider 及其价格
|
||||
models_with_providers = (
|
||||
self.db.query(Provider, Model)
|
||||
.join(Model, Provider.id == Model.provider_id)
|
||||
@@ -428,15 +389,16 @@ class ModelRoutingMiddleware:
|
||||
if not models_with_providers:
|
||||
return None
|
||||
|
||||
# 按总价格排序(输入+输出价格)
|
||||
# 按总价格排序
|
||||
cheapest = min(
|
||||
models_with_providers, key=lambda x: x[1].input_price_per_1m + x[1].output_price_per_1m
|
||||
models_with_providers,
|
||||
key=lambda x: x[1].get_effective_input_price() + x[1].get_effective_output_price()
|
||||
)
|
||||
|
||||
provider = cheapest[0]
|
||||
model = cheapest[1]
|
||||
|
||||
logger.debug(f"Selected cheapest provider {provider.name} for model {model_name} "
|
||||
f"(input: ${model.input_price_per_1m}/M, output: ${model.output_price_per_1m}/M)")
|
||||
f"(input: ${model.get_effective_input_price()}/M, output: ${model.get_effective_output_price()}/M)")
|
||||
|
||||
return provider
|
||||
|
||||
@@ -1,432 +0,0 @@
|
||||
"""
|
||||
模型映射解析服务
|
||||
|
||||
负责统一的模型别名/降级解析,按优先级顺序:
|
||||
1. 映射(mapping):Provider 特定 → 全局
|
||||
2. 别名(alias):Provider 特定 → 全局
|
||||
3. 直接匹配 GlobalModel.name
|
||||
|
||||
支持特性:
|
||||
- 带缓存(本地或 Redis),减少数据库访问
|
||||
- 提供模糊匹配能力,用于提示相似模型
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from src.core.logger import logger
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.config.constants import CacheSize, CacheTTL
|
||||
from src.core.logger import logger
|
||||
from src.models.database import GlobalModel, ModelMapping
|
||||
from src.services.cache.backend import BaseCacheBackend, get_cache_backend
|
||||
|
||||
|
||||
class ModelMappingResolver:
|
||||
"""统一的 ModelMapping 解析服务(可跨进程共享缓存)。"""
|
||||
|
||||
def __init__(self, cache_ttl: int = CacheTTL.MODEL_MAPPING, cache_backend_type: str = "auto"):
|
||||
self._cache_ttl = cache_ttl
|
||||
self._cache_backend_type = cache_backend_type
|
||||
self._mapping_cache: Optional[BaseCacheBackend] = None
|
||||
self._global_model_cache: Optional[BaseCacheBackend] = None
|
||||
self._initialized = False
|
||||
self._stats = {
|
||||
"mapping_hits": 0,
|
||||
"mapping_misses": 0,
|
||||
"global_hits": 0,
|
||||
"global_misses": 0,
|
||||
}
|
||||
|
||||
async def _ensure_initialized(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._mapping_cache = await get_cache_backend(
|
||||
name="model_mapping_resolver:mapping",
|
||||
backend_type=self._cache_backend_type,
|
||||
max_size=CacheSize.MODEL_MAPPING,
|
||||
ttl=self._cache_ttl,
|
||||
)
|
||||
self._global_model_cache = await get_cache_backend(
|
||||
name="model_mapping_resolver:global",
|
||||
backend_type=self._cache_backend_type,
|
||||
max_size=CacheSize.MODEL_MAPPING,
|
||||
ttl=self._cache_ttl,
|
||||
)
|
||||
self._initialized = True
|
||||
logger.debug(f"[ModelMappingResolver] 缓存后端已初始化: {self._mapping_cache.get_stats()['backend']}")
|
||||
|
||||
def _cache_key(self, source_model: str, provider_id: Optional[str]) -> str:
|
||||
return f"{provider_id or 'global'}:{source_model}"
|
||||
|
||||
async def _lookup_target_global_model_id(
|
||||
self,
|
||||
db: Session,
|
||||
source_model: str,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
按优先级查找目标 GlobalModel ID:
|
||||
1. 映射(mapping_type='mapping'):Provider 特定 → 全局
|
||||
2. 别名(mapping_type='alias'):Provider 特定 → 全局
|
||||
3. 直接匹配 GlobalModel.name
|
||||
"""
|
||||
await self._ensure_initialized()
|
||||
cache_key = self._cache_key(source_model, provider_id)
|
||||
cached = await self._mapping_cache.get(cache_key)
|
||||
if cached is not None:
|
||||
self._stats["mapping_hits"] += 1
|
||||
return cached or None
|
||||
|
||||
self._stats["mapping_misses"] += 1
|
||||
|
||||
target_id: Optional[str] = None
|
||||
|
||||
# 优先级 1:查找映射(mapping_type='mapping')
|
||||
# 1.1 Provider 特定映射
|
||||
if provider_id:
|
||||
mapping = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == source_model,
|
||||
ModelMapping.provider_id == provider_id,
|
||||
ModelMapping.mapping_type == "mapping",
|
||||
ModelMapping.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if mapping:
|
||||
target_id = mapping.target_global_model_id
|
||||
logger.debug(f"[MappingResolver] 命中 Provider 映射: {source_model} -> {target_id[:8]}...")
|
||||
|
||||
# 1.2 全局映射
|
||||
if not target_id:
|
||||
mapping = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == source_model,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
ModelMapping.mapping_type == "mapping",
|
||||
ModelMapping.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if mapping:
|
||||
target_id = mapping.target_global_model_id
|
||||
logger.debug(f"[MappingResolver] 命中全局映射: {source_model} -> {target_id[:8]}...")
|
||||
|
||||
# 优先级 2:查找别名(mapping_type='alias')
|
||||
# 2.1 Provider 特定别名
|
||||
if not target_id and provider_id:
|
||||
alias = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == source_model,
|
||||
ModelMapping.provider_id == provider_id,
|
||||
ModelMapping.mapping_type == "alias",
|
||||
ModelMapping.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if alias:
|
||||
target_id = alias.target_global_model_id
|
||||
logger.debug(f"[MappingResolver] 命中 Provider 别名: {source_model} -> {target_id[:8]}...")
|
||||
|
||||
# 2.2 全局别名
|
||||
if not target_id:
|
||||
alias = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == source_model,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
ModelMapping.mapping_type == "alias",
|
||||
ModelMapping.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if alias:
|
||||
target_id = alias.target_global_model_id
|
||||
logger.debug(f"[MappingResolver] 命中全局别名: {source_model} -> {target_id[:8]}...")
|
||||
|
||||
# 优先级 3:直接匹配 GlobalModel.name
|
||||
if not target_id:
|
||||
global_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.name == source_model,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if global_model:
|
||||
target_id = global_model.id
|
||||
logger.debug(f"[MappingResolver] 直接匹配 GlobalModel: {source_model}")
|
||||
|
||||
cached_value = target_id if target_id is not None else ""
|
||||
await self._mapping_cache.set(cache_key, cached_value, self._cache_ttl)
|
||||
return target_id
|
||||
|
||||
async def resolve_to_global_model_name(
|
||||
self,
|
||||
db: Session,
|
||||
source_model: str,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""解析模型名/别名为 GlobalModel.name。未找到时返回原始输入。"""
|
||||
target_id = await self._lookup_target_global_model_id(db, source_model, provider_id)
|
||||
if not target_id:
|
||||
return source_model
|
||||
|
||||
await self._ensure_initialized()
|
||||
cached_name = await self._global_model_cache.get(target_id)
|
||||
if cached_name:
|
||||
self._stats["global_hits"] += 1
|
||||
return cached_name
|
||||
|
||||
self._stats["global_misses"] += 1
|
||||
global_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(GlobalModel.id == target_id, GlobalModel.is_active == True)
|
||||
.first()
|
||||
)
|
||||
if global_model:
|
||||
await self._global_model_cache.set(target_id, global_model.name, self._cache_ttl)
|
||||
return global_model.name
|
||||
|
||||
return source_model
|
||||
|
||||
async def get_global_model_by_request(
|
||||
self,
|
||||
db: Session,
|
||||
source_model: str,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> Optional[GlobalModel]:
|
||||
"""解析并返回 GlobalModel 对象(绑定当前 Session)。"""
|
||||
target_id = await self._lookup_target_global_model_id(db, source_model, provider_id)
|
||||
if not target_id:
|
||||
return None
|
||||
|
||||
global_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(GlobalModel.id == target_id, GlobalModel.is_active == True)
|
||||
.first()
|
||||
)
|
||||
return global_model
|
||||
|
||||
async def get_global_model_with_mapping_info(
|
||||
self,
|
||||
db: Session,
|
||||
source_model: str,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> Tuple[Optional[GlobalModel], bool]:
|
||||
"""
|
||||
解析并返回 GlobalModel 对象,同时返回是否发生了映射。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
source_model: 用户请求的模型名
|
||||
provider_id: Provider ID(可选)
|
||||
|
||||
Returns:
|
||||
(global_model, is_mapped) - GlobalModel 对象和是否发生了映射
|
||||
is_mapped=True 表示 source_model 通过 mapping 规则映射到了不同的模型
|
||||
is_mapped=False 表示 source_model 直接匹配或通过 alias 匹配
|
||||
"""
|
||||
await self._ensure_initialized()
|
||||
|
||||
# 先检查是否存在 mapping 类型的映射规则
|
||||
has_mapping = False
|
||||
|
||||
# 检查 Provider 特定映射
|
||||
if provider_id:
|
||||
mapping = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == source_model,
|
||||
ModelMapping.provider_id == provider_id,
|
||||
ModelMapping.mapping_type == "mapping",
|
||||
ModelMapping.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if mapping:
|
||||
has_mapping = True
|
||||
|
||||
# 检查全局映射
|
||||
if not has_mapping:
|
||||
mapping = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == source_model,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
ModelMapping.mapping_type == "mapping",
|
||||
ModelMapping.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if mapping:
|
||||
has_mapping = True
|
||||
|
||||
# 获取 GlobalModel
|
||||
global_model = await self.get_global_model_by_request(db, source_model, provider_id)
|
||||
|
||||
return global_model, has_mapping
|
||||
|
||||
async def get_global_model_direct(
|
||||
self,
|
||||
db: Session,
|
||||
source_model: str,
|
||||
) -> Optional[GlobalModel]:
|
||||
"""
|
||||
直接通过模型名获取 GlobalModel,不应用任何映射规则。
|
||||
仅查找 alias 和直接匹配。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
source_model: 模型名
|
||||
|
||||
Returns:
|
||||
GlobalModel 对象或 None
|
||||
"""
|
||||
# 优先级 1:查找别名(alias)
|
||||
# 全局别名
|
||||
alias = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == source_model,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
ModelMapping.mapping_type == "alias",
|
||||
ModelMapping.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if alias:
|
||||
global_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(GlobalModel.id == alias.target_global_model_id, GlobalModel.is_active == True)
|
||||
.first()
|
||||
)
|
||||
if global_model:
|
||||
return global_model
|
||||
|
||||
# 优先级 2:直接匹配 GlobalModel.name
|
||||
global_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.name == source_model,
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return global_model
|
||||
|
||||
def find_similar_models(
|
||||
self,
|
||||
db: Session,
|
||||
invalid_model: str,
|
||||
limit: int = 3,
|
||||
threshold: float = 0.4,
|
||||
) -> List[Tuple[str, float]]:
|
||||
"""用于提示相似的 GlobalModel.name。"""
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
all_models = db.query(GlobalModel.name).filter(GlobalModel.is_active == True).all()
|
||||
similarities: List[Tuple[str, float]] = []
|
||||
invalid_lower = invalid_model.lower()
|
||||
|
||||
for model in all_models:
|
||||
model_name = model.name
|
||||
ratio = SequenceMatcher(None, invalid_lower, model_name.lower()).ratio()
|
||||
if invalid_lower in model_name.lower() or model_name.lower() in invalid_lower:
|
||||
ratio += 0.2
|
||||
if ratio >= threshold:
|
||||
similarities.append((model_name, ratio))
|
||||
|
||||
similarities.sort(key=lambda item: item[1], reverse=True)
|
||||
return similarities[:limit]
|
||||
|
||||
async def invalidate_mapping_cache(self, source_model: str, provider_id: Optional[str] = None):
|
||||
await self._ensure_initialized()
|
||||
keys = [self._cache_key(source_model, provider_id)]
|
||||
if provider_id:
|
||||
keys.append(self._cache_key(source_model, None))
|
||||
for key in keys:
|
||||
await self._mapping_cache.delete(key)
|
||||
|
||||
async def invalidate_global_model_cache(self, global_model_id: Optional[str] = None):
|
||||
await self._ensure_initialized()
|
||||
if global_model_id:
|
||||
await self._global_model_cache.delete(global_model_id)
|
||||
else:
|
||||
await self._global_model_cache.clear()
|
||||
|
||||
async def clear_cache(self):
|
||||
await self._ensure_initialized()
|
||||
await self._mapping_cache.clear()
|
||||
await self._global_model_cache.clear()
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
total_mapping = self._stats["mapping_hits"] + self._stats["mapping_misses"]
|
||||
total_global = self._stats["global_hits"] + self._stats["global_misses"]
|
||||
stats = {
|
||||
"mapping_hit_rate": (
|
||||
self._stats["mapping_hits"] / total_mapping if total_mapping else 0.0
|
||||
),
|
||||
"global_hit_rate": self._stats["global_hits"] / total_global if total_global else 0.0,
|
||||
"stats": self._stats,
|
||||
}
|
||||
if self._initialized:
|
||||
stats["mapping_cache_backend"] = self._mapping_cache.get_stats()
|
||||
stats["global_cache_backend"] = self._global_model_cache.get_stats()
|
||||
return stats
|
||||
|
||||
|
||||
_model_mapping_resolver: Optional[ModelMappingResolver] = None
|
||||
|
||||
|
||||
def get_model_mapping_resolver(
|
||||
cache_ttl: int = 300, cache_backend_type: Optional[str] = None
|
||||
) -> ModelMappingResolver:
|
||||
global _model_mapping_resolver
|
||||
|
||||
if _model_mapping_resolver is None:
|
||||
if cache_backend_type is None:
|
||||
cache_backend_type = os.getenv("ALIAS_CACHE_BACKEND", "auto")
|
||||
_model_mapping_resolver = ModelMappingResolver(
|
||||
cache_ttl=cache_ttl,
|
||||
cache_backend_type=cache_backend_type,
|
||||
)
|
||||
logger.debug(f"[ModelMappingResolver] 初始化(cache_ttl={cache_ttl}s, backend={cache_backend_type})")
|
||||
|
||||
# 注册到缓存失效服务
|
||||
try:
|
||||
from src.services.cache.invalidation import get_cache_invalidation_service
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.set_mapping_resolver(_model_mapping_resolver)
|
||||
except Exception as exc:
|
||||
logger.warning(f"[ModelMappingResolver] 注册缓存失效服务失败: {exc}")
|
||||
|
||||
return _model_mapping_resolver
|
||||
|
||||
|
||||
async def resolve_model_to_global_name(
|
||||
db: Session,
|
||||
source_model: str,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> str:
|
||||
resolver = get_model_mapping_resolver()
|
||||
return await resolver.resolve_to_global_model_name(db, source_model, provider_id)
|
||||
|
||||
|
||||
async def get_global_model_by_request(
|
||||
db: Session,
|
||||
source_model: str,
|
||||
provider_id: Optional[str] = None,
|
||||
) -> Optional[GlobalModel]:
|
||||
resolver = get_model_mapping_resolver()
|
||||
return await resolver.get_global_model_by_request(db, source_model, provider_id)
|
||||
@@ -50,6 +50,7 @@ class ModelService:
|
||||
provider_id=provider_id,
|
||||
global_model_id=model_data.global_model_id,
|
||||
provider_model_name=model_data.provider_model_name,
|
||||
provider_model_aliases=model_data.provider_model_aliases,
|
||||
price_per_request=model_data.price_per_request,
|
||||
tiered_pricing=model_data.tiered_pricing,
|
||||
supports_vision=model_data.supports_vision,
|
||||
@@ -147,6 +148,10 @@ class ModelService:
|
||||
if not model:
|
||||
raise NotFoundException(f"模型 {model_id} 不存在")
|
||||
|
||||
# 保存旧的别名,用于清除缓存
|
||||
old_provider_model_name = model.provider_model_name
|
||||
old_provider_model_aliases = model.provider_model_aliases
|
||||
|
||||
# 更新字段
|
||||
update_data = model_data.model_dump(exclude_unset=True)
|
||||
|
||||
@@ -164,13 +169,28 @@ class ModelService:
|
||||
db.refresh(model)
|
||||
|
||||
# 清除 Redis 缓存(异步执行,不阻塞返回)
|
||||
# 先清除旧的别名缓存
|
||||
asyncio.create_task(
|
||||
ModelCacheService.invalidate_model_cache(
|
||||
model_id=model.id,
|
||||
provider_id=model.provider_id,
|
||||
global_model_id=model.global_model_id,
|
||||
provider_model_name=old_provider_model_name,
|
||||
provider_model_aliases=old_provider_model_aliases,
|
||||
)
|
||||
)
|
||||
# 再清除新的别名缓存(如果有变化)
|
||||
if (model.provider_model_name != old_provider_model_name or
|
||||
model.provider_model_aliases != old_provider_model_aliases):
|
||||
asyncio.create_task(
|
||||
ModelCacheService.invalidate_model_cache(
|
||||
model_id=model.id,
|
||||
provider_id=model.provider_id,
|
||||
global_model_id=model.global_model_id,
|
||||
provider_model_name=model.provider_model_name,
|
||||
provider_model_aliases=model.provider_model_aliases,
|
||||
)
|
||||
)
|
||||
|
||||
# 清除内存缓存(ModelMapperMiddleware 实例)
|
||||
if model.provider_id and model.global_model_id:
|
||||
@@ -191,7 +211,6 @@ class ModelService:
|
||||
新架构删除逻辑:
|
||||
- Model 只是 Provider 对 GlobalModel 的实现,删除不影响 GlobalModel
|
||||
- 检查是否是该 GlobalModel 的最后一个实现(如果是,警告但允许删除)
|
||||
- 不检查 ModelMapping(映射是 GlobalModel 之间的关系,别名也统一存储在此表中)
|
||||
"""
|
||||
model = db.query(Model).filter(Model.id == model_id).first()
|
||||
if not model:
|
||||
@@ -213,11 +232,37 @@ class ModelService:
|
||||
logger.warning(f"警告:删除模型 {model_id}(Provider: {model.provider_id[:8]}...)后,"
|
||||
f"GlobalModel '{model.global_model_id}' 将没有任何活跃的关联提供商")
|
||||
|
||||
# 保存缓存清除所需的信息(删除后无法访问)
|
||||
cache_info = {
|
||||
"model_id": model.id,
|
||||
"provider_id": model.provider_id,
|
||||
"global_model_id": model.global_model_id,
|
||||
"provider_model_name": model.provider_model_name,
|
||||
"provider_model_aliases": model.provider_model_aliases,
|
||||
}
|
||||
|
||||
try:
|
||||
db.delete(model)
|
||||
db.commit()
|
||||
logger.info(f"删除模型成功: id={model_id}, provider_model_name={model.provider_model_name}, "
|
||||
f"global_model_id={model.global_model_id[:8] if model.global_model_id else 'None'}...")
|
||||
|
||||
# 清除 Redis 缓存
|
||||
asyncio.create_task(
|
||||
ModelCacheService.invalidate_model_cache(
|
||||
model_id=cache_info["model_id"],
|
||||
provider_id=cache_info["provider_id"],
|
||||
global_model_id=cache_info["global_model_id"],
|
||||
provider_model_name=cache_info["provider_model_name"],
|
||||
provider_model_aliases=cache_info["provider_model_aliases"],
|
||||
)
|
||||
)
|
||||
|
||||
# 清除内存缓存
|
||||
if cache_info["provider_id"] and cache_info["global_model_id"]:
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_changed(cache_info["provider_id"], cache_info["global_model_id"])
|
||||
|
||||
logger.info(f"删除模型成功: id={model_id}, provider_model_name={cache_info['provider_model_name']}, "
|
||||
f"global_model_id={cache_info['global_model_id'][:8] if cache_info['global_model_id'] else 'None'}...")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"删除模型失败: {str(e)}")
|
||||
@@ -240,6 +285,8 @@ class ModelService:
|
||||
model_id=model.id,
|
||||
provider_id=model.provider_id,
|
||||
global_model_id=model.global_model_id,
|
||||
provider_model_name=model.provider_model_name,
|
||||
provider_model_aliases=model.provider_model_aliases,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -326,6 +373,7 @@ class ModelService:
|
||||
provider_id=model.provider_id,
|
||||
global_model_id=model.global_model_id,
|
||||
provider_model_name=model.provider_model_name,
|
||||
provider_model_aliases=model.provider_model_aliases,
|
||||
# 原始配置值(可能为空)
|
||||
price_per_request=model.price_per_request,
|
||||
tiered_pricing=model.tiered_pricing,
|
||||
|
||||
@@ -7,13 +7,11 @@ from typing import Dict
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.models.database import GlobalModel, Model, Provider
|
||||
from src.services.model.cost import ModelCostService
|
||||
from src.services.model.mapper import ModelMapperMiddleware, ModelRoutingMiddleware
|
||||
|
||||
|
||||
|
||||
class ProviderService:
|
||||
"""提供商服务类"""
|
||||
|
||||
@@ -34,30 +32,15 @@ class ProviderService:
|
||||
检查模型是否可用(严格白名单模式)
|
||||
|
||||
Args:
|
||||
model_name: 模型名称
|
||||
model_name: 模型名称(必须是 GlobalModel.name)
|
||||
|
||||
Returns:
|
||||
Model对象如果存在且激活,否则None
|
||||
"""
|
||||
# 首先检查是否有直接的模型记录
|
||||
model = (
|
||||
self.db.query(Model)
|
||||
.filter(Model.provider_model_name == model_name, Model.is_active == True)
|
||||
.first()
|
||||
)
|
||||
|
||||
if model:
|
||||
return model
|
||||
|
||||
# 方案 A:检查是否是别名(全局别名系统)
|
||||
from src.services.model.mapping_resolver import resolve_model_to_global_name
|
||||
|
||||
global_model_name = await resolve_model_to_global_name(self.db, model_name)
|
||||
|
||||
# 查找 GlobalModel
|
||||
# 直接查找 GlobalModel
|
||||
global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(GlobalModel.name == global_model_name, GlobalModel.is_active == True)
|
||||
.filter(GlobalModel.name == model_name, GlobalModel.is_active == True)
|
||||
.first()
|
||||
)
|
||||
|
||||
@@ -79,34 +62,15 @@ class ProviderService:
|
||||
|
||||
Args:
|
||||
provider_id: 提供商ID
|
||||
model_name: 模型名称
|
||||
model_name: 模型名称(必须是 GlobalModel.name)
|
||||
|
||||
Returns:
|
||||
Model对象如果该提供商支持该模型且激活,否则None
|
||||
"""
|
||||
# 首先检查该提供商下是否有直接的模型记录
|
||||
model = (
|
||||
self.db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == provider_id,
|
||||
Model.provider_model_name == model_name,
|
||||
Model.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if model:
|
||||
return model
|
||||
|
||||
# 方案 A:检查是否是别名
|
||||
from src.services.model.mapping_resolver import resolve_model_to_global_name
|
||||
|
||||
global_model_name = await resolve_model_to_global_name(self.db, model_name, provider_id)
|
||||
|
||||
# 查找 GlobalModel
|
||||
# 直接查找 GlobalModel
|
||||
global_model = (
|
||||
self.db.query(GlobalModel)
|
||||
.filter(GlobalModel.name == global_model_name, GlobalModel.is_active == True)
|
||||
.filter(GlobalModel.name == model_name, GlobalModel.is_active == True)
|
||||
.first()
|
||||
)
|
||||
|
||||
@@ -148,12 +112,19 @@ class ProviderService:
|
||||
获取所有可用的模型
|
||||
|
||||
Returns:
|
||||
模型和支持的提供商映射
|
||||
字典,键为模型名,值为提供商列表
|
||||
"""
|
||||
return self.router.get_available_models()
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空缓存"""
|
||||
self.mapper.clear_cache()
|
||||
self.cost_service.clear_cache()
|
||||
logger.info("Provider service cache cleared")
|
||||
def select_provider(self, model_name: str, preferred_provider=None):
|
||||
"""
|
||||
选择提供商
|
||||
|
||||
Args:
|
||||
model_name: 模型名
|
||||
preferred_provider: 首选提供商
|
||||
|
||||
Returns:
|
||||
Provider对象
|
||||
"""
|
||||
return self.router.select_provider(model_name, preferred_provider)
|
||||
|
||||
@@ -26,11 +26,10 @@ class UsageService:
|
||||
) -> tuple[float, float]:
|
||||
"""异步获取模型价格(输入价格,输出价格)每1M tokens
|
||||
|
||||
新架构查找逻辑:
|
||||
1. 使用 ModelMappingResolver 解析别名(如果是)
|
||||
2. 解析为 GlobalModel.name
|
||||
3. 查找该 Provider 的 Model 实现并获取价格
|
||||
4. 如果找不到则使用系统默认价格
|
||||
查找逻辑:
|
||||
1. 直接通过 GlobalModel.name 匹配
|
||||
2. 查找该 Provider 的 Model 实现并获取价格
|
||||
3. 如果找不到则使用系统默认价格
|
||||
"""
|
||||
|
||||
service = ModelCostService(db)
|
||||
@@ -40,11 +39,10 @@ class UsageService:
|
||||
def get_model_price(cls, db: Session, provider: str, model: str) -> tuple[float, float]:
|
||||
"""获取模型价格(输入价格,输出价格)每1M tokens
|
||||
|
||||
新架构查找逻辑:
|
||||
1. 使用 ModelMappingResolver 解析别名(如果是)
|
||||
2. 解析为 GlobalModel.name
|
||||
3. 查找该 Provider 的 Model 实现并获取价格
|
||||
4. 如果找不到则使用系统默认价格
|
||||
查找逻辑:
|
||||
1. 直接通过 GlobalModel.name 匹配
|
||||
2. 查找该 Provider 的 Model 实现并获取价格
|
||||
3. 如果找不到则使用系统默认价格
|
||||
"""
|
||||
|
||||
service = ModelCostService(db)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
|
||||
class SSEEventParser:
|
||||
@@ -8,7 +8,7 @@ class SSEEventParser:
|
||||
self._reset_buffer()
|
||||
|
||||
def _reset_buffer(self) -> None:
|
||||
self._buffer: Dict[str, Optional[str] | List[str]] = {
|
||||
self._buffer: Dict[str, Union[Optional[str], List[str]]] = {
|
||||
"event": None,
|
||||
"data": [],
|
||||
"id": None,
|
||||
@@ -17,16 +17,19 @@ class SSEEventParser:
|
||||
|
||||
def _finalize_event(self) -> Optional[Dict[str, Optional[str]]]:
|
||||
data_lines = self._buffer.get("data", [])
|
||||
if not data_lines:
|
||||
if not isinstance(data_lines, list) or not data_lines:
|
||||
self._reset_buffer()
|
||||
return None
|
||||
|
||||
data_str = "\n".join(data_lines)
|
||||
event = {
|
||||
"event": self._buffer.get("event"),
|
||||
event_val = self._buffer.get("event")
|
||||
id_val = self._buffer.get("id")
|
||||
retry_val = self._buffer.get("retry")
|
||||
event: Dict[str, Optional[str]] = {
|
||||
"event": event_val if isinstance(event_val, str) else None,
|
||||
"data": data_str,
|
||||
"id": self._buffer.get("id"),
|
||||
"retry": self._buffer.get("retry"),
|
||||
"id": id_val if isinstance(id_val, str) else None,
|
||||
"retry": retry_val if isinstance(retry_val, str) else None,
|
||||
}
|
||||
|
||||
self._reset_buffer()
|
||||
|
||||
Reference in New Issue
Block a user