mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 00:02:28 +08:00
Compare commits
17 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
21587449c8 | ||
|
|
3d0ab353d3 | ||
|
|
b2a857c164 | ||
|
|
4d1d863916 | ||
|
|
b579420690 | ||
|
|
9d5c84f9d3 | ||
|
|
53e6a82480 | ||
|
|
bd11ebdbd5 | ||
|
|
1dac4cb156 | ||
|
|
50abb55c94 | ||
|
|
73d3c9d3e4 | ||
|
|
d24c3885ab | ||
|
|
d696c575e6 | ||
|
|
46ff5a1a50 | ||
|
|
edce43d45f | ||
|
|
33265b4b13 | ||
|
|
a94aeca2d3 |
@@ -60,8 +60,11 @@ python generate_keys.py # 生成密钥, 并将生成的密钥填入 .env
|
||||
# 3. 部署
|
||||
docker-compose up -d
|
||||
|
||||
# 4. 更新
|
||||
docker-compose pull && docker-compose up -d
|
||||
# 4. 首次部署时, 初始化数据库
|
||||
./migrate.sh
|
||||
|
||||
# 5. 更新
|
||||
docker-compose pull && docker-compose up -d && ./migrate.sh
|
||||
```
|
||||
|
||||
### Docker Compose(本地构建镜像)
|
||||
|
||||
@@ -20,10 +20,10 @@ depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create ENUM types
|
||||
op.execute("CREATE TYPE userrole AS ENUM ('admin', 'user')")
|
||||
# Create ENUM types (with IF NOT EXISTS for idempotency)
|
||||
op.execute("DO $$ BEGIN CREATE TYPE userrole AS ENUM ('admin', 'user'); EXCEPTION WHEN duplicate_object THEN NULL; END $$")
|
||||
op.execute(
|
||||
"CREATE TYPE providerbillingtype AS ENUM ('monthly_quota', 'pay_as_you_go', 'free_tier')"
|
||||
"DO $$ BEGIN CREATE TYPE providerbillingtype AS ENUM ('monthly_quota', 'pay_as_you_go', 'free_tier'); EXCEPTION WHEN duplicate_object THEN NULL; END $$"
|
||||
)
|
||||
|
||||
# ==================== users ====================
|
||||
@@ -35,7 +35,7 @@ def upgrade() -> None:
|
||||
sa.Column("password_hash", sa.String(255), nullable=False),
|
||||
sa.Column(
|
||||
"role",
|
||||
sa.Enum("admin", "user", name="userrole", create_type=False),
|
||||
postgresql.ENUM("admin", "user", name="userrole", create_type=False),
|
||||
nullable=False,
|
||||
server_default="user",
|
||||
),
|
||||
@@ -67,7 +67,7 @@ def upgrade() -> None:
|
||||
sa.Column("website", sa.String(500), nullable=True),
|
||||
sa.Column(
|
||||
"billing_type",
|
||||
sa.Enum(
|
||||
postgresql.ENUM(
|
||||
"monthly_quota", "pay_as_you_go", "free_tier", name="providerbillingtype", create_type=False
|
||||
),
|
||||
nullable=False,
|
||||
|
||||
@@ -26,16 +26,66 @@ branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def column_exists(bind, table_name: str, column_name: str) -> bool:
|
||||
"""检查列是否存在"""
|
||||
result = bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = :table_name AND column_name = :column_name
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"table_name": table_name, "column_name": column_name},
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
|
||||
def table_exists(bind, table_name: str) -> bool:
|
||||
"""检查表是否存在"""
|
||||
result = bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.tables
|
||||
WHERE table_name = :table_name
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"table_name": table_name},
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
|
||||
def index_exists(bind, index_name: str) -> bool:
|
||||
"""检查索引是否存在"""
|
||||
result = bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM pg_indexes
|
||||
WHERE indexname = :index_name
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"index_name": index_name},
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""添加 provider_model_aliases 字段,迁移数据,删除 model_mappings 表"""
|
||||
# 1. 添加 provider_model_aliases 字段
|
||||
op.add_column(
|
||||
'models',
|
||||
sa.Column('provider_model_aliases', sa.JSON(), nullable=True)
|
||||
)
|
||||
|
||||
# 2. 迁移 model_mappings 数据
|
||||
bind = op.get_bind()
|
||||
|
||||
# 1. 添加 provider_model_aliases 字段(如果不存在)
|
||||
if not column_exists(bind, "models", "provider_model_aliases"):
|
||||
op.add_column(
|
||||
'models',
|
||||
sa.Column('provider_model_aliases', sa.JSON(), nullable=True)
|
||||
)
|
||||
|
||||
# 2. 迁移 model_mappings 数据(如果表存在)
|
||||
session = Session(bind=bind)
|
||||
|
||||
model_mappings_table = sa.table(
|
||||
@@ -96,104 +146,118 @@ def upgrade() -> None:
|
||||
|
||||
# 查询所有活跃的 provider 级别 alias(只迁移 is_active=True 且 mapping_type='alias' 的)
|
||||
# 全局别名/映射不迁移(新架构不再支持 source_model -> GlobalModel.name 的解析)
|
||||
mappings = session.execute(
|
||||
sa.select(
|
||||
model_mappings_table.c.source_model,
|
||||
model_mappings_table.c.target_global_model_id,
|
||||
model_mappings_table.c.provider_id,
|
||||
)
|
||||
.where(
|
||||
model_mappings_table.c.is_active.is_(True),
|
||||
model_mappings_table.c.provider_id.isnot(None),
|
||||
model_mappings_table.c.mapping_type == "alias",
|
||||
)
|
||||
.order_by(model_mappings_table.c.provider_id, model_mappings_table.c.source_model)
|
||||
).all()
|
||||
|
||||
# 按 (provider_id, target_global_model_id) 分组,收集别名
|
||||
alias_groups: dict = {}
|
||||
for source_model, target_global_model_id, provider_id in mappings:
|
||||
if not isinstance(source_model, str):
|
||||
continue
|
||||
source_model = source_model.strip()
|
||||
if not source_model:
|
||||
continue
|
||||
if not isinstance(provider_id, str) or not provider_id:
|
||||
continue
|
||||
if not isinstance(target_global_model_id, str) or not target_global_model_id:
|
||||
continue
|
||||
|
||||
key = (provider_id, target_global_model_id)
|
||||
if key not in alias_groups:
|
||||
alias_groups[key] = []
|
||||
priority = len(alias_groups[key]) + 1
|
||||
alias_groups[key].append({"name": source_model, "priority": priority})
|
||||
|
||||
# 更新对应的 models 记录
|
||||
for (provider_id, global_model_id), aliases in alias_groups.items():
|
||||
model_row = session.execute(
|
||||
sa.select(models_table.c.id, models_table.c.provider_model_aliases)
|
||||
# 仅当 model_mappings 表存在时执行迁移
|
||||
if table_exists(bind, "model_mappings"):
|
||||
mappings = session.execute(
|
||||
sa.select(
|
||||
model_mappings_table.c.source_model,
|
||||
model_mappings_table.c.target_global_model_id,
|
||||
model_mappings_table.c.provider_id,
|
||||
)
|
||||
.where(
|
||||
models_table.c.provider_id == provider_id,
|
||||
models_table.c.global_model_id == global_model_id,
|
||||
model_mappings_table.c.is_active.is_(True),
|
||||
model_mappings_table.c.provider_id.isnot(None),
|
||||
model_mappings_table.c.mapping_type == "alias",
|
||||
)
|
||||
.limit(1)
|
||||
).first()
|
||||
.order_by(model_mappings_table.c.provider_id, model_mappings_table.c.source_model)
|
||||
).all()
|
||||
|
||||
if model_row:
|
||||
model_id = model_row[0]
|
||||
existing_aliases = normalize_alias_list(model_row[1])
|
||||
# 按 (provider_id, target_global_model_id) 分组,收集别名
|
||||
alias_groups: dict = {}
|
||||
for source_model, target_global_model_id, provider_id in mappings:
|
||||
if not isinstance(source_model, str):
|
||||
continue
|
||||
source_model = source_model.strip()
|
||||
if not source_model:
|
||||
continue
|
||||
if not isinstance(provider_id, str) or not provider_id:
|
||||
continue
|
||||
if not isinstance(target_global_model_id, str) or not target_global_model_id:
|
||||
continue
|
||||
|
||||
existing_names = {a["name"] for a in existing_aliases}
|
||||
merged_aliases = list(existing_aliases)
|
||||
for alias in aliases:
|
||||
name = alias.get("name")
|
||||
if not isinstance(name, str):
|
||||
continue
|
||||
name = name.strip()
|
||||
if not name or name in existing_names:
|
||||
continue
|
||||
key = (provider_id, target_global_model_id)
|
||||
if key not in alias_groups:
|
||||
alias_groups[key] = []
|
||||
priority = len(alias_groups[key]) + 1
|
||||
alias_groups[key].append({"name": source_model, "priority": priority})
|
||||
|
||||
merged_aliases.append(
|
||||
{
|
||||
"name": name,
|
||||
"priority": len(merged_aliases) + 1,
|
||||
}
|
||||
# 更新对应的 models 记录
|
||||
for (provider_id, global_model_id), aliases in alias_groups.items():
|
||||
model_row = session.execute(
|
||||
sa.select(models_table.c.id, models_table.c.provider_model_aliases)
|
||||
.where(
|
||||
models_table.c.provider_id == provider_id,
|
||||
models_table.c.global_model_id == global_model_id,
|
||||
)
|
||||
existing_names.add(name)
|
||||
.limit(1)
|
||||
).first()
|
||||
|
||||
session.execute(
|
||||
models_table.update()
|
||||
.where(models_table.c.id == model_id)
|
||||
.values(
|
||||
provider_model_aliases=merged_aliases if merged_aliases else None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
if model_row:
|
||||
model_id = model_row[0]
|
||||
existing_aliases = normalize_alias_list(model_row[1])
|
||||
|
||||
existing_names = {a["name"] for a in existing_aliases}
|
||||
merged_aliases = list(existing_aliases)
|
||||
for alias in aliases:
|
||||
name = alias.get("name")
|
||||
if not isinstance(name, str):
|
||||
continue
|
||||
name = name.strip()
|
||||
if not name or name in existing_names:
|
||||
continue
|
||||
|
||||
merged_aliases.append(
|
||||
{
|
||||
"name": name,
|
||||
"priority": len(merged_aliases) + 1,
|
||||
}
|
||||
)
|
||||
existing_names.add(name)
|
||||
|
||||
session.execute(
|
||||
models_table.update()
|
||||
.where(models_table.c.id == model_id)
|
||||
.values(
|
||||
provider_model_aliases=merged_aliases if merged_aliases else None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
session.commit()
|
||||
session.commit()
|
||||
|
||||
# 3. 删除 model_mappings 表
|
||||
op.drop_table('model_mappings')
|
||||
# 3. 删除 model_mappings 表
|
||||
op.drop_table('model_mappings')
|
||||
|
||||
# 4. 添加索引优化别名解析性能
|
||||
# provider_model_name 索引(支持精确匹配)
|
||||
op.create_index(
|
||||
"idx_model_provider_model_name",
|
||||
"models",
|
||||
["provider_model_name"],
|
||||
unique=False,
|
||||
postgresql_where=sa.text("is_active = true"),
|
||||
)
|
||||
# provider_model_name 索引(支持精确匹配,如果不存在)
|
||||
if not index_exists(bind, "idx_model_provider_model_name"):
|
||||
op.create_index(
|
||||
"idx_model_provider_model_name",
|
||||
"models",
|
||||
["provider_model_name"],
|
||||
unique=False,
|
||||
postgresql_where=sa.text("is_active = true"),
|
||||
)
|
||||
|
||||
# provider_model_aliases GIN 索引(支持 JSONB 查询,仅 PostgreSQL)
|
||||
if bind.dialect.name == "postgresql":
|
||||
# 将 json 列转为 jsonb(jsonb 性能更好且支持 GIN 索引)
|
||||
# 使用 IF NOT EXISTS 风格的检查来避免重复转换
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE models
|
||||
ALTER COLUMN provider_model_aliases TYPE jsonb
|
||||
USING provider_model_aliases::jsonb
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'models'
|
||||
AND column_name = 'provider_model_aliases'
|
||||
AND data_type = 'json'
|
||||
) THEN
|
||||
ALTER TABLE models
|
||||
ALTER COLUMN provider_model_aliases TYPE jsonb
|
||||
USING provider_model_aliases::jsonb;
|
||||
END IF;
|
||||
END $$;
|
||||
"""
|
||||
)
|
||||
# 创建 GIN 索引
|
||||
|
||||
@@ -5,8 +5,8 @@ Revises: e9b3d63f0cbf
|
||||
Create Date: 2025-12-15 17:07:44.631032+00:00
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
@@ -16,10 +16,29 @@ branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def column_exists(bind, table_name: str, column_name: str) -> bool:
|
||||
"""检查列是否存在"""
|
||||
result = bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = :table_name AND column_name = :column_name
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"table_name": table_name, "column_name": column_name},
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""应用迁移:升级到新版本"""
|
||||
# 添加首字时间字段到 usage 表
|
||||
op.add_column('usage', sa.Column('first_byte_time_ms', sa.Integer(), nullable=True))
|
||||
bind = op.get_bind()
|
||||
|
||||
# 添加首字时间字段到 usage 表(如果不存在)
|
||||
if not column_exists(bind, "usage", "first_byte_time_ms"):
|
||||
op.add_column('usage', sa.Column('first_byte_time_ms', sa.Integer(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
"""refactor global_model to use config json field
|
||||
|
||||
Revision ID: 1cc6942cf06f
|
||||
Revises: 180e63a9c83a
|
||||
Create Date: 2025-12-16 03:11:32.480976+00:00
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '1cc6942cf06f'
|
||||
down_revision = '180e63a9c83a'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def column_exists(bind, table_name: str, column_name: str) -> bool:
|
||||
"""检查列是否存在"""
|
||||
result = bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = :table_name AND column_name = :column_name
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"table_name": table_name, "column_name": column_name},
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""应用迁移:升级到新版本
|
||||
|
||||
1. 添加 config 列
|
||||
2. 把旧数据迁移到 config
|
||||
3. 删除旧列
|
||||
"""
|
||||
bind = op.get_bind()
|
||||
|
||||
# 检查是否已经迁移过(config 列存在且旧列不存在)
|
||||
has_config = column_exists(bind, "global_models", "config")
|
||||
has_old_columns = column_exists(bind, "global_models", "default_supports_streaming")
|
||||
|
||||
if has_config and not has_old_columns:
|
||||
# 已完成迁移,跳过
|
||||
return
|
||||
|
||||
# 1. 添加 config 列(使用 JSONB 类型,支持索引和更高效的查询)
|
||||
if not has_config:
|
||||
op.add_column('global_models', sa.Column('config', postgresql.JSONB(), nullable=True))
|
||||
|
||||
# 2. 迁移数据:把旧字段合并到 config JSON(仅当旧列存在时)
|
||||
if has_old_columns:
|
||||
op.execute("""
|
||||
UPDATE global_models
|
||||
SET config = jsonb_strip_nulls(jsonb_build_object(
|
||||
'streaming', COALESCE(default_supports_streaming, true),
|
||||
'vision', CASE WHEN COALESCE(default_supports_vision, false) THEN true ELSE NULL END,
|
||||
'function_calling', CASE WHEN COALESCE(default_supports_function_calling, false) THEN true ELSE NULL END,
|
||||
'extended_thinking', CASE WHEN COALESCE(default_supports_extended_thinking, false) THEN true ELSE NULL END,
|
||||
'image_generation', CASE WHEN COALESCE(default_supports_image_generation, false) THEN true ELSE NULL END,
|
||||
'description', description,
|
||||
'icon_url', icon_url,
|
||||
'official_url', official_url
|
||||
))
|
||||
""")
|
||||
|
||||
# 3. 删除旧列
|
||||
op.drop_column('global_models', 'default_supports_streaming')
|
||||
op.drop_column('global_models', 'default_supports_vision')
|
||||
op.drop_column('global_models', 'default_supports_function_calling')
|
||||
op.drop_column('global_models', 'default_supports_extended_thinking')
|
||||
op.drop_column('global_models', 'default_supports_image_generation')
|
||||
op.drop_column('global_models', 'description')
|
||||
op.drop_column('global_models', 'icon_url')
|
||||
op.drop_column('global_models', 'official_url')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""回滚迁移:降级到旧版本"""
|
||||
# 1. 添加旧列
|
||||
op.add_column('global_models', sa.Column('icon_url', sa.VARCHAR(length=500), nullable=True))
|
||||
op.add_column('global_models', sa.Column('official_url', sa.VARCHAR(length=500), nullable=True))
|
||||
op.add_column('global_models', sa.Column('description', sa.TEXT(), nullable=True))
|
||||
op.add_column('global_models', sa.Column('default_supports_streaming', sa.BOOLEAN(), nullable=True))
|
||||
op.add_column('global_models', sa.Column('default_supports_vision', sa.BOOLEAN(), nullable=True))
|
||||
op.add_column('global_models', sa.Column('default_supports_function_calling', sa.BOOLEAN(), nullable=True))
|
||||
op.add_column('global_models', sa.Column('default_supports_extended_thinking', sa.BOOLEAN(), nullable=True))
|
||||
op.add_column('global_models', sa.Column('default_supports_image_generation', sa.BOOLEAN(), nullable=True))
|
||||
|
||||
# 2. 从 config 恢复数据
|
||||
op.execute("""
|
||||
UPDATE global_models
|
||||
SET
|
||||
default_supports_streaming = COALESCE((config->>'streaming')::boolean, true),
|
||||
default_supports_vision = COALESCE((config->>'vision')::boolean, false),
|
||||
default_supports_function_calling = COALESCE((config->>'function_calling')::boolean, false),
|
||||
default_supports_extended_thinking = COALESCE((config->>'extended_thinking')::boolean, false),
|
||||
default_supports_image_generation = COALESCE((config->>'image_generation')::boolean, false),
|
||||
description = config->>'description',
|
||||
icon_url = config->>'icon_url',
|
||||
official_url = config->>'official_url'
|
||||
""")
|
||||
|
||||
# 3. 删除 config 列
|
||||
op.drop_column('global_models', 'config')
|
||||
@@ -1,5 +1,179 @@
|
||||
import apiClient from './client'
|
||||
|
||||
// 配置导出数据结构
|
||||
export interface ConfigExportData {
|
||||
version: string
|
||||
exported_at: string
|
||||
global_models: GlobalModelExport[]
|
||||
providers: ProviderExport[]
|
||||
}
|
||||
|
||||
// 用户导出数据结构
|
||||
export interface UsersExportData {
|
||||
version: string
|
||||
exported_at: string
|
||||
users: UserExport[]
|
||||
}
|
||||
|
||||
export interface UserExport {
|
||||
email: string
|
||||
username: string
|
||||
password_hash: string
|
||||
role: string
|
||||
allowed_providers?: string[] | null
|
||||
allowed_endpoints?: string[] | null
|
||||
allowed_models?: string[] | null
|
||||
model_capability_settings?: any
|
||||
quota_usd?: number | null
|
||||
used_usd?: number
|
||||
total_usd?: number
|
||||
is_active: boolean
|
||||
api_keys: UserApiKeyExport[]
|
||||
}
|
||||
|
||||
export interface UserApiKeyExport {
|
||||
key_hash: string
|
||||
key_encrypted?: string | null
|
||||
name?: string | null
|
||||
is_standalone: boolean
|
||||
balance_used_usd?: number
|
||||
current_balance_usd?: number | null
|
||||
allowed_providers?: string[] | null
|
||||
allowed_endpoints?: string[] | null
|
||||
allowed_api_formats?: string[] | null
|
||||
allowed_models?: string[] | null
|
||||
rate_limit?: number
|
||||
concurrent_limit?: number | null
|
||||
force_capabilities?: any
|
||||
is_active: boolean
|
||||
auto_delete_on_expiry?: boolean
|
||||
total_requests?: number
|
||||
total_cost_usd?: number
|
||||
}
|
||||
|
||||
export interface GlobalModelExport {
|
||||
name: string
|
||||
display_name: string
|
||||
default_price_per_request?: number | null
|
||||
default_tiered_pricing: any
|
||||
supported_capabilities?: string[] | null
|
||||
config?: any
|
||||
is_active: boolean
|
||||
}
|
||||
|
||||
export interface ProviderExport {
|
||||
name: string
|
||||
display_name: string
|
||||
description?: string | null
|
||||
website?: string | null
|
||||
billing_type?: string | null
|
||||
monthly_quota_usd?: number | null
|
||||
quota_reset_day?: number
|
||||
rpm_limit?: number | null
|
||||
provider_priority?: number
|
||||
is_active: boolean
|
||||
rate_limit?: number | null
|
||||
concurrent_limit?: number | null
|
||||
config?: any
|
||||
endpoints: EndpointExport[]
|
||||
models: ModelExport[]
|
||||
}
|
||||
|
||||
export interface EndpointExport {
|
||||
api_format: string
|
||||
base_url: string
|
||||
headers?: any
|
||||
timeout?: number
|
||||
max_retries?: number
|
||||
max_concurrent?: number | null
|
||||
rate_limit?: number | null
|
||||
is_active: boolean
|
||||
custom_path?: string | null
|
||||
config?: any
|
||||
keys: KeyExport[]
|
||||
}
|
||||
|
||||
export interface KeyExport {
|
||||
api_key: string
|
||||
name?: string | null
|
||||
note?: string | null
|
||||
rate_multiplier?: number
|
||||
internal_priority?: number
|
||||
global_priority?: number | null
|
||||
max_concurrent?: number | null
|
||||
rate_limit?: number | null
|
||||
daily_limit?: number | null
|
||||
monthly_limit?: number | null
|
||||
allowed_models?: string[] | null
|
||||
capabilities?: any
|
||||
is_active: boolean
|
||||
}
|
||||
|
||||
export interface ModelExport {
|
||||
global_model_name: string | null
|
||||
provider_model_name: string
|
||||
provider_model_aliases?: any
|
||||
price_per_request?: number | null
|
||||
tiered_pricing?: any
|
||||
supports_vision?: boolean | null
|
||||
supports_function_calling?: boolean | null
|
||||
supports_streaming?: boolean | null
|
||||
supports_extended_thinking?: boolean | null
|
||||
supports_image_generation?: boolean | null
|
||||
is_active: boolean
|
||||
config?: any
|
||||
}
|
||||
|
||||
// Provider 模型查询响应
|
||||
export interface ProviderModelsQueryResponse {
|
||||
success: boolean
|
||||
data: {
|
||||
models: Array<{
|
||||
id: string
|
||||
object?: string
|
||||
created?: number
|
||||
owned_by?: string
|
||||
display_name?: string
|
||||
api_format?: string
|
||||
}>
|
||||
error?: string
|
||||
}
|
||||
provider: {
|
||||
id: string
|
||||
name: string
|
||||
display_name: string
|
||||
}
|
||||
}
|
||||
|
||||
export interface ConfigImportRequest extends ConfigExportData {
|
||||
merge_mode: 'skip' | 'overwrite' | 'error'
|
||||
}
|
||||
|
||||
export interface UsersImportRequest extends UsersExportData {
|
||||
merge_mode: 'skip' | 'overwrite' | 'error'
|
||||
}
|
||||
|
||||
export interface UsersImportResponse {
|
||||
message: string
|
||||
stats: {
|
||||
users: { created: number; updated: number; skipped: number }
|
||||
api_keys: { created: number; skipped: number }
|
||||
errors: string[]
|
||||
}
|
||||
}
|
||||
|
||||
export interface ConfigImportResponse {
|
||||
message: string
|
||||
stats: {
|
||||
global_models: { created: number; updated: number; skipped: number }
|
||||
providers: { created: number; updated: number; skipped: number }
|
||||
endpoints: { created: number; updated: number; skipped: number }
|
||||
keys: { created: number; updated: number; skipped: number }
|
||||
models: { created: number; updated: number; skipped: number }
|
||||
errors: string[]
|
||||
}
|
||||
}
|
||||
|
||||
// API密钥管理相关接口定义
|
||||
export interface AdminApiKey {
|
||||
id: string // UUID
|
||||
@@ -173,5 +347,44 @@ export const adminApi = {
|
||||
'/api/admin/system/api-formats'
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
// 导出配置
|
||||
async exportConfig(): Promise<ConfigExportData> {
|
||||
const response = await apiClient.get<ConfigExportData>('/api/admin/system/config/export')
|
||||
return response.data
|
||||
},
|
||||
|
||||
// 导入配置
|
||||
async importConfig(data: ConfigImportRequest): Promise<ConfigImportResponse> {
|
||||
const response = await apiClient.post<ConfigImportResponse>(
|
||||
'/api/admin/system/config/import',
|
||||
data
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
// 导出用户数据
|
||||
async exportUsers(): Promise<UsersExportData> {
|
||||
const response = await apiClient.get<UsersExportData>('/api/admin/system/users/export')
|
||||
return response.data
|
||||
},
|
||||
|
||||
// 导入用户数据
|
||||
async importUsers(data: UsersImportRequest): Promise<UsersImportResponse> {
|
||||
const response = await apiClient.post<UsersImportResponse>(
|
||||
'/api/admin/system/users/import',
|
||||
data
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
// 查询 Provider 可用模型(从上游 API 获取)
|
||||
async queryProviderModels(providerId: string, apiKeyId?: string): Promise<ProviderModelsQueryResponse> {
|
||||
const response = await apiClient.post<ProviderModelsQueryResponse>(
|
||||
'/api/admin/provider-query/models',
|
||||
{ provider_id: providerId, api_key_id: apiKeyId }
|
||||
)
|
||||
return response.data
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,25 @@
|
||||
// API 格式常量
|
||||
export const API_FORMATS = {
|
||||
CLAUDE: 'CLAUDE',
|
||||
CLAUDE_CLI: 'CLAUDE_CLI',
|
||||
OPENAI: 'OPENAI',
|
||||
OPENAI_CLI: 'OPENAI_CLI',
|
||||
GEMINI: 'GEMINI',
|
||||
GEMINI_CLI: 'GEMINI_CLI',
|
||||
} as const
|
||||
|
||||
export type APIFormat = typeof API_FORMATS[keyof typeof API_FORMATS]
|
||||
|
||||
// API 格式显示名称映射(按品牌分组:API 在前,CLI 在后)
|
||||
export const API_FORMAT_LABELS: Record<string, string> = {
|
||||
[API_FORMATS.CLAUDE]: 'Claude',
|
||||
[API_FORMATS.CLAUDE_CLI]: 'Claude CLI',
|
||||
[API_FORMATS.OPENAI]: 'OpenAI',
|
||||
[API_FORMATS.OPENAI_CLI]: 'OpenAI CLI',
|
||||
[API_FORMATS.GEMINI]: 'Gemini',
|
||||
[API_FORMATS.GEMINI_CLI]: 'Gemini CLI',
|
||||
}
|
||||
|
||||
export interface ProviderEndpoint {
|
||||
id: string
|
||||
provider_id: string
|
||||
@@ -214,6 +236,7 @@ export interface ConcurrencyStatus {
|
||||
export interface ProviderModelAlias {
|
||||
name: string
|
||||
priority: number // 优先级(数字越小优先级越高)
|
||||
api_formats?: string[] // 作用域(适用的 API 格式),为空表示对所有格式生效
|
||||
}
|
||||
|
||||
export interface Model {
|
||||
@@ -407,67 +430,45 @@ export interface TieredPricingConfig {
|
||||
export interface GlobalModelCreate {
|
||||
name: string
|
||||
display_name: string
|
||||
description?: string
|
||||
official_url?: string
|
||||
icon_url?: string
|
||||
// 按次计费配置(可选,与阶梯计费叠加)
|
||||
default_price_per_request?: number
|
||||
// 阶梯计费配置(必填,固定价格用单阶梯表示)
|
||||
default_tiered_pricing: TieredPricingConfig
|
||||
// 默认能力配置
|
||||
default_supports_vision?: boolean
|
||||
default_supports_function_calling?: boolean
|
||||
default_supports_streaming?: boolean
|
||||
default_supports_extended_thinking?: boolean
|
||||
default_supports_image_generation?: boolean
|
||||
// Key 能力配置 - 模型支持的能力列表
|
||||
supported_capabilities?: string[]
|
||||
// 模型配置(JSON格式)- 包含能力、规格、元信息等
|
||||
config?: Record<string, any>
|
||||
is_active?: boolean
|
||||
}
|
||||
|
||||
export interface GlobalModelUpdate {
|
||||
display_name?: string
|
||||
description?: string
|
||||
official_url?: string
|
||||
icon_url?: string
|
||||
is_active?: boolean
|
||||
// 按次计费配置
|
||||
default_price_per_request?: number | null // null 表示清空
|
||||
// 阶梯计费配置
|
||||
default_tiered_pricing?: TieredPricingConfig
|
||||
// 默认能力配置
|
||||
default_supports_vision?: boolean
|
||||
default_supports_function_calling?: boolean
|
||||
default_supports_streaming?: boolean
|
||||
default_supports_extended_thinking?: boolean
|
||||
default_supports_image_generation?: boolean
|
||||
// Key 能力配置 - 模型支持的能力列表
|
||||
supported_capabilities?: string[] | null
|
||||
// 模型配置(JSON格式)- 包含能力、规格、元信息等
|
||||
config?: Record<string, any> | null
|
||||
}
|
||||
|
||||
export interface GlobalModelResponse {
|
||||
id: string
|
||||
name: string
|
||||
display_name: string
|
||||
description?: string
|
||||
official_url?: string
|
||||
icon_url?: string
|
||||
is_active: boolean
|
||||
// 按次计费配置
|
||||
default_price_per_request?: number
|
||||
// 阶梯计费配置(必填)
|
||||
default_tiered_pricing: TieredPricingConfig
|
||||
// 默认能力配置
|
||||
default_supports_vision?: boolean
|
||||
default_supports_function_calling?: boolean
|
||||
default_supports_streaming?: boolean
|
||||
default_supports_extended_thinking?: boolean
|
||||
default_supports_image_generation?: boolean
|
||||
// Key 能力配置 - 模型支持的能力列表
|
||||
supported_capabilities?: string[] | null
|
||||
// 模型配置(JSON格式)
|
||||
config?: Record<string, any> | null
|
||||
// 统计数据
|
||||
provider_count?: number
|
||||
alias_count?: number
|
||||
usage_count?: number
|
||||
created_at: string
|
||||
updated_at?: string
|
||||
|
||||
288
frontend/src/api/models-dev.ts
Normal file
288
frontend/src/api/models-dev.ts
Normal file
@@ -0,0 +1,288 @@
|
||||
/**
|
||||
* Models.dev API 服务
|
||||
* 通过后端代理获取 models.dev 数据(解决跨域问题)
|
||||
*/
|
||||
|
||||
import api from './client'
|
||||
|
||||
// 缓存配置
|
||||
const CACHE_KEY = 'models_dev_cache'
|
||||
const CACHE_DURATION = 15 * 60 * 1000 // 15 分钟
|
||||
|
||||
// Models.dev API 数据结构
|
||||
export interface ModelsDevCost {
|
||||
input?: number
|
||||
output?: number
|
||||
reasoning?: number
|
||||
cache_read?: number
|
||||
}
|
||||
|
||||
export interface ModelsDevLimit {
|
||||
context?: number
|
||||
output?: number
|
||||
}
|
||||
|
||||
export interface ModelsDevModel {
|
||||
id: string
|
||||
name: string
|
||||
family?: string
|
||||
reasoning?: boolean
|
||||
tool_call?: boolean
|
||||
structured_output?: boolean
|
||||
temperature?: boolean
|
||||
attachment?: boolean
|
||||
knowledge?: string
|
||||
release_date?: string
|
||||
last_updated?: string
|
||||
input?: string[] // 输入模态: text, image, audio, video, pdf
|
||||
output?: string[] // 输出模态: text, image, audio
|
||||
open_weights?: boolean
|
||||
cost?: ModelsDevCost
|
||||
limit?: ModelsDevLimit
|
||||
deprecated?: boolean
|
||||
}
|
||||
|
||||
export interface ModelsDevProvider {
|
||||
id: string
|
||||
env?: string[]
|
||||
npm?: string
|
||||
api?: string
|
||||
name: string
|
||||
doc?: string
|
||||
models: Record<string, ModelsDevModel>
|
||||
official?: boolean // 是否为官方提供商
|
||||
}
|
||||
|
||||
export type ModelsDevData = Record<string, ModelsDevProvider>
|
||||
|
||||
// 扁平化的模型列表项(用于搜索和选择)
|
||||
export interface ModelsDevModelItem {
|
||||
providerId: string
|
||||
providerName: string
|
||||
modelId: string
|
||||
modelName: string
|
||||
family?: string
|
||||
inputPrice?: number
|
||||
outputPrice?: number
|
||||
contextLimit?: number
|
||||
outputLimit?: number
|
||||
supportsVision?: boolean
|
||||
supportsToolCall?: boolean
|
||||
supportsReasoning?: boolean
|
||||
supportsStructuredOutput?: boolean
|
||||
supportsTemperature?: boolean
|
||||
supportsAttachment?: boolean
|
||||
openWeights?: boolean
|
||||
deprecated?: boolean
|
||||
official?: boolean // 是否来自官方提供商
|
||||
// 用于 display_metadata 的额外字段
|
||||
knowledgeCutoff?: string
|
||||
releaseDate?: string
|
||||
inputModalities?: string[]
|
||||
outputModalities?: string[]
|
||||
}
|
||||
|
||||
interface CacheData {
|
||||
timestamp: number
|
||||
data: ModelsDevData
|
||||
}
|
||||
|
||||
// 内存缓存
|
||||
let memoryCache: CacheData | null = null
|
||||
|
||||
function hasOfficialFlag(data: ModelsDevData): boolean {
|
||||
return Object.values(data).some(provider => typeof provider?.official === 'boolean')
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 models.dev 数据(带缓存)
|
||||
*/
|
||||
export async function getModelsDevData(): Promise<ModelsDevData> {
|
||||
// 1. 检查内存缓存
|
||||
if (memoryCache && Date.now() - memoryCache.timestamp < CACHE_DURATION) {
|
||||
// 兼容旧缓存:没有 official 字段时丢弃,强制刷新一次
|
||||
if (hasOfficialFlag(memoryCache.data)) {
|
||||
return memoryCache.data
|
||||
}
|
||||
memoryCache = null
|
||||
}
|
||||
|
||||
// 2. 检查 localStorage 缓存
|
||||
try {
|
||||
const cached = localStorage.getItem(CACHE_KEY)
|
||||
if (cached) {
|
||||
const cacheData: CacheData = JSON.parse(cached)
|
||||
if (Date.now() - cacheData.timestamp < CACHE_DURATION) {
|
||||
// 兼容旧缓存:没有 official 字段时丢弃,强制刷新一次
|
||||
if (hasOfficialFlag(cacheData.data)) {
|
||||
memoryCache = cacheData
|
||||
return cacheData.data
|
||||
}
|
||||
localStorage.removeItem(CACHE_KEY)
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// 缓存解析失败,忽略
|
||||
}
|
||||
|
||||
// 3. 从后端代理获取新数据
|
||||
const response = await api.get<ModelsDevData>('/api/admin/models/external')
|
||||
const data = response.data
|
||||
|
||||
// 4. 更新缓存
|
||||
const cacheData: CacheData = {
|
||||
timestamp: Date.now(),
|
||||
data,
|
||||
}
|
||||
memoryCache = cacheData
|
||||
try {
|
||||
localStorage.setItem(CACHE_KEY, JSON.stringify(cacheData))
|
||||
} catch {
|
||||
// localStorage 写入失败,忽略
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// 模型列表缓存(避免重复转换)
|
||||
let modelsListCache: ModelsDevModelItem[] | null = null
|
||||
let modelsListCacheTimestamp: number | null = null
|
||||
|
||||
/**
|
||||
* 获取扁平化的模型列表
|
||||
* 数据只加载一次,通过参数过滤官方/全部
|
||||
*/
|
||||
export async function getModelsDevList(officialOnly: boolean = true): Promise<ModelsDevModelItem[]> {
|
||||
const data = await getModelsDevData()
|
||||
const currentTimestamp = memoryCache?.timestamp ?? 0
|
||||
|
||||
// 如果缓存为空或数据已刷新,构建一次
|
||||
if (!modelsListCache || modelsListCacheTimestamp !== currentTimestamp) {
|
||||
const items: ModelsDevModelItem[] = []
|
||||
|
||||
for (const [providerId, provider] of Object.entries(data)) {
|
||||
if (!provider.models) continue
|
||||
|
||||
for (const [modelId, model] of Object.entries(provider.models)) {
|
||||
items.push({
|
||||
providerId,
|
||||
providerName: provider.name,
|
||||
modelId,
|
||||
modelName: model.name || modelId,
|
||||
family: model.family,
|
||||
inputPrice: model.cost?.input,
|
||||
outputPrice: model.cost?.output,
|
||||
contextLimit: model.limit?.context,
|
||||
outputLimit: model.limit?.output,
|
||||
supportsVision: model.input?.includes('image'),
|
||||
supportsToolCall: model.tool_call,
|
||||
supportsReasoning: model.reasoning,
|
||||
supportsStructuredOutput: model.structured_output,
|
||||
supportsTemperature: model.temperature,
|
||||
supportsAttachment: model.attachment,
|
||||
openWeights: model.open_weights,
|
||||
deprecated: model.deprecated,
|
||||
official: provider.official,
|
||||
// display_metadata 相关字段
|
||||
knowledgeCutoff: model.knowledge,
|
||||
releaseDate: model.release_date,
|
||||
inputModalities: model.input,
|
||||
outputModalities: model.output,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 按 provider 名称和模型名称排序
|
||||
items.sort((a, b) => {
|
||||
const providerCompare = a.providerName.localeCompare(b.providerName)
|
||||
if (providerCompare !== 0) return providerCompare
|
||||
return a.modelName.localeCompare(b.modelName)
|
||||
})
|
||||
|
||||
modelsListCache = items
|
||||
modelsListCacheTimestamp = currentTimestamp
|
||||
}
|
||||
|
||||
// 根据参数过滤
|
||||
if (officialOnly) {
|
||||
return modelsListCache.filter(m => m.official)
|
||||
}
|
||||
return modelsListCache
|
||||
}
|
||||
|
||||
/**
|
||||
* 搜索模型
|
||||
* 搜索时包含所有提供商(包括第三方)
|
||||
*/
|
||||
export async function searchModelsDevModels(
|
||||
query: string,
|
||||
options?: {
|
||||
limit?: number
|
||||
excludeDeprecated?: boolean
|
||||
}
|
||||
): Promise<ModelsDevModelItem[]> {
|
||||
// 搜索时包含全部提供商
|
||||
const allModels = await getModelsDevList(false)
|
||||
const { limit = 50, excludeDeprecated = true } = options || {}
|
||||
|
||||
const queryLower = query.toLowerCase()
|
||||
|
||||
const filtered = allModels.filter((model) => {
|
||||
if (excludeDeprecated && model.deprecated) return false
|
||||
|
||||
// 搜索模型 ID、名称、provider 名称、family
|
||||
return (
|
||||
model.modelId.toLowerCase().includes(queryLower) ||
|
||||
model.modelName.toLowerCase().includes(queryLower) ||
|
||||
model.providerName.toLowerCase().includes(queryLower) ||
|
||||
model.family?.toLowerCase().includes(queryLower)
|
||||
)
|
||||
})
|
||||
|
||||
// 排序:精确匹配优先
|
||||
filtered.sort((a, b) => {
|
||||
const aExact =
|
||||
a.modelId.toLowerCase() === queryLower ||
|
||||
a.modelName.toLowerCase() === queryLower
|
||||
const bExact =
|
||||
b.modelId.toLowerCase() === queryLower ||
|
||||
b.modelName.toLowerCase() === queryLower
|
||||
if (aExact && !bExact) return -1
|
||||
if (!aExact && bExact) return 1
|
||||
return 0
|
||||
})
|
||||
|
||||
return filtered.slice(0, limit)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取特定模型详情
|
||||
*/
|
||||
export async function getModelsDevModel(
|
||||
providerId: string,
|
||||
modelId: string
|
||||
): Promise<ModelsDevModel | null> {
|
||||
const data = await getModelsDevData()
|
||||
return data[providerId]?.models?.[modelId] || null
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 provider logo URL
|
||||
*/
|
||||
export function getProviderLogoUrl(providerId: string): string {
|
||||
return `https://models.dev/logos/${providerId}.svg`
|
||||
}
|
||||
|
||||
/**
|
||||
* 清除缓存
|
||||
*/
|
||||
export function clearModelsDevCache(): void {
|
||||
memoryCache = null
|
||||
modelsListCache = null
|
||||
modelsListCacheTimestamp = null
|
||||
try {
|
||||
localStorage.removeItem(CACHE_KEY)
|
||||
} catch {
|
||||
// 忽略错误
|
||||
}
|
||||
}
|
||||
@@ -9,20 +9,14 @@ export interface PublicGlobalModel {
|
||||
id: string
|
||||
name: string
|
||||
display_name: string | null
|
||||
description: string | null
|
||||
icon_url: string | null
|
||||
is_active: boolean
|
||||
// 阶梯计费配置
|
||||
default_tiered_pricing: TieredPricingConfig
|
||||
default_price_per_request: number | null // 按次计费价格
|
||||
// 能力
|
||||
default_supports_vision: boolean
|
||||
default_supports_function_calling: boolean
|
||||
default_supports_streaming: boolean
|
||||
default_supports_extended_thinking: boolean
|
||||
default_supports_image_generation: boolean
|
||||
// Key 能力支持
|
||||
supported_capabilities: string[] | null
|
||||
// 模型配置(JSON)
|
||||
config: Record<string, any> | null
|
||||
}
|
||||
|
||||
export interface PublicGlobalModelListResponse {
|
||||
|
||||
@@ -299,7 +299,7 @@ function formatDuration(ms: number): string {
|
||||
const hours = Math.floor(ms / (1000 * 60 * 60))
|
||||
const minutes = Math.floor((ms % (1000 * 60 * 60)) / (1000 * 60))
|
||||
if (hours > 0) {
|
||||
return `${hours}h${minutes > 0 ? minutes + 'm' : ''}`
|
||||
return `${hours}h${minutes > 0 ? `${minutes}m` : ''}`
|
||||
}
|
||||
return `${minutes}m`
|
||||
}
|
||||
|
||||
@@ -34,11 +34,10 @@ const buttonClass = computed(() => {
|
||||
'inline-flex items-center justify-center rounded-xl text-sm font-semibold transition-all duration-200 ring-offset-background focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50 active:scale-[0.98]'
|
||||
|
||||
const variantClasses = {
|
||||
default:
|
||||
'bg-primary text-white shadow-[0_20px_35px_rgba(204,120,92,0.35)] hover:bg-primary/90 hover:shadow-[0_25px_45px_rgba(204,120,92,0.45)]',
|
||||
destructive: 'bg-destructive text-destructive-foreground hover:bg-destructive/85 shadow-sm',
|
||||
default: 'bg-primary text-white hover:bg-primary/90',
|
||||
destructive: 'bg-destructive text-destructive-foreground hover:bg-destructive/85',
|
||||
outline:
|
||||
'border border-border/60 bg-card/60 text-foreground hover:border-primary/60 hover:text-primary hover:bg-primary/10 shadow-sm backdrop-blur transition-all',
|
||||
'border border-border/60 bg-card/60 text-foreground hover:border-primary/60 hover:text-primary hover:bg-primary/10 backdrop-blur transition-all',
|
||||
secondary:
|
||||
'bg-secondary text-secondary-foreground shadow-inner hover:bg-secondary/80',
|
||||
ghost: 'hover:bg-accent hover:text-accent-foreground',
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
<Teleport to="body">
|
||||
<div
|
||||
v-if="isOpen"
|
||||
class="fixed inset-0 overflow-y-auto"
|
||||
class="fixed inset-0 overflow-y-auto pointer-events-none"
|
||||
:style="{ zIndex: containerZIndex }"
|
||||
>
|
||||
<!-- 背景遮罩 -->
|
||||
@@ -16,13 +16,13 @@
|
||||
>
|
||||
<div
|
||||
v-if="isOpen"
|
||||
class="fixed inset-0 bg-black/40 backdrop-blur-sm transition-opacity"
|
||||
class="fixed inset-0 bg-black/40 backdrop-blur-sm transition-opacity pointer-events-auto"
|
||||
:style="{ zIndex: backdropZIndex }"
|
||||
@click="handleClose"
|
||||
/>
|
||||
</Transition>
|
||||
|
||||
<div class="relative flex min-h-full items-end justify-center p-4 text-center sm:items-center sm:p-0">
|
||||
<div class="relative flex min-h-full items-end justify-center p-4 text-center sm:items-center sm:p-0 pointer-events-none">
|
||||
<!-- 对话框内容 -->
|
||||
<Transition
|
||||
enter-active-class="duration-300 ease-out"
|
||||
@@ -34,7 +34,7 @@
|
||||
>
|
||||
<div
|
||||
v-if="isOpen"
|
||||
class="relative transform rounded-lg bg-background text-left shadow-2xl transition-all sm:my-8 sm:w-full border border-border"
|
||||
class="relative transform rounded-lg bg-background text-left shadow-2xl transition-all sm:my-8 sm:w-full border border-border pointer-events-auto"
|
||||
:style="{ zIndex: contentZIndex }"
|
||||
:class="maxWidthClass"
|
||||
@click.stop
|
||||
|
||||
@@ -45,7 +45,7 @@ const props = withDefaults(defineProps<Props>(), {
|
||||
|
||||
const contentClass = computed(() =>
|
||||
cn(
|
||||
'z-[100] max-h-96 min-w-[8rem] overflow-hidden rounded-2xl border border-border bg-card text-foreground shadow-2xl backdrop-blur-xl pointer-events-auto',
|
||||
'z-[200] max-h-96 min-w-[8rem] overflow-hidden rounded-2xl border border-border bg-card text-foreground shadow-2xl backdrop-blur-xl pointer-events-auto',
|
||||
'data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95',
|
||||
'data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2',
|
||||
props.class
|
||||
|
||||
@@ -2,174 +2,304 @@
|
||||
<Dialog
|
||||
:model-value="open"
|
||||
:title="isEditMode ? '编辑模型' : '创建统一模型'"
|
||||
:description="isEditMode ? '修改模型配置和价格信息' : '添加一个新的全局模型定义'"
|
||||
:description="isEditMode ? '修改模型配置和价格信息' : ''"
|
||||
:icon="isEditMode ? SquarePen : Layers"
|
||||
size="xl"
|
||||
size="3xl"
|
||||
@update:model-value="handleDialogUpdate"
|
||||
>
|
||||
<form
|
||||
class="space-y-5 max-h-[70vh] overflow-y-auto pr-1"
|
||||
@submit.prevent="handleSubmit"
|
||||
<div
|
||||
class="flex gap-4"
|
||||
:class="isEditMode ? '' : 'h-[500px]'"
|
||||
>
|
||||
<!-- 基本信息 -->
|
||||
<section class="space-y-3">
|
||||
<h4 class="font-medium text-sm">
|
||||
基本信息
|
||||
</h4>
|
||||
|
||||
<div class="grid grid-cols-2 gap-3">
|
||||
<div class="space-y-1.5">
|
||||
<Label
|
||||
for="model-name"
|
||||
class="text-xs"
|
||||
>模型名称 *</Label>
|
||||
<Input
|
||||
id="model-name"
|
||||
v-model="form.name"
|
||||
placeholder="claude-3-5-sonnet-20241022"
|
||||
:disabled="isEditMode"
|
||||
required
|
||||
/>
|
||||
<p
|
||||
v-if="!isEditMode"
|
||||
class="text-xs text-muted-foreground"
|
||||
>
|
||||
创建后不可修改
|
||||
</p>
|
||||
</div>
|
||||
<div class="space-y-1.5">
|
||||
<Label
|
||||
for="model-display-name"
|
||||
class="text-xs"
|
||||
>显示名称 *</Label>
|
||||
<Input
|
||||
id="model-display-name"
|
||||
v-model="form.display_name"
|
||||
placeholder="Claude 3.5 Sonnet"
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="space-y-1.5">
|
||||
<Label
|
||||
for="model-description"
|
||||
class="text-xs"
|
||||
>描述</Label>
|
||||
<Input
|
||||
id="model-description"
|
||||
v-model="form.description"
|
||||
placeholder="简短描述此模型的特点"
|
||||
/>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- 能力配置 -->
|
||||
<section class="space-y-2">
|
||||
<h4 class="font-medium text-sm">
|
||||
默认能力
|
||||
</h4>
|
||||
<div class="flex flex-wrap gap-2">
|
||||
<label class="flex items-center gap-2 px-3 py-1.5 rounded-md border border-border bg-muted/30 cursor-pointer text-sm">
|
||||
<input
|
||||
v-model="form.default_supports_streaming"
|
||||
type="checkbox"
|
||||
class="rounded"
|
||||
>
|
||||
<Zap class="w-3.5 h-3.5 text-muted-foreground" />
|
||||
<span>流式输出</span>
|
||||
</label>
|
||||
<label class="flex items-center gap-2 px-3 py-1.5 rounded-md border border-border bg-muted/30 cursor-pointer text-sm">
|
||||
<input
|
||||
v-model="form.default_supports_vision"
|
||||
type="checkbox"
|
||||
class="rounded"
|
||||
>
|
||||
<Eye class="w-3.5 h-3.5 text-muted-foreground" />
|
||||
<span>视觉理解</span>
|
||||
</label>
|
||||
<label class="flex items-center gap-2 px-3 py-1.5 rounded-md border border-border bg-muted/30 cursor-pointer text-sm">
|
||||
<input
|
||||
v-model="form.default_supports_function_calling"
|
||||
type="checkbox"
|
||||
class="rounded"
|
||||
>
|
||||
<Wrench class="w-3.5 h-3.5 text-muted-foreground" />
|
||||
<span>工具调用</span>
|
||||
</label>
|
||||
<label class="flex items-center gap-2 px-3 py-1.5 rounded-md border border-border bg-muted/30 cursor-pointer text-sm">
|
||||
<input
|
||||
v-model="form.default_supports_extended_thinking"
|
||||
type="checkbox"
|
||||
class="rounded"
|
||||
>
|
||||
<Brain class="w-3.5 h-3.5 text-muted-foreground" />
|
||||
<span>深度思考</span>
|
||||
</label>
|
||||
<label class="flex items-center gap-2 px-3 py-1.5 rounded-md border border-border bg-muted/30 cursor-pointer text-sm">
|
||||
<input
|
||||
v-model="form.default_supports_image_generation"
|
||||
type="checkbox"
|
||||
class="rounded"
|
||||
>
|
||||
<Image class="w-3.5 h-3.5 text-muted-foreground" />
|
||||
<span>图像生成</span>
|
||||
</label>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- Key 能力配置 -->
|
||||
<section
|
||||
v-if="availableCapabilities.length > 0"
|
||||
class="space-y-2"
|
||||
<!-- 左侧:模型选择(仅创建模式) -->
|
||||
<div
|
||||
v-if="!isEditMode"
|
||||
class="w-[260px] shrink-0 flex flex-col h-full"
|
||||
>
|
||||
<h4 class="font-medium text-sm">
|
||||
Key 能力支持
|
||||
</h4>
|
||||
<div class="flex flex-wrap gap-2">
|
||||
<label
|
||||
v-for="cap in availableCapabilities"
|
||||
:key="cap.name"
|
||||
class="flex items-center gap-2 px-3 py-1.5 rounded-md border border-border bg-muted/30 cursor-pointer text-sm"
|
||||
>
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="form.supported_capabilities?.includes(cap.name)"
|
||||
class="rounded"
|
||||
@change="toggleCapability(cap.name)"
|
||||
>
|
||||
<span>{{ cap.display_name }}</span>
|
||||
</label>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- 价格配置 -->
|
||||
<section class="space-y-3">
|
||||
<h4 class="font-medium text-sm">
|
||||
价格配置
|
||||
</h4>
|
||||
<TieredPricingEditor
|
||||
ref="tieredPricingEditorRef"
|
||||
v-model="tieredPricing"
|
||||
:show-cache1h="form.supported_capabilities?.includes('cache_1h')"
|
||||
/>
|
||||
|
||||
<!-- 按次计费 -->
|
||||
<div class="flex items-center gap-3 pt-2 border-t">
|
||||
<Label class="text-xs whitespace-nowrap">按次计费 ($/次)</Label>
|
||||
<!-- 搜索框 -->
|
||||
<div class="relative mb-3">
|
||||
<Search class="absolute left-2.5 top-1/2 -translate-y-1/2 h-4 w-4 text-muted-foreground" />
|
||||
<Input
|
||||
:model-value="form.default_price_per_request ?? ''"
|
||||
type="number"
|
||||
step="0.001"
|
||||
min="0"
|
||||
class="w-32"
|
||||
placeholder="留空不启用"
|
||||
@update:model-value="(v) => form.default_price_per_request = parseNumberInput(v, { allowFloat: true })"
|
||||
v-model="searchQuery"
|
||||
type="text"
|
||||
placeholder="搜索模型、提供商..."
|
||||
class="pl-8 h-8 text-sm"
|
||||
/>
|
||||
<span class="text-xs text-muted-foreground">每次请求固定费用,可与 Token 计费叠加</span>
|
||||
</div>
|
||||
</section>
|
||||
</form>
|
||||
|
||||
<!-- 模型列表(两级结构) -->
|
||||
<div class="flex-1 overflow-y-auto border rounded-lg min-h-0 scrollbar-thin">
|
||||
<div
|
||||
v-if="loading"
|
||||
class="flex items-center justify-center h-32"
|
||||
>
|
||||
<Loader2 class="w-5 h-5 animate-spin text-muted-foreground" />
|
||||
</div>
|
||||
<template v-else>
|
||||
<!-- 提供商分组 -->
|
||||
<div
|
||||
v-for="group in groupedModels"
|
||||
:key="group.providerId"
|
||||
class="border-b last:border-b-0"
|
||||
>
|
||||
<!-- 提供商标题行 -->
|
||||
<div
|
||||
class="flex items-center gap-2 px-2.5 py-2 cursor-pointer hover:bg-muted text-sm"
|
||||
@click="toggleProvider(group.providerId)"
|
||||
>
|
||||
<ChevronRight
|
||||
class="w-3.5 h-3.5 text-muted-foreground transition-transform shrink-0"
|
||||
:class="expandedProvider === group.providerId ? 'rotate-90' : ''"
|
||||
/>
|
||||
<img
|
||||
:src="getProviderLogoUrl(group.providerId)"
|
||||
:alt="group.providerName"
|
||||
class="w-4 h-4 rounded shrink-0 dark:invert dark:brightness-90"
|
||||
@error="handleLogoError"
|
||||
>
|
||||
<span class="truncate font-medium text-xs flex-1">{{ group.providerName }}</span>
|
||||
<span class="text-[10px] text-muted-foreground shrink-0">{{ group.models.length }}</span>
|
||||
</div>
|
||||
<!-- 模型列表 -->
|
||||
<div
|
||||
v-if="expandedProvider === group.providerId"
|
||||
class="bg-muted/30"
|
||||
>
|
||||
<div
|
||||
v-for="model in group.models"
|
||||
:key="model.modelId"
|
||||
class="flex flex-col gap-0.5 pl-7 pr-2.5 py-1.5 cursor-pointer text-xs border-t"
|
||||
:class="selectedModel?.modelId === model.modelId && selectedModel?.providerId === model.providerId
|
||||
? 'bg-primary text-primary-foreground'
|
||||
: 'hover:bg-muted'"
|
||||
@click="selectModel(model)"
|
||||
>
|
||||
<span class="truncate font-medium">{{ model.modelName }}</span>
|
||||
<span
|
||||
class="truncate text-[10px]"
|
||||
:class="selectedModel?.modelId === model.modelId && selectedModel?.providerId === model.providerId
|
||||
? 'text-primary-foreground/70'
|
||||
: 'text-muted-foreground'"
|
||||
>{{ model.modelId }}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
v-if="groupedModels.length === 0"
|
||||
class="text-center py-8 text-sm text-muted-foreground"
|
||||
>
|
||||
{{ searchQuery ? '未找到模型' : '加载中...' }}
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 右侧:表单 -->
|
||||
<div
|
||||
class="flex-1 overflow-y-auto h-full scrollbar-thin"
|
||||
:class="isEditMode ? 'max-h-[70vh]' : ''"
|
||||
>
|
||||
<form
|
||||
class="space-y-5"
|
||||
@submit.prevent="handleSubmit"
|
||||
>
|
||||
<!-- 基本信息 -->
|
||||
<section class="space-y-3">
|
||||
<h4 class="font-medium text-sm">
|
||||
基本信息
|
||||
</h4>
|
||||
<div class="grid grid-cols-2 gap-3">
|
||||
<div class="space-y-1.5">
|
||||
<Label
|
||||
for="model-name"
|
||||
class="text-xs"
|
||||
>模型名称 *</Label>
|
||||
<Input
|
||||
id="model-name"
|
||||
v-model="form.name"
|
||||
placeholder="claude-3-5-sonnet-20241022"
|
||||
:disabled="isEditMode"
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
<div class="space-y-1.5">
|
||||
<Label
|
||||
for="model-display-name"
|
||||
class="text-xs"
|
||||
>显示名称 *</Label>
|
||||
<Input
|
||||
id="model-display-name"
|
||||
v-model="form.display_name"
|
||||
placeholder="Claude 3.5 Sonnet"
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div class="space-y-1.5">
|
||||
<Label
|
||||
for="model-description"
|
||||
class="text-xs"
|
||||
>描述</Label>
|
||||
<Input
|
||||
id="model-description"
|
||||
:model-value="form.config?.description || ''"
|
||||
placeholder="简短描述此模型的特点"
|
||||
@update:model-value="(v) => setConfigField('description', v || undefined)"
|
||||
/>
|
||||
</div>
|
||||
<div class="grid grid-cols-3 gap-3">
|
||||
<div class="space-y-1.5">
|
||||
<Label
|
||||
for="model-family"
|
||||
class="text-xs"
|
||||
>模型系列</Label>
|
||||
<Input
|
||||
id="model-family"
|
||||
:model-value="form.config?.family || ''"
|
||||
placeholder="如 GPT-4、Claude 3"
|
||||
@update:model-value="(v) => setConfigField('family', v || undefined)"
|
||||
/>
|
||||
</div>
|
||||
<div class="space-y-1.5">
|
||||
<Label
|
||||
for="model-context-limit"
|
||||
class="text-xs"
|
||||
>上下文限制</Label>
|
||||
<Input
|
||||
id="model-context-limit"
|
||||
type="number"
|
||||
:model-value="form.config?.context_limit ?? ''"
|
||||
placeholder="如 128000"
|
||||
@update:model-value="(v) => setConfigField('context_limit', v ? Number(v) : undefined)"
|
||||
/>
|
||||
</div>
|
||||
<div class="space-y-1.5">
|
||||
<Label
|
||||
for="model-output-limit"
|
||||
class="text-xs"
|
||||
>输出限制</Label>
|
||||
<Input
|
||||
id="model-output-limit"
|
||||
type="number"
|
||||
:model-value="form.config?.output_limit ?? ''"
|
||||
placeholder="如 8192"
|
||||
@update:model-value="(v) => setConfigField('output_limit', v ? Number(v) : undefined)"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- 能力配置 -->
|
||||
<section class="space-y-2">
|
||||
<h4 class="font-medium text-sm">
|
||||
默认能力
|
||||
</h4>
|
||||
<div class="flex flex-wrap gap-2">
|
||||
<label class="flex items-center gap-2 px-2.5 py-1 rounded-md border bg-muted/30 cursor-pointer text-sm">
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="form.config?.streaming !== false"
|
||||
class="rounded"
|
||||
@change="setConfigField('streaming', ($event.target as HTMLInputElement).checked)"
|
||||
>
|
||||
<Zap class="w-3.5 h-3.5 text-muted-foreground" />
|
||||
<span>流式</span>
|
||||
</label>
|
||||
<label class="flex items-center gap-2 px-2.5 py-1 rounded-md border bg-muted/30 cursor-pointer text-sm">
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="form.config?.vision === true"
|
||||
class="rounded"
|
||||
@change="setConfigField('vision', ($event.target as HTMLInputElement).checked)"
|
||||
>
|
||||
<Eye class="w-3.5 h-3.5 text-muted-foreground" />
|
||||
<span>视觉</span>
|
||||
</label>
|
||||
<label class="flex items-center gap-2 px-2.5 py-1 rounded-md border bg-muted/30 cursor-pointer text-sm">
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="form.config?.function_calling === true"
|
||||
class="rounded"
|
||||
@change="setConfigField('function_calling', ($event.target as HTMLInputElement).checked)"
|
||||
>
|
||||
<Wrench class="w-3.5 h-3.5 text-muted-foreground" />
|
||||
<span>工具</span>
|
||||
</label>
|
||||
<label class="flex items-center gap-2 px-2.5 py-1 rounded-md border bg-muted/30 cursor-pointer text-sm">
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="form.config?.extended_thinking === true"
|
||||
class="rounded"
|
||||
@change="setConfigField('extended_thinking', ($event.target as HTMLInputElement).checked)"
|
||||
>
|
||||
<Brain class="w-3.5 h-3.5 text-muted-foreground" />
|
||||
<span>思考</span>
|
||||
</label>
|
||||
<label class="flex items-center gap-2 px-2.5 py-1 rounded-md border bg-muted/30 cursor-pointer text-sm">
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="form.config?.image_generation === true"
|
||||
class="rounded"
|
||||
@change="setConfigField('image_generation', ($event.target as HTMLInputElement).checked)"
|
||||
>
|
||||
<Image class="w-3.5 h-3.5 text-muted-foreground" />
|
||||
<span>生图</span>
|
||||
</label>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- Key 能力配置 -->
|
||||
<section
|
||||
v-if="availableCapabilities.length > 0"
|
||||
class="space-y-2"
|
||||
>
|
||||
<h4 class="font-medium text-sm">
|
||||
Key 能力支持
|
||||
</h4>
|
||||
<div class="flex flex-wrap gap-2">
|
||||
<label
|
||||
v-for="cap in availableCapabilities"
|
||||
:key="cap.name"
|
||||
class="flex items-center gap-2 px-2.5 py-1 rounded-md border bg-muted/30 cursor-pointer text-sm"
|
||||
>
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="form.supported_capabilities?.includes(cap.name)"
|
||||
class="rounded"
|
||||
@change="toggleCapability(cap.name)"
|
||||
>
|
||||
<span>{{ cap.display_name }}</span>
|
||||
</label>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- 价格配置 -->
|
||||
<section class="space-y-3">
|
||||
<h4 class="font-medium text-sm">
|
||||
价格配置
|
||||
</h4>
|
||||
<TieredPricingEditor
|
||||
ref="tieredPricingEditorRef"
|
||||
v-model="tieredPricing"
|
||||
:show-cache1h="form.supported_capabilities?.includes('cache_1h')"
|
||||
/>
|
||||
<div class="flex items-center gap-3 pt-2 border-t">
|
||||
<Label class="text-xs whitespace-nowrap">按次计费</Label>
|
||||
<Input
|
||||
:model-value="form.default_price_per_request ?? ''"
|
||||
type="number"
|
||||
step="0.001"
|
||||
min="0"
|
||||
class="w-24"
|
||||
placeholder="$/次"
|
||||
@update:model-value="(v) => form.default_price_per_request = parseNumberInput(v, { allowFloat: true })"
|
||||
/>
|
||||
<span class="text-xs text-muted-foreground">可与 Token 计费叠加</span>
|
||||
</div>
|
||||
</section>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<template #footer>
|
||||
<Button
|
||||
@@ -180,7 +310,7 @@
|
||||
取消
|
||||
</Button>
|
||||
<Button
|
||||
:disabled="submitting"
|
||||
:disabled="submitting || !form.name || !form.display_name"
|
||||
@click="handleSubmit"
|
||||
>
|
||||
<Loader2
|
||||
@@ -189,19 +319,35 @@
|
||||
/>
|
||||
{{ isEditMode ? '保存' : '创建' }}
|
||||
</Button>
|
||||
<Button
|
||||
v-if="selectedModel && !isEditMode"
|
||||
type="button"
|
||||
variant="ghost"
|
||||
@click="clearSelection"
|
||||
>
|
||||
清空
|
||||
</Button>
|
||||
</template>
|
||||
</Dialog>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, onMounted } from 'vue'
|
||||
import { Eye, Wrench, Brain, Zap, Image, Loader2, Layers, SquarePen } from 'lucide-vue-next'
|
||||
import { ref, computed, onMounted, watch } from 'vue'
|
||||
import {
|
||||
Eye, Wrench, Brain, Zap, Image, Loader2, Layers, SquarePen,
|
||||
Search, ChevronRight
|
||||
} from 'lucide-vue-next'
|
||||
import { Dialog, Button, Input, Label } from '@/components/ui'
|
||||
import { useToast } from '@/composables/useToast'
|
||||
import { useFormDialog } from '@/composables/useFormDialog'
|
||||
import { parseNumberInput } from '@/utils/form'
|
||||
import { log } from '@/utils/logger'
|
||||
import TieredPricingEditor from './TieredPricingEditor.vue'
|
||||
import {
|
||||
getModelsDevList,
|
||||
getProviderLogoUrl,
|
||||
type ModelsDevModelItem,
|
||||
} from '@/api/models-dev'
|
||||
import {
|
||||
createGlobalModel,
|
||||
updateGlobalModel,
|
||||
@@ -226,42 +372,147 @@ const { success, error: showError } = useToast()
|
||||
const submitting = ref(false)
|
||||
const tieredPricingEditorRef = ref<InstanceType<typeof TieredPricingEditor> | null>(null)
|
||||
|
||||
// 阶梯计费配置(统一使用,固定价格就是单阶梯)
|
||||
// 模型列表相关
|
||||
const loading = ref(false)
|
||||
const searchQuery = ref('')
|
||||
const allModelsCache = ref<ModelsDevModelItem[]>([]) // 全部模型(缓存)
|
||||
const selectedModel = ref<ModelsDevModelItem | null>(null)
|
||||
const expandedProvider = ref<string | null>(null)
|
||||
|
||||
// 当前显示的模型列表:有搜索词时用全部,否则只用官方
|
||||
const allModels = computed(() => {
|
||||
if (searchQuery.value) {
|
||||
return allModelsCache.value
|
||||
}
|
||||
return allModelsCache.value.filter(m => m.official)
|
||||
})
|
||||
|
||||
// 按提供商分组的模型
|
||||
interface ProviderGroup {
|
||||
providerId: string
|
||||
providerName: string
|
||||
models: ModelsDevModelItem[]
|
||||
}
|
||||
|
||||
const groupedModels = computed(() => {
|
||||
let models = allModels.value.filter(m => !m.deprecated)
|
||||
// 搜索(支持空格分隔的多关键词 AND 搜索)
|
||||
if (searchQuery.value) {
|
||||
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||
models = models.filter(model => {
|
||||
const searchableText = `${model.providerId} ${model.providerName} ${model.modelId} ${model.modelName} ${model.family || ''}`.toLowerCase()
|
||||
return keywords.every(keyword => searchableText.includes(keyword))
|
||||
})
|
||||
}
|
||||
|
||||
// 按提供商分组
|
||||
const groups = new Map<string, ProviderGroup>()
|
||||
for (const model of models) {
|
||||
if (!groups.has(model.providerId)) {
|
||||
groups.set(model.providerId, {
|
||||
providerId: model.providerId,
|
||||
providerName: model.providerName,
|
||||
models: []
|
||||
})
|
||||
}
|
||||
groups.get(model.providerId)!.models.push(model)
|
||||
}
|
||||
|
||||
// 转换为数组并排序
|
||||
const result = Array.from(groups.values())
|
||||
|
||||
// 如果有搜索词,把提供商名称/ID匹配的排在前面
|
||||
if (searchQuery.value) {
|
||||
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||
result.sort((a, b) => {
|
||||
const aText = `${a.providerId} ${a.providerName}`.toLowerCase()
|
||||
const bText = `${b.providerId} ${b.providerName}`.toLowerCase()
|
||||
const aProviderMatch = keywords.some(k => aText.includes(k))
|
||||
const bProviderMatch = keywords.some(k => bText.includes(k))
|
||||
if (aProviderMatch && !bProviderMatch) return -1
|
||||
if (!aProviderMatch && bProviderMatch) return 1
|
||||
return a.providerName.localeCompare(b.providerName)
|
||||
})
|
||||
} else {
|
||||
result.sort((a, b) => a.providerName.localeCompare(b.providerName))
|
||||
}
|
||||
|
||||
return result
|
||||
})
|
||||
|
||||
// 搜索时如果只有一个提供商,自动展开
|
||||
watch(groupedModels, (groups) => {
|
||||
if (searchQuery.value && groups.length === 1) {
|
||||
expandedProvider.value = groups[0].providerId
|
||||
}
|
||||
})
|
||||
|
||||
// 切换提供商展开状态
|
||||
function toggleProvider(providerId: string) {
|
||||
expandedProvider.value = expandedProvider.value === providerId ? null : providerId
|
||||
}
|
||||
|
||||
// 阶梯计费配置
|
||||
const tieredPricing = ref<TieredPricingConfig | null>(null)
|
||||
|
||||
interface FormData {
|
||||
name: string
|
||||
display_name: string
|
||||
description?: string
|
||||
default_price_per_request?: number
|
||||
default_supports_streaming?: boolean
|
||||
default_supports_image_generation?: boolean
|
||||
default_supports_vision?: boolean
|
||||
default_supports_function_calling?: boolean
|
||||
default_supports_extended_thinking?: boolean
|
||||
supported_capabilities?: string[]
|
||||
config?: Record<string, any>
|
||||
is_active?: boolean
|
||||
}
|
||||
|
||||
const defaultForm = (): FormData => ({
|
||||
name: '',
|
||||
display_name: '',
|
||||
description: '',
|
||||
default_price_per_request: undefined,
|
||||
default_supports_streaming: true,
|
||||
default_supports_image_generation: false,
|
||||
default_supports_vision: false,
|
||||
default_supports_function_calling: false,
|
||||
default_supports_extended_thinking: false,
|
||||
supported_capabilities: [],
|
||||
config: { streaming: true },
|
||||
is_active: true,
|
||||
})
|
||||
|
||||
const form = ref<FormData>(defaultForm())
|
||||
|
||||
const KEEP_FALSE_CONFIG_KEYS = new Set(['streaming'])
|
||||
|
||||
// 设置 config 字段
|
||||
function setConfigField(key: string, value: any) {
|
||||
if (!form.value.config) {
|
||||
form.value.config = {}
|
||||
}
|
||||
if (value === undefined || value === '' || (value === false && !KEEP_FALSE_CONFIG_KEYS.has(key))) {
|
||||
delete form.value.config[key]
|
||||
} else {
|
||||
form.value.config[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
// Key 能力选项
|
||||
const availableCapabilities = ref<CapabilityDefinition[]>([])
|
||||
|
||||
// 加载模型列表
|
||||
async function loadModels() {
|
||||
if (allModelsCache.value.length > 0) return
|
||||
loading.value = true
|
||||
try {
|
||||
// 只加载一次全部模型,过滤在 computed 中完成
|
||||
allModelsCache.value = await getModelsDevList(false)
|
||||
} catch (err) {
|
||||
log.error('Failed to load models:', err)
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 打开对话框时加载数据
|
||||
watch(() => props.open, (isOpen) => {
|
||||
if (isOpen && !props.model) {
|
||||
loadModels()
|
||||
}
|
||||
})
|
||||
|
||||
// 加载可用能力列表
|
||||
async function loadCapabilities() {
|
||||
try {
|
||||
@@ -284,38 +535,92 @@ function toggleCapability(capName: string) {
|
||||
}
|
||||
}
|
||||
|
||||
// 组件挂载时加载能力列表
|
||||
onMounted(() => {
|
||||
loadCapabilities()
|
||||
})
|
||||
|
||||
// 选择模型并填充表单
|
||||
function selectModel(model: ModelsDevModelItem) {
|
||||
selectedModel.value = model
|
||||
expandedProvider.value = model.providerId
|
||||
form.value.name = model.modelId
|
||||
form.value.display_name = model.modelName
|
||||
|
||||
// 构建 config
|
||||
const config: Record<string, any> = {
|
||||
streaming: true,
|
||||
}
|
||||
if (model.supportsVision) config.vision = true
|
||||
if (model.supportsToolCall) config.function_calling = true
|
||||
if (model.supportsReasoning) config.extended_thinking = true
|
||||
if (model.supportsStructuredOutput) config.structured_output = true
|
||||
if (model.supportsTemperature !== false) config.temperature = model.supportsTemperature
|
||||
if (model.supportsAttachment) config.attachment = true
|
||||
if (model.openWeights) config.open_weights = true
|
||||
if (model.contextLimit) config.context_limit = model.contextLimit
|
||||
if (model.outputLimit) config.output_limit = model.outputLimit
|
||||
if (model.knowledgeCutoff) config.knowledge_cutoff = model.knowledgeCutoff
|
||||
if (model.family) config.family = model.family
|
||||
if (model.releaseDate) config.release_date = model.releaseDate
|
||||
if (model.inputModalities?.length) config.input_modalities = model.inputModalities
|
||||
if (model.outputModalities?.length) config.output_modalities = model.outputModalities
|
||||
form.value.config = config
|
||||
|
||||
if (model.inputPrice !== undefined || model.outputPrice !== undefined) {
|
||||
tieredPricing.value = {
|
||||
tiers: [{
|
||||
up_to: null,
|
||||
input_price_per_1m: model.inputPrice || 0,
|
||||
output_price_per_1m: model.outputPrice || 0,
|
||||
}]
|
||||
}
|
||||
} else {
|
||||
tieredPricing.value = null
|
||||
}
|
||||
}
|
||||
|
||||
// 清除选择(手动填写)
|
||||
function clearSelection() {
|
||||
selectedModel.value = null
|
||||
form.value = defaultForm()
|
||||
tieredPricing.value = null
|
||||
}
|
||||
|
||||
// Logo 加载失败处理
|
||||
function handleLogoError(event: Event) {
|
||||
const img = event.target as HTMLImageElement
|
||||
img.style.display = 'none'
|
||||
}
|
||||
|
||||
// 重置表单
|
||||
function resetForm() {
|
||||
form.value = defaultForm()
|
||||
tieredPricing.value = null
|
||||
searchQuery.value = ''
|
||||
selectedModel.value = null
|
||||
expandedProvider.value = null
|
||||
}
|
||||
|
||||
// 加载模型数据(编辑模式)
|
||||
function loadModelData() {
|
||||
if (!props.model) return
|
||||
// 先重置创建模式的残留状态
|
||||
selectedModel.value = null
|
||||
searchQuery.value = ''
|
||||
expandedProvider.value = null
|
||||
|
||||
form.value = {
|
||||
name: props.model.name,
|
||||
display_name: props.model.display_name,
|
||||
description: props.model.description,
|
||||
default_price_per_request: props.model.default_price_per_request,
|
||||
default_supports_streaming: props.model.default_supports_streaming,
|
||||
default_supports_image_generation: props.model.default_supports_image_generation,
|
||||
default_supports_vision: props.model.default_supports_vision,
|
||||
default_supports_function_calling: props.model.default_supports_function_calling,
|
||||
default_supports_extended_thinking: props.model.default_supports_extended_thinking,
|
||||
supported_capabilities: [...(props.model.supported_capabilities || [])],
|
||||
config: props.model.config ? { ...props.model.config } : { streaming: true },
|
||||
is_active: props.model.is_active,
|
||||
}
|
||||
|
||||
// 加载阶梯计费配置(深拷贝)
|
||||
if (props.model.default_tiered_pricing) {
|
||||
tieredPricing.value = JSON.parse(JSON.stringify(props.model.default_tiered_pricing))
|
||||
}
|
||||
// 确保 tieredPricing 也被正确设置或重置
|
||||
tieredPricing.value = props.model.default_tiered_pricing
|
||||
? JSON.parse(JSON.stringify(props.model.default_tiered_pricing))
|
||||
: null
|
||||
}
|
||||
|
||||
// 使用 useFormDialog 统一处理对话框逻辑
|
||||
@@ -339,24 +644,22 @@ async function handleSubmit() {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取包含自动计算缓存价格的最终数据
|
||||
const finalTiers = tieredPricingEditorRef.value?.getFinalTiers()
|
||||
const finalTieredPricing = finalTiers ? { tiers: finalTiers } : tieredPricing.value
|
||||
|
||||
// 清理空的 config
|
||||
const cleanConfig = form.value.config && Object.keys(form.value.config).length > 0
|
||||
? form.value.config
|
||||
: undefined
|
||||
|
||||
submitting.value = true
|
||||
try {
|
||||
if (isEditMode.value && props.model) {
|
||||
const updateData: GlobalModelUpdate = {
|
||||
display_name: form.value.display_name,
|
||||
description: form.value.description,
|
||||
// 使用 null 而不是 undefined 来显式清空字段
|
||||
config: cleanConfig || null,
|
||||
default_price_per_request: form.value.default_price_per_request ?? null,
|
||||
default_tiered_pricing: finalTieredPricing,
|
||||
default_supports_streaming: form.value.default_supports_streaming,
|
||||
default_supports_image_generation: form.value.default_supports_image_generation,
|
||||
default_supports_vision: form.value.default_supports_vision,
|
||||
default_supports_function_calling: form.value.default_supports_function_calling,
|
||||
default_supports_extended_thinking: form.value.default_supports_extended_thinking,
|
||||
supported_capabilities: form.value.supported_capabilities?.length ? form.value.supported_capabilities : null,
|
||||
is_active: form.value.is_active,
|
||||
}
|
||||
@@ -366,14 +669,9 @@ async function handleSubmit() {
|
||||
const createData: GlobalModelCreate = {
|
||||
name: form.value.name!,
|
||||
display_name: form.value.display_name!,
|
||||
description: form.value.description,
|
||||
default_price_per_request: form.value.default_price_per_request || undefined,
|
||||
config: cleanConfig,
|
||||
default_price_per_request: form.value.default_price_per_request ?? undefined,
|
||||
default_tiered_pricing: finalTieredPricing,
|
||||
default_supports_streaming: form.value.default_supports_streaming,
|
||||
default_supports_image_generation: form.value.default_supports_image_generation,
|
||||
default_supports_vision: form.value.default_supports_vision,
|
||||
default_supports_function_calling: form.value.default_supports_function_calling,
|
||||
default_supports_extended_thinking: form.value.default_supports_extended_thinking,
|
||||
supported_capabilities: form.value.supported_capabilities?.length ? form.value.supported_capabilities : undefined,
|
||||
is_active: form.value.is_active,
|
||||
}
|
||||
|
||||
@@ -38,12 +38,12 @@
|
||||
>
|
||||
<Copy class="w-3 h-3" />
|
||||
</button>
|
||||
<template v-if="model.description">
|
||||
<template v-if="model.config?.description">
|
||||
<span class="shrink-0">·</span>
|
||||
<span
|
||||
class="text-xs truncate"
|
||||
:title="model.description"
|
||||
>{{ model.description }}</span>
|
||||
:title="model.config?.description"
|
||||
>{{ model.config?.description }}</span>
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
@@ -143,10 +143,10 @@
|
||||
</p>
|
||||
</div>
|
||||
<Badge
|
||||
:variant="model.default_supports_streaming ?? false ? 'default' : 'secondary'"
|
||||
:variant="model.config?.streaming !== false ? 'default' : 'secondary'"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ model.default_supports_streaming ?? false ? '支持' : '不支持' }}
|
||||
{{ model.config?.streaming !== false ? '支持' : '不支持' }}
|
||||
</Badge>
|
||||
</div>
|
||||
<div class="flex items-center gap-2 p-3 rounded-lg border">
|
||||
@@ -160,10 +160,10 @@
|
||||
</p>
|
||||
</div>
|
||||
<Badge
|
||||
:variant="model.default_supports_image_generation ?? false ? 'default' : 'secondary'"
|
||||
:variant="model.config?.image_generation === true ? 'default' : 'secondary'"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ model.default_supports_image_generation ?? false ? '支持' : '不支持' }}
|
||||
{{ model.config?.image_generation === true ? '支持' : '不支持' }}
|
||||
</Badge>
|
||||
</div>
|
||||
<div class="flex items-center gap-2 p-3 rounded-lg border">
|
||||
@@ -177,10 +177,10 @@
|
||||
</p>
|
||||
</div>
|
||||
<Badge
|
||||
:variant="model.default_supports_vision ?? false ? 'default' : 'secondary'"
|
||||
:variant="model.config?.vision === true ? 'default' : 'secondary'"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ model.default_supports_vision ?? false ? '支持' : '不支持' }}
|
||||
{{ model.config?.vision === true ? '支持' : '不支持' }}
|
||||
</Badge>
|
||||
</div>
|
||||
<div class="flex items-center gap-2 p-3 rounded-lg border">
|
||||
@@ -194,10 +194,10 @@
|
||||
</p>
|
||||
</div>
|
||||
<Badge
|
||||
:variant="model.default_supports_function_calling ?? false ? 'default' : 'secondary'"
|
||||
:variant="model.config?.function_calling === true ? 'default' : 'secondary'"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ model.default_supports_function_calling ?? false ? '支持' : '不支持' }}
|
||||
{{ model.config?.function_calling === true ? '支持' : '不支持' }}
|
||||
</Badge>
|
||||
</div>
|
||||
<div class="flex items-center gap-2 p-3 rounded-lg border">
|
||||
@@ -211,10 +211,10 @@
|
||||
</p>
|
||||
</div>
|
||||
<Badge
|
||||
:variant="model.default_supports_extended_thinking ?? false ? 'default' : 'secondary'"
|
||||
:variant="model.config?.extended_thinking === true ? 'default' : 'secondary'"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ model.default_supports_extended_thinking ?? false ? '支持' : '不支持' }}
|
||||
{{ model.config?.extended_thinking === true ? '支持' : '不支持' }}
|
||||
</Badge>
|
||||
</div>
|
||||
</div>
|
||||
@@ -396,11 +396,11 @@
|
||||
</div>
|
||||
<div class="p-3 rounded-lg border bg-muted/20">
|
||||
<div class="flex items-center justify-between">
|
||||
<Label class="text-xs text-muted-foreground">别名数量</Label>
|
||||
<Tag class="w-4 h-4 text-muted-foreground" />
|
||||
<Label class="text-xs text-muted-foreground">调用次数</Label>
|
||||
<BarChart3 class="w-4 h-4 text-muted-foreground" />
|
||||
</div>
|
||||
<p class="text-2xl font-bold mt-1">
|
||||
{{ model.alias_count || 0 }}
|
||||
{{ model.usage_count || 0 }}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
@@ -455,105 +455,153 @@
|
||||
<template v-else-if="providers.length > 0">
|
||||
<!-- 桌面端表格 -->
|
||||
<Table class="hidden sm:table">
|
||||
<TableHeader>
|
||||
<TableRow class="border-b border-border/60 hover:bg-transparent">
|
||||
<TableHead class="h-10 font-semibold">
|
||||
Provider
|
||||
</TableHead>
|
||||
<TableHead class="w-[120px] h-10 font-semibold">
|
||||
能力
|
||||
</TableHead>
|
||||
<TableHead class="w-[180px] h-10 font-semibold">
|
||||
价格 ($/M)
|
||||
</TableHead>
|
||||
<TableHead class="w-[80px] h-10 font-semibold text-center">
|
||||
操作
|
||||
</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
<TableRow
|
||||
<TableHeader>
|
||||
<TableRow class="border-b border-border/60 hover:bg-transparent">
|
||||
<TableHead class="h-10 font-semibold">
|
||||
Provider
|
||||
</TableHead>
|
||||
<TableHead class="w-[120px] h-10 font-semibold">
|
||||
能力
|
||||
</TableHead>
|
||||
<TableHead class="w-[180px] h-10 font-semibold">
|
||||
价格 ($/M)
|
||||
</TableHead>
|
||||
<TableHead class="w-[80px] h-10 font-semibold text-center">
|
||||
操作
|
||||
</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
<TableRow
|
||||
v-for="provider in providers"
|
||||
:key="provider.id"
|
||||
class="border-b border-border/40 hover:bg-muted/30 transition-colors"
|
||||
>
|
||||
<TableCell class="py-3">
|
||||
<div class="flex items-center gap-2">
|
||||
<span
|
||||
class="w-2 h-2 rounded-full shrink-0"
|
||||
:class="provider.is_active ? 'bg-green-500' : 'bg-gray-300'"
|
||||
:title="provider.is_active ? '活跃' : '停用'"
|
||||
/>
|
||||
<span class="font-medium truncate">{{ provider.display_name }}</span>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell class="py-3">
|
||||
<div class="flex gap-0.5">
|
||||
<Zap
|
||||
v-if="provider.supports_streaming"
|
||||
class="w-3.5 h-3.5 text-muted-foreground"
|
||||
title="流式输出"
|
||||
/>
|
||||
<Eye
|
||||
v-if="provider.supports_vision"
|
||||
class="w-3.5 h-3.5 text-muted-foreground"
|
||||
title="视觉理解"
|
||||
/>
|
||||
<Wrench
|
||||
v-if="provider.supports_function_calling"
|
||||
class="w-3.5 h-3.5 text-muted-foreground"
|
||||
title="工具调用"
|
||||
/>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell class="py-3">
|
||||
<div class="text-xs font-mono space-y-0.5">
|
||||
<!-- Token 计费:输入/输出 -->
|
||||
<div v-if="(provider.input_price_per_1m || 0) > 0 || (provider.output_price_per_1m || 0) > 0">
|
||||
<span class="text-muted-foreground">输入/输出:</span>
|
||||
<span class="ml-1">${{ (provider.input_price_per_1m || 0).toFixed(1) }}/${{ (provider.output_price_per_1m || 0).toFixed(1) }}</span>
|
||||
<!-- 阶梯标记 -->
|
||||
<span
|
||||
v-if="(provider.tier_count || 1) > 1"
|
||||
class="ml-1 text-muted-foreground"
|
||||
title="阶梯计费"
|
||||
>[阶梯]</span>
|
||||
</div>
|
||||
<!-- 缓存价格 -->
|
||||
<div
|
||||
v-if="(provider.cache_creation_price_per_1m || 0) > 0 || (provider.cache_read_price_per_1m || 0) > 0"
|
||||
class="text-muted-foreground"
|
||||
>
|
||||
<span>缓存:</span>
|
||||
<span class="ml-1">${{ (provider.cache_creation_price_per_1m || 0).toFixed(2) }}/${{ (provider.cache_read_price_per_1m || 0).toFixed(2) }}</span>
|
||||
</div>
|
||||
<!-- 1h 缓存价格 -->
|
||||
<div
|
||||
v-if="(provider.cache_1h_creation_price_per_1m || 0) > 0"
|
||||
class="text-muted-foreground"
|
||||
>
|
||||
<span>1h 缓存:</span>
|
||||
<span class="ml-1">${{ (provider.cache_1h_creation_price_per_1m || 0).toFixed(2) }}</span>
|
||||
</div>
|
||||
<!-- 按次计费 -->
|
||||
<div v-if="(provider.price_per_request || 0) > 0">
|
||||
<span class="text-muted-foreground">按次:</span>
|
||||
<span class="ml-1">${{ (provider.price_per_request || 0).toFixed(3) }}/次</span>
|
||||
</div>
|
||||
<!-- 无定价 -->
|
||||
<span
|
||||
v-if="!(provider.input_price_per_1m || 0) && !(provider.output_price_per_1m || 0) && !(provider.price_per_request || 0)"
|
||||
class="text-muted-foreground"
|
||||
>-</span>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell class="py-3 text-center">
|
||||
<div class="flex items-center justify-center gap-1">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
title="编辑此关联"
|
||||
@click="$emit('editProvider', provider)"
|
||||
>
|
||||
<Edit class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
:title="provider.is_active ? '停用此关联' : '启用此关联'"
|
||||
@click="$emit('toggleProviderStatus', provider)"
|
||||
>
|
||||
<Power class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
title="删除此关联"
|
||||
@click="$emit('deleteProvider', provider)"
|
||||
>
|
||||
<Trash2 class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
</TableBody>
|
||||
</Table>
|
||||
|
||||
<!-- 移动端卡片列表 -->
|
||||
<div class="sm:hidden divide-y divide-border/40">
|
||||
<div
|
||||
v-for="provider in providers"
|
||||
:key="provider.id"
|
||||
class="border-b border-border/40 hover:bg-muted/30 transition-colors"
|
||||
class="p-4 space-y-3"
|
||||
>
|
||||
<TableCell class="py-3">
|
||||
<div class="flex items-center gap-2">
|
||||
<div class="flex items-start justify-between gap-3">
|
||||
<div class="flex items-center gap-2 min-w-0">
|
||||
<span
|
||||
class="w-2 h-2 rounded-full shrink-0"
|
||||
:class="provider.is_active ? 'bg-green-500' : 'bg-gray-300'"
|
||||
:title="provider.is_active ? '活跃' : '停用'"
|
||||
/>
|
||||
<span class="font-medium truncate">{{ provider.display_name }}</span>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell class="py-3">
|
||||
<div class="flex gap-0.5">
|
||||
<Zap
|
||||
v-if="provider.supports_streaming"
|
||||
class="w-3.5 h-3.5 text-muted-foreground"
|
||||
title="流式输出"
|
||||
/>
|
||||
<Eye
|
||||
v-if="provider.supports_vision"
|
||||
class="w-3.5 h-3.5 text-muted-foreground"
|
||||
title="视觉理解"
|
||||
/>
|
||||
<Wrench
|
||||
v-if="provider.supports_function_calling"
|
||||
class="w-3.5 h-3.5 text-muted-foreground"
|
||||
title="工具调用"
|
||||
/>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell class="py-3">
|
||||
<div class="text-xs font-mono space-y-0.5">
|
||||
<!-- Token 计费:输入/输出 -->
|
||||
<div v-if="(provider.input_price_per_1m || 0) > 0 || (provider.output_price_per_1m || 0) > 0">
|
||||
<span class="text-muted-foreground">输入/输出:</span>
|
||||
<span class="ml-1">${{ (provider.input_price_per_1m || 0).toFixed(1) }}/${{ (provider.output_price_per_1m || 0).toFixed(1) }}</span>
|
||||
<!-- 阶梯标记 -->
|
||||
<span
|
||||
v-if="(provider.tier_count || 1) > 1"
|
||||
class="ml-1 text-muted-foreground"
|
||||
title="阶梯计费"
|
||||
>[阶梯]</span>
|
||||
</div>
|
||||
<!-- 缓存价格 -->
|
||||
<div
|
||||
v-if="(provider.cache_creation_price_per_1m || 0) > 0 || (provider.cache_read_price_per_1m || 0) > 0"
|
||||
class="text-muted-foreground"
|
||||
>
|
||||
<span>缓存:</span>
|
||||
<span class="ml-1">${{ (provider.cache_creation_price_per_1m || 0).toFixed(2) }}/${{ (provider.cache_read_price_per_1m || 0).toFixed(2) }}</span>
|
||||
</div>
|
||||
<!-- 1h 缓存价格 -->
|
||||
<div
|
||||
v-if="(provider.cache_1h_creation_price_per_1m || 0) > 0"
|
||||
class="text-muted-foreground"
|
||||
>
|
||||
<span>1h 缓存:</span>
|
||||
<span class="ml-1">${{ (provider.cache_1h_creation_price_per_1m || 0).toFixed(2) }}</span>
|
||||
</div>
|
||||
<!-- 按次计费 -->
|
||||
<div v-if="(provider.price_per_request || 0) > 0">
|
||||
<span class="text-muted-foreground">按次:</span>
|
||||
<span class="ml-1">${{ (provider.price_per_request || 0).toFixed(3) }}/次</span>
|
||||
</div>
|
||||
<!-- 无定价 -->
|
||||
<span
|
||||
v-if="!(provider.input_price_per_1m || 0) && !(provider.output_price_per_1m || 0) && !(provider.price_per_request || 0)"
|
||||
class="text-muted-foreground"
|
||||
>-</span>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell class="py-3 text-center">
|
||||
<div class="flex items-center justify-center gap-1">
|
||||
<div class="flex items-center gap-1 shrink-0">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
title="编辑此关联"
|
||||
@click="$emit('editProvider', provider)"
|
||||
>
|
||||
<Edit class="w-3.5 h-3.5" />
|
||||
@@ -562,7 +610,6 @@
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
:title="provider.is_active ? '停用此关联' : '启用此关联'"
|
||||
@click="$emit('toggleProviderStatus', provider)"
|
||||
>
|
||||
<Power class="w-3.5 h-3.5" />
|
||||
@@ -571,82 +618,35 @@
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
title="删除此关联"
|
||||
@click="$emit('deleteProvider', provider)"
|
||||
>
|
||||
<Trash2 class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
</TableBody>
|
||||
</Table>
|
||||
|
||||
<!-- 移动端卡片列表 -->
|
||||
<div class="sm:hidden divide-y divide-border/40">
|
||||
<div
|
||||
v-for="provider in providers"
|
||||
:key="provider.id"
|
||||
class="p-4 space-y-3"
|
||||
>
|
||||
<div class="flex items-start justify-between gap-3">
|
||||
<div class="flex items-center gap-2 min-w-0">
|
||||
<span
|
||||
class="w-2 h-2 rounded-full shrink-0"
|
||||
:class="provider.is_active ? 'bg-green-500' : 'bg-gray-300'"
|
||||
/>
|
||||
<span class="font-medium truncate">{{ provider.display_name }}</span>
|
||||
</div>
|
||||
<div class="flex items-center gap-1 shrink-0">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
@click="$emit('editProvider', provider)"
|
||||
<div class="flex items-center gap-3 text-xs">
|
||||
<div class="flex gap-1">
|
||||
<Zap
|
||||
v-if="provider.supports_streaming"
|
||||
class="w-3.5 h-3.5 text-muted-foreground"
|
||||
/>
|
||||
<Eye
|
||||
v-if="provider.supports_vision"
|
||||
class="w-3.5 h-3.5 text-muted-foreground"
|
||||
/>
|
||||
<Wrench
|
||||
v-if="provider.supports_function_calling"
|
||||
class="w-3.5 h-3.5 text-muted-foreground"
|
||||
/>
|
||||
</div>
|
||||
<div
|
||||
v-if="(provider.input_price_per_1m || 0) > 0 || (provider.output_price_per_1m || 0) > 0"
|
||||
class="text-muted-foreground font-mono"
|
||||
>
|
||||
<Edit class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
@click="$emit('toggleProviderStatus', provider)"
|
||||
>
|
||||
<Power class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
@click="$emit('deleteProvider', provider)"
|
||||
>
|
||||
<Trash2 class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
${{ (provider.input_price_per_1m || 0).toFixed(1) }}/${{ (provider.output_price_per_1m || 0).toFixed(1) }}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex items-center gap-3 text-xs">
|
||||
<div class="flex gap-1">
|
||||
<Zap
|
||||
v-if="provider.supports_streaming"
|
||||
class="w-3.5 h-3.5 text-muted-foreground"
|
||||
/>
|
||||
<Eye
|
||||
v-if="provider.supports_vision"
|
||||
class="w-3.5 h-3.5 text-muted-foreground"
|
||||
/>
|
||||
<Wrench
|
||||
v-if="provider.supports_function_calling"
|
||||
class="w-3.5 h-3.5 text-muted-foreground"
|
||||
/>
|
||||
</div>
|
||||
<div
|
||||
v-if="(provider.input_price_per_1m || 0) > 0 || (provider.output_price_per_1m || 0) > 0"
|
||||
class="text-muted-foreground font-mono"
|
||||
>
|
||||
${{ (provider.input_price_per_1m || 0).toFixed(1) }}/${{ (provider.output_price_per_1m || 0).toFixed(1) }}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
@@ -695,7 +695,8 @@ import {
|
||||
Loader2,
|
||||
RefreshCw,
|
||||
Copy,
|
||||
Layers
|
||||
Layers,
|
||||
BarChart3
|
||||
} from 'lucide-vue-next'
|
||||
import { useToast } from '@/composables/useToast'
|
||||
import Card from '@/components/ui/card.vue'
|
||||
|
||||
@@ -117,8 +117,12 @@
|
||||
class="text-center py-6 text-muted-foreground border rounded-lg border-dashed"
|
||||
>
|
||||
<Tag class="w-8 h-8 mx-auto mb-2 opacity-50" />
|
||||
<p class="text-sm">未配置映射</p>
|
||||
<p class="text-xs mt-1">将只使用主模型名称</p>
|
||||
<p class="text-sm">
|
||||
未配置映射
|
||||
</p>
|
||||
<p class="text-xs mt-1">
|
||||
将只使用主模型名称
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -312,8 +312,41 @@
|
||||
|
||||
<template #footer>
|
||||
<div class="flex items-center justify-between w-full">
|
||||
<div class="text-xs text-muted-foreground">
|
||||
当前模式: <span class="font-medium">{{ activeMainTab === 'provider' ? '提供商优先' : 'Key 优先' }}</span>
|
||||
<div class="flex items-center gap-4">
|
||||
<div class="text-xs text-muted-foreground">
|
||||
当前模式: <span class="font-medium">{{ activeMainTab === 'provider' ? '提供商优先' : 'Key 优先' }}</span>
|
||||
</div>
|
||||
<div class="flex items-center gap-2 pl-4 border-l border-border">
|
||||
<span class="text-xs text-muted-foreground">调度:</span>
|
||||
<div class="flex gap-0.5 p-0.5 bg-muted/40 rounded-md">
|
||||
<button
|
||||
type="button"
|
||||
class="px-2 py-1 text-xs font-medium rounded transition-all"
|
||||
:class="[
|
||||
schedulingMode === 'fixed_order'
|
||||
? 'bg-primary text-primary-foreground shadow-sm'
|
||||
: 'text-muted-foreground hover:text-foreground hover:bg-muted/50'
|
||||
]"
|
||||
title="严格按优先级顺序,不考虑缓存"
|
||||
@click="schedulingMode = 'fixed_order'"
|
||||
>
|
||||
固定顺序
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
class="px-2 py-1 text-xs font-medium rounded transition-all"
|
||||
:class="[
|
||||
schedulingMode === 'cache_affinity'
|
||||
? 'bg-primary text-primary-foreground shadow-sm'
|
||||
: 'text-muted-foreground hover:text-foreground hover:bg-muted/50'
|
||||
]"
|
||||
title="优先使用已缓存的Provider,利用Prompt Cache"
|
||||
@click="schedulingMode = 'cache_affinity'"
|
||||
>
|
||||
缓存亲和
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex gap-2">
|
||||
<Button
|
||||
@@ -410,6 +443,9 @@ const saving = ref(false)
|
||||
// Key 优先级编辑状态
|
||||
const editingKeyPriority = ref<Record<string, string | null>>({}) // format -> keyId
|
||||
|
||||
// 调度模式状态
|
||||
const schedulingMode = ref<'fixed_order' | 'cache_affinity'>('cache_affinity')
|
||||
|
||||
// 可用的 API 格式
|
||||
const availableFormats = computed(() => {
|
||||
return Object.keys(keysByFormat.value).sort()
|
||||
@@ -433,11 +469,18 @@ watch(internalOpen, async (open) => {
|
||||
// 加载当前的优先级模式配置
|
||||
async function loadCurrentPriorityMode() {
|
||||
try {
|
||||
const response = await adminApi.getSystemConfig('provider_priority_mode')
|
||||
const currentMode = response.value || 'provider'
|
||||
const [priorityResponse, schedulingResponse] = await Promise.all([
|
||||
adminApi.getSystemConfig('provider_priority_mode'),
|
||||
adminApi.getSystemConfig('scheduling_mode')
|
||||
])
|
||||
const currentMode = priorityResponse.value || 'provider'
|
||||
activeMainTab.value = currentMode === 'global_key' ? 'key' : 'provider'
|
||||
|
||||
const currentSchedulingMode = schedulingResponse.value || 'cache_affinity'
|
||||
schedulingMode.value = currentSchedulingMode === 'fixed_order' ? 'fixed_order' : 'cache_affinity'
|
||||
} catch {
|
||||
activeMainTab.value = 'provider'
|
||||
schedulingMode.value = 'cache_affinity'
|
||||
}
|
||||
}
|
||||
|
||||
@@ -611,11 +654,19 @@ async function save() {
|
||||
|
||||
const newMode = activeMainTab.value === 'key' ? 'global_key' : 'provider'
|
||||
|
||||
await adminApi.updateSystemConfig(
|
||||
'provider_priority_mode',
|
||||
newMode,
|
||||
'Provider/Key 优先级策略:provider(提供商优先模式) 或 global_key(全局Key优先模式)'
|
||||
)
|
||||
// 保存优先级模式和调度模式
|
||||
await Promise.all([
|
||||
adminApi.updateSystemConfig(
|
||||
'provider_priority_mode',
|
||||
newMode,
|
||||
'Provider/Key 优先级策略:provider(提供商优先模式) 或 global_key(全局Key优先模式)'
|
||||
),
|
||||
adminApi.updateSystemConfig(
|
||||
'scheduling_mode',
|
||||
schedulingMode.value,
|
||||
'调度模式:fixed_order(固定顺序模式) 或 cache_affinity(缓存亲和模式)'
|
||||
)
|
||||
])
|
||||
|
||||
const providerUpdates = sortedProviders.value.map((provider, index) =>
|
||||
updateProvider(provider.id, { provider_priority: index + 1 })
|
||||
|
||||
@@ -526,7 +526,14 @@
|
||||
@edit-model="handleEditModel"
|
||||
@delete-model="handleDeleteModel"
|
||||
@batch-assign="handleBatchAssign"
|
||||
@manage-alias="handleManageAlias"
|
||||
/>
|
||||
|
||||
<!-- 模型名称映射 -->
|
||||
<ModelAliasesTab
|
||||
v-if="provider"
|
||||
:key="`aliases-${provider.id}`"
|
||||
:provider="provider"
|
||||
@refresh="handleRelatedDataRefresh"
|
||||
/>
|
||||
</div>
|
||||
</template>
|
||||
@@ -629,16 +636,6 @@
|
||||
@update:open="batchAssignDialogOpen = $event"
|
||||
@changed="handleBatchAssignChanged"
|
||||
/>
|
||||
|
||||
<!-- 模型别名管理对话框 -->
|
||||
<ModelAliasDialog
|
||||
v-if="open && provider"
|
||||
:open="aliasDialogOpen"
|
||||
:provider-id="provider.id"
|
||||
:model="aliasEditingModel"
|
||||
@update:open="aliasDialogOpen = $event"
|
||||
@saved="handleAliasSaved"
|
||||
/>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
@@ -667,8 +664,8 @@ import {
|
||||
KeyFormDialog,
|
||||
KeyAllowedModelsDialog,
|
||||
ModelsTab,
|
||||
BatchAssignModelsDialog,
|
||||
ModelAliasDialog
|
||||
ModelAliasesTab,
|
||||
BatchAssignModelsDialog
|
||||
} from '@/features/providers/components'
|
||||
import EndpointFormDialog from '@/features/providers/components/EndpointFormDialog.vue'
|
||||
import ProviderModelFormDialog from '@/features/providers/components/ProviderModelFormDialog.vue'
|
||||
@@ -737,10 +734,6 @@ const deleteModelConfirmOpen = ref(false)
|
||||
const modelToDelete = ref<Model | null>(null)
|
||||
const batchAssignDialogOpen = ref(false)
|
||||
|
||||
// 别名管理相关状态
|
||||
const aliasDialogOpen = ref(false)
|
||||
const aliasEditingModel = ref<Model | null>(null)
|
||||
|
||||
// 拖动排序相关状态
|
||||
const dragState = ref({
|
||||
isDragging: false,
|
||||
@@ -762,8 +755,7 @@ const hasBlockingDialogOpen = computed(() =>
|
||||
deleteKeyConfirmOpen.value ||
|
||||
modelFormDialogOpen.value ||
|
||||
deleteModelConfirmOpen.value ||
|
||||
batchAssignDialogOpen.value ||
|
||||
aliasDialogOpen.value
|
||||
batchAssignDialogOpen.value
|
||||
)
|
||||
|
||||
// 监听 providerId 变化
|
||||
@@ -792,7 +784,6 @@ watch(() => props.open, (newOpen) => {
|
||||
keyAllowedModelsDialogOpen.value = false
|
||||
deleteKeyConfirmOpen.value = false
|
||||
batchAssignDialogOpen.value = false
|
||||
aliasDialogOpen.value = false
|
||||
|
||||
// 重置临时数据
|
||||
endpointToEdit.value = null
|
||||
@@ -1030,19 +1021,6 @@ async function handleBatchAssignChanged() {
|
||||
emit('refresh')
|
||||
}
|
||||
|
||||
// 处理管理映射 - 打开别名对话框
|
||||
function handleManageAlias(model: Model) {
|
||||
aliasEditingModel.value = model
|
||||
aliasDialogOpen.value = true
|
||||
}
|
||||
|
||||
// 处理别名保存完成
|
||||
async function handleAliasSaved() {
|
||||
aliasEditingModel.value = null
|
||||
await loadProvider()
|
||||
emit('refresh')
|
||||
}
|
||||
|
||||
// 处理模型保存完成
|
||||
async function handleModelSaved() {
|
||||
editingModel.value = null
|
||||
|
||||
@@ -10,3 +10,4 @@ export { default as BatchAssignModelsDialog } from './BatchAssignModelsDialog.vu
|
||||
export { default as ModelAliasDialog } from './ModelAliasDialog.vue'
|
||||
|
||||
export { default as ModelsTab } from './provider-tabs/ModelsTab.vue'
|
||||
export { default as ModelAliasesTab } from './provider-tabs/ModelAliasesTab.vue'
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -165,15 +165,6 @@
|
||||
>
|
||||
<Edit class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-8 w-8"
|
||||
title="管理映射"
|
||||
@click="openAliasDialog(model)"
|
||||
>
|
||||
<Tag class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
@@ -218,7 +209,7 @@
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
import { Box, Edit, Trash2, Layers, Eye, Wrench, Zap, Brain, Power, Copy, Image, Tag } from 'lucide-vue-next'
|
||||
import { Box, Edit, Trash2, Layers, Eye, Wrench, Zap, Brain, Power, Copy, Image } from 'lucide-vue-next'
|
||||
import Card from '@/components/ui/card.vue'
|
||||
import Button from '@/components/ui/button.vue'
|
||||
import { useToast } from '@/composables/useToast'
|
||||
@@ -233,7 +224,6 @@ const emit = defineEmits<{
|
||||
'editModel': [model: Model]
|
||||
'deleteModel': [model: Model]
|
||||
'batchAssign': []
|
||||
'manageAlias': [model: Model]
|
||||
}>()
|
||||
|
||||
const { error: showError, success: showSuccess } = useToast()
|
||||
@@ -373,11 +363,6 @@ function openBatchAssignDialog() {
|
||||
emit('batchAssign')
|
||||
}
|
||||
|
||||
// 打开别名管理对话框
|
||||
function openAliasDialog(model: Model) {
|
||||
emit('manageAlias', model)
|
||||
}
|
||||
|
||||
// 切换模型启用状态
|
||||
async function toggleModelActive(model: Model) {
|
||||
if (togglingModelId.value) return
|
||||
|
||||
@@ -611,41 +611,42 @@ export const MOCK_GLOBAL_MODELS: GlobalModelResponse[] = [
|
||||
id: 'gm-001',
|
||||
name: 'claude-haiku-4-5-20251001',
|
||||
display_name: 'claude-haiku-4-5',
|
||||
description: 'Anthropic 最快速的 Claude 4 系列模型',
|
||||
is_active: true,
|
||||
default_tiered_pricing: {
|
||||
tiers: [{ up_to: null, input_price_per_1m: 1.00, output_price_per_1m: 5.00, cache_creation_price_per_1m: 1.25, cache_read_price_per_1m: 0.1 }]
|
||||
},
|
||||
default_supports_vision: true,
|
||||
default_supports_function_calling: true,
|
||||
default_supports_streaming: true,
|
||||
default_supports_extended_thinking: true,
|
||||
config: {
|
||||
streaming: true,
|
||||
vision: true,
|
||||
function_calling: true,
|
||||
extended_thinking: true,
|
||||
description: 'Anthropic 最快速的 Claude 4 系列模型'
|
||||
},
|
||||
provider_count: 3,
|
||||
alias_count: 2,
|
||||
created_at: '2024-01-01T00:00:00Z'
|
||||
},
|
||||
{
|
||||
id: 'gm-002',
|
||||
name: 'claude-opus-4-5-20251101',
|
||||
display_name: 'claude-opus-4-5',
|
||||
description: 'Anthropic 最强大的模型',
|
||||
is_active: true,
|
||||
default_tiered_pricing: {
|
||||
tiers: [{ up_to: null, input_price_per_1m: 5.00, output_price_per_1m: 25.00, cache_creation_price_per_1m: 6.25, cache_read_price_per_1m: 0.5 }]
|
||||
},
|
||||
default_supports_vision: true,
|
||||
default_supports_function_calling: true,
|
||||
default_supports_streaming: true,
|
||||
default_supports_extended_thinking: true,
|
||||
config: {
|
||||
streaming: true,
|
||||
vision: true,
|
||||
function_calling: true,
|
||||
extended_thinking: true,
|
||||
description: 'Anthropic 最强大的模型'
|
||||
},
|
||||
provider_count: 2,
|
||||
alias_count: 1,
|
||||
created_at: '2024-01-01T00:00:00Z'
|
||||
},
|
||||
{
|
||||
id: 'gm-003',
|
||||
name: 'claude-sonnet-4-5-20250929',
|
||||
display_name: 'claude-sonnet-4-5',
|
||||
description: 'Anthropic 平衡型模型,支持 1h 缓存和 CLI 1M 上下文',
|
||||
is_active: true,
|
||||
default_tiered_pricing: {
|
||||
tiers: [
|
||||
@@ -677,116 +678,124 @@ export const MOCK_GLOBAL_MODELS: GlobalModelResponse[] = [
|
||||
}
|
||||
]
|
||||
},
|
||||
default_supports_vision: true,
|
||||
default_supports_function_calling: true,
|
||||
default_supports_streaming: true,
|
||||
default_supports_extended_thinking: true,
|
||||
config: {
|
||||
streaming: true,
|
||||
vision: true,
|
||||
function_calling: true,
|
||||
extended_thinking: true,
|
||||
description: 'Anthropic 平衡型模型,支持 1h 缓存和 CLI 1M 上下文'
|
||||
},
|
||||
supported_capabilities: ['cache_1h', 'cli_1m'],
|
||||
provider_count: 3,
|
||||
alias_count: 2,
|
||||
created_at: '2024-01-01T00:00:00Z'
|
||||
},
|
||||
{
|
||||
id: 'gm-004',
|
||||
name: 'gemini-3-pro-image-preview',
|
||||
display_name: 'gemini-3-pro-image-preview',
|
||||
description: 'Google Gemini 3 Pro 图像生成预览版',
|
||||
is_active: true,
|
||||
default_price_per_request: 0.300,
|
||||
default_tiered_pricing: {
|
||||
tiers: []
|
||||
},
|
||||
default_supports_vision: true,
|
||||
default_supports_function_calling: false,
|
||||
default_supports_streaming: true,
|
||||
default_supports_image_generation: true,
|
||||
config: {
|
||||
streaming: true,
|
||||
vision: true,
|
||||
function_calling: false,
|
||||
image_generation: true,
|
||||
description: 'Google Gemini 3 Pro 图像生成预览版'
|
||||
},
|
||||
provider_count: 1,
|
||||
alias_count: 0,
|
||||
created_at: '2024-01-01T00:00:00Z'
|
||||
},
|
||||
{
|
||||
id: 'gm-005',
|
||||
name: 'gemini-3-pro-preview',
|
||||
display_name: 'gemini-3-pro-preview',
|
||||
description: 'Google Gemini 3 Pro 预览版',
|
||||
is_active: true,
|
||||
default_tiered_pricing: {
|
||||
tiers: [{ up_to: null, input_price_per_1m: 2.00, output_price_per_1m: 12.00 }]
|
||||
},
|
||||
default_supports_vision: true,
|
||||
default_supports_function_calling: true,
|
||||
default_supports_streaming: true,
|
||||
default_supports_extended_thinking: true,
|
||||
config: {
|
||||
streaming: true,
|
||||
vision: true,
|
||||
function_calling: true,
|
||||
extended_thinking: true,
|
||||
description: 'Google Gemini 3 Pro 预览版'
|
||||
},
|
||||
provider_count: 1,
|
||||
alias_count: 0,
|
||||
created_at: '2024-01-01T00:00:00Z'
|
||||
},
|
||||
{
|
||||
id: 'gm-006',
|
||||
name: 'gpt-5.1',
|
||||
display_name: 'gpt-5.1',
|
||||
description: 'OpenAI GPT-5.1 模型',
|
||||
is_active: true,
|
||||
default_tiered_pricing: {
|
||||
tiers: [{ up_to: null, input_price_per_1m: 1.25, output_price_per_1m: 10.00 }]
|
||||
},
|
||||
default_supports_vision: true,
|
||||
default_supports_function_calling: true,
|
||||
default_supports_streaming: true,
|
||||
default_supports_extended_thinking: true,
|
||||
config: {
|
||||
streaming: true,
|
||||
vision: true,
|
||||
function_calling: true,
|
||||
extended_thinking: true,
|
||||
description: 'OpenAI GPT-5.1 模型'
|
||||
},
|
||||
provider_count: 2,
|
||||
alias_count: 1,
|
||||
created_at: '2024-01-01T00:00:00Z'
|
||||
},
|
||||
{
|
||||
id: 'gm-007',
|
||||
name: 'gpt-5.1-codex',
|
||||
display_name: 'gpt-5.1-codex',
|
||||
description: 'OpenAI GPT-5.1 Codex 代码专用模型',
|
||||
is_active: true,
|
||||
default_tiered_pricing: {
|
||||
tiers: [{ up_to: null, input_price_per_1m: 1.25, output_price_per_1m: 10.00 }]
|
||||
},
|
||||
default_supports_vision: true,
|
||||
default_supports_function_calling: true,
|
||||
default_supports_streaming: true,
|
||||
default_supports_extended_thinking: true,
|
||||
config: {
|
||||
streaming: true,
|
||||
vision: true,
|
||||
function_calling: true,
|
||||
extended_thinking: true,
|
||||
description: 'OpenAI GPT-5.1 Codex 代码专用模型'
|
||||
},
|
||||
provider_count: 2,
|
||||
alias_count: 0,
|
||||
created_at: '2024-01-01T00:00:00Z'
|
||||
},
|
||||
{
|
||||
id: 'gm-008',
|
||||
name: 'gpt-5.1-codex-max',
|
||||
display_name: 'gpt-5.1-codex-max',
|
||||
description: 'OpenAI GPT-5.1 Codex Max 代码专用增强版',
|
||||
is_active: true,
|
||||
default_tiered_pricing: {
|
||||
tiers: [{ up_to: null, input_price_per_1m: 1.25, output_price_per_1m: 10.00 }]
|
||||
},
|
||||
default_supports_vision: true,
|
||||
default_supports_function_calling: true,
|
||||
default_supports_streaming: true,
|
||||
default_supports_extended_thinking: true,
|
||||
config: {
|
||||
streaming: true,
|
||||
vision: true,
|
||||
function_calling: true,
|
||||
extended_thinking: true,
|
||||
description: 'OpenAI GPT-5.1 Codex Max 代码专用增强版'
|
||||
},
|
||||
provider_count: 2,
|
||||
alias_count: 0,
|
||||
created_at: '2024-01-01T00:00:00Z'
|
||||
},
|
||||
{
|
||||
id: 'gm-009',
|
||||
name: 'gpt-5.1-codex-mini',
|
||||
display_name: 'gpt-5.1-codex-mini',
|
||||
description: 'OpenAI GPT-5.1 Codex Mini 轻量代码模型',
|
||||
is_active: true,
|
||||
default_tiered_pricing: {
|
||||
tiers: [{ up_to: null, input_price_per_1m: 1.25, output_price_per_1m: 10.00 }]
|
||||
},
|
||||
default_supports_vision: true,
|
||||
default_supports_function_calling: true,
|
||||
default_supports_streaming: true,
|
||||
default_supports_extended_thinking: true,
|
||||
config: {
|
||||
streaming: true,
|
||||
vision: true,
|
||||
function_calling: true,
|
||||
extended_thinking: true,
|
||||
description: 'OpenAI GPT-5.1 Codex Mini 轻量代码模型'
|
||||
},
|
||||
provider_count: 2,
|
||||
alias_count: 0,
|
||||
created_at: '2024-01-01T00:00:00Z'
|
||||
}
|
||||
]
|
||||
|
||||
@@ -1000,17 +1000,11 @@ const mockHandlers: Record<string, (config: AxiosRequestConfig) => Promise<Axios
|
||||
id: m.id,
|
||||
name: m.name,
|
||||
display_name: m.display_name,
|
||||
description: m.description,
|
||||
icon_url: null,
|
||||
is_active: m.is_active,
|
||||
default_tiered_pricing: m.default_tiered_pricing,
|
||||
default_price_per_request: null,
|
||||
default_supports_vision: m.default_supports_vision,
|
||||
default_supports_function_calling: m.default_supports_function_calling,
|
||||
default_supports_streaming: m.default_supports_streaming,
|
||||
default_supports_extended_thinking: m.default_supports_extended_thinking || false,
|
||||
default_supports_image_generation: false,
|
||||
supported_capabilities: null
|
||||
default_price_per_request: m.default_price_per_request,
|
||||
supported_capabilities: m.supported_capabilities,
|
||||
config: m.config
|
||||
})),
|
||||
total: MOCK_GLOBAL_MODELS.length
|
||||
})
|
||||
|
||||
@@ -1169,4 +1169,26 @@ body[theme-mode='dark'] .literary-annotation {
|
||||
.scrollbar-hide::-webkit-scrollbar {
|
||||
display: none;
|
||||
}
|
||||
}
|
||||
|
||||
.scrollbar-thin {
|
||||
scrollbar-width: thin;
|
||||
scrollbar-color: hsl(var(--border)) transparent;
|
||||
}
|
||||
|
||||
.scrollbar-thin::-webkit-scrollbar {
|
||||
width: 6px;
|
||||
}
|
||||
|
||||
.scrollbar-thin::-webkit-scrollbar-track {
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
.scrollbar-thin::-webkit-scrollbar-thumb {
|
||||
background-color: hsl(var(--border));
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.scrollbar-thin::-webkit-scrollbar-thumb:hover {
|
||||
background-color: hsl(var(--muted-foreground) / 0.5);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -751,15 +751,13 @@ const expiringSoonCount = computed(() => apiKeys.value.filter(key => isExpiringS
|
||||
const filteredApiKeys = computed(() => {
|
||||
let result = apiKeys.value
|
||||
|
||||
// 搜索筛选
|
||||
// 搜索筛选(支持空格分隔的多关键词 AND 搜索)
|
||||
if (searchQuery.value) {
|
||||
const query = searchQuery.value.toLowerCase()
|
||||
result = result.filter(key =>
|
||||
(key.name && key.name.toLowerCase().includes(query)) ||
|
||||
(key.key_display && key.key_display.toLowerCase().includes(query)) ||
|
||||
(key.username && key.username.toLowerCase().includes(query)) ||
|
||||
(key.user_email && key.user_email.toLowerCase().includes(query))
|
||||
)
|
||||
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||
result = result.filter(key => {
|
||||
const searchableText = `${key.name || ''} ${key.key_display || ''} ${key.username || ''} ${key.user_email || ''}`.toLowerCase()
|
||||
return keywords.every(keyword => searchableText.includes(keyword))
|
||||
})
|
||||
}
|
||||
|
||||
// 状态筛选
|
||||
|
||||
@@ -935,7 +935,10 @@ onBeforeUnmount(() => {
|
||||
:key="`${index}-${aliasIndex}`"
|
||||
>
|
||||
<TableCell>
|
||||
<Badge variant="outline" class="text-xs">
|
||||
<Badge
|
||||
variant="outline"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ mapping.provider_name }}
|
||||
</Badge>
|
||||
</TableCell>
|
||||
@@ -981,7 +984,10 @@ onBeforeUnmount(() => {
|
||||
class="p-4 space-y-2"
|
||||
>
|
||||
<div class="flex items-center justify-between">
|
||||
<Badge variant="outline" class="text-xs">
|
||||
<Badge
|
||||
variant="outline"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ mapping.provider_name }}
|
||||
</Badge>
|
||||
<div class="flex items-center gap-2">
|
||||
|
||||
@@ -111,9 +111,6 @@
|
||||
<TableHead class="w-[80px] text-center">
|
||||
提供商
|
||||
</TableHead>
|
||||
<TableHead class="w-[70px] text-center">
|
||||
别名/映射
|
||||
</TableHead>
|
||||
<TableHead class="w-[80px] text-center">
|
||||
调用次数
|
||||
</TableHead>
|
||||
@@ -128,7 +125,7 @@
|
||||
<TableBody>
|
||||
<TableRow v-if="loading">
|
||||
<TableCell
|
||||
colspan="8"
|
||||
colspan="7"
|
||||
class="text-center py-8"
|
||||
>
|
||||
<Loader2 class="w-6 h-6 animate-spin mx-auto" />
|
||||
@@ -136,7 +133,7 @@
|
||||
</TableRow>
|
||||
<TableRow v-else-if="filteredGlobalModels.length === 0">
|
||||
<TableCell
|
||||
colspan="8"
|
||||
colspan="7"
|
||||
class="text-center py-8 text-muted-foreground"
|
||||
>
|
||||
没有找到匹配的模型
|
||||
@@ -171,27 +168,27 @@
|
||||
<div class="space-y-1 w-fit">
|
||||
<div class="flex flex-wrap gap-1">
|
||||
<Zap
|
||||
v-if="model.default_supports_streaming"
|
||||
v-if="model.config?.streaming !== false"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
title="流式输出"
|
||||
/>
|
||||
<Image
|
||||
v-if="model.default_supports_image_generation"
|
||||
v-if="model.config?.image_generation === true"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
title="图像生成"
|
||||
/>
|
||||
<Eye
|
||||
v-if="model.default_supports_vision"
|
||||
v-if="model.config?.vision === true"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
title="视觉理解"
|
||||
/>
|
||||
<Wrench
|
||||
v-if="model.default_supports_function_calling"
|
||||
v-if="model.config?.function_calling === true"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
title="工具调用"
|
||||
/>
|
||||
<Brain
|
||||
v-if="model.default_supports_extended_thinking"
|
||||
v-if="model.config?.extended_thinking === true"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
title="深度思考"
|
||||
/>
|
||||
@@ -244,11 +241,6 @@
|
||||
{{ model.provider_count || 0 }}
|
||||
</Badge>
|
||||
</TableCell>
|
||||
<TableCell class="text-center">
|
||||
<Badge variant="secondary">
|
||||
{{ model.alias_count || 0 }}
|
||||
</Badge>
|
||||
</TableCell>
|
||||
<TableCell class="text-center">
|
||||
<span class="text-sm font-mono">{{ formatUsageCount(model.usage_count || 0) }}</span>
|
||||
</TableCell>
|
||||
@@ -369,23 +361,23 @@
|
||||
<!-- 第二行:能力图标 -->
|
||||
<div class="flex flex-wrap gap-1.5">
|
||||
<Zap
|
||||
v-if="model.default_supports_streaming"
|
||||
v-if="model.config?.streaming !== false"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
/>
|
||||
<Image
|
||||
v-if="model.default_supports_image_generation"
|
||||
v-if="model.config?.image_generation === true"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
/>
|
||||
<Eye
|
||||
v-if="model.default_supports_vision"
|
||||
v-if="model.config?.vision === true"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
/>
|
||||
<Wrench
|
||||
v-if="model.default_supports_function_calling"
|
||||
v-if="model.config?.function_calling === true"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
/>
|
||||
<Brain
|
||||
v-if="model.default_supports_extended_thinking"
|
||||
v-if="model.config?.extended_thinking === true"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
/>
|
||||
</div>
|
||||
@@ -393,7 +385,6 @@
|
||||
<!-- 第三行:统计信息 -->
|
||||
<div class="flex flex-wrap items-center gap-3 text-xs text-muted-foreground">
|
||||
<span>提供商 {{ model.provider_count || 0 }}</span>
|
||||
<span>别名 {{ model.alias_count || 0 }}</span>
|
||||
<span>调用 {{ formatUsageCount(model.usage_count || 0) }}</span>
|
||||
<span
|
||||
v-if="getFirstTierPrice(model, 'input') || getFirstTierPrice(model, 'output')"
|
||||
@@ -1011,30 +1002,30 @@ async function batchRemoveSelectedProviders() {
|
||||
const filteredGlobalModels = computed(() => {
|
||||
let result = globalModels.value
|
||||
|
||||
// 搜索
|
||||
// 搜索(支持空格分隔的多关键词 AND 搜索)
|
||||
if (searchQuery.value) {
|
||||
const query = searchQuery.value.toLowerCase()
|
||||
result = result.filter(m =>
|
||||
m.name.toLowerCase().includes(query) ||
|
||||
m.display_name?.toLowerCase().includes(query)
|
||||
)
|
||||
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||
result = result.filter(m => {
|
||||
const searchableText = `${m.name} ${m.display_name || ''}`.toLowerCase()
|
||||
return keywords.every(keyword => searchableText.includes(keyword))
|
||||
})
|
||||
}
|
||||
|
||||
// 能力筛选
|
||||
if (capabilityFilters.value.streaming) {
|
||||
result = result.filter(m => m.default_supports_streaming)
|
||||
result = result.filter(m => m.config?.streaming !== false)
|
||||
}
|
||||
if (capabilityFilters.value.imageGeneration) {
|
||||
result = result.filter(m => m.default_supports_image_generation)
|
||||
result = result.filter(m => m.config?.image_generation === true)
|
||||
}
|
||||
if (capabilityFilters.value.vision) {
|
||||
result = result.filter(m => m.default_supports_vision)
|
||||
result = result.filter(m => m.config?.vision === true)
|
||||
}
|
||||
if (capabilityFilters.value.toolUse) {
|
||||
result = result.filter(m => m.default_supports_function_calling)
|
||||
result = result.filter(m => m.config?.function_calling === true)
|
||||
}
|
||||
if (capabilityFilters.value.extendedThinking) {
|
||||
result = result.filter(m => m.default_supports_extended_thinking)
|
||||
result = result.filter(m => m.config?.extended_thinking === true)
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
@@ -505,13 +505,13 @@ const priorityModeConfig = computed(() => {
|
||||
const filteredProviders = computed(() => {
|
||||
let result = [...providers.value]
|
||||
|
||||
// 搜索筛选
|
||||
// 搜索筛选(支持空格分隔的多关键词 AND 搜索)
|
||||
if (searchQuery.value.trim()) {
|
||||
const query = searchQuery.value.toLowerCase()
|
||||
result = result.filter(p =>
|
||||
p.display_name.toLowerCase().includes(query) ||
|
||||
p.name.toLowerCase().includes(query)
|
||||
)
|
||||
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||
result = result.filter(p => {
|
||||
const searchableText = `${p.display_name} ${p.name}`.toLowerCase()
|
||||
return keywords.every(keyword => searchableText.includes(keyword))
|
||||
})
|
||||
}
|
||||
|
||||
// 排序
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
<template #actions>
|
||||
<Button
|
||||
:disabled="loading"
|
||||
class="shadow-none hover:shadow-none"
|
||||
@click="saveSystemConfig"
|
||||
>
|
||||
{{ loading ? '保存中...' : '保存所有配置' }}
|
||||
@@ -15,6 +16,94 @@
|
||||
</PageHeader>
|
||||
|
||||
<div class="mt-6 space-y-6">
|
||||
<!-- 配置导出/导入 -->
|
||||
<CardSection
|
||||
title="配置管理"
|
||||
description="导出或导入提供商和模型配置,便于备份或迁移"
|
||||
>
|
||||
<div class="flex flex-wrap gap-4">
|
||||
<div class="flex-1 min-w-[200px]">
|
||||
<p class="text-sm text-muted-foreground mb-3">
|
||||
导出当前所有提供商、端点、API Key 和模型配置到 JSON 文件
|
||||
</p>
|
||||
<Button
|
||||
variant="outline"
|
||||
:disabled="exportLoading"
|
||||
@click="handleExportConfig"
|
||||
>
|
||||
<Download class="w-4 h-4 mr-2" />
|
||||
{{ exportLoading ? '导出中...' : '导出配置' }}
|
||||
</Button>
|
||||
</div>
|
||||
<div class="flex-1 min-w-[200px]">
|
||||
<p class="text-sm text-muted-foreground mb-3">
|
||||
从 JSON 文件导入配置,支持跳过、覆盖或报错三种冲突处理模式
|
||||
</p>
|
||||
<div class="flex items-center gap-2">
|
||||
<input
|
||||
ref="configFileInput"
|
||||
type="file"
|
||||
accept=".json"
|
||||
class="hidden"
|
||||
@change="handleConfigFileSelect"
|
||||
>
|
||||
<Button
|
||||
variant="outline"
|
||||
:disabled="importLoading"
|
||||
@click="triggerConfigFileSelect"
|
||||
>
|
||||
<Upload class="w-4 h-4 mr-2" />
|
||||
{{ importLoading ? '导入中...' : '导入配置' }}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</CardSection>
|
||||
|
||||
<!-- 用户数据导出/导入 -->
|
||||
<CardSection
|
||||
title="用户数据管理"
|
||||
description="导出或导入用户及其 API Keys 数据(不含管理员)"
|
||||
>
|
||||
<div class="flex flex-wrap gap-4">
|
||||
<div class="flex-1 min-w-[200px]">
|
||||
<p class="text-sm text-muted-foreground mb-3">
|
||||
导出所有普通用户及其 API Keys 到 JSON 文件
|
||||
</p>
|
||||
<Button
|
||||
variant="outline"
|
||||
:disabled="exportUsersLoading"
|
||||
@click="handleExportUsers"
|
||||
>
|
||||
<Download class="w-4 h-4 mr-2" />
|
||||
{{ exportUsersLoading ? '导出中...' : '导出用户数据' }}
|
||||
</Button>
|
||||
</div>
|
||||
<div class="flex-1 min-w-[200px]">
|
||||
<p class="text-sm text-muted-foreground mb-3">
|
||||
从 JSON 文件导入用户数据(需相同 ENCRYPTION_KEY)
|
||||
</p>
|
||||
<div class="flex items-center gap-2">
|
||||
<input
|
||||
ref="usersFileInput"
|
||||
type="file"
|
||||
accept=".json"
|
||||
class="hidden"
|
||||
@change="handleUsersFileSelect"
|
||||
>
|
||||
<Button
|
||||
variant="outline"
|
||||
:disabled="importUsersLoading"
|
||||
@click="triggerUsersFileSelect"
|
||||
>
|
||||
<Upload class="w-4 h-4 mr-2" />
|
||||
{{ importUsersLoading ? '导入中...' : '导入用户数据' }}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</CardSection>
|
||||
|
||||
<!-- 基础配置 -->
|
||||
<CardSection
|
||||
title="基础配置"
|
||||
@@ -375,11 +464,312 @@
|
||||
</div>
|
||||
</CardSection>
|
||||
</div>
|
||||
|
||||
<!-- 导入配置对话框 -->
|
||||
<Dialog
|
||||
v-model:open="importDialogOpen"
|
||||
title="导入配置"
|
||||
description="选择冲突处理模式并确认导入"
|
||||
>
|
||||
<div class="space-y-4">
|
||||
<div
|
||||
v-if="importPreview"
|
||||
class="p-3 bg-muted rounded-lg text-sm"
|
||||
>
|
||||
<p class="font-medium mb-2">
|
||||
配置预览
|
||||
</p>
|
||||
<ul class="space-y-1 text-muted-foreground">
|
||||
<li>全局模型: {{ importPreview.global_models?.length || 0 }} 个</li>
|
||||
<li>提供商: {{ importPreview.providers?.length || 0 }} 个</li>
|
||||
<li>
|
||||
端点: {{ importPreview.providers?.reduce((sum: number, p: any) => sum + (p.endpoints?.length || 0), 0) }} 个
|
||||
</li>
|
||||
<li>
|
||||
API Keys: {{ importPreview.providers?.reduce((sum: number, p: any) => sum + p.endpoints?.reduce((s: number, e: any) => s + (e.keys?.length || 0), 0), 0) }} 个
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Label class="block text-sm font-medium mb-2">冲突处理模式</Label>
|
||||
<Select
|
||||
v-model="mergeMode"
|
||||
v-model:open="mergeModeSelectOpen"
|
||||
>
|
||||
<SelectTrigger>
|
||||
<SelectValue />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="skip">
|
||||
跳过 - 保留现有配置
|
||||
</SelectItem>
|
||||
<SelectItem value="overwrite">
|
||||
覆盖 - 用导入配置替换
|
||||
</SelectItem>
|
||||
<SelectItem value="error">
|
||||
报错 - 遇到冲突时中止
|
||||
</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
<template v-if="mergeMode === 'skip'">
|
||||
已存在的配置将被保留,仅导入新配置
|
||||
</template>
|
||||
<template v-else-if="mergeMode === 'overwrite'">
|
||||
已存在的配置将被导入的配置覆盖
|
||||
</template>
|
||||
<template v-else>
|
||||
如果发现任何冲突,导入将中止并回滚
|
||||
</template>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<p class="text-xs text-muted-foreground">
|
||||
注意:相同的 API Keys 会自动跳过,不会创建重复记录。
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<template #footer>
|
||||
<Button
|
||||
variant="outline"
|
||||
@click="importDialogOpen = false; mergeModeSelectOpen = false"
|
||||
>
|
||||
取消
|
||||
</Button>
|
||||
<Button
|
||||
:disabled="importLoading"
|
||||
@click="confirmImport"
|
||||
>
|
||||
{{ importLoading ? '导入中...' : '确认导入' }}
|
||||
</Button>
|
||||
</template>
|
||||
</Dialog>
|
||||
|
||||
<!-- 导入结果对话框 -->
|
||||
<Dialog
|
||||
v-model:open="importResultDialogOpen"
|
||||
title="导入完成"
|
||||
>
|
||||
<div
|
||||
v-if="importResult"
|
||||
class="space-y-4"
|
||||
>
|
||||
<div class="grid grid-cols-2 gap-4 text-sm">
|
||||
<div class="p-3 bg-muted rounded-lg">
|
||||
<p class="font-medium">
|
||||
全局模型
|
||||
</p>
|
||||
<p class="text-muted-foreground">
|
||||
创建: {{ importResult.stats.global_models.created }},
|
||||
更新: {{ importResult.stats.global_models.updated }},
|
||||
跳过: {{ importResult.stats.global_models.skipped }}
|
||||
</p>
|
||||
</div>
|
||||
<div class="p-3 bg-muted rounded-lg">
|
||||
<p class="font-medium">
|
||||
提供商
|
||||
</p>
|
||||
<p class="text-muted-foreground">
|
||||
创建: {{ importResult.stats.providers.created }},
|
||||
更新: {{ importResult.stats.providers.updated }},
|
||||
跳过: {{ importResult.stats.providers.skipped }}
|
||||
</p>
|
||||
</div>
|
||||
<div class="p-3 bg-muted rounded-lg">
|
||||
<p class="font-medium">
|
||||
端点
|
||||
</p>
|
||||
<p class="text-muted-foreground">
|
||||
创建: {{ importResult.stats.endpoints.created }},
|
||||
更新: {{ importResult.stats.endpoints.updated }},
|
||||
跳过: {{ importResult.stats.endpoints.skipped }}
|
||||
</p>
|
||||
</div>
|
||||
<div class="p-3 bg-muted rounded-lg">
|
||||
<p class="font-medium">
|
||||
API Keys
|
||||
</p>
|
||||
<p class="text-muted-foreground">
|
||||
创建: {{ importResult.stats.keys.created }},
|
||||
跳过: {{ importResult.stats.keys.skipped }}
|
||||
</p>
|
||||
</div>
|
||||
<div class="p-3 bg-muted rounded-lg col-span-2">
|
||||
<p class="font-medium">
|
||||
模型配置
|
||||
</p>
|
||||
<p class="text-muted-foreground">
|
||||
创建: {{ importResult.stats.models.created }},
|
||||
更新: {{ importResult.stats.models.updated }},
|
||||
跳过: {{ importResult.stats.models.skipped }}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div
|
||||
v-if="importResult.stats.errors.length > 0"
|
||||
class="p-3 bg-destructive/10 rounded-lg"
|
||||
>
|
||||
<p class="font-medium text-destructive mb-2">
|
||||
警告信息
|
||||
</p>
|
||||
<ul class="text-sm text-destructive space-y-1">
|
||||
<li
|
||||
v-for="(err, index) in importResult.stats.errors"
|
||||
:key="index"
|
||||
>
|
||||
{{ err }}
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<template #footer>
|
||||
<Button @click="importResultDialogOpen = false">
|
||||
确定
|
||||
</Button>
|
||||
</template>
|
||||
</Dialog>
|
||||
|
||||
<!-- 用户数据导入对话框 -->
|
||||
<Dialog
|
||||
v-model:open="importUsersDialogOpen"
|
||||
title="导入用户数据"
|
||||
description="选择冲突处理模式并确认导入"
|
||||
>
|
||||
<div class="space-y-4">
|
||||
<div
|
||||
v-if="importUsersPreview"
|
||||
class="p-3 bg-muted rounded-lg text-sm"
|
||||
>
|
||||
<p class="font-medium mb-2">
|
||||
数据预览
|
||||
</p>
|
||||
<ul class="space-y-1 text-muted-foreground">
|
||||
<li>用户: {{ importUsersPreview.users?.length || 0 }} 个</li>
|
||||
<li>
|
||||
API Keys: {{ importUsersPreview.users?.reduce((sum: number, u: any) => sum + (u.api_keys?.length || 0), 0) }} 个
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Label class="block text-sm font-medium mb-2">冲突处理模式</Label>
|
||||
<Select
|
||||
v-model="usersMergeMode"
|
||||
v-model:open="usersMergeModeSelectOpen"
|
||||
>
|
||||
<SelectTrigger>
|
||||
<SelectValue />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="skip">
|
||||
跳过 - 保留现有用户
|
||||
</SelectItem>
|
||||
<SelectItem value="overwrite">
|
||||
覆盖 - 用导入数据替换
|
||||
</SelectItem>
|
||||
<SelectItem value="error">
|
||||
报错 - 遇到冲突时中止
|
||||
</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
<template v-if="usersMergeMode === 'skip'">
|
||||
已存在的用户将被保留,仅导入新用户
|
||||
</template>
|
||||
<template v-else-if="usersMergeMode === 'overwrite'">
|
||||
已存在的用户将被导入的数据覆盖
|
||||
</template>
|
||||
<template v-else>
|
||||
如果发现任何冲突,导入将中止并回滚
|
||||
</template>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<p class="text-xs text-muted-foreground">
|
||||
注意:用户 API Keys 需要目标系统使用相同的 ENCRYPTION_KEY 环境变量才能正常工作。
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<template #footer>
|
||||
<Button
|
||||
variant="outline"
|
||||
@click="importUsersDialogOpen = false; usersMergeModeSelectOpen = false"
|
||||
>
|
||||
取消
|
||||
</Button>
|
||||
<Button
|
||||
:disabled="importUsersLoading"
|
||||
@click="confirmImportUsers"
|
||||
>
|
||||
{{ importUsersLoading ? '导入中...' : '确认导入' }}
|
||||
</Button>
|
||||
</template>
|
||||
</Dialog>
|
||||
|
||||
<!-- 用户数据导入结果对话框 -->
|
||||
<Dialog
|
||||
v-model:open="importUsersResultDialogOpen"
|
||||
title="用户数据导入完成"
|
||||
>
|
||||
<div
|
||||
v-if="importUsersResult"
|
||||
class="space-y-4"
|
||||
>
|
||||
<div class="grid grid-cols-2 gap-4 text-sm">
|
||||
<div class="p-3 bg-muted rounded-lg">
|
||||
<p class="font-medium">
|
||||
用户
|
||||
</p>
|
||||
<p class="text-muted-foreground">
|
||||
创建: {{ importUsersResult.stats.users.created }},
|
||||
更新: {{ importUsersResult.stats.users.updated }},
|
||||
跳过: {{ importUsersResult.stats.users.skipped }}
|
||||
</p>
|
||||
</div>
|
||||
<div class="p-3 bg-muted rounded-lg">
|
||||
<p class="font-medium">
|
||||
API Keys
|
||||
</p>
|
||||
<p class="text-muted-foreground">
|
||||
创建: {{ importUsersResult.stats.api_keys.created }},
|
||||
跳过: {{ importUsersResult.stats.api_keys.skipped }}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div
|
||||
v-if="importUsersResult.stats.errors.length > 0"
|
||||
class="p-3 bg-destructive/10 rounded-lg"
|
||||
>
|
||||
<p class="font-medium text-destructive mb-2">
|
||||
警告信息
|
||||
</p>
|
||||
<ul class="text-sm text-destructive space-y-1">
|
||||
<li
|
||||
v-for="(err, index) in importUsersResult.stats.errors"
|
||||
:key="index"
|
||||
>
|
||||
{{ err }}
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<template #footer>
|
||||
<Button @click="importUsersResultDialogOpen = false">
|
||||
确定
|
||||
</Button>
|
||||
</template>
|
||||
</Dialog>
|
||||
</PageContainer>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
import { Download, Upload } from 'lucide-vue-next'
|
||||
import Button from '@/components/ui/button.vue'
|
||||
import Input from '@/components/ui/input.vue'
|
||||
import Label from '@/components/ui/label.vue'
|
||||
@@ -389,9 +779,12 @@ import SelectTrigger from '@/components/ui/select-trigger.vue'
|
||||
import SelectValue from '@/components/ui/select-value.vue'
|
||||
import SelectContent from '@/components/ui/select-content.vue'
|
||||
import SelectItem from '@/components/ui/select-item.vue'
|
||||
import {
|
||||
Dialog,
|
||||
} from '@/components/ui'
|
||||
import { PageHeader, PageContainer, CardSection } from '@/components/layout'
|
||||
import { useToast } from '@/composables/useToast'
|
||||
import { adminApi } from '@/api/admin'
|
||||
import { adminApi, type ConfigExportData, type ConfigImportResponse, type UsersExportData, type UsersImportResponse } from '@/api/admin'
|
||||
import { log } from '@/utils/logger'
|
||||
|
||||
const { success, error } = useToast()
|
||||
@@ -423,6 +816,28 @@ interface SystemConfig {
|
||||
const loading = ref(false)
|
||||
const logLevelSelectOpen = ref(false)
|
||||
|
||||
// 导出/导入相关
|
||||
const exportLoading = ref(false)
|
||||
const importLoading = ref(false)
|
||||
const importDialogOpen = ref(false)
|
||||
const importResultDialogOpen = ref(false)
|
||||
const configFileInput = ref<HTMLInputElement | null>(null)
|
||||
const importPreview = ref<ConfigExportData | null>(null)
|
||||
const importResult = ref<ConfigImportResponse | null>(null)
|
||||
const mergeMode = ref<'skip' | 'overwrite' | 'error'>('skip')
|
||||
const mergeModeSelectOpen = ref(false)
|
||||
|
||||
// 用户数据导出/导入相关
|
||||
const exportUsersLoading = ref(false)
|
||||
const importUsersLoading = ref(false)
|
||||
const importUsersDialogOpen = ref(false)
|
||||
const importUsersResultDialogOpen = ref(false)
|
||||
const usersFileInput = ref<HTMLInputElement | null>(null)
|
||||
const importUsersPreview = ref<UsersExportData | null>(null)
|
||||
const importUsersResult = ref<UsersImportResponse | null>(null)
|
||||
const usersMergeMode = ref<'skip' | 'overwrite' | 'error'>('skip')
|
||||
const usersMergeModeSelectOpen = ref(false)
|
||||
|
||||
const systemConfig = ref<SystemConfig>({
|
||||
// 基础配置
|
||||
default_user_quota_usd: 10.0,
|
||||
@@ -623,4 +1038,185 @@ async function saveSystemConfig() {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 导出配置
|
||||
async function handleExportConfig() {
|
||||
exportLoading.value = true
|
||||
try {
|
||||
const data = await adminApi.exportConfig()
|
||||
const blob = new Blob([JSON.stringify(data, null, 2)], { type: 'application/json' })
|
||||
const url = URL.createObjectURL(blob)
|
||||
const a = document.createElement('a')
|
||||
a.href = url
|
||||
a.download = `aether-config-${new Date().toISOString().slice(0, 10)}.json`
|
||||
document.body.appendChild(a)
|
||||
a.click()
|
||||
document.body.removeChild(a)
|
||||
URL.revokeObjectURL(url)
|
||||
success('配置已导出')
|
||||
} catch (err) {
|
||||
error('导出配置失败')
|
||||
log.error('导出配置失败:', err)
|
||||
} finally {
|
||||
exportLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 触发文件选择
|
||||
function triggerConfigFileSelect() {
|
||||
configFileInput.value?.click()
|
||||
}
|
||||
|
||||
// 文件大小限制 (10MB)
|
||||
const MAX_FILE_SIZE = 10 * 1024 * 1024
|
||||
|
||||
// 处理文件选择
|
||||
function handleConfigFileSelect(event: Event) {
|
||||
const input = event.target as HTMLInputElement
|
||||
const file = input.files?.[0]
|
||||
if (!file) return
|
||||
|
||||
if (file.size > MAX_FILE_SIZE) {
|
||||
error('文件大小不能超过 10MB')
|
||||
input.value = ''
|
||||
return
|
||||
}
|
||||
|
||||
const reader = new FileReader()
|
||||
reader.onload = (e) => {
|
||||
try {
|
||||
const content = e.target?.result as string
|
||||
const data = JSON.parse(content) as ConfigExportData
|
||||
|
||||
// 验证版本
|
||||
if (data.version !== '1.0') {
|
||||
error(`不支持的配置版本: ${data.version}`)
|
||||
return
|
||||
}
|
||||
|
||||
importPreview.value = data
|
||||
mergeMode.value = 'skip'
|
||||
importDialogOpen.value = true
|
||||
} catch (err) {
|
||||
error('解析配置文件失败,请确保是有效的 JSON 文件')
|
||||
log.error('解析配置文件失败:', err)
|
||||
}
|
||||
}
|
||||
reader.readAsText(file)
|
||||
|
||||
// 重置 input 以便能再次选择同一文件
|
||||
input.value = ''
|
||||
}
|
||||
|
||||
// 确认导入
|
||||
async function confirmImport() {
|
||||
if (!importPreview.value) return
|
||||
|
||||
importLoading.value = true
|
||||
try {
|
||||
const result = await adminApi.importConfig({
|
||||
...importPreview.value,
|
||||
merge_mode: mergeMode.value
|
||||
})
|
||||
importResult.value = result
|
||||
importDialogOpen.value = false
|
||||
mergeModeSelectOpen.value = false
|
||||
importResultDialogOpen.value = true
|
||||
success('配置导入成功')
|
||||
} catch (err: any) {
|
||||
error(err.response?.data?.detail || '导入配置失败')
|
||||
log.error('导入配置失败:', err)
|
||||
} finally {
|
||||
importLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 导出用户数据
|
||||
async function handleExportUsers() {
|
||||
exportUsersLoading.value = true
|
||||
try {
|
||||
const data = await adminApi.exportUsers()
|
||||
const blob = new Blob([JSON.stringify(data, null, 2)], { type: 'application/json' })
|
||||
const url = URL.createObjectURL(blob)
|
||||
const a = document.createElement('a')
|
||||
a.href = url
|
||||
a.download = `aether-users-${new Date().toISOString().slice(0, 10)}.json`
|
||||
document.body.appendChild(a)
|
||||
a.click()
|
||||
document.body.removeChild(a)
|
||||
URL.revokeObjectURL(url)
|
||||
success('用户数据已导出')
|
||||
} catch (err) {
|
||||
error('导出用户数据失败')
|
||||
log.error('导出用户数据失败:', err)
|
||||
} finally {
|
||||
exportUsersLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 触发用户数据文件选择
|
||||
function triggerUsersFileSelect() {
|
||||
usersFileInput.value?.click()
|
||||
}
|
||||
|
||||
// 处理用户数据文件选择
|
||||
function handleUsersFileSelect(event: Event) {
|
||||
const input = event.target as HTMLInputElement
|
||||
const file = input.files?.[0]
|
||||
if (!file) return
|
||||
|
||||
if (file.size > MAX_FILE_SIZE) {
|
||||
error('文件大小不能超过 10MB')
|
||||
input.value = ''
|
||||
return
|
||||
}
|
||||
|
||||
const reader = new FileReader()
|
||||
reader.onload = (e) => {
|
||||
try {
|
||||
const content = e.target?.result as string
|
||||
const data = JSON.parse(content) as UsersExportData
|
||||
|
||||
// 验证版本
|
||||
if (data.version !== '1.0') {
|
||||
error(`不支持的配置版本: ${data.version}`)
|
||||
return
|
||||
}
|
||||
|
||||
importUsersPreview.value = data
|
||||
usersMergeMode.value = 'skip'
|
||||
importUsersDialogOpen.value = true
|
||||
} catch (err) {
|
||||
error('解析用户数据文件失败,请确保是有效的 JSON 文件')
|
||||
log.error('解析用户数据文件失败:', err)
|
||||
}
|
||||
}
|
||||
reader.readAsText(file)
|
||||
|
||||
// 重置 input 以便能再次选择同一文件
|
||||
input.value = ''
|
||||
}
|
||||
|
||||
// 确认导入用户数据
|
||||
async function confirmImportUsers() {
|
||||
if (!importUsersPreview.value) return
|
||||
|
||||
importUsersLoading.value = true
|
||||
try {
|
||||
const result = await adminApi.importUsers({
|
||||
...importUsersPreview.value,
|
||||
merge_mode: usersMergeMode.value
|
||||
})
|
||||
importUsersResult.value = result
|
||||
importUsersDialogOpen.value = false
|
||||
usersMergeModeSelectOpen.value = false
|
||||
importUsersResultDialogOpen.value = true
|
||||
success('用户数据导入成功')
|
||||
} catch (err: any) {
|
||||
error(err.response?.data?.detail || '导入用户数据失败')
|
||||
log.error('导入用户数据失败:', err)
|
||||
} finally {
|
||||
importUsersLoading.value = false
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -791,11 +791,13 @@ const filteredUsers = computed(() => {
|
||||
return new Date(b.created_at).getTime() - new Date(a.created_at).getTime()
|
||||
})
|
||||
|
||||
// 搜索(支持空格分隔的多关键词 AND 搜索)
|
||||
if (searchQuery.value) {
|
||||
const query = searchQuery.value.toLowerCase()
|
||||
filtered = filtered.filter(
|
||||
u => u.username.toLowerCase().includes(query) || u.email?.toLowerCase().includes(query)
|
||||
)
|
||||
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||
filtered = filtered.filter(u => {
|
||||
const searchableText = `${u.username} ${u.email || ''}`.toLowerCase()
|
||||
return keywords.every(keyword => searchableText.includes(keyword))
|
||||
})
|
||||
}
|
||||
|
||||
if (filterRole.value !== 'all') {
|
||||
|
||||
@@ -103,7 +103,7 @@
|
||||
</div>
|
||||
<div class="grid grid-cols-2 gap-2 sm:gap-3 xl:grid-cols-4">
|
||||
<Card class="relative p-3 sm:p-4 border-book-cloth/30">
|
||||
<Clock class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-book-cloth" />
|
||||
<Clock class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
||||
<div class="pr-6">
|
||||
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
||||
平均响应
|
||||
@@ -114,7 +114,7 @@
|
||||
</div>
|
||||
</Card>
|
||||
<Card class="relative p-3 sm:p-4 border-kraft/30">
|
||||
<AlertTriangle class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-kraft" />
|
||||
<AlertTriangle class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
||||
<div class="pr-6">
|
||||
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
||||
错误率
|
||||
@@ -128,7 +128,7 @@
|
||||
</div>
|
||||
</Card>
|
||||
<Card class="relative p-3 sm:p-4 border-book-cloth/25">
|
||||
<Shuffle class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-kraft" />
|
||||
<Shuffle class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
||||
<div class="pr-6">
|
||||
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
||||
转移次数
|
||||
@@ -142,7 +142,7 @@
|
||||
v-if="costStats"
|
||||
class="relative p-3 sm:p-4 border-manilla/40"
|
||||
>
|
||||
<DollarSign class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-book-cloth" />
|
||||
<DollarSign class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
||||
<div class="pr-6">
|
||||
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
||||
实际成本
|
||||
@@ -180,7 +180,7 @@
|
||||
</div>
|
||||
<div class="grid grid-cols-2 gap-2 sm:gap-3 xl:grid-cols-4">
|
||||
<Card class="relative p-3 sm:p-4 border-book-cloth/30">
|
||||
<Database class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-book-cloth" />
|
||||
<Database class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
||||
<div class="pr-6">
|
||||
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
||||
缓存命中率
|
||||
@@ -191,7 +191,7 @@
|
||||
</div>
|
||||
</Card>
|
||||
<Card class="relative p-3 sm:p-4 border-kraft/30">
|
||||
<Hash class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-kraft" />
|
||||
<Hash class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
||||
<div class="pr-6">
|
||||
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
||||
缓存读取
|
||||
@@ -202,7 +202,7 @@
|
||||
</div>
|
||||
</Card>
|
||||
<Card class="relative p-3 sm:p-4 border-book-cloth/25">
|
||||
<Database class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-kraft" />
|
||||
<Database class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
||||
<div class="pr-6">
|
||||
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
||||
缓存创建
|
||||
@@ -216,7 +216,7 @@
|
||||
v-if="tokenBreakdown"
|
||||
class="relative p-3 sm:p-4 border-manilla/40"
|
||||
>
|
||||
<Hash class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-book-cloth" />
|
||||
<Hash class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
||||
<div class="pr-6">
|
||||
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
||||
总Token
|
||||
@@ -254,16 +254,16 @@
|
||||
<Card class="overflow-hidden p-4 flex flex-col flex-1 min-h-0 h-full max-h-[280px] sm:max-h-none">
|
||||
<div
|
||||
v-if="loadingAnnouncements"
|
||||
class="py-8 text-center"
|
||||
class="flex-1 flex items-center justify-center"
|
||||
>
|
||||
<Loader2 class="h-5 w-5 animate-spin mx-auto text-muted-foreground" />
|
||||
<Loader2 class="h-5 w-5 animate-spin text-muted-foreground" />
|
||||
</div>
|
||||
|
||||
<div
|
||||
v-else-if="announcements.length === 0"
|
||||
class="py-8 text-center"
|
||||
class="flex-1 flex flex-col items-center justify-center"
|
||||
>
|
||||
<Bell class="h-8 w-8 mx-auto text-muted-foreground/40" />
|
||||
<Bell class="h-8 w-8 text-muted-foreground/40" />
|
||||
<p class="mt-2 text-xs text-muted-foreground">
|
||||
暂无公告
|
||||
</p>
|
||||
@@ -793,9 +793,8 @@ const statCardGlows = [
|
||||
'bg-kraft/30'
|
||||
]
|
||||
|
||||
const getStatIconColor = (index: number): string => {
|
||||
const colors = ['text-book-cloth', 'text-kraft', 'text-book-cloth', 'text-kraft']
|
||||
return colors[index % colors.length]
|
||||
const getStatIconColor = (_index: number): string => {
|
||||
return 'text-muted-foreground'
|
||||
}
|
||||
|
||||
// 统计数据
|
||||
|
||||
@@ -226,8 +226,8 @@
|
||||
<div
|
||||
v-for="announcement in announcements"
|
||||
:key="announcement.id"
|
||||
class="p-4 space-y-2 cursor-pointer transition-colors"
|
||||
:class="[
|
||||
'p-4 space-y-2 cursor-pointer transition-colors',
|
||||
announcement.is_read ? 'hover:bg-muted/30' : 'bg-primary/5 hover:bg-primary/10'
|
||||
]"
|
||||
@click="viewAnnouncementDetail(announcement)"
|
||||
|
||||
@@ -165,17 +165,17 @@
|
||||
<TableCell class="py-4">
|
||||
<div class="flex gap-1.5">
|
||||
<Eye
|
||||
v-if="model.default_supports_vision"
|
||||
v-if="model.config?.vision === true"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
title="Vision"
|
||||
/>
|
||||
<Wrench
|
||||
v-if="model.default_supports_function_calling"
|
||||
v-if="model.config?.function_calling === true"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
title="Tool Use"
|
||||
/>
|
||||
<Brain
|
||||
v-if="model.default_supports_extended_thinking"
|
||||
v-if="model.config?.extended_thinking === true"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
title="Extended Thinking"
|
||||
/>
|
||||
@@ -253,15 +253,15 @@
|
||||
<!-- 第二行:能力图标 -->
|
||||
<div class="flex gap-1.5">
|
||||
<Eye
|
||||
v-if="model.default_supports_vision"
|
||||
v-if="model.config?.vision === true"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
/>
|
||||
<Wrench
|
||||
v-if="model.default_supports_function_calling"
|
||||
v-if="model.config?.function_calling === true"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
/>
|
||||
<Brain
|
||||
v-if="model.default_supports_extended_thinking"
|
||||
v-if="model.config?.extended_thinking === true"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
/>
|
||||
</div>
|
||||
@@ -474,24 +474,24 @@ async function toggleCapability(modelName: string, capName: string) {
|
||||
const filteredModels = computed(() => {
|
||||
let result = models.value
|
||||
|
||||
// 搜索
|
||||
// 搜索(支持空格分隔的多关键词 AND 搜索)
|
||||
if (searchQuery.value) {
|
||||
const query = searchQuery.value.toLowerCase()
|
||||
result = result.filter(m =>
|
||||
m.name.toLowerCase().includes(query) ||
|
||||
m.display_name?.toLowerCase().includes(query)
|
||||
)
|
||||
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||
result = result.filter(m => {
|
||||
const searchableText = `${m.name} ${m.display_name || ''}`.toLowerCase()
|
||||
return keywords.every(keyword => searchableText.includes(keyword))
|
||||
})
|
||||
}
|
||||
|
||||
// 能力筛选
|
||||
if (capabilityFilters.value.vision) {
|
||||
result = result.filter(m => m.default_supports_vision)
|
||||
result = result.filter(m => m.config?.vision === true)
|
||||
}
|
||||
if (capabilityFilters.value.toolUse) {
|
||||
result = result.filter(m => m.default_supports_function_calling)
|
||||
result = result.filter(m => m.config?.function_calling === true)
|
||||
}
|
||||
if (capabilityFilters.value.extendedThinking) {
|
||||
result = result.filter(m => m.default_supports_extended_thinking)
|
||||
result = result.filter(m => m.config?.extended_thinking === true)
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
@@ -62,6 +62,7 @@
|
||||
<Button
|
||||
type="submit"
|
||||
:disabled="savingProfile"
|
||||
class="shadow-none hover:shadow-none"
|
||||
>
|
||||
{{ savingProfile ? '保存中...' : '保存修改' }}
|
||||
</Button>
|
||||
@@ -107,6 +108,7 @@
|
||||
<Button
|
||||
type="submit"
|
||||
:disabled="changingPassword"
|
||||
class="shadow-none hover:shadow-none"
|
||||
>
|
||||
{{ changingPassword ? '修改中...' : '修改密码' }}
|
||||
</Button>
|
||||
@@ -320,6 +322,7 @@
|
||||
import { ref, onMounted } from 'vue'
|
||||
import { useAuthStore } from '@/stores/auth'
|
||||
import { meApi, type Profile } from '@/api/me'
|
||||
import { useDarkMode, type ThemeMode } from '@/composables/useDarkMode'
|
||||
import Card from '@/components/ui/card.vue'
|
||||
import Button from '@/components/ui/button.vue'
|
||||
import Badge from '@/components/ui/badge.vue'
|
||||
@@ -338,6 +341,7 @@ import { log } from '@/utils/logger'
|
||||
|
||||
const authStore = useAuthStore()
|
||||
const { success, error: showError } = useToast()
|
||||
const { setThemeMode } = useDarkMode()
|
||||
|
||||
const profile = ref<Profile | null>(null)
|
||||
|
||||
@@ -375,20 +379,8 @@ function handleThemeChange(value: string) {
|
||||
themeSelectOpen.value = false
|
||||
updatePreferences()
|
||||
|
||||
// 应用主题
|
||||
if (value === 'dark') {
|
||||
document.documentElement.classList.add('dark')
|
||||
} else if (value === 'light') {
|
||||
document.documentElement.classList.remove('dark')
|
||||
} else {
|
||||
// system: 跟随系统
|
||||
const prefersDark = window.matchMedia('(prefers-color-scheme: dark)').matches
|
||||
if (prefersDark) {
|
||||
document.documentElement.classList.add('dark')
|
||||
} else {
|
||||
document.documentElement.classList.remove('dark')
|
||||
}
|
||||
}
|
||||
// 使用 useDarkMode 统一切换主题
|
||||
setThemeMode(value as ThemeMode)
|
||||
}
|
||||
|
||||
function handleLanguageChange(value: string) {
|
||||
@@ -418,10 +410,16 @@ async function loadProfile() {
|
||||
async function loadPreferences() {
|
||||
try {
|
||||
const prefs = await meApi.getPreferences()
|
||||
|
||||
// 主题以本地 localStorage 为准(useDarkMode 在应用启动时已初始化)
|
||||
// 这样可以避免刷新页面时主题被服务端旧值覆盖
|
||||
const { themeMode: currentThemeMode } = useDarkMode()
|
||||
const localTheme = currentThemeMode.value
|
||||
|
||||
preferencesForm.value = {
|
||||
avatar_url: prefs.avatar_url || '',
|
||||
bio: prefs.bio || '',
|
||||
theme: prefs.theme || 'light',
|
||||
theme: localTheme, // 使用本地主题,而非服务端返回值
|
||||
language: prefs.language || 'zh-CN',
|
||||
timezone: prefs.timezone || 'Asia/Shanghai',
|
||||
notifications: {
|
||||
@@ -431,11 +429,12 @@ async function loadPreferences() {
|
||||
}
|
||||
}
|
||||
|
||||
// 应用主题
|
||||
if (preferencesForm.value.theme === 'dark') {
|
||||
document.documentElement.classList.add('dark')
|
||||
} else if (preferencesForm.value.theme === 'light') {
|
||||
document.documentElement.classList.remove('dark')
|
||||
// 如果本地主题和服务端不一致,同步到服务端(静默更新,不提示用户)
|
||||
const serverTheme = prefs.theme || 'light'
|
||||
if (localTheme !== serverTheme) {
|
||||
meApi.updatePreferences({ theme: localTheme }).catch(() => {
|
||||
// 静默失败,不影响用户体验
|
||||
})
|
||||
}
|
||||
} catch (error) {
|
||||
log.error('加载偏好设置失败:', error)
|
||||
|
||||
@@ -38,10 +38,10 @@
|
||||
</button>
|
||||
</div>
|
||||
<p
|
||||
v-if="model.description"
|
||||
v-if="model.config?.description"
|
||||
class="text-xs text-muted-foreground"
|
||||
>
|
||||
{{ model.description }}
|
||||
{{ model.config?.description }}
|
||||
</p>
|
||||
</div>
|
||||
<Button
|
||||
@@ -73,10 +73,10 @@
|
||||
</p>
|
||||
</div>
|
||||
<Badge
|
||||
:variant="model.default_supports_streaming ?? false ? 'default' : 'secondary'"
|
||||
:variant="model.config?.streaming !== false ? 'default' : 'secondary'"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ model.default_supports_streaming ?? false ? '支持' : '不支持' }}
|
||||
{{ model.config?.streaming !== false ? '支持' : '不支持' }}
|
||||
</Badge>
|
||||
</div>
|
||||
<div class="flex items-center gap-2 p-3 rounded-lg border">
|
||||
@@ -90,10 +90,10 @@
|
||||
</p>
|
||||
</div>
|
||||
<Badge
|
||||
:variant="model.default_supports_image_generation ?? false ? 'default' : 'secondary'"
|
||||
:variant="model.config?.image_generation === true ? 'default' : 'secondary'"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ model.default_supports_image_generation ?? false ? '支持' : '不支持' }}
|
||||
{{ model.config?.image_generation === true ? '支持' : '不支持' }}
|
||||
</Badge>
|
||||
</div>
|
||||
<div class="flex items-center gap-2 p-3 rounded-lg border">
|
||||
@@ -107,10 +107,10 @@
|
||||
</p>
|
||||
</div>
|
||||
<Badge
|
||||
:variant="model.default_supports_vision ?? false ? 'default' : 'secondary'"
|
||||
:variant="model.config?.vision === true ? 'default' : 'secondary'"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ model.default_supports_vision ?? false ? '支持' : '不支持' }}
|
||||
{{ model.config?.vision === true ? '支持' : '不支持' }}
|
||||
</Badge>
|
||||
</div>
|
||||
<div class="flex items-center gap-2 p-3 rounded-lg border">
|
||||
@@ -124,10 +124,10 @@
|
||||
</p>
|
||||
</div>
|
||||
<Badge
|
||||
:variant="model.default_supports_function_calling ?? false ? 'default' : 'secondary'"
|
||||
:variant="model.config?.function_calling === true ? 'default' : 'secondary'"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ model.default_supports_function_calling ?? false ? '支持' : '不支持' }}
|
||||
{{ model.config?.function_calling === true ? '支持' : '不支持' }}
|
||||
</Badge>
|
||||
</div>
|
||||
<div class="flex items-center gap-2 p-3 rounded-lg border">
|
||||
@@ -141,10 +141,10 @@
|
||||
</p>
|
||||
</div>
|
||||
<Badge
|
||||
:variant="model.default_supports_extended_thinking ?? false ? 'default' : 'secondary'"
|
||||
:variant="model.config?.extended_thinking === true ? 'default' : 'secondary'"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ model.default_supports_extended_thinking ?? false ? '支持' : '不支持' }}
|
||||
{{ model.config?.extended_thinking === true ? '支持' : '不支持' }}
|
||||
</Badge>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
12
migrate.sh
Executable file
12
migrate.sh
Executable file
@@ -0,0 +1,12 @@
|
||||
#!/bin/bash
|
||||
# 数据库迁移脚本 - 在 Docker 容器内执行 Alembic 迁移
|
||||
|
||||
set -e
|
||||
|
||||
CONTAINER_NAME="aether-app"
|
||||
|
||||
echo "Running database migrations in container: $CONTAINER_NAME"
|
||||
|
||||
docker exec $CONTAINER_NAME alembic upgrade head
|
||||
|
||||
echo "Database migration completed successfully"
|
||||
@@ -3,10 +3,8 @@
|
||||
A proxy server that enables AI models to work with multiple API providers.
|
||||
"""
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
# 注意: dotenv 加载已统一移至 src/config/settings.py
|
||||
# 不要在此处重复加载
|
||||
|
||||
try:
|
||||
from src._version import __version__
|
||||
|
||||
@@ -7,6 +7,7 @@ from .api_keys import router as api_keys_router
|
||||
from .endpoints import router as endpoints_router
|
||||
from .models import router as models_router
|
||||
from .monitoring import router as monitoring_router
|
||||
from .provider_query import router as provider_query_router
|
||||
from .provider_strategy import router as provider_strategy_router
|
||||
from .providers import router as providers_router
|
||||
from .security import router as security_router
|
||||
@@ -26,5 +27,6 @@ router.include_router(provider_strategy_router)
|
||||
router.include_router(adaptive_router)
|
||||
router.include_router(models_router)
|
||||
router.include_router(security_router)
|
||||
router.include_router(provider_query_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .catalog import router as catalog_router
|
||||
from .external import router as external_router
|
||||
from .global_models import router as global_models_router
|
||||
|
||||
router = APIRouter(prefix="/api/admin/models", tags=["Admin - Model Management"])
|
||||
@@ -12,3 +13,4 @@ router = APIRouter(prefix="/api/admin/models", tags=["Admin - Model Management"]
|
||||
# 挂载子路由
|
||||
router.include_router(catalog_router)
|
||||
router.include_router(global_models_router)
|
||||
router.include_router(external_router)
|
||||
|
||||
@@ -72,10 +72,12 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
|
||||
for gm in global_models:
|
||||
gm_id = gm.id
|
||||
provider_entries: List[ModelCatalogProviderDetail] = []
|
||||
# 从 config JSON 读取能力标志
|
||||
gm_config = gm.config or {}
|
||||
capability_flags = {
|
||||
"supports_vision": gm.default_supports_vision or False,
|
||||
"supports_function_calling": gm.default_supports_function_calling or False,
|
||||
"supports_streaming": gm.default_supports_streaming or False,
|
||||
"supports_vision": gm_config.get("vision", False),
|
||||
"supports_function_calling": gm_config.get("function_calling", False),
|
||||
"supports_streaming": gm_config.get("streaming", True),
|
||||
}
|
||||
|
||||
# 遍历该 GlobalModel 的所有关联提供商
|
||||
@@ -140,7 +142,7 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
|
||||
ModelCatalogItem(
|
||||
global_model_name=gm.name,
|
||||
display_name=gm.display_name,
|
||||
description=gm.description,
|
||||
description=gm_config.get("description"),
|
||||
providers=provider_entries,
|
||||
price_range=price_range,
|
||||
total_providers=len(provider_entries),
|
||||
|
||||
141
src/api/admin/models/external.py
Normal file
141
src/api/admin/models/external.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
models.dev 外部模型数据代理
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from src.clients import get_redis_client
|
||||
from src.core.logger import logger
|
||||
from src.models.database import User
|
||||
from src.utils.auth_utils import require_admin
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
CACHE_KEY = "aether:external:models_dev"
|
||||
CACHE_TTL = 15 * 60 # 15 分钟
|
||||
|
||||
# 标记官方/一手提供商,前端可据此过滤第三方转售商
|
||||
OFFICIAL_PROVIDERS = {
|
||||
"anthropic", # Claude 官方
|
||||
"openai", # OpenAI 官方
|
||||
"google", # Gemini 官方
|
||||
"google-vertex", # Google Vertex AI
|
||||
"azure", # Azure OpenAI
|
||||
"amazon-bedrock", # AWS Bedrock
|
||||
"xai", # Grok 官方
|
||||
"meta", # Llama 官方
|
||||
"deepseek", # DeepSeek 官方
|
||||
"mistral", # Mistral 官方
|
||||
"cohere", # Cohere 官方
|
||||
"zhipuai", # 智谱 AI 官方
|
||||
"alibaba", # 阿里云(通义千问)
|
||||
"minimax", # MiniMax 官方
|
||||
"moonshot", # 月之暗面(Kimi)
|
||||
"baichuan", # 百川智能
|
||||
"ai21", # AI21 Labs
|
||||
}
|
||||
|
||||
|
||||
async def _get_cached_data() -> Optional[dict[str, Any]]:
|
||||
"""从 Redis 获取缓存数据"""
|
||||
redis = await get_redis_client()
|
||||
if redis is None:
|
||||
return None
|
||||
try:
|
||||
cached = await redis.get(CACHE_KEY)
|
||||
if cached:
|
||||
result: dict[str, Any] = json.loads(cached)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"读取 models.dev 缓存失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def _set_cached_data(data: dict) -> None:
|
||||
"""将数据写入 Redis 缓存"""
|
||||
redis = await get_redis_client()
|
||||
if redis is None:
|
||||
return
|
||||
try:
|
||||
await redis.setex(CACHE_KEY, CACHE_TTL, json.dumps(data, ensure_ascii=False))
|
||||
except Exception as e:
|
||||
logger.warning(f"写入 models.dev 缓存失败: {e}")
|
||||
|
||||
|
||||
def _mark_official_providers(data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""为每个提供商标记是否为官方"""
|
||||
result = {}
|
||||
for provider_id, provider_data in data.items():
|
||||
result[provider_id] = {
|
||||
**provider_data,
|
||||
"official": provider_id in OFFICIAL_PROVIDERS,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/external")
|
||||
async def get_external_models(_: User = Depends(require_admin)) -> JSONResponse:
|
||||
"""
|
||||
获取 models.dev 的模型数据(代理请求,解决跨域问题)
|
||||
数据缓存 15 分钟(使用 Redis,多 worker 共享)
|
||||
每个提供商会标记 official 字段,前端可据此过滤
|
||||
"""
|
||||
# 检查缓存
|
||||
cached = await _get_cached_data()
|
||||
if cached is not None:
|
||||
# 兼容旧缓存:如果没有 official 字段则补全并回写
|
||||
try:
|
||||
needs_mark = False
|
||||
for provider_data in cached.values():
|
||||
if not isinstance(provider_data, dict) or "official" not in provider_data:
|
||||
needs_mark = True
|
||||
break
|
||||
if needs_mark:
|
||||
marked_cached = _mark_official_providers(cached)
|
||||
await _set_cached_data(marked_cached)
|
||||
return JSONResponse(content=marked_cached)
|
||||
except Exception as e:
|
||||
logger.warning(f"处理 models.dev 缓存数据失败,将直接返回原缓存: {e}")
|
||||
return JSONResponse(content=cached)
|
||||
|
||||
# 从 models.dev 获取数据
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get("https://models.dev/api.json")
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# 标记官方提供商
|
||||
marked_data = _mark_official_providers(data)
|
||||
|
||||
# 写入缓存
|
||||
await _set_cached_data(marked_data)
|
||||
|
||||
return JSONResponse(content=marked_data)
|
||||
except httpx.TimeoutException:
|
||||
raise HTTPException(status_code=504, detail="请求 models.dev 超时")
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise HTTPException(
|
||||
status_code=502, detail=f"models.dev 返回错误: {e.response.status_code}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=502, detail=f"获取外部模型数据失败: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/external/cache")
|
||||
async def clear_external_models_cache(_: User = Depends(require_admin)) -> dict:
|
||||
"""清除 models.dev 缓存"""
|
||||
redis = await get_redis_client()
|
||||
if redis is None:
|
||||
return {"cleared": False, "message": "Redis 未启用"}
|
||||
try:
|
||||
await redis.delete(CACHE_KEY)
|
||||
return {"cleared": True}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"清除缓存失败: {str(e)}")
|
||||
@@ -187,21 +187,15 @@ class AdminCreateGlobalModelAdapter(AdminApiAdapter):
|
||||
db=context.db,
|
||||
name=self.payload.name,
|
||||
display_name=self.payload.display_name,
|
||||
description=self.payload.description,
|
||||
official_url=self.payload.official_url,
|
||||
icon_url=self.payload.icon_url,
|
||||
is_active=self.payload.is_active,
|
||||
# 按次计费配置
|
||||
default_price_per_request=self.payload.default_price_per_request,
|
||||
# 阶梯计费配置
|
||||
default_tiered_pricing=tiered_pricing_dict,
|
||||
# 默认能力配置
|
||||
default_supports_vision=self.payload.default_supports_vision,
|
||||
default_supports_function_calling=self.payload.default_supports_function_calling,
|
||||
default_supports_streaming=self.payload.default_supports_streaming,
|
||||
default_supports_extended_thinking=self.payload.default_supports_extended_thinking,
|
||||
# Key 能力配置
|
||||
supported_capabilities=self.payload.supported_capabilities,
|
||||
# 模型配置(JSON)
|
||||
config=self.payload.config,
|
||||
)
|
||||
|
||||
logger.info(f"GlobalModel 已创建: id={global_model.id} name={global_model.name}")
|
||||
|
||||
@@ -21,7 +21,8 @@ from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import ApiKey, User
|
||||
from src.services.cache.affinity_manager import get_affinity_manager
|
||||
from src.services.cache.aware_scheduler import get_cache_aware_scheduler
|
||||
from src.services.cache.aware_scheduler import CacheAwareScheduler, get_cache_aware_scheduler
|
||||
from src.services.system.config import SystemConfigService
|
||||
|
||||
router = APIRouter(prefix="/api/admin/monitoring/cache", tags=["Admin - Monitoring: Cache"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
@@ -250,7 +251,22 @@ class AdminCacheStatsAdapter(AdminApiAdapter):
|
||||
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
|
||||
try:
|
||||
redis_client = get_redis_client_sync()
|
||||
scheduler = await get_cache_aware_scheduler(redis_client)
|
||||
# 读取系统配置,确保监控接口与编排器使用一致的模式
|
||||
priority_mode = SystemConfigService.get_config(
|
||||
context.db,
|
||||
"provider_priority_mode",
|
||||
CacheAwareScheduler.PRIORITY_MODE_PROVIDER,
|
||||
)
|
||||
scheduling_mode = SystemConfigService.get_config(
|
||||
context.db,
|
||||
"scheduling_mode",
|
||||
CacheAwareScheduler.SCHEDULING_MODE_CACHE_AFFINITY,
|
||||
)
|
||||
scheduler = await get_cache_aware_scheduler(
|
||||
redis_client,
|
||||
priority_mode=priority_mode,
|
||||
scheduling_mode=scheduling_mode,
|
||||
)
|
||||
stats = await scheduler.get_stats()
|
||||
logger.info("缓存统计信息查询成功")
|
||||
context.add_audit_metadata(
|
||||
@@ -270,7 +286,22 @@ class AdminCacheMetricsAdapter(AdminApiAdapter):
|
||||
async def handle(self, context: ApiRequestContext) -> PlainTextResponse:
|
||||
try:
|
||||
redis_client = get_redis_client_sync()
|
||||
scheduler = await get_cache_aware_scheduler(redis_client)
|
||||
# 读取系统配置,确保监控接口与编排器使用一致的模式
|
||||
priority_mode = SystemConfigService.get_config(
|
||||
context.db,
|
||||
"provider_priority_mode",
|
||||
CacheAwareScheduler.PRIORITY_MODE_PROVIDER,
|
||||
)
|
||||
scheduling_mode = SystemConfigService.get_config(
|
||||
context.db,
|
||||
"scheduling_mode",
|
||||
CacheAwareScheduler.SCHEDULING_MODE_CACHE_AFFINITY,
|
||||
)
|
||||
scheduler = await get_cache_aware_scheduler(
|
||||
redis_client,
|
||||
priority_mode=priority_mode,
|
||||
scheduling_mode=scheduling_mode,
|
||||
)
|
||||
stats = await scheduler.get_stats()
|
||||
payload = self._format_prometheus(stats)
|
||||
context.add_audit_metadata(
|
||||
|
||||
@@ -7,7 +7,7 @@ from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
@@ -52,8 +52,7 @@ class CandidateResponse(BaseModel):
|
||||
started_at: Optional[datetime] = None
|
||||
finished_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class RequestTraceResponse(BaseModel):
|
||||
|
||||
@@ -1,46 +1,28 @@
|
||||
"""
|
||||
Provider Query API 端点
|
||||
用于查询提供商的余额、使用记录等信息
|
||||
用于查询提供商的模型列表等信息
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.core.crypto import crypto_service
|
||||
from src.core.logger import logger
|
||||
from src.database.database import get_db
|
||||
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint, User
|
||||
|
||||
# 初始化适配器注册
|
||||
from src.plugins.provider_query import init # noqa
|
||||
from src.plugins.provider_query import get_query_registry
|
||||
from src.plugins.provider_query.base import QueryCapability
|
||||
from src.models.database import Provider, ProviderEndpoint, User
|
||||
from src.utils.auth_utils import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/provider-query", tags=["Provider Query"])
|
||||
router = APIRouter(prefix="/api/admin/provider-query", tags=["Provider Query"])
|
||||
|
||||
|
||||
# ============ Request/Response Models ============
|
||||
|
||||
|
||||
class BalanceQueryRequest(BaseModel):
|
||||
"""余额查询请求"""
|
||||
|
||||
provider_id: str
|
||||
api_key_id: Optional[str] = None # 如果不指定,使用提供商的第一个可用 API Key
|
||||
|
||||
|
||||
class UsageSummaryQueryRequest(BaseModel):
|
||||
"""使用汇总查询请求"""
|
||||
|
||||
provider_id: str
|
||||
api_key_id: Optional[str] = None
|
||||
period: str = "month" # day, week, month, year
|
||||
|
||||
|
||||
class ModelsQueryRequest(BaseModel):
|
||||
"""模型列表查询请求"""
|
||||
|
||||
@@ -51,360 +33,281 @@ class ModelsQueryRequest(BaseModel):
|
||||
# ============ API Endpoints ============
|
||||
|
||||
|
||||
@router.get("/adapters")
|
||||
async def list_adapters(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取所有可用的查询适配器
|
||||
async def _fetch_openai_models(
|
||||
client: httpx.AsyncClient,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
api_format: str,
|
||||
extra_headers: Optional[dict] = None,
|
||||
) -> tuple[list, Optional[str]]:
|
||||
"""获取 OpenAI 格式的模型列表
|
||||
|
||||
Returns:
|
||||
适配器列表
|
||||
tuple[list, Optional[str]]: (模型列表, 错误信息)
|
||||
"""
|
||||
registry = get_query_registry()
|
||||
adapters = registry.list_adapters()
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
if extra_headers:
|
||||
# 防止 extra_headers 覆盖 Authorization
|
||||
safe_headers = {k: v for k, v in extra_headers.items() if k.lower() != "authorization"}
|
||||
headers.update(safe_headers)
|
||||
|
||||
return {"success": True, "data": adapters}
|
||||
# 构建 /v1/models URL
|
||||
if base_url.endswith("/v1"):
|
||||
models_url = f"{base_url}/models"
|
||||
else:
|
||||
models_url = f"{base_url}/v1/models"
|
||||
|
||||
try:
|
||||
response = await client.get(models_url, headers=headers)
|
||||
logger.debug(f"OpenAI models request to {models_url}: status={response.status_code}")
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
models = []
|
||||
if "data" in data:
|
||||
models = data["data"]
|
||||
elif isinstance(data, list):
|
||||
models = data
|
||||
# 为每个模型添加 api_format 字段
|
||||
for m in models:
|
||||
m["api_format"] = api_format
|
||||
return models, None
|
||||
else:
|
||||
# 记录详细的错误信息
|
||||
error_body = response.text[:500] if response.text else "(empty)"
|
||||
error_msg = f"HTTP {response.status_code}: {error_body}"
|
||||
logger.warning(f"OpenAI models request to {models_url} failed: {error_msg}")
|
||||
return [], error_msg
|
||||
except Exception as e:
|
||||
error_msg = f"Request error: {str(e)}"
|
||||
logger.warning(f"Failed to fetch models from {models_url}: {e}")
|
||||
return [], error_msg
|
||||
|
||||
|
||||
@router.get("/capabilities/{provider_id}")
|
||||
async def get_provider_capabilities(
|
||||
provider_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取提供商支持的查询能力
|
||||
|
||||
Args:
|
||||
provider_id: 提供商 ID
|
||||
async def _fetch_claude_models(
|
||||
client: httpx.AsyncClient, base_url: str, api_key: str, api_format: str
|
||||
) -> tuple[list, Optional[str]]:
|
||||
"""获取 Claude 格式的模型列表
|
||||
|
||||
Returns:
|
||||
支持的查询能力列表
|
||||
tuple[list, Optional[str]]: (模型列表, 错误信息)
|
||||
"""
|
||||
# 获取提供商
|
||||
from sqlalchemy import select
|
||||
|
||||
result = await db.execute(select(Provider).where(Provider.id == provider_id))
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
registry = get_query_registry()
|
||||
capabilities = registry.get_capabilities_for_provider(provider.name)
|
||||
|
||||
if capabilities is None:
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"provider_id": provider_id,
|
||||
"provider_name": provider.name,
|
||||
"capabilities": [],
|
||||
"has_adapter": False,
|
||||
"message": "No query adapter available for this provider",
|
||||
},
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"provider_id": provider_id,
|
||||
"provider_name": provider.name,
|
||||
"capabilities": [c.name for c in capabilities],
|
||||
"has_adapter": True,
|
||||
},
|
||||
headers = {
|
||||
"x-api-key": api_key,
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"anthropic-version": "2023-06-01",
|
||||
}
|
||||
|
||||
# 构建 /v1/models URL
|
||||
if base_url.endswith("/v1"):
|
||||
models_url = f"{base_url}/models"
|
||||
else:
|
||||
models_url = f"{base_url}/v1/models"
|
||||
|
||||
@router.post("/balance")
|
||||
async def query_balance(
|
||||
request: BalanceQueryRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
查询提供商余额
|
||||
try:
|
||||
response = await client.get(models_url, headers=headers)
|
||||
logger.debug(f"Claude models request to {models_url}: status={response.status_code}")
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
models = []
|
||||
if "data" in data:
|
||||
models = data["data"]
|
||||
elif isinstance(data, list):
|
||||
models = data
|
||||
# 为每个模型添加 api_format 字段
|
||||
for m in models:
|
||||
m["api_format"] = api_format
|
||||
return models, None
|
||||
else:
|
||||
error_body = response.text[:500] if response.text else "(empty)"
|
||||
error_msg = f"HTTP {response.status_code}: {error_body}"
|
||||
logger.warning(f"Claude models request to {models_url} failed: {error_msg}")
|
||||
return [], error_msg
|
||||
except Exception as e:
|
||||
error_msg = f"Request error: {str(e)}"
|
||||
logger.warning(f"Failed to fetch Claude models from {models_url}: {e}")
|
||||
return [], error_msg
|
||||
|
||||
Args:
|
||||
request: 查询请求
|
||||
|
||||
async def _fetch_gemini_models(
|
||||
client: httpx.AsyncClient, base_url: str, api_key: str, api_format: str
|
||||
) -> tuple[list, Optional[str]]:
|
||||
"""获取 Gemini 格式的模型列表
|
||||
|
||||
Returns:
|
||||
余额信息
|
||||
tuple[list, Optional[str]]: (模型列表, 错误信息)
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
# 兼容 base_url 已包含 /v1beta 的情况
|
||||
base_url_clean = base_url.rstrip("/")
|
||||
if base_url_clean.endswith("/v1beta"):
|
||||
models_url = f"{base_url_clean}/models?key={api_key}"
|
||||
else:
|
||||
models_url = f"{base_url_clean}/v1beta/models?key={api_key}"
|
||||
|
||||
# 获取提供商及其端点
|
||||
result = await db.execute(
|
||||
select(Provider)
|
||||
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
|
||||
.where(Provider.id == request.provider_id)
|
||||
)
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
# 获取 API Key
|
||||
api_key_value = None
|
||||
endpoint_config = None
|
||||
|
||||
if request.api_key_id:
|
||||
# 查找指定的 API Key
|
||||
for endpoint in provider.endpoints:
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.id == request.api_key_id:
|
||||
api_key_value = api_key.api_key
|
||||
endpoint_config = {
|
||||
"base_url": endpoint.base_url,
|
||||
"api_format": endpoint.api_format if endpoint.api_format else None,
|
||||
try:
|
||||
response = await client.get(models_url)
|
||||
logger.debug(f"Gemini models request to {models_url}: status={response.status_code}")
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if "models" in data:
|
||||
# 转换为统一格式
|
||||
return [
|
||||
{
|
||||
"id": m.get("name", "").replace("models/", ""),
|
||||
"owned_by": "google",
|
||||
"display_name": m.get("displayName", ""),
|
||||
"api_format": api_format,
|
||||
}
|
||||
break
|
||||
if api_key_value:
|
||||
break
|
||||
|
||||
if not api_key_value:
|
||||
raise HTTPException(status_code=404, detail="API Key not found")
|
||||
else:
|
||||
# 使用第一个可用的 API Key
|
||||
for endpoint in provider.endpoints:
|
||||
if endpoint.is_active and endpoint.api_keys:
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.is_active:
|
||||
api_key_value = api_key.api_key
|
||||
endpoint_config = {
|
||||
"base_url": endpoint.base_url,
|
||||
"api_format": endpoint.api_format if endpoint.api_format else None,
|
||||
}
|
||||
break
|
||||
if api_key_value:
|
||||
break
|
||||
|
||||
if not api_key_value:
|
||||
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
|
||||
|
||||
# 查询余额
|
||||
registry = get_query_registry()
|
||||
query_result = await registry.query_provider_balance(
|
||||
provider_type=provider.name, api_key=api_key_value, endpoint_config=endpoint_config
|
||||
)
|
||||
|
||||
if not query_result.success:
|
||||
logger.warning(f"Balance query failed for provider {provider.name}: {query_result.error}")
|
||||
|
||||
return {
|
||||
"success": query_result.success,
|
||||
"data": query_result.to_dict(),
|
||||
"provider": {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.post("/usage-summary")
|
||||
async def query_usage_summary(
|
||||
request: UsageSummaryQueryRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
查询提供商使用汇总
|
||||
|
||||
Args:
|
||||
request: 查询请求
|
||||
|
||||
Returns:
|
||||
使用汇总信息
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
# 获取提供商及其端点
|
||||
result = await db.execute(
|
||||
select(Provider)
|
||||
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
|
||||
.where(Provider.id == request.provider_id)
|
||||
)
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
# 获取 API Key(逻辑同上)
|
||||
api_key_value = None
|
||||
endpoint_config = None
|
||||
|
||||
if request.api_key_id:
|
||||
for endpoint in provider.endpoints:
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.id == request.api_key_id:
|
||||
api_key_value = api_key.api_key
|
||||
endpoint_config = {"base_url": endpoint.base_url}
|
||||
break
|
||||
if api_key_value:
|
||||
break
|
||||
|
||||
if not api_key_value:
|
||||
raise HTTPException(status_code=404, detail="API Key not found")
|
||||
else:
|
||||
for endpoint in provider.endpoints:
|
||||
if endpoint.is_active and endpoint.api_keys:
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.is_active:
|
||||
api_key_value = api_key.api_key
|
||||
endpoint_config = {"base_url": endpoint.base_url}
|
||||
break
|
||||
if api_key_value:
|
||||
break
|
||||
|
||||
if not api_key_value:
|
||||
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
|
||||
|
||||
# 查询使用汇总
|
||||
registry = get_query_registry()
|
||||
query_result = await registry.query_provider_usage(
|
||||
provider_type=provider.name,
|
||||
api_key=api_key_value,
|
||||
period=request.period,
|
||||
endpoint_config=endpoint_config,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": query_result.success,
|
||||
"data": query_result.to_dict(),
|
||||
"provider": {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
},
|
||||
}
|
||||
for m in data["models"]
|
||||
], None
|
||||
return [], None
|
||||
else:
|
||||
error_body = response.text[:500] if response.text else "(empty)"
|
||||
error_msg = f"HTTP {response.status_code}: {error_body}"
|
||||
logger.warning(f"Gemini models request to {models_url} failed: {error_msg}")
|
||||
return [], error_msg
|
||||
except Exception as e:
|
||||
error_msg = f"Request error: {str(e)}"
|
||||
logger.warning(f"Failed to fetch Gemini models from {models_url}: {e}")
|
||||
return [], error_msg
|
||||
|
||||
|
||||
@router.post("/models")
|
||||
async def query_available_models(
|
||||
request: ModelsQueryRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
查询提供商可用模型
|
||||
|
||||
遍历所有活跃端点,根据端点的 API 格式选择正确的请求方式:
|
||||
- OPENAI/OPENAI_CLI: /v1/models (Bearer token)
|
||||
- CLAUDE/CLAUDE_CLI: /v1/models (x-api-key)
|
||||
- GEMINI/GEMINI_CLI: /v1beta/models (URL key parameter)
|
||||
|
||||
Args:
|
||||
request: 查询请求
|
||||
|
||||
Returns:
|
||||
模型列表
|
||||
所有端点的模型列表(合并)
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
# 获取提供商及其端点
|
||||
result = await db.execute(
|
||||
select(Provider)
|
||||
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
|
||||
.where(Provider.id == request.provider_id)
|
||||
provider = (
|
||||
db.query(Provider)
|
||||
.options(joinedload(Provider.endpoints).joinedload(ProviderEndpoint.api_keys))
|
||||
.filter(Provider.id == request.provider_id)
|
||||
.first()
|
||||
)
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
# 获取 API Key
|
||||
api_key_value = None
|
||||
endpoint_config = None
|
||||
# 收集所有活跃端点的配置
|
||||
endpoint_configs: list[dict] = []
|
||||
|
||||
if request.api_key_id:
|
||||
# 指定了特定的 API Key,只使用该 Key 对应的端点
|
||||
for endpoint in provider.endpoints:
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.id == request.api_key_id:
|
||||
api_key_value = api_key.api_key
|
||||
endpoint_config = {"base_url": endpoint.base_url}
|
||||
try:
|
||||
api_key_value = crypto_service.decrypt(api_key.api_key)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decrypt API key: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to decrypt API key")
|
||||
endpoint_configs.append({
|
||||
"api_key": api_key_value,
|
||||
"base_url": endpoint.base_url,
|
||||
"api_format": endpoint.api_format,
|
||||
"extra_headers": endpoint.headers,
|
||||
})
|
||||
break
|
||||
if api_key_value:
|
||||
if endpoint_configs:
|
||||
break
|
||||
|
||||
if not api_key_value:
|
||||
if not endpoint_configs:
|
||||
raise HTTPException(status_code=404, detail="API Key not found")
|
||||
else:
|
||||
# 遍历所有活跃端点,每个端点取第一个可用的 Key
|
||||
for endpoint in provider.endpoints:
|
||||
if endpoint.is_active and endpoint.api_keys:
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.is_active:
|
||||
api_key_value = api_key.api_key
|
||||
endpoint_config = {"base_url": endpoint.base_url}
|
||||
break
|
||||
if api_key_value:
|
||||
break
|
||||
if not endpoint.is_active or not endpoint.api_keys:
|
||||
continue
|
||||
|
||||
if not api_key_value:
|
||||
# 找第一个可用的 Key
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.is_active:
|
||||
try:
|
||||
api_key_value = crypto_service.decrypt(api_key.api_key)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decrypt API key: {e}")
|
||||
continue # 尝试下一个 Key
|
||||
endpoint_configs.append({
|
||||
"api_key": api_key_value,
|
||||
"base_url": endpoint.base_url,
|
||||
"api_format": endpoint.api_format,
|
||||
"extra_headers": endpoint.headers,
|
||||
})
|
||||
break # 只取第一个可用的 Key
|
||||
|
||||
if not endpoint_configs:
|
||||
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
|
||||
|
||||
# 查询模型
|
||||
registry = get_query_registry()
|
||||
adapter = registry.get_adapter_for_provider(provider.name)
|
||||
# 并发请求所有端点的模型列表
|
||||
all_models: list = []
|
||||
errors: list[str] = []
|
||||
|
||||
if not adapter:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"No query adapter available for provider: {provider.name}"
|
||||
async def fetch_endpoint_models(
|
||||
client: httpx.AsyncClient, config: dict
|
||||
) -> tuple[list, Optional[str]]:
|
||||
base_url = config["base_url"]
|
||||
if not base_url:
|
||||
return [], None
|
||||
base_url = base_url.rstrip("/")
|
||||
api_format = config["api_format"]
|
||||
api_key_value = config["api_key"]
|
||||
extra_headers = config["extra_headers"]
|
||||
|
||||
try:
|
||||
if api_format in ["CLAUDE", "CLAUDE_CLI"]:
|
||||
return await _fetch_claude_models(client, base_url, api_key_value, api_format)
|
||||
elif api_format in ["GEMINI", "GEMINI_CLI"]:
|
||||
return await _fetch_gemini_models(client, base_url, api_key_value, api_format)
|
||||
else:
|
||||
return await _fetch_openai_models(
|
||||
client, base_url, api_key_value, api_format, extra_headers
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching models from {api_format} endpoint: {e}")
|
||||
return [], f"{api_format}: {str(e)}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
results = await asyncio.gather(
|
||||
*[fetch_endpoint_models(client, c) for c in endpoint_configs]
|
||||
)
|
||||
for models, error in results:
|
||||
all_models.extend(models)
|
||||
if error:
|
||||
errors.append(error)
|
||||
|
||||
query_result = await adapter.query_available_models(
|
||||
api_key=api_key_value, endpoint_config=endpoint_config
|
||||
)
|
||||
# 按 model id 去重(保留第一个)
|
||||
seen_ids: set[str] = set()
|
||||
unique_models: list = []
|
||||
for model in all_models:
|
||||
model_id = model.get("id")
|
||||
if model_id and model_id not in seen_ids:
|
||||
seen_ids.add(model_id)
|
||||
unique_models.append(model)
|
||||
|
||||
error = "; ".join(errors) if errors else None
|
||||
if not unique_models and not error:
|
||||
error = "No models returned from any endpoint"
|
||||
|
||||
return {
|
||||
"success": query_result.success,
|
||||
"data": query_result.to_dict(),
|
||||
"success": len(unique_models) > 0,
|
||||
"data": {"models": unique_models, "error": error},
|
||||
"provider": {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/cache/{provider_id}")
|
||||
async def clear_query_cache(
|
||||
provider_id: str,
|
||||
api_key_id: Optional[str] = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
清除查询缓存
|
||||
|
||||
Args:
|
||||
provider_id: 提供商 ID
|
||||
api_key_id: 可选,指定清除某个 API Key 的缓存
|
||||
|
||||
Returns:
|
||||
清除结果
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
|
||||
# 获取提供商
|
||||
result = await db.execute(select(Provider).where(Provider.id == provider_id))
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
registry = get_query_registry()
|
||||
adapter = registry.get_adapter_for_provider(provider.name)
|
||||
|
||||
if adapter:
|
||||
if api_key_id:
|
||||
# 获取 API Key 值来清除缓存
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
result = await db.execute(select(ProviderAPIKey).where(ProviderAPIKey.id == api_key_id))
|
||||
api_key = result.scalar_one_or_none()
|
||||
if api_key:
|
||||
adapter.clear_cache(api_key.api_key)
|
||||
else:
|
||||
adapter.clear_cache()
|
||||
|
||||
return {"success": True, "message": "Cache cleared successfully"}
|
||||
|
||||
@@ -91,6 +91,34 @@ async def get_api_formats(request: Request, db: Session = Depends(get_db)):
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/config/export")
|
||||
async def export_config(request: Request, db: Session = Depends(get_db)):
|
||||
"""导出提供商和模型配置(管理员)"""
|
||||
adapter = AdminExportConfigAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/config/import")
|
||||
async def import_config(request: Request, db: Session = Depends(get_db)):
|
||||
"""导入提供商和模型配置(管理员)"""
|
||||
adapter = AdminImportConfigAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/users/export")
|
||||
async def export_users(request: Request, db: Session = Depends(get_db)):
|
||||
"""导出用户数据(管理员)"""
|
||||
adapter = AdminExportUsersAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/users/import")
|
||||
async def import_users(request: Request, db: Session = Depends(get_db)):
|
||||
"""导入用户数据(管理员)"""
|
||||
adapter = AdminImportUsersAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# -------- 系统设置适配器 --------
|
||||
|
||||
|
||||
@@ -310,3 +338,749 @@ class AdminGetApiFormatsAdapter(AdminApiAdapter):
|
||||
)
|
||||
|
||||
return {"formats": formats}
|
||||
|
||||
|
||||
class AdminExportConfigAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
"""导出提供商和模型配置(解密数据)"""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from src.core.crypto import crypto_service
|
||||
from src.models.database import GlobalModel, Model, ProviderAPIKey, ProviderEndpoint
|
||||
|
||||
db = context.db
|
||||
|
||||
# 导出 GlobalModels
|
||||
global_models = db.query(GlobalModel).all()
|
||||
global_models_data = []
|
||||
for gm in global_models:
|
||||
global_models_data.append(
|
||||
{
|
||||
"name": gm.name,
|
||||
"display_name": gm.display_name,
|
||||
"default_price_per_request": gm.default_price_per_request,
|
||||
"default_tiered_pricing": gm.default_tiered_pricing,
|
||||
"supported_capabilities": gm.supported_capabilities,
|
||||
"config": gm.config,
|
||||
"is_active": gm.is_active,
|
||||
}
|
||||
)
|
||||
|
||||
# 导出 Providers 及其关联数据
|
||||
providers = db.query(Provider).all()
|
||||
providers_data = []
|
||||
for provider in providers:
|
||||
# 导出 Endpoints
|
||||
endpoints = (
|
||||
db.query(ProviderEndpoint)
|
||||
.filter(ProviderEndpoint.provider_id == provider.id)
|
||||
.all()
|
||||
)
|
||||
endpoints_data = []
|
||||
for ep in endpoints:
|
||||
# 导出 Endpoint Keys
|
||||
keys = (
|
||||
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == ep.id).all()
|
||||
)
|
||||
keys_data = []
|
||||
for key in keys:
|
||||
# 解密 API Key
|
||||
try:
|
||||
decrypted_key = crypto_service.decrypt(key.api_key)
|
||||
except Exception:
|
||||
decrypted_key = ""
|
||||
|
||||
keys_data.append(
|
||||
{
|
||||
"api_key": decrypted_key,
|
||||
"name": key.name,
|
||||
"note": key.note,
|
||||
"rate_multiplier": key.rate_multiplier,
|
||||
"internal_priority": key.internal_priority,
|
||||
"global_priority": key.global_priority,
|
||||
"max_concurrent": key.max_concurrent,
|
||||
"rate_limit": key.rate_limit,
|
||||
"daily_limit": key.daily_limit,
|
||||
"monthly_limit": key.monthly_limit,
|
||||
"allowed_models": key.allowed_models,
|
||||
"capabilities": key.capabilities,
|
||||
"is_active": key.is_active,
|
||||
}
|
||||
)
|
||||
|
||||
endpoints_data.append(
|
||||
{
|
||||
"api_format": ep.api_format,
|
||||
"base_url": ep.base_url,
|
||||
"headers": ep.headers,
|
||||
"timeout": ep.timeout,
|
||||
"max_retries": ep.max_retries,
|
||||
"max_concurrent": ep.max_concurrent,
|
||||
"rate_limit": ep.rate_limit,
|
||||
"is_active": ep.is_active,
|
||||
"custom_path": ep.custom_path,
|
||||
"config": ep.config,
|
||||
"keys": keys_data,
|
||||
}
|
||||
)
|
||||
|
||||
# 导出 Provider Models
|
||||
models = db.query(Model).filter(Model.provider_id == provider.id).all()
|
||||
models_data = []
|
||||
for model in models:
|
||||
# 获取关联的 GlobalModel 名称
|
||||
global_model = (
|
||||
db.query(GlobalModel).filter(GlobalModel.id == model.global_model_id).first()
|
||||
)
|
||||
models_data.append(
|
||||
{
|
||||
"global_model_name": global_model.name if global_model else None,
|
||||
"provider_model_name": model.provider_model_name,
|
||||
"provider_model_aliases": model.provider_model_aliases,
|
||||
"price_per_request": model.price_per_request,
|
||||
"tiered_pricing": model.tiered_pricing,
|
||||
"supports_vision": model.supports_vision,
|
||||
"supports_function_calling": model.supports_function_calling,
|
||||
"supports_streaming": model.supports_streaming,
|
||||
"supports_extended_thinking": model.supports_extended_thinking,
|
||||
"supports_image_generation": model.supports_image_generation,
|
||||
"is_active": model.is_active,
|
||||
"config": model.config,
|
||||
}
|
||||
)
|
||||
|
||||
providers_data.append(
|
||||
{
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
"description": provider.description,
|
||||
"website": provider.website,
|
||||
"billing_type": provider.billing_type.value if provider.billing_type else None,
|
||||
"monthly_quota_usd": provider.monthly_quota_usd,
|
||||
"quota_reset_day": provider.quota_reset_day,
|
||||
"rpm_limit": provider.rpm_limit,
|
||||
"provider_priority": provider.provider_priority,
|
||||
"is_active": provider.is_active,
|
||||
"rate_limit": provider.rate_limit,
|
||||
"concurrent_limit": provider.concurrent_limit,
|
||||
"config": provider.config,
|
||||
"endpoints": endpoints_data,
|
||||
"models": models_data,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"version": "1.0",
|
||||
"exported_at": datetime.now(timezone.utc).isoformat(),
|
||||
"global_models": global_models_data,
|
||||
"providers": providers_data,
|
||||
}
|
||||
|
||||
|
||||
MAX_IMPORT_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
|
||||
|
||||
class AdminImportConfigAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
"""导入提供商和模型配置"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from src.core.crypto import crypto_service
|
||||
from src.core.enums import ProviderBillingType
|
||||
from src.models.database import GlobalModel, Model, ProviderAPIKey, ProviderEndpoint
|
||||
|
||||
# 检查请求体大小
|
||||
if context.raw_body and len(context.raw_body) > MAX_IMPORT_SIZE:
|
||||
raise InvalidRequestException("请求体大小不能超过 10MB")
|
||||
|
||||
db = context.db
|
||||
payload = context.ensure_json_body()
|
||||
|
||||
# 验证配置版本
|
||||
version = payload.get("version")
|
||||
if version != "1.0":
|
||||
raise InvalidRequestException(f"不支持的配置版本: {version}")
|
||||
|
||||
# 获取导入选项
|
||||
merge_mode = payload.get("merge_mode", "skip") # skip, overwrite, error
|
||||
global_models_data = payload.get("global_models", [])
|
||||
providers_data = payload.get("providers", [])
|
||||
|
||||
stats = {
|
||||
"global_models": {"created": 0, "updated": 0, "skipped": 0},
|
||||
"providers": {"created": 0, "updated": 0, "skipped": 0},
|
||||
"endpoints": {"created": 0, "updated": 0, "skipped": 0},
|
||||
"keys": {"created": 0, "updated": 0, "skipped": 0},
|
||||
"models": {"created": 0, "updated": 0, "skipped": 0},
|
||||
"errors": [],
|
||||
}
|
||||
|
||||
try:
|
||||
# 导入 GlobalModels
|
||||
global_model_map = {} # name -> id 映射
|
||||
for gm_data in global_models_data:
|
||||
existing = (
|
||||
db.query(GlobalModel).filter(GlobalModel.name == gm_data["name"]).first()
|
||||
)
|
||||
|
||||
if existing:
|
||||
global_model_map[gm_data["name"]] = existing.id
|
||||
if merge_mode == "skip":
|
||||
stats["global_models"]["skipped"] += 1
|
||||
continue
|
||||
elif merge_mode == "error":
|
||||
raise InvalidRequestException(
|
||||
f"GlobalModel '{gm_data['name']}' 已存在"
|
||||
)
|
||||
elif merge_mode == "overwrite":
|
||||
# 更新现有记录
|
||||
existing.display_name = gm_data.get(
|
||||
"display_name", existing.display_name
|
||||
)
|
||||
existing.default_price_per_request = gm_data.get(
|
||||
"default_price_per_request"
|
||||
)
|
||||
existing.default_tiered_pricing = gm_data.get(
|
||||
"default_tiered_pricing", existing.default_tiered_pricing
|
||||
)
|
||||
existing.supported_capabilities = gm_data.get(
|
||||
"supported_capabilities"
|
||||
)
|
||||
existing.config = gm_data.get("config")
|
||||
existing.is_active = gm_data.get("is_active", True)
|
||||
existing.updated_at = datetime.now(timezone.utc)
|
||||
stats["global_models"]["updated"] += 1
|
||||
else:
|
||||
# 创建新记录
|
||||
new_gm = GlobalModel(
|
||||
id=str(uuid.uuid4()),
|
||||
name=gm_data["name"],
|
||||
display_name=gm_data.get("display_name", gm_data["name"]),
|
||||
default_price_per_request=gm_data.get("default_price_per_request"),
|
||||
default_tiered_pricing=gm_data.get(
|
||||
"default_tiered_pricing",
|
||||
{"tiers": [{"up_to": None, "input_price_per_1m": 0, "output_price_per_1m": 0}]},
|
||||
),
|
||||
supported_capabilities=gm_data.get("supported_capabilities"),
|
||||
config=gm_data.get("config"),
|
||||
is_active=gm_data.get("is_active", True),
|
||||
)
|
||||
db.add(new_gm)
|
||||
db.flush()
|
||||
global_model_map[gm_data["name"]] = new_gm.id
|
||||
stats["global_models"]["created"] += 1
|
||||
|
||||
# 导入 Providers
|
||||
for prov_data in providers_data:
|
||||
existing_provider = (
|
||||
db.query(Provider).filter(Provider.name == prov_data["name"]).first()
|
||||
)
|
||||
|
||||
if existing_provider:
|
||||
provider_id = existing_provider.id
|
||||
if merge_mode == "skip":
|
||||
stats["providers"]["skipped"] += 1
|
||||
# 仍然需要处理 endpoints 和 models(如果存在)
|
||||
elif merge_mode == "error":
|
||||
raise InvalidRequestException(
|
||||
f"Provider '{prov_data['name']}' 已存在"
|
||||
)
|
||||
elif merge_mode == "overwrite":
|
||||
# 更新现有记录
|
||||
existing_provider.display_name = prov_data.get(
|
||||
"display_name", existing_provider.display_name
|
||||
)
|
||||
existing_provider.description = prov_data.get("description")
|
||||
existing_provider.website = prov_data.get("website")
|
||||
if prov_data.get("billing_type"):
|
||||
existing_provider.billing_type = ProviderBillingType(
|
||||
prov_data["billing_type"]
|
||||
)
|
||||
existing_provider.monthly_quota_usd = prov_data.get(
|
||||
"monthly_quota_usd"
|
||||
)
|
||||
existing_provider.quota_reset_day = prov_data.get(
|
||||
"quota_reset_day", 30
|
||||
)
|
||||
existing_provider.rpm_limit = prov_data.get("rpm_limit")
|
||||
existing_provider.provider_priority = prov_data.get(
|
||||
"provider_priority", 100
|
||||
)
|
||||
existing_provider.is_active = prov_data.get("is_active", True)
|
||||
existing_provider.rate_limit = prov_data.get("rate_limit")
|
||||
existing_provider.concurrent_limit = prov_data.get(
|
||||
"concurrent_limit"
|
||||
)
|
||||
existing_provider.config = prov_data.get("config")
|
||||
existing_provider.updated_at = datetime.now(timezone.utc)
|
||||
stats["providers"]["updated"] += 1
|
||||
else:
|
||||
# 创建新 Provider
|
||||
billing_type = ProviderBillingType.PAY_AS_YOU_GO
|
||||
if prov_data.get("billing_type"):
|
||||
billing_type = ProviderBillingType(prov_data["billing_type"])
|
||||
|
||||
new_provider = Provider(
|
||||
id=str(uuid.uuid4()),
|
||||
name=prov_data["name"],
|
||||
display_name=prov_data.get("display_name", prov_data["name"]),
|
||||
description=prov_data.get("description"),
|
||||
website=prov_data.get("website"),
|
||||
billing_type=billing_type,
|
||||
monthly_quota_usd=prov_data.get("monthly_quota_usd"),
|
||||
quota_reset_day=prov_data.get("quota_reset_day", 30),
|
||||
rpm_limit=prov_data.get("rpm_limit"),
|
||||
provider_priority=prov_data.get("provider_priority", 100),
|
||||
is_active=prov_data.get("is_active", True),
|
||||
rate_limit=prov_data.get("rate_limit"),
|
||||
concurrent_limit=prov_data.get("concurrent_limit"),
|
||||
config=prov_data.get("config"),
|
||||
)
|
||||
db.add(new_provider)
|
||||
db.flush()
|
||||
provider_id = new_provider.id
|
||||
stats["providers"]["created"] += 1
|
||||
|
||||
# 导入 Endpoints
|
||||
for ep_data in prov_data.get("endpoints", []):
|
||||
existing_ep = (
|
||||
db.query(ProviderEndpoint)
|
||||
.filter(
|
||||
ProviderEndpoint.provider_id == provider_id,
|
||||
ProviderEndpoint.api_format == ep_data["api_format"],
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_ep:
|
||||
endpoint_id = existing_ep.id
|
||||
if merge_mode == "skip":
|
||||
stats["endpoints"]["skipped"] += 1
|
||||
elif merge_mode == "error":
|
||||
raise InvalidRequestException(
|
||||
f"Endpoint '{ep_data['api_format']}' 已存在于 Provider '{prov_data['name']}'"
|
||||
)
|
||||
elif merge_mode == "overwrite":
|
||||
existing_ep.base_url = ep_data.get(
|
||||
"base_url", existing_ep.base_url
|
||||
)
|
||||
existing_ep.headers = ep_data.get("headers")
|
||||
existing_ep.timeout = ep_data.get("timeout", 300)
|
||||
existing_ep.max_retries = ep_data.get("max_retries", 3)
|
||||
existing_ep.max_concurrent = ep_data.get("max_concurrent")
|
||||
existing_ep.rate_limit = ep_data.get("rate_limit")
|
||||
existing_ep.is_active = ep_data.get("is_active", True)
|
||||
existing_ep.custom_path = ep_data.get("custom_path")
|
||||
existing_ep.config = ep_data.get("config")
|
||||
existing_ep.updated_at = datetime.now(timezone.utc)
|
||||
stats["endpoints"]["updated"] += 1
|
||||
else:
|
||||
new_ep = ProviderEndpoint(
|
||||
id=str(uuid.uuid4()),
|
||||
provider_id=provider_id,
|
||||
api_format=ep_data["api_format"],
|
||||
base_url=ep_data["base_url"],
|
||||
headers=ep_data.get("headers"),
|
||||
timeout=ep_data.get("timeout", 300),
|
||||
max_retries=ep_data.get("max_retries", 3),
|
||||
max_concurrent=ep_data.get("max_concurrent"),
|
||||
rate_limit=ep_data.get("rate_limit"),
|
||||
is_active=ep_data.get("is_active", True),
|
||||
custom_path=ep_data.get("custom_path"),
|
||||
config=ep_data.get("config"),
|
||||
)
|
||||
db.add(new_ep)
|
||||
db.flush()
|
||||
endpoint_id = new_ep.id
|
||||
stats["endpoints"]["created"] += 1
|
||||
|
||||
# 导入 Keys
|
||||
# 获取当前 endpoint 下所有已有的 keys,用于去重
|
||||
existing_keys = (
|
||||
db.query(ProviderAPIKey)
|
||||
.filter(ProviderAPIKey.endpoint_id == endpoint_id)
|
||||
.all()
|
||||
)
|
||||
# 解密已有 keys 用于比对
|
||||
existing_key_values = set()
|
||||
for ek in existing_keys:
|
||||
try:
|
||||
decrypted = crypto_service.decrypt(ek.api_key)
|
||||
existing_key_values.add(decrypted)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for key_data in ep_data.get("keys", []):
|
||||
if not key_data.get("api_key"):
|
||||
stats["errors"].append(
|
||||
f"跳过空 API Key (Endpoint: {ep_data['api_format']})"
|
||||
)
|
||||
continue
|
||||
|
||||
# 检查是否已存在相同的 Key(通过明文比对)
|
||||
if key_data["api_key"] in existing_key_values:
|
||||
stats["keys"]["skipped"] += 1
|
||||
continue
|
||||
|
||||
encrypted_key = crypto_service.encrypt(key_data["api_key"])
|
||||
|
||||
new_key = ProviderAPIKey(
|
||||
id=str(uuid.uuid4()),
|
||||
endpoint_id=endpoint_id,
|
||||
api_key=encrypted_key,
|
||||
name=key_data.get("name"),
|
||||
note=key_data.get("note"),
|
||||
rate_multiplier=key_data.get("rate_multiplier", 1.0),
|
||||
internal_priority=key_data.get("internal_priority", 100),
|
||||
global_priority=key_data.get("global_priority"),
|
||||
max_concurrent=key_data.get("max_concurrent"),
|
||||
rate_limit=key_data.get("rate_limit"),
|
||||
daily_limit=key_data.get("daily_limit"),
|
||||
monthly_limit=key_data.get("monthly_limit"),
|
||||
allowed_models=key_data.get("allowed_models"),
|
||||
capabilities=key_data.get("capabilities"),
|
||||
is_active=key_data.get("is_active", True),
|
||||
)
|
||||
db.add(new_key)
|
||||
# 添加到已有集合,防止同一批导入中重复
|
||||
existing_key_values.add(key_data["api_key"])
|
||||
stats["keys"]["created"] += 1
|
||||
|
||||
# 导入 Models
|
||||
for model_data in prov_data.get("models", []):
|
||||
global_model_name = model_data.get("global_model_name")
|
||||
if not global_model_name:
|
||||
stats["errors"].append(
|
||||
f"跳过无 global_model_name 的模型 (Provider: {prov_data['name']})"
|
||||
)
|
||||
continue
|
||||
|
||||
global_model_id = global_model_map.get(global_model_name)
|
||||
if not global_model_id:
|
||||
# 尝试从数据库查找
|
||||
existing_gm = (
|
||||
db.query(GlobalModel)
|
||||
.filter(GlobalModel.name == global_model_name)
|
||||
.first()
|
||||
)
|
||||
if existing_gm:
|
||||
global_model_id = existing_gm.id
|
||||
else:
|
||||
stats["errors"].append(
|
||||
f"GlobalModel '{global_model_name}' 不存在,跳过模型"
|
||||
)
|
||||
continue
|
||||
|
||||
existing_model = (
|
||||
db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == provider_id,
|
||||
Model.provider_model_name == model_data["provider_model_name"],
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_model:
|
||||
if merge_mode == "skip":
|
||||
stats["models"]["skipped"] += 1
|
||||
elif merge_mode == "error":
|
||||
raise InvalidRequestException(
|
||||
f"Model '{model_data['provider_model_name']}' 已存在于 Provider '{prov_data['name']}'"
|
||||
)
|
||||
elif merge_mode == "overwrite":
|
||||
existing_model.global_model_id = global_model_id
|
||||
existing_model.provider_model_aliases = model_data.get(
|
||||
"provider_model_aliases"
|
||||
)
|
||||
existing_model.price_per_request = model_data.get(
|
||||
"price_per_request"
|
||||
)
|
||||
existing_model.tiered_pricing = model_data.get(
|
||||
"tiered_pricing"
|
||||
)
|
||||
existing_model.supports_vision = model_data.get(
|
||||
"supports_vision"
|
||||
)
|
||||
existing_model.supports_function_calling = model_data.get(
|
||||
"supports_function_calling"
|
||||
)
|
||||
existing_model.supports_streaming = model_data.get(
|
||||
"supports_streaming"
|
||||
)
|
||||
existing_model.supports_extended_thinking = model_data.get(
|
||||
"supports_extended_thinking"
|
||||
)
|
||||
existing_model.supports_image_generation = model_data.get(
|
||||
"supports_image_generation"
|
||||
)
|
||||
existing_model.is_active = model_data.get("is_active", True)
|
||||
existing_model.config = model_data.get("config")
|
||||
existing_model.updated_at = datetime.now(timezone.utc)
|
||||
stats["models"]["updated"] += 1
|
||||
else:
|
||||
new_model = Model(
|
||||
id=str(uuid.uuid4()),
|
||||
provider_id=provider_id,
|
||||
global_model_id=global_model_id,
|
||||
provider_model_name=model_data["provider_model_name"],
|
||||
provider_model_aliases=model_data.get(
|
||||
"provider_model_aliases"
|
||||
),
|
||||
price_per_request=model_data.get("price_per_request"),
|
||||
tiered_pricing=model_data.get("tiered_pricing"),
|
||||
supports_vision=model_data.get("supports_vision"),
|
||||
supports_function_calling=model_data.get(
|
||||
"supports_function_calling"
|
||||
),
|
||||
supports_streaming=model_data.get("supports_streaming"),
|
||||
supports_extended_thinking=model_data.get(
|
||||
"supports_extended_thinking"
|
||||
),
|
||||
supports_image_generation=model_data.get(
|
||||
"supports_image_generation"
|
||||
),
|
||||
is_active=model_data.get("is_active", True),
|
||||
config=model_data.get("config"),
|
||||
)
|
||||
db.add(new_model)
|
||||
stats["models"]["created"] += 1
|
||||
|
||||
db.commit()
|
||||
|
||||
# 失效缓存
|
||||
from src.services.cache.invalidation import get_cache_invalidation_service
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.clear_all_caches()
|
||||
|
||||
return {
|
||||
"message": "配置导入成功",
|
||||
"stats": stats,
|
||||
}
|
||||
|
||||
except InvalidRequestException:
|
||||
db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise InvalidRequestException(f"导入失败: {str(e)}")
|
||||
|
||||
|
||||
class AdminExportUsersAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
"""导出用户数据(保留加密数据,排除管理员)"""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from src.core.enums import UserRole
|
||||
from src.models.database import ApiKey, User
|
||||
|
||||
db = context.db
|
||||
|
||||
# 导出 Users(排除管理员)
|
||||
users = db.query(User).filter(
|
||||
User.is_deleted.is_(False),
|
||||
User.role != UserRole.ADMIN
|
||||
).all()
|
||||
users_data = []
|
||||
for user in users:
|
||||
# 导出用户的 API Keys(保留加密数据)
|
||||
api_keys = db.query(ApiKey).filter(ApiKey.user_id == user.id).all()
|
||||
api_keys_data = []
|
||||
for key in api_keys:
|
||||
api_keys_data.append(
|
||||
{
|
||||
"key_hash": key.key_hash,
|
||||
"key_encrypted": key.key_encrypted,
|
||||
"name": key.name,
|
||||
"is_standalone": key.is_standalone,
|
||||
"balance_used_usd": key.balance_used_usd,
|
||||
"current_balance_usd": key.current_balance_usd,
|
||||
"allowed_providers": key.allowed_providers,
|
||||
"allowed_endpoints": key.allowed_endpoints,
|
||||
"allowed_api_formats": key.allowed_api_formats,
|
||||
"allowed_models": key.allowed_models,
|
||||
"rate_limit": key.rate_limit,
|
||||
"concurrent_limit": key.concurrent_limit,
|
||||
"force_capabilities": key.force_capabilities,
|
||||
"is_active": key.is_active,
|
||||
"auto_delete_on_expiry": key.auto_delete_on_expiry,
|
||||
"total_requests": key.total_requests,
|
||||
"total_cost_usd": key.total_cost_usd,
|
||||
}
|
||||
)
|
||||
|
||||
users_data.append(
|
||||
{
|
||||
"email": user.email,
|
||||
"username": user.username,
|
||||
"password_hash": user.password_hash,
|
||||
"role": user.role.value if user.role else "user",
|
||||
"allowed_providers": user.allowed_providers,
|
||||
"allowed_endpoints": user.allowed_endpoints,
|
||||
"allowed_models": user.allowed_models,
|
||||
"model_capability_settings": user.model_capability_settings,
|
||||
"quota_usd": user.quota_usd,
|
||||
"used_usd": user.used_usd,
|
||||
"total_usd": user.total_usd,
|
||||
"is_active": user.is_active,
|
||||
"api_keys": api_keys_data,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"version": "1.0",
|
||||
"exported_at": datetime.now(timezone.utc).isoformat(),
|
||||
"users": users_data,
|
||||
}
|
||||
|
||||
|
||||
class AdminImportUsersAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
"""导入用户数据"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from src.core.enums import UserRole
|
||||
from src.models.database import ApiKey, User
|
||||
|
||||
# 检查请求体大小
|
||||
if context.raw_body and len(context.raw_body) > MAX_IMPORT_SIZE:
|
||||
raise InvalidRequestException("请求体大小不能超过 10MB")
|
||||
|
||||
db = context.db
|
||||
payload = context.ensure_json_body()
|
||||
|
||||
# 验证配置版本
|
||||
version = payload.get("version")
|
||||
if version != "1.0":
|
||||
raise InvalidRequestException(f"不支持的配置版本: {version}")
|
||||
|
||||
# 获取导入选项
|
||||
merge_mode = payload.get("merge_mode", "skip") # skip, overwrite, error
|
||||
users_data = payload.get("users", [])
|
||||
|
||||
stats = {
|
||||
"users": {"created": 0, "updated": 0, "skipped": 0},
|
||||
"api_keys": {"created": 0, "skipped": 0},
|
||||
"errors": [],
|
||||
}
|
||||
|
||||
try:
|
||||
for user_data in users_data:
|
||||
# 跳过管理员角色的导入(不区分大小写)
|
||||
role_str = str(user_data.get("role", "")).lower()
|
||||
if role_str == "admin":
|
||||
stats["errors"].append(f"跳过管理员用户: {user_data.get('email')}")
|
||||
stats["users"]["skipped"] += 1
|
||||
continue
|
||||
|
||||
existing_user = (
|
||||
db.query(User).filter(User.email == user_data["email"]).first()
|
||||
)
|
||||
|
||||
if existing_user:
|
||||
user_id = existing_user.id
|
||||
if merge_mode == "skip":
|
||||
stats["users"]["skipped"] += 1
|
||||
elif merge_mode == "error":
|
||||
raise InvalidRequestException(
|
||||
f"用户 '{user_data['email']}' 已存在"
|
||||
)
|
||||
elif merge_mode == "overwrite":
|
||||
# 更新现有用户
|
||||
existing_user.username = user_data.get(
|
||||
"username", existing_user.username
|
||||
)
|
||||
if user_data.get("password_hash"):
|
||||
existing_user.password_hash = user_data["password_hash"]
|
||||
if user_data.get("role"):
|
||||
existing_user.role = UserRole(user_data["role"])
|
||||
existing_user.allowed_providers = user_data.get("allowed_providers")
|
||||
existing_user.allowed_endpoints = user_data.get("allowed_endpoints")
|
||||
existing_user.allowed_models = user_data.get("allowed_models")
|
||||
existing_user.model_capability_settings = user_data.get(
|
||||
"model_capability_settings"
|
||||
)
|
||||
existing_user.quota_usd = user_data.get("quota_usd")
|
||||
existing_user.used_usd = user_data.get("used_usd", 0.0)
|
||||
existing_user.total_usd = user_data.get("total_usd", 0.0)
|
||||
existing_user.is_active = user_data.get("is_active", True)
|
||||
existing_user.updated_at = datetime.now(timezone.utc)
|
||||
stats["users"]["updated"] += 1
|
||||
else:
|
||||
# 创建新用户
|
||||
role = UserRole.USER
|
||||
if user_data.get("role"):
|
||||
role = UserRole(user_data["role"])
|
||||
|
||||
new_user = User(
|
||||
id=str(uuid.uuid4()),
|
||||
email=user_data["email"],
|
||||
username=user_data.get("username", user_data["email"].split("@")[0]),
|
||||
password_hash=user_data.get("password_hash", ""),
|
||||
role=role,
|
||||
allowed_providers=user_data.get("allowed_providers"),
|
||||
allowed_endpoints=user_data.get("allowed_endpoints"),
|
||||
allowed_models=user_data.get("allowed_models"),
|
||||
model_capability_settings=user_data.get("model_capability_settings"),
|
||||
quota_usd=user_data.get("quota_usd"),
|
||||
used_usd=user_data.get("used_usd", 0.0),
|
||||
total_usd=user_data.get("total_usd", 0.0),
|
||||
is_active=user_data.get("is_active", True),
|
||||
)
|
||||
db.add(new_user)
|
||||
db.flush()
|
||||
user_id = new_user.id
|
||||
stats["users"]["created"] += 1
|
||||
|
||||
# 导入 API Keys
|
||||
for key_data in user_data.get("api_keys", []):
|
||||
# 检查是否已存在相同的 key_hash
|
||||
if key_data.get("key_hash"):
|
||||
existing_key = (
|
||||
db.query(ApiKey)
|
||||
.filter(ApiKey.key_hash == key_data["key_hash"])
|
||||
.first()
|
||||
)
|
||||
if existing_key:
|
||||
stats["api_keys"]["skipped"] += 1
|
||||
continue
|
||||
|
||||
new_key = ApiKey(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
key_hash=key_data.get("key_hash", ""),
|
||||
key_encrypted=key_data.get("key_encrypted"),
|
||||
name=key_data.get("name"),
|
||||
is_standalone=key_data.get("is_standalone", False),
|
||||
balance_used_usd=key_data.get("balance_used_usd", 0.0),
|
||||
current_balance_usd=key_data.get("current_balance_usd"),
|
||||
allowed_providers=key_data.get("allowed_providers"),
|
||||
allowed_endpoints=key_data.get("allowed_endpoints"),
|
||||
allowed_api_formats=key_data.get("allowed_api_formats"),
|
||||
allowed_models=key_data.get("allowed_models"),
|
||||
rate_limit=key_data.get("rate_limit", 100),
|
||||
concurrent_limit=key_data.get("concurrent_limit", 5),
|
||||
force_capabilities=key_data.get("force_capabilities"),
|
||||
is_active=key_data.get("is_active", True),
|
||||
auto_delete_on_expiry=key_data.get("auto_delete_on_expiry", False),
|
||||
total_requests=key_data.get("total_requests", 0),
|
||||
total_cost_usd=key_data.get("total_cost_usd", 0.0),
|
||||
)
|
||||
db.add(new_key)
|
||||
stats["api_keys"]["created"] += 1
|
||||
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"message": "用户数据导入成功",
|
||||
"stats": stats,
|
||||
}
|
||||
|
||||
except InvalidRequestException:
|
||||
db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise InvalidRequestException(f"导入失败: {str(e)}")
|
||||
|
||||
@@ -142,7 +142,7 @@ class AnnouncementOptionalAuthAdapter(ApiAdapter):
|
||||
|
||||
token = authorization.replace("Bearer ", "").strip()
|
||||
try:
|
||||
payload = await AuthService.verify_token(token)
|
||||
payload = await AuthService.verify_token(token, token_type="access")
|
||||
user_id = payload.get("user_id")
|
||||
if not user_id:
|
||||
return None
|
||||
|
||||
@@ -211,7 +211,7 @@ class AuthRefreshAdapter(AuthPublicAdapter):
|
||||
|
||||
class AuthRegisterAdapter(AuthPublicAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from ..models.database import SystemConfig
|
||||
from src.models.database import SystemConfig
|
||||
|
||||
db = context.db
|
||||
payload = context.ensure_json_body()
|
||||
|
||||
@@ -65,6 +65,21 @@ class ModelInfo:
|
||||
created_at: Optional[str] # ISO 格式
|
||||
created_timestamp: int # Unix 时间戳
|
||||
provider_name: str
|
||||
# 能力配置
|
||||
streaming: bool = True
|
||||
vision: bool = False
|
||||
function_calling: bool = False
|
||||
extended_thinking: bool = False
|
||||
image_generation: bool = False
|
||||
structured_output: bool = False
|
||||
# 规格参数
|
||||
context_limit: Optional[int] = None
|
||||
output_limit: Optional[int] = None
|
||||
# 元信息
|
||||
family: Optional[str] = None
|
||||
knowledge_cutoff: Optional[str] = None
|
||||
input_modalities: Optional[list[str]] = None
|
||||
output_modalities: Optional[list[str]] = None
|
||||
|
||||
|
||||
def get_available_provider_ids(db: Session, api_formats: list[str]) -> set[str]:
|
||||
@@ -181,13 +196,19 @@ def _extract_model_info(model: Any) -> ModelInfo:
|
||||
global_model = model.global_model
|
||||
model_id: str = global_model.name if global_model else model.provider_model_name
|
||||
display_name: str = global_model.display_name if global_model else model.provider_model_name
|
||||
description: Optional[str] = global_model.description if global_model else None
|
||||
created_at: Optional[str] = (
|
||||
model.created_at.strftime("%Y-%m-%dT%H:%M:%SZ") if model.created_at else None
|
||||
)
|
||||
created_timestamp: int = int(model.created_at.timestamp()) if model.created_at else 0
|
||||
provider_name: str = model.provider.name if model.provider else "unknown"
|
||||
|
||||
# 从 GlobalModel.config 提取配置信息
|
||||
config: dict = {}
|
||||
description: Optional[str] = None
|
||||
if global_model:
|
||||
config = global_model.config or {}
|
||||
description = config.get("description")
|
||||
|
||||
return ModelInfo(
|
||||
id=model_id,
|
||||
display_name=display_name,
|
||||
@@ -195,6 +216,21 @@ def _extract_model_info(model: Any) -> ModelInfo:
|
||||
created_at=created_at,
|
||||
created_timestamp=created_timestamp,
|
||||
provider_name=provider_name,
|
||||
# 能力配置
|
||||
streaming=config.get("streaming", True),
|
||||
vision=config.get("vision", False),
|
||||
function_calling=config.get("function_calling", False),
|
||||
extended_thinking=config.get("extended_thinking", False),
|
||||
image_generation=config.get("image_generation", False),
|
||||
structured_output=config.get("structured_output", False),
|
||||
# 规格参数
|
||||
context_limit=config.get("context_limit"),
|
||||
output_limit=config.get("output_limit"),
|
||||
# 元信息
|
||||
family=config.get("family"),
|
||||
knowledge_cutoff=config.get("knowledge_cutoff"),
|
||||
input_modalities=config.get("input_modalities"),
|
||||
output_modalities=config.get("output_modalities"),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -5,13 +5,12 @@ from enum import Enum
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.exceptions import QuotaExceededException
|
||||
from src.core.logger import logger
|
||||
from src.models.database import ApiKey, AuditEventType, User, UserRole
|
||||
from src.services.auth.service import AuthService
|
||||
from src.services.cache.user_cache import UserCacheService
|
||||
from src.services.system.audit import AuditService
|
||||
from src.services.usage.service import UsageService
|
||||
|
||||
@@ -180,7 +179,7 @@ class ApiRequestPipeline:
|
||||
|
||||
token = authorization.replace("Bearer ", "").strip()
|
||||
try:
|
||||
payload = await self.auth_service.verify_token(token)
|
||||
payload = await self.auth_service.verify_token(token, token_type="access")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
@@ -191,8 +190,8 @@ class ApiRequestPipeline:
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="无效的管理员令牌")
|
||||
|
||||
# 使用缓存查询用户
|
||||
user = await UserCacheService.get_user_by_id(db, user_id)
|
||||
# 直接查询数据库,确保返回的是当前 Session 绑定的对象
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(status_code=403, detail="用户不存在或已禁用")
|
||||
|
||||
@@ -207,7 +206,7 @@ class ApiRequestPipeline:
|
||||
|
||||
token = authorization.replace("Bearer ", "").strip()
|
||||
try:
|
||||
payload = await self.auth_service.verify_token(token)
|
||||
payload = await self.auth_service.verify_token(token, token_type="access")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
@@ -218,8 +217,8 @@ class ApiRequestPipeline:
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="无效的用户令牌")
|
||||
|
||||
# 使用缓存查询用户
|
||||
user = await UserCacheService.get_user_by_id(db, user_id)
|
||||
# 直接查询数据库,确保返回的是当前 Session 绑定的对象
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(status_code=403, detail="用户不存在或已禁用")
|
||||
|
||||
@@ -242,11 +241,15 @@ class ApiRequestPipeline:
|
||||
status_code: Optional[int] = None,
|
||||
error: Optional[str] = None,
|
||||
) -> None:
|
||||
"""记录审计事件
|
||||
|
||||
事务策略:复用请求级 Session,不单独提交。
|
||||
审计记录随主事务一起提交,由中间件统一管理。
|
||||
"""
|
||||
if not getattr(adapter, "audit_log_enabled", True):
|
||||
return
|
||||
|
||||
bind = context.db.get_bind()
|
||||
if bind is None:
|
||||
if context.db is None:
|
||||
return
|
||||
|
||||
event_type = adapter.audit_success_event if success else adapter.audit_failure_event
|
||||
@@ -266,11 +269,11 @@ class ApiRequestPipeline:
|
||||
error=error,
|
||||
)
|
||||
|
||||
SessionMaker = sessionmaker(bind=bind)
|
||||
audit_session = SessionMaker()
|
||||
try:
|
||||
# 复用请求级 Session,不创建新的连接
|
||||
# 审计记录随主事务一起提交,由中间件统一管理
|
||||
self.audit_service.log_event(
|
||||
db=audit_session,
|
||||
db=context.db,
|
||||
event_type=event_type,
|
||||
description=f"{context.request.method} {context.request.url.path} via {adapter.name}",
|
||||
user_id=context.user.id if context.user else None,
|
||||
@@ -282,12 +285,9 @@ class ApiRequestPipeline:
|
||||
error_message=error,
|
||||
metadata=metadata,
|
||||
)
|
||||
audit_session.commit()
|
||||
except Exception as exc:
|
||||
audit_session.rollback()
|
||||
# 审计失败不应影响主请求,仅记录警告
|
||||
logger.warning(f"[Audit] Failed to record event for adapter={adapter.name}: {exc}")
|
||||
finally:
|
||||
audit_session.close()
|
||||
|
||||
def _build_audit_metadata(
|
||||
self,
|
||||
|
||||
@@ -731,8 +731,15 @@ class DashboardDailyStatsAdapter(DashboardAdapter):
|
||||
)
|
||||
# stats_daily.date 存储的是业务日期对应的 UTC 开始时间
|
||||
# 需要转回业务时区再取日期,才能与日期序列匹配
|
||||
def _to_business_date_str(value: datetime) -> str:
|
||||
if value.tzinfo is None:
|
||||
value_utc = value.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
value_utc = value.astimezone(timezone.utc)
|
||||
return value_utc.astimezone(app_tz).date().isoformat()
|
||||
|
||||
stats_map = {
|
||||
stat.date.replace(tzinfo=timezone.utc).astimezone(app_tz).date().isoformat(): {
|
||||
_to_business_date_str(stat.date): {
|
||||
"requests": stat.total_requests,
|
||||
"tokens": stat.input_tokens + stat.output_tokens + stat.cache_creation_tokens + stat.cache_read_tokens,
|
||||
"cost": stat.total_cost,
|
||||
@@ -790,6 +797,38 @@ class DashboardDailyStatsAdapter(DashboardAdapter):
|
||||
"unique_providers": today_unique_providers,
|
||||
"fallback_count": today_fallback_count,
|
||||
}
|
||||
|
||||
# 历史预聚合缺失时兜底:按业务日范围实时计算(仅补最近少量缺失,避免全表扫描)
|
||||
yesterday_date = today_local.date() - timedelta(days=1)
|
||||
historical_end = min(end_date_local.date(), yesterday_date)
|
||||
missing_dates: list[str] = []
|
||||
cursor = start_date_local.date()
|
||||
while cursor <= historical_end:
|
||||
date_str = cursor.isoformat()
|
||||
if date_str not in stats_map:
|
||||
missing_dates.append(date_str)
|
||||
cursor += timedelta(days=1)
|
||||
|
||||
if missing_dates:
|
||||
for date_str in missing_dates[-7:]:
|
||||
target_local = datetime.fromisoformat(date_str).replace(tzinfo=app_tz)
|
||||
computed = StatsAggregatorService.compute_daily_stats(db, target_local)
|
||||
stats_map[date_str] = {
|
||||
"requests": computed["total_requests"],
|
||||
"tokens": (
|
||||
computed["input_tokens"]
|
||||
+ computed["output_tokens"]
|
||||
+ computed["cache_creation_tokens"]
|
||||
+ computed["cache_read_tokens"]
|
||||
),
|
||||
"cost": computed["total_cost"],
|
||||
"avg_response_time": computed["avg_response_time_ms"] / 1000.0
|
||||
if computed["avg_response_time_ms"]
|
||||
else 0,
|
||||
"unique_models": computed["unique_models"],
|
||||
"unique_providers": computed["unique_providers"],
|
||||
"fallback_count": computed["fallback_count"],
|
||||
}
|
||||
else:
|
||||
# 普通用户:仍需实时查询(用户级预聚合可选)
|
||||
query = db.query(Usage).filter(
|
||||
|
||||
@@ -411,9 +411,10 @@ class BaseMessageHandler:
|
||||
QuotaExceededException,
|
||||
RateLimitException,
|
||||
ModelNotSupportedException,
|
||||
UpstreamClientException,
|
||||
)
|
||||
|
||||
if isinstance(error, (ProviderException, QuotaExceededException, RateLimitException, ModelNotSupportedException)):
|
||||
if isinstance(error, (ProviderException, QuotaExceededException, RateLimitException, ModelNotSupportedException, UpstreamClientException)):
|
||||
# 业务异常:简洁日志,不打印堆栈
|
||||
logger.error(f"{message}: [{type(error).__name__}] {error}")
|
||||
else:
|
||||
|
||||
@@ -266,8 +266,11 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
if mapping and mapping.model:
|
||||
# 使用 select_provider_model_name 支持别名功能
|
||||
# 传入 api_key.id 作为 affinity_key,实现相同用户稳定选择同一别名
|
||||
# 传入 api_format 用于过滤适用的别名作用域
|
||||
affinity_key = self.api_key.id if self.api_key else None
|
||||
mapped_name = mapping.model.select_provider_model_name(affinity_key)
|
||||
mapped_name = mapping.model.select_provider_model_name(
|
||||
affinity_key, api_format=self.FORMAT_ID
|
||||
)
|
||||
logger.debug(f"[Chat] 模型映射: {source_model} -> {mapped_name}")
|
||||
return mapped_name
|
||||
|
||||
|
||||
@@ -155,8 +155,11 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
if mapping and mapping.model:
|
||||
# 使用 select_provider_model_name 支持别名功能
|
||||
# 传入 api_key.id 作为 affinity_key,实现相同用户稳定选择同一别名
|
||||
# 传入 api_format 用于过滤适用的别名作用域
|
||||
affinity_key = self.api_key.id if self.api_key else None
|
||||
mapped_name = mapping.model.select_provider_model_name(affinity_key)
|
||||
mapped_name = mapping.model.select_provider_model_name(
|
||||
affinity_key, api_format=self.FORMAT_ID
|
||||
)
|
||||
logger.debug(f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)")
|
||||
return mapped_name
|
||||
|
||||
|
||||
@@ -210,9 +210,9 @@ class PublicModelsAdapter(PublicApiAdapter):
|
||||
provider_display_name=provider.display_name,
|
||||
name=unified_name,
|
||||
display_name=display_name,
|
||||
description=global_model.description if global_model else None,
|
||||
description=global_model.config.get("description") if global_model and global_model.config else None,
|
||||
tags=None,
|
||||
icon_url=global_model.icon_url if global_model else None,
|
||||
icon_url=global_model.config.get("icon_url") if global_model and global_model.config else None,
|
||||
input_price_per_1m=model.get_effective_input_price(),
|
||||
output_price_per_1m=model.get_effective_output_price(),
|
||||
cache_creation_price_per_1m=model.get_effective_cache_creation_price(),
|
||||
@@ -274,7 +274,6 @@ class PublicSearchModelsAdapter(PublicApiAdapter):
|
||||
Model.provider_model_name.ilike(f"%{self.query}%")
|
||||
| GlobalModel.name.ilike(f"%{self.query}%")
|
||||
| GlobalModel.display_name.ilike(f"%{self.query}%")
|
||||
| GlobalModel.description.ilike(f"%{self.query}%")
|
||||
)
|
||||
query_stmt = query_stmt.filter(search_filter)
|
||||
if self.provider_id is not None:
|
||||
@@ -293,9 +292,9 @@ class PublicSearchModelsAdapter(PublicApiAdapter):
|
||||
provider_display_name=provider.display_name,
|
||||
name=unified_name,
|
||||
display_name=display_name,
|
||||
description=global_model.description if global_model else None,
|
||||
description=global_model.config.get("description") if global_model and global_model.config else None,
|
||||
tags=None,
|
||||
icon_url=global_model.icon_url if global_model else None,
|
||||
icon_url=global_model.config.get("icon_url") if global_model and global_model.config else None,
|
||||
input_price_per_1m=model.get_effective_input_price(),
|
||||
output_price_per_1m=model.get_effective_output_price(),
|
||||
cache_creation_price_per_1m=model.get_effective_cache_creation_price(),
|
||||
@@ -499,7 +498,6 @@ class PublicGlobalModelsAdapter(PublicApiAdapter):
|
||||
or_(
|
||||
GlobalModel.name.ilike(search_term),
|
||||
GlobalModel.display_name.ilike(search_term),
|
||||
GlobalModel.description.ilike(search_term),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -517,21 +515,11 @@ class PublicGlobalModelsAdapter(PublicApiAdapter):
|
||||
id=gm.id,
|
||||
name=gm.name,
|
||||
display_name=gm.display_name,
|
||||
description=gm.description,
|
||||
icon_url=gm.icon_url,
|
||||
is_active=gm.is_active,
|
||||
default_price_per_request=gm.default_price_per_request,
|
||||
default_tiered_pricing=gm.default_tiered_pricing,
|
||||
default_supports_vision=gm.default_supports_vision or False,
|
||||
default_supports_function_calling=gm.default_supports_function_calling or False,
|
||||
default_supports_streaming=(
|
||||
gm.default_supports_streaming
|
||||
if gm.default_supports_streaming is not None
|
||||
else True
|
||||
),
|
||||
default_supports_extended_thinking=gm.default_supports_extended_thinking
|
||||
or False,
|
||||
supported_capabilities=gm.supported_capabilities,
|
||||
config=gm.config,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -251,8 +251,8 @@ def _build_gemini_list_response(
|
||||
"version": "001",
|
||||
"displayName": m.display_name,
|
||||
"description": m.description or f"Model {m.id}",
|
||||
"inputTokenLimit": 128000,
|
||||
"outputTokenLimit": 8192,
|
||||
"inputTokenLimit": m.context_limit if m.context_limit is not None else 128000,
|
||||
"outputTokenLimit": m.output_limit if m.output_limit is not None else 8192,
|
||||
"supportedGenerationMethods": ["generateContent", "countTokens"],
|
||||
"temperature": 1.0,
|
||||
"maxTemperature": 2.0,
|
||||
@@ -297,8 +297,8 @@ def _build_gemini_model_response(model_info: ModelInfo) -> dict:
|
||||
"version": "001",
|
||||
"displayName": model_info.display_name,
|
||||
"description": model_info.description or f"Model {model_info.id}",
|
||||
"inputTokenLimit": 128000,
|
||||
"outputTokenLimit": 8192,
|
||||
"inputTokenLimit": model_info.context_limit if model_info.context_limit is not None else 128000,
|
||||
"outputTokenLimit": model_info.output_limit if model_info.output_limit is not None else 8192,
|
||||
"supportedGenerationMethods": ["generateContent", "countTokens"],
|
||||
"temperature": 1.0,
|
||||
"maxTemperature": 2.0,
|
||||
|
||||
@@ -267,6 +267,9 @@ async def get_redis_client(require_redis: bool = False) -> Optional[aioredis.Red
|
||||
|
||||
if _redis_manager is None:
|
||||
_redis_manager = RedisClientManager()
|
||||
# 如果尚未连接(例如启动时降级、或 close() 后),尝试重新初始化。
|
||||
# initialize() 内部包含熔断器逻辑,避免频繁重试导致抖动。
|
||||
if _redis_manager.get_client() is None:
|
||||
await _redis_manager.initialize(require_redis=require_redis)
|
||||
|
||||
return _redis_manager.get_client()
|
||||
|
||||
@@ -41,8 +41,8 @@ class CacheSize:
|
||||
class ConcurrencyDefaults:
|
||||
"""并发控制默认值"""
|
||||
|
||||
# 自适应并发初始限制(保守值)
|
||||
INITIAL_LIMIT = 3
|
||||
# 自适应并发初始限制(宽松起步,遇到 429 再降低)
|
||||
INITIAL_LIMIT = 50
|
||||
|
||||
# 429错误后的冷却时间(分钟)- 在此期间不会增加并发限制
|
||||
COOLDOWN_AFTER_429_MINUTES = 5
|
||||
@@ -67,13 +67,14 @@ class ConcurrencyDefaults:
|
||||
MIN_SAMPLES_FOR_DECISION = 5
|
||||
|
||||
# 扩容步长 - 每次扩容增加的并发数
|
||||
INCREASE_STEP = 1
|
||||
INCREASE_STEP = 2
|
||||
|
||||
# 缩容乘数 - 遇到 429 时的缩容比例
|
||||
DECREASE_MULTIPLIER = 0.7
|
||||
# 缩容乘数 - 遇到 429 时基于当前并发数的缩容比例
|
||||
# 0.85 表示降到触发 429 时并发数的 85%
|
||||
DECREASE_MULTIPLIER = 0.85
|
||||
|
||||
# 最大并发限制上限
|
||||
MAX_CONCURRENT_LIMIT = 100
|
||||
MAX_CONCURRENT_LIMIT = 200
|
||||
|
||||
# 最小并发限制下限
|
||||
MIN_CONCURRENT_LIMIT = 1
|
||||
@@ -85,6 +86,11 @@ class ConcurrencyDefaults:
|
||||
# 探测性扩容最小请求数 - 在探测间隔内至少需要这么多请求
|
||||
PROBE_INCREASE_MIN_REQUESTS = 10
|
||||
|
||||
# === 缓存用户预留比例 ===
|
||||
# 缓存用户槽位预留比例(新用户可用 1 - 此值)
|
||||
# 0.1 表示缓存用户预留 10%,新用户可用 90%
|
||||
CACHE_RESERVATION_RATIO = 0.1
|
||||
|
||||
|
||||
class CircuitBreakerDefaults:
|
||||
"""熔断器配置默认值(滑动窗口 + 半开状态模式)
|
||||
|
||||
@@ -122,9 +122,9 @@ class Config:
|
||||
|
||||
# 并发控制配置
|
||||
# CONCURRENCY_SLOT_TTL: 并发槽位 TTL(秒),防止死锁
|
||||
# CACHE_RESERVATION_RATIO: 缓存用户预留比例(默认 30%)
|
||||
# CACHE_RESERVATION_RATIO: 缓存用户预留比例(默认 10%,新用户可用 90%)
|
||||
self.concurrency_slot_ttl = int(os.getenv("CONCURRENCY_SLOT_TTL", "600"))
|
||||
self.cache_reservation_ratio = float(os.getenv("CACHE_RESERVATION_RATIO", "0.3"))
|
||||
self.cache_reservation_ratio = float(os.getenv("CACHE_RESERVATION_RATIO", "0.1"))
|
||||
|
||||
# HTTP 请求超时配置(秒)
|
||||
self.http_connect_timeout = float(os.getenv("HTTP_CONNECT_TIMEOUT", "10.0"))
|
||||
|
||||
@@ -46,6 +46,11 @@ class BatchCommitter:
|
||||
|
||||
def mark_dirty(self, session: Session):
|
||||
"""标记 Session 有待提交的更改"""
|
||||
# 请求级事务由中间件统一 commit/rollback;避免后台任务在请求中途误提交。
|
||||
if session is None:
|
||||
return
|
||||
if session.info.get("managed_by_middleware"):
|
||||
return
|
||||
self._pending_sessions.add(session)
|
||||
|
||||
async def _batch_commit_loop(self):
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
"""
|
||||
统一的请求上下文
|
||||
|
||||
RequestContext 贯穿整个请求生命周期,包含所有请求相关信息。
|
||||
这确保了数据在各层之间传递时不会丢失。
|
||||
|
||||
使用方式:
|
||||
1. Pipeline 层创建 RequestContext
|
||||
2. 各层通过 context 访问和更新信息
|
||||
3. Adapter 层使用 context 记录 Usage
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext:
|
||||
"""
|
||||
请求上下文 - 贯穿整个请求生命周期
|
||||
|
||||
设计原则:
|
||||
1. 在请求开始时创建,包含所有已知信息
|
||||
2. 在请求执行过程中逐步填充 Provider 信息
|
||||
3. 在请求结束时用于记录 Usage
|
||||
"""
|
||||
|
||||
# ==================== 请求标识 ====================
|
||||
request_id: str
|
||||
|
||||
# ==================== 认证信息 ====================
|
||||
user: Any # User model
|
||||
api_key: Any # ApiKey model
|
||||
db: Any # Database session
|
||||
|
||||
# ==================== 请求信息 ====================
|
||||
api_format: str # CLAUDE, OPENAI, GEMINI, etc.
|
||||
model: str # 用户请求的模型名
|
||||
is_stream: bool = False
|
||||
|
||||
# ==================== 原始请求 ====================
|
||||
original_headers: Dict[str, str] = field(default_factory=dict)
|
||||
original_body: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# ==================== 客户端信息 ====================
|
||||
client_ip: str = "unknown"
|
||||
user_agent: str = ""
|
||||
|
||||
# ==================== 计时 ====================
|
||||
start_time: float = field(default_factory=time.time)
|
||||
|
||||
# ==================== Provider 信息(请求执行后填充)====================
|
||||
provider_name: Optional[str] = None
|
||||
provider_id: Optional[str] = None
|
||||
endpoint_id: Optional[str] = None
|
||||
provider_api_key_id: Optional[str] = None
|
||||
|
||||
# ==================== 模型映射信息 ====================
|
||||
resolved_model: Optional[str] = None # 映射后的模型名
|
||||
original_model: Optional[str] = None # 原始模型名(用于价格计算)
|
||||
|
||||
# ==================== 请求/响应头 ====================
|
||||
provider_request_headers: Dict[str, str] = field(default_factory=dict)
|
||||
provider_response_headers: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
# ==================== 追踪信息 ====================
|
||||
attempt_id: Optional[str] = None
|
||||
|
||||
# ==================== 能力需求 ====================
|
||||
capability_requirements: Dict[str, bool] = field(default_factory=dict)
|
||||
# 运行时计算的能力需求,来源于:
|
||||
# 1. 用户 model_capability_settings
|
||||
# 2. 用户 ApiKey.force_capabilities
|
||||
# 3. 请求头 X-Require-Capability
|
||||
# 4. 失败重试时动态添加
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
*,
|
||||
db: Any,
|
||||
user: Any,
|
||||
api_key: Any,
|
||||
api_format: str,
|
||||
model: str,
|
||||
is_stream: bool = False,
|
||||
original_headers: Optional[Dict[str, str]] = None,
|
||||
original_body: Optional[Dict[str, Any]] = None,
|
||||
client_ip: str = "unknown",
|
||||
user_agent: str = "",
|
||||
request_id: Optional[str] = None,
|
||||
) -> "RequestContext":
|
||||
"""创建请求上下文"""
|
||||
return cls(
|
||||
request_id=request_id or str(uuid.uuid4()),
|
||||
db=db,
|
||||
user=user,
|
||||
api_key=api_key,
|
||||
api_format=api_format,
|
||||
model=model,
|
||||
is_stream=is_stream,
|
||||
original_headers=original_headers or {},
|
||||
original_body=original_body or {},
|
||||
client_ip=client_ip,
|
||||
user_agent=user_agent,
|
||||
original_model=model, # 初始时原始模型等于请求模型
|
||||
)
|
||||
|
||||
def update_provider_info(
|
||||
self,
|
||||
*,
|
||||
provider_name: str,
|
||||
provider_id: str,
|
||||
endpoint_id: str,
|
||||
provider_api_key_id: str,
|
||||
resolved_model: Optional[str] = None,
|
||||
) -> None:
|
||||
"""更新 Provider 信息(请求执行后调用)"""
|
||||
self.provider_name = provider_name
|
||||
self.provider_id = provider_id
|
||||
self.endpoint_id = endpoint_id
|
||||
self.provider_api_key_id = provider_api_key_id
|
||||
if resolved_model:
|
||||
self.resolved_model = resolved_model
|
||||
|
||||
def update_headers(
|
||||
self,
|
||||
*,
|
||||
request_headers: Optional[Dict[str, str]] = None,
|
||||
response_headers: Optional[Dict[str, str]] = None,
|
||||
) -> None:
|
||||
"""更新请求/响应头"""
|
||||
if request_headers:
|
||||
self.provider_request_headers = request_headers
|
||||
if response_headers:
|
||||
self.provider_response_headers = response_headers
|
||||
|
||||
@property
|
||||
def elapsed_ms(self) -> int:
|
||||
"""计算已经过的时间(毫秒)"""
|
||||
return int((time.time() - self.start_time) * 1000)
|
||||
|
||||
@property
|
||||
def effective_model(self) -> str:
|
||||
"""获取有效的模型名(映射后优先)"""
|
||||
return self.resolved_model or self.model
|
||||
|
||||
@property
|
||||
def billing_model(self) -> str:
|
||||
"""获取计费模型名(原始模型优先)"""
|
||||
return self.original_model or self.model
|
||||
|
||||
def to_metadata_dict(self) -> Dict[str, Any]:
|
||||
"""转换为元数据字典(用于 Usage 记录)"""
|
||||
return {
|
||||
"api_format": self.api_format,
|
||||
"provider": self.provider_name or "unknown",
|
||||
"model": self.effective_model,
|
||||
"original_model": self.billing_model,
|
||||
"provider_id": self.provider_id,
|
||||
"provider_endpoint_id": self.endpoint_id,
|
||||
"provider_api_key_id": self.provider_api_key_id,
|
||||
"provider_request_headers": self.provider_request_headers,
|
||||
"provider_response_headers": self.provider_response_headers,
|
||||
"attempt_id": self.attempt_id,
|
||||
}
|
||||
@@ -9,7 +9,7 @@
|
||||
|
||||
输出策略:
|
||||
- 控制台: 开发环境=DEBUG, 生产环境=INFO (通过 LOG_LEVEL 控制)
|
||||
- 文件: 始终保存 DEBUG 级别,保留30天,每日轮转
|
||||
- 文件: 始终保存 DEBUG 级别,保留30天,按大小轮转 (100MB)
|
||||
|
||||
使用方式:
|
||||
from src.core.logger import logger
|
||||
@@ -72,12 +72,15 @@ def _log_filter(record: dict) -> bool: # type: ignore[type-arg]
|
||||
|
||||
|
||||
if IS_DOCKER:
|
||||
# 生产环境:禁用 backtrace 和 diagnose,减少日志噪音
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
format=CONSOLE_FORMAT_PROD,
|
||||
level=LOG_LEVEL,
|
||||
filter=_log_filter, # type: ignore[arg-type]
|
||||
colorize=False,
|
||||
backtrace=False,
|
||||
diagnose=False,
|
||||
)
|
||||
else:
|
||||
logger.add(
|
||||
@@ -92,30 +95,37 @@ if not DISABLE_FILE_LOG:
|
||||
log_dir = PROJECT_ROOT / "logs"
|
||||
log_dir.mkdir(exist_ok=True)
|
||||
|
||||
# 文件日志通用配置
|
||||
file_log_config = {
|
||||
"format": FILE_FORMAT,
|
||||
"filter": _log_filter,
|
||||
"rotation": "100 MB",
|
||||
"retention": "30 days",
|
||||
"compression": "gz",
|
||||
"enqueue": True,
|
||||
"encoding": "utf-8",
|
||||
"catch": True,
|
||||
}
|
||||
|
||||
# 生产环境禁用详细堆栈
|
||||
if IS_DOCKER:
|
||||
file_log_config["backtrace"] = False
|
||||
file_log_config["diagnose"] = False
|
||||
|
||||
# 主日志文件 - 所有级别
|
||||
logger.add(
|
||||
log_dir / "app.log",
|
||||
format=FILE_FORMAT,
|
||||
level="DEBUG",
|
||||
filter=_log_filter, # type: ignore[arg-type]
|
||||
rotation="00:00",
|
||||
retention="30 days",
|
||||
compression="gz",
|
||||
enqueue=True,
|
||||
encoding="utf-8",
|
||||
**file_log_config, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# 错误日志文件 - 仅 ERROR 及以上
|
||||
error_log_config = file_log_config.copy()
|
||||
error_log_config["rotation"] = "50 MB"
|
||||
logger.add(
|
||||
log_dir / "error.log",
|
||||
format=FILE_FORMAT,
|
||||
level="ERROR",
|
||||
filter=_log_filter, # type: ignore[arg-type]
|
||||
rotation="00:00",
|
||||
retention="30 days",
|
||||
compression="gz",
|
||||
enqueue=True,
|
||||
encoding="utf-8",
|
||||
**error_log_config, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
import time
|
||||
from typing import AsyncGenerator, Generator, Optional
|
||||
|
||||
from starlette.requests import Request
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
@@ -150,9 +151,22 @@ def _log_pool_capacity():
|
||||
theoretical = config.db_pool_size + config.db_max_overflow
|
||||
workers = max(1, config.worker_processes)
|
||||
total_estimated = theoretical * workers
|
||||
logger.info("数据库连接池配置")
|
||||
if total_estimated > config.db_pool_warn_threshold:
|
||||
logger.warning("数据库连接需求可能超过阈值,请调小池大小或减少 worker 数")
|
||||
safe_limit = config.pg_max_connections - config.pg_reserved_connections
|
||||
logger.info(
|
||||
"数据库连接池配置: pool_size=%s, max_overflow=%s, workers=%s, total_estimated=%s, safe_limit=%s",
|
||||
config.db_pool_size,
|
||||
config.db_max_overflow,
|
||||
workers,
|
||||
total_estimated,
|
||||
safe_limit,
|
||||
)
|
||||
if total_estimated > safe_limit:
|
||||
logger.warning(
|
||||
"数据库连接池总需求可能超过 PostgreSQL 限制: %s > %s (pg_max_connections - reserved),"
|
||||
"建议调整 DB_POOL_SIZE/DB_MAX_OVERFLOW 或减少 worker 数",
|
||||
total_estimated,
|
||||
safe_limit,
|
||||
)
|
||||
|
||||
|
||||
def _ensure_async_engine() -> AsyncEngine:
|
||||
@@ -185,7 +199,7 @@ def _ensure_async_engine() -> AsyncEngine:
|
||||
# 创建异步引擎
|
||||
_async_engine = create_async_engine(
|
||||
ASYNC_DATABASE_URL,
|
||||
poolclass=QueuePool, # 使用队列连接池
|
||||
# AsyncEngine 不能使用 QueuePool;默认使用 AsyncAdaptedQueuePool
|
||||
pool_size=config.db_pool_size,
|
||||
max_overflow=config.db_max_overflow,
|
||||
pool_timeout=config.db_pool_timeout,
|
||||
@@ -209,7 +223,18 @@ def _ensure_async_engine() -> AsyncEngine:
|
||||
|
||||
|
||||
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""获取异步数据库会话"""
|
||||
"""获取异步数据库会话
|
||||
|
||||
.. deprecated::
|
||||
此方法已废弃,项目统一使用同步 Session。
|
||||
未来版本可能移除此方法。请使用 get_db() 代替。
|
||||
"""
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"get_async_db() 已废弃,项目统一使用同步 Session。请使用 get_db() 代替。",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
# 确保异步引擎已初始化
|
||||
_ensure_async_engine()
|
||||
|
||||
@@ -220,16 +245,61 @@ async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
await session.close()
|
||||
|
||||
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
def get_db(request: Request = None) -> Generator[Session, None, None]: # type: ignore[assignment]
|
||||
"""获取数据库会话
|
||||
|
||||
注意:事务管理由业务逻辑层显式控制(手动调用 commit/rollback)
|
||||
这里只负责会话的创建和关闭,不自动提交
|
||||
事务策略说明
|
||||
============
|
||||
本项目采用**混合事务管理**策略:
|
||||
|
||||
1. **LLM 请求路径**:
|
||||
- 由 PluginMiddleware 统一管理事务
|
||||
- Service 层使用 db.flush() 使更改可见,但不提交
|
||||
- 请求结束时由中间件统一 commit 或 rollback
|
||||
- 例外:UsageService.record_usage() 会显式 commit,因为使用记录需要立即持久化
|
||||
|
||||
2. **管理后台 API**:
|
||||
- 路由层显式调用 db.commit()
|
||||
- 每个操作独立提交,不依赖中间件
|
||||
|
||||
3. **后台任务/调度器**:
|
||||
- 使用独立 Session(通过 create_session() 或 next(get_db()))
|
||||
- 自行管理事务生命周期
|
||||
|
||||
使用方式
|
||||
========
|
||||
- FastAPI 请求:通过 Depends(get_db) 注入,支持中间件管理的 session 复用
|
||||
- 非请求上下文:直接调用 get_db(),退化为独立 session 模式
|
||||
|
||||
注意事项
|
||||
========
|
||||
- 本函数不自动提交事务
|
||||
- 异常时会自动回滚
|
||||
- 中间件管理模式下,session 关闭由中间件负责
|
||||
"""
|
||||
# FastAPI 请求上下文:优先复用中间件绑定的 request.state.db
|
||||
if request is not None:
|
||||
existing_db = getattr(getattr(request, "state", None), "db", None)
|
||||
if isinstance(existing_db, Session):
|
||||
yield existing_db
|
||||
return
|
||||
|
||||
# 确保引擎已初始化
|
||||
_ensure_engine()
|
||||
|
||||
db = _SessionLocal()
|
||||
|
||||
# 如果中间件声明会统一管理会话生命周期,则把 session 绑定到 request.state,
|
||||
# 并由中间件负责 commit/rollback/close(这里不关闭,避免流式响应提前释放会话)。
|
||||
managed_by_middleware = bool(
|
||||
request is not None
|
||||
and hasattr(request, "state")
|
||||
and getattr(request.state, "db_managed_by_middleware", False)
|
||||
)
|
||||
if managed_by_middleware:
|
||||
request.state.db = db
|
||||
db.info["managed_by_middleware"] = True
|
||||
|
||||
try:
|
||||
yield db
|
||||
# 不再自动 commit,由业务代码显式管理事务
|
||||
@@ -241,12 +311,13 @@ def get_db() -> Generator[Session, None, None]:
|
||||
logger.debug(f"回滚事务时出错(可忽略): {rollback_error}")
|
||||
raise
|
||||
finally:
|
||||
try:
|
||||
db.close() # 确保连接返回池
|
||||
except Exception as close_error:
|
||||
# 记录关闭错误(如 IllegalStateChangeError)
|
||||
# 连接池会处理连接的回收
|
||||
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
|
||||
if not managed_by_middleware:
|
||||
try:
|
||||
db.close() # 确保连接返回池
|
||||
except Exception as close_error:
|
||||
# 记录关闭错误(如 IllegalStateChangeError)
|
||||
# 连接池会处理连接的回收
|
||||
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
|
||||
|
||||
|
||||
def create_session() -> Session:
|
||||
@@ -273,16 +344,17 @@ def get_db_url() -> str:
|
||||
|
||||
|
||||
def init_db():
|
||||
"""初始化数据库"""
|
||||
"""初始化数据库
|
||||
|
||||
注意:数据库表结构由 Alembic 管理,部署时请运行 ./migrate.sh
|
||||
"""
|
||||
logger.info("初始化数据库...")
|
||||
|
||||
# 确保引擎已创建
|
||||
engine = _ensure_engine()
|
||||
_ensure_engine()
|
||||
|
||||
# 创建所有表
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
# 数据库表已通过SQLAlchemy自动创建
|
||||
# 数据库表结构由 Alembic 迁移管理
|
||||
# 首次部署或更新后请运行: ./migrate.sh
|
||||
|
||||
db = _SessionLocal()
|
||||
try:
|
||||
@@ -335,7 +407,7 @@ def init_admin_user(db: Session):
|
||||
admin.set_password(config.admin_password)
|
||||
|
||||
db.add(admin)
|
||||
db.commit() # 刷新以获取ID,但不提交
|
||||
db.flush() # 分配ID,但不提交事务(由外层 init_db 统一 commit)
|
||||
|
||||
logger.info(f"创建管理员账户成功: {admin.email} ({admin.username})")
|
||||
except Exception as e:
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
采用模块化架构设计
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
@@ -39,14 +38,12 @@ async def initialize_providers():
|
||||
"""从数据库初始化提供商(仅用于日志记录)"""
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.core.enums import APIFormat
|
||||
from src.database import get_db
|
||||
from src.database.database import create_session
|
||||
from src.models.database import Provider
|
||||
|
||||
try:
|
||||
# 创建数据库会话
|
||||
db_gen = get_db()
|
||||
db: Session = next(db_gen)
|
||||
db: Session = create_session()
|
||||
|
||||
try:
|
||||
# 从数据库加载所有活跃的提供商
|
||||
@@ -75,7 +72,7 @@ async def initialize_providers():
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("从数据库初始化提供商失败")
|
||||
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ from starlette.responses import Response as StarletteResponse
|
||||
|
||||
from src.config import config
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.plugins.manager import get_plugin_manager
|
||||
from src.plugins.rate_limit.base import RateLimitResult
|
||||
|
||||
@@ -71,26 +70,13 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
start_time = time.time()
|
||||
request.state.request_id = request.headers.get("x-request-id", "")
|
||||
request.state.start_time = start_time
|
||||
# 标记:若请求过程中通过 Depends(get_db) 创建了会话,则由本中间件统一管理其生命周期
|
||||
request.state.db_managed_by_middleware = True
|
||||
|
||||
# 从 request.app 获取 FastAPI 应用实例(而不是从 __init__ 的 app 参数)
|
||||
# 这样才能访问到真正的 FastAPI 实例和其 dependency_overrides
|
||||
db_func = get_db
|
||||
if hasattr(request, "app") and hasattr(request.app, "dependency_overrides"):
|
||||
if get_db in request.app.dependency_overrides:
|
||||
db_func = request.app.dependency_overrides[get_db]
|
||||
logger.debug("Using overridden get_db from app.dependency_overrides")
|
||||
|
||||
# 创建数据库会话供需要的插件或后续处理使用
|
||||
db_gen = db_func()
|
||||
db = None
|
||||
response = None
|
||||
exception_to_raise = None
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
db = next(db_gen)
|
||||
request.state.db = db
|
||||
|
||||
# 1. 限流插件调用(可选功能)
|
||||
rate_limit_result = await self._call_rate_limit_plugins(request)
|
||||
if rate_limit_result and not rate_limit_result.allowed:
|
||||
@@ -111,10 +97,17 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
# 3. 提交关键数据库事务(在返回响应前)
|
||||
# 这确保了 Usage 记录、配额扣减等关键数据在响应返回前持久化
|
||||
try:
|
||||
db.commit()
|
||||
db = getattr(request.state, "db", None)
|
||||
if isinstance(db, Session):
|
||||
db.commit()
|
||||
except Exception as commit_error:
|
||||
logger.error(f"关键事务提交失败: {commit_error}")
|
||||
db.rollback()
|
||||
try:
|
||||
if isinstance(db, Session):
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
await self._call_error_plugins(request, commit_error, start_time)
|
||||
# 返回 500 错误,因为数据可能不一致
|
||||
response = JSONResponse(
|
||||
status_code=500,
|
||||
@@ -139,14 +132,18 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
except RuntimeError as e:
|
||||
if str(e) == "No response returned.":
|
||||
if db:
|
||||
db.rollback()
|
||||
db = getattr(request.state, "db", None)
|
||||
if isinstance(db, Session):
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.error("Downstream handler completed without returning a response")
|
||||
|
||||
await self._call_error_plugins(request, e, start_time)
|
||||
|
||||
if db:
|
||||
if isinstance(db, Session):
|
||||
try:
|
||||
db.commit()
|
||||
except Exception:
|
||||
@@ -167,14 +164,18 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
except Exception as e:
|
||||
# 回滚数据库事务
|
||||
if db:
|
||||
db.rollback()
|
||||
db = getattr(request.state, "db", None)
|
||||
if isinstance(db, Session):
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 错误处理插件调用
|
||||
await self._call_error_plugins(request, e, start_time)
|
||||
|
||||
# 尝试提交错误日志
|
||||
if db:
|
||||
if isinstance(db, Session):
|
||||
try:
|
||||
db.commit()
|
||||
except:
|
||||
@@ -183,38 +184,13 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
exception_to_raise = e
|
||||
|
||||
finally:
|
||||
# 确保数据库会话被正确关闭
|
||||
# 注意:需要安全地处理各种状态,避免 IllegalStateChangeError
|
||||
if db is not None:
|
||||
db = getattr(request.state, "db", None)
|
||||
if isinstance(db, Session):
|
||||
try:
|
||||
# 检查会话是否可以安全地进行回滚
|
||||
# 只有当没有进行中的事务操作时才尝试回滚
|
||||
if db.is_active and not db.get_transaction().is_active:
|
||||
# 事务不在活跃状态,可以安全回滚
|
||||
pass
|
||||
elif db.is_active:
|
||||
# 事务在活跃状态,尝试回滚
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception as rollback_error:
|
||||
# 回滚失败(可能是 commit 正在进行中),忽略错误
|
||||
logger.debug(f"Rollback skipped: {rollback_error}")
|
||||
except Exception:
|
||||
# 检查状态时出错,忽略
|
||||
pass
|
||||
|
||||
# 通过触发生成器的 finally 块来关闭会话(标准模式)
|
||||
# 这会调用 get_db() 的 finally 块,执行 db.close()
|
||||
try:
|
||||
next(db_gen, None)
|
||||
except StopIteration:
|
||||
# 正常情况:生成器已耗尽
|
||||
pass
|
||||
except Exception as cleanup_error:
|
||||
# 忽略 IllegalStateChangeError 等清理错误
|
||||
# 这些错误通常是由于事务状态不一致导致的,不影响业务逻辑
|
||||
if "IllegalStateChangeError" not in str(type(cleanup_error).__name__):
|
||||
logger.warning(f"Database cleanup warning: {cleanup_error}")
|
||||
db.close()
|
||||
except Exception as close_error:
|
||||
# 连接池会处理连接的回收,这里的异常不应影响响应
|
||||
logger.debug(f"关闭数据库连接时出错(可忽略): {close_error}")
|
||||
|
||||
# 在 finally 块之后处理异常和响应
|
||||
if exception_to_raise:
|
||||
@@ -250,7 +226,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
return False
|
||||
|
||||
async def _get_rate_limit_key_and_config(
|
||||
self, request: Request, db: Session
|
||||
self, request: Request
|
||||
) -> tuple[Optional[str], Optional[int]]:
|
||||
"""
|
||||
获取速率限制的key和配置
|
||||
@@ -318,14 +294,8 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
# 如果没有限流插件,允许通过
|
||||
return None
|
||||
|
||||
# 获取数据库会话
|
||||
db = getattr(request.state, "db", None)
|
||||
if not db:
|
||||
logger.warning("速率限制检查:无法获取数据库会话")
|
||||
return None
|
||||
|
||||
# 获取速率限制的key和配置(从数据库)
|
||||
key, rate_limit_value = await self._get_rate_limit_key_and_config(request, db)
|
||||
# 获取速率限制的 key 和配置
|
||||
key, rate_limit_value = await self._get_rate_limit_key_and_config(request)
|
||||
if not key:
|
||||
# 不需要限流的端点(如未分类路径),静默跳过
|
||||
return None
|
||||
@@ -336,7 +306,7 @@ class PluginMiddleware(BaseHTTPMiddleware):
|
||||
key=key,
|
||||
endpoint=request.url.path,
|
||||
method=request.method,
|
||||
rate_limit=rate_limit_value, # 传入数据库配置的限制值
|
||||
rate_limit=rate_limit_value, # 传入配置的限制值
|
||||
)
|
||||
# 类型检查:确保返回的是RateLimitResult类型
|
||||
if isinstance(result, RateLimitResult):
|
||||
|
||||
@@ -107,20 +107,6 @@ class CreateProviderRequest(BaseModel):
|
||||
if not re.match(r"^https?://", v, re.IGNORECASE):
|
||||
v = f"https://{v}"
|
||||
|
||||
# 防止 SSRF 攻击:禁止内网地址
|
||||
forbidden_patterns = [
|
||||
r"localhost",
|
||||
r"127\.0\.0\.1",
|
||||
r"0\.0\.0\.0",
|
||||
r"192\.168\.",
|
||||
r"10\.",
|
||||
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
|
||||
r"169\.254\.",
|
||||
]
|
||||
for pattern in forbidden_patterns:
|
||||
if re.search(pattern, v, re.IGNORECASE):
|
||||
raise ValueError("不允许使用内网地址")
|
||||
|
||||
return v
|
||||
|
||||
@field_validator("billing_type")
|
||||
@@ -195,19 +181,6 @@ class CreateEndpointRequest(BaseModel):
|
||||
if not re.match(r"^https?://", v, re.IGNORECASE):
|
||||
raise ValueError("URL 必须以 http:// 或 https:// 开头")
|
||||
|
||||
# 防止 SSRF
|
||||
forbidden_patterns = [
|
||||
r"localhost",
|
||||
r"127\.0\.0\.1",
|
||||
r"0\.0\.0\.0",
|
||||
r"192\.168\.",
|
||||
r"10\.",
|
||||
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
|
||||
]
|
||||
for pattern in forbidden_patterns:
|
||||
if re.search(pattern, v, re.IGNORECASE):
|
||||
raise ValueError("不允许使用内网地址")
|
||||
|
||||
return v.rstrip("/") # 移除末尾斜杠
|
||||
|
||||
@field_validator("api_format")
|
||||
|
||||
@@ -6,7 +6,7 @@ import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from ..core.enums import UserRole
|
||||
|
||||
@@ -336,8 +336,7 @@ class ProviderResponse(BaseModel):
|
||||
active_models_count: int = 0
|
||||
api_keys_count: int = 0
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# ========== 模型管理 ==========
|
||||
@@ -442,8 +441,7 @@ class ModelResponse(BaseModel):
|
||||
global_model_name: Optional[str] = None
|
||||
global_model_display_name: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ModelDetailResponse(BaseModel):
|
||||
@@ -469,8 +467,7 @@ class ModelDetailResponse(BaseModel):
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# ========== 系统设置 ==========
|
||||
@@ -562,20 +559,15 @@ class PublicGlobalModelResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
display_name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
icon_url: Optional[str] = None
|
||||
is_active: bool = True
|
||||
# 按次计费配置
|
||||
default_price_per_request: Optional[float] = None
|
||||
# 阶梯计费配置
|
||||
default_tiered_pricing: Optional[dict] = None
|
||||
# 默认能力
|
||||
default_supports_vision: bool = False
|
||||
default_supports_function_calling: bool = False
|
||||
default_supports_streaming: bool = True
|
||||
default_supports_extended_thinking: bool = False
|
||||
# Key 能力配置
|
||||
supported_capabilities: Optional[List[str]] = None
|
||||
# 模型配置(JSON)
|
||||
config: Optional[dict] = None
|
||||
|
||||
|
||||
class PublicGlobalModelListResponse(BaseModel):
|
||||
|
||||
@@ -5,7 +5,7 @@ Provider API Key相关的API模型
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class ProviderAPIKeyBase(BaseModel):
|
||||
@@ -53,8 +53,7 @@ class ProviderAPIKeyResponse(ProviderAPIKeyBase):
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ProviderAPIKeyStats(BaseModel):
|
||||
|
||||
@@ -26,8 +26,8 @@ from sqlalchemy import (
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import declarative_base, relationship
|
||||
|
||||
from ..config import config
|
||||
from ..core.enums import ProviderBillingType, UserRole
|
||||
@@ -576,11 +576,6 @@ class GlobalModel(Base):
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True)
|
||||
name = Column(String(100), unique=True, nullable=False, index=True) # 统一模型名(唯一)
|
||||
display_name = Column(String(100), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
# 模型元数据
|
||||
icon_url = Column(String(500), nullable=True)
|
||||
official_url = Column(String(500), nullable=True) # 官方文档链接
|
||||
|
||||
# 按次计费配置(每次请求的固定费用,美元)- 可选,与按 token 计费叠加
|
||||
default_price_per_request = Column(Float, nullable=True, default=None) # 每次请求固定费用
|
||||
@@ -606,17 +601,34 @@ class GlobalModel(Base):
|
||||
# }
|
||||
default_tiered_pricing = Column(JSON, nullable=False)
|
||||
|
||||
# 默认能力配置 - Provider 可覆盖
|
||||
default_supports_vision = Column(Boolean, default=False, nullable=True)
|
||||
default_supports_function_calling = Column(Boolean, default=False, nullable=True)
|
||||
default_supports_streaming = Column(Boolean, default=True, nullable=True)
|
||||
default_supports_extended_thinking = Column(Boolean, default=False, nullable=True)
|
||||
default_supports_image_generation = Column(Boolean, default=False, nullable=True)
|
||||
|
||||
# Key 能力配置 - 模型支持的能力列表(如 ["cache_1h", "context_1m"])
|
||||
# Key 只能启用模型支持的能力
|
||||
supported_capabilities = Column(JSON, nullable=True, default=list)
|
||||
|
||||
# 模型配置(JSON格式)- 包含能力、规格、元信息等
|
||||
# 结构示例:
|
||||
# {
|
||||
# # 能力配置
|
||||
# "streaming": true,
|
||||
# "vision": true,
|
||||
# "function_calling": true,
|
||||
# "extended_thinking": false,
|
||||
# "image_generation": false,
|
||||
# # 规格参数
|
||||
# "context_limit": 200000,
|
||||
# "output_limit": 8192,
|
||||
# # 元信息
|
||||
# "description": "...",
|
||||
# "icon_url": "...",
|
||||
# "official_url": "...",
|
||||
# "knowledge_cutoff": "2024-04",
|
||||
# "family": "claude-3.5",
|
||||
# "release_date": "2024-10-22",
|
||||
# "input_modalities": ["text", "image"],
|
||||
# "output_modalities": ["text"],
|
||||
# }
|
||||
config = Column(JSONB, nullable=True, default=dict)
|
||||
|
||||
# 状态
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
|
||||
@@ -767,11 +779,22 @@ class Model(Base):
|
||||
"""获取有效的能力配置(通用辅助方法)"""
|
||||
local_value = getattr(self, attr_name, None)
|
||||
if local_value is not None:
|
||||
return local_value
|
||||
return bool(local_value)
|
||||
if self.global_model:
|
||||
global_value = getattr(self.global_model, f"default_{attr_name}", None)
|
||||
if global_value is not None:
|
||||
return global_value
|
||||
config_key_map = {
|
||||
"supports_vision": "vision",
|
||||
"supports_function_calling": "function_calling",
|
||||
"supports_streaming": "streaming",
|
||||
"supports_extended_thinking": "extended_thinking",
|
||||
"supports_image_generation": "image_generation",
|
||||
}
|
||||
config_key = config_key_map.get(attr_name)
|
||||
if config_key:
|
||||
global_config = getattr(self.global_model, "config", None)
|
||||
if isinstance(global_config, dict):
|
||||
global_value = global_config.get(config_key)
|
||||
if global_value is not None:
|
||||
return bool(global_value)
|
||||
return default
|
||||
|
||||
def get_effective_supports_vision(self) -> bool:
|
||||
@@ -789,7 +812,9 @@ class Model(Base):
|
||||
def get_effective_supports_image_generation(self) -> bool:
|
||||
return self._get_effective_capability("supports_image_generation", False)
|
||||
|
||||
def select_provider_model_name(self, affinity_key: Optional[str] = None) -> str:
|
||||
def select_provider_model_name(
|
||||
self, affinity_key: Optional[str] = None, api_format: Optional[str] = None
|
||||
) -> str:
|
||||
"""按优先级选择要使用的 Provider 模型名称
|
||||
|
||||
如果配置了 provider_model_aliases,按优先级选择(数字越小越优先);
|
||||
@@ -798,6 +823,7 @@ class Model(Base):
|
||||
|
||||
Args:
|
||||
affinity_key: 用于哈希分散的亲和键(如用户 API Key 哈希),确保同一用户稳定选择同一别名
|
||||
api_format: 当前请求的 API 格式(如 CLAUDE、OPENAI 等),用于过滤适用的别名
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
@@ -816,6 +842,13 @@ class Model(Base):
|
||||
if not isinstance(name, str) or not name.strip():
|
||||
continue
|
||||
|
||||
# 检查 api_formats 作用域(如果配置了且当前有 api_format)
|
||||
alias_api_formats = raw.get("api_formats")
|
||||
if api_format and alias_api_formats:
|
||||
# 如果配置了作用域,只有匹配时才生效
|
||||
if isinstance(alias_api_formats, list) and api_format not in alias_api_formats:
|
||||
continue
|
||||
|
||||
raw_priority = raw.get("priority", 1)
|
||||
try:
|
||||
priority = int(raw_priority)
|
||||
|
||||
@@ -6,7 +6,7 @@ import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
# ========== ProviderEndpoint CRUD ==========
|
||||
|
||||
@@ -45,24 +45,9 @@ class ProviderEndpointCreate(BaseModel):
|
||||
@field_validator("base_url")
|
||||
@classmethod
|
||||
def validate_base_url(cls, v: str) -> str:
|
||||
"""验证 API URL(SSRF 防护)"""
|
||||
if not re.match(r"^https?://", v, re.IGNORECASE):
|
||||
raise ValueError("URL 必须以 http:// 或 https:// 开头")
|
||||
|
||||
# 防止 SSRF 攻击:禁止内网地址
|
||||
forbidden_patterns = [
|
||||
r"localhost",
|
||||
r"127\.0\.0\.1",
|
||||
r"0\.0\.0\.0",
|
||||
r"192\.168\.",
|
||||
r"10\.",
|
||||
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
|
||||
r"169\.254\.",
|
||||
]
|
||||
for pattern in forbidden_patterns:
|
||||
if re.search(pattern, v, re.IGNORECASE):
|
||||
raise ValueError("不允许使用内网地址")
|
||||
|
||||
return v.rstrip("/") # 移除末尾斜杠
|
||||
|
||||
|
||||
@@ -83,27 +68,13 @@ class ProviderEndpointUpdate(BaseModel):
|
||||
@field_validator("base_url")
|
||||
@classmethod
|
||||
def validate_base_url(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""验证 API URL(SSRF 防护)"""
|
||||
"""验证 API URL"""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
if not re.match(r"^https?://", v, re.IGNORECASE):
|
||||
raise ValueError("URL 必须以 http:// 或 https:// 开头")
|
||||
|
||||
# 防止 SSRF 攻击:禁止内网地址
|
||||
forbidden_patterns = [
|
||||
r"localhost",
|
||||
r"127\.0\.0\.1",
|
||||
r"0\.0\.0\.0",
|
||||
r"192\.168\.",
|
||||
r"10\.",
|
||||
r"172\.(1[6-9]|2[0-9]|3[0-1])\.",
|
||||
r"169\.254\.",
|
||||
]
|
||||
for pattern in forbidden_patterns:
|
||||
if re.search(pattern, v, re.IGNORECASE):
|
||||
raise ValueError("不允许使用内网地址")
|
||||
|
||||
return v.rstrip("/") # 移除末尾斜杠
|
||||
|
||||
|
||||
@@ -141,8 +112,7 @@ class ProviderEndpointResponse(BaseModel):
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# ========== ProviderAPIKey 相关(新架构) ==========
|
||||
@@ -384,8 +354,7 @@ class EndpointAPIKeyResponse(BaseModel):
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# ========== 健康监控相关 ==========
|
||||
@@ -535,8 +504,7 @@ class ProviderWithEndpointsSummary(BaseModel):
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# ========== 健康监控可视化模型 ==========
|
||||
|
||||
@@ -5,7 +5,7 @@ Pydantic 数据模型(阶段一统一模型管理)
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
|
||||
# ========== 阶梯计费相关模型 ==========
|
||||
@@ -187,9 +187,6 @@ class GlobalModelCreate(BaseModel):
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=100, description="统一模型名(唯一)")
|
||||
display_name: str = Field(..., min_length=1, max_length=100, description="显示名称")
|
||||
description: Optional[str] = Field(None, description="模型描述")
|
||||
official_url: Optional[str] = Field(None, max_length=500, description="官方文档链接")
|
||||
icon_url: Optional[str] = Field(None, max_length=500, description="图标 URL")
|
||||
# 按次计费配置(可选,与阶梯计费叠加)
|
||||
default_price_per_request: Optional[float] = Field(None, ge=0, description="每次请求固定费用")
|
||||
# 统一阶梯计费配置(必填)
|
||||
@@ -197,22 +194,15 @@ class GlobalModelCreate(BaseModel):
|
||||
default_tiered_pricing: TieredPricingConfig = Field(
|
||||
..., description="阶梯计费配置(固定价格用单阶梯表示)"
|
||||
)
|
||||
# 默认能力配置
|
||||
default_supports_vision: Optional[bool] = Field(False, description="默认是否支持视觉")
|
||||
default_supports_function_calling: Optional[bool] = Field(
|
||||
False, description="默认是否支持函数调用"
|
||||
)
|
||||
default_supports_streaming: Optional[bool] = Field(True, description="默认是否支持流式输出")
|
||||
default_supports_extended_thinking: Optional[bool] = Field(
|
||||
False, description="默认是否支持扩展思考"
|
||||
)
|
||||
default_supports_image_generation: Optional[bool] = Field(
|
||||
False, description="默认是否支持图像生成"
|
||||
)
|
||||
# Key 能力配置 - 模型支持的能力列表(如 ["cache_1h", "context_1m"])
|
||||
supported_capabilities: Optional[List[str]] = Field(
|
||||
None, description="支持的 Key 能力列表"
|
||||
)
|
||||
# 模型配置(JSON格式)- 包含能力、规格、元信息等
|
||||
config: Optional[Dict[str, Any]] = Field(
|
||||
None,
|
||||
description="模型配置(streaming, vision, context_limit, description 等)"
|
||||
)
|
||||
is_active: Optional[bool] = Field(True, description="是否激活")
|
||||
|
||||
|
||||
@@ -220,9 +210,6 @@ class GlobalModelUpdate(BaseModel):
|
||||
"""更新 GlobalModel 请求"""
|
||||
|
||||
display_name: Optional[str] = Field(None, min_length=1, max_length=100)
|
||||
description: Optional[str] = None
|
||||
official_url: Optional[str] = Field(None, max_length=500)
|
||||
icon_url: Optional[str] = Field(None, max_length=500)
|
||||
is_active: Optional[bool] = None
|
||||
# 按次计费配置
|
||||
default_price_per_request: Optional[float] = Field(None, ge=0, description="每次请求固定费用")
|
||||
@@ -230,16 +217,15 @@ class GlobalModelUpdate(BaseModel):
|
||||
default_tiered_pricing: Optional[TieredPricingConfig] = Field(
|
||||
None, description="阶梯计费配置"
|
||||
)
|
||||
# 默认能力配置
|
||||
default_supports_vision: Optional[bool] = None
|
||||
default_supports_function_calling: Optional[bool] = None
|
||||
default_supports_streaming: Optional[bool] = None
|
||||
default_supports_extended_thinking: Optional[bool] = None
|
||||
default_supports_image_generation: Optional[bool] = None
|
||||
# Key 能力配置 - 模型支持的能力列表(如 ["cache_1h", "context_1m"])
|
||||
supported_capabilities: Optional[List[str]] = Field(
|
||||
None, description="支持的 Key 能力列表"
|
||||
)
|
||||
# 模型配置(JSON格式)- 包含能力、规格、元信息等
|
||||
config: Optional[Dict[str, Any]] = Field(
|
||||
None,
|
||||
description="模型配置(streaming, vision, context_limit, description 等)"
|
||||
)
|
||||
|
||||
|
||||
class GlobalModelResponse(BaseModel):
|
||||
@@ -248,34 +234,29 @@ class GlobalModelResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str]
|
||||
official_url: Optional[str]
|
||||
icon_url: Optional[str]
|
||||
is_active: bool
|
||||
# 按次计费配置
|
||||
default_price_per_request: Optional[float] = Field(None, description="每次请求固定费用")
|
||||
# 阶梯计费配置
|
||||
default_tiered_pricing: TieredPricingConfig = Field(
|
||||
..., description="阶梯计费配置"
|
||||
default_tiered_pricing: Optional[TieredPricingConfig] = Field(
|
||||
default=None, description="阶梯计费配置"
|
||||
)
|
||||
# 默认能力配置
|
||||
default_supports_vision: Optional[bool]
|
||||
default_supports_function_calling: Optional[bool]
|
||||
default_supports_streaming: Optional[bool]
|
||||
default_supports_extended_thinking: Optional[bool]
|
||||
default_supports_image_generation: Optional[bool]
|
||||
# Key 能力配置 - 模型支持的能力列表
|
||||
supported_capabilities: Optional[List[str]] = Field(
|
||||
default=None, description="支持的 Key 能力列表"
|
||||
)
|
||||
# 模型配置(JSON格式)
|
||||
config: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="模型配置(streaming, vision, context_limit, description 等)"
|
||||
)
|
||||
# 统计数据(可选)
|
||||
provider_count: Optional[int] = Field(default=0, description="支持的 Provider 数量")
|
||||
usage_count: Optional[int] = Field(default=0, description="调用次数")
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class GlobalModelWithStats(GlobalModelResponse):
|
||||
|
||||
@@ -51,7 +51,7 @@ class JwtAuthPlugin(AuthPlugin):
|
||||
|
||||
try:
|
||||
# 验证JWT token
|
||||
payload = AuthService.verify_token(token)
|
||||
payload = await AuthService.verify_token(token, token_type="access")
|
||||
logger.debug(f"JWT token验证成功, payload: {payload}")
|
||||
|
||||
# 从payload中提取用户信息
|
||||
|
||||
@@ -93,8 +93,8 @@ class AuthService:
|
||||
@staticmethod
|
||||
async def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
|
||||
"""用户登录认证"""
|
||||
# 使用缓存查询用户
|
||||
user = await UserCacheService.get_user_by_email(db, email)
|
||||
# 登录校验必须读取密码哈希,不能使用不包含 password_hash 的缓存对象
|
||||
user = db.query(User).filter(User.email == email).first()
|
||||
|
||||
if not user:
|
||||
logger.warning(f"登录失败 - 用户不存在: {email}")
|
||||
@@ -109,13 +109,10 @@ class AuthService:
|
||||
return None
|
||||
|
||||
# 更新最后登录时间
|
||||
# 需要重新从数据库获取以便更新(缓存的对象是分离的)
|
||||
db_user = db.query(User).filter(User.id == user.id).first()
|
||||
if db_user:
|
||||
db_user.last_login_at = datetime.now(timezone.utc)
|
||||
db.commit() # 立即提交事务,释放数据库锁
|
||||
# 清除缓存,因为用户信息已更新
|
||||
await UserCacheService.invalidate_user_cache(user.id, user.email)
|
||||
user.last_login_at = datetime.now(timezone.utc)
|
||||
db.commit() # 立即提交事务,释放数据库锁
|
||||
# 清除缓存,因为用户信息已更新
|
||||
await UserCacheService.invalidate_user_cache(user.id, user.email)
|
||||
|
||||
logger.info(f"用户登录成功: {email} (ID: {user.id})")
|
||||
return user
|
||||
@@ -198,7 +195,10 @@ class AuthService:
|
||||
if user.role == UserRole.ADMIN:
|
||||
return True
|
||||
|
||||
if user.role.value >= required_role.value:
|
||||
# 避免使用字符串比较导致权限判断错误(例如 'user' >= 'admin')
|
||||
role_rank = {UserRole.USER: 0, UserRole.ADMIN: 1}
|
||||
# 未知用户角色默认 -1(拒绝),未知要求角色默认 999(拒绝)
|
||||
if role_rank.get(user.role, -1) >= role_rank.get(required_role, 999):
|
||||
return True
|
||||
|
||||
logger.warning(f"权限不足: 用户 {user.email} 角色 {user.role.value} < 需要 {required_role.value}")
|
||||
@@ -230,7 +230,7 @@ class AuthService:
|
||||
)
|
||||
|
||||
if success:
|
||||
user_id = payload.get("sub")
|
||||
user_id = payload.get("user_id")
|
||||
logger.info(f"用户登出成功: user_id={user_id}")
|
||||
|
||||
return success
|
||||
|
||||
71
src/services/cache/aware_scheduler.py
vendored
71
src/services/cache/aware_scheduler.py
vendored
@@ -59,7 +59,6 @@ from src.services.health.monitor import health_monitor
|
||||
from src.services.provider.format import normalize_api_format
|
||||
from src.services.rate_limit.adaptive_reservation import (
|
||||
AdaptiveReservationManager,
|
||||
ReservationResult,
|
||||
get_adaptive_reservation_manager,
|
||||
)
|
||||
from src.services.rate_limit.concurrency_manager import get_concurrency_manager
|
||||
@@ -112,8 +111,6 @@ class CacheAwareScheduler:
|
||||
- 健康度监控
|
||||
"""
|
||||
|
||||
# 静态常量作为默认值(实际由 AdaptiveReservationManager 动态计算)
|
||||
CACHE_RESERVATION_RATIO = 0.3
|
||||
# 优先级模式常量
|
||||
PRIORITY_MODE_PROVIDER = "provider" # 提供商优先模式
|
||||
PRIORITY_MODE_GLOBAL_KEY = "global_key" # 全局 Key 优先模式
|
||||
@@ -121,8 +118,17 @@ class CacheAwareScheduler:
|
||||
PRIORITY_MODE_PROVIDER,
|
||||
PRIORITY_MODE_GLOBAL_KEY,
|
||||
}
|
||||
# 调度模式常量
|
||||
SCHEDULING_MODE_FIXED_ORDER = "fixed_order" # 固定顺序模式
|
||||
SCHEDULING_MODE_CACHE_AFFINITY = "cache_affinity" # 缓存亲和模式
|
||||
ALLOWED_SCHEDULING_MODES = {
|
||||
SCHEDULING_MODE_FIXED_ORDER,
|
||||
SCHEDULING_MODE_CACHE_AFFINITY,
|
||||
}
|
||||
|
||||
def __init__(self, redis_client=None, priority_mode: Optional[str] = None):
|
||||
def __init__(
|
||||
self, redis_client=None, priority_mode: Optional[str] = None, scheduling_mode: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
初始化调度器
|
||||
|
||||
@@ -132,12 +138,16 @@ class CacheAwareScheduler:
|
||||
Args:
|
||||
redis_client: Redis客户端(可选)
|
||||
priority_mode: 候选排序策略(provider | global_key)
|
||||
scheduling_mode: 调度模式(fixed_order | cache_affinity)
|
||||
"""
|
||||
self.redis = redis_client
|
||||
self.priority_mode = self._normalize_priority_mode(
|
||||
priority_mode or self.PRIORITY_MODE_PROVIDER
|
||||
)
|
||||
logger.debug(f"[CacheAwareScheduler] 初始化优先级模式: {self.priority_mode}")
|
||||
self.scheduling_mode = self._normalize_scheduling_mode(
|
||||
scheduling_mode or self.SCHEDULING_MODE_CACHE_AFFINITY
|
||||
)
|
||||
logger.debug(f"[CacheAwareScheduler] 初始化优先级模式: {self.priority_mode}, 调度模式: {self.scheduling_mode}")
|
||||
|
||||
# 初始化子组件(将在第一次使用时异步初始化)
|
||||
self._affinity_manager: Optional[CacheAffinityManager] = None
|
||||
@@ -673,14 +683,19 @@ class CacheAwareScheduler:
|
||||
f"(api_format={target_format.value}, model={model_name})"
|
||||
)
|
||||
|
||||
# 4. 应用缓存亲和性排序(使用 global_model_id 作为模型标识)
|
||||
if affinity_key and candidates:
|
||||
candidates = await self._apply_cache_affinity(
|
||||
candidates=candidates,
|
||||
affinity_key=affinity_key,
|
||||
api_format=target_format,
|
||||
global_model_id=global_model_id,
|
||||
)
|
||||
# 4. 应用缓存亲和性排序(仅在缓存亲和模式下启用)
|
||||
if self.scheduling_mode == self.SCHEDULING_MODE_CACHE_AFFINITY:
|
||||
if affinity_key and candidates:
|
||||
candidates = await self._apply_cache_affinity(
|
||||
candidates=candidates,
|
||||
affinity_key=affinity_key,
|
||||
api_format=target_format,
|
||||
global_model_id=global_model_id,
|
||||
)
|
||||
else:
|
||||
# 固定顺序模式:标记所有候选为非缓存
|
||||
for candidate in candidates:
|
||||
candidate.is_cached = False
|
||||
|
||||
return candidates, global_model_id
|
||||
|
||||
@@ -1060,6 +1075,22 @@ class CacheAwareScheduler:
|
||||
self.priority_mode = normalized
|
||||
logger.debug(f"[CacheAwareScheduler] 切换优先级模式为: {self.priority_mode}")
|
||||
|
||||
def _normalize_scheduling_mode(self, mode: Optional[str]) -> str:
|
||||
normalized = (mode or "").strip().lower()
|
||||
if normalized not in self.ALLOWED_SCHEDULING_MODES:
|
||||
if normalized:
|
||||
logger.warning(f"[CacheAwareScheduler] 无效的调度模式 '{mode}',回退为 cache_affinity")
|
||||
return self.SCHEDULING_MODE_CACHE_AFFINITY
|
||||
return normalized
|
||||
|
||||
def set_scheduling_mode(self, mode: Optional[str]) -> None:
|
||||
"""运行时更新调度模式"""
|
||||
normalized = self._normalize_scheduling_mode(mode)
|
||||
if normalized == self.scheduling_mode:
|
||||
return
|
||||
self.scheduling_mode = normalized
|
||||
logger.debug(f"[CacheAwareScheduler] 切换调度模式为: {self.scheduling_mode}")
|
||||
|
||||
def _apply_priority_mode_sort(
|
||||
self, candidates: List[ProviderCandidate], affinity_key: Optional[str] = None
|
||||
) -> List[ProviderCandidate]:
|
||||
@@ -1286,7 +1317,6 @@ class CacheAwareScheduler:
|
||||
|
||||
return {
|
||||
"scheduler": "cache_aware",
|
||||
"cache_reservation_ratio": self.CACHE_RESERVATION_RATIO,
|
||||
"dynamic_reservation": {
|
||||
"enabled": True,
|
||||
"config": reservation_stats["config"],
|
||||
@@ -1307,6 +1337,7 @@ _scheduler: Optional[CacheAwareScheduler] = None
|
||||
async def get_cache_aware_scheduler(
|
||||
redis_client=None,
|
||||
priority_mode: Optional[str] = None,
|
||||
scheduling_mode: Optional[str] = None,
|
||||
) -> CacheAwareScheduler:
|
||||
"""
|
||||
获取全局CacheAwareScheduler实例
|
||||
@@ -1317,6 +1348,7 @@ async def get_cache_aware_scheduler(
|
||||
Args:
|
||||
redis_client: Redis客户端(可选)
|
||||
priority_mode: 外部覆盖的优先级模式(provider | global_key)
|
||||
scheduling_mode: 外部覆盖的调度模式(fixed_order | cache_affinity)
|
||||
|
||||
Returns:
|
||||
CacheAwareScheduler实例
|
||||
@@ -1324,8 +1356,13 @@ async def get_cache_aware_scheduler(
|
||||
global _scheduler
|
||||
|
||||
if _scheduler is None:
|
||||
_scheduler = CacheAwareScheduler(redis_client, priority_mode=priority_mode)
|
||||
elif priority_mode:
|
||||
_scheduler.set_priority_mode(priority_mode)
|
||||
_scheduler = CacheAwareScheduler(
|
||||
redis_client, priority_mode=priority_mode, scheduling_mode=scheduling_mode
|
||||
)
|
||||
else:
|
||||
if priority_mode:
|
||||
_scheduler.set_priority_mode(priority_mode)
|
||||
if scheduling_mode:
|
||||
_scheduler.set_scheduling_mode(scheduling_mode)
|
||||
|
||||
return _scheduler
|
||||
|
||||
52
src/services/cache/model_cache.py
vendored
52
src/services/cache/model_cache.py
vendored
@@ -1,5 +1,21 @@
|
||||
"""
|
||||
Model 映射缓存服务 - 减少模型查询
|
||||
|
||||
架构说明
|
||||
========
|
||||
本服务采用混合 async/sync 模式:
|
||||
- 缓存操作(CacheService):真正的 async,使用 aioredis
|
||||
- 数据库查询(db.query):同步的 SQLAlchemy Session
|
||||
|
||||
设计决策
|
||||
--------
|
||||
1. 保持 async 方法签名:因为缓存命中时完全异步,性能最优
|
||||
2. 缓存未命中时的同步查询:FastAPI 会在线程池中执行,不会阻塞事件循环
|
||||
3. 调用方必须在 async 上下文中使用 await
|
||||
|
||||
使用示例
|
||||
--------
|
||||
global_model = await ModelCacheService.resolve_global_model_by_name_or_alias(db, "gpt-4")
|
||||
"""
|
||||
|
||||
import time
|
||||
@@ -19,7 +35,11 @@ from src.models.database import GlobalModel, Model
|
||||
|
||||
|
||||
class ModelCacheService:
|
||||
"""Model 映射缓存服务"""
|
||||
"""Model 映射缓存服务
|
||||
|
||||
提供 GlobalModel 和 Model 的缓存查询功能,减少数据库访问。
|
||||
所有公开方法均为 async,需要在 async 上下文中调用。
|
||||
"""
|
||||
|
||||
# 缓存 TTL(秒)- 使用统一常量
|
||||
CACHE_TTL = CacheTTL.MODEL
|
||||
@@ -385,7 +405,7 @@ class ModelCacheService:
|
||||
"is_active": model.is_active,
|
||||
"is_available": model.is_available if hasattr(model, "is_available") else True,
|
||||
"price_per_request": (
|
||||
float(model.price_per_request) if model.price_per_request else None
|
||||
float(model.price_per_request) if model.price_per_request is not None else None
|
||||
),
|
||||
"tiered_pricing": model.tiered_pricing,
|
||||
"supports_vision": model.supports_vision,
|
||||
@@ -425,14 +445,15 @@ class ModelCacheService:
|
||||
"id": global_model.id,
|
||||
"name": global_model.name,
|
||||
"display_name": global_model.display_name,
|
||||
"default_supports_vision": global_model.default_supports_vision,
|
||||
"default_supports_function_calling": global_model.default_supports_function_calling,
|
||||
"default_supports_streaming": global_model.default_supports_streaming,
|
||||
"default_supports_extended_thinking": global_model.default_supports_extended_thinking,
|
||||
"default_supports_image_generation": global_model.default_supports_image_generation,
|
||||
"supported_capabilities": global_model.supported_capabilities,
|
||||
"config": global_model.config,
|
||||
"default_tiered_pricing": global_model.default_tiered_pricing,
|
||||
"default_price_per_request": (
|
||||
float(global_model.default_price_per_request)
|
||||
if global_model.default_price_per_request is not None
|
||||
else None
|
||||
),
|
||||
"is_active": global_model.is_active,
|
||||
"description": global_model.description,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@@ -442,19 +463,10 @@ class ModelCacheService:
|
||||
id=global_model_dict["id"],
|
||||
name=global_model_dict["name"],
|
||||
display_name=global_model_dict.get("display_name"),
|
||||
default_supports_vision=global_model_dict.get("default_supports_vision", False),
|
||||
default_supports_function_calling=global_model_dict.get(
|
||||
"default_supports_function_calling", False
|
||||
),
|
||||
default_supports_streaming=global_model_dict.get("default_supports_streaming", True),
|
||||
default_supports_extended_thinking=global_model_dict.get(
|
||||
"default_supports_extended_thinking", False
|
||||
),
|
||||
default_supports_image_generation=global_model_dict.get(
|
||||
"default_supports_image_generation", False
|
||||
),
|
||||
supported_capabilities=global_model_dict.get("supported_capabilities") or [],
|
||||
config=global_model_dict.get("config"),
|
||||
default_tiered_pricing=global_model_dict.get("default_tiered_pricing"),
|
||||
default_price_per_request=global_model_dict.get("default_price_per_request"),
|
||||
is_active=global_model_dict.get("is_active", True),
|
||||
description=global_model_dict.get("description"),
|
||||
)
|
||||
return global_model
|
||||
|
||||
254
src/services/cache/provider_cache.py
vendored
254
src/services/cache/provider_cache.py
vendored
@@ -1,254 +0,0 @@
|
||||
"""
|
||||
Provider 配置缓存服务 - 减少 Provider/Endpoint/APIKey 查询
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.config.constants import CacheTTL
|
||||
from src.core.cache_service import CacheKeys, CacheService
|
||||
from src.core.logger import logger
|
||||
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
|
||||
|
||||
|
||||
|
||||
class ProviderCacheService:
|
||||
"""Provider 配置缓存服务"""
|
||||
|
||||
# 缓存 TTL(秒)- 使用统一常量
|
||||
CACHE_TTL = CacheTTL.PROVIDER
|
||||
|
||||
@staticmethod
|
||||
async def get_provider_by_id(db: Session, provider_id: str) -> Optional[Provider]:
|
||||
"""
|
||||
获取 Provider(带缓存)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
provider_id: Provider ID
|
||||
|
||||
Returns:
|
||||
Provider 对象或 None
|
||||
"""
|
||||
cache_key = CacheKeys.provider_by_id(provider_id)
|
||||
|
||||
# 1. 尝试从缓存获取
|
||||
cached_data = await CacheService.get(cache_key)
|
||||
if cached_data:
|
||||
logger.debug(f"Provider 缓存命中: {provider_id}")
|
||||
return ProviderCacheService._dict_to_provider(cached_data)
|
||||
|
||||
# 2. 缓存未命中,查询数据库
|
||||
provider = db.query(Provider).filter(Provider.id == provider_id).first()
|
||||
|
||||
# 3. 写入缓存
|
||||
if provider:
|
||||
provider_dict = ProviderCacheService._provider_to_dict(provider)
|
||||
await CacheService.set(
|
||||
cache_key, provider_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
|
||||
)
|
||||
logger.debug(f"Provider 已缓存: {provider_id}")
|
||||
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
async def get_endpoint_by_id(db: Session, endpoint_id: str) -> Optional[ProviderEndpoint]:
|
||||
"""
|
||||
获取 Endpoint(带缓存)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
endpoint_id: Endpoint ID
|
||||
|
||||
Returns:
|
||||
ProviderEndpoint 对象或 None
|
||||
"""
|
||||
cache_key = CacheKeys.endpoint_by_id(endpoint_id)
|
||||
|
||||
# 1. 尝试从缓存获取
|
||||
cached_data = await CacheService.get(cache_key)
|
||||
if cached_data:
|
||||
logger.debug(f"Endpoint 缓存命中: {endpoint_id}")
|
||||
return ProviderCacheService._dict_to_endpoint(cached_data)
|
||||
|
||||
# 2. 缓存未命中,查询数据库
|
||||
endpoint = db.query(ProviderEndpoint).filter(ProviderEndpoint.id == endpoint_id).first()
|
||||
|
||||
# 3. 写入缓存
|
||||
if endpoint:
|
||||
endpoint_dict = ProviderCacheService._endpoint_to_dict(endpoint)
|
||||
await CacheService.set(
|
||||
cache_key, endpoint_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
|
||||
)
|
||||
logger.debug(f"Endpoint 已缓存: {endpoint_id}")
|
||||
|
||||
return endpoint
|
||||
|
||||
@staticmethod
|
||||
async def get_api_key_by_id(db: Session, api_key_id: str) -> Optional[ProviderAPIKey]:
|
||||
"""
|
||||
获取 API Key(带缓存)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
api_key_id: API Key ID
|
||||
|
||||
Returns:
|
||||
ProviderAPIKey 对象或 None
|
||||
"""
|
||||
cache_key = CacheKeys.api_key_by_id(api_key_id)
|
||||
|
||||
# 1. 尝试从缓存获取
|
||||
cached_data = await CacheService.get(cache_key)
|
||||
if cached_data:
|
||||
logger.debug(f"API Key 缓存命中: {api_key_id}")
|
||||
return ProviderCacheService._dict_to_api_key(cached_data)
|
||||
|
||||
# 2. 缓存未命中,查询数据库
|
||||
api_key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == api_key_id).first()
|
||||
|
||||
# 3. 写入缓存
|
||||
if api_key:
|
||||
api_key_dict = ProviderCacheService._api_key_to_dict(api_key)
|
||||
await CacheService.set(
|
||||
cache_key, api_key_dict, ttl_seconds=ProviderCacheService.CACHE_TTL
|
||||
)
|
||||
logger.debug(f"API Key 已缓存: {api_key_id}")
|
||||
|
||||
return api_key
|
||||
|
||||
@staticmethod
|
||||
async def invalidate_provider_cache(provider_id: str):
|
||||
"""
|
||||
清除 Provider 缓存
|
||||
|
||||
Args:
|
||||
provider_id: Provider ID
|
||||
"""
|
||||
await CacheService.delete(CacheKeys.provider_by_id(provider_id))
|
||||
logger.debug(f"Provider 缓存已清除: {provider_id}")
|
||||
|
||||
@staticmethod
|
||||
async def invalidate_endpoint_cache(endpoint_id: str):
|
||||
"""
|
||||
清除 Endpoint 缓存
|
||||
|
||||
Args:
|
||||
endpoint_id: Endpoint ID
|
||||
"""
|
||||
await CacheService.delete(CacheKeys.endpoint_by_id(endpoint_id))
|
||||
logger.debug(f"Endpoint 缓存已清除: {endpoint_id}")
|
||||
|
||||
@staticmethod
|
||||
async def invalidate_api_key_cache(api_key_id: str):
|
||||
"""
|
||||
清除 API Key 缓存
|
||||
|
||||
Args:
|
||||
api_key_id: API Key ID
|
||||
"""
|
||||
await CacheService.delete(CacheKeys.api_key_by_id(api_key_id))
|
||||
logger.debug(f"API Key 缓存已清除: {api_key_id}")
|
||||
|
||||
@staticmethod
|
||||
def _provider_to_dict(provider: Provider) -> dict:
|
||||
"""将 Provider 对象转换为字典(用于缓存)"""
|
||||
return {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"api_format": provider.api_format,
|
||||
"base_url": provider.base_url,
|
||||
"is_active": provider.is_active,
|
||||
"priority": provider.priority,
|
||||
"rpm_limit": provider.rpm_limit,
|
||||
"rpm_used": provider.rpm_used,
|
||||
"rpm_reset_at": provider.rpm_reset_at.isoformat() if provider.rpm_reset_at else None,
|
||||
"config": provider.config,
|
||||
"description": provider.description,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _dict_to_provider(provider_dict: dict) -> Provider:
|
||||
"""从字典重建 Provider 对象(分离的对象,不在 Session 中)"""
|
||||
from datetime import datetime
|
||||
|
||||
provider = Provider(
|
||||
id=provider_dict["id"],
|
||||
name=provider_dict["name"],
|
||||
api_format=provider_dict["api_format"],
|
||||
base_url=provider_dict.get("base_url"),
|
||||
is_active=provider_dict["is_active"],
|
||||
priority=provider_dict.get("priority", 0),
|
||||
rpm_limit=provider_dict.get("rpm_limit"),
|
||||
rpm_used=provider_dict.get("rpm_used", 0),
|
||||
config=provider_dict.get("config"),
|
||||
description=provider_dict.get("description"),
|
||||
)
|
||||
|
||||
if provider_dict.get("rpm_reset_at"):
|
||||
provider.rpm_reset_at = datetime.fromisoformat(provider_dict["rpm_reset_at"])
|
||||
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
def _endpoint_to_dict(endpoint: ProviderEndpoint) -> dict:
|
||||
"""将 Endpoint 对象转换为字典"""
|
||||
return {
|
||||
"id": endpoint.id,
|
||||
"provider_id": endpoint.provider_id,
|
||||
"name": endpoint.name,
|
||||
"base_url": endpoint.base_url,
|
||||
"is_active": endpoint.is_active,
|
||||
"priority": endpoint.priority,
|
||||
"weight": endpoint.weight,
|
||||
"custom_path": endpoint.custom_path,
|
||||
"config": endpoint.config,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _dict_to_endpoint(endpoint_dict: dict) -> ProviderEndpoint:
|
||||
"""从字典重建 Endpoint 对象"""
|
||||
endpoint = ProviderEndpoint(
|
||||
id=endpoint_dict["id"],
|
||||
provider_id=endpoint_dict["provider_id"],
|
||||
name=endpoint_dict["name"],
|
||||
base_url=endpoint_dict["base_url"],
|
||||
is_active=endpoint_dict["is_active"],
|
||||
priority=endpoint_dict.get("priority", 0),
|
||||
weight=endpoint_dict.get("weight", 1.0),
|
||||
custom_path=endpoint_dict.get("custom_path"),
|
||||
config=endpoint_dict.get("config"),
|
||||
)
|
||||
return endpoint
|
||||
|
||||
@staticmethod
|
||||
def _api_key_to_dict(api_key: ProviderAPIKey) -> dict:
|
||||
"""将 API Key 对象转换为字典"""
|
||||
return {
|
||||
"id": api_key.id,
|
||||
"endpoint_id": api_key.endpoint_id,
|
||||
"key_value": api_key.key_value,
|
||||
"is_active": api_key.is_active,
|
||||
"max_rpm": api_key.max_rpm,
|
||||
"current_rpm": api_key.current_rpm,
|
||||
"health_score": api_key.health_score,
|
||||
"circuit_breaker_state": api_key.circuit_breaker_state,
|
||||
"adaptive_concurrency_limit": api_key.adaptive_concurrency_limit,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _dict_to_api_key(api_key_dict: dict) -> ProviderAPIKey:
|
||||
"""从字典重建 API Key 对象"""
|
||||
api_key = ProviderAPIKey(
|
||||
id=api_key_dict["id"],
|
||||
endpoint_id=api_key_dict["endpoint_id"],
|
||||
key_value=api_key_dict["key_value"],
|
||||
is_active=api_key_dict["is_active"],
|
||||
max_rpm=api_key_dict.get("max_rpm"),
|
||||
current_rpm=api_key_dict.get("current_rpm", 0),
|
||||
health_score=api_key_dict.get("health_score", 1.0),
|
||||
circuit_breaker_state=api_key_dict.get("circuit_breaker_state"),
|
||||
adaptive_concurrency_limit=api_key_dict.get("adaptive_concurrency_limit"),
|
||||
)
|
||||
return api_key
|
||||
24
src/services/cache/user_cache.py
vendored
24
src/services/cache/user_cache.py
vendored
@@ -1,5 +1,22 @@
|
||||
"""
|
||||
用户缓存服务 - 减少数据库查询
|
||||
|
||||
架构说明
|
||||
========
|
||||
本服务采用混合 async/sync 模式:
|
||||
- 缓存操作(CacheService):真正的 async,使用 aioredis
|
||||
- 数据库查询(db.query):同步的 SQLAlchemy Session
|
||||
|
||||
设计决策
|
||||
--------
|
||||
1. 保持 async 方法签名:因为缓存命中时完全异步,性能最优
|
||||
2. 缓存未命中时的同步查询:FastAPI 会在线程池中执行,不会阻塞事件循环
|
||||
3. 调用方必须在 async 上下文中使用 await
|
||||
|
||||
使用示例
|
||||
--------
|
||||
user = await UserCacheService.get_user_by_id(db, user_id)
|
||||
await UserCacheService.invalidate_user_cache(user_id, email)
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
@@ -12,9 +29,12 @@ from src.core.logger import logger
|
||||
from src.models.database import User
|
||||
|
||||
|
||||
|
||||
class UserCacheService:
|
||||
"""用户缓存服务"""
|
||||
"""用户缓存服务
|
||||
|
||||
提供 User 的缓存查询功能,减少数据库访问。
|
||||
所有公开方法均为 async,需要在 async 上下文中调用。
|
||||
"""
|
||||
|
||||
# 缓存 TTL(秒)- 使用统一常量
|
||||
CACHE_TTL = CacheTTL.USER
|
||||
|
||||
@@ -62,7 +62,6 @@ class GlobalModelService:
|
||||
query = query.filter(
|
||||
(GlobalModel.name.ilike(search_pattern))
|
||||
| (GlobalModel.display_name.ilike(search_pattern))
|
||||
| (GlobalModel.description.ilike(search_pattern))
|
||||
)
|
||||
|
||||
# 按名称排序
|
||||
@@ -75,21 +74,15 @@ class GlobalModelService:
|
||||
db: Session,
|
||||
name: str,
|
||||
display_name: str,
|
||||
description: Optional[str] = None,
|
||||
official_url: Optional[str] = None,
|
||||
icon_url: Optional[str] = None,
|
||||
is_active: Optional[bool] = True,
|
||||
# 按次计费配置
|
||||
default_price_per_request: Optional[float] = None,
|
||||
# 阶梯计费配置(必填)
|
||||
default_tiered_pricing: dict = None,
|
||||
# 默认能力配置
|
||||
default_supports_vision: Optional[bool] = None,
|
||||
default_supports_function_calling: Optional[bool] = None,
|
||||
default_supports_streaming: Optional[bool] = None,
|
||||
default_supports_extended_thinking: Optional[bool] = None,
|
||||
# Key 能力配置
|
||||
supported_capabilities: Optional[List[str]] = None,
|
||||
# 模型配置(JSON)
|
||||
config: Optional[dict] = None,
|
||||
) -> GlobalModel:
|
||||
"""创建 GlobalModel"""
|
||||
# 检查名称是否已存在
|
||||
@@ -100,21 +93,15 @@ class GlobalModelService:
|
||||
global_model = GlobalModel(
|
||||
name=name,
|
||||
display_name=display_name,
|
||||
description=description,
|
||||
official_url=official_url,
|
||||
icon_url=icon_url,
|
||||
is_active=is_active,
|
||||
# 按次计费配置
|
||||
default_price_per_request=default_price_per_request,
|
||||
# 阶梯计费配置
|
||||
default_tiered_pricing=default_tiered_pricing,
|
||||
# 默认能力配置
|
||||
default_supports_vision=default_supports_vision,
|
||||
default_supports_function_calling=default_supports_function_calling,
|
||||
default_supports_streaming=default_supports_streaming,
|
||||
default_supports_extended_thinking=default_supports_extended_thinking,
|
||||
# Key 能力配置
|
||||
supported_capabilities=supported_capabilities,
|
||||
# 模型配置(JSON)
|
||||
config=config,
|
||||
)
|
||||
|
||||
db.add(global_model)
|
||||
|
||||
@@ -69,24 +69,29 @@ class ErrorClassifier:
|
||||
# 这些错误是由用户请求本身导致的,换 Provider 也无济于事
|
||||
# 注意:标准 API 返回的 error.type 已在 CLIENT_ERROR_TYPES 中处理
|
||||
# 这里主要用于匹配非标准格式或第三方代理的错误消息
|
||||
#
|
||||
# 重要:不要在此列表中包含 Provider Key 配置问题(如 invalid_api_key)
|
||||
# 这类错误应该触发故障转移,而不是直接返回给用户
|
||||
CLIENT_ERROR_PATTERNS: Tuple[str, ...] = (
|
||||
"could not process image", # 图片处理失败
|
||||
"image too large", # 图片过大
|
||||
"invalid image", # 无效图片
|
||||
"unsupported image", # 不支持的图片格式
|
||||
"content_policy_violation", # 内容违规
|
||||
"invalid_api_key", # 无效的 API Key(不同于认证失败)
|
||||
"context_length_exceeded", # 上下文长度超限
|
||||
"content_length_limit", # 请求内容长度超限 (Claude API)
|
||||
"content_length_exceeds", # 内容长度超限变体 (AWS CodeWhisperer)
|
||||
"max_tokens", # token 数超限
|
||||
"invalid_prompt", # 无效的提示词
|
||||
"content too long", # 内容过长
|
||||
"input is too long", # 输入过长 (AWS)
|
||||
"message is too long", # 消息过长
|
||||
"prompt is too long", # Prompt 超长(第三方代理常见格式)
|
||||
"image exceeds", # 图片超出限制
|
||||
"pdf too large", # PDF 过大
|
||||
"file too large", # 文件过大
|
||||
"tool_use_id", # tool_result 引用了不存在的 tool_use(兼容非标准代理)
|
||||
"validationexception", # AWS 验证异常
|
||||
)
|
||||
|
||||
def __init__(
|
||||
@@ -110,18 +115,124 @@ class ErrorClassifier:
|
||||
# 表示客户端错误的 error type(不区分大小写)
|
||||
# 这些 type 表明是请求本身的问题,不应重试
|
||||
CLIENT_ERROR_TYPES: Tuple[str, ...] = (
|
||||
"invalid_request_error", # Claude/OpenAI 标准客户端错误类型
|
||||
"invalid_argument", # Gemini 参数错误
|
||||
"failed_precondition", # Gemini 前置条件错误
|
||||
# Claude/OpenAI 标准
|
||||
"invalid_request_error",
|
||||
# Gemini
|
||||
"invalid_argument",
|
||||
"failed_precondition",
|
||||
# AWS
|
||||
"validationexception",
|
||||
# 通用
|
||||
"validation_error",
|
||||
"bad_request",
|
||||
)
|
||||
|
||||
# 表示客户端错误的 reason/code 字段值
|
||||
CLIENT_ERROR_REASONS: Tuple[str, ...] = (
|
||||
"CONTENT_LENGTH_EXCEEDS_THRESHOLD",
|
||||
"CONTEXT_LENGTH_EXCEEDED",
|
||||
"MAX_TOKENS_EXCEEDED",
|
||||
"INVALID_CONTENT",
|
||||
"CONTENT_POLICY_VIOLATION",
|
||||
)
|
||||
|
||||
def _parse_error_response(self, error_text: Optional[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
解析错误响应为结构化数据
|
||||
|
||||
支持多种格式:
|
||||
- {"error": {"type": "...", "message": "..."}} (Claude/OpenAI)
|
||||
- {"error": {"message": "...", "__type": "..."}} (AWS)
|
||||
- {"errorMessage": "..."} (Lambda)
|
||||
- {"error": "..."}
|
||||
- {"message": "...", "reason": "..."}
|
||||
|
||||
Returns:
|
||||
结构化的错误信息: {
|
||||
"type": str, # 错误类型
|
||||
"message": str, # 错误消息
|
||||
"reason": str, # 错误原因/代码
|
||||
"raw": str, # 原始文本
|
||||
}
|
||||
"""
|
||||
result = {"type": "", "message": "", "reason": "", "raw": error_text or ""}
|
||||
|
||||
if not error_text:
|
||||
return result
|
||||
|
||||
try:
|
||||
data = json.loads(error_text)
|
||||
|
||||
# 格式 1: {"error": {"type": "...", "message": "..."}}
|
||||
if isinstance(data.get("error"), dict):
|
||||
error_obj = data["error"]
|
||||
result["type"] = str(error_obj.get("type", ""))
|
||||
result["message"] = str(error_obj.get("message", ""))
|
||||
|
||||
# AWS 格式: {"error": {"__type": "...", "message": "...", "reason": "..."}}
|
||||
# __type 直接在 error 对象中,而不是嵌套在 message 里
|
||||
if "__type" in error_obj:
|
||||
result["type"] = result["type"] or str(error_obj.get("__type", ""))
|
||||
if "reason" in error_obj:
|
||||
result["reason"] = str(error_obj.get("reason", ""))
|
||||
if "code" in error_obj:
|
||||
result["reason"] = result["reason"] or str(error_obj.get("code", ""))
|
||||
|
||||
# 嵌套 JSON 格式: message 字段本身是 JSON 字符串
|
||||
# 支持多种嵌套格式:
|
||||
# - AWS: {"__type": "...", "message": "...", "reason": "..."}
|
||||
# - 第三方代理: {"error": {"type": "...", "message": "..."}}
|
||||
if result["message"].startswith("{"):
|
||||
try:
|
||||
nested = json.loads(result["message"])
|
||||
if isinstance(nested, dict):
|
||||
# AWS 格式
|
||||
if "__type" in nested:
|
||||
result["type"] = result["type"] or str(nested.get("__type", ""))
|
||||
result["message"] = str(nested.get("message", result["message"]))
|
||||
result["reason"] = str(nested.get("reason", ""))
|
||||
# 第三方代理格式: {"error": {"message": "..."}}
|
||||
elif isinstance(nested.get("error"), dict):
|
||||
inner_error = nested["error"]
|
||||
inner_msg = str(inner_error.get("message", ""))
|
||||
if inner_msg:
|
||||
result["message"] = inner_msg
|
||||
# 简单格式: {"message": "..."}
|
||||
elif "message" in nested:
|
||||
result["message"] = str(nested["message"])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 格式 2: {"error": "..."}
|
||||
elif isinstance(data.get("error"), str):
|
||||
result["message"] = str(data["error"])
|
||||
|
||||
# 格式 3: {"errorMessage": "..."} (Lambda)
|
||||
elif "errorMessage" in data:
|
||||
result["message"] = str(data["errorMessage"])
|
||||
|
||||
# 格式 4: {"message": "...", "reason": "..."}
|
||||
elif "message" in data:
|
||||
result["message"] = str(data["message"])
|
||||
result["reason"] = str(data.get("reason", ""))
|
||||
|
||||
# 提取顶层的 reason/code
|
||||
if not result["reason"]:
|
||||
result["reason"] = str(data.get("reason", data.get("code", "")))
|
||||
|
||||
except (json.JSONDecodeError, TypeError, KeyError):
|
||||
result["message"] = error_text[:500] if len(error_text) > 500 else error_text
|
||||
|
||||
return result
|
||||
|
||||
def _is_client_error(self, error_text: Optional[str]) -> bool:
|
||||
"""
|
||||
检测错误响应是否为客户端错误(不应重试)
|
||||
|
||||
判断逻辑:
|
||||
判断逻辑(按优先级):
|
||||
1. 检查 error.type 是否为已知的客户端错误类型
|
||||
2. 检查错误文本是否包含已知的客户端错误模式
|
||||
2. 检查 reason/code 是否为已知的客户端错误原因
|
||||
3. 回退到关键词匹配
|
||||
|
||||
Args:
|
||||
error_text: 错误响应文本
|
||||
@@ -132,67 +243,53 @@ class ErrorClassifier:
|
||||
if not error_text:
|
||||
return False
|
||||
|
||||
# 尝试解析 JSON 并检查 error type
|
||||
try:
|
||||
data = json.loads(error_text)
|
||||
if isinstance(data.get("error"), dict):
|
||||
error_type = data["error"].get("type", "")
|
||||
if error_type and any(
|
||||
t.lower() in error_type.lower() for t in self.CLIENT_ERROR_TYPES
|
||||
):
|
||||
return True
|
||||
except (json.JSONDecodeError, TypeError, KeyError):
|
||||
pass
|
||||
parsed = self._parse_error_response(error_text)
|
||||
|
||||
# 回退到关键词匹配
|
||||
error_lower = error_text.lower()
|
||||
return any(pattern.lower() in error_lower for pattern in self.CLIENT_ERROR_PATTERNS)
|
||||
# 1. 检查 error type
|
||||
if parsed["type"]:
|
||||
error_type_lower = parsed["type"].lower()
|
||||
if any(t.lower() in error_type_lower for t in self.CLIENT_ERROR_TYPES):
|
||||
return True
|
||||
|
||||
# 2. 检查 reason/code
|
||||
if parsed["reason"]:
|
||||
reason_upper = parsed["reason"].upper()
|
||||
if any(r in reason_upper for r in self.CLIENT_ERROR_REASONS):
|
||||
return True
|
||||
|
||||
# 3. 回退到关键词匹配(合并 message 和 raw)
|
||||
search_text = f"{parsed['message']} {parsed['raw']}".lower()
|
||||
return any(pattern.lower() in search_text for pattern in self.CLIENT_ERROR_PATTERNS)
|
||||
|
||||
def _extract_error_message(self, error_text: Optional[str]) -> Optional[str]:
|
||||
"""
|
||||
从错误响应中提取错误消息
|
||||
|
||||
支持格式:
|
||||
- {"error": {"message": "..."}} (OpenAI/Claude)
|
||||
- {"error": {"type": "...", "message": "..."}}
|
||||
- {"error": "..."}
|
||||
- {"message": "..."}
|
||||
|
||||
Args:
|
||||
error_text: 错误响应文本
|
||||
|
||||
Returns:
|
||||
提取的错误消息,如果无法解析则返回原始文本
|
||||
提取的错误消息
|
||||
"""
|
||||
if not error_text:
|
||||
return None
|
||||
|
||||
try:
|
||||
data = json.loads(error_text)
|
||||
parsed = self._parse_error_response(error_text)
|
||||
|
||||
# {"error": {"message": "..."}} 或 {"error": {"type": "...", "message": "..."}}
|
||||
if isinstance(data.get("error"), dict):
|
||||
error_obj = data["error"]
|
||||
message = error_obj.get("message", "")
|
||||
error_type = error_obj.get("type", "")
|
||||
if message:
|
||||
if error_type:
|
||||
return f"{error_type}: {message}"
|
||||
return str(message)
|
||||
# 构建可读的错误消息
|
||||
parts = []
|
||||
if parsed["type"]:
|
||||
parts.append(parsed["type"])
|
||||
if parsed["reason"]:
|
||||
parts.append(f"[{parsed['reason']}]")
|
||||
if parsed["message"]:
|
||||
parts.append(parsed["message"])
|
||||
|
||||
# {"error": "..."}
|
||||
if isinstance(data.get("error"), str):
|
||||
return str(data["error"])
|
||||
|
||||
# {"message": "..."}
|
||||
if isinstance(data.get("message"), str):
|
||||
return str(data["message"])
|
||||
|
||||
except (json.JSONDecodeError, TypeError, KeyError):
|
||||
pass
|
||||
if parts:
|
||||
return ": ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
# 无法解析,返回原始文本(截断)
|
||||
return error_text[:500] if len(error_text) > 500 else error_text
|
||||
return parsed["raw"][:500] if len(parsed["raw"]) > 500 else parsed["raw"]
|
||||
|
||||
def classify(
|
||||
self,
|
||||
|
||||
@@ -102,9 +102,15 @@ class FallbackOrchestrator:
|
||||
"provider_priority_mode",
|
||||
CacheAwareScheduler.PRIORITY_MODE_PROVIDER,
|
||||
)
|
||||
scheduling_mode = SystemConfigService.get_config(
|
||||
self.db,
|
||||
"scheduling_mode",
|
||||
CacheAwareScheduler.SCHEDULING_MODE_CACHE_AFFINITY,
|
||||
)
|
||||
self.cache_scheduler = await get_cache_aware_scheduler(
|
||||
self.redis,
|
||||
priority_mode=priority_mode,
|
||||
scheduling_mode=scheduling_mode,
|
||||
)
|
||||
else:
|
||||
# 确保运行时配置变更能生效
|
||||
@@ -113,7 +119,13 @@ class FallbackOrchestrator:
|
||||
"provider_priority_mode",
|
||||
CacheAwareScheduler.PRIORITY_MODE_PROVIDER,
|
||||
)
|
||||
scheduling_mode = SystemConfigService.get_config(
|
||||
self.db,
|
||||
"scheduling_mode",
|
||||
CacheAwareScheduler.SCHEDULING_MODE_CACHE_AFFINITY,
|
||||
)
|
||||
self.cache_scheduler.set_priority_mode(priority_mode)
|
||||
self.cache_scheduler.set_scheduling_mode(scheduling_mode)
|
||||
|
||||
# 确保 cache_scheduler 内部组件也已初始化
|
||||
await self.cache_scheduler._ensure_initialized()
|
||||
|
||||
@@ -5,6 +5,10 @@
|
||||
- 使用滑动窗口采样,容忍并发波动
|
||||
- 基于窗口内高利用率采样比例决策,而非要求连续高利用率
|
||||
- 增加探测性扩容机制,长时间稳定时主动尝试扩容
|
||||
|
||||
AIMD 参数说明:
|
||||
- 扩容:加性增加 (+INCREASE_STEP)
|
||||
- 缩容:乘性减少 (*DECREASE_MULTIPLIER,默认 0.85)
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
@@ -34,7 +38,7 @@ class AdaptiveConcurrencyManager:
|
||||
核心算法:基于滑动窗口利用率的 AIMD
|
||||
- 滑动窗口记录最近 N 次请求的利用率
|
||||
- 当窗口内高利用率采样比例 >= 60% 时触发扩容
|
||||
- 遇到 429 错误时乘性减少 (*0.7)
|
||||
- 遇到 429 错误时乘性减少 (*0.85)
|
||||
- 长时间无 429 且有流量时触发探测性扩容
|
||||
|
||||
扩容条件(满足任一即可):
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
|
||||
import asyncio
|
||||
import math
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import timedelta # noqa: F401 - kept for potential future use
|
||||
from typing import Optional, Tuple
|
||||
@@ -40,6 +39,7 @@ class ConcurrencyManager:
|
||||
self._memory_lock: asyncio.Lock = asyncio.Lock()
|
||||
self._memory_endpoint_counts: dict[str, int] = {}
|
||||
self._memory_key_counts: dict[str, int] = {}
|
||||
self._owns_redis: bool = False
|
||||
self._memory_initialized = True
|
||||
|
||||
async def initialize(self) -> None:
|
||||
@@ -47,41 +47,29 @@ class ConcurrencyManager:
|
||||
if self._redis is not None:
|
||||
return
|
||||
|
||||
# 优先使用 REDIS_URL,如果没有则根据密码构建 URL
|
||||
redis_url = os.getenv("REDIS_URL")
|
||||
|
||||
if not redis_url:
|
||||
# 本地开发模式:从 REDIS_PASSWORD 构建 URL
|
||||
redis_password = os.getenv("REDIS_PASSWORD")
|
||||
if redis_password:
|
||||
redis_url = f"redis://:{redis_password}@localhost:6379/0"
|
||||
else:
|
||||
redis_url = "redis://localhost:6379/0"
|
||||
|
||||
try:
|
||||
self._redis = await aioredis.from_url(
|
||||
redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
socket_timeout=5.0,
|
||||
socket_connect_timeout=5.0,
|
||||
)
|
||||
# 测试连接
|
||||
await self._redis.ping()
|
||||
# 脱敏显示(隐藏密码)
|
||||
safe_url = redis_url.split("@")[-1] if "@" in redis_url else redis_url
|
||||
logger.info(f"[OK] Redis 连接成功: {safe_url}")
|
||||
# 复用全局 Redis 客户端(带熔断/降级),避免重复创建连接池
|
||||
from src.clients.redis_client import get_redis_client
|
||||
|
||||
self._redis = await get_redis_client(require_redis=False)
|
||||
self._owns_redis = False
|
||||
if self._redis:
|
||||
logger.info("[OK] ConcurrencyManager 已复用全局 Redis 客户端")
|
||||
else:
|
||||
logger.warning("[WARN] Redis 不可用,并发控制降级为内存模式(仅在单实例环境下安全)")
|
||||
except Exception as e:
|
||||
logger.error(f"[ERROR] Redis 连接失败: {e}")
|
||||
logger.warning("[WARN] 并发控制将被禁用(仅在单实例环境下安全)")
|
||||
logger.error(f"[ERROR] 获取全局 Redis 客户端失败: {e}")
|
||||
logger.warning("[WARN] 并发控制将降级为内存模式(仅在单实例环境下安全)")
|
||||
self._redis = None
|
||||
self._owns_redis = False
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭 Redis 连接"""
|
||||
if self._redis:
|
||||
if self._redis and self._owns_redis:
|
||||
await self._redis.close()
|
||||
self._redis = None
|
||||
logger.info("Redis 连接已关闭")
|
||||
logger.info("ConcurrencyManager Redis 连接已关闭")
|
||||
self._redis = None
|
||||
self._owns_redis = False
|
||||
|
||||
def _get_endpoint_key(self, endpoint_id: str) -> str:
|
||||
"""获取 Endpoint 并发计数的 Redis Key"""
|
||||
|
||||
@@ -3,7 +3,7 @@ RPM (Requests Per Minute) 限流服务
|
||||
"""
|
||||
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -72,11 +72,7 @@ class RPMLimiter:
|
||||
# 获取当前分钟窗口
|
||||
now = datetime.now(timezone.utc)
|
||||
window_start = now.replace(second=0, microsecond=0)
|
||||
window_end = (
|
||||
window_start.replace(minute=window_start.minute + 1)
|
||||
if window_start.minute < 59
|
||||
else window_start.replace(hour=window_start.hour + 1, minute=0)
|
||||
)
|
||||
window_end = window_start + timedelta(minutes=1)
|
||||
|
||||
# 查找或创建追踪记录
|
||||
tracking = (
|
||||
|
||||
@@ -11,7 +11,6 @@ from sqlalchemy.orm import Session
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import AuditEventType, AuditLog
|
||||
from src.utils.transaction_manager import transactional
|
||||
|
||||
|
||||
|
||||
@@ -19,10 +18,13 @@ from src.utils.transaction_manager import transactional
|
||||
|
||||
|
||||
class AuditService:
|
||||
"""审计服务"""
|
||||
"""审计服务
|
||||
|
||||
事务策略:本服务不负责事务提交,由中间件统一管理。
|
||||
所有方法只做 db.add/flush,提交由请求结束时的中间件处理。
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@transactional(commit=False) # 不自动提交,让调用方决定
|
||||
def log_event(
|
||||
db: Session,
|
||||
event_type: AuditEventType,
|
||||
@@ -54,47 +56,44 @@ class AuditService:
|
||||
|
||||
Returns:
|
||||
审计日志记录
|
||||
|
||||
Note:
|
||||
不在此方法内提交事务,由调用方或中间件统一管理。
|
||||
"""
|
||||
try:
|
||||
audit_log = AuditLog(
|
||||
event_type=event_type.value,
|
||||
description=description,
|
||||
user_id=user_id,
|
||||
api_key_id=api_key_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
status_code=status_code,
|
||||
error_message=error_message,
|
||||
event_metadata=metadata,
|
||||
)
|
||||
audit_log = AuditLog(
|
||||
event_type=event_type.value,
|
||||
description=description,
|
||||
user_id=user_id,
|
||||
api_key_id=api_key_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
request_id=request_id,
|
||||
status_code=status_code,
|
||||
error_message=error_message,
|
||||
event_metadata=metadata,
|
||||
)
|
||||
|
||||
db.add(audit_log)
|
||||
db.commit() # 立即提交事务,释放数据库锁
|
||||
db.refresh(audit_log)
|
||||
db.add(audit_log)
|
||||
# 使用 flush 使记录可见但不提交事务,事务由中间件统一管理
|
||||
db.flush()
|
||||
|
||||
# 同时记录到系统日志
|
||||
log_message = (
|
||||
f"AUDIT [{event_type.value}] - {description} | "
|
||||
f"user_id={user_id}, ip={ip_address}"
|
||||
)
|
||||
# 同时记录到系统日志
|
||||
log_message = (
|
||||
f"AUDIT [{event_type.value}] - {description} | "
|
||||
f"user_id={user_id}, ip={ip_address}"
|
||||
)
|
||||
|
||||
if event_type in [
|
||||
AuditEventType.UNAUTHORIZED_ACCESS,
|
||||
AuditEventType.SUSPICIOUS_ACTIVITY,
|
||||
]:
|
||||
logger.warning(log_message)
|
||||
elif event_type in [AuditEventType.LOGIN_FAILED, AuditEventType.REQUEST_FAILED]:
|
||||
logger.info(log_message)
|
||||
else:
|
||||
logger.debug(log_message)
|
||||
if event_type in [
|
||||
AuditEventType.UNAUTHORIZED_ACCESS,
|
||||
AuditEventType.SUSPICIOUS_ACTIVITY,
|
||||
]:
|
||||
logger.warning(log_message)
|
||||
elif event_type in [AuditEventType.LOGIN_FAILED, AuditEventType.REQUEST_FAILED]:
|
||||
logger.info(log_message)
|
||||
else:
|
||||
logger.debug(log_message)
|
||||
|
||||
return audit_log
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log audit event: {e}")
|
||||
db.rollback()
|
||||
return None
|
||||
return audit_log
|
||||
|
||||
@staticmethod
|
||||
def log_login_attempt(
|
||||
|
||||
@@ -35,6 +35,7 @@ class CleanupScheduler:
|
||||
def __init__(self):
|
||||
self.running = False
|
||||
self._interval_tasks = []
|
||||
self._stats_aggregation_lock = asyncio.Lock()
|
||||
|
||||
async def start(self):
|
||||
"""启动调度器"""
|
||||
@@ -56,6 +57,14 @@ class CleanupScheduler:
|
||||
job_id="stats_aggregation",
|
||||
name="统计数据聚合",
|
||||
)
|
||||
# 统计聚合补偿任务 - 每 30 分钟检查缺失并回填
|
||||
scheduler.add_interval_job(
|
||||
self._scheduled_stats_aggregation,
|
||||
minutes=30,
|
||||
job_id="stats_aggregation_backfill",
|
||||
name="统计数据聚合补偿",
|
||||
backfill=True,
|
||||
)
|
||||
|
||||
# 清理任务 - 凌晨 3 点执行
|
||||
scheduler.add_cron_job(
|
||||
@@ -115,9 +124,9 @@ class CleanupScheduler:
|
||||
|
||||
# ========== 任务函数(APScheduler 直接调用异步函数) ==========
|
||||
|
||||
async def _scheduled_stats_aggregation(self):
|
||||
async def _scheduled_stats_aggregation(self, backfill: bool = False):
|
||||
"""统计聚合任务(定时调用)"""
|
||||
await self._perform_stats_aggregation()
|
||||
await self._perform_stats_aggregation(backfill=backfill)
|
||||
|
||||
async def _scheduled_cleanup(self):
|
||||
"""清理任务(定时调用)"""
|
||||
@@ -144,136 +153,157 @@ class CleanupScheduler:
|
||||
Args:
|
||||
backfill: 是否回填历史数据(启动时检查缺失的日期)
|
||||
"""
|
||||
db = create_session()
|
||||
try:
|
||||
# 检查是否启用统计聚合
|
||||
if not SystemConfigService.get_config(db, "enable_stats_aggregation", True):
|
||||
logger.info("统计聚合已禁用,跳过聚合任务")
|
||||
return
|
||||
if self._stats_aggregation_lock.locked():
|
||||
logger.info("统计聚合任务正在运行,跳过本次触发")
|
||||
return
|
||||
|
||||
logger.info("开始执行统计数据聚合...")
|
||||
|
||||
from src.models.database import StatsDaily, User as DBUser
|
||||
from src.services.system.scheduler import APP_TIMEZONE
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
# 使用业务时区计算日期,确保与定时任务触发时间一致
|
||||
# 定时任务在 Asia/Shanghai 凌晨 1 点触发,此时应聚合 Asia/Shanghai 的"昨天"
|
||||
app_tz = ZoneInfo(APP_TIMEZONE)
|
||||
now_local = datetime.now(app_tz)
|
||||
today_local = now_local.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
if backfill:
|
||||
# 启动时检查并回填缺失的日期
|
||||
from src.models.database import StatsSummary
|
||||
|
||||
summary = db.query(StatsSummary).first()
|
||||
if not summary:
|
||||
# 首次运行,回填所有历史数据
|
||||
logger.info("检测到首次运行,开始回填历史统计数据...")
|
||||
days_to_backfill = SystemConfigService.get_config(
|
||||
db, "stats_backfill_days", 365
|
||||
)
|
||||
count = StatsAggregatorService.backfill_historical_data(
|
||||
db, days=days_to_backfill
|
||||
)
|
||||
logger.info(f"历史数据回填完成,共 {count} 天")
|
||||
async with self._stats_aggregation_lock:
|
||||
db = create_session()
|
||||
try:
|
||||
# 检查是否启用统计聚合
|
||||
if not SystemConfigService.get_config(db, "enable_stats_aggregation", True):
|
||||
logger.info("统计聚合已禁用,跳过聚合任务")
|
||||
return
|
||||
|
||||
# 非首次运行,检查最近是否有缺失的日期需要回填
|
||||
latest_stat = (
|
||||
db.query(StatsDaily)
|
||||
.order_by(StatsDaily.date.desc())
|
||||
.first()
|
||||
)
|
||||
logger.info("开始执行统计数据聚合...")
|
||||
|
||||
if latest_stat:
|
||||
latest_date_utc = latest_stat.date
|
||||
if latest_date_utc.tzinfo is None:
|
||||
latest_date_utc = latest_date_utc.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
latest_date_utc = latest_date_utc.astimezone(timezone.utc)
|
||||
from src.models.database import StatsDaily, User as DBUser
|
||||
from src.services.system.scheduler import APP_TIMEZONE
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
# 使用业务日期计算缺失区间(避免用 UTC 年月日导致日期偏移,且对 DST 更安全)
|
||||
latest_business_date = latest_date_utc.astimezone(app_tz).date()
|
||||
yesterday_business_date = (today_local.date() - timedelta(days=1))
|
||||
missing_start_date = latest_business_date + timedelta(days=1)
|
||||
# 使用业务时区计算日期,确保与定时任务触发时间一致
|
||||
# 定时任务在 Asia/Shanghai 凌晨 1 点触发,此时应聚合 Asia/Shanghai 的"昨天"
|
||||
app_tz = ZoneInfo(APP_TIMEZONE)
|
||||
now_local = datetime.now(app_tz)
|
||||
today_local = now_local.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
if missing_start_date <= yesterday_business_date:
|
||||
missing_days = (yesterday_business_date - missing_start_date).days + 1
|
||||
logger.info(
|
||||
f"检测到缺失 {missing_days} 天的统计数据 "
|
||||
f"({missing_start_date} ~ {yesterday_business_date}),开始回填..."
|
||||
if backfill:
|
||||
# 启动时检查并回填缺失的日期
|
||||
from src.models.database import StatsSummary
|
||||
|
||||
summary = db.query(StatsSummary).first()
|
||||
if not summary:
|
||||
# 首次运行,回填所有历史数据
|
||||
logger.info("检测到首次运行,开始回填历史统计数据...")
|
||||
days_to_backfill = SystemConfigService.get_config(
|
||||
db, "stats_backfill_days", 365
|
||||
)
|
||||
count = StatsAggregatorService.backfill_historical_data(
|
||||
db, days=days_to_backfill
|
||||
)
|
||||
logger.info(f"历史数据回填完成,共 {count} 天")
|
||||
return
|
||||
|
||||
current_date = missing_start_date
|
||||
users = db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
|
||||
# 非首次运行,检查最近是否有缺失的日期需要回填
|
||||
latest_stat = db.query(StatsDaily).order_by(StatsDaily.date.desc()).first()
|
||||
|
||||
while current_date <= yesterday_business_date:
|
||||
try:
|
||||
current_date_local = datetime.combine(
|
||||
current_date, datetime.min.time(), tzinfo=app_tz
|
||||
if latest_stat:
|
||||
latest_date_utc = latest_stat.date
|
||||
if latest_date_utc.tzinfo is None:
|
||||
latest_date_utc = latest_date_utc.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
latest_date_utc = latest_date_utc.astimezone(timezone.utc)
|
||||
|
||||
# 使用业务日期计算缺失区间(避免用 UTC 年月日导致日期偏移,且对 DST 更安全)
|
||||
latest_business_date = latest_date_utc.astimezone(app_tz).date()
|
||||
yesterday_business_date = today_local.date() - timedelta(days=1)
|
||||
missing_start_date = latest_business_date + timedelta(days=1)
|
||||
|
||||
if missing_start_date <= yesterday_business_date:
|
||||
missing_days = (
|
||||
yesterday_business_date - missing_start_date
|
||||
).days + 1
|
||||
|
||||
# 限制最大回填天数,防止停机很久后一次性回填太多
|
||||
max_backfill_days: int = SystemConfigService.get_config(
|
||||
db, "max_stats_backfill_days", 30
|
||||
) or 30
|
||||
if missing_days > max_backfill_days:
|
||||
logger.warning(
|
||||
f"缺失 {missing_days} 天数据超过最大回填限制 "
|
||||
f"{max_backfill_days} 天,只回填最近 {max_backfill_days} 天"
|
||||
)
|
||||
StatsAggregatorService.aggregate_daily_stats(db, current_date_local)
|
||||
# 聚合用户数据
|
||||
for (user_id,) in users:
|
||||
try:
|
||||
StatsAggregatorService.aggregate_user_daily_stats(
|
||||
db, user_id, current_date_local
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"回填用户 {user_id} 日期 {current_date} 失败: {e}"
|
||||
)
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"回填日期 {current_date} 失败: {e}")
|
||||
missing_start_date = yesterday_business_date - timedelta(
|
||||
days=max_backfill_days - 1
|
||||
)
|
||||
missing_days = max_backfill_days
|
||||
|
||||
logger.info(
|
||||
f"检测到缺失 {missing_days} 天的统计数据 "
|
||||
f"({missing_start_date} ~ {yesterday_business_date}),开始回填..."
|
||||
)
|
||||
|
||||
current_date = missing_start_date
|
||||
users = (
|
||||
db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
|
||||
)
|
||||
|
||||
while current_date <= yesterday_business_date:
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
current_date_local = datetime.combine(
|
||||
current_date, datetime.min.time(), tzinfo=app_tz
|
||||
)
|
||||
StatsAggregatorService.aggregate_daily_stats(
|
||||
db, current_date_local
|
||||
)
|
||||
for (user_id,) in users:
|
||||
try:
|
||||
StatsAggregatorService.aggregate_user_daily_stats(
|
||||
db, user_id, current_date_local
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"回填用户 {user_id} 日期 {current_date} 失败: {e}"
|
||||
)
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"回填日期 {current_date} 失败: {e}")
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
current_date += timedelta(days=1)
|
||||
current_date += timedelta(days=1)
|
||||
|
||||
# 更新全局汇总
|
||||
StatsAggregatorService.update_summary(db)
|
||||
logger.info(f"缺失数据回填完成,共 {missing_days} 天")
|
||||
else:
|
||||
logger.info("统计数据已是最新,无需回填")
|
||||
return
|
||||
StatsAggregatorService.update_summary(db)
|
||||
logger.info(f"缺失数据回填完成,共 {missing_days} 天")
|
||||
else:
|
||||
logger.info("统计数据已是最新,无需回填")
|
||||
return
|
||||
|
||||
# 定时任务:聚合昨天的数据
|
||||
# 注意:aggregate_daily_stats 期望业务时区的日期,不是 UTC
|
||||
yesterday_local = today_local - timedelta(days=1)
|
||||
# 定时任务:聚合昨天的数据
|
||||
yesterday_local = today_local - timedelta(days=1)
|
||||
|
||||
StatsAggregatorService.aggregate_daily_stats(db, yesterday_local)
|
||||
StatsAggregatorService.aggregate_daily_stats(db, yesterday_local)
|
||||
|
||||
# 聚合所有用户的昨日数据
|
||||
users = db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
|
||||
for (user_id,) in users:
|
||||
try:
|
||||
StatsAggregatorService.aggregate_user_daily_stats(db, user_id, yesterday_local)
|
||||
except Exception as e:
|
||||
logger.warning(f"聚合用户 {user_id} 统计数据失败: {e}")
|
||||
# 回滚当前用户的失败操作,继续处理其他用户
|
||||
users = db.query(DBUser.id).filter(DBUser.is_active.is_(True)).all()
|
||||
for (user_id,) in users:
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
StatsAggregatorService.aggregate_user_daily_stats(
|
||||
db, user_id, yesterday_local
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"聚合用户 {user_id} 统计数据失败: {e}")
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 更新全局汇总
|
||||
StatsAggregatorService.update_summary(db)
|
||||
StatsAggregatorService.update_summary(db)
|
||||
|
||||
logger.info("统计数据聚合完成")
|
||||
logger.info("统计数据聚合完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"统计聚合任务执行失败: {e}")
|
||||
db.rollback()
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.exception(f"统计聚合任务执行失败: {e}")
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def _perform_pending_cleanup(self):
|
||||
"""执行 pending 状态清理"""
|
||||
|
||||
@@ -71,6 +71,10 @@ class SystemConfigService:
|
||||
"value": "provider",
|
||||
"description": "优先级策略:provider(提供商优先模式) 或 global_key(全局Key优先模式)",
|
||||
},
|
||||
"scheduling_mode": {
|
||||
"value": "cache_affinity",
|
||||
"description": "调度模式:fixed_order(固定顺序模式,严格按优先级顺序) 或 cache_affinity(缓存亲和模式,优先使用已缓存的Provider)",
|
||||
},
|
||||
"auto_delete_expired_keys": {
|
||||
"value": False,
|
||||
"description": "是否自动删除过期的API Key(True=物理删除,False=仅禁用),仅管理员可配置",
|
||||
|
||||
@@ -56,65 +56,44 @@ class StatsAggregatorService:
|
||||
"""统计数据聚合服务"""
|
||||
|
||||
@staticmethod
|
||||
def aggregate_daily_stats(db: Session, date: datetime) -> StatsDaily:
|
||||
"""聚合指定日期的统计数据
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
date: 要聚合的业务日期(使用 APP_TIMEZONE 时区)
|
||||
|
||||
Returns:
|
||||
StatsDaily 记录
|
||||
"""
|
||||
# 将业务日期转换为 UTC 时间范围
|
||||
def compute_daily_stats(db: Session, date: datetime) -> dict:
|
||||
"""计算指定业务日期的统计数据(不写入数据库)"""
|
||||
day_start, day_end = _get_business_day_range(date)
|
||||
|
||||
# stats_daily.date 存储的是业务日期对应的 UTC 开始时间
|
||||
# 检查是否已存在该日期的记录
|
||||
existing = db.query(StatsDaily).filter(StatsDaily.date == day_start).first()
|
||||
if existing:
|
||||
stats = existing
|
||||
else:
|
||||
stats = StatsDaily(id=str(uuid.uuid4()), date=day_start)
|
||||
|
||||
# 基础请求统计
|
||||
base_query = db.query(Usage).filter(
|
||||
and_(Usage.created_at >= day_start, Usage.created_at < day_end)
|
||||
)
|
||||
|
||||
total_requests = base_query.count()
|
||||
|
||||
# 如果没有请求,直接返回空记录
|
||||
if total_requests == 0:
|
||||
stats.total_requests = 0
|
||||
stats.success_requests = 0
|
||||
stats.error_requests = 0
|
||||
stats.input_tokens = 0
|
||||
stats.output_tokens = 0
|
||||
stats.cache_creation_tokens = 0
|
||||
stats.cache_read_tokens = 0
|
||||
stats.total_cost = 0.0
|
||||
stats.actual_total_cost = 0.0
|
||||
stats.input_cost = 0.0
|
||||
stats.output_cost = 0.0
|
||||
stats.cache_creation_cost = 0.0
|
||||
stats.cache_read_cost = 0.0
|
||||
stats.avg_response_time_ms = 0.0
|
||||
stats.fallback_count = 0
|
||||
return {
|
||||
"day_start": day_start,
|
||||
"total_requests": 0,
|
||||
"success_requests": 0,
|
||||
"error_requests": 0,
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"cache_creation_tokens": 0,
|
||||
"cache_read_tokens": 0,
|
||||
"total_cost": 0.0,
|
||||
"actual_total_cost": 0.0,
|
||||
"input_cost": 0.0,
|
||||
"output_cost": 0.0,
|
||||
"cache_creation_cost": 0.0,
|
||||
"cache_read_cost": 0.0,
|
||||
"avg_response_time_ms": 0.0,
|
||||
"fallback_count": 0,
|
||||
"unique_models": 0,
|
||||
"unique_providers": 0,
|
||||
}
|
||||
|
||||
if not existing:
|
||||
db.add(stats)
|
||||
db.commit()
|
||||
return stats
|
||||
|
||||
# 错误请求数
|
||||
error_requests = (
|
||||
base_query.filter(
|
||||
(Usage.status_code >= 400) | (Usage.error_message.isnot(None))
|
||||
).count()
|
||||
)
|
||||
|
||||
# Token 和成本聚合
|
||||
aggregated = (
|
||||
db.query(
|
||||
func.sum(Usage.input_tokens).label("input_tokens"),
|
||||
@@ -157,7 +136,6 @@ class StatsAggregatorService:
|
||||
or 0
|
||||
)
|
||||
|
||||
# 使用维度统计
|
||||
unique_models = (
|
||||
db.query(func.count(func.distinct(Usage.model)))
|
||||
.filter(and_(Usage.created_at >= day_start, Usage.created_at < day_end))
|
||||
@@ -171,31 +149,74 @@ class StatsAggregatorService:
|
||||
or 0
|
||||
)
|
||||
|
||||
return {
|
||||
"day_start": day_start,
|
||||
"total_requests": total_requests,
|
||||
"success_requests": total_requests - error_requests,
|
||||
"error_requests": error_requests,
|
||||
"input_tokens": int(aggregated.input_tokens or 0) if aggregated else 0,
|
||||
"output_tokens": int(aggregated.output_tokens or 0) if aggregated else 0,
|
||||
"cache_creation_tokens": int(aggregated.cache_creation_tokens or 0) if aggregated else 0,
|
||||
"cache_read_tokens": int(aggregated.cache_read_tokens or 0) if aggregated else 0,
|
||||
"total_cost": float(aggregated.total_cost or 0) if aggregated else 0.0,
|
||||
"actual_total_cost": float(aggregated.actual_total_cost or 0) if aggregated else 0.0,
|
||||
"input_cost": float(aggregated.input_cost or 0) if aggregated else 0.0,
|
||||
"output_cost": float(aggregated.output_cost or 0) if aggregated else 0.0,
|
||||
"cache_creation_cost": float(aggregated.cache_creation_cost or 0) if aggregated else 0.0,
|
||||
"cache_read_cost": float(aggregated.cache_read_cost or 0) if aggregated else 0.0,
|
||||
"avg_response_time_ms": float(aggregated.avg_response_time or 0) if aggregated else 0.0,
|
||||
"fallback_count": fallback_count,
|
||||
"unique_models": unique_models,
|
||||
"unique_providers": unique_providers,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def aggregate_daily_stats(db: Session, date: datetime) -> StatsDaily:
|
||||
"""聚合指定日期的统计数据
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
date: 要聚合的业务日期(使用 APP_TIMEZONE 时区)
|
||||
|
||||
Returns:
|
||||
StatsDaily 记录
|
||||
"""
|
||||
computed = StatsAggregatorService.compute_daily_stats(db, date)
|
||||
day_start = computed["day_start"]
|
||||
|
||||
# stats_daily.date 存储的是业务日期对应的 UTC 开始时间
|
||||
# 检查是否已存在该日期的记录
|
||||
existing = db.query(StatsDaily).filter(StatsDaily.date == day_start).first()
|
||||
if existing:
|
||||
stats = existing
|
||||
else:
|
||||
stats = StatsDaily(id=str(uuid.uuid4()), date=day_start)
|
||||
|
||||
# 更新统计记录
|
||||
stats.total_requests = total_requests
|
||||
stats.success_requests = total_requests - error_requests
|
||||
stats.error_requests = error_requests
|
||||
stats.input_tokens = int(aggregated.input_tokens or 0)
|
||||
stats.output_tokens = int(aggregated.output_tokens or 0)
|
||||
stats.cache_creation_tokens = int(aggregated.cache_creation_tokens or 0)
|
||||
stats.cache_read_tokens = int(aggregated.cache_read_tokens or 0)
|
||||
stats.total_cost = float(aggregated.total_cost or 0)
|
||||
stats.actual_total_cost = float(aggregated.actual_total_cost or 0)
|
||||
stats.input_cost = float(aggregated.input_cost or 0)
|
||||
stats.output_cost = float(aggregated.output_cost or 0)
|
||||
stats.cache_creation_cost = float(aggregated.cache_creation_cost or 0)
|
||||
stats.cache_read_cost = float(aggregated.cache_read_cost or 0)
|
||||
stats.avg_response_time_ms = float(aggregated.avg_response_time or 0)
|
||||
stats.fallback_count = fallback_count
|
||||
stats.unique_models = unique_models
|
||||
stats.unique_providers = unique_providers
|
||||
stats.total_requests = computed["total_requests"]
|
||||
stats.success_requests = computed["success_requests"]
|
||||
stats.error_requests = computed["error_requests"]
|
||||
stats.input_tokens = computed["input_tokens"]
|
||||
stats.output_tokens = computed["output_tokens"]
|
||||
stats.cache_creation_tokens = computed["cache_creation_tokens"]
|
||||
stats.cache_read_tokens = computed["cache_read_tokens"]
|
||||
stats.total_cost = computed["total_cost"]
|
||||
stats.actual_total_cost = computed["actual_total_cost"]
|
||||
stats.input_cost = computed["input_cost"]
|
||||
stats.output_cost = computed["output_cost"]
|
||||
stats.cache_creation_cost = computed["cache_creation_cost"]
|
||||
stats.cache_read_cost = computed["cache_read_cost"]
|
||||
stats.avg_response_time_ms = computed["avg_response_time_ms"]
|
||||
stats.fallback_count = computed["fallback_count"]
|
||||
stats.unique_models = computed["unique_models"]
|
||||
stats.unique_providers = computed["unique_providers"]
|
||||
|
||||
if not existing:
|
||||
db.add(stats)
|
||||
db.commit()
|
||||
|
||||
# 日志使用业务日期(输入参数),而不是 UTC 日期
|
||||
logger.info(f"[StatsAggregator] 聚合日期 {date.date()} 完成: {total_requests} 请求")
|
||||
logger.info(f"[StatsAggregator] 聚合日期 {date.date()} 完成: {computed['total_requests']} 请求")
|
||||
return stats
|
||||
|
||||
@staticmethod
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -71,8 +71,8 @@ class PreferenceService:
|
||||
raise NotFoundException("Provider not found or inactive")
|
||||
preferences.default_provider_id = default_provider_id
|
||||
if theme is not None:
|
||||
if theme not in ["light", "dark", "auto"]:
|
||||
raise ValueError("Invalid theme. Must be 'light', 'dark', or 'auto'")
|
||||
if theme not in ["light", "dark", "auto", "system"]:
|
||||
raise ValueError("Invalid theme. Must be 'light', 'dark', 'auto', or 'system'")
|
||||
preferences.theme = theme
|
||||
if language is not None:
|
||||
preferences.language = language
|
||||
|
||||
@@ -19,7 +19,7 @@ from ..models.database import User, UserRole
|
||||
security = HTTPBearer()
|
||||
|
||||
|
||||
def get_current_user(
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security), db: Session = Depends(get_db)
|
||||
) -> User:
|
||||
"""
|
||||
@@ -41,7 +41,7 @@ def get_current_user(
|
||||
try:
|
||||
# 验证Token格式和签名
|
||||
try:
|
||||
payload = AuthService.verify_token(token)
|
||||
payload = await AuthService.verify_token(token, token_type="access")
|
||||
except HTTPException as token_error:
|
||||
# 保持原始的HTTP状态码(如401 Unauthorized),不要转换为403
|
||||
logger.error(f"Token验证失败: {token_error.status_code}: {token_error.detail}, Token前10位: {token[:10]}...")
|
||||
@@ -122,7 +122,7 @@ def get_current_user(
|
||||
)
|
||||
|
||||
|
||||
def get_current_user_from_header(
|
||||
async def get_current_user_from_header(
|
||||
authorization: Optional[str] = Header(None), db: Session = Depends(get_db)
|
||||
) -> User:
|
||||
"""
|
||||
@@ -144,7 +144,7 @@ def get_current_user_from_header(
|
||||
token = authorization.replace("Bearer ", "")
|
||||
|
||||
try:
|
||||
payload = AuthService.verify_token(token)
|
||||
payload = await AuthService.verify_token(token, token_type="access")
|
||||
user_id = payload.get("user_id")
|
||||
|
||||
if not user_id:
|
||||
|
||||
363
tests/api/test_pipeline.py
Normal file
363
tests/api/test_pipeline.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""
|
||||
API Pipeline 测试
|
||||
|
||||
测试 ApiRequestPipeline 的核心功能:
|
||||
- 认证流程(API Key、JWT Token)
|
||||
- 配额计算
|
||||
- 审计日志记录
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
|
||||
|
||||
class TestPipelineQuotaCalculation:
|
||||
"""测试 Pipeline 配额计算"""
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline(self) -> ApiRequestPipeline:
|
||||
return ApiRequestPipeline()
|
||||
|
||||
def test_calculate_quota_remaining_with_quota(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试有配额限制时计算剩余配额"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.quota_usd = 100.0
|
||||
mock_user.used_usd = 30.0
|
||||
|
||||
remaining = pipeline._calculate_quota_remaining(mock_user)
|
||||
|
||||
assert remaining == 70.0
|
||||
|
||||
def test_calculate_quota_remaining_no_quota(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试无配额限制时返回 None"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.quota_usd = None
|
||||
mock_user.used_usd = 30.0
|
||||
|
||||
remaining = pipeline._calculate_quota_remaining(mock_user)
|
||||
|
||||
assert remaining is None
|
||||
|
||||
def test_calculate_quota_remaining_negative_quota(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试负配额时返回 None"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.quota_usd = -1
|
||||
mock_user.used_usd = 0.0
|
||||
|
||||
remaining = pipeline._calculate_quota_remaining(mock_user)
|
||||
|
||||
assert remaining is None
|
||||
|
||||
def test_calculate_quota_remaining_exceeded(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试配额已超时返回 0"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.quota_usd = 100.0
|
||||
mock_user.used_usd = 150.0
|
||||
|
||||
remaining = pipeline._calculate_quota_remaining(mock_user)
|
||||
|
||||
assert remaining == 0.0
|
||||
|
||||
def test_calculate_quota_remaining_none_user(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试用户为 None 时返回 None"""
|
||||
remaining = pipeline._calculate_quota_remaining(None)
|
||||
|
||||
assert remaining is None
|
||||
|
||||
|
||||
class TestPipelineAuditLogging:
|
||||
"""测试 Pipeline 审计日志"""
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline(self) -> ApiRequestPipeline:
|
||||
return ApiRequestPipeline()
|
||||
|
||||
def test_record_audit_event_success(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试记录成功的审计事件"""
|
||||
mock_context = MagicMock()
|
||||
mock_context.db = MagicMock()
|
||||
mock_context.user = MagicMock()
|
||||
mock_context.user.id = "user-123"
|
||||
mock_context.api_key = MagicMock()
|
||||
mock_context.api_key.id = "key-123"
|
||||
mock_context.request_id = "req-123"
|
||||
mock_context.client_ip = "127.0.0.1"
|
||||
mock_context.user_agent = "test-agent"
|
||||
mock_context.request = MagicMock()
|
||||
mock_context.request.method = "POST"
|
||||
mock_context.request.url.path = "/v1/messages"
|
||||
mock_context.start_time = 1000.0
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.name = "test-adapter"
|
||||
mock_adapter.audit_log_enabled = True
|
||||
mock_adapter.audit_success_event = None
|
||||
mock_adapter.audit_failure_event = None
|
||||
|
||||
with patch.object(
|
||||
pipeline.audit_service,
|
||||
"log_event",
|
||||
) as mock_log:
|
||||
with patch("time.time", return_value=1001.0):
|
||||
pipeline._record_audit_event(
|
||||
mock_context, mock_adapter, success=True, status_code=200
|
||||
)
|
||||
|
||||
mock_log.assert_called_once()
|
||||
call_kwargs = mock_log.call_args[1]
|
||||
assert call_kwargs["user_id"] == "user-123"
|
||||
assert call_kwargs["status_code"] == 200
|
||||
|
||||
def test_record_audit_event_failure(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试记录失败的审计事件"""
|
||||
mock_context = MagicMock()
|
||||
mock_context.db = MagicMock()
|
||||
mock_context.user = MagicMock()
|
||||
mock_context.user.id = "user-123"
|
||||
mock_context.api_key = MagicMock()
|
||||
mock_context.api_key.id = "key-123"
|
||||
mock_context.request_id = "req-123"
|
||||
mock_context.client_ip = "127.0.0.1"
|
||||
mock_context.user_agent = "test-agent"
|
||||
mock_context.request = MagicMock()
|
||||
mock_context.request.method = "POST"
|
||||
mock_context.request.url.path = "/v1/messages"
|
||||
mock_context.start_time = 1000.0
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.name = "test-adapter"
|
||||
mock_adapter.audit_log_enabled = True
|
||||
mock_adapter.audit_success_event = None
|
||||
mock_adapter.audit_failure_event = None
|
||||
|
||||
with patch.object(
|
||||
pipeline.audit_service,
|
||||
"log_event",
|
||||
) as mock_log:
|
||||
with patch("time.time", return_value=1001.0):
|
||||
pipeline._record_audit_event(
|
||||
mock_context, mock_adapter, success=False, status_code=500, error="Internal error"
|
||||
)
|
||||
|
||||
mock_log.assert_called_once()
|
||||
call_kwargs = mock_log.call_args[1]
|
||||
assert call_kwargs["status_code"] == 500
|
||||
assert call_kwargs["error_message"] == "Internal error"
|
||||
|
||||
def test_record_audit_event_no_db(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试没有数据库会话时跳过审计"""
|
||||
mock_context = MagicMock()
|
||||
mock_context.db = None
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.audit_log_enabled = True
|
||||
|
||||
with patch.object(
|
||||
pipeline.audit_service,
|
||||
"log_event",
|
||||
) as mock_log:
|
||||
# 不应该抛出异常
|
||||
pipeline._record_audit_event(mock_context, mock_adapter, success=True)
|
||||
|
||||
# 不应该调用 log_event
|
||||
mock_log.assert_not_called()
|
||||
|
||||
def test_record_audit_event_disabled(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试审计日志被禁用时跳过"""
|
||||
mock_context = MagicMock()
|
||||
mock_context.db = MagicMock()
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.audit_log_enabled = False
|
||||
|
||||
with patch.object(
|
||||
pipeline.audit_service,
|
||||
"log_event",
|
||||
) as mock_log:
|
||||
pipeline._record_audit_event(mock_context, mock_adapter, success=True)
|
||||
|
||||
mock_log.assert_not_called()
|
||||
|
||||
def test_record_audit_event_exception_handling(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试审计日志异常不影响主流程"""
|
||||
mock_context = MagicMock()
|
||||
mock_context.db = MagicMock()
|
||||
mock_context.user = MagicMock()
|
||||
mock_context.user.id = "user-123"
|
||||
mock_context.api_key = MagicMock()
|
||||
mock_context.api_key.id = "key-123"
|
||||
mock_context.request_id = "req-123"
|
||||
mock_context.client_ip = "127.0.0.1"
|
||||
mock_context.user_agent = "test-agent"
|
||||
mock_context.request = MagicMock()
|
||||
mock_context.request.method = "POST"
|
||||
mock_context.request.url.path = "/v1/messages"
|
||||
mock_context.start_time = 1000.0
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.name = "test-adapter"
|
||||
mock_adapter.audit_log_enabled = True
|
||||
mock_adapter.audit_success_event = None
|
||||
|
||||
with patch.object(
|
||||
pipeline.audit_service,
|
||||
"log_event",
|
||||
side_effect=Exception("DB error"),
|
||||
):
|
||||
with patch("time.time", return_value=1001.0):
|
||||
# 不应该抛出异常
|
||||
pipeline._record_audit_event(mock_context, mock_adapter, success=True)
|
||||
|
||||
|
||||
class TestPipelineAuthentication:
|
||||
"""测试 Pipeline 认证相关逻辑"""
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline(self) -> ApiRequestPipeline:
|
||||
return ApiRequestPipeline()
|
||||
|
||||
def test_authenticate_client_missing_key(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试缺少 API Key 时抛出异常"""
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {}
|
||||
mock_request.url.path = "/v1/messages"
|
||||
mock_request.state = MagicMock()
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.extract_api_key = MagicMock(return_value=None)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
pipeline._authenticate_client(mock_request, mock_db, mock_adapter)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "API密钥" in exc_info.value.detail
|
||||
|
||||
def test_authenticate_client_invalid_key(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试无效的 API Key"""
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {"Authorization": "Bearer sk-invalid"}
|
||||
mock_request.url.path = "/v1/messages"
|
||||
mock_request.state = MagicMock()
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.extract_api_key = MagicMock(return_value="sk-invalid")
|
||||
|
||||
with patch.object(
|
||||
pipeline.auth_service,
|
||||
"authenticate_api_key",
|
||||
return_value=None,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
pipeline._authenticate_client(mock_request, mock_db, mock_adapter)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
def test_authenticate_client_quota_exceeded(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试配额超限时抛出异常"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "user-123"
|
||||
mock_user.quota_usd = 100.0
|
||||
mock_user.used_usd = 100.0
|
||||
|
||||
mock_api_key = MagicMock()
|
||||
mock_api_key.id = "key-123"
|
||||
mock_api_key.is_standalone = False
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {"Authorization": "Bearer sk-test"}
|
||||
mock_request.url.path = "/v1/messages"
|
||||
mock_request.state = MagicMock()
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.extract_api_key = MagicMock(return_value="sk-test")
|
||||
|
||||
with patch.object(
|
||||
pipeline.auth_service,
|
||||
"authenticate_api_key",
|
||||
return_value=(mock_user, mock_api_key),
|
||||
):
|
||||
with patch.object(
|
||||
pipeline.usage_service,
|
||||
"check_user_quota",
|
||||
return_value=(False, "配额不足"),
|
||||
):
|
||||
from src.core.exceptions import QuotaExceededException
|
||||
|
||||
with pytest.raises(QuotaExceededException):
|
||||
pipeline._authenticate_client(mock_request, mock_db, mock_adapter)
|
||||
|
||||
|
||||
class TestPipelineAdminAuth:
|
||||
"""测试管理员认证"""
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline(self) -> ApiRequestPipeline:
|
||||
return ApiRequestPipeline()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_admin_missing_token(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试缺少管理员令牌"""
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {}
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await pipeline._authenticate_admin(mock_request, mock_db)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "管理员凭证" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_admin_invalid_token(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试无效的管理员令牌"""
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {"authorization": "Bearer invalid-token"}
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
with patch.object(
|
||||
pipeline.auth_service,
|
||||
"verify_token",
|
||||
side_effect=HTTPException(status_code=401, detail="Invalid token"),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await pipeline._authenticate_admin(mock_request, mock_db)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_admin_success(self, pipeline: ApiRequestPipeline) -> None:
|
||||
"""测试管理员认证成功"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "admin-123"
|
||||
mock_user.is_active = True
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.headers = {"authorization": "Bearer valid-token"}
|
||||
mock_request.state = MagicMock()
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
|
||||
with patch.object(
|
||||
pipeline.auth_service,
|
||||
"verify_token",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"user_id": "admin-123"},
|
||||
):
|
||||
result = await pipeline._authenticate_admin(mock_request, mock_db)
|
||||
|
||||
assert result == mock_user
|
||||
assert mock_request.state.user_id == "admin-123"
|
||||
1
tests/services/__init__.py
Normal file
1
tests/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""服务层测试"""
|
||||
299
tests/services/test_auth.py
Normal file
299
tests/services/test_auth.py
Normal file
@@ -0,0 +1,299 @@
|
||||
"""
|
||||
认证服务测试
|
||||
|
||||
测试 AuthService 的核心功能:
|
||||
- JWT Token 创建和验证
|
||||
- 用户登录认证
|
||||
- API Key 认证
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import jwt
|
||||
|
||||
from src.services.auth.service import (
|
||||
AuthService,
|
||||
JWT_SECRET_KEY,
|
||||
JWT_ALGORITHM,
|
||||
JWT_EXPIRATION_HOURS,
|
||||
)
|
||||
|
||||
|
||||
class TestJWTTokenCreation:
|
||||
"""测试 JWT Token 创建"""
|
||||
|
||||
def test_create_access_token_contains_required_fields(self) -> None:
|
||||
"""测试访问令牌包含必要字段"""
|
||||
data = {"sub": "user123", "email": "test@example.com"}
|
||||
token = AuthService.create_access_token(data)
|
||||
|
||||
# 解码验证
|
||||
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||
|
||||
assert payload["sub"] == "user123"
|
||||
assert payload["email"] == "test@example.com"
|
||||
assert payload["type"] == "access"
|
||||
assert "exp" in payload
|
||||
|
||||
def test_create_access_token_expiration(self) -> None:
|
||||
"""测试访问令牌过期时间正确"""
|
||||
data = {"sub": "user123"}
|
||||
token = AuthService.create_access_token(data)
|
||||
|
||||
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||
|
||||
# 验证过期时间在预期范围内(允许1分钟误差)
|
||||
exp_time = datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
|
||||
expected_exp = datetime.now(timezone.utc) + timedelta(hours=JWT_EXPIRATION_HOURS)
|
||||
|
||||
assert abs((exp_time - expected_exp).total_seconds()) < 60
|
||||
|
||||
def test_create_refresh_token_type(self) -> None:
|
||||
"""测试刷新令牌类型正确"""
|
||||
data = {"sub": "user123"}
|
||||
token = AuthService.create_refresh_token(data)
|
||||
|
||||
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||
|
||||
assert payload["type"] == "refresh"
|
||||
|
||||
def test_create_refresh_token_longer_expiration(self) -> None:
|
||||
"""测试刷新令牌过期时间更长"""
|
||||
data = {"sub": "user123"}
|
||||
access_token = AuthService.create_access_token(data)
|
||||
refresh_token = AuthService.create_refresh_token(data)
|
||||
|
||||
access_payload = jwt.decode(access_token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||
refresh_payload = jwt.decode(refresh_token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||
|
||||
# 刷新令牌应该比访问令牌过期时间更长
|
||||
assert refresh_payload["exp"] > access_payload["exp"]
|
||||
|
||||
|
||||
class TestJWTTokenVerification:
|
||||
"""测试 JWT Token 验证"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_valid_access_token(self) -> None:
|
||||
"""测试验证有效的访问令牌"""
|
||||
data = {"sub": "user123", "email": "test@example.com"}
|
||||
token = AuthService.create_access_token(data)
|
||||
|
||||
with patch(
|
||||
"src.services.auth.service.JWTBlacklistService.is_blacklisted",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
):
|
||||
payload = await AuthService.verify_token(token, token_type="access")
|
||||
|
||||
assert payload["sub"] == "user123"
|
||||
assert payload["type"] == "access"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_expired_token_raises_error(self) -> None:
|
||||
"""测试验证过期令牌抛出异常"""
|
||||
# 创建一个已过期的 token
|
||||
data = {"sub": "user123", "type": "access"}
|
||||
expire = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
data["exp"] = expire
|
||||
expired_token = jwt.encode(data, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await AuthService.verify_token(expired_token)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "过期" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_invalid_token_raises_error(self) -> None:
|
||||
"""测试验证无效令牌抛出异常"""
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await AuthService.verify_token("invalid.token.here")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_wrong_token_type_raises_error(self) -> None:
|
||||
"""测试令牌类型不匹配抛出异常"""
|
||||
data = {"sub": "user123"}
|
||||
refresh_token = AuthService.create_refresh_token(data)
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with patch(
|
||||
"src.services.auth.service.JWTBlacklistService.is_blacklisted",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await AuthService.verify_token(refresh_token, token_type="access")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "类型错误" in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_blacklisted_token_raises_error(self) -> None:
|
||||
"""测试已撤销的令牌抛出异常"""
|
||||
data = {"sub": "user123"}
|
||||
token = AuthService.create_access_token(data)
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with patch(
|
||||
"src.services.auth.service.JWTBlacklistService.is_blacklisted",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await AuthService.verify_token(token)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "撤销" in exc_info.value.detail
|
||||
|
||||
|
||||
class TestUserAuthentication:
|
||||
"""测试用户登录认证"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_user_success(self) -> None:
|
||||
"""测试用户登录成功"""
|
||||
# Mock 数据库和用户对象
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "user-123"
|
||||
mock_user.email = "test@example.com"
|
||||
mock_user.is_active = True
|
||||
mock_user.verify_password.return_value = True
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
|
||||
with patch(
|
||||
"src.services.auth.service.UserCacheService.invalidate_user_cache",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
result = await AuthService.authenticate_user(mock_db, "test@example.com", "password123")
|
||||
|
||||
assert result == mock_user
|
||||
mock_user.verify_password.assert_called_once_with("password123")
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_user_not_found(self) -> None:
|
||||
"""测试用户不存在"""
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
result = await AuthService.authenticate_user(mock_db, "nonexistent@example.com", "password")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_user_wrong_password(self) -> None:
|
||||
"""测试密码错误"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.email = "test@example.com"
|
||||
mock_user.verify_password.return_value = False
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
|
||||
result = await AuthService.authenticate_user(mock_db, "test@example.com", "wrongpassword")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_user_inactive(self) -> None:
|
||||
"""测试用户已禁用"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.email = "test@example.com"
|
||||
mock_user.is_active = False
|
||||
mock_user.verify_password.return_value = True
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = mock_user
|
||||
|
||||
result = await AuthService.authenticate_user(mock_db, "test@example.com", "password123")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestAPIKeyAuthentication:
|
||||
"""测试 API Key 认证"""
|
||||
|
||||
def test_authenticate_api_key_success(self) -> None:
|
||||
"""测试 API Key 认证成功"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "user-123"
|
||||
mock_user.email = "test@example.com"
|
||||
mock_user.is_active = True
|
||||
|
||||
mock_api_key = MagicMock()
|
||||
mock_api_key.is_active = True
|
||||
mock_api_key.expires_at = None
|
||||
mock_api_key.user = mock_user
|
||||
mock_api_key.balance_used_usd = 0.0
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = (
|
||||
mock_api_key
|
||||
)
|
||||
|
||||
with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"):
|
||||
with patch(
|
||||
"src.services.auth.service.ApiKeyService.check_balance",
|
||||
return_value=(True, 100.0),
|
||||
):
|
||||
result = AuthService.authenticate_api_key(mock_db, "sk-test-key")
|
||||
|
||||
assert result is not None
|
||||
assert result[0] == mock_user
|
||||
assert result[1] == mock_api_key
|
||||
|
||||
def test_authenticate_api_key_not_found(self) -> None:
|
||||
"""测试 API Key 不存在"""
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = (
|
||||
None
|
||||
)
|
||||
|
||||
with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"):
|
||||
result = AuthService.authenticate_api_key(mock_db, "sk-invalid-key")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_authenticate_api_key_inactive(self) -> None:
|
||||
"""测试 API Key 已禁用"""
|
||||
mock_api_key = MagicMock()
|
||||
mock_api_key.is_active = False
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = (
|
||||
mock_api_key
|
||||
)
|
||||
|
||||
with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"):
|
||||
result = AuthService.authenticate_api_key(mock_db, "sk-inactive-key")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_authenticate_api_key_expired(self) -> None:
|
||||
"""测试 API Key 已过期"""
|
||||
mock_api_key = MagicMock()
|
||||
mock_api_key.is_active = True
|
||||
mock_api_key.expires_at = datetime.now(timezone.utc) - timedelta(days=1)
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.options.return_value.filter.return_value.first.return_value = (
|
||||
mock_api_key
|
||||
)
|
||||
|
||||
with patch("src.services.auth.service.ApiKey.hash_key", return_value="hashed_key"):
|
||||
result = AuthService.authenticate_api_key(mock_db, "sk-expired-key")
|
||||
|
||||
assert result is None
|
||||
292
tests/services/test_usage_service.py
Normal file
292
tests/services/test_usage_service.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""
|
||||
UsageService 测试
|
||||
|
||||
测试用量统计服务的核心功能:
|
||||
- 成本计算
|
||||
- 配额检查
|
||||
- 用量统计查询
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
|
||||
from src.services.usage.service import UsageService
|
||||
|
||||
|
||||
class TestCostCalculation:
|
||||
"""测试成本计算"""
|
||||
|
||||
def test_calculate_cost_basic(self) -> None:
|
||||
"""测试基础成本计算"""
|
||||
# 价格:输入 $3/1M, 输出 $15/1M
|
||||
result = UsageService.calculate_cost(
|
||||
input_tokens=1000,
|
||||
output_tokens=500,
|
||||
input_price_per_1m=3.0,
|
||||
output_price_per_1m=15.0,
|
||||
)
|
||||
|
||||
input_cost, output_cost, cache_creation_cost, cache_read_cost, cache_cost, request_cost, total_cost = result
|
||||
|
||||
# 1000 tokens * $3 / 1M = $0.003
|
||||
assert abs(input_cost - 0.003) < 0.0001
|
||||
# 500 tokens * $15 / 1M = $0.0075
|
||||
assert abs(output_cost - 0.0075) < 0.0001
|
||||
# Total = $0.003 + $0.0075 = $0.0105
|
||||
assert abs(total_cost - 0.0105) < 0.0001
|
||||
|
||||
def test_calculate_cost_with_cache(self) -> None:
|
||||
"""测试带缓存的成本计算"""
|
||||
result = UsageService.calculate_cost(
|
||||
input_tokens=1000,
|
||||
output_tokens=500,
|
||||
input_price_per_1m=3.0,
|
||||
output_price_per_1m=15.0,
|
||||
cache_creation_input_tokens=200,
|
||||
cache_read_input_tokens=300,
|
||||
cache_creation_price_per_1m=3.75, # 1.25x input price
|
||||
cache_read_price_per_1m=0.3, # 0.1x input price
|
||||
)
|
||||
|
||||
(
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
cache_read_cost,
|
||||
cache_cost,
|
||||
request_cost,
|
||||
total_cost,
|
||||
) = result
|
||||
|
||||
# 验证缓存成本被计算
|
||||
assert cache_creation_cost > 0
|
||||
assert cache_read_cost > 0
|
||||
assert cache_cost == cache_creation_cost + cache_read_cost
|
||||
|
||||
def test_calculate_cost_with_request_price(self) -> None:
|
||||
"""测试按次计费"""
|
||||
result = UsageService.calculate_cost(
|
||||
input_tokens=1000,
|
||||
output_tokens=500,
|
||||
input_price_per_1m=3.0,
|
||||
output_price_per_1m=15.0,
|
||||
price_per_request=0.01,
|
||||
)
|
||||
|
||||
(
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
cache_read_cost,
|
||||
cache_cost,
|
||||
request_cost,
|
||||
total_cost,
|
||||
) = result
|
||||
|
||||
assert request_cost == 0.01
|
||||
# Total 包含 request_cost
|
||||
assert total_cost == input_cost + output_cost + request_cost
|
||||
|
||||
def test_calculate_cost_zero_tokens(self) -> None:
|
||||
"""测试零 token 的成本计算"""
|
||||
result = UsageService.calculate_cost(
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
input_price_per_1m=3.0,
|
||||
output_price_per_1m=15.0,
|
||||
)
|
||||
|
||||
(
|
||||
input_cost,
|
||||
output_cost,
|
||||
cache_creation_cost,
|
||||
cache_read_cost,
|
||||
cache_cost,
|
||||
request_cost,
|
||||
total_cost,
|
||||
) = result
|
||||
|
||||
assert input_cost == 0
|
||||
assert output_cost == 0
|
||||
assert total_cost == 0
|
||||
|
||||
|
||||
class TestQuotaCheck:
|
||||
"""测试配额检查"""
|
||||
|
||||
def test_check_user_quota_sufficient(self) -> None:
|
||||
"""测试配额充足"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.quota_usd = 100.0
|
||||
mock_user.used_usd = 30.0
|
||||
mock_user.role = MagicMock()
|
||||
mock_user.role.value = "user"
|
||||
|
||||
mock_api_key = MagicMock()
|
||||
mock_api_key.is_standalone = False
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key)
|
||||
|
||||
assert is_ok is True
|
||||
|
||||
def test_check_user_quota_exceeded(self) -> None:
|
||||
"""测试配额超限(当有预估成本时)"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.quota_usd = 100.0
|
||||
mock_user.used_usd = 99.0 # 接近配额上限
|
||||
mock_user.role = MagicMock()
|
||||
mock_user.role.value = "user"
|
||||
|
||||
mock_api_key = MagicMock()
|
||||
mock_api_key.is_standalone = False
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
# 当预估成本超过剩余配额时应该返回 False
|
||||
is_ok, message = UsageService.check_user_quota(
|
||||
mock_db, mock_user, estimated_cost=5.0, api_key=mock_api_key
|
||||
)
|
||||
|
||||
assert is_ok is False
|
||||
assert "配额" in message
|
||||
|
||||
def test_check_user_quota_no_limit(self) -> None:
|
||||
"""测试无配额限制(None)"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.quota_usd = None
|
||||
mock_user.used_usd = 1000.0
|
||||
mock_user.role = MagicMock()
|
||||
mock_user.role.value = "user"
|
||||
|
||||
mock_api_key = MagicMock()
|
||||
mock_api_key.is_standalone = False
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key)
|
||||
|
||||
assert is_ok is True
|
||||
|
||||
def test_check_user_quota_admin_bypass(self) -> None:
|
||||
"""测试管理员绕过配额检查"""
|
||||
from src.models.database import UserRole
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.quota_usd = 0.0
|
||||
mock_user.used_usd = 1000.0
|
||||
mock_user.role = UserRole.ADMIN
|
||||
|
||||
mock_api_key = MagicMock()
|
||||
mock_api_key.is_standalone = False
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key)
|
||||
|
||||
assert is_ok is True
|
||||
|
||||
def test_check_standalone_api_key_balance(self) -> None:
|
||||
"""测试独立 API Key 余额检查"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.quota_usd = 0.0
|
||||
mock_user.used_usd = 0.0
|
||||
mock_user.role = MagicMock()
|
||||
mock_user.role.value = "user"
|
||||
|
||||
mock_api_key = MagicMock()
|
||||
mock_api_key.is_standalone = True
|
||||
mock_api_key.current_balance_usd = 50.0
|
||||
mock_api_key.balance_used_usd = 10.0
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
is_ok, message = UsageService.check_user_quota(mock_db, mock_user, api_key=mock_api_key)
|
||||
|
||||
assert is_ok is True
|
||||
|
||||
def test_check_standalone_api_key_insufficient_balance(self) -> None:
|
||||
"""测试独立 API Key 余额不足"""
|
||||
mock_user = MagicMock()
|
||||
mock_user.quota_usd = 100.0
|
||||
mock_user.used_usd = 0.0
|
||||
mock_user.role = MagicMock()
|
||||
mock_user.role.value = "user"
|
||||
|
||||
mock_api_key = MagicMock()
|
||||
mock_api_key.is_standalone = True
|
||||
mock_api_key.current_balance_usd = 10.0
|
||||
mock_api_key.balance_used_usd = 9.0 # 剩余 $1
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
# 需要 mock ApiKeyService.get_remaining_balance
|
||||
with patch(
|
||||
"src.services.user.apikey.ApiKeyService.get_remaining_balance",
|
||||
return_value=1.0,
|
||||
):
|
||||
# 预估成本 $5 超过剩余余额 $1
|
||||
is_ok, message = UsageService.check_user_quota(
|
||||
mock_db, mock_user, estimated_cost=5.0, api_key=mock_api_key
|
||||
)
|
||||
|
||||
assert is_ok is False
|
||||
|
||||
|
||||
class TestUsageStatistics:
|
||||
"""测试用量统计查询
|
||||
|
||||
注意:get_usage_summary 方法内部使用了数据库方言特定的日期函数,
|
||||
需要真实数据库或更复杂的 mock。这里只测试方法存在性。
|
||||
"""
|
||||
|
||||
def test_get_usage_summary_exists(self) -> None:
|
||||
"""测试 get_usage_summary 方法存在"""
|
||||
assert hasattr(UsageService, "get_usage_summary")
|
||||
assert callable(getattr(UsageService, "get_usage_summary"))
|
||||
|
||||
|
||||
class TestHelperMethods:
|
||||
"""测试辅助方法"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_rate_multiplier_and_free_tier_default(self) -> None:
|
||||
"""测试默认费率倍数"""
|
||||
mock_db = MagicMock()
|
||||
# 模拟未找到 provider_api_key
|
||||
mock_db.query.return_value.filter.return_value.first.return_value = None
|
||||
|
||||
rate_multiplier, is_free_tier = await UsageService._get_rate_multiplier_and_free_tier(
|
||||
mock_db, provider_api_key_id=None, provider_id=None
|
||||
)
|
||||
|
||||
assert rate_multiplier == 1.0
|
||||
assert is_free_tier is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_rate_multiplier_from_provider_api_key(self) -> None:
|
||||
"""测试从 ProviderAPIKey 获取费率倍数"""
|
||||
mock_provider_api_key = MagicMock()
|
||||
mock_provider_api_key.rate_multiplier = 0.8
|
||||
|
||||
mock_endpoint = MagicMock()
|
||||
mock_endpoint.provider_id = "provider-123"
|
||||
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.billing_type = "standard"
|
||||
|
||||
mock_db = MagicMock()
|
||||
# 第一次查询返回 provider_api_key
|
||||
mock_db.query.return_value.filter.return_value.first.side_effect = [
|
||||
mock_provider_api_key,
|
||||
mock_endpoint,
|
||||
mock_provider,
|
||||
]
|
||||
|
||||
rate_multiplier, is_free_tier = await UsageService._get_rate_multiplier_and_free_tier(
|
||||
mock_db, provider_api_key_id="pak-123", provider_id=None
|
||||
)
|
||||
|
||||
assert rate_multiplier == 0.8
|
||||
assert is_free_tier is False
|
||||
Reference in New Issue
Block a user