mirror of
https://github.com/fawney19/Aether.git
synced 2026-01-03 00:02:28 +08:00
Compare commits
50 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9d5c84f9d3 | ||
|
|
53e6a82480 | ||
|
|
bd11ebdbd5 | ||
|
|
1dac4cb156 | ||
|
|
50abb55c94 | ||
|
|
73d3c9d3e4 | ||
|
|
d24c3885ab | ||
|
|
d696c575e6 | ||
|
|
46ff5a1a50 | ||
|
|
edce43d45f | ||
|
|
33265b4b13 | ||
|
|
a94aeca2d3 | ||
|
|
c42ebdd0ee | ||
|
|
f1e3c2ab11 | ||
|
|
4e2ba0e57f | ||
|
|
a3df41d63d | ||
|
|
ad1c8c394c | ||
|
|
9b496abb73 | ||
|
|
f3a69a6160 | ||
|
|
adcdb73d29 | ||
|
|
cf67160821 | ||
|
|
718f56ba75 | ||
|
|
d87de10f62 | ||
|
|
c6b9582978 | ||
|
|
3d583b0a8d | ||
|
|
f849a54027 | ||
|
|
f2cd96c34c | ||
|
|
34d480910a | ||
|
|
f16fb28405 | ||
|
|
a0ffc2c406 | ||
|
|
a7bfab1475 | ||
|
|
84d4db0f8d | ||
|
|
903b182fdf | ||
|
|
d9bd0790fe | ||
|
|
c6fcc7982d | ||
|
|
aaa6a8f60d | ||
|
|
11774c69b6 | ||
|
|
8f0a0cbdb1 | ||
|
|
51b85915d2 | ||
|
|
b0d295c6c9 | ||
|
|
c94f011462 | ||
|
|
3296d026e3 | ||
|
|
2e01c7cf5a | ||
|
|
88e37594cf | ||
|
|
03ee6c16d9 | ||
|
|
743f23e640 | ||
|
|
7068aa9130 | ||
|
|
56fb6bf36c | ||
|
|
728f9bb126 | ||
|
|
5319c06f0e |
@@ -23,7 +23,7 @@ RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
# Python 依赖(安装到系统,不用 -e 模式)
|
||||
COPY pyproject.toml README.md ./
|
||||
RUN mkdir -p src && touch src/__init__.py && \
|
||||
pip install --no-cache-dir .
|
||||
SETUPTOOLS_SCM_PRETEND_VERSION=0.1.0 pip install --no-cache-dir .
|
||||
|
||||
# 前端依赖
|
||||
COPY frontend/package*.json /tmp/frontend/
|
||||
|
||||
@@ -60,8 +60,11 @@ python generate_keys.py # 生成密钥, 并将生成的密钥填入 .env
|
||||
# 3. 部署
|
||||
docker-compose up -d
|
||||
|
||||
# 4. 更新
|
||||
docker-compose pull && docker-compose up -d
|
||||
# 4. 首次部署时, 初始化数据库
|
||||
./migrate.sh
|
||||
|
||||
# 5. 更新
|
||||
docker-compose pull && docker-compose up -d && ./migrate.sh
|
||||
```
|
||||
|
||||
### Docker Compose(本地构建镜像)
|
||||
|
||||
@@ -20,10 +20,10 @@ depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create ENUM types
|
||||
op.execute("CREATE TYPE userrole AS ENUM ('admin', 'user')")
|
||||
# Create ENUM types (with IF NOT EXISTS for idempotency)
|
||||
op.execute("DO $$ BEGIN CREATE TYPE userrole AS ENUM ('admin', 'user'); EXCEPTION WHEN duplicate_object THEN NULL; END $$")
|
||||
op.execute(
|
||||
"CREATE TYPE providerbillingtype AS ENUM ('monthly_quota', 'pay_as_you_go', 'free_tier')"
|
||||
"DO $$ BEGIN CREATE TYPE providerbillingtype AS ENUM ('monthly_quota', 'pay_as_you_go', 'free_tier'); EXCEPTION WHEN duplicate_object THEN NULL; END $$"
|
||||
)
|
||||
|
||||
# ==================== users ====================
|
||||
@@ -35,7 +35,7 @@ def upgrade() -> None:
|
||||
sa.Column("password_hash", sa.String(255), nullable=False),
|
||||
sa.Column(
|
||||
"role",
|
||||
sa.Enum("admin", "user", name="userrole", create_type=False),
|
||||
postgresql.ENUM("admin", "user", name="userrole", create_type=False),
|
||||
nullable=False,
|
||||
server_default="user",
|
||||
),
|
||||
@@ -67,7 +67,7 @@ def upgrade() -> None:
|
||||
sa.Column("website", sa.String(500), nullable=True),
|
||||
sa.Column(
|
||||
"billing_type",
|
||||
sa.Enum(
|
||||
postgresql.ENUM(
|
||||
"monthly_quota", "pay_as_you_go", "free_tier", name="providerbillingtype", create_type=False
|
||||
),
|
||||
nullable=False,
|
||||
|
||||
@@ -0,0 +1,315 @@
|
||||
"""remove_model_mappings_add_aliases
|
||||
|
||||
合并迁移:
|
||||
1. 添加 provider_model_aliases 字段到 models 表
|
||||
2. 迁移 model_mappings 数据到 provider_model_aliases
|
||||
3. 删除 model_mappings 表
|
||||
4. 添加索引优化别名解析性能
|
||||
|
||||
Revision ID: e9b3d63f0cbf
|
||||
Revises: 20251210_baseline
|
||||
Create Date: 2025-12-14 13:00:22.828183+00:00
|
||||
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'e9b3d63f0cbf'
|
||||
down_revision = '20251210_baseline'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def column_exists(bind, table_name: str, column_name: str) -> bool:
|
||||
"""检查列是否存在"""
|
||||
result = bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = :table_name AND column_name = :column_name
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"table_name": table_name, "column_name": column_name},
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
|
||||
def table_exists(bind, table_name: str) -> bool:
|
||||
"""检查表是否存在"""
|
||||
result = bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.tables
|
||||
WHERE table_name = :table_name
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"table_name": table_name},
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
|
||||
def index_exists(bind, index_name: str) -> bool:
|
||||
"""检查索引是否存在"""
|
||||
result = bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM pg_indexes
|
||||
WHERE indexname = :index_name
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"index_name": index_name},
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""添加 provider_model_aliases 字段,迁移数据,删除 model_mappings 表"""
|
||||
bind = op.get_bind()
|
||||
|
||||
# 1. 添加 provider_model_aliases 字段(如果不存在)
|
||||
if not column_exists(bind, "models", "provider_model_aliases"):
|
||||
op.add_column(
|
||||
'models',
|
||||
sa.Column('provider_model_aliases', sa.JSON(), nullable=True)
|
||||
)
|
||||
|
||||
# 2. 迁移 model_mappings 数据(如果表存在)
|
||||
session = Session(bind=bind)
|
||||
|
||||
model_mappings_table = sa.table(
|
||||
"model_mappings",
|
||||
sa.column("source_model", sa.String),
|
||||
sa.column("target_global_model_id", sa.String),
|
||||
sa.column("provider_id", sa.String),
|
||||
sa.column("mapping_type", sa.String),
|
||||
sa.column("is_active", sa.Boolean),
|
||||
)
|
||||
|
||||
models_table = sa.table(
|
||||
"models",
|
||||
sa.column("id", sa.String),
|
||||
sa.column("provider_id", sa.String),
|
||||
sa.column("global_model_id", sa.String),
|
||||
sa.column("provider_model_aliases", sa.JSON),
|
||||
sa.column("updated_at", sa.DateTime(timezone=True)),
|
||||
)
|
||||
|
||||
def normalize_alias_list(value) -> list[dict]:
|
||||
"""将 DB 返回的 JSON 值规范化为 list[{'name': str, 'priority': int}]"""
|
||||
if value is None:
|
||||
return []
|
||||
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
value = json.loads(value) if value else []
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
if not isinstance(value, list):
|
||||
return []
|
||||
|
||||
normalized: list[dict] = []
|
||||
for item in value:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
raw_name = item.get("name")
|
||||
if not isinstance(raw_name, str):
|
||||
continue
|
||||
name = raw_name.strip()
|
||||
if not name:
|
||||
continue
|
||||
|
||||
raw_priority = item.get("priority", 1)
|
||||
try:
|
||||
priority = int(raw_priority)
|
||||
except Exception:
|
||||
priority = 1
|
||||
if priority < 1:
|
||||
priority = 1
|
||||
|
||||
normalized.append({"name": name, "priority": priority})
|
||||
|
||||
return normalized
|
||||
|
||||
# 查询所有活跃的 provider 级别 alias(只迁移 is_active=True 且 mapping_type='alias' 的)
|
||||
# 全局别名/映射不迁移(新架构不再支持 source_model -> GlobalModel.name 的解析)
|
||||
# 仅当 model_mappings 表存在时执行迁移
|
||||
if table_exists(bind, "model_mappings"):
|
||||
mappings = session.execute(
|
||||
sa.select(
|
||||
model_mappings_table.c.source_model,
|
||||
model_mappings_table.c.target_global_model_id,
|
||||
model_mappings_table.c.provider_id,
|
||||
)
|
||||
.where(
|
||||
model_mappings_table.c.is_active.is_(True),
|
||||
model_mappings_table.c.provider_id.isnot(None),
|
||||
model_mappings_table.c.mapping_type == "alias",
|
||||
)
|
||||
.order_by(model_mappings_table.c.provider_id, model_mappings_table.c.source_model)
|
||||
).all()
|
||||
|
||||
# 按 (provider_id, target_global_model_id) 分组,收集别名
|
||||
alias_groups: dict = {}
|
||||
for source_model, target_global_model_id, provider_id in mappings:
|
||||
if not isinstance(source_model, str):
|
||||
continue
|
||||
source_model = source_model.strip()
|
||||
if not source_model:
|
||||
continue
|
||||
if not isinstance(provider_id, str) or not provider_id:
|
||||
continue
|
||||
if not isinstance(target_global_model_id, str) or not target_global_model_id:
|
||||
continue
|
||||
|
||||
key = (provider_id, target_global_model_id)
|
||||
if key not in alias_groups:
|
||||
alias_groups[key] = []
|
||||
priority = len(alias_groups[key]) + 1
|
||||
alias_groups[key].append({"name": source_model, "priority": priority})
|
||||
|
||||
# 更新对应的 models 记录
|
||||
for (provider_id, global_model_id), aliases in alias_groups.items():
|
||||
model_row = session.execute(
|
||||
sa.select(models_table.c.id, models_table.c.provider_model_aliases)
|
||||
.where(
|
||||
models_table.c.provider_id == provider_id,
|
||||
models_table.c.global_model_id == global_model_id,
|
||||
)
|
||||
.limit(1)
|
||||
).first()
|
||||
|
||||
if model_row:
|
||||
model_id = model_row[0]
|
||||
existing_aliases = normalize_alias_list(model_row[1])
|
||||
|
||||
existing_names = {a["name"] for a in existing_aliases}
|
||||
merged_aliases = list(existing_aliases)
|
||||
for alias in aliases:
|
||||
name = alias.get("name")
|
||||
if not isinstance(name, str):
|
||||
continue
|
||||
name = name.strip()
|
||||
if not name or name in existing_names:
|
||||
continue
|
||||
|
||||
merged_aliases.append(
|
||||
{
|
||||
"name": name,
|
||||
"priority": len(merged_aliases) + 1,
|
||||
}
|
||||
)
|
||||
existing_names.add(name)
|
||||
|
||||
session.execute(
|
||||
models_table.update()
|
||||
.where(models_table.c.id == model_id)
|
||||
.values(
|
||||
provider_model_aliases=merged_aliases if merged_aliases else None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
)
|
||||
|
||||
session.commit()
|
||||
|
||||
# 3. 删除 model_mappings 表
|
||||
op.drop_table('model_mappings')
|
||||
|
||||
# 4. 添加索引优化别名解析性能
|
||||
# provider_model_name 索引(支持精确匹配,如果不存在)
|
||||
if not index_exists(bind, "idx_model_provider_model_name"):
|
||||
op.create_index(
|
||||
"idx_model_provider_model_name",
|
||||
"models",
|
||||
["provider_model_name"],
|
||||
unique=False,
|
||||
postgresql_where=sa.text("is_active = true"),
|
||||
)
|
||||
|
||||
# provider_model_aliases GIN 索引(支持 JSONB 查询,仅 PostgreSQL)
|
||||
if bind.dialect.name == "postgresql":
|
||||
# 将 json 列转为 jsonb(jsonb 性能更好且支持 GIN 索引)
|
||||
# 使用 IF NOT EXISTS 风格的检查来避免重复转换
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = 'models'
|
||||
AND column_name = 'provider_model_aliases'
|
||||
AND data_type = 'json'
|
||||
) THEN
|
||||
ALTER TABLE models
|
||||
ALTER COLUMN provider_model_aliases TYPE jsonb
|
||||
USING provider_model_aliases::jsonb;
|
||||
END IF;
|
||||
END $$;
|
||||
"""
|
||||
)
|
||||
# 创建 GIN 索引
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_model_provider_model_aliases_gin
|
||||
ON models USING gin(provider_model_aliases jsonb_path_ops)
|
||||
WHERE is_active = true
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""恢复 model_mappings 表,移除 provider_model_aliases 字段和索引"""
|
||||
bind = op.get_bind()
|
||||
|
||||
# 1. 删除索引
|
||||
op.drop_index("idx_model_provider_model_name", table_name="models")
|
||||
|
||||
if bind.dialect.name == "postgresql":
|
||||
op.execute("DROP INDEX IF EXISTS idx_model_provider_model_aliases_gin")
|
||||
# 将 jsonb 列还原为 json
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE models
|
||||
ALTER COLUMN provider_model_aliases TYPE json
|
||||
USING provider_model_aliases::json
|
||||
"""
|
||||
)
|
||||
|
||||
# 2. 恢复 model_mappings 表
|
||||
op.create_table(
|
||||
'model_mappings',
|
||||
sa.Column('id', sa.String(36), primary_key=True),
|
||||
sa.Column('source_model', sa.String(200), nullable=False),
|
||||
sa.Column(
|
||||
'target_global_model_id',
|
||||
sa.String(36),
|
||||
sa.ForeignKey('global_models.id', ondelete='CASCADE'),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column('provider_id', sa.String(36), sa.ForeignKey('providers.id'), nullable=True),
|
||||
sa.Column('mapping_type', sa.String(20), nullable=False, server_default='alias'),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.UniqueConstraint('source_model', 'provider_id', name='uq_model_mapping_source_provider'),
|
||||
)
|
||||
op.create_index('ix_model_mappings_source_model', 'model_mappings', ['source_model'])
|
||||
op.create_index('ix_model_mappings_target_global_model_id', 'model_mappings', ['target_global_model_id'])
|
||||
op.create_index('ix_model_mappings_provider_id', 'model_mappings', ['provider_id'])
|
||||
op.create_index('ix_model_mappings_mapping_type', 'model_mappings', ['mapping_type'])
|
||||
|
||||
# 3. 移除 provider_model_aliases 字段
|
||||
op.drop_column('models', 'provider_model_aliases')
|
||||
@@ -0,0 +1,47 @@
|
||||
"""add first_byte_time_ms to usage table
|
||||
|
||||
Revision ID: 180e63a9c83a
|
||||
Revises: e9b3d63f0cbf
|
||||
Create Date: 2025-12-15 17:07:44.631032+00:00
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '180e63a9c83a'
|
||||
down_revision = 'e9b3d63f0cbf'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def column_exists(bind, table_name: str, column_name: str) -> bool:
|
||||
"""检查列是否存在"""
|
||||
result = bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = :table_name AND column_name = :column_name
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"table_name": table_name, "column_name": column_name},
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""应用迁移:升级到新版本"""
|
||||
bind = op.get_bind()
|
||||
|
||||
# 添加首字时间字段到 usage 表(如果不存在)
|
||||
if not column_exists(bind, "usage", "first_byte_time_ms"):
|
||||
op.add_column('usage', sa.Column('first_byte_time_ms', sa.Integer(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""回滚迁移:降级到旧版本"""
|
||||
# 删除首字时间字段
|
||||
op.drop_column('usage', 'first_byte_time_ms')
|
||||
@@ -0,0 +1,110 @@
|
||||
"""refactor global_model to use config json field
|
||||
|
||||
Revision ID: 1cc6942cf06f
|
||||
Revises: 180e63a9c83a
|
||||
Create Date: 2025-12-16 03:11:32.480976+00:00
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '1cc6942cf06f'
|
||||
down_revision = '180e63a9c83a'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def column_exists(bind, table_name: str, column_name: str) -> bool:
|
||||
"""检查列是否存在"""
|
||||
result = bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_name = :table_name AND column_name = :column_name
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"table_name": table_name, "column_name": column_name},
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""应用迁移:升级到新版本
|
||||
|
||||
1. 添加 config 列
|
||||
2. 把旧数据迁移到 config
|
||||
3. 删除旧列
|
||||
"""
|
||||
bind = op.get_bind()
|
||||
|
||||
# 检查是否已经迁移过(config 列存在且旧列不存在)
|
||||
has_config = column_exists(bind, "global_models", "config")
|
||||
has_old_columns = column_exists(bind, "global_models", "default_supports_streaming")
|
||||
|
||||
if has_config and not has_old_columns:
|
||||
# 已完成迁移,跳过
|
||||
return
|
||||
|
||||
# 1. 添加 config 列(使用 JSONB 类型,支持索引和更高效的查询)
|
||||
if not has_config:
|
||||
op.add_column('global_models', sa.Column('config', postgresql.JSONB(), nullable=True))
|
||||
|
||||
# 2. 迁移数据:把旧字段合并到 config JSON(仅当旧列存在时)
|
||||
if has_old_columns:
|
||||
op.execute("""
|
||||
UPDATE global_models
|
||||
SET config = jsonb_strip_nulls(jsonb_build_object(
|
||||
'streaming', COALESCE(default_supports_streaming, true),
|
||||
'vision', CASE WHEN COALESCE(default_supports_vision, false) THEN true ELSE NULL END,
|
||||
'function_calling', CASE WHEN COALESCE(default_supports_function_calling, false) THEN true ELSE NULL END,
|
||||
'extended_thinking', CASE WHEN COALESCE(default_supports_extended_thinking, false) THEN true ELSE NULL END,
|
||||
'image_generation', CASE WHEN COALESCE(default_supports_image_generation, false) THEN true ELSE NULL END,
|
||||
'description', description,
|
||||
'icon_url', icon_url,
|
||||
'official_url', official_url
|
||||
))
|
||||
""")
|
||||
|
||||
# 3. 删除旧列
|
||||
op.drop_column('global_models', 'default_supports_streaming')
|
||||
op.drop_column('global_models', 'default_supports_vision')
|
||||
op.drop_column('global_models', 'default_supports_function_calling')
|
||||
op.drop_column('global_models', 'default_supports_extended_thinking')
|
||||
op.drop_column('global_models', 'default_supports_image_generation')
|
||||
op.drop_column('global_models', 'description')
|
||||
op.drop_column('global_models', 'icon_url')
|
||||
op.drop_column('global_models', 'official_url')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""回滚迁移:降级到旧版本"""
|
||||
# 1. 添加旧列
|
||||
op.add_column('global_models', sa.Column('icon_url', sa.VARCHAR(length=500), nullable=True))
|
||||
op.add_column('global_models', sa.Column('official_url', sa.VARCHAR(length=500), nullable=True))
|
||||
op.add_column('global_models', sa.Column('description', sa.TEXT(), nullable=True))
|
||||
op.add_column('global_models', sa.Column('default_supports_streaming', sa.BOOLEAN(), nullable=True))
|
||||
op.add_column('global_models', sa.Column('default_supports_vision', sa.BOOLEAN(), nullable=True))
|
||||
op.add_column('global_models', sa.Column('default_supports_function_calling', sa.BOOLEAN(), nullable=True))
|
||||
op.add_column('global_models', sa.Column('default_supports_extended_thinking', sa.BOOLEAN(), nullable=True))
|
||||
op.add_column('global_models', sa.Column('default_supports_image_generation', sa.BOOLEAN(), nullable=True))
|
||||
|
||||
# 2. 从 config 恢复数据
|
||||
op.execute("""
|
||||
UPDATE global_models
|
||||
SET
|
||||
default_supports_streaming = COALESCE((config->>'streaming')::boolean, true),
|
||||
default_supports_vision = COALESCE((config->>'vision')::boolean, false),
|
||||
default_supports_function_calling = COALESCE((config->>'function_calling')::boolean, false),
|
||||
default_supports_extended_thinking = COALESCE((config->>'extended_thinking')::boolean, false),
|
||||
default_supports_image_generation = COALESCE((config->>'image_generation')::boolean, false),
|
||||
description = config->>'description',
|
||||
icon_url = config->>'icon_url',
|
||||
official_url = config->>'official_url'
|
||||
""")
|
||||
|
||||
# 3. 删除 config 列
|
||||
op.drop_column('global_models', 'config')
|
||||
@@ -12,8 +12,6 @@ services:
|
||||
TZ: Asia/Shanghai
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
ports:
|
||||
- "${DB_PORT:-5432}:5432"
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U postgres"]
|
||||
interval: 5s
|
||||
@@ -27,8 +25,6 @@ services:
|
||||
command: redis-server --appendonly yes --requirepass ${REDIS_PASSWORD}
|
||||
volumes:
|
||||
- redis_data:/data
|
||||
ports:
|
||||
- "${REDIS_PORT:-6379}:6379"
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "--raw", "incr", "ping"]
|
||||
interval: 5s
|
||||
|
||||
@@ -1,5 +1,179 @@
|
||||
import apiClient from './client'
|
||||
|
||||
// 配置导出数据结构
|
||||
export interface ConfigExportData {
|
||||
version: string
|
||||
exported_at: string
|
||||
global_models: GlobalModelExport[]
|
||||
providers: ProviderExport[]
|
||||
}
|
||||
|
||||
// 用户导出数据结构
|
||||
export interface UsersExportData {
|
||||
version: string
|
||||
exported_at: string
|
||||
users: UserExport[]
|
||||
}
|
||||
|
||||
export interface UserExport {
|
||||
email: string
|
||||
username: string
|
||||
password_hash: string
|
||||
role: string
|
||||
allowed_providers?: string[] | null
|
||||
allowed_endpoints?: string[] | null
|
||||
allowed_models?: string[] | null
|
||||
model_capability_settings?: any
|
||||
quota_usd?: number | null
|
||||
used_usd?: number
|
||||
total_usd?: number
|
||||
is_active: boolean
|
||||
api_keys: UserApiKeyExport[]
|
||||
}
|
||||
|
||||
export interface UserApiKeyExport {
|
||||
key_hash: string
|
||||
key_encrypted?: string | null
|
||||
name?: string | null
|
||||
is_standalone: boolean
|
||||
balance_used_usd?: number
|
||||
current_balance_usd?: number | null
|
||||
allowed_providers?: string[] | null
|
||||
allowed_endpoints?: string[] | null
|
||||
allowed_api_formats?: string[] | null
|
||||
allowed_models?: string[] | null
|
||||
rate_limit?: number
|
||||
concurrent_limit?: number | null
|
||||
force_capabilities?: any
|
||||
is_active: boolean
|
||||
auto_delete_on_expiry?: boolean
|
||||
total_requests?: number
|
||||
total_cost_usd?: number
|
||||
}
|
||||
|
||||
export interface GlobalModelExport {
|
||||
name: string
|
||||
display_name: string
|
||||
default_price_per_request?: number | null
|
||||
default_tiered_pricing: any
|
||||
supported_capabilities?: string[] | null
|
||||
config?: any
|
||||
is_active: boolean
|
||||
}
|
||||
|
||||
export interface ProviderExport {
|
||||
name: string
|
||||
display_name: string
|
||||
description?: string | null
|
||||
website?: string | null
|
||||
billing_type?: string | null
|
||||
monthly_quota_usd?: number | null
|
||||
quota_reset_day?: number
|
||||
rpm_limit?: number | null
|
||||
provider_priority?: number
|
||||
is_active: boolean
|
||||
rate_limit?: number | null
|
||||
concurrent_limit?: number | null
|
||||
config?: any
|
||||
endpoints: EndpointExport[]
|
||||
models: ModelExport[]
|
||||
}
|
||||
|
||||
export interface EndpointExport {
|
||||
api_format: string
|
||||
base_url: string
|
||||
headers?: any
|
||||
timeout?: number
|
||||
max_retries?: number
|
||||
max_concurrent?: number | null
|
||||
rate_limit?: number | null
|
||||
is_active: boolean
|
||||
custom_path?: string | null
|
||||
config?: any
|
||||
keys: KeyExport[]
|
||||
}
|
||||
|
||||
export interface KeyExport {
|
||||
api_key: string
|
||||
name?: string | null
|
||||
note?: string | null
|
||||
rate_multiplier?: number
|
||||
internal_priority?: number
|
||||
global_priority?: number | null
|
||||
max_concurrent?: number | null
|
||||
rate_limit?: number | null
|
||||
daily_limit?: number | null
|
||||
monthly_limit?: number | null
|
||||
allowed_models?: string[] | null
|
||||
capabilities?: any
|
||||
is_active: boolean
|
||||
}
|
||||
|
||||
export interface ModelExport {
|
||||
global_model_name: string | null
|
||||
provider_model_name: string
|
||||
provider_model_aliases?: any
|
||||
price_per_request?: number | null
|
||||
tiered_pricing?: any
|
||||
supports_vision?: boolean | null
|
||||
supports_function_calling?: boolean | null
|
||||
supports_streaming?: boolean | null
|
||||
supports_extended_thinking?: boolean | null
|
||||
supports_image_generation?: boolean | null
|
||||
is_active: boolean
|
||||
config?: any
|
||||
}
|
||||
|
||||
// Provider 模型查询响应
|
||||
export interface ProviderModelsQueryResponse {
|
||||
success: boolean
|
||||
data: {
|
||||
models: Array<{
|
||||
id: string
|
||||
object?: string
|
||||
created?: number
|
||||
owned_by?: string
|
||||
display_name?: string
|
||||
api_format?: string
|
||||
}>
|
||||
error?: string
|
||||
}
|
||||
provider: {
|
||||
id: string
|
||||
name: string
|
||||
display_name: string
|
||||
}
|
||||
}
|
||||
|
||||
export interface ConfigImportRequest extends ConfigExportData {
|
||||
merge_mode: 'skip' | 'overwrite' | 'error'
|
||||
}
|
||||
|
||||
export interface UsersImportRequest extends UsersExportData {
|
||||
merge_mode: 'skip' | 'overwrite' | 'error'
|
||||
}
|
||||
|
||||
export interface UsersImportResponse {
|
||||
message: string
|
||||
stats: {
|
||||
users: { created: number; updated: number; skipped: number }
|
||||
api_keys: { created: number; skipped: number }
|
||||
errors: string[]
|
||||
}
|
||||
}
|
||||
|
||||
export interface ConfigImportResponse {
|
||||
message: string
|
||||
stats: {
|
||||
global_models: { created: number; updated: number; skipped: number }
|
||||
providers: { created: number; updated: number; skipped: number }
|
||||
endpoints: { created: number; updated: number; skipped: number }
|
||||
keys: { created: number; updated: number; skipped: number }
|
||||
models: { created: number; updated: number; skipped: number }
|
||||
errors: string[]
|
||||
}
|
||||
}
|
||||
|
||||
// API密钥管理相关接口定义
|
||||
export interface AdminApiKey {
|
||||
id: string // UUID
|
||||
@@ -173,5 +347,44 @@ export const adminApi = {
|
||||
'/api/admin/system/api-formats'
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
// 导出配置
|
||||
async exportConfig(): Promise<ConfigExportData> {
|
||||
const response = await apiClient.get<ConfigExportData>('/api/admin/system/config/export')
|
||||
return response.data
|
||||
},
|
||||
|
||||
// 导入配置
|
||||
async importConfig(data: ConfigImportRequest): Promise<ConfigImportResponse> {
|
||||
const response = await apiClient.post<ConfigImportResponse>(
|
||||
'/api/admin/system/config/import',
|
||||
data
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
// 导出用户数据
|
||||
async exportUsers(): Promise<UsersExportData> {
|
||||
const response = await apiClient.get<UsersExportData>('/api/admin/system/users/export')
|
||||
return response.data
|
||||
},
|
||||
|
||||
// 导入用户数据
|
||||
async importUsers(data: UsersImportRequest): Promise<UsersImportResponse> {
|
||||
const response = await apiClient.post<UsersImportResponse>(
|
||||
'/api/admin/system/users/import',
|
||||
data
|
||||
)
|
||||
return response.data
|
||||
},
|
||||
|
||||
// 查询 Provider 可用模型(从上游 API 获取)
|
||||
async queryProviderModels(providerId: string, apiKeyId?: string): Promise<ProviderModelsQueryResponse> {
|
||||
const response = await apiClient.post<ProviderModelsQueryResponse>(
|
||||
'/api/admin/provider-query/models',
|
||||
{ provider_id: providerId, api_key_id: apiKeyId }
|
||||
)
|
||||
return response.data
|
||||
}
|
||||
}
|
||||
|
||||
@@ -271,3 +271,93 @@ export const cacheAnalysisApi = {
|
||||
return response.data
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 模型映射缓存管理 API ====================
|
||||
|
||||
// 映射条目
|
||||
export interface ModelMappingItem {
|
||||
mapping_name: string
|
||||
global_model_name: string | null
|
||||
global_model_display_name: string | null
|
||||
providers: string[]
|
||||
ttl: number | null
|
||||
}
|
||||
|
||||
// 未映射的条目(NOT_FOUND、invalid、error)
|
||||
export interface UnmappedEntry {
|
||||
mapping_name: string
|
||||
status: 'not_found' | 'invalid' | 'error'
|
||||
ttl: number | null
|
||||
}
|
||||
|
||||
// Provider 模型映射缓存(Redis 缓存)
|
||||
export interface ProviderModelMapping {
|
||||
provider_id: string
|
||||
provider_name: string
|
||||
global_model_id: string
|
||||
global_model_name: string
|
||||
global_model_display_name: string | null
|
||||
provider_model_name: string
|
||||
aliases: string[] | null
|
||||
ttl: number | null
|
||||
hit_count: number
|
||||
}
|
||||
|
||||
export interface ModelMappingCacheStats {
|
||||
available: boolean
|
||||
message?: string
|
||||
ttl_seconds?: number
|
||||
total_keys?: number
|
||||
breakdown?: {
|
||||
model_by_id: number
|
||||
model_by_provider_global: number
|
||||
global_model_by_id: number
|
||||
global_model_by_name: number
|
||||
global_model_resolve: number
|
||||
}
|
||||
mappings?: ModelMappingItem[]
|
||||
provider_model_mappings?: ProviderModelMapping[] | null
|
||||
unmapped?: UnmappedEntry[] | null
|
||||
}
|
||||
|
||||
export interface ClearModelMappingCacheResponse {
|
||||
status: string
|
||||
message: string
|
||||
deleted_count?: number
|
||||
model_name?: string
|
||||
deleted_keys?: string[]
|
||||
}
|
||||
|
||||
export const modelMappingCacheApi = {
|
||||
/**
|
||||
* 获取模型映射缓存统计
|
||||
*/
|
||||
async getStats(): Promise<ModelMappingCacheStats> {
|
||||
const response = await api.get('/api/admin/monitoring/cache/model-mapping/stats')
|
||||
return response.data.data
|
||||
},
|
||||
|
||||
/**
|
||||
* 清除所有模型映射缓存
|
||||
*/
|
||||
async clearAll(): Promise<ClearModelMappingCacheResponse> {
|
||||
const response = await api.delete('/api/admin/monitoring/cache/model-mapping')
|
||||
return response.data
|
||||
},
|
||||
|
||||
/**
|
||||
* 清除指定模型名称的映射缓存
|
||||
*/
|
||||
async clearByName(modelName: string): Promise<ClearModelMappingCacheResponse> {
|
||||
const response = await api.delete(`/api/admin/monitoring/cache/model-mapping/${encodeURIComponent(modelName)}`)
|
||||
return response.data
|
||||
},
|
||||
|
||||
/**
|
||||
* 清除指定 Provider 和 GlobalModel 的映射缓存
|
||||
*/
|
||||
async clearProviderModel(providerId: string, globalModelId: string): Promise<ClearModelMappingCacheResponse> {
|
||||
const response = await api.delete(`/api/admin/monitoring/cache/model-mapping/provider/${providerId}/${globalModelId}`)
|
||||
return response.data
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,121 +0,0 @@
|
||||
/**
|
||||
* 模型别名管理 API
|
||||
*/
|
||||
|
||||
import client from '../client'
|
||||
import type { ModelMapping, ModelMappingCreate, ModelMappingUpdate } from './types'
|
||||
|
||||
export interface ModelAlias {
|
||||
id: string
|
||||
alias: string
|
||||
global_model_id: string
|
||||
global_model_name: string | null
|
||||
global_model_display_name: string | null
|
||||
provider_id: string | null
|
||||
provider_name: string | null
|
||||
scope: 'global' | 'provider'
|
||||
mapping_type: 'alias' | 'mapping'
|
||||
is_active: boolean
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
export interface CreateModelAliasRequest {
|
||||
alias: string
|
||||
global_model_id: string
|
||||
provider_id?: string | null
|
||||
mapping_type?: 'alias' | 'mapping'
|
||||
is_active?: boolean
|
||||
}
|
||||
|
||||
export interface UpdateModelAliasRequest {
|
||||
alias?: string
|
||||
global_model_id?: string
|
||||
provider_id?: string | null
|
||||
mapping_type?: 'alias' | 'mapping'
|
||||
is_active?: boolean
|
||||
}
|
||||
|
||||
function transformMapping(mapping: ModelMapping): ModelAlias {
|
||||
return {
|
||||
id: mapping.id,
|
||||
alias: mapping.source_model,
|
||||
global_model_id: mapping.target_global_model_id,
|
||||
global_model_name: mapping.target_global_model_name,
|
||||
global_model_display_name: mapping.target_global_model_display_name,
|
||||
provider_id: mapping.provider_id ?? null,
|
||||
provider_name: mapping.provider_name ?? null,
|
||||
scope: mapping.scope,
|
||||
mapping_type: mapping.mapping_type || 'alias',
|
||||
is_active: mapping.is_active,
|
||||
created_at: mapping.created_at,
|
||||
updated_at: mapping.updated_at
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取别名列表
|
||||
*/
|
||||
export async function getAliases(params?: {
|
||||
provider_id?: string
|
||||
global_model_id?: string
|
||||
is_active?: boolean
|
||||
skip?: number
|
||||
limit?: number
|
||||
}): Promise<ModelAlias[]> {
|
||||
const response = await client.get('/api/admin/models/mappings', {
|
||||
params: {
|
||||
provider_id: params?.provider_id,
|
||||
target_global_model_id: params?.global_model_id,
|
||||
is_active: params?.is_active,
|
||||
skip: params?.skip,
|
||||
limit: params?.limit
|
||||
}
|
||||
})
|
||||
return (response.data as ModelMapping[]).map(transformMapping)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取单个别名
|
||||
*/
|
||||
export async function getAlias(id: string): Promise<ModelAlias> {
|
||||
const response = await client.get(`/api/admin/models/mappings/${id}`)
|
||||
return transformMapping(response.data)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建别名
|
||||
*/
|
||||
export async function createAlias(data: CreateModelAliasRequest): Promise<ModelAlias> {
|
||||
const payload: ModelMappingCreate = {
|
||||
source_model: data.alias,
|
||||
target_global_model_id: data.global_model_id,
|
||||
provider_id: data.provider_id ?? null,
|
||||
mapping_type: data.mapping_type ?? 'alias',
|
||||
is_active: data.is_active ?? true
|
||||
}
|
||||
const response = await client.post('/api/admin/models/mappings', payload)
|
||||
return transformMapping(response.data)
|
||||
}
|
||||
|
||||
/**
|
||||
* 更新别名
|
||||
*/
|
||||
export async function updateAlias(id: string, data: UpdateModelAliasRequest): Promise<ModelAlias> {
|
||||
const payload: ModelMappingUpdate = {
|
||||
source_model: data.alias,
|
||||
target_global_model_id: data.global_model_id,
|
||||
provider_id: data.provider_id ?? null,
|
||||
mapping_type: data.mapping_type,
|
||||
is_active: data.is_active
|
||||
}
|
||||
const response = await client.patch(`/api/admin/models/mappings/${id}`, payload)
|
||||
return transformMapping(response.data)
|
||||
}
|
||||
|
||||
/**
|
||||
* 删除别名
|
||||
*/
|
||||
export async function deleteAlias(id: string): Promise<void> {
|
||||
await client.delete(`/api/admin/models/mappings/${id}`)
|
||||
}
|
||||
@@ -4,6 +4,5 @@ export * from './endpoints'
|
||||
export * from './keys'
|
||||
export * from './health'
|
||||
export * from './models'
|
||||
export * from './aliases'
|
||||
export * from './adaptive'
|
||||
export * from './global-models'
|
||||
|
||||
@@ -5,9 +5,6 @@ import type {
|
||||
ModelUpdate,
|
||||
ModelCatalogResponse,
|
||||
ProviderAvailableSourceModelsResponse,
|
||||
UpdateModelMappingRequest,
|
||||
UpdateModelMappingResponse,
|
||||
DeleteModelMappingResponse
|
||||
} from './types'
|
||||
|
||||
/**
|
||||
@@ -99,27 +96,6 @@ export async function getProviderAvailableSourceModels(
|
||||
return response.data
|
||||
}
|
||||
|
||||
/**
|
||||
* 更新目录中的模型映射
|
||||
*/
|
||||
export async function updateCatalogMapping(
|
||||
mappingId: string,
|
||||
data: UpdateModelMappingRequest
|
||||
): Promise<UpdateModelMappingResponse> {
|
||||
const response = await client.put(`/api/admin/models/catalog/mappings/${mappingId}`, data)
|
||||
return response.data
|
||||
}
|
||||
|
||||
/**
|
||||
* 删除目录中的模型映射
|
||||
*/
|
||||
export async function deleteCatalogMapping(
|
||||
mappingId: string
|
||||
): Promise<DeleteModelMappingResponse> {
|
||||
const response = await client.delete(`/api/admin/models/catalog/mappings/${mappingId}`)
|
||||
return response.data
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量为 Provider 关联 GlobalModels
|
||||
*/
|
||||
|
||||
@@ -1,3 +1,25 @@
|
||||
// API 格式常量
|
||||
export const API_FORMATS = {
|
||||
CLAUDE: 'CLAUDE',
|
||||
CLAUDE_CLI: 'CLAUDE_CLI',
|
||||
OPENAI: 'OPENAI',
|
||||
OPENAI_CLI: 'OPENAI_CLI',
|
||||
GEMINI: 'GEMINI',
|
||||
GEMINI_CLI: 'GEMINI_CLI',
|
||||
} as const
|
||||
|
||||
export type APIFormat = typeof API_FORMATS[keyof typeof API_FORMATS]
|
||||
|
||||
// API 格式显示名称映射(按品牌分组:API 在前,CLI 在后)
|
||||
export const API_FORMAT_LABELS: Record<string, string> = {
|
||||
[API_FORMATS.CLAUDE]: 'Claude',
|
||||
[API_FORMATS.CLAUDE_CLI]: 'Claude CLI',
|
||||
[API_FORMATS.OPENAI]: 'OpenAI',
|
||||
[API_FORMATS.OPENAI_CLI]: 'OpenAI CLI',
|
||||
[API_FORMATS.GEMINI]: 'Gemini',
|
||||
[API_FORMATS.GEMINI_CLI]: 'Gemini CLI',
|
||||
}
|
||||
|
||||
export interface ProviderEndpoint {
|
||||
id: string
|
||||
provider_id: string
|
||||
@@ -211,11 +233,18 @@ export interface ConcurrencyStatus {
|
||||
key_max_concurrent?: number
|
||||
}
|
||||
|
||||
export interface ProviderModelAlias {
|
||||
name: string
|
||||
priority: number // 优先级(数字越小优先级越高)
|
||||
api_formats?: string[] // 作用域(适用的 API 格式),为空表示对所有格式生效
|
||||
}
|
||||
|
||||
export interface Model {
|
||||
id: string
|
||||
provider_id: string
|
||||
global_model_id?: string // 关联的 GlobalModel ID
|
||||
provider_model_name: string // Provider 侧的模型名称(原 name)
|
||||
provider_model_name: string // Provider 侧的主模型名称
|
||||
provider_model_aliases?: ProviderModelAlias[] | null // 模型名称别名列表(带优先级)
|
||||
// 原始配置值(可能为空,为空时使用 GlobalModel 默认值)
|
||||
price_per_request?: number | null // 按次计费价格
|
||||
tiered_pricing?: TieredPricingConfig | null // 阶梯计费配置
|
||||
@@ -244,7 +273,8 @@ export interface Model {
|
||||
}
|
||||
|
||||
export interface ModelCreate {
|
||||
provider_model_name: string // Provider 侧的模型名称(原 name)
|
||||
provider_model_name: string // Provider 侧的主模型名称
|
||||
provider_model_aliases?: ProviderModelAlias[] // 模型名称别名列表(带优先级)
|
||||
global_model_id: string // 关联的 GlobalModel ID(必填)
|
||||
// 计费配置(可选,为空时使用 GlobalModel 默认值)
|
||||
price_per_request?: number // 按次计费价格
|
||||
@@ -261,6 +291,7 @@ export interface ModelCreate {
|
||||
|
||||
export interface ModelUpdate {
|
||||
provider_model_name?: string
|
||||
provider_model_aliases?: ProviderModelAlias[] | null // 模型名称别名列表(带优先级)
|
||||
global_model_id?: string
|
||||
price_per_request?: number | null // 按次计费价格(null 表示清空/使用默认值)
|
||||
tiered_pricing?: TieredPricingConfig | null // 阶梯计费配置
|
||||
@@ -273,21 +304,6 @@ export interface ModelUpdate {
|
||||
is_available?: boolean
|
||||
}
|
||||
|
||||
export interface ModelMapping {
|
||||
id: string
|
||||
source_model: string // 别名/源模型名
|
||||
target_global_model_id: string // 目标 GlobalModel ID
|
||||
target_global_model_name: string | null
|
||||
target_global_model_display_name: string | null
|
||||
provider_id: string | null
|
||||
provider_name: string | null
|
||||
scope: 'global' | 'provider'
|
||||
mapping_type: 'alias' | 'mapping'
|
||||
is_active: boolean
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
export interface ModelCapabilities {
|
||||
supports_vision: boolean
|
||||
supports_function_calling: boolean
|
||||
@@ -335,7 +351,6 @@ export interface ModelCatalogItem {
|
||||
global_model_name: string // GlobalModel.name(原 source_model)
|
||||
display_name: string // GlobalModel.display_name
|
||||
description?: string | null // GlobalModel.description
|
||||
aliases: string[] // 所有指向该 GlobalModel 的别名列表
|
||||
providers: ModelCatalogProviderDetail[] // 支持该模型的 Provider 列表
|
||||
price_range: ModelPriceRange // 价格区间
|
||||
total_providers: number
|
||||
@@ -351,8 +366,6 @@ export interface ProviderAvailableSourceModel {
|
||||
global_model_name: string // GlobalModel.name(原 source_model)
|
||||
display_name: string // GlobalModel.display_name
|
||||
provider_model_name: string // Model.provider_model_name(Provider 侧的模型名)
|
||||
has_alias: boolean // 是否有别名指向该 GlobalModel
|
||||
aliases: string[] // 别名列表
|
||||
model_id?: string | null // Model.id
|
||||
price: ProviderModelPriceInfo
|
||||
capabilities: ModelCapabilities
|
||||
@@ -371,65 +384,6 @@ export interface BatchAssignProviderConfig {
|
||||
model_id?: string
|
||||
}
|
||||
|
||||
export interface BatchAssignModelMappingRequest {
|
||||
global_model_id: string // 要分配的 GlobalModel ID(原 source_model)
|
||||
providers: BatchAssignProviderConfig[]
|
||||
}
|
||||
|
||||
export interface BatchAssignProviderResult {
|
||||
provider_id: string
|
||||
mapping_id?: string | null
|
||||
created_model: boolean
|
||||
model_id?: string | null
|
||||
updated: boolean
|
||||
}
|
||||
|
||||
export interface BatchAssignError {
|
||||
provider_id: string
|
||||
error: string
|
||||
}
|
||||
|
||||
export interface BatchAssignModelMappingResponse {
|
||||
success: boolean
|
||||
created_mappings: BatchAssignProviderResult[]
|
||||
errors: BatchAssignError[]
|
||||
}
|
||||
|
||||
export interface ModelMappingCreate {
|
||||
source_model: string // 源模型名或别名
|
||||
target_global_model_id: string // 目标 GlobalModel ID
|
||||
provider_id?: string | null
|
||||
mapping_type?: 'alias' | 'mapping'
|
||||
is_active?: boolean
|
||||
}
|
||||
|
||||
export interface ModelMappingUpdate {
|
||||
source_model?: string // 源模型名或别名
|
||||
target_global_model_id?: string // 目标 GlobalModel ID
|
||||
provider_id?: string | null
|
||||
mapping_type?: 'alias' | 'mapping'
|
||||
is_active?: boolean
|
||||
}
|
||||
|
||||
export interface UpdateModelMappingRequest {
|
||||
source_model?: string
|
||||
target_global_model_id?: string
|
||||
provider_id?: string | null
|
||||
mapping_type?: 'alias' | 'mapping'
|
||||
is_active?: boolean
|
||||
}
|
||||
|
||||
export interface UpdateModelMappingResponse {
|
||||
success: boolean
|
||||
mapping_id: string
|
||||
message?: string
|
||||
}
|
||||
|
||||
export interface DeleteModelMappingResponse {
|
||||
success: boolean
|
||||
message?: string
|
||||
}
|
||||
|
||||
export interface AdaptiveStatsResponse {
|
||||
adaptive_mode: boolean
|
||||
current_limit: number | null
|
||||
@@ -476,67 +430,45 @@ export interface TieredPricingConfig {
|
||||
export interface GlobalModelCreate {
|
||||
name: string
|
||||
display_name: string
|
||||
description?: string
|
||||
official_url?: string
|
||||
icon_url?: string
|
||||
// 按次计费配置(可选,与阶梯计费叠加)
|
||||
default_price_per_request?: number
|
||||
// 阶梯计费配置(必填,固定价格用单阶梯表示)
|
||||
default_tiered_pricing: TieredPricingConfig
|
||||
// 默认能力配置
|
||||
default_supports_vision?: boolean
|
||||
default_supports_function_calling?: boolean
|
||||
default_supports_streaming?: boolean
|
||||
default_supports_extended_thinking?: boolean
|
||||
default_supports_image_generation?: boolean
|
||||
// Key 能力配置 - 模型支持的能力列表
|
||||
supported_capabilities?: string[]
|
||||
// 模型配置(JSON格式)- 包含能力、规格、元信息等
|
||||
config?: Record<string, any>
|
||||
is_active?: boolean
|
||||
}
|
||||
|
||||
export interface GlobalModelUpdate {
|
||||
display_name?: string
|
||||
description?: string
|
||||
official_url?: string
|
||||
icon_url?: string
|
||||
is_active?: boolean
|
||||
// 按次计费配置
|
||||
default_price_per_request?: number | null // null 表示清空
|
||||
// 阶梯计费配置
|
||||
default_tiered_pricing?: TieredPricingConfig
|
||||
// 默认能力配置
|
||||
default_supports_vision?: boolean
|
||||
default_supports_function_calling?: boolean
|
||||
default_supports_streaming?: boolean
|
||||
default_supports_extended_thinking?: boolean
|
||||
default_supports_image_generation?: boolean
|
||||
// Key 能力配置 - 模型支持的能力列表
|
||||
supported_capabilities?: string[] | null
|
||||
// 模型配置(JSON格式)- 包含能力、规格、元信息等
|
||||
config?: Record<string, any> | null
|
||||
}
|
||||
|
||||
export interface GlobalModelResponse {
|
||||
id: string
|
||||
name: string
|
||||
display_name: string
|
||||
description?: string
|
||||
official_url?: string
|
||||
icon_url?: string
|
||||
is_active: boolean
|
||||
// 按次计费配置
|
||||
default_price_per_request?: number
|
||||
// 阶梯计费配置(必填)
|
||||
default_tiered_pricing: TieredPricingConfig
|
||||
// 默认能力配置
|
||||
default_supports_vision?: boolean
|
||||
default_supports_function_calling?: boolean
|
||||
default_supports_streaming?: boolean
|
||||
default_supports_extended_thinking?: boolean
|
||||
default_supports_image_generation?: boolean
|
||||
// Key 能力配置 - 模型支持的能力列表
|
||||
supported_capabilities?: string[] | null
|
||||
// 模型配置(JSON格式)
|
||||
config?: Record<string, any> | null
|
||||
// 统计数据
|
||||
provider_count?: number
|
||||
alias_count?: number
|
||||
usage_count?: number
|
||||
created_at: string
|
||||
updated_at?: string
|
||||
|
||||
288
frontend/src/api/models-dev.ts
Normal file
288
frontend/src/api/models-dev.ts
Normal file
@@ -0,0 +1,288 @@
|
||||
/**
|
||||
* Models.dev API 服务
|
||||
* 通过后端代理获取 models.dev 数据(解决跨域问题)
|
||||
*/
|
||||
|
||||
import api from './client'
|
||||
|
||||
// 缓存配置
|
||||
const CACHE_KEY = 'models_dev_cache'
|
||||
const CACHE_DURATION = 15 * 60 * 1000 // 15 分钟
|
||||
|
||||
// Models.dev API 数据结构
|
||||
export interface ModelsDevCost {
|
||||
input?: number
|
||||
output?: number
|
||||
reasoning?: number
|
||||
cache_read?: number
|
||||
}
|
||||
|
||||
export interface ModelsDevLimit {
|
||||
context?: number
|
||||
output?: number
|
||||
}
|
||||
|
||||
export interface ModelsDevModel {
|
||||
id: string
|
||||
name: string
|
||||
family?: string
|
||||
reasoning?: boolean
|
||||
tool_call?: boolean
|
||||
structured_output?: boolean
|
||||
temperature?: boolean
|
||||
attachment?: boolean
|
||||
knowledge?: string
|
||||
release_date?: string
|
||||
last_updated?: string
|
||||
input?: string[] // 输入模态: text, image, audio, video, pdf
|
||||
output?: string[] // 输出模态: text, image, audio
|
||||
open_weights?: boolean
|
||||
cost?: ModelsDevCost
|
||||
limit?: ModelsDevLimit
|
||||
deprecated?: boolean
|
||||
}
|
||||
|
||||
export interface ModelsDevProvider {
|
||||
id: string
|
||||
env?: string[]
|
||||
npm?: string
|
||||
api?: string
|
||||
name: string
|
||||
doc?: string
|
||||
models: Record<string, ModelsDevModel>
|
||||
official?: boolean // 是否为官方提供商
|
||||
}
|
||||
|
||||
export type ModelsDevData = Record<string, ModelsDevProvider>
|
||||
|
||||
// 扁平化的模型列表项(用于搜索和选择)
|
||||
export interface ModelsDevModelItem {
|
||||
providerId: string
|
||||
providerName: string
|
||||
modelId: string
|
||||
modelName: string
|
||||
family?: string
|
||||
inputPrice?: number
|
||||
outputPrice?: number
|
||||
contextLimit?: number
|
||||
outputLimit?: number
|
||||
supportsVision?: boolean
|
||||
supportsToolCall?: boolean
|
||||
supportsReasoning?: boolean
|
||||
supportsStructuredOutput?: boolean
|
||||
supportsTemperature?: boolean
|
||||
supportsAttachment?: boolean
|
||||
openWeights?: boolean
|
||||
deprecated?: boolean
|
||||
official?: boolean // 是否来自官方提供商
|
||||
// 用于 display_metadata 的额外字段
|
||||
knowledgeCutoff?: string
|
||||
releaseDate?: string
|
||||
inputModalities?: string[]
|
||||
outputModalities?: string[]
|
||||
}
|
||||
|
||||
interface CacheData {
|
||||
timestamp: number
|
||||
data: ModelsDevData
|
||||
}
|
||||
|
||||
// 内存缓存
|
||||
let memoryCache: CacheData | null = null
|
||||
|
||||
function hasOfficialFlag(data: ModelsDevData): boolean {
|
||||
return Object.values(data).some(provider => typeof provider?.official === 'boolean')
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 models.dev 数据(带缓存)
|
||||
*/
|
||||
export async function getModelsDevData(): Promise<ModelsDevData> {
|
||||
// 1. 检查内存缓存
|
||||
if (memoryCache && Date.now() - memoryCache.timestamp < CACHE_DURATION) {
|
||||
// 兼容旧缓存:没有 official 字段时丢弃,强制刷新一次
|
||||
if (hasOfficialFlag(memoryCache.data)) {
|
||||
return memoryCache.data
|
||||
}
|
||||
memoryCache = null
|
||||
}
|
||||
|
||||
// 2. 检查 localStorage 缓存
|
||||
try {
|
||||
const cached = localStorage.getItem(CACHE_KEY)
|
||||
if (cached) {
|
||||
const cacheData: CacheData = JSON.parse(cached)
|
||||
if (Date.now() - cacheData.timestamp < CACHE_DURATION) {
|
||||
// 兼容旧缓存:没有 official 字段时丢弃,强制刷新一次
|
||||
if (hasOfficialFlag(cacheData.data)) {
|
||||
memoryCache = cacheData
|
||||
return cacheData.data
|
||||
}
|
||||
localStorage.removeItem(CACHE_KEY)
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// 缓存解析失败,忽略
|
||||
}
|
||||
|
||||
// 3. 从后端代理获取新数据
|
||||
const response = await api.get<ModelsDevData>('/api/admin/models/external')
|
||||
const data = response.data
|
||||
|
||||
// 4. 更新缓存
|
||||
const cacheData: CacheData = {
|
||||
timestamp: Date.now(),
|
||||
data,
|
||||
}
|
||||
memoryCache = cacheData
|
||||
try {
|
||||
localStorage.setItem(CACHE_KEY, JSON.stringify(cacheData))
|
||||
} catch {
|
||||
// localStorage 写入失败,忽略
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// 模型列表缓存(避免重复转换)
|
||||
let modelsListCache: ModelsDevModelItem[] | null = null
|
||||
let modelsListCacheTimestamp: number | null = null
|
||||
|
||||
/**
|
||||
* 获取扁平化的模型列表
|
||||
* 数据只加载一次,通过参数过滤官方/全部
|
||||
*/
|
||||
export async function getModelsDevList(officialOnly: boolean = true): Promise<ModelsDevModelItem[]> {
|
||||
const data = await getModelsDevData()
|
||||
const currentTimestamp = memoryCache?.timestamp ?? 0
|
||||
|
||||
// 如果缓存为空或数据已刷新,构建一次
|
||||
if (!modelsListCache || modelsListCacheTimestamp !== currentTimestamp) {
|
||||
const items: ModelsDevModelItem[] = []
|
||||
|
||||
for (const [providerId, provider] of Object.entries(data)) {
|
||||
if (!provider.models) continue
|
||||
|
||||
for (const [modelId, model] of Object.entries(provider.models)) {
|
||||
items.push({
|
||||
providerId,
|
||||
providerName: provider.name,
|
||||
modelId,
|
||||
modelName: model.name || modelId,
|
||||
family: model.family,
|
||||
inputPrice: model.cost?.input,
|
||||
outputPrice: model.cost?.output,
|
||||
contextLimit: model.limit?.context,
|
||||
outputLimit: model.limit?.output,
|
||||
supportsVision: model.input?.includes('image'),
|
||||
supportsToolCall: model.tool_call,
|
||||
supportsReasoning: model.reasoning,
|
||||
supportsStructuredOutput: model.structured_output,
|
||||
supportsTemperature: model.temperature,
|
||||
supportsAttachment: model.attachment,
|
||||
openWeights: model.open_weights,
|
||||
deprecated: model.deprecated,
|
||||
official: provider.official,
|
||||
// display_metadata 相关字段
|
||||
knowledgeCutoff: model.knowledge,
|
||||
releaseDate: model.release_date,
|
||||
inputModalities: model.input,
|
||||
outputModalities: model.output,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 按 provider 名称和模型名称排序
|
||||
items.sort((a, b) => {
|
||||
const providerCompare = a.providerName.localeCompare(b.providerName)
|
||||
if (providerCompare !== 0) return providerCompare
|
||||
return a.modelName.localeCompare(b.modelName)
|
||||
})
|
||||
|
||||
modelsListCache = items
|
||||
modelsListCacheTimestamp = currentTimestamp
|
||||
}
|
||||
|
||||
// 根据参数过滤
|
||||
if (officialOnly) {
|
||||
return modelsListCache.filter(m => m.official)
|
||||
}
|
||||
return modelsListCache
|
||||
}
|
||||
|
||||
/**
|
||||
* 搜索模型
|
||||
* 搜索时包含所有提供商(包括第三方)
|
||||
*/
|
||||
export async function searchModelsDevModels(
|
||||
query: string,
|
||||
options?: {
|
||||
limit?: number
|
||||
excludeDeprecated?: boolean
|
||||
}
|
||||
): Promise<ModelsDevModelItem[]> {
|
||||
// 搜索时包含全部提供商
|
||||
const allModels = await getModelsDevList(false)
|
||||
const { limit = 50, excludeDeprecated = true } = options || {}
|
||||
|
||||
const queryLower = query.toLowerCase()
|
||||
|
||||
const filtered = allModels.filter((model) => {
|
||||
if (excludeDeprecated && model.deprecated) return false
|
||||
|
||||
// 搜索模型 ID、名称、provider 名称、family
|
||||
return (
|
||||
model.modelId.toLowerCase().includes(queryLower) ||
|
||||
model.modelName.toLowerCase().includes(queryLower) ||
|
||||
model.providerName.toLowerCase().includes(queryLower) ||
|
||||
model.family?.toLowerCase().includes(queryLower)
|
||||
)
|
||||
})
|
||||
|
||||
// 排序:精确匹配优先
|
||||
filtered.sort((a, b) => {
|
||||
const aExact =
|
||||
a.modelId.toLowerCase() === queryLower ||
|
||||
a.modelName.toLowerCase() === queryLower
|
||||
const bExact =
|
||||
b.modelId.toLowerCase() === queryLower ||
|
||||
b.modelName.toLowerCase() === queryLower
|
||||
if (aExact && !bExact) return -1
|
||||
if (!aExact && bExact) return 1
|
||||
return 0
|
||||
})
|
||||
|
||||
return filtered.slice(0, limit)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取特定模型详情
|
||||
*/
|
||||
export async function getModelsDevModel(
|
||||
providerId: string,
|
||||
modelId: string
|
||||
): Promise<ModelsDevModel | null> {
|
||||
const data = await getModelsDevData()
|
||||
return data[providerId]?.models?.[modelId] || null
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 provider logo URL
|
||||
*/
|
||||
export function getProviderLogoUrl(providerId: string): string {
|
||||
return `https://models.dev/logos/${providerId}.svg`
|
||||
}
|
||||
|
||||
/**
|
||||
* 清除缓存
|
||||
*/
|
||||
export function clearModelsDevCache(): void {
|
||||
memoryCache = null
|
||||
modelsListCache = null
|
||||
modelsListCacheTimestamp = null
|
||||
try {
|
||||
localStorage.removeItem(CACHE_KEY)
|
||||
} catch {
|
||||
// 忽略错误
|
||||
}
|
||||
}
|
||||
@@ -9,20 +9,14 @@ export interface PublicGlobalModel {
|
||||
id: string
|
||||
name: string
|
||||
display_name: string | null
|
||||
description: string | null
|
||||
icon_url: string | null
|
||||
is_active: boolean
|
||||
// 阶梯计费配置
|
||||
default_tiered_pricing: TieredPricingConfig
|
||||
default_price_per_request: number | null // 按次计费价格
|
||||
// 能力
|
||||
default_supports_vision: boolean
|
||||
default_supports_function_calling: boolean
|
||||
default_supports_streaming: boolean
|
||||
default_supports_extended_thinking: boolean
|
||||
default_supports_image_generation: boolean
|
||||
// Key 能力支持
|
||||
supported_capabilities: string[] | null
|
||||
// 模型配置(JSON)
|
||||
config: Record<string, any> | null
|
||||
}
|
||||
|
||||
export interface PublicGlobalModelListResponse {
|
||||
|
||||
@@ -299,7 +299,7 @@ function formatDuration(ms: number): string {
|
||||
const hours = Math.floor(ms / (1000 * 60 * 60))
|
||||
const minutes = Math.floor((ms % (1000 * 60 * 60)) / (1000 * 60))
|
||||
if (hours > 0) {
|
||||
return `${hours}h${minutes > 0 ? minutes + 'm' : ''}`
|
||||
return `${hours}h${minutes > 0 ? `${minutes}m` : ''}`
|
||||
}
|
||||
return `${minutes}m`
|
||||
}
|
||||
|
||||
@@ -34,11 +34,10 @@ const buttonClass = computed(() => {
|
||||
'inline-flex items-center justify-center rounded-xl text-sm font-semibold transition-all duration-200 ring-offset-background focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50 active:scale-[0.98]'
|
||||
|
||||
const variantClasses = {
|
||||
default:
|
||||
'bg-primary text-white shadow-[0_20px_35px_rgba(204,120,92,0.35)] hover:bg-primary/90 hover:shadow-[0_25px_45px_rgba(204,120,92,0.45)]',
|
||||
destructive: 'bg-destructive text-destructive-foreground hover:bg-destructive/85 shadow-sm',
|
||||
default: 'bg-primary text-white hover:bg-primary/90',
|
||||
destructive: 'bg-destructive text-destructive-foreground hover:bg-destructive/85',
|
||||
outline:
|
||||
'border border-border/60 bg-card/60 text-foreground hover:border-primary/60 hover:text-primary hover:bg-primary/10 shadow-sm backdrop-blur transition-all',
|
||||
'border border-border/60 bg-card/60 text-foreground hover:border-primary/60 hover:text-primary hover:bg-primary/10 backdrop-blur transition-all',
|
||||
secondary:
|
||||
'bg-secondary text-secondary-foreground shadow-inner hover:bg-secondary/80',
|
||||
ghost: 'hover:bg-accent hover:text-accent-foreground',
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
<Teleport to="body">
|
||||
<div
|
||||
v-if="isOpen"
|
||||
class="fixed inset-0 overflow-y-auto"
|
||||
class="fixed inset-0 overflow-y-auto pointer-events-none"
|
||||
:style="{ zIndex: containerZIndex }"
|
||||
>
|
||||
<!-- 背景遮罩 -->
|
||||
@@ -16,7 +16,7 @@
|
||||
>
|
||||
<div
|
||||
v-if="isOpen"
|
||||
class="fixed inset-0 bg-black/40 backdrop-blur-sm transition-opacity"
|
||||
class="fixed inset-0 bg-black/40 backdrop-blur-sm transition-opacity pointer-events-auto"
|
||||
:style="{ zIndex: backdropZIndex }"
|
||||
@click="handleClose"
|
||||
/>
|
||||
@@ -34,7 +34,7 @@
|
||||
>
|
||||
<div
|
||||
v-if="isOpen"
|
||||
class="relative transform rounded-lg bg-background text-left shadow-2xl transition-all sm:my-8 sm:w-full border border-border"
|
||||
class="relative transform rounded-lg bg-background text-left shadow-2xl transition-all sm:my-8 sm:w-full border border-border pointer-events-auto"
|
||||
:style="{ zIndex: contentZIndex }"
|
||||
:class="maxWidthClass"
|
||||
@click.stop
|
||||
|
||||
@@ -45,7 +45,7 @@ const props = withDefaults(defineProps<Props>(), {
|
||||
|
||||
const contentClass = computed(() =>
|
||||
cn(
|
||||
'z-[100] max-h-96 min-w-[8rem] overflow-hidden rounded-2xl border border-border bg-card text-foreground shadow-2xl backdrop-blur-xl pointer-events-auto',
|
||||
'z-[200] max-h-96 min-w-[8rem] overflow-hidden rounded-2xl border border-border bg-card text-foreground shadow-2xl backdrop-blur-xl pointer-events-auto',
|
||||
'data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95',
|
||||
'data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2',
|
||||
props.class
|
||||
|
||||
@@ -1,384 +0,0 @@
|
||||
<template>
|
||||
<Dialog
|
||||
:model-value="open"
|
||||
:title="dialogTitle"
|
||||
:description="dialogDescription"
|
||||
:icon="dialogIcon"
|
||||
size="md"
|
||||
@update:model-value="handleDialogUpdate"
|
||||
>
|
||||
<form
|
||||
class="space-y-4"
|
||||
@submit.prevent="handleSubmit"
|
||||
>
|
||||
<!-- 模式选择(仅创建时显示) -->
|
||||
<div
|
||||
v-if="!isEditMode"
|
||||
class="space-y-2"
|
||||
>
|
||||
<Label>创建类型 *</Label>
|
||||
<div class="grid grid-cols-2 gap-3">
|
||||
<button
|
||||
type="button"
|
||||
class="p-3 rounded-lg border-2 text-left transition-all"
|
||||
:class="[
|
||||
form.mapping_type === 'alias'
|
||||
? 'border-primary bg-primary/5'
|
||||
: 'border-border hover:border-primary/50'
|
||||
]"
|
||||
@click="form.mapping_type = 'alias'"
|
||||
>
|
||||
<div class="font-medium text-sm">
|
||||
别名
|
||||
</div>
|
||||
<div class="text-xs text-muted-foreground mt-1">
|
||||
名称简写,按目标模型计费
|
||||
</div>
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
class="p-3 rounded-lg border-2 text-left transition-all"
|
||||
:class="[
|
||||
form.mapping_type === 'mapping'
|
||||
? 'border-primary bg-primary/5'
|
||||
: 'border-border hover:border-primary/50'
|
||||
]"
|
||||
@click="form.mapping_type = 'mapping'"
|
||||
>
|
||||
<div class="font-medium text-sm">
|
||||
映射
|
||||
</div>
|
||||
<div class="text-xs text-muted-foreground mt-1">
|
||||
模型降级,按源模型计费
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 模式说明 -->
|
||||
<div class="rounded-lg border border-border bg-muted/50 p-3 text-sm">
|
||||
<p class="text-foreground font-medium mb-1">
|
||||
{{ form.mapping_type === 'alias' ? '别名模式' : '映射模式' }}
|
||||
</p>
|
||||
<p class="text-muted-foreground text-xs">
|
||||
{{ form.mapping_type === 'alias'
|
||||
? '用户请求此别名时,会路由到目标模型,并按目标模型价格计费。'
|
||||
: '将源模型的请求转发到目标模型处理,按源模型价格计费。' }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- Provider 选择/作用范围 -->
|
||||
<div
|
||||
v-if="showProviderSelect"
|
||||
class="space-y-2"
|
||||
>
|
||||
<Label>作用范围</Label>
|
||||
<!-- 固定 Provider 时显示只读 -->
|
||||
<div
|
||||
v-if="fixedProvider"
|
||||
class="px-3 py-2 border rounded-md bg-muted/50 text-sm"
|
||||
>
|
||||
仅 {{ fixedProvider.display_name || fixedProvider.name }}
|
||||
</div>
|
||||
<!-- 否则显示可选择的下拉 -->
|
||||
<Select
|
||||
v-else
|
||||
v-model:open="providerSelectOpen"
|
||||
:model-value="form.provider_id || 'global'"
|
||||
@update:model-value="handleProviderChange"
|
||||
>
|
||||
<SelectTrigger class="w-full">
|
||||
<SelectValue placeholder="选择作用范围" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="global">
|
||||
全局(所有 Provider)
|
||||
</SelectItem>
|
||||
<SelectItem
|
||||
v-for="p in providers"
|
||||
:key="p.id"
|
||||
:value="p.id"
|
||||
>
|
||||
仅 {{ p.display_name || p.name }}
|
||||
</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
<!-- 别名模式:别名名称 -->
|
||||
<div
|
||||
v-if="form.mapping_type === 'alias'"
|
||||
class="space-y-2"
|
||||
>
|
||||
<Label for="alias-name">别名名称 *</Label>
|
||||
<Input
|
||||
id="alias-name"
|
||||
v-model="form.alias"
|
||||
placeholder="如:sonnet, opus"
|
||||
:disabled="isEditMode"
|
||||
required
|
||||
/>
|
||||
<p class="text-xs text-muted-foreground">
|
||||
{{ isEditMode ? '创建后不可修改' : '用户将使用此名称请求模型' }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- 映射模式:选择源模型 -->
|
||||
<div
|
||||
v-else
|
||||
class="space-y-2"
|
||||
>
|
||||
<Label>源模型 (用户请求的模型) *</Label>
|
||||
<Select
|
||||
v-model:open="sourceModelSelectOpen"
|
||||
:model-value="form.alias"
|
||||
:disabled="isEditMode"
|
||||
@update:model-value="form.alias = $event"
|
||||
>
|
||||
<SelectTrigger
|
||||
class="w-full"
|
||||
:class="{ 'opacity-50': isEditMode }"
|
||||
>
|
||||
<SelectValue placeholder="请选择源模型" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem
|
||||
v-for="model in availableSourceModels"
|
||||
:key="model.id"
|
||||
:value="model.name"
|
||||
>
|
||||
{{ model.display_name }} ({{ model.name }})
|
||||
</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<p class="text-xs text-muted-foreground">
|
||||
{{ isEditMode ? '创建后不可修改' : '选择要被映射的源模型,计费将按此模型价格' }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- 目标模型选择 -->
|
||||
<div class="space-y-2">
|
||||
<Label>
|
||||
{{ form.mapping_type === 'alias' ? '目标模型 *' : '目标模型 (实际处理请求) *' }}
|
||||
</Label>
|
||||
<!-- 固定目标模型时显示只读信息 -->
|
||||
<div
|
||||
v-if="fixedTargetModel"
|
||||
class="px-3 py-2 border rounded-md bg-muted/50"
|
||||
>
|
||||
<span class="font-medium">{{ fixedTargetModel.display_name }}</span>
|
||||
<span class="text-muted-foreground ml-1">({{ fixedTargetModel.name }})</span>
|
||||
</div>
|
||||
<!-- 否则显示下拉选择 -->
|
||||
<Select
|
||||
v-else
|
||||
v-model:open="targetModelSelectOpen"
|
||||
:model-value="form.global_model_id"
|
||||
@update:model-value="form.global_model_id = $event"
|
||||
>
|
||||
<SelectTrigger class="w-full">
|
||||
<SelectValue placeholder="请选择模型" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem
|
||||
v-for="model in availableTargetModels"
|
||||
:key="model.id"
|
||||
:value="model.id"
|
||||
>
|
||||
{{ model.display_name }} ({{ model.name }})
|
||||
</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
</form>
|
||||
|
||||
<template #footer>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
@click="handleCancel"
|
||||
>
|
||||
取消
|
||||
</Button>
|
||||
<Button
|
||||
:disabled="submitting"
|
||||
@click="handleSubmit"
|
||||
>
|
||||
<Loader2
|
||||
v-if="submitting"
|
||||
class="w-4 h-4 mr-2 animate-spin"
|
||||
/>
|
||||
{{ isEditMode ? '保存' : '创建' }}
|
||||
</Button>
|
||||
</template>
|
||||
</Dialog>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed } from 'vue'
|
||||
import { Loader2, Tag, SquarePen } from 'lucide-vue-next'
|
||||
import { Dialog, Select, SelectTrigger, SelectValue, SelectContent, SelectItem } from '@/components/ui'
|
||||
import Button from '@/components/ui/button.vue'
|
||||
import Input from '@/components/ui/input.vue'
|
||||
import Label from '@/components/ui/label.vue'
|
||||
import { useToast } from '@/composables/useToast'
|
||||
import { useFormDialog } from '@/composables/useFormDialog'
|
||||
import type { ModelAlias, CreateModelAliasRequest, UpdateModelAliasRequest } from '@/api/endpoints/aliases'
|
||||
import type { GlobalModelResponse } from '@/api/global-models'
|
||||
|
||||
export interface ProviderOption {
|
||||
id: string
|
||||
name: string
|
||||
display_name?: string
|
||||
}
|
||||
|
||||
interface AliasFormData {
|
||||
alias: string
|
||||
global_model_id: string
|
||||
provider_id: string | null
|
||||
mapping_type: 'alias' | 'mapping'
|
||||
is_active: boolean
|
||||
}
|
||||
|
||||
const props = withDefaults(defineProps<{
|
||||
open: boolean
|
||||
editingAlias?: ModelAlias | null
|
||||
globalModels: GlobalModelResponse[]
|
||||
providers?: ProviderOption[]
|
||||
fixedTargetModel?: GlobalModelResponse | null // 用于从模型详情抽屉打开时固定目标模型
|
||||
fixedProvider?: ProviderOption | null // 用于 Provider 特定别名固定 Provider
|
||||
showProviderSelect?: boolean // 是否显示 Provider 选择(默认 true)
|
||||
}>(), {
|
||||
editingAlias: null,
|
||||
providers: () => [],
|
||||
fixedTargetModel: null,
|
||||
fixedProvider: null,
|
||||
showProviderSelect: true
|
||||
})
|
||||
|
||||
const emit = defineEmits<{
|
||||
'update:open': [value: boolean]
|
||||
'submit': [data: CreateModelAliasRequest | UpdateModelAliasRequest, isEdit: boolean]
|
||||
}>()
|
||||
|
||||
const { error: showError } = useToast()
|
||||
|
||||
// 状态
|
||||
const submitting = ref(false)
|
||||
const providerSelectOpen = ref(false)
|
||||
const sourceModelSelectOpen = ref(false)
|
||||
const targetModelSelectOpen = ref(false)
|
||||
const form = ref<AliasFormData>({
|
||||
alias: '',
|
||||
global_model_id: '',
|
||||
provider_id: null,
|
||||
mapping_type: 'alias',
|
||||
is_active: true,
|
||||
})
|
||||
|
||||
// 处理 Provider 选择变化
|
||||
function handleProviderChange(value: string) {
|
||||
form.value.provider_id = value === 'global' ? null : value
|
||||
}
|
||||
|
||||
// 重置表单
|
||||
function resetForm() {
|
||||
form.value = {
|
||||
alias: '',
|
||||
global_model_id: props.fixedTargetModel?.id || '',
|
||||
provider_id: props.fixedProvider?.id || null,
|
||||
mapping_type: 'alias',
|
||||
is_active: true,
|
||||
}
|
||||
}
|
||||
|
||||
// 加载别名数据(编辑模式)
|
||||
function loadAliasData() {
|
||||
if (!props.editingAlias) return
|
||||
form.value = {
|
||||
alias: props.editingAlias.alias,
|
||||
global_model_id: props.editingAlias.global_model_id,
|
||||
provider_id: props.editingAlias.provider_id,
|
||||
mapping_type: props.editingAlias.mapping_type || 'alias',
|
||||
is_active: props.editingAlias.is_active,
|
||||
}
|
||||
}
|
||||
|
||||
// 使用 useFormDialog 统一处理对话框逻辑
|
||||
const { isEditMode, handleDialogUpdate, handleCancel } = useFormDialog({
|
||||
isOpen: () => props.open,
|
||||
entity: () => props.editingAlias,
|
||||
isLoading: submitting,
|
||||
onClose: () => emit('update:open', false),
|
||||
loadData: loadAliasData,
|
||||
resetForm,
|
||||
})
|
||||
|
||||
// 对话框标题
|
||||
const dialogTitle = computed(() => {
|
||||
if (isEditMode.value) {
|
||||
return form.value.mapping_type === 'mapping' ? '编辑映射' : '编辑别名'
|
||||
}
|
||||
if (props.fixedProvider) {
|
||||
return `创建 ${props.fixedProvider.display_name || props.fixedProvider.name} 特定别名/映射`
|
||||
}
|
||||
return '创建别名/映射'
|
||||
})
|
||||
|
||||
// 对话框描述
|
||||
const dialogDescription = computed(() => {
|
||||
if (isEditMode.value) {
|
||||
return form.value.mapping_type === 'mapping' ? '修改模型映射配置' : '修改别名设置'
|
||||
}
|
||||
return '为模型创建别名或映射规则'
|
||||
})
|
||||
|
||||
// 对话框图标
|
||||
const dialogIcon = computed(() => isEditMode.value ? SquarePen : Tag)
|
||||
|
||||
// 映射模式下可选的源模型(排除已选择的目标模型)
|
||||
const availableSourceModels = computed(() => {
|
||||
return props.globalModels.filter(m => m.id !== form.value.global_model_id)
|
||||
})
|
||||
|
||||
// 可选的目标模型(映射模式下排除已选择的源模型)
|
||||
const availableTargetModels = computed(() => {
|
||||
if (form.value.mapping_type === 'mapping' && form.value.alias) {
|
||||
// 找到源模型对应的 GlobalModel
|
||||
const sourceModel = props.globalModels.find(m => m.name === form.value.alias)
|
||||
if (sourceModel) {
|
||||
return props.globalModels.filter(m => m.id !== sourceModel.id)
|
||||
}
|
||||
}
|
||||
return props.globalModels
|
||||
})
|
||||
|
||||
// 提交表单
|
||||
async function handleSubmit() {
|
||||
if (!form.value.alias) {
|
||||
showError(form.value.mapping_type === 'alias' ? '请输入别名名称' : '请选择源模型', '错误')
|
||||
return
|
||||
}
|
||||
|
||||
const targetModelId = props.fixedTargetModel?.id || form.value.global_model_id
|
||||
if (!targetModelId) {
|
||||
showError('请选择目标模型', '错误')
|
||||
return
|
||||
}
|
||||
|
||||
submitting.value = true
|
||||
try {
|
||||
const data: CreateModelAliasRequest | UpdateModelAliasRequest = {
|
||||
alias: form.value.alias,
|
||||
global_model_id: targetModelId,
|
||||
provider_id: props.fixedProvider?.id || form.value.provider_id,
|
||||
mapping_type: form.value.mapping_type,
|
||||
is_active: form.value.is_active,
|
||||
}
|
||||
|
||||
emit('submit', data, !!props.editingAlias)
|
||||
} finally {
|
||||
submitting.value = false
|
||||
}
|
||||
}
|
||||
</script>
|
||||
@@ -2,174 +2,304 @@
|
||||
<Dialog
|
||||
:model-value="open"
|
||||
:title="isEditMode ? '编辑模型' : '创建统一模型'"
|
||||
:description="isEditMode ? '修改模型配置和价格信息' : '添加一个新的全局模型定义'"
|
||||
:description="isEditMode ? '修改模型配置和价格信息' : ''"
|
||||
:icon="isEditMode ? SquarePen : Layers"
|
||||
size="xl"
|
||||
size="3xl"
|
||||
@update:model-value="handleDialogUpdate"
|
||||
>
|
||||
<form
|
||||
class="space-y-5 max-h-[70vh] overflow-y-auto pr-1"
|
||||
@submit.prevent="handleSubmit"
|
||||
<div
|
||||
class="flex gap-4"
|
||||
:class="isEditMode ? '' : 'h-[500px]'"
|
||||
>
|
||||
<!-- 基本信息 -->
|
||||
<section class="space-y-3">
|
||||
<h4 class="font-medium text-sm">
|
||||
基本信息
|
||||
</h4>
|
||||
|
||||
<div class="grid grid-cols-2 gap-3">
|
||||
<div class="space-y-1.5">
|
||||
<Label
|
||||
for="model-name"
|
||||
class="text-xs"
|
||||
>模型名称 *</Label>
|
||||
<Input
|
||||
id="model-name"
|
||||
v-model="form.name"
|
||||
placeholder="claude-3-5-sonnet-20241022"
|
||||
:disabled="isEditMode"
|
||||
required
|
||||
/>
|
||||
<p
|
||||
v-if="!isEditMode"
|
||||
class="text-xs text-muted-foreground"
|
||||
>
|
||||
创建后不可修改
|
||||
</p>
|
||||
</div>
|
||||
<div class="space-y-1.5">
|
||||
<Label
|
||||
for="model-display-name"
|
||||
class="text-xs"
|
||||
>显示名称 *</Label>
|
||||
<Input
|
||||
id="model-display-name"
|
||||
v-model="form.display_name"
|
||||
placeholder="Claude 3.5 Sonnet"
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="space-y-1.5">
|
||||
<Label
|
||||
for="model-description"
|
||||
class="text-xs"
|
||||
>描述</Label>
|
||||
<Input
|
||||
id="model-description"
|
||||
v-model="form.description"
|
||||
placeholder="简短描述此模型的特点"
|
||||
/>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- 能力配置 -->
|
||||
<section class="space-y-2">
|
||||
<h4 class="font-medium text-sm">
|
||||
默认能力
|
||||
</h4>
|
||||
<div class="flex flex-wrap gap-2">
|
||||
<label class="flex items-center gap-2 px-3 py-1.5 rounded-md border border-border bg-muted/30 cursor-pointer text-sm">
|
||||
<input
|
||||
v-model="form.default_supports_streaming"
|
||||
type="checkbox"
|
||||
class="rounded"
|
||||
>
|
||||
<Zap class="w-3.5 h-3.5 text-muted-foreground" />
|
||||
<span>流式输出</span>
|
||||
</label>
|
||||
<label class="flex items-center gap-2 px-3 py-1.5 rounded-md border border-border bg-muted/30 cursor-pointer text-sm">
|
||||
<input
|
||||
v-model="form.default_supports_vision"
|
||||
type="checkbox"
|
||||
class="rounded"
|
||||
>
|
||||
<Eye class="w-3.5 h-3.5 text-muted-foreground" />
|
||||
<span>视觉理解</span>
|
||||
</label>
|
||||
<label class="flex items-center gap-2 px-3 py-1.5 rounded-md border border-border bg-muted/30 cursor-pointer text-sm">
|
||||
<input
|
||||
v-model="form.default_supports_function_calling"
|
||||
type="checkbox"
|
||||
class="rounded"
|
||||
>
|
||||
<Wrench class="w-3.5 h-3.5 text-muted-foreground" />
|
||||
<span>工具调用</span>
|
||||
</label>
|
||||
<label class="flex items-center gap-2 px-3 py-1.5 rounded-md border border-border bg-muted/30 cursor-pointer text-sm">
|
||||
<input
|
||||
v-model="form.default_supports_extended_thinking"
|
||||
type="checkbox"
|
||||
class="rounded"
|
||||
>
|
||||
<Brain class="w-3.5 h-3.5 text-muted-foreground" />
|
||||
<span>深度思考</span>
|
||||
</label>
|
||||
<label class="flex items-center gap-2 px-3 py-1.5 rounded-md border border-border bg-muted/30 cursor-pointer text-sm">
|
||||
<input
|
||||
v-model="form.default_supports_image_generation"
|
||||
type="checkbox"
|
||||
class="rounded"
|
||||
>
|
||||
<Image class="w-3.5 h-3.5 text-muted-foreground" />
|
||||
<span>图像生成</span>
|
||||
</label>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- Key 能力配置 -->
|
||||
<section
|
||||
v-if="availableCapabilities.length > 0"
|
||||
class="space-y-2"
|
||||
<!-- 左侧:模型选择(仅创建模式) -->
|
||||
<div
|
||||
v-if="!isEditMode"
|
||||
class="w-[260px] shrink-0 flex flex-col h-full"
|
||||
>
|
||||
<h4 class="font-medium text-sm">
|
||||
Key 能力支持
|
||||
</h4>
|
||||
<div class="flex flex-wrap gap-2">
|
||||
<label
|
||||
v-for="cap in availableCapabilities"
|
||||
:key="cap.name"
|
||||
class="flex items-center gap-2 px-3 py-1.5 rounded-md border border-border bg-muted/30 cursor-pointer text-sm"
|
||||
>
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="form.supported_capabilities?.includes(cap.name)"
|
||||
class="rounded"
|
||||
@change="toggleCapability(cap.name)"
|
||||
>
|
||||
<span>{{ cap.display_name }}</span>
|
||||
</label>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- 价格配置 -->
|
||||
<section class="space-y-3">
|
||||
<h4 class="font-medium text-sm">
|
||||
价格配置
|
||||
</h4>
|
||||
<TieredPricingEditor
|
||||
ref="tieredPricingEditorRef"
|
||||
v-model="tieredPricing"
|
||||
:show-cache1h="form.supported_capabilities?.includes('cache_1h')"
|
||||
/>
|
||||
|
||||
<!-- 按次计费 -->
|
||||
<div class="flex items-center gap-3 pt-2 border-t">
|
||||
<Label class="text-xs whitespace-nowrap">按次计费 ($/次)</Label>
|
||||
<!-- 搜索框 -->
|
||||
<div class="relative mb-3">
|
||||
<Search class="absolute left-2.5 top-1/2 -translate-y-1/2 h-4 w-4 text-muted-foreground" />
|
||||
<Input
|
||||
:model-value="form.default_price_per_request ?? ''"
|
||||
type="number"
|
||||
step="0.001"
|
||||
min="0"
|
||||
class="w-32"
|
||||
placeholder="留空不启用"
|
||||
@update:model-value="(v) => form.default_price_per_request = parseNumberInput(v, { allowFloat: true })"
|
||||
v-model="searchQuery"
|
||||
type="text"
|
||||
placeholder="搜索模型、提供商..."
|
||||
class="pl-8 h-8 text-sm"
|
||||
/>
|
||||
<span class="text-xs text-muted-foreground">每次请求固定费用,可与 Token 计费叠加</span>
|
||||
</div>
|
||||
</section>
|
||||
</form>
|
||||
|
||||
<!-- 模型列表(两级结构) -->
|
||||
<div class="flex-1 overflow-y-auto border rounded-lg min-h-0 scrollbar-thin">
|
||||
<div
|
||||
v-if="loading"
|
||||
class="flex items-center justify-center h-32"
|
||||
>
|
||||
<Loader2 class="w-5 h-5 animate-spin text-muted-foreground" />
|
||||
</div>
|
||||
<template v-else>
|
||||
<!-- 提供商分组 -->
|
||||
<div
|
||||
v-for="group in groupedModels"
|
||||
:key="group.providerId"
|
||||
class="border-b last:border-b-0"
|
||||
>
|
||||
<!-- 提供商标题行 -->
|
||||
<div
|
||||
class="flex items-center gap-2 px-2.5 py-2 cursor-pointer hover:bg-muted text-sm"
|
||||
@click="toggleProvider(group.providerId)"
|
||||
>
|
||||
<ChevronRight
|
||||
class="w-3.5 h-3.5 text-muted-foreground transition-transform shrink-0"
|
||||
:class="expandedProvider === group.providerId ? 'rotate-90' : ''"
|
||||
/>
|
||||
<img
|
||||
:src="getProviderLogoUrl(group.providerId)"
|
||||
:alt="group.providerName"
|
||||
class="w-4 h-4 rounded shrink-0 dark:invert dark:brightness-90"
|
||||
@error="handleLogoError"
|
||||
>
|
||||
<span class="truncate font-medium text-xs flex-1">{{ group.providerName }}</span>
|
||||
<span class="text-[10px] text-muted-foreground shrink-0">{{ group.models.length }}</span>
|
||||
</div>
|
||||
<!-- 模型列表 -->
|
||||
<div
|
||||
v-if="expandedProvider === group.providerId"
|
||||
class="bg-muted/30"
|
||||
>
|
||||
<div
|
||||
v-for="model in group.models"
|
||||
:key="model.modelId"
|
||||
class="flex flex-col gap-0.5 pl-7 pr-2.5 py-1.5 cursor-pointer text-xs border-t"
|
||||
:class="selectedModel?.modelId === model.modelId && selectedModel?.providerId === model.providerId
|
||||
? 'bg-primary text-primary-foreground'
|
||||
: 'hover:bg-muted'"
|
||||
@click="selectModel(model)"
|
||||
>
|
||||
<span class="truncate font-medium">{{ model.modelName }}</span>
|
||||
<span
|
||||
class="truncate text-[10px]"
|
||||
:class="selectedModel?.modelId === model.modelId && selectedModel?.providerId === model.providerId
|
||||
? 'text-primary-foreground/70'
|
||||
: 'text-muted-foreground'"
|
||||
>{{ model.modelId }}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
v-if="groupedModels.length === 0"
|
||||
class="text-center py-8 text-sm text-muted-foreground"
|
||||
>
|
||||
{{ searchQuery ? '未找到模型' : '加载中...' }}
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 右侧:表单 -->
|
||||
<div
|
||||
class="flex-1 overflow-y-auto h-full scrollbar-thin"
|
||||
:class="isEditMode ? 'max-h-[70vh]' : ''"
|
||||
>
|
||||
<form
|
||||
class="space-y-5"
|
||||
@submit.prevent="handleSubmit"
|
||||
>
|
||||
<!-- 基本信息 -->
|
||||
<section class="space-y-3">
|
||||
<h4 class="font-medium text-sm">
|
||||
基本信息
|
||||
</h4>
|
||||
<div class="grid grid-cols-2 gap-3">
|
||||
<div class="space-y-1.5">
|
||||
<Label
|
||||
for="model-name"
|
||||
class="text-xs"
|
||||
>模型名称 *</Label>
|
||||
<Input
|
||||
id="model-name"
|
||||
v-model="form.name"
|
||||
placeholder="claude-3-5-sonnet-20241022"
|
||||
:disabled="isEditMode"
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
<div class="space-y-1.5">
|
||||
<Label
|
||||
for="model-display-name"
|
||||
class="text-xs"
|
||||
>显示名称 *</Label>
|
||||
<Input
|
||||
id="model-display-name"
|
||||
v-model="form.display_name"
|
||||
placeholder="Claude 3.5 Sonnet"
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div class="space-y-1.5">
|
||||
<Label
|
||||
for="model-description"
|
||||
class="text-xs"
|
||||
>描述</Label>
|
||||
<Input
|
||||
id="model-description"
|
||||
:model-value="form.config?.description || ''"
|
||||
placeholder="简短描述此模型的特点"
|
||||
@update:model-value="(v) => setConfigField('description', v || undefined)"
|
||||
/>
|
||||
</div>
|
||||
<div class="grid grid-cols-3 gap-3">
|
||||
<div class="space-y-1.5">
|
||||
<Label
|
||||
for="model-family"
|
||||
class="text-xs"
|
||||
>模型系列</Label>
|
||||
<Input
|
||||
id="model-family"
|
||||
:model-value="form.config?.family || ''"
|
||||
placeholder="如 GPT-4、Claude 3"
|
||||
@update:model-value="(v) => setConfigField('family', v || undefined)"
|
||||
/>
|
||||
</div>
|
||||
<div class="space-y-1.5">
|
||||
<Label
|
||||
for="model-context-limit"
|
||||
class="text-xs"
|
||||
>上下文限制</Label>
|
||||
<Input
|
||||
id="model-context-limit"
|
||||
type="number"
|
||||
:model-value="form.config?.context_limit ?? ''"
|
||||
placeholder="如 128000"
|
||||
@update:model-value="(v) => setConfigField('context_limit', v ? Number(v) : undefined)"
|
||||
/>
|
||||
</div>
|
||||
<div class="space-y-1.5">
|
||||
<Label
|
||||
for="model-output-limit"
|
||||
class="text-xs"
|
||||
>输出限制</Label>
|
||||
<Input
|
||||
id="model-output-limit"
|
||||
type="number"
|
||||
:model-value="form.config?.output_limit ?? ''"
|
||||
placeholder="如 8192"
|
||||
@update:model-value="(v) => setConfigField('output_limit', v ? Number(v) : undefined)"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- 能力配置 -->
|
||||
<section class="space-y-2">
|
||||
<h4 class="font-medium text-sm">
|
||||
默认能力
|
||||
</h4>
|
||||
<div class="flex flex-wrap gap-2">
|
||||
<label class="flex items-center gap-2 px-2.5 py-1 rounded-md border bg-muted/30 cursor-pointer text-sm">
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="form.config?.streaming !== false"
|
||||
class="rounded"
|
||||
@change="setConfigField('streaming', ($event.target as HTMLInputElement).checked)"
|
||||
>
|
||||
<Zap class="w-3.5 h-3.5 text-muted-foreground" />
|
||||
<span>流式</span>
|
||||
</label>
|
||||
<label class="flex items-center gap-2 px-2.5 py-1 rounded-md border bg-muted/30 cursor-pointer text-sm">
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="form.config?.vision === true"
|
||||
class="rounded"
|
||||
@change="setConfigField('vision', ($event.target as HTMLInputElement).checked)"
|
||||
>
|
||||
<Eye class="w-3.5 h-3.5 text-muted-foreground" />
|
||||
<span>视觉</span>
|
||||
</label>
|
||||
<label class="flex items-center gap-2 px-2.5 py-1 rounded-md border bg-muted/30 cursor-pointer text-sm">
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="form.config?.function_calling === true"
|
||||
class="rounded"
|
||||
@change="setConfigField('function_calling', ($event.target as HTMLInputElement).checked)"
|
||||
>
|
||||
<Wrench class="w-3.5 h-3.5 text-muted-foreground" />
|
||||
<span>工具</span>
|
||||
</label>
|
||||
<label class="flex items-center gap-2 px-2.5 py-1 rounded-md border bg-muted/30 cursor-pointer text-sm">
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="form.config?.extended_thinking === true"
|
||||
class="rounded"
|
||||
@change="setConfigField('extended_thinking', ($event.target as HTMLInputElement).checked)"
|
||||
>
|
||||
<Brain class="w-3.5 h-3.5 text-muted-foreground" />
|
||||
<span>思考</span>
|
||||
</label>
|
||||
<label class="flex items-center gap-2 px-2.5 py-1 rounded-md border bg-muted/30 cursor-pointer text-sm">
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="form.config?.image_generation === true"
|
||||
class="rounded"
|
||||
@change="setConfigField('image_generation', ($event.target as HTMLInputElement).checked)"
|
||||
>
|
||||
<Image class="w-3.5 h-3.5 text-muted-foreground" />
|
||||
<span>生图</span>
|
||||
</label>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- Key 能力配置 -->
|
||||
<section
|
||||
v-if="availableCapabilities.length > 0"
|
||||
class="space-y-2"
|
||||
>
|
||||
<h4 class="font-medium text-sm">
|
||||
Key 能力支持
|
||||
</h4>
|
||||
<div class="flex flex-wrap gap-2">
|
||||
<label
|
||||
v-for="cap in availableCapabilities"
|
||||
:key="cap.name"
|
||||
class="flex items-center gap-2 px-2.5 py-1 rounded-md border bg-muted/30 cursor-pointer text-sm"
|
||||
>
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="form.supported_capabilities?.includes(cap.name)"
|
||||
class="rounded"
|
||||
@change="toggleCapability(cap.name)"
|
||||
>
|
||||
<span>{{ cap.display_name }}</span>
|
||||
</label>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- 价格配置 -->
|
||||
<section class="space-y-3">
|
||||
<h4 class="font-medium text-sm">
|
||||
价格配置
|
||||
</h4>
|
||||
<TieredPricingEditor
|
||||
ref="tieredPricingEditorRef"
|
||||
v-model="tieredPricing"
|
||||
:show-cache1h="form.supported_capabilities?.includes('cache_1h')"
|
||||
/>
|
||||
<div class="flex items-center gap-3 pt-2 border-t">
|
||||
<Label class="text-xs whitespace-nowrap">按次计费</Label>
|
||||
<Input
|
||||
:model-value="form.default_price_per_request ?? ''"
|
||||
type="number"
|
||||
step="0.001"
|
||||
min="0"
|
||||
class="w-24"
|
||||
placeholder="$/次"
|
||||
@update:model-value="(v) => form.default_price_per_request = parseNumberInput(v, { allowFloat: true })"
|
||||
/>
|
||||
<span class="text-xs text-muted-foreground">可与 Token 计费叠加</span>
|
||||
</div>
|
||||
</section>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<template #footer>
|
||||
<Button
|
||||
@@ -180,7 +310,7 @@
|
||||
取消
|
||||
</Button>
|
||||
<Button
|
||||
:disabled="submitting"
|
||||
:disabled="submitting || !form.name || !form.display_name"
|
||||
@click="handleSubmit"
|
||||
>
|
||||
<Loader2
|
||||
@@ -189,19 +319,35 @@
|
||||
/>
|
||||
{{ isEditMode ? '保存' : '创建' }}
|
||||
</Button>
|
||||
<Button
|
||||
v-if="selectedModel && !isEditMode"
|
||||
type="button"
|
||||
variant="ghost"
|
||||
@click="clearSelection"
|
||||
>
|
||||
清空
|
||||
</Button>
|
||||
</template>
|
||||
</Dialog>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, onMounted } from 'vue'
|
||||
import { Eye, Wrench, Brain, Zap, Image, Loader2, Layers, SquarePen } from 'lucide-vue-next'
|
||||
import { ref, computed, onMounted, watch } from 'vue'
|
||||
import {
|
||||
Eye, Wrench, Brain, Zap, Image, Loader2, Layers, SquarePen,
|
||||
Search, ChevronRight
|
||||
} from 'lucide-vue-next'
|
||||
import { Dialog, Button, Input, Label } from '@/components/ui'
|
||||
import { useToast } from '@/composables/useToast'
|
||||
import { useFormDialog } from '@/composables/useFormDialog'
|
||||
import { parseNumberInput } from '@/utils/form'
|
||||
import { log } from '@/utils/logger'
|
||||
import TieredPricingEditor from './TieredPricingEditor.vue'
|
||||
import {
|
||||
getModelsDevList,
|
||||
getProviderLogoUrl,
|
||||
type ModelsDevModelItem,
|
||||
} from '@/api/models-dev'
|
||||
import {
|
||||
createGlobalModel,
|
||||
updateGlobalModel,
|
||||
@@ -226,42 +372,147 @@ const { success, error: showError } = useToast()
|
||||
const submitting = ref(false)
|
||||
const tieredPricingEditorRef = ref<InstanceType<typeof TieredPricingEditor> | null>(null)
|
||||
|
||||
// 阶梯计费配置(统一使用,固定价格就是单阶梯)
|
||||
// 模型列表相关
|
||||
const loading = ref(false)
|
||||
const searchQuery = ref('')
|
||||
const allModelsCache = ref<ModelsDevModelItem[]>([]) // 全部模型(缓存)
|
||||
const selectedModel = ref<ModelsDevModelItem | null>(null)
|
||||
const expandedProvider = ref<string | null>(null)
|
||||
|
||||
// 当前显示的模型列表:有搜索词时用全部,否则只用官方
|
||||
const allModels = computed(() => {
|
||||
if (searchQuery.value) {
|
||||
return allModelsCache.value
|
||||
}
|
||||
return allModelsCache.value.filter(m => m.official)
|
||||
})
|
||||
|
||||
// 按提供商分组的模型
|
||||
interface ProviderGroup {
|
||||
providerId: string
|
||||
providerName: string
|
||||
models: ModelsDevModelItem[]
|
||||
}
|
||||
|
||||
const groupedModels = computed(() => {
|
||||
let models = allModels.value.filter(m => !m.deprecated)
|
||||
// 搜索(支持空格分隔的多关键词 AND 搜索)
|
||||
if (searchQuery.value) {
|
||||
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||
models = models.filter(model => {
|
||||
const searchableText = `${model.providerId} ${model.providerName} ${model.modelId} ${model.modelName} ${model.family || ''}`.toLowerCase()
|
||||
return keywords.every(keyword => searchableText.includes(keyword))
|
||||
})
|
||||
}
|
||||
|
||||
// 按提供商分组
|
||||
const groups = new Map<string, ProviderGroup>()
|
||||
for (const model of models) {
|
||||
if (!groups.has(model.providerId)) {
|
||||
groups.set(model.providerId, {
|
||||
providerId: model.providerId,
|
||||
providerName: model.providerName,
|
||||
models: []
|
||||
})
|
||||
}
|
||||
groups.get(model.providerId)!.models.push(model)
|
||||
}
|
||||
|
||||
// 转换为数组并排序
|
||||
const result = Array.from(groups.values())
|
||||
|
||||
// 如果有搜索词,把提供商名称/ID匹配的排在前面
|
||||
if (searchQuery.value) {
|
||||
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||
result.sort((a, b) => {
|
||||
const aText = `${a.providerId} ${a.providerName}`.toLowerCase()
|
||||
const bText = `${b.providerId} ${b.providerName}`.toLowerCase()
|
||||
const aProviderMatch = keywords.some(k => aText.includes(k))
|
||||
const bProviderMatch = keywords.some(k => bText.includes(k))
|
||||
if (aProviderMatch && !bProviderMatch) return -1
|
||||
if (!aProviderMatch && bProviderMatch) return 1
|
||||
return a.providerName.localeCompare(b.providerName)
|
||||
})
|
||||
} else {
|
||||
result.sort((a, b) => a.providerName.localeCompare(b.providerName))
|
||||
}
|
||||
|
||||
return result
|
||||
})
|
||||
|
||||
// 搜索时如果只有一个提供商,自动展开
|
||||
watch(groupedModels, (groups) => {
|
||||
if (searchQuery.value && groups.length === 1) {
|
||||
expandedProvider.value = groups[0].providerId
|
||||
}
|
||||
})
|
||||
|
||||
// 切换提供商展开状态
|
||||
function toggleProvider(providerId: string) {
|
||||
expandedProvider.value = expandedProvider.value === providerId ? null : providerId
|
||||
}
|
||||
|
||||
// 阶梯计费配置
|
||||
const tieredPricing = ref<TieredPricingConfig | null>(null)
|
||||
|
||||
interface FormData {
|
||||
name: string
|
||||
display_name: string
|
||||
description?: string
|
||||
default_price_per_request?: number
|
||||
default_supports_streaming?: boolean
|
||||
default_supports_image_generation?: boolean
|
||||
default_supports_vision?: boolean
|
||||
default_supports_function_calling?: boolean
|
||||
default_supports_extended_thinking?: boolean
|
||||
supported_capabilities?: string[]
|
||||
config?: Record<string, any>
|
||||
is_active?: boolean
|
||||
}
|
||||
|
||||
const defaultForm = (): FormData => ({
|
||||
name: '',
|
||||
display_name: '',
|
||||
description: '',
|
||||
default_price_per_request: undefined,
|
||||
default_supports_streaming: true,
|
||||
default_supports_image_generation: false,
|
||||
default_supports_vision: false,
|
||||
default_supports_function_calling: false,
|
||||
default_supports_extended_thinking: false,
|
||||
supported_capabilities: [],
|
||||
config: { streaming: true },
|
||||
is_active: true,
|
||||
})
|
||||
|
||||
const form = ref<FormData>(defaultForm())
|
||||
|
||||
const KEEP_FALSE_CONFIG_KEYS = new Set(['streaming'])
|
||||
|
||||
// 设置 config 字段
|
||||
function setConfigField(key: string, value: any) {
|
||||
if (!form.value.config) {
|
||||
form.value.config = {}
|
||||
}
|
||||
if (value === undefined || value === '' || (value === false && !KEEP_FALSE_CONFIG_KEYS.has(key))) {
|
||||
delete form.value.config[key]
|
||||
} else {
|
||||
form.value.config[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
// Key 能力选项
|
||||
const availableCapabilities = ref<CapabilityDefinition[]>([])
|
||||
|
||||
// 加载模型列表
|
||||
async function loadModels() {
|
||||
if (allModelsCache.value.length > 0) return
|
||||
loading.value = true
|
||||
try {
|
||||
// 只加载一次全部模型,过滤在 computed 中完成
|
||||
allModelsCache.value = await getModelsDevList(false)
|
||||
} catch (err) {
|
||||
log.error('Failed to load models:', err)
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 打开对话框时加载数据
|
||||
watch(() => props.open, (isOpen) => {
|
||||
if (isOpen && !props.model) {
|
||||
loadModels()
|
||||
}
|
||||
})
|
||||
|
||||
// 加载可用能力列表
|
||||
async function loadCapabilities() {
|
||||
try {
|
||||
@@ -284,38 +535,92 @@ function toggleCapability(capName: string) {
|
||||
}
|
||||
}
|
||||
|
||||
// 组件挂载时加载能力列表
|
||||
onMounted(() => {
|
||||
loadCapabilities()
|
||||
})
|
||||
|
||||
// 选择模型并填充表单
|
||||
function selectModel(model: ModelsDevModelItem) {
|
||||
selectedModel.value = model
|
||||
expandedProvider.value = model.providerId
|
||||
form.value.name = model.modelId
|
||||
form.value.display_name = model.modelName
|
||||
|
||||
// 构建 config
|
||||
const config: Record<string, any> = {
|
||||
streaming: true,
|
||||
}
|
||||
if (model.supportsVision) config.vision = true
|
||||
if (model.supportsToolCall) config.function_calling = true
|
||||
if (model.supportsReasoning) config.extended_thinking = true
|
||||
if (model.supportsStructuredOutput) config.structured_output = true
|
||||
if (model.supportsTemperature !== false) config.temperature = model.supportsTemperature
|
||||
if (model.supportsAttachment) config.attachment = true
|
||||
if (model.openWeights) config.open_weights = true
|
||||
if (model.contextLimit) config.context_limit = model.contextLimit
|
||||
if (model.outputLimit) config.output_limit = model.outputLimit
|
||||
if (model.knowledgeCutoff) config.knowledge_cutoff = model.knowledgeCutoff
|
||||
if (model.family) config.family = model.family
|
||||
if (model.releaseDate) config.release_date = model.releaseDate
|
||||
if (model.inputModalities?.length) config.input_modalities = model.inputModalities
|
||||
if (model.outputModalities?.length) config.output_modalities = model.outputModalities
|
||||
form.value.config = config
|
||||
|
||||
if (model.inputPrice !== undefined || model.outputPrice !== undefined) {
|
||||
tieredPricing.value = {
|
||||
tiers: [{
|
||||
up_to: null,
|
||||
input_price_per_1m: model.inputPrice || 0,
|
||||
output_price_per_1m: model.outputPrice || 0,
|
||||
}]
|
||||
}
|
||||
} else {
|
||||
tieredPricing.value = null
|
||||
}
|
||||
}
|
||||
|
||||
// 清除选择(手动填写)
|
||||
function clearSelection() {
|
||||
selectedModel.value = null
|
||||
form.value = defaultForm()
|
||||
tieredPricing.value = null
|
||||
}
|
||||
|
||||
// Logo 加载失败处理
|
||||
function handleLogoError(event: Event) {
|
||||
const img = event.target as HTMLImageElement
|
||||
img.style.display = 'none'
|
||||
}
|
||||
|
||||
// 重置表单
|
||||
function resetForm() {
|
||||
form.value = defaultForm()
|
||||
tieredPricing.value = null
|
||||
searchQuery.value = ''
|
||||
selectedModel.value = null
|
||||
expandedProvider.value = null
|
||||
}
|
||||
|
||||
// 加载模型数据(编辑模式)
|
||||
function loadModelData() {
|
||||
if (!props.model) return
|
||||
// 先重置创建模式的残留状态
|
||||
selectedModel.value = null
|
||||
searchQuery.value = ''
|
||||
expandedProvider.value = null
|
||||
|
||||
form.value = {
|
||||
name: props.model.name,
|
||||
display_name: props.model.display_name,
|
||||
description: props.model.description,
|
||||
default_price_per_request: props.model.default_price_per_request,
|
||||
default_supports_streaming: props.model.default_supports_streaming,
|
||||
default_supports_image_generation: props.model.default_supports_image_generation,
|
||||
default_supports_vision: props.model.default_supports_vision,
|
||||
default_supports_function_calling: props.model.default_supports_function_calling,
|
||||
default_supports_extended_thinking: props.model.default_supports_extended_thinking,
|
||||
supported_capabilities: [...(props.model.supported_capabilities || [])],
|
||||
config: props.model.config ? { ...props.model.config } : { streaming: true },
|
||||
is_active: props.model.is_active,
|
||||
}
|
||||
|
||||
// 加载阶梯计费配置(深拷贝)
|
||||
if (props.model.default_tiered_pricing) {
|
||||
tieredPricing.value = JSON.parse(JSON.stringify(props.model.default_tiered_pricing))
|
||||
}
|
||||
// 确保 tieredPricing 也被正确设置或重置
|
||||
tieredPricing.value = props.model.default_tiered_pricing
|
||||
? JSON.parse(JSON.stringify(props.model.default_tiered_pricing))
|
||||
: null
|
||||
}
|
||||
|
||||
// 使用 useFormDialog 统一处理对话框逻辑
|
||||
@@ -339,24 +644,22 @@ async function handleSubmit() {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取包含自动计算缓存价格的最终数据
|
||||
const finalTiers = tieredPricingEditorRef.value?.getFinalTiers()
|
||||
const finalTieredPricing = finalTiers ? { tiers: finalTiers } : tieredPricing.value
|
||||
|
||||
// 清理空的 config
|
||||
const cleanConfig = form.value.config && Object.keys(form.value.config).length > 0
|
||||
? form.value.config
|
||||
: undefined
|
||||
|
||||
submitting.value = true
|
||||
try {
|
||||
if (isEditMode.value && props.model) {
|
||||
const updateData: GlobalModelUpdate = {
|
||||
display_name: form.value.display_name,
|
||||
description: form.value.description,
|
||||
// 使用 null 而不是 undefined 来显式清空字段
|
||||
config: cleanConfig || null,
|
||||
default_price_per_request: form.value.default_price_per_request ?? null,
|
||||
default_tiered_pricing: finalTieredPricing,
|
||||
default_supports_streaming: form.value.default_supports_streaming,
|
||||
default_supports_image_generation: form.value.default_supports_image_generation,
|
||||
default_supports_vision: form.value.default_supports_vision,
|
||||
default_supports_function_calling: form.value.default_supports_function_calling,
|
||||
default_supports_extended_thinking: form.value.default_supports_extended_thinking,
|
||||
supported_capabilities: form.value.supported_capabilities?.length ? form.value.supported_capabilities : null,
|
||||
is_active: form.value.is_active,
|
||||
}
|
||||
@@ -366,14 +669,9 @@ async function handleSubmit() {
|
||||
const createData: GlobalModelCreate = {
|
||||
name: form.value.name!,
|
||||
display_name: form.value.display_name!,
|
||||
description: form.value.description,
|
||||
default_price_per_request: form.value.default_price_per_request || undefined,
|
||||
config: cleanConfig,
|
||||
default_price_per_request: form.value.default_price_per_request ?? undefined,
|
||||
default_tiered_pricing: finalTieredPricing,
|
||||
default_supports_streaming: form.value.default_supports_streaming,
|
||||
default_supports_image_generation: form.value.default_supports_image_generation,
|
||||
default_supports_vision: form.value.default_supports_vision,
|
||||
default_supports_function_calling: form.value.default_supports_function_calling,
|
||||
default_supports_extended_thinking: form.value.default_supports_extended_thinking,
|
||||
supported_capabilities: form.value.supported_capabilities?.length ? form.value.supported_capabilities : undefined,
|
||||
is_active: form.value.is_active,
|
||||
}
|
||||
|
||||
@@ -38,12 +38,12 @@
|
||||
>
|
||||
<Copy class="w-3 h-3" />
|
||||
</button>
|
||||
<template v-if="model.description">
|
||||
<template v-if="model.config?.description">
|
||||
<span class="shrink-0">·</span>
|
||||
<span
|
||||
class="text-xs truncate"
|
||||
:title="model.description"
|
||||
>{{ model.description }}</span>
|
||||
:title="model.config?.description"
|
||||
>{{ model.config?.description }}</span>
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
@@ -104,19 +104,6 @@
|
||||
<span class="hidden sm:inline">关联提供商</span>
|
||||
<span class="sm:hidden">提供商</span>
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
class="flex-1 px-2 sm:px-4 py-2 text-xs sm:text-sm font-medium rounded-md transition-all duration-200"
|
||||
:class="[
|
||||
detailTab === 'aliases'
|
||||
? 'bg-primary text-primary-foreground shadow-sm'
|
||||
: 'text-muted-foreground hover:text-foreground hover:bg-background/50'
|
||||
]"
|
||||
@click="detailTab = 'aliases'"
|
||||
>
|
||||
<span class="hidden sm:inline">别名/映射</span>
|
||||
<span class="sm:hidden">别名</span>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Tab 内容 -->
|
||||
@@ -156,10 +143,10 @@
|
||||
</p>
|
||||
</div>
|
||||
<Badge
|
||||
:variant="model.default_supports_streaming ?? false ? 'default' : 'secondary'"
|
||||
:variant="model.config?.streaming !== false ? 'default' : 'secondary'"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ model.default_supports_streaming ?? false ? '支持' : '不支持' }}
|
||||
{{ model.config?.streaming !== false ? '支持' : '不支持' }}
|
||||
</Badge>
|
||||
</div>
|
||||
<div class="flex items-center gap-2 p-3 rounded-lg border">
|
||||
@@ -173,10 +160,10 @@
|
||||
</p>
|
||||
</div>
|
||||
<Badge
|
||||
:variant="model.default_supports_image_generation ?? false ? 'default' : 'secondary'"
|
||||
:variant="model.config?.image_generation === true ? 'default' : 'secondary'"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ model.default_supports_image_generation ?? false ? '支持' : '不支持' }}
|
||||
{{ model.config?.image_generation === true ? '支持' : '不支持' }}
|
||||
</Badge>
|
||||
</div>
|
||||
<div class="flex items-center gap-2 p-3 rounded-lg border">
|
||||
@@ -190,10 +177,10 @@
|
||||
</p>
|
||||
</div>
|
||||
<Badge
|
||||
:variant="model.default_supports_vision ?? false ? 'default' : 'secondary'"
|
||||
:variant="model.config?.vision === true ? 'default' : 'secondary'"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ model.default_supports_vision ?? false ? '支持' : '不支持' }}
|
||||
{{ model.config?.vision === true ? '支持' : '不支持' }}
|
||||
</Badge>
|
||||
</div>
|
||||
<div class="flex items-center gap-2 p-3 rounded-lg border">
|
||||
@@ -207,10 +194,10 @@
|
||||
</p>
|
||||
</div>
|
||||
<Badge
|
||||
:variant="model.default_supports_function_calling ?? false ? 'default' : 'secondary'"
|
||||
:variant="model.config?.function_calling === true ? 'default' : 'secondary'"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ model.default_supports_function_calling ?? false ? '支持' : '不支持' }}
|
||||
{{ model.config?.function_calling === true ? '支持' : '不支持' }}
|
||||
</Badge>
|
||||
</div>
|
||||
<div class="flex items-center gap-2 p-3 rounded-lg border">
|
||||
@@ -224,10 +211,10 @@
|
||||
</p>
|
||||
</div>
|
||||
<Badge
|
||||
:variant="model.default_supports_extended_thinking ?? false ? 'default' : 'secondary'"
|
||||
:variant="model.config?.extended_thinking === true ? 'default' : 'secondary'"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ model.default_supports_extended_thinking ?? false ? '支持' : '不支持' }}
|
||||
{{ model.config?.extended_thinking === true ? '支持' : '不支持' }}
|
||||
</Badge>
|
||||
</div>
|
||||
</div>
|
||||
@@ -409,11 +396,11 @@
|
||||
</div>
|
||||
<div class="p-3 rounded-lg border bg-muted/20">
|
||||
<div class="flex items-center justify-between">
|
||||
<Label class="text-xs text-muted-foreground">别名数量</Label>
|
||||
<Tag class="w-4 h-4 text-muted-foreground" />
|
||||
<Label class="text-xs text-muted-foreground">调用次数</Label>
|
||||
<BarChart3 class="w-4 h-4 text-muted-foreground" />
|
||||
</div>
|
||||
<p class="text-2xl font-bold mt-1">
|
||||
{{ model.alias_count || 0 }}
|
||||
{{ model.usage_count || 0 }}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
@@ -468,105 +455,153 @@
|
||||
<template v-else-if="providers.length > 0">
|
||||
<!-- 桌面端表格 -->
|
||||
<Table class="hidden sm:table">
|
||||
<TableHeader>
|
||||
<TableRow class="border-b border-border/60 hover:bg-transparent">
|
||||
<TableHead class="h-10 font-semibold">
|
||||
Provider
|
||||
</TableHead>
|
||||
<TableHead class="w-[120px] h-10 font-semibold">
|
||||
能力
|
||||
</TableHead>
|
||||
<TableHead class="w-[180px] h-10 font-semibold">
|
||||
价格 ($/M)
|
||||
</TableHead>
|
||||
<TableHead class="w-[80px] h-10 font-semibold text-center">
|
||||
操作
|
||||
</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
<TableRow
|
||||
<TableHeader>
|
||||
<TableRow class="border-b border-border/60 hover:bg-transparent">
|
||||
<TableHead class="h-10 font-semibold">
|
||||
Provider
|
||||
</TableHead>
|
||||
<TableHead class="w-[120px] h-10 font-semibold">
|
||||
能力
|
||||
</TableHead>
|
||||
<TableHead class="w-[180px] h-10 font-semibold">
|
||||
价格 ($/M)
|
||||
</TableHead>
|
||||
<TableHead class="w-[80px] h-10 font-semibold text-center">
|
||||
操作
|
||||
</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
<TableRow
|
||||
v-for="provider in providers"
|
||||
:key="provider.id"
|
||||
class="border-b border-border/40 hover:bg-muted/30 transition-colors"
|
||||
>
|
||||
<TableCell class="py-3">
|
||||
<div class="flex items-center gap-2">
|
||||
<span
|
||||
class="w-2 h-2 rounded-full shrink-0"
|
||||
:class="provider.is_active ? 'bg-green-500' : 'bg-gray-300'"
|
||||
:title="provider.is_active ? '活跃' : '停用'"
|
||||
/>
|
||||
<span class="font-medium truncate">{{ provider.display_name }}</span>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell class="py-3">
|
||||
<div class="flex gap-0.5">
|
||||
<Zap
|
||||
v-if="provider.supports_streaming"
|
||||
class="w-3.5 h-3.5 text-muted-foreground"
|
||||
title="流式输出"
|
||||
/>
|
||||
<Eye
|
||||
v-if="provider.supports_vision"
|
||||
class="w-3.5 h-3.5 text-muted-foreground"
|
||||
title="视觉理解"
|
||||
/>
|
||||
<Wrench
|
||||
v-if="provider.supports_function_calling"
|
||||
class="w-3.5 h-3.5 text-muted-foreground"
|
||||
title="工具调用"
|
||||
/>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell class="py-3">
|
||||
<div class="text-xs font-mono space-y-0.5">
|
||||
<!-- Token 计费:输入/输出 -->
|
||||
<div v-if="(provider.input_price_per_1m || 0) > 0 || (provider.output_price_per_1m || 0) > 0">
|
||||
<span class="text-muted-foreground">输入/输出:</span>
|
||||
<span class="ml-1">${{ (provider.input_price_per_1m || 0).toFixed(1) }}/${{ (provider.output_price_per_1m || 0).toFixed(1) }}</span>
|
||||
<!-- 阶梯标记 -->
|
||||
<span
|
||||
v-if="(provider.tier_count || 1) > 1"
|
||||
class="ml-1 text-muted-foreground"
|
||||
title="阶梯计费"
|
||||
>[阶梯]</span>
|
||||
</div>
|
||||
<!-- 缓存价格 -->
|
||||
<div
|
||||
v-if="(provider.cache_creation_price_per_1m || 0) > 0 || (provider.cache_read_price_per_1m || 0) > 0"
|
||||
class="text-muted-foreground"
|
||||
>
|
||||
<span>缓存:</span>
|
||||
<span class="ml-1">${{ (provider.cache_creation_price_per_1m || 0).toFixed(2) }}/${{ (provider.cache_read_price_per_1m || 0).toFixed(2) }}</span>
|
||||
</div>
|
||||
<!-- 1h 缓存价格 -->
|
||||
<div
|
||||
v-if="(provider.cache_1h_creation_price_per_1m || 0) > 0"
|
||||
class="text-muted-foreground"
|
||||
>
|
||||
<span>1h 缓存:</span>
|
||||
<span class="ml-1">${{ (provider.cache_1h_creation_price_per_1m || 0).toFixed(2) }}</span>
|
||||
</div>
|
||||
<!-- 按次计费 -->
|
||||
<div v-if="(provider.price_per_request || 0) > 0">
|
||||
<span class="text-muted-foreground">按次:</span>
|
||||
<span class="ml-1">${{ (provider.price_per_request || 0).toFixed(3) }}/次</span>
|
||||
</div>
|
||||
<!-- 无定价 -->
|
||||
<span
|
||||
v-if="!(provider.input_price_per_1m || 0) && !(provider.output_price_per_1m || 0) && !(provider.price_per_request || 0)"
|
||||
class="text-muted-foreground"
|
||||
>-</span>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell class="py-3 text-center">
|
||||
<div class="flex items-center justify-center gap-1">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
title="编辑此关联"
|
||||
@click="$emit('editProvider', provider)"
|
||||
>
|
||||
<Edit class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
:title="provider.is_active ? '停用此关联' : '启用此关联'"
|
||||
@click="$emit('toggleProviderStatus', provider)"
|
||||
>
|
||||
<Power class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
title="删除此关联"
|
||||
@click="$emit('deleteProvider', provider)"
|
||||
>
|
||||
<Trash2 class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
</TableBody>
|
||||
</Table>
|
||||
|
||||
<!-- 移动端卡片列表 -->
|
||||
<div class="sm:hidden divide-y divide-border/40">
|
||||
<div
|
||||
v-for="provider in providers"
|
||||
:key="provider.id"
|
||||
class="border-b border-border/40 hover:bg-muted/30 transition-colors"
|
||||
class="p-4 space-y-3"
|
||||
>
|
||||
<TableCell class="py-3">
|
||||
<div class="flex items-center gap-2">
|
||||
<div class="flex items-start justify-between gap-3">
|
||||
<div class="flex items-center gap-2 min-w-0">
|
||||
<span
|
||||
class="w-2 h-2 rounded-full shrink-0"
|
||||
:class="provider.is_active ? 'bg-green-500' : 'bg-gray-300'"
|
||||
:title="provider.is_active ? '活跃' : '停用'"
|
||||
/>
|
||||
<span class="font-medium truncate">{{ provider.display_name }}</span>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell class="py-3">
|
||||
<div class="flex gap-0.5">
|
||||
<Zap
|
||||
v-if="provider.supports_streaming"
|
||||
class="w-3.5 h-3.5 text-muted-foreground"
|
||||
title="流式输出"
|
||||
/>
|
||||
<Eye
|
||||
v-if="provider.supports_vision"
|
||||
class="w-3.5 h-3.5 text-muted-foreground"
|
||||
title="视觉理解"
|
||||
/>
|
||||
<Wrench
|
||||
v-if="provider.supports_function_calling"
|
||||
class="w-3.5 h-3.5 text-muted-foreground"
|
||||
title="工具调用"
|
||||
/>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell class="py-3">
|
||||
<div class="text-xs font-mono space-y-0.5">
|
||||
<!-- Token 计费:输入/输出 -->
|
||||
<div v-if="(provider.input_price_per_1m || 0) > 0 || (provider.output_price_per_1m || 0) > 0">
|
||||
<span class="text-muted-foreground">输入/输出:</span>
|
||||
<span class="ml-1">${{ (provider.input_price_per_1m || 0).toFixed(1) }}/${{ (provider.output_price_per_1m || 0).toFixed(1) }}</span>
|
||||
<!-- 阶梯标记 -->
|
||||
<span
|
||||
v-if="(provider.tier_count || 1) > 1"
|
||||
class="ml-1 text-muted-foreground"
|
||||
title="阶梯计费"
|
||||
>[阶梯]</span>
|
||||
</div>
|
||||
<!-- 缓存价格 -->
|
||||
<div
|
||||
v-if="(provider.cache_creation_price_per_1m || 0) > 0 || (provider.cache_read_price_per_1m || 0) > 0"
|
||||
class="text-muted-foreground"
|
||||
>
|
||||
<span>缓存:</span>
|
||||
<span class="ml-1">${{ (provider.cache_creation_price_per_1m || 0).toFixed(2) }}/${{ (provider.cache_read_price_per_1m || 0).toFixed(2) }}</span>
|
||||
</div>
|
||||
<!-- 1h 缓存价格 -->
|
||||
<div
|
||||
v-if="(provider.cache_1h_creation_price_per_1m || 0) > 0"
|
||||
class="text-muted-foreground"
|
||||
>
|
||||
<span>1h 缓存:</span>
|
||||
<span class="ml-1">${{ (provider.cache_1h_creation_price_per_1m || 0).toFixed(2) }}</span>
|
||||
</div>
|
||||
<!-- 按次计费 -->
|
||||
<div v-if="(provider.price_per_request || 0) > 0">
|
||||
<span class="text-muted-foreground">按次:</span>
|
||||
<span class="ml-1">${{ (provider.price_per_request || 0).toFixed(3) }}/次</span>
|
||||
</div>
|
||||
<!-- 无定价 -->
|
||||
<span
|
||||
v-if="!(provider.input_price_per_1m || 0) && !(provider.output_price_per_1m || 0) && !(provider.price_per_request || 0)"
|
||||
class="text-muted-foreground"
|
||||
>-</span>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell class="py-3 text-center">
|
||||
<div class="flex items-center justify-center gap-1">
|
||||
<div class="flex items-center gap-1 shrink-0">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
title="编辑此关联"
|
||||
@click="$emit('editProvider', provider)"
|
||||
>
|
||||
<Edit class="w-3.5 h-3.5" />
|
||||
@@ -575,7 +610,6 @@
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
:title="provider.is_active ? '停用此关联' : '启用此关联'"
|
||||
@click="$emit('toggleProviderStatus', provider)"
|
||||
>
|
||||
<Power class="w-3.5 h-3.5" />
|
||||
@@ -584,82 +618,35 @@
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
title="删除此关联"
|
||||
@click="$emit('deleteProvider', provider)"
|
||||
>
|
||||
<Trash2 class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
</TableBody>
|
||||
</Table>
|
||||
|
||||
<!-- 移动端卡片列表 -->
|
||||
<div class="sm:hidden divide-y divide-border/40">
|
||||
<div
|
||||
v-for="provider in providers"
|
||||
:key="provider.id"
|
||||
class="p-4 space-y-3"
|
||||
>
|
||||
<div class="flex items-start justify-between gap-3">
|
||||
<div class="flex items-center gap-2 min-w-0">
|
||||
<span
|
||||
class="w-2 h-2 rounded-full shrink-0"
|
||||
:class="provider.is_active ? 'bg-green-500' : 'bg-gray-300'"
|
||||
/>
|
||||
<span class="font-medium truncate">{{ provider.display_name }}</span>
|
||||
</div>
|
||||
<div class="flex items-center gap-1 shrink-0">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
@click="$emit('editProvider', provider)"
|
||||
<div class="flex items-center gap-3 text-xs">
|
||||
<div class="flex gap-1">
|
||||
<Zap
|
||||
v-if="provider.supports_streaming"
|
||||
class="w-3.5 h-3.5 text-muted-foreground"
|
||||
/>
|
||||
<Eye
|
||||
v-if="provider.supports_vision"
|
||||
class="w-3.5 h-3.5 text-muted-foreground"
|
||||
/>
|
||||
<Wrench
|
||||
v-if="provider.supports_function_calling"
|
||||
class="w-3.5 h-3.5 text-muted-foreground"
|
||||
/>
|
||||
</div>
|
||||
<div
|
||||
v-if="(provider.input_price_per_1m || 0) > 0 || (provider.output_price_per_1m || 0) > 0"
|
||||
class="text-muted-foreground font-mono"
|
||||
>
|
||||
<Edit class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
@click="$emit('toggleProviderStatus', provider)"
|
||||
>
|
||||
<Power class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
@click="$emit('deleteProvider', provider)"
|
||||
>
|
||||
<Trash2 class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
${{ (provider.input_price_per_1m || 0).toFixed(1) }}/${{ (provider.output_price_per_1m || 0).toFixed(1) }}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex items-center gap-3 text-xs">
|
||||
<div class="flex gap-1">
|
||||
<Zap
|
||||
v-if="provider.supports_streaming"
|
||||
class="w-3.5 h-3.5 text-muted-foreground"
|
||||
/>
|
||||
<Eye
|
||||
v-if="provider.supports_vision"
|
||||
class="w-3.5 h-3.5 text-muted-foreground"
|
||||
/>
|
||||
<Wrench
|
||||
v-if="provider.supports_function_calling"
|
||||
class="w-3.5 h-3.5 text-muted-foreground"
|
||||
/>
|
||||
</div>
|
||||
<div
|
||||
v-if="(provider.input_price_per_1m || 0) > 0 || (provider.output_price_per_1m || 0) > 0"
|
||||
class="text-muted-foreground font-mono"
|
||||
>
|
||||
${{ (provider.input_price_per_1m || 0).toFixed(1) }}/${{ (provider.output_price_per_1m || 0).toFixed(1) }}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
@@ -684,236 +671,6 @@
|
||||
</div>
|
||||
</Card>
|
||||
</div>
|
||||
|
||||
<!-- Tab 3: 别名 -->
|
||||
<div v-show="detailTab === 'aliases'">
|
||||
<Card class="overflow-hidden">
|
||||
<!-- 标题栏 -->
|
||||
<div class="px-4 py-3 border-b border-border/60">
|
||||
<div class="flex items-center justify-between gap-4">
|
||||
<div>
|
||||
<h4 class="text-sm font-semibold">
|
||||
别名与映射
|
||||
</h4>
|
||||
</div>
|
||||
<div class="flex items-center gap-2">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-8 w-8"
|
||||
title="添加别名/映射"
|
||||
@click="$emit('addAlias')"
|
||||
>
|
||||
<Plus class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-8 w-8"
|
||||
title="刷新"
|
||||
@click="$emit('refreshAliases')"
|
||||
>
|
||||
<RefreshCw
|
||||
class="w-3.5 h-3.5"
|
||||
:class="loadingAliases ? 'animate-spin' : ''"
|
||||
/>
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 表格内容 -->
|
||||
<div
|
||||
v-if="loadingAliases"
|
||||
class="flex items-center justify-center py-12"
|
||||
>
|
||||
<Loader2 class="w-6 h-6 animate-spin text-primary" />
|
||||
</div>
|
||||
|
||||
<template v-else-if="aliases.length > 0">
|
||||
<!-- 桌面端表格 -->
|
||||
<Table class="hidden sm:table">
|
||||
<TableHeader>
|
||||
<TableRow class="border-b border-border/60 hover:bg-transparent">
|
||||
<TableHead class="h-10 font-semibold">
|
||||
别名
|
||||
</TableHead>
|
||||
<TableHead class="w-[80px] h-10 font-semibold">
|
||||
类型
|
||||
</TableHead>
|
||||
<TableHead class="w-[100px] h-10 font-semibold">
|
||||
作用域
|
||||
</TableHead>
|
||||
<TableHead class="w-[100px] h-10 font-semibold text-center">
|
||||
操作
|
||||
</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
<TableRow
|
||||
v-for="alias in aliases"
|
||||
:key="alias.id"
|
||||
class="border-b border-border/40 hover:bg-muted/30 transition-colors"
|
||||
>
|
||||
<TableCell class="py-3">
|
||||
<div class="flex items-center gap-2">
|
||||
<span
|
||||
class="w-2 h-2 rounded-full shrink-0"
|
||||
:class="alias.is_active ? 'bg-green-500' : 'bg-gray-300'"
|
||||
:title="alias.is_active ? '活跃' : '停用'"
|
||||
/>
|
||||
<code class="text-sm font-medium bg-muted px-1.5 py-0.5 rounded">{{ alias.alias }}</code>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell class="py-3">
|
||||
<Badge
|
||||
variant="secondary"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ alias.mapping_type === 'mapping' ? '映射' : '别名' }}
|
||||
</Badge>
|
||||
</TableCell>
|
||||
<TableCell class="py-3">
|
||||
<Badge
|
||||
v-if="alias.provider_id"
|
||||
variant="outline"
|
||||
class="text-xs truncate max-w-[90px]"
|
||||
:title="alias.provider_name || 'Provider'"
|
||||
>
|
||||
{{ alias.provider_name || 'Provider' }}
|
||||
</Badge>
|
||||
<Badge
|
||||
v-else
|
||||
variant="default"
|
||||
class="text-xs"
|
||||
>
|
||||
全局
|
||||
</Badge>
|
||||
</TableCell>
|
||||
<TableCell class="py-3 text-center">
|
||||
<div class="flex items-center justify-center gap-0.5">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
title="编辑"
|
||||
@click="$emit('editAlias', alias)"
|
||||
>
|
||||
<Edit class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
:title="alias.is_active ? '停用' : '启用'"
|
||||
@click="$emit('toggleAliasStatus', alias)"
|
||||
>
|
||||
<Power class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
title="删除"
|
||||
@click="$emit('deleteAlias', alias)"
|
||||
>
|
||||
<Trash2 class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
</TableBody>
|
||||
</Table>
|
||||
|
||||
<!-- 移动端卡片列表 -->
|
||||
<div class="sm:hidden divide-y divide-border/40">
|
||||
<div
|
||||
v-for="alias in aliases"
|
||||
:key="alias.id"
|
||||
class="p-4 space-y-2"
|
||||
>
|
||||
<div class="flex items-start justify-between gap-3">
|
||||
<div class="flex items-center gap-2 min-w-0 flex-1">
|
||||
<span
|
||||
class="w-2 h-2 rounded-full shrink-0"
|
||||
:class="alias.is_active ? 'bg-green-500' : 'bg-gray-300'"
|
||||
/>
|
||||
<code class="text-sm font-medium bg-muted px-1.5 py-0.5 rounded truncate">{{ alias.alias }}</code>
|
||||
</div>
|
||||
<div class="flex items-center gap-1 shrink-0">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
@click="$emit('editAlias', alias)"
|
||||
>
|
||||
<Edit class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
@click="$emit('toggleAliasStatus', alias)"
|
||||
>
|
||||
<Power class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
@click="$emit('deleteAlias', alias)"
|
||||
>
|
||||
<Trash2 class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex items-center gap-2">
|
||||
<Badge
|
||||
variant="secondary"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ alias.mapping_type === 'mapping' ? '映射' : '别名' }}
|
||||
</Badge>
|
||||
<Badge
|
||||
v-if="alias.provider_id"
|
||||
variant="outline"
|
||||
class="text-xs truncate max-w-[120px]"
|
||||
>
|
||||
{{ alias.provider_name || 'Provider' }}
|
||||
</Badge>
|
||||
<Badge
|
||||
v-else
|
||||
variant="default"
|
||||
class="text-xs"
|
||||
>
|
||||
全局
|
||||
</Badge>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<div
|
||||
v-else
|
||||
class="text-center py-12"
|
||||
>
|
||||
<!-- 空状态 -->
|
||||
<Tag class="w-12 h-12 mx-auto text-muted-foreground/30 mb-3" />
|
||||
<p class="text-sm text-muted-foreground">
|
||||
暂无别名或映射
|
||||
</p>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
class="mt-4"
|
||||
@click="$emit('addAlias')"
|
||||
>
|
||||
<Plus class="w-4 h-4 mr-1" />
|
||||
添加别名/映射
|
||||
</Button>
|
||||
</div>
|
||||
</Card>
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
</div>
|
||||
@@ -931,7 +688,6 @@ import {
|
||||
Zap,
|
||||
Image,
|
||||
Building2,
|
||||
Tag,
|
||||
Plus,
|
||||
Edit,
|
||||
Trash2,
|
||||
@@ -939,7 +695,8 @@ import {
|
||||
Loader2,
|
||||
RefreshCw,
|
||||
Copy,
|
||||
Layers
|
||||
Layers,
|
||||
BarChart3
|
||||
} from 'lucide-vue-next'
|
||||
import { useToast } from '@/composables/useToast'
|
||||
import Card from '@/components/ui/card.vue'
|
||||
@@ -955,13 +712,11 @@ import TableCell from '@/components/ui/table-cell.vue'
|
||||
|
||||
// 使用外部类型定义
|
||||
import type { GlobalModelResponse } from '@/api/global-models'
|
||||
import type { ModelAlias } from '@/api/endpoints/aliases'
|
||||
import type { TieredPricingConfig, PricingTier } from '@/api/endpoints/types'
|
||||
import type { CapabilityDefinition } from '@/api/endpoints'
|
||||
|
||||
const props = withDefaults(defineProps<Props>(), {
|
||||
loadingProviders: false,
|
||||
loadingAliases: false,
|
||||
hasBlockingDialogOpen: false,
|
||||
})
|
||||
const emit = defineEmits<{
|
||||
@@ -973,11 +728,6 @@ const emit = defineEmits<{
|
||||
'deleteProvider': [provider: any]
|
||||
'toggleProviderStatus': [provider: any]
|
||||
'refreshProviders': []
|
||||
'addAlias': []
|
||||
'editAlias': [alias: ModelAlias]
|
||||
'toggleAliasStatus': [alias: ModelAlias]
|
||||
'deleteAlias': [alias: ModelAlias]
|
||||
'refreshAliases': []
|
||||
}>()
|
||||
const { success: showSuccess, error: showError } = useToast()
|
||||
|
||||
@@ -985,9 +735,7 @@ interface Props {
|
||||
model: GlobalModelResponse | null
|
||||
open: boolean
|
||||
providers: any[]
|
||||
aliases: ModelAlias[]
|
||||
loadingProviders?: boolean
|
||||
loadingAliases?: boolean
|
||||
hasBlockingDialogOpen?: boolean
|
||||
capabilities?: CapabilityDefinition[]
|
||||
}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
export { default as GlobalModelFormDialog } from './GlobalModelFormDialog.vue'
|
||||
export { default as AliasDialog } from './AliasDialog.vue'
|
||||
export { default as ModelDetailDrawer } from './ModelDetailDrawer.vue'
|
||||
export { default as TieredPricingEditor } from './TieredPricingEditor.vue'
|
||||
|
||||
337
frontend/src/features/providers/components/ModelAliasDialog.vue
Normal file
337
frontend/src/features/providers/components/ModelAliasDialog.vue
Normal file
@@ -0,0 +1,337 @@
|
||||
<template>
|
||||
<Dialog
|
||||
:model-value="open"
|
||||
title="管理模型名称映射"
|
||||
description="配置 Provider 对此模型使用的名称变体,系统会按优先级顺序选择"
|
||||
:icon="Tag"
|
||||
size="lg"
|
||||
@update:model-value="handleClose"
|
||||
>
|
||||
<div class="space-y-4">
|
||||
<!-- 模型信息 -->
|
||||
<div class="rounded-lg border bg-muted/30 p-3">
|
||||
<p class="font-medium">
|
||||
{{ model?.global_model_display_name || model?.provider_model_name }}
|
||||
</p>
|
||||
<p class="text-sm text-muted-foreground font-mono">
|
||||
主名称: {{ model?.provider_model_name }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- 别名列表 -->
|
||||
<div class="space-y-3">
|
||||
<div class="flex items-center justify-between">
|
||||
<Label class="text-sm font-medium">名称映射</Label>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
@click="addAlias"
|
||||
>
|
||||
<Plus class="w-4 h-4 mr-1" />
|
||||
添加
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<!-- 提示信息 -->
|
||||
<div
|
||||
v-if="aliases.length > 0"
|
||||
class="flex items-center gap-2 px-3 py-2 text-xs text-muted-foreground bg-muted/30 rounded-md"
|
||||
>
|
||||
<Info class="w-3.5 h-3.5 shrink-0" />
|
||||
<span>拖拽调整顺序,点击序号可编辑(相同数字为同级,负载均衡)</span>
|
||||
</div>
|
||||
|
||||
<div
|
||||
v-if="aliases.length > 0"
|
||||
class="space-y-2"
|
||||
>
|
||||
<div
|
||||
v-for="(alias, index) in aliases"
|
||||
:key="index"
|
||||
class="group flex items-center gap-3 px-3 py-2.5 rounded-lg border transition-all duration-200"
|
||||
:class="[
|
||||
draggedIndex === index
|
||||
? 'border-primary/50 bg-primary/5 shadow-md scale-[1.01]'
|
||||
: dragOverIndex === index
|
||||
? 'border-primary/30 bg-primary/5'
|
||||
: 'border-border/50 bg-background hover:border-border hover:bg-muted/30'
|
||||
]"
|
||||
draggable="true"
|
||||
@dragstart="handleDragStart(index, $event)"
|
||||
@dragend="handleDragEnd"
|
||||
@dragover.prevent="handleDragOver(index)"
|
||||
@dragleave="handleDragLeave"
|
||||
@drop="handleDrop(index)"
|
||||
>
|
||||
<!-- 拖拽手柄 -->
|
||||
<div class="cursor-grab active:cursor-grabbing p-1 rounded hover:bg-muted text-muted-foreground/40 group-hover:text-muted-foreground transition-colors shrink-0">
|
||||
<GripVertical class="w-4 h-4" />
|
||||
</div>
|
||||
|
||||
<!-- 可编辑优先级 -->
|
||||
<div class="shrink-0">
|
||||
<input
|
||||
v-if="editingPriorityIndex === index"
|
||||
type="number"
|
||||
min="1"
|
||||
:value="alias.priority"
|
||||
class="w-8 h-6 rounded-md bg-background border border-primary text-xs font-medium text-center focus:outline-none [appearance:textfield] [&::-webkit-outer-spin-button]:appearance-none [&::-webkit-inner-spin-button]:appearance-none"
|
||||
autofocus
|
||||
@blur="finishEditPriority(index, $event)"
|
||||
@keydown.enter="($event.target as HTMLInputElement).blur()"
|
||||
@keydown.escape="cancelEditPriority"
|
||||
>
|
||||
<div
|
||||
v-else
|
||||
class="w-6 h-6 rounded-md bg-muted/50 flex items-center justify-center text-xs font-medium text-muted-foreground cursor-pointer hover:bg-primary/10 hover:text-primary transition-colors"
|
||||
title="点击编辑优先级,相同数字为同级(负载均衡)"
|
||||
@click.stop="startEditPriority(index)"
|
||||
>
|
||||
{{ alias.priority }}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 别名输入框 -->
|
||||
<Input
|
||||
v-model="alias.name"
|
||||
placeholder="映射名称,如 Claude-Sonnet-4.5"
|
||||
class="flex-1"
|
||||
/>
|
||||
|
||||
<!-- 删除按钮 -->
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="shrink-0 text-destructive hover:text-destructive h-8 w-8"
|
||||
@click="removeAlias(index)"
|
||||
>
|
||||
<X class="w-4 h-4" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div
|
||||
v-else
|
||||
class="text-center py-6 text-muted-foreground border rounded-lg border-dashed"
|
||||
>
|
||||
<Tag class="w-8 h-8 mx-auto mb-2 opacity-50" />
|
||||
<p class="text-sm">
|
||||
未配置映射
|
||||
</p>
|
||||
<p class="text-xs mt-1">
|
||||
将只使用主模型名称
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<template #footer>
|
||||
<Button
|
||||
variant="outline"
|
||||
@click="handleClose(false)"
|
||||
>
|
||||
取消
|
||||
</Button>
|
||||
<Button
|
||||
:disabled="submitting"
|
||||
@click="handleSubmit"
|
||||
>
|
||||
<Loader2
|
||||
v-if="submitting"
|
||||
class="w-4 h-4 mr-2 animate-spin"
|
||||
/>
|
||||
保存
|
||||
</Button>
|
||||
</template>
|
||||
</Dialog>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, watch } from 'vue'
|
||||
import { Tag, Plus, X, Loader2, GripVertical, Info } from 'lucide-vue-next'
|
||||
import { Dialog, Button, Input, Label } from '@/components/ui'
|
||||
import { useToast } from '@/composables/useToast'
|
||||
import { updateModel } from '@/api/endpoints/models'
|
||||
import type { Model, ProviderModelAlias } from '@/api/endpoints'
|
||||
|
||||
interface Props {
|
||||
open: boolean
|
||||
providerId: string
|
||||
model: Model | null
|
||||
}
|
||||
|
||||
const props = defineProps<Props>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
'update:open': [value: boolean]
|
||||
'saved': []
|
||||
}>()
|
||||
|
||||
const { error: showError, success: showSuccess } = useToast()
|
||||
|
||||
const submitting = ref(false)
|
||||
const aliases = ref<ProviderModelAlias[]>([])
|
||||
|
||||
// 拖拽状态
|
||||
const draggedIndex = ref<number | null>(null)
|
||||
const dragOverIndex = ref<number | null>(null)
|
||||
|
||||
// 优先级编辑状态
|
||||
const editingPriorityIndex = ref<number | null>(null)
|
||||
|
||||
// 监听 open 变化
|
||||
watch(() => props.open, (newOpen) => {
|
||||
if (newOpen && props.model) {
|
||||
// 加载现有别名配置
|
||||
if (props.model.provider_model_aliases && Array.isArray(props.model.provider_model_aliases)) {
|
||||
aliases.value = JSON.parse(JSON.stringify(props.model.provider_model_aliases))
|
||||
} else {
|
||||
aliases.value = []
|
||||
}
|
||||
// 重置状态
|
||||
editingPriorityIndex.value = null
|
||||
draggedIndex.value = null
|
||||
dragOverIndex.value = null
|
||||
}
|
||||
})
|
||||
|
||||
// 添加别名
|
||||
function addAlias() {
|
||||
// 新别名优先级为当前最大优先级 + 1,或者默认为 1
|
||||
const maxPriority = aliases.value.length > 0
|
||||
? Math.max(...aliases.value.map(a => a.priority))
|
||||
: 0
|
||||
aliases.value.push({ name: '', priority: maxPriority + 1 })
|
||||
}
|
||||
|
||||
// 移除别名
|
||||
function removeAlias(index: number) {
|
||||
aliases.value.splice(index, 1)
|
||||
}
|
||||
|
||||
// ===== 拖拽排序 =====
|
||||
function handleDragStart(index: number, event: DragEvent) {
|
||||
draggedIndex.value = index
|
||||
if (event.dataTransfer) {
|
||||
event.dataTransfer.effectAllowed = 'move'
|
||||
}
|
||||
}
|
||||
|
||||
function handleDragEnd() {
|
||||
draggedIndex.value = null
|
||||
dragOverIndex.value = null
|
||||
}
|
||||
|
||||
function handleDragOver(index: number) {
|
||||
if (draggedIndex.value !== null && draggedIndex.value !== index) {
|
||||
dragOverIndex.value = index
|
||||
}
|
||||
}
|
||||
|
||||
function handleDragLeave() {
|
||||
dragOverIndex.value = null
|
||||
}
|
||||
|
||||
function handleDrop(targetIndex: number) {
|
||||
const dragIndex = draggedIndex.value
|
||||
if (dragIndex === null || dragIndex === targetIndex) {
|
||||
dragOverIndex.value = null
|
||||
return
|
||||
}
|
||||
|
||||
const items = [...aliases.value]
|
||||
const draggedItem = items[dragIndex]
|
||||
|
||||
// 记录每个别名的原始优先级(在修改前)
|
||||
const originalPriorityMap = new Map<number, number>()
|
||||
items.forEach((alias, idx) => {
|
||||
originalPriorityMap.set(idx, alias.priority)
|
||||
})
|
||||
|
||||
// 重排数组
|
||||
items.splice(dragIndex, 1)
|
||||
items.splice(targetIndex, 0, draggedItem)
|
||||
|
||||
// 按新顺序为每个组分配新的优先级
|
||||
// 同组的别名保持相同的优先级(被拖动的别名单独成组)
|
||||
const groupNewPriority = new Map<number, number>() // 原优先级 -> 新优先级
|
||||
let currentPriority = 1
|
||||
|
||||
// 找到被拖动项在原数组中的索引对应的原始优先级
|
||||
const draggedOriginalPriority = originalPriorityMap.get(dragIndex)!
|
||||
|
||||
items.forEach((alias, newIdx) => {
|
||||
// 找到这个别名在原数组中的索引
|
||||
const originalIdx = aliases.value.findIndex(a => a === alias)
|
||||
const originalPriority = originalIdx >= 0 ? originalPriorityMap.get(originalIdx)! : alias.priority
|
||||
|
||||
if (alias === draggedItem) {
|
||||
// 被拖动的别名是独立的新组,获得当前优先级
|
||||
alias.priority = currentPriority
|
||||
currentPriority++
|
||||
} else {
|
||||
if (groupNewPriority.has(originalPriority)) {
|
||||
// 这个组已经分配过优先级,使用相同的值
|
||||
alias.priority = groupNewPriority.get(originalPriority)!
|
||||
} else {
|
||||
// 这个组第一次出现,分配新优先级
|
||||
groupNewPriority.set(originalPriority, currentPriority)
|
||||
alias.priority = currentPriority
|
||||
currentPriority++
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
aliases.value = items
|
||||
draggedIndex.value = null
|
||||
dragOverIndex.value = null
|
||||
}
|
||||
|
||||
// ===== 优先级编辑 =====
|
||||
function startEditPriority(index: number) {
|
||||
editingPriorityIndex.value = index
|
||||
}
|
||||
|
||||
function finishEditPriority(index: number, event: FocusEvent) {
|
||||
const input = event.target as HTMLInputElement
|
||||
const newPriority = parseInt(input.value) || 1
|
||||
aliases.value[index].priority = Math.max(1, newPriority)
|
||||
editingPriorityIndex.value = null
|
||||
}
|
||||
|
||||
function cancelEditPriority() {
|
||||
editingPriorityIndex.value = null
|
||||
}
|
||||
|
||||
// 关闭对话框
|
||||
function handleClose(value: boolean) {
|
||||
if (!submitting.value) {
|
||||
emit('update:open', value)
|
||||
}
|
||||
}
|
||||
|
||||
// 提交保存
|
||||
async function handleSubmit() {
|
||||
if (submitting.value || !props.model) return
|
||||
|
||||
submitting.value = true
|
||||
try {
|
||||
// 过滤掉空的别名
|
||||
const validAliases = aliases.value.filter(a => a.name.trim())
|
||||
|
||||
await updateModel(props.providerId, props.model.id, {
|
||||
provider_model_aliases: validAliases.length > 0 ? validAliases : null
|
||||
})
|
||||
|
||||
showSuccess('映射配置已保存')
|
||||
emit('update:open', false)
|
||||
emit('saved')
|
||||
} catch (err: any) {
|
||||
showError(err.response?.data?.detail || '保存失败', '错误')
|
||||
} finally {
|
||||
submitting.value = false
|
||||
}
|
||||
}
|
||||
</script>
|
||||
@@ -312,8 +312,41 @@
|
||||
|
||||
<template #footer>
|
||||
<div class="flex items-center justify-between w-full">
|
||||
<div class="text-xs text-muted-foreground">
|
||||
当前模式: <span class="font-medium">{{ activeMainTab === 'provider' ? '提供商优先' : 'Key 优先' }}</span>
|
||||
<div class="flex items-center gap-4">
|
||||
<div class="text-xs text-muted-foreground">
|
||||
当前模式: <span class="font-medium">{{ activeMainTab === 'provider' ? '提供商优先' : 'Key 优先' }}</span>
|
||||
</div>
|
||||
<div class="flex items-center gap-2 pl-4 border-l border-border">
|
||||
<span class="text-xs text-muted-foreground">调度:</span>
|
||||
<div class="flex gap-0.5 p-0.5 bg-muted/40 rounded-md">
|
||||
<button
|
||||
type="button"
|
||||
class="px-2 py-1 text-xs font-medium rounded transition-all"
|
||||
:class="[
|
||||
schedulingMode === 'fixed_order'
|
||||
? 'bg-primary text-primary-foreground shadow-sm'
|
||||
: 'text-muted-foreground hover:text-foreground hover:bg-muted/50'
|
||||
]"
|
||||
title="严格按优先级顺序,不考虑缓存"
|
||||
@click="schedulingMode = 'fixed_order'"
|
||||
>
|
||||
固定顺序
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
class="px-2 py-1 text-xs font-medium rounded transition-all"
|
||||
:class="[
|
||||
schedulingMode === 'cache_affinity'
|
||||
? 'bg-primary text-primary-foreground shadow-sm'
|
||||
: 'text-muted-foreground hover:text-foreground hover:bg-muted/50'
|
||||
]"
|
||||
title="优先使用已缓存的Provider,利用Prompt Cache"
|
||||
@click="schedulingMode = 'cache_affinity'"
|
||||
>
|
||||
缓存亲和
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex gap-2">
|
||||
<Button
|
||||
@@ -410,6 +443,9 @@ const saving = ref(false)
|
||||
// Key 优先级编辑状态
|
||||
const editingKeyPriority = ref<Record<string, string | null>>({}) // format -> keyId
|
||||
|
||||
// 调度模式状态
|
||||
const schedulingMode = ref<'fixed_order' | 'cache_affinity'>('cache_affinity')
|
||||
|
||||
// 可用的 API 格式
|
||||
const availableFormats = computed(() => {
|
||||
return Object.keys(keysByFormat.value).sort()
|
||||
@@ -433,11 +469,18 @@ watch(internalOpen, async (open) => {
|
||||
// 加载当前的优先级模式配置
|
||||
async function loadCurrentPriorityMode() {
|
||||
try {
|
||||
const response = await adminApi.getSystemConfig('provider_priority_mode')
|
||||
const currentMode = response.value || 'provider'
|
||||
const [priorityResponse, schedulingResponse] = await Promise.all([
|
||||
adminApi.getSystemConfig('provider_priority_mode'),
|
||||
adminApi.getSystemConfig('scheduling_mode')
|
||||
])
|
||||
const currentMode = priorityResponse.value || 'provider'
|
||||
activeMainTab.value = currentMode === 'global_key' ? 'key' : 'provider'
|
||||
|
||||
const currentSchedulingMode = schedulingResponse.value || 'cache_affinity'
|
||||
schedulingMode.value = currentSchedulingMode === 'fixed_order' ? 'fixed_order' : 'cache_affinity'
|
||||
} catch {
|
||||
activeMainTab.value = 'provider'
|
||||
schedulingMode.value = 'cache_affinity'
|
||||
}
|
||||
}
|
||||
|
||||
@@ -611,11 +654,19 @@ async function save() {
|
||||
|
||||
const newMode = activeMainTab.value === 'key' ? 'global_key' : 'provider'
|
||||
|
||||
await adminApi.updateSystemConfig(
|
||||
'provider_priority_mode',
|
||||
newMode,
|
||||
'Provider/Key 优先级策略:provider(提供商优先模式) 或 global_key(全局Key优先模式)'
|
||||
)
|
||||
// 保存优先级模式和调度模式
|
||||
await Promise.all([
|
||||
adminApi.updateSystemConfig(
|
||||
'provider_priority_mode',
|
||||
newMode,
|
||||
'Provider/Key 优先级策略:provider(提供商优先模式) 或 global_key(全局Key优先模式)'
|
||||
),
|
||||
adminApi.updateSystemConfig(
|
||||
'scheduling_mode',
|
||||
schedulingMode.value,
|
||||
'调度模式:fixed_order(固定顺序模式) 或 cache_affinity(缓存亲和模式)'
|
||||
)
|
||||
])
|
||||
|
||||
const providerUpdates = sortedProviders.value.map((provider, index) =>
|
||||
updateProvider(provider.id, { provider_priority: index + 1 })
|
||||
|
||||
@@ -528,10 +528,10 @@
|
||||
@batch-assign="handleBatchAssign"
|
||||
/>
|
||||
|
||||
<!-- 模型映射 -->
|
||||
<MappingsTab
|
||||
<!-- 模型名称映射 -->
|
||||
<ModelAliasesTab
|
||||
v-if="provider"
|
||||
:key="`mappings-${provider.id}`"
|
||||
:key="`aliases-${provider.id}`"
|
||||
:provider="provider"
|
||||
@refresh="handleRelatedDataRefresh"
|
||||
/>
|
||||
@@ -663,8 +663,8 @@ import { getProvider, getProviderEndpoints } from '@/api/endpoints'
|
||||
import {
|
||||
KeyFormDialog,
|
||||
KeyAllowedModelsDialog,
|
||||
MappingsTab,
|
||||
ModelsTab,
|
||||
ModelAliasesTab,
|
||||
BatchAssignModelsDialog
|
||||
} from '@/features/providers/components'
|
||||
import EndpointFormDialog from '@/features/providers/components/EndpointFormDialog.vue'
|
||||
|
||||
@@ -7,6 +7,7 @@ export { default as ProviderModelFormDialog } from './ProviderModelFormDialog.vu
|
||||
export { default as ProviderDetailDrawer } from './ProviderDetailDrawer.vue'
|
||||
export { default as EndpointHealthTimeline } from './EndpointHealthTimeline.vue'
|
||||
export { default as BatchAssignModelsDialog } from './BatchAssignModelsDialog.vue'
|
||||
export { default as ModelAliasDialog } from './ModelAliasDialog.vue'
|
||||
|
||||
export { default as MappingsTab } from './provider-tabs/MappingsTab.vue'
|
||||
export { default as ModelsTab } from './provider-tabs/ModelsTab.vue'
|
||||
export { default as ModelAliasesTab } from './provider-tabs/ModelAliasesTab.vue'
|
||||
|
||||
@@ -1,310 +0,0 @@
|
||||
<template>
|
||||
<Card class="overflow-hidden">
|
||||
<!-- 标题头部 -->
|
||||
<div class="p-4 border-b border-border/60">
|
||||
<div class="flex items-center justify-between">
|
||||
<div class="flex items-center gap-2">
|
||||
<h3 class="text-sm font-semibold leading-none">
|
||||
别名与映射管理
|
||||
</h3>
|
||||
</div>
|
||||
<Button
|
||||
v-if="!hideAddButton"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
class="h-8"
|
||||
@click="openCreateDialog"
|
||||
>
|
||||
<Plus class="w-3.5 h-3.5 mr-1.5" />
|
||||
创建别名/映射
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 加载状态 -->
|
||||
<div
|
||||
v-if="loading"
|
||||
class="flex items-center justify-center py-12"
|
||||
>
|
||||
<div class="animate-spin rounded-full h-8 w-8 border-b-2 border-primary" />
|
||||
</div>
|
||||
|
||||
<!-- 别名列表 -->
|
||||
<div
|
||||
v-else-if="mappings.length > 0"
|
||||
class="overflow-x-auto"
|
||||
>
|
||||
<table class="w-full text-sm">
|
||||
<thead class="bg-muted/50 text-xs uppercase tracking-wide text-muted-foreground">
|
||||
<tr>
|
||||
<th class="text-left px-4 py-3 font-semibold">
|
||||
名称
|
||||
</th>
|
||||
<th class="text-left px-4 py-3 font-semibold w-24">
|
||||
类型
|
||||
</th>
|
||||
<th class="text-left px-4 py-3 font-semibold">
|
||||
指向模型
|
||||
</th>
|
||||
<th
|
||||
v-if="!hideAddButton"
|
||||
class="px-4 py-3 font-semibold w-28 text-center"
|
||||
>
|
||||
操作
|
||||
</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr
|
||||
v-for="mapping in mappings"
|
||||
:key="mapping.id"
|
||||
class="border-b border-border/40 last:border-b-0 hover:bg-muted/30 transition-colors"
|
||||
>
|
||||
<td class="px-4 py-3">
|
||||
<div class="flex items-center gap-2">
|
||||
<!-- 状态指示灯 -->
|
||||
<span
|
||||
class="w-2 h-2 rounded-full shrink-0"
|
||||
:class="mapping.is_active ? 'bg-green-500' : 'bg-gray-300'"
|
||||
:title="mapping.is_active ? '活跃' : '停用'"
|
||||
/>
|
||||
<span class="font-mono">{{ mapping.alias }}</span>
|
||||
</div>
|
||||
</td>
|
||||
<td class="px-4 py-3">
|
||||
<Badge
|
||||
variant="secondary"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ mapping.mapping_type === 'mapping' ? '映射' : '别名' }}
|
||||
</Badge>
|
||||
</td>
|
||||
<td class="px-4 py-3">
|
||||
{{ mapping.global_model_display_name || mapping.global_model_name }}
|
||||
</td>
|
||||
<td
|
||||
v-if="!hideAddButton"
|
||||
class="px-4 py-3"
|
||||
>
|
||||
<div class="flex justify-center gap-1.5">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-8 w-8"
|
||||
title="编辑"
|
||||
@click="openEditDialog(mapping)"
|
||||
>
|
||||
<Edit class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-8 w-8"
|
||||
:disabled="togglingId === mapping.id"
|
||||
:title="mapping.is_active ? '点击停用' : '点击启用'"
|
||||
@click="toggleActive(mapping)"
|
||||
>
|
||||
<Power class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-8 w-8 text-destructive hover:text-destructive"
|
||||
title="删除"
|
||||
@click="confirmDelete(mapping)"
|
||||
>
|
||||
<Trash2 class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<!-- 空状态 -->
|
||||
<div
|
||||
v-else
|
||||
class="p-8 text-center text-muted-foreground"
|
||||
>
|
||||
<ArrowLeftRight class="w-12 h-12 mx-auto mb-3 opacity-50" />
|
||||
<p class="text-sm">
|
||||
暂无特定别名/映射
|
||||
</p>
|
||||
<p class="text-xs mt-1">
|
||||
点击上方按钮添加
|
||||
</p>
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
<!-- 使用共享的 AliasDialog 组件 -->
|
||||
<AliasDialog
|
||||
:open="dialogOpen"
|
||||
:editing-alias="editingAlias"
|
||||
:global-models="availableModels"
|
||||
:fixed-provider="fixedProviderOption"
|
||||
:show-provider-select="true"
|
||||
@update:open="handleDialogVisibility"
|
||||
@submit="handleAliasSubmit"
|
||||
/>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
import { ArrowLeftRight, Plus, Edit, Trash2, Power } from 'lucide-vue-next'
|
||||
import Card from '@/components/ui/card.vue'
|
||||
import Badge from '@/components/ui/badge.vue'
|
||||
import Button from '@/components/ui/button.vue'
|
||||
import AliasDialog from '@/features/models/components/AliasDialog.vue'
|
||||
import { useToast } from '@/composables/useToast'
|
||||
import {
|
||||
getAliases,
|
||||
createAlias,
|
||||
updateAlias,
|
||||
deleteAlias,
|
||||
type ModelAlias,
|
||||
type CreateModelAliasRequest,
|
||||
type UpdateModelAliasRequest,
|
||||
} from '@/api/endpoints/aliases'
|
||||
import { listGlobalModels, type GlobalModelResponse } from '@/api/global-models'
|
||||
|
||||
const props = withDefaults(defineProps<{
|
||||
provider: any
|
||||
hideAddButton?: boolean
|
||||
}>(), {
|
||||
hideAddButton: false
|
||||
})
|
||||
|
||||
const emit = defineEmits<{
|
||||
refresh: []
|
||||
}>()
|
||||
|
||||
const { success, error: showError } = useToast()
|
||||
|
||||
// 状态
|
||||
const loading = ref(false)
|
||||
const submitting = ref(false)
|
||||
const togglingId = ref<string | null>(null)
|
||||
const mappings = ref<ModelAlias[]>([])
|
||||
const availableModels = ref<GlobalModelResponse[]>([])
|
||||
const dialogOpen = ref(false)
|
||||
const editingAlias = ref<ModelAlias | null>(null)
|
||||
|
||||
// 固定的 Provider 选项(传递给 AliasDialog)
|
||||
const fixedProviderOption = computed(() => ({
|
||||
id: props.provider.id,
|
||||
name: props.provider.name,
|
||||
display_name: props.provider.display_name
|
||||
}))
|
||||
|
||||
// 加载映射 (实际返回的是该 Provider 的别名列表)
|
||||
async function loadMappings() {
|
||||
try {
|
||||
loading.value = true
|
||||
mappings.value = await getAliases({ provider_id: props.provider.id })
|
||||
} catch (err: any) {
|
||||
showError(err.response?.data?.detail || '加载失败', '错误')
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 加载可用的 GlobalModel 列表
|
||||
async function loadAvailableModels() {
|
||||
try {
|
||||
const response = await listGlobalModels({ limit: 1000, is_active: true })
|
||||
availableModels.value = response.models || []
|
||||
} catch (err: any) {
|
||||
showError(err.response?.data?.detail || '加载模型列表失败', '错误')
|
||||
}
|
||||
}
|
||||
|
||||
// 打开创建对话框
|
||||
function openCreateDialog() {
|
||||
editingAlias.value = null
|
||||
dialogOpen.value = true
|
||||
}
|
||||
|
||||
// 打开编辑对话框
|
||||
function openEditDialog(alias: ModelAlias) {
|
||||
editingAlias.value = alias
|
||||
dialogOpen.value = true
|
||||
}
|
||||
|
||||
// 处理对话框可见性变化
|
||||
function handleDialogVisibility(value: boolean) {
|
||||
dialogOpen.value = value
|
||||
if (!value) {
|
||||
editingAlias.value = null
|
||||
}
|
||||
}
|
||||
|
||||
// 处理别名提交(来自 AliasDialog 组件)
|
||||
async function handleAliasSubmit(data: CreateModelAliasRequest | UpdateModelAliasRequest, isEdit: boolean) {
|
||||
submitting.value = true
|
||||
try {
|
||||
if (isEdit && editingAlias.value) {
|
||||
// 更新
|
||||
await updateAlias(editingAlias.value.id, data as UpdateModelAliasRequest)
|
||||
success(data.mapping_type === 'mapping' ? '映射已更新' : '别名已更新')
|
||||
} else {
|
||||
// 创建 - 确保 provider_id 设置为当前 Provider
|
||||
const createData = data as CreateModelAliasRequest
|
||||
createData.provider_id = props.provider.id
|
||||
await createAlias(createData)
|
||||
success(data.mapping_type === 'mapping' ? '映射已创建' : '别名已创建')
|
||||
}
|
||||
dialogOpen.value = false
|
||||
editingAlias.value = null
|
||||
await loadMappings()
|
||||
emit('refresh')
|
||||
} catch (err: any) {
|
||||
const detail = err.response?.data?.detail || err.message
|
||||
let errorMessage = detail
|
||||
if (detail === '映射已存在') {
|
||||
errorMessage = '该名称已存在,请使用其他名称'
|
||||
}
|
||||
showError(errorMessage, isEdit ? '更新失败' : '创建失败')
|
||||
} finally {
|
||||
submitting.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 切换启用状态
|
||||
async function toggleActive(alias: ModelAlias) {
|
||||
if (togglingId.value) return
|
||||
|
||||
togglingId.value = alias.id
|
||||
try {
|
||||
const newStatus = !alias.is_active
|
||||
await updateAlias(alias.id, { is_active: newStatus })
|
||||
alias.is_active = newStatus
|
||||
} catch (err: any) {
|
||||
showError(err.response?.data?.detail || '操作失败', '错误')
|
||||
} finally {
|
||||
togglingId.value = null
|
||||
}
|
||||
}
|
||||
|
||||
// 确认删除
|
||||
async function confirmDelete(alias: ModelAlias) {
|
||||
const typeName = alias.mapping_type === 'mapping' ? '映射' : '别名'
|
||||
if (!confirm(`确定要删除${typeName} "${alias.alias}" 吗?`)) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
await deleteAlias(alias.id)
|
||||
success(`${typeName}已删除`)
|
||||
await loadMappings()
|
||||
emit('refresh')
|
||||
} catch (err: any) {
|
||||
showError(err.response?.data?.detail || err.message, '删除失败')
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
loadMappings()
|
||||
loadAvailableModels()
|
||||
})
|
||||
</script>
|
||||
File diff suppressed because it is too large
Load Diff
@@ -479,10 +479,25 @@ const groupedTimeline = computed<NodeGroup[]>(() => {
|
||||
return groups
|
||||
})
|
||||
|
||||
// 计算链路总耗时(从第一个节点开始到最后一个节点结束)
|
||||
// 计算链路总耗时(使用成功候选的 latency_ms 字段)
|
||||
// 优先使用 latency_ms,因为它与 Usage.response_time_ms 使用相同的时间基准
|
||||
// 避免 finished_at - started_at 带来的额外延迟(数据库操作时间)
|
||||
const totalTraceLatency = computed(() => {
|
||||
if (!timeline.value || timeline.value.length === 0) return 0
|
||||
|
||||
// 查找成功的候选,使用其 latency_ms
|
||||
const successCandidate = timeline.value.find(c => c.status === 'success')
|
||||
if (successCandidate?.latency_ms != null) {
|
||||
return successCandidate.latency_ms
|
||||
}
|
||||
|
||||
// 如果没有成功的候选,查找失败但有 latency_ms 的候选
|
||||
const failedWithLatency = timeline.value.find(c => c.status === 'failed' && c.latency_ms != null)
|
||||
if (failedWithLatency?.latency_ms != null) {
|
||||
return failedWithLatency.latency_ms
|
||||
}
|
||||
|
||||
// 回退:使用 finished_at - started_at 计算
|
||||
let earliestStart: number | null = null
|
||||
let latestEnd: number | null = null
|
||||
|
||||
|
||||
@@ -177,8 +177,9 @@
|
||||
费用
|
||||
</TableHead>
|
||||
<TableHead class="h-12 font-semibold w-[70px] text-right">
|
||||
<div class="inline-block max-w-[2rem] leading-tight">
|
||||
响应时间
|
||||
<div class="flex flex-col items-end text-xs gap-0.5">
|
||||
<span>首字</span>
|
||||
<span class="text-muted-foreground font-normal">总耗时</span>
|
||||
</div>
|
||||
</TableHead>
|
||||
</TableRow>
|
||||
@@ -356,15 +357,28 @@
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell class="text-right py-4 w-[70px]">
|
||||
<span
|
||||
<div
|
||||
v-if="record.status === 'pending' || record.status === 'streaming'"
|
||||
class="text-primary tabular-nums"
|
||||
class="flex flex-col items-end text-xs gap-0.5"
|
||||
>
|
||||
{{ getElapsedTime(record) }}
|
||||
</span>
|
||||
<span v-else-if="record.response_time_ms">
|
||||
{{ (record.response_time_ms / 1000).toFixed(2) }}s
|
||||
</span>
|
||||
<span class="text-primary tabular-nums">
|
||||
{{ getElapsedTime(record) }}
|
||||
</span>
|
||||
</div>
|
||||
<div
|
||||
v-else-if="record.response_time_ms != null"
|
||||
class="flex flex-col items-end text-xs gap-0.5"
|
||||
>
|
||||
<span
|
||||
v-if="record.first_byte_time_ms != null"
|
||||
class="tabular-nums"
|
||||
>{{ (record.first_byte_time_ms / 1000).toFixed(2) }}s</span>
|
||||
<span
|
||||
v-else
|
||||
class="text-muted-foreground"
|
||||
>-</span>
|
||||
<span class="text-muted-foreground tabular-nums">{{ (record.response_time_ms / 1000).toFixed(2) }}s</span>
|
||||
</div>
|
||||
<span
|
||||
v-else
|
||||
class="text-muted-foreground"
|
||||
@@ -543,13 +557,14 @@ function formatApiFormat(format: string): string {
|
||||
}
|
||||
|
||||
// 获取实际使用的模型(优先 target_model,其次 model_version)
|
||||
// 只有当实际模型与请求模型不同时才返回,用于显示映射箭头
|
||||
function getActualModel(record: UsageRecord): string | null {
|
||||
// 优先显示模型映射
|
||||
if (record.target_model) {
|
||||
if (record.target_model && record.target_model !== record.model) {
|
||||
return record.target_model
|
||||
}
|
||||
// 其次显示 Provider 返回的实际版本(如 Gemini 的 modelVersion)
|
||||
if (record.request_metadata?.model_version) {
|
||||
if (record.request_metadata?.model_version && record.request_metadata.model_version !== record.model) {
|
||||
return record.request_metadata.model_version
|
||||
}
|
||||
return null
|
||||
|
||||
@@ -78,6 +78,7 @@ export interface UsageRecord {
|
||||
cost: number
|
||||
actual_cost?: number
|
||||
response_time_ms?: number
|
||||
first_byte_time_ms?: number // 首字时间 (TTFB)
|
||||
is_stream: boolean
|
||||
status_code?: number
|
||||
error_message?: string
|
||||
|
||||
@@ -313,7 +313,6 @@ import {
|
||||
Gauge,
|
||||
Layers,
|
||||
FolderTree,
|
||||
Tag,
|
||||
Box,
|
||||
LogOut,
|
||||
SunMoon,
|
||||
@@ -411,7 +410,6 @@ const navigation = computed(() => {
|
||||
{ name: '用户管理', href: '/admin/users', icon: Users },
|
||||
{ name: '提供商', href: '/admin/providers', icon: FolderTree },
|
||||
{ name: '模型管理', href: '/admin/models', icon: Layers },
|
||||
{ name: '别名映射', href: '/admin/aliases', icon: Tag },
|
||||
{ name: '独立密钥', href: '/admin/keys', icon: Key },
|
||||
{ name: '使用记录', href: '/admin/usage', icon: BarChart3 },
|
||||
]
|
||||
|
||||
@@ -611,41 +611,42 @@ export const MOCK_GLOBAL_MODELS: GlobalModelResponse[] = [
|
||||
id: 'gm-001',
|
||||
name: 'claude-haiku-4-5-20251001',
|
||||
display_name: 'claude-haiku-4-5',
|
||||
description: 'Anthropic 最快速的 Claude 4 系列模型',
|
||||
is_active: true,
|
||||
default_tiered_pricing: {
|
||||
tiers: [{ up_to: null, input_price_per_1m: 1.00, output_price_per_1m: 5.00, cache_creation_price_per_1m: 1.25, cache_read_price_per_1m: 0.1 }]
|
||||
},
|
||||
default_supports_vision: true,
|
||||
default_supports_function_calling: true,
|
||||
default_supports_streaming: true,
|
||||
default_supports_extended_thinking: true,
|
||||
config: {
|
||||
streaming: true,
|
||||
vision: true,
|
||||
function_calling: true,
|
||||
extended_thinking: true,
|
||||
description: 'Anthropic 最快速的 Claude 4 系列模型'
|
||||
},
|
||||
provider_count: 3,
|
||||
alias_count: 2,
|
||||
created_at: '2024-01-01T00:00:00Z'
|
||||
},
|
||||
{
|
||||
id: 'gm-002',
|
||||
name: 'claude-opus-4-5-20251101',
|
||||
display_name: 'claude-opus-4-5',
|
||||
description: 'Anthropic 最强大的模型',
|
||||
is_active: true,
|
||||
default_tiered_pricing: {
|
||||
tiers: [{ up_to: null, input_price_per_1m: 5.00, output_price_per_1m: 25.00, cache_creation_price_per_1m: 6.25, cache_read_price_per_1m: 0.5 }]
|
||||
},
|
||||
default_supports_vision: true,
|
||||
default_supports_function_calling: true,
|
||||
default_supports_streaming: true,
|
||||
default_supports_extended_thinking: true,
|
||||
config: {
|
||||
streaming: true,
|
||||
vision: true,
|
||||
function_calling: true,
|
||||
extended_thinking: true,
|
||||
description: 'Anthropic 最强大的模型'
|
||||
},
|
||||
provider_count: 2,
|
||||
alias_count: 1,
|
||||
created_at: '2024-01-01T00:00:00Z'
|
||||
},
|
||||
{
|
||||
id: 'gm-003',
|
||||
name: 'claude-sonnet-4-5-20250929',
|
||||
display_name: 'claude-sonnet-4-5',
|
||||
description: 'Anthropic 平衡型模型,支持 1h 缓存和 CLI 1M 上下文',
|
||||
is_active: true,
|
||||
default_tiered_pricing: {
|
||||
tiers: [
|
||||
@@ -677,116 +678,124 @@ export const MOCK_GLOBAL_MODELS: GlobalModelResponse[] = [
|
||||
}
|
||||
]
|
||||
},
|
||||
default_supports_vision: true,
|
||||
default_supports_function_calling: true,
|
||||
default_supports_streaming: true,
|
||||
default_supports_extended_thinking: true,
|
||||
config: {
|
||||
streaming: true,
|
||||
vision: true,
|
||||
function_calling: true,
|
||||
extended_thinking: true,
|
||||
description: 'Anthropic 平衡型模型,支持 1h 缓存和 CLI 1M 上下文'
|
||||
},
|
||||
supported_capabilities: ['cache_1h', 'cli_1m'],
|
||||
provider_count: 3,
|
||||
alias_count: 2,
|
||||
created_at: '2024-01-01T00:00:00Z'
|
||||
},
|
||||
{
|
||||
id: 'gm-004',
|
||||
name: 'gemini-3-pro-image-preview',
|
||||
display_name: 'gemini-3-pro-image-preview',
|
||||
description: 'Google Gemini 3 Pro 图像生成预览版',
|
||||
is_active: true,
|
||||
default_price_per_request: 0.300,
|
||||
default_tiered_pricing: {
|
||||
tiers: []
|
||||
},
|
||||
default_supports_vision: true,
|
||||
default_supports_function_calling: false,
|
||||
default_supports_streaming: true,
|
||||
default_supports_image_generation: true,
|
||||
config: {
|
||||
streaming: true,
|
||||
vision: true,
|
||||
function_calling: false,
|
||||
image_generation: true,
|
||||
description: 'Google Gemini 3 Pro 图像生成预览版'
|
||||
},
|
||||
provider_count: 1,
|
||||
alias_count: 0,
|
||||
created_at: '2024-01-01T00:00:00Z'
|
||||
},
|
||||
{
|
||||
id: 'gm-005',
|
||||
name: 'gemini-3-pro-preview',
|
||||
display_name: 'gemini-3-pro-preview',
|
||||
description: 'Google Gemini 3 Pro 预览版',
|
||||
is_active: true,
|
||||
default_tiered_pricing: {
|
||||
tiers: [{ up_to: null, input_price_per_1m: 2.00, output_price_per_1m: 12.00 }]
|
||||
},
|
||||
default_supports_vision: true,
|
||||
default_supports_function_calling: true,
|
||||
default_supports_streaming: true,
|
||||
default_supports_extended_thinking: true,
|
||||
config: {
|
||||
streaming: true,
|
||||
vision: true,
|
||||
function_calling: true,
|
||||
extended_thinking: true,
|
||||
description: 'Google Gemini 3 Pro 预览版'
|
||||
},
|
||||
provider_count: 1,
|
||||
alias_count: 0,
|
||||
created_at: '2024-01-01T00:00:00Z'
|
||||
},
|
||||
{
|
||||
id: 'gm-006',
|
||||
name: 'gpt-5.1',
|
||||
display_name: 'gpt-5.1',
|
||||
description: 'OpenAI GPT-5.1 模型',
|
||||
is_active: true,
|
||||
default_tiered_pricing: {
|
||||
tiers: [{ up_to: null, input_price_per_1m: 1.25, output_price_per_1m: 10.00 }]
|
||||
},
|
||||
default_supports_vision: true,
|
||||
default_supports_function_calling: true,
|
||||
default_supports_streaming: true,
|
||||
default_supports_extended_thinking: true,
|
||||
config: {
|
||||
streaming: true,
|
||||
vision: true,
|
||||
function_calling: true,
|
||||
extended_thinking: true,
|
||||
description: 'OpenAI GPT-5.1 模型'
|
||||
},
|
||||
provider_count: 2,
|
||||
alias_count: 1,
|
||||
created_at: '2024-01-01T00:00:00Z'
|
||||
},
|
||||
{
|
||||
id: 'gm-007',
|
||||
name: 'gpt-5.1-codex',
|
||||
display_name: 'gpt-5.1-codex',
|
||||
description: 'OpenAI GPT-5.1 Codex 代码专用模型',
|
||||
is_active: true,
|
||||
default_tiered_pricing: {
|
||||
tiers: [{ up_to: null, input_price_per_1m: 1.25, output_price_per_1m: 10.00 }]
|
||||
},
|
||||
default_supports_vision: true,
|
||||
default_supports_function_calling: true,
|
||||
default_supports_streaming: true,
|
||||
default_supports_extended_thinking: true,
|
||||
config: {
|
||||
streaming: true,
|
||||
vision: true,
|
||||
function_calling: true,
|
||||
extended_thinking: true,
|
||||
description: 'OpenAI GPT-5.1 Codex 代码专用模型'
|
||||
},
|
||||
provider_count: 2,
|
||||
alias_count: 0,
|
||||
created_at: '2024-01-01T00:00:00Z'
|
||||
},
|
||||
{
|
||||
id: 'gm-008',
|
||||
name: 'gpt-5.1-codex-max',
|
||||
display_name: 'gpt-5.1-codex-max',
|
||||
description: 'OpenAI GPT-5.1 Codex Max 代码专用增强版',
|
||||
is_active: true,
|
||||
default_tiered_pricing: {
|
||||
tiers: [{ up_to: null, input_price_per_1m: 1.25, output_price_per_1m: 10.00 }]
|
||||
},
|
||||
default_supports_vision: true,
|
||||
default_supports_function_calling: true,
|
||||
default_supports_streaming: true,
|
||||
default_supports_extended_thinking: true,
|
||||
config: {
|
||||
streaming: true,
|
||||
vision: true,
|
||||
function_calling: true,
|
||||
extended_thinking: true,
|
||||
description: 'OpenAI GPT-5.1 Codex Max 代码专用增强版'
|
||||
},
|
||||
provider_count: 2,
|
||||
alias_count: 0,
|
||||
created_at: '2024-01-01T00:00:00Z'
|
||||
},
|
||||
{
|
||||
id: 'gm-009',
|
||||
name: 'gpt-5.1-codex-mini',
|
||||
display_name: 'gpt-5.1-codex-mini',
|
||||
description: 'OpenAI GPT-5.1 Codex Mini 轻量代码模型',
|
||||
is_active: true,
|
||||
default_tiered_pricing: {
|
||||
tiers: [{ up_to: null, input_price_per_1m: 1.25, output_price_per_1m: 10.00 }]
|
||||
},
|
||||
default_supports_vision: true,
|
||||
default_supports_function_calling: true,
|
||||
default_supports_streaming: true,
|
||||
default_supports_extended_thinking: true,
|
||||
config: {
|
||||
streaming: true,
|
||||
vision: true,
|
||||
function_calling: true,
|
||||
extended_thinking: true,
|
||||
description: 'OpenAI GPT-5.1 Codex Mini 轻量代码模型'
|
||||
},
|
||||
provider_count: 2,
|
||||
alias_count: 0,
|
||||
created_at: '2024-01-01T00:00:00Z'
|
||||
}
|
||||
]
|
||||
|
||||
@@ -1000,17 +1000,11 @@ const mockHandlers: Record<string, (config: AxiosRequestConfig) => Promise<Axios
|
||||
id: m.id,
|
||||
name: m.name,
|
||||
display_name: m.display_name,
|
||||
description: m.description,
|
||||
icon_url: null,
|
||||
is_active: m.is_active,
|
||||
default_tiered_pricing: m.default_tiered_pricing,
|
||||
default_price_per_request: null,
|
||||
default_supports_vision: m.default_supports_vision,
|
||||
default_supports_function_calling: m.default_supports_function_calling,
|
||||
default_supports_streaming: m.default_supports_streaming,
|
||||
default_supports_extended_thinking: m.default_supports_extended_thinking || false,
|
||||
default_supports_image_generation: false,
|
||||
supported_capabilities: null
|
||||
default_price_per_request: m.default_price_per_request,
|
||||
supported_capabilities: m.supported_capabilities,
|
||||
config: m.config
|
||||
})),
|
||||
total: MOCK_GLOBAL_MODELS.length
|
||||
})
|
||||
|
||||
@@ -91,11 +91,6 @@ const routes: RouteRecordRaw[] = [
|
||||
name: 'ModelManagement',
|
||||
component: () => importWithRetry(() => import('@/views/admin/ModelManagement.vue'))
|
||||
},
|
||||
{
|
||||
path: 'aliases',
|
||||
name: 'AliasManagement',
|
||||
component: () => importWithRetry(() => import('@/views/admin/AliasManagement.vue'))
|
||||
},
|
||||
{
|
||||
path: 'health-monitor',
|
||||
name: 'HealthMonitor',
|
||||
|
||||
@@ -1169,4 +1169,26 @@ body[theme-mode='dark'] .literary-annotation {
|
||||
.scrollbar-hide::-webkit-scrollbar {
|
||||
display: none;
|
||||
}
|
||||
}
|
||||
|
||||
.scrollbar-thin {
|
||||
scrollbar-width: thin;
|
||||
scrollbar-color: hsl(var(--border)) transparent;
|
||||
}
|
||||
|
||||
.scrollbar-thin::-webkit-scrollbar {
|
||||
width: 6px;
|
||||
}
|
||||
|
||||
.scrollbar-thin::-webkit-scrollbar-track {
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
.scrollbar-thin::-webkit-scrollbar-thumb {
|
||||
background-color: hsl(var(--border));
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.scrollbar-thin::-webkit-scrollbar-thumb:hover {
|
||||
background-color: hsl(var(--muted-foreground) / 0.5);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,500 +0,0 @@
|
||||
<template>
|
||||
<div class="flex flex-col">
|
||||
<Card class="overflow-hidden">
|
||||
<!-- 搜索和过滤区域 -->
|
||||
<div class="px-4 sm:px-6 py-3 sm:py-3.5 border-b border-border/60">
|
||||
<div class="flex flex-col sm:flex-row sm:items-center sm:justify-between gap-3 sm:gap-4">
|
||||
<h3 class="text-sm sm:text-base font-semibold shrink-0">
|
||||
别名管理
|
||||
</h3>
|
||||
<div class="flex flex-wrap items-center gap-2">
|
||||
<!-- 搜索框 -->
|
||||
<div class="relative">
|
||||
<Search class="absolute left-2.5 top-1/2 -translate-y-1/2 h-3.5 w-3.5 text-muted-foreground z-10 pointer-events-none" />
|
||||
<Input
|
||||
id="alias-search"
|
||||
v-model="aliasesSearch"
|
||||
placeholder="搜索别名或关联模型"
|
||||
class="w-32 sm:w-44 pl-8 pr-3 h-8 text-sm border-border/60 focus-visible:ring-1"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div class="hidden sm:block h-4 w-px bg-border" />
|
||||
|
||||
<!-- 提供商过滤器 -->
|
||||
<Select
|
||||
v-model:open="aliasProviderSelectOpen"
|
||||
:model-value="aliasProviderFilter"
|
||||
@update:model-value="aliasProviderFilter = $event"
|
||||
>
|
||||
<SelectTrigger class="w-28 sm:w-40 h-8 text-xs border-border/60">
|
||||
<SelectValue />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="all">
|
||||
全部别名
|
||||
</SelectItem>
|
||||
<SelectItem value="global">
|
||||
仅全局别名
|
||||
</SelectItem>
|
||||
<SelectItem
|
||||
v-for="provider in providers"
|
||||
:key="provider.id"
|
||||
:value="provider.id"
|
||||
>
|
||||
{{ provider.display_name }}
|
||||
</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
|
||||
<div class="hidden sm:block h-4 w-px bg-border" />
|
||||
|
||||
<!-- 操作按钮 -->
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-8 w-8"
|
||||
title="新建别名"
|
||||
@click="openCreateAliasDialog"
|
||||
>
|
||||
<Plus class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<RefreshButton
|
||||
:loading="loadingAliases"
|
||||
@click="loadAliases"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
v-if="loadingAliases"
|
||||
class="flex items-center justify-center py-12"
|
||||
>
|
||||
<Loader2 class="w-10 h-10 animate-spin text-primary" />
|
||||
</div>
|
||||
<div v-else>
|
||||
<Table class="hidden xl:table text-sm">
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead class="w-[200px]">
|
||||
别名
|
||||
</TableHead>
|
||||
<TableHead class="w-[280px]">
|
||||
关联模型
|
||||
</TableHead>
|
||||
<TableHead class="w-[70px] text-center">
|
||||
类型
|
||||
</TableHead>
|
||||
<TableHead class="w-[100px] text-center">
|
||||
作用域
|
||||
</TableHead>
|
||||
<TableHead class="w-[70px] text-center">
|
||||
状态
|
||||
</TableHead>
|
||||
<TableHead class="w-[100px] text-center">
|
||||
操作
|
||||
</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
<TableRow v-if="filteredAliases.length === 0">
|
||||
<TableCell
|
||||
colspan="6"
|
||||
class="text-center py-8 text-muted-foreground"
|
||||
>
|
||||
{{ aliasProviderFilter === 'global' ? '暂无全局别名' : '暂无别名' }}
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
<TableRow
|
||||
v-for="alias in paginatedAliases"
|
||||
:key="alias.id"
|
||||
>
|
||||
<TableCell>
|
||||
<span class="font-mono font-medium">{{ alias.alias }}</span>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div class="flex flex-col gap-0.5">
|
||||
<span class="font-medium">{{ alias.global_model_display_name || alias.global_model_name }}</span>
|
||||
<span class="text-xs text-muted-foreground font-mono">{{ alias.global_model_name }}</span>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell class="text-center">
|
||||
<Badge
|
||||
variant="secondary"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ alias.mapping_type === 'mapping' ? '映射' : '别名' }}
|
||||
</Badge>
|
||||
</TableCell>
|
||||
<TableCell class="text-center">
|
||||
<Badge
|
||||
v-if="alias.provider_id"
|
||||
variant="outline"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ alias.provider_name || 'Provider 特定' }}
|
||||
</Badge>
|
||||
<Badge
|
||||
v-else
|
||||
variant="default"
|
||||
class="text-xs"
|
||||
>
|
||||
全局
|
||||
</Badge>
|
||||
</TableCell>
|
||||
<TableCell class="text-center">
|
||||
<Badge
|
||||
:variant="alias.is_active ? 'default' : 'secondary'"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ alias.is_active ? '活跃' : '停用' }}
|
||||
</Badge>
|
||||
</TableCell>
|
||||
<TableCell class="text-center">
|
||||
<div class="flex items-center justify-center gap-1">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
title="编辑别名"
|
||||
@click="openEditAliasDialog(alias)"
|
||||
>
|
||||
<Edit class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
:title="alias.is_active ? '停用别名' : '启用别名'"
|
||||
@click="toggleAliasStatus(alias)"
|
||||
>
|
||||
<Power class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
title="删除别名"
|
||||
@click="confirmDeleteAlias(alias)"
|
||||
>
|
||||
<Trash2 class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
</TableBody>
|
||||
</Table>
|
||||
|
||||
<!-- 移动端卡片列表 -->
|
||||
<div
|
||||
v-if="filteredAliases.length > 0"
|
||||
class="xl:hidden divide-y divide-border/40"
|
||||
>
|
||||
<div
|
||||
v-for="alias in paginatedAliases"
|
||||
:key="alias.id"
|
||||
class="p-4 space-y-2"
|
||||
>
|
||||
<div class="flex items-start justify-between gap-3">
|
||||
<div class="flex-1 min-w-0">
|
||||
<div class="flex items-center gap-2">
|
||||
<span class="font-mono font-medium truncate">{{ alias.alias }}</span>
|
||||
<Badge
|
||||
:variant="alias.is_active ? 'default' : 'secondary'"
|
||||
class="text-xs shrink-0"
|
||||
>
|
||||
{{ alias.is_active ? '活跃' : '停用' }}
|
||||
</Badge>
|
||||
</div>
|
||||
<div class="text-xs text-muted-foreground mt-1">
|
||||
<span class="font-medium">{{ alias.global_model_display_name || alias.global_model_name }}</span>
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex items-center gap-0.5 shrink-0">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
@click="openEditAliasDialog(alias)"
|
||||
>
|
||||
<Edit class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
@click="toggleAliasStatus(alias)"
|
||||
>
|
||||
<Power class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-7 w-7"
|
||||
@click="confirmDeleteAlias(alias)"
|
||||
>
|
||||
<Trash2 class="w-3.5 h-3.5" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex items-center gap-2">
|
||||
<Badge
|
||||
variant="secondary"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ alias.mapping_type === 'mapping' ? '映射' : '别名' }}
|
||||
</Badge>
|
||||
<Badge
|
||||
v-if="alias.provider_id"
|
||||
variant="outline"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ alias.provider_name || 'Provider 特定' }}
|
||||
</Badge>
|
||||
<Badge
|
||||
v-else
|
||||
variant="default"
|
||||
class="text-xs"
|
||||
>
|
||||
全局
|
||||
</Badge>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 分页 -->
|
||||
<Pagination
|
||||
v-if="!loadingAliases && filteredAliases.length > 0"
|
||||
:current="aliasesCurrentPage"
|
||||
:total="filteredAliases.length"
|
||||
:page-size="aliasesPageSize"
|
||||
@update:current="aliasesCurrentPage = $event"
|
||||
@update:page-size="aliasesPageSize = $event"
|
||||
/>
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
<!-- 创建/编辑别名对话框 -->
|
||||
<AliasDialog
|
||||
:open="createAliasDialogOpen"
|
||||
:editing-alias="editingAlias"
|
||||
:global-models="globalModels"
|
||||
:providers="providers"
|
||||
@update:open="handleAliasDialogUpdate"
|
||||
@submit="handleAliasSubmit"
|
||||
/>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted, watch } from 'vue'
|
||||
import {
|
||||
Edit,
|
||||
Loader2,
|
||||
Plus,
|
||||
Power,
|
||||
Search,
|
||||
Trash2
|
||||
} from 'lucide-vue-next'
|
||||
import {
|
||||
Card,
|
||||
Button,
|
||||
Input,
|
||||
Badge,
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
RefreshButton,
|
||||
Pagination
|
||||
} from '@/components/ui'
|
||||
import AliasDialog from '@/features/models/components/AliasDialog.vue'
|
||||
import { useToast } from '@/composables/useToast'
|
||||
import { useConfirm } from '@/composables/useConfirm'
|
||||
import {
|
||||
getAliases,
|
||||
createAlias,
|
||||
updateAlias,
|
||||
deleteAlias,
|
||||
type ModelAlias,
|
||||
type CreateModelAliasRequest,
|
||||
type UpdateModelAliasRequest
|
||||
} from '@/api/endpoints/aliases'
|
||||
import { listGlobalModels, type GlobalModelResponse } from '@/api/global-models'
|
||||
import { getProvidersSummary } from '@/api/endpoints/providers'
|
||||
import { log } from '@/utils/logger'
|
||||
|
||||
const { success, error: showError } = useToast()
|
||||
const { confirmDanger } = useConfirm()
|
||||
|
||||
// 状态
|
||||
const loadingAliases = ref(false)
|
||||
const submitting = ref(false)
|
||||
const aliasesSearch = ref('')
|
||||
const aliasProviderFilter = ref<string>('all')
|
||||
const aliasProviderSelectOpen = ref(false)
|
||||
const createAliasDialogOpen = ref(false)
|
||||
const editingAliasId = ref<string | null>(null)
|
||||
|
||||
// 数据
|
||||
const allAliases = ref<ModelAlias[]>([])
|
||||
const globalModels = ref<GlobalModelResponse[]>([])
|
||||
const providers = ref<any[]>([])
|
||||
|
||||
// 分页
|
||||
const aliasesCurrentPage = ref(1)
|
||||
const aliasesPageSize = ref(20)
|
||||
|
||||
// 编辑中的别名对象
|
||||
const editingAlias = computed(() => {
|
||||
if (!editingAliasId.value) return null
|
||||
return allAliases.value.find(a => a.id === editingAliasId.value) || null
|
||||
})
|
||||
|
||||
// 筛选后的别名列表
|
||||
const filteredAliases = computed(() => {
|
||||
let result = allAliases.value
|
||||
|
||||
// 按 Provider 筛选
|
||||
if (aliasProviderFilter.value === 'global') {
|
||||
result = result.filter(alias => !alias.provider_id)
|
||||
} else if (aliasProviderFilter.value !== 'all') {
|
||||
result = result.filter(alias => alias.provider_id === aliasProviderFilter.value)
|
||||
}
|
||||
|
||||
// 按搜索关键词筛选
|
||||
const keyword = aliasesSearch.value.trim().toLowerCase()
|
||||
if (keyword) {
|
||||
result = result.filter(alias =>
|
||||
alias.alias.toLowerCase().includes(keyword) ||
|
||||
alias.global_model_name?.toLowerCase().includes(keyword) ||
|
||||
alias.global_model_display_name?.toLowerCase().includes(keyword)
|
||||
)
|
||||
}
|
||||
|
||||
return result
|
||||
})
|
||||
|
||||
// 分页计算
|
||||
const paginatedAliases = computed(() => {
|
||||
const start = (aliasesCurrentPage.value - 1) * aliasesPageSize.value
|
||||
const end = start + aliasesPageSize.value
|
||||
return filteredAliases.value.slice(start, end)
|
||||
})
|
||||
|
||||
// 搜索或筛选变化时重置到第一页
|
||||
watch([aliasesSearch, aliasProviderFilter], () => {
|
||||
aliasesCurrentPage.value = 1
|
||||
})
|
||||
|
||||
async function loadAliases() {
|
||||
loadingAliases.value = true
|
||||
try {
|
||||
allAliases.value = await getAliases({ limit: 1000 })
|
||||
} catch (err: any) {
|
||||
showError(err.response?.data?.detail || err.message, '加载别名失败')
|
||||
} finally {
|
||||
loadingAliases.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function loadGlobalModelsList() {
|
||||
try {
|
||||
const response = await listGlobalModels()
|
||||
globalModels.value = response.models || []
|
||||
} catch (err: any) {
|
||||
log.error('加载模型失败:', err)
|
||||
}
|
||||
}
|
||||
|
||||
async function loadProviders() {
|
||||
try {
|
||||
providers.value = await getProvidersSummary()
|
||||
} catch (err: any) {
|
||||
showError(err.response?.data?.detail || err.message, '加载 Provider 列表失败')
|
||||
}
|
||||
}
|
||||
|
||||
function openCreateAliasDialog() {
|
||||
editingAliasId.value = null
|
||||
createAliasDialogOpen.value = true
|
||||
}
|
||||
|
||||
function openEditAliasDialog(alias: ModelAlias) {
|
||||
editingAliasId.value = alias.id
|
||||
createAliasDialogOpen.value = true
|
||||
}
|
||||
|
||||
function handleAliasDialogUpdate(value: boolean) {
|
||||
createAliasDialogOpen.value = value
|
||||
if (!value) {
|
||||
editingAliasId.value = null
|
||||
}
|
||||
}
|
||||
|
||||
async function handleAliasSubmit(data: CreateModelAliasRequest | UpdateModelAliasRequest, isEdit: boolean) {
|
||||
submitting.value = true
|
||||
try {
|
||||
if (isEdit && editingAliasId.value) {
|
||||
await updateAlias(editingAliasId.value, data as UpdateModelAliasRequest)
|
||||
success(data.mapping_type === 'mapping' ? '映射已更新' : '别名已更新')
|
||||
} else {
|
||||
await createAlias(data as CreateModelAliasRequest)
|
||||
success(data.mapping_type === 'mapping' ? '映射已创建' : '别名已创建')
|
||||
}
|
||||
createAliasDialogOpen.value = false
|
||||
editingAliasId.value = null
|
||||
await loadAliases()
|
||||
} catch (err: any) {
|
||||
const detail = err.response?.data?.detail || err.message
|
||||
let errorMessage = detail
|
||||
if (detail === '映射已存在') {
|
||||
errorMessage = '目标作用域已存在同名别名,请先删除冲突的映射或选择其他作用域'
|
||||
}
|
||||
showError(errorMessage, isEdit ? '更新失败' : '创建失败')
|
||||
} finally {
|
||||
submitting.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function confirmDeleteAlias(alias: ModelAlias) {
|
||||
const confirmed = await confirmDanger(
|
||||
`确定要删除别名 "${alias.alias}" 吗?`,
|
||||
'删除别名'
|
||||
)
|
||||
if (!confirmed) return
|
||||
|
||||
try {
|
||||
await deleteAlias(alias.id)
|
||||
success('别名已删除')
|
||||
await loadAliases()
|
||||
} catch (err: any) {
|
||||
showError(err.response?.data?.detail || err.message, '删除失败')
|
||||
}
|
||||
}
|
||||
|
||||
async function toggleAliasStatus(alias: ModelAlias) {
|
||||
try {
|
||||
await updateAlias(alias.id, { is_active: !alias.is_active })
|
||||
alias.is_active = !alias.is_active
|
||||
success(alias.is_active ? '别名已启用' : '别名已停用')
|
||||
} catch (err: any) {
|
||||
showError(err.response?.data?.detail || err.message, '操作失败')
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(async () => {
|
||||
await Promise.all([
|
||||
loadAliases(),
|
||||
loadGlobalModelsList(),
|
||||
loadProviders()
|
||||
])
|
||||
})
|
||||
</script>
|
||||
@@ -751,15 +751,13 @@ const expiringSoonCount = computed(() => apiKeys.value.filter(key => isExpiringS
|
||||
const filteredApiKeys = computed(() => {
|
||||
let result = apiKeys.value
|
||||
|
||||
// 搜索筛选
|
||||
// 搜索筛选(支持空格分隔的多关键词 AND 搜索)
|
||||
if (searchQuery.value) {
|
||||
const query = searchQuery.value.toLowerCase()
|
||||
result = result.filter(key =>
|
||||
(key.name && key.name.toLowerCase().includes(query)) ||
|
||||
(key.key_display && key.key_display.toLowerCase().includes(query)) ||
|
||||
(key.username && key.username.toLowerCase().includes(query)) ||
|
||||
(key.user_email && key.user_email.toLowerCase().includes(query))
|
||||
)
|
||||
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||
result = result.filter(key => {
|
||||
const searchableText = `${key.name || ''} ${key.key_display || ''} ${key.username || ''} ${key.user_email || ''}`.toLowerCase()
|
||||
return keywords.every(keyword => searchableText.includes(keyword))
|
||||
})
|
||||
}
|
||||
|
||||
// 状态筛选
|
||||
|
||||
@@ -18,10 +18,10 @@ import SelectContent from '@/components/ui/select-content.vue'
|
||||
import SelectItem from '@/components/ui/select-item.vue'
|
||||
import SelectValue from '@/components/ui/select-value.vue'
|
||||
import ScatterChart from '@/components/charts/ScatterChart.vue'
|
||||
import { Trash2, Eraser, Search, X, BarChart3, ChevronDown, ChevronRight } from 'lucide-vue-next'
|
||||
import { Trash2, Eraser, Search, X, BarChart3, ChevronDown, ChevronRight, Database, ArrowRight } from 'lucide-vue-next'
|
||||
import { useToast } from '@/composables/useToast'
|
||||
import { useConfirm } from '@/composables/useConfirm'
|
||||
import { cacheApi, type CacheStats, type CacheConfig, type UserAffinity } from '@/api/cache'
|
||||
import { cacheApi, modelMappingCacheApi, type CacheStats, type CacheConfig, type UserAffinity, type ModelMappingCacheStats } from '@/api/cache'
|
||||
import type { TTLAnalysisUser } from '@/api/cache'
|
||||
import { formatNumber, formatTokens, formatCost, formatRemainingTime } from '@/utils/format'
|
||||
import {
|
||||
@@ -47,6 +47,13 @@ const currentPage = ref(1)
|
||||
const pageSize = ref(20)
|
||||
const currentTime = ref(Math.floor(Date.now() / 1000))
|
||||
|
||||
// ==================== 模型映射缓存 ====================
|
||||
|
||||
const modelMappingStats = ref<ModelMappingCacheStats | null>(null)
|
||||
const modelMappingLoading = ref(false)
|
||||
const clearingModelMapping = ref(false)
|
||||
const clearingModelName = ref<string | null>(null)
|
||||
|
||||
const { success: showSuccess, error: showError, info: showInfo } = useToast()
|
||||
const { confirm: showConfirm } = useConfirm()
|
||||
|
||||
@@ -241,13 +248,107 @@ function stopCountdown() {
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 模型映射缓存方法 ====================
|
||||
|
||||
async function fetchModelMappingStats() {
|
||||
modelMappingLoading.value = true
|
||||
try {
|
||||
modelMappingStats.value = await modelMappingCacheApi.getStats()
|
||||
} catch (error) {
|
||||
showError('获取模型映射缓存统计失败')
|
||||
log.error('获取模型映射缓存统计失败', error)
|
||||
} finally {
|
||||
modelMappingLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function clearAllModelMappingCache() {
|
||||
const confirmed = await showConfirm({
|
||||
title: '确认清除',
|
||||
message: '确定要清除所有模型映射缓存吗?这会影响所有模型的名称解析。',
|
||||
confirmText: '确认清除',
|
||||
variant: 'destructive'
|
||||
})
|
||||
|
||||
if (!confirmed) return
|
||||
|
||||
clearingModelMapping.value = true
|
||||
try {
|
||||
const result = await modelMappingCacheApi.clearAll()
|
||||
showSuccess(`已清除 ${result.deleted_count} 个缓存键`)
|
||||
await fetchModelMappingStats()
|
||||
} catch (error) {
|
||||
showError('清除模型映射缓存失败')
|
||||
log.error('清除模型映射缓存失败', error)
|
||||
} finally {
|
||||
clearingModelMapping.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function clearModelMappingByName(modelName: string) {
|
||||
clearingModelName.value = modelName
|
||||
try {
|
||||
await modelMappingCacheApi.clearByName(modelName)
|
||||
showSuccess(`已清除 ${modelName} 的映射缓存`)
|
||||
await fetchModelMappingStats()
|
||||
} catch (error) {
|
||||
showError('清除缓存失败')
|
||||
log.error('清除模型映射缓存失败', error)
|
||||
} finally {
|
||||
clearingModelName.value = null
|
||||
}
|
||||
}
|
||||
|
||||
async function clearProviderModelMapping(providerId: string, globalModelId: string, displayName?: string) {
|
||||
const confirmed = await showConfirm({
|
||||
title: '确认清除',
|
||||
message: `确定要清除 ${displayName || 'Provider 模型映射'} 的缓存吗?`,
|
||||
confirmText: '确认清除',
|
||||
variant: 'destructive'
|
||||
})
|
||||
|
||||
if (!confirmed) return
|
||||
|
||||
try {
|
||||
await modelMappingCacheApi.clearProviderModel(providerId, globalModelId)
|
||||
showSuccess('已清除 Provider 模型映射缓存')
|
||||
await fetchModelMappingStats()
|
||||
} catch (error) {
|
||||
showError('清除缓存失败')
|
||||
log.error('清除 Provider 模型映射缓存失败', error)
|
||||
}
|
||||
}
|
||||
|
||||
function formatTTL(ttl: number | null): string {
|
||||
if (ttl === null || ttl < 0) return '-'
|
||||
if (ttl < 60) return `${ttl}s`
|
||||
const minutes = Math.floor(ttl / 60)
|
||||
const seconds = ttl % 60
|
||||
if (seconds === 0) return `${minutes}m`
|
||||
return `${minutes}m${seconds}s`
|
||||
}
|
||||
|
||||
function getUnmappedStatusBadge(status: string): { variant: 'default' | 'secondary' | 'destructive' | 'outline', text: string } {
|
||||
switch (status) {
|
||||
case 'not_found':
|
||||
return { variant: 'secondary', text: '未找到' }
|
||||
case 'invalid':
|
||||
return { variant: 'destructive', text: '无效' }
|
||||
case 'error':
|
||||
return { variant: 'destructive', text: '错误' }
|
||||
default:
|
||||
return { variant: 'outline', text: status }
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 刷新所有数据 ====================
|
||||
|
||||
async function refreshData() {
|
||||
await Promise.all([
|
||||
fetchCacheStats(),
|
||||
fetchCacheConfig(),
|
||||
fetchAffinityList()
|
||||
fetchAffinityList(),
|
||||
fetchModelMappingStats()
|
||||
])
|
||||
}
|
||||
|
||||
@@ -272,6 +373,7 @@ onMounted(() => {
|
||||
fetchCacheStats()
|
||||
fetchCacheConfig()
|
||||
fetchAffinityList()
|
||||
fetchModelMappingStats()
|
||||
startCountdown()
|
||||
refreshAnalysis()
|
||||
})
|
||||
@@ -599,6 +701,344 @@ onBeforeUnmount(() => {
|
||||
/>
|
||||
</Card>
|
||||
|
||||
<!-- 模型映射缓存管理 -->
|
||||
<Card class="overflow-hidden">
|
||||
<div class="px-4 sm:px-6 py-3 sm:py-3.5 border-b border-border/60">
|
||||
<div class="flex flex-col sm:flex-row sm:items-center sm:justify-between gap-3 sm:gap-4">
|
||||
<div class="flex items-center gap-3 shrink-0">
|
||||
<Database class="h-5 w-5 text-muted-foreground hidden sm:block" />
|
||||
<h3 class="text-sm sm:text-base font-semibold">
|
||||
模型映射缓存
|
||||
</h3>
|
||||
</div>
|
||||
<div class="flex flex-wrap items-center gap-2">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
class="h-8 w-8 text-muted-foreground/70 hover:text-destructive"
|
||||
title="清除全部映射缓存"
|
||||
:disabled="clearingModelMapping || !modelMappingStats?.available"
|
||||
@click="clearAllModelMappingCache"
|
||||
>
|
||||
<Eraser class="h-4 w-4" />
|
||||
</Button>
|
||||
<RefreshButton
|
||||
:loading="modelMappingLoading"
|
||||
@click="fetchModelMappingStats"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 映射缓存表格 -->
|
||||
<Table
|
||||
v-if="modelMappingStats?.available && modelMappingStats.mappings && modelMappingStats.mappings.length > 0"
|
||||
class="hidden md:table"
|
||||
>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead class="w-[25%]">
|
||||
全局模型
|
||||
</TableHead>
|
||||
<TableHead class="w-8 text-center" />
|
||||
<TableHead class="w-[30%]">
|
||||
映射模型
|
||||
</TableHead>
|
||||
<TableHead class="w-[25%]">
|
||||
提供商
|
||||
</TableHead>
|
||||
<TableHead class="w-[10%] text-center">
|
||||
剩余
|
||||
</TableHead>
|
||||
<TableHead class="w-[5%] text-right">
|
||||
操作
|
||||
</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
<TableRow
|
||||
v-for="mapping in modelMappingStats.mappings"
|
||||
:key="mapping.mapping_name"
|
||||
>
|
||||
<TableCell>
|
||||
<div v-if="mapping.global_model_name">
|
||||
<div class="text-sm font-medium">
|
||||
{{ mapping.global_model_display_name || mapping.global_model_name }}
|
||||
</div>
|
||||
<div
|
||||
v-if="mapping.global_model_display_name && mapping.global_model_display_name !== mapping.global_model_name"
|
||||
class="text-xs text-muted-foreground font-mono"
|
||||
>
|
||||
{{ mapping.global_model_name }}
|
||||
</div>
|
||||
</div>
|
||||
<span
|
||||
v-else
|
||||
class="text-sm text-muted-foreground"
|
||||
>-</span>
|
||||
</TableCell>
|
||||
<TableCell class="text-center">
|
||||
<ArrowRight class="h-4 w-4 text-muted-foreground" />
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<span class="text-sm font-mono">{{ mapping.mapping_name }}</span>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div
|
||||
v-if="mapping.providers && mapping.providers.length > 0"
|
||||
class="flex flex-wrap gap-1"
|
||||
>
|
||||
<Badge
|
||||
v-for="provider in mapping.providers.slice(0, 3)"
|
||||
:key="provider"
|
||||
variant="outline"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ provider }}
|
||||
</Badge>
|
||||
<Badge
|
||||
v-if="mapping.providers.length > 3"
|
||||
variant="outline"
|
||||
class="text-xs"
|
||||
>
|
||||
+{{ mapping.providers.length - 3 }}
|
||||
</Badge>
|
||||
</div>
|
||||
<span
|
||||
v-else
|
||||
class="text-sm text-muted-foreground"
|
||||
>-</span>
|
||||
</TableCell>
|
||||
<TableCell class="text-center">
|
||||
<span class="text-xs text-muted-foreground">{{ formatTTL(mapping.ttl) }}</span>
|
||||
</TableCell>
|
||||
<TableCell class="text-right">
|
||||
<Button
|
||||
size="icon"
|
||||
variant="ghost"
|
||||
class="h-6 w-6 text-muted-foreground/50 hover:text-destructive"
|
||||
:disabled="clearingModelName === mapping.mapping_name"
|
||||
title="清除缓存"
|
||||
@click="clearModelMappingByName(mapping.mapping_name)"
|
||||
>
|
||||
<X class="h-3.5 w-3.5" />
|
||||
</Button>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
</TableBody>
|
||||
</Table>
|
||||
|
||||
<!-- 移动端卡片列表 -->
|
||||
<div
|
||||
v-if="modelMappingStats?.available && modelMappingStats.mappings && modelMappingStats.mappings.length > 0"
|
||||
class="md:hidden divide-y divide-border/40"
|
||||
>
|
||||
<div
|
||||
v-for="mapping in modelMappingStats.mappings"
|
||||
:key="`m-${mapping.mapping_name}`"
|
||||
class="p-4 space-y-2"
|
||||
>
|
||||
<div class="flex items-center justify-between gap-2">
|
||||
<span class="text-sm font-medium truncate">{{ mapping.global_model_display_name || mapping.global_model_name || '-' }}</span>
|
||||
<Button
|
||||
size="icon"
|
||||
variant="ghost"
|
||||
class="h-6 w-6 text-muted-foreground/50 hover:text-destructive shrink-0"
|
||||
:disabled="clearingModelName === mapping.mapping_name"
|
||||
@click="clearModelMappingByName(mapping.mapping_name)"
|
||||
>
|
||||
<X class="h-3.5 w-3.5" />
|
||||
</Button>
|
||||
</div>
|
||||
<div class="flex items-center gap-2 text-xs text-muted-foreground">
|
||||
<ArrowRight class="h-3.5 w-3.5 shrink-0" />
|
||||
<span class="font-mono">{{ mapping.mapping_name }}</span>
|
||||
</div>
|
||||
<div
|
||||
v-if="mapping.providers && mapping.providers.length > 0"
|
||||
class="flex flex-wrap gap-1"
|
||||
>
|
||||
<Badge
|
||||
v-for="provider in mapping.providers"
|
||||
:key="provider"
|
||||
variant="outline"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ provider }}
|
||||
</Badge>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 未映射条目(NOT_FOUND 等) -->
|
||||
<div
|
||||
v-if="modelMappingStats?.available && modelMappingStats.unmapped && modelMappingStats.unmapped.length > 0"
|
||||
class="px-6 py-4 border-t border-border/40"
|
||||
>
|
||||
<div class="text-xs text-muted-foreground mb-2">
|
||||
未映射的缓存条目
|
||||
</div>
|
||||
<div class="flex flex-wrap gap-1.5">
|
||||
<Badge
|
||||
v-for="entry in modelMappingStats.unmapped"
|
||||
:key="entry.mapping_name"
|
||||
:variant="getUnmappedStatusBadge(entry.status).variant"
|
||||
class="text-xs font-mono cursor-pointer"
|
||||
:title="`${getUnmappedStatusBadge(entry.status).text} - 点击清除`"
|
||||
@click="clearModelMappingByName(entry.mapping_name)"
|
||||
>
|
||||
{{ entry.mapping_name }}
|
||||
</Badge>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Provider 模型映射缓存 -->
|
||||
<div
|
||||
v-if="modelMappingStats?.available && modelMappingStats.provider_model_mappings && modelMappingStats.provider_model_mappings.length > 0"
|
||||
class="border-t border-border/40"
|
||||
>
|
||||
<div class="px-6 py-3 text-xs text-muted-foreground border-b border-border/30 bg-muted/20">
|
||||
Provider 模型映射缓存
|
||||
</div>
|
||||
<!-- 桌面端表格 -->
|
||||
<Table class="hidden md:table">
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead class="w-[15%]">
|
||||
提供商
|
||||
</TableHead>
|
||||
<TableHead class="w-[25%]">
|
||||
请求名称
|
||||
</TableHead>
|
||||
<TableHead class="w-8 text-center" />
|
||||
<TableHead class="w-[25%]">
|
||||
映射模型
|
||||
</TableHead>
|
||||
<TableHead class="w-[10%] text-center">
|
||||
剩余
|
||||
</TableHead>
|
||||
<TableHead class="w-[10%] text-center">
|
||||
次数
|
||||
</TableHead>
|
||||
<TableHead class="w-[7%] text-right">
|
||||
操作
|
||||
</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
<template
|
||||
v-for="(mapping, index) in modelMappingStats.provider_model_mappings"
|
||||
:key="index"
|
||||
>
|
||||
<TableRow
|
||||
v-for="(alias, aliasIndex) in (mapping.aliases || [])"
|
||||
:key="`${index}-${aliasIndex}`"
|
||||
>
|
||||
<TableCell>
|
||||
<Badge
|
||||
variant="outline"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ mapping.provider_name }}
|
||||
</Badge>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<span class="text-sm font-mono">{{ alias }}</span>
|
||||
</TableCell>
|
||||
<TableCell class="text-center">
|
||||
<ArrowRight class="h-4 w-4 text-muted-foreground" />
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<span class="text-sm font-mono font-medium">{{ mapping.provider_model_name }}</span>
|
||||
</TableCell>
|
||||
<TableCell class="text-center">
|
||||
<span class="text-xs text-muted-foreground">{{ formatTTL(mapping.ttl) }}</span>
|
||||
</TableCell>
|
||||
<TableCell class="text-center">
|
||||
<span class="text-sm">{{ mapping.hit_count || 0 }}</span>
|
||||
</TableCell>
|
||||
<TableCell class="text-right">
|
||||
<Button
|
||||
size="icon"
|
||||
variant="ghost"
|
||||
class="h-7 w-7 text-muted-foreground/70 hover:text-destructive"
|
||||
title="清除缓存"
|
||||
@click="clearProviderModelMapping(mapping.provider_id, mapping.global_model_id, `${mapping.provider_name} - ${alias}`)"
|
||||
>
|
||||
<Trash2 class="h-3.5 w-3.5" />
|
||||
</Button>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
</template>
|
||||
</TableBody>
|
||||
</Table>
|
||||
<!-- 移动端卡片 -->
|
||||
<div class="md:hidden divide-y divide-border/40">
|
||||
<template
|
||||
v-for="(mapping, index) in modelMappingStats.provider_model_mappings"
|
||||
:key="`m-pm-${index}`"
|
||||
>
|
||||
<div
|
||||
v-for="(alias, aliasIndex) in (mapping.aliases || [])"
|
||||
:key="`m-pm-${index}-${aliasIndex}`"
|
||||
class="p-4 space-y-2"
|
||||
>
|
||||
<div class="flex items-center justify-between">
|
||||
<Badge
|
||||
variant="outline"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ mapping.provider_name }}
|
||||
</Badge>
|
||||
<div class="flex items-center gap-2">
|
||||
<span class="text-xs text-muted-foreground">{{ formatTTL(mapping.ttl) }}</span>
|
||||
<span class="text-xs">{{ mapping.hit_count || 0 }}次</span>
|
||||
<Button
|
||||
size="icon"
|
||||
variant="ghost"
|
||||
class="h-6 w-6 text-muted-foreground/70 hover:text-destructive"
|
||||
title="清除缓存"
|
||||
@click="clearProviderModelMapping(mapping.provider_id, mapping.global_model_id, `${mapping.provider_name} - ${alias}`)"
|
||||
>
|
||||
<Trash2 class="h-3 w-3" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex items-center gap-2 text-sm">
|
||||
<span class="font-mono">{{ alias }}</span>
|
||||
<ArrowRight class="h-3.5 w-3.5 shrink-0 text-muted-foreground/60" />
|
||||
<span class="font-mono font-medium">{{ mapping.provider_model_name }}</span>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 无缓存状态 -->
|
||||
<div
|
||||
v-else-if="modelMappingStats?.available && (!modelMappingStats.mappings || modelMappingStats.mappings.length === 0) && (!modelMappingStats.unmapped || modelMappingStats.unmapped.length === 0) && (!modelMappingStats.provider_model_mappings || modelMappingStats.provider_model_mappings.length === 0)"
|
||||
class="px-6 py-8 text-center text-sm text-muted-foreground"
|
||||
>
|
||||
暂无模型解析缓存
|
||||
</div>
|
||||
|
||||
<!-- Redis 未启用 -->
|
||||
<div
|
||||
v-else-if="modelMappingStats && !modelMappingStats.available"
|
||||
class="px-6 py-8 text-center text-sm text-muted-foreground"
|
||||
>
|
||||
{{ modelMappingStats.message || 'Redis 未启用' }}
|
||||
</div>
|
||||
|
||||
<!-- 加载中 -->
|
||||
<div
|
||||
v-else-if="modelMappingLoading"
|
||||
class="px-6 py-8 text-center text-sm text-muted-foreground"
|
||||
>
|
||||
加载中...
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
<!-- TTL 分析区域 -->
|
||||
<Card class="overflow-hidden">
|
||||
<div class="px-4 sm:px-6 py-3 sm:py-3.5 border-b border-border/60">
|
||||
|
||||
@@ -111,9 +111,6 @@
|
||||
<TableHead class="w-[80px] text-center">
|
||||
提供商
|
||||
</TableHead>
|
||||
<TableHead class="w-[70px] text-center">
|
||||
别名/映射
|
||||
</TableHead>
|
||||
<TableHead class="w-[80px] text-center">
|
||||
调用次数
|
||||
</TableHead>
|
||||
@@ -128,7 +125,7 @@
|
||||
<TableBody>
|
||||
<TableRow v-if="loading">
|
||||
<TableCell
|
||||
colspan="8"
|
||||
colspan="7"
|
||||
class="text-center py-8"
|
||||
>
|
||||
<Loader2 class="w-6 h-6 animate-spin mx-auto" />
|
||||
@@ -136,7 +133,7 @@
|
||||
</TableRow>
|
||||
<TableRow v-else-if="filteredGlobalModels.length === 0">
|
||||
<TableCell
|
||||
colspan="8"
|
||||
colspan="7"
|
||||
class="text-center py-8 text-muted-foreground"
|
||||
>
|
||||
没有找到匹配的模型
|
||||
@@ -171,27 +168,27 @@
|
||||
<div class="space-y-1 w-fit">
|
||||
<div class="flex flex-wrap gap-1">
|
||||
<Zap
|
||||
v-if="model.default_supports_streaming"
|
||||
v-if="model.config?.streaming !== false"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
title="流式输出"
|
||||
/>
|
||||
<Image
|
||||
v-if="model.default_supports_image_generation"
|
||||
v-if="model.config?.image_generation === true"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
title="图像生成"
|
||||
/>
|
||||
<Eye
|
||||
v-if="model.default_supports_vision"
|
||||
v-if="model.config?.vision === true"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
title="视觉理解"
|
||||
/>
|
||||
<Wrench
|
||||
v-if="model.default_supports_function_calling"
|
||||
v-if="model.config?.function_calling === true"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
title="工具调用"
|
||||
/>
|
||||
<Brain
|
||||
v-if="model.default_supports_extended_thinking"
|
||||
v-if="model.config?.extended_thinking === true"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
title="深度思考"
|
||||
/>
|
||||
@@ -244,11 +241,6 @@
|
||||
{{ model.provider_count || 0 }}
|
||||
</Badge>
|
||||
</TableCell>
|
||||
<TableCell class="text-center">
|
||||
<Badge variant="secondary">
|
||||
{{ model.alias_count || 0 }}
|
||||
</Badge>
|
||||
</TableCell>
|
||||
<TableCell class="text-center">
|
||||
<span class="text-sm font-mono">{{ formatUsageCount(model.usage_count || 0) }}</span>
|
||||
</TableCell>
|
||||
@@ -369,23 +361,23 @@
|
||||
<!-- 第二行:能力图标 -->
|
||||
<div class="flex flex-wrap gap-1.5">
|
||||
<Zap
|
||||
v-if="model.default_supports_streaming"
|
||||
v-if="model.config?.streaming !== false"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
/>
|
||||
<Image
|
||||
v-if="model.default_supports_image_generation"
|
||||
v-if="model.config?.image_generation === true"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
/>
|
||||
<Eye
|
||||
v-if="model.default_supports_vision"
|
||||
v-if="model.config?.vision === true"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
/>
|
||||
<Wrench
|
||||
v-if="model.default_supports_function_calling"
|
||||
v-if="model.config?.function_calling === true"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
/>
|
||||
<Brain
|
||||
v-if="model.default_supports_extended_thinking"
|
||||
v-if="model.config?.extended_thinking === true"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
/>
|
||||
</div>
|
||||
@@ -393,7 +385,6 @@
|
||||
<!-- 第三行:统计信息 -->
|
||||
<div class="flex flex-wrap items-center gap-3 text-xs text-muted-foreground">
|
||||
<span>提供商 {{ model.provider_count || 0 }}</span>
|
||||
<span>别名 {{ model.alias_count || 0 }}</span>
|
||||
<span>调用 {{ formatUsageCount(model.usage_count || 0) }}</span>
|
||||
<span
|
||||
v-if="getFirstTierPrice(model, 'input') || getFirstTierPrice(model, 'output')"
|
||||
@@ -425,25 +416,12 @@
|
||||
@success="handleModelFormSuccess"
|
||||
/>
|
||||
|
||||
<!-- 创建/编辑别名/映射对话框 -->
|
||||
<AliasDialog
|
||||
:open="createAliasDialogOpen"
|
||||
:editing-alias="editingAlias"
|
||||
:global-models="globalModels"
|
||||
:providers="providers"
|
||||
:fixed-target-model="isTargetModelFixed ? selectedModel : null"
|
||||
@update:open="handleAliasDialogUpdate"
|
||||
@submit="handleAliasSubmit"
|
||||
/>
|
||||
|
||||
<!-- 模型详情抽屉 -->
|
||||
<ModelDetailDrawer
|
||||
:model="selectedModel"
|
||||
:open="!!selectedModel"
|
||||
:providers="selectedModelProviders"
|
||||
:aliases="selectedModelAliases"
|
||||
:loading-providers="loadingModelProviders"
|
||||
:loading-aliases="loadingModelAliases"
|
||||
:has-blocking-dialog-open="hasBlockingDialogOpen"
|
||||
:capabilities="capabilities"
|
||||
@update:open="handleDrawerOpenChange"
|
||||
@@ -454,11 +432,6 @@
|
||||
@delete-provider="confirmDeleteProviderImplementation"
|
||||
@toggle-provider-status="toggleProviderStatus"
|
||||
@refresh-providers="refreshSelectedModelProviders"
|
||||
@add-alias="openAddAliasDialog"
|
||||
@edit-alias="openEditAliasDialog"
|
||||
@toggle-alias-status="toggleAliasStatusFromDrawer"
|
||||
@delete-alias="confirmDeleteAliasFromDrawer"
|
||||
@refresh-aliases="refreshSelectedModelAliases"
|
||||
/>
|
||||
|
||||
<!-- 批量添加关联提供商对话框 -->
|
||||
@@ -736,9 +709,7 @@ import {
|
||||
} from 'lucide-vue-next'
|
||||
import ModelDetailDrawer from '@/features/models/components/ModelDetailDrawer.vue'
|
||||
import GlobalModelFormDialog from '@/features/models/components/GlobalModelFormDialog.vue'
|
||||
import AliasDialog from '@/features/models/components/AliasDialog.vue'
|
||||
import ProviderModelFormDialog from '@/features/providers/components/ProviderModelFormDialog.vue'
|
||||
import type { CreateModelAliasRequest, UpdateModelAliasRequest } from '@/api/endpoints/aliases'
|
||||
import type { Model } from '@/api/endpoints'
|
||||
import { useToast } from '@/composables/useToast'
|
||||
import { useConfirm } from '@/composables/useConfirm'
|
||||
@@ -768,13 +739,6 @@ import {
|
||||
type GlobalModelResponse,
|
||||
} from '@/api/global-models'
|
||||
import { log } from '@/utils/logger'
|
||||
import {
|
||||
getAliases,
|
||||
createAlias,
|
||||
updateAlias,
|
||||
deleteAlias,
|
||||
type ModelAlias,
|
||||
} from '@/api/endpoints/aliases'
|
||||
import { getProvidersSummary } from '@/api/endpoints/providers'
|
||||
import { getAllCapabilities, type CapabilityDefinition } from '@/api/endpoints'
|
||||
|
||||
@@ -788,13 +752,9 @@ const searchQuery = ref('')
|
||||
const selectedModel = ref<GlobalModelResponse | null>(null)
|
||||
const createModelDialogOpen = ref(false)
|
||||
const editingModel = ref<GlobalModelResponse | null>(null)
|
||||
const createAliasDialogOpen = ref(false)
|
||||
const editingAliasId = ref<string | null>(null)
|
||||
const isTargetModelFixed = ref(false) // 目标模型是否固定(从模型详情抽屉打开时为 true)
|
||||
|
||||
// 数据
|
||||
const globalModels = ref<GlobalModelResponse[]>([])
|
||||
const allAliases = ref<ModelAlias[]>([])
|
||||
const providers = ref<any[]>([])
|
||||
const capabilities = ref<CapabilityDefinition[]>([])
|
||||
|
||||
@@ -804,9 +764,7 @@ const catalogPageSize = ref(20)
|
||||
|
||||
// 选中模型的详细数据
|
||||
const selectedModelProviders = ref<any[]>([])
|
||||
const selectedModelAliases = ref<ModelAlias[]>([])
|
||||
const loadingModelProviders = ref(false)
|
||||
const loadingModelAliases = ref(false)
|
||||
|
||||
// 批量添加关联提供商
|
||||
const batchAddProvidersDialogOpen = ref(false)
|
||||
@@ -876,19 +834,10 @@ function hasTieredPricing(model: GlobalModelResponse): boolean {
|
||||
// 检测是否有对话框打开(防止误关闭抽屉)
|
||||
const hasBlockingDialogOpen = computed(() =>
|
||||
createModelDialogOpen.value ||
|
||||
createAliasDialogOpen.value ||
|
||||
batchAddProvidersDialogOpen.value ||
|
||||
editProviderDialogOpen.value
|
||||
)
|
||||
|
||||
// 编辑中的别名对象(用于传递给 AliasDialog)
|
||||
const editingAlias = computed(() => {
|
||||
if (!editingAliasId.value) return null
|
||||
return allAliases.value.find(a => a.id === editingAliasId.value) ||
|
||||
selectedModelAliases.value.find(a => a.id === editingAliasId.value) ||
|
||||
null
|
||||
})
|
||||
|
||||
// 能力筛选
|
||||
const capabilityFilters = ref({
|
||||
streaming: false,
|
||||
@@ -1053,30 +1002,30 @@ async function batchRemoveSelectedProviders() {
|
||||
const filteredGlobalModels = computed(() => {
|
||||
let result = globalModels.value
|
||||
|
||||
// 搜索
|
||||
// 搜索(支持空格分隔的多关键词 AND 搜索)
|
||||
if (searchQuery.value) {
|
||||
const query = searchQuery.value.toLowerCase()
|
||||
result = result.filter(m =>
|
||||
m.name.toLowerCase().includes(query) ||
|
||||
m.display_name?.toLowerCase().includes(query)
|
||||
)
|
||||
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||
result = result.filter(m => {
|
||||
const searchableText = `${m.name} ${m.display_name || ''}`.toLowerCase()
|
||||
return keywords.every(keyword => searchableText.includes(keyword))
|
||||
})
|
||||
}
|
||||
|
||||
// 能力筛选
|
||||
if (capabilityFilters.value.streaming) {
|
||||
result = result.filter(m => m.default_supports_streaming)
|
||||
result = result.filter(m => m.config?.streaming !== false)
|
||||
}
|
||||
if (capabilityFilters.value.imageGeneration) {
|
||||
result = result.filter(m => m.default_supports_image_generation)
|
||||
result = result.filter(m => m.config?.image_generation === true)
|
||||
}
|
||||
if (capabilityFilters.value.vision) {
|
||||
result = result.filter(m => m.default_supports_vision)
|
||||
result = result.filter(m => m.config?.vision === true)
|
||||
}
|
||||
if (capabilityFilters.value.toolUse) {
|
||||
result = result.filter(m => m.default_supports_function_calling)
|
||||
result = result.filter(m => m.config?.function_calling === true)
|
||||
}
|
||||
if (capabilityFilters.value.extendedThinking) {
|
||||
result = result.filter(m => m.default_supports_extended_thinking)
|
||||
result = result.filter(m => m.config?.extended_thinking === true)
|
||||
}
|
||||
|
||||
return result
|
||||
@@ -1131,11 +1080,8 @@ async function selectModel(model: GlobalModelResponse) {
|
||||
selectedModel.value = model
|
||||
detailTab.value = 'basic'
|
||||
|
||||
// 加载该模型的关联提供商和别名
|
||||
await Promise.all([
|
||||
loadModelProviders(model.id),
|
||||
loadModelAliases(model.id)
|
||||
])
|
||||
// 加载该模型的关联提供商
|
||||
await loadModelProviders(model.id)
|
||||
}
|
||||
|
||||
// 加载指定模型的关联提供商
|
||||
@@ -1187,27 +1133,6 @@ async function loadModelProviders(_globalModelId: string) {
|
||||
}
|
||||
}
|
||||
|
||||
// 加载指定模型的别名
|
||||
async function loadModelAliases(globalModelId: string) {
|
||||
loadingModelAliases.value = true
|
||||
try {
|
||||
const aliases = await getAliases({ limit: 1000 })
|
||||
selectedModelAliases.value = aliases.filter(a => a.global_model_id === globalModelId)
|
||||
} catch (err: any) {
|
||||
log.error('加载别名失败:', err)
|
||||
selectedModelAliases.value = []
|
||||
} finally {
|
||||
loadingModelAliases.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 刷新当前选中模型的别名
|
||||
async function refreshSelectedModelAliases() {
|
||||
if (selectedModel.value) {
|
||||
await loadModelAliases(selectedModel.value.id)
|
||||
}
|
||||
}
|
||||
|
||||
// 刷新当前选中模型的关联提供商
|
||||
async function refreshSelectedModelProviders() {
|
||||
if (selectedModel.value) {
|
||||
@@ -1329,14 +1254,6 @@ async function confirmDeleteProviderImplementation(provider: any) {
|
||||
}
|
||||
}
|
||||
|
||||
// 打开添加别名对话框(从模型详情抽屉)
|
||||
function openAddAliasDialog() {
|
||||
if (!selectedModel.value) return
|
||||
editingAliasId.value = null
|
||||
isTargetModelFixed.value = true // 目标模型固定为当前选中模型
|
||||
createAliasDialogOpen.value = true
|
||||
}
|
||||
|
||||
function openCreateModelDialog() {
|
||||
editingModel.value = null
|
||||
createModelDialogOpen.value = true
|
||||
@@ -1391,106 +1308,6 @@ async function toggleModelStatus(model: GlobalModelResponse) {
|
||||
}
|
||||
}
|
||||
|
||||
async function loadAliases() {
|
||||
try {
|
||||
allAliases.value = await getAliases({ limit: 1000 })
|
||||
} catch (err: any) {
|
||||
showError(err.response?.data?.detail || err.message, '加载别名失败')
|
||||
}
|
||||
}
|
||||
|
||||
function openEditAliasDialog(alias: ModelAlias) {
|
||||
editingAliasId.value = alias.id
|
||||
isTargetModelFixed.value = false
|
||||
createAliasDialogOpen.value = true
|
||||
}
|
||||
|
||||
// 处理别名对话框关闭事件
|
||||
function handleAliasDialogUpdate(value: boolean) {
|
||||
createAliasDialogOpen.value = value
|
||||
if (!value) {
|
||||
editingAliasId.value = null
|
||||
isTargetModelFixed.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 处理别名提交(来自 AliasDialog 组件)
|
||||
async function handleAliasSubmit(data: CreateModelAliasRequest | UpdateModelAliasRequest, isEdit: boolean) {
|
||||
submitting.value = true
|
||||
try {
|
||||
if (isEdit && editingAliasId.value) {
|
||||
// 更新
|
||||
await updateAlias(editingAliasId.value, data as UpdateModelAliasRequest)
|
||||
success(data.mapping_type === 'mapping' ? '映射已更新' : '别名已更新')
|
||||
} else {
|
||||
// 创建
|
||||
await createAlias(data as CreateModelAliasRequest)
|
||||
success(data.mapping_type === 'mapping' ? '映射已创建' : '别名已创建')
|
||||
}
|
||||
createAliasDialogOpen.value = false
|
||||
editingAliasId.value = null
|
||||
isTargetModelFixed.value = false
|
||||
|
||||
// 刷新数据
|
||||
await loadAliases()
|
||||
if (selectedModel.value) {
|
||||
await loadModelAliases(selectedModel.value.id)
|
||||
}
|
||||
// 刷新外层模型列表以更新 alias_count
|
||||
await loadGlobalModels()
|
||||
} catch (err: any) {
|
||||
const detail = err.response?.data?.detail || err.message
|
||||
// 优化错误提示文案
|
||||
let errorMessage = detail
|
||||
if (detail === '映射已存在') {
|
||||
errorMessage = '目标作用域已存在同名别名,请先删除冲突的映射或选择其他作用域'
|
||||
}
|
||||
showError(errorMessage, isEdit ? '更新失败' : '创建失败')
|
||||
} finally {
|
||||
submitting.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function confirmDeleteAlias(alias: ModelAlias) {
|
||||
const confirmed = await confirmDanger(
|
||||
`确定要删除别名 "${alias.alias}" 吗?`,
|
||||
'删除别名'
|
||||
)
|
||||
if (!confirmed) return
|
||||
|
||||
try {
|
||||
await deleteAlias(alias.id)
|
||||
success('别名已删除')
|
||||
await loadAliases()
|
||||
// 刷新外层模型列表以更新 alias_count
|
||||
await loadGlobalModels()
|
||||
} catch (err: any) {
|
||||
showError(err.response?.data?.detail || err.message, '删除失败')
|
||||
}
|
||||
}
|
||||
|
||||
async function toggleAliasStatus(alias: ModelAlias) {
|
||||
try {
|
||||
await updateAlias(alias.id, { is_active: !alias.is_active })
|
||||
alias.is_active = !alias.is_active
|
||||
success(alias.is_active ? '别名已启用' : '别名已停用')
|
||||
} catch (err: any) {
|
||||
showError(err.response?.data?.detail || err.message, '操作失败')
|
||||
}
|
||||
}
|
||||
|
||||
// 从抽屉中切换别名状态
|
||||
async function toggleAliasStatusFromDrawer(alias: ModelAlias) {
|
||||
await toggleAliasStatus(alias)
|
||||
await refreshSelectedModelAliases()
|
||||
}
|
||||
|
||||
// 从抽屉中删除别名
|
||||
async function confirmDeleteAliasFromDrawer(alias: ModelAlias) {
|
||||
await confirmDeleteAlias(alias)
|
||||
await refreshSelectedModelAliases()
|
||||
}
|
||||
|
||||
async function refreshData() {
|
||||
await loadGlobalModels()
|
||||
}
|
||||
|
||||
@@ -505,13 +505,13 @@ const priorityModeConfig = computed(() => {
|
||||
const filteredProviders = computed(() => {
|
||||
let result = [...providers.value]
|
||||
|
||||
// 搜索筛选
|
||||
// 搜索筛选(支持空格分隔的多关键词 AND 搜索)
|
||||
if (searchQuery.value.trim()) {
|
||||
const query = searchQuery.value.toLowerCase()
|
||||
result = result.filter(p =>
|
||||
p.display_name.toLowerCase().includes(query) ||
|
||||
p.name.toLowerCase().includes(query)
|
||||
)
|
||||
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||
result = result.filter(p => {
|
||||
const searchableText = `${p.display_name} ${p.name}`.toLowerCase()
|
||||
return keywords.every(keyword => searchableText.includes(keyword))
|
||||
})
|
||||
}
|
||||
|
||||
// 排序
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
<template #actions>
|
||||
<Button
|
||||
:disabled="loading"
|
||||
class="shadow-none hover:shadow-none"
|
||||
@click="saveSystemConfig"
|
||||
>
|
||||
{{ loading ? '保存中...' : '保存所有配置' }}
|
||||
@@ -15,6 +16,94 @@
|
||||
</PageHeader>
|
||||
|
||||
<div class="mt-6 space-y-6">
|
||||
<!-- 配置导出/导入 -->
|
||||
<CardSection
|
||||
title="配置管理"
|
||||
description="导出或导入提供商和模型配置,便于备份或迁移"
|
||||
>
|
||||
<div class="flex flex-wrap gap-4">
|
||||
<div class="flex-1 min-w-[200px]">
|
||||
<p class="text-sm text-muted-foreground mb-3">
|
||||
导出当前所有提供商、端点、API Key 和模型配置到 JSON 文件
|
||||
</p>
|
||||
<Button
|
||||
variant="outline"
|
||||
:disabled="exportLoading"
|
||||
@click="handleExportConfig"
|
||||
>
|
||||
<Download class="w-4 h-4 mr-2" />
|
||||
{{ exportLoading ? '导出中...' : '导出配置' }}
|
||||
</Button>
|
||||
</div>
|
||||
<div class="flex-1 min-w-[200px]">
|
||||
<p class="text-sm text-muted-foreground mb-3">
|
||||
从 JSON 文件导入配置,支持跳过、覆盖或报错三种冲突处理模式
|
||||
</p>
|
||||
<div class="flex items-center gap-2">
|
||||
<input
|
||||
ref="configFileInput"
|
||||
type="file"
|
||||
accept=".json"
|
||||
class="hidden"
|
||||
@change="handleConfigFileSelect"
|
||||
>
|
||||
<Button
|
||||
variant="outline"
|
||||
:disabled="importLoading"
|
||||
@click="triggerConfigFileSelect"
|
||||
>
|
||||
<Upload class="w-4 h-4 mr-2" />
|
||||
{{ importLoading ? '导入中...' : '导入配置' }}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</CardSection>
|
||||
|
||||
<!-- 用户数据导出/导入 -->
|
||||
<CardSection
|
||||
title="用户数据管理"
|
||||
description="导出或导入用户及其 API Keys 数据(不含管理员)"
|
||||
>
|
||||
<div class="flex flex-wrap gap-4">
|
||||
<div class="flex-1 min-w-[200px]">
|
||||
<p class="text-sm text-muted-foreground mb-3">
|
||||
导出所有普通用户及其 API Keys 到 JSON 文件
|
||||
</p>
|
||||
<Button
|
||||
variant="outline"
|
||||
:disabled="exportUsersLoading"
|
||||
@click="handleExportUsers"
|
||||
>
|
||||
<Download class="w-4 h-4 mr-2" />
|
||||
{{ exportUsersLoading ? '导出中...' : '导出用户数据' }}
|
||||
</Button>
|
||||
</div>
|
||||
<div class="flex-1 min-w-[200px]">
|
||||
<p class="text-sm text-muted-foreground mb-3">
|
||||
从 JSON 文件导入用户数据(需相同 ENCRYPTION_KEY)
|
||||
</p>
|
||||
<div class="flex items-center gap-2">
|
||||
<input
|
||||
ref="usersFileInput"
|
||||
type="file"
|
||||
accept=".json"
|
||||
class="hidden"
|
||||
@change="handleUsersFileSelect"
|
||||
>
|
||||
<Button
|
||||
variant="outline"
|
||||
:disabled="importUsersLoading"
|
||||
@click="triggerUsersFileSelect"
|
||||
>
|
||||
<Upload class="w-4 h-4 mr-2" />
|
||||
{{ importUsersLoading ? '导入中...' : '导入用户数据' }}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</CardSection>
|
||||
|
||||
<!-- 基础配置 -->
|
||||
<CardSection
|
||||
title="基础配置"
|
||||
@@ -375,11 +464,312 @@
|
||||
</div>
|
||||
</CardSection>
|
||||
</div>
|
||||
|
||||
<!-- 导入配置对话框 -->
|
||||
<Dialog
|
||||
v-model:open="importDialogOpen"
|
||||
title="导入配置"
|
||||
description="选择冲突处理模式并确认导入"
|
||||
>
|
||||
<div class="space-y-4">
|
||||
<div
|
||||
v-if="importPreview"
|
||||
class="p-3 bg-muted rounded-lg text-sm"
|
||||
>
|
||||
<p class="font-medium mb-2">
|
||||
配置预览
|
||||
</p>
|
||||
<ul class="space-y-1 text-muted-foreground">
|
||||
<li>全局模型: {{ importPreview.global_models?.length || 0 }} 个</li>
|
||||
<li>提供商: {{ importPreview.providers?.length || 0 }} 个</li>
|
||||
<li>
|
||||
端点: {{ importPreview.providers?.reduce((sum: number, p: any) => sum + (p.endpoints?.length || 0), 0) }} 个
|
||||
</li>
|
||||
<li>
|
||||
API Keys: {{ importPreview.providers?.reduce((sum: number, p: any) => sum + p.endpoints?.reduce((s: number, e: any) => s + (e.keys?.length || 0), 0), 0) }} 个
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Label class="block text-sm font-medium mb-2">冲突处理模式</Label>
|
||||
<Select
|
||||
v-model="mergeMode"
|
||||
v-model:open="mergeModeSelectOpen"
|
||||
>
|
||||
<SelectTrigger>
|
||||
<SelectValue />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="skip">
|
||||
跳过 - 保留现有配置
|
||||
</SelectItem>
|
||||
<SelectItem value="overwrite">
|
||||
覆盖 - 用导入配置替换
|
||||
</SelectItem>
|
||||
<SelectItem value="error">
|
||||
报错 - 遇到冲突时中止
|
||||
</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
<template v-if="mergeMode === 'skip'">
|
||||
已存在的配置将被保留,仅导入新配置
|
||||
</template>
|
||||
<template v-else-if="mergeMode === 'overwrite'">
|
||||
已存在的配置将被导入的配置覆盖
|
||||
</template>
|
||||
<template v-else>
|
||||
如果发现任何冲突,导入将中止并回滚
|
||||
</template>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<p class="text-xs text-muted-foreground">
|
||||
注意:相同的 API Keys 会自动跳过,不会创建重复记录。
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<template #footer>
|
||||
<Button
|
||||
variant="outline"
|
||||
@click="importDialogOpen = false; mergeModeSelectOpen = false"
|
||||
>
|
||||
取消
|
||||
</Button>
|
||||
<Button
|
||||
:disabled="importLoading"
|
||||
@click="confirmImport"
|
||||
>
|
||||
{{ importLoading ? '导入中...' : '确认导入' }}
|
||||
</Button>
|
||||
</template>
|
||||
</Dialog>
|
||||
|
||||
<!-- 导入结果对话框 -->
|
||||
<Dialog
|
||||
v-model:open="importResultDialogOpen"
|
||||
title="导入完成"
|
||||
>
|
||||
<div
|
||||
v-if="importResult"
|
||||
class="space-y-4"
|
||||
>
|
||||
<div class="grid grid-cols-2 gap-4 text-sm">
|
||||
<div class="p-3 bg-muted rounded-lg">
|
||||
<p class="font-medium">
|
||||
全局模型
|
||||
</p>
|
||||
<p class="text-muted-foreground">
|
||||
创建: {{ importResult.stats.global_models.created }},
|
||||
更新: {{ importResult.stats.global_models.updated }},
|
||||
跳过: {{ importResult.stats.global_models.skipped }}
|
||||
</p>
|
||||
</div>
|
||||
<div class="p-3 bg-muted rounded-lg">
|
||||
<p class="font-medium">
|
||||
提供商
|
||||
</p>
|
||||
<p class="text-muted-foreground">
|
||||
创建: {{ importResult.stats.providers.created }},
|
||||
更新: {{ importResult.stats.providers.updated }},
|
||||
跳过: {{ importResult.stats.providers.skipped }}
|
||||
</p>
|
||||
</div>
|
||||
<div class="p-3 bg-muted rounded-lg">
|
||||
<p class="font-medium">
|
||||
端点
|
||||
</p>
|
||||
<p class="text-muted-foreground">
|
||||
创建: {{ importResult.stats.endpoints.created }},
|
||||
更新: {{ importResult.stats.endpoints.updated }},
|
||||
跳过: {{ importResult.stats.endpoints.skipped }}
|
||||
</p>
|
||||
</div>
|
||||
<div class="p-3 bg-muted rounded-lg">
|
||||
<p class="font-medium">
|
||||
API Keys
|
||||
</p>
|
||||
<p class="text-muted-foreground">
|
||||
创建: {{ importResult.stats.keys.created }},
|
||||
跳过: {{ importResult.stats.keys.skipped }}
|
||||
</p>
|
||||
</div>
|
||||
<div class="p-3 bg-muted rounded-lg col-span-2">
|
||||
<p class="font-medium">
|
||||
模型配置
|
||||
</p>
|
||||
<p class="text-muted-foreground">
|
||||
创建: {{ importResult.stats.models.created }},
|
||||
更新: {{ importResult.stats.models.updated }},
|
||||
跳过: {{ importResult.stats.models.skipped }}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div
|
||||
v-if="importResult.stats.errors.length > 0"
|
||||
class="p-3 bg-destructive/10 rounded-lg"
|
||||
>
|
||||
<p class="font-medium text-destructive mb-2">
|
||||
警告信息
|
||||
</p>
|
||||
<ul class="text-sm text-destructive space-y-1">
|
||||
<li
|
||||
v-for="(err, index) in importResult.stats.errors"
|
||||
:key="index"
|
||||
>
|
||||
{{ err }}
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<template #footer>
|
||||
<Button @click="importResultDialogOpen = false">
|
||||
确定
|
||||
</Button>
|
||||
</template>
|
||||
</Dialog>
|
||||
|
||||
<!-- 用户数据导入对话框 -->
|
||||
<Dialog
|
||||
v-model:open="importUsersDialogOpen"
|
||||
title="导入用户数据"
|
||||
description="选择冲突处理模式并确认导入"
|
||||
>
|
||||
<div class="space-y-4">
|
||||
<div
|
||||
v-if="importUsersPreview"
|
||||
class="p-3 bg-muted rounded-lg text-sm"
|
||||
>
|
||||
<p class="font-medium mb-2">
|
||||
数据预览
|
||||
</p>
|
||||
<ul class="space-y-1 text-muted-foreground">
|
||||
<li>用户: {{ importUsersPreview.users?.length || 0 }} 个</li>
|
||||
<li>
|
||||
API Keys: {{ importUsersPreview.users?.reduce((sum: number, u: any) => sum + (u.api_keys?.length || 0), 0) }} 个
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Label class="block text-sm font-medium mb-2">冲突处理模式</Label>
|
||||
<Select
|
||||
v-model="usersMergeMode"
|
||||
v-model:open="usersMergeModeSelectOpen"
|
||||
>
|
||||
<SelectTrigger>
|
||||
<SelectValue />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="skip">
|
||||
跳过 - 保留现有用户
|
||||
</SelectItem>
|
||||
<SelectItem value="overwrite">
|
||||
覆盖 - 用导入数据替换
|
||||
</SelectItem>
|
||||
<SelectItem value="error">
|
||||
报错 - 遇到冲突时中止
|
||||
</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<p class="mt-1 text-xs text-muted-foreground">
|
||||
<template v-if="usersMergeMode === 'skip'">
|
||||
已存在的用户将被保留,仅导入新用户
|
||||
</template>
|
||||
<template v-else-if="usersMergeMode === 'overwrite'">
|
||||
已存在的用户将被导入的数据覆盖
|
||||
</template>
|
||||
<template v-else>
|
||||
如果发现任何冲突,导入将中止并回滚
|
||||
</template>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<p class="text-xs text-muted-foreground">
|
||||
注意:用户 API Keys 需要目标系统使用相同的 ENCRYPTION_KEY 环境变量才能正常工作。
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<template #footer>
|
||||
<Button
|
||||
variant="outline"
|
||||
@click="importUsersDialogOpen = false; usersMergeModeSelectOpen = false"
|
||||
>
|
||||
取消
|
||||
</Button>
|
||||
<Button
|
||||
:disabled="importUsersLoading"
|
||||
@click="confirmImportUsers"
|
||||
>
|
||||
{{ importUsersLoading ? '导入中...' : '确认导入' }}
|
||||
</Button>
|
||||
</template>
|
||||
</Dialog>
|
||||
|
||||
<!-- 用户数据导入结果对话框 -->
|
||||
<Dialog
|
||||
v-model:open="importUsersResultDialogOpen"
|
||||
title="用户数据导入完成"
|
||||
>
|
||||
<div
|
||||
v-if="importUsersResult"
|
||||
class="space-y-4"
|
||||
>
|
||||
<div class="grid grid-cols-2 gap-4 text-sm">
|
||||
<div class="p-3 bg-muted rounded-lg">
|
||||
<p class="font-medium">
|
||||
用户
|
||||
</p>
|
||||
<p class="text-muted-foreground">
|
||||
创建: {{ importUsersResult.stats.users.created }},
|
||||
更新: {{ importUsersResult.stats.users.updated }},
|
||||
跳过: {{ importUsersResult.stats.users.skipped }}
|
||||
</p>
|
||||
</div>
|
||||
<div class="p-3 bg-muted rounded-lg">
|
||||
<p class="font-medium">
|
||||
API Keys
|
||||
</p>
|
||||
<p class="text-muted-foreground">
|
||||
创建: {{ importUsersResult.stats.api_keys.created }},
|
||||
跳过: {{ importUsersResult.stats.api_keys.skipped }}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div
|
||||
v-if="importUsersResult.stats.errors.length > 0"
|
||||
class="p-3 bg-destructive/10 rounded-lg"
|
||||
>
|
||||
<p class="font-medium text-destructive mb-2">
|
||||
警告信息
|
||||
</p>
|
||||
<ul class="text-sm text-destructive space-y-1">
|
||||
<li
|
||||
v-for="(err, index) in importUsersResult.stats.errors"
|
||||
:key="index"
|
||||
>
|
||||
{{ err }}
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<template #footer>
|
||||
<Button @click="importUsersResultDialogOpen = false">
|
||||
确定
|
||||
</Button>
|
||||
</template>
|
||||
</Dialog>
|
||||
</PageContainer>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
import { Download, Upload } from 'lucide-vue-next'
|
||||
import Button from '@/components/ui/button.vue'
|
||||
import Input from '@/components/ui/input.vue'
|
||||
import Label from '@/components/ui/label.vue'
|
||||
@@ -389,9 +779,12 @@ import SelectTrigger from '@/components/ui/select-trigger.vue'
|
||||
import SelectValue from '@/components/ui/select-value.vue'
|
||||
import SelectContent from '@/components/ui/select-content.vue'
|
||||
import SelectItem from '@/components/ui/select-item.vue'
|
||||
import {
|
||||
Dialog,
|
||||
} from '@/components/ui'
|
||||
import { PageHeader, PageContainer, CardSection } from '@/components/layout'
|
||||
import { useToast } from '@/composables/useToast'
|
||||
import { adminApi } from '@/api/admin'
|
||||
import { adminApi, type ConfigExportData, type ConfigImportResponse, type UsersExportData, type UsersImportResponse } from '@/api/admin'
|
||||
import { log } from '@/utils/logger'
|
||||
|
||||
const { success, error } = useToast()
|
||||
@@ -423,6 +816,28 @@ interface SystemConfig {
|
||||
const loading = ref(false)
|
||||
const logLevelSelectOpen = ref(false)
|
||||
|
||||
// 导出/导入相关
|
||||
const exportLoading = ref(false)
|
||||
const importLoading = ref(false)
|
||||
const importDialogOpen = ref(false)
|
||||
const importResultDialogOpen = ref(false)
|
||||
const configFileInput = ref<HTMLInputElement | null>(null)
|
||||
const importPreview = ref<ConfigExportData | null>(null)
|
||||
const importResult = ref<ConfigImportResponse | null>(null)
|
||||
const mergeMode = ref<'skip' | 'overwrite' | 'error'>('skip')
|
||||
const mergeModeSelectOpen = ref(false)
|
||||
|
||||
// 用户数据导出/导入相关
|
||||
const exportUsersLoading = ref(false)
|
||||
const importUsersLoading = ref(false)
|
||||
const importUsersDialogOpen = ref(false)
|
||||
const importUsersResultDialogOpen = ref(false)
|
||||
const usersFileInput = ref<HTMLInputElement | null>(null)
|
||||
const importUsersPreview = ref<UsersExportData | null>(null)
|
||||
const importUsersResult = ref<UsersImportResponse | null>(null)
|
||||
const usersMergeMode = ref<'skip' | 'overwrite' | 'error'>('skip')
|
||||
const usersMergeModeSelectOpen = ref(false)
|
||||
|
||||
const systemConfig = ref<SystemConfig>({
|
||||
// 基础配置
|
||||
default_user_quota_usd: 10.0,
|
||||
@@ -623,4 +1038,185 @@ async function saveSystemConfig() {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 导出配置
|
||||
async function handleExportConfig() {
|
||||
exportLoading.value = true
|
||||
try {
|
||||
const data = await adminApi.exportConfig()
|
||||
const blob = new Blob([JSON.stringify(data, null, 2)], { type: 'application/json' })
|
||||
const url = URL.createObjectURL(blob)
|
||||
const a = document.createElement('a')
|
||||
a.href = url
|
||||
a.download = `aether-config-${new Date().toISOString().slice(0, 10)}.json`
|
||||
document.body.appendChild(a)
|
||||
a.click()
|
||||
document.body.removeChild(a)
|
||||
URL.revokeObjectURL(url)
|
||||
success('配置已导出')
|
||||
} catch (err) {
|
||||
error('导出配置失败')
|
||||
log.error('导出配置失败:', err)
|
||||
} finally {
|
||||
exportLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 触发文件选择
|
||||
function triggerConfigFileSelect() {
|
||||
configFileInput.value?.click()
|
||||
}
|
||||
|
||||
// 文件大小限制 (10MB)
|
||||
const MAX_FILE_SIZE = 10 * 1024 * 1024
|
||||
|
||||
// 处理文件选择
|
||||
function handleConfigFileSelect(event: Event) {
|
||||
const input = event.target as HTMLInputElement
|
||||
const file = input.files?.[0]
|
||||
if (!file) return
|
||||
|
||||
if (file.size > MAX_FILE_SIZE) {
|
||||
error('文件大小不能超过 10MB')
|
||||
input.value = ''
|
||||
return
|
||||
}
|
||||
|
||||
const reader = new FileReader()
|
||||
reader.onload = (e) => {
|
||||
try {
|
||||
const content = e.target?.result as string
|
||||
const data = JSON.parse(content) as ConfigExportData
|
||||
|
||||
// 验证版本
|
||||
if (data.version !== '1.0') {
|
||||
error(`不支持的配置版本: ${data.version}`)
|
||||
return
|
||||
}
|
||||
|
||||
importPreview.value = data
|
||||
mergeMode.value = 'skip'
|
||||
importDialogOpen.value = true
|
||||
} catch (err) {
|
||||
error('解析配置文件失败,请确保是有效的 JSON 文件')
|
||||
log.error('解析配置文件失败:', err)
|
||||
}
|
||||
}
|
||||
reader.readAsText(file)
|
||||
|
||||
// 重置 input 以便能再次选择同一文件
|
||||
input.value = ''
|
||||
}
|
||||
|
||||
// 确认导入
|
||||
async function confirmImport() {
|
||||
if (!importPreview.value) return
|
||||
|
||||
importLoading.value = true
|
||||
try {
|
||||
const result = await adminApi.importConfig({
|
||||
...importPreview.value,
|
||||
merge_mode: mergeMode.value
|
||||
})
|
||||
importResult.value = result
|
||||
importDialogOpen.value = false
|
||||
mergeModeSelectOpen.value = false
|
||||
importResultDialogOpen.value = true
|
||||
success('配置导入成功')
|
||||
} catch (err: any) {
|
||||
error(err.response?.data?.detail || '导入配置失败')
|
||||
log.error('导入配置失败:', err)
|
||||
} finally {
|
||||
importLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 导出用户数据
|
||||
async function handleExportUsers() {
|
||||
exportUsersLoading.value = true
|
||||
try {
|
||||
const data = await adminApi.exportUsers()
|
||||
const blob = new Blob([JSON.stringify(data, null, 2)], { type: 'application/json' })
|
||||
const url = URL.createObjectURL(blob)
|
||||
const a = document.createElement('a')
|
||||
a.href = url
|
||||
a.download = `aether-users-${new Date().toISOString().slice(0, 10)}.json`
|
||||
document.body.appendChild(a)
|
||||
a.click()
|
||||
document.body.removeChild(a)
|
||||
URL.revokeObjectURL(url)
|
||||
success('用户数据已导出')
|
||||
} catch (err) {
|
||||
error('导出用户数据失败')
|
||||
log.error('导出用户数据失败:', err)
|
||||
} finally {
|
||||
exportUsersLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 触发用户数据文件选择
|
||||
function triggerUsersFileSelect() {
|
||||
usersFileInput.value?.click()
|
||||
}
|
||||
|
||||
// 处理用户数据文件选择
|
||||
function handleUsersFileSelect(event: Event) {
|
||||
const input = event.target as HTMLInputElement
|
||||
const file = input.files?.[0]
|
||||
if (!file) return
|
||||
|
||||
if (file.size > MAX_FILE_SIZE) {
|
||||
error('文件大小不能超过 10MB')
|
||||
input.value = ''
|
||||
return
|
||||
}
|
||||
|
||||
const reader = new FileReader()
|
||||
reader.onload = (e) => {
|
||||
try {
|
||||
const content = e.target?.result as string
|
||||
const data = JSON.parse(content) as UsersExportData
|
||||
|
||||
// 验证版本
|
||||
if (data.version !== '1.0') {
|
||||
error(`不支持的配置版本: ${data.version}`)
|
||||
return
|
||||
}
|
||||
|
||||
importUsersPreview.value = data
|
||||
usersMergeMode.value = 'skip'
|
||||
importUsersDialogOpen.value = true
|
||||
} catch (err) {
|
||||
error('解析用户数据文件失败,请确保是有效的 JSON 文件')
|
||||
log.error('解析用户数据文件失败:', err)
|
||||
}
|
||||
}
|
||||
reader.readAsText(file)
|
||||
|
||||
// 重置 input 以便能再次选择同一文件
|
||||
input.value = ''
|
||||
}
|
||||
|
||||
// 确认导入用户数据
|
||||
async function confirmImportUsers() {
|
||||
if (!importUsersPreview.value) return
|
||||
|
||||
importUsersLoading.value = true
|
||||
try {
|
||||
const result = await adminApi.importUsers({
|
||||
...importUsersPreview.value,
|
||||
merge_mode: usersMergeMode.value
|
||||
})
|
||||
importUsersResult.value = result
|
||||
importUsersDialogOpen.value = false
|
||||
usersMergeModeSelectOpen.value = false
|
||||
importUsersResultDialogOpen.value = true
|
||||
success('用户数据导入成功')
|
||||
} catch (err: any) {
|
||||
error(err.response?.data?.detail || '导入用户数据失败')
|
||||
log.error('导入用户数据失败:', err)
|
||||
} finally {
|
||||
importUsersLoading.value = false
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -791,11 +791,13 @@ const filteredUsers = computed(() => {
|
||||
return new Date(b.created_at).getTime() - new Date(a.created_at).getTime()
|
||||
})
|
||||
|
||||
// 搜索(支持空格分隔的多关键词 AND 搜索)
|
||||
if (searchQuery.value) {
|
||||
const query = searchQuery.value.toLowerCase()
|
||||
filtered = filtered.filter(
|
||||
u => u.username.toLowerCase().includes(query) || u.email?.toLowerCase().includes(query)
|
||||
)
|
||||
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||
filtered = filtered.filter(u => {
|
||||
const searchableText = `${u.username} ${u.email || ''}`.toLowerCase()
|
||||
return keywords.every(keyword => searchableText.includes(keyword))
|
||||
})
|
||||
}
|
||||
|
||||
if (filterRole.value !== 'all') {
|
||||
|
||||
@@ -103,7 +103,7 @@
|
||||
</div>
|
||||
<div class="grid grid-cols-2 gap-2 sm:gap-3 xl:grid-cols-4">
|
||||
<Card class="relative p-3 sm:p-4 border-book-cloth/30">
|
||||
<Clock class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-book-cloth" />
|
||||
<Clock class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
||||
<div class="pr-6">
|
||||
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
||||
平均响应
|
||||
@@ -114,7 +114,7 @@
|
||||
</div>
|
||||
</Card>
|
||||
<Card class="relative p-3 sm:p-4 border-kraft/30">
|
||||
<AlertTriangle class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-kraft" />
|
||||
<AlertTriangle class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
||||
<div class="pr-6">
|
||||
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
||||
错误率
|
||||
@@ -128,7 +128,7 @@
|
||||
</div>
|
||||
</Card>
|
||||
<Card class="relative p-3 sm:p-4 border-book-cloth/25">
|
||||
<Shuffle class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-kraft" />
|
||||
<Shuffle class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
||||
<div class="pr-6">
|
||||
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
||||
转移次数
|
||||
@@ -142,7 +142,7 @@
|
||||
v-if="costStats"
|
||||
class="relative p-3 sm:p-4 border-manilla/40"
|
||||
>
|
||||
<DollarSign class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-book-cloth" />
|
||||
<DollarSign class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
||||
<div class="pr-6">
|
||||
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
||||
实际成本
|
||||
@@ -180,7 +180,7 @@
|
||||
</div>
|
||||
<div class="grid grid-cols-2 gap-2 sm:gap-3 xl:grid-cols-4">
|
||||
<Card class="relative p-3 sm:p-4 border-book-cloth/30">
|
||||
<Database class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-book-cloth" />
|
||||
<Database class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
||||
<div class="pr-6">
|
||||
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
||||
缓存命中率
|
||||
@@ -191,7 +191,7 @@
|
||||
</div>
|
||||
</Card>
|
||||
<Card class="relative p-3 sm:p-4 border-kraft/30">
|
||||
<Hash class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-kraft" />
|
||||
<Hash class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
||||
<div class="pr-6">
|
||||
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
||||
缓存读取
|
||||
@@ -202,7 +202,7 @@
|
||||
</div>
|
||||
</Card>
|
||||
<Card class="relative p-3 sm:p-4 border-book-cloth/25">
|
||||
<Database class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-kraft" />
|
||||
<Database class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
||||
<div class="pr-6">
|
||||
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
||||
缓存创建
|
||||
@@ -216,7 +216,7 @@
|
||||
v-if="tokenBreakdown"
|
||||
class="relative p-3 sm:p-4 border-manilla/40"
|
||||
>
|
||||
<Hash class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-book-cloth" />
|
||||
<Hash class="absolute top-3 right-3 h-3.5 w-3.5 sm:h-4 sm:w-4 text-muted-foreground" />
|
||||
<div class="pr-6">
|
||||
<p class="text-[9px] sm:text-[10px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.3em] text-muted-foreground">
|
||||
总Token
|
||||
@@ -254,16 +254,16 @@
|
||||
<Card class="overflow-hidden p-4 flex flex-col flex-1 min-h-0 h-full max-h-[280px] sm:max-h-none">
|
||||
<div
|
||||
v-if="loadingAnnouncements"
|
||||
class="py-8 text-center"
|
||||
class="flex-1 flex items-center justify-center"
|
||||
>
|
||||
<Loader2 class="h-5 w-5 animate-spin mx-auto text-muted-foreground" />
|
||||
<Loader2 class="h-5 w-5 animate-spin text-muted-foreground" />
|
||||
</div>
|
||||
|
||||
<div
|
||||
v-else-if="announcements.length === 0"
|
||||
class="py-8 text-center"
|
||||
class="flex-1 flex flex-col items-center justify-center"
|
||||
>
|
||||
<Bell class="h-8 w-8 mx-auto text-muted-foreground/40" />
|
||||
<Bell class="h-8 w-8 text-muted-foreground/40" />
|
||||
<p class="mt-2 text-xs text-muted-foreground">
|
||||
暂无公告
|
||||
</p>
|
||||
@@ -793,9 +793,8 @@ const statCardGlows = [
|
||||
'bg-kraft/30'
|
||||
]
|
||||
|
||||
const getStatIconColor = (index: number): string => {
|
||||
const colors = ['text-book-cloth', 'text-kraft', 'text-book-cloth', 'text-kraft']
|
||||
return colors[index % colors.length]
|
||||
const getStatIconColor = (_index: number): string => {
|
||||
return 'text-muted-foreground'
|
||||
}
|
||||
|
||||
// 统计数据
|
||||
|
||||
@@ -226,8 +226,8 @@
|
||||
<div
|
||||
v-for="announcement in announcements"
|
||||
:key="announcement.id"
|
||||
class="p-4 space-y-2 cursor-pointer transition-colors"
|
||||
:class="[
|
||||
'p-4 space-y-2 cursor-pointer transition-colors',
|
||||
announcement.is_read ? 'hover:bg-muted/30' : 'bg-primary/5 hover:bg-primary/10'
|
||||
]"
|
||||
@click="viewAnnouncementDetail(announcement)"
|
||||
|
||||
@@ -165,17 +165,17 @@
|
||||
<TableCell class="py-4">
|
||||
<div class="flex gap-1.5">
|
||||
<Eye
|
||||
v-if="model.default_supports_vision"
|
||||
v-if="model.config?.vision === true"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
title="Vision"
|
||||
/>
|
||||
<Wrench
|
||||
v-if="model.default_supports_function_calling"
|
||||
v-if="model.config?.function_calling === true"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
title="Tool Use"
|
||||
/>
|
||||
<Brain
|
||||
v-if="model.default_supports_extended_thinking"
|
||||
v-if="model.config?.extended_thinking === true"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
title="Extended Thinking"
|
||||
/>
|
||||
@@ -253,15 +253,15 @@
|
||||
<!-- 第二行:能力图标 -->
|
||||
<div class="flex gap-1.5">
|
||||
<Eye
|
||||
v-if="model.default_supports_vision"
|
||||
v-if="model.config?.vision === true"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
/>
|
||||
<Wrench
|
||||
v-if="model.default_supports_function_calling"
|
||||
v-if="model.config?.function_calling === true"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
/>
|
||||
<Brain
|
||||
v-if="model.default_supports_extended_thinking"
|
||||
v-if="model.config?.extended_thinking === true"
|
||||
class="w-4 h-4 text-muted-foreground"
|
||||
/>
|
||||
</div>
|
||||
@@ -474,24 +474,24 @@ async function toggleCapability(modelName: string, capName: string) {
|
||||
const filteredModels = computed(() => {
|
||||
let result = models.value
|
||||
|
||||
// 搜索
|
||||
// 搜索(支持空格分隔的多关键词 AND 搜索)
|
||||
if (searchQuery.value) {
|
||||
const query = searchQuery.value.toLowerCase()
|
||||
result = result.filter(m =>
|
||||
m.name.toLowerCase().includes(query) ||
|
||||
m.display_name?.toLowerCase().includes(query)
|
||||
)
|
||||
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
|
||||
result = result.filter(m => {
|
||||
const searchableText = `${m.name} ${m.display_name || ''}`.toLowerCase()
|
||||
return keywords.every(keyword => searchableText.includes(keyword))
|
||||
})
|
||||
}
|
||||
|
||||
// 能力筛选
|
||||
if (capabilityFilters.value.vision) {
|
||||
result = result.filter(m => m.default_supports_vision)
|
||||
result = result.filter(m => m.config?.vision === true)
|
||||
}
|
||||
if (capabilityFilters.value.toolUse) {
|
||||
result = result.filter(m => m.default_supports_function_calling)
|
||||
result = result.filter(m => m.config?.function_calling === true)
|
||||
}
|
||||
if (capabilityFilters.value.extendedThinking) {
|
||||
result = result.filter(m => m.default_supports_extended_thinking)
|
||||
result = result.filter(m => m.config?.extended_thinking === true)
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
@@ -62,6 +62,7 @@
|
||||
<Button
|
||||
type="submit"
|
||||
:disabled="savingProfile"
|
||||
class="shadow-none hover:shadow-none"
|
||||
>
|
||||
{{ savingProfile ? '保存中...' : '保存修改' }}
|
||||
</Button>
|
||||
@@ -107,6 +108,7 @@
|
||||
<Button
|
||||
type="submit"
|
||||
:disabled="changingPassword"
|
||||
class="shadow-none hover:shadow-none"
|
||||
>
|
||||
{{ changingPassword ? '修改中...' : '修改密码' }}
|
||||
</Button>
|
||||
@@ -320,6 +322,7 @@
|
||||
import { ref, onMounted } from 'vue'
|
||||
import { useAuthStore } from '@/stores/auth'
|
||||
import { meApi, type Profile } from '@/api/me'
|
||||
import { useDarkMode, type ThemeMode } from '@/composables/useDarkMode'
|
||||
import Card from '@/components/ui/card.vue'
|
||||
import Button from '@/components/ui/button.vue'
|
||||
import Badge from '@/components/ui/badge.vue'
|
||||
@@ -338,6 +341,7 @@ import { log } from '@/utils/logger'
|
||||
|
||||
const authStore = useAuthStore()
|
||||
const { success, error: showError } = useToast()
|
||||
const { setThemeMode } = useDarkMode()
|
||||
|
||||
const profile = ref<Profile | null>(null)
|
||||
|
||||
@@ -375,20 +379,8 @@ function handleThemeChange(value: string) {
|
||||
themeSelectOpen.value = false
|
||||
updatePreferences()
|
||||
|
||||
// 应用主题
|
||||
if (value === 'dark') {
|
||||
document.documentElement.classList.add('dark')
|
||||
} else if (value === 'light') {
|
||||
document.documentElement.classList.remove('dark')
|
||||
} else {
|
||||
// system: 跟随系统
|
||||
const prefersDark = window.matchMedia('(prefers-color-scheme: dark)').matches
|
||||
if (prefersDark) {
|
||||
document.documentElement.classList.add('dark')
|
||||
} else {
|
||||
document.documentElement.classList.remove('dark')
|
||||
}
|
||||
}
|
||||
// 使用 useDarkMode 统一切换主题
|
||||
setThemeMode(value as ThemeMode)
|
||||
}
|
||||
|
||||
function handleLanguageChange(value: string) {
|
||||
@@ -418,10 +410,16 @@ async function loadProfile() {
|
||||
async function loadPreferences() {
|
||||
try {
|
||||
const prefs = await meApi.getPreferences()
|
||||
|
||||
// 主题以本地 localStorage 为准(useDarkMode 在应用启动时已初始化)
|
||||
// 这样可以避免刷新页面时主题被服务端旧值覆盖
|
||||
const { themeMode: currentThemeMode } = useDarkMode()
|
||||
const localTheme = currentThemeMode.value
|
||||
|
||||
preferencesForm.value = {
|
||||
avatar_url: prefs.avatar_url || '',
|
||||
bio: prefs.bio || '',
|
||||
theme: prefs.theme || 'light',
|
||||
theme: localTheme, // 使用本地主题,而非服务端返回值
|
||||
language: prefs.language || 'zh-CN',
|
||||
timezone: prefs.timezone || 'Asia/Shanghai',
|
||||
notifications: {
|
||||
@@ -431,11 +429,12 @@ async function loadPreferences() {
|
||||
}
|
||||
}
|
||||
|
||||
// 应用主题
|
||||
if (preferencesForm.value.theme === 'dark') {
|
||||
document.documentElement.classList.add('dark')
|
||||
} else if (preferencesForm.value.theme === 'light') {
|
||||
document.documentElement.classList.remove('dark')
|
||||
// 如果本地主题和服务端不一致,同步到服务端(静默更新,不提示用户)
|
||||
const serverTheme = prefs.theme || 'light'
|
||||
if (localTheme !== serverTheme) {
|
||||
meApi.updatePreferences({ theme: localTheme }).catch(() => {
|
||||
// 静默失败,不影响用户体验
|
||||
})
|
||||
}
|
||||
} catch (error) {
|
||||
log.error('加载偏好设置失败:', error)
|
||||
|
||||
@@ -38,10 +38,10 @@
|
||||
</button>
|
||||
</div>
|
||||
<p
|
||||
v-if="model.description"
|
||||
v-if="model.config?.description"
|
||||
class="text-xs text-muted-foreground"
|
||||
>
|
||||
{{ model.description }}
|
||||
{{ model.config?.description }}
|
||||
</p>
|
||||
</div>
|
||||
<Button
|
||||
@@ -73,10 +73,10 @@
|
||||
</p>
|
||||
</div>
|
||||
<Badge
|
||||
:variant="model.default_supports_streaming ?? false ? 'default' : 'secondary'"
|
||||
:variant="model.config?.streaming !== false ? 'default' : 'secondary'"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ model.default_supports_streaming ?? false ? '支持' : '不支持' }}
|
||||
{{ model.config?.streaming !== false ? '支持' : '不支持' }}
|
||||
</Badge>
|
||||
</div>
|
||||
<div class="flex items-center gap-2 p-3 rounded-lg border">
|
||||
@@ -90,10 +90,10 @@
|
||||
</p>
|
||||
</div>
|
||||
<Badge
|
||||
:variant="model.default_supports_image_generation ?? false ? 'default' : 'secondary'"
|
||||
:variant="model.config?.image_generation === true ? 'default' : 'secondary'"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ model.default_supports_image_generation ?? false ? '支持' : '不支持' }}
|
||||
{{ model.config?.image_generation === true ? '支持' : '不支持' }}
|
||||
</Badge>
|
||||
</div>
|
||||
<div class="flex items-center gap-2 p-3 rounded-lg border">
|
||||
@@ -107,10 +107,10 @@
|
||||
</p>
|
||||
</div>
|
||||
<Badge
|
||||
:variant="model.default_supports_vision ?? false ? 'default' : 'secondary'"
|
||||
:variant="model.config?.vision === true ? 'default' : 'secondary'"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ model.default_supports_vision ?? false ? '支持' : '不支持' }}
|
||||
{{ model.config?.vision === true ? '支持' : '不支持' }}
|
||||
</Badge>
|
||||
</div>
|
||||
<div class="flex items-center gap-2 p-3 rounded-lg border">
|
||||
@@ -124,10 +124,10 @@
|
||||
</p>
|
||||
</div>
|
||||
<Badge
|
||||
:variant="model.default_supports_function_calling ?? false ? 'default' : 'secondary'"
|
||||
:variant="model.config?.function_calling === true ? 'default' : 'secondary'"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ model.default_supports_function_calling ?? false ? '支持' : '不支持' }}
|
||||
{{ model.config?.function_calling === true ? '支持' : '不支持' }}
|
||||
</Badge>
|
||||
</div>
|
||||
<div class="flex items-center gap-2 p-3 rounded-lg border">
|
||||
@@ -141,10 +141,10 @@
|
||||
</p>
|
||||
</div>
|
||||
<Badge
|
||||
:variant="model.default_supports_extended_thinking ?? false ? 'default' : 'secondary'"
|
||||
:variant="model.config?.extended_thinking === true ? 'default' : 'secondary'"
|
||||
class="text-xs"
|
||||
>
|
||||
{{ model.default_supports_extended_thinking ?? false ? '支持' : '不支持' }}
|
||||
{{ model.config?.extended_thinking === true ? '支持' : '不支持' }}
|
||||
</Badge>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
12
migrate.sh
Executable file
12
migrate.sh
Executable file
@@ -0,0 +1,12 @@
|
||||
#!/bin/bash
|
||||
# 数据库迁移脚本 - 在 Docker 容器内执行 Alembic 迁移
|
||||
|
||||
set -e
|
||||
|
||||
CONTAINER_NAME="aether-app"
|
||||
|
||||
echo "Running database migrations in container: $CONTAINER_NAME"
|
||||
|
||||
docker exec $CONTAINER_NAME alembic upgrade head
|
||||
|
||||
echo "Database migration completed successfully"
|
||||
@@ -7,6 +7,7 @@ from .api_keys import router as api_keys_router
|
||||
from .endpoints import router as endpoints_router
|
||||
from .models import router as models_router
|
||||
from .monitoring import router as monitoring_router
|
||||
from .provider_query import router as provider_query_router
|
||||
from .provider_strategy import router as provider_strategy_router
|
||||
from .providers import router as providers_router
|
||||
from .security import router as security_router
|
||||
@@ -26,5 +27,6 @@ router.include_router(provider_strategy_router)
|
||||
router.include_router(adaptive_router)
|
||||
router.include_router(models_router)
|
||||
router.include_router(security_router)
|
||||
router.include_router(provider_query_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
|
||||
@@ -5,12 +5,12 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .catalog import router as catalog_router
|
||||
from .external import router as external_router
|
||||
from .global_models import router as global_models_router
|
||||
from .mappings import router as mappings_router
|
||||
|
||||
router = APIRouter(prefix="/api/admin/models", tags=["Admin - Model Management"])
|
||||
|
||||
# 挂载子路由
|
||||
router.include_router(catalog_router)
|
||||
router.include_router(global_models_router)
|
||||
router.include_router(mappings_router)
|
||||
router.include_router(external_router)
|
||||
|
||||
@@ -1,38 +1,26 @@
|
||||
"""
|
||||
统一模型目录 Admin API
|
||||
|
||||
阶段一:基于 ModelMapping 和 Model 的聚合视图
|
||||
基于 GlobalModel 的聚合视图
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Set
|
||||
from typing import Dict, List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlalchemy import func, or_
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import GlobalModel, Model, ModelMapping, Provider
|
||||
from src.models.database import GlobalModel, Model
|
||||
from src.models.pydantic_models import (
|
||||
BatchAssignError,
|
||||
BatchAssignModelMappingRequest,
|
||||
BatchAssignModelMappingResponse,
|
||||
BatchAssignProviderResult,
|
||||
DeleteModelMappingResponse,
|
||||
ModelCapabilities,
|
||||
ModelCatalogItem,
|
||||
ModelCatalogProviderDetail,
|
||||
ModelCatalogResponse,
|
||||
ModelPriceRange,
|
||||
OrphanedModel,
|
||||
UpdateModelMappingRequest,
|
||||
UpdateModelMappingResponse,
|
||||
)
|
||||
from src.services.cache.invalidation import get_cache_invalidation_service
|
||||
from src.services.model.service import ModelService
|
||||
|
||||
router = APIRouter(prefix="/catalog", tags=["Admin - Model Catalog"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
@@ -47,24 +35,13 @@ async def get_model_catalog(
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/batch-assign", response_model=BatchAssignModelMappingResponse)
|
||||
async def batch_assign_model_mappings(
|
||||
request: Request,
|
||||
payload: BatchAssignModelMappingRequest,
|
||||
db: Session = Depends(get_db),
|
||||
) -> BatchAssignModelMappingResponse:
|
||||
adapter = AdminBatchAssignModelMappingsAdapter(payload=payload)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminGetModelCatalogAdapter(AdminApiAdapter):
|
||||
"""管理员查询统一模型目录
|
||||
|
||||
新架构说明:
|
||||
架构说明:
|
||||
1. 以 GlobalModel 为中心聚合数据
|
||||
2. ModelMapping 表提供别名信息(provider_id=NULL 表示全局)
|
||||
3. Model 表提供关联提供商和价格
|
||||
2. Model 表提供关联提供商和价格
|
||||
"""
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
@@ -75,29 +52,7 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
|
||||
db.query(GlobalModel).filter(GlobalModel.is_active == True).all()
|
||||
)
|
||||
|
||||
# 2. 获取所有活跃的别名(含全局和 Provider 特定)
|
||||
aliases_rows: List[ModelMapping] = (
|
||||
db.query(ModelMapping)
|
||||
.options(joinedload(ModelMapping.target_global_model))
|
||||
.filter(
|
||||
ModelMapping.is_active == True,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 按 GlobalModel ID 组织别名
|
||||
aliases_by_global_model: Dict[str, List[str]] = {}
|
||||
for alias_row in aliases_rows:
|
||||
if not alias_row.target_global_model_id:
|
||||
continue
|
||||
gm_id = alias_row.target_global_model_id
|
||||
if gm_id not in aliases_by_global_model:
|
||||
aliases_by_global_model[gm_id] = []
|
||||
if alias_row.source_model not in aliases_by_global_model[gm_id]:
|
||||
aliases_by_global_model[gm_id].append(alias_row.source_model)
|
||||
|
||||
# 3. 获取所有活跃的 Model 实现(包含 global_model 以便计算有效价格)
|
||||
# 2. 获取所有活跃的 Model 实现(包含 global_model 以便计算有效价格)
|
||||
models: List[Model] = (
|
||||
db.query(Model)
|
||||
.options(joinedload(Model.provider), joinedload(Model.global_model))
|
||||
@@ -111,16 +66,18 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
|
||||
if model.global_model_id:
|
||||
models_by_global_model.setdefault(model.global_model_id, []).append(model)
|
||||
|
||||
# 4. 为每个 GlobalModel 构建 catalog item
|
||||
# 3. 为每个 GlobalModel 构建 catalog item
|
||||
catalog_items: List[ModelCatalogItem] = []
|
||||
|
||||
for gm in global_models:
|
||||
gm_id = gm.id
|
||||
provider_entries: List[ModelCatalogProviderDetail] = []
|
||||
# 从 config JSON 读取能力标志
|
||||
gm_config = gm.config or {}
|
||||
capability_flags = {
|
||||
"supports_vision": gm.default_supports_vision or False,
|
||||
"supports_function_calling": gm.default_supports_function_calling or False,
|
||||
"supports_streaming": gm.default_supports_streaming or False,
|
||||
"supports_vision": gm_config.get("vision", False),
|
||||
"supports_function_calling": gm_config.get("function_calling", False),
|
||||
"supports_streaming": gm_config.get("streaming", True),
|
||||
}
|
||||
|
||||
# 遍历该 GlobalModel 的所有关联提供商
|
||||
@@ -168,7 +125,6 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
|
||||
supports_function_calling=model.get_effective_supports_function_calling(),
|
||||
supports_streaming=model.get_effective_supports_streaming(),
|
||||
is_active=bool(model.is_active),
|
||||
mapping_id=None, # 新架构中不再有 mapping_id
|
||||
)
|
||||
)
|
||||
|
||||
@@ -186,8 +142,7 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
|
||||
ModelCatalogItem(
|
||||
global_model_name=gm.name,
|
||||
display_name=gm.display_name,
|
||||
description=gm.description,
|
||||
aliases=aliases_by_global_model.get(gm_id, []),
|
||||
description=gm_config.get("description"),
|
||||
providers=provider_entries,
|
||||
price_range=price_range,
|
||||
total_providers=len(provider_entries),
|
||||
@@ -195,238 +150,7 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
|
||||
)
|
||||
)
|
||||
|
||||
# 5. 查找孤立的别名(别名指向的 GlobalModel 不存在或不活跃)
|
||||
orphaned_rows = (
|
||||
db.query(ModelMapping.source_model, GlobalModel.name, func.count(ModelMapping.id))
|
||||
.outerjoin(GlobalModel, ModelMapping.target_global_model_id == GlobalModel.id)
|
||||
.filter(
|
||||
ModelMapping.is_active == True,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
or_(GlobalModel.id == None, GlobalModel.is_active == False),
|
||||
)
|
||||
.group_by(ModelMapping.source_model, GlobalModel.name)
|
||||
.all()
|
||||
)
|
||||
orphaned_models = [
|
||||
OrphanedModel(alias=row[0], global_model_name=row[1], mapping_count=row[2])
|
||||
for row in orphaned_rows
|
||||
if row[0]
|
||||
]
|
||||
|
||||
return ModelCatalogResponse(
|
||||
models=catalog_items,
|
||||
total=len(catalog_items),
|
||||
orphaned_models=orphaned_models,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminBatchAssignModelMappingsAdapter(AdminApiAdapter):
|
||||
payload: BatchAssignModelMappingRequest
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db: Session = context.db
|
||||
created: List[BatchAssignProviderResult] = []
|
||||
errors: List[BatchAssignError] = []
|
||||
|
||||
for provider_config in self.payload.providers:
|
||||
provider_id = provider_config.provider_id
|
||||
try:
|
||||
provider: Provider = db.query(Provider).filter(Provider.id == provider_id).first()
|
||||
if not provider:
|
||||
errors.append(
|
||||
BatchAssignError(provider_id=provider_id, error="Provider 不存在")
|
||||
)
|
||||
continue
|
||||
|
||||
model_id: Optional[str] = None
|
||||
created_model = False
|
||||
|
||||
if provider_config.create_model:
|
||||
model_data = provider_config.model_data
|
||||
if not model_data:
|
||||
errors.append(
|
||||
BatchAssignError(provider_id=provider_id, error="缺少 model_data 配置")
|
||||
)
|
||||
continue
|
||||
|
||||
existing_model = ModelService.get_model_by_name(
|
||||
db, provider_id, model_data.provider_model_name
|
||||
)
|
||||
if existing_model:
|
||||
model_id = existing_model.id
|
||||
logger.info("模型 %s 已存在于 Provider %s,复用现有模型",
|
||||
model_data.provider_model_name,
|
||||
provider.name,
|
||||
)
|
||||
else:
|
||||
model = ModelService.create_model(db, provider_id, model_data)
|
||||
model_id = model.id
|
||||
created_model = True
|
||||
else:
|
||||
model_id = provider_config.model_id
|
||||
if not model_id:
|
||||
errors.append(
|
||||
BatchAssignError(provider_id=provider_id, error="缺少 model_id")
|
||||
)
|
||||
continue
|
||||
model = (
|
||||
db.query(Model)
|
||||
.filter(Model.id == model_id, Model.provider_id == provider_id)
|
||||
.first()
|
||||
)
|
||||
if not model:
|
||||
errors.append(
|
||||
BatchAssignError(
|
||||
provider_id=provider_id, error="模型不存在或不属于当前 Provider")
|
||||
)
|
||||
continue
|
||||
|
||||
# 批量分配功能需要适配 GlobalModel 架构
|
||||
# 参见 docs/optimization-backlog.md 中的待办项
|
||||
errors.append(
|
||||
BatchAssignError(
|
||||
provider_id=provider_id,
|
||||
error="批量分配功能暂时不可用,需要适配新的 GlobalModel 架构",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
except Exception as exc:
|
||||
db.rollback()
|
||||
logger.error("批量添加模型映射失败(需要适配新架构)")
|
||||
errors.append(BatchAssignError(provider_id=provider_id, error=str(exc)))
|
||||
|
||||
return BatchAssignModelMappingResponse(
|
||||
success=len(created) > 0,
|
||||
created_mappings=created,
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/mappings/{mapping_id}", response_model=UpdateModelMappingResponse)
|
||||
async def update_model_mapping(
|
||||
request: Request,
|
||||
mapping_id: str,
|
||||
payload: UpdateModelMappingRequest,
|
||||
db: Session = Depends(get_db),
|
||||
) -> UpdateModelMappingResponse:
|
||||
"""更新模型映射"""
|
||||
adapter = AdminUpdateModelMappingAdapter(mapping_id=mapping_id, payload=payload)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/mappings/{mapping_id}", response_model=DeleteModelMappingResponse)
|
||||
async def delete_model_mapping(
|
||||
request: Request,
|
||||
mapping_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
) -> DeleteModelMappingResponse:
|
||||
"""删除模型映射"""
|
||||
adapter = AdminDeleteModelMappingAdapter(mapping_id=mapping_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminUpdateModelMappingAdapter(AdminApiAdapter):
|
||||
"""更新模型映射"""
|
||||
|
||||
mapping_id: str
|
||||
payload: UpdateModelMappingRequest
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db: Session = context.db
|
||||
|
||||
mapping: Optional[ModelMapping] = (
|
||||
db.query(ModelMapping).filter(ModelMapping.id == self.mapping_id).first()
|
||||
)
|
||||
|
||||
if not mapping:
|
||||
raise HTTPException(status_code=404, detail="映射不存在")
|
||||
|
||||
update_data = self.payload.model_dump(exclude_unset=True)
|
||||
|
||||
if "provider_id" in update_data:
|
||||
new_provider_id = update_data["provider_id"]
|
||||
if new_provider_id:
|
||||
provider = db.query(Provider).filter(Provider.id == new_provider_id).first()
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider 不存在")
|
||||
mapping.provider_id = new_provider_id
|
||||
|
||||
if "target_global_model_id" in update_data:
|
||||
target_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.id == update_data["target_global_model_id"],
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not target_model:
|
||||
raise HTTPException(status_code=404, detail="目标 GlobalModel 不存在或未激活")
|
||||
mapping.target_global_model_id = update_data["target_global_model_id"]
|
||||
|
||||
if "source_model" in update_data:
|
||||
new_source = update_data["source_model"].strip()
|
||||
if not new_source:
|
||||
raise HTTPException(status_code=400, detail="source_model 不能为空")
|
||||
mapping.source_model = new_source
|
||||
|
||||
if "is_active" in update_data:
|
||||
mapping.is_active = update_data["is_active"]
|
||||
|
||||
duplicate = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == mapping.source_model,
|
||||
ModelMapping.provider_id == mapping.provider_id,
|
||||
ModelMapping.id != mapping.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if duplicate:
|
||||
raise HTTPException(status_code=400, detail="映射已存在")
|
||||
|
||||
db.commit()
|
||||
db.refresh(mapping)
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_mapping_changed(mapping.source_model, mapping.provider_id)
|
||||
|
||||
return UpdateModelMappingResponse(
|
||||
success=True,
|
||||
mapping_id=mapping.id,
|
||||
message="映射更新成功",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminDeleteModelMappingAdapter(AdminApiAdapter):
|
||||
"""删除模型映射"""
|
||||
|
||||
mapping_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db: Session = context.db
|
||||
|
||||
mapping: Optional[ModelMapping] = (
|
||||
db.query(ModelMapping).filter(ModelMapping.id == self.mapping_id).first()
|
||||
)
|
||||
|
||||
if not mapping:
|
||||
raise HTTPException(status_code=404, detail="映射不存在")
|
||||
|
||||
source_model = mapping.source_model
|
||||
provider_id = mapping.provider_id
|
||||
|
||||
db.delete(mapping)
|
||||
db.commit()
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_mapping_changed(source_model, provider_id)
|
||||
|
||||
return DeleteModelMappingResponse(
|
||||
success=True,
|
||||
message=f"映射 {self.mapping_id} 已删除",
|
||||
)
|
||||
|
||||
141
src/api/admin/models/external.py
Normal file
141
src/api/admin/models/external.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
models.dev 外部模型数据代理
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from src.clients import get_redis_client
|
||||
from src.core.logger import logger
|
||||
from src.models.database import User
|
||||
from src.utils.auth_utils import require_admin
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
CACHE_KEY = "aether:external:models_dev"
|
||||
CACHE_TTL = 15 * 60 # 15 分钟
|
||||
|
||||
# 标记官方/一手提供商,前端可据此过滤第三方转售商
|
||||
OFFICIAL_PROVIDERS = {
|
||||
"anthropic", # Claude 官方
|
||||
"openai", # OpenAI 官方
|
||||
"google", # Gemini 官方
|
||||
"google-vertex", # Google Vertex AI
|
||||
"azure", # Azure OpenAI
|
||||
"amazon-bedrock", # AWS Bedrock
|
||||
"xai", # Grok 官方
|
||||
"meta", # Llama 官方
|
||||
"deepseek", # DeepSeek 官方
|
||||
"mistral", # Mistral 官方
|
||||
"cohere", # Cohere 官方
|
||||
"zhipuai", # 智谱 AI 官方
|
||||
"alibaba", # 阿里云(通义千问)
|
||||
"minimax", # MiniMax 官方
|
||||
"moonshot", # 月之暗面(Kimi)
|
||||
"baichuan", # 百川智能
|
||||
"ai21", # AI21 Labs
|
||||
}
|
||||
|
||||
|
||||
async def _get_cached_data() -> Optional[dict[str, Any]]:
|
||||
"""从 Redis 获取缓存数据"""
|
||||
redis = await get_redis_client()
|
||||
if redis is None:
|
||||
return None
|
||||
try:
|
||||
cached = await redis.get(CACHE_KEY)
|
||||
if cached:
|
||||
result: dict[str, Any] = json.loads(cached)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"读取 models.dev 缓存失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def _set_cached_data(data: dict) -> None:
|
||||
"""将数据写入 Redis 缓存"""
|
||||
redis = await get_redis_client()
|
||||
if redis is None:
|
||||
return
|
||||
try:
|
||||
await redis.setex(CACHE_KEY, CACHE_TTL, json.dumps(data, ensure_ascii=False))
|
||||
except Exception as e:
|
||||
logger.warning(f"写入 models.dev 缓存失败: {e}")
|
||||
|
||||
|
||||
def _mark_official_providers(data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""为每个提供商标记是否为官方"""
|
||||
result = {}
|
||||
for provider_id, provider_data in data.items():
|
||||
result[provider_id] = {
|
||||
**provider_data,
|
||||
"official": provider_id in OFFICIAL_PROVIDERS,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/external")
|
||||
async def get_external_models(_: User = Depends(require_admin)) -> JSONResponse:
|
||||
"""
|
||||
获取 models.dev 的模型数据(代理请求,解决跨域问题)
|
||||
数据缓存 15 分钟(使用 Redis,多 worker 共享)
|
||||
每个提供商会标记 official 字段,前端可据此过滤
|
||||
"""
|
||||
# 检查缓存
|
||||
cached = await _get_cached_data()
|
||||
if cached is not None:
|
||||
# 兼容旧缓存:如果没有 official 字段则补全并回写
|
||||
try:
|
||||
needs_mark = False
|
||||
for provider_data in cached.values():
|
||||
if not isinstance(provider_data, dict) or "official" not in provider_data:
|
||||
needs_mark = True
|
||||
break
|
||||
if needs_mark:
|
||||
marked_cached = _mark_official_providers(cached)
|
||||
await _set_cached_data(marked_cached)
|
||||
return JSONResponse(content=marked_cached)
|
||||
except Exception as e:
|
||||
logger.warning(f"处理 models.dev 缓存数据失败,将直接返回原缓存: {e}")
|
||||
return JSONResponse(content=cached)
|
||||
|
||||
# 从 models.dev 获取数据
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get("https://models.dev/api.json")
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
# 标记官方提供商
|
||||
marked_data = _mark_official_providers(data)
|
||||
|
||||
# 写入缓存
|
||||
await _set_cached_data(marked_data)
|
||||
|
||||
return JSONResponse(content=marked_data)
|
||||
except httpx.TimeoutException:
|
||||
raise HTTPException(status_code=504, detail="请求 models.dev 超时")
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise HTTPException(
|
||||
status_code=502, detail=f"models.dev 返回错误: {e.response.status_code}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=502, detail=f"获取外部模型数据失败: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/external/cache")
|
||||
async def clear_external_models_cache(_: User = Depends(require_admin)) -> dict:
|
||||
"""清除 models.dev 缓存"""
|
||||
redis = await get_redis_client()
|
||||
if redis is None:
|
||||
return {"cleared": False, "message": "Redis 未启用"}
|
||||
try:
|
||||
await redis.delete(CACHE_KEY)
|
||||
return {"cleared": True}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"清除缓存失败: {str(e)}")
|
||||
@@ -123,7 +123,7 @@ class AdminListGlobalModelsAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from sqlalchemy import func
|
||||
|
||||
from src.models.database import Model, ModelMapping
|
||||
from src.models.database import Model
|
||||
|
||||
models = GlobalModelService.list_global_models(
|
||||
db=context.db,
|
||||
@@ -144,17 +144,8 @@ class AdminListGlobalModelsAdapter(AdminApiAdapter):
|
||||
or 0
|
||||
)
|
||||
|
||||
# 统计别名数量
|
||||
alias_count = (
|
||||
context.db.query(func.count(ModelMapping.id))
|
||||
.filter(ModelMapping.target_global_model_id == gm.id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
response = GlobalModelResponse.model_validate(gm)
|
||||
response.provider_count = provider_count
|
||||
response.alias_count = alias_count
|
||||
# usage_count 直接从 GlobalModel 表读取,已在 model_validate 中自动映射
|
||||
model_responses.append(response)
|
||||
|
||||
@@ -196,21 +187,15 @@ class AdminCreateGlobalModelAdapter(AdminApiAdapter):
|
||||
db=context.db,
|
||||
name=self.payload.name,
|
||||
display_name=self.payload.display_name,
|
||||
description=self.payload.description,
|
||||
official_url=self.payload.official_url,
|
||||
icon_url=self.payload.icon_url,
|
||||
is_active=self.payload.is_active,
|
||||
# 按次计费配置
|
||||
default_price_per_request=self.payload.default_price_per_request,
|
||||
# 阶梯计费配置
|
||||
default_tiered_pricing=tiered_pricing_dict,
|
||||
# 默认能力配置
|
||||
default_supports_vision=self.payload.default_supports_vision,
|
||||
default_supports_function_calling=self.payload.default_supports_function_calling,
|
||||
default_supports_streaming=self.payload.default_supports_streaming,
|
||||
default_supports_extended_thinking=self.payload.default_supports_extended_thinking,
|
||||
# Key 能力配置
|
||||
supported_capabilities=self.payload.supported_capabilities,
|
||||
# 模型配置(JSON)
|
||||
config=self.payload.config,
|
||||
)
|
||||
|
||||
logger.info(f"GlobalModel 已创建: id={global_model.id} name={global_model.name}")
|
||||
|
||||
@@ -1,303 +0,0 @@
|
||||
"""模型映射管理 API
|
||||
|
||||
提供模型映射的 CRUD 操作。
|
||||
|
||||
模型映射(Mapping)用于将源模型映射到目标模型,例如:
|
||||
- 请求 gpt-5.1 → Provider A 映射到 gpt-4
|
||||
- 用于处理 Provider 不支持请求模型的情况
|
||||
|
||||
映射必须关联到特定的 Provider。
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.api import (
|
||||
ModelMappingCreate,
|
||||
ModelMappingResponse,
|
||||
ModelMappingUpdate,
|
||||
)
|
||||
from src.models.database import GlobalModel, ModelMapping, Provider, User
|
||||
from src.services.cache.invalidation import get_cache_invalidation_service
|
||||
|
||||
|
||||
router = APIRouter(prefix="/mappings", tags=["Model Mappings"])
|
||||
|
||||
|
||||
def _serialize_mapping(mapping: ModelMapping) -> ModelMappingResponse:
|
||||
target = mapping.target_global_model
|
||||
provider = mapping.provider
|
||||
scope = "provider" if mapping.provider_id else "global"
|
||||
return ModelMappingResponse(
|
||||
id=mapping.id,
|
||||
source_model=mapping.source_model,
|
||||
target_global_model_id=mapping.target_global_model_id,
|
||||
target_global_model_name=target.name if target else None,
|
||||
target_global_model_display_name=target.display_name if target else None,
|
||||
provider_id=mapping.provider_id,
|
||||
provider_name=provider.name if provider else None,
|
||||
scope=scope,
|
||||
mapping_type=mapping.mapping_type,
|
||||
is_active=mapping.is_active,
|
||||
created_at=mapping.created_at,
|
||||
updated_at=mapping.updated_at,
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=List[ModelMappingResponse])
|
||||
async def list_mappings(
|
||||
provider_id: Optional[str] = Query(None, description="按 Provider 筛选"),
|
||||
source_model: Optional[str] = Query(None, description="按源模型名筛选"),
|
||||
target_global_model_id: Optional[str] = Query(None, description="按目标模型筛选"),
|
||||
scope: Optional[str] = Query(None, description="global 或 provider"),
|
||||
mapping_type: Optional[str] = Query(None, description="映射类型: alias 或 mapping"),
|
||||
is_active: Optional[bool] = Query(None, description="按状态筛选"),
|
||||
skip: int = Query(0, ge=0, description="跳过记录数"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回记录数"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取模型映射列表"""
|
||||
query = db.query(ModelMapping).options(
|
||||
joinedload(ModelMapping.target_global_model),
|
||||
joinedload(ModelMapping.provider),
|
||||
)
|
||||
|
||||
if provider_id is not None:
|
||||
query = query.filter(ModelMapping.provider_id == provider_id)
|
||||
if scope == "global":
|
||||
query = query.filter(ModelMapping.provider_id.is_(None))
|
||||
elif scope == "provider":
|
||||
query = query.filter(ModelMapping.provider_id.isnot(None))
|
||||
if mapping_type is not None:
|
||||
query = query.filter(ModelMapping.mapping_type == mapping_type)
|
||||
if source_model:
|
||||
query = query.filter(ModelMapping.source_model.ilike(f"%{source_model}%"))
|
||||
if target_global_model_id is not None:
|
||||
query = query.filter(ModelMapping.target_global_model_id == target_global_model_id)
|
||||
if is_active is not None:
|
||||
query = query.filter(ModelMapping.is_active == is_active)
|
||||
|
||||
mappings = query.offset(skip).limit(limit).all()
|
||||
return [_serialize_mapping(mapping) for mapping in mappings]
|
||||
|
||||
|
||||
@router.get("/{mapping_id}", response_model=ModelMappingResponse)
|
||||
async def get_mapping(
|
||||
mapping_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取单个模型映射"""
|
||||
mapping = (
|
||||
db.query(ModelMapping)
|
||||
.options(
|
||||
joinedload(ModelMapping.target_global_model),
|
||||
joinedload(ModelMapping.provider),
|
||||
)
|
||||
.filter(ModelMapping.id == mapping_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not mapping:
|
||||
raise HTTPException(status_code=404, detail=f"映射 {mapping_id} 不存在")
|
||||
|
||||
return _serialize_mapping(mapping)
|
||||
|
||||
|
||||
@router.post("", response_model=ModelMappingResponse, status_code=201)
|
||||
async def create_mapping(
|
||||
data: ModelMappingCreate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""创建模型映射"""
|
||||
source_model = data.source_model.strip()
|
||||
if not source_model:
|
||||
raise HTTPException(status_code=400, detail="source_model 不能为空")
|
||||
|
||||
# 验证 mapping_type
|
||||
if data.mapping_type not in ("alias", "mapping"):
|
||||
raise HTTPException(status_code=400, detail="mapping_type 必须是 'alias' 或 'mapping'")
|
||||
|
||||
# 验证目标 GlobalModel 存在
|
||||
target_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(GlobalModel.id == data.target_global_model_id, GlobalModel.is_active == True)
|
||||
.first()
|
||||
)
|
||||
if not target_model:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"目标模型 {data.target_global_model_id} 不存在或未激活"
|
||||
)
|
||||
|
||||
# 验证 Provider 存在
|
||||
provider = None
|
||||
provider_id = data.provider_id
|
||||
if provider_id:
|
||||
provider = db.query(Provider).filter(Provider.id == provider_id).first()
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail=f"Provider {provider_id} 不存在")
|
||||
|
||||
# 检查映射是否已存在(全局或同一 Provider 下不可重复)
|
||||
existing = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == source_model,
|
||||
ModelMapping.provider_id == provider_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="映射已存在")
|
||||
|
||||
# 创建映射
|
||||
mapping = ModelMapping(
|
||||
id=str(uuid.uuid4()),
|
||||
source_model=source_model,
|
||||
target_global_model_id=data.target_global_model_id,
|
||||
provider_id=provider_id,
|
||||
mapping_type=data.mapping_type,
|
||||
is_active=data.is_active,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
db.add(mapping)
|
||||
db.commit()
|
||||
mapping = (
|
||||
db.query(ModelMapping)
|
||||
.options(
|
||||
joinedload(ModelMapping.target_global_model),
|
||||
joinedload(ModelMapping.provider),
|
||||
)
|
||||
.filter(ModelMapping.id == mapping.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
logger.info(f"创建模型映射: {source_model} -> {target_model.name} "
|
||||
f"(Provider: {provider.name if provider else 'global'}, ID: {mapping.id})")
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_mapping_changed(source_model, provider_id)
|
||||
|
||||
return _serialize_mapping(mapping)
|
||||
|
||||
|
||||
@router.patch("/{mapping_id}", response_model=ModelMappingResponse)
|
||||
async def update_mapping(
|
||||
mapping_id: str,
|
||||
data: ModelMappingUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""更新模型映射"""
|
||||
mapping = db.query(ModelMapping).filter(ModelMapping.id == mapping_id).first()
|
||||
if not mapping:
|
||||
raise HTTPException(status_code=404, detail=f"映射 {mapping_id} 不存在")
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
|
||||
# 更新 Provider
|
||||
if "provider_id" in update_data:
|
||||
new_provider_id = update_data["provider_id"]
|
||||
if new_provider_id:
|
||||
provider = db.query(Provider).filter(Provider.id == new_provider_id).first()
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail=f"Provider {new_provider_id} 不存在")
|
||||
mapping.provider_id = new_provider_id
|
||||
|
||||
# 更新目标模型
|
||||
if "target_global_model_id" in update_data:
|
||||
target_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(
|
||||
GlobalModel.id == update_data["target_global_model_id"],
|
||||
GlobalModel.is_active == True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not target_model:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"目标模型 {update_data['target_global_model_id']} 不存在或未激活",
|
||||
)
|
||||
mapping.target_global_model_id = update_data["target_global_model_id"]
|
||||
|
||||
# 更新源模型名
|
||||
if "source_model" in update_data:
|
||||
new_source = update_data["source_model"].strip()
|
||||
if not new_source:
|
||||
raise HTTPException(status_code=400, detail="source_model 不能为空")
|
||||
mapping.source_model = new_source
|
||||
|
||||
# 检查唯一约束
|
||||
duplicate = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
ModelMapping.source_model == mapping.source_model,
|
||||
ModelMapping.provider_id == mapping.provider_id,
|
||||
ModelMapping.id != mapping_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if duplicate:
|
||||
raise HTTPException(status_code=400, detail="映射已存在")
|
||||
|
||||
# 更新映射类型
|
||||
if "mapping_type" in update_data:
|
||||
if update_data["mapping_type"] not in ("alias", "mapping"):
|
||||
raise HTTPException(status_code=400, detail="mapping_type 必须是 'alias' 或 'mapping'")
|
||||
mapping.mapping_type = update_data["mapping_type"]
|
||||
|
||||
# 更新状态
|
||||
if "is_active" in update_data:
|
||||
mapping.is_active = update_data["is_active"]
|
||||
|
||||
mapping.updated_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
db.refresh(mapping)
|
||||
|
||||
logger.info(f"更新模型映射 (ID: {mapping.id})")
|
||||
|
||||
mapping = (
|
||||
db.query(ModelMapping)
|
||||
.options(
|
||||
joinedload(ModelMapping.target_global_model),
|
||||
joinedload(ModelMapping.provider),
|
||||
)
|
||||
.filter(ModelMapping.id == mapping.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_mapping_changed(mapping.source_model, mapping.provider_id)
|
||||
|
||||
return _serialize_mapping(mapping)
|
||||
|
||||
|
||||
@router.delete("/{mapping_id}", status_code=204)
|
||||
async def delete_mapping(
|
||||
mapping_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除模型映射"""
|
||||
mapping = db.query(ModelMapping).filter(ModelMapping.id == mapping_id).first()
|
||||
|
||||
if not mapping:
|
||||
raise HTTPException(status_code=404, detail=f"映射 {mapping_id} 不存在")
|
||||
|
||||
source_model = mapping.source_model
|
||||
provider_id = mapping.provider_id
|
||||
|
||||
logger.info(f"删除模型映射: {source_model} -> {mapping.target_global_model_id} (ID: {mapping.id})")
|
||||
|
||||
db.delete(mapping)
|
||||
db.commit()
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.on_model_mapping_changed(source_model, provider_id)
|
||||
|
||||
return None
|
||||
@@ -12,6 +12,7 @@ from fastapi.responses import PlainTextResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
from src.api.base.context import ApiRequestContext
|
||||
from src.api.base.pagination import PaginationMeta, build_pagination_payload, paginate_sequence
|
||||
from src.api.base.pipeline import ApiRequestPipeline
|
||||
from src.clients.redis_client import get_redis_client_sync
|
||||
@@ -20,7 +21,8 @@ from src.core.logger import logger
|
||||
from src.database import get_db
|
||||
from src.models.database import ApiKey, User
|
||||
from src.services.cache.affinity_manager import get_affinity_manager
|
||||
from src.services.cache.aware_scheduler import get_cache_aware_scheduler
|
||||
from src.services.cache.aware_scheduler import CacheAwareScheduler, get_cache_aware_scheduler
|
||||
from src.services.system.config import SystemConfigService
|
||||
|
||||
router = APIRouter(prefix="/api/admin/monitoring/cache", tags=["Admin - Monitoring: Cache"])
|
||||
pipeline = ApiRequestPipeline()
|
||||
@@ -87,19 +89,19 @@ def resolve_user_identifier(db: Session, identifier: str) -> Optional[str]:
|
||||
# 2. 尝试作为 Username 查询
|
||||
user = db.query(User).filter(User.username == identifier).first()
|
||||
if user:
|
||||
logger.debug(f"通过Username解析: {identifier} -> {user.id[:8]}...")
|
||||
logger.debug(f"通过Username解析: {identifier} -> {user.id[:8]}...") # type: ignore[index]
|
||||
return user.id
|
||||
|
||||
# 3. 尝试作为 Email 查询
|
||||
user = db.query(User).filter(User.email == identifier).first()
|
||||
if user:
|
||||
logger.debug(f"通过Email解析: {identifier} -> {user.id[:8]}...")
|
||||
logger.debug(f"通过Email解析: {identifier} -> {user.id[:8]}...") # type: ignore[index]
|
||||
return user.id
|
||||
|
||||
# 4. 尝试作为 API Key ID 查询
|
||||
api_key = db.query(ApiKey).filter(ApiKey.id == identifier).first()
|
||||
if api_key:
|
||||
logger.debug(f"通过API Key ID解析: {identifier[:8]}... -> User ID: {api_key.user_id[:8]}...")
|
||||
logger.debug(f"通过API Key ID解析: {identifier[:8]}... -> User ID: {api_key.user_id[:8]}...") # type: ignore[index]
|
||||
return api_key.user_id
|
||||
|
||||
# 无法识别
|
||||
@@ -111,7 +113,7 @@ def resolve_user_identifier(db: Session, identifier: str) -> Optional[str]:
|
||||
async def get_cache_stats(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
) -> Any:
|
||||
"""
|
||||
获取缓存亲和性统计信息
|
||||
|
||||
@@ -131,7 +133,7 @@ async def get_user_affinity(
|
||||
user_identifier: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
) -> Any:
|
||||
"""
|
||||
查询指定用户的所有缓存亲和性
|
||||
|
||||
@@ -157,7 +159,7 @@ async def list_affinities(
|
||||
limit: int = Query(100, ge=1, le=1000, description="返回数量限制"),
|
||||
offset: int = Query(0, ge=0, description="偏移量"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
) -> Any:
|
||||
"""
|
||||
获取所有缓存亲和性列表,可选按关键词过滤
|
||||
|
||||
@@ -173,7 +175,7 @@ async def clear_user_cache(
|
||||
user_identifier: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
) -> Any:
|
||||
"""
|
||||
Clear cache affinity for a specific user
|
||||
|
||||
@@ -188,7 +190,7 @@ async def clear_user_cache(
|
||||
async def clear_all_cache(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
) -> Any:
|
||||
"""
|
||||
Clear all cache affinities
|
||||
|
||||
@@ -203,7 +205,7 @@ async def clear_provider_cache(
|
||||
provider_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
) -> Any:
|
||||
"""
|
||||
Clear cache affinities for a specific provider
|
||||
|
||||
@@ -218,7 +220,7 @@ async def clear_provider_cache(
|
||||
async def get_cache_config(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
) -> Any:
|
||||
"""
|
||||
获取缓存相关配置
|
||||
|
||||
@@ -234,7 +236,7 @@ async def get_cache_config(
|
||||
async def get_cache_metrics(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
) -> Any:
|
||||
"""
|
||||
以 Prometheus 文本格式暴露缓存调度指标,方便接入 Grafana。
|
||||
"""
|
||||
@@ -246,10 +248,25 @@ async def get_cache_metrics(
|
||||
|
||||
|
||||
class AdminCacheStatsAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
|
||||
try:
|
||||
redis_client = get_redis_client_sync()
|
||||
scheduler = await get_cache_aware_scheduler(redis_client)
|
||||
# 读取系统配置,确保监控接口与编排器使用一致的模式
|
||||
priority_mode = SystemConfigService.get_config(
|
||||
context.db,
|
||||
"provider_priority_mode",
|
||||
CacheAwareScheduler.PRIORITY_MODE_PROVIDER,
|
||||
)
|
||||
scheduling_mode = SystemConfigService.get_config(
|
||||
context.db,
|
||||
"scheduling_mode",
|
||||
CacheAwareScheduler.SCHEDULING_MODE_CACHE_AFFINITY,
|
||||
)
|
||||
scheduler = await get_cache_aware_scheduler(
|
||||
redis_client,
|
||||
priority_mode=priority_mode,
|
||||
scheduling_mode=scheduling_mode,
|
||||
)
|
||||
stats = await scheduler.get_stats()
|
||||
logger.info("缓存统计信息查询成功")
|
||||
context.add_audit_metadata(
|
||||
@@ -266,10 +283,25 @@ class AdminCacheStatsAdapter(AdminApiAdapter):
|
||||
|
||||
|
||||
class AdminCacheMetricsAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
async def handle(self, context: ApiRequestContext) -> PlainTextResponse:
|
||||
try:
|
||||
redis_client = get_redis_client_sync()
|
||||
scheduler = await get_cache_aware_scheduler(redis_client)
|
||||
# 读取系统配置,确保监控接口与编排器使用一致的模式
|
||||
priority_mode = SystemConfigService.get_config(
|
||||
context.db,
|
||||
"provider_priority_mode",
|
||||
CacheAwareScheduler.PRIORITY_MODE_PROVIDER,
|
||||
)
|
||||
scheduling_mode = SystemConfigService.get_config(
|
||||
context.db,
|
||||
"scheduling_mode",
|
||||
CacheAwareScheduler.SCHEDULING_MODE_CACHE_AFFINITY,
|
||||
)
|
||||
scheduler = await get_cache_aware_scheduler(
|
||||
redis_client,
|
||||
priority_mode=priority_mode,
|
||||
scheduling_mode=scheduling_mode,
|
||||
)
|
||||
stats = await scheduler.get_stats()
|
||||
payload = self._format_prometheus(stats)
|
||||
context.add_audit_metadata(
|
||||
@@ -391,7 +423,7 @@ class AdminCacheMetricsAdapter(AdminApiAdapter):
|
||||
class AdminGetUserAffinityAdapter(AdminApiAdapter):
|
||||
user_identifier: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
|
||||
db = context.db
|
||||
try:
|
||||
user_id = resolve_user_identifier(db, self.user_identifier)
|
||||
@@ -472,7 +504,7 @@ class AdminListAffinitiesAdapter(AdminApiAdapter):
|
||||
limit: int
|
||||
offset: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
|
||||
db = context.db
|
||||
redis_client = get_redis_client_sync()
|
||||
if not redis_client:
|
||||
@@ -682,7 +714,7 @@ class AdminListAffinitiesAdapter(AdminApiAdapter):
|
||||
class AdminClearUserCacheAdapter(AdminApiAdapter):
|
||||
user_identifier: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
|
||||
db = context.db
|
||||
try:
|
||||
redis_client = get_redis_client_sync()
|
||||
@@ -786,7 +818,7 @@ class AdminClearUserCacheAdapter(AdminApiAdapter):
|
||||
|
||||
|
||||
class AdminClearAllCacheAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
|
||||
try:
|
||||
redis_client = get_redis_client_sync()
|
||||
affinity_mgr = await get_affinity_manager(redis_client)
|
||||
@@ -806,7 +838,7 @@ class AdminClearAllCacheAdapter(AdminApiAdapter):
|
||||
class AdminClearProviderCacheAdapter(AdminApiAdapter):
|
||||
provider_id: str
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
|
||||
try:
|
||||
redis_client = get_redis_client_sync()
|
||||
affinity_mgr = await get_affinity_manager(redis_client)
|
||||
@@ -829,7 +861,7 @@ class AdminClearProviderCacheAdapter(AdminApiAdapter):
|
||||
|
||||
|
||||
class AdminCacheConfigAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
|
||||
from src.services.cache.affinity_manager import CacheAffinityManager
|
||||
from src.services.cache.aware_scheduler import CacheAwareScheduler
|
||||
from src.services.rate_limit.adaptive_reservation import get_adaptive_reservation_manager
|
||||
@@ -869,3 +901,464 @@ class AdminCacheConfigAdapter(AdminApiAdapter):
|
||||
dynamic_reservation_enabled=True,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
# ==================== 模型映射缓存管理 ====================
|
||||
|
||||
|
||||
@router.get("/model-mapping/stats")
|
||||
async def get_model_mapping_cache_stats(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
获取模型映射缓存统计信息
|
||||
|
||||
返回:
|
||||
- 缓存键数量
|
||||
- 缓存 TTL 配置
|
||||
- 各类型缓存数量
|
||||
"""
|
||||
adapter = AdminModelMappingCacheStatsAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/model-mapping")
|
||||
async def clear_all_model_mapping_cache(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
清除所有模型映射缓存
|
||||
|
||||
警告: 这会影响所有模型解析,请谨慎使用
|
||||
"""
|
||||
adapter = AdminClearAllModelMappingCacheAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/model-mapping/{model_name}")
|
||||
async def clear_model_mapping_cache_by_name(
|
||||
model_name: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
清除指定模型名称的映射缓存
|
||||
|
||||
参数:
|
||||
- model_name: 模型名称(可以是 GlobalModel.name 或映射名称)
|
||||
"""
|
||||
adapter = AdminClearModelMappingCacheByNameAdapter(model_name=model_name)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.delete("/model-mapping/provider/{provider_id}/{global_model_id}")
|
||||
async def clear_provider_model_mapping_cache(
|
||||
provider_id: str,
|
||||
global_model_id: str,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
清除指定 Provider 和 GlobalModel 的模型映射缓存
|
||||
|
||||
参数:
|
||||
- provider_id: Provider ID
|
||||
- global_model_id: GlobalModel ID
|
||||
"""
|
||||
adapter = AdminClearProviderModelMappingCacheAdapter(
|
||||
provider_id=provider_id, global_model_id=global_model_id
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
|
||||
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
|
||||
import json
|
||||
|
||||
from src.clients.redis_client import get_redis_client
|
||||
from src.config.constants import CacheTTL
|
||||
from src.models.database import GlobalModel, Model, Provider
|
||||
|
||||
db = context.db
|
||||
|
||||
try:
|
||||
redis = await get_redis_client(require_redis=False)
|
||||
if not redis:
|
||||
return {
|
||||
"status": "ok",
|
||||
"data": {
|
||||
"available": False,
|
||||
"message": "Redis 未启用,模型映射缓存不可用",
|
||||
},
|
||||
}
|
||||
|
||||
# 统计各类型缓存键数量
|
||||
model_id_keys = []
|
||||
global_model_id_keys = []
|
||||
global_model_name_keys = []
|
||||
global_model_resolve_keys = []
|
||||
provider_global_keys = []
|
||||
|
||||
# 扫描所有模型相关的缓存键
|
||||
async for key in redis.scan_iter(match="model:*", count=100):
|
||||
key_str = key.decode() if isinstance(key, bytes) else key
|
||||
if key_str.startswith("model:id:"):
|
||||
model_id_keys.append(key_str)
|
||||
elif key_str.startswith("model:provider_global:"):
|
||||
# 过滤掉 hits 统计键,只保留实际的缓存键
|
||||
if not key_str.startswith("model:provider_global:hits:"):
|
||||
provider_global_keys.append(key_str)
|
||||
|
||||
async for key in redis.scan_iter(match="global_model:*", count=100):
|
||||
key_str = key.decode() if isinstance(key, bytes) else key
|
||||
if key_str.startswith("global_model:id:"):
|
||||
global_model_id_keys.append(key_str)
|
||||
elif key_str.startswith("global_model:name:"):
|
||||
global_model_name_keys.append(key_str)
|
||||
elif key_str.startswith("global_model:resolve:"):
|
||||
global_model_resolve_keys.append(key_str)
|
||||
|
||||
total_keys = (
|
||||
len(model_id_keys)
|
||||
+ len(global_model_id_keys)
|
||||
+ len(global_model_name_keys)
|
||||
+ len(global_model_resolve_keys)
|
||||
+ len(provider_global_keys)
|
||||
)
|
||||
|
||||
# 解析缓存内容,构建映射列表
|
||||
mappings = []
|
||||
unmapped_entries = []
|
||||
|
||||
for key in global_model_resolve_keys[:100]: # 最多处理 100 个
|
||||
mapping_name = key.replace("global_model:resolve:", "")
|
||||
try:
|
||||
cached_value = await redis.get(key)
|
||||
ttl = await redis.ttl(key)
|
||||
|
||||
if cached_value:
|
||||
cached_str = (
|
||||
cached_value.decode()
|
||||
if isinstance(cached_value, bytes)
|
||||
else cached_value
|
||||
)
|
||||
|
||||
if cached_str == "NOT_FOUND":
|
||||
unmapped_entries.append({
|
||||
"mapping_name": mapping_name,
|
||||
"status": "not_found",
|
||||
"ttl": ttl if ttl > 0 else None,
|
||||
})
|
||||
else:
|
||||
try:
|
||||
cached_data = json.loads(cached_str)
|
||||
global_model_id = cached_data.get("id")
|
||||
global_model_name = cached_data.get("name")
|
||||
global_model_display_name = cached_data.get("display_name")
|
||||
|
||||
# 跳过 mapping_name == global_model_name 的情况(直接匹配,不是映射)
|
||||
if mapping_name == global_model_name:
|
||||
continue
|
||||
|
||||
# 查询哪些 Provider 配置了这个映射名称
|
||||
provider_names = []
|
||||
if global_model_id:
|
||||
models = (
|
||||
db.query(Model, Provider)
|
||||
.join(Provider, Model.provider_id == Provider.id)
|
||||
.filter(
|
||||
Model.global_model_id == global_model_id,
|
||||
Model.is_active,
|
||||
Provider.is_active,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
# 只显示配置了该映射名称的 Provider
|
||||
for model, provider in models:
|
||||
# 检查是否是主模型名称
|
||||
if model.provider_model_name == mapping_name:
|
||||
provider_names.append(
|
||||
provider.display_name or provider.name
|
||||
)
|
||||
continue
|
||||
# 检查是否在别名列表中
|
||||
if model.provider_model_aliases:
|
||||
alias_names = [
|
||||
a.get("name")
|
||||
for a in model.provider_model_aliases
|
||||
if isinstance(a, dict)
|
||||
]
|
||||
if mapping_name in alias_names:
|
||||
provider_names.append(
|
||||
provider.display_name or provider.name
|
||||
)
|
||||
provider_names = sorted(list(set(provider_names)))
|
||||
|
||||
mappings.append({
|
||||
"mapping_name": mapping_name,
|
||||
"global_model_name": global_model_name,
|
||||
"global_model_display_name": global_model_display_name,
|
||||
"providers": provider_names,
|
||||
"ttl": ttl if ttl > 0 else None,
|
||||
})
|
||||
|
||||
except json.JSONDecodeError:
|
||||
unmapped_entries.append({
|
||||
"mapping_name": mapping_name,
|
||||
"status": "invalid",
|
||||
"ttl": ttl if ttl > 0 else None,
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"解析缓存键 {key} 失败: {e}")
|
||||
unmapped_entries.append({
|
||||
"mapping_name": mapping_name,
|
||||
"status": "error",
|
||||
"ttl": None,
|
||||
})
|
||||
|
||||
# 按 mapping_name 排序
|
||||
mappings.sort(key=lambda x: x["mapping_name"])
|
||||
|
||||
# 3. 解析 provider_global 缓存(Provider 级别的模型解析缓存)
|
||||
provider_model_mappings = []
|
||||
# 预加载 Provider 和 GlobalModel 数据
|
||||
provider_map = {str(p.id): p for p in db.query(Provider).filter(Provider.is_active.is_(True)).all()}
|
||||
global_model_map = {str(gm.id): gm for gm in db.query(GlobalModel).filter(GlobalModel.is_active.is_(True)).all()}
|
||||
|
||||
for key in provider_global_keys[:100]: # 最多处理 100 个
|
||||
# key 格式: model:provider_global:{provider_id}:{global_model_id}
|
||||
try:
|
||||
parts = key.replace("model:provider_global:", "").split(":")
|
||||
if len(parts) != 2:
|
||||
continue
|
||||
provider_id, global_model_id = parts
|
||||
|
||||
cached_value = await redis.get(key)
|
||||
ttl = await redis.ttl(key)
|
||||
|
||||
# 获取命中次数
|
||||
hit_count_key = f"model:provider_global:hits:{provider_id}:{global_model_id}"
|
||||
hit_count_raw = await redis.get(hit_count_key)
|
||||
hit_count = int(hit_count_raw) if hit_count_raw else 0
|
||||
|
||||
if cached_value:
|
||||
cached_str = (
|
||||
cached_value.decode()
|
||||
if isinstance(cached_value, bytes)
|
||||
else cached_value
|
||||
)
|
||||
try:
|
||||
cached_data = json.loads(cached_str)
|
||||
provider_model_name = cached_data.get("provider_model_name")
|
||||
provider_model_aliases = cached_data.get("provider_model_aliases", [])
|
||||
|
||||
# 获取 Provider 和 GlobalModel 信息
|
||||
provider = provider_map.get(provider_id)
|
||||
global_model = global_model_map.get(global_model_id)
|
||||
|
||||
if provider and global_model:
|
||||
# 提取别名名称
|
||||
alias_names = []
|
||||
if provider_model_aliases:
|
||||
for alias_entry in provider_model_aliases:
|
||||
if isinstance(alias_entry, dict) and alias_entry.get("name"):
|
||||
alias_names.append(alias_entry["name"])
|
||||
|
||||
# provider_model_name 为空时跳过
|
||||
if not provider_model_name:
|
||||
continue
|
||||
|
||||
# 只显示有实际映射的条目:
|
||||
# 1. 全局模型名 != Provider 模型名(模型名称映射)
|
||||
# 2. 或者有别名配置
|
||||
has_name_mapping = global_model.name != provider_model_name
|
||||
has_aliases = len(alias_names) > 0
|
||||
|
||||
if has_name_mapping or has_aliases:
|
||||
# 构建用于展示的别名列表
|
||||
# 如果只有名称映射没有别名,则用 global_model_name 作为"请求名称"
|
||||
display_aliases = alias_names if alias_names else [global_model.name]
|
||||
|
||||
provider_model_mappings.append({
|
||||
"provider_id": provider_id,
|
||||
"provider_name": provider.display_name or provider.name,
|
||||
"global_model_id": global_model_id,
|
||||
"global_model_name": global_model.name,
|
||||
"global_model_display_name": global_model.display_name,
|
||||
"provider_model_name": provider_model_name,
|
||||
"aliases": display_aliases,
|
||||
"ttl": ttl if ttl > 0 else None,
|
||||
"hit_count": hit_count,
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"解析 provider_global 缓存键 {key} 失败: {e}")
|
||||
|
||||
# 按 provider_name + global_model_name 排序
|
||||
provider_model_mappings.sort(key=lambda x: (x["provider_name"], x["global_model_name"]))
|
||||
|
||||
response_data = {
|
||||
"available": True,
|
||||
"ttl_seconds": CacheTTL.MODEL,
|
||||
"total_keys": total_keys,
|
||||
"breakdown": {
|
||||
"model_by_id": len(model_id_keys),
|
||||
"model_by_provider_global": len(provider_global_keys),
|
||||
"global_model_by_id": len(global_model_id_keys),
|
||||
"global_model_by_name": len(global_model_name_keys),
|
||||
"global_model_resolve": len(global_model_resolve_keys),
|
||||
},
|
||||
"mappings": mappings,
|
||||
"provider_model_mappings": provider_model_mappings if provider_model_mappings else None,
|
||||
"unmapped": unmapped_entries if unmapped_entries else None,
|
||||
}
|
||||
|
||||
context.add_audit_metadata(
|
||||
action="model_mapping_cache_stats",
|
||||
total_keys=total_keys,
|
||||
)
|
||||
return {"status": "ok", "data": response_data}
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception(f"获取模型映射缓存统计失败: {exc}")
|
||||
raise HTTPException(status_code=500, detail=f"获取统计失败: {exc}")
|
||||
|
||||
|
||||
class AdminClearAllModelMappingCacheAdapter(AdminApiAdapter):
|
||||
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
|
||||
from src.clients.redis_client import get_redis_client
|
||||
|
||||
try:
|
||||
redis = await get_redis_client(require_redis=False)
|
||||
if not redis:
|
||||
raise HTTPException(status_code=503, detail="Redis 未启用")
|
||||
|
||||
deleted_count = 0
|
||||
|
||||
# 删除所有模型相关的缓存键
|
||||
keys_to_delete = []
|
||||
async for key in redis.scan_iter(match="model:*", count=100):
|
||||
keys_to_delete.append(key)
|
||||
async for key in redis.scan_iter(match="global_model:*", count=100):
|
||||
keys_to_delete.append(key)
|
||||
|
||||
if keys_to_delete:
|
||||
deleted_count = await redis.delete(*keys_to_delete)
|
||||
|
||||
logger.warning(f"已清除所有模型映射缓存(管理员操作): {deleted_count} 个键")
|
||||
context.add_audit_metadata(
|
||||
action="model_mapping_cache_clear_all",
|
||||
deleted_count=deleted_count,
|
||||
)
|
||||
return {
|
||||
"status": "ok",
|
||||
"message": f"已清除所有模型映射缓存",
|
||||
"deleted_count": deleted_count,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception(f"清除模型映射缓存失败: {exc}")
|
||||
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminClearModelMappingCacheByNameAdapter(AdminApiAdapter):
|
||||
model_name: str
|
||||
|
||||
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
|
||||
from src.clients.redis_client import get_redis_client
|
||||
|
||||
try:
|
||||
redis = await get_redis_client(require_redis=False)
|
||||
if not redis:
|
||||
raise HTTPException(status_code=503, detail="Redis 未启用")
|
||||
|
||||
deleted_keys = []
|
||||
|
||||
# 清除 resolve 缓存
|
||||
resolve_key = f"global_model:resolve:{self.model_name}"
|
||||
if await redis.exists(resolve_key):
|
||||
await redis.delete(resolve_key)
|
||||
deleted_keys.append(resolve_key)
|
||||
|
||||
# 清除 name 缓存
|
||||
name_key = f"global_model:name:{self.model_name}"
|
||||
if await redis.exists(name_key):
|
||||
await redis.delete(name_key)
|
||||
deleted_keys.append(name_key)
|
||||
|
||||
logger.info(f"已清除模型映射缓存: model_name={self.model_name}, 删除键={deleted_keys}")
|
||||
context.add_audit_metadata(
|
||||
action="model_mapping_cache_clear_by_name",
|
||||
model_name=self.model_name,
|
||||
deleted_keys=deleted_keys,
|
||||
)
|
||||
return {
|
||||
"status": "ok",
|
||||
"message": f"已清除模型 {self.model_name} 的映射缓存",
|
||||
"model_name": self.model_name,
|
||||
"deleted_keys": deleted_keys,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception(f"清除模型映射缓存失败: {exc}")
|
||||
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminClearProviderModelMappingCacheAdapter(AdminApiAdapter):
|
||||
provider_id: str
|
||||
global_model_id: str
|
||||
|
||||
async def handle(self, context: ApiRequestContext) -> Dict[str, Any]: # type: ignore[override]
|
||||
from src.clients.redis_client import get_redis_client
|
||||
|
||||
try:
|
||||
redis = await get_redis_client(require_redis=False)
|
||||
if not redis:
|
||||
raise HTTPException(status_code=503, detail="Redis 未启用")
|
||||
|
||||
deleted_keys = []
|
||||
|
||||
# 清除 provider_global 缓存
|
||||
provider_global_key = f"model:provider_global:{self.provider_id}:{self.global_model_id}"
|
||||
if await redis.exists(provider_global_key):
|
||||
await redis.delete(provider_global_key)
|
||||
deleted_keys.append(provider_global_key)
|
||||
|
||||
# 清除对应的 hit_count 缓存
|
||||
hit_count_key = f"model:provider_global:hits:{self.provider_id}:{self.global_model_id}"
|
||||
if await redis.exists(hit_count_key):
|
||||
await redis.delete(hit_count_key)
|
||||
deleted_keys.append(hit_count_key)
|
||||
|
||||
logger.info(
|
||||
f"已清除 Provider 模型映射缓存: provider_id={self.provider_id[:8]}..., "
|
||||
f"global_model_id={self.global_model_id[:8]}..., 删除键={deleted_keys}"
|
||||
)
|
||||
context.add_audit_metadata(
|
||||
action="provider_model_mapping_cache_clear",
|
||||
provider_id=self.provider_id,
|
||||
global_model_id=self.global_model_id,
|
||||
deleted_keys=deleted_keys,
|
||||
)
|
||||
return {
|
||||
"status": "ok",
|
||||
"message": "已清除 Provider 模型映射缓存",
|
||||
"provider_id": self.provider_id,
|
||||
"global_model_id": self.global_model_id,
|
||||
"deleted_keys": deleted_keys,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception(f"清除 Provider 模型映射缓存失败: {exc}")
|
||||
raise HTTPException(status_code=500, detail=f"清除失败: {exc}")
|
||||
|
||||
@@ -1,46 +1,28 @@
|
||||
"""
|
||||
Provider Query API 端点
|
||||
用于查询提供商的余额、使用记录等信息
|
||||
用于查询提供商的模型列表等信息
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.core.crypto import crypto_service
|
||||
from src.core.logger import logger
|
||||
from src.database.database import get_db
|
||||
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint, User
|
||||
|
||||
# 初始化适配器注册
|
||||
from src.plugins.provider_query import init # noqa
|
||||
from src.plugins.provider_query import get_query_registry
|
||||
from src.plugins.provider_query.base import QueryCapability
|
||||
from src.models.database import Provider, ProviderEndpoint, User
|
||||
from src.utils.auth_utils import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/provider-query", tags=["Provider Query"])
|
||||
router = APIRouter(prefix="/api/admin/provider-query", tags=["Provider Query"])
|
||||
|
||||
|
||||
# ============ Request/Response Models ============
|
||||
|
||||
|
||||
class BalanceQueryRequest(BaseModel):
|
||||
"""余额查询请求"""
|
||||
|
||||
provider_id: str
|
||||
api_key_id: Optional[str] = None # 如果不指定,使用提供商的第一个可用 API Key
|
||||
|
||||
|
||||
class UsageSummaryQueryRequest(BaseModel):
|
||||
"""使用汇总查询请求"""
|
||||
|
||||
provider_id: str
|
||||
api_key_id: Optional[str] = None
|
||||
period: str = "month" # day, week, month, year
|
||||
|
||||
|
||||
class ModelsQueryRequest(BaseModel):
|
||||
"""模型列表查询请求"""
|
||||
|
||||
@@ -51,360 +33,281 @@ class ModelsQueryRequest(BaseModel):
|
||||
# ============ API Endpoints ============
|
||||
|
||||
|
||||
@router.get("/adapters")
|
||||
async def list_adapters(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取所有可用的查询适配器
|
||||
async def _fetch_openai_models(
|
||||
client: httpx.AsyncClient,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
api_format: str,
|
||||
extra_headers: Optional[dict] = None,
|
||||
) -> tuple[list, Optional[str]]:
|
||||
"""获取 OpenAI 格式的模型列表
|
||||
|
||||
Returns:
|
||||
适配器列表
|
||||
tuple[list, Optional[str]]: (模型列表, 错误信息)
|
||||
"""
|
||||
registry = get_query_registry()
|
||||
adapters = registry.list_adapters()
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
if extra_headers:
|
||||
# 防止 extra_headers 覆盖 Authorization
|
||||
safe_headers = {k: v for k, v in extra_headers.items() if k.lower() != "authorization"}
|
||||
headers.update(safe_headers)
|
||||
|
||||
return {"success": True, "data": adapters}
|
||||
# 构建 /v1/models URL
|
||||
if base_url.endswith("/v1"):
|
||||
models_url = f"{base_url}/models"
|
||||
else:
|
||||
models_url = f"{base_url}/v1/models"
|
||||
|
||||
try:
|
||||
response = await client.get(models_url, headers=headers)
|
||||
logger.debug(f"OpenAI models request to {models_url}: status={response.status_code}")
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
models = []
|
||||
if "data" in data:
|
||||
models = data["data"]
|
||||
elif isinstance(data, list):
|
||||
models = data
|
||||
# 为每个模型添加 api_format 字段
|
||||
for m in models:
|
||||
m["api_format"] = api_format
|
||||
return models, None
|
||||
else:
|
||||
# 记录详细的错误信息
|
||||
error_body = response.text[:500] if response.text else "(empty)"
|
||||
error_msg = f"HTTP {response.status_code}: {error_body}"
|
||||
logger.warning(f"OpenAI models request to {models_url} failed: {error_msg}")
|
||||
return [], error_msg
|
||||
except Exception as e:
|
||||
error_msg = f"Request error: {str(e)}"
|
||||
logger.warning(f"Failed to fetch models from {models_url}: {e}")
|
||||
return [], error_msg
|
||||
|
||||
|
||||
@router.get("/capabilities/{provider_id}")
|
||||
async def get_provider_capabilities(
|
||||
provider_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取提供商支持的查询能力
|
||||
|
||||
Args:
|
||||
provider_id: 提供商 ID
|
||||
async def _fetch_claude_models(
|
||||
client: httpx.AsyncClient, base_url: str, api_key: str, api_format: str
|
||||
) -> tuple[list, Optional[str]]:
|
||||
"""获取 Claude 格式的模型列表
|
||||
|
||||
Returns:
|
||||
支持的查询能力列表
|
||||
tuple[list, Optional[str]]: (模型列表, 错误信息)
|
||||
"""
|
||||
# 获取提供商
|
||||
from sqlalchemy import select
|
||||
|
||||
result = await db.execute(select(Provider).where(Provider.id == provider_id))
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
registry = get_query_registry()
|
||||
capabilities = registry.get_capabilities_for_provider(provider.name)
|
||||
|
||||
if capabilities is None:
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"provider_id": provider_id,
|
||||
"provider_name": provider.name,
|
||||
"capabilities": [],
|
||||
"has_adapter": False,
|
||||
"message": "No query adapter available for this provider",
|
||||
},
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"provider_id": provider_id,
|
||||
"provider_name": provider.name,
|
||||
"capabilities": [c.name for c in capabilities],
|
||||
"has_adapter": True,
|
||||
},
|
||||
headers = {
|
||||
"x-api-key": api_key,
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"anthropic-version": "2023-06-01",
|
||||
}
|
||||
|
||||
# 构建 /v1/models URL
|
||||
if base_url.endswith("/v1"):
|
||||
models_url = f"{base_url}/models"
|
||||
else:
|
||||
models_url = f"{base_url}/v1/models"
|
||||
|
||||
@router.post("/balance")
|
||||
async def query_balance(
|
||||
request: BalanceQueryRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
查询提供商余额
|
||||
try:
|
||||
response = await client.get(models_url, headers=headers)
|
||||
logger.debug(f"Claude models request to {models_url}: status={response.status_code}")
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
models = []
|
||||
if "data" in data:
|
||||
models = data["data"]
|
||||
elif isinstance(data, list):
|
||||
models = data
|
||||
# 为每个模型添加 api_format 字段
|
||||
for m in models:
|
||||
m["api_format"] = api_format
|
||||
return models, None
|
||||
else:
|
||||
error_body = response.text[:500] if response.text else "(empty)"
|
||||
error_msg = f"HTTP {response.status_code}: {error_body}"
|
||||
logger.warning(f"Claude models request to {models_url} failed: {error_msg}")
|
||||
return [], error_msg
|
||||
except Exception as e:
|
||||
error_msg = f"Request error: {str(e)}"
|
||||
logger.warning(f"Failed to fetch Claude models from {models_url}: {e}")
|
||||
return [], error_msg
|
||||
|
||||
Args:
|
||||
request: 查询请求
|
||||
|
||||
async def _fetch_gemini_models(
|
||||
client: httpx.AsyncClient, base_url: str, api_key: str, api_format: str
|
||||
) -> tuple[list, Optional[str]]:
|
||||
"""获取 Gemini 格式的模型列表
|
||||
|
||||
Returns:
|
||||
余额信息
|
||||
tuple[list, Optional[str]]: (模型列表, 错误信息)
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
# 兼容 base_url 已包含 /v1beta 的情况
|
||||
base_url_clean = base_url.rstrip("/")
|
||||
if base_url_clean.endswith("/v1beta"):
|
||||
models_url = f"{base_url_clean}/models?key={api_key}"
|
||||
else:
|
||||
models_url = f"{base_url_clean}/v1beta/models?key={api_key}"
|
||||
|
||||
# 获取提供商及其端点
|
||||
result = await db.execute(
|
||||
select(Provider)
|
||||
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
|
||||
.where(Provider.id == request.provider_id)
|
||||
)
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
# 获取 API Key
|
||||
api_key_value = None
|
||||
endpoint_config = None
|
||||
|
||||
if request.api_key_id:
|
||||
# 查找指定的 API Key
|
||||
for endpoint in provider.endpoints:
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.id == request.api_key_id:
|
||||
api_key_value = api_key.api_key
|
||||
endpoint_config = {
|
||||
"base_url": endpoint.base_url,
|
||||
"api_format": endpoint.api_format if endpoint.api_format else None,
|
||||
try:
|
||||
response = await client.get(models_url)
|
||||
logger.debug(f"Gemini models request to {models_url}: status={response.status_code}")
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if "models" in data:
|
||||
# 转换为统一格式
|
||||
return [
|
||||
{
|
||||
"id": m.get("name", "").replace("models/", ""),
|
||||
"owned_by": "google",
|
||||
"display_name": m.get("displayName", ""),
|
||||
"api_format": api_format,
|
||||
}
|
||||
break
|
||||
if api_key_value:
|
||||
break
|
||||
|
||||
if not api_key_value:
|
||||
raise HTTPException(status_code=404, detail="API Key not found")
|
||||
else:
|
||||
# 使用第一个可用的 API Key
|
||||
for endpoint in provider.endpoints:
|
||||
if endpoint.is_active and endpoint.api_keys:
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.is_active:
|
||||
api_key_value = api_key.api_key
|
||||
endpoint_config = {
|
||||
"base_url": endpoint.base_url,
|
||||
"api_format": endpoint.api_format if endpoint.api_format else None,
|
||||
}
|
||||
break
|
||||
if api_key_value:
|
||||
break
|
||||
|
||||
if not api_key_value:
|
||||
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
|
||||
|
||||
# 查询余额
|
||||
registry = get_query_registry()
|
||||
query_result = await registry.query_provider_balance(
|
||||
provider_type=provider.name, api_key=api_key_value, endpoint_config=endpoint_config
|
||||
)
|
||||
|
||||
if not query_result.success:
|
||||
logger.warning(f"Balance query failed for provider {provider.name}: {query_result.error}")
|
||||
|
||||
return {
|
||||
"success": query_result.success,
|
||||
"data": query_result.to_dict(),
|
||||
"provider": {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.post("/usage-summary")
|
||||
async def query_usage_summary(
|
||||
request: UsageSummaryQueryRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
查询提供商使用汇总
|
||||
|
||||
Args:
|
||||
request: 查询请求
|
||||
|
||||
Returns:
|
||||
使用汇总信息
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
# 获取提供商及其端点
|
||||
result = await db.execute(
|
||||
select(Provider)
|
||||
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
|
||||
.where(Provider.id == request.provider_id)
|
||||
)
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
# 获取 API Key(逻辑同上)
|
||||
api_key_value = None
|
||||
endpoint_config = None
|
||||
|
||||
if request.api_key_id:
|
||||
for endpoint in provider.endpoints:
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.id == request.api_key_id:
|
||||
api_key_value = api_key.api_key
|
||||
endpoint_config = {"base_url": endpoint.base_url}
|
||||
break
|
||||
if api_key_value:
|
||||
break
|
||||
|
||||
if not api_key_value:
|
||||
raise HTTPException(status_code=404, detail="API Key not found")
|
||||
else:
|
||||
for endpoint in provider.endpoints:
|
||||
if endpoint.is_active and endpoint.api_keys:
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.is_active:
|
||||
api_key_value = api_key.api_key
|
||||
endpoint_config = {"base_url": endpoint.base_url}
|
||||
break
|
||||
if api_key_value:
|
||||
break
|
||||
|
||||
if not api_key_value:
|
||||
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
|
||||
|
||||
# 查询使用汇总
|
||||
registry = get_query_registry()
|
||||
query_result = await registry.query_provider_usage(
|
||||
provider_type=provider.name,
|
||||
api_key=api_key_value,
|
||||
period=request.period,
|
||||
endpoint_config=endpoint_config,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": query_result.success,
|
||||
"data": query_result.to_dict(),
|
||||
"provider": {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
},
|
||||
}
|
||||
for m in data["models"]
|
||||
], None
|
||||
return [], None
|
||||
else:
|
||||
error_body = response.text[:500] if response.text else "(empty)"
|
||||
error_msg = f"HTTP {response.status_code}: {error_body}"
|
||||
logger.warning(f"Gemini models request to {models_url} failed: {error_msg}")
|
||||
return [], error_msg
|
||||
except Exception as e:
|
||||
error_msg = f"Request error: {str(e)}"
|
||||
logger.warning(f"Failed to fetch Gemini models from {models_url}: {e}")
|
||||
return [], error_msg
|
||||
|
||||
|
||||
@router.post("/models")
|
||||
async def query_available_models(
|
||||
request: ModelsQueryRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
查询提供商可用模型
|
||||
|
||||
遍历所有活跃端点,根据端点的 API 格式选择正确的请求方式:
|
||||
- OPENAI/OPENAI_CLI: /v1/models (Bearer token)
|
||||
- CLAUDE/CLAUDE_CLI: /v1/models (x-api-key)
|
||||
- GEMINI/GEMINI_CLI: /v1beta/models (URL key parameter)
|
||||
|
||||
Args:
|
||||
request: 查询请求
|
||||
|
||||
Returns:
|
||||
模型列表
|
||||
所有端点的模型列表(合并)
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
# 获取提供商及其端点
|
||||
result = await db.execute(
|
||||
select(Provider)
|
||||
.options(selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys))
|
||||
.where(Provider.id == request.provider_id)
|
||||
provider = (
|
||||
db.query(Provider)
|
||||
.options(joinedload(Provider.endpoints).joinedload(ProviderEndpoint.api_keys))
|
||||
.filter(Provider.id == request.provider_id)
|
||||
.first()
|
||||
)
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
# 获取 API Key
|
||||
api_key_value = None
|
||||
endpoint_config = None
|
||||
# 收集所有活跃端点的配置
|
||||
endpoint_configs: list[dict] = []
|
||||
|
||||
if request.api_key_id:
|
||||
# 指定了特定的 API Key,只使用该 Key 对应的端点
|
||||
for endpoint in provider.endpoints:
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.id == request.api_key_id:
|
||||
api_key_value = api_key.api_key
|
||||
endpoint_config = {"base_url": endpoint.base_url}
|
||||
try:
|
||||
api_key_value = crypto_service.decrypt(api_key.api_key)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decrypt API key: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to decrypt API key")
|
||||
endpoint_configs.append({
|
||||
"api_key": api_key_value,
|
||||
"base_url": endpoint.base_url,
|
||||
"api_format": endpoint.api_format,
|
||||
"extra_headers": endpoint.headers,
|
||||
})
|
||||
break
|
||||
if api_key_value:
|
||||
if endpoint_configs:
|
||||
break
|
||||
|
||||
if not api_key_value:
|
||||
if not endpoint_configs:
|
||||
raise HTTPException(status_code=404, detail="API Key not found")
|
||||
else:
|
||||
# 遍历所有活跃端点,每个端点取第一个可用的 Key
|
||||
for endpoint in provider.endpoints:
|
||||
if endpoint.is_active and endpoint.api_keys:
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.is_active:
|
||||
api_key_value = api_key.api_key
|
||||
endpoint_config = {"base_url": endpoint.base_url}
|
||||
break
|
||||
if api_key_value:
|
||||
break
|
||||
if not endpoint.is_active or not endpoint.api_keys:
|
||||
continue
|
||||
|
||||
if not api_key_value:
|
||||
# 找第一个可用的 Key
|
||||
for api_key in endpoint.api_keys:
|
||||
if api_key.is_active:
|
||||
try:
|
||||
api_key_value = crypto_service.decrypt(api_key.api_key)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decrypt API key: {e}")
|
||||
continue # 尝试下一个 Key
|
||||
endpoint_configs.append({
|
||||
"api_key": api_key_value,
|
||||
"base_url": endpoint.base_url,
|
||||
"api_format": endpoint.api_format,
|
||||
"extra_headers": endpoint.headers,
|
||||
})
|
||||
break # 只取第一个可用的 Key
|
||||
|
||||
if not endpoint_configs:
|
||||
raise HTTPException(status_code=400, detail="No active API Key found for this provider")
|
||||
|
||||
# 查询模型
|
||||
registry = get_query_registry()
|
||||
adapter = registry.get_adapter_for_provider(provider.name)
|
||||
# 并发请求所有端点的模型列表
|
||||
all_models: list = []
|
||||
errors: list[str] = []
|
||||
|
||||
if not adapter:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"No query adapter available for provider: {provider.name}"
|
||||
async def fetch_endpoint_models(
|
||||
client: httpx.AsyncClient, config: dict
|
||||
) -> tuple[list, Optional[str]]:
|
||||
base_url = config["base_url"]
|
||||
if not base_url:
|
||||
return [], None
|
||||
base_url = base_url.rstrip("/")
|
||||
api_format = config["api_format"]
|
||||
api_key_value = config["api_key"]
|
||||
extra_headers = config["extra_headers"]
|
||||
|
||||
try:
|
||||
if api_format in ["CLAUDE", "CLAUDE_CLI"]:
|
||||
return await _fetch_claude_models(client, base_url, api_key_value, api_format)
|
||||
elif api_format in ["GEMINI", "GEMINI_CLI"]:
|
||||
return await _fetch_gemini_models(client, base_url, api_key_value, api_format)
|
||||
else:
|
||||
return await _fetch_openai_models(
|
||||
client, base_url, api_key_value, api_format, extra_headers
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching models from {api_format} endpoint: {e}")
|
||||
return [], f"{api_format}: {str(e)}"
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
results = await asyncio.gather(
|
||||
*[fetch_endpoint_models(client, c) for c in endpoint_configs]
|
||||
)
|
||||
for models, error in results:
|
||||
all_models.extend(models)
|
||||
if error:
|
||||
errors.append(error)
|
||||
|
||||
query_result = await adapter.query_available_models(
|
||||
api_key=api_key_value, endpoint_config=endpoint_config
|
||||
)
|
||||
# 按 model id 去重(保留第一个)
|
||||
seen_ids: set[str] = set()
|
||||
unique_models: list = []
|
||||
for model in all_models:
|
||||
model_id = model.get("id")
|
||||
if model_id and model_id not in seen_ids:
|
||||
seen_ids.add(model_id)
|
||||
unique_models.append(model)
|
||||
|
||||
error = "; ".join(errors) if errors else None
|
||||
if not unique_models and not error:
|
||||
error = "No models returned from any endpoint"
|
||||
|
||||
return {
|
||||
"success": query_result.success,
|
||||
"data": query_result.to_dict(),
|
||||
"success": len(unique_models) > 0,
|
||||
"data": {"models": unique_models, "error": error},
|
||||
"provider": {
|
||||
"id": provider.id,
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/cache/{provider_id}")
|
||||
async def clear_query_cache(
|
||||
provider_id: str,
|
||||
api_key_id: Optional[str] = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
清除查询缓存
|
||||
|
||||
Args:
|
||||
provider_id: 提供商 ID
|
||||
api_key_id: 可选,指定清除某个 API Key 的缓存
|
||||
|
||||
Returns:
|
||||
清除结果
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
|
||||
# 获取提供商
|
||||
result = await db.execute(select(Provider).where(Provider.id == provider_id))
|
||||
provider = result.scalar_one_or_none()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
registry = get_query_registry()
|
||||
adapter = registry.get_adapter_for_provider(provider.name)
|
||||
|
||||
if adapter:
|
||||
if api_key_id:
|
||||
# 获取 API Key 值来清除缓存
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
result = await db.execute(select(ProviderAPIKey).where(ProviderAPIKey.id == api_key_id))
|
||||
api_key = result.scalar_one_or_none()
|
||||
if api_key:
|
||||
adapter.clear_cache(api_key.api_key)
|
||||
else:
|
||||
adapter.clear_cache()
|
||||
|
||||
return {"success": True, "message": "Cache cleared successfully"}
|
||||
|
||||
@@ -6,7 +6,6 @@ from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
|
||||
from src.api.base.admin_adapter import AdminApiAdapter
|
||||
@@ -26,7 +25,6 @@ from src.models.pydantic_models import (
|
||||
from src.models.database import (
|
||||
GlobalModel,
|
||||
Model,
|
||||
ModelMapping,
|
||||
Provider,
|
||||
)
|
||||
from src.models.pydantic_models import (
|
||||
@@ -136,8 +134,7 @@ async def get_provider_available_source_models(
|
||||
获取该 Provider 支持的所有统一模型名(source_model)
|
||||
|
||||
包括:
|
||||
1. 通过 ModelMapping 映射的模型
|
||||
2. 直连模型(Model.provider_model_name 直接作为统一模型名)
|
||||
1. 直连模型(Model.provider_model_name 直接作为统一模型名)
|
||||
"""
|
||||
adapter = AdminGetProviderAvailableSourceModelsAdapter(provider_id=provider_id)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
@@ -294,10 +291,9 @@ class AdminGetProviderAvailableSourceModelsAdapter(AdminApiAdapter):
|
||||
"""
|
||||
返回 Provider 支持的所有 GlobalModel
|
||||
|
||||
方案 A 逻辑:
|
||||
逻辑:
|
||||
1. 查询该 Provider 的所有 Model
|
||||
2. 通过 Model.global_model_id 获取 GlobalModel
|
||||
3. 查询所有指向该 GlobalModel 的别名(ModelMapping.alias)
|
||||
"""
|
||||
db = context.db
|
||||
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
|
||||
@@ -324,27 +320,10 @@ class AdminGetProviderAvailableSourceModelsAdapter(AdminApiAdapter):
|
||||
|
||||
# 如果该 GlobalModel 还未处理,初始化
|
||||
if global_model_name not in global_models_dict:
|
||||
# 查询指向该 GlobalModel 的所有别名/映射
|
||||
alias_rows = (
|
||||
db.query(ModelMapping.source_model)
|
||||
.filter(
|
||||
ModelMapping.target_global_model_id == global_model.id,
|
||||
ModelMapping.is_active == True,
|
||||
or_(
|
||||
ModelMapping.provider_id == self.provider_id,
|
||||
ModelMapping.provider_id.is_(None),
|
||||
),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
alias_list = [alias[0] for alias in alias_rows]
|
||||
|
||||
global_models_dict[global_model_name] = {
|
||||
"global_model_name": global_model_name,
|
||||
"display_name": global_model.display_name,
|
||||
"provider_model_name": model.provider_model_name,
|
||||
"has_alias": len(alias_list) > 0,
|
||||
"aliases": alias_list,
|
||||
"model_id": model.id,
|
||||
"price": {
|
||||
"input_price_per_1m": model.get_effective_input_price(),
|
||||
|
||||
@@ -91,6 +91,34 @@ async def get_api_formats(request: Request, db: Session = Depends(get_db)):
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/config/export")
|
||||
async def export_config(request: Request, db: Session = Depends(get_db)):
|
||||
"""导出提供商和模型配置(管理员)"""
|
||||
adapter = AdminExportConfigAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/config/import")
|
||||
async def import_config(request: Request, db: Session = Depends(get_db)):
|
||||
"""导入提供商和模型配置(管理员)"""
|
||||
adapter = AdminImportConfigAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.get("/users/export")
|
||||
async def export_users(request: Request, db: Session = Depends(get_db)):
|
||||
"""导出用户数据(管理员)"""
|
||||
adapter = AdminExportUsersAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
@router.post("/users/import")
|
||||
async def import_users(request: Request, db: Session = Depends(get_db)):
|
||||
"""导入用户数据(管理员)"""
|
||||
adapter = AdminImportUsersAdapter()
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
|
||||
|
||||
|
||||
# -------- 系统设置适配器 --------
|
||||
|
||||
|
||||
@@ -310,3 +338,749 @@ class AdminGetApiFormatsAdapter(AdminApiAdapter):
|
||||
)
|
||||
|
||||
return {"formats": formats}
|
||||
|
||||
|
||||
class AdminExportConfigAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
"""导出提供商和模型配置(解密数据)"""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from src.core.crypto import crypto_service
|
||||
from src.models.database import GlobalModel, Model, ProviderAPIKey, ProviderEndpoint
|
||||
|
||||
db = context.db
|
||||
|
||||
# 导出 GlobalModels
|
||||
global_models = db.query(GlobalModel).all()
|
||||
global_models_data = []
|
||||
for gm in global_models:
|
||||
global_models_data.append(
|
||||
{
|
||||
"name": gm.name,
|
||||
"display_name": gm.display_name,
|
||||
"default_price_per_request": gm.default_price_per_request,
|
||||
"default_tiered_pricing": gm.default_tiered_pricing,
|
||||
"supported_capabilities": gm.supported_capabilities,
|
||||
"config": gm.config,
|
||||
"is_active": gm.is_active,
|
||||
}
|
||||
)
|
||||
|
||||
# 导出 Providers 及其关联数据
|
||||
providers = db.query(Provider).all()
|
||||
providers_data = []
|
||||
for provider in providers:
|
||||
# 导出 Endpoints
|
||||
endpoints = (
|
||||
db.query(ProviderEndpoint)
|
||||
.filter(ProviderEndpoint.provider_id == provider.id)
|
||||
.all()
|
||||
)
|
||||
endpoints_data = []
|
||||
for ep in endpoints:
|
||||
# 导出 Endpoint Keys
|
||||
keys = (
|
||||
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == ep.id).all()
|
||||
)
|
||||
keys_data = []
|
||||
for key in keys:
|
||||
# 解密 API Key
|
||||
try:
|
||||
decrypted_key = crypto_service.decrypt(key.api_key)
|
||||
except Exception:
|
||||
decrypted_key = ""
|
||||
|
||||
keys_data.append(
|
||||
{
|
||||
"api_key": decrypted_key,
|
||||
"name": key.name,
|
||||
"note": key.note,
|
||||
"rate_multiplier": key.rate_multiplier,
|
||||
"internal_priority": key.internal_priority,
|
||||
"global_priority": key.global_priority,
|
||||
"max_concurrent": key.max_concurrent,
|
||||
"rate_limit": key.rate_limit,
|
||||
"daily_limit": key.daily_limit,
|
||||
"monthly_limit": key.monthly_limit,
|
||||
"allowed_models": key.allowed_models,
|
||||
"capabilities": key.capabilities,
|
||||
"is_active": key.is_active,
|
||||
}
|
||||
)
|
||||
|
||||
endpoints_data.append(
|
||||
{
|
||||
"api_format": ep.api_format,
|
||||
"base_url": ep.base_url,
|
||||
"headers": ep.headers,
|
||||
"timeout": ep.timeout,
|
||||
"max_retries": ep.max_retries,
|
||||
"max_concurrent": ep.max_concurrent,
|
||||
"rate_limit": ep.rate_limit,
|
||||
"is_active": ep.is_active,
|
||||
"custom_path": ep.custom_path,
|
||||
"config": ep.config,
|
||||
"keys": keys_data,
|
||||
}
|
||||
)
|
||||
|
||||
# 导出 Provider Models
|
||||
models = db.query(Model).filter(Model.provider_id == provider.id).all()
|
||||
models_data = []
|
||||
for model in models:
|
||||
# 获取关联的 GlobalModel 名称
|
||||
global_model = (
|
||||
db.query(GlobalModel).filter(GlobalModel.id == model.global_model_id).first()
|
||||
)
|
||||
models_data.append(
|
||||
{
|
||||
"global_model_name": global_model.name if global_model else None,
|
||||
"provider_model_name": model.provider_model_name,
|
||||
"provider_model_aliases": model.provider_model_aliases,
|
||||
"price_per_request": model.price_per_request,
|
||||
"tiered_pricing": model.tiered_pricing,
|
||||
"supports_vision": model.supports_vision,
|
||||
"supports_function_calling": model.supports_function_calling,
|
||||
"supports_streaming": model.supports_streaming,
|
||||
"supports_extended_thinking": model.supports_extended_thinking,
|
||||
"supports_image_generation": model.supports_image_generation,
|
||||
"is_active": model.is_active,
|
||||
"config": model.config,
|
||||
}
|
||||
)
|
||||
|
||||
providers_data.append(
|
||||
{
|
||||
"name": provider.name,
|
||||
"display_name": provider.display_name,
|
||||
"description": provider.description,
|
||||
"website": provider.website,
|
||||
"billing_type": provider.billing_type.value if provider.billing_type else None,
|
||||
"monthly_quota_usd": provider.monthly_quota_usd,
|
||||
"quota_reset_day": provider.quota_reset_day,
|
||||
"rpm_limit": provider.rpm_limit,
|
||||
"provider_priority": provider.provider_priority,
|
||||
"is_active": provider.is_active,
|
||||
"rate_limit": provider.rate_limit,
|
||||
"concurrent_limit": provider.concurrent_limit,
|
||||
"config": provider.config,
|
||||
"endpoints": endpoints_data,
|
||||
"models": models_data,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"version": "1.0",
|
||||
"exported_at": datetime.now(timezone.utc).isoformat(),
|
||||
"global_models": global_models_data,
|
||||
"providers": providers_data,
|
||||
}
|
||||
|
||||
|
||||
MAX_IMPORT_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
|
||||
|
||||
class AdminImportConfigAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
"""导入提供商和模型配置"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from src.core.crypto import crypto_service
|
||||
from src.core.enums import ProviderBillingType
|
||||
from src.models.database import GlobalModel, Model, ProviderAPIKey, ProviderEndpoint
|
||||
|
||||
# 检查请求体大小
|
||||
if context.raw_body and len(context.raw_body) > MAX_IMPORT_SIZE:
|
||||
raise InvalidRequestException("请求体大小不能超过 10MB")
|
||||
|
||||
db = context.db
|
||||
payload = context.ensure_json_body()
|
||||
|
||||
# 验证配置版本
|
||||
version = payload.get("version")
|
||||
if version != "1.0":
|
||||
raise InvalidRequestException(f"不支持的配置版本: {version}")
|
||||
|
||||
# 获取导入选项
|
||||
merge_mode = payload.get("merge_mode", "skip") # skip, overwrite, error
|
||||
global_models_data = payload.get("global_models", [])
|
||||
providers_data = payload.get("providers", [])
|
||||
|
||||
stats = {
|
||||
"global_models": {"created": 0, "updated": 0, "skipped": 0},
|
||||
"providers": {"created": 0, "updated": 0, "skipped": 0},
|
||||
"endpoints": {"created": 0, "updated": 0, "skipped": 0},
|
||||
"keys": {"created": 0, "updated": 0, "skipped": 0},
|
||||
"models": {"created": 0, "updated": 0, "skipped": 0},
|
||||
"errors": [],
|
||||
}
|
||||
|
||||
try:
|
||||
# 导入 GlobalModels
|
||||
global_model_map = {} # name -> id 映射
|
||||
for gm_data in global_models_data:
|
||||
existing = (
|
||||
db.query(GlobalModel).filter(GlobalModel.name == gm_data["name"]).first()
|
||||
)
|
||||
|
||||
if existing:
|
||||
global_model_map[gm_data["name"]] = existing.id
|
||||
if merge_mode == "skip":
|
||||
stats["global_models"]["skipped"] += 1
|
||||
continue
|
||||
elif merge_mode == "error":
|
||||
raise InvalidRequestException(
|
||||
f"GlobalModel '{gm_data['name']}' 已存在"
|
||||
)
|
||||
elif merge_mode == "overwrite":
|
||||
# 更新现有记录
|
||||
existing.display_name = gm_data.get(
|
||||
"display_name", existing.display_name
|
||||
)
|
||||
existing.default_price_per_request = gm_data.get(
|
||||
"default_price_per_request"
|
||||
)
|
||||
existing.default_tiered_pricing = gm_data.get(
|
||||
"default_tiered_pricing", existing.default_tiered_pricing
|
||||
)
|
||||
existing.supported_capabilities = gm_data.get(
|
||||
"supported_capabilities"
|
||||
)
|
||||
existing.config = gm_data.get("config")
|
||||
existing.is_active = gm_data.get("is_active", True)
|
||||
existing.updated_at = datetime.now(timezone.utc)
|
||||
stats["global_models"]["updated"] += 1
|
||||
else:
|
||||
# 创建新记录
|
||||
new_gm = GlobalModel(
|
||||
id=str(uuid.uuid4()),
|
||||
name=gm_data["name"],
|
||||
display_name=gm_data.get("display_name", gm_data["name"]),
|
||||
default_price_per_request=gm_data.get("default_price_per_request"),
|
||||
default_tiered_pricing=gm_data.get(
|
||||
"default_tiered_pricing",
|
||||
{"tiers": [{"up_to": None, "input_price_per_1m": 0, "output_price_per_1m": 0}]},
|
||||
),
|
||||
supported_capabilities=gm_data.get("supported_capabilities"),
|
||||
config=gm_data.get("config"),
|
||||
is_active=gm_data.get("is_active", True),
|
||||
)
|
||||
db.add(new_gm)
|
||||
db.flush()
|
||||
global_model_map[gm_data["name"]] = new_gm.id
|
||||
stats["global_models"]["created"] += 1
|
||||
|
||||
# 导入 Providers
|
||||
for prov_data in providers_data:
|
||||
existing_provider = (
|
||||
db.query(Provider).filter(Provider.name == prov_data["name"]).first()
|
||||
)
|
||||
|
||||
if existing_provider:
|
||||
provider_id = existing_provider.id
|
||||
if merge_mode == "skip":
|
||||
stats["providers"]["skipped"] += 1
|
||||
# 仍然需要处理 endpoints 和 models(如果存在)
|
||||
elif merge_mode == "error":
|
||||
raise InvalidRequestException(
|
||||
f"Provider '{prov_data['name']}' 已存在"
|
||||
)
|
||||
elif merge_mode == "overwrite":
|
||||
# 更新现有记录
|
||||
existing_provider.display_name = prov_data.get(
|
||||
"display_name", existing_provider.display_name
|
||||
)
|
||||
existing_provider.description = prov_data.get("description")
|
||||
existing_provider.website = prov_data.get("website")
|
||||
if prov_data.get("billing_type"):
|
||||
existing_provider.billing_type = ProviderBillingType(
|
||||
prov_data["billing_type"]
|
||||
)
|
||||
existing_provider.monthly_quota_usd = prov_data.get(
|
||||
"monthly_quota_usd"
|
||||
)
|
||||
existing_provider.quota_reset_day = prov_data.get(
|
||||
"quota_reset_day", 30
|
||||
)
|
||||
existing_provider.rpm_limit = prov_data.get("rpm_limit")
|
||||
existing_provider.provider_priority = prov_data.get(
|
||||
"provider_priority", 100
|
||||
)
|
||||
existing_provider.is_active = prov_data.get("is_active", True)
|
||||
existing_provider.rate_limit = prov_data.get("rate_limit")
|
||||
existing_provider.concurrent_limit = prov_data.get(
|
||||
"concurrent_limit"
|
||||
)
|
||||
existing_provider.config = prov_data.get("config")
|
||||
existing_provider.updated_at = datetime.now(timezone.utc)
|
||||
stats["providers"]["updated"] += 1
|
||||
else:
|
||||
# 创建新 Provider
|
||||
billing_type = ProviderBillingType.PAY_AS_YOU_GO
|
||||
if prov_data.get("billing_type"):
|
||||
billing_type = ProviderBillingType(prov_data["billing_type"])
|
||||
|
||||
new_provider = Provider(
|
||||
id=str(uuid.uuid4()),
|
||||
name=prov_data["name"],
|
||||
display_name=prov_data.get("display_name", prov_data["name"]),
|
||||
description=prov_data.get("description"),
|
||||
website=prov_data.get("website"),
|
||||
billing_type=billing_type,
|
||||
monthly_quota_usd=prov_data.get("monthly_quota_usd"),
|
||||
quota_reset_day=prov_data.get("quota_reset_day", 30),
|
||||
rpm_limit=prov_data.get("rpm_limit"),
|
||||
provider_priority=prov_data.get("provider_priority", 100),
|
||||
is_active=prov_data.get("is_active", True),
|
||||
rate_limit=prov_data.get("rate_limit"),
|
||||
concurrent_limit=prov_data.get("concurrent_limit"),
|
||||
config=prov_data.get("config"),
|
||||
)
|
||||
db.add(new_provider)
|
||||
db.flush()
|
||||
provider_id = new_provider.id
|
||||
stats["providers"]["created"] += 1
|
||||
|
||||
# 导入 Endpoints
|
||||
for ep_data in prov_data.get("endpoints", []):
|
||||
existing_ep = (
|
||||
db.query(ProviderEndpoint)
|
||||
.filter(
|
||||
ProviderEndpoint.provider_id == provider_id,
|
||||
ProviderEndpoint.api_format == ep_data["api_format"],
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_ep:
|
||||
endpoint_id = existing_ep.id
|
||||
if merge_mode == "skip":
|
||||
stats["endpoints"]["skipped"] += 1
|
||||
elif merge_mode == "error":
|
||||
raise InvalidRequestException(
|
||||
f"Endpoint '{ep_data['api_format']}' 已存在于 Provider '{prov_data['name']}'"
|
||||
)
|
||||
elif merge_mode == "overwrite":
|
||||
existing_ep.base_url = ep_data.get(
|
||||
"base_url", existing_ep.base_url
|
||||
)
|
||||
existing_ep.headers = ep_data.get("headers")
|
||||
existing_ep.timeout = ep_data.get("timeout", 300)
|
||||
existing_ep.max_retries = ep_data.get("max_retries", 3)
|
||||
existing_ep.max_concurrent = ep_data.get("max_concurrent")
|
||||
existing_ep.rate_limit = ep_data.get("rate_limit")
|
||||
existing_ep.is_active = ep_data.get("is_active", True)
|
||||
existing_ep.custom_path = ep_data.get("custom_path")
|
||||
existing_ep.config = ep_data.get("config")
|
||||
existing_ep.updated_at = datetime.now(timezone.utc)
|
||||
stats["endpoints"]["updated"] += 1
|
||||
else:
|
||||
new_ep = ProviderEndpoint(
|
||||
id=str(uuid.uuid4()),
|
||||
provider_id=provider_id,
|
||||
api_format=ep_data["api_format"],
|
||||
base_url=ep_data["base_url"],
|
||||
headers=ep_data.get("headers"),
|
||||
timeout=ep_data.get("timeout", 300),
|
||||
max_retries=ep_data.get("max_retries", 3),
|
||||
max_concurrent=ep_data.get("max_concurrent"),
|
||||
rate_limit=ep_data.get("rate_limit"),
|
||||
is_active=ep_data.get("is_active", True),
|
||||
custom_path=ep_data.get("custom_path"),
|
||||
config=ep_data.get("config"),
|
||||
)
|
||||
db.add(new_ep)
|
||||
db.flush()
|
||||
endpoint_id = new_ep.id
|
||||
stats["endpoints"]["created"] += 1
|
||||
|
||||
# 导入 Keys
|
||||
# 获取当前 endpoint 下所有已有的 keys,用于去重
|
||||
existing_keys = (
|
||||
db.query(ProviderAPIKey)
|
||||
.filter(ProviderAPIKey.endpoint_id == endpoint_id)
|
||||
.all()
|
||||
)
|
||||
# 解密已有 keys 用于比对
|
||||
existing_key_values = set()
|
||||
for ek in existing_keys:
|
||||
try:
|
||||
decrypted = crypto_service.decrypt(ek.api_key)
|
||||
existing_key_values.add(decrypted)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for key_data in ep_data.get("keys", []):
|
||||
if not key_data.get("api_key"):
|
||||
stats["errors"].append(
|
||||
f"跳过空 API Key (Endpoint: {ep_data['api_format']})"
|
||||
)
|
||||
continue
|
||||
|
||||
# 检查是否已存在相同的 Key(通过明文比对)
|
||||
if key_data["api_key"] in existing_key_values:
|
||||
stats["keys"]["skipped"] += 1
|
||||
continue
|
||||
|
||||
encrypted_key = crypto_service.encrypt(key_data["api_key"])
|
||||
|
||||
new_key = ProviderAPIKey(
|
||||
id=str(uuid.uuid4()),
|
||||
endpoint_id=endpoint_id,
|
||||
api_key=encrypted_key,
|
||||
name=key_data.get("name"),
|
||||
note=key_data.get("note"),
|
||||
rate_multiplier=key_data.get("rate_multiplier", 1.0),
|
||||
internal_priority=key_data.get("internal_priority", 100),
|
||||
global_priority=key_data.get("global_priority"),
|
||||
max_concurrent=key_data.get("max_concurrent"),
|
||||
rate_limit=key_data.get("rate_limit"),
|
||||
daily_limit=key_data.get("daily_limit"),
|
||||
monthly_limit=key_data.get("monthly_limit"),
|
||||
allowed_models=key_data.get("allowed_models"),
|
||||
capabilities=key_data.get("capabilities"),
|
||||
is_active=key_data.get("is_active", True),
|
||||
)
|
||||
db.add(new_key)
|
||||
# 添加到已有集合,防止同一批导入中重复
|
||||
existing_key_values.add(key_data["api_key"])
|
||||
stats["keys"]["created"] += 1
|
||||
|
||||
# 导入 Models
|
||||
for model_data in prov_data.get("models", []):
|
||||
global_model_name = model_data.get("global_model_name")
|
||||
if not global_model_name:
|
||||
stats["errors"].append(
|
||||
f"跳过无 global_model_name 的模型 (Provider: {prov_data['name']})"
|
||||
)
|
||||
continue
|
||||
|
||||
global_model_id = global_model_map.get(global_model_name)
|
||||
if not global_model_id:
|
||||
# 尝试从数据库查找
|
||||
existing_gm = (
|
||||
db.query(GlobalModel)
|
||||
.filter(GlobalModel.name == global_model_name)
|
||||
.first()
|
||||
)
|
||||
if existing_gm:
|
||||
global_model_id = existing_gm.id
|
||||
else:
|
||||
stats["errors"].append(
|
||||
f"GlobalModel '{global_model_name}' 不存在,跳过模型"
|
||||
)
|
||||
continue
|
||||
|
||||
existing_model = (
|
||||
db.query(Model)
|
||||
.filter(
|
||||
Model.provider_id == provider_id,
|
||||
Model.provider_model_name == model_data["provider_model_name"],
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_model:
|
||||
if merge_mode == "skip":
|
||||
stats["models"]["skipped"] += 1
|
||||
elif merge_mode == "error":
|
||||
raise InvalidRequestException(
|
||||
f"Model '{model_data['provider_model_name']}' 已存在于 Provider '{prov_data['name']}'"
|
||||
)
|
||||
elif merge_mode == "overwrite":
|
||||
existing_model.global_model_id = global_model_id
|
||||
existing_model.provider_model_aliases = model_data.get(
|
||||
"provider_model_aliases"
|
||||
)
|
||||
existing_model.price_per_request = model_data.get(
|
||||
"price_per_request"
|
||||
)
|
||||
existing_model.tiered_pricing = model_data.get(
|
||||
"tiered_pricing"
|
||||
)
|
||||
existing_model.supports_vision = model_data.get(
|
||||
"supports_vision"
|
||||
)
|
||||
existing_model.supports_function_calling = model_data.get(
|
||||
"supports_function_calling"
|
||||
)
|
||||
existing_model.supports_streaming = model_data.get(
|
||||
"supports_streaming"
|
||||
)
|
||||
existing_model.supports_extended_thinking = model_data.get(
|
||||
"supports_extended_thinking"
|
||||
)
|
||||
existing_model.supports_image_generation = model_data.get(
|
||||
"supports_image_generation"
|
||||
)
|
||||
existing_model.is_active = model_data.get("is_active", True)
|
||||
existing_model.config = model_data.get("config")
|
||||
existing_model.updated_at = datetime.now(timezone.utc)
|
||||
stats["models"]["updated"] += 1
|
||||
else:
|
||||
new_model = Model(
|
||||
id=str(uuid.uuid4()),
|
||||
provider_id=provider_id,
|
||||
global_model_id=global_model_id,
|
||||
provider_model_name=model_data["provider_model_name"],
|
||||
provider_model_aliases=model_data.get(
|
||||
"provider_model_aliases"
|
||||
),
|
||||
price_per_request=model_data.get("price_per_request"),
|
||||
tiered_pricing=model_data.get("tiered_pricing"),
|
||||
supports_vision=model_data.get("supports_vision"),
|
||||
supports_function_calling=model_data.get(
|
||||
"supports_function_calling"
|
||||
),
|
||||
supports_streaming=model_data.get("supports_streaming"),
|
||||
supports_extended_thinking=model_data.get(
|
||||
"supports_extended_thinking"
|
||||
),
|
||||
supports_image_generation=model_data.get(
|
||||
"supports_image_generation"
|
||||
),
|
||||
is_active=model_data.get("is_active", True),
|
||||
config=model_data.get("config"),
|
||||
)
|
||||
db.add(new_model)
|
||||
stats["models"]["created"] += 1
|
||||
|
||||
db.commit()
|
||||
|
||||
# 失效缓存
|
||||
from src.services.cache.invalidation import get_cache_invalidation_service
|
||||
|
||||
cache_service = get_cache_invalidation_service()
|
||||
cache_service.clear_all_caches()
|
||||
|
||||
return {
|
||||
"message": "配置导入成功",
|
||||
"stats": stats,
|
||||
}
|
||||
|
||||
except InvalidRequestException:
|
||||
db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise InvalidRequestException(f"导入失败: {str(e)}")
|
||||
|
||||
|
||||
class AdminExportUsersAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
"""导出用户数据(保留加密数据,排除管理员)"""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from src.core.enums import UserRole
|
||||
from src.models.database import ApiKey, User
|
||||
|
||||
db = context.db
|
||||
|
||||
# 导出 Users(排除管理员)
|
||||
users = db.query(User).filter(
|
||||
User.is_deleted.is_(False),
|
||||
User.role != UserRole.ADMIN
|
||||
).all()
|
||||
users_data = []
|
||||
for user in users:
|
||||
# 导出用户的 API Keys(保留加密数据)
|
||||
api_keys = db.query(ApiKey).filter(ApiKey.user_id == user.id).all()
|
||||
api_keys_data = []
|
||||
for key in api_keys:
|
||||
api_keys_data.append(
|
||||
{
|
||||
"key_hash": key.key_hash,
|
||||
"key_encrypted": key.key_encrypted,
|
||||
"name": key.name,
|
||||
"is_standalone": key.is_standalone,
|
||||
"balance_used_usd": key.balance_used_usd,
|
||||
"current_balance_usd": key.current_balance_usd,
|
||||
"allowed_providers": key.allowed_providers,
|
||||
"allowed_endpoints": key.allowed_endpoints,
|
||||
"allowed_api_formats": key.allowed_api_formats,
|
||||
"allowed_models": key.allowed_models,
|
||||
"rate_limit": key.rate_limit,
|
||||
"concurrent_limit": key.concurrent_limit,
|
||||
"force_capabilities": key.force_capabilities,
|
||||
"is_active": key.is_active,
|
||||
"auto_delete_on_expiry": key.auto_delete_on_expiry,
|
||||
"total_requests": key.total_requests,
|
||||
"total_cost_usd": key.total_cost_usd,
|
||||
}
|
||||
)
|
||||
|
||||
users_data.append(
|
||||
{
|
||||
"email": user.email,
|
||||
"username": user.username,
|
||||
"password_hash": user.password_hash,
|
||||
"role": user.role.value if user.role else "user",
|
||||
"allowed_providers": user.allowed_providers,
|
||||
"allowed_endpoints": user.allowed_endpoints,
|
||||
"allowed_models": user.allowed_models,
|
||||
"model_capability_settings": user.model_capability_settings,
|
||||
"quota_usd": user.quota_usd,
|
||||
"used_usd": user.used_usd,
|
||||
"total_usd": user.total_usd,
|
||||
"is_active": user.is_active,
|
||||
"api_keys": api_keys_data,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"version": "1.0",
|
||||
"exported_at": datetime.now(timezone.utc).isoformat(),
|
||||
"users": users_data,
|
||||
}
|
||||
|
||||
|
||||
class AdminImportUsersAdapter(AdminApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
"""导入用户数据"""
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from src.core.enums import UserRole
|
||||
from src.models.database import ApiKey, User
|
||||
|
||||
# 检查请求体大小
|
||||
if context.raw_body and len(context.raw_body) > MAX_IMPORT_SIZE:
|
||||
raise InvalidRequestException("请求体大小不能超过 10MB")
|
||||
|
||||
db = context.db
|
||||
payload = context.ensure_json_body()
|
||||
|
||||
# 验证配置版本
|
||||
version = payload.get("version")
|
||||
if version != "1.0":
|
||||
raise InvalidRequestException(f"不支持的配置版本: {version}")
|
||||
|
||||
# 获取导入选项
|
||||
merge_mode = payload.get("merge_mode", "skip") # skip, overwrite, error
|
||||
users_data = payload.get("users", [])
|
||||
|
||||
stats = {
|
||||
"users": {"created": 0, "updated": 0, "skipped": 0},
|
||||
"api_keys": {"created": 0, "skipped": 0},
|
||||
"errors": [],
|
||||
}
|
||||
|
||||
try:
|
||||
for user_data in users_data:
|
||||
# 跳过管理员角色的导入(不区分大小写)
|
||||
role_str = str(user_data.get("role", "")).lower()
|
||||
if role_str == "admin":
|
||||
stats["errors"].append(f"跳过管理员用户: {user_data.get('email')}")
|
||||
stats["users"]["skipped"] += 1
|
||||
continue
|
||||
|
||||
existing_user = (
|
||||
db.query(User).filter(User.email == user_data["email"]).first()
|
||||
)
|
||||
|
||||
if existing_user:
|
||||
user_id = existing_user.id
|
||||
if merge_mode == "skip":
|
||||
stats["users"]["skipped"] += 1
|
||||
elif merge_mode == "error":
|
||||
raise InvalidRequestException(
|
||||
f"用户 '{user_data['email']}' 已存在"
|
||||
)
|
||||
elif merge_mode == "overwrite":
|
||||
# 更新现有用户
|
||||
existing_user.username = user_data.get(
|
||||
"username", existing_user.username
|
||||
)
|
||||
if user_data.get("password_hash"):
|
||||
existing_user.password_hash = user_data["password_hash"]
|
||||
if user_data.get("role"):
|
||||
existing_user.role = UserRole(user_data["role"])
|
||||
existing_user.allowed_providers = user_data.get("allowed_providers")
|
||||
existing_user.allowed_endpoints = user_data.get("allowed_endpoints")
|
||||
existing_user.allowed_models = user_data.get("allowed_models")
|
||||
existing_user.model_capability_settings = user_data.get(
|
||||
"model_capability_settings"
|
||||
)
|
||||
existing_user.quota_usd = user_data.get("quota_usd")
|
||||
existing_user.used_usd = user_data.get("used_usd", 0.0)
|
||||
existing_user.total_usd = user_data.get("total_usd", 0.0)
|
||||
existing_user.is_active = user_data.get("is_active", True)
|
||||
existing_user.updated_at = datetime.now(timezone.utc)
|
||||
stats["users"]["updated"] += 1
|
||||
else:
|
||||
# 创建新用户
|
||||
role = UserRole.USER
|
||||
if user_data.get("role"):
|
||||
role = UserRole(user_data["role"])
|
||||
|
||||
new_user = User(
|
||||
id=str(uuid.uuid4()),
|
||||
email=user_data["email"],
|
||||
username=user_data.get("username", user_data["email"].split("@")[0]),
|
||||
password_hash=user_data.get("password_hash", ""),
|
||||
role=role,
|
||||
allowed_providers=user_data.get("allowed_providers"),
|
||||
allowed_endpoints=user_data.get("allowed_endpoints"),
|
||||
allowed_models=user_data.get("allowed_models"),
|
||||
model_capability_settings=user_data.get("model_capability_settings"),
|
||||
quota_usd=user_data.get("quota_usd"),
|
||||
used_usd=user_data.get("used_usd", 0.0),
|
||||
total_usd=user_data.get("total_usd", 0.0),
|
||||
is_active=user_data.get("is_active", True),
|
||||
)
|
||||
db.add(new_user)
|
||||
db.flush()
|
||||
user_id = new_user.id
|
||||
stats["users"]["created"] += 1
|
||||
|
||||
# 导入 API Keys
|
||||
for key_data in user_data.get("api_keys", []):
|
||||
# 检查是否已存在相同的 key_hash
|
||||
if key_data.get("key_hash"):
|
||||
existing_key = (
|
||||
db.query(ApiKey)
|
||||
.filter(ApiKey.key_hash == key_data["key_hash"])
|
||||
.first()
|
||||
)
|
||||
if existing_key:
|
||||
stats["api_keys"]["skipped"] += 1
|
||||
continue
|
||||
|
||||
new_key = ApiKey(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
key_hash=key_data.get("key_hash", ""),
|
||||
key_encrypted=key_data.get("key_encrypted"),
|
||||
name=key_data.get("name"),
|
||||
is_standalone=key_data.get("is_standalone", False),
|
||||
balance_used_usd=key_data.get("balance_used_usd", 0.0),
|
||||
current_balance_usd=key_data.get("current_balance_usd"),
|
||||
allowed_providers=key_data.get("allowed_providers"),
|
||||
allowed_endpoints=key_data.get("allowed_endpoints"),
|
||||
allowed_api_formats=key_data.get("allowed_api_formats"),
|
||||
allowed_models=key_data.get("allowed_models"),
|
||||
rate_limit=key_data.get("rate_limit", 100),
|
||||
concurrent_limit=key_data.get("concurrent_limit", 5),
|
||||
force_capabilities=key_data.get("force_capabilities"),
|
||||
is_active=key_data.get("is_active", True),
|
||||
auto_delete_on_expiry=key_data.get("auto_delete_on_expiry", False),
|
||||
total_requests=key_data.get("total_requests", 0),
|
||||
total_cost_usd=key_data.get("total_cost_usd", 0.0),
|
||||
)
|
||||
db.add(new_key)
|
||||
stats["api_keys"]["created"] += 1
|
||||
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"message": "用户数据导入成功",
|
||||
"stats": stats,
|
||||
}
|
||||
|
||||
except InvalidRequestException:
|
||||
db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise InvalidRequestException(f"导入失败: {str(e)}")
|
||||
|
||||
@@ -628,6 +628,7 @@ class AdminUsageRecordsAdapter(AdminApiAdapter):
|
||||
"actual_cost": actual_cost,
|
||||
"rate_multiplier": rate_multiplier,
|
||||
"response_time_ms": usage.response_time_ms,
|
||||
"first_byte_time_ms": usage.first_byte_time_ms, # 首字时间 (TTFB)
|
||||
"created_at": usage.created_at.isoformat(),
|
||||
"is_stream": usage.is_stream,
|
||||
"input_price_per_1m": usage.input_price_per_1m,
|
||||
@@ -738,6 +739,7 @@ class AdminUsageDetailAdapter(AdminApiAdapter):
|
||||
"status_code": usage_record.status_code,
|
||||
"error_message": usage_record.error_message,
|
||||
"response_time_ms": usage_record.response_time_ms,
|
||||
"first_byte_time_ms": usage_record.first_byte_time_ms, # 首字时间 (TTFB)
|
||||
"created_at": usage_record.created_at.isoformat() if usage_record.created_at else None,
|
||||
"request_headers": usage_record.request_headers,
|
||||
"request_body": usage_record.get_request_body(),
|
||||
|
||||
@@ -65,6 +65,21 @@ class ModelInfo:
|
||||
created_at: Optional[str] # ISO 格式
|
||||
created_timestamp: int # Unix 时间戳
|
||||
provider_name: str
|
||||
# 能力配置
|
||||
streaming: bool = True
|
||||
vision: bool = False
|
||||
function_calling: bool = False
|
||||
extended_thinking: bool = False
|
||||
image_generation: bool = False
|
||||
structured_output: bool = False
|
||||
# 规格参数
|
||||
context_limit: Optional[int] = None
|
||||
output_limit: Optional[int] = None
|
||||
# 元信息
|
||||
family: Optional[str] = None
|
||||
knowledge_cutoff: Optional[str] = None
|
||||
input_modalities: Optional[list[str]] = None
|
||||
output_modalities: Optional[list[str]] = None
|
||||
|
||||
|
||||
def get_available_provider_ids(db: Session, api_formats: list[str]) -> set[str]:
|
||||
@@ -181,13 +196,19 @@ def _extract_model_info(model: Any) -> ModelInfo:
|
||||
global_model = model.global_model
|
||||
model_id: str = global_model.name if global_model else model.provider_model_name
|
||||
display_name: str = global_model.display_name if global_model else model.provider_model_name
|
||||
description: Optional[str] = global_model.description if global_model else None
|
||||
created_at: Optional[str] = (
|
||||
model.created_at.strftime("%Y-%m-%dT%H:%M:%SZ") if model.created_at else None
|
||||
)
|
||||
created_timestamp: int = int(model.created_at.timestamp()) if model.created_at else 0
|
||||
provider_name: str = model.provider.name if model.provider else "unknown"
|
||||
|
||||
# 从 GlobalModel.config 提取配置信息
|
||||
config: dict = {}
|
||||
description: Optional[str] = None
|
||||
if global_model:
|
||||
config = global_model.config or {}
|
||||
description = config.get("description")
|
||||
|
||||
return ModelInfo(
|
||||
id=model_id,
|
||||
display_name=display_name,
|
||||
@@ -195,6 +216,21 @@ def _extract_model_info(model: Any) -> ModelInfo:
|
||||
created_at=created_at,
|
||||
created_timestamp=created_timestamp,
|
||||
provider_name=provider_name,
|
||||
# 能力配置
|
||||
streaming=config.get("streaming", True),
|
||||
vision=config.get("vision", False),
|
||||
function_calling=config.get("function_calling", False),
|
||||
extended_thinking=config.get("extended_thinking", False),
|
||||
image_generation=config.get("image_generation", False),
|
||||
structured_output=config.get("structured_output", False),
|
||||
# 规格参数
|
||||
context_limit=config.get("context_limit"),
|
||||
output_limit=config.get("output_limit"),
|
||||
# 元信息
|
||||
family=config.get("family"),
|
||||
knowledge_cutoff=config.get("knowledge_cutoff"),
|
||||
input_modalities=config.get("input_modalities"),
|
||||
output_modalities=config.get("output_modalities"),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import List, Sequence, Tuple, TypeVar
|
||||
from typing import Any, List, Sequence, Tuple, TypeVar
|
||||
|
||||
from sqlalchemy.orm import Query
|
||||
|
||||
@@ -40,10 +40,10 @@ def paginate_sequence(
|
||||
return sliced, meta
|
||||
|
||||
|
||||
def build_pagination_payload(items: List[dict], meta: PaginationMeta, **extra) -> dict:
|
||||
def build_pagination_payload(items: List[dict], meta: PaginationMeta, **extra: Any) -> dict:
|
||||
"""
|
||||
构建标准分页响应 payload。
|
||||
"""
|
||||
payload = {"items": items, "meta": meta.to_dict()}
|
||||
payload: dict = {"items": items, "meta": meta.to_dict()}
|
||||
payload.update(extra)
|
||||
return payload
|
||||
|
||||
@@ -731,8 +731,15 @@ class DashboardDailyStatsAdapter(DashboardAdapter):
|
||||
)
|
||||
# stats_daily.date 存储的是业务日期对应的 UTC 开始时间
|
||||
# 需要转回业务时区再取日期,才能与日期序列匹配
|
||||
def _to_business_date_str(value: datetime) -> str:
|
||||
if value.tzinfo is None:
|
||||
value_utc = value.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
value_utc = value.astimezone(timezone.utc)
|
||||
return value_utc.astimezone(app_tz).date().isoformat()
|
||||
|
||||
stats_map = {
|
||||
stat.date.replace(tzinfo=timezone.utc).astimezone(app_tz).date().isoformat(): {
|
||||
_to_business_date_str(stat.date): {
|
||||
"requests": stat.total_requests,
|
||||
"tokens": stat.input_tokens + stat.output_tokens + stat.cache_creation_tokens + stat.cache_read_tokens,
|
||||
"cost": stat.total_cost,
|
||||
@@ -790,6 +797,38 @@ class DashboardDailyStatsAdapter(DashboardAdapter):
|
||||
"unique_providers": today_unique_providers,
|
||||
"fallback_count": today_fallback_count,
|
||||
}
|
||||
|
||||
# 历史预聚合缺失时兜底:按业务日范围实时计算(仅补最近少量缺失,避免全表扫描)
|
||||
yesterday_date = today_local.date() - timedelta(days=1)
|
||||
historical_end = min(end_date_local.date(), yesterday_date)
|
||||
missing_dates: list[str] = []
|
||||
cursor = start_date_local.date()
|
||||
while cursor <= historical_end:
|
||||
date_str = cursor.isoformat()
|
||||
if date_str not in stats_map:
|
||||
missing_dates.append(date_str)
|
||||
cursor += timedelta(days=1)
|
||||
|
||||
if missing_dates:
|
||||
for date_str in missing_dates[-7:]:
|
||||
target_local = datetime.fromisoformat(date_str).replace(tzinfo=app_tz)
|
||||
computed = StatsAggregatorService.compute_daily_stats(db, target_local)
|
||||
stats_map[date_str] = {
|
||||
"requests": computed["total_requests"],
|
||||
"tokens": (
|
||||
computed["input_tokens"]
|
||||
+ computed["output_tokens"]
|
||||
+ computed["cache_creation_tokens"]
|
||||
+ computed["cache_read_tokens"]
|
||||
),
|
||||
"cost": computed["total_cost"],
|
||||
"avg_response_time": computed["avg_response_time_ms"] / 1000.0
|
||||
if computed["avg_response_time_ms"]
|
||||
else 0,
|
||||
"unique_models": computed["unique_models"],
|
||||
"unique_providers": computed["unique_providers"],
|
||||
"fallback_count": computed["fallback_count"],
|
||||
}
|
||||
else:
|
||||
# 普通用户:仍需实时查询(用户级预聚合可选)
|
||||
query = db.query(Usage).filter(
|
||||
|
||||
@@ -100,6 +100,8 @@ class MessageTelemetry:
|
||||
cache_read_tokens: int = 0,
|
||||
is_stream: bool = False,
|
||||
provider_request_headers: Optional[Dict[str, Any]] = None,
|
||||
# 时间指标
|
||||
first_byte_time_ms: Optional[int] = None, # 首字时间/TTFB
|
||||
# Provider 侧追踪信息(用于记录真实成本)
|
||||
provider_id: Optional[str] = None,
|
||||
provider_endpoint_id: Optional[str] = None,
|
||||
@@ -133,6 +135,7 @@ class MessageTelemetry:
|
||||
api_format=api_format,
|
||||
is_stream=is_stream,
|
||||
response_time_ms=response_time_ms,
|
||||
first_byte_time_ms=first_byte_time_ms, # 传递首字时间
|
||||
status_code=status_code,
|
||||
request_headers=request_headers,
|
||||
request_body=request_body,
|
||||
@@ -395,3 +398,24 @@ class BaseMessageHandler:
|
||||
|
||||
# 创建后台任务,不阻塞当前流
|
||||
asyncio.create_task(_do_update())
|
||||
|
||||
def _log_request_error(self, message: str, error: Exception) -> None:
|
||||
"""记录请求错误日志,对业务异常不打印堆栈
|
||||
|
||||
Args:
|
||||
message: 错误消息前缀
|
||||
error: 异常对象
|
||||
"""
|
||||
from src.core.exceptions import (
|
||||
ProviderException,
|
||||
QuotaExceededException,
|
||||
RateLimitException,
|
||||
ModelNotSupportedException,
|
||||
)
|
||||
|
||||
if isinstance(error, (ProviderException, QuotaExceededException, RateLimitException, ModelNotSupportedException)):
|
||||
# 业务异常:简洁日志,不打印堆栈
|
||||
logger.error(f"{message}: [{type(error).__name__}] {error}")
|
||||
else:
|
||||
# 未知异常:完整堆栈
|
||||
logger.exception(f"{message}: {error}")
|
||||
|
||||
@@ -34,6 +34,7 @@ from src.api.handlers.base.response_parser import ResponseParser
|
||||
from src.api.handlers.base.stream_context import StreamContext
|
||||
from src.api.handlers.base.stream_processor import StreamProcessor
|
||||
from src.api.handlers.base.stream_telemetry import StreamTelemetryRecorder
|
||||
from src.api.handlers.base.utils import build_sse_headers
|
||||
from src.config.settings import config
|
||||
from src.core.exceptions import (
|
||||
EmbeddedErrorException,
|
||||
@@ -263,7 +264,13 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
mapping = await mapper.get_mapping(source_model, provider_id)
|
||||
|
||||
if mapping and mapping.model:
|
||||
mapped_name = str(mapping.model.provider_model_name)
|
||||
# 使用 select_provider_model_name 支持别名功能
|
||||
# 传入 api_key.id 作为 affinity_key,实现相同用户稳定选择同一别名
|
||||
# 传入 api_format 用于过滤适用的别名作用域
|
||||
affinity_key = self.api_key.id if self.api_key else None
|
||||
mapped_name = mapping.model.select_provider_model_name(
|
||||
affinity_key, api_format=self.FORMAT_ID
|
||||
)
|
||||
logger.debug(f"[Chat] 模型映射: {source_model} -> {mapped_name}")
|
||||
return mapped_name
|
||||
|
||||
@@ -362,7 +369,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
ctx,
|
||||
original_headers,
|
||||
original_request_body,
|
||||
self.elapsed_ms(),
|
||||
self.start_time, # 传入开始时间,让 telemetry 在流结束后计算响应时间
|
||||
)
|
||||
|
||||
# 创建监控流
|
||||
@@ -375,11 +382,12 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
return StreamingResponse(
|
||||
monitored_stream,
|
||||
media_type="text/event-stream",
|
||||
headers=build_sse_headers(),
|
||||
background=background_tasks,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"流式请求失败: {e}")
|
||||
self._log_request_error("流式请求失败", e)
|
||||
await self._record_stream_failure(ctx, e, original_headers, original_request_body)
|
||||
raise
|
||||
|
||||
@@ -470,12 +478,13 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
|
||||
stream_response.raise_for_status()
|
||||
|
||||
# 创建行迭代器
|
||||
line_iterator = stream_response.aiter_lines()
|
||||
# 使用字节流迭代器(避免 aiter_lines 的性能问题)
|
||||
# aiter_raw() 返回原始数据块,无缓冲,实现真正的流式传输
|
||||
byte_iterator = stream_response.aiter_raw()
|
||||
|
||||
# 预读检测嵌套错误
|
||||
prefetched_lines = await stream_processor.prefetch_and_check_error(
|
||||
line_iterator,
|
||||
prefetched_chunks = await stream_processor.prefetch_and_check_error(
|
||||
byte_iterator,
|
||||
provider,
|
||||
endpoint,
|
||||
ctx,
|
||||
@@ -500,13 +509,14 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
|
||||
await http_client.aclose()
|
||||
raise
|
||||
|
||||
# 创建流生成器
|
||||
# 创建流生成器(传入字节流迭代器)
|
||||
return stream_processor.create_response_stream(
|
||||
ctx,
|
||||
line_iterator,
|
||||
byte_iterator,
|
||||
response_ctx,
|
||||
http_client,
|
||||
prefetched_lines,
|
||||
prefetched_chunks,
|
||||
start_time=self.start_time,
|
||||
)
|
||||
|
||||
async def _record_stream_failure(
|
||||
|
||||
@@ -11,17 +11,15 @@ CLI Message Handler 通用基类
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import codecs
|
||||
import json
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Dict,
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
import httpx
|
||||
@@ -35,6 +33,8 @@ from src.api.handlers.base.base_handler import (
|
||||
)
|
||||
from src.api.handlers.base.parsers import get_parser_for_format
|
||||
from src.api.handlers.base.request_builder import PassthroughRequestBuilder
|
||||
from src.api.handlers.base.stream_context import StreamContext
|
||||
from src.api.handlers.base.utils import build_sse_headers
|
||||
|
||||
# 直接从具体模块导入,避免循环依赖
|
||||
from src.api.handlers.base.response_parser import (
|
||||
@@ -61,63 +61,6 @@ from src.services.provider.transport import build_provider_url
|
||||
from src.utils.sse_parser import SSEEventParser
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamContext:
|
||||
"""流式请求的上下文信息"""
|
||||
|
||||
# 请求信息
|
||||
model: str = "unknown" # 用户请求的原始模型名
|
||||
mapped_model: Optional[str] = None # 映射后的目标模型名(如果发生了映射)
|
||||
api_format: str = ""
|
||||
request_id: str = ""
|
||||
|
||||
# 用户信息(提前提取避免 Session detached)
|
||||
user_id: int = 0
|
||||
api_key_id: int = 0
|
||||
|
||||
# 统计信息
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cached_tokens: int = 0 # cache_read_input_tokens
|
||||
cache_creation_tokens: int = 0 # cache_creation_input_tokens
|
||||
collected_text: str = ""
|
||||
response_id: Optional[str] = None
|
||||
final_usage: Optional[Dict[str, Any]] = None
|
||||
final_response: Optional[Dict[str, Any]] = None
|
||||
parsed_chunks: list = field(default_factory=list)
|
||||
|
||||
# 流状态
|
||||
start_time: float = field(default_factory=time.time)
|
||||
chunk_count: int = 0
|
||||
data_count: int = 0
|
||||
has_completion: bool = False
|
||||
|
||||
# 响应信息
|
||||
status_code: int = 200
|
||||
response_headers: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
# 请求信息(发送给 Provider 的)
|
||||
provider_request_headers: Dict[str, str] = field(default_factory=dict)
|
||||
provider_request_body: Optional[Dict[str, Any]] = None # 实际发送的请求体
|
||||
|
||||
# Provider 信息
|
||||
provider_name: Optional[str] = None
|
||||
provider_id: Optional[str] = None # Provider ID(用于记录真实成本)
|
||||
endpoint_id: Optional[str] = None
|
||||
key_id: Optional[str] = None
|
||||
attempt_id: Optional[str] = None
|
||||
attempt_synced: bool = False
|
||||
error_message: Optional[str] = None
|
||||
|
||||
# 格式转换信息
|
||||
provider_api_format: str = "" # Provider 的 API 格式(用于响应转换)
|
||||
client_api_format: str = "" # 客户端请求的 API 格式
|
||||
|
||||
# Provider 响应元数据(存储 provider 返回的额外信息,如 Gemini 的 modelVersion)
|
||||
response_metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class CliMessageHandlerBase(BaseMessageHandler):
|
||||
"""
|
||||
CLI 格式消息处理器基类
|
||||
@@ -190,14 +133,17 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
"""
|
||||
获取模型映射后的实际模型名
|
||||
|
||||
按优先级查找:映射 → 别名 → 直接匹配 GlobalModel
|
||||
查找逻辑:
|
||||
1. 直接通过 GlobalModel.name 匹配
|
||||
2. 查找该 Provider 的 Model 实现
|
||||
3. 使用 provider_model_name / provider_model_aliases 选择最终名称
|
||||
|
||||
Args:
|
||||
source_model: 用户请求的模型名(可能是别名)
|
||||
source_model: 用户请求的模型名(必须是 GlobalModel.name)
|
||||
provider_id: Provider ID
|
||||
|
||||
Returns:
|
||||
映射后的 provider_model_name,如果没有找到映射则返回 None
|
||||
映射后的 Provider 模型名,如果没有找到映射则返回 None
|
||||
"""
|
||||
from src.services.model.mapper import ModelMapperMiddleware
|
||||
|
||||
@@ -207,7 +153,13 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
logger.debug(f"[CLI] _get_mapped_model: source={source_model}, provider={provider_id[:8]}..., mapping={mapping}")
|
||||
|
||||
if mapping and mapping.model:
|
||||
mapped_name = str(mapping.model.provider_model_name)
|
||||
# 使用 select_provider_model_name 支持别名功能
|
||||
# 传入 api_key.id 作为 affinity_key,实现相同用户稳定选择同一别名
|
||||
# 传入 api_format 用于过滤适用的别名作用域
|
||||
affinity_key = self.api_key.id if self.api_key else None
|
||||
mapped_name = mapping.model.select_provider_model_name(
|
||||
affinity_key, api_format=self.FORMAT_ID
|
||||
)
|
||||
logger.debug(f"[CLI] 模型映射: {source_model} -> {mapped_name} (provider={provider_id[:8]}...)")
|
||||
return mapped_name
|
||||
|
||||
@@ -403,24 +355,12 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
return StreamingResponse(
|
||||
monitored_stream,
|
||||
media_type="text/event-stream",
|
||||
headers=build_sse_headers(),
|
||||
background=background_tasks,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# 对于已知的业务异常,只记录简洁的错误信息,不输出完整堆栈
|
||||
from src.core.exceptions import (
|
||||
ProviderException,
|
||||
QuotaExceededException,
|
||||
RateLimitException,
|
||||
ModelNotSupportedException,
|
||||
)
|
||||
|
||||
if isinstance(e, (ProviderException, QuotaExceededException, RateLimitException, ModelNotSupportedException)):
|
||||
# 业务异常:简洁日志
|
||||
logger.error(f"流式请求失败: [{type(e).__name__}] {e}")
|
||||
else:
|
||||
# 未知异常:完整堆栈
|
||||
logger.exception(f"流式请求失败: {e}")
|
||||
self._log_request_error("流式请求失败", e)
|
||||
await self._record_stream_failure(ctx, e, original_headers, original_request_body)
|
||||
raise
|
||||
|
||||
@@ -440,7 +380,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
ctx.chunk_count = 0
|
||||
ctx.data_count = 0
|
||||
ctx.has_completion = False
|
||||
ctx.collected_text = ""
|
||||
ctx._collected_text_parts = [] # 重置文本收集
|
||||
ctx.input_tokens = 0
|
||||
ctx.output_tokens = 0
|
||||
ctx.cached_tokens = 0
|
||||
@@ -528,12 +468,12 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
|
||||
stream_response.raise_for_status()
|
||||
|
||||
# 创建行迭代器(只创建一次,后续会继续使用)
|
||||
line_iterator = stream_response.aiter_lines()
|
||||
# 使用字节流迭代器(避免 aiter_lines 的性能问题)
|
||||
byte_iterator = stream_response.aiter_raw()
|
||||
|
||||
# 预读第一个数据块,检测嵌套错误(HTTP 200 但响应体包含错误)
|
||||
prefetched_lines = await self._prefetch_and_check_embedded_error(
|
||||
line_iterator, provider, endpoint, ctx
|
||||
prefetched_chunks = await self._prefetch_and_check_embedded_error(
|
||||
byte_iterator, provider, endpoint, ctx
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
@@ -558,10 +498,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
# 创建流生成器(带预读数据,使用同一个迭代器)
|
||||
return self._create_response_stream_with_prefetch(
|
||||
ctx,
|
||||
line_iterator,
|
||||
byte_iterator,
|
||||
response_ctx,
|
||||
http_client,
|
||||
prefetched_lines,
|
||||
prefetched_chunks,
|
||||
)
|
||||
|
||||
async def _create_response_stream(
|
||||
@@ -571,58 +511,75 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
response_ctx: Any,
|
||||
http_client: httpx.AsyncClient,
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""创建响应流生成器"""
|
||||
"""创建响应流生成器(使用字节流)"""
|
||||
try:
|
||||
sse_parser = SSEEventParser()
|
||||
last_data_time = time.time()
|
||||
streaming_status_updated = False
|
||||
buffer = b""
|
||||
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
|
||||
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
||||
|
||||
# 检查是否需要格式转换
|
||||
needs_conversion = self._needs_format_conversion(ctx)
|
||||
|
||||
async for line in stream_response.aiter_lines():
|
||||
async for chunk in stream_response.aiter_raw():
|
||||
# 在第一次输出数据前更新状态为 streaming
|
||||
if not streaming_status_updated:
|
||||
self._update_usage_to_streaming(ctx.request_id)
|
||||
streaming_status_updated = True
|
||||
|
||||
normalized_line = line.rstrip("\r")
|
||||
events = sse_parser.feed_line(normalized_line)
|
||||
|
||||
if normalized_line == "":
|
||||
for event in events:
|
||||
self._handle_sse_event(
|
||||
ctx,
|
||||
event.get("event"),
|
||||
event.get("data") or "",
|
||||
buffer += chunk
|
||||
# 处理缓冲区中的完整行
|
||||
while b"\n" in buffer:
|
||||
line_bytes, buffer = buffer.split(b"\n", 1)
|
||||
try:
|
||||
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
|
||||
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[{self.request_id}] UTF-8 解码失败: {e}, "
|
||||
f"bytes={line_bytes[:50]!r}"
|
||||
)
|
||||
yield b"\n"
|
||||
continue
|
||||
continue
|
||||
|
||||
ctx.chunk_count += 1
|
||||
normalized_line = line.rstrip("\r")
|
||||
events = sse_parser.feed_line(normalized_line)
|
||||
|
||||
# 空流检测:超过阈值且无数据,发送错误事件并结束
|
||||
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
|
||||
elapsed = time.time() - last_data_time
|
||||
if elapsed > self.DATA_TIMEOUT:
|
||||
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
|
||||
error_event = {
|
||||
"type": "error",
|
||||
"error": {
|
||||
"type": "empty_stream_timeout",
|
||||
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
|
||||
},
|
||||
}
|
||||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
||||
return # 结束生成器
|
||||
if normalized_line == "":
|
||||
for event in events:
|
||||
self._handle_sse_event(
|
||||
ctx,
|
||||
event.get("event"),
|
||||
event.get("data") or "",
|
||||
)
|
||||
yield b"\n"
|
||||
continue
|
||||
|
||||
# 格式转换或直接透传
|
||||
if needs_conversion:
|
||||
converted_line = self._convert_sse_line(ctx, line, events)
|
||||
if converted_line:
|
||||
yield (converted_line + "\n").encode("utf-8")
|
||||
else:
|
||||
yield (line + "\n").encode("utf-8")
|
||||
ctx.chunk_count += 1
|
||||
|
||||
# 空流检测:超过阈值且无数据,发送错误事件并结束
|
||||
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
|
||||
elapsed = time.time() - last_data_time
|
||||
if elapsed > self.DATA_TIMEOUT:
|
||||
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
|
||||
error_event = {
|
||||
"type": "error",
|
||||
"error": {
|
||||
"type": "empty_stream_timeout",
|
||||
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
|
||||
},
|
||||
}
|
||||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
||||
return # 结束生成器
|
||||
|
||||
# 格式转换或直接透传
|
||||
if needs_conversion:
|
||||
converted_line = self._convert_sse_line(ctx, line, events)
|
||||
if converted_line:
|
||||
yield (converted_line + "\n").encode("utf-8")
|
||||
else:
|
||||
yield (line + "\n").encode("utf-8")
|
||||
|
||||
for event in events:
|
||||
self._handle_sse_event(
|
||||
@@ -696,7 +653,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
|
||||
async def _prefetch_and_check_embedded_error(
|
||||
self,
|
||||
line_iterator: Any,
|
||||
byte_iterator: Any,
|
||||
provider: Provider,
|
||||
endpoint: ProviderEndpoint,
|
||||
ctx: StreamContext,
|
||||
@@ -710,20 +667,25 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
同时检测 HTML 响应(通常是 base_url 配置错误导致返回网页)。
|
||||
|
||||
Args:
|
||||
line_iterator: 行迭代器(aiter_lines() 返回的迭代器)
|
||||
byte_iterator: 字节流迭代器
|
||||
provider: Provider 对象
|
||||
endpoint: Endpoint 对象
|
||||
ctx: 流上下文
|
||||
|
||||
Returns:
|
||||
预读的行列表(需要在后续流中先输出)
|
||||
预读的字节块列表(需要在后续流中先输出)
|
||||
|
||||
Raises:
|
||||
EmbeddedErrorException: 如果检测到嵌套错误
|
||||
ProviderNotAvailableException: 如果检测到 HTML 响应(配置错误)
|
||||
"""
|
||||
prefetched_lines: list = []
|
||||
prefetched_chunks: list = []
|
||||
max_prefetch_lines = 5 # 最多预读5行来检测错误
|
||||
buffer = b""
|
||||
line_count = 0
|
||||
should_stop = False
|
||||
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
|
||||
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
||||
|
||||
try:
|
||||
# 获取对应格式的解析器
|
||||
@@ -736,69 +698,86 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
else:
|
||||
provider_parser = self.parser
|
||||
|
||||
line_count = 0
|
||||
async for line in line_iterator:
|
||||
prefetched_lines.append(line)
|
||||
line_count += 1
|
||||
async for chunk in byte_iterator:
|
||||
prefetched_chunks.append(chunk)
|
||||
buffer += chunk
|
||||
|
||||
# 解析数据
|
||||
normalized_line = line.rstrip("\r")
|
||||
# 尝试按行解析缓冲区
|
||||
while b"\n" in buffer:
|
||||
line_bytes, buffer = buffer.split(b"\n", 1)
|
||||
try:
|
||||
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
|
||||
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[{self.request_id}] 预读时 UTF-8 解码失败: {e}, "
|
||||
f"bytes={line_bytes[:50]!r}"
|
||||
)
|
||||
continue
|
||||
|
||||
# 检测 HTML 响应(base_url 配置错误的常见症状)
|
||||
lower_line = normalized_line.lower()
|
||||
if lower_line.startswith("<!doctype") or lower_line.startswith("<html"):
|
||||
logger.error(
|
||||
f" [{self.request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
|
||||
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
|
||||
f"base_url={endpoint.base_url}"
|
||||
)
|
||||
raise ProviderNotAvailableException(
|
||||
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,请检查 endpoint 的 base_url 配置是否正确"
|
||||
)
|
||||
line_count += 1
|
||||
normalized_line = line.rstrip("\r")
|
||||
|
||||
if not normalized_line or normalized_line.startswith(":"):
|
||||
# 空行或注释行,继续预读
|
||||
if line_count >= max_prefetch_lines:
|
||||
# 检测 HTML 响应(base_url 配置错误的常见症状)
|
||||
lower_line = normalized_line.lower()
|
||||
if lower_line.startswith("<!doctype") or lower_line.startswith("<html"):
|
||||
logger.error(
|
||||
f" [{self.request_id}] 检测到 HTML 响应,可能是 base_url 配置错误: "
|
||||
f"Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
|
||||
f"base_url={endpoint.base_url}"
|
||||
)
|
||||
raise ProviderNotAvailableException(
|
||||
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,请检查 endpoint 的 base_url 配置是否正确"
|
||||
)
|
||||
|
||||
if not normalized_line or normalized_line.startswith(":"):
|
||||
# 空行或注释行,继续预读
|
||||
if line_count >= max_prefetch_lines:
|
||||
break
|
||||
continue
|
||||
|
||||
# 尝试解析 SSE 数据
|
||||
data_str = normalized_line
|
||||
if normalized_line.startswith("data: "):
|
||||
data_str = normalized_line[6:]
|
||||
|
||||
if data_str == "[DONE]":
|
||||
should_stop = True
|
||||
break
|
||||
continue
|
||||
|
||||
# 尝试解析 SSE 数据
|
||||
data_str = normalized_line
|
||||
if normalized_line.startswith("data: "):
|
||||
data_str = normalized_line[6:]
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
# 不是有效 JSON,可能是部分数据,继续
|
||||
if line_count >= max_prefetch_lines:
|
||||
break
|
||||
continue
|
||||
|
||||
if data_str == "[DONE]":
|
||||
# 使用解析器检查是否为错误响应
|
||||
if isinstance(data, dict) and provider_parser.is_error_response(data):
|
||||
# 提取错误信息
|
||||
parsed = provider_parser.parse_response(data, 200)
|
||||
logger.warning(f" [{self.request_id}] 检测到嵌套错误: "
|
||||
f"Provider={provider.name}, "
|
||||
f"error_type={parsed.error_type}, "
|
||||
f"message={parsed.error_message}")
|
||||
raise EmbeddedErrorException(
|
||||
provider_name=str(provider.name),
|
||||
error_code=(
|
||||
int(parsed.error_type)
|
||||
if parsed.error_type and parsed.error_type.isdigit()
|
||||
else None
|
||||
),
|
||||
error_message=parsed.error_message,
|
||||
error_status=parsed.error_type,
|
||||
)
|
||||
|
||||
# 预读到有效数据,没有错误,停止预读
|
||||
should_stop = True
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
# 不是有效 JSON,可能是部分数据,继续
|
||||
if line_count >= max_prefetch_lines:
|
||||
break
|
||||
continue
|
||||
|
||||
# 使用解析器检查是否为错误响应
|
||||
if isinstance(data, dict) and provider_parser.is_error_response(data):
|
||||
# 提取错误信息
|
||||
parsed = provider_parser.parse_response(data, 200)
|
||||
logger.warning(f" [{self.request_id}] 检测到嵌套错误: "
|
||||
f"Provider={provider.name}, "
|
||||
f"error_type={parsed.error_type}, "
|
||||
f"message={parsed.error_message}")
|
||||
raise EmbeddedErrorException(
|
||||
provider_name=str(provider.name),
|
||||
error_code=(
|
||||
int(parsed.error_type)
|
||||
if parsed.error_type and parsed.error_type.isdigit()
|
||||
else None
|
||||
),
|
||||
error_message=parsed.error_message,
|
||||
error_status=parsed.error_type,
|
||||
)
|
||||
|
||||
# 预读到有效数据,没有错误,停止预读
|
||||
break
|
||||
if should_stop or line_count >= max_prefetch_lines:
|
||||
break
|
||||
|
||||
except EmbeddedErrorException:
|
||||
# 重新抛出嵌套错误
|
||||
@@ -807,112 +786,168 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
# 其他异常(如网络错误)在预读阶段发生,记录日志但不中断
|
||||
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
|
||||
|
||||
return prefetched_lines
|
||||
return prefetched_chunks
|
||||
|
||||
async def _create_response_stream_with_prefetch(
|
||||
self,
|
||||
ctx: StreamContext,
|
||||
line_iterator: Any,
|
||||
byte_iterator: Any,
|
||||
response_ctx: Any,
|
||||
http_client: httpx.AsyncClient,
|
||||
prefetched_lines: list,
|
||||
prefetched_chunks: list,
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""创建响应流生成器(带预读数据)"""
|
||||
"""创建响应流生成器(带预读数据,使用字节流)"""
|
||||
try:
|
||||
sse_parser = SSEEventParser()
|
||||
last_data_time = time.time()
|
||||
buffer = b""
|
||||
first_yield = True # 标记是否是第一次 yield
|
||||
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
|
||||
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
||||
|
||||
# 检查是否需要格式转换
|
||||
needs_conversion = self._needs_format_conversion(ctx)
|
||||
|
||||
# 在第一次输出数据前更新状态为 streaming
|
||||
if prefetched_lines:
|
||||
if prefetched_chunks:
|
||||
self._update_usage_to_streaming(ctx.request_id)
|
||||
|
||||
# 先处理预读的数据
|
||||
for line in prefetched_lines:
|
||||
normalized_line = line.rstrip("\r")
|
||||
events = sse_parser.feed_line(normalized_line)
|
||||
# 先处理预读的字节块
|
||||
for chunk in prefetched_chunks:
|
||||
buffer += chunk
|
||||
# 处理缓冲区中的完整行
|
||||
while b"\n" in buffer:
|
||||
line_bytes, buffer = buffer.split(b"\n", 1)
|
||||
try:
|
||||
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
|
||||
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[{self.request_id}] UTF-8 解码失败: {e}, "
|
||||
f"bytes={line_bytes[:50]!r}"
|
||||
)
|
||||
continue
|
||||
|
||||
normalized_line = line.rstrip("\r")
|
||||
events = sse_parser.feed_line(normalized_line)
|
||||
|
||||
if normalized_line == "":
|
||||
for event in events:
|
||||
self._handle_sse_event(
|
||||
ctx,
|
||||
event.get("event"),
|
||||
event.get("data") or "",
|
||||
)
|
||||
# 记录首字时间 (第一次 yield)
|
||||
if first_yield:
|
||||
ctx.record_first_byte_time(self.start_time)
|
||||
first_yield = False
|
||||
yield b"\n"
|
||||
continue
|
||||
|
||||
ctx.chunk_count += 1
|
||||
|
||||
# 格式转换或直接透传
|
||||
if needs_conversion:
|
||||
converted_line = self._convert_sse_line(ctx, line, events)
|
||||
if converted_line:
|
||||
# 记录首字时间 (第一次 yield)
|
||||
if first_yield:
|
||||
ctx.record_first_byte_time(self.start_time)
|
||||
first_yield = False
|
||||
yield (converted_line + "\n").encode("utf-8")
|
||||
else:
|
||||
# 记录首字时间 (第一次 yield)
|
||||
if first_yield:
|
||||
ctx.record_first_byte_time(self.start_time)
|
||||
first_yield = False
|
||||
yield (line + "\n").encode("utf-8")
|
||||
|
||||
if normalized_line == "":
|
||||
for event in events:
|
||||
self._handle_sse_event(
|
||||
ctx,
|
||||
event.get("event"),
|
||||
event.get("data") or "",
|
||||
)
|
||||
yield b"\n"
|
||||
continue
|
||||
|
||||
ctx.chunk_count += 1
|
||||
|
||||
# 格式转换或直接透传
|
||||
if needs_conversion:
|
||||
converted_line = self._convert_sse_line(ctx, line, events)
|
||||
if converted_line:
|
||||
yield (converted_line + "\n").encode("utf-8")
|
||||
else:
|
||||
yield (line + "\n").encode("utf-8")
|
||||
|
||||
for event in events:
|
||||
self._handle_sse_event(
|
||||
ctx,
|
||||
event.get("event"),
|
||||
event.get("data") or "",
|
||||
)
|
||||
|
||||
if ctx.data_count > 0:
|
||||
last_data_time = time.time()
|
||||
if ctx.data_count > 0:
|
||||
last_data_time = time.time()
|
||||
|
||||
# 继续处理剩余的流数据(使用同一个迭代器)
|
||||
async for line in line_iterator:
|
||||
normalized_line = line.rstrip("\r")
|
||||
events = sse_parser.feed_line(normalized_line)
|
||||
async for chunk in byte_iterator:
|
||||
buffer += chunk
|
||||
# 处理缓冲区中的完整行
|
||||
while b"\n" in buffer:
|
||||
line_bytes, buffer = buffer.split(b"\n", 1)
|
||||
try:
|
||||
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
|
||||
line = decoder.decode(line_bytes + b"\n", False).rstrip("\n")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[{self.request_id}] UTF-8 解码失败: {e}, "
|
||||
f"bytes={line_bytes[:50]!r}"
|
||||
)
|
||||
continue
|
||||
|
||||
normalized_line = line.rstrip("\r")
|
||||
events = sse_parser.feed_line(normalized_line)
|
||||
|
||||
if normalized_line == "":
|
||||
for event in events:
|
||||
self._handle_sse_event(
|
||||
ctx,
|
||||
event.get("event"),
|
||||
event.get("data") or "",
|
||||
)
|
||||
# 记录首字时间 (第一次 yield) - 如果预读数据为空
|
||||
if first_yield:
|
||||
ctx.record_first_byte_time(self.start_time)
|
||||
first_yield = False
|
||||
yield b"\n"
|
||||
continue
|
||||
|
||||
ctx.chunk_count += 1
|
||||
|
||||
# 空流检测:超过阈值且无数据,发送错误事件并结束
|
||||
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
|
||||
elapsed = time.time() - last_data_time
|
||||
if elapsed > self.DATA_TIMEOUT:
|
||||
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
|
||||
error_event = {
|
||||
"type": "error",
|
||||
"error": {
|
||||
"type": "empty_stream_timeout",
|
||||
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
|
||||
},
|
||||
}
|
||||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
||||
return
|
||||
|
||||
# 格式转换或直接透传
|
||||
if needs_conversion:
|
||||
converted_line = self._convert_sse_line(ctx, line, events)
|
||||
if converted_line:
|
||||
# 记录首字时间 (第一次 yield) - 如果预读数据为空
|
||||
if first_yield:
|
||||
ctx.record_first_byte_time(self.start_time)
|
||||
first_yield = False
|
||||
yield (converted_line + "\n").encode("utf-8")
|
||||
else:
|
||||
# 记录首字时间 (第一次 yield) - 如果预读数据为空
|
||||
if first_yield:
|
||||
ctx.record_first_byte_time(self.start_time)
|
||||
first_yield = False
|
||||
yield (line + "\n").encode("utf-8")
|
||||
|
||||
if normalized_line == "":
|
||||
for event in events:
|
||||
self._handle_sse_event(
|
||||
ctx,
|
||||
event.get("event"),
|
||||
event.get("data") or "",
|
||||
)
|
||||
yield b"\n"
|
||||
continue
|
||||
|
||||
ctx.chunk_count += 1
|
||||
|
||||
# 空流检测:超过阈值且无数据,发送错误事件并结束
|
||||
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
|
||||
elapsed = time.time() - last_data_time
|
||||
if elapsed > self.DATA_TIMEOUT:
|
||||
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
|
||||
error_event = {
|
||||
"type": "error",
|
||||
"error": {
|
||||
"type": "empty_stream_timeout",
|
||||
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
|
||||
},
|
||||
}
|
||||
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
|
||||
return
|
||||
|
||||
# 格式转换或直接透传
|
||||
if needs_conversion:
|
||||
converted_line = self._convert_sse_line(ctx, line, events)
|
||||
if converted_line:
|
||||
yield (converted_line + "\n").encode("utf-8")
|
||||
else:
|
||||
yield (line + "\n").encode("utf-8")
|
||||
|
||||
for event in events:
|
||||
self._handle_sse_event(
|
||||
ctx,
|
||||
event.get("event"),
|
||||
event.get("data") or "",
|
||||
)
|
||||
|
||||
if ctx.data_count > 0:
|
||||
last_data_time = time.time()
|
||||
if ctx.data_count > 0:
|
||||
last_data_time = time.time()
|
||||
|
||||
# 处理剩余事件
|
||||
flushed_events = sse_parser.flush()
|
||||
@@ -1041,7 +1076,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
# 提取文本内容
|
||||
text = self.parser.extract_text_content(data)
|
||||
if text:
|
||||
ctx.collected_text += text
|
||||
ctx.append_text(text)
|
||||
|
||||
# 检查完成事件
|
||||
if event_type in ("response.completed", "message_stop"):
|
||||
@@ -1093,9 +1128,11 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
) -> None:
|
||||
"""在流完成后记录统计信息"""
|
||||
try:
|
||||
await asyncio.sleep(0.1)
|
||||
# 使用 self.start_time 作为时间基准,与首字时间保持一致
|
||||
# 注意:不要把统计延迟算进响应时间里
|
||||
response_time_ms = int((time.time() - self.start_time) * 1000)
|
||||
|
||||
response_time_ms = int((time.time() - ctx.start_time) * 1000)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
if not ctx.provider_name:
|
||||
logger.warning(f"[{ctx.request_id}] 流式请求失败,未选中提供商")
|
||||
@@ -1175,6 +1212,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
input_tokens=actual_input_tokens,
|
||||
output_tokens=ctx.output_tokens,
|
||||
response_time_ms=response_time_ms,
|
||||
first_byte_time_ms=ctx.first_byte_time_ms, # 传递首字时间
|
||||
status_code=ctx.status_code,
|
||||
request_headers=original_headers,
|
||||
request_body=actual_request_body,
|
||||
@@ -1195,9 +1233,18 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
response_metadata=ctx.response_metadata if ctx.response_metadata else None,
|
||||
)
|
||||
logger.debug(f"{self.FORMAT_ID} 流式响应完成")
|
||||
# 简洁的请求完成摘要
|
||||
logger.info(f"[OK] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name} | {response_time_ms}ms | "
|
||||
f"in:{ctx.input_tokens or 0} out:{ctx.output_tokens or 0}")
|
||||
# 简洁的请求完成摘要(两行格式)
|
||||
line1 = (
|
||||
f"[OK] {self.request_id[:8]} | {ctx.model} | {ctx.provider_name}"
|
||||
)
|
||||
if ctx.first_byte_time_ms:
|
||||
line1 += f" | TTFB: {ctx.first_byte_time_ms}ms"
|
||||
|
||||
line2 = (
|
||||
f" Total: {response_time_ms}ms | "
|
||||
f"in:{ctx.input_tokens or 0} out:{ctx.output_tokens or 0}"
|
||||
)
|
||||
logger.info(f"{line1}\n{line2}")
|
||||
|
||||
# 更新候选记录的最终状态和延迟时间
|
||||
# 注意:RequestExecutor 会在流开始时过早地标记成功(只记录了连接建立的时间)
|
||||
@@ -1249,7 +1296,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
|
||||
original_request_body: Dict[str, Any],
|
||||
) -> None:
|
||||
"""记录流式请求失败"""
|
||||
response_time_ms = int((time.time() - ctx.start_time) * 1000)
|
||||
# 使用 self.start_time 作为时间基准,与首字时间保持一致
|
||||
response_time_ms = int((time.time() - self.start_time) * 1000)
|
||||
|
||||
status_code = 503
|
||||
if isinstance(error, ProviderAuthException):
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
不再经过 Protocol 抽象层。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple, Type
|
||||
|
||||
from src.api.handlers.base.response_parser import (
|
||||
ParsedChunk,
|
||||
@@ -13,6 +13,7 @@ from src.api.handlers.base.response_parser import (
|
||||
ResponseParser,
|
||||
StreamStats,
|
||||
)
|
||||
from src.api.handlers.base.utils import extract_cache_creation_tokens
|
||||
|
||||
|
||||
def _check_nested_error(response: Dict[str, Any]) -> Tuple[bool, Optional[Dict[str, Any]]]:
|
||||
@@ -60,7 +61,7 @@ def _check_nested_error(response: Dict[str, Any]) -> Tuple[bool, Optional[Dict[s
|
||||
class OpenAIResponseParser(ResponseParser):
|
||||
"""OpenAI 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
from src.api.handlers.openai.stream_parser import OpenAIStreamParser
|
||||
|
||||
self._parser = OpenAIStreamParser()
|
||||
@@ -146,7 +147,7 @@ class OpenAIResponseParser(ResponseParser):
|
||||
if choices:
|
||||
message = choices[0].get("message", {})
|
||||
content = message.get("content")
|
||||
if content:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
return ""
|
||||
|
||||
@@ -158,7 +159,7 @@ class OpenAIResponseParser(ResponseParser):
|
||||
class OpenAICliResponseParser(OpenAIResponseParser):
|
||||
"""OpenAI CLI 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.name = "OPENAI_CLI"
|
||||
self.api_format = "OPENAI_CLI"
|
||||
@@ -167,7 +168,7 @@ class OpenAICliResponseParser(OpenAIResponseParser):
|
||||
class ClaudeResponseParser(ResponseParser):
|
||||
"""Claude 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
from src.api.handlers.claude.stream_parser import ClaudeStreamParser
|
||||
|
||||
self._parser = ClaudeStreamParser()
|
||||
@@ -252,7 +253,7 @@ class ClaudeResponseParser(ResponseParser):
|
||||
usage = response.get("usage", {})
|
||||
result.input_tokens = usage.get("input_tokens", 0)
|
||||
result.output_tokens = usage.get("output_tokens", 0)
|
||||
result.cache_creation_tokens = usage.get("cache_creation_input_tokens", 0)
|
||||
result.cache_creation_tokens = extract_cache_creation_tokens(usage)
|
||||
result.cache_read_tokens = usage.get("cache_read_input_tokens", 0)
|
||||
|
||||
# 检查错误(支持嵌套错误格式)
|
||||
@@ -265,11 +266,16 @@ class ClaudeResponseParser(ResponseParser):
|
||||
return result
|
||||
|
||||
def extract_usage_from_response(self, response: Dict[str, Any]) -> Dict[str, int]:
|
||||
# 对于 message_start 事件,usage 在 message.usage 路径下
|
||||
# 对于其他响应,usage 在顶层
|
||||
usage = response.get("usage", {})
|
||||
if not usage and "message" in response:
|
||||
usage = response.get("message", {}).get("usage", {})
|
||||
|
||||
return {
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
|
||||
"cache_creation_tokens": extract_cache_creation_tokens(usage),
|
||||
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
|
||||
}
|
||||
|
||||
@@ -291,7 +297,7 @@ class ClaudeResponseParser(ResponseParser):
|
||||
class ClaudeCliResponseParser(ClaudeResponseParser):
|
||||
"""Claude CLI 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.name = "CLAUDE_CLI"
|
||||
self.api_format = "CLAUDE_CLI"
|
||||
@@ -300,7 +306,7 @@ class ClaudeCliResponseParser(ClaudeResponseParser):
|
||||
class GeminiResponseParser(ResponseParser):
|
||||
"""Gemini 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
from src.api.handlers.gemini.stream_parser import GeminiStreamParser
|
||||
|
||||
self._parser = GeminiStreamParser()
|
||||
@@ -443,20 +449,20 @@ class GeminiResponseParser(ResponseParser):
|
||||
|
||||
使用增强的错误检测逻辑,支持嵌套在 chunks 中的错误
|
||||
"""
|
||||
return self._parser.is_error_event(response)
|
||||
return bool(self._parser.is_error_event(response))
|
||||
|
||||
|
||||
class GeminiCliResponseParser(GeminiResponseParser):
|
||||
"""Gemini CLI 格式响应解析器"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.name = "GEMINI_CLI"
|
||||
self.api_format = "GEMINI_CLI"
|
||||
|
||||
|
||||
# 解析器注册表
|
||||
_PARSERS = {
|
||||
_PARSERS: Dict[str, Type[ResponseParser]] = {
|
||||
"CLAUDE": ClaudeResponseParser,
|
||||
"CLAUDE_CLI": ClaudeCliResponseParser,
|
||||
"OPENAI": OpenAIResponseParser,
|
||||
@@ -498,6 +504,5 @@ __all__ = [
|
||||
"GeminiResponseParser",
|
||||
"GeminiCliResponseParser",
|
||||
"get_parser_for_format",
|
||||
"get_parser_from_protocol",
|
||||
"is_cli_format",
|
||||
]
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
- 请求/响应数据
|
||||
"""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@@ -25,12 +26,18 @@ class StreamContext:
|
||||
model: str
|
||||
api_format: str
|
||||
|
||||
# 请求标识信息(CLI handler 需要)
|
||||
request_id: str = ""
|
||||
user_id: int = 0
|
||||
api_key_id: int = 0
|
||||
|
||||
# Provider 信息(在请求执行时填充)
|
||||
provider_name: Optional[str] = None
|
||||
provider_id: Optional[str] = None
|
||||
endpoint_id: Optional[str] = None
|
||||
key_id: Optional[str] = None
|
||||
attempt_id: Optional[str] = None
|
||||
attempt_synced: bool = False
|
||||
provider_api_format: Optional[str] = None # Provider 的响应格式
|
||||
|
||||
# 模型映射
|
||||
@@ -43,7 +50,14 @@ class StreamContext:
|
||||
cache_creation_tokens: int = 0
|
||||
|
||||
# 响应内容
|
||||
collected_text: str = ""
|
||||
_collected_text_parts: List[str] = field(default_factory=list, repr=False)
|
||||
response_id: Optional[str] = None
|
||||
final_usage: Optional[Dict[str, Any]] = None
|
||||
final_response: Optional[Dict[str, Any]] = None
|
||||
|
||||
# 时间指标
|
||||
first_byte_time_ms: Optional[int] = None # 首字时间 (TTFB - Time To First Byte)
|
||||
start_time: float = field(default_factory=time.time)
|
||||
|
||||
# 响应状态
|
||||
status_code: int = 200
|
||||
@@ -55,6 +69,12 @@ class StreamContext:
|
||||
provider_request_headers: Dict[str, str] = field(default_factory=dict)
|
||||
provider_request_body: Optional[Dict[str, Any]] = None
|
||||
|
||||
# 格式转换信息(CLI handler 需要)
|
||||
client_api_format: str = ""
|
||||
|
||||
# Provider 响应元数据(CLI handler 需要)
|
||||
response_metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# 流式处理统计
|
||||
data_count: int = 0
|
||||
chunk_count: int = 0
|
||||
@@ -71,16 +91,30 @@ class StreamContext:
|
||||
self.chunk_count = 0
|
||||
self.data_count = 0
|
||||
self.has_completion = False
|
||||
self.collected_text = ""
|
||||
self._collected_text_parts = []
|
||||
self.input_tokens = 0
|
||||
self.output_tokens = 0
|
||||
self.cached_tokens = 0
|
||||
self.cache_creation_tokens = 0
|
||||
self.error_message = None
|
||||
self.status_code = 200
|
||||
self.first_byte_time_ms = None
|
||||
self.response_headers = {}
|
||||
self.provider_request_headers = {}
|
||||
self.provider_request_body = None
|
||||
self.response_id = None
|
||||
self.final_usage = None
|
||||
self.final_response = None
|
||||
|
||||
@property
|
||||
def collected_text(self) -> str:
|
||||
"""已收集的文本内容(按需拼接,避免在流式过程中频繁做字符串拷贝)"""
|
||||
return "".join(self._collected_text_parts)
|
||||
|
||||
def append_text(self, text: str) -> None:
|
||||
"""追加文本内容(仅在需要收集文本时调用)"""
|
||||
if text:
|
||||
self._collected_text_parts.append(text)
|
||||
|
||||
def update_provider_info(
|
||||
self,
|
||||
@@ -104,14 +138,40 @@ class StreamContext:
|
||||
cached_tokens: Optional[int] = None,
|
||||
cache_creation_tokens: Optional[int] = None,
|
||||
) -> None:
|
||||
"""更新 Token 使用统计"""
|
||||
if input_tokens is not None:
|
||||
"""
|
||||
更新 Token 使用统计
|
||||
|
||||
采用防御性更新策略:只有当新值 > 0 或当前值为 0 时才更新,避免用 0 覆盖已有的正确值。
|
||||
|
||||
设计原理:
|
||||
- 在流式响应中,某些事件可能不包含完整的 usage 信息(字段为 0 或不存在)
|
||||
- 后续事件可能会提供完整的统计数据
|
||||
- 通过这种策略,确保一旦获得非零值就保留它,不会被后续的 0 值覆盖
|
||||
|
||||
示例场景:
|
||||
- message_start 事件:input_tokens=100, output_tokens=0
|
||||
- message_delta 事件:input_tokens=0, output_tokens=50
|
||||
- 最终结果:input_tokens=100, output_tokens=50
|
||||
|
||||
注意事项:
|
||||
- 此策略假设初始值为 0 是正确的默认状态
|
||||
- 如果需要将已有值重置为 0,请直接修改实例属性(不使用此方法)
|
||||
|
||||
Args:
|
||||
input_tokens: 输入 tokens 数量
|
||||
output_tokens: 输出 tokens 数量
|
||||
cached_tokens: 缓存命中 tokens 数量
|
||||
cache_creation_tokens: 缓存创建 tokens 数量
|
||||
"""
|
||||
if input_tokens is not None and (input_tokens > 0 or self.input_tokens == 0):
|
||||
self.input_tokens = input_tokens
|
||||
if output_tokens is not None:
|
||||
if output_tokens is not None and (output_tokens > 0 or self.output_tokens == 0):
|
||||
self.output_tokens = output_tokens
|
||||
if cached_tokens is not None:
|
||||
if cached_tokens is not None and (cached_tokens > 0 or self.cached_tokens == 0):
|
||||
self.cached_tokens = cached_tokens
|
||||
if cache_creation_tokens is not None:
|
||||
if cache_creation_tokens is not None and (
|
||||
cache_creation_tokens > 0 or self.cache_creation_tokens == 0
|
||||
):
|
||||
self.cache_creation_tokens = cache_creation_tokens
|
||||
|
||||
def mark_failed(self, status_code: int, error_message: str) -> None:
|
||||
@@ -119,6 +179,19 @@ class StreamContext:
|
||||
self.status_code = status_code
|
||||
self.error_message = error_message
|
||||
|
||||
def record_first_byte_time(self, start_time: float) -> None:
|
||||
"""
|
||||
记录首字时间 (TTFB - Time To First Byte)
|
||||
|
||||
应在第一次向客户端发送数据时调用。
|
||||
如果已记录过,则不会覆盖(避免重试时重复记录)。
|
||||
|
||||
Args:
|
||||
start_time: 请求开始时间 (time.time())
|
||||
"""
|
||||
if self.first_byte_time_ms is None:
|
||||
self.first_byte_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
def is_success(self) -> bool:
|
||||
"""检查请求是否成功"""
|
||||
return self.status_code < 400
|
||||
@@ -145,10 +218,22 @@ class StreamContext:
|
||||
获取日志摘要
|
||||
|
||||
用于请求完成/失败时的日志输出。
|
||||
包含首字时间 (TTFB) 和总响应时间,分两行显示。
|
||||
"""
|
||||
status = "OK" if self.is_success() else "FAIL"
|
||||
return (
|
||||
|
||||
# 第一行:基本信息 + 首字时间
|
||||
line1 = (
|
||||
f"[{status}] {request_id[:8]} | {self.model} | "
|
||||
f"{self.provider_name or 'unknown'} | {response_time_ms}ms | "
|
||||
f"{self.provider_name or 'unknown'}"
|
||||
)
|
||||
if self.first_byte_time_ms is not None:
|
||||
line1 += f" | TTFB: {self.first_byte_time_ms}ms"
|
||||
|
||||
# 第二行:总响应时间 + tokens
|
||||
line2 = (
|
||||
f" Total: {response_time_ms}ms | "
|
||||
f"in:{self.input_tokens} out:{self.output_tokens}"
|
||||
)
|
||||
|
||||
return f"{line1}\n{line2}"
|
||||
|
||||
@@ -9,7 +9,9 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import codecs
|
||||
import json
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Callable, Optional
|
||||
|
||||
import httpx
|
||||
@@ -36,6 +38,8 @@ class StreamProcessor:
|
||||
request_id: str,
|
||||
default_parser: ResponseParser,
|
||||
on_streaming_start: Optional[Callable[[], None]] = None,
|
||||
*,
|
||||
collect_text: bool = False,
|
||||
):
|
||||
"""
|
||||
初始化流处理器
|
||||
@@ -48,6 +52,7 @@ class StreamProcessor:
|
||||
self.request_id = request_id
|
||||
self.default_parser = default_parser
|
||||
self.on_streaming_start = on_streaming_start
|
||||
self.collect_text = collect_text
|
||||
|
||||
def get_parser_for_provider(self, ctx: StreamContext) -> ResponseParser:
|
||||
"""
|
||||
@@ -112,9 +117,10 @@ class StreamProcessor:
|
||||
)
|
||||
|
||||
# 提取文本
|
||||
text = parser.extract_text_content(data)
|
||||
if text:
|
||||
ctx.collected_text += text
|
||||
if self.collect_text:
|
||||
text = parser.extract_text_content(data)
|
||||
if text:
|
||||
ctx.append_text(text)
|
||||
|
||||
# 检查完成
|
||||
event_type = event_name or data.get("type", "")
|
||||
@@ -123,7 +129,7 @@ class StreamProcessor:
|
||||
|
||||
async def prefetch_and_check_error(
|
||||
self,
|
||||
line_iterator: Any,
|
||||
byte_iterator: Any,
|
||||
provider: Provider,
|
||||
endpoint: ProviderEndpoint,
|
||||
ctx: StreamContext,
|
||||
@@ -136,97 +142,126 @@ class StreamProcessor:
|
||||
这种情况需要在流开始输出之前检测,以便触发重试逻辑。
|
||||
|
||||
Args:
|
||||
line_iterator: 行迭代器
|
||||
byte_iterator: 字节流迭代器
|
||||
provider: Provider 对象
|
||||
endpoint: Endpoint 对象
|
||||
ctx: 流式上下文
|
||||
max_prefetch_lines: 最多预读行数
|
||||
|
||||
Returns:
|
||||
预读的行列表
|
||||
预读的字节块列表
|
||||
|
||||
Raises:
|
||||
EmbeddedErrorException: 如果检测到嵌套错误
|
||||
"""
|
||||
prefetched_lines: list = []
|
||||
prefetched_chunks: list = []
|
||||
parser = self.get_parser_for_provider(ctx)
|
||||
buffer = b""
|
||||
line_count = 0
|
||||
should_stop = False
|
||||
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
|
||||
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
||||
|
||||
try:
|
||||
line_count = 0
|
||||
async for line in line_iterator:
|
||||
prefetched_lines.append(line)
|
||||
line_count += 1
|
||||
async for chunk in byte_iterator:
|
||||
prefetched_chunks.append(chunk)
|
||||
buffer += chunk
|
||||
|
||||
normalized_line = line.rstrip("\r")
|
||||
if not normalized_line or normalized_line.startswith(":"):
|
||||
if line_count >= max_prefetch_lines:
|
||||
# 尝试按行解析缓冲区
|
||||
while b"\n" in buffer:
|
||||
line_bytes, buffer = buffer.split(b"\n", 1)
|
||||
try:
|
||||
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
|
||||
line = decoder.decode(line_bytes + b"\n", False).rstrip("\r\n")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[{self.request_id}] 预读时 UTF-8 解码失败: {e}, "
|
||||
f"bytes={line_bytes[:50]!r}"
|
||||
)
|
||||
continue
|
||||
|
||||
line_count += 1
|
||||
|
||||
# 跳过空行和注释行
|
||||
if not line or line.startswith(":"):
|
||||
if line_count >= max_prefetch_lines:
|
||||
should_stop = True
|
||||
break
|
||||
continue
|
||||
|
||||
# 尝试解析 SSE 数据
|
||||
data_str = line
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
|
||||
if data_str == "[DONE]":
|
||||
should_stop = True
|
||||
break
|
||||
continue
|
||||
|
||||
# 尝试解析 SSE 数据
|
||||
data_str = normalized_line
|
||||
if normalized_line.startswith("data: "):
|
||||
data_str = normalized_line[6:]
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
if line_count >= max_prefetch_lines:
|
||||
should_stop = True
|
||||
break
|
||||
continue
|
||||
|
||||
if data_str == "[DONE]":
|
||||
# 使用解析器检查是否为错误响应
|
||||
if isinstance(data, dict) and parser.is_error_response(data):
|
||||
parsed = parser.parse_response(data, 200)
|
||||
logger.warning(
|
||||
f" [{self.request_id}] 检测到嵌套错误: "
|
||||
f"Provider={provider.name}, "
|
||||
f"error_type={parsed.error_type}, "
|
||||
f"message={parsed.error_message}"
|
||||
)
|
||||
raise EmbeddedErrorException(
|
||||
provider_name=str(provider.name),
|
||||
error_code=(
|
||||
int(parsed.error_type)
|
||||
if parsed.error_type and parsed.error_type.isdigit()
|
||||
else None
|
||||
),
|
||||
error_message=parsed.error_message,
|
||||
error_status=parsed.error_type,
|
||||
)
|
||||
|
||||
# 预读到有效数据,没有错误,停止预读
|
||||
should_stop = True
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
if line_count >= max_prefetch_lines:
|
||||
break
|
||||
continue
|
||||
|
||||
# 使用解析器检查是否为错误响应
|
||||
if isinstance(data, dict) and parser.is_error_response(data):
|
||||
parsed = parser.parse_response(data, 200)
|
||||
logger.warning(
|
||||
f" [{self.request_id}] 检测到嵌套错误: "
|
||||
f"Provider={provider.name}, "
|
||||
f"error_type={parsed.error_type}, "
|
||||
f"message={parsed.error_message}"
|
||||
)
|
||||
raise EmbeddedErrorException(
|
||||
provider_name=str(provider.name),
|
||||
error_code=(
|
||||
int(parsed.error_type)
|
||||
if parsed.error_type and parsed.error_type.isdigit()
|
||||
else None
|
||||
),
|
||||
error_message=parsed.error_message,
|
||||
error_status=parsed.error_type,
|
||||
)
|
||||
|
||||
# 预读到有效数据,没有错误,停止预读
|
||||
break
|
||||
if should_stop or line_count >= max_prefetch_lines:
|
||||
break
|
||||
|
||||
except EmbeddedErrorException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.debug(f" [{self.request_id}] 预读流时发生异常: {e}")
|
||||
|
||||
return prefetched_lines
|
||||
return prefetched_chunks
|
||||
|
||||
async def create_response_stream(
|
||||
self,
|
||||
ctx: StreamContext,
|
||||
line_iterator: Any,
|
||||
byte_iterator: Any,
|
||||
response_ctx: Any,
|
||||
http_client: httpx.AsyncClient,
|
||||
prefetched_lines: Optional[list] = None,
|
||||
prefetched_chunks: Optional[list] = None,
|
||||
*,
|
||||
start_time: Optional[float] = None,
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""
|
||||
创建响应流生成器
|
||||
|
||||
统一的流生成器,支持带预读数据和不带预读数据两种情况。
|
||||
从字节流中解析 SSE 数据并转发,支持预读数据。
|
||||
|
||||
Args:
|
||||
ctx: 流式上下文
|
||||
line_iterator: 行迭代器
|
||||
byte_iterator: 字节流迭代器
|
||||
response_ctx: HTTP 响应上下文管理器
|
||||
http_client: HTTP 客户端
|
||||
prefetched_lines: 预读的行列表(可选)
|
||||
prefetched_chunks: 预读的字节块列表(可选)
|
||||
start_time: 请求开始时间,用于计算 TTFB(可选)
|
||||
|
||||
Yields:
|
||||
编码后的响应数据块
|
||||
@@ -234,25 +269,82 @@ class StreamProcessor:
|
||||
try:
|
||||
sse_parser = SSEEventParser()
|
||||
streaming_started = False
|
||||
buffer = b""
|
||||
# 使用增量解码器处理跨 chunk 的 UTF-8 字符
|
||||
decoder = codecs.getincrementaldecoder("utf-8")(errors="replace")
|
||||
|
||||
# 处理预读数据
|
||||
if prefetched_lines:
|
||||
if prefetched_chunks:
|
||||
if not streaming_started and self.on_streaming_start:
|
||||
self.on_streaming_start()
|
||||
streaming_started = True
|
||||
|
||||
for line in prefetched_lines:
|
||||
for chunk in self._process_line(ctx, sse_parser, line):
|
||||
yield chunk
|
||||
for chunk in prefetched_chunks:
|
||||
# 记录首字时间 (TTFB) - 在 yield 之前记录
|
||||
if start_time is not None:
|
||||
ctx.record_first_byte_time(start_time)
|
||||
start_time = None # 只记录一次
|
||||
|
||||
# 把原始数据转发给客户端
|
||||
yield chunk
|
||||
|
||||
buffer += chunk
|
||||
# 处理缓冲区中的完整行
|
||||
while b"\n" in buffer:
|
||||
line_bytes, buffer = buffer.split(b"\n", 1)
|
||||
try:
|
||||
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
|
||||
line = decoder.decode(line_bytes + b"\n", False)
|
||||
self._process_line(ctx, sse_parser, line)
|
||||
except Exception as e:
|
||||
# 解码失败,记录警告但继续处理
|
||||
logger.warning(
|
||||
f"[{self.request_id}] UTF-8 解码失败: {e}, "
|
||||
f"bytes={line_bytes[:50]!r}"
|
||||
)
|
||||
continue
|
||||
|
||||
# 处理剩余的流数据
|
||||
async for line in line_iterator:
|
||||
async for chunk in byte_iterator:
|
||||
if not streaming_started and self.on_streaming_start:
|
||||
self.on_streaming_start()
|
||||
streaming_started = True
|
||||
|
||||
for chunk in self._process_line(ctx, sse_parser, line):
|
||||
yield chunk
|
||||
# 记录首字时间 (TTFB) - 在 yield 之前记录(如果预读数据为空)
|
||||
if start_time is not None:
|
||||
ctx.record_first_byte_time(start_time)
|
||||
start_time = None # 只记录一次
|
||||
|
||||
# 原始数据透传
|
||||
yield chunk
|
||||
|
||||
buffer += chunk
|
||||
# 处理缓冲区中的完整行
|
||||
while b"\n" in buffer:
|
||||
line_bytes, buffer = buffer.split(b"\n", 1)
|
||||
try:
|
||||
# 使用增量解码器,可以正确处理跨 chunk 的多字节字符
|
||||
line = decoder.decode(line_bytes + b"\n", False)
|
||||
self._process_line(ctx, sse_parser, line)
|
||||
except Exception as e:
|
||||
# 解码失败,记录警告但继续处理
|
||||
logger.warning(
|
||||
f"[{self.request_id}] UTF-8 解码失败: {e}, "
|
||||
f"bytes={line_bytes[:50]!r}"
|
||||
)
|
||||
continue
|
||||
|
||||
# 处理剩余的缓冲区数据(如果有未完成的行)
|
||||
if buffer:
|
||||
try:
|
||||
# 使用 final=True 处理最后的不完整字符
|
||||
line = decoder.decode(buffer, True)
|
||||
self._process_line(ctx, sse_parser, line)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[{self.request_id}] 处理剩余缓冲区失败: {e}, "
|
||||
f"bytes={buffer[:50]!r}"
|
||||
)
|
||||
|
||||
# 处理剩余事件
|
||||
for event in sse_parser.flush():
|
||||
@@ -268,7 +360,7 @@ class StreamProcessor:
|
||||
ctx: StreamContext,
|
||||
sse_parser: SSEEventParser,
|
||||
line: str,
|
||||
) -> list[bytes]:
|
||||
) -> None:
|
||||
"""
|
||||
处理单行数据
|
||||
|
||||
@@ -276,26 +368,17 @@ class StreamProcessor:
|
||||
ctx: 流式上下文
|
||||
sse_parser: SSE 解析器
|
||||
line: 原始行数据
|
||||
|
||||
Returns:
|
||||
要发送的数据块列表
|
||||
"""
|
||||
result: list[bytes] = []
|
||||
normalized_line = line.rstrip("\r")
|
||||
# SSEEventParser 以“去掉换行符”的单行文本作为输入;这里统一剔除 CR/LF,
|
||||
# 避免把空行误判成 "\n" 并导致事件边界解析错误。
|
||||
normalized_line = line.rstrip("\r\n")
|
||||
events = sse_parser.feed_line(normalized_line)
|
||||
|
||||
if normalized_line == "":
|
||||
for event in events:
|
||||
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
||||
result.append(b"\n")
|
||||
else:
|
||||
if normalized_line != "":
|
||||
ctx.chunk_count += 1
|
||||
result.append((line + "\n").encode("utf-8"))
|
||||
|
||||
for event in events:
|
||||
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
||||
|
||||
return result
|
||||
for event in events:
|
||||
self.handle_sse_event(ctx, event.get("event"), event.get("data") or "")
|
||||
|
||||
async def create_monitored_stream(
|
||||
self,
|
||||
@@ -317,16 +400,26 @@ class StreamProcessor:
|
||||
响应数据块
|
||||
"""
|
||||
try:
|
||||
# 断连检查频率:每次 await 都会引入调度开销,过于频繁会让流式"发一段停一段"
|
||||
# 这里按时间间隔节流,兼顾及时停止上游读取与吞吐平滑性。
|
||||
next_disconnect_check_at = 0.0
|
||||
disconnect_check_interval_s = 0.25
|
||||
|
||||
async for chunk in stream_generator:
|
||||
if await is_disconnected():
|
||||
logger.warning(f"ID:{self.request_id} | Client disconnected")
|
||||
ctx.status_code = 499 # Client Closed Request
|
||||
ctx.error_message = "client_disconnected"
|
||||
break
|
||||
now = time.monotonic()
|
||||
if now >= next_disconnect_check_at:
|
||||
next_disconnect_check_at = now + disconnect_check_interval_s
|
||||
if await is_disconnected():
|
||||
logger.warning(f"ID:{self.request_id} | Client disconnected")
|
||||
ctx.status_code = 499 # Client Closed Request
|
||||
ctx.error_message = "client_disconnected"
|
||||
|
||||
break
|
||||
yield chunk
|
||||
except asyncio.CancelledError:
|
||||
ctx.status_code = 499
|
||||
ctx.error_message = "client_disconnected"
|
||||
|
||||
raise
|
||||
except Exception as e:
|
||||
ctx.status_code = 500
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -57,7 +58,7 @@ class StreamTelemetryRecorder:
|
||||
ctx: StreamContext,
|
||||
original_headers: Dict[str, str],
|
||||
original_request_body: Dict[str, Any],
|
||||
response_time_ms: int,
|
||||
start_time: float,
|
||||
) -> None:
|
||||
"""
|
||||
记录流式统计信息
|
||||
@@ -66,11 +67,15 @@ class StreamTelemetryRecorder:
|
||||
ctx: 流式上下文
|
||||
original_headers: 原始请求头
|
||||
original_request_body: 原始请求体
|
||||
response_time_ms: 响应时间(毫秒)
|
||||
start_time: 请求开始时间 (time.time())
|
||||
"""
|
||||
bg_db = None
|
||||
|
||||
try:
|
||||
# 在流结束后计算响应时间,与首字时间使用相同的时间基准
|
||||
# 注意:不要把统计延迟(stream_stats_delay)算进响应时间里
|
||||
response_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
await asyncio.sleep(config.stream_stats_delay) # 等待流完全关闭
|
||||
|
||||
if not ctx.provider_name:
|
||||
@@ -155,6 +160,7 @@ class StreamTelemetryRecorder:
|
||||
input_tokens=ctx.input_tokens,
|
||||
output_tokens=ctx.output_tokens,
|
||||
response_time_ms=response_time_ms,
|
||||
first_byte_time_ms=ctx.first_byte_time_ms, # 传递首字时间
|
||||
status_code=ctx.status_code,
|
||||
request_headers=original_headers,
|
||||
request_body=actual_request_body,
|
||||
|
||||
55
src/api/handlers/base/utils.py
Normal file
55
src/api/handlers/base/utils.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
Handler 基础工具函数
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int:
|
||||
"""
|
||||
提取缓存创建 tokens(兼容新旧格式)
|
||||
|
||||
Claude API 在不同版本中使用了不同的字段名来表示缓存创建 tokens:
|
||||
- 新格式(2024年后):使用 claude_cache_creation_5_m_tokens 和
|
||||
claude_cache_creation_1_h_tokens 分别表示 5 分钟和 1 小时缓存
|
||||
- 旧格式:使用 cache_creation_input_tokens 表示总的缓存创建 tokens
|
||||
|
||||
此函数自动检测并适配两种格式,优先使用新格式。
|
||||
|
||||
Args:
|
||||
usage: API 响应中的 usage 字典
|
||||
|
||||
Returns:
|
||||
缓存创建 tokens 总数
|
||||
"""
|
||||
# 检查新格式字段是否存在(而非值是否为 0)
|
||||
# 如果字段存在,即使值为 0 也是合法的,不应 fallback 到旧格式
|
||||
has_new_format = (
|
||||
"claude_cache_creation_5_m_tokens" in usage
|
||||
or "claude_cache_creation_1_h_tokens" in usage
|
||||
)
|
||||
|
||||
if has_new_format:
|
||||
cache_5m = usage.get("claude_cache_creation_5_m_tokens", 0)
|
||||
cache_1h = usage.get("claude_cache_creation_1_h_tokens", 0)
|
||||
return int(cache_5m) + int(cache_1h)
|
||||
|
||||
# 回退到旧格式
|
||||
return int(usage.get("cache_creation_input_tokens", 0))
|
||||
|
||||
|
||||
def build_sse_headers(extra_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
|
||||
"""
|
||||
构建 SSE(text/event-stream)推荐响应头,用于减少代理缓冲带来的卡顿/成段输出。
|
||||
|
||||
说明:
|
||||
- Cache-Control: no-transform 可避免部分代理对流做压缩/改写导致缓冲
|
||||
- X-Accel-Buffering: no 可显式提示 Nginx 关闭缓冲(即使全局已关闭也无害)
|
||||
"""
|
||||
headers: Dict[str, str] = {
|
||||
"Cache-Control": "no-cache, no-transform",
|
||||
"X-Accel-Buffering": "no",
|
||||
}
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
return headers
|
||||
@@ -8,6 +8,7 @@ Claude Chat Handler - 基于通用 Chat Handler 基类的简化实现
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.api.handlers.base.chat_handler_base import ChatHandlerBase
|
||||
from src.api.handlers.base.utils import extract_cache_creation_tokens
|
||||
|
||||
|
||||
class ClaudeChatHandler(ChatHandlerBase):
|
||||
@@ -63,7 +64,7 @@ class ClaudeChatHandler(ChatHandlerBase):
|
||||
result["model"] = mapped_model
|
||||
return result
|
||||
|
||||
async def _convert_request(self, request):
|
||||
async def _convert_request(self, request: Any) -> Any:
|
||||
"""
|
||||
将请求转换为 Claude 格式
|
||||
|
||||
@@ -109,30 +110,18 @@ class ClaudeChatHandler(ChatHandlerBase):
|
||||
Claude 格式使用:
|
||||
- input_tokens / output_tokens
|
||||
- cache_creation_input_tokens / cache_read_input_tokens
|
||||
- 新格式:claude_cache_creation_5_m_tokens / claude_cache_creation_1_h_tokens
|
||||
"""
|
||||
usage = response.get("usage", {})
|
||||
|
||||
input_tokens = usage.get("input_tokens", 0)
|
||||
output_tokens = usage.get("output_tokens", 0)
|
||||
cache_creation_input_tokens = usage.get("cache_creation_input_tokens", 0)
|
||||
cache_read_input_tokens = usage.get("cache_read_input_tokens", 0)
|
||||
|
||||
# 处理新的 cache_creation 格式
|
||||
if "cache_creation" in usage:
|
||||
cache_creation_data = usage.get("cache_creation", {})
|
||||
if not cache_creation_input_tokens:
|
||||
cache_creation_input_tokens = cache_creation_data.get(
|
||||
"ephemeral_5m_input_tokens", 0
|
||||
) + cache_creation_data.get("ephemeral_1h_input_tokens", 0)
|
||||
|
||||
return {
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"cache_creation_input_tokens": cache_creation_input_tokens,
|
||||
"cache_read_input_tokens": cache_read_input_tokens,
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"cache_creation_input_tokens": extract_cache_creation_tokens(usage),
|
||||
"cache_read_input_tokens": usage.get("cache_read_input_tokens", 0),
|
||||
}
|
||||
|
||||
def _normalize_response(self, response: Dict) -> Dict:
|
||||
def _normalize_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
规范化 Claude 响应
|
||||
|
||||
@@ -143,8 +132,9 @@ class ClaudeChatHandler(ChatHandlerBase):
|
||||
规范化后的响应
|
||||
"""
|
||||
if self.response_normalizer and self.response_normalizer.should_normalize(response):
|
||||
return self.response_normalizer.normalize_claude_response(
|
||||
result: Dict[str, Any] = self.response_normalizer.normalize_claude_response(
|
||||
response_data=response,
|
||||
request_id=self.request_id,
|
||||
)
|
||||
return result
|
||||
return response
|
||||
|
||||
@@ -9,6 +9,8 @@ from __future__ import annotations
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from src.api.handlers.base.utils import extract_cache_creation_tokens
|
||||
|
||||
|
||||
class ClaudeStreamParser:
|
||||
"""
|
||||
@@ -108,7 +110,10 @@ class ClaudeStreamParser:
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads(line)
|
||||
result = json.loads(line)
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
return None
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
@@ -147,7 +152,8 @@ class ClaudeStreamParser:
|
||||
Returns:
|
||||
事件类型字符串
|
||||
"""
|
||||
return event.get("type")
|
||||
event_type = event.get("type")
|
||||
return str(event_type) if event_type is not None else None
|
||||
|
||||
def extract_text_delta(self, event: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
@@ -164,7 +170,8 @@ class ClaudeStreamParser:
|
||||
|
||||
delta = event.get("delta", {})
|
||||
if delta.get("type") == self.DELTA_TEXT:
|
||||
return delta.get("text")
|
||||
text = delta.get("text")
|
||||
return str(text) if text is not None else None
|
||||
|
||||
return None
|
||||
|
||||
@@ -188,7 +195,7 @@ class ClaudeStreamParser:
|
||||
return {
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
|
||||
"cache_creation_tokens": extract_cache_creation_tokens(usage),
|
||||
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
|
||||
}
|
||||
|
||||
@@ -199,7 +206,7 @@ class ClaudeStreamParser:
|
||||
return {
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"cache_creation_tokens": usage.get("cache_creation_input_tokens", 0),
|
||||
"cache_creation_tokens": extract_cache_creation_tokens(usage),
|
||||
"cache_read_tokens": usage.get("cache_read_input_tokens", 0),
|
||||
}
|
||||
|
||||
@@ -219,7 +226,8 @@ class ClaudeStreamParser:
|
||||
return None
|
||||
|
||||
message = event.get("message", {})
|
||||
return message.get("id")
|
||||
msg_id = message.get("id")
|
||||
return str(msg_id) if msg_id is not None else None
|
||||
|
||||
def extract_stop_reason(self, event: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
@@ -235,7 +243,8 @@ class ClaudeStreamParser:
|
||||
return None
|
||||
|
||||
delta = event.get("delta", {})
|
||||
return delta.get("stop_reason")
|
||||
reason = delta.get("stop_reason")
|
||||
return str(reason) if reason is not None else None
|
||||
|
||||
|
||||
__all__ = ["ClaudeStreamParser"]
|
||||
|
||||
@@ -11,6 +11,7 @@ from src.api.handlers.base.cli_handler_base import (
|
||||
CliMessageHandlerBase,
|
||||
StreamContext,
|
||||
)
|
||||
from src.api.handlers.base.utils import extract_cache_creation_tokens
|
||||
|
||||
|
||||
class ClaudeCliMessageHandler(CliMessageHandlerBase):
|
||||
@@ -95,11 +96,12 @@ class ClaudeCliMessageHandler(CliMessageHandlerBase):
|
||||
usage = message.get("usage", {})
|
||||
if usage:
|
||||
ctx.input_tokens = usage.get("input_tokens", 0)
|
||||
# Claude 的缓存 tokens 使用不同的字段名
|
||||
|
||||
cache_read = usage.get("cache_read_input_tokens", 0)
|
||||
if cache_read:
|
||||
ctx.cached_tokens = cache_read
|
||||
cache_creation = usage.get("cache_creation_input_tokens", 0)
|
||||
|
||||
cache_creation = extract_cache_creation_tokens(usage)
|
||||
if cache_creation:
|
||||
ctx.cache_creation_tokens = cache_creation
|
||||
|
||||
@@ -109,7 +111,7 @@ class ClaudeCliMessageHandler(CliMessageHandlerBase):
|
||||
if delta.get("type") == "text_delta":
|
||||
text = delta.get("text", "")
|
||||
if text:
|
||||
ctx.collected_text += text
|
||||
ctx.append_text(text)
|
||||
|
||||
# 处理消息增量(包含最终 usage)
|
||||
elif event_type == "message_delta":
|
||||
@@ -119,11 +121,15 @@ class ClaudeCliMessageHandler(CliMessageHandlerBase):
|
||||
ctx.input_tokens = usage["input_tokens"]
|
||||
if "output_tokens" in usage:
|
||||
ctx.output_tokens = usage["output_tokens"]
|
||||
# 更新缓存 tokens
|
||||
|
||||
# 更新缓存读取 tokens
|
||||
if "cache_read_input_tokens" in usage:
|
||||
ctx.cached_tokens = usage["cache_read_input_tokens"]
|
||||
if "cache_creation_input_tokens" in usage:
|
||||
ctx.cache_creation_tokens = usage["cache_creation_input_tokens"]
|
||||
|
||||
# 更新缓存创建 tokens
|
||||
cache_creation = extract_cache_creation_tokens(usage)
|
||||
if cache_creation > 0:
|
||||
ctx.cache_creation_tokens = cache_creation
|
||||
|
||||
# 检查是否结束
|
||||
delta = data.get("delta", {})
|
||||
|
||||
@@ -70,7 +70,7 @@ class ClaudeToGeminiConverter:
|
||||
return [{"text": content}]
|
||||
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
parts: List[Dict[str, Any]] = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
parts.append({"text": block})
|
||||
@@ -249,6 +249,8 @@ class GeminiToClaudeConverter:
|
||||
"RECITATION": "content_filtered",
|
||||
"OTHER": "stop_sequence",
|
||||
}
|
||||
if finish_reason is None:
|
||||
return "end_turn"
|
||||
return mapping.get(finish_reason, "end_turn")
|
||||
|
||||
def _create_empty_response(self) -> Dict[str, Any]:
|
||||
@@ -365,7 +367,7 @@ class OpenAIToGeminiConverter:
|
||||
return [{"text": content}]
|
||||
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
parts: List[Dict[str, Any]] = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
parts.append({"text": item})
|
||||
@@ -524,7 +526,7 @@ class GeminiToOpenAIConverter:
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
}
|
||||
|
||||
def _convert_finish_reason(self, finish_reason: Optional[str]) -> Optional[str]:
|
||||
def _convert_finish_reason(self, finish_reason: Optional[str]) -> str:
|
||||
"""转换停止原因"""
|
||||
mapping = {
|
||||
"STOP": "stop",
|
||||
@@ -533,6 +535,8 @@ class GeminiToOpenAIConverter:
|
||||
"RECITATION": "content_filter",
|
||||
"OTHER": "stop",
|
||||
}
|
||||
if finish_reason is None:
|
||||
return "stop"
|
||||
return mapping.get(finish_reason, "stop")
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ Gemini API 的流式响应格式与 Claude/OpenAI 不同:
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
|
||||
class GeminiStreamParser:
|
||||
@@ -32,18 +32,18 @@ class GeminiStreamParser:
|
||||
FINISH_REASON_RECITATION = "RECITATION"
|
||||
FINISH_REASON_OTHER = "OTHER"
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._buffer = ""
|
||||
self._in_array = False
|
||||
self._brace_depth = 0
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
"""重置解析器状态"""
|
||||
self._buffer = ""
|
||||
self._in_array = False
|
||||
self._brace_depth = 0
|
||||
|
||||
def parse_chunk(self, chunk: bytes | str) -> List[Dict[str, Any]]:
|
||||
def parse_chunk(self, chunk: Union[bytes, str]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
解析流式数据块
|
||||
|
||||
@@ -111,7 +111,10 @@ class GeminiStreamParser:
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads(line.strip().rstrip(","))
|
||||
result = json.loads(line.strip().rstrip(","))
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
return None
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
@@ -216,7 +219,8 @@ class GeminiStreamParser:
|
||||
"""
|
||||
candidates = event.get("candidates", [])
|
||||
if candidates:
|
||||
return candidates[0].get("finishReason")
|
||||
reason = candidates[0].get("finishReason")
|
||||
return str(reason) if reason is not None else None
|
||||
return None
|
||||
|
||||
def extract_text_delta(self, event: Dict[str, Any]) -> Optional[str]:
|
||||
@@ -285,7 +289,8 @@ class GeminiStreamParser:
|
||||
Returns:
|
||||
模型版本,如果没有返回 None
|
||||
"""
|
||||
return event.get("modelVersion")
|
||||
version = event.get("modelVersion")
|
||||
return str(version) if version is not None else None
|
||||
|
||||
def extract_safety_ratings(self, event: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
@@ -301,7 +306,10 @@ class GeminiStreamParser:
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
return candidates[0].get("safetyRatings")
|
||||
ratings = candidates[0].get("safetyRatings")
|
||||
if isinstance(ratings, list):
|
||||
return ratings
|
||||
return None
|
||||
|
||||
|
||||
__all__ = ["GeminiStreamParser"]
|
||||
|
||||
@@ -160,7 +160,7 @@ class GeminiCliMessageHandler(CliMessageHandlerBase):
|
||||
parts = content.get("parts", [])
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
ctx.collected_text += part["text"]
|
||||
ctx.append_text(part["text"])
|
||||
|
||||
# 检查结束原因
|
||||
finish_reason = candidate.get("finishReason")
|
||||
|
||||
@@ -78,7 +78,10 @@ class OpenAIStreamParser:
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads(line)
|
||||
result = json.loads(line)
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
return None
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
@@ -116,7 +119,8 @@ class OpenAIStreamParser:
|
||||
"""
|
||||
choices = chunk.get("choices", [])
|
||||
if choices:
|
||||
return choices[0].get("finish_reason")
|
||||
reason = choices[0].get("finish_reason")
|
||||
return str(reason) if reason is not None else None
|
||||
return None
|
||||
|
||||
def extract_text_delta(self, chunk: Dict[str, Any]) -> Optional[str]:
|
||||
@@ -156,7 +160,10 @@ class OpenAIStreamParser:
|
||||
return None
|
||||
|
||||
delta = choices[0].get("delta", {})
|
||||
return delta.get("tool_calls")
|
||||
tool_calls = delta.get("tool_calls")
|
||||
if isinstance(tool_calls, list):
|
||||
return tool_calls
|
||||
return None
|
||||
|
||||
def extract_role(self, chunk: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
@@ -175,7 +182,8 @@ class OpenAIStreamParser:
|
||||
return None
|
||||
|
||||
delta = choices[0].get("delta", {})
|
||||
return delta.get("role")
|
||||
role = delta.get("role")
|
||||
return str(role) if role is not None else None
|
||||
|
||||
|
||||
__all__ = ["OpenAIStreamParser"]
|
||||
|
||||
@@ -94,9 +94,9 @@ class OpenAICliMessageHandler(CliMessageHandlerBase):
|
||||
if event_type in ["response.output_text.delta", "response.outtext.delta"]:
|
||||
delta = data.get("delta")
|
||||
if isinstance(delta, str):
|
||||
ctx.collected_text += delta
|
||||
ctx.append_text(delta)
|
||||
elif isinstance(delta, dict) and "text" in delta:
|
||||
ctx.collected_text += delta["text"]
|
||||
ctx.append_text(delta["text"])
|
||||
|
||||
# 处理完成事件
|
||||
elif event_type == "response.completed":
|
||||
@@ -124,7 +124,7 @@ class OpenAICliMessageHandler(CliMessageHandlerBase):
|
||||
if content_item.get("type") == "output_text":
|
||||
text = content_item.get("text", "")
|
||||
if text:
|
||||
ctx.collected_text += text
|
||||
ctx.append_text(text)
|
||||
|
||||
# 备用:从顶层 usage 提取
|
||||
usage_obj = data.get("usage")
|
||||
|
||||
@@ -61,15 +61,18 @@ async def get_model_supported_capabilities(
|
||||
获取指定模型支持的能力列表
|
||||
|
||||
Args:
|
||||
model_name: 模型名称(如 claude-sonnet-4-20250514)
|
||||
model_name: 模型名称(如 claude-sonnet-4-20250514,必须是 GlobalModel.name)
|
||||
|
||||
Returns:
|
||||
模型支持的能力列表,以及每个能力的详细定义
|
||||
"""
|
||||
from src.services.model.mapping_resolver import get_model_mapping_resolver
|
||||
from src.models.database import GlobalModel
|
||||
|
||||
mapping_resolver = get_model_mapping_resolver()
|
||||
global_model = await mapping_resolver.get_global_model_by_request(db, model_name, None)
|
||||
global_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(GlobalModel.name == model_name, GlobalModel.is_active == True)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not global_model:
|
||||
return {
|
||||
|
||||
@@ -20,14 +20,12 @@ from src.models.api import (
|
||||
ProviderStatsResponse,
|
||||
PublicGlobalModelListResponse,
|
||||
PublicGlobalModelResponse,
|
||||
PublicModelMappingResponse,
|
||||
PublicModelResponse,
|
||||
PublicProviderResponse,
|
||||
)
|
||||
from src.models.database import (
|
||||
GlobalModel,
|
||||
Model,
|
||||
ModelMapping,
|
||||
Provider,
|
||||
ProviderEndpoint,
|
||||
RequestCandidate,
|
||||
@@ -72,24 +70,6 @@ async def get_public_models(
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
|
||||
|
||||
|
||||
@router.get("/model-mappings", response_model=List[PublicModelMappingResponse])
|
||||
async def get_public_model_mappings(
|
||||
request: Request,
|
||||
provider_id: Optional[str] = Query(None, description="提供商ID过滤"),
|
||||
alias: Optional[str] = Query(None, description="别名过滤(原source_model)"),
|
||||
skip: int = Query(0, description="跳过记录数"),
|
||||
limit: int = Query(100, description="返回记录数限制"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
adapter = PublicModelMappingsAdapter(
|
||||
provider_id=provider_id,
|
||||
alias=alias,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
)
|
||||
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=ApiMode.PUBLIC)
|
||||
|
||||
|
||||
@router.get("/stats", response_model=ProviderStatsResponse)
|
||||
async def get_public_stats(request: Request, db: Session = Depends(get_db)):
|
||||
adapter = PublicStatsAdapter()
|
||||
@@ -176,13 +156,6 @@ class PublicProvidersAdapter(PublicApiAdapter):
|
||||
.filter(and_(Model.provider_id == provider.id, Model.is_active.is_(True)))
|
||||
.count()
|
||||
)
|
||||
mappings_count = (
|
||||
db.query(ModelMapping)
|
||||
.filter(
|
||||
and_(ModelMapping.provider_id == provider.id, ModelMapping.is_active.is_(True))
|
||||
)
|
||||
.count()
|
||||
)
|
||||
endpoints_count = len(provider.endpoints) if provider.endpoints else 0
|
||||
active_endpoints_count = (
|
||||
sum(1 for ep in provider.endpoints if ep.is_active) if provider.endpoints else 0
|
||||
@@ -196,7 +169,6 @@ class PublicProvidersAdapter(PublicApiAdapter):
|
||||
provider_priority=provider.provider_priority,
|
||||
models_count=models_count,
|
||||
active_models_count=active_models_count,
|
||||
mappings_count=mappings_count,
|
||||
endpoints_count=endpoints_count,
|
||||
active_endpoints_count=active_endpoints_count,
|
||||
)
|
||||
@@ -238,9 +210,9 @@ class PublicModelsAdapter(PublicApiAdapter):
|
||||
provider_display_name=provider.display_name,
|
||||
name=unified_name,
|
||||
display_name=display_name,
|
||||
description=global_model.description if global_model else None,
|
||||
description=global_model.config.get("description") if global_model and global_model.config else None,
|
||||
tags=None,
|
||||
icon_url=global_model.icon_url if global_model else None,
|
||||
icon_url=global_model.config.get("icon_url") if global_model and global_model.config else None,
|
||||
input_price_per_1m=model.get_effective_input_price(),
|
||||
output_price_per_1m=model.get_effective_output_price(),
|
||||
cache_creation_price_per_1m=model.get_effective_cache_creation_price(),
|
||||
@@ -256,78 +228,6 @@ class PublicModelsAdapter(PublicApiAdapter):
|
||||
return response
|
||||
|
||||
|
||||
@dataclass
|
||||
class PublicModelMappingsAdapter(PublicApiAdapter):
|
||||
provider_id: Optional[str]
|
||||
alias: Optional[str] # 原 source_model,改为 alias
|
||||
skip: int
|
||||
limit: int
|
||||
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
logger.debug("公共API请求模型映射列表")
|
||||
|
||||
query = (
|
||||
db.query(ModelMapping, GlobalModel, Provider)
|
||||
.join(GlobalModel, ModelMapping.target_global_model_id == GlobalModel.id)
|
||||
.outerjoin(Provider, ModelMapping.provider_id == Provider.id)
|
||||
.filter(
|
||||
and_(
|
||||
ModelMapping.is_active.is_(True),
|
||||
GlobalModel.is_active.is_(True),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if self.provider_id is not None:
|
||||
provider_global_model_ids = (
|
||||
db.query(Model.global_model_id)
|
||||
.join(Provider, Model.provider_id == Provider.id)
|
||||
.filter(
|
||||
Provider.id == self.provider_id,
|
||||
Model.is_active.is_(True),
|
||||
Provider.is_active.is_(True),
|
||||
Model.global_model_id.isnot(None),
|
||||
)
|
||||
.distinct()
|
||||
)
|
||||
query = query.filter(
|
||||
or_(
|
||||
ModelMapping.provider_id == self.provider_id,
|
||||
and_(
|
||||
ModelMapping.provider_id.is_(None),
|
||||
ModelMapping.target_global_model_id.in_(provider_global_model_ids),
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
query = query.filter(ModelMapping.provider_id.is_(None))
|
||||
|
||||
if self.alias is not None:
|
||||
query = query.filter(ModelMapping.source_model.ilike(f"%{self.alias}%"))
|
||||
|
||||
results = query.offset(self.skip).limit(self.limit).all()
|
||||
response = []
|
||||
for mapping, global_model, provider in results:
|
||||
scope = "provider" if mapping.provider_id else "global"
|
||||
mapping_data = PublicModelMappingResponse(
|
||||
id=mapping.id,
|
||||
source_model=mapping.source_model,
|
||||
target_global_model_id=mapping.target_global_model_id,
|
||||
target_global_model_name=global_model.name if global_model else None,
|
||||
target_global_model_display_name=(
|
||||
global_model.display_name if global_model else None
|
||||
),
|
||||
provider_id=mapping.provider_id,
|
||||
scope=scope,
|
||||
is_active=mapping.is_active,
|
||||
)
|
||||
response.append(mapping_data.model_dump())
|
||||
|
||||
logger.debug(f"返回 {len(response)} 个模型映射")
|
||||
return response
|
||||
|
||||
|
||||
class PublicStatsAdapter(PublicApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
db = context.db
|
||||
@@ -339,9 +239,6 @@ class PublicStatsAdapter(PublicApiAdapter):
|
||||
.filter(and_(Model.is_active.is_(True), Provider.is_active.is_(True)))
|
||||
.count()
|
||||
)
|
||||
from ...models.database import ModelMapping
|
||||
|
||||
active_mappings = db.query(ModelMapping).filter(ModelMapping.is_active.is_(True)).count()
|
||||
formats = (
|
||||
db.query(Provider.api_format).filter(Provider.is_active.is_(True)).distinct().all()
|
||||
)
|
||||
@@ -351,7 +248,6 @@ class PublicStatsAdapter(PublicApiAdapter):
|
||||
active_providers=active_providers,
|
||||
total_models=active_models,
|
||||
active_models=active_models,
|
||||
total_mappings=active_mappings,
|
||||
supported_formats=supported_formats,
|
||||
)
|
||||
logger.debug("返回系统统计信息")
|
||||
@@ -378,7 +274,6 @@ class PublicSearchModelsAdapter(PublicApiAdapter):
|
||||
Model.provider_model_name.ilike(f"%{self.query}%")
|
||||
| GlobalModel.name.ilike(f"%{self.query}%")
|
||||
| GlobalModel.display_name.ilike(f"%{self.query}%")
|
||||
| GlobalModel.description.ilike(f"%{self.query}%")
|
||||
)
|
||||
query_stmt = query_stmt.filter(search_filter)
|
||||
if self.provider_id is not None:
|
||||
@@ -397,9 +292,9 @@ class PublicSearchModelsAdapter(PublicApiAdapter):
|
||||
provider_display_name=provider.display_name,
|
||||
name=unified_name,
|
||||
display_name=display_name,
|
||||
description=global_model.description if global_model else None,
|
||||
description=global_model.config.get("description") if global_model and global_model.config else None,
|
||||
tags=None,
|
||||
icon_url=global_model.icon_url if global_model else None,
|
||||
icon_url=global_model.config.get("icon_url") if global_model and global_model.config else None,
|
||||
input_price_per_1m=model.get_effective_input_price(),
|
||||
output_price_per_1m=model.get_effective_output_price(),
|
||||
cache_creation_price_per_1m=model.get_effective_cache_creation_price(),
|
||||
@@ -603,7 +498,6 @@ class PublicGlobalModelsAdapter(PublicApiAdapter):
|
||||
or_(
|
||||
GlobalModel.name.ilike(search_term),
|
||||
GlobalModel.display_name.ilike(search_term),
|
||||
GlobalModel.description.ilike(search_term),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -621,21 +515,11 @@ class PublicGlobalModelsAdapter(PublicApiAdapter):
|
||||
id=gm.id,
|
||||
name=gm.name,
|
||||
display_name=gm.display_name,
|
||||
description=gm.description,
|
||||
icon_url=gm.icon_url,
|
||||
is_active=gm.is_active,
|
||||
default_price_per_request=gm.default_price_per_request,
|
||||
default_tiered_pricing=gm.default_tiered_pricing,
|
||||
default_supports_vision=gm.default_supports_vision or False,
|
||||
default_supports_function_calling=gm.default_supports_function_calling or False,
|
||||
default_supports_streaming=(
|
||||
gm.default_supports_streaming
|
||||
if gm.default_supports_streaming is not None
|
||||
else True
|
||||
),
|
||||
default_supports_extended_thinking=gm.default_supports_extended_thinking
|
||||
or False,
|
||||
supported_capabilities=gm.supported_capabilities,
|
||||
config=gm.config,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -251,8 +251,8 @@ def _build_gemini_list_response(
|
||||
"version": "001",
|
||||
"displayName": m.display_name,
|
||||
"description": m.description or f"Model {m.id}",
|
||||
"inputTokenLimit": 128000,
|
||||
"outputTokenLimit": 8192,
|
||||
"inputTokenLimit": m.context_limit if m.context_limit is not None else 128000,
|
||||
"outputTokenLimit": m.output_limit if m.output_limit is not None else 8192,
|
||||
"supportedGenerationMethods": ["generateContent", "countTokens"],
|
||||
"temperature": 1.0,
|
||||
"maxTemperature": 2.0,
|
||||
@@ -297,8 +297,8 @@ def _build_gemini_model_response(model_info: ModelInfo) -> dict:
|
||||
"version": "001",
|
||||
"displayName": model_info.display_name,
|
||||
"description": model_info.description or f"Model {model_info.id}",
|
||||
"inputTokenLimit": 128000,
|
||||
"outputTokenLimit": 8192,
|
||||
"inputTokenLimit": model_info.context_limit if model_info.context_limit is not None else 128000,
|
||||
"outputTokenLimit": model_info.output_limit if model_info.output_limit is not None else 8192,
|
||||
"supportedGenerationMethods": ["generateContent", "countTokens"],
|
||||
"temperature": 1.0,
|
||||
"maxTemperature": 2.0,
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import and_, func
|
||||
from sqlalchemy.orm import Session, joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.api.base.adapter import ApiAdapter, ApiMode
|
||||
from src.api.base.authenticated_adapter import AuthenticatedApiAdapter
|
||||
@@ -713,7 +713,7 @@ class ListAvailableProvidersAdapter(AuthenticatedApiAdapter):
|
||||
async def handle(self, context): # type: ignore[override]
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from src.models.database import Model, ModelMapping, ProviderEndpoint
|
||||
from src.models.database import Model, ProviderEndpoint
|
||||
|
||||
db = context.db
|
||||
|
||||
@@ -765,53 +765,6 @@ class ListAvailableProvidersAdapter(AuthenticatedApiAdapter):
|
||||
}
|
||||
)
|
||||
|
||||
# 查询该 Provider 所有 Model 对应的 GlobalModel 的别名/映射
|
||||
provider_model_global_ids = {
|
||||
m.global_model_id for m in provider.models if m.global_model_id
|
||||
}
|
||||
if provider_model_global_ids:
|
||||
# 查询全局别名 + Provider 特定映射
|
||||
alias_mappings = (
|
||||
db.query(ModelMapping)
|
||||
.options(joinedload(ModelMapping.target_global_model))
|
||||
.filter(
|
||||
ModelMapping.target_global_model_id.in_(provider_model_global_ids),
|
||||
ModelMapping.is_active == True,
|
||||
(ModelMapping.provider_id == provider.id)
|
||||
| (ModelMapping.provider_id == None),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
for alias_obj in alias_mappings:
|
||||
# 为这个别名找到该 Provider 的 Model 实现
|
||||
model = next(
|
||||
(
|
||||
m
|
||||
for m in provider.models
|
||||
if m.global_model_id == alias_obj.target_global_model_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
if model:
|
||||
models_data.append(
|
||||
{
|
||||
"id": alias_obj.id,
|
||||
"name": alias_obj.source_model,
|
||||
"display_name": (
|
||||
alias_obj.target_global_model.display_name
|
||||
if alias_obj.target_global_model
|
||||
else alias_obj.source_model
|
||||
),
|
||||
"input_price_per_1m": model.input_price_per_1m,
|
||||
"output_price_per_1m": model.output_price_per_1m,
|
||||
"cache_creation_price_per_1m": model.cache_creation_price_per_1m,
|
||||
"cache_read_price_per_1m": model.cache_read_price_per_1m,
|
||||
"supports_vision": model.supports_vision,
|
||||
"supports_function_calling": model.supports_function_calling,
|
||||
"supports_streaming": model.supports_streaming,
|
||||
}
|
||||
)
|
||||
|
||||
result.append(
|
||||
{
|
||||
"id": provider.id,
|
||||
|
||||
@@ -14,7 +14,6 @@ class CacheTTL:
|
||||
# Provider/Model 缓存 - 配置变更不频繁
|
||||
PROVIDER = 300 # 5分钟
|
||||
MODEL = 300 # 5分钟
|
||||
MODEL_MAPPING = 300 # 5分钟
|
||||
|
||||
# 缓存亲和性 - 对应 provider_api_key.cache_ttl_minutes 默认值
|
||||
CACHE_AFFINITY = 300 # 5分钟
|
||||
@@ -33,9 +32,6 @@ class CacheSize:
|
||||
# 默认 LRU 缓存大小
|
||||
DEFAULT = 1000
|
||||
|
||||
# ModelMapping 缓存(可能有较多别名)
|
||||
MODEL_MAPPING = 2000
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# 并发和限流常量
|
||||
|
||||
@@ -120,6 +120,33 @@ class CacheService:
|
||||
logger.warning(f"缓存检查失败: {key} - {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def incr(key: str, ttl_seconds: Optional[int] = None) -> int:
|
||||
"""
|
||||
递增缓存值
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
ttl_seconds: 可选,如果提供则刷新 TTL
|
||||
|
||||
Returns:
|
||||
递增后的值,如果失败返回 0
|
||||
"""
|
||||
try:
|
||||
redis = await get_redis_client(require_redis=False)
|
||||
if not redis:
|
||||
return 0
|
||||
|
||||
result = await redis.incr(key)
|
||||
# 如果提供了 TTL,刷新过期时间
|
||||
if ttl_seconds is not None:
|
||||
await redis.expire(key, ttl_seconds)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"缓存递增失败: {key} - {e}")
|
||||
return 0
|
||||
|
||||
|
||||
# 缓存键前缀
|
||||
class CacheKeys:
|
||||
|
||||
@@ -115,7 +115,7 @@ class SyncLRUCache:
|
||||
"""删除缓存值(通过索引)"""
|
||||
self.delete(key)
|
||||
|
||||
def keys(self):
|
||||
def keys(self) -> list:
|
||||
"""返回所有未过期的 key"""
|
||||
with self._lock:
|
||||
now = time.time()
|
||||
|
||||
@@ -67,7 +67,7 @@ FILE_FORMAT = "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:
|
||||
logger.remove()
|
||||
|
||||
|
||||
def _log_filter(record):
|
||||
def _log_filter(record: dict) -> bool: # type: ignore[type-arg]
|
||||
return "watchfiles" not in record["name"]
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ if IS_DOCKER:
|
||||
sys.stdout,
|
||||
format=CONSOLE_FORMAT_PROD,
|
||||
level=LOG_LEVEL,
|
||||
filter=_log_filter,
|
||||
filter=_log_filter, # type: ignore[arg-type]
|
||||
colorize=False,
|
||||
)
|
||||
else:
|
||||
@@ -84,7 +84,7 @@ else:
|
||||
sys.stdout,
|
||||
format=CONSOLE_FORMAT_DEV,
|
||||
level=LOG_LEVEL,
|
||||
filter=_log_filter,
|
||||
filter=_log_filter, # type: ignore[arg-type]
|
||||
colorize=True,
|
||||
)
|
||||
|
||||
@@ -97,7 +97,7 @@ if not DISABLE_FILE_LOG:
|
||||
log_dir / "app.log",
|
||||
format=FILE_FORMAT,
|
||||
level="DEBUG",
|
||||
filter=_log_filter,
|
||||
filter=_log_filter, # type: ignore[arg-type]
|
||||
rotation="00:00",
|
||||
retention="30 days",
|
||||
compression="gz",
|
||||
@@ -110,7 +110,7 @@ if not DISABLE_FILE_LOG:
|
||||
log_dir / "error.log",
|
||||
format=FILE_FORMAT,
|
||||
level="ERROR",
|
||||
filter=_log_filter,
|
||||
filter=_log_filter, # type: ignore[arg-type]
|
||||
rotation="00:00",
|
||||
retention="30 days",
|
||||
compression="gz",
|
||||
|
||||
@@ -44,3 +44,24 @@ health_open_circuits = Gauge(
|
||||
"health_open_circuits",
|
||||
"Number of provider keys currently in circuit breaker open state",
|
||||
)
|
||||
|
||||
# 模型映射解析相关
|
||||
model_mapping_resolution_total = Counter(
|
||||
"model_mapping_resolution_total",
|
||||
"Total number of model mapping resolutions",
|
||||
["method", "cache_hit"],
|
||||
# method: direct_match, provider_model_name, alias, not_found
|
||||
# cache_hit: true, false
|
||||
)
|
||||
|
||||
model_mapping_resolution_duration_seconds = Histogram(
|
||||
"model_mapping_resolution_duration_seconds",
|
||||
"Duration of model mapping resolution in seconds",
|
||||
["method"],
|
||||
buckets=[0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0], # 1ms 到 1s
|
||||
)
|
||||
|
||||
model_mapping_conflict_total = Counter(
|
||||
"model_mapping_conflict_total",
|
||||
"Total number of mapping conflicts detected (same name maps to multiple GlobalModels)",
|
||||
)
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class ProviderHealthTracker:
|
||||
@@ -32,7 +32,7 @@ class ProviderHealthTracker:
|
||||
# 存储优先级调整
|
||||
self.priority_adjustments: Dict[str, int] = {}
|
||||
|
||||
def record_success(self, provider_name: str):
|
||||
def record_success(self, provider_name: str) -> None:
|
||||
"""记录成功的请求"""
|
||||
current_time = time.time()
|
||||
|
||||
@@ -47,7 +47,7 @@ class ProviderHealthTracker:
|
||||
if self.priority_adjustments.get(provider_name, 0) < 0:
|
||||
self.priority_adjustments[provider_name] += 1
|
||||
|
||||
def record_failure(self, provider_name: str):
|
||||
def record_failure(self, provider_name: str) -> None:
|
||||
"""记录失败的请求"""
|
||||
current_time = time.time()
|
||||
|
||||
@@ -93,7 +93,7 @@ class ProviderHealthTracker:
|
||||
"status": self._get_status_label(failure_rate, recent_failures),
|
||||
}
|
||||
|
||||
def _cleanup_old_records(self, provider_name: str, current_time: float):
|
||||
def _cleanup_old_records(self, provider_name: str, current_time: float) -> None:
|
||||
"""清理超出时间窗口的记录"""
|
||||
# 清理失败记录
|
||||
self.failures[provider_name] = [
|
||||
@@ -130,7 +130,7 @@ class ProviderHealthTracker:
|
||||
adjustment = self.get_priority_adjustment(provider_name)
|
||||
return adjustment > -3
|
||||
|
||||
def reset_provider_health(self, provider_name: str):
|
||||
def reset_provider_health(self, provider_name: str) -> None:
|
||||
"""重置提供商的健康状态(管理员手动操作)"""
|
||||
self.failures[provider_name] = []
|
||||
self.successes[provider_name] = []
|
||||
@@ -146,7 +146,7 @@ class SimpleProviderSelector:
|
||||
def __init__(self, health_tracker: ProviderHealthTracker):
|
||||
self.health_tracker = health_tracker
|
||||
|
||||
def select_provider(self, providers: list, specified_provider: Optional[str] = None):
|
||||
def select_provider(self, providers: list, specified_provider: Optional[str] = None) -> Any:
|
||||
"""
|
||||
选择提供商
|
||||
|
||||
|
||||
@@ -5,12 +5,12 @@
|
||||
|
||||
import asyncio
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, TypeVar
|
||||
from typing import Any, Callable, Coroutine, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
async def run_in_executor(func: Callable[..., T], *args, **kwargs) -> T:
|
||||
async def run_in_executor(func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
|
||||
"""
|
||||
在线程池中运行同步函数,避免阻塞事件循环
|
||||
|
||||
@@ -21,7 +21,7 @@ async def run_in_executor(func: Callable[..., T], *args, **kwargs) -> T:
|
||||
return await loop.run_in_executor(None, lambda: func(*args, **kwargs))
|
||||
|
||||
|
||||
def async_wrap_sync_db(func: Callable[..., T]) -> Callable[..., Any]:
|
||||
def async_wrap_sync_db(func: Callable[..., T]) -> Callable[..., Coroutine[Any, Any, T]]:
|
||||
"""
|
||||
装饰器:包装同步数据库函数为异步函数
|
||||
|
||||
@@ -35,7 +35,7 @@ def async_wrap_sync_db(func: Callable[..., T]) -> Callable[..., Any]:
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> T:
|
||||
return await run_in_executor(func, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -273,16 +273,17 @@ def get_db_url() -> str:
|
||||
|
||||
|
||||
def init_db():
|
||||
"""初始化数据库"""
|
||||
"""初始化数据库
|
||||
|
||||
注意:数据库表结构由 Alembic 管理,部署时请运行 ./migrate.sh
|
||||
"""
|
||||
logger.info("初始化数据库...")
|
||||
|
||||
# 确保引擎已创建
|
||||
engine = _ensure_engine()
|
||||
_ensure_engine()
|
||||
|
||||
# 创建所有表
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
# 数据库表已通过SQLAlchemy自动创建
|
||||
# 数据库表结构由 Alembic 迁移管理
|
||||
# 首次部署或更新后请运行: ./migrate.sh
|
||||
|
||||
db = _SessionLocal()
|
||||
try:
|
||||
@@ -347,7 +348,7 @@ def init_default_models(db: Session):
|
||||
"""初始化默认模型配置"""
|
||||
|
||||
# 注意:作为中转代理服务,不再预设模型配置
|
||||
# 模型配置应该通过 Model 和 ModelMapping 表动态管理
|
||||
# 模型配置应该通过 GlobalModel 和 Model 表动态管理
|
||||
# 这个函数保留用于未来可能的默认模型初始化
|
||||
pass
|
||||
|
||||
|
||||
@@ -2,4 +2,4 @@
|
||||
中间件模块
|
||||
"""
|
||||
|
||||
__all__ = []
|
||||
__all__: list[str] = []
|
||||
|
||||
@@ -334,7 +334,6 @@ class ProviderResponse(BaseModel):
|
||||
updated_at: datetime
|
||||
models_count: int = 0
|
||||
active_models_count: int = 0
|
||||
model_mappings_count: int = 0
|
||||
api_keys_count: int = 0
|
||||
|
||||
class Config:
|
||||
@@ -346,7 +345,11 @@ class ModelCreate(BaseModel):
|
||||
"""创建模型请求 - 价格和能力字段可选,为空时使用 GlobalModel 默认值"""
|
||||
|
||||
provider_model_name: str = Field(
|
||||
..., min_length=1, max_length=200, description="Provider 侧的模型名称"
|
||||
..., min_length=1, max_length=200, description="Provider 侧的主模型名称"
|
||||
)
|
||||
provider_model_aliases: Optional[List[dict]] = Field(
|
||||
None,
|
||||
description="模型名称别名列表,格式: [{'name': 'alias1', 'priority': 1}, ...]",
|
||||
)
|
||||
global_model_id: str = Field(..., description="关联的 GlobalModel ID(必填)")
|
||||
# 按次计费配置 - 可选,为空时使用 GlobalModel 默认值
|
||||
@@ -374,6 +377,10 @@ class ModelUpdate(BaseModel):
|
||||
"""更新模型请求"""
|
||||
|
||||
provider_model_name: Optional[str] = Field(None, min_length=1, max_length=200)
|
||||
provider_model_aliases: Optional[List[dict]] = Field(
|
||||
None,
|
||||
description="模型名称别名列表,格式: [{'name': 'alias1', 'priority': 1}, ...]",
|
||||
)
|
||||
global_model_id: Optional[str] = None
|
||||
# 按次计费配置
|
||||
price_per_request: Optional[float] = Field(None, ge=0, description="每次请求固定费用")
|
||||
@@ -398,6 +405,7 @@ class ModelResponse(BaseModel):
|
||||
provider_id: str
|
||||
global_model_id: Optional[str]
|
||||
provider_model_name: str
|
||||
provider_model_aliases: Optional[List[dict]] = None
|
||||
|
||||
# 按次计费配置
|
||||
price_per_request: Optional[float] = None
|
||||
@@ -465,54 +473,6 @@ class ModelDetailResponse(BaseModel):
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# ========== 模型映射 ==========
|
||||
class ModelMappingCreate(BaseModel):
|
||||
"""创建模型映射请求(源模型到目标模型的映射)"""
|
||||
|
||||
source_model: str = Field(..., min_length=1, max_length=200, description="源模型名或别名")
|
||||
target_global_model_id: str = Field(..., description="目标 GlobalModel ID")
|
||||
provider_id: Optional[str] = Field(None, description="Provider ID(为空时表示全局别名)")
|
||||
mapping_type: str = Field(
|
||||
"alias",
|
||||
description="映射类型:alias=按目标模型计费(别名),mapping=按源模型计费(降级映射)",
|
||||
)
|
||||
is_active: bool = Field(True, description="是否启用")
|
||||
|
||||
|
||||
class ModelMappingUpdate(BaseModel):
|
||||
"""更新模型映射请求"""
|
||||
|
||||
source_model: Optional[str] = Field(
|
||||
None, min_length=1, max_length=200, description="源模型名或别名"
|
||||
)
|
||||
target_global_model_id: Optional[str] = Field(None, description="目标 GlobalModel ID")
|
||||
provider_id: Optional[str] = Field(None, description="Provider ID(为空时表示全局别名)")
|
||||
mapping_type: Optional[str] = Field(
|
||||
None, description="映射类型:alias=按目标模型计费(别名),mapping=按源模型计费(降级映射)"
|
||||
)
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
|
||||
class ModelMappingResponse(BaseModel):
|
||||
"""模型映射响应"""
|
||||
|
||||
id: str
|
||||
source_model: str
|
||||
target_global_model_id: str
|
||||
target_global_model_name: Optional[str]
|
||||
target_global_model_display_name: Optional[str]
|
||||
provider_id: Optional[str]
|
||||
provider_name: Optional[str]
|
||||
scope: str = Field(..., description="global 或 provider")
|
||||
mapping_type: str = Field(..., description="映射类型:alias 或 mapping")
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# ========== 系统设置 ==========
|
||||
class SystemSettingsRequest(BaseModel):
|
||||
"""系统设置请求"""
|
||||
@@ -558,7 +518,6 @@ class PublicProviderResponse(BaseModel):
|
||||
# 统计信息
|
||||
models_count: int
|
||||
active_models_count: int
|
||||
mappings_count: int
|
||||
endpoints_count: int # 端点总数
|
||||
active_endpoints_count: int # 活跃端点数
|
||||
|
||||
@@ -587,19 +546,6 @@ class PublicModelResponse(BaseModel):
|
||||
is_active: bool = True
|
||||
|
||||
|
||||
class PublicModelMappingResponse(BaseModel):
|
||||
"""公开的模型映射信息响应"""
|
||||
|
||||
id: str
|
||||
source_model: str
|
||||
target_global_model_id: str
|
||||
target_global_model_name: Optional[str]
|
||||
target_global_model_display_name: Optional[str]
|
||||
provider_id: Optional[str] = None
|
||||
scope: str = Field(..., description="global 或 provider")
|
||||
is_active: bool
|
||||
|
||||
|
||||
class ProviderStatsResponse(BaseModel):
|
||||
"""提供商统计信息响应"""
|
||||
|
||||
@@ -607,7 +553,6 @@ class ProviderStatsResponse(BaseModel):
|
||||
active_providers: int
|
||||
total_models: int
|
||||
active_models: int
|
||||
total_mappings: int
|
||||
supported_formats: List[str]
|
||||
|
||||
|
||||
@@ -617,20 +562,15 @@ class PublicGlobalModelResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
display_name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
icon_url: Optional[str] = None
|
||||
is_active: bool = True
|
||||
# 按次计费配置
|
||||
default_price_per_request: Optional[float] = None
|
||||
# 阶梯计费配置
|
||||
default_tiered_pricing: Optional[dict] = None
|
||||
# 默认能力
|
||||
default_supports_vision: bool = False
|
||||
default_supports_function_calling: bool = False
|
||||
default_supports_streaming: bool = True
|
||||
default_supports_extended_thinking: bool = False
|
||||
# Key 能力配置
|
||||
supported_capabilities: Optional[List[str]] = None
|
||||
# 模型配置(JSON)
|
||||
config: Optional[dict] = None
|
||||
|
||||
|
||||
class PublicGlobalModelListResponse(BaseModel):
|
||||
|
||||
@@ -7,6 +7,7 @@ import secrets
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum as PyEnum
|
||||
from typing import Optional
|
||||
|
||||
import bcrypt
|
||||
from sqlalchemy import (
|
||||
@@ -25,6 +26,7 @@ from sqlalchemy import (
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
@@ -306,7 +308,8 @@ class Usage(Base):
|
||||
is_stream = Column(Boolean, default=False) # 是否为流式请求
|
||||
status_code = Column(Integer)
|
||||
error_message = Column(Text, nullable=True)
|
||||
response_time_ms = Column(Integer) # 响应时间(毫秒)
|
||||
response_time_ms = Column(Integer) # 总响应时间(毫秒)
|
||||
first_byte_time_ms = Column(Integer, nullable=True) # 首字时间/TTFB(毫秒)
|
||||
|
||||
# 请求状态追踪
|
||||
# pending: 请求开始处理中
|
||||
@@ -491,9 +494,6 @@ class Provider(Base):
|
||||
|
||||
# 关系
|
||||
models = relationship("Model", back_populates="provider", cascade="all, delete-orphan")
|
||||
model_mappings = relationship(
|
||||
"ModelMapping", back_populates="provider", cascade="all, delete-orphan"
|
||||
)
|
||||
endpoints = relationship(
|
||||
"ProviderEndpoint", back_populates="provider", cascade="all, delete-orphan"
|
||||
)
|
||||
@@ -577,11 +577,6 @@ class GlobalModel(Base):
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True)
|
||||
name = Column(String(100), unique=True, nullable=False, index=True) # 统一模型名(唯一)
|
||||
display_name = Column(String(100), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
# 模型元数据
|
||||
icon_url = Column(String(500), nullable=True)
|
||||
official_url = Column(String(500), nullable=True) # 官方文档链接
|
||||
|
||||
# 按次计费配置(每次请求的固定费用,美元)- 可选,与按 token 计费叠加
|
||||
default_price_per_request = Column(Float, nullable=True, default=None) # 每次请求固定费用
|
||||
@@ -607,17 +602,34 @@ class GlobalModel(Base):
|
||||
# }
|
||||
default_tiered_pricing = Column(JSON, nullable=False)
|
||||
|
||||
# 默认能力配置 - Provider 可覆盖
|
||||
default_supports_vision = Column(Boolean, default=False, nullable=True)
|
||||
default_supports_function_calling = Column(Boolean, default=False, nullable=True)
|
||||
default_supports_streaming = Column(Boolean, default=True, nullable=True)
|
||||
default_supports_extended_thinking = Column(Boolean, default=False, nullable=True)
|
||||
default_supports_image_generation = Column(Boolean, default=False, nullable=True)
|
||||
|
||||
# Key 能力配置 - 模型支持的能力列表(如 ["cache_1h", "context_1m"])
|
||||
# Key 只能启用模型支持的能力
|
||||
supported_capabilities = Column(JSON, nullable=True, default=list)
|
||||
|
||||
# 模型配置(JSON格式)- 包含能力、规格、元信息等
|
||||
# 结构示例:
|
||||
# {
|
||||
# # 能力配置
|
||||
# "streaming": true,
|
||||
# "vision": true,
|
||||
# "function_calling": true,
|
||||
# "extended_thinking": false,
|
||||
# "image_generation": false,
|
||||
# # 规格参数
|
||||
# "context_limit": 200000,
|
||||
# "output_limit": 8192,
|
||||
# # 元信息
|
||||
# "description": "...",
|
||||
# "icon_url": "...",
|
||||
# "official_url": "...",
|
||||
# "knowledge_cutoff": "2024-04",
|
||||
# "family": "claude-3.5",
|
||||
# "release_date": "2024-10-22",
|
||||
# "input_modalities": ["text", "image"],
|
||||
# "output_modalities": ["text"],
|
||||
# }
|
||||
config = Column(JSONB, nullable=True, default=dict)
|
||||
|
||||
# 状态
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
|
||||
@@ -656,7 +668,11 @@ class Model(Base):
|
||||
global_model_id = Column(String(36), ForeignKey("global_models.id"), nullable=False, index=True)
|
||||
|
||||
# Provider 映射配置
|
||||
provider_model_name = Column(String(200), nullable=False) # Provider 侧的模型名称
|
||||
provider_model_name = Column(String(200), nullable=False) # Provider 侧的主模型名称
|
||||
# 模型名称别名列表(带优先级),用于同一模型在 Provider 侧有多个名称变体的场景
|
||||
# 格式: [{"name": "Claude-Sonnet-4.5", "priority": 1}, {"name": "Claude-Sonnet-4-5", "priority": 2}]
|
||||
# 为空时只使用 provider_model_name
|
||||
provider_model_aliases = Column(JSON, nullable=True, default=None)
|
||||
|
||||
# 按次计费配置(每次请求的固定费用,美元)- 可为空,为空时使用 GlobalModel 的默认值
|
||||
price_per_request = Column(Float, nullable=True) # 每次请求固定费用
|
||||
@@ -764,11 +780,22 @@ class Model(Base):
|
||||
"""获取有效的能力配置(通用辅助方法)"""
|
||||
local_value = getattr(self, attr_name, None)
|
||||
if local_value is not None:
|
||||
return local_value
|
||||
return bool(local_value)
|
||||
if self.global_model:
|
||||
global_value = getattr(self.global_model, f"default_{attr_name}", None)
|
||||
if global_value is not None:
|
||||
return global_value
|
||||
config_key_map = {
|
||||
"supports_vision": "vision",
|
||||
"supports_function_calling": "function_calling",
|
||||
"supports_streaming": "streaming",
|
||||
"supports_extended_thinking": "extended_thinking",
|
||||
"supports_image_generation": "image_generation",
|
||||
}
|
||||
config_key = config_key_map.get(attr_name)
|
||||
if config_key:
|
||||
global_config = getattr(self.global_model, "config", None)
|
||||
if isinstance(global_config, dict):
|
||||
global_value = global_config.get(config_key)
|
||||
if global_value is not None:
|
||||
return bool(global_value)
|
||||
return default
|
||||
|
||||
def get_effective_supports_vision(self) -> bool:
|
||||
@@ -786,60 +813,93 @@ class Model(Base):
|
||||
def get_effective_supports_image_generation(self) -> bool:
|
||||
return self._get_effective_capability("supports_image_generation", False)
|
||||
|
||||
def select_provider_model_name(
|
||||
self, affinity_key: Optional[str] = None, api_format: Optional[str] = None
|
||||
) -> str:
|
||||
"""按优先级选择要使用的 Provider 模型名称
|
||||
|
||||
class ModelMapping(Base):
|
||||
"""模型映射表 - 统一处理别名与降级策略
|
||||
如果配置了 provider_model_aliases,按优先级选择(数字越小越优先);
|
||||
相同优先级的别名通过哈希分散实现负载均衡(与 Key 调度策略一致);
|
||||
否则返回 provider_model_name。
|
||||
|
||||
设计原则:
|
||||
- source_model 接收用户请求的原始模型名/别名
|
||||
- target_global_model_id 指向真实的 GlobalModel
|
||||
- provider_id 为空表示全局别名,非空表示 Provider 特定映射/降级
|
||||
- 一个 (source_model, provider_id) 组合唯一
|
||||
Args:
|
||||
affinity_key: 用于哈希分散的亲和键(如用户 API Key 哈希),确保同一用户稳定选择同一别名
|
||||
api_format: 当前请求的 API 格式(如 CLAUDE、OPENAI 等),用于过滤适用的别名
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
映射类型 (mapping_type):
|
||||
- alias: 别名模式,按目标模型计费(只是名称简写)
|
||||
- mapping: 映射模式,按源模型计费(模型降级/替代)
|
||||
"""
|
||||
if not self.provider_model_aliases:
|
||||
return self.provider_model_name
|
||||
|
||||
__tablename__ = "model_mappings"
|
||||
raw_aliases = self.provider_model_aliases
|
||||
if not isinstance(raw_aliases, list) or len(raw_aliases) == 0:
|
||||
return self.provider_model_name
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True)
|
||||
aliases: list[dict] = []
|
||||
for raw in raw_aliases:
|
||||
if not isinstance(raw, dict):
|
||||
continue
|
||||
name = raw.get("name")
|
||||
if not isinstance(name, str) or not name.strip():
|
||||
continue
|
||||
|
||||
# 源模型名称(可能是别名或真实 GlobalModel.name)
|
||||
source_model = Column(String(200), nullable=False, index=True)
|
||||
# 检查 api_formats 作用域(如果配置了且当前有 api_format)
|
||||
alias_api_formats = raw.get("api_formats")
|
||||
if api_format and alias_api_formats:
|
||||
# 如果配置了作用域,只有匹配时才生效
|
||||
if isinstance(alias_api_formats, list) and api_format not in alias_api_formats:
|
||||
continue
|
||||
|
||||
# 目标 GlobalModel
|
||||
target_global_model_id = Column(
|
||||
String(36), ForeignKey("global_models.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
raw_priority = raw.get("priority", 1)
|
||||
try:
|
||||
priority = int(raw_priority)
|
||||
except Exception:
|
||||
priority = 1
|
||||
if priority < 1:
|
||||
priority = 1
|
||||
|
||||
# Provider 关联:NULL 代表全局别名
|
||||
provider_id = Column(String(36), ForeignKey("providers.id"), nullable=True, index=True)
|
||||
aliases.append({"name": name.strip(), "priority": priority})
|
||||
|
||||
# 映射类型:alias=按目标模型计费,mapping=按源模型计费
|
||||
mapping_type = Column(String(20), nullable=False, default="alias", index=True)
|
||||
if not aliases:
|
||||
return self.provider_model_name
|
||||
|
||||
# 状态
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
# 按优先级排序(数字越小越优先)
|
||||
sorted_aliases = sorted(aliases, key=lambda x: x["priority"])
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
nullable=False,
|
||||
)
|
||||
# 获取最高优先级(最小数字)
|
||||
highest_priority = sorted_aliases[0]["priority"]
|
||||
|
||||
# 关系
|
||||
target_global_model = relationship("GlobalModel", foreign_keys=[target_global_model_id])
|
||||
provider = relationship("Provider", back_populates="model_mappings")
|
||||
# 获取所有最高优先级的别名
|
||||
top_priority_aliases = [
|
||||
alias for alias in sorted_aliases
|
||||
if alias["priority"] == highest_priority
|
||||
]
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("source_model", "provider_id", name="uq_model_mapping_source_provider"),
|
||||
)
|
||||
# 如果有多个相同优先级的别名,通过哈希分散选择
|
||||
if len(top_priority_aliases) > 1 and affinity_key:
|
||||
# 为每个别名计算哈希得分,选择得分最小的
|
||||
def hash_score(alias: dict) -> int:
|
||||
combined = f"{affinity_key}:{alias['name']}"
|
||||
return int(hashlib.md5(combined.encode()).hexdigest(), 16)
|
||||
|
||||
selected = min(top_priority_aliases, key=hash_score)
|
||||
elif len(top_priority_aliases) > 1:
|
||||
# 没有 affinity_key 时,使用确定性选择(按名称排序后取第一个)
|
||||
# 避免随机选择导致同一请求重试时选择不同的模型名称
|
||||
selected = min(top_priority_aliases, key=lambda x: x["name"])
|
||||
else:
|
||||
selected = top_priority_aliases[0]
|
||||
|
||||
return selected["name"]
|
||||
|
||||
def get_all_provider_model_names(self) -> list[str]:
|
||||
"""获取所有可用的 Provider 模型名称(主名称 + 别名)"""
|
||||
names = [self.provider_model_name]
|
||||
if self.provider_model_aliases:
|
||||
for alias in self.provider_model_aliases:
|
||||
if isinstance(alias, dict) and alias.get("name"):
|
||||
names.append(alias["name"])
|
||||
return names
|
||||
|
||||
|
||||
class ProviderAPIKey(Base):
|
||||
|
||||
@@ -7,8 +7,6 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from .api import ModelCreate
|
||||
|
||||
|
||||
# ========== 阶梯计费相关模型 ==========
|
||||
|
||||
@@ -131,24 +129,14 @@ class ModelCatalogProviderDetail(BaseModel):
|
||||
supports_function_calling: Optional[bool] = None
|
||||
supports_streaming: Optional[bool] = None
|
||||
is_active: bool
|
||||
mapping_id: Optional[str]
|
||||
|
||||
|
||||
class OrphanedModel(BaseModel):
|
||||
"""孤立的统一模型(Mapping 存在但 GlobalModel 缺失)"""
|
||||
|
||||
alias: str # 别名
|
||||
global_model_name: Optional[str] # 关联的 GlobalModel 名称(如果有)
|
||||
mapping_count: int
|
||||
|
||||
|
||||
class ModelCatalogItem(BaseModel):
|
||||
"""统一模型目录条目(方案 A:基于 GlobalModel)"""
|
||||
"""统一模型目录条目(基于 GlobalModel)"""
|
||||
|
||||
global_model_name: str # GlobalModel.name
|
||||
display_name: str # GlobalModel.display_name
|
||||
description: Optional[str] # GlobalModel.description
|
||||
aliases: List[str] # 所有指向该 GlobalModel 的别名列表
|
||||
providers: List[ModelCatalogProviderDetail] # 支持该模型的 Provider 列表
|
||||
price_range: ModelPriceRange # 价格区间(从所有 Provider 的 Model 中聚合)
|
||||
total_providers: int
|
||||
@@ -160,7 +148,6 @@ class ModelCatalogResponse(BaseModel):
|
||||
|
||||
models: List[ModelCatalogItem]
|
||||
total: int
|
||||
orphaned_models: List[OrphanedModel]
|
||||
|
||||
|
||||
class ProviderModelPriceInfo(BaseModel):
|
||||
@@ -174,13 +161,11 @@ class ProviderModelPriceInfo(BaseModel):
|
||||
|
||||
|
||||
class ProviderAvailableSourceModel(BaseModel):
|
||||
"""Provider 支持的统一模型条目(方案 A)"""
|
||||
"""Provider 支持的统一模型条目"""
|
||||
|
||||
global_model_name: str # GlobalModel.name
|
||||
display_name: str # GlobalModel.display_name
|
||||
provider_model_name: str # Model.provider_model_name (Provider 侧的模型名)
|
||||
has_alias: bool # 是否有别名指向该 GlobalModel
|
||||
aliases: List[str] # 别名列表
|
||||
model_id: Optional[str] # Model.id
|
||||
price: ProviderModelPriceInfo
|
||||
capabilities: ModelCapabilities
|
||||
@@ -194,50 +179,7 @@ class ProviderAvailableSourceModelsResponse(BaseModel):
|
||||
total: int
|
||||
|
||||
|
||||
class BatchAssignProviderConfig(BaseModel):
|
||||
"""批量添加映射的 Provider 配置"""
|
||||
|
||||
provider_id: str
|
||||
create_model: bool = Field(False, description="是否需要创建新的 Model")
|
||||
model_data: Optional[ModelCreate] = Field(
|
||||
None, description="create_model=true 时需要提供的模型配置", alias="model_config"
|
||||
)
|
||||
model_id: Optional[str] = Field(None, description="create_model=false 时需要提供的现有模型 ID")
|
||||
|
||||
|
||||
class BatchAssignModelMappingRequest(BaseModel):
|
||||
"""批量添加模型映射请求(方案 A:暂不支持,需要重构)"""
|
||||
|
||||
global_model_id: str # 要分配的 GlobalModel ID
|
||||
providers: List[BatchAssignProviderConfig]
|
||||
|
||||
|
||||
class BatchAssignProviderResult(BaseModel):
|
||||
"""批量映射结果条目"""
|
||||
|
||||
provider_id: str
|
||||
mapping_id: Optional[str]
|
||||
created_model: bool
|
||||
model_id: Optional[str]
|
||||
updated: bool = False
|
||||
|
||||
|
||||
class BatchAssignError(BaseModel):
|
||||
"""批量映射错误信息"""
|
||||
|
||||
provider_id: str
|
||||
error: str
|
||||
|
||||
|
||||
class BatchAssignModelMappingResponse(BaseModel):
|
||||
"""批量映射响应"""
|
||||
|
||||
success: bool
|
||||
created_mappings: List[BatchAssignProviderResult]
|
||||
errors: List[BatchAssignError]
|
||||
|
||||
|
||||
# ========== 阶段二:GlobalModel 相关模型 ==========
|
||||
# ========== GlobalModel 相关模型 ==========
|
||||
|
||||
|
||||
class GlobalModelCreate(BaseModel):
|
||||
@@ -245,9 +187,6 @@ class GlobalModelCreate(BaseModel):
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=100, description="统一模型名(唯一)")
|
||||
display_name: str = Field(..., min_length=1, max_length=100, description="显示名称")
|
||||
description: Optional[str] = Field(None, description="模型描述")
|
||||
official_url: Optional[str] = Field(None, max_length=500, description="官方文档链接")
|
||||
icon_url: Optional[str] = Field(None, max_length=500, description="图标 URL")
|
||||
# 按次计费配置(可选,与阶梯计费叠加)
|
||||
default_price_per_request: Optional[float] = Field(None, ge=0, description="每次请求固定费用")
|
||||
# 统一阶梯计费配置(必填)
|
||||
@@ -255,22 +194,15 @@ class GlobalModelCreate(BaseModel):
|
||||
default_tiered_pricing: TieredPricingConfig = Field(
|
||||
..., description="阶梯计费配置(固定价格用单阶梯表示)"
|
||||
)
|
||||
# 默认能力配置
|
||||
default_supports_vision: Optional[bool] = Field(False, description="默认是否支持视觉")
|
||||
default_supports_function_calling: Optional[bool] = Field(
|
||||
False, description="默认是否支持函数调用"
|
||||
)
|
||||
default_supports_streaming: Optional[bool] = Field(True, description="默认是否支持流式输出")
|
||||
default_supports_extended_thinking: Optional[bool] = Field(
|
||||
False, description="默认是否支持扩展思考"
|
||||
)
|
||||
default_supports_image_generation: Optional[bool] = Field(
|
||||
False, description="默认是否支持图像生成"
|
||||
)
|
||||
# Key 能力配置 - 模型支持的能力列表(如 ["cache_1h", "context_1m"])
|
||||
supported_capabilities: Optional[List[str]] = Field(
|
||||
None, description="支持的 Key 能力列表"
|
||||
)
|
||||
# 模型配置(JSON格式)- 包含能力、规格、元信息等
|
||||
config: Optional[Dict[str, Any]] = Field(
|
||||
None,
|
||||
description="模型配置(streaming, vision, context_limit, description 等)"
|
||||
)
|
||||
is_active: Optional[bool] = Field(True, description="是否激活")
|
||||
|
||||
|
||||
@@ -278,9 +210,6 @@ class GlobalModelUpdate(BaseModel):
|
||||
"""更新 GlobalModel 请求"""
|
||||
|
||||
display_name: Optional[str] = Field(None, min_length=1, max_length=100)
|
||||
description: Optional[str] = None
|
||||
official_url: Optional[str] = Field(None, max_length=500)
|
||||
icon_url: Optional[str] = Field(None, max_length=500)
|
||||
is_active: Optional[bool] = None
|
||||
# 按次计费配置
|
||||
default_price_per_request: Optional[float] = Field(None, ge=0, description="每次请求固定费用")
|
||||
@@ -288,16 +217,15 @@ class GlobalModelUpdate(BaseModel):
|
||||
default_tiered_pricing: Optional[TieredPricingConfig] = Field(
|
||||
None, description="阶梯计费配置"
|
||||
)
|
||||
# 默认能力配置
|
||||
default_supports_vision: Optional[bool] = None
|
||||
default_supports_function_calling: Optional[bool] = None
|
||||
default_supports_streaming: Optional[bool] = None
|
||||
default_supports_extended_thinking: Optional[bool] = None
|
||||
default_supports_image_generation: Optional[bool] = None
|
||||
# Key 能力配置 - 模型支持的能力列表(如 ["cache_1h", "context_1m"])
|
||||
supported_capabilities: Optional[List[str]] = Field(
|
||||
None, description="支持的 Key 能力列表"
|
||||
)
|
||||
# 模型配置(JSON格式)- 包含能力、规格、元信息等
|
||||
config: Optional[Dict[str, Any]] = Field(
|
||||
None,
|
||||
description="模型配置(streaming, vision, context_limit, description 等)"
|
||||
)
|
||||
|
||||
|
||||
class GlobalModelResponse(BaseModel):
|
||||
@@ -306,29 +234,24 @@ class GlobalModelResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: Optional[str]
|
||||
official_url: Optional[str]
|
||||
icon_url: Optional[str]
|
||||
is_active: bool
|
||||
# 按次计费配置
|
||||
default_price_per_request: Optional[float] = Field(None, description="每次请求固定费用")
|
||||
# 阶梯计费配置
|
||||
default_tiered_pricing: TieredPricingConfig = Field(
|
||||
..., description="阶梯计费配置"
|
||||
default_tiered_pricing: Optional[TieredPricingConfig] = Field(
|
||||
default=None, description="阶梯计费配置"
|
||||
)
|
||||
# 默认能力配置
|
||||
default_supports_vision: Optional[bool]
|
||||
default_supports_function_calling: Optional[bool]
|
||||
default_supports_streaming: Optional[bool]
|
||||
default_supports_extended_thinking: Optional[bool]
|
||||
default_supports_image_generation: Optional[bool]
|
||||
# Key 能力配置 - 模型支持的能力列表
|
||||
supported_capabilities: Optional[List[str]] = Field(
|
||||
default=None, description="支持的 Key 能力列表"
|
||||
)
|
||||
# 模型配置(JSON格式)
|
||||
config: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="模型配置(streaming, vision, context_limit, description 等)"
|
||||
)
|
||||
# 统计数据(可选)
|
||||
provider_count: Optional[int] = Field(default=0, description="支持的 Provider 数量")
|
||||
alias_count: Optional[int] = Field(default=0, description="别名数量")
|
||||
usage_count: Optional[int] = Field(default=0, description="调用次数")
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime]
|
||||
@@ -355,7 +278,7 @@ class GlobalModelListResponse(BaseModel):
|
||||
class BatchAssignToProvidersRequest(BaseModel):
|
||||
"""批量为 Provider 添加 GlobalModel 实现"""
|
||||
|
||||
provider_ids: List[str] = Field(..., min_items=1, description="Provider ID 列表")
|
||||
provider_ids: List[str] = Field(..., min_length=1, description="Provider ID 列表")
|
||||
create_models: bool = Field(default=False, description="是否自动创建 Model 记录")
|
||||
|
||||
|
||||
@@ -379,43 +302,11 @@ class BatchAssignModelsToProviderResponse(BaseModel):
|
||||
errors: List[dict]
|
||||
|
||||
|
||||
class UpdateModelMappingRequest(BaseModel):
|
||||
"""更新模型映射请求"""
|
||||
|
||||
source_model: Optional[str] = Field(
|
||||
None, min_length=1, max_length=200, description="源模型名或别名"
|
||||
)
|
||||
target_global_model_id: Optional[str] = Field(None, description="目标 GlobalModel ID")
|
||||
provider_id: Optional[str] = Field(None, description="Provider ID(为空时为全局别名)")
|
||||
is_active: Optional[bool] = Field(None, description="是否启用")
|
||||
|
||||
|
||||
class UpdateModelMappingResponse(BaseModel):
|
||||
"""更新模型映射响应"""
|
||||
|
||||
success: bool
|
||||
mapping_id: str
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
class DeleteModelMappingResponse(BaseModel):
|
||||
"""删除模型映射响应"""
|
||||
|
||||
success: bool
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BatchAssignError",
|
||||
"BatchAssignModelMappingRequest",
|
||||
"BatchAssignModelMappingResponse",
|
||||
"BatchAssignModelsToProviderRequest",
|
||||
"BatchAssignModelsToProviderResponse",
|
||||
"BatchAssignProviderConfig",
|
||||
"BatchAssignProviderResult",
|
||||
"BatchAssignToProvidersRequest",
|
||||
"BatchAssignToProvidersResponse",
|
||||
"DeleteModelMappingResponse",
|
||||
"GlobalModelCreate",
|
||||
"GlobalModelListResponse",
|
||||
"GlobalModelResponse",
|
||||
@@ -426,10 +317,7 @@ __all__ = [
|
||||
"ModelCatalogProviderDetail",
|
||||
"ModelCatalogResponse",
|
||||
"ModelPriceRange",
|
||||
"OrphanedModel",
|
||||
"ProviderAvailableSourceModel",
|
||||
"ProviderAvailableSourceModelsResponse",
|
||||
"ProviderModelPriceInfo",
|
||||
"UpdateModelMappingRequest",
|
||||
"UpdateModelMappingResponse",
|
||||
]
|
||||
|
||||
251
src/services/cache/aware_scheduler.py
vendored
251
src/services/cache/aware_scheduler.py
vendored
@@ -28,15 +28,17 @@
|
||||
- 失效缓存亲和性,避免重复选择故障资源
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from src.core.logger import logger
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from src.core.enums import APIFormat
|
||||
from src.core.exceptions import ProviderNotAvailableException
|
||||
from src.core.exceptions import ModelNotSupportedException, ProviderNotAvailableException
|
||||
from src.core.logger import logger
|
||||
from src.models.database import (
|
||||
ApiKey,
|
||||
Model,
|
||||
@@ -44,10 +46,15 @@ from src.models.database import (
|
||||
ProviderAPIKey,
|
||||
ProviderEndpoint,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.models.database import GlobalModel
|
||||
|
||||
from src.services.cache.affinity_manager import (
|
||||
CacheAffinityManager,
|
||||
get_affinity_manager,
|
||||
)
|
||||
from src.services.cache.model_cache import ModelCacheService
|
||||
from src.services.health.monitor import health_monitor
|
||||
from src.services.provider.format import normalize_api_format
|
||||
from src.services.rate_limit.adaptive_reservation import (
|
||||
@@ -114,8 +121,17 @@ class CacheAwareScheduler:
|
||||
PRIORITY_MODE_PROVIDER,
|
||||
PRIORITY_MODE_GLOBAL_KEY,
|
||||
}
|
||||
# 调度模式常量
|
||||
SCHEDULING_MODE_FIXED_ORDER = "fixed_order" # 固定顺序模式
|
||||
SCHEDULING_MODE_CACHE_AFFINITY = "cache_affinity" # 缓存亲和模式
|
||||
ALLOWED_SCHEDULING_MODES = {
|
||||
SCHEDULING_MODE_FIXED_ORDER,
|
||||
SCHEDULING_MODE_CACHE_AFFINITY,
|
||||
}
|
||||
|
||||
def __init__(self, redis_client=None, priority_mode: Optional[str] = None):
|
||||
def __init__(
|
||||
self, redis_client=None, priority_mode: Optional[str] = None, scheduling_mode: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
初始化调度器
|
||||
|
||||
@@ -125,12 +141,16 @@ class CacheAwareScheduler:
|
||||
Args:
|
||||
redis_client: Redis客户端(可选)
|
||||
priority_mode: 候选排序策略(provider | global_key)
|
||||
scheduling_mode: 调度模式(fixed_order | cache_affinity)
|
||||
"""
|
||||
self.redis = redis_client
|
||||
self.priority_mode = self._normalize_priority_mode(
|
||||
priority_mode or self.PRIORITY_MODE_PROVIDER
|
||||
)
|
||||
logger.debug(f"[CacheAwareScheduler] 初始化优先级模式: {self.priority_mode}")
|
||||
self.scheduling_mode = self._normalize_scheduling_mode(
|
||||
scheduling_mode or self.SCHEDULING_MODE_CACHE_AFFINITY
|
||||
)
|
||||
logger.debug(f"[CacheAwareScheduler] 初始化优先级模式: {self.priority_mode}, 调度模式: {self.scheduling_mode}")
|
||||
|
||||
# 初始化子组件(将在第一次使用时异步初始化)
|
||||
self._affinity_manager: Optional[CacheAffinityManager] = None
|
||||
@@ -227,19 +247,6 @@ class CacheAwareScheduler:
|
||||
if provider_offset == 0:
|
||||
# 没有找到任何候选,提供友好的错误提示
|
||||
error_msg = f"模型 '{model_name}' 不可用"
|
||||
|
||||
# 查找相似模型
|
||||
from src.services.model.mapping_resolver import get_model_mapping_resolver
|
||||
|
||||
resolver = get_model_mapping_resolver()
|
||||
similar_models = resolver.find_similar_models(db, model_name, limit=3)
|
||||
|
||||
if similar_models:
|
||||
suggestions = [
|
||||
f"{name} (相似度: {score:.0%})" for name, score in similar_models
|
||||
]
|
||||
error_msg += f"\n\n您可能想使用以下模型:\n - " + "\n - ".join(suggestions)
|
||||
|
||||
raise ProviderNotAvailableException(error_msg)
|
||||
break
|
||||
|
||||
@@ -272,9 +279,11 @@ class CacheAwareScheduler:
|
||||
self._metrics["concurrency_denied"] += 1
|
||||
continue
|
||||
|
||||
logger.debug(f" └─ 选择 Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
|
||||
logger.debug(
|
||||
f" └─ 选择 Provider={provider.name}, Endpoint={endpoint.id[:8]}..., "
|
||||
f"Key=***{key.api_key[-4:]}, 缓存命中={is_cached_user}, "
|
||||
f"并发状态[{snapshot.describe()}]")
|
||||
f"并发状态[{snapshot.describe()}]"
|
||||
)
|
||||
|
||||
if key.cache_ttl_minutes > 0 and global_model_id:
|
||||
ttl = key.cache_ttl_minutes * 60 if key.cache_ttl_minutes > 0 else None
|
||||
@@ -362,7 +371,9 @@ class CacheAwareScheduler:
|
||||
logger.debug(f" -> 无并发管理器,直接通过")
|
||||
snapshot = ConcurrencySnapshot(
|
||||
endpoint_current=0,
|
||||
endpoint_limit=int(endpoint.max_concurrent) if endpoint.max_concurrent is not None else None,
|
||||
endpoint_limit=(
|
||||
int(endpoint.max_concurrent) if endpoint.max_concurrent is not None else None
|
||||
),
|
||||
key_current=0,
|
||||
key_limit=effective_key_limit,
|
||||
is_cached_user=is_cached_user,
|
||||
@@ -497,10 +508,12 @@ class CacheAwareScheduler:
|
||||
user = None
|
||||
|
||||
# 调试日志
|
||||
logger.debug(f"[_get_effective_restrictions] ApiKey={user_api_key.id[:8]}..., "
|
||||
logger.debug(
|
||||
f"[_get_effective_restrictions] ApiKey={user_api_key.id[:8]}..., "
|
||||
f"User={user.id[:8] if user else 'None'}..., "
|
||||
f"ApiKey.allowed_models={user_api_key.allowed_models}, "
|
||||
f"User.allowed_models={user.allowed_models if user else 'N/A'}")
|
||||
f"User.allowed_models={user.allowed_models if user else 'N/A'}"
|
||||
)
|
||||
|
||||
def merge_restrictions(key_restriction, user_restriction):
|
||||
"""合并两个限制列表,返回有效的限制集合"""
|
||||
@@ -579,13 +592,17 @@ class CacheAwareScheduler:
|
||||
|
||||
target_format = normalize_api_format(api_format)
|
||||
|
||||
# 0. 解析 model_name 到 GlobalModel(用于缓存亲和性的规范化标识)
|
||||
from src.services.model.mapping_resolver import get_model_mapping_resolver
|
||||
mapping_resolver = get_model_mapping_resolver()
|
||||
global_model = await mapping_resolver.get_global_model_by_request(db, model_name, None)
|
||||
# 0. 解析 model_name 到 GlobalModel(支持直接匹配和别名匹配,使用 ModelCacheService)
|
||||
global_model = await ModelCacheService.resolve_global_model_by_name_or_alias(db, model_name)
|
||||
|
||||
if not global_model:
|
||||
logger.warning(f"GlobalModel not found: {model_name}")
|
||||
raise ModelNotSupportedException(model=model_name)
|
||||
|
||||
# 使用 GlobalModel.id 作为缓存亲和性的模型标识,确保别名和规范名都能命中同一个缓存
|
||||
global_model_id: str = str(global_model.id) if global_model else model_name
|
||||
global_model_id: str = str(global_model.id)
|
||||
requested_model_name = model_name
|
||||
resolved_model_name = str(global_model.name)
|
||||
|
||||
# 获取合并后的访问限制(ApiKey + User)
|
||||
restrictions = self._get_effective_restrictions(user_api_key)
|
||||
@@ -595,16 +612,29 @@ class CacheAwareScheduler:
|
||||
allowed_models = restrictions["allowed_models"]
|
||||
|
||||
# 0.1 检查 API 格式是否被允许
|
||||
if allowed_api_formats:
|
||||
if allowed_api_formats is not None:
|
||||
if target_format.value not in allowed_api_formats:
|
||||
logger.debug(f"API Key {user_api_key.id[:8] if user_api_key else 'N/A'}... 不允许使用 API 格式 {target_format.value}, "
|
||||
f"允许的格式: {allowed_api_formats}")
|
||||
logger.debug(
|
||||
f"API Key {user_api_key.id[:8] if user_api_key else 'N/A'}... 不允许使用 API 格式 {target_format.value}, "
|
||||
f"允许的格式: {allowed_api_formats}"
|
||||
)
|
||||
return [], global_model_id
|
||||
|
||||
# 0.2 检查模型是否被允许
|
||||
if allowed_models:
|
||||
if model_name not in allowed_models:
|
||||
logger.debug(f"用户/API Key 不允许使用模型 {model_name}, " f"允许的模型: {allowed_models}")
|
||||
if allowed_models is not None:
|
||||
if (
|
||||
requested_model_name not in allowed_models
|
||||
and resolved_model_name not in allowed_models
|
||||
):
|
||||
resolved_note = (
|
||||
f" (解析为 {resolved_model_name})"
|
||||
if resolved_model_name != requested_model_name
|
||||
else ""
|
||||
)
|
||||
logger.debug(
|
||||
f"用户/API Key 不允许使用模型 {requested_model_name}{resolved_note}, "
|
||||
f"允许的模型: {allowed_models}"
|
||||
)
|
||||
return [], global_model_id
|
||||
|
||||
# 1. 查询 Providers
|
||||
@@ -618,7 +648,7 @@ class CacheAwareScheduler:
|
||||
return [], global_model_id
|
||||
|
||||
# 1.5 根据 allowed_providers 过滤(合并 ApiKey 和 User 的限制)
|
||||
if allowed_providers:
|
||||
if allowed_providers is not None:
|
||||
original_count = len(providers)
|
||||
# 同时支持 provider id 和 name 匹配
|
||||
providers = [
|
||||
@@ -635,7 +665,8 @@ class CacheAwareScheduler:
|
||||
db=db,
|
||||
providers=providers,
|
||||
target_format=target_format,
|
||||
model_name=model_name,
|
||||
model_name=requested_model_name,
|
||||
resolved_model_name=resolved_model_name,
|
||||
affinity_key=affinity_key,
|
||||
max_candidates=max_candidates,
|
||||
allowed_endpoints=allowed_endpoints,
|
||||
@@ -650,17 +681,24 @@ class CacheAwareScheduler:
|
||||
self._metrics["total_candidates"] += len(candidates)
|
||||
self._metrics["last_candidate_count"] = len(candidates)
|
||||
|
||||
logger.debug(f"预先获取到 {len(candidates)} 个可用组合 "
|
||||
f"(api_format={target_format.value}, model={model_name})")
|
||||
logger.debug(
|
||||
f"预先获取到 {len(candidates)} 个可用组合 "
|
||||
f"(api_format={target_format.value}, model={model_name})"
|
||||
)
|
||||
|
||||
# 4. 应用缓存亲和性排序(使用 global_model_id 作为模型标识)
|
||||
if affinity_key and candidates:
|
||||
candidates = await self._apply_cache_affinity(
|
||||
candidates=candidates,
|
||||
affinity_key=affinity_key,
|
||||
api_format=target_format,
|
||||
global_model_id=global_model_id,
|
||||
)
|
||||
# 4. 应用缓存亲和性排序(仅在缓存亲和模式下启用)
|
||||
if self.scheduling_mode == self.SCHEDULING_MODE_CACHE_AFFINITY:
|
||||
if affinity_key and candidates:
|
||||
candidates = await self._apply_cache_affinity(
|
||||
candidates=candidates,
|
||||
affinity_key=affinity_key,
|
||||
api_format=target_format,
|
||||
global_model_id=global_model_id,
|
||||
)
|
||||
else:
|
||||
# 固定顺序模式:标记所有候选为非缓存
|
||||
for candidate in candidates:
|
||||
candidate.is_cached = False
|
||||
|
||||
return candidates, global_model_id
|
||||
|
||||
@@ -685,7 +723,6 @@ class CacheAwareScheduler:
|
||||
db.query(Provider)
|
||||
.options(
|
||||
selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys),
|
||||
selectinload(Provider.model_mappings),
|
||||
# 同时加载 models 和 global_model 关系,以便 get_effective_* 方法能正确继承默认值
|
||||
selectinload(Provider.models).selectinload(Model.global_model),
|
||||
)
|
||||
@@ -715,63 +752,31 @@ class CacheAwareScheduler:
|
||||
- 模型支持的能力是全局的,与具体的 Key 无关
|
||||
- 如果模型不支持某能力,整个 Provider 的所有 Key 都应该被跳过
|
||||
|
||||
映射回退逻辑:
|
||||
- 如果存在模型映射(mapping),先尝试映射后的模型
|
||||
- 如果映射后的模型因能力不满足而失败,回退尝试原模型
|
||||
- 其他失败原因(如模型不存在、Provider 未实现等)不触发回退
|
||||
支持两种匹配方式:
|
||||
1. 直接匹配 GlobalModel.name
|
||||
2. 通过 ModelCacheService 匹配别名(全局查找)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
provider: Provider 对象
|
||||
model_name: 模型名称
|
||||
model_name: 模型名称(可以是 GlobalModel.name 或别名)
|
||||
is_stream: 是否是流式请求,如果为 True 则同时检查流式支持
|
||||
capability_requirements: 能力需求(可选),用于检查模型是否支持所需能力
|
||||
|
||||
Returns:
|
||||
(is_supported, skip_reason, supported_capabilities) - 是否支持、跳过原因、模型支持的能力列表
|
||||
"""
|
||||
from src.services.model.mapping_resolver import get_model_mapping_resolver
|
||||
|
||||
mapping_resolver = get_model_mapping_resolver()
|
||||
|
||||
# 获取映射后的模型,同时检查是否发生了映射
|
||||
global_model, is_mapped = await mapping_resolver.get_global_model_with_mapping_info(
|
||||
db, model_name, str(provider.id)
|
||||
)
|
||||
# 使用 ModelCacheService 解析模型名称(支持别名)
|
||||
global_model = await ModelCacheService.resolve_global_model_by_name_or_alias(db, model_name)
|
||||
|
||||
if not global_model:
|
||||
return False, "模型不存在", None
|
||||
# 完全未找到匹配
|
||||
return False, "模型不存在或 Provider 未配置此模型", None
|
||||
|
||||
# 尝试检查映射后的模型
|
||||
# 找到 GlobalModel 后,检查当前 Provider 是否支持
|
||||
is_supported, skip_reason, caps = await self._check_model_support_for_global_model(
|
||||
db, provider, global_model, model_name, is_stream, capability_requirements
|
||||
)
|
||||
|
||||
# 如果映射后的模型因能力不满足而失败,且存在映射,则回退尝试原模型
|
||||
if not is_supported and is_mapped and skip_reason and "不支持能力" in skip_reason:
|
||||
logger.debug(
|
||||
f"Provider {provider.name} 映射模型 {global_model.name} 能力不满足,"
|
||||
f"回退尝试原模型 {model_name}"
|
||||
)
|
||||
|
||||
# 获取原模型(不应用映射)
|
||||
original_global_model = await mapping_resolver.get_global_model_direct(db, model_name)
|
||||
|
||||
if original_global_model and original_global_model.id != global_model.id:
|
||||
# 尝试原模型
|
||||
is_supported, skip_reason, caps = await self._check_model_support_for_global_model(
|
||||
db,
|
||||
provider,
|
||||
original_global_model,
|
||||
model_name,
|
||||
is_stream,
|
||||
capability_requirements,
|
||||
)
|
||||
if is_supported:
|
||||
logger.debug(
|
||||
f"Provider {provider.name} 原模型 {original_global_model.name} 支持所需能力"
|
||||
)
|
||||
|
||||
return is_supported, skip_reason, caps
|
||||
|
||||
async def _check_model_support_for_global_model(
|
||||
@@ -798,7 +803,16 @@ class CacheAwareScheduler:
|
||||
(is_supported, skip_reason, supported_capabilities)
|
||||
"""
|
||||
# 确保 global_model 附加到当前 Session
|
||||
global_model = db.merge(global_model, load=False)
|
||||
# 注意:从缓存重建的对象是 transient 状态,不能使用 load=False
|
||||
# 使用 load=True(默认)允许 SQLAlchemy 正确处理 transient 对象
|
||||
from sqlalchemy import inspect
|
||||
insp = inspect(global_model)
|
||||
if insp.transient or insp.detached:
|
||||
# transient/detached 对象:使用默认 merge(会查询 DB 检查是否存在)
|
||||
global_model = db.merge(global_model)
|
||||
else:
|
||||
# persistent 对象:已经附加到 session,无需 merge
|
||||
pass
|
||||
|
||||
# 获取模型支持的能力列表
|
||||
model_supported_capabilities: List[str] = list(global_model.supported_capabilities or [])
|
||||
@@ -806,6 +820,11 @@ class CacheAwareScheduler:
|
||||
# 查询该 Provider 是否有实现这个 GlobalModel
|
||||
for model in provider.models:
|
||||
if model.global_model_id == global_model.id and model.is_active:
|
||||
logger.debug(
|
||||
f"[_check_model_support_for_global_model] Provider={provider.name}, "
|
||||
f"GlobalModel={global_model.name}, "
|
||||
f"provider_model_name={model.provider_model_name}"
|
||||
)
|
||||
# 检查流式支持
|
||||
if is_stream:
|
||||
supports_streaming = model.get_effective_supports_streaming()
|
||||
@@ -833,6 +852,7 @@ class CacheAwareScheduler:
|
||||
key: ProviderAPIKey,
|
||||
model_name: str,
|
||||
capability_requirements: Optional[Dict[str, bool]] = None,
|
||||
resolved_model_name: Optional[str] = None,
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
检查 API Key 的可用性
|
||||
@@ -844,6 +864,7 @@ class CacheAwareScheduler:
|
||||
key: API Key 对象
|
||||
model_name: 模型名称
|
||||
capability_requirements: 能力需求(可选)
|
||||
resolved_model_name: 解析后的 GlobalModel.name(可选)
|
||||
|
||||
Returns:
|
||||
(is_available, skip_reason)
|
||||
@@ -855,7 +876,10 @@ class CacheAwareScheduler:
|
||||
|
||||
# 模型权限检查:使用 allowed_models 白名单
|
||||
# None = 允许所有模型,[] = 拒绝所有模型,["a","b"] = 只允许指定模型
|
||||
if key.allowed_models is not None and model_name not in key.allowed_models:
|
||||
if key.allowed_models is not None and (
|
||||
model_name not in key.allowed_models
|
||||
and (not resolved_model_name or resolved_model_name not in key.allowed_models)
|
||||
):
|
||||
allowed_preview = ", ".join(key.allowed_models[:3]) if key.allowed_models else "(无)"
|
||||
suffix = "..." if len(key.allowed_models) > 3 else ""
|
||||
return False, f"模型权限不匹配(允许: {allowed_preview}{suffix})"
|
||||
@@ -880,6 +904,7 @@ class CacheAwareScheduler:
|
||||
target_format: APIFormat,
|
||||
model_name: str,
|
||||
affinity_key: Optional[str],
|
||||
resolved_model_name: Optional[str] = None,
|
||||
max_candidates: Optional[int] = None,
|
||||
allowed_endpoints: Optional[set] = None,
|
||||
is_stream: bool = False,
|
||||
@@ -892,8 +917,9 @@ class CacheAwareScheduler:
|
||||
db: 数据库会话
|
||||
providers: Provider 列表
|
||||
target_format: 目标 API 格式
|
||||
model_name: 模型名称
|
||||
model_name: 模型名称(用户请求的名称,可能是别名)
|
||||
affinity_key: 亲和性标识符(通常为API Key ID)
|
||||
resolved_model_name: 解析后的 GlobalModel.name(用于 Key.allowed_models 校验)
|
||||
max_candidates: 最大候选数
|
||||
allowed_endpoints: 允许的 Endpoint ID 集合(None 表示不限制)
|
||||
is_stream: 是否是流式请求,如果为 True 则过滤不支持流式的 Provider
|
||||
@@ -925,7 +951,7 @@ class CacheAwareScheduler:
|
||||
continue
|
||||
|
||||
# 检查 Endpoint 是否在允许列表中
|
||||
if allowed_endpoints and endpoint.id not in allowed_endpoints:
|
||||
if allowed_endpoints is not None and endpoint.id not in allowed_endpoints:
|
||||
logger.debug(
|
||||
f"Endpoint {endpoint.id[:8]}... 不在用户/API Key 的允许列表中,跳过"
|
||||
)
|
||||
@@ -938,7 +964,10 @@ class CacheAwareScheduler:
|
||||
for key in keys:
|
||||
# Key 级别的能力检查(模型级别的能力检查已在上面完成)
|
||||
is_available, skip_reason = self._check_key_availability(
|
||||
key, model_name, capability_requirements
|
||||
key,
|
||||
model_name,
|
||||
capability_requirements,
|
||||
resolved_model_name=resolved_model_name,
|
||||
)
|
||||
|
||||
candidate = ProviderCandidate(
|
||||
@@ -1007,11 +1036,13 @@ class CacheAwareScheduler:
|
||||
candidate.is_cached = True
|
||||
cached_candidates.append(candidate)
|
||||
matched = True
|
||||
logger.debug(f"检测到缓存亲和性: affinity_key={affinity_key[:8]}..., "
|
||||
logger.debug(
|
||||
f"检测到缓存亲和性: affinity_key={affinity_key[:8]}..., "
|
||||
f"api_format={api_format_str}, global_model_id={global_model_id[:8]}..., "
|
||||
f"provider={provider.name}, endpoint={endpoint.id[:8]}..., "
|
||||
f"provider_key=***{key.api_key[-4:]}, "
|
||||
f"使用次数={affinity.request_count}")
|
||||
f"使用次数={affinity.request_count}"
|
||||
)
|
||||
else:
|
||||
candidate.is_cached = False
|
||||
other_candidates.append(candidate)
|
||||
@@ -1047,6 +1078,22 @@ class CacheAwareScheduler:
|
||||
self.priority_mode = normalized
|
||||
logger.debug(f"[CacheAwareScheduler] 切换优先级模式为: {self.priority_mode}")
|
||||
|
||||
def _normalize_scheduling_mode(self, mode: Optional[str]) -> str:
|
||||
normalized = (mode or "").strip().lower()
|
||||
if normalized not in self.ALLOWED_SCHEDULING_MODES:
|
||||
if normalized:
|
||||
logger.warning(f"[CacheAwareScheduler] 无效的调度模式 '{mode}',回退为 cache_affinity")
|
||||
return self.SCHEDULING_MODE_CACHE_AFFINITY
|
||||
return normalized
|
||||
|
||||
def set_scheduling_mode(self, mode: Optional[str]) -> None:
|
||||
"""运行时更新调度模式"""
|
||||
normalized = self._normalize_scheduling_mode(mode)
|
||||
if normalized == self.scheduling_mode:
|
||||
return
|
||||
self.scheduling_mode = normalized
|
||||
logger.debug(f"[CacheAwareScheduler] 切换调度模式为: {self.scheduling_mode}")
|
||||
|
||||
def _apply_priority_mode_sort(
|
||||
self, candidates: List[ProviderCandidate], affinity_key: Optional[str] = None
|
||||
) -> List[ProviderCandidate]:
|
||||
@@ -1117,6 +1164,7 @@ class CacheAwareScheduler:
|
||||
c.key.internal_priority if c.key else 999999,
|
||||
c.key.id if c.key else "",
|
||||
)
|
||||
|
||||
result.extend(sorted(group, key=secondary_sort))
|
||||
|
||||
return result
|
||||
@@ -1293,6 +1341,7 @@ _scheduler: Optional[CacheAwareScheduler] = None
|
||||
async def get_cache_aware_scheduler(
|
||||
redis_client=None,
|
||||
priority_mode: Optional[str] = None,
|
||||
scheduling_mode: Optional[str] = None,
|
||||
) -> CacheAwareScheduler:
|
||||
"""
|
||||
获取全局CacheAwareScheduler实例
|
||||
@@ -1303,6 +1352,7 @@ async def get_cache_aware_scheduler(
|
||||
Args:
|
||||
redis_client: Redis客户端(可选)
|
||||
priority_mode: 外部覆盖的优先级模式(provider | global_key)
|
||||
scheduling_mode: 外部覆盖的调度模式(fixed_order | cache_affinity)
|
||||
|
||||
Returns:
|
||||
CacheAwareScheduler实例
|
||||
@@ -1310,8 +1360,13 @@ async def get_cache_aware_scheduler(
|
||||
global _scheduler
|
||||
|
||||
if _scheduler is None:
|
||||
_scheduler = CacheAwareScheduler(redis_client, priority_mode=priority_mode)
|
||||
elif priority_mode:
|
||||
_scheduler.set_priority_mode(priority_mode)
|
||||
_scheduler = CacheAwareScheduler(
|
||||
redis_client, priority_mode=priority_mode, scheduling_mode=scheduling_mode
|
||||
)
|
||||
else:
|
||||
if priority_mode:
|
||||
_scheduler.set_priority_mode(priority_mode)
|
||||
if scheduling_mode:
|
||||
_scheduler.set_scheduling_mode(scheduling_mode)
|
||||
|
||||
return _scheduler
|
||||
|
||||
3
src/services/cache/backend.py
vendored
3
src/services/cache/backend.py
vendored
@@ -6,8 +6,7 @@
|
||||
2. RedisCache: Redis 缓存(分布式)
|
||||
|
||||
使用场景:
|
||||
- ModelMappingResolver: 模型映射与别名解析缓存
|
||||
- ModelMapper: 模型映射缓存
|
||||
- ModelCacheService: 模型解析缓存
|
||||
- 其他需要缓存的服务
|
||||
"""
|
||||
|
||||
|
||||
42
src/services/cache/invalidation.py
vendored
42
src/services/cache/invalidation.py
vendored
@@ -3,18 +3,14 @@
|
||||
|
||||
统一管理各种缓存的失效逻辑,支持:
|
||||
1. GlobalModel 变更时失效相关缓存
|
||||
2. ModelMapping 变更时失效别名/降级缓存
|
||||
3. Model 变更时失效模型映射缓存
|
||||
4. 支持同步和异步缓存后端
|
||||
2. Model 变更时失效模型映射缓存
|
||||
3. 支持同步和异步缓存后端
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from src.core.logger import logger
|
||||
|
||||
from src.core.logger import logger
|
||||
|
||||
|
||||
class CacheInvalidationService:
|
||||
"""
|
||||
@@ -25,14 +21,8 @@ class CacheInvalidationService:
|
||||
|
||||
def __init__(self):
|
||||
"""初始化缓存失效服务"""
|
||||
self._mapping_resolver = None
|
||||
self._model_mappers = [] # 可能有多个 ModelMapperMiddleware 实例
|
||||
|
||||
def set_mapping_resolver(self, mapping_resolver):
|
||||
"""设置模型映射解析器实例"""
|
||||
self._mapping_resolver = mapping_resolver
|
||||
logger.debug(f"[CacheInvalidation] 模型映射解析器已注册 (实例: {id(mapping_resolver)})")
|
||||
|
||||
def register_model_mapper(self, model_mapper):
|
||||
"""注册 ModelMapper 实例"""
|
||||
if model_mapper not in self._model_mappers:
|
||||
@@ -48,37 +38,12 @@ class CacheInvalidationService:
|
||||
"""
|
||||
logger.info(f"[CacheInvalidation] GlobalModel 变更: {model_name}")
|
||||
|
||||
# 异步失效模型解析器中的缓存
|
||||
if self._mapping_resolver:
|
||||
asyncio.create_task(self._mapping_resolver.invalidate_global_model_cache())
|
||||
|
||||
# 失效所有 ModelMapper 中与此模型相关的缓存
|
||||
for mapper in self._model_mappers:
|
||||
# 清空所有缓存(因为不知道哪些 provider 使用了这个模型)
|
||||
mapper.clear_cache()
|
||||
logger.debug(f"[CacheInvalidation] 已清空 ModelMapper 缓存")
|
||||
|
||||
def on_model_mapping_changed(self, source_model: str, provider_id: Optional[str] = None):
|
||||
"""
|
||||
ModelMapping 变更时的缓存失效
|
||||
|
||||
Args:
|
||||
source_model: 变更的源模型名
|
||||
provider_id: 相关 Provider(None 表示全局)
|
||||
"""
|
||||
logger.info(f"[CacheInvalidation] ModelMapping 变更: {source_model} (provider={provider_id})")
|
||||
|
||||
if self._mapping_resolver:
|
||||
asyncio.create_task(
|
||||
self._mapping_resolver.invalidate_mapping_cache(source_model, provider_id)
|
||||
)
|
||||
|
||||
for mapper in self._model_mappers:
|
||||
if provider_id:
|
||||
mapper.refresh_cache(provider_id)
|
||||
else:
|
||||
mapper.clear_cache()
|
||||
|
||||
def on_model_changed(self, provider_id: str, global_model_id: str):
|
||||
"""
|
||||
Model 变更时的缓存失效
|
||||
@@ -98,9 +63,6 @@ class CacheInvalidationService:
|
||||
"""清空所有缓存"""
|
||||
logger.info("[CacheInvalidation] 清空所有缓存")
|
||||
|
||||
if self._mapping_resolver:
|
||||
asyncio.create_task(self._mapping_resolver.clear_cache())
|
||||
|
||||
for mapper in self._model_mappers:
|
||||
mapper.clear_cache()
|
||||
|
||||
|
||||
295
src/services/cache/model_cache.py
vendored
295
src/services/cache/model_cache.py
vendored
@@ -1,16 +1,21 @@
|
||||
"""
|
||||
Model 映射缓存服务 - 减少模型映射和别名查询
|
||||
Model 映射缓存服务 - 减少模型查询
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.config.constants import CacheTTL
|
||||
from src.core.cache_service import CacheService
|
||||
from src.core.logger import logger
|
||||
from src.models.database import GlobalModel, Model, ModelMapping
|
||||
|
||||
from src.core.metrics import (
|
||||
model_mapping_conflict_total,
|
||||
model_mapping_resolution_duration_seconds,
|
||||
model_mapping_resolution_total,
|
||||
)
|
||||
from src.models.database import GlobalModel, Model
|
||||
|
||||
|
||||
class ModelCacheService:
|
||||
@@ -99,11 +104,16 @@ class ModelCacheService:
|
||||
Model 对象或 None
|
||||
"""
|
||||
cache_key = f"model:provider_global:{provider_id}:{global_model_id}"
|
||||
hit_count_key = f"model:provider_global:hits:{provider_id}:{global_model_id}"
|
||||
|
||||
# 1. 尝试从缓存获取
|
||||
cached_data = await CacheService.get(cache_key)
|
||||
if cached_data:
|
||||
logger.debug(f"Model 缓存命中(provider+global): {provider_id[:8]}...+{global_model_id[:8]}...")
|
||||
logger.debug(
|
||||
f"Model 缓存命中(provider+global): {provider_id[:8]}...+{global_model_id[:8]}..."
|
||||
)
|
||||
# 递增命中计数,同时刷新 TTL
|
||||
await CacheService.incr(hit_count_key, ttl_seconds=ModelCacheService.CACHE_TTL)
|
||||
return ModelCacheService._dict_to_model(cached_data)
|
||||
|
||||
# 2. 缓存未命中,查询数据库
|
||||
@@ -121,7 +131,11 @@ class ModelCacheService:
|
||||
if model:
|
||||
model_dict = ModelCacheService._model_to_dict(model)
|
||||
await CacheService.set(cache_key, model_dict, ttl_seconds=ModelCacheService.CACHE_TTL)
|
||||
logger.debug(f"Model 已缓存(provider+global): {provider_id[:8]}...+{global_model_id[:8]}...")
|
||||
# 重置命中计数(新缓存从1开始)
|
||||
await CacheService.set(hit_count_key, 1, ttl_seconds=ModelCacheService.CACHE_TTL)
|
||||
logger.debug(
|
||||
f"Model 已缓存(provider+global): {provider_id[:8]}...+{global_model_id[:8]}..."
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
@@ -158,95 +172,206 @@ class ModelCacheService:
|
||||
|
||||
return global_model
|
||||
|
||||
@staticmethod
|
||||
async def resolve_alias(
|
||||
db: Session, source_model: str, provider_id: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
解析模型别名(带缓存)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
source_model: 源模型名称或别名
|
||||
provider_id: Provider ID(可选,用于 Provider 特定别名)
|
||||
|
||||
Returns:
|
||||
目标 GlobalModel ID 或 None
|
||||
"""
|
||||
# 构造缓存键
|
||||
if provider_id:
|
||||
cache_key = f"alias:provider:{provider_id}:{source_model}"
|
||||
else:
|
||||
cache_key = f"alias:global:{source_model}"
|
||||
|
||||
# 1. 尝试从缓存获取
|
||||
cached_result = await CacheService.get(cache_key)
|
||||
if cached_result:
|
||||
logger.debug(f"别名缓存命中: {source_model} (provider: {provider_id or 'global'})")
|
||||
return cached_result
|
||||
|
||||
# 2. 缓存未命中,查询数据库
|
||||
query = db.query(ModelMapping).filter(ModelMapping.source_model == source_model)
|
||||
|
||||
if provider_id:
|
||||
# Provider 特定别名优先
|
||||
query = query.filter(ModelMapping.provider_id == provider_id)
|
||||
else:
|
||||
# 全局别名
|
||||
query = query.filter(ModelMapping.provider_id.is_(None))
|
||||
|
||||
mapping = query.first()
|
||||
|
||||
# 3. 写入缓存
|
||||
target_global_model_id = mapping.target_global_model_id if mapping else None
|
||||
await CacheService.set(
|
||||
cache_key, target_global_model_id, ttl_seconds=ModelCacheService.CACHE_TTL
|
||||
)
|
||||
|
||||
if mapping:
|
||||
logger.debug(f"别名已缓存: {source_model} → {target_global_model_id}")
|
||||
|
||||
return target_global_model_id
|
||||
|
||||
@staticmethod
|
||||
async def invalidate_model_cache(
|
||||
model_id: str, provider_id: Optional[str] = None, global_model_id: Optional[str] = None
|
||||
):
|
||||
model_id: str,
|
||||
provider_id: Optional[str] = None,
|
||||
global_model_id: Optional[str] = None,
|
||||
provider_model_name: Optional[str] = None,
|
||||
provider_model_aliases: Optional[list] = None,
|
||||
) -> None:
|
||||
"""清除 Model 缓存
|
||||
|
||||
Args:
|
||||
model_id: Model ID
|
||||
provider_id: Provider ID(用于清除 provider_global 缓存)
|
||||
global_model_id: GlobalModel ID(用于清除 provider_global 缓存)
|
||||
provider_model_name: provider_model_name(用于清除 resolve 缓存)
|
||||
provider_model_aliases: 映射名称列表(用于清除 resolve 缓存)
|
||||
"""
|
||||
# 清除 model:id 缓存
|
||||
await CacheService.delete(f"model:id:{model_id}")
|
||||
|
||||
# 清除 provider_global 缓存(如果提供了必要参数)
|
||||
# 清除 provider_global 缓存及其命中计数(如果提供了必要参数)
|
||||
if provider_id and global_model_id:
|
||||
await CacheService.delete(f"model:provider_global:{provider_id}:{global_model_id}")
|
||||
logger.debug(f"Model 缓存已清除: {model_id}, provider_global:{provider_id[:8]}...:{global_model_id[:8]}...")
|
||||
await CacheService.delete(f"model:provider_global:hits:{provider_id}:{global_model_id}")
|
||||
logger.debug(
|
||||
f"Model 缓存已清除: {model_id}, provider_global:{provider_id[:8]}...:{global_model_id[:8]}..."
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Model 缓存已清除: {model_id}")
|
||||
|
||||
# 清除 resolve 缓存(provider_model_name 和 aliases 可能都被用作解析 key)
|
||||
resolve_keys_to_clear = []
|
||||
if provider_model_name:
|
||||
resolve_keys_to_clear.append(provider_model_name)
|
||||
if provider_model_aliases:
|
||||
for alias_entry in provider_model_aliases:
|
||||
if isinstance(alias_entry, dict):
|
||||
alias_name = alias_entry.get("name", "").strip()
|
||||
if alias_name:
|
||||
resolve_keys_to_clear.append(alias_name)
|
||||
|
||||
for key in resolve_keys_to_clear:
|
||||
await CacheService.delete(f"global_model:resolve:{key}")
|
||||
|
||||
if resolve_keys_to_clear:
|
||||
logger.debug(f"Model resolve 缓存已清除: {resolve_keys_to_clear}")
|
||||
|
||||
@staticmethod
|
||||
async def invalidate_global_model_cache(global_model_id: str, name: Optional[str] = None):
|
||||
async def invalidate_global_model_cache(global_model_id: str, name: Optional[str] = None) -> None:
|
||||
"""清除 GlobalModel 缓存"""
|
||||
await CacheService.delete(f"global_model:id:{global_model_id}")
|
||||
if name:
|
||||
await CacheService.delete(f"global_model:name:{name}")
|
||||
# 同时清除 resolve 缓存,因为 GlobalModel.name 也是一个 resolve key
|
||||
await CacheService.delete(f"global_model:resolve:{name}")
|
||||
logger.debug(f"GlobalModel 缓存已清除: {global_model_id}")
|
||||
|
||||
@staticmethod
|
||||
async def invalidate_alias_cache(source_model: str, provider_id: Optional[str] = None):
|
||||
"""清除别名缓存"""
|
||||
if provider_id:
|
||||
cache_key = f"alias:provider:{provider_id}:{source_model}"
|
||||
else:
|
||||
cache_key = f"alias:global:{source_model}"
|
||||
async def resolve_global_model_by_name_or_alias(
|
||||
db: Session, model_name: str
|
||||
) -> Optional[GlobalModel]:
|
||||
"""
|
||||
通过名称解析 GlobalModel(带缓存)
|
||||
|
||||
await CacheService.delete(cache_key)
|
||||
logger.debug(f"别名缓存已清除: {source_model}")
|
||||
查找顺序:
|
||||
1. 检查缓存
|
||||
2. 通过 provider_model_name 匹配(查询 Model 表)
|
||||
3. 直接匹配 GlobalModel.name(兜底)
|
||||
|
||||
注意:此方法不使用 provider_model_aliases 进行全局解析。
|
||||
provider_model_aliases 是 Provider 级别的别名配置,只在特定 Provider 上下文中生效,
|
||||
由 resolve_provider_model() 处理。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
model_name: 模型名称(可以是 GlobalModel.name 或 provider_model_name)
|
||||
|
||||
Returns:
|
||||
GlobalModel 对象或 None
|
||||
"""
|
||||
start_time = time.time()
|
||||
resolution_method = "not_found"
|
||||
cache_hit = False
|
||||
|
||||
normalized_name = model_name.strip()
|
||||
if not normalized_name:
|
||||
return None
|
||||
|
||||
cache_key = f"global_model:resolve:{normalized_name}"
|
||||
|
||||
try:
|
||||
# 1. 尝试从缓存获取
|
||||
cached_data = await CacheService.get(cache_key)
|
||||
if cached_data:
|
||||
if cached_data == "NOT_FOUND":
|
||||
# 缓存的负结果
|
||||
cache_hit = True
|
||||
resolution_method = "not_found"
|
||||
logger.debug(f"GlobalModel 缓存命中(映射解析-未找到): {normalized_name}")
|
||||
return None
|
||||
if isinstance(cached_data, dict) and "supported_capabilities" not in cached_data:
|
||||
# 兼容旧缓存:字段不全时视为未命中,走 DB 刷新
|
||||
logger.debug(f"GlobalModel 缓存命中但 schema 过旧,刷新: {normalized_name}")
|
||||
else:
|
||||
cache_hit = True
|
||||
resolution_method = "direct_match" # 缓存命中时无法区分原始解析方式
|
||||
logger.debug(f"GlobalModel 缓存命中(映射解析): {normalized_name}")
|
||||
return ModelCacheService._dict_to_global_model(cached_data)
|
||||
|
||||
# 2. 通过 provider_model_name 匹配(不考虑 provider_model_aliases)
|
||||
# 重要:provider_model_aliases 是 Provider 级别的别名配置,只在特定 Provider 上下文中生效
|
||||
# 全局解析不应该受到某个 Provider 别名配置的影响
|
||||
# 例如:Provider A 把 "haiku" 映射到 "sonnet",不应该影响 Provider B 的 "haiku" 解析
|
||||
from src.models.database import Provider
|
||||
|
||||
models_with_global = (
|
||||
db.query(Model, GlobalModel)
|
||||
.join(Provider, Model.provider_id == Provider.id)
|
||||
.join(GlobalModel, Model.global_model_id == GlobalModel.id)
|
||||
.filter(
|
||||
Provider.is_active == True,
|
||||
Model.is_active == True,
|
||||
GlobalModel.is_active == True,
|
||||
Model.provider_model_name == normalized_name,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 收集匹配的 GlobalModel(只通过 provider_model_name 匹配)
|
||||
matched_global_models: List[GlobalModel] = []
|
||||
seen_global_model_ids: set[str] = set()
|
||||
for model, gm in models_with_global:
|
||||
if gm.id not in seen_global_model_ids:
|
||||
seen_global_model_ids.add(gm.id)
|
||||
matched_global_models.append(gm)
|
||||
logger.debug(
|
||||
f"模型名称 '{normalized_name}' 通过 provider_model_name 匹配到 "
|
||||
f"GlobalModel: {gm.name} (Model: {model.id[:8]}...)"
|
||||
)
|
||||
|
||||
# 如果通过 provider_model_name 找到了,返回
|
||||
if matched_global_models:
|
||||
resolution_method = "provider_model_name"
|
||||
|
||||
if len(matched_global_models) > 1:
|
||||
# 检测到冲突(多个不同的 GlobalModel 有相同的 provider_model_name)
|
||||
model_names = [gm.name for gm in matched_global_models if gm.name]
|
||||
logger.warning(
|
||||
f"模型映射冲突: 名称 '{normalized_name}' 匹配到多个不同的 GlobalModel: "
|
||||
f"{', '.join(model_names)},使用第一个匹配结果"
|
||||
)
|
||||
# 记录冲突指标
|
||||
model_mapping_conflict_total.inc()
|
||||
|
||||
# 返回第一个匹配的 GlobalModel
|
||||
result_global_model = matched_global_models[0]
|
||||
global_model_dict = ModelCacheService._global_model_to_dict(result_global_model)
|
||||
await CacheService.set(
|
||||
cache_key, global_model_dict, ttl_seconds=ModelCacheService.CACHE_TTL
|
||||
)
|
||||
logger.debug(
|
||||
f"GlobalModel 已缓存(映射解析-{resolution_method}): {normalized_name} -> {result_global_model.name}"
|
||||
)
|
||||
return result_global_model
|
||||
|
||||
# 3. 如果通过 provider 映射没找到,最后尝试直接通过 GlobalModel.name 查找
|
||||
global_model = (
|
||||
db.query(GlobalModel)
|
||||
.filter(GlobalModel.name == normalized_name, GlobalModel.is_active == True)
|
||||
.first()
|
||||
)
|
||||
|
||||
if global_model:
|
||||
resolution_method = "direct_match"
|
||||
# 缓存结果
|
||||
global_model_dict = ModelCacheService._global_model_to_dict(global_model)
|
||||
await CacheService.set(
|
||||
cache_key, global_model_dict, ttl_seconds=ModelCacheService.CACHE_TTL
|
||||
)
|
||||
logger.debug(f"GlobalModel 已缓存(映射解析-直接匹配): {normalized_name}")
|
||||
return global_model
|
||||
|
||||
# 4. 完全未找到
|
||||
resolution_method = "not_found"
|
||||
# 未找到匹配,缓存负结果
|
||||
await CacheService.set(
|
||||
cache_key, "NOT_FOUND", ttl_seconds=ModelCacheService.CACHE_TTL
|
||||
)
|
||||
logger.debug(f"GlobalModel 未找到(映射解析): {normalized_name}")
|
||||
return None
|
||||
|
||||
finally:
|
||||
# 记录监控指标
|
||||
duration = time.time() - start_time
|
||||
model_mapping_resolution_total.labels(
|
||||
method=resolution_method, cache_hit=str(cache_hit).lower()
|
||||
).inc()
|
||||
model_mapping_resolution_duration_seconds.labels(method=resolution_method).observe(
|
||||
duration
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _model_to_dict(model: Model) -> dict:
|
||||
@@ -256,16 +381,18 @@ class ModelCacheService:
|
||||
"provider_id": model.provider_id,
|
||||
"global_model_id": model.global_model_id,
|
||||
"provider_model_name": model.provider_model_name,
|
||||
"provider_model_aliases": getattr(model, "provider_model_aliases", None),
|
||||
"is_active": model.is_active,
|
||||
"is_available": model.is_available if hasattr(model, "is_available") else True,
|
||||
"price_per_request": (
|
||||
float(model.price_per_request) if model.price_per_request else None
|
||||
float(model.price_per_request) if model.price_per_request is not None else None
|
||||
),
|
||||
"tiered_pricing": model.tiered_pricing,
|
||||
"supports_vision": model.supports_vision,
|
||||
"supports_function_calling": model.supports_function_calling,
|
||||
"supports_streaming": model.supports_streaming,
|
||||
"supports_extended_thinking": model.supports_extended_thinking,
|
||||
"supports_image_generation": getattr(model, "supports_image_generation", None),
|
||||
"config": model.config,
|
||||
}
|
||||
|
||||
@@ -277,6 +404,7 @@ class ModelCacheService:
|
||||
provider_id=model_dict["provider_id"],
|
||||
global_model_id=model_dict["global_model_id"],
|
||||
provider_model_name=model_dict["provider_model_name"],
|
||||
provider_model_aliases=model_dict.get("provider_model_aliases"),
|
||||
is_active=model_dict["is_active"],
|
||||
is_available=model_dict.get("is_available", True),
|
||||
price_per_request=model_dict.get("price_per_request"),
|
||||
@@ -285,6 +413,7 @@ class ModelCacheService:
|
||||
supports_function_calling=model_dict.get("supports_function_calling"),
|
||||
supports_streaming=model_dict.get("supports_streaming"),
|
||||
supports_extended_thinking=model_dict.get("supports_extended_thinking"),
|
||||
supports_image_generation=model_dict.get("supports_image_generation"),
|
||||
config=model_dict.get("config"),
|
||||
)
|
||||
return model
|
||||
@@ -296,14 +425,15 @@ class ModelCacheService:
|
||||
"id": global_model.id,
|
||||
"name": global_model.name,
|
||||
"display_name": global_model.display_name,
|
||||
"family": global_model.family,
|
||||
"group_id": global_model.group_id,
|
||||
"supports_vision": global_model.supports_vision,
|
||||
"supports_thinking": global_model.supports_thinking,
|
||||
"context_window": global_model.context_window,
|
||||
"max_output_tokens": global_model.max_output_tokens,
|
||||
"supported_capabilities": global_model.supported_capabilities,
|
||||
"config": global_model.config,
|
||||
"default_tiered_pricing": global_model.default_tiered_pricing,
|
||||
"default_price_per_request": (
|
||||
float(global_model.default_price_per_request)
|
||||
if global_model.default_price_per_request is not None
|
||||
else None
|
||||
),
|
||||
"is_active": global_model.is_active,
|
||||
"description": global_model.description,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@@ -313,13 +443,10 @@ class ModelCacheService:
|
||||
id=global_model_dict["id"],
|
||||
name=global_model_dict["name"],
|
||||
display_name=global_model_dict.get("display_name"),
|
||||
family=global_model_dict.get("family"),
|
||||
group_id=global_model_dict.get("group_id"),
|
||||
supports_vision=global_model_dict.get("supports_vision", False),
|
||||
supports_thinking=global_model_dict.get("supports_thinking", False),
|
||||
context_window=global_model_dict.get("context_window"),
|
||||
max_output_tokens=global_model_dict.get("max_output_tokens"),
|
||||
supported_capabilities=global_model_dict.get("supported_capabilities") or [],
|
||||
config=global_model_dict.get("config"),
|
||||
default_tiered_pricing=global_model_dict.get("default_tiered_pricing"),
|
||||
default_price_per_request=global_model_dict.get("default_price_per_request"),
|
||||
is_active=global_model_dict.get("is_active", True),
|
||||
description=global_model_dict.get("description"),
|
||||
)
|
||||
return global_model
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user