17 Commits

Author SHA1 Message Date
fawney19
7faca5512a feat(ui): 优化密钥添加和仪表盘空状态体验
- KeyFormDialog: 添加模式下保存后不关闭对话框,清除表单以便连续添加
- KeyFormDialog: 按钮文案根据编辑/添加模式动态显示
- Dashboard: 优化统计卡片加载状态和空数据占位显示
2026-01-10 19:32:36 +08:00
fawney19
ad84272084 fix: 修复普通用户无法访问仪表盘接口的权限问题
将 DashboardAdapter 的 mode 从 ApiMode.ADMIN 改为 ApiMode.USER,
允许普通用户访问 /api/dashboard/stats 和 /api/dashboard/daily-stats 接口。
2026-01-10 19:31:19 +08:00
fawney19
09e0f594ff refactor: 重构限流系统和健康监控,支持按 API 格式区分
- 将 adaptive_concurrency 重命名为 adaptive_rpm,从并发控制改为 RPM 控制
- 健康监控器支持按 API 格式独立管理健康度和熔断器状态
- 新增 model_permissions 模块,支持按格式配置允许的模型
- 重构前端提供商相关表单组件,新增 Collapsible UI 组件
- 新增数据库迁移脚本支持新的数据结构
2026-01-10 18:48:35 +08:00
fawney19
dd2fbf4424 style(ui): 调整模型详情抽屉关联提供商表格列宽 2026-01-08 13:37:41 +08:00
fawney19
99b12a49c6 Merge pull request #78 from fawney19/perf/optimize
perf: 优化 HTTP 客户端连接池复用
2026-01-08 13:37:13 +08:00
fawney19
ea35efe440 perf: 优化 HTTP 客户端连接池复用
- 新增 get_proxy_client() 方法,相同代理配置复用同一客户端
- 添加 LRU 淘汰策略,代理客户端上限 50 个防止内存泄漏
- 新增 get_default_client_async() 异步线程安全版本
- 使用模块级锁避免类属性初始化竞态条件
- 优化 ConcurrencyManager 使用 Redis MGET 批量获取减少往返
- 添加 get_pool_stats() 连接池统计信息接口
2026-01-08 13:34:59 +08:00
fawney19
bf09e740e9 fix(ui): 优化提供商详情页的交互体验
- 模型列表删除按钮仅在 hover 时显示红色
- 批量关联模型对话框:只有全局模型时展开,有多个分组时全部折叠
2026-01-08 11:25:52 +08:00
fawney19
60c77cec56 Merge pull request #77 from AAEE86/ui
style(ui): improve text visibility in dark mode for model badges
2026-01-08 10:52:54 +08:00
fawney19
0e4a1dddb5 refactor(ui): 优化批量端点创建的 UI 和性能
- 调整布局: API URL 移至顶部, API 格式选择移至下方
- 优化 checkbox 样式: 使用自定义勾选框替代原生样式
- API 格式按列排序: 基础格式和对应 CLI 格式上下对齐
- 请求配置改为 4 列布局, 更紧凑
- 使用 Promise.allSettled 并发创建端点, 提升性能
- 改进错误提示: 失败时直接展示具体错误信息给用户
- 清理未使用的 Select 组件导入和 selectOpen 变量
2026-01-08 10:50:25 +08:00
AAEE86
1cf18b6e12 feat(ui): support batch endpoint creation with multiple API formats (#76)
Replace single API format selector with multi-select checkbox interface in endpoint creation dialog. Users can now select multiple API formats to create multiple endpoints simultaneously with shared configuration (URL, path, timeout, etc.).

- Change API format selection from dropdown to checkbox grid layout
- Add selectedFormats array to track multiple format selections
- Implement batch creation logic with individual error handling
- Update submit button to show endpoint count being created
- Adjust form layout to improve visual hierarchy
- Display appropriate success/failure messages for batch operations
- Reset selectedFormats on form reset
2026-01-08 10:42:14 +08:00
AAEE86
f9a8be898a style(ui): improve text visibility in dark mode for model badges 2026-01-08 10:26:58 +08:00
fawney19
1521ce5a96 feat: 添加负载均衡调度模式
- 新增 load_balance 调度模式,同优先级内随机轮换
- 前端支持三种调度模式切换:缓存亲和、负载均衡、固定顺序
2026-01-08 03:20:04 +08:00
fawney19
f2e62dd197 feat: 添加版本更新检查功能
- 后端新增 /api/admin/system/check-update 接口,从 GitHub Tags 获取最新版本
- 前端新增 UpdateDialog 组件,管理员登录后自动检查更新并弹窗提示
- 同一会话内只检查一次,点击"稍后提醒"后 24 小时内不再提示
- CI 和 deploy.sh 自动生成 _version.py 版本文件
2026-01-08 03:01:54 +08:00
fawney19
d378630b38 perf: 添加多层缓存优化减少数据库查询
- 新增 ProviderCacheService 缓存 Provider 和 ProviderAPIKey 数据
- SystemConfigService 添加进程内缓存(TTL 60秒)
- API Key last_used_at 更新添加节流策略(60秒间隔)
- HTTP 连接池配置改为可配置,支持根据 Worker 数量自动计算
- 前端优先级管理改用 health_score 显示健康度
2026-01-08 02:34:59 +08:00
fawney19
d9e6346911 fix: 降低 API Key 最小长度限制至 3 个字符 2026-01-08 01:53:16 +08:00
fawney19
238788e0e9 fix: 统一端点默认重试次数为 2
同步前端表单、mock 数据和后端导入配置中端点的默认重试次数
2026-01-08 01:40:40 +08:00
fawney19
68ff828505 feat: 容器启动时自动执行数据库迁移
- 添加 entrypoint.sh 在容器启动前执行 alembic upgrade head
- 更新 Dockerfile.app 和 Dockerfile.app.local 使用新入口脚本
- 移除手动迁移脚本 migrate.sh
- 简化 README 部署说明
2026-01-08 01:28:36 +08:00
116 changed files with 7932 additions and 4365 deletions

View File

@@ -146,10 +146,33 @@ jobs:
type=semver,pattern={{major}}.{{minor}} type=semver,pattern={{major}}.{{minor}}
type=sha,prefix= type=sha,prefix=
- name: Extract version from tag
id: version
run: |
# 从 tag 提取版本号,如 v0.2.5 -> 0.2.5
VERSION="${GITHUB_REF#refs/tags/v}"
if [ "$VERSION" = "$GITHUB_REF" ]; then
# 不是 tag 触发,使用 git describe
VERSION=$(git describe --tags --always | sed 's/^v//')
fi
echo "version=$VERSION" >> $GITHUB_OUTPUT
echo "Extracted version: $VERSION"
- name: Update Dockerfile.app to use registry base image - name: Update Dockerfile.app to use registry base image
run: | run: |
sed -i "s|FROM aether-base:latest AS builder|FROM ${{ env.REGISTRY }}/${{ env.BASE_IMAGE_NAME }}:latest AS builder|g" Dockerfile.app sed -i "s|FROM aether-base:latest AS builder|FROM ${{ env.REGISTRY }}/${{ env.BASE_IMAGE_NAME }}:latest AS builder|g" Dockerfile.app
- name: Generate version file
run: |
# 生成 _version.py 文件
cat > src/_version.py << EOF
# Auto-generated by CI
__version__ = '${{ steps.version.outputs.version }}'
__version_tuple__ = tuple(int(x) for x in '${{ steps.version.outputs.version }}'.split('.') if x.isdigit())
version = __version__
version_tuple = __version_tuple__
EOF
- name: Build and push app image - name: Build and push app image
uses: docker/build-push-action@v5 uses: docker/build-push-action@v5
with: with:

3
.gitignore vendored
View File

@@ -224,3 +224,6 @@ extracted_*.ts
.deps-hash .deps-hash
.code-hash .code-hash
.migration-hash .migration-hash
# Version file (auto-generated by hatch-vcs)
src/_version.py

View File

@@ -147,6 +147,10 @@ RUN printf '%s\n' \
# 创建目录 # 创建目录
RUN mkdir -p /var/log/supervisor /app/logs /app/data RUN mkdir -p /var/log/supervisor /app/logs /app/data
# 入口脚本(启动前执行迁移)
COPY entrypoint.sh /entrypoint.sh
RUN chmod +x /entrypoint.sh
# 环境变量 # 环境变量
ENV PYTHONUNBUFFERED=1 \ ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \ PYTHONDONTWRITEBYTECODE=1 \
@@ -161,4 +165,5 @@ EXPOSE 80
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost/health || exit 1 CMD curl -f http://localhost/health || exit 1
ENTRYPOINT ["/entrypoint.sh"]
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"] CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]

View File

@@ -139,6 +139,10 @@ RUN printf '%s\n' \
# 创建目录 # 创建目录
RUN mkdir -p /var/log/supervisor /app/logs /app/data RUN mkdir -p /var/log/supervisor /app/logs /app/data
# 入口脚本(启动前执行迁移)
COPY entrypoint.sh /entrypoint.sh
RUN chmod +x /entrypoint.sh
# 环境变量 # 环境变量
ENV PYTHONUNBUFFERED=1 \ ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \ PYTHONDONTWRITEBYTECODE=1 \
@@ -152,4 +156,5 @@ EXPOSE 80
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost/health || exit 1 CMD curl -f http://localhost/health || exit 1
ENTRYPOINT ["/entrypoint.sh"]
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"] CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]

View File

@@ -57,14 +57,8 @@ cd Aether
cp .env.example .env cp .env.example .env
python generate_keys.py # 生成密钥, 并将生成的密钥填入 .env python generate_keys.py # 生成密钥, 并将生成的密钥填入 .env
# 3. 部署 # 3. 部署 / 更新(自动执行数据库迁移)
docker compose up -d docker compose pull && docker compose up -d
# 4. 首次部署时, 初始化数据库
./migrate.sh
# 5. 更新
docker compose pull && docker compose up -d && ./migrate.sh
``` ```
### Docker Compose本地构建镜像 ### Docker Compose本地构建镜像

View File

@@ -0,0 +1,530 @@
"""consolidated schema updates
Revision ID: m4n5o6p7q8r9
Revises: 02a45b66b7c4
Create Date: 2026-01-10 20:00:00.000000
This migration consolidates all schema changes from 2026-01-08 to 2026-01-10:
1. provider_api_keys: Key 直接关联 Provider (provider_id, api_formats)
2. provider_api_keys: 添加 rate_multipliers JSON 字段(按格式费率)
3. models: global_model_id 改为可空(支持独立 ProviderModel
4. providers: 添加 timeout, max_retries, proxy从 endpoint 迁移)
5. providers: display_name 重命名为 name删除原 name
6. provider_api_keys: max_concurrent -> rpm_limit并发改 RPM
7. provider_api_keys: 健康度改为按格式存储health_by_format, circuit_breaker_by_format
8. provider_endpoints: 删除废弃的 rate_limit 列
9. usage: 添加 client_response_headers 字段
10. provider_api_keys: 删除 endpoint_idKey 不再与 Endpoint 绑定)
11. provider_endpoints: 删除废弃的 max_concurrent 列
12. providers: 删除废弃的 rpm_limit, rpm_used, rpm_reset_at 列
"""
import logging
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy import inspect
# 配置日志
alembic_logger = logging.getLogger("alembic.runtime.migration")
revision = "m4n5o6p7q8r9"
down_revision = "02a45b66b7c4"
branch_labels = None
depends_on = None
def _column_exists(table_name: str, column_name: str) -> bool:
"""Check if a column exists in the table"""
bind = op.get_bind()
inspector = inspect(bind)
columns = [col["name"] for col in inspector.get_columns(table_name)]
return column_name in columns
def _constraint_exists(table_name: str, constraint_name: str) -> bool:
"""Check if a constraint exists"""
bind = op.get_bind()
inspector = inspect(bind)
fks = inspector.get_foreign_keys(table_name)
return any(fk.get("name") == constraint_name for fk in fks)
def _index_exists(table_name: str, index_name: str) -> bool:
"""Check if an index exists"""
bind = op.get_bind()
inspector = inspect(bind)
indexes = inspector.get_indexes(table_name)
return any(idx.get("name") == index_name for idx in indexes)
def upgrade() -> None:
"""Apply all consolidated schema changes"""
bind = op.get_bind()
# ========== 1. provider_api_keys: 添加 provider_id 和 api_formats ==========
if not _column_exists("provider_api_keys", "provider_id"):
op.add_column("provider_api_keys", sa.Column("provider_id", sa.String(36), nullable=True))
# 数据迁移:从 endpoint 获取 provider_id
op.execute("""
UPDATE provider_api_keys k
SET provider_id = e.provider_id
FROM provider_endpoints e
WHERE k.endpoint_id = e.id AND k.provider_id IS NULL
""")
# 检查无法关联的孤儿 Key
result = bind.execute(sa.text(
"SELECT COUNT(*) FROM provider_api_keys WHERE provider_id IS NULL"
))
orphan_count = result.scalar() or 0
if orphan_count > 0:
# 使用 logger 记录更明显的告警
alembic_logger.warning("=" * 60)
alembic_logger.warning(f"[MIGRATION WARNING] 发现 {orphan_count} 个无法关联 Provider 的孤儿 Key")
alembic_logger.warning("=" * 60)
alembic_logger.info("正在备份孤儿 Key 到 _orphan_api_keys_backup 表...")
# 先备份孤儿数据到临时表,避免数据丢失
op.execute("""
CREATE TABLE IF NOT EXISTS _orphan_api_keys_backup AS
SELECT *, NOW() as backup_at
FROM provider_api_keys
WHERE provider_id IS NULL
""")
# 记录备份的 Key ID
orphan_ids = bind.execute(sa.text(
"SELECT id, name FROM provider_api_keys WHERE provider_id IS NULL"
)).fetchall()
alembic_logger.info("备份的孤儿 Key 列表:")
for key_id, key_name in orphan_ids:
alembic_logger.info(f" - Key: {key_name} (ID: {key_id})")
# 删除孤儿数据
op.execute("DELETE FROM provider_api_keys WHERE provider_id IS NULL")
alembic_logger.info(f"已备份并删除 {orphan_count} 个孤儿 Key")
# 提供恢复指南
alembic_logger.warning("-" * 60)
alembic_logger.warning("[恢复指南] 如需恢复孤儿 Key")
alembic_logger.warning(" 1. 查询备份表: SELECT * FROM _orphan_api_keys_backup;")
alembic_logger.warning(" 2. 确定正确的 provider_id")
alembic_logger.warning(" 3. 执行恢复:")
alembic_logger.warning(" INSERT INTO provider_api_keys (...)")
alembic_logger.warning(" SELECT ... FROM _orphan_api_keys_backup WHERE ...;")
alembic_logger.warning("-" * 60)
# 设置 NOT NULL 并创建外键
op.alter_column("provider_api_keys", "provider_id", nullable=False)
if not _constraint_exists("provider_api_keys", "fk_provider_api_keys_provider"):
op.create_foreign_key(
"fk_provider_api_keys_provider",
"provider_api_keys",
"providers",
["provider_id"],
["id"],
ondelete="CASCADE",
)
if not _index_exists("provider_api_keys", "idx_provider_api_keys_provider_id"):
op.create_index("idx_provider_api_keys_provider_id", "provider_api_keys", ["provider_id"])
if not _column_exists("provider_api_keys", "api_formats"):
op.add_column("provider_api_keys", sa.Column("api_formats", sa.JSON(), nullable=True))
# 数据迁移:从 endpoint 获取 api_format
op.execute("""
UPDATE provider_api_keys k
SET api_formats = json_build_array(e.api_format)
FROM provider_endpoints e
WHERE k.endpoint_id = e.id AND k.api_formats IS NULL
""")
op.alter_column("provider_api_keys", "api_formats", nullable=False, server_default="[]")
# 修改 endpoint_id 为可空,外键改为 SET NULL
if _constraint_exists("provider_api_keys", "provider_api_keys_endpoint_id_fkey"):
op.drop_constraint("provider_api_keys_endpoint_id_fkey", "provider_api_keys", type_="foreignkey")
op.alter_column("provider_api_keys", "endpoint_id", nullable=True)
# 不再重建外键,因为后面会删除这个字段
# ========== 2. provider_api_keys: 添加 rate_multipliers ==========
if not _column_exists("provider_api_keys", "rate_multipliers"):
op.add_column(
"provider_api_keys",
sa.Column("rate_multipliers", postgresql.JSON(astext_type=sa.Text()), nullable=True),
)
# 数据迁移:将 rate_multiplier 按 api_formats 转换
op.execute("""
UPDATE provider_api_keys
SET rate_multipliers = (
SELECT jsonb_object_agg(elem, rate_multiplier)
FROM jsonb_array_elements_text(api_formats::jsonb) AS elem
)
WHERE api_formats IS NOT NULL
AND api_formats::text != '[]'
AND api_formats::text != 'null'
AND rate_multipliers IS NULL
""")
# ========== 3. models: global_model_id 改为可空 ==========
op.alter_column("models", "global_model_id", existing_type=sa.String(36), nullable=True)
# ========== 4. providers: 添加 timeout, max_retries, proxy ==========
if not _column_exists("providers", "timeout"):
op.add_column(
"providers",
sa.Column("timeout", sa.Integer(), nullable=True, comment="请求超时(秒)"),
)
if not _column_exists("providers", "max_retries"):
op.add_column(
"providers",
sa.Column("max_retries", sa.Integer(), nullable=True, comment="最大重试次数"),
)
if not _column_exists("providers", "proxy"):
op.add_column(
"providers",
sa.Column("proxy", postgresql.JSONB(), nullable=True, comment="代理配置"),
)
# 从端点迁移数据到 provider
op.execute("""
UPDATE providers p
SET
timeout = COALESCE(
p.timeout,
(SELECT MAX(e.timeout) FROM provider_endpoints e WHERE e.provider_id = p.id AND e.timeout IS NOT NULL),
300
),
max_retries = COALESCE(
p.max_retries,
(SELECT MAX(e.max_retries) FROM provider_endpoints e WHERE e.provider_id = p.id AND e.max_retries IS NOT NULL),
2
),
proxy = COALESCE(
p.proxy,
(SELECT e.proxy FROM provider_endpoints e WHERE e.provider_id = p.id AND e.proxy IS NOT NULL ORDER BY e.created_at LIMIT 1)
)
WHERE p.timeout IS NULL OR p.max_retries IS NULL
""")
# ========== 5. providers: display_name -> name ==========
# 注意:这里假设 display_name 已经被重命名为 name
# 如果 display_name 仍然存在,则需要执行重命名
if _column_exists("providers", "display_name"):
# 删除旧的 name 索引
if _index_exists("providers", "ix_providers_name"):
op.drop_index("ix_providers_name", table_name="providers")
# 如果存在旧的 name 列,先删除
if _column_exists("providers", "name"):
op.drop_column("providers", "name")
# 重命名 display_name 为 name
op.alter_column("providers", "display_name", new_column_name="name")
# 创建新索引
op.create_index("ix_providers_name", "providers", ["name"], unique=True)
# ========== 6. provider_api_keys: max_concurrent -> rpm_limit ==========
if _column_exists("provider_api_keys", "max_concurrent"):
op.alter_column("provider_api_keys", "max_concurrent", new_column_name="rpm_limit")
if _column_exists("provider_api_keys", "learned_max_concurrent"):
op.alter_column("provider_api_keys", "learned_max_concurrent", new_column_name="learned_rpm_limit")
if _column_exists("provider_api_keys", "last_concurrent_peak"):
op.alter_column("provider_api_keys", "last_concurrent_peak", new_column_name="last_rpm_peak")
# 删除废弃字段
for col in ["rate_limit", "daily_limit", "monthly_limit"]:
if _column_exists("provider_api_keys", col):
op.drop_column("provider_api_keys", col)
# ========== 7. provider_api_keys: 健康度改为按格式存储 ==========
if not _column_exists("provider_api_keys", "health_by_format"):
op.add_column(
"provider_api_keys",
sa.Column(
"health_by_format",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
comment="按API格式存储的健康度数据",
),
)
if not _column_exists("provider_api_keys", "circuit_breaker_by_format"):
op.add_column(
"provider_api_keys",
sa.Column(
"circuit_breaker_by_format",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
comment="按API格式存储的熔断器状态",
),
)
# 数据迁移:如果存在旧字段,迁移数据到新结构
if _column_exists("provider_api_keys", "health_score"):
op.execute("""
UPDATE provider_api_keys
SET health_by_format = (
SELECT jsonb_object_agg(
elem,
jsonb_build_object(
'health_score', COALESCE(health_score, 1.0),
'consecutive_failures', COALESCE(consecutive_failures, 0),
'last_failure_at', last_failure_at,
'request_results_window', COALESCE(request_results_window::jsonb, '[]'::jsonb)
)
)
FROM jsonb_array_elements_text(api_formats::jsonb) AS elem
)
WHERE api_formats IS NOT NULL
AND api_formats::text != '[]'
AND health_by_format IS NULL
""")
# Circuit Breaker 迁移策略:
# 不复制旧的 circuit_breaker_open 状态到所有 format而是全部重置为 closed
# 原因:旧的单一 circuit breaker 状态可能因某一个 format 失败而打开,
# 如果复制到所有 format会导致其他正常工作的 format 被错误标记为不可用
if _column_exists("provider_api_keys", "circuit_breaker_open"):
op.execute("""
UPDATE provider_api_keys
SET circuit_breaker_by_format = (
SELECT jsonb_object_agg(
elem,
jsonb_build_object(
'open', false,
'open_at', NULL,
'next_probe_at', NULL,
'half_open_until', NULL,
'half_open_successes', 0,
'half_open_failures', 0
)
)
FROM jsonb_array_elements_text(api_formats::jsonb) AS elem
)
WHERE api_formats IS NOT NULL
AND api_formats::text != '[]'
AND circuit_breaker_by_format IS NULL
""")
# 设置默认空对象
op.execute("""
UPDATE provider_api_keys
SET health_by_format = '{}'::jsonb
WHERE health_by_format IS NULL
""")
op.execute("""
UPDATE provider_api_keys
SET circuit_breaker_by_format = '{}'::jsonb
WHERE circuit_breaker_by_format IS NULL
""")
# 创建 GIN 索引
if not _index_exists("provider_api_keys", "ix_provider_api_keys_health_by_format"):
op.create_index(
"ix_provider_api_keys_health_by_format",
"provider_api_keys",
["health_by_format"],
postgresql_using="gin",
)
if not _index_exists("provider_api_keys", "ix_provider_api_keys_circuit_breaker_by_format"):
op.create_index(
"ix_provider_api_keys_circuit_breaker_by_format",
"provider_api_keys",
["circuit_breaker_by_format"],
postgresql_using="gin",
)
# 删除旧字段
old_health_columns = [
"health_score",
"consecutive_failures",
"last_failure_at",
"request_results_window",
"circuit_breaker_open",
"circuit_breaker_open_at",
"next_probe_at",
"half_open_until",
"half_open_successes",
"half_open_failures",
]
for col in old_health_columns:
if _column_exists("provider_api_keys", col):
op.drop_column("provider_api_keys", col)
# ========== 8. provider_endpoints: 删除废弃的 rate_limit 列 ==========
if _column_exists("provider_endpoints", "rate_limit"):
op.drop_column("provider_endpoints", "rate_limit")
# ========== 9. usage: 添加 client_response_headers ==========
if not _column_exists("usage", "client_response_headers"):
op.add_column(
"usage",
sa.Column("client_response_headers", sa.JSON(), nullable=True),
)
# ========== 10. provider_api_keys: 删除 endpoint_id ==========
# Key 不再与 Endpoint 绑定,通过 provider_id + api_formats 关联
if _column_exists("provider_api_keys", "endpoint_id"):
# 确保外键已删除(前面可能已经删除)
try:
bind = op.get_bind()
inspector = inspect(bind)
for fk in inspector.get_foreign_keys("provider_api_keys"):
constrained = fk.get("constrained_columns") or []
if "endpoint_id" in constrained:
name = fk.get("name")
if name:
op.drop_constraint(name, "provider_api_keys", type_="foreignkey")
except Exception:
pass # 外键可能已经不存在
op.drop_column("provider_api_keys", "endpoint_id")
# ========== 11. provider_endpoints: 删除废弃的 max_concurrent 列 ==========
if _column_exists("provider_endpoints", "max_concurrent"):
op.drop_column("provider_endpoints", "max_concurrent")
# ========== 12. providers: 删除废弃的 RPM 相关字段 ==========
if _column_exists("providers", "rpm_limit"):
op.drop_column("providers", "rpm_limit")
if _column_exists("providers", "rpm_used"):
op.drop_column("providers", "rpm_used")
if _column_exists("providers", "rpm_reset_at"):
op.drop_column("providers", "rpm_reset_at")
alembic_logger.info("[OK] Consolidated migration completed successfully")
def downgrade() -> None:
"""
Downgrade is complex due to data migrations.
For safety, this only removes new columns without restoring old structure.
Manual intervention may be required for full rollback.
"""
bind = op.get_bind()
# 12. 恢复 providers RPM 相关字段
if not _column_exists("providers", "rpm_limit"):
op.add_column("providers", sa.Column("rpm_limit", sa.Integer(), nullable=True))
if not _column_exists("providers", "rpm_used"):
op.add_column(
"providers",
sa.Column("rpm_used", sa.Integer(), server_default="0", nullable=True),
)
if not _column_exists("providers", "rpm_reset_at"):
op.add_column(
"providers",
sa.Column("rpm_reset_at", sa.DateTime(timezone=True), nullable=True),
)
# 11. 恢复 provider_endpoints.max_concurrent
if not _column_exists("provider_endpoints", "max_concurrent"):
op.add_column("provider_endpoints", sa.Column("max_concurrent", sa.Integer(), nullable=True))
# 10. 恢复 endpoint_id
if not _column_exists("provider_api_keys", "endpoint_id"):
op.add_column("provider_api_keys", sa.Column("endpoint_id", sa.String(36), nullable=True))
# 9. 删除 client_response_headers
if _column_exists("usage", "client_response_headers"):
op.drop_column("usage", "client_response_headers")
# 8. 恢复 provider_endpoints.rate_limit如果需要
if not _column_exists("provider_endpoints", "rate_limit"):
op.add_column("provider_endpoints", sa.Column("rate_limit", sa.Integer(), nullable=True))
# 7. 删除健康度 JSON 字段
bind.execute(sa.text("DROP INDEX IF EXISTS ix_provider_api_keys_health_by_format"))
bind.execute(sa.text("DROP INDEX IF EXISTS ix_provider_api_keys_circuit_breaker_by_format"))
if _column_exists("provider_api_keys", "health_by_format"):
op.drop_column("provider_api_keys", "health_by_format")
if _column_exists("provider_api_keys", "circuit_breaker_by_format"):
op.drop_column("provider_api_keys", "circuit_breaker_by_format")
# 6. rpm_limit -> max_concurrent简化版仅重命名
if _column_exists("provider_api_keys", "rpm_limit"):
op.alter_column("provider_api_keys", "rpm_limit", new_column_name="max_concurrent")
if _column_exists("provider_api_keys", "learned_rpm_limit"):
op.alter_column("provider_api_keys", "learned_rpm_limit", new_column_name="learned_max_concurrent")
if _column_exists("provider_api_keys", "last_rpm_peak"):
op.alter_column("provider_api_keys", "last_rpm_peak", new_column_name="last_concurrent_peak")
# 恢复已删除的字段
if not _column_exists("provider_api_keys", "rate_limit"):
op.add_column("provider_api_keys", sa.Column("rate_limit", sa.Integer(), nullable=True))
if not _column_exists("provider_api_keys", "daily_limit"):
op.add_column("provider_api_keys", sa.Column("daily_limit", sa.Integer(), nullable=True))
if not _column_exists("provider_api_keys", "monthly_limit"):
op.add_column("provider_api_keys", sa.Column("monthly_limit", sa.Integer(), nullable=True))
# 5. name -> display_name (需要先删除索引)
if _index_exists("providers", "ix_providers_name"):
op.drop_index("ix_providers_name", table_name="providers")
op.alter_column("providers", "name", new_column_name="display_name")
# 重新添加原 name 字段
op.add_column("providers", sa.Column("name", sa.String(100), nullable=True))
op.execute("""
UPDATE providers
SET name = LOWER(REPLACE(REPLACE(display_name, ' ', '_'), '-', '_'))
""")
op.alter_column("providers", "name", nullable=False)
op.create_index("ix_providers_name", "providers", ["name"], unique=True)
# 4. 删除 providers 的 timeout, max_retries, proxy
if _column_exists("providers", "proxy"):
op.drop_column("providers", "proxy")
if _column_exists("providers", "max_retries"):
op.drop_column("providers", "max_retries")
if _column_exists("providers", "timeout"):
op.drop_column("providers", "timeout")
# 3. models: global_model_id 改回 NOT NULL
result = bind.execute(sa.text(
"SELECT COUNT(*) FROM models WHERE global_model_id IS NULL"
))
orphan_model_count = result.scalar() or 0
if orphan_model_count > 0:
alembic_logger.warning(f"[WARN] 发现 {orphan_model_count} 个无 global_model_id 的独立模型,将被删除")
op.execute("DELETE FROM models WHERE global_model_id IS NULL")
alembic_logger.info(f"已删除 {orphan_model_count} 个独立模型")
op.alter_column("models", "global_model_id", nullable=False)
# 2. 删除 rate_multipliers
if _column_exists("provider_api_keys", "rate_multipliers"):
op.drop_column("provider_api_keys", "rate_multipliers")
# 1. 删除 provider_id 和 api_formats
if _index_exists("provider_api_keys", "idx_provider_api_keys_provider_id"):
op.drop_index("idx_provider_api_keys_provider_id", table_name="provider_api_keys")
if _constraint_exists("provider_api_keys", "fk_provider_api_keys_provider"):
op.drop_constraint("fk_provider_api_keys_provider", "provider_api_keys", type_="foreignkey")
if _column_exists("provider_api_keys", "api_formats"):
op.drop_column("provider_api_keys", "api_formats")
if _column_exists("provider_api_keys", "provider_id"):
op.drop_column("provider_api_keys", "provider_id")
# 恢复 endpoint_id 外键(简化版:仅创建外键,不强制 NOT NULL
if _column_exists("provider_api_keys", "endpoint_id"):
if not _constraint_exists("provider_api_keys", "provider_api_keys_endpoint_id_fkey"):
op.create_foreign_key(
"provider_api_keys_endpoint_id_fkey",
"provider_api_keys",
"provider_endpoints",
["endpoint_id"],
["id"],
ondelete="SET NULL",
)
alembic_logger.info("[OK] Downgrade completed (simplified version)")

View File

@@ -88,9 +88,28 @@ build_base() {
save_deps_hash save_deps_hash
} }
# 生成版本文件
generate_version_file() {
# 从 git 获取版本号
local version
version=$(git describe --tags --always 2>/dev/null | sed 's/^v//')
if [ -z "$version" ]; then
version="unknown"
fi
echo ">>> Generating version file: $version"
cat > src/_version.py << EOF
# Auto-generated by deploy.sh - do not edit
__version__ = '$version'
__version_tuple__ = tuple(int(x) for x in '$version'.split('-')[0].split('.') if x.isdigit())
version = __version__
version_tuple = __version_tuple__
EOF
}
# 构建应用镜像 # 构建应用镜像
build_app() { build_app() {
echo ">>> Building app image (code only)..." echo ">>> Building app image (code only)..."
generate_version_file
docker build -f Dockerfile.app.local -t aether-app:latest . docker build -f Dockerfile.app.local -t aether-app:latest .
save_code_hash save_code_hash
} }

8
entrypoint.sh Normal file
View File

@@ -0,0 +1,8 @@
#!/bin/bash
set -e
echo "Running database migrations..."
alembic upgrade head
echo "Starting application..."
exec "$@"

View File

@@ -67,7 +67,6 @@ export interface GlobalModelExport {
export interface ProviderExport { export interface ProviderExport {
name: string name: string
display_name: string
description?: string | null description?: string | null
website?: string | null website?: string | null
billing_type?: string | null billing_type?: string | null
@@ -76,10 +75,13 @@ export interface ProviderExport {
rpm_limit?: number | null rpm_limit?: number | null
provider_priority?: number provider_priority?: number
is_active: boolean is_active: boolean
rate_limit?: number | null
concurrent_limit?: number | null concurrent_limit?: number | null
timeout?: number | null
max_retries?: number | null
proxy?: any
config?: any config?: any
endpoints: EndpointExport[] endpoints: EndpointExport[]
api_keys: ProviderKeyExport[]
models: ModelExport[] models: ModelExport[]
} }
@@ -89,27 +91,26 @@ export interface EndpointExport {
headers?: any headers?: any
timeout?: number timeout?: number
max_retries?: number max_retries?: number
max_concurrent?: number | null
rate_limit?: number | null
is_active: boolean is_active: boolean
custom_path?: string | null custom_path?: string | null
config?: any config?: any
keys: KeyExport[] proxy?: any
} }
export interface KeyExport { export interface ProviderKeyExport {
api_key: string api_key: string
name?: string | null name?: string | null
note?: string | null note?: string | null
api_formats: string[]
rate_multiplier?: number rate_multiplier?: number
rate_multipliers?: Record<string, number> | null
internal_priority?: number internal_priority?: number
global_priority?: number | null global_priority?: number | null
max_concurrent?: number | null rpm_limit?: number | null
rate_limit?: number | null allowed_models?: any
daily_limit?: number | null
monthly_limit?: number | null
allowed_models?: string[] | null
capabilities?: any capabilities?: any
cache_ttl_minutes?: number
max_probe_interval_minutes?: number
is_active: boolean is_active: boolean
} }
@@ -159,6 +160,15 @@ export interface EmailTemplateResetResponse {
} }
} }
// 检查更新响应
export interface CheckUpdateResponse {
current_version: string
latest_version: string | null
has_update: boolean
release_url: string | null
error: string | null
}
// LDAP 配置响应 // LDAP 配置响应
export interface LdapConfigResponse { export interface LdapConfigResponse {
server_url: string | null server_url: string | null
@@ -526,6 +536,14 @@ export const adminApi = {
return response.data return response.data
}, },
// 检查系统更新
async checkUpdate(): Promise<CheckUpdateResponse> {
const response = await apiClient.get<CheckUpdateResponse>(
'/api/admin/system/check-update'
)
return response.data
},
// LDAP 配置相关 // LDAP 配置相关
// 获取 LDAP 配置 // 获取 LDAP 配置
async getLdapConfig(): Promise<LdapConfigResponse> { async getLdapConfig(): Promise<LdapConfigResponse> {

View File

@@ -155,6 +155,7 @@ export interface RequestDetail {
request_body?: Record<string, any> request_body?: Record<string, any>
provider_request_headers?: Record<string, any> provider_request_headers?: Record<string, any>
response_headers?: Record<string, any> response_headers?: Record<string, any>
client_response_headers?: Record<string, any>
response_body?: Record<string, any> response_body?: Record<string, any>
metadata?: Record<string, any> metadata?: Record<string, any>
// 阶梯计费信息 // 阶梯计费信息

View File

@@ -14,7 +14,7 @@ export async function toggleAdaptiveMode(
message: string message: string
key_id: string key_id: string
is_adaptive: boolean is_adaptive: boolean
max_concurrent: number | null rpm_limit: number | null
effective_limit: number | null effective_limit: number | null
}> { }> {
const response = await client.patch(`/api/admin/adaptive/keys/${keyId}/mode`, data) const response = await client.patch(`/api/admin/adaptive/keys/${keyId}/mode`, data)
@@ -22,16 +22,16 @@ export async function toggleAdaptiveMode(
} }
/** /**
* 设置 Key 的固定并发限制 * 设置 Key 的固定 RPM 限制
*/ */
export async function setConcurrentLimit( export async function setRpmLimit(
keyId: string, keyId: string,
limit: number limit: number
): Promise<{ ): Promise<{
message: string message: string
key_id: string key_id: string
is_adaptive: boolean is_adaptive: boolean
max_concurrent: number rpm_limit: number
previous_mode: string previous_mode: string
}> { }> {
const response = await client.patch(`/api/admin/adaptive/keys/${keyId}/limit`, null, { const response = await client.patch(`/api/admin/adaptive/keys/${keyId}/limit`, null, {

View File

@@ -27,15 +27,9 @@ export async function createEndpoint(
api_format: string api_format: string
base_url: string base_url: string
custom_path?: string custom_path?: string
auth_type?: string
auth_header?: string
headers?: Record<string, string> headers?: Record<string, string>
timeout?: number timeout?: number
max_retries?: number max_retries?: number
priority?: number
weight?: number
max_concurrent?: number
rate_limit?: number
is_active?: boolean is_active?: boolean
config?: Record<string, any> config?: Record<string, any>
proxy?: ProxyConfig | null proxy?: ProxyConfig | null
@@ -52,16 +46,10 @@ export async function updateEndpoint(
endpointId: string, endpointId: string,
data: Partial<{ data: Partial<{
base_url: string base_url: string
custom_path: string custom_path: string | null
auth_type: string
auth_header: string
headers: Record<string, string> headers: Record<string, string>
timeout: number timeout: number
max_retries: number max_retries: number
priority: number
weight: number
max_concurrent: number
rate_limit: number
is_active: boolean is_active: boolean
config: Record<string, any> config: Record<string, any>
proxy: ProxyConfig | null proxy: ProxyConfig | null
@@ -74,7 +62,7 @@ export async function updateEndpoint(
/** /**
* 删除 Endpoint * 删除 Endpoint
*/ */
export async function deleteEndpoint(endpointId: string): Promise<{ message: string; deleted_keys_count: number }> { export async function deleteEndpoint(endpointId: string): Promise<{ message: string; affected_keys_count: number }> {
const response = await client.delete(`/api/admin/endpoints/${endpointId}`) const response = await client.delete(`/api/admin/endpoints/${endpointId}`)
return response.data return response.data
} }

View File

@@ -32,16 +32,21 @@ export async function getKeyHealth(keyId: string): Promise<HealthStatus> {
/** /**
* 恢复Key健康状态一键恢复重置健康度 + 关闭熔断器 + 取消自动禁用) * 恢复Key健康状态一键恢复重置健康度 + 关闭熔断器 + 取消自动禁用)
* @param keyId Key ID
* @param apiFormat 可选,指定 API 格式(如 CLAUDE、OPENAI不指定则恢复所有格式
*/ */
export async function recoverKeyHealth(keyId: string): Promise<{ export async function recoverKeyHealth(keyId: string, apiFormat?: string): Promise<{
message: string message: string
details: { details: {
api_format?: string
health_score: number health_score: number
circuit_breaker_open: boolean circuit_breaker_open: boolean
is_active: boolean is_active: boolean
} }
}> { }> {
const response = await client.patch(`/api/admin/endpoints/health/keys/${keyId}`) const response = await client.patch(`/api/admin/endpoints/health/keys/${keyId}`, null, {
params: apiFormat ? { api_format: apiFormat } : undefined
})
return response.data return response.data
} }

View File

@@ -1,5 +1,5 @@
import client from '../client' import client from '../client'
import type { EndpointAPIKey } from './types' import type { EndpointAPIKey, AllowedModels } from './types'
/** /**
* 能力定义类型 * 能力定义类型
@@ -49,67 +49,6 @@ export async function getModelCapabilities(modelName: string): Promise<ModelCapa
return response.data return response.data
} }
/**
* 获取 Endpoint 的所有 Keys
*/
export async function getEndpointKeys(endpointId: string): Promise<EndpointAPIKey[]> {
const response = await client.get(`/api/admin/endpoints/${endpointId}/keys`)
return response.data
}
/**
* 为 Endpoint 添加 Key
*/
export async function addEndpointKey(
endpointId: string,
data: {
endpoint_id: string
api_key: string
name: string // 密钥名称(必填)
rate_multiplier?: number // 成本倍率(默认 1.0
internal_priority?: number // Endpoint 内部优先级(数字越小越优先)
max_concurrent?: number // 最大并发数(留空=自适应模式)
rate_limit?: number
daily_limit?: number
monthly_limit?: number
cache_ttl_minutes?: number // 缓存 TTL分钟0=禁用
max_probe_interval_minutes?: number // 熔断探测间隔(分钟)
allowed_models?: string[] // 允许使用的模型列表
capabilities?: Record<string, boolean> // 能力标签配置
note?: string // 备注说明(可选)
}
): Promise<EndpointAPIKey> {
const response = await client.post(`/api/admin/endpoints/${endpointId}/keys`, data)
return response.data
}
/**
* 更新 Endpoint Key
*/
export async function updateEndpointKey(
keyId: string,
data: Partial<{
api_key: string
name: string // 密钥名称
rate_multiplier: number // 成本倍率
internal_priority: number // Endpoint 内部优先级(提供商优先模式,数字越小越优先)
global_priority: number // 全局 Key 优先级(全局 Key 优先模式,数字越小越优先)
max_concurrent: number // 最大并发数(留空=自适应模式)
rate_limit: number
daily_limit: number
monthly_limit: number
cache_ttl_minutes: number // 缓存 TTL分钟0=禁用
max_probe_interval_minutes: number // 熔断探测间隔(分钟)
allowed_models: string[] | null // 允许使用的模型列表null 表示允许所有
capabilities: Record<string, boolean> | null // 能力标签配置
is_active: boolean
note: string // 备注说明
}>
): Promise<EndpointAPIKey> {
const response = await client.put(`/api/admin/endpoints/keys/${keyId}`, data)
return response.data
}
/** /**
* 获取完整的 API Key用于查看和复制 * 获取完整的 API Key用于查看和复制
*/ */
@@ -119,22 +58,71 @@ export async function revealEndpointKey(keyId: string): Promise<{ api_key: strin
} }
/** /**
* 删除 Endpoint Key * 删除 Key
*/ */
export async function deleteEndpointKey(keyId: string): Promise<{ message: string }> { export async function deleteEndpointKey(keyId: string): Promise<{ message: string }> {
const response = await client.delete(`/api/admin/endpoints/keys/${keyId}`) const response = await client.delete(`/api/admin/endpoints/keys/${keyId}`)
return response.data return response.data
} }
// ========== Provider 级别的 Keys API ==========
/** /**
* 批量更新 Endpoint Keys 的优先级(用于拖动排序) * 获取 Provider 的所有 Keys
*/ */
export async function batchUpdateKeyPriority( export async function getProviderKeys(providerId: string): Promise<EndpointAPIKey[]> {
endpointId: string, const response = await client.get(`/api/admin/endpoints/providers/${providerId}/keys`)
priorities: Array<{ key_id: string; internal_priority: number }> return response.data
): Promise<{ message: string; updated_count: number }> { }
const response = await client.put(`/api/admin/endpoints/${endpointId}/keys/batch-priority`, {
priorities /**
}) * 为 Provider 添加 Key
*/
export async function addProviderKey(
providerId: string,
data: {
api_formats: string[] // 支持的 API 格式列表(必填)
api_key: string
name: string
rate_multiplier?: number // 默认成本倍率
rate_multipliers?: Record<string, number> | null // 按 API 格式的成本倍率
internal_priority?: number
rpm_limit?: number | null // RPM 限制(留空=自适应模式)
cache_ttl_minutes?: number
max_probe_interval_minutes?: number
allowed_models?: AllowedModels
capabilities?: Record<string, boolean>
note?: string
}
): Promise<EndpointAPIKey> {
const response = await client.post(`/api/admin/endpoints/providers/${providerId}/keys`, data)
return response.data
}
/**
* 更新 Key
*/
export async function updateProviderKey(
keyId: string,
data: Partial<{
api_formats: string[] // 支持的 API 格式列表
api_key: string
name: string
rate_multiplier: number // 默认成本倍率
rate_multipliers: Record<string, number> | null // 按 API 格式的成本倍率
internal_priority: number
global_priority: number | null
rpm_limit: number | null // RPM 限制(留空=自适应模式)
cache_ttl_minutes: number
max_probe_interval_minutes: number
allowed_models: AllowedModels
capabilities: Record<string, boolean> | null
is_active: boolean
note: string
}>
): Promise<EndpointAPIKey> {
const response = await client.put(`/api/admin/endpoints/keys/${keyId}`, data)
return response.data return response.data
} }

View File

@@ -147,14 +147,26 @@ export async function queryProviderUpstreamModels(
/** /**
* 从上游提供商导入模型 * 从上游提供商导入模型
* @param providerId 提供商 ID
* @param modelIds 模型 ID 列表
* @param options 可选配置
* @param options.tiered_pricing 阶梯计费配置
* @param options.price_per_request 按次计费价格
*/ */
export async function importModelsFromUpstream( export async function importModelsFromUpstream(
providerId: string, providerId: string,
modelIds: string[] modelIds: string[],
options?: {
tiered_pricing?: object
price_per_request?: number
}
): Promise<ImportFromUpstreamResponse> { ): Promise<ImportFromUpstreamResponse> {
const response = await client.post( const response = await client.post(
`/api/admin/providers/${providerId}/import-from-upstream`, `/api/admin/providers/${providerId}/import-from-upstream`,
{ model_ids: modelIds } {
model_ids: modelIds,
...options
}
) )
return response.data return response.data
} }

View File

@@ -1,5 +1,5 @@
import client from '../client' import client from '../client'
import type { ProviderWithEndpointsSummary } from './types' import type { ProviderWithEndpointsSummary, ProxyConfig } from './types'
/** /**
* 获取 Providers 摘要(包含 Endpoints 统计) * 获取 Providers 摘要(包含 Endpoints 统计)
@@ -23,7 +23,7 @@ export async function getProvider(providerId: string): Promise<ProviderWithEndpo
export async function updateProvider( export async function updateProvider(
providerId: string, providerId: string,
data: Partial<{ data: Partial<{
display_name: string name: string
description: string description: string
website: string website: string
provider_priority: number provider_priority: number
@@ -33,6 +33,10 @@ export async function updateProvider(
quota_last_reset_at: string // 周期开始时间 quota_last_reset_at: string // 周期开始时间
quota_expires_at: string quota_expires_at: string
rpm_limit: number | null rpm_limit: number | null
// 请求配置(从 Endpoint 迁移)
timeout: number
max_retries: number
proxy: ProxyConfig | null
cache_ttl_minutes: number // 0表示不支持缓存>0表示支持缓存并设置TTL(分钟) cache_ttl_minutes: number // 0表示不支持缓存>0表示支持缓存并设置TTL(分钟)
max_probe_interval_minutes: number max_probe_interval_minutes: number
is_active: boolean is_active: boolean
@@ -83,7 +87,6 @@ export interface TestModelResponse {
provider?: { provider?: {
id: string id: string
name: string name: string
display_name: string
} }
model?: string model?: string
} }
@@ -92,4 +95,3 @@ export async function testModel(data: TestModelRequest): Promise<TestModelRespon
const response = await client.post('/api/admin/provider-query/test-model', data) const response = await client.post('/api/admin/provider-query/test-model', data)
return response.data return response.data
} }

View File

@@ -20,6 +20,38 @@ export const API_FORMAT_LABELS: Record<string, string> = {
[API_FORMATS.GEMINI_CLI]: 'Gemini CLI', [API_FORMATS.GEMINI_CLI]: 'Gemini CLI',
} }
// API 格式缩写映射(用于空间紧凑的显示场景)
export const API_FORMAT_SHORT: Record<string, string> = {
[API_FORMATS.OPENAI]: 'O',
[API_FORMATS.OPENAI_CLI]: 'OC',
[API_FORMATS.CLAUDE]: 'C',
[API_FORMATS.CLAUDE_CLI]: 'CC',
[API_FORMATS.GEMINI]: 'G',
[API_FORMATS.GEMINI_CLI]: 'GC',
}
// API 格式排序顺序(统一的显示顺序)
export const API_FORMAT_ORDER: string[] = [
API_FORMATS.OPENAI,
API_FORMATS.OPENAI_CLI,
API_FORMATS.CLAUDE,
API_FORMATS.CLAUDE_CLI,
API_FORMATS.GEMINI,
API_FORMATS.GEMINI_CLI,
]
// 工具函数:按标准顺序排序 API 格式数组
export function sortApiFormats(formats: string[]): string[] {
return [...formats].sort((a, b) => {
const aIdx = API_FORMAT_ORDER.indexOf(a)
const bIdx = API_FORMAT_ORDER.indexOf(b)
if (aIdx === -1 && bIdx === -1) return 0
if (aIdx === -1) return 1
if (bIdx === -1) return -1
return aIdx - bIdx
})
}
/** /**
* 代理配置类型 * 代理配置类型
*/ */
@@ -37,18 +69,9 @@ export interface ProviderEndpoint {
api_format: string api_format: string
base_url: string base_url: string
custom_path?: string // 自定义请求路径(可选,为空则使用 API 格式默认路径) custom_path?: string // 自定义请求路径(可选,为空则使用 API 格式默认路径)
auth_type: string
auth_header?: string
headers?: Record<string, string> headers?: Record<string, string>
timeout: number timeout: number
max_retries: number max_retries: number
priority: number
weight: number
max_concurrent?: number
rate_limit?: number
health_score: number
consecutive_failures: number
last_failure_at?: string
is_active: boolean is_active: boolean
config?: Record<string, any> config?: Record<string, any>
proxy?: ProxyConfig | null proxy?: ProxyConfig | null
@@ -58,25 +81,55 @@ export interface ProviderEndpoint {
updated_at: string updated_at: string
} }
/**
* 模型权限配置类型(支持简单列表和按格式字典两种模式)
*
* 使用示例:
* 1. 不限制(允许所有模型): null
* 2. 简单列表模式(所有 API 格式共享同一个白名单): ["gpt-4", "claude-3-opus"]
* 3. 按格式字典模式(不同 API 格式使用不同的白名单):
* { "OPENAI": ["gpt-4"], "CLAUDE": ["claude-3-opus"] }
*/
export type AllowedModels = string[] | Record<string, string[]> | null
// AllowedModels 类型守卫函数
export function isAllowedModelsList(value: AllowedModels): value is string[] {
return Array.isArray(value)
}
export function isAllowedModelsDict(value: AllowedModels): value is Record<string, string[]> {
if (value === null || typeof value !== 'object' || Array.isArray(value)) {
return false
}
// 验证所有值都是字符串数组
return Object.values(value).every(
(v) => Array.isArray(v) && v.every((item) => typeof item === 'string')
)
}
export interface EndpointAPIKey { export interface EndpointAPIKey {
id: string id: string
endpoint_id: string provider_id: string
api_formats: string[] // 支持的 API 格式列表
api_key_masked: string api_key_masked: string
api_key_plain?: string | null api_key_plain?: string | null
name: string // 密钥名称(必填,用于识别) name: string // 密钥名称(必填,用于识别)
rate_multiplier: number // 成本倍率(真实成本 = 表面成本 × 倍率) rate_multiplier: number // 默认成本倍率(真实成本 = 表面成本 × 倍率)
internal_priority: number // Endpoint 内部优先级 rate_multipliers?: Record<string, number> | null // 按 API 格式的成本倍率,如 {"CLAUDE": 1.0, "OPENAI": 0.8}
internal_priority: number // Key 内部优先级
global_priority?: number | null // 全局 Key 优先级 global_priority?: number | null // 全局 Key 优先级
max_concurrent?: number rpm_limit?: number | null // RPM 速率限制 (1-10000)null 表示自适应模式
rate_limit?: number allowed_models?: AllowedModels // 允许使用的模型列表null=不限制,列表=简单白名单,字典=按格式区分)
daily_limit?: number
monthly_limit?: number
allowed_models?: string[] | null // 允许使用的模型列表null = 支持所有模型)
capabilities?: Record<string, boolean> | null // 能力标签配置(如 cache_1h, context_1m capabilities?: Record<string, boolean> | null // 能力标签配置(如 cache_1h, context_1m
// 缓存与熔断配置 // 缓存与熔断配置
cache_ttl_minutes: number // 缓存 TTL分钟0=禁用 cache_ttl_minutes: number // 缓存 TTL分钟0=禁用
max_probe_interval_minutes: number // 熔断探测间隔(分钟) max_probe_interval_minutes: number // 熔断探测间隔(分钟)
// 按格式的健康度数据
health_by_format?: Record<string, FormatHealthData>
circuit_breaker_by_format?: Record<string, FormatCircuitBreakerData>
// 聚合字段(从 health_by_format 计算,用于列表显示)
health_score: number health_score: number
circuit_breaker_open?: boolean
consecutive_failures: number consecutive_failures: number
last_failure_at?: string last_failure_at?: string
request_count: number request_count: number
@@ -89,10 +142,10 @@ export interface EndpointAPIKey {
last_used_at?: string last_used_at?: string
created_at: string created_at: string
updated_at: string updated_at: string
// 自适应并发字段 // 自适应 RPM 字段
is_adaptive?: boolean // 是否为自适应模式(max_concurrent=NULL is_adaptive?: boolean // 是否为自适应模式(rpm_limit=NULL
effective_limit?: number // 当前有效限制(自适应使用学习值,固定使用配置值) effective_limit?: number // 当前有效 RPM 限制(自适应使用学习值,固定使用配置值)
learned_max_concurrent?: number learned_rpm_limit?: number // 学习到的 RPM 限制
// 滑动窗口利用率采样 // 滑动窗口利用率采样
utilization_samples?: Array<{ ts: number; util: number }> // 利用率采样窗口 utilization_samples?: Array<{ ts: number; util: number }> // 利用率采样窗口
last_probe_increase_at?: string // 上次探测性扩容时间 last_probe_increase_at?: string // 上次探测性扩容时间
@@ -100,8 +153,7 @@ export interface EndpointAPIKey {
rpm_429_count?: number rpm_429_count?: number
last_429_at?: string last_429_at?: string
last_429_type?: string last_429_type?: string
// 熔断器字段(滑动窗口 + 半开模式) // 单格式场景的熔断器字段
circuit_breaker_open?: boolean
circuit_breaker_open_at?: string circuit_breaker_open_at?: string
next_probe_at?: string next_probe_at?: string
half_open_until?: string half_open_until?: string
@@ -110,17 +162,36 @@ export interface EndpointAPIKey {
request_results_window?: Array<{ ts: number; ok: boolean }> // 请求结果滑动窗口 request_results_window?: Array<{ ts: number; ok: boolean }> // 请求结果滑动窗口
} }
// 按格式的健康度数据
export interface FormatHealthData {
health_score: number
error_rate: number
window_size: number
consecutive_failures: number
last_failure_at?: string | null
circuit_breaker: FormatCircuitBreakerData
}
// 按格式的熔断器数据
export interface FormatCircuitBreakerData {
open: boolean
open_at?: string | null
next_probe_at?: string | null
half_open_until?: string | null
half_open_successes: number
half_open_failures: number
}
export interface EndpointAPIKeyUpdate { export interface EndpointAPIKeyUpdate {
api_formats?: string[] // 支持的 API 格式列表
name?: string name?: string
api_key?: string // 仅在需要更新时提供 api_key?: string // 仅在需要更新时提供
rate_multiplier?: number rate_multiplier?: number // 默认成本倍率
rate_multipliers?: Record<string, number> | null // 按 API 格式的成本倍率
internal_priority?: number internal_priority?: number
global_priority?: number | null global_priority?: number | null
max_concurrent?: number | null // null 表示切换为自适应模式 rpm_limit?: number | null // RPM 速率限制 (1-10000)null 表示切换为自适应模式
rate_limit?: number allowed_models?: AllowedModels
daily_limit?: number
monthly_limit?: number
allowed_models?: string[] | null
capabilities?: Record<string, boolean> | null capabilities?: Record<string, boolean> | null
cache_ttl_minutes?: number cache_ttl_minutes?: number
max_probe_interval_minutes?: number max_probe_interval_minutes?: number
@@ -198,7 +269,6 @@ export interface PublicEndpointStatusMonitorResponse {
export interface ProviderWithEndpointsSummary { export interface ProviderWithEndpointsSummary {
id: string id: string
name: string name: string
display_name: string
description?: string description?: string
website?: string website?: string
provider_priority: number provider_priority: number
@@ -208,9 +278,10 @@ export interface ProviderWithEndpointsSummary {
quota_reset_day?: number quota_reset_day?: number
quota_last_reset_at?: string // 当前周期开始时间 quota_last_reset_at?: string // 当前周期开始时间
quota_expires_at?: string quota_expires_at?: string
rpm_limit?: number | null // 请求配置(从 Endpoint 迁移)
rpm_used?: number timeout?: number // 请求超时(秒)
rpm_reset_at?: string max_retries?: number // 最大重试次数
proxy?: ProxyConfig | null // 代理配置
is_active: boolean is_active: boolean
total_endpoints: number total_endpoints: number
active_endpoints: number active_endpoints: number
@@ -253,13 +324,10 @@ export interface HealthSummary {
} }
} }
export interface ConcurrencyStatus { export interface KeyRpmStatus {
endpoint_id?: string key_id: string
endpoint_current_concurrency: number current_rpm: number
endpoint_max_concurrent?: number rpm_limit?: number
key_id?: string
key_current_concurrency: number
key_max_concurrent?: number
} }
export interface ProviderModelMapping { export interface ProviderModelMapping {
@@ -361,7 +429,6 @@ export interface ModelPriceRange {
export interface ModelCatalogProviderDetail { export interface ModelCatalogProviderDetail {
provider_id: string provider_id: string
provider_name: string provider_name: string
provider_display_name?: string | null
model_id?: string | null model_id?: string | null
target_model: string target_model: string
input_price_per_1m?: number | null input_price_per_1m?: number | null
@@ -534,10 +601,10 @@ export interface UpstreamModel {
*/ */
export interface ImportFromUpstreamSuccessItem { export interface ImportFromUpstreamSuccessItem {
model_id: string model_id: string
global_model_id: string
global_model_name: string
provider_model_id: string provider_model_id: string
created_global_model: boolean global_model_id?: string // 可选,未关联时为空字符串
global_model_name?: string // 可选,未关联时为空字符串
created_global_model: boolean // 始终为 false不再自动创建 GlobalModel
} }
/** /**

View File

@@ -0,0 +1,112 @@
<template>
<Dialog
v-model="isOpen"
size="md"
title=""
>
<div class="flex flex-col items-center text-center py-2">
<!-- Logo -->
<HeaderLogo
size="h-16 w-16"
class-name="text-primary"
/>
<!-- Title -->
<h2 class="text-xl font-semibold text-foreground mt-4 mb-2">
发现新版本
</h2>
<!-- Version Info -->
<div class="flex items-center gap-3 mb-4">
<span class="px-3 py-1.5 rounded-lg bg-muted text-sm font-mono text-muted-foreground">
v{{ currentVersion }}
</span>
<svg
class="h-4 w-4 text-muted-foreground"
fill="none"
stroke="currentColor"
viewBox="0 0 24 24"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M13 7l5 5m0 0l-5 5m5-5H6"
/>
</svg>
<span class="px-3 py-1.5 rounded-lg bg-primary/10 text-sm font-mono font-medium text-primary">
v{{ latestVersion }}
</span>
</div>
<!-- Description -->
<p class="text-sm text-muted-foreground max-w-xs">
新版本已发布建议更新以获得最新功能和安全修复
</p>
</div>
<template #footer>
<div class="flex w-full gap-3">
<Button
variant="outline"
class="flex-1"
@click="handleLater"
>
稍后提醒
</Button>
<Button
class="flex-1"
@click="handleViewRelease"
>
查看更新
</Button>
</div>
</template>
</Dialog>
</template>
<script setup lang="ts">
import { ref, watch } from 'vue'
import { Dialog } from '@/components/ui'
import Button from '@/components/ui/button.vue'
import HeaderLogo from '@/components/HeaderLogo.vue'
const props = defineProps<{
modelValue: boolean
currentVersion: string
latestVersion: string
releaseUrl: string | null
}>()
const emit = defineEmits<{
'update:modelValue': [value: boolean]
}>()
const isOpen = ref(props.modelValue)
watch(() => props.modelValue, (val) => {
isOpen.value = val
})
watch(isOpen, (val) => {
emit('update:modelValue', val)
})
function handleLater() {
// 记录忽略的版本24小时内不再提醒
const ignoreKey = 'aether_update_ignore'
const ignoreData = {
version: props.latestVersion,
until: Date.now() + 24 * 60 * 60 * 1000 // 24小时
}
localStorage.setItem(ignoreKey, JSON.stringify(ignoreData))
isOpen.value = false
}
function handleViewRelease() {
if (props.releaseUrl) {
window.open(props.releaseUrl, '_blank')
}
isOpen.value = false
}
</script>

View File

@@ -0,0 +1,15 @@
<script setup lang="ts">
import { CollapsibleContent, type CollapsibleContentProps } from 'radix-vue'
import { cn } from '@/lib/utils'
const props = defineProps<CollapsibleContentProps & { class?: string }>()
</script>
<template>
<CollapsibleContent
v-bind="props"
:class="cn('overflow-hidden data-[state=closed]:animate-collapsible-up data-[state=open]:animate-collapsible-down', props.class)"
>
<slot />
</CollapsibleContent>
</template>

View File

@@ -0,0 +1,11 @@
<script setup lang="ts">
import { CollapsibleTrigger, type CollapsibleTriggerProps } from 'radix-vue'
const props = defineProps<CollapsibleTriggerProps>()
</script>
<template>
<CollapsibleTrigger v-bind="props" as-child>
<slot />
</CollapsibleTrigger>
</template>

View File

@@ -0,0 +1,15 @@
<script setup lang="ts">
import { CollapsibleRoot, type CollapsibleRootEmits, type CollapsibleRootProps } from 'radix-vue'
import { useForwardPropsEmits } from 'radix-vue'
const props = defineProps<CollapsibleRootProps>()
const emits = defineEmits<CollapsibleRootEmits>()
const forwarded = useForwardPropsEmits(props, emits)
</script>
<template>
<CollapsibleRoot v-bind="forwarded">
<slot />
</CollapsibleRoot>
</template>

View File

@@ -65,3 +65,8 @@ export { default as RefreshButton } from './refresh-button.vue'
// Tooltip 提示系列 // Tooltip 提示系列
export { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from './tooltip' export { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from './tooltip'
// Collapsible 折叠系列
export { default as Collapsible } from './collapsible.vue'
export { default as CollapsibleTrigger } from './collapsible-trigger.vue'
export { default as CollapsibleContent } from './collapsible-content.vue'

View File

@@ -186,7 +186,7 @@
@click.stop @click.stop
@change="toggleSelection('allowed_providers', provider.id)" @change="toggleSelection('allowed_providers', provider.id)"
> >
<span class="text-sm">{{ provider.display_name || provider.name }}</span> <span class="text-sm">{{ provider.name }}</span>
</div> </div>
<div <div
v-if="providers.length === 0" v-if="providers.length === 0"

View File

@@ -460,13 +460,13 @@
<TableHead class="h-10 font-semibold"> <TableHead class="h-10 font-semibold">
Provider Provider
</TableHead> </TableHead>
<TableHead class="w-[120px] h-10 font-semibold"> <TableHead class="w-[100px] h-10 font-semibold">
能力 能力
</TableHead> </TableHead>
<TableHead class="w-[180px] h-10 font-semibold"> <TableHead class="w-[200px] h-10 font-semibold">
价格 ($/M) 价格 ($/M)
</TableHead> </TableHead>
<TableHead class="w-[80px] h-10 font-semibold text-center"> <TableHead class="w-[100px] h-10 font-semibold text-center">
操作 操作
</TableHead> </TableHead>
</TableRow> </TableRow>
@@ -484,7 +484,7 @@
:class="provider.is_active ? 'bg-green-500' : 'bg-gray-300'" :class="provider.is_active ? 'bg-green-500' : 'bg-gray-300'"
:title="provider.is_active ? '活跃' : '停用'" :title="provider.is_active ? '活跃' : '停用'"
/> />
<span class="font-medium truncate">{{ provider.display_name }}</span> <span class="font-medium truncate">{{ provider.name }}</span>
</div> </div>
</TableCell> </TableCell>
<TableCell class="py-3"> <TableCell class="py-3">
@@ -595,7 +595,7 @@
class="w-2 h-2 rounded-full shrink-0" class="w-2 h-2 rounded-full shrink-0"
:class="provider.is_active ? 'bg-green-500' : 'bg-gray-300'" :class="provider.is_active ? 'bg-green-500' : 'bg-gray-300'"
/> />
<span class="font-medium truncate">{{ provider.display_name }}</span> <span class="font-medium truncate">{{ provider.name }}</span>
</div> </div>
<div class="flex items-center gap-1 shrink-0"> <div class="flex items-center gap-1 shrink-0">
<Button <Button

View File

@@ -531,20 +531,23 @@ watch(() => props.open, async (isOpen) => {
// 加载数据 // 加载数据
async function loadData() { async function loadData() {
await Promise.all([loadGlobalModels(), loadExistingModels()]) await Promise.all([loadGlobalModels(), loadExistingModels()])
// 默认折叠全局模型组
collapsedGroups.value = new Set(['global'])
// 检查缓存,如果有缓存数据则直接使用 // 检查缓存,如果有缓存数据则直接使用
const cachedModels = getCachedModels(props.providerId) const cachedModels = getCachedModels(props.providerId)
if (cachedModels) { if (cachedModels && cachedModels.length > 0) {
upstreamModels.value = cachedModels upstreamModels.value = cachedModels
upstreamModelsLoaded.value = true upstreamModelsLoaded.value = true
// 折叠所有上游模型组 // 有多个分组时全部折叠
const allGroups = new Set(['global'])
for (const model of cachedModels) { for (const model of cachedModels) {
if (model.api_format) { if (model.api_format) {
collapsedGroups.value.add(model.api_format) allGroups.add(model.api_format)
} }
} }
collapsedGroups.value = allGroups
} else {
// 只有全局模型时展开
collapsedGroups.value = new Set()
} }
} }
@@ -585,8 +588,8 @@ async function fetchUpstreamModels(forceRefresh = false) {
} else { } else {
upstreamModels.value = result.models upstreamModels.value = result.models
upstreamModelsLoaded.value = true upstreamModelsLoaded.value = true
// 折叠所有上游模型组 // 有多个分组时全部折叠
const allGroups = new Set(collapsedGroups.value) const allGroups = new Set(['global'])
for (const model of result.models) { for (const model of result.models) {
if (model.api_format) { if (model.api_format) {
allGroups.add(model.api_format) allGroups.add(model.api_format)

View File

@@ -1,245 +1,200 @@
<template> <template>
<Dialog <Dialog
:model-value="internalOpen" :model-value="internalOpen"
:title="isEditMode ? '编辑 API 端点' : '添加 API 端点'" title="端点管理"
:description="isEditMode ? `修改 ${provider?.display_name} 的端点配置` : '为提供商添加新的 API 端点'" :description="`管理 ${provider?.name} 的 API 端点`"
:icon="isEditMode ? SquarePen : Link" :icon="Settings"
size="xl" size="2xl"
@update:model-value="handleDialogUpdate" @update:model-value="handleDialogUpdate"
> >
<form <div class="space-y-4">
class="space-y-6" <!-- 已有端点列表 -->
@submit.prevent="handleSubmit()" <div
> v-if="localEndpoints.length > 0"
<!-- API 配置 --> class="space-y-2"
<div class="space-y-4"> >
<h3 <Label class="text-muted-foreground">已配置的端点</Label>
v-if="isEditMode"
class="text-sm font-medium"
>
API 配置
</h3>
<div class="grid grid-cols-2 gap-4">
<!-- API 格式 -->
<div class="space-y-2">
<Label for="api_format">API 格式 *</Label>
<template v-if="isEditMode">
<Input
id="api_format"
v-model="form.api_format"
disabled
class="bg-muted"
/>
<p class="text-xs text-muted-foreground">
API 格式创建后不可修改
</p>
</template>
<template v-else>
<Select
v-model="form.api_format"
v-model:open="selectOpen"
required
>
<SelectTrigger>
<SelectValue placeholder="请选择 API 格式" />
</SelectTrigger>
<SelectContent>
<SelectItem
v-for="format in apiFormats"
:key="format.value"
:value="format.value"
>
{{ format.label }}
</SelectItem>
</SelectContent>
</Select>
</template>
</div>
<!-- API URL -->
<div class="space-y-2">
<Label for="base_url">API URL *</Label>
<Input
id="base_url"
v-model="form.base_url"
placeholder="https://api.example.com"
required
/>
</div>
</div>
<!-- 自定义路径 -->
<div class="space-y-2"> <div class="space-y-2">
<Label for="custom_path">自定义请求路径可选</Label> <div
<Input v-for="endpoint in localEndpoints"
id="custom_path" :key="endpoint.id"
v-model="form.custom_path" class="rounded-md border px-3 py-2"
:placeholder="defaultPathPlaceholder" :class="{ 'opacity-50': !endpoint.is_active }"
/> >
</div> <!-- 编辑模式 -->
</div> <template v-if="editingEndpointId === endpoint.id">
<div class="space-y-2">
<!-- 请求配置 --> <div class="flex items-center gap-2">
<div class="space-y-4"> <span class="text-sm font-medium w-24 shrink-0">{{ API_FORMAT_LABELS[endpoint.api_format] || endpoint.api_format }}</span>
<h3 class="text-sm font-medium"> <div class="flex items-center gap-1 ml-auto">
请求配置 <Button
</h3> variant="ghost"
size="icon"
<div class="grid grid-cols-3 gap-4"> class="h-7 w-7"
<div class="space-y-2"> title="保存"
<Label for="timeout">超时</Label> :disabled="savingEndpointId === endpoint.id"
<Input @click="saveEndpointUrl(endpoint)"
id="timeout" >
v-model.number="form.timeout" <Check class="w-3.5 h-3.5" />
type="number" </Button>
placeholder="300" <Button
/> variant="ghost"
</div> size="icon"
class="h-7 w-7"
<div class="space-y-2"> title="取消"
<Label for="max_retries">最大重试</Label> @click="cancelEdit"
<Input >
id="max_retries" <X class="w-3.5 h-3.5" />
v-model.number="form.max_retries" </Button>
type="number" </div>
placeholder="3" </div>
/> <div class="grid grid-cols-2 gap-2">
</div> <div class="space-y-1">
<Label class="text-xs text-muted-foreground">Base URL</Label>
<div class="space-y-2"> <Input
<Label for="max_concurrent">最大并发</Label> v-model="editingUrl"
<Input class="h-8 text-sm"
id="max_concurrent" placeholder="https://api.example.com"
:model-value="form.max_concurrent ?? ''" @keyup.escape="cancelEdit"
type="number" />
placeholder="无限制" </div>
@update:model-value="(v) => form.max_concurrent = parseNumberInput(v)" <div class="space-y-1">
/> <Label class="text-xs text-muted-foreground">自定义路径 (可选)</Label>
</div> <Input
</div> v-model="editingPath"
class="h-8 text-sm"
<div class="grid grid-cols-2 gap-4"> :placeholder="editingDefaultPath || '留空使用默认路径'"
<div class="space-y-2"> @keyup.escape="cancelEdit"
<Label for="rate_limit">速率限制(请求/分钟)</Label> />
<Input </div>
id="rate_limit" </div>
:model-value="form.rate_limit ?? ''" </div>
type="number" </template>
placeholder="无限制" <!-- 查看模式 -->
@update:model-value="(v) => form.rate_limit = parseNumberInput(v)" <template v-else>
/> <div class="flex items-center gap-3">
<div class="w-24 shrink-0">
<span class="text-sm font-medium">{{ API_FORMAT_LABELS[endpoint.api_format] || endpoint.api_format }}</span>
</div>
<div class="flex-1 min-w-0">
<span class="text-sm text-muted-foreground truncate block">
{{ endpoint.base_url }}{{ endpoint.custom_path ? endpoint.custom_path : '' }}
</span>
</div>
<div class="flex items-center gap-1 shrink-0">
<Button
variant="ghost"
size="icon"
class="h-7 w-7"
title="编辑"
@click="startEdit(endpoint)"
>
<Edit class="w-3.5 h-3.5" />
</Button>
<Button
variant="ghost"
size="icon"
class="h-7 w-7"
:title="endpoint.is_active ? '停用' : '启用'"
:disabled="togglingEndpointId === endpoint.id"
@click="handleToggleEndpoint(endpoint)"
>
<Power class="w-3.5 h-3.5" />
</Button>
<Button
variant="ghost"
size="icon"
class="h-7 w-7 text-destructive hover:text-destructive"
title="删除"
:disabled="deletingEndpointId === endpoint.id"
@click="handleDeleteEndpoint(endpoint)"
>
<Trash2 class="w-3.5 h-3.5" />
</Button>
</div>
</div>
</template>
</div> </div>
</div> </div>
</div> </div>
<!-- 代理配置 --> <!-- 添加新端点 -->
<div class="space-y-4"> <div
<div class="flex items-center justify-between"> v-if="availableFormats.length > 0"
<h3 class="text-sm font-medium"> class="space-y-3 pt-3 border-t"
代理配置 >
</h3> <Label class="text-muted-foreground">添加新端点</Label>
<div class="flex items-center gap-2"> <div class="flex items-end gap-3">
<Switch v-model="proxyEnabled" /> <div class="w-32 shrink-0 space-y-1.5">
<span class="text-sm text-muted-foreground">启用代理</span> <Label class="text-xs">API 格式</Label>
</div> <Select
</div> v-model="newEndpoint.api_format"
v-model:open="formatSelectOpen"
<div
v-if="proxyEnabled"
class="space-y-4 rounded-lg border p-4"
>
<div class="space-y-2">
<Label for="proxy_url">代理 URL *</Label>
<Input
id="proxy_url"
v-model="form.proxy_url"
placeholder="http://host:port 或 socks5://host:port"
required
:class="proxyUrlError ? 'border-red-500' : ''"
/>
<p
v-if="proxyUrlError"
class="text-xs text-red-500"
> >
{{ proxyUrlError }} <SelectTrigger class="h-9">
</p> <SelectValue placeholder="选择格式" />
<p </SelectTrigger>
v-else <SelectContent>
class="text-xs text-muted-foreground" <SelectItem
> v-for="format in availableFormats"
支持 HTTPHTTPSSOCKS5 代理 :key="format.value"
</p> :value="format.value"
>
{{ format.label }}
</SelectItem>
</SelectContent>
</Select>
</div> </div>
<div class="flex-1 space-y-1.5">
<div class="grid grid-cols-2 gap-4"> <Label class="text-xs">Base URL</Label>
<div class="space-y-2"> <Input
<Label for="proxy_user">用户名可选</Label> v-model="newEndpoint.base_url"
<Input placeholder="https://api.example.com"
:id="`proxy_user_${formId}`" class="h-9"
v-model="form.proxy_username" />
:name="`proxy_user_${formId}`"
placeholder="代理认证用户名"
autocomplete="off"
data-form-type="other"
data-lpignore="true"
data-1p-ignore="true"
/>
</div>
<div class="space-y-2">
<Label :for="`proxy_pass_${formId}`">密码可选</Label>
<Input
:id="`proxy_pass_${formId}`"
v-model="form.proxy_password"
:name="`proxy_pass_${formId}`"
type="text"
:placeholder="passwordPlaceholder"
autocomplete="off"
data-form-type="other"
data-lpignore="true"
data-1p-ignore="true"
:style="{ '-webkit-text-security': 'disc', 'text-security': 'disc' }"
/>
</div>
</div> </div>
<div class="w-40 shrink-0 space-y-1.5">
<Label class="text-xs">自定义路径</Label>
<Input
v-model="newEndpoint.custom_path"
:placeholder="newEndpointDefaultPath || '可选'"
class="h-9"
/>
</div>
<Button
variant="outline"
size="sm"
class="h-9 shrink-0"
:disabled="!newEndpoint.api_format || !newEndpoint.base_url || addingEndpoint"
@click="handleAddEndpoint"
>
{{ addingEndpoint ? '添加中...' : '添加' }}
</Button>
</div> </div>
</div> </div>
</form>
<!-- 空状态 -->
<div
v-if="localEndpoints.length === 0 && availableFormats.length === 0"
class="text-center py-8 text-muted-foreground"
>
<p>所有 API 格式都已配置</p>
</div>
</div>
<template #footer> <template #footer>
<Button <Button
type="button"
variant="outline" variant="outline"
:disabled="loading" @click="handleClose"
@click="handleCancel"
> >
取消 关闭
</Button>
<Button
:disabled="loading || !form.base_url || (!isEditMode && !form.api_format)"
@click="handleSubmit()"
>
{{ loading ? (isEditMode ? '保存中...' : '创建中...') : (isEditMode ? '保存修改' : '创建') }}
</Button> </Button>
</template> </template>
</Dialog> </Dialog>
<!-- 确认清空凭据对话框 -->
<AlertDialog
v-model="showClearCredentialsDialog"
title="清空代理凭据"
description="代理 URL 为空,但用户名和密码仍有值。是否清空这些凭据并继续保存?"
type="warning"
confirm-text="清空并保存"
cancel-text="返回编辑"
@confirm="confirmClearCredentials"
@cancel="showClearCredentialsDialog = false"
/>
</template> </template>
<script setup lang="ts"> <script setup lang="ts">
import { ref, computed, onMounted } from 'vue' import { ref, computed, onMounted, watch } from 'vue'
import { import {
Dialog, Dialog,
Button, Button,
@@ -250,17 +205,15 @@ import {
SelectValue, SelectValue,
SelectContent, SelectContent,
SelectItem, SelectItem,
Switch,
} from '@/components/ui' } from '@/components/ui'
import AlertDialog from '@/components/common/AlertDialog.vue' import { Settings, Edit, Trash2, Check, X, Power } from 'lucide-vue-next'
import { Link, SquarePen } from 'lucide-vue-next'
import { useToast } from '@/composables/useToast' import { useToast } from '@/composables/useToast'
import { useFormDialog } from '@/composables/useFormDialog'
import { parseNumberInput } from '@/utils/form'
import { log } from '@/utils/logger' import { log } from '@/utils/logger'
import { import {
createEndpoint, createEndpoint,
updateEndpoint, updateEndpoint,
deleteEndpoint,
API_FORMAT_LABELS,
type ProviderEndpoint, type ProviderEndpoint,
type ProviderWithEndpointsSummary type ProviderWithEndpointsSummary
} from '@/api/endpoints' } from '@/api/endpoints'
@@ -269,7 +222,7 @@ import { adminApi } from '@/api/admin'
const props = defineProps<{ const props = defineProps<{
modelValue: boolean modelValue: boolean
provider: ProviderWithEndpointsSummary | null provider: ProviderWithEndpointsSummary | null
endpoint?: ProviderEndpoint | null // 编辑模式时传入 endpoints?: ProviderEndpoint[]
}>() }>()
const emit = defineEmits<{ const emit = defineEmits<{
@@ -279,258 +232,184 @@ const emit = defineEmits<{
}>() }>()
const { success, error: showError } = useToast() const { success, error: showError } = useToast()
const loading = ref(false)
const selectOpen = ref(false)
const proxyEnabled = ref(false)
const showClearCredentialsDialog = ref(false) // 确认清空凭据对话框
// 生成随机 ID 防止浏览器自动填充 // 状态
const formId = Math.random().toString(36).substring(2, 10) const addingEndpoint = ref(false)
const editingEndpointId = ref<string | null>(null)
const editingUrl = ref('')
const editingPath = ref('')
const savingEndpointId = ref<string | null>(null)
const deletingEndpointId = ref<string | null>(null)
const togglingEndpointId = ref<string | null>(null)
const formatSelectOpen = ref(false)
// 内部状态 // 内部状态
const internalOpen = computed(() => props.modelValue) const internalOpen = computed(() => props.modelValue)
// 表单数据 // 新端点表单
const form = ref({ const newEndpoint = ref({
api_format: '', api_format: '',
base_url: '', base_url: '',
custom_path: '', custom_path: '',
timeout: 300,
max_retries: 3,
max_concurrent: undefined as number | undefined,
rate_limit: undefined as number | undefined,
is_active: true,
// 代理配置
proxy_url: '',
proxy_username: '',
proxy_password: '',
}) })
// API 格式列表 // API 格式列表
const apiFormats = ref<Array<{ value: string; label: string; default_path: string; aliases: string[] }>>([]) const apiFormats = ref<Array<{ value: string; label: string; default_path: string }>>([])
// 加载API格式列表 // 本地端点列表
const localEndpoints = ref<ProviderEndpoint[]>([])
// 可用的格式(未添加的)
const availableFormats = computed(() => {
const existingFormats = localEndpoints.value.map(e => e.api_format)
return apiFormats.value.filter(f => !existingFormats.includes(f.value))
})
// 获取指定 API 格式的默认路径
function getDefaultPath(apiFormat: string): string {
const format = apiFormats.value.find(f => f.value === apiFormat)
return format?.default_path || ''
}
// 当前编辑端点的默认路径
const editingDefaultPath = computed(() => {
const endpoint = localEndpoints.value.find(e => e.id === editingEndpointId.value)
return endpoint ? getDefaultPath(endpoint.api_format) : ''
})
// 新端点选择的格式的默认路径
const newEndpointDefaultPath = computed(() => {
return getDefaultPath(newEndpoint.value.api_format)
})
// 加载 API 格式列表
const loadApiFormats = async () => { const loadApiFormats = async () => {
try { try {
const response = await adminApi.getApiFormats() const response = await adminApi.getApiFormats()
apiFormats.value = response.formats apiFormats.value = response.formats
} catch (error) { } catch (error) {
log.error('加载API格式失败:', error) log.error('加载API格式失败:', error)
if (!isEditMode.value) {
showError('加载API格式失败', '错误')
}
} }
} }
// 根据选择的 API 格式计算默认路径
const defaultPath = computed(() => {
const format = apiFormats.value.find(f => f.value === form.value.api_format)
return format?.default_path || '/'
})
// 动态 placeholder
const defaultPathPlaceholder = computed(() => {
return `留空使用默认路径:${defaultPath.value}`
})
// 检查是否有已保存的密码(后端返回 *** 表示有密码)
const hasExistingPassword = computed(() => {
if (!props.endpoint?.proxy) return false
const proxy = props.endpoint.proxy as { password?: string }
return proxy?.password === MASKED_PASSWORD
})
// 密码输入框的 placeholder
const passwordPlaceholder = computed(() => {
if (hasExistingPassword.value) {
return '已保存密码,留空保持不变'
}
return '代理认证密码'
})
// 代理 URL 验证
const proxyUrlError = computed(() => {
// 只有启用代理且填写了 URL 时才验证
if (!proxyEnabled.value || !form.value.proxy_url) {
return ''
}
const url = form.value.proxy_url.trim()
// 检查禁止的特殊字符
if (/[\n\r]/.test(url)) {
return '代理 URL 包含非法字符'
}
// 验证协议(不支持 SOCKS4
if (!/^(http|https|socks5):\/\//i.test(url)) {
return '代理 URL 必须以 http://, https:// 或 socks5:// 开头'
}
try {
const parsed = new URL(url)
if (!parsed.host) {
return '代理 URL 必须包含有效的 host'
}
// 禁止 URL 中内嵌认证信息
if (parsed.username || parsed.password) {
return '请勿在 URL 中包含用户名和密码,请使用独立的认证字段'
}
} catch {
return '代理 URL 格式无效'
}
return ''
})
// 组件挂载时加载API格式
onMounted(() => { onMounted(() => {
loadApiFormats() loadApiFormats()
}) })
// 重置表单 // 监听 props 变化
function resetForm() { watch(() => props.modelValue, (open) => {
form.value = { if (open) {
api_format: '', localEndpoints.value = [...(props.endpoints || [])]
base_url: '', // 重置编辑状态
custom_path: '', editingEndpointId.value = null
timeout: 300, editingUrl.value = ''
max_retries: 3, editingPath.value = ''
max_concurrent: undefined, } else {
rate_limit: undefined, // 关闭对话框时完全清空新端点表单
is_active: true, newEndpoint.value = { api_format: '', base_url: '', custom_path: '' }
proxy_url: '',
proxy_username: '',
proxy_password: '',
} }
proxyEnabled.value = false }, { immediate: true })
watch(() => props.endpoints, (endpoints) => {
if (props.modelValue) {
localEndpoints.value = [...(endpoints || [])]
}
}, { deep: true })
// 开始编辑
function startEdit(endpoint: ProviderEndpoint) {
editingEndpointId.value = endpoint.id
editingUrl.value = endpoint.base_url
editingPath.value = endpoint.custom_path || ''
} }
// 原始密码占位符(后端返回的脱敏标记) // 取消编辑
const MASKED_PASSWORD = '***' function cancelEdit() {
editingEndpointId.value = null
// 加载端点数据(编辑模式) editingUrl.value = ''
function loadEndpointData() { editingPath.value = ''
if (!props.endpoint) return
const proxy = props.endpoint.proxy as { url?: string; username?: string; password?: string; enabled?: boolean } | null
form.value = {
api_format: props.endpoint.api_format,
base_url: props.endpoint.base_url,
custom_path: props.endpoint.custom_path || '',
timeout: props.endpoint.timeout,
max_retries: props.endpoint.max_retries,
max_concurrent: props.endpoint.max_concurrent || undefined,
rate_limit: props.endpoint.rate_limit || undefined,
is_active: props.endpoint.is_active,
proxy_url: proxy?.url || '',
proxy_username: proxy?.username || '',
// 如果密码是脱敏标记,显示为空(让用户知道有密码但看不到)
proxy_password: proxy?.password === MASKED_PASSWORD ? '' : (proxy?.password || ''),
}
// 根据 enabled 字段或 url 存在判断是否启用代理
proxyEnabled.value = proxy?.enabled ?? !!proxy?.url
} }
// 使用 useFormDialog 统一处理对话框逻辑 // 保存端点
const { isEditMode, handleDialogUpdate, handleCancel } = useFormDialog({ async function saveEndpointUrl(endpoint: ProviderEndpoint) {
isOpen: () => props.modelValue, if (!editingUrl.value) return
entity: () => props.endpoint,
isLoading: loading,
onClose: () => emit('update:modelValue', false),
loadData: loadEndpointData,
resetForm,
})
// 构建代理配置 savingEndpointId.value = endpoint.id
// - 有 URL 时始终保存配置,通过 enabled 字段控制是否启用
// - 无 URL 时返回 null
function buildProxyConfig(): { url: string; username?: string; password?: string; enabled: boolean } | null {
if (!form.value.proxy_url) {
// 没填 URL无代理配置
return null
}
return {
url: form.value.proxy_url,
username: form.value.proxy_username || undefined,
password: form.value.proxy_password || undefined,
enabled: proxyEnabled.value, // 开关状态决定是否启用
}
}
// 提交表单
const handleSubmit = async (skipCredentialCheck = false) => {
if (!props.provider && !props.endpoint) return
// 只在开关开启且填写了 URL 时验证
if (proxyEnabled.value && form.value.proxy_url && proxyUrlError.value) {
showError(proxyUrlError.value, '代理配置错误')
return
}
// 检查:开关开启但没有 URL却有用户名或密码
const hasOrphanedCredentials = proxyEnabled.value
&& !form.value.proxy_url
&& (form.value.proxy_username || form.value.proxy_password)
if (hasOrphanedCredentials && !skipCredentialCheck) {
// 弹出确认对话框
showClearCredentialsDialog.value = true
return
}
loading.value = true
try { try {
const proxyConfig = buildProxyConfig() await updateEndpoint(endpoint.id, {
base_url: editingUrl.value,
if (isEditMode.value && props.endpoint) { custom_path: editingPath.value || null, // 空字符串时传 null 清空
// 更新端点 })
await updateEndpoint(props.endpoint.id, { success('端点已更新')
base_url: form.value.base_url, emit('endpointUpdated')
custom_path: form.value.custom_path || undefined, cancelEdit()
timeout: form.value.timeout,
max_retries: form.value.max_retries,
max_concurrent: form.value.max_concurrent,
rate_limit: form.value.rate_limit,
is_active: form.value.is_active,
proxy: proxyConfig,
})
success('端点已更新', '保存成功')
emit('endpointUpdated')
} else if (props.provider) {
// 创建端点
await createEndpoint(props.provider.id, {
provider_id: props.provider.id,
api_format: form.value.api_format,
base_url: form.value.base_url,
custom_path: form.value.custom_path || undefined,
timeout: form.value.timeout,
max_retries: form.value.max_retries,
max_concurrent: form.value.max_concurrent,
rate_limit: form.value.rate_limit,
is_active: form.value.is_active,
proxy: proxyConfig,
})
success('端点创建成功', '成功')
emit('endpointCreated')
resetForm()
}
emit('update:modelValue', false)
} catch (error: any) { } catch (error: any) {
const action = isEditMode.value ? '更新' : '创建' showError(error.response?.data?.detail || '更新失败', '错误')
showError(error.response?.data?.detail || `${action}端点失败`, '错误')
} finally { } finally {
loading.value = false savingEndpointId.value = null
} }
} }
// 确认清空凭据并继续保存 // 添加端点
const confirmClearCredentials = () => { async function handleAddEndpoint() {
form.value.proxy_username = '' if (!props.provider || !newEndpoint.value.api_format || !newEndpoint.value.base_url) return
form.value.proxy_password = ''
showClearCredentialsDialog.value = false addingEndpoint.value = true
handleSubmit(true) // 跳过凭据检查,直接提交 try {
await createEndpoint(props.provider.id, {
provider_id: props.provider.id,
api_format: newEndpoint.value.api_format,
base_url: newEndpoint.value.base_url,
custom_path: newEndpoint.value.custom_path || undefined,
is_active: true,
})
success(`已添加 ${API_FORMAT_LABELS[newEndpoint.value.api_format] || newEndpoint.value.api_format} 端点`)
// 重置表单,保留 URL
const url = newEndpoint.value.base_url
newEndpoint.value = { api_format: '', base_url: url, custom_path: '' }
emit('endpointCreated')
} catch (error: any) {
showError(error.response?.data?.detail || '添加失败', '错误')
} finally {
addingEndpoint.value = false
}
}
// 切换端点启用状态
async function handleToggleEndpoint(endpoint: ProviderEndpoint) {
togglingEndpointId.value = endpoint.id
try {
const newStatus = !endpoint.is_active
await updateEndpoint(endpoint.id, { is_active: newStatus })
success(newStatus ? '端点已启用' : '端点已停用')
emit('endpointUpdated')
} catch (error: any) {
showError(error.response?.data?.detail || '操作失败', '错误')
} finally {
togglingEndpointId.value = null
}
}
// 删除端点
async function handleDeleteEndpoint(endpoint: ProviderEndpoint) {
deletingEndpointId.value = endpoint.id
try {
await deleteEndpoint(endpoint.id)
success(`已删除 ${API_FORMAT_LABELS[endpoint.api_format] || endpoint.api_format} 端点`)
emit('endpointUpdated')
} catch (error: any) {
showError(error.response?.data?.detail || '删除失败', '错误')
} finally {
deletingEndpointId.value = null
}
}
// 关闭对话框
function handleDialogUpdate(value: boolean) {
emit('update:modelValue', value)
}
function handleClose() {
emit('update:modelValue', false)
} }
</script> </script>

View File

@@ -1,165 +1,179 @@
<template> <template>
<Dialog <Dialog
:model-value="isOpen" :model-value="isOpen"
title="配置允许的模型" title="获取上游模型"
description="选择该 API Key 允许访问的模型,留空则允许访问所有模型" :description="`使用密钥 ${props.apiKey?.name || props.apiKey?.api_key_masked || ''} 从上游获取模型列表。导入的模型需要关联全局模型后才能参与路由。`"
:icon="Settings2" :icon="Layers"
size="2xl" size="2xl"
@update:model-value="handleDialogUpdate" @update:model-value="handleDialogUpdate"
> >
<div class="space-y-4 py-2"> <div class="space-y-4 py-2">
<!-- 已选模型展示 --> <!-- 操作区域 -->
<div <div class="flex items-center justify-between">
v-if="selectedModels.length > 0" <div class="text-sm text-muted-foreground">
class="space-y-2" <span v-if="!hasQueried">点击获取按钮查询上游可用模型</span>
> <span v-else-if="upstreamModels.length > 0">
<div class="flex items-center justify-between px-1"> {{ upstreamModels.length }} 个模型已选 {{ selectedModels.length }}
<div class="text-xs font-medium text-muted-foreground"> </span>
已选模型 ({{ selectedModels.length }}) <span v-else>未找到可用模型</span>
</div>
<Button
type="button"
variant="ghost"
size="sm"
class="h-6 text-xs hover:text-destructive"
@click="clearModels"
>
清空
</Button>
</div>
<div class="flex flex-wrap gap-1.5 p-2 bg-muted/20 rounded-lg border border-border/40 min-h-[40px]">
<Badge
v-for="modelName in selectedModels"
:key="modelName"
variant="secondary"
class="text-[11px] px-2 py-0.5 bg-background border-border/60 shadow-sm"
>
{{ getModelLabel(modelName) }}
<button
class="ml-0.5 hover:text-destructive focus:outline-none"
@click.stop="toggleModel(modelName, false)"
>
&times;
</button>
</Badge>
</div> </div>
<Button
variant="outline"
size="sm"
:disabled="loading"
@click="fetchUpstreamModels"
>
<RefreshCw
class="w-3.5 h-3.5 mr-1.5"
:class="{ 'animate-spin': loading }"
/>
{{ hasQueried ? '刷新' : '获取模型' }}
</Button>
</div> </div>
<!-- 模型列表区域 --> <!-- 加载状态 -->
<div class="space-y-2"> <div
v-if="loading"
class="flex flex-col items-center justify-center py-12 space-y-3"
>
<div class="animate-spin rounded-full h-8 w-8 border-2 border-primary/20 border-t-primary" />
<span class="text-xs text-muted-foreground">正在从上游获取模型列表...</span>
</div>
<!-- 错误状态 -->
<div
v-else-if="errorMessage"
class="flex flex-col items-center justify-center py-12 text-destructive border border-dashed border-destructive/30 rounded-lg bg-destructive/5"
>
<AlertCircle class="w-10 h-10 mb-2 opacity-50" />
<span class="text-sm text-center px-4">{{ errorMessage }}</span>
<Button
variant="outline"
size="sm"
class="mt-3"
@click="fetchUpstreamModels"
>
重试
</Button>
</div>
<!-- 未查询状态 -->
<div
v-else-if="!hasQueried"
class="flex flex-col items-center justify-center py-12 text-muted-foreground border border-dashed rounded-lg bg-muted/10"
>
<Layers class="w-10 h-10 mb-2 opacity-20" />
<span class="text-sm">点击上方按钮获取模型列表</span>
</div>
<!-- 无模型 -->
<div
v-else-if="upstreamModels.length === 0"
class="flex flex-col items-center justify-center py-12 text-muted-foreground border border-dashed rounded-lg bg-muted/10"
>
<Box class="w-10 h-10 mb-2 opacity-20" />
<span class="text-sm">上游 API 未返回可用模型</span>
</div>
<!-- 模型列表 -->
<div v-else class="space-y-2">
<!-- 全选/取消 -->
<div class="flex items-center justify-between px-1"> <div class="flex items-center justify-between px-1">
<div class="text-xs font-medium text-muted-foreground"> <div class="flex items-center gap-2">
可选模型列表 <Checkbox
:checked="isAllSelected"
:indeterminate="isPartiallySelected"
@update:checked="toggleSelectAll"
/>
<span class="text-xs text-muted-foreground">
{{ isAllSelected ? '取消全选' : '全选' }}
</span>
</div> </div>
<div <div class="text-xs text-muted-foreground">
v-if="!loadingModels && availableModels.length > 0" {{ newModelsCount }} 个新模型不在本地
class="text-[10px] text-muted-foreground/60"
>
{{ availableModels.length }} 个模型
</div> </div>
</div> </div>
<!-- 加载状态 --> <div class="max-h-[320px] overflow-y-auto pr-1 space-y-1 custom-scrollbar">
<div
v-if="loadingModels"
class="flex flex-col items-center justify-center py-12 space-y-3"
>
<div class="animate-spin rounded-full h-8 w-8 border-2 border-primary/20 border-t-primary" />
<span class="text-xs text-muted-foreground">正在加载模型列表...</span>
</div>
<!-- 无模型 -->
<div
v-else-if="availableModels.length === 0"
class="flex flex-col items-center justify-center py-12 text-muted-foreground border border-dashed rounded-lg bg-muted/10"
>
<Box class="w-10 h-10 mb-2 opacity-20" />
<span class="text-sm">暂无可选模型</span>
</div>
<!-- 模型列表 -->
<div
v-else
class="max-h-[320px] overflow-y-auto pr-1 space-y-1.5 custom-scrollbar"
>
<div <div
v-for="model in availableModels" v-for="model in upstreamModels"
:key="model.global_model_name" :key="`${model.id}:${model.api_format || ''}`"
class="group flex items-center gap-3 px-3 py-2.5 rounded-lg border transition-all duration-200 cursor-pointer select-none" class="group flex items-center gap-3 px-3 py-2.5 rounded-lg border transition-all duration-200 cursor-pointer select-none"
:class="[ :class="[
selectedModels.includes(model.global_model_name) selectedModels.includes(model.id)
? 'border-primary/40 bg-primary/5 shadow-sm' ? 'border-primary/40 bg-primary/5 shadow-sm'
: 'border-border/40 bg-background hover:border-primary/20 hover:bg-muted/30' : 'border-border/40 bg-background hover:border-primary/20 hover:bg-muted/30'
]" ]"
@click="toggleModel(model.global_model_name, !selectedModels.includes(model.global_model_name))" @click="toggleModel(model.id)"
> >
<!-- Checkbox -->
<Checkbox <Checkbox
:checked="selectedModels.includes(model.global_model_name)" :checked="selectedModels.includes(model.id)"
class="data-[state=checked]:bg-primary data-[state=checked]:border-primary" class="data-[state=checked]:bg-primary data-[state=checked]:border-primary"
@click.stop @click.stop
@update:checked="checked => toggleModel(model.global_model_name, checked)" @update:checked="checked => toggleModel(model.id, checked)"
/> />
<!-- Info -->
<div class="flex-1 min-w-0"> <div class="flex-1 min-w-0">
<div class="flex items-center justify-between gap-2"> <div class="flex items-center gap-2">
<span class="text-sm font-medium truncate text-foreground/90">{{ model.display_name }}</span> <span class="text-sm font-medium truncate text-foreground/90">
<span {{ model.display_name || model.id }}
v-if="hasPricing(model)"
class="text-[10px] font-mono text-muted-foreground/80 bg-muted/30 px-1.5 py-0.5 rounded border border-border/30 shrink-0"
>
{{ formatPricingShort(model) }}
</span> </span>
<Badge
v-if="model.api_format"
variant="outline"
class="text-[10px] px-1.5 py-0 shrink-0"
>
{{ API_FORMAT_LABELS[model.api_format] || model.api_format }}
</Badge>
<Badge
v-if="isModelExisting(model.id)"
variant="secondary"
class="text-[10px] px-1.5 py-0 shrink-0"
>
已存在
</Badge>
</div> </div>
<div class="text-[11px] text-muted-foreground/60 font-mono truncate mt-0.5"> <div class="text-[11px] text-muted-foreground/60 font-mono truncate mt-0.5">
{{ model.global_model_name }} {{ model.id }}
</div> </div>
</div> </div>
<div
<!-- 测试按钮 --> v-if="model.owned_by"
<Button class="text-[10px] text-muted-foreground/50 shrink-0"
variant="ghost"
size="icon"
class="h-7 w-7 shrink-0"
title="测试模型连接"
:disabled="testingModelName === model.global_model_name"
@click.stop="testModelConnection(model)"
> >
<Loader2 {{ model.owned_by }}
v-if="testingModelName === model.global_model_name" </div>
class="w-3.5 h-3.5 animate-spin"
/>
<Play
v-else
class="w-3.5 h-3.5"
/>
</Button>
</div> </div>
</div> </div>
</div> </div>
</div> </div>
<template #footer> <template #footer>
<div class="flex items-center justify-end gap-2 w-full pt-2"> <div class="flex items-center justify-between w-full pt-2">
<Button <div class="text-xs text-muted-foreground">
variant="outline" <span v-if="selectedModels.length > 0 && newSelectedCount > 0">
class="h-9" 将导入 {{ newSelectedCount }} 个新模型
@click="handleCancel" </span>
> </div>
取消 <div class="flex items-center gap-2">
</Button> <Button
<Button variant="outline"
:disabled="saving" class="h-9"
class="h-9 min-w-[80px]" @click="handleCancel"
@click="handleSave" >
> 取消
<Loader2 </Button>
v-if="saving" <Button
class="w-3.5 h-3.5 mr-1.5 animate-spin" :disabled="importing || selectedModels.length === 0 || newSelectedCount === 0"
/> class="h-9 min-w-[100px]"
{{ saving ? '保存中' : '保存配置' }} @click="handleImport"
</Button> >
<Loader2
v-if="importing"
class="w-3.5 h-3.5 mr-1.5 animate-spin"
/>
{{ importing ? '导入中' : `导入 ${newSelectedCount} 个模型` }}
</Button>
</div>
</div> </div>
</template> </template>
</Dialog> </Dialog>
@@ -167,19 +181,19 @@
<script setup lang="ts"> <script setup lang="ts">
import { ref, computed, watch } from 'vue' import { ref, computed, watch } from 'vue'
import { Box, Loader2, Settings2, Play } from 'lucide-vue-next' import { Box, Layers, Loader2, RefreshCw, AlertCircle } from 'lucide-vue-next'
import { Dialog } from '@/components/ui' import { Dialog } from '@/components/ui'
import Button from '@/components/ui/button.vue' import Button from '@/components/ui/button.vue'
import Badge from '@/components/ui/badge.vue' import Badge from '@/components/ui/badge.vue'
import Checkbox from '@/components/ui/checkbox.vue' import Checkbox from '@/components/ui/checkbox.vue'
import { useToast } from '@/composables/useToast' import { useToast } from '@/composables/useToast'
import { parseApiError, parseTestModelError } from '@/utils/errorParser' import { adminApi } from '@/api/admin'
import { import {
updateEndpointKey, importModelsFromUpstream,
getProviderAvailableSourceModels, getProviderModels,
testModel,
type EndpointAPIKey, type EndpointAPIKey,
type ProviderAvailableSourceModel type UpstreamModel,
API_FORMAT_LABELS,
} from '@/api/endpoints' } from '@/api/endpoints'
const props = defineProps<{ const props = defineProps<{
@@ -196,130 +210,116 @@ const emit = defineEmits<{
const { success, error: showError } = useToast() const { success, error: showError } = useToast()
const isOpen = computed(() => props.open) const isOpen = computed(() => props.open)
const saving = ref(false) const loading = ref(false)
const loadingModels = ref(false) const importing = ref(false)
const availableModels = ref<ProviderAvailableSourceModel[]>([]) const hasQueried = ref(false)
const errorMessage = ref('')
const upstreamModels = ref<UpstreamModel[]>([])
const selectedModels = ref<string[]>([]) const selectedModels = ref<string[]>([])
const initialModels = ref<string[]>([]) const existingModelIds = ref<Set<string>>(new Set())
const testingModelName = ref<string | null>(null)
// 计算属性
const isAllSelected = computed(() =>
upstreamModels.value.length > 0 &&
selectedModels.value.length === upstreamModels.value.length
)
const isPartiallySelected = computed(() =>
selectedModels.value.length > 0 &&
selectedModels.value.length < upstreamModels.value.length
)
const newModelsCount = computed(() =>
upstreamModels.value.filter(m => !existingModelIds.value.has(m.id)).length
)
const newSelectedCount = computed(() =>
selectedModels.value.filter(id => !existingModelIds.value.has(id)).length
)
// 检查模型是否已存在
function isModelExisting(modelId: string): boolean {
return existingModelIds.value.has(modelId)
}
// 监听对话框打开 // 监听对话框打开
watch(() => props.open, (open) => { watch(() => props.open, (open) => {
if (open) { if (open) {
loadData() resetState()
loadExistingModels()
} }
}) })
async function loadData() { function resetState() {
// 初始化已选模型 hasQueried.value = false
if (props.apiKey?.allowed_models) { errorMessage.value = ''
selectedModels.value = [...props.apiKey.allowed_models] upstreamModels.value = []
initialModels.value = [...props.apiKey.allowed_models]
} else {
selectedModels.value = []
initialModels.value = []
}
// 加载可选模型
if (props.providerId) {
await loadAvailableModels()
}
}
async function loadAvailableModels() {
if (!props.providerId) return
try {
loadingModels.value = true
const response = await getProviderAvailableSourceModels(props.providerId)
availableModels.value = response.models
} catch (err: any) {
const errorMessage = parseApiError(err, '加载模型列表失败')
showError(errorMessage, '错误')
} finally {
loadingModels.value = false
}
}
const modelLabelMap = computed(() => {
const map = new Map<string, string>()
availableModels.value.forEach(model => {
map.set(model.global_model_name, model.display_name || model.global_model_name)
})
return map
})
function getModelLabel(modelName: string): string {
return modelLabelMap.value.get(modelName) ?? modelName
}
function hasPricing(model: ProviderAvailableSourceModel): boolean {
const input = model.price.input_price_per_1m ?? 0
const output = model.price.output_price_per_1m ?? 0
return input > 0 || output > 0
}
function formatPricingShort(model: ProviderAvailableSourceModel): string {
const input = model.price.input_price_per_1m ?? 0
const output = model.price.output_price_per_1m ?? 0
if (input > 0 || output > 0) {
return `$${formatPrice(input)}/$${formatPrice(output)}`
}
return ''
}
function formatPrice(value?: number | null): string {
if (value === undefined || value === null || value === 0) return '0'
if (value >= 1) {
return value.toFixed(2)
}
return value.toFixed(2)
}
function toggleModel(modelName: string, checked: boolean) {
if (checked) {
if (!selectedModels.value.includes(modelName)) {
selectedModels.value = [...selectedModels.value, modelName]
}
} else {
selectedModels.value = selectedModels.value.filter(name => name !== modelName)
}
}
function clearModels() {
selectedModels.value = [] selectedModels.value = []
} }
// 测试模型连接 // 加载已存在的模型列表
async function testModelConnection(model: ProviderAvailableSourceModel) { async function loadExistingModels() {
if (!props.providerId || !props.apiKey || testingModelName.value) return if (!props.providerId) return
testingModelName.value = model.global_model_name
try { try {
const result = await testModel({ const models = await getProviderModels(props.providerId)
provider_id: props.providerId, existingModelIds.value = new Set(
model_name: model.provider_model_name, models.map((m: { provider_model_name: string }) => m.provider_model_name)
api_key_id: props.apiKey.id, )
message: "hello" } catch {
}) existingModelIds.value = new Set()
if (result.success) {
success(`模型 "${model.display_name}" 测试成功`)
} else {
showError(`模型测试失败: ${parseTestModelError(result)}`)
}
} catch (err: any) {
const errorMsg = err.response?.data?.detail || err.message || '测试请求失败'
showError(`模型测试失败: ${errorMsg}`)
} finally {
testingModelName.value = null
} }
} }
function areArraysEqual(a: string[], b: string[]): boolean { // 获取上游模型
if (a.length !== b.length) return false async function fetchUpstreamModels() {
const sortedA = [...a].sort() if (!props.providerId || !props.apiKey) return
const sortedB = [...b].sort()
return sortedA.every((value, index) => value === sortedB[index]) loading.value = true
errorMessage.value = ''
try {
const response = await adminApi.queryProviderModels(props.providerId, props.apiKey.id)
if (response.success && response.data?.models) {
upstreamModels.value = response.data.models
// 默认选中所有新模型
selectedModels.value = response.data.models
.filter((m: UpstreamModel) => !existingModelIds.value.has(m.id))
.map((m: UpstreamModel) => m.id)
hasQueried.value = true
// 如果有部分失败,显示警告提示
if (response.data.error) {
showError(`部分格式获取失败: ${response.data.error}`, '警告')
}
} else {
errorMessage.value = response.data?.error || '获取上游模型失败'
}
} catch (err: any) {
errorMessage.value = err.response?.data?.detail || '获取上游模型失败'
} finally {
loading.value = false
}
}
// 切换模型选择
function toggleModel(modelId: string, checked?: boolean) {
const shouldSelect = checked !== undefined ? checked : !selectedModels.value.includes(modelId)
if (shouldSelect) {
if (!selectedModels.value.includes(modelId)) {
selectedModels.value = [...selectedModels.value, modelId]
}
} else {
selectedModels.value = selectedModels.value.filter(id => id !== modelId)
}
}
// 全选/取消全选
function toggleSelectAll(checked: boolean) {
if (checked) {
selectedModels.value = upstreamModels.value.map(m => m.id)
} else {
selectedModels.value = []
}
} }
function handleDialogUpdate(value: boolean) { function handleDialogUpdate(value: boolean) {
@@ -332,30 +332,44 @@ function handleCancel() {
emit('close') emit('close')
} }
async function handleSave() { // 导入选中的模型
if (!props.apiKey) return async function handleImport() {
if (!props.providerId || selectedModels.value.length === 0) return
// 检查是否有变化 // 过滤出新模型(不在已存在列表中的)
const hasChanged = !areArraysEqual(selectedModels.value, initialModels.value) const modelsToImport = selectedModels.value.filter(id => !existingModelIds.value.has(id))
if (!hasChanged) { if (modelsToImport.length === 0) {
emit('close') showError('所选模型都已存在', '提示')
return return
} }
saving.value = true importing.value = true
try { try {
await updateEndpointKey(props.apiKey.id, { const response = await importModelsFromUpstream(props.providerId, modelsToImport)
// 空数组时发送 null表示允许所有模型
allowed_models: selectedModels.value.length > 0 ? [...selectedModels.value] : null const successCount = response.success?.length || 0
}) const errorCount = response.errors?.length || 0
success('允许的模型已更新', '成功')
emit('saved') if (successCount > 0 && errorCount === 0) {
emit('close') success(`成功导入 ${successCount} 个模型`, '导入成功')
emit('saved')
emit('close')
} else if (successCount > 0 && errorCount > 0) {
success(`成功导入 ${successCount} 个模型,${errorCount} 个失败`, '部分成功')
emit('saved')
// 刷新列表以更新已存在状态
await loadExistingModels()
// 更新选中列表,移除已成功导入的
const successIds = new Set(response.success?.map((s: { model_id: string }) => s.model_id) || [])
selectedModels.value = selectedModels.value.filter(id => !successIds.has(id))
} else {
const errorMsg = response.errors?.[0]?.error || '导入失败'
showError(errorMsg, '导入失败')
}
} catch (err: any) { } catch (err: any) {
const errorMessage = parseApiError(err, '保存失败') showError(err.response?.data?.detail || '导入失败', '错误')
showError(errorMessage, '错误')
} finally { } finally {
saving.value = false importing.value = false
} }
} }
</script> </script>

View File

@@ -0,0 +1,696 @@
<template>
<Dialog
:model-value="isOpen"
title="模型权限"
:description="`管理密钥 ${props.apiKey?.name || ''} 可访问的模型,清空右侧列表表示允许全部`"
:icon="Shield"
size="4xl"
@update:model-value="handleDialogUpdate"
>
<template #default>
<div class="space-y-4">
<!-- 字典模式警告 -->
<div
v-if="isDictMode"
class="rounded-lg border border-amber-500/50 bg-amber-50 dark:bg-amber-950/30 p-3"
>
<p class="text-sm text-amber-700 dark:text-amber-400">
<strong>注意</strong>此密钥使用按 API 格式区分的模型权限配置
编辑后将转换为统一列表模式原有的格式区分信息将丢失
</p>
</div>
<!-- 密钥信息头部 -->
<div class="rounded-lg border bg-muted/30 p-4">
<div class="flex items-start justify-between">
<div>
<p class="font-semibold text-lg">{{ apiKey?.name }}</p>
<p class="text-sm text-muted-foreground font-mono">
{{ apiKey?.api_key_masked }}
</p>
</div>
<Badge
:variant="allowedModels.length === 0 ? 'default' : 'outline'"
class="text-xs"
>
{{ allowedModels.length === 0 ? '允许全部' : `限制 ${allowedModels.length} 个模型` }}
</Badge>
</div>
</div>
<!-- 左右对比布局 -->
<div class="flex gap-2 items-stretch">
<!-- 左侧可添加的模型 -->
<div class="flex-1 space-y-2">
<div class="flex items-center justify-between gap-2">
<p class="text-sm font-medium shrink-0">可添加</p>
<div class="flex-1 relative">
<Search class="absolute left-2 top-1/2 -translate-y-1/2 w-3.5 h-3.5 text-muted-foreground" />
<Input
v-model="searchQuery"
placeholder="搜索模型..."
class="pl-7 h-7 text-xs"
/>
</div>
<button
v-if="upstreamModelsLoaded"
type="button"
class="p-1.5 hover:bg-muted rounded-md transition-colors shrink-0"
title="刷新上游模型"
:disabled="fetchingUpstreamModels"
@click="fetchUpstreamModels()"
>
<RefreshCw
class="w-3.5 h-3.5"
:class="{ 'animate-spin': fetchingUpstreamModels }"
/>
</button>
<button
v-else-if="!fetchingUpstreamModels"
type="button"
class="p-1.5 hover:bg-muted rounded-md transition-colors shrink-0"
title="从提供商获取模型"
@click="fetchUpstreamModels()"
>
<Zap class="w-3.5 h-3.5" />
</button>
<Loader2
v-else
class="w-3.5 h-3.5 animate-spin text-muted-foreground shrink-0"
/>
</div>
<div class="border rounded-lg h-80 overflow-y-auto">
<div
v-if="loadingGlobalModels"
class="flex items-center justify-center h-full"
>
<Loader2 class="w-6 h-6 animate-spin text-primary" />
</div>
<div
v-else-if="totalAvailableCount === 0 && !upstreamModelsLoaded"
class="flex flex-col items-center justify-center h-full text-muted-foreground"
>
<Shield class="w-10 h-10 mb-2 opacity-30" />
<p class="text-sm">{{ searchQuery ? '无匹配结果' : '暂无可添加模型' }}</p>
</div>
<div v-else class="p-2 space-y-2">
<!-- 全局模型折叠组 -->
<div
v-if="availableGlobalModels.length > 0 || !upstreamModelsLoaded"
class="border rounded-lg overflow-hidden"
>
<div class="flex items-center gap-2 px-3 py-2 bg-muted/30">
<button
type="button"
class="flex items-center gap-2 flex-1 hover:bg-muted/50 -mx-1 px-1 rounded transition-colors"
@click="toggleGroupCollapse('global')"
>
<ChevronDown
class="w-4 h-4 transition-transform shrink-0"
:class="collapsedGroups.has('global') ? '-rotate-90' : ''"
/>
<span class="text-xs font-medium">全局模型</span>
<span class="text-xs text-muted-foreground">
({{ availableGlobalModels.length }})
</span>
</button>
<button
v-if="availableGlobalModels.length > 0"
type="button"
class="text-xs text-primary hover:underline shrink-0"
@click.stop="selectAllGlobalModels"
>
{{ isAllGlobalModelsSelected ? '取消' : '全选' }}
</button>
</div>
<div
v-show="!collapsedGroups.has('global')"
class="p-2 space-y-1 border-t"
>
<div
v-if="availableGlobalModels.length === 0"
class="py-4 text-center text-xs text-muted-foreground"
>
所有全局模型均已添加
</div>
<div
v-for="model in availableGlobalModels"
v-else
:key="model.name"
class="flex items-center gap-2 p-2 rounded-lg border transition-colors cursor-pointer"
:class="selectedLeftIds.includes(model.name)
? 'border-primary bg-primary/10'
: 'hover:bg-muted/50'"
@click="toggleLeftSelection(model.name)"
>
<Checkbox
:checked="selectedLeftIds.includes(model.name)"
@update:checked="toggleLeftSelection(model.name)"
@click.stop
/>
<div class="flex-1 min-w-0">
<p class="font-medium text-sm truncate">{{ model.display_name }}</p>
<p class="text-xs text-muted-foreground truncate font-mono">{{ model.name }}</p>
</div>
</div>
</div>
</div>
<!-- 从提供商获取的模型折叠组 -->
<div
v-for="group in upstreamModelGroups"
:key="group.api_format"
class="border rounded-lg overflow-hidden"
>
<div class="flex items-center gap-2 px-3 py-2 bg-muted/30">
<button
type="button"
class="flex items-center gap-2 flex-1 hover:bg-muted/50 -mx-1 px-1 rounded transition-colors"
@click="toggleGroupCollapse(group.api_format)"
>
<ChevronDown
class="w-4 h-4 transition-transform shrink-0"
:class="collapsedGroups.has(group.api_format) ? '-rotate-90' : ''"
/>
<span class="text-xs font-medium">
{{ API_FORMAT_LABELS[group.api_format] || group.api_format }}
</span>
<span class="text-xs text-muted-foreground">
({{ group.models.length }})
</span>
</button>
<button
type="button"
class="text-xs text-primary hover:underline shrink-0"
@click.stop="selectAllUpstreamModels(group.api_format)"
>
{{ isUpstreamGroupAllSelected(group.api_format) ? '取消' : '全选' }}
</button>
</div>
<div
v-show="!collapsedGroups.has(group.api_format)"
class="p-2 space-y-1 border-t"
>
<div
v-for="model in group.models"
:key="model.id"
class="flex items-center gap-2 p-2 rounded-lg border transition-colors cursor-pointer"
:class="selectedLeftIds.includes(model.id)
? 'border-primary bg-primary/10'
: 'hover:bg-muted/50'"
@click="toggleLeftSelection(model.id)"
>
<Checkbox
:checked="selectedLeftIds.includes(model.id)"
@update:checked="toggleLeftSelection(model.id)"
@click.stop
/>
<div class="flex-1 min-w-0">
<p class="font-medium text-sm truncate">{{ model.id }}</p>
<p class="text-xs text-muted-foreground truncate font-mono">
{{ model.owned_by || model.id }}
</p>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
<!-- 中间操作按钮 -->
<div class="flex flex-col items-center justify-center w-12 shrink-0 gap-2">
<Button
variant="outline"
size="sm"
class="w-9 h-8"
:class="selectedLeftIds.length > 0 ? 'border-primary' : ''"
:disabled="selectedLeftIds.length === 0"
title="添加选中"
@click="addSelected"
>
<ChevronRight
class="w-6 h-6 stroke-[3]"
:class="selectedLeftIds.length > 0 ? 'text-primary' : ''"
/>
</Button>
<Button
variant="outline"
size="sm"
class="w-9 h-8"
:class="selectedRightIds.length > 0 ? 'border-primary' : ''"
:disabled="selectedRightIds.length === 0"
title="移除选中"
@click="removeSelected"
>
<ChevronLeft
class="w-6 h-6 stroke-[3]"
:class="selectedRightIds.length > 0 ? 'text-primary' : ''"
/>
</Button>
</div>
<!-- 右侧已添加的允许模型 -->
<div class="flex-1 space-y-2">
<div class="flex items-center justify-between">
<p class="text-sm font-medium">已添加</p>
<Button
v-if="allowedModels.length > 0"
variant="ghost"
size="sm"
class="h-6 px-2 text-xs"
@click="toggleSelectAllRight"
>
{{ isAllRightSelected ? '取消' : '全选' }}
</Button>
</div>
<div class="border rounded-lg h-80 overflow-y-auto">
<div
v-if="allowedModels.length === 0"
class="flex flex-col items-center justify-center h-full text-muted-foreground"
>
<Shield class="w-10 h-10 mb-2 opacity-30" />
<p class="text-sm">允许访问全部模型</p>
<p class="text-xs mt-1">添加模型以限制访问范围</p>
</div>
<div v-else class="p-2 space-y-1">
<div
v-for="modelName in allowedModels"
:key="'allowed-' + modelName"
class="flex items-center gap-2 p-2 rounded-lg border transition-colors cursor-pointer"
:class="selectedRightIds.includes(modelName)
? 'border-primary bg-primary/10'
: 'hover:bg-muted/50'"
@click="toggleRightSelection(modelName)"
>
<Checkbox
:checked="selectedRightIds.includes(modelName)"
@update:checked="toggleRightSelection(modelName)"
@click.stop
/>
<div class="flex-1 min-w-0">
<p class="font-medium text-sm truncate">
{{ getModelDisplayName(modelName) }}
</p>
<p class="text-xs text-muted-foreground truncate font-mono">
{{ modelName }}
</p>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
</template>
<template #footer>
<div class="flex items-center justify-between w-full">
<p class="text-xs text-muted-foreground">
{{ hasChanges ? '有未保存的更改' : '' }}
</p>
<div class="flex items-center gap-2">
<Button variant="outline" @click="handleCancel">取消</Button>
<Button :disabled="saving || !hasChanges" @click="handleSave">
{{ saving ? '保存中...' : '保存' }}
</Button>
</div>
</div>
</template>
</Dialog>
</template>
<script setup lang="ts">
import { ref, computed, watch, onUnmounted } from 'vue'
import {
Shield,
Search,
RefreshCw,
Loader2,
Zap,
ChevronRight,
ChevronLeft,
ChevronDown
} from 'lucide-vue-next'
import { Dialog, Button, Input, Checkbox, Badge } from '@/components/ui'
import { useToast } from '@/composables/useToast'
import { parseApiError } from '@/utils/errorParser'
import {
updateProviderKey,
API_FORMAT_LABELS,
type EndpointAPIKey,
type AllowedModels,
} from '@/api/endpoints'
import { getGlobalModels, type GlobalModelResponse } from '@/api/global-models'
import { adminApi } from '@/api/admin'
import type { UpstreamModel } from '@/api/endpoints/types'
interface AvailableModel {
name: string
display_name: string
}
const props = defineProps<{
open: boolean
apiKey: EndpointAPIKey | null
providerId: string
}>()
const emit = defineEmits<{
close: []
saved: []
}>()
const { success, error: showError } = useToast()
const isOpen = computed(() => props.open)
const saving = ref(false)
const loadingGlobalModels = ref(false)
const fetchingUpstreamModels = ref(false)
const upstreamModelsLoaded = ref(false)
// 用于取消异步操作的标志
let loadingCancelled = false
// 搜索
const searchQuery = ref('')
// 折叠状态
const collapsedGroups = ref<Set<string>>(new Set())
// 可用模型列表(全局模型)
const allGlobalModels = ref<AvailableModel[]>([])
// 上游模型列表
const upstreamModels = ref<UpstreamModel[]>([])
// 已添加的允许模型(右侧)
const allowedModels = ref<string[]>([])
const initialAllowedModels = ref<string[]>([])
// 选中状态
const selectedLeftIds = ref<string[]>([])
const selectedRightIds = ref<string[]>([])
// 是否有更改
const hasChanges = computed(() => {
if (allowedModels.value.length !== initialAllowedModels.value.length) return true
const sorted1 = [...allowedModels.value].sort()
const sorted2 = [...initialAllowedModels.value].sort()
return sorted1.some((v, i) => v !== sorted2[i])
})
// 计算可添加的全局模型(排除已添加的)
const availableGlobalModelsBase = computed(() => {
const allowedSet = new Set(allowedModels.value)
return allGlobalModels.value.filter(m => !allowedSet.has(m.name))
})
// 搜索过滤后的全局模型
const availableGlobalModels = computed(() => {
if (!searchQuery.value.trim()) return availableGlobalModelsBase.value
const query = searchQuery.value.toLowerCase()
return availableGlobalModelsBase.value.filter(m =>
m.name.toLowerCase().includes(query) ||
m.display_name.toLowerCase().includes(query)
)
})
// 计算可添加的上游模型(排除已添加的)
const availableUpstreamModelsBase = computed(() => {
const allowedSet = new Set(allowedModels.value)
return upstreamModels.value.filter(m => !allowedSet.has(m.id))
})
// 搜索过滤后的上游模型
const availableUpstreamModels = computed(() => {
if (!searchQuery.value.trim()) return availableUpstreamModelsBase.value
const query = searchQuery.value.toLowerCase()
return availableUpstreamModelsBase.value.filter(m =>
m.id.toLowerCase().includes(query) ||
(m.owned_by && m.owned_by.toLowerCase().includes(query))
)
})
// 按 API 格式分组的上游模型
const upstreamModelGroups = computed(() => {
const groups: Record<string, UpstreamModel[]> = {}
for (const model of availableUpstreamModels.value) {
const format = model.api_format || 'unknown'
if (!groups[format]) groups[format] = []
groups[format].push(model)
}
const order = Object.keys(API_FORMAT_LABELS)
return Object.entries(groups)
.map(([api_format, models]) => ({ api_format, models }))
.sort((a, b) => {
const aIndex = order.indexOf(a.api_format)
const bIndex = order.indexOf(b.api_format)
if (aIndex === -1 && bIndex === -1) return a.api_format.localeCompare(b.api_format)
if (aIndex === -1) return 1
if (bIndex === -1) return -1
return aIndex - bIndex
})
})
// 总可添加数量
const totalAvailableCount = computed(() => {
return availableGlobalModels.value.length + availableUpstreamModels.value.length
})
// 右侧全选状态
const isAllRightSelected = computed(() =>
allowedModels.value.length > 0 &&
selectedRightIds.value.length === allowedModels.value.length
)
// 全局模型是否全选
const isAllGlobalModelsSelected = computed(() => {
if (availableGlobalModels.value.length === 0) return false
return availableGlobalModels.value.every(m => selectedLeftIds.value.includes(m.name))
})
// 检查某个上游组是否全选
function isUpstreamGroupAllSelected(apiFormat: string): boolean {
const group = upstreamModelGroups.value.find(g => g.api_format === apiFormat)
if (!group || group.models.length === 0) return false
return group.models.every(m => selectedLeftIds.value.includes(m.id))
}
// 获取模型显示名称
function getModelDisplayName(name: string): string {
const globalModel = allGlobalModels.value.find(m => m.name === name)
if (globalModel) return globalModel.display_name
const upstreamModel = upstreamModels.value.find(m => m.id === name)
if (upstreamModel) return upstreamModel.id
return name
}
// 加载全局模型
async function loadGlobalModels() {
loadingGlobalModels.value = true
try {
const response = await getGlobalModels({ limit: 1000 })
// 检查是否已取消dialog 已关闭)
if (loadingCancelled) return
allGlobalModels.value = response.models.map((m: GlobalModelResponse) => ({
name: m.name,
display_name: m.display_name
}))
} catch (err) {
if (loadingCancelled) return
showError('加载全局模型失败', '错误')
} finally {
loadingGlobalModels.value = false
}
}
// 从提供商获取模型(使用当前 key
async function fetchUpstreamModels() {
if (!props.providerId || !props.apiKey) return
try {
fetchingUpstreamModels.value = true
// 使用当前 key 的 ID 来查询上游模型
const response = await adminApi.queryProviderModels(props.providerId, props.apiKey.id)
// 检查是否已取消
if (loadingCancelled) return
if (response.success && response.data?.models) {
upstreamModels.value = response.data.models
upstreamModelsLoaded.value = true
const allGroups = new Set(['global'])
for (const model of response.data.models) {
if (model.api_format) allGroups.add(model.api_format)
}
collapsedGroups.value = allGroups
} else {
showError(response.data?.error || '获取上游模型失败', '错误')
}
} catch (err: any) {
if (loadingCancelled) return
showError(err.response?.data?.detail || '获取上游模型失败', '错误')
} finally {
fetchingUpstreamModels.value = false
}
}
// 切换折叠状态
function toggleGroupCollapse(group: string) {
if (collapsedGroups.value.has(group)) {
collapsedGroups.value.delete(group)
} else {
collapsedGroups.value.add(group)
}
collapsedGroups.value = new Set(collapsedGroups.value)
}
// 是否为字典模式(按 API 格式区分)
const isDictMode = ref(false)
// 解析 allowed_models
function parseAllowedModels(allowed: AllowedModels): string[] {
if (allowed === null || allowed === undefined) {
isDictMode.value = false
return []
}
if (Array.isArray(allowed)) {
isDictMode.value = false
return [...allowed]
}
// 字典模式:合并所有格式的模型,并设置警告标志
isDictMode.value = true
const all = new Set<string>()
for (const models of Object.values(allowed)) {
models.forEach(m => all.add(m))
}
return Array.from(all)
}
// 左侧选择
function toggleLeftSelection(name: string) {
const idx = selectedLeftIds.value.indexOf(name)
if (idx === -1) {
selectedLeftIds.value.push(name)
} else {
selectedLeftIds.value.splice(idx, 1)
}
}
// 右侧选择
function toggleRightSelection(name: string) {
const idx = selectedRightIds.value.indexOf(name)
if (idx === -1) {
selectedRightIds.value.push(name)
} else {
selectedRightIds.value.splice(idx, 1)
}
}
// 右侧全选切换
function toggleSelectAllRight() {
if (isAllRightSelected.value) {
selectedRightIds.value = []
} else {
selectedRightIds.value = [...allowedModels.value]
}
}
// 全选全局模型
function selectAllGlobalModels() {
const allNames = availableGlobalModels.value.map(m => m.name)
const allSelected = allNames.every(name => selectedLeftIds.value.includes(name))
if (allSelected) {
selectedLeftIds.value = selectedLeftIds.value.filter(id => !allNames.includes(id))
} else {
const newNames = allNames.filter(name => !selectedLeftIds.value.includes(name))
selectedLeftIds.value.push(...newNames)
}
}
// 全选某个 API 格式的上游模型
function selectAllUpstreamModels(apiFormat: string) {
const group = upstreamModelGroups.value.find(g => g.api_format === apiFormat)
if (!group) return
const allIds = group.models.map(m => m.id)
const allSelected = allIds.every(id => selectedLeftIds.value.includes(id))
if (allSelected) {
selectedLeftIds.value = selectedLeftIds.value.filter(id => !allIds.includes(id))
} else {
const newIds = allIds.filter(id => !selectedLeftIds.value.includes(id))
selectedLeftIds.value.push(...newIds)
}
}
// 添加选中的模型到右侧
function addSelected() {
for (const name of selectedLeftIds.value) {
if (!allowedModels.value.includes(name)) {
allowedModels.value.push(name)
}
}
selectedLeftIds.value = []
}
// 从右侧移除选中的模型
function removeSelected() {
allowedModels.value = allowedModels.value.filter(
name => !selectedRightIds.value.includes(name)
)
selectedRightIds.value = []
}
// 监听对话框打开
watch(() => props.open, async (open) => {
if (open && props.apiKey) {
// 重置取消标志
loadingCancelled = false
const parsed = parseAllowedModels(props.apiKey.allowed_models ?? null)
allowedModels.value = [...parsed]
initialAllowedModels.value = [...parsed]
selectedLeftIds.value = []
selectedRightIds.value = []
searchQuery.value = ''
upstreamModels.value = []
upstreamModelsLoaded.value = false
collapsedGroups.value = new Set()
await loadGlobalModels()
} else {
// dialog 关闭时设置取消标志
loadingCancelled = true
}
})
// 组件卸载时取消所有异步操作
onUnmounted(() => {
loadingCancelled = true
})
function handleDialogUpdate(value: boolean) {
if (!value) emit('close')
}
function handleCancel() {
emit('close')
}
async function handleSave() {
if (!props.apiKey) return
saving.value = true
try {
// 空列表 = null允许全部
const newAllowed: AllowedModels = allowedModels.value.length > 0
? [...allowedModels.value]
: null
await updateProviderKey(props.apiKey.id, { allowed_models: newAllowed })
success('模型权限已更新', '成功')
emit('saved')
emit('close')
} catch (err: any) {
showError(parseApiError(err, '保存失败'), '错误')
} finally {
saving.value = false
}
}
</script>

View File

@@ -2,57 +2,36 @@
<Dialog <Dialog
:model-value="isOpen" :model-value="isOpen"
:title="isEditMode ? '编辑密钥' : '添加密钥'" :title="isEditMode ? '编辑密钥' : '添加密钥'"
:description="isEditMode ? '修改 API 密钥配置' : '为端点添加新的 API 密钥'" :description="isEditMode ? '修改 API 密钥配置' : '为提供商添加新的 API 密钥'"
:icon="isEditMode ? SquarePen : Key" :icon="isEditMode ? SquarePen : Key"
size="2xl" size="2xl"
@update:model-value="handleDialogUpdate" @update:model-value="handleDialogUpdate"
> >
<form <form
class="space-y-5" class="space-y-4"
autocomplete="off" autocomplete="off"
@submit.prevent="handleSave" @submit.prevent="handleSave"
> >
<!-- 基本信息 --> <!-- 基本信息 -->
<div class="space-y-3"> <div class="grid grid-cols-2 gap-3">
<h3 class="text-sm font-medium border-b pb-2"> <div>
基本信息 <Label :for="keyNameInputId">密钥名称 *</Label>
</h3> <Input
<div class="grid grid-cols-2 gap-4"> :id="keyNameInputId"
<div> v-model="form.name"
<Label :for="keyNameInputId">密钥名称 *</Label> :name="keyNameFieldName"
<Input required
:id="keyNameInputId" placeholder="例如:主 Key、备用 Key 1"
v-model="form.name" maxlength="100"
:name="keyNameFieldName" autocomplete="off"
required autocapitalize="none"
placeholder="例如:主 Key、备用 Key 1" autocorrect="off"
maxlength="100" spellcheck="false"
autocomplete="off" data-form-type="other"
autocapitalize="none" data-lpignore="true"
autocorrect="off" data-1p-ignore="true"
spellcheck="false" />
data-form-type="other"
data-lpignore="true"
data-1p-ignore="true"
/>
</div>
<div>
<Label for="rate_multiplier">成本倍率 *</Label>
<Input
id="rate_multiplier"
v-model.number="form.rate_multiplier"
type="number"
step="0.01"
min="0.01"
required
placeholder="1.0"
/>
<p class="text-xs text-muted-foreground mt-1">
真实成本 = 表面成本 × 倍率
</p>
</div>
</div> </div>
<div> <div>
<Label :for="apiKeyInputId">API 密钥 {{ editingKey ? '' : '*' }}</Label> <Label :for="apiKeyInputId">API 密钥 {{ editingKey ? '' : '*' }}</Label>
<Input <Input
@@ -83,148 +62,161 @@
v-else-if="editingKey" v-else-if="editingKey"
class="text-xs text-muted-foreground mt-1" class="text-xs text-muted-foreground mt-1"
> >
留空表示不修改输入新值则覆盖 留空表示不修改
</p> </p>
</div> </div>
</div>
<!-- 备注 -->
<div>
<Label for="note">备注</Label>
<Input
id="note"
v-model="form.note"
placeholder="可选的备注信息"
/>
</div>
<!-- API 格式选择 -->
<div v-if="sortedApiFormats.length > 0">
<Label class="mb-1.5 block">支持的 API 格式 *</Label>
<div class="grid grid-cols-2 sm:grid-cols-3 gap-2">
<div
v-for="format in sortedApiFormats"
:key="format"
class="flex items-center justify-between rounded-md border px-2 py-1.5 transition-colors cursor-pointer"
:class="form.api_formats.includes(format)
? 'bg-primary/5 border-primary/30'
: 'bg-muted/30 border-border hover:border-muted-foreground/30'"
@click="toggleApiFormat(format)"
>
<div class="flex items-center gap-1.5 min-w-0">
<span
class="w-4 h-4 rounded border flex items-center justify-center text-xs shrink-0"
:class="form.api_formats.includes(format)
? 'bg-primary border-primary text-primary-foreground'
: 'border-muted-foreground/30'"
>
<span v-if="form.api_formats.includes(format)"></span>
</span>
<span
class="text-sm whitespace-nowrap"
:class="form.api_formats.includes(format) ? 'text-primary' : 'text-muted-foreground'"
>{{ API_FORMAT_LABELS[format] || format }}</span>
</div>
<div
class="flex items-center shrink-0 ml-2 text-xs text-muted-foreground gap-1"
@click.stop
>
<span>×</span>
<input
:value="form.rate_multipliers[format] ?? ''"
type="number"
step="0.01"
min="0.01"
placeholder="1"
class="w-9 bg-transparent text-right outline-none [appearance:textfield] [&::-webkit-outer-spin-button]:appearance-none [&::-webkit-inner-spin-button]:appearance-none"
:class="form.api_formats.includes(format) ? 'text-primary' : 'text-muted-foreground'"
title="成本倍率"
@input="(e) => updateRateMultiplier(format, (e.target as HTMLInputElement).value)"
>
</div>
</div>
</div>
</div>
<!-- 配置项 -->
<div class="grid grid-cols-4 gap-3">
<div> <div>
<Label for="note">备注</Label> <Label
for="internal_priority"
class="text-xs"
>优先级</Label>
<Input <Input
id="note" id="internal_priority"
v-model="form.note" v-model.number="form.internal_priority"
placeholder="可选的备注信息" type="number"
min="0"
class="h-8"
/> />
<p class="text-xs text-muted-foreground mt-0.5">
越小越优先
</p>
</div>
<div>
<Label
for="rpm_limit"
class="text-xs"
>RPM 限制</Label>
<Input
id="rpm_limit"
:model-value="form.rpm_limit ?? ''"
type="number"
min="1"
max="10000"
placeholder="自适应"
class="h-8"
@update:model-value="(v) => form.rpm_limit = parseNullableNumberInput(v, { min: 1, max: 10000 })"
/>
<p class="text-xs text-muted-foreground mt-0.5">
留空自适应
</p>
</div>
<div>
<Label
for="cache_ttl_minutes"
class="text-xs"
>缓存 TTL</Label>
<Input
id="cache_ttl_minutes"
:model-value="form.cache_ttl_minutes ?? ''"
type="number"
min="0"
max="60"
class="h-8"
@update:model-value="(v) => form.cache_ttl_minutes = parseNumberInput(v, { min: 0, max: 60 }) ?? 5"
/>
<p class="text-xs text-muted-foreground mt-0.5">
分钟0禁用
</p>
</div>
<div>
<Label
for="max_probe_interval_minutes"
class="text-xs"
>熔断探测</Label>
<Input
id="max_probe_interval_minutes"
:model-value="form.max_probe_interval_minutes ?? ''"
type="number"
min="2"
max="32"
placeholder="32"
class="h-8"
@update:model-value="(v) => form.max_probe_interval_minutes = parseNumberInput(v, { min: 2, max: 32 }) ?? 32"
/>
<p class="text-xs text-muted-foreground mt-0.5">
分钟2-32
</p>
</div> </div>
</div> </div>
<!-- 调度与限流 --> <!-- 能力标签 -->
<div class="space-y-3"> <div v-if="availableCapabilities.length > 0">
<h3 class="text-sm font-medium border-b pb-2"> <Label class="text-xs mb-1.5 block">能力标签</Label>
调度与限流 <div class="flex flex-wrap gap-1.5">
</h3> <button
<div class="grid grid-cols-2 gap-4">
<div>
<Label for="internal_priority">内部优先级</Label>
<Input
id="internal_priority"
v-model.number="form.internal_priority"
type="number"
min="0"
/>
<p class="text-xs text-muted-foreground mt-1">
数字越小越优先
</p>
</div>
<div>
<Label for="max_concurrent">最大并发</Label>
<Input
id="max_concurrent"
:model-value="form.max_concurrent ?? ''"
type="number"
min="1"
placeholder="留空启用自适应"
@update:model-value="(v) => form.max_concurrent = parseNumberInput(v)"
/>
<p class="text-xs text-muted-foreground mt-1">
留空 = 自适应模式
</p>
</div>
</div>
<div class="grid grid-cols-3 gap-4">
<div>
<Label for="rate_limit">速率限制(/分钟)</Label>
<Input
id="rate_limit"
:model-value="form.rate_limit ?? ''"
type="number"
min="1"
@update:model-value="(v) => form.rate_limit = parseNumberInput(v)"
/>
</div>
<div>
<Label for="daily_limit">每日限制</Label>
<Input
id="daily_limit"
:model-value="form.daily_limit ?? ''"
type="number"
min="1"
@update:model-value="(v) => form.daily_limit = parseNumberInput(v)"
/>
</div>
<div>
<Label for="monthly_limit">每月限制</Label>
<Input
id="monthly_limit"
:model-value="form.monthly_limit ?? ''"
type="number"
min="1"
@update:model-value="(v) => form.monthly_limit = parseNumberInput(v)"
/>
</div>
</div>
</div>
<!-- 缓存与熔断 -->
<div class="space-y-3">
<h3 class="text-sm font-medium border-b pb-2">
缓存与熔断
</h3>
<div class="grid grid-cols-2 gap-4">
<div>
<Label for="cache_ttl_minutes">缓存 TTL (分钟)</Label>
<Input
id="cache_ttl_minutes"
:model-value="form.cache_ttl_minutes ?? ''"
type="number"
min="0"
max="60"
@update:model-value="(v) => form.cache_ttl_minutes = parseNumberInput(v, { min: 0, max: 60 }) ?? 5"
/>
<p class="text-xs text-muted-foreground mt-1">
0 = 禁用缓存亲和性
</p>
</div>
<div>
<Label for="max_probe_interval_minutes">熔断探测间隔 (分钟)</Label>
<Input
id="max_probe_interval_minutes"
:model-value="form.max_probe_interval_minutes ?? ''"
type="number"
min="2"
max="32"
placeholder="32"
@update:model-value="(v) => form.max_probe_interval_minutes = parseNumberInput(v, { min: 2, max: 32 }) ?? 32"
/>
<p class="text-xs text-muted-foreground mt-1">
范围 2-32 分钟
</p>
</div>
</div>
</div>
<!-- 能力标签配置 -->
<div
v-if="availableCapabilities.length > 0"
class="space-y-3"
>
<h3 class="text-sm font-medium border-b pb-2">
能力标签
</h3>
<div class="flex flex-wrap gap-2">
<label
v-for="cap in availableCapabilities" v-for="cap in availableCapabilities"
:key="cap.name" :key="cap.name"
class="flex items-center gap-2 px-3 py-1.5 rounded-md border border-border bg-muted/30 cursor-pointer text-sm" type="button"
class="inline-flex items-center gap-1.5 px-2.5 py-1 rounded-md border text-sm transition-colors"
:class="form.capabilities[cap.name]
? 'bg-primary/10 border-primary/50 text-primary'
: 'bg-card border-border hover:bg-muted/50 text-muted-foreground'"
@click="form.capabilities[cap.name] = !form.capabilities[cap.name]"
> >
<input {{ cap.display_name }}
type="checkbox" </button>
:checked="form.capabilities[cap.name] || false"
class="rounded"
@change="form.capabilities[cap.name] = !form.capabilities[cap.name]"
>
<span>{{ cap.display_name }}</span>
</label>
</div> </div>
</div> </div>
</form> </form>
@@ -240,25 +232,27 @@
:disabled="saving" :disabled="saving"
@click="handleSave" @click="handleSave"
> >
{{ saving ? '保存中...' : '保存' }} {{ saving ? (isEditMode ? '保存中...' : '添加中...') : (isEditMode ? '保存' : '添加') }}
</Button> </Button>
</template> </template>
</Dialog> </Dialog>
</template> </template>
<script setup lang="ts"> <script setup lang="ts">
import { ref, computed, onMounted } from 'vue' import { ref, computed, onMounted, watch } from 'vue'
import { Dialog, Button, Input, Label } from '@/components/ui' import { Dialog, Button, Input, Label } from '@/components/ui'
import { Key, SquarePen } from 'lucide-vue-next' import { Key, SquarePen } from 'lucide-vue-next'
import { useToast } from '@/composables/useToast' import { useToast } from '@/composables/useToast'
import { useFormDialog } from '@/composables/useFormDialog' import { useFormDialog } from '@/composables/useFormDialog'
import { parseApiError } from '@/utils/errorParser' import { parseApiError } from '@/utils/errorParser'
import { parseNumberInput } from '@/utils/form' import { parseNumberInput, parseNullableNumberInput } from '@/utils/form'
import { log } from '@/utils/logger' import { log } from '@/utils/logger'
import { import {
addEndpointKey, addProviderKey,
updateEndpointKey, updateProviderKey,
getAllCapabilities, getAllCapabilities,
API_FORMAT_LABELS,
sortApiFormats,
type EndpointAPIKey, type EndpointAPIKey,
type EndpointAPIKeyUpdate, type EndpointAPIKeyUpdate,
type ProviderEndpoint, type ProviderEndpoint,
@@ -270,6 +264,7 @@ const props = defineProps<{
endpoint: ProviderEndpoint | null endpoint: ProviderEndpoint | null
editingKey: EndpointAPIKey | null editingKey: EndpointAPIKey | null
providerId: string | null providerId: string | null
availableApiFormats: string[] // Provider 支持的所有 API 格式
}>() }>()
const emit = defineEmits<{ const emit = defineEmits<{
@@ -279,6 +274,9 @@ const emit = defineEmits<{
const { success, error: showError } = useToast() const { success, error: showError } = useToast()
// 排序后的可用 API 格式列表
const sortedApiFormats = computed(() => sortApiFormats(props.availableApiFormats))
const isOpen = computed(() => props.open) const isOpen = computed(() => props.open)
const saving = ref(false) const saving = ref(false)
const formNonce = ref(createFieldNonce()) const formNonce = ref(createFieldNonce())
@@ -297,12 +295,10 @@ const availableCapabilities = ref<CapabilityDefinition[]>([])
const form = ref({ const form = ref({
name: '', name: '',
api_key: '', api_key: '',
rate_multiplier: 1.0, api_formats: [] as string[], // 支持的 API 格式列表
internal_priority: 50, rate_multipliers: {} as Record<string, number>, // 按 API 格式的成本倍率
max_concurrent: undefined as number | undefined, internal_priority: 10,
rate_limit: undefined as number | undefined, rpm_limit: undefined as number | null | undefined, // RPM 限制null=自适应undefined=保持原值)
daily_limit: undefined as number | undefined,
monthly_limit: undefined as number | undefined,
cache_ttl_minutes: 5, cache_ttl_minutes: 5,
max_probe_interval_minutes: 32, max_probe_interval_minutes: 32,
note: '', note: '',
@@ -323,6 +319,43 @@ onMounted(() => {
loadCapabilities() loadCapabilities()
}) })
// API 格式切换
function toggleApiFormat(format: string) {
const index = form.value.api_formats.indexOf(format)
if (index === -1) {
// 添加格式
form.value.api_formats.push(format)
} else {
// 移除格式前检查:至少保留一个格式
if (form.value.api_formats.length <= 1) {
showError('至少需要选择一个 API 格式', '验证失败')
return
}
// 移除格式,但保留倍率配置(用户可能只是临时取消)
form.value.api_formats.splice(index, 1)
}
}
// 更新指定格式的成本倍率
function updateRateMultiplier(format: string, value: string | number) {
// 使用对象替换以确保 Vue 3 响应性
const newMultipliers = { ...form.value.rate_multipliers }
if (value === '' || value === null || value === undefined) {
// 清空时删除该格式的配置(使用默认倍率)
delete newMultipliers[format]
} else {
const numValue = typeof value === 'string' ? parseFloat(value) : value
// 限制倍率范围0.01 - 100
if (!isNaN(numValue) && numValue >= 0.01 && numValue <= 100) {
newMultipliers[format] = numValue
}
}
// 替换整个对象以触发响应式更新
form.value.rate_multipliers = newMultipliers
}
// API 密钥输入框样式计算 // API 密钥输入框样式计算
function getApiKeyInputClass(): string { function getApiKeyInputClass(): string {
const classes = [] const classes = []
@@ -349,8 +382,8 @@ const apiKeyError = computed(() => {
} }
// 如果输入了值,检查长度 // 如果输入了值,检查长度
if (apiKey.length < 10) { if (apiKey.length < 3) {
return 'API 密钥至少需要 10 个字符' return 'API 密钥至少需要 3 个字符'
} }
return '' return ''
@@ -363,12 +396,10 @@ function resetForm() {
form.value = { form.value = {
name: '', name: '',
api_key: '', api_key: '',
rate_multiplier: 1.0, api_formats: [], // 默认不选中任何格式
internal_priority: 50, rate_multipliers: {},
max_concurrent: undefined, internal_priority: 10,
rate_limit: undefined, rpm_limit: undefined,
daily_limit: undefined,
monthly_limit: undefined,
cache_ttl_minutes: 5, cache_ttl_minutes: 5,
max_probe_interval_minutes: 32, max_probe_interval_minutes: 32,
note: '', note: '',
@@ -377,6 +408,14 @@ function resetForm() {
} }
} }
// 添加成功后清除部分字段以便继续添加
function clearForNextAdd() {
formNonce.value = createFieldNonce()
apiKeyFocused.value = false
form.value.name = ''
form.value.api_key = ''
}
// 加载密钥数据(编辑模式) // 加载密钥数据(编辑模式)
function loadKeyData() { function loadKeyData() {
if (!props.editingKey) return if (!props.editingKey) return
@@ -385,13 +424,13 @@ function loadKeyData() {
form.value = { form.value = {
name: props.editingKey.name, name: props.editingKey.name,
api_key: '', api_key: '',
rate_multiplier: props.editingKey.rate_multiplier || 1.0, api_formats: props.editingKey.api_formats?.length > 0
internal_priority: props.editingKey.internal_priority ?? 50, ? [...props.editingKey.api_formats]
: [], // 编辑模式下保持原有选择,不默认全选
rate_multipliers: { ...(props.editingKey.rate_multipliers || {}) },
internal_priority: props.editingKey.internal_priority ?? 10,
// 保留原始的 null/undefined 状态null 表示自适应模式 // 保留原始的 null/undefined 状态null 表示自适应模式
max_concurrent: props.editingKey.max_concurrent ?? undefined, rpm_limit: props.editingKey.rpm_limit ?? undefined,
rate_limit: props.editingKey.rate_limit ?? undefined,
daily_limit: props.editingKey.daily_limit ?? undefined,
monthly_limit: props.editingKey.monthly_limit ?? undefined,
cache_ttl_minutes: props.editingKey.cache_ttl_minutes ?? 5, cache_ttl_minutes: props.editingKey.cache_ttl_minutes ?? 5,
max_probe_interval_minutes: props.editingKey.max_probe_interval_minutes ?? 32, max_probe_interval_minutes: props.editingKey.max_probe_interval_minutes ?? 32,
note: props.editingKey.note || '', note: props.editingKey.note || '',
@@ -415,7 +454,11 @@ function createFieldNonce(): string {
} }
async function handleSave() { async function handleSave() {
if (!props.endpoint) return // 必须有 providerId
if (!props.providerId) {
showError('无法保存:缺少提供商信息', '错误')
return
}
// 提交前验证 // 提交前验证
if (apiKeyError.value) { if (apiKeyError.value) {
@@ -429,6 +472,12 @@ async function handleSave() {
return return
} }
// 验证至少选择一个 API 格式
if (form.value.api_formats.length === 0) {
showError('请至少选择一个 API 格式', '验证失败')
return
}
// 过滤出有效的能力配置(只包含值为 true 的) // 过滤出有效的能力配置(只包含值为 true 的)
const activeCapabilities: Record<string, boolean> = {} const activeCapabilities: Record<string, boolean> = {}
for (const [key, value] of Object.entries(form.value.capabilities)) { for (const [key, value] of Object.entries(form.value.capabilities)) {
@@ -440,21 +489,27 @@ async function handleSave() {
saving.value = true saving.value = true
try { try {
// 准备 rate_multipliers 数据:只保留已选中格式的倍率配置
const filteredMultipliers: Record<string, number> = {}
for (const format of form.value.api_formats) {
if (form.value.rate_multipliers[format] !== undefined) {
filteredMultipliers[format] = form.value.rate_multipliers[format]
}
}
const rateMultipliersData = Object.keys(filteredMultipliers).length > 0
? filteredMultipliers
: null
if (props.editingKey) { if (props.editingKey) {
// 更新模式 // 更新模式
// 注意:max_concurrent 需要显式发送 null 来切换到自适应模式 // 注意:rpm_limit 使用 null 表示自适应模式
// undefined 会在 JSON 中被忽略,所以用 null 表示"清空/自适应" // undefined 表示"保持原值不变"会在 JSON 序列化时被忽略
const updateData: EndpointAPIKeyUpdate = { const updateData: EndpointAPIKeyUpdate = {
api_formats: form.value.api_formats,
name: form.value.name, name: form.value.name,
rate_multiplier: form.value.rate_multiplier, rate_multipliers: rateMultipliersData,
internal_priority: form.value.internal_priority, internal_priority: form.value.internal_priority,
// 显式使用 null 表示自适应模式,这样后端能区分"未提供"和"设置为 null" rpm_limit: form.value.rpm_limit,
// 注意:只有 max_concurrent 需要这种处理,因为它有"自适应模式"的概念
// 其他限制字段rate_limit 等)不支持"清空"操作undefined 会被 JSON 忽略即不更新
max_concurrent: form.value.max_concurrent === undefined ? null : form.value.max_concurrent,
rate_limit: form.value.rate_limit,
daily_limit: form.value.daily_limit,
monthly_limit: form.value.monthly_limit,
cache_ttl_minutes: form.value.cache_ttl_minutes, cache_ttl_minutes: form.value.cache_ttl_minutes,
max_probe_interval_minutes: form.value.max_probe_interval_minutes, max_probe_interval_minutes: form.value.max_probe_interval_minutes,
note: form.value.note, note: form.value.note,
@@ -466,26 +521,27 @@ async function handleSave() {
updateData.api_key = form.value.api_key updateData.api_key = form.value.api_key
} }
await updateEndpointKey(props.editingKey.id, updateData) await updateProviderKey(props.editingKey.id, updateData)
success('密钥已更新', '成功') success('密钥已更新', '成功')
} else { } else {
// 新增 // 新增模式
await addEndpointKey(props.endpoint.id, { await addProviderKey(props.providerId, {
endpoint_id: props.endpoint.id, api_formats: form.value.api_formats,
api_key: form.value.api_key, api_key: form.value.api_key,
name: form.value.name, name: form.value.name,
rate_multiplier: form.value.rate_multiplier, rate_multipliers: rateMultipliersData,
internal_priority: form.value.internal_priority, internal_priority: form.value.internal_priority,
max_concurrent: form.value.max_concurrent, rpm_limit: form.value.rpm_limit,
rate_limit: form.value.rate_limit,
daily_limit: form.value.daily_limit,
monthly_limit: form.value.monthly_limit,
cache_ttl_minutes: form.value.cache_ttl_minutes, cache_ttl_minutes: form.value.cache_ttl_minutes,
max_probe_interval_minutes: form.value.max_probe_interval_minutes, max_probe_interval_minutes: form.value.max_probe_interval_minutes,
note: form.value.note, note: form.value.note,
capabilities: capabilitiesData || undefined capabilities: capabilitiesData || undefined
}) })
success('密钥已添加', '成功') success('密钥已添加', '成功')
// 添加模式:不关闭对话框,只清除名称和密钥以便继续添加
emit('saved')
clearForNextAdd()
return
} }
emit('saved') emit('saved')

View File

@@ -95,7 +95,7 @@
<!-- 提供商信息 --> <!-- 提供商信息 -->
<div class="flex-1 min-w-0 flex items-center gap-2"> <div class="flex-1 min-w-0 flex items-center gap-2">
<span class="font-medium text-sm truncate">{{ provider.display_name }}</span> <span class="font-medium text-sm truncate">{{ provider.name }}</span>
<Badge <Badge
v-if="!provider.is_active" v-if="!provider.is_active"
variant="secondary" variant="secondary"
@@ -262,17 +262,17 @@
<div class="shrink-0 flex items-center gap-3"> <div class="shrink-0 flex items-center gap-3">
<!-- 健康度 --> <!-- 健康度 -->
<div <div
v-if="key.success_rate !== null" v-if="key.health_score != null"
class="text-xs text-right" class="text-xs text-right"
> >
<div <div
class="font-medium tabular-nums" class="font-medium tabular-nums"
:class="[ :class="[
key.success_rate >= 0.95 ? 'text-green-600' : key.health_score >= 0.95 ? 'text-green-600' :
key.success_rate >= 0.8 ? 'text-yellow-600' : 'text-red-500' key.health_score >= 0.5 ? 'text-yellow-600' : 'text-red-500'
]" ]"
> >
{{ (key.success_rate * 100).toFixed(0) }}% {{ ((key.health_score || 0) * 100).toFixed(0) }}%
</div> </div>
<div class="text-[10px] text-muted-foreground opacity-70"> <div class="text-[10px] text-muted-foreground opacity-70">
{{ key.request_count }} reqs {{ key.request_count }} reqs
@@ -319,19 +319,6 @@
<div class="flex items-center gap-2 pl-4 border-l border-border"> <div class="flex items-center gap-2 pl-4 border-l border-border">
<span class="text-xs text-muted-foreground">调度:</span> <span class="text-xs text-muted-foreground">调度:</span>
<div class="flex gap-0.5 p-0.5 bg-muted/40 rounded-md"> <div class="flex gap-0.5 p-0.5 bg-muted/40 rounded-md">
<button
type="button"
class="px-2 py-1 text-xs font-medium rounded transition-all"
:class="[
schedulingMode === 'fixed_order'
? 'bg-primary text-primary-foreground shadow-sm'
: 'text-muted-foreground hover:text-foreground hover:bg-muted/50'
]"
title="严格按优先级顺序,不考虑缓存"
@click="schedulingMode = 'fixed_order'"
>
固定顺序
</button>
<button <button
type="button" type="button"
class="px-2 py-1 text-xs font-medium rounded transition-all" class="px-2 py-1 text-xs font-medium rounded transition-all"
@@ -345,6 +332,32 @@
> >
缓存亲和 缓存亲和
</button> </button>
<button
type="button"
class="px-2 py-1 text-xs font-medium rounded transition-all"
:class="[
schedulingMode === 'load_balance'
? 'bg-primary text-primary-foreground shadow-sm'
: 'text-muted-foreground hover:text-foreground hover:bg-muted/50'
]"
title="同优先级内随机轮换,不考虑缓存"
@click="schedulingMode = 'load_balance'"
>
负载均衡
</button>
<button
type="button"
class="px-2 py-1 text-xs font-medium rounded transition-all"
:class="[
schedulingMode === 'fixed_order'
? 'bg-primary text-primary-foreground shadow-sm'
: 'text-muted-foreground hover:text-foreground hover:bg-muted/50'
]"
title="严格按优先级顺序,不考虑缓存"
@click="schedulingMode = 'fixed_order'"
>
固定顺序
</button>
</div> </div>
</div> </div>
</div> </div>
@@ -382,7 +395,7 @@ import { Dialog } from '@/components/ui'
import Button from '@/components/ui/button.vue' import Button from '@/components/ui/button.vue'
import Badge from '@/components/ui/badge.vue' import Badge from '@/components/ui/badge.vue'
import { useToast } from '@/composables/useToast' import { useToast } from '@/composables/useToast'
import { updateProvider, updateEndpointKey } from '@/api/endpoints' import { updateProvider, updateProviderKey } from '@/api/endpoints'
import type { ProviderWithEndpointsSummary } from '@/api/endpoints' import type { ProviderWithEndpointsSummary } from '@/api/endpoints'
import { adminApi } from '@/api/admin' import { adminApi } from '@/api/admin'
@@ -400,6 +413,7 @@ interface KeyWithMeta {
endpoint_base_url: string endpoint_base_url: string
api_format: string api_format: string
capabilities: string[] capabilities: string[]
health_score: number | null
success_rate: number | null success_rate: number | null
avg_response_time_ms: number | null avg_response_time_ms: number | null
request_count: number request_count: number
@@ -444,7 +458,7 @@ const saving = ref(false)
const editingKeyPriority = ref<Record<string, string | null>>({}) // format -> keyId const editingKeyPriority = ref<Record<string, string | null>>({}) // format -> keyId
// 调度模式状态 // 调度模式状态
const schedulingMode = ref<'fixed_order' | 'cache_affinity'>('cache_affinity') const schedulingMode = ref<'fixed_order' | 'load_balance' | 'cache_affinity'>('cache_affinity')
// 可用的 API 格式 // 可用的 API 格式
const availableFormats = computed(() => { const availableFormats = computed(() => {
@@ -477,7 +491,11 @@ async function loadCurrentPriorityMode() {
activeMainTab.value = currentMode === 'global_key' ? 'key' : 'provider' activeMainTab.value = currentMode === 'global_key' ? 'key' : 'provider'
const currentSchedulingMode = schedulingResponse.value || 'cache_affinity' const currentSchedulingMode = schedulingResponse.value || 'cache_affinity'
schedulingMode.value = currentSchedulingMode === 'fixed_order' ? 'fixed_order' : 'cache_affinity' if (currentSchedulingMode === 'fixed_order' || currentSchedulingMode === 'load_balance' || currentSchedulingMode === 'cache_affinity') {
schedulingMode.value = currentSchedulingMode
} else {
schedulingMode.value = 'cache_affinity'
}
} catch { } catch {
activeMainTab.value = 'provider' activeMainTab.value = 'provider'
schedulingMode.value = 'cache_affinity' schedulingMode.value = 'cache_affinity'
@@ -678,7 +696,7 @@ async function save() {
const keys = keysByFormat.value[format] const keys = keysByFormat.value[format]
keys.forEach((key) => { keys.forEach((key) => {
// 使用用户设置的 priority 值,相同 priority 会做负载均衡 // 使用用户设置的 priority 值,相同 priority 会做负载均衡
keyUpdates.push(updateEndpointKey(key.id, { global_priority: key.priority })) keyUpdates.push(updateProviderKey(key.id, { global_priority: key.priority }))
}) })
} }

View File

@@ -4,47 +4,29 @@
:title="isEditMode ? '编辑提供商' : '添加提供商'" :title="isEditMode ? '编辑提供商' : '添加提供商'"
:description="isEditMode ? '更新提供商配置。API 端点和密钥需在详情页面单独管理。' : '创建新的提供商配置。创建后可以为其添加 API 端点和密钥。'" :description="isEditMode ? '更新提供商配置。API 端点和密钥需在详情页面单独管理。' : '创建新的提供商配置。创建后可以为其添加 API 端点和密钥。'"
:icon="isEditMode ? SquarePen : Server" :icon="isEditMode ? SquarePen : Server"
size="2xl" size="xl"
@update:model-value="handleDialogUpdate" @update:model-value="handleDialogUpdate"
> >
<form <form
class="space-y-6" class="space-y-5"
@submit.prevent="handleSubmit" @submit.prevent="handleSubmit"
> >
<!-- 基本信息 --> <!-- 基本信息 -->
<div class="space-y-4"> <div class="space-y-3">
<h3 class="text-sm font-medium border-b pb-2"> <h3 class="text-sm font-medium border-b pb-2">
基本信息 基本信息
</h3> </h3>
<!-- 添加模式显示提供商标识 -->
<div
v-if="!isEditMode"
class="space-y-2"
>
<Label for="name">提供商标识 *</Label>
<Input
id="name"
v-model="form.name"
placeholder="例如: openai-primary"
required
/>
<p class="text-xs text-muted-foreground">
唯一ID创建后不可修改
</p>
</div>
<div class="grid grid-cols-2 gap-4"> <div class="grid grid-cols-2 gap-4">
<div class="space-y-2"> <div class="space-y-1.5">
<Label for="display_name">显示名称 *</Label> <Label for="name">名称 *</Label>
<Input <Input
id="display_name" id="name"
v-model="form.display_name" v-model="form.name"
placeholder="例如: OpenAI 主账号" placeholder="例如: OpenAI 主账号"
required
/> />
</div> </div>
<div class="space-y-2"> <div class="space-y-1.5">
<Label for="website">主站链接</Label> <Label for="website">主站链接</Label>
<Input <Input
id="website" id="website"
@@ -55,24 +37,28 @@
</div> </div>
</div> </div>
<div class="space-y-2"> <div class="space-y-1.5">
<Label for="description">描述</Label> <Label for="description">描述</Label>
<Textarea <Input
id="description" id="description"
v-model="form.description" v-model="form.description"
placeholder="提供商描述(可选)" placeholder="提供商描述(可选)"
rows="2"
/> />
</div> </div>
</div> </div>
<!-- 计费与限流 --> <!-- 计费与限流 / 请求配置 -->
<div class="space-y-4"> <div class="space-y-3">
<h3 class="text-sm font-medium border-b pb-2">
计费与限流
</h3>
<div class="grid grid-cols-2 gap-4"> <div class="grid grid-cols-2 gap-4">
<div class="space-y-2"> <h3 class="text-sm font-medium border-b pb-2">
计费与限流
</h3>
<h3 class="text-sm font-medium border-b pb-2">
请求配置
</h3>
</div>
<div class="grid grid-cols-2 gap-4">
<div class="space-y-1.5">
<Label>计费类型</Label> <Label>计费类型</Label>
<Select <Select
v-model="form.billing_type" v-model="form.billing_type"
@@ -82,27 +68,35 @@
<SelectValue /> <SelectValue />
</SelectTrigger> </SelectTrigger>
<SelectContent> <SelectContent>
<SelectItem value="monthly_quota"> <SelectItem value="monthly_quota">月卡额度</SelectItem>
月卡额度 <SelectItem value="pay_as_you_go">按量付费</SelectItem>
</SelectItem> <SelectItem value="free_tier">免费套餐</SelectItem>
<SelectItem value="pay_as_you_go">
按量付费
</SelectItem>
<SelectItem value="free_tier">
免费套餐
</SelectItem>
</SelectContent> </SelectContent>
</Select> </Select>
</div> </div>
<div class="space-y-2"> <div class="grid grid-cols-2 gap-4">
<Label>RPM 限制</Label> <div class="space-y-1.5">
<Input <Label>超时时间 ()</Label>
:model-value="form.rpm_limit ?? ''" <Input
type="number" :model-value="form.timeout ?? ''"
min="0" type="number"
placeholder="不限制请留空" min="1"
@update:model-value="(v) => form.rpm_limit = parseNumberInput(v)" max="600"
/> placeholder="默认 300"
@update:model-value="(v) => form.timeout = parseNumberInput(v)"
/>
</div>
<div class="space-y-1.5">
<Label>最大重试次数</Label>
<Input
:model-value="form.max_retries ?? ''"
type="number"
min="0"
max="10"
placeholder="默认 2"
@update:model-value="(v) => form.max_retries = parseNumberInput(v)"
/>
</div>
</div> </div>
</div> </div>
@@ -111,52 +105,94 @@
v-if="form.billing_type === 'monthly_quota'" v-if="form.billing_type === 'monthly_quota'"
class="grid grid-cols-2 gap-4 p-3 border rounded-lg bg-muted/50" class="grid grid-cols-2 gap-4 p-3 border rounded-lg bg-muted/50"
> >
<div class="space-y-2"> <div class="space-y-1.5">
<Label class="text-xs">周期额度 (USD)</Label> <Label class="text-xs">周期额度 (USD)</Label>
<Input <Input
:model-value="form.monthly_quota_usd ?? ''" :model-value="form.monthly_quota_usd ?? ''"
type="number" type="number"
step="0.01" step="0.01"
min="0" min="0"
class="h-9"
@update:model-value="(v) => form.monthly_quota_usd = parseNumberInput(v, { allowFloat: true })" @update:model-value="(v) => form.monthly_quota_usd = parseNumberInput(v, { allowFloat: true })"
/> />
</div> </div>
<div class="space-y-2"> <div class="space-y-1.5">
<Label class="text-xs">重置周期 (天)</Label> <Label class="text-xs">重置周期 (天)</Label>
<Input <Input
:model-value="form.quota_reset_day ?? ''" :model-value="form.quota_reset_day ?? ''"
type="number" type="number"
min="1" min="1"
max="365" max="365"
class="h-9"
@update:model-value="(v) => form.quota_reset_day = parseNumberInput(v) ?? 30" @update:model-value="(v) => form.quota_reset_day = parseNumberInput(v) ?? 30"
/> />
</div> </div>
<div class="space-y-2"> <div class="space-y-1.5">
<Label class="text-xs"> <Label class="text-xs">
周期开始时间 周期开始时间 <span class="text-red-500">*</span>
<span class="text-red-500">*</span>
</Label> </Label>
<Input <Input
v-model="form.quota_last_reset_at" v-model="form.quota_last_reset_at"
type="datetime-local" type="datetime-local"
class="h-9"
/> />
<p class="text-xs text-muted-foreground">
系统会自动统计从该时间点开始的使用量
</p>
</div> </div>
<div class="space-y-2"> <div class="space-y-1.5">
<Label class="text-xs">过期时间</Label> <Label class="text-xs">过期时间</Label>
<Input <Input
v-model="form.quota_expires_at" v-model="form.quota_expires_at"
type="datetime-local" type="datetime-local"
class="h-9"
/> />
<p class="text-xs text-muted-foreground"> </div>
留空表示永久有效 </div>
</p> </div>
<!-- 代理配置 -->
<div class="space-y-3">
<div class="flex items-center justify-between">
<h3 class="text-sm font-medium">
代理配置
</h3>
<div class="flex items-center gap-2">
<Switch
:model-value="form.proxy_enabled"
@update:model-value="(v: boolean) => form.proxy_enabled = v"
/>
<span class="text-sm text-muted-foreground">启用代理</span>
</div>
</div>
<div
v-if="form.proxy_enabled"
class="grid grid-cols-2 gap-4 p-3 border rounded-lg bg-muted/50"
>
<div class="space-y-1.5">
<Label class="text-xs">代理地址 *</Label>
<Input
v-model="form.proxy_url"
placeholder="http://proxy:port 或 socks5://proxy:port"
/>
</div>
<div class="grid grid-cols-2 gap-3">
<div class="space-y-1.5">
<Label class="text-xs">用户名</Label>
<Input
v-model="form.proxy_username"
placeholder="可选"
autocomplete="off"
data-form-type="other"
data-lpignore="true"
data-1p-ignore="true"
/>
</div>
<div class="space-y-1.5">
<Label class="text-xs">密码</Label>
<Input
v-model="form.proxy_password"
type="password"
placeholder="可选"
autocomplete="new-password"
data-form-type="other"
data-lpignore="true"
data-1p-ignore="true"
/>
</div>
</div> </div>
</div> </div>
</div> </div>
@@ -172,7 +208,7 @@
取消 取消
</Button> </Button>
<Button <Button
:disabled="loading || !form.display_name || (!isEditMode && !form.name)" :disabled="loading || !form.name"
@click="handleSubmit" @click="handleSubmit"
> >
{{ loading ? (isEditMode ? '保存中...' : '创建中...') : (isEditMode ? '保存' : '创建') }} {{ loading ? (isEditMode ? '保存中...' : '创建中...') : (isEditMode ? '保存' : '创建') }}
@@ -187,13 +223,13 @@ import {
Dialog, Dialog,
Button, Button,
Input, Input,
Textarea,
Label, Label,
Select, Select,
SelectTrigger, SelectTrigger,
SelectValue, SelectValue,
SelectContent, SelectContent,
SelectItem, SelectItem,
Switch,
} from '@/components/ui' } from '@/components/ui'
import { Server, SquarePen } from 'lucide-vue-next' import { Server, SquarePen } from 'lucide-vue-next'
import { useToast } from '@/composables/useToast' import { useToast } from '@/composables/useToast'
@@ -223,7 +259,6 @@ const internalOpen = computed(() => props.modelValue)
// 表单数据 // 表单数据
const form = ref({ const form = ref({
name: '', name: '',
display_name: '',
description: '', description: '',
website: '', website: '',
// 计费配置 // 计费配置
@@ -232,19 +267,25 @@ const form = ref({
quota_reset_day: 30, quota_reset_day: 30,
quota_last_reset_at: '', // 周期开始时间 quota_last_reset_at: '', // 周期开始时间
quota_expires_at: '', quota_expires_at: '',
rpm_limit: undefined as string | number | undefined,
provider_priority: 999, provider_priority: 999,
// 状态配置 // 状态配置
is_active: true, is_active: true,
rate_limit: undefined as number | undefined, rate_limit: undefined as number | undefined,
concurrent_limit: undefined as number | undefined, concurrent_limit: undefined as number | undefined,
// 请求配置
timeout: undefined as number | undefined,
max_retries: undefined as number | undefined,
// 代理配置(扁平化便于表单绑定)
proxy_enabled: false,
proxy_url: '',
proxy_username: '',
proxy_password: '',
}) })
// 重置表单 // 重置表单
function resetForm() { function resetForm() {
form.value = { form.value = {
name: '', name: '',
display_name: '',
description: '', description: '',
website: '', website: '',
billing_type: 'pay_as_you_go', billing_type: 'pay_as_you_go',
@@ -252,11 +293,18 @@ function resetForm() {
quota_reset_day: 30, quota_reset_day: 30,
quota_last_reset_at: '', quota_last_reset_at: '',
quota_expires_at: '', quota_expires_at: '',
rpm_limit: undefined,
provider_priority: 999, provider_priority: 999,
is_active: true, is_active: true,
rate_limit: undefined, rate_limit: undefined,
concurrent_limit: undefined, concurrent_limit: undefined,
// 请求配置
timeout: undefined,
max_retries: undefined,
// 代理配置
proxy_enabled: false,
proxy_url: '',
proxy_username: '',
proxy_password: '',
} }
} }
@@ -264,9 +312,9 @@ function resetForm() {
function loadProviderData() { function loadProviderData() {
if (!props.provider) return if (!props.provider) return
const proxy = props.provider.proxy
form.value = { form.value = {
name: props.provider.name, name: props.provider.name,
display_name: props.provider.display_name,
description: props.provider.description || '', description: props.provider.description || '',
website: props.provider.website || '', website: props.provider.website || '',
billing_type: (props.provider.billing_type as 'monthly_quota' | 'pay_as_you_go' | 'free_tier') || 'pay_as_you_go', billing_type: (props.provider.billing_type as 'monthly_quota' | 'pay_as_you_go' | 'free_tier') || 'pay_as_you_go',
@@ -276,11 +324,18 @@ function loadProviderData() {
new Date(props.provider.quota_last_reset_at).toISOString().slice(0, 16) : '', new Date(props.provider.quota_last_reset_at).toISOString().slice(0, 16) : '',
quota_expires_at: props.provider.quota_expires_at ? quota_expires_at: props.provider.quota_expires_at ?
new Date(props.provider.quota_expires_at).toISOString().slice(0, 16) : '', new Date(props.provider.quota_expires_at).toISOString().slice(0, 16) : '',
rpm_limit: props.provider.rpm_limit ?? undefined,
provider_priority: props.provider.provider_priority || 999, provider_priority: props.provider.provider_priority || 999,
is_active: props.provider.is_active, is_active: props.provider.is_active,
rate_limit: undefined, rate_limit: undefined,
concurrent_limit: undefined, concurrent_limit: undefined,
// 请求配置
timeout: props.provider.timeout ?? undefined,
max_retries: props.provider.max_retries ?? undefined,
// 代理配置
proxy_enabled: proxy?.enabled ?? false,
proxy_url: proxy?.url || '',
proxy_username: proxy?.username || '',
proxy_password: proxy?.password || '',
} }
} }
@@ -302,17 +357,37 @@ const handleSubmit = async () => {
return return
} }
// 启用代理时必须填写代理地址
if (form.value.proxy_enabled && !form.value.proxy_url) {
showError('启用代理时必须填写代理地址', '验证失败')
return
}
loading.value = true loading.value = true
try { try {
// 构建代理配置
const proxy = form.value.proxy_enabled ? {
url: form.value.proxy_url,
username: form.value.proxy_username || undefined,
password: form.value.proxy_password || undefined,
enabled: true,
} : null
const payload = { const payload = {
...form.value, name: form.value.name,
rpm_limit: description: form.value.description || undefined,
form.value.rpm_limit === undefined || form.value.rpm_limit === '' website: form.value.website || undefined,
? null billing_type: form.value.billing_type,
: Number(form.value.rpm_limit), monthly_quota_usd: form.value.monthly_quota_usd,
// 空字符串时不发送 quota_reset_day: form.value.quota_reset_day,
quota_last_reset_at: form.value.quota_last_reset_at || undefined, quota_last_reset_at: form.value.quota_last_reset_at || undefined,
quota_expires_at: form.value.quota_expires_at || undefined, quota_expires_at: form.value.quota_expires_at || undefined,
provider_priority: form.value.provider_priority,
is_active: form.value.is_active,
// 请求配置
timeout: form.value.timeout ?? undefined,
max_retries: form.value.max_retries ?? undefined,
proxy,
} }
if (isEditMode.value && props.provider) { if (isEditMode.value && props.provider) {

View File

@@ -2,6 +2,7 @@ export { default as ProviderFormDialog } from './ProviderFormDialog.vue'
export { default as EndpointFormDialog } from './EndpointFormDialog.vue' export { default as EndpointFormDialog } from './EndpointFormDialog.vue'
export { default as KeyFormDialog } from './KeyFormDialog.vue' export { default as KeyFormDialog } from './KeyFormDialog.vue'
export { default as KeyAllowedModelsDialog } from './KeyAllowedModelsDialog.vue' export { default as KeyAllowedModelsDialog } from './KeyAllowedModelsDialog.vue'
export { default as KeyAllowedModelsEditDialog } from './KeyAllowedModelsEditDialog.vue'
export { default as PriorityManagementDialog } from './PriorityManagementDialog.vue' export { default as PriorityManagementDialog } from './PriorityManagementDialog.vue'
export { default as ProviderModelFormDialog } from './ProviderModelFormDialog.vue' export { default as ProviderModelFormDialog } from './ProviderModelFormDialog.vue'
export { default as ProviderDetailDrawer } from './ProviderDetailDrawer.vue' export { default as ProviderDetailDrawer } from './ProviderDetailDrawer.vue'

View File

@@ -178,7 +178,7 @@
<Button <Button
variant="ghost" variant="ghost"
size="icon" size="icon"
class="h-8 w-8 text-destructive hover:text-destructive" class="h-8 w-8 hover:text-destructive"
title="删除" title="删除"
@click="deleteModel(model)" @click="deleteModel(model)"
> >

View File

@@ -289,14 +289,14 @@
/> />
</div> </div>
<!-- 错误信息卡片 --> <!-- 响应客户端错误卡片 -->
<Card <Card
v-if="detail.error_message" v-if="detail.error_message"
class="border-red-200 dark:border-red-800" class="border-red-200 dark:border-red-800"
> >
<div class="p-4"> <div class="p-4">
<h4 class="text-sm font-semibold text-red-600 dark:text-red-400 mb-2"> <h4 class="text-sm font-semibold text-red-600 dark:text-red-400 mb-2">
错误信息 响应客户端错误
</h4> </h4>
<div class="bg-red-50 dark:bg-red-900/20 rounded-lg p-3"> <div class="bg-red-50 dark:bg-red-900/20 rounded-lg p-3">
<p class="text-sm text-red-800 dark:text-red-300"> <p class="text-sm text-red-800 dark:text-red-300">
@@ -431,7 +431,7 @@
<TabsContent value="response-headers"> <TabsContent value="response-headers">
<JsonContent <JsonContent
:data="detail.response_headers" :data="actualResponseHeaders"
:view-mode="viewMode" :view-mode="viewMode"
:expand-depth="currentExpandDepth" :expand-depth="currentExpandDepth"
:is-dark="isDark" :is-dark="isDark"
@@ -614,6 +614,25 @@ const tabs = [
{ name: 'metadata', label: '元数据' }, { name: 'metadata', label: '元数据' },
] ]
// 判断数据是否有实际内容(非空对象/数组)
function hasContent(data: unknown): boolean {
if (data === null || data === undefined) return false
if (typeof data === 'object') {
return Object.keys(data as object).length > 0
}
return true
}
// 获取实际的响应头(优先 client_response_headers回退到 response_headers
const actualResponseHeaders = computed(() => {
if (!detail.value) return null
// 优先返回客户端响应头,如果没有则回退到提供商响应头
if (hasContent(detail.value.client_response_headers)) {
return detail.value.client_response_headers
}
return detail.value.response_headers
})
// 根据实际数据决定显示哪些 Tab // 根据实际数据决定显示哪些 Tab
const visibleTabs = computed(() => { const visibleTabs = computed(() => {
if (!detail.value) return [] if (!detail.value) return []
@@ -621,15 +640,15 @@ const visibleTabs = computed(() => {
return tabs.filter(tab => { return tabs.filter(tab => {
switch (tab.name) { switch (tab.name) {
case 'request-headers': case 'request-headers':
return detail.value!.request_headers && Object.keys(detail.value!.request_headers).length > 0 return hasContent(detail.value!.request_headers)
case 'request-body': case 'request-body':
return detail.value!.request_body !== null && detail.value!.request_body !== undefined return hasContent(detail.value!.request_body)
case 'response-headers': case 'response-headers':
return detail.value!.response_headers && Object.keys(detail.value!.response_headers).length > 0 return hasContent(actualResponseHeaders.value)
case 'response-body': case 'response-body':
return detail.value!.response_body !== null && detail.value!.response_body !== undefined return hasContent(detail.value!.response_body)
case 'metadata': case 'metadata':
return detail.value!.metadata && Object.keys(detail.value!.metadata).length > 0 return hasContent(detail.value!.metadata)
default: default:
return false return false
} }
@@ -775,7 +794,7 @@ function copyJsonToClipboard(tabName: string) {
data = detail.value.request_body data = detail.value.request_body
break break
case 'response-headers': case 'response-headers':
data = detail.value.response_headers data = actualResponseHeaders.value
break break
case 'response-body': case 'response-body':
data = detail.value.response_body data = detail.value.response_body

View File

@@ -252,7 +252,7 @@
@click.stop @click.stop
@change="toggleSelection('allowed_providers', provider.id)" @change="toggleSelection('allowed_providers', provider.id)"
> >
<span class="text-sm">{{ provider.display_name || provider.name }}</span> <span class="text-sm">{{ provider.name }}</span>
</div> </div>
<div <div
v-if="providers.length === 0" v-if="providers.length === 0"

View File

@@ -295,6 +295,15 @@
</template> </template>
<RouterView /> <RouterView />
<!-- 更新提示弹窗 -->
<UpdateDialog
v-if="updateInfo"
v-model="showUpdateDialog"
:current-version="updateInfo.current_version"
:latest-version="updateInfo.latest_version || ''"
:release-url="updateInfo.release_url"
/>
</AppShell> </AppShell>
</template> </template>
@@ -304,10 +313,12 @@ import { useRoute, useRouter } from 'vue-router'
import { useAuthStore } from '@/stores/auth' import { useAuthStore } from '@/stores/auth'
import { useDarkMode } from '@/composables/useDarkMode' import { useDarkMode } from '@/composables/useDarkMode'
import { isDemoMode } from '@/config/demo' import { isDemoMode } from '@/config/demo'
import { adminApi, type CheckUpdateResponse } from '@/api/admin'
import Button from '@/components/ui/button.vue' import Button from '@/components/ui/button.vue'
import AppShell from '@/components/layout/AppShell.vue' import AppShell from '@/components/layout/AppShell.vue'
import SidebarNav from '@/components/layout/SidebarNav.vue' import SidebarNav from '@/components/layout/SidebarNav.vue'
import HeaderLogo from '@/components/HeaderLogo.vue' import HeaderLogo from '@/components/HeaderLogo.vue'
import UpdateDialog from '@/components/common/UpdateDialog.vue'
import { import {
Home, Home,
Users, Users,
@@ -345,17 +356,67 @@ const showAuthError = ref(false)
const mobileMenuOpen = ref(false) const mobileMenuOpen = ref(false)
let authCheckInterval: number | null = null let authCheckInterval: number | null = null
// 更新检查相关
const showUpdateDialog = ref(false)
const updateInfo = ref<CheckUpdateResponse | null>(null)
// 路由变化时自动关闭移动端菜单 // 路由变化时自动关闭移动端菜单
watch(() => route.path, () => { watch(() => route.path, () => {
mobileMenuOpen.value = false mobileMenuOpen.value = false
}) })
// 检查是否应该显示更新提示
function shouldShowUpdatePrompt(latestVersion: string): boolean {
const ignoreKey = 'aether_update_ignore'
const ignoreData = localStorage.getItem(ignoreKey)
if (!ignoreData) return true
try {
const { version, until } = JSON.parse(ignoreData)
// 如果忽略的是同一版本且未过期,则不显示
if (version === latestVersion && Date.now() < until) {
return false
}
} catch {
// 解析失败,显示提示
}
return true
}
// 检查更新
async function checkForUpdate() {
// 只有管理员才检查更新
if (authStore.user?.role !== 'admin') return
// 同一会话内只检查一次
const sessionKey = 'aether_update_checked'
if (sessionStorage.getItem(sessionKey)) return
sessionStorage.setItem(sessionKey, '1')
try {
const result = await adminApi.checkUpdate()
if (result.has_update && result.latest_version) {
if (shouldShowUpdatePrompt(result.latest_version)) {
updateInfo.value = result
showUpdateDialog.value = true
}
}
} catch {
// 静默失败,不影响用户体验
}
}
onMounted(() => { onMounted(() => {
authCheckInterval = setInterval(() => { authCheckInterval = setInterval(() => {
if (authStore.user && !authStore.token) { if (authStore.user && !authStore.token) {
showAuthError.value = true showAuthError.value = true
} }
}, 5000) }, 5000)
// 延迟检查更新,避免影响页面加载
setTimeout(() => {
checkForUpdate()
}, 2000)
}) })
onUnmounted(() => { onUnmounted(() => {

View File

@@ -424,8 +424,7 @@ export const MOCK_ADMIN_API_KEYS: AdminApiKeysResponse = {
export const MOCK_PROVIDERS: ProviderWithEndpointsSummary[] = [ export const MOCK_PROVIDERS: ProviderWithEndpointsSummary[] = [
{ {
id: 'provider-001', id: 'provider-001',
name: 'duck_coding_free', name: 'DuckCodingFree',
display_name: 'DuckCodingFree',
description: '', description: '',
website: 'https://duckcoding.com', website: 'https://duckcoding.com',
provider_priority: 1, provider_priority: 1,
@@ -451,8 +450,7 @@ export const MOCK_PROVIDERS: ProviderWithEndpointsSummary[] = [
}, },
{ {
id: 'provider-002', id: 'provider-002',
name: 'open_claude_code', name: 'OpenClaudeCode',
display_name: 'OpenClaudeCode',
description: '', description: '',
website: 'https://www.openclaudecode.cn', website: 'https://www.openclaudecode.cn',
provider_priority: 2, provider_priority: 2,
@@ -477,8 +475,7 @@ export const MOCK_PROVIDERS: ProviderWithEndpointsSummary[] = [
}, },
{ {
id: 'provider-003', id: 'provider-003',
name: '88_code', name: '88Code',
display_name: '88Code',
description: '', description: '',
website: 'https://www.88code.org/', website: 'https://www.88code.org/',
provider_priority: 3, provider_priority: 3,
@@ -503,8 +500,7 @@ export const MOCK_PROVIDERS: ProviderWithEndpointsSummary[] = [
}, },
{ {
id: 'provider-004', id: 'provider-004',
name: 'ikun_code', name: 'IKunCode',
display_name: 'IKunCode',
description: '', description: '',
website: 'https://api.ikuncode.cc', website: 'https://api.ikuncode.cc',
provider_priority: 4, provider_priority: 4,
@@ -531,8 +527,7 @@ export const MOCK_PROVIDERS: ProviderWithEndpointsSummary[] = [
}, },
{ {
id: 'provider-005', id: 'provider-005',
name: 'duck_coding', name: 'DuckCoding',
display_name: 'DuckCoding',
description: '', description: '',
website: 'https://duckcoding.com', website: 'https://duckcoding.com',
provider_priority: 5, provider_priority: 5,
@@ -561,8 +556,7 @@ export const MOCK_PROVIDERS: ProviderWithEndpointsSummary[] = [
}, },
{ {
id: 'provider-006', id: 'provider-006',
name: 'privnode', name: 'Privnode',
display_name: 'Privnode',
description: '', description: '',
website: 'https://privnode.com', website: 'https://privnode.com',
provider_priority: 6, provider_priority: 6,
@@ -584,8 +578,7 @@ export const MOCK_PROVIDERS: ProviderWithEndpointsSummary[] = [
}, },
{ {
id: 'provider-007', id: 'provider-007',
name: 'undying_api', name: 'UndyingAPI',
display_name: 'UndyingAPI',
description: '', description: '',
website: 'https://vip.undyingapi.com', website: 'https://vip.undyingapi.com',
provider_priority: 7, provider_priority: 7,

View File

@@ -418,16 +418,16 @@ const MOCK_ALIASES = [
// Mock Endpoint Keys // Mock Endpoint Keys
const MOCK_ENDPOINT_KEYS = [ const MOCK_ENDPOINT_KEYS = [
{ id: 'ekey-001', endpoint_id: 'ep-001', api_key_masked: 'sk-ant...abc1', name: 'Primary Key', rate_multiplier: 1.0, internal_priority: 1, health_score: 98, consecutive_failures: 0, request_count: 5000, success_count: 4950, error_count: 50, success_rate: 99, avg_response_time_ms: 1200, is_active: true, created_at: '2024-01-01T00:00:00Z', updated_at: new Date().toISOString() }, { id: 'ekey-001', provider_id: 'provider-001', api_formats: ['CLAUDE'], api_key_masked: 'sk-ant...abc1', name: 'Primary Key', rate_multiplier: 1.0, internal_priority: 1, health_score: 0.98, consecutive_failures: 0, request_count: 5000, success_count: 4950, error_count: 50, success_rate: 0.99, avg_response_time_ms: 1200, cache_ttl_minutes: 5, max_probe_interval_minutes: 32, is_active: true, created_at: '2024-01-01T00:00:00Z', updated_at: new Date().toISOString() },
{ id: 'ekey-002', endpoint_id: 'ep-001', api_key_masked: 'sk-ant...def2', name: 'Backup Key', rate_multiplier: 1.0, internal_priority: 2, health_score: 95, consecutive_failures: 1, request_count: 2000, success_count: 1950, error_count: 50, success_rate: 97.5, avg_response_time_ms: 1350, is_active: true, created_at: '2024-02-01T00:00:00Z', updated_at: new Date().toISOString() }, { id: 'ekey-002', provider_id: 'provider-001', api_formats: ['CLAUDE'], api_key_masked: 'sk-ant...def2', name: 'Backup Key', rate_multiplier: 1.0, internal_priority: 2, health_score: 0.95, consecutive_failures: 1, request_count: 2000, success_count: 1950, error_count: 50, success_rate: 0.975, avg_response_time_ms: 1350, cache_ttl_minutes: 5, max_probe_interval_minutes: 32, is_active: true, created_at: '2024-02-01T00:00:00Z', updated_at: new Date().toISOString() },
{ id: 'ekey-003', endpoint_id: 'ep-002', api_key_masked: 'sk-oai...ghi3', name: 'OpenAI Main', rate_multiplier: 1.0, internal_priority: 1, health_score: 97, consecutive_failures: 0, request_count: 3500, success_count: 3450, error_count: 50, success_rate: 98.6, avg_response_time_ms: 900, is_active: true, created_at: '2024-01-15T00:00:00Z', updated_at: new Date().toISOString() } { id: 'ekey-003', provider_id: 'provider-002', api_formats: ['OPENAI'], api_key_masked: 'sk-oai...ghi3', name: 'OpenAI Main', rate_multiplier: 1.0, internal_priority: 1, health_score: 0.97, consecutive_failures: 0, request_count: 3500, success_count: 3450, error_count: 50, success_rate: 0.986, avg_response_time_ms: 900, cache_ttl_minutes: 5, max_probe_interval_minutes: 32, is_active: true, created_at: '2024-01-15T00:00:00Z', updated_at: new Date().toISOString() }
] ]
// Mock Endpoints // Mock Endpoints
const MOCK_ENDPOINTS = [ const MOCK_ENDPOINTS = [
{ id: 'ep-001', provider_id: 'provider-001', provider_name: 'anthropic', api_format: 'claude', base_url: 'https://api.anthropic.com', auth_type: 'bearer', timeout: 120, max_retries: 3, priority: 100, weight: 100, health_score: 98, consecutive_failures: 0, is_active: true, total_keys: 2, active_keys: 2, created_at: '2024-01-01T00:00:00Z', updated_at: new Date().toISOString() }, { id: 'ep-001', provider_id: 'provider-001', provider_name: 'anthropic', api_format: 'CLAUDE', base_url: 'https://api.anthropic.com', timeout: 300, max_retries: 2, is_active: true, total_keys: 2, active_keys: 2, created_at: '2024-01-01T00:00:00Z', updated_at: new Date().toISOString() },
{ id: 'ep-002', provider_id: 'provider-002', provider_name: 'openai', api_format: 'openai', base_url: 'https://api.openai.com', auth_type: 'bearer', timeout: 60, max_retries: 3, priority: 90, weight: 100, health_score: 97, consecutive_failures: 0, is_active: true, total_keys: 1, active_keys: 1, created_at: '2024-01-01T00:00:00Z', updated_at: new Date().toISOString() }, { id: 'ep-002', provider_id: 'provider-002', provider_name: 'openai', api_format: 'OPENAI', base_url: 'https://api.openai.com', timeout: 60, max_retries: 2, is_active: true, total_keys: 1, active_keys: 1, created_at: '2024-01-01T00:00:00Z', updated_at: new Date().toISOString() },
{ id: 'ep-003', provider_id: 'provider-003', provider_name: 'google', api_format: 'gemini', base_url: 'https://generativelanguage.googleapis.com', auth_type: 'api_key', timeout: 60, max_retries: 3, priority: 80, weight: 100, health_score: 96, consecutive_failures: 0, is_active: true, total_keys: 1, active_keys: 1, created_at: '2024-01-15T00:00:00Z', updated_at: new Date().toISOString() } { id: 'ep-003', provider_id: 'provider-003', provider_name: 'google', api_format: 'GEMINI', base_url: 'https://generativelanguage.googleapis.com', timeout: 60, max_retries: 2, is_active: true, total_keys: 1, active_keys: 1, created_at: '2024-01-15T00:00:00Z', updated_at: new Date().toISOString() }
] ]
// Mock 能力定义 // Mock 能力定义
@@ -581,7 +581,6 @@ const mockHandlers: Record<string, (config: AxiosRequestConfig) => Promise<Axios
return createMockResponse(MOCK_PROVIDERS.map(p => ({ return createMockResponse(MOCK_PROVIDERS.map(p => ({
id: p.id, id: p.id,
name: p.name, name: p.name,
display_name: p.display_name,
is_active: p.is_active is_active: p.is_active
}))) })))
}, },
@@ -1222,13 +1221,8 @@ function generateMockEndpointsForProvider(providerId: string) {
base_url: format.includes('CLAUDE') ? 'https://api.anthropic.com' : base_url: format.includes('CLAUDE') ? 'https://api.anthropic.com' :
format.includes('OPENAI') ? 'https://api.openai.com' : format.includes('OPENAI') ? 'https://api.openai.com' :
'https://generativelanguage.googleapis.com', 'https://generativelanguage.googleapis.com',
auth_type: format.includes('GEMINI') ? 'api_key' : 'bearer', timeout: 300,
timeout: 120, max_retries: 2,
max_retries: 3,
priority: 100 - index * 10,
weight: 100,
health_score: healthDetail?.health_score ?? 1.0,
consecutive_failures: healthDetail?.health_score && healthDetail.health_score < 0.7 ? 2 : 0,
is_active: healthDetail?.is_active ?? true, is_active: healthDetail?.is_active ?? true,
total_keys: Math.ceil(Math.random() * 3) + 1, total_keys: Math.ceil(Math.random() * 3) + 1,
active_keys: Math.ceil(Math.random() * 2) + 1, active_keys: Math.ceil(Math.random() * 2) + 1,
@@ -1238,11 +1232,16 @@ function generateMockEndpointsForProvider(providerId: string) {
}) })
} }
// 为 endpoint 生成 keys // 为 provider 生成 keysKey 归属 Provider通过 api_formats 关联)
function generateMockKeysForEndpoint(endpointId: string, count: number = 2) { const PROVIDER_KEYS_CACHE: Record<string, any[]> = {}
function generateMockKeysForProvider(providerId: string, count: number = 2) {
const provider = MOCK_PROVIDERS.find(p => p.id === providerId)
const formats = provider?.api_formats || []
return Array.from({ length: count }, (_, i) => ({ return Array.from({ length: count }, (_, i) => ({
id: `key-${endpointId}-${i + 1}`, id: `key-${providerId}-${i + 1}`,
endpoint_id: endpointId, provider_id: providerId,
api_formats: i === 0 ? formats : formats.slice(0, 1),
api_key_masked: `sk-***...${Math.random().toString(36).substring(2, 6)}`, api_key_masked: `sk-***...${Math.random().toString(36).substring(2, 6)}`,
name: i === 0 ? 'Primary Key' : `Backup Key ${i}`, name: i === 0 ? 'Primary Key' : `Backup Key ${i}`,
rate_multiplier: 1.0, rate_multiplier: 1.0,
@@ -1254,6 +1253,8 @@ function generateMockKeysForEndpoint(endpointId: string, count: number = 2) {
error_count: Math.floor(Math.random() * 100), error_count: Math.floor(Math.random() * 100),
success_rate: 0.95 + Math.random() * 0.04, // 0.95-0.99 success_rate: 0.95 + Math.random() * 0.04, // 0.95-0.99
avg_response_time_ms: 800 + Math.floor(Math.random() * 600), avg_response_time_ms: 800 + Math.floor(Math.random() * 600),
cache_ttl_minutes: 5,
max_probe_interval_minutes: 32,
is_active: true, is_active: true,
created_at: '2024-01-01T00:00:00Z', created_at: '2024-01-01T00:00:00Z',
updated_at: new Date().toISOString() updated_at: new Date().toISOString()
@@ -1463,29 +1464,63 @@ registerDynamicRoute('PUT', '/api/admin/endpoints/:endpointId', async (config, p
registerDynamicRoute('DELETE', '/api/admin/endpoints/:endpointId', async (_config, _params) => { registerDynamicRoute('DELETE', '/api/admin/endpoints/:endpointId', async (_config, _params) => {
await delay() await delay()
requireAdmin() requireAdmin()
return createMockResponse({ message: '删除成功(演示模式)' }) return createMockResponse({ message: '删除成功(演示模式)', affected_keys_count: 0 })
}) })
// Endpoint Keys 列表 // Provider Keys 列表
registerDynamicRoute('GET', '/api/admin/endpoints/:endpointId/keys', async (_config, params) => { registerDynamicRoute('GET', '/api/admin/endpoints/providers/:providerId/keys', async (_config, params) => {
await delay() await delay()
requireAdmin() requireAdmin()
const keys = generateMockKeysForEndpoint(params.endpointId, 2) if (!PROVIDER_KEYS_CACHE[params.providerId]) {
return createMockResponse(keys) PROVIDER_KEYS_CACHE[params.providerId] = generateMockKeysForProvider(params.providerId, 2)
}
return createMockResponse(PROVIDER_KEYS_CACHE[params.providerId])
}) })
// 创建 Key // 为 Provider 创建 Key
registerDynamicRoute('POST', '/api/admin/endpoints/:endpointId/keys', async (config, params) => { registerDynamicRoute('POST', '/api/admin/endpoints/providers/:providerId/keys', async (config, params) => {
await delay() await delay()
requireAdmin() requireAdmin()
const body = JSON.parse(config.data || '{}') const body = JSON.parse(config.data || '{}')
return createMockResponse({ const apiKeyPlain = body.api_key || 'sk-demo'
const masked = apiKeyPlain.length >= 12
? `${apiKeyPlain.slice(0, 8)}***${apiKeyPlain.slice(-4)}`
: 'sk-***...demo'
const newKey = {
id: `key-demo-${Date.now()}`, id: `key-demo-${Date.now()}`,
endpoint_id: params.endpointId, provider_id: params.providerId,
api_key_masked: 'sk-***...demo', api_formats: body.api_formats || [],
...body, api_key_masked: masked,
created_at: new Date().toISOString() api_key_plain: null,
}) name: body.name || 'New Key',
note: body.note,
rate_multiplier: body.rate_multiplier ?? 1.0,
rate_multipliers: body.rate_multipliers ?? null,
internal_priority: body.internal_priority ?? 50,
global_priority: body.global_priority ?? null,
rpm_limit: body.rpm_limit ?? null,
allowed_models: body.allowed_models ?? null,
capabilities: body.capabilities ?? null,
cache_ttl_minutes: body.cache_ttl_minutes ?? 5,
max_probe_interval_minutes: body.max_probe_interval_minutes ?? 32,
health_score: 1.0,
consecutive_failures: 0,
request_count: 0,
success_count: 0,
error_count: 0,
success_rate: 0.0,
avg_response_time_ms: 0.0,
is_active: true,
created_at: new Date().toISOString(),
updated_at: new Date().toISOString(),
}
if (!PROVIDER_KEYS_CACHE[params.providerId]) {
PROVIDER_KEYS_CACHE[params.providerId] = []
}
PROVIDER_KEYS_CACHE[params.providerId].push(newKey)
return createMockResponse(newKey)
}) })
// Key 更新 // Key 更新
@@ -1503,6 +1538,50 @@ registerDynamicRoute('DELETE', '/api/admin/endpoints/keys/:keyId', async (_confi
return createMockResponse({ message: '删除成功(演示模式)' }) return createMockResponse({ message: '删除成功(演示模式)' })
}) })
// Key Reveal
registerDynamicRoute('GET', '/api/admin/endpoints/keys/:keyId/reveal', async (_config, _params) => {
await delay()
requireAdmin()
return createMockResponse({ api_key: 'sk-demo-reveal' })
})
// Keys grouped by format
mockHandlers['GET /api/admin/endpoints/keys/grouped-by-format'] = async () => {
await delay()
requireAdmin()
// 确保每个 provider 都有 key 数据
for (const provider of MOCK_PROVIDERS) {
if (!PROVIDER_KEYS_CACHE[provider.id]) {
PROVIDER_KEYS_CACHE[provider.id] = generateMockKeysForProvider(provider.id, 2)
}
}
const grouped: Record<string, any[]> = {}
for (const provider of MOCK_PROVIDERS) {
const endpoints = generateMockEndpointsForProvider(provider.id)
const baseUrlByFormat = Object.fromEntries(endpoints.map(e => [e.api_format, e.base_url]))
const keys = PROVIDER_KEYS_CACHE[provider.id] || []
for (const key of keys) {
const formats: string[] = key.api_formats || []
for (const fmt of formats) {
if (!grouped[fmt]) grouped[fmt] = []
grouped[fmt].push({
...key,
api_format: fmt,
provider_name: provider.name,
endpoint_base_url: baseUrlByFormat[fmt],
global_priority: key.global_priority ?? null,
circuit_breaker_open: false,
capabilities: [],
})
}
}
}
return createMockResponse(grouped)
}
// Provider Models 列表 // Provider Models 列表
registerDynamicRoute('GET', '/api/admin/providers/:providerId/models', async (_config, params) => { registerDynamicRoute('GET', '/api/admin/providers/:providerId/models', async (_config, params) => {
await delay() await delay()

View File

@@ -20,7 +20,7 @@ interface ValidationError {
const fieldNameMap: Record<string, string> = { const fieldNameMap: Record<string, string> = {
'api_key': 'API 密钥', 'api_key': 'API 密钥',
'priority': '优先级', 'priority': '优先级',
'max_concurrent': '最大并发', 'rpm_limit': 'RPM 限制',
'rate_limit': '速率限制', 'rate_limit': '速率限制',
'daily_limit': '每日限制', 'daily_limit': '每日限制',
'monthly_limit': '每月限制', 'monthly_limit': '每月限制',
@@ -44,7 +44,6 @@ const fieldNameMap: Record<string, string> = {
'monthly_quota_usd': '月度配额', 'monthly_quota_usd': '月度配额',
'quota_reset_day': '配额重置日', 'quota_reset_day': '配额重置日',
'quota_expires_at': '配额过期时间', 'quota_expires_at': '配额过期时间',
'rpm_limit': 'RPM 限制',
'cache_ttl_minutes': '缓存 TTL', 'cache_ttl_minutes': '缓存 TTL',
'max_probe_interval_minutes': '最大探测间隔', 'max_probe_interval_minutes': '最大探测间隔',
} }
@@ -54,7 +53,7 @@ const fieldNameMap: Record<string, string> = {
*/ */
const errorTypeMap: Record<string, (error: ValidationError) => string> = { const errorTypeMap: Record<string, (error: ValidationError) => string> = {
'string_too_short': (error) => { 'string_too_short': (error) => {
const minLength = error.ctx?.min_length || 10 const minLength = error.ctx?.min_length || 3
return `长度不能少于 ${minLength} 个字符` return `长度不能少于 ${minLength} 个字符`
}, },
'string_too_long': (error) => { 'string_too_long': (error) => {
@@ -151,11 +150,18 @@ export function parseApiError(err: unknown, defaultMessage: string = '操作失
return '无法连接到服务器,请检查网络连接' return '无法连接到服务器,请检查网络连接'
} }
const detail = err.response?.data?.detail const data = err.response?.data
// 1. 处理 {error: {type, message}} 格式ProxyException 返回格式)
if (data?.error?.message) {
return data.error.message
}
const detail = data?.detail
// 如果没有 detail 字段 // 如果没有 detail 字段
if (!detail) { if (!detail) {
return err.response?.data?.message || err.message || defaultMessage return data?.message || err.message || defaultMessage
} }
// 1. 处理 Pydantic 验证错误(数组格式) // 1. 处理 Pydantic 验证错误(数组格式)

View File

@@ -54,6 +54,57 @@ export function parseNumberInput(
return result return result
} }
/**
* Parse number input value for nullable fields (like rpm_limit)
* Returns `null` when empty (to signal "use adaptive/default mode")
* Returns `undefined` when not provided (to signal "keep original value")
*
* @param value - Input value (string or number)
* @param options - Parse options
* @returns Parsed number, null (for empty/adaptive), or undefined
*/
export function parseNullableNumberInput(
value: string | number | null | undefined,
options: {
allowFloat?: boolean
min?: number
max?: number
} = {}
): number | null | undefined {
const { allowFloat = false, min, max } = options
// Empty string means "null" (adaptive mode)
if (value === '') {
return null
}
// null/undefined means "keep original value"
if (value === null || value === undefined) {
return undefined
}
// Parse the value
const num = typeof value === 'string'
? (allowFloat ? parseFloat(value) : parseInt(value, 10))
: value
// Handle NaN - treat as null (adaptive mode)
if (isNaN(num)) {
return null
}
// Apply min/max constraints
let result = num
if (min !== undefined && result < min) {
result = min
}
if (max !== undefined && result > max) {
result = max
}
return result
}
/** /**
* Create a handler function for number input with specific field * Create a handler function for number input with specific field
* Useful for creating inline handlers in templates * Useful for creating inline handlers in templates

View File

@@ -530,9 +530,6 @@
/> />
<div class="flex-1 min-w-0"> <div class="flex-1 min-w-0">
<p class="font-medium text-sm truncate"> <p class="font-medium text-sm truncate">
{{ provider.display_name || provider.name }}
</p>
<p class="text-xs text-muted-foreground truncate">
{{ provider.name }} {{ provider.name }}
</p> </p>
</div> </div>
@@ -645,10 +642,7 @@
/> />
<div class="flex-1 min-w-0"> <div class="flex-1 min-w-0">
<p class="font-medium text-sm truncate"> <p class="font-medium text-sm truncate">
{{ provider.display_name }} {{ provider.name }}
</p>
<p class="text-xs text-muted-foreground truncate">
{{ provider.identifier }}
</p> </p>
</div> </div>
<Badge <Badge
@@ -679,7 +673,7 @@
<ProviderModelFormDialog <ProviderModelFormDialog
:open="editProviderDialogOpen" :open="editProviderDialogOpen"
:provider-id="editingProvider?.id || ''" :provider-id="editingProvider?.id || ''"
:provider-name="editingProvider?.display_name || ''" :provider-name="editingProvider?.name || ''"
:editing-model="editingProviderModel" :editing-model="editingProviderModel"
@update:open="handleEditProviderDialogUpdate" @update:open="handleEditProviderDialogUpdate"
@saved="handleEditProviderSaved" @saved="handleEditProviderSaved"
@@ -939,7 +933,7 @@ async function batchAddSelectedProviders() {
const errorMessages = result.errors const errorMessages = result.errors
.map(e => { .map(e => {
const provider = providerOptions.value.find(p => p.id === e.provider_id) const provider = providerOptions.value.find(p => p.id === e.provider_id)
const providerName = provider?.display_name || provider?.name || e.provider_id const providerName = provider?.name || e.provider_id
return `${providerName}: ${e.error}` return `${providerName}: ${e.error}`
}) })
.join('\n') .join('\n')
@@ -977,7 +971,7 @@ async function batchRemoveSelectedProviders() {
await deleteModel(providerId, provider.model_id) await deleteModel(providerId, provider.model_id)
successCount++ successCount++
} catch (err: any) { } catch (err: any) {
errors.push(`${provider.display_name}: ${parseApiError(err, '删除失败')}`) errors.push(`${provider.name}: ${parseApiError(err, '删除失败')}`)
} }
} }
@@ -1088,8 +1082,7 @@ async function loadModelProviders(_globalModelId: string) {
selectedModelProviders.value = response.providers.map(p => ({ selectedModelProviders.value = response.providers.map(p => ({
id: p.provider_id, id: p.provider_id,
model_id: p.model_id, model_id: p.model_id,
display_name: p.provider_display_name || p.provider_name, name: p.provider_name,
identifier: p.provider_name,
provider_type: 'API', provider_type: 'API',
target_model: p.target_model, target_model: p.target_model,
is_active: p.is_active, is_active: p.is_active,
@@ -1219,7 +1212,7 @@ async function confirmDeleteProviderImplementation(provider: any) {
} }
const confirmed = await confirmDanger( const confirmed = await confirmDanger(
`确定要删除 ${provider.display_name} 的模型关联吗?\n\n模型: ${provider.target_model}\n\n此操作不可恢复`, `确定要删除 ${provider.name} 的模型关联吗?\n\n模型: ${provider.target_model}\n\n此操作不可恢复`,
'删除关联提供商' '删除关联提供商'
) )
if (!confirmed) return if (!confirmed) return
@@ -1227,7 +1220,7 @@ async function confirmDeleteProviderImplementation(provider: any) {
try { try {
const { deleteModel } = await import('@/api/endpoints') const { deleteModel } = await import('@/api/endpoints')
await deleteModel(provider.id, provider.model_id) await deleteModel(provider.id, provider.model_id)
success(`已删除 ${provider.display_name} 的模型实现`) success(`已删除 ${provider.name} 的模型实现`)
// 重新加载 Provider 列表 // 重新加载 Provider 列表
if (selectedModel.value) { if (selectedModel.value) {
await loadModelProviders(selectedModel.value.id) await loadModelProviders(selectedModel.value.id)

View File

@@ -134,10 +134,7 @@
@click="handleRowClick($event, provider.id)" @click="handleRowClick($event, provider.id)"
> >
<TableCell class="py-3.5"> <TableCell class="py-3.5">
<div class="flex flex-col gap-0.5"> <span class="text-sm font-medium text-foreground">{{ provider.name }}</span>
<span class="text-sm font-medium text-foreground">{{ provider.display_name }}</span>
<span class="text-xs text-muted-foreground/70 font-mono">{{ provider.name }}</span>
</div>
</TableCell> </TableCell>
<TableCell class="py-3.5"> <TableCell class="py-3.5">
<Badge <Badge
@@ -219,17 +216,10 @@
>${{ (provider.monthly_used_usd ?? 0).toFixed(2) }}</span> / <span class="font-medium">${{ (provider.monthly_quota_usd ?? 0).toFixed(2) }}</span> >${{ (provider.monthly_used_usd ?? 0).toFixed(2) }}</span> / <span class="font-medium">${{ (provider.monthly_quota_usd ?? 0).toFixed(2) }}</span>
</div> </div>
<div <div
v-if="rpmUsage(provider)" v-else
class="flex items-center gap-1"
>
<span class="text-muted-foreground/70">RPM:</span>
<span class="font-medium text-foreground/80">{{ rpmUsage(provider) }}</span>
</div>
<div
v-if="provider.billing_type !== 'monthly_quota' && !rpmUsage(provider)"
class="text-muted-foreground/50" class="text-muted-foreground/50"
> >
无限制 按量付费
</div> </div>
</div> </div>
</TableCell> </TableCell>
@@ -304,7 +294,7 @@
<div class="flex items-start justify-between gap-3"> <div class="flex items-start justify-between gap-3">
<div class="flex-1 min-w-0"> <div class="flex-1 min-w-0">
<div class="flex items-center gap-2"> <div class="flex items-center gap-2">
<span class="font-medium text-foreground truncate">{{ provider.display_name }}</span> <span class="font-medium text-foreground truncate">{{ provider.name }}</span>
<Badge <Badge
:variant="provider.is_active ? 'success' : 'secondary'" :variant="provider.is_active ? 'success' : 'secondary'"
class="text-xs shrink-0" class="text-xs shrink-0"
@@ -312,7 +302,6 @@
{{ provider.is_active ? '活跃' : '停用' }} {{ provider.is_active ? '活跃' : '停用' }}
</Badge> </Badge>
</div> </div>
<span class="text-xs text-muted-foreground/70 font-mono">{{ provider.name }}</span>
</div> </div>
<div <div
class="flex items-center gap-0.5 shrink-0" class="flex items-center gap-0.5 shrink-0"
@@ -383,20 +372,17 @@
</span> </span>
</div> </div>
<!-- 第四行配额/限流 --> <!-- 第四行配额 -->
<div <div
v-if="provider.billing_type === 'monthly_quota' || rpmUsage(provider)" v-if="provider.billing_type === 'monthly_quota'"
class="flex items-center gap-3 text-xs text-muted-foreground" class="flex items-center gap-3 text-xs text-muted-foreground"
> >
<span v-if="provider.billing_type === 'monthly_quota'"> <span>
配额: <span 配额: <span
class="font-semibold" class="font-semibold"
:class="getQuotaUsedColorClass(provider)" :class="getQuotaUsedColorClass(provider)"
>${{ (provider.monthly_used_usd ?? 0).toFixed(2) }}</span> / ${{ (provider.monthly_quota_usd ?? 0).toFixed(2) }} >${{ (provider.monthly_used_usd ?? 0).toFixed(2) }}</span> / ${{ (provider.monthly_quota_usd ?? 0).toFixed(2) }}
</span> </span>
<span v-if="rpmUsage(provider)">
RPM: {{ rpmUsage(provider) }}
</span>
</div> </div>
</div> </div>
</div> </div>
@@ -509,7 +495,7 @@ const filteredProviders = computed(() => {
if (searchQuery.value.trim()) { if (searchQuery.value.trim()) {
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0) const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
result = result.filter(p => { result = result.filter(p => {
const searchableText = `${p.display_name} ${p.name}`.toLowerCase() const searchableText = `${p.name}`.toLowerCase()
return keywords.every(keyword => searchableText.includes(keyword)) return keywords.every(keyword => searchableText.includes(keyword))
}) })
} }
@@ -525,7 +511,7 @@ const filteredProviders = computed(() => {
return a.provider_priority - b.provider_priority return a.provider_priority - b.provider_priority
} }
// 3. 按名称排序 // 3. 按名称排序
return a.display_name.localeCompare(b.display_name) return a.name.localeCompare(b.name)
}) })
}) })
@@ -586,7 +572,10 @@ function sortEndpoints(endpoints: any[]) {
// 判断端点是否可用(有 key // 判断端点是否可用(有 key
function isEndpointAvailable(endpoint: any, _provider: ProviderWithEndpointsSummary): boolean { function isEndpointAvailable(endpoint: any, _provider: ProviderWithEndpointsSummary): boolean {
// 检查端点是否有活跃的密钥 // 检查端点是否启用,以及是否有活跃的密钥
if (endpoint.is_active === false) {
return false
}
return (endpoint.active_keys ?? 0) > 0 return (endpoint.active_keys ?? 0) > 0
} }
@@ -639,21 +628,6 @@ function getQuotaUsedColorClass(provider: ProviderWithEndpointsSummary): string
return 'text-foreground' return 'text-foreground'
} }
function rpmUsage(provider: ProviderWithEndpointsSummary): string | null {
const rpmLimit = provider.rpm_limit
const rpmUsed = provider.rpm_used ?? 0
if (rpmLimit === null || rpmLimit === undefined) {
return rpmUsed > 0 ? `${rpmUsed}` : null
}
if (rpmLimit === 0) {
return '已完全禁止'
}
return `${rpmUsed} / ${rpmLimit}`
}
// 使用复用的行点击逻辑 // 使用复用的行点击逻辑
const { handleMouseDown, shouldTriggerRowClick } = useRowClick() const { handleMouseDown, shouldTriggerRowClick } = useRowClick()
@@ -706,7 +680,7 @@ function handleProviderAdded() {
async function handleDeleteProvider(provider: ProviderWithEndpointsSummary) { async function handleDeleteProvider(provider: ProviderWithEndpointsSummary) {
const confirmed = await confirmDanger( const confirmed = await confirmDanger(
'删除提供商', '删除提供商',
`确定要删除提供商 "${provider.display_name}" 吗?\n\n这将同时删除其所有端点、密钥和配置。此操作不可恢复` `确定要删除提供商 "${provider.name}" 吗?\n\n这将同时删除其所有端点、密钥和配置。此操作不可恢复`
) )
if (!confirmed) return if (!confirmed) return

View File

@@ -511,7 +511,7 @@
端点: {{ importPreview.providers?.reduce((sum: number, p: any) => sum + (p.endpoints?.length || 0), 0) }} 端点: {{ importPreview.providers?.reduce((sum: number, p: any) => sum + (p.endpoints?.length || 0), 0) }}
</li> </li>
<li> <li>
API Keys: {{ importPreview.providers?.reduce((sum: number, p: any) => sum + p.endpoints?.reduce((s: number, e: any) => s + (e.keys?.length || 0), 0), 0) }} API Keys: {{ importPreview.providers?.reduce((sum: number, p: any) => sum + (p.api_keys?.length || 0), 0) }}
</li> </li>
</ul> </ul>
</div> </div>
@@ -1144,7 +1144,7 @@ function handleConfigFileSelect(event: Event) {
const data = JSON.parse(content) as ConfigExportData const data = JSON.parse(content) as ConfigExportData
// 验证版本 // 验证版本
if (data.version !== '1.0') { if (data.version !== '2.0') {
error(`不支持的配置版本: ${data.version}`) error(`不支持的配置版本: ${data.version}`)
return return
} }

View File

@@ -16,7 +16,8 @@
<!-- 主要统计卡片 --> <!-- 主要统计卡片 -->
<div class="grid grid-cols-2 gap-3 sm:gap-4 xl:grid-cols-4"> <div class="grid grid-cols-2 gap-3 sm:gap-4 xl:grid-cols-4">
<template v-if="loading && stats.length === 0"> <!-- 加载中骨架屏 -->
<template v-if="loading">
<Card <Card
v-for="i in 4" v-for="i in 4"
:key="'skeleton-' + i" :key="'skeleton-' + i"
@@ -27,62 +28,98 @@
<Skeleton class="h-4 w-16" /> <Skeleton class="h-4 w-16" />
</Card> </Card>
</template> </template>
<Card <!-- 有数据时显示统计卡片 -->
v-for="(stat, index) in stats" <template v-else-if="stats.length > 0">
v-else <Card
:key="stat.name" v-for="(stat, index) in stats"
class="relative overflow-hidden p-3 sm:p-5" :key="stat.name"
:class="statCardBorders[index % statCardBorders.length]" class="relative overflow-hidden p-3 sm:p-5"
> :class="statCardBorders[index % statCardBorders.length]"
<div
class="pointer-events-none absolute -right-4 -top-6 h-28 w-28 rounded-full blur-3xl opacity-40"
:class="statCardGlows[index % statCardGlows.length]"
/>
<!-- 图标固定在右上角 -->
<div
class="absolute top-3 right-3 sm:top-5 sm:right-5 rounded-xl sm:rounded-2xl border border-border bg-card/50 p-2 sm:p-3 shadow-inner backdrop-blur-sm"
:class="getStatIconColor(index)"
> >
<component
:is="stat.icon"
class="h-4 w-4 sm:h-5 sm:w-5"
/>
</div>
<!-- 内容区域 -->
<div>
<p class="text-[9px] sm:text-[11px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.4em] text-muted-foreground pr-10 sm:pr-14">
{{ stat.name }}
</p>
<p class="mt-2 sm:mt-4 text-xl sm:text-3xl font-semibold text-foreground">
{{ stat.value }}
</p>
<p
v-if="stat.subValue"
class="mt-0.5 sm:mt-1 text-[10px] sm:text-sm text-muted-foreground"
>
{{ stat.subValue }}
</p>
<div <div
v-if="stat.change || stat.extraBadge" class="pointer-events-none absolute -right-4 -top-6 h-28 w-28 rounded-full blur-3xl opacity-40"
class="mt-1.5 sm:mt-2 flex items-center gap-1 sm:gap-1.5 flex-wrap" :class="statCardGlows[index % statCardGlows.length]"
/>
<!-- 图标固定在右上角 -->
<div
class="absolute top-3 right-3 sm:top-5 sm:right-5 rounded-xl sm:rounded-2xl border border-border bg-card/50 p-2 sm:p-3 shadow-inner backdrop-blur-sm"
:class="getStatIconColor(index)"
> >
<Badge <component
v-if="stat.change" :is="stat.icon"
variant="secondary" class="h-4 w-4 sm:h-5 sm:w-5"
class="text-[9px] sm:text-xs" />
>
{{ stat.change }}
</Badge>
<Badge
v-if="stat.extraBadge"
variant="secondary"
class="text-[9px] sm:text-xs"
>
{{ stat.extraBadge }}
</Badge>
</div> </div>
</div> <!-- 内容区域 -->
</Card> <div>
<p class="text-[9px] sm:text-[11px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.4em] text-muted-foreground pr-10 sm:pr-14">
{{ stat.name }}
</p>
<p class="mt-2 sm:mt-4 text-xl sm:text-3xl font-semibold text-foreground">
{{ stat.value }}
</p>
<p
v-if="stat.subValue"
class="mt-0.5 sm:mt-1 text-[10px] sm:text-sm text-muted-foreground"
>
{{ stat.subValue }}
</p>
<div
v-if="stat.change || stat.extraBadge"
class="mt-1.5 sm:mt-2 flex items-center gap-1 sm:gap-1.5 flex-wrap"
>
<Badge
v-if="stat.change"
variant="secondary"
class="text-[9px] sm:text-xs"
>
{{ stat.change }}
</Badge>
<Badge
v-if="stat.extraBadge"
variant="secondary"
class="text-[9px] sm:text-xs"
>
{{ stat.extraBadge }}
</Badge>
</div>
</div>
</Card>
</template>
<!-- 无数据时显示占位卡片 -->
<template v-else>
<Card
v-for="(placeholder, index) in emptyStatPlaceholders"
:key="'empty-' + index"
class="relative overflow-hidden p-3 sm:p-5"
:class="statCardBorders[index % statCardBorders.length]"
>
<div
class="pointer-events-none absolute -right-4 -top-6 h-28 w-28 rounded-full blur-3xl opacity-20"
:class="statCardGlows[index % statCardGlows.length]"
/>
<div
class="absolute top-3 right-3 sm:top-5 sm:right-5 rounded-xl sm:rounded-2xl border border-border bg-card/50 p-2 sm:p-3 shadow-inner backdrop-blur-sm"
:class="getStatIconColor(index)"
>
<component
:is="placeholder.icon"
class="h-4 w-4 sm:h-5 sm:w-5"
/>
</div>
<div>
<p class="text-[9px] sm:text-[11px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.4em] text-muted-foreground pr-10 sm:pr-14">
{{ placeholder.name }}
</p>
<p class="mt-2 sm:mt-4 text-xl sm:text-3xl font-semibold text-muted-foreground/50">
--
</p>
<p class="mt-0.5 sm:mt-1 text-[10px] sm:text-sm text-muted-foreground/50">
暂无数据
</p>
</div>
</Card>
</template>
</div> </div>
<!-- 管理员系统健康摘要 --> <!-- 管理员系统健康摘要 -->
@@ -872,6 +909,24 @@ const iconMap: Record<string, any> = {
Users, Activity, TrendingUp, DollarSign, Key, Hash, Database Users, Activity, TrendingUp, DollarSign, Key, Hash, Database
} }
// 空状态占位卡片
const emptyStatPlaceholders = computed(() => {
if (isAdmin.value) {
return [
{ name: '今日请求', icon: Activity },
{ name: '今日 Tokens', icon: Hash },
{ name: '活跃用户', icon: Users },
{ name: '今日费用', icon: DollarSign }
]
}
return [
{ name: '今日请求', icon: Activity },
{ name: '今日 Tokens', icon: Hash },
{ name: 'API Keys', icon: Key },
{ name: '今日费用', icon: DollarSign }
]
})
const totalStats = computed(() => { const totalStats = computed(() => {
if (dailyStats.value.length === 0) { if (dailyStats.value.length === 0) {
return { requests: 0, tokens: 0, cost: 0, avgResponseTime: 0 } return { requests: 0, tokens: 0, cost: 0, avgResponseTime: 0 }

View File

@@ -78,6 +78,20 @@ export default {
md: "calc(var(--radius) - 2px)", md: "calc(var(--radius) - 2px)",
sm: "calc(var(--radius) - 4px)", sm: "calc(var(--radius) - 4px)",
}, },
keyframes: {
"collapsible-down": {
from: { height: "0" },
to: { height: "var(--radix-collapsible-content-height)" },
},
"collapsible-up": {
from: { height: "var(--radix-collapsible-content-height)" },
to: { height: "0" },
},
},
animation: {
"collapsible-down": "collapsible-down 0.2s ease-out",
"collapsible-up": "collapsible-up 0.2s ease-out",
},
}, },
}, },
plugins: [require("tailwindcss-animate")], plugins: [require("tailwindcss-animate")],

View File

@@ -1,12 +0,0 @@
#!/bin/bash
# 数据库迁移脚本 - 在 Docker 容器内执行 Alembic 迁移
set -e
CONTAINER_NAME="aether-app"
echo "Running database migrations in container: $CONTAINER_NAME"
docker exec $CONTAINER_NAME alembic upgrade head
echo "Database migration completed successfully"

View File

@@ -1,34 +0,0 @@
# file generated by setuptools-scm
# don't change, don't track in version control
__all__ = [
"__version__",
"__version_tuple__",
"version",
"version_tuple",
"__commit_id__",
"commit_id",
]
TYPE_CHECKING = False
if TYPE_CHECKING:
from typing import Tuple
from typing import Union
VERSION_TUPLE = Tuple[Union[int, str], ...]
COMMIT_ID = Union[str, None]
else:
VERSION_TUPLE = object
COMMIT_ID = object
version: str
__version__: str
__version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE
commit_id: COMMIT_ID
__commit_id__: COMMIT_ID
__version__ = version = '0.2.5'
__version_tuple__ = version_tuple = (0, 2, 5)
__commit_id__ = commit_id = None

View File

@@ -1,12 +1,12 @@
""" """
自适应并发管理 API 端点 自适应 RPM 管理 API 端点
设计原则: 设计原则:
- 自适应模式由 max_concurrent 字段决定: - 自适应模式由 rpm_limit 字段决定:
- max_concurrent = NULL启用自适应模式系统自动学习并调整并发限制 - rpm_limit = NULL启用自适应模式系统自动学习并调整 RPM 限制
- max_concurrent = 数字:固定限制模式,使用用户指定的并发限制 - rpm_limit = 数字:固定限制模式,使用用户指定的 RPM 限制
- learned_max_concurrent自适应模式下学习到的并发限制值 - learned_rpm_limit自适应模式下学习到的 RPM 限制值
- adaptive_mode 是计算字段,基于 max_concurrent 是否为 NULL - adaptive_mode 是计算字段,基于 rpm_limit 是否为 NULL
""" """
from dataclasses import dataclass from dataclasses import dataclass
@@ -18,12 +18,13 @@ from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline from src.api.base.pipeline import ApiRequestPipeline
from src.config.constants import RPMDefaults
from src.core.exceptions import InvalidRequestException, translate_pydantic_error from src.core.exceptions import InvalidRequestException, translate_pydantic_error
from src.database import get_db from src.database import get_db
from src.models.database import ProviderAPIKey from src.models.database import ProviderAPIKey
from src.services.rate_limit.adaptive_concurrency import get_adaptive_manager from src.services.rate_limit.adaptive_rpm import get_adaptive_rpm_manager
router = APIRouter(prefix="/api/admin/adaptive", tags=["Adaptive Concurrency"]) router = APIRouter(prefix="/api/admin/adaptive", tags=["Adaptive RPM"])
pipeline = ApiRequestPipeline() pipeline = ApiRequestPipeline()
@@ -35,19 +36,19 @@ class EnableAdaptiveRequest(BaseModel):
enabled: bool = Field(..., description="是否启用自适应模式true=自适应false=固定限制)") enabled: bool = Field(..., description="是否启用自适应模式true=自适应false=固定限制)")
fixed_limit: Optional[int] = Field( fixed_limit: Optional[int] = Field(
None, ge=1, le=100, description="固定并发限制(仅当 enabled=false 时生效)" None, ge=1, le=100, description="固定 RPM 限制(仅当 enabled=false 时生效1-100"
) )
class AdaptiveStatsResponse(BaseModel): class AdaptiveStatsResponse(BaseModel):
"""自适应统计响应""" """自适应统计响应"""
adaptive_mode: bool = Field(..., description="是否为自适应模式(max_concurrent=NULL") adaptive_mode: bool = Field(..., description="是否为自适应模式(rpm_limit=NULL")
max_concurrent: Optional[int] = Field(None, description="用户配置的固定限制NULL=自适应)") rpm_limit: Optional[int] = Field(None, description="用户配置的固定限制NULL=自适应)")
effective_limit: Optional[int] = Field( effective_limit: Optional[int] = Field(
None, description="当前有效限制(自适应使用学习值,固定使用配置值)" None, description="当前有效限制(自适应使用学习值,固定使用配置值)"
) )
learned_limit: Optional[int] = Field(None, description="学习到的并发限制") learned_limit: Optional[int] = Field(None, description="学习到的 RPM 限制")
concurrent_429_count: int concurrent_429_count: int
rpm_429_count: int rpm_429_count: int
last_429_at: Optional[str] last_429_at: Optional[str]
@@ -61,11 +62,12 @@ class KeyListItem(BaseModel):
id: str id: str
name: Optional[str] name: Optional[str]
endpoint_id: str provider_id: str
is_adaptive: bool = Field(..., description="是否为自适应模式max_concurrent=NULL") api_formats: List[str] = Field(default_factory=list)
max_concurrent: Optional[int] = Field(None, description="固定并发限制NULL=自适应") is_adaptive: bool = Field(..., description="是否为自适应模式rpm_limit=NULL")
rpm_limit: Optional[int] = Field(None, description="固定 RPM 限制NULL=自适应)")
effective_limit: Optional[int] = Field(None, description="当前有效限制") effective_limit: Optional[int] = Field(None, description="当前有效限制")
learned_max_concurrent: Optional[int] = Field(None, description="学习到的并发限制") learned_rpm_limit: Optional[int] = Field(None, description="学习到的 RPM 限制")
concurrent_429_count: int concurrent_429_count: int
rpm_429_count: int rpm_429_count: int
@@ -80,22 +82,22 @@ class KeyListItem(BaseModel):
) )
async def list_adaptive_keys( async def list_adaptive_keys(
request: Request, request: Request,
endpoint_id: Optional[str] = Query(None, description="Endpoint 过滤"), provider_id: Optional[str] = Query(None, description="Provider 过滤"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
""" """
获取所有启用自适应模式的Key列表 获取所有启用自适应模式的Key列表
可选参数: 可选参数:
- endpoint_id: 按 Endpoint 过滤 - provider_id: 按 Provider 过滤
""" """
adapter = ListAdaptiveKeysAdapter(endpoint_id=endpoint_id) adapter = ListAdaptiveKeysAdapter(provider_id=provider_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.patch( @router.patch(
"/keys/{key_id}/mode", "/keys/{key_id}/mode",
summary="Toggle key's concurrency control mode", summary="Toggle key's RPM control mode",
) )
async def toggle_adaptive_mode( async def toggle_adaptive_mode(
key_id: str, key_id: str,
@@ -103,10 +105,10 @@ async def toggle_adaptive_mode(
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
""" """
Toggle the concurrency control mode for a specific key Toggle the RPM control mode for a specific key
Parameters: Parameters:
- enabled: true=adaptive mode (max_concurrent=NULL), false=fixed limit mode - enabled: true=adaptive mode (rpm_limit=NULL), false=fixed limit mode
- fixed_limit: fixed limit value (required when enabled=false) - fixed_limit: fixed limit value (required when enabled=false)
""" """
adapter = ToggleAdaptiveModeAdapter(key_id=key_id) adapter = ToggleAdaptiveModeAdapter(key_id=key_id)
@@ -124,7 +126,7 @@ async def get_adaptive_stats(
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
""" """
获取指定Key的自适应并发统计信息 获取指定Key的自适应 RPM 统计信息
包括: 包括:
- 当前配置 - 当前配置
@@ -149,12 +151,12 @@ async def reset_adaptive_learning(
Reset the adaptive learning state for a specific key Reset the adaptive learning state for a specific key
Clears: Clears:
- Learned concurrency limit (learned_max_concurrent) - Learned RPM limit (learned_rpm_limit)
- 429 error counts - 429 error counts
- Adjustment history - Adjustment history
Does not change: Does not change:
- max_concurrent config (determines adaptive mode) - rpm_limit config (determines adaptive mode)
""" """
adapter = ResetAdaptiveLearningAdapter(key_id=key_id) adapter = ResetAdaptiveLearningAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@@ -162,40 +164,40 @@ async def reset_adaptive_learning(
@router.patch( @router.patch(
"/keys/{key_id}/limit", "/keys/{key_id}/limit",
summary="Set key to fixed concurrency limit mode", summary="Set key to fixed RPM limit mode",
) )
async def set_concurrent_limit( async def set_rpm_limit(
key_id: str, key_id: str,
request: Request, request: Request,
limit: int = Query(..., ge=1, le=100, description="Concurrency limit value"), limit: int = Query(..., ge=1, le=100, description="RPM limit value (1-100)"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
""" """
Set key to fixed concurrency limit mode Set key to fixed RPM limit mode
Note: Note:
- After setting this value, key switches to fixed limit mode and won't auto-adjust - After setting this value, key switches to fixed limit mode and won't auto-adjust
- To restore adaptive mode, use PATCH /keys/{key_id}/mode - To restore adaptive mode, use PATCH /keys/{key_id}/mode
""" """
adapter = SetConcurrentLimitAdapter(key_id=key_id, limit=limit) adapter = SetRPMLimitAdapter(key_id=key_id, limit=limit)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get( @router.get(
"/summary", "/summary",
summary="获取自适应并发的全局统计", summary="获取自适应 RPM 的全局统计",
) )
async def get_adaptive_summary( async def get_adaptive_summary(
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
""" """
获取自适应并发的全局统计摘要 获取自适应 RPM 的全局统计摘要
包括: 包括:
- 启用自适应模式的Key数量 - 启用自适应模式的Key数量
- 总429错误数 - 总429错误数
- 并发限制调整次数 - RPM 限制调整次数
""" """
adapter = AdaptiveSummaryAdapter() adapter = AdaptiveSummaryAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@@ -206,26 +208,29 @@ async def get_adaptive_summary(
@dataclass @dataclass
class ListAdaptiveKeysAdapter(AdminApiAdapter): class ListAdaptiveKeysAdapter(AdminApiAdapter):
endpoint_id: Optional[str] = None provider_id: Optional[str] = None
async def handle(self, context): # type: ignore[override] async def handle(self, context): # type: ignore[override]
# 自适应模式:max_concurrent = NULL # 自适应模式:rpm_limit = NULL
query = context.db.query(ProviderAPIKey).filter(ProviderAPIKey.max_concurrent.is_(None)) query = context.db.query(ProviderAPIKey).filter(ProviderAPIKey.rpm_limit.is_(None))
if self.endpoint_id: if self.provider_id:
query = query.filter(ProviderAPIKey.endpoint_id == self.endpoint_id) query = query.filter(ProviderAPIKey.provider_id == self.provider_id)
keys = query.all() keys = query.all()
return [ return [
KeyListItem( KeyListItem(
id=key.id, id=key.id,
name=key.name, name=key.name,
endpoint_id=key.endpoint_id, provider_id=key.provider_id,
is_adaptive=key.max_concurrent is None, api_formats=key.api_formats or [],
max_concurrent=key.max_concurrent, is_adaptive=key.rpm_limit is None,
rpm_limit=key.rpm_limit,
effective_limit=( effective_limit=(
key.learned_max_concurrent if key.max_concurrent is None else key.max_concurrent (key.learned_rpm_limit if key.learned_rpm_limit is not None else RPMDefaults.INITIAL_LIMIT)
if key.rpm_limit is None
else key.rpm_limit
), ),
learned_max_concurrent=key.learned_max_concurrent, learned_rpm_limit=key.learned_rpm_limit,
concurrent_429_count=key.concurrent_429_count or 0, concurrent_429_count=key.concurrent_429_count or 0,
rpm_429_count=key.rpm_429_count or 0, rpm_429_count=key.rpm_429_count or 0,
) )
@@ -252,28 +257,32 @@ class ToggleAdaptiveModeAdapter(AdminApiAdapter):
raise InvalidRequestException("请求数据验证失败") raise InvalidRequestException("请求数据验证失败")
if body.enabled: if body.enabled:
# 启用自适应模式:将 max_concurrent 设为 NULL # 启用自适应模式:将 rpm_limit 设为 NULL
key.max_concurrent = None key.rpm_limit = None
message = "已切换为自适应模式,系统将自动学习并调整并发限制" message = "已切换为自适应模式,系统将自动学习并调整 RPM 限制"
else: else:
# 禁用自适应模式:设置固定限制 # 禁用自适应模式:设置固定限制
if body.fixed_limit is None: if body.fixed_limit is None:
raise HTTPException( raise HTTPException(
status_code=400, detail="禁用自适应模式时必须提供 fixed_limit 参数" status_code=400, detail="禁用自适应模式时必须提供 fixed_limit 参数"
) )
key.max_concurrent = body.fixed_limit key.rpm_limit = body.fixed_limit
message = f"已切换为固定限制模式,并发限制设为 {body.fixed_limit}" message = f"已切换为固定限制模式,RPM 限制设为 {body.fixed_limit}"
context.db.commit() context.db.commit()
context.db.refresh(key) context.db.refresh(key)
is_adaptive = key.max_concurrent is None is_adaptive = key.rpm_limit is None
return { return {
"message": message, "message": message,
"key_id": key.id, "key_id": key.id,
"is_adaptive": is_adaptive, "is_adaptive": is_adaptive,
"max_concurrent": key.max_concurrent, "rpm_limit": key.rpm_limit,
"effective_limit": key.learned_max_concurrent if is_adaptive else key.max_concurrent, "effective_limit": (
(key.learned_rpm_limit if key.learned_rpm_limit is not None else RPMDefaults.INITIAL_LIMIT)
if is_adaptive
else key.rpm_limit
),
} }
@@ -286,13 +295,13 @@ class GetAdaptiveStatsAdapter(AdminApiAdapter):
if not key: if not key:
raise HTTPException(status_code=404, detail="Key not found") raise HTTPException(status_code=404, detail="Key not found")
adaptive_manager = get_adaptive_manager() adaptive_manager = get_adaptive_rpm_manager()
stats = adaptive_manager.get_adjustment_stats(key) stats = adaptive_manager.get_adjustment_stats(key)
# 转换字段名以匹配响应模型 # 转换字段名以匹配响应模型
return AdaptiveStatsResponse( return AdaptiveStatsResponse(
adaptive_mode=stats["adaptive_mode"], adaptive_mode=stats["adaptive_mode"],
max_concurrent=stats["max_concurrent"], rpm_limit=stats["rpm_limit"],
effective_limit=stats["effective_limit"], effective_limit=stats["effective_limit"],
learned_limit=stats["learned_limit"], learned_limit=stats["learned_limit"],
concurrent_429_count=stats["concurrent_429_count"], concurrent_429_count=stats["concurrent_429_count"],
@@ -313,13 +322,13 @@ class ResetAdaptiveLearningAdapter(AdminApiAdapter):
if not key: if not key:
raise HTTPException(status_code=404, detail="Key not found") raise HTTPException(status_code=404, detail="Key not found")
adaptive_manager = get_adaptive_manager() adaptive_manager = get_adaptive_rpm_manager()
adaptive_manager.reset_learning(context.db, key) adaptive_manager.reset_learning(context.db, key)
return {"message": "学习状态已重置", "key_id": key.id} return {"message": "学习状态已重置", "key_id": key.id}
@dataclass @dataclass
class SetConcurrentLimitAdapter(AdminApiAdapter): class SetRPMLimitAdapter(AdminApiAdapter):
key_id: str key_id: str
limit: int limit: int
@@ -328,25 +337,25 @@ class SetConcurrentLimitAdapter(AdminApiAdapter):
if not key: if not key:
raise HTTPException(status_code=404, detail="Key not found") raise HTTPException(status_code=404, detail="Key not found")
was_adaptive = key.max_concurrent is None was_adaptive = key.rpm_limit is None
key.max_concurrent = self.limit key.rpm_limit = self.limit
context.db.commit() context.db.commit()
context.db.refresh(key) context.db.refresh(key)
return { return {
"message": f"已设置为固定限制模式,并发限制为 {self.limit}", "message": f"已设置为固定限制模式,RPM 限制为 {self.limit}",
"key_id": key.id, "key_id": key.id,
"is_adaptive": False, "is_adaptive": False,
"max_concurrent": key.max_concurrent, "rpm_limit": key.rpm_limit,
"previous_mode": "adaptive" if was_adaptive else "fixed", "previous_mode": "adaptive" if was_adaptive else "fixed",
} }
class AdaptiveSummaryAdapter(AdminApiAdapter): class AdaptiveSummaryAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override] async def handle(self, context): # type: ignore[override]
# 自适应模式:max_concurrent = NULL # 自适应模式:rpm_limit = NULL
adaptive_keys = ( adaptive_keys = (
context.db.query(ProviderAPIKey).filter(ProviderAPIKey.max_concurrent.is_(None)).all() context.db.query(ProviderAPIKey).filter(ProviderAPIKey.rpm_limit.is_(None)).all()
) )
total_keys = len(adaptive_keys) total_keys = len(adaptive_keys)

View File

@@ -1,9 +1,8 @@
""" """
Endpoint 并发控制管理 API Key RPM 限制管理 API
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
from fastapi import APIRouter, Depends, Request from fastapi import APIRouter, Depends, Request
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -12,83 +11,56 @@ from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import NotFoundException from src.core.exceptions import NotFoundException
from src.database import get_db from src.database import get_db
from src.models.database import ProviderAPIKey, ProviderEndpoint from src.models.database import ProviderAPIKey
from src.models.endpoint_models import ( from src.models.endpoint_models import KeyRpmStatusResponse
ConcurrencyStatusResponse,
ResetConcurrencyRequest,
)
from src.services.rate_limit.concurrency_manager import get_concurrency_manager from src.services.rate_limit.concurrency_manager import get_concurrency_manager
router = APIRouter(tags=["Concurrency Control"]) router = APIRouter(tags=["RPM Control"])
pipeline = ApiRequestPipeline() pipeline = ApiRequestPipeline()
@router.get("/concurrency/endpoint/{endpoint_id}", response_model=ConcurrencyStatusResponse) @router.get("/rpm/key/{key_id}", response_model=KeyRpmStatusResponse)
async def get_endpoint_concurrency( async def get_key_rpm(
endpoint_id: str,
request: Request,
db: Session = Depends(get_db),
) -> ConcurrencyStatusResponse:
"""
获取 Endpoint 当前并发状态
查询指定 Endpoint 的实时并发使用情况,包括当前并发数和最大并发限制。
**路径参数**:
- `endpoint_id`: Endpoint ID
**返回字段**:
- `endpoint_id`: Endpoint ID
- `endpoint_current_concurrency`: 当前并发数
- `endpoint_max_concurrent`: 最大并发限制
"""
adapter = AdminEndpointConcurrencyAdapter(endpoint_id=endpoint_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/concurrency/key/{key_id}", response_model=ConcurrencyStatusResponse)
async def get_key_concurrency(
key_id: str, key_id: str,
request: Request, request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> ConcurrencyStatusResponse: ) -> KeyRpmStatusResponse:
""" """
获取 Key 当前并发状态 获取 Key 当前 RPM 状态
查询指定 API Key 的实时并发使用情况,包括当前并发数和最大并发限制。 查询指定 API Key 的实时 RPM 使用情况,包括当前 RPM 计数和最大 RPM 限制。
**路径参数**: **路径参数**:
- `key_id`: API Key ID - `key_id`: API Key ID
**返回字段**: **返回字段**:
- `key_id`: API Key ID - `key_id`: API Key ID
- `key_current_concurrency`: 当前并发 - `current_rpm`: 当前 RPM 计
- `key_max_concurrent`: 最大并发限制 - `rpm_limit`: RPM 限制
""" """
adapter = AdminKeyConcurrencyAdapter(key_id=key_id) adapter = AdminKeyRpmAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/concurrency") @router.delete("/rpm/key/{key_id}")
async def reset_concurrency( async def reset_key_rpm(
request: ResetConcurrencyRequest, key_id: str,
http_request: Request, http_request: Request,
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
""" """
重置并发计数器 重置 Key RPM 计数器
重置指定 Endpoint 或 Key 的并发计数器,用于解决计数不准确的问题。 重置指定 API Key 的 RPM 计数器,用于解决计数不准确的问题。
管理员功能,请谨慎使用。 管理员功能,请谨慎使用。
**请求体字段**: **路径参数**:
- `endpoint_id`: Endpoint ID可选 - `key_id`: API Key ID
- `key_id`: API Key ID可选
**返回字段**: **返回字段**:
- `message`: 操作结果消息 - `message`: 操作结果消息
""" """
adapter = AdminResetConcurrencyAdapter(endpoint_id=request.endpoint_id, key_id=request.key_id) adapter = AdminResetKeyRpmAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=http_request, db=db, mode=adapter.mode) return await pipeline.run(adapter=adapter, http_request=http_request, db=db, mode=adapter.mode)
@@ -96,31 +68,7 @@ async def reset_concurrency(
@dataclass @dataclass
class AdminEndpointConcurrencyAdapter(AdminApiAdapter): class AdminKeyRpmAdapter(AdminApiAdapter):
endpoint_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
concurrency_manager = await get_concurrency_manager()
endpoint_count, _ = await concurrency_manager.get_current_concurrency(
endpoint_id=self.endpoint_id
)
return ConcurrencyStatusResponse(
endpoint_id=self.endpoint_id,
endpoint_current_concurrency=endpoint_count,
endpoint_max_concurrent=endpoint.max_concurrent,
)
@dataclass
class AdminKeyConcurrencyAdapter(AdminApiAdapter):
key_id: str key_id: str
async def handle(self, context): # type: ignore[override] async def handle(self, context): # type: ignore[override]
@@ -130,23 +78,20 @@ class AdminKeyConcurrencyAdapter(AdminApiAdapter):
raise NotFoundException(f"Key {self.key_id} 不存在") raise NotFoundException(f"Key {self.key_id} 不存在")
concurrency_manager = await get_concurrency_manager() concurrency_manager = await get_concurrency_manager()
_, key_count = await concurrency_manager.get_current_concurrency(key_id=self.key_id) key_count = await concurrency_manager.get_key_rpm_count(key_id=self.key_id)
return ConcurrencyStatusResponse( return KeyRpmStatusResponse(
key_id=self.key_id, key_id=self.key_id,
key_current_concurrency=key_count, current_rpm=key_count,
key_max_concurrent=key.max_concurrent, rpm_limit=key.rpm_limit,
) )
@dataclass @dataclass
class AdminResetConcurrencyAdapter(AdminApiAdapter): class AdminResetKeyRpmAdapter(AdminApiAdapter):
endpoint_id: Optional[str] key_id: str
key_id: Optional[str]
async def handle(self, context): # type: ignore[override] async def handle(self, context): # type: ignore[override]
concurrency_manager = await get_concurrency_manager() concurrency_manager = await get_concurrency_manager()
await concurrency_manager.reset_concurrency( await concurrency_manager.reset_key_rpm(key_id=self.key_id)
endpoint_id=self.endpoint_id, key_id=self.key_id return {"message": "RPM 计数已重置"}
)
return {"message": "并发计数已重置"}

View File

@@ -5,7 +5,7 @@ Endpoint 健康监控 API
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Dict, List from typing import Dict, List, Optional
from fastapi import APIRouter, Depends, Query, Request from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy import func from sqlalchemy import func
@@ -128,29 +128,32 @@ async def get_api_format_health_monitor(
async def get_key_health( async def get_key_health(
key_id: str, key_id: str,
request: Request, request: Request,
api_format: Optional[str] = Query(None, description="API 格式(可选,如 CLAUDE、OPENAI"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> HealthStatusResponse: ) -> HealthStatusResponse:
""" """
获取 Key 健康状态 获取 Key 健康状态
获取指定 API Key 的健康状态详情,包括健康分数、连续失败次数、 获取指定 API Key 的健康状态详情,包括健康分数、连续失败次数、
熔断器状态等信息。 熔断器状态等信息。支持按 API 格式查询。
**路径参数**: **路径参数**:
- `key_id`: API Key ID - `key_id`: API Key ID
**查询参数**:
- `api_format`: 可选,指定 API 格式(如 CLAUDE、OPENAI
- 指定时返回该格式的健康度详情
- 不指定时返回所有格式的健康度摘要
**返回字段**: **返回字段**:
- `key_id`: API Key ID - `key_id`: API Key ID
- `key_health_score`: 健康分数0.0-1.0 - `key_health_score`: 健康分数0.0-1.0
- `key_consecutive_failures`: 连续失败次数
- `key_last_failure_at`: 最后失败时间
- `key_is_active`: 是否活跃 - `key_is_active`: 是否活跃
- `key_statistics`: 统计信息 - `key_statistics`: 统计信息
- `circuit_breaker_open`: 熔断器是否打开 - `health_by_format`: 按格式的健康度数据(无 api_format 参数时)
- `circuit_breaker_open_at`: 熔断器打开时间 - `circuit_breaker_open`: 熔断器是否打开(有 api_format 参数时)
- `next_probe_at`: 下次探测时间
""" """
adapter = AdminKeyHealthAdapter(key_id=key_id) adapter = AdminKeyHealthAdapter(key_id=key_id, api_format=api_format)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@@ -158,17 +161,23 @@ async def get_key_health(
async def recover_key_health( async def recover_key_health(
key_id: str, key_id: str,
request: Request, request: Request,
api_format: Optional[str] = Query(None, description="API 格式(可选,不指定则恢复所有格式)"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
""" """
恢复 Key 健康状态 恢复 Key 健康状态
手动恢复指定 Key 的健康状态,将健康分数重置为 1.0,关闭熔断器, 手动恢复指定 Key 的健康状态,将健康分数重置为 1.0,关闭熔断器,
取消自动禁用,并重置所有失败计数。 取消自动禁用,并重置所有失败计数。支持按 API 格式恢复。
**路径参数**: **路径参数**:
- `key_id`: API Key ID - `key_id`: API Key ID
**查询参数**:
- `api_format`: 可选,指定 API 格式(如 CLAUDE、OPENAI
- 指定时仅恢复该格式的健康度
- 不指定时恢复所有格式
**返回字段**: **返回字段**:
- `message`: 操作结果消息 - `message`: 操作结果消息
- `details`: 详细信息 - `details`: 详细信息
@@ -176,7 +185,7 @@ async def recover_key_health(
- `circuit_breaker_open`: 熔断器状态 - `circuit_breaker_open`: 熔断器状态
- `is_active`: 是否活跃 - `is_active`: 是否活跃
""" """
adapter = AdminRecoverKeyHealthAdapter(key_id=key_id) adapter = AdminRecoverKeyHealthAdapter(key_id=key_id, api_format=api_format)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@@ -276,34 +285,9 @@ class AdminApiFormatHealthMonitorAdapter(AdminApiAdapter):
) )
all_formats[api_format] = provider_count all_formats[api_format] = provider_count
# 1.1 获取所有活跃的 API 格式及其 API Key 数量 # 1.1 建立每个 API 格式对应的 Endpoint ID 列表(用于时间线生成),并收集活跃的 provider+format 组合
active_keys = (
db.query(
ProviderEndpoint.api_format,
func.count(ProviderAPIKey.id).label("key_count"),
)
.join(ProviderAPIKey, ProviderEndpoint.id == ProviderAPIKey.endpoint_id)
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
.filter(
ProviderEndpoint.is_active.is_(True),
Provider.is_active.is_(True),
ProviderAPIKey.is_active.is_(True),
)
.group_by(ProviderEndpoint.api_format)
.all()
)
# 构建所有格式的 key_count 映射
key_counts: Dict[str, int] = {}
for api_format_enum, key_count in active_keys:
api_format = (
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
)
key_counts[api_format] = key_count
# 1.2 建立每个 API 格式对应的 Endpoint ID 列表,供 Usage 时间线生成使用
endpoint_rows = ( endpoint_rows = (
db.query(ProviderEndpoint.api_format, ProviderEndpoint.id) db.query(ProviderEndpoint.api_format, ProviderEndpoint.id, ProviderEndpoint.provider_id)
.join(Provider, ProviderEndpoint.provider_id == Provider.id) .join(Provider, ProviderEndpoint.provider_id == Provider.id)
.filter( .filter(
ProviderEndpoint.is_active.is_(True), ProviderEndpoint.is_active.is_(True),
@@ -312,11 +296,32 @@ class AdminApiFormatHealthMonitorAdapter(AdminApiAdapter):
.all() .all()
) )
endpoint_map: Dict[str, List[str]] = defaultdict(list) endpoint_map: Dict[str, List[str]] = defaultdict(list)
for api_format_enum, endpoint_id in endpoint_rows: active_provider_formats: set[tuple[str, str]] = set()
for api_format_enum, endpoint_id, provider_id in endpoint_rows:
api_format = ( api_format = (
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum) api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
) )
endpoint_map[api_format].append(endpoint_id) endpoint_map[api_format].append(endpoint_id)
active_provider_formats.add((str(provider_id), api_format))
# 1.2 统计每个 API 格式可用的活跃 Key 数量Key 属于 Provider通过 api_formats 关联格式)
key_counts: Dict[str, int] = {}
if active_provider_formats:
active_provider_keys = (
db.query(ProviderAPIKey.provider_id, ProviderAPIKey.api_formats)
.join(Provider, ProviderAPIKey.provider_id == Provider.id)
.filter(
Provider.is_active.is_(True),
ProviderAPIKey.is_active.is_(True),
)
.all()
)
for provider_id, api_formats in active_provider_keys:
pid = str(provider_id)
for fmt in (api_formats or []):
if (pid, fmt) not in active_provider_formats:
continue
key_counts[fmt] = key_counts.get(fmt, 0) + 1
# 2. 统计窗口内每个 API 格式的请求状态分布(真实统计) # 2. 统计窗口内每个 API 格式的请求状态分布(真实统计)
# 只统计最终状态success, failed, skipped # 只统计最终状态success, failed, skipped
@@ -457,28 +462,45 @@ class AdminApiFormatHealthMonitorAdapter(AdminApiAdapter):
@dataclass @dataclass
class AdminKeyHealthAdapter(AdminApiAdapter): class AdminKeyHealthAdapter(AdminApiAdapter):
key_id: str key_id: str
api_format: Optional[str] = None
async def handle(self, context): # type: ignore[override] async def handle(self, context): # type: ignore[override]
health_data = health_monitor.get_key_health(context.db, self.key_id) health_data = health_monitor.get_key_health(context.db, self.key_id, self.api_format)
if not health_data: if not health_data:
raise NotFoundException(f"Key {self.key_id} 不存在") raise NotFoundException(f"Key {self.key_id} 不存在")
return HealthStatusResponse( # 构建响应
key_id=health_data["key_id"], response_data = {
key_health_score=health_data["health_score"], "key_id": health_data["key_id"],
key_consecutive_failures=health_data["consecutive_failures"], "key_is_active": health_data["is_active"],
key_last_failure_at=health_data["last_failure_at"], "key_statistics": health_data.get("statistics"),
key_is_active=health_data["is_active"], "key_health_score": health_data.get("health_score", 1.0),
key_statistics=health_data["statistics"], }
circuit_breaker_open=health_data["circuit_breaker_open"],
circuit_breaker_open_at=health_data["circuit_breaker_open_at"], if self.api_format:
next_probe_at=health_data["next_probe_at"], # 单格式查询
) response_data["api_format"] = self.api_format
response_data["key_consecutive_failures"] = health_data.get("consecutive_failures")
response_data["key_last_failure_at"] = health_data.get("last_failure_at")
circuit = health_data.get("circuit_breaker", {})
response_data["circuit_breaker_open"] = circuit.get("open", False)
response_data["circuit_breaker_open_at"] = circuit.get("open_at")
response_data["next_probe_at"] = circuit.get("next_probe_at")
response_data["half_open_until"] = circuit.get("half_open_until")
response_data["half_open_successes"] = circuit.get("half_open_successes", 0)
response_data["half_open_failures"] = circuit.get("half_open_failures", 0)
else:
# 全格式查询
response_data["any_circuit_open"] = health_data.get("any_circuit_open", False)
response_data["health_by_format"] = health_data.get("health_by_format")
return HealthStatusResponse(**response_data)
@dataclass @dataclass
class AdminRecoverKeyHealthAdapter(AdminApiAdapter): class AdminRecoverKeyHealthAdapter(AdminApiAdapter):
key_id: str key_id: str
api_format: Optional[str] = None
async def handle(self, context): # type: ignore[override] async def handle(self, context): # type: ignore[override]
db = context.db db = context.db
@@ -486,28 +508,38 @@ class AdminRecoverKeyHealthAdapter(AdminApiAdapter):
if not key: if not key:
raise NotFoundException(f"Key {self.key_id} 不存在") raise NotFoundException(f"Key {self.key_id} 不存在")
key.health_score = 1.0 # 使用 health_monitor.reset_health 重置健康度
key.consecutive_failures = 0 success = health_monitor.reset_health(db, key_id=self.key_id, api_format=self.api_format)
key.last_failure_at = None if not success:
key.circuit_breaker_open = False raise Exception("重置健康度失败")
key.circuit_breaker_open_at = None
key.next_probe_at = None # 如果 Key 被禁用,重新启用
if not key.is_active: if not key.is_active:
key.is_active = True key.is_active = True # type: ignore[assignment]
db.commit() db.commit()
admin_name = context.user.username if context.user else "admin" if self.api_format:
logger.info(f"管理员恢复Key健康状态: {self.key_id} (health_score: 1.0, circuit_breaker: closed)") logger.info(f"管理员恢复Key健康状态: {self.key_id}/{self.api_format}")
return {
return { "message": f"Key 的 {self.api_format} 格式已恢复",
"message": "Key已完全恢复", "details": {
"details": { "api_format": self.api_format,
"health_score": 1.0, "health_score": 1.0,
"circuit_breaker_open": False, "circuit_breaker_open": False,
"is_active": True, "is_active": True,
}, },
} }
else:
logger.info(f"管理员恢复Key健康状态: {self.key_id} (所有格式)")
return {
"message": "Key 所有格式已恢复",
"details": {
"health_score": 1.0,
"circuit_breaker_open": False,
"is_active": True,
},
}
class AdminRecoverAllKeysHealthAdapter(AdminApiAdapter): class AdminRecoverAllKeysHealthAdapter(AdminApiAdapter):
@@ -516,10 +548,17 @@ class AdminRecoverAllKeysHealthAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override] async def handle(self, context): # type: ignore[override]
db = context.db db = context.db
# 查找所有熔断的 Key # 查找所有熔断格式的 Key(检查 circuit_breaker_by_format JSON 字段)
circuit_open_keys = ( all_keys = db.query(ProviderAPIKey).all()
db.query(ProviderAPIKey).filter(ProviderAPIKey.circuit_breaker_open == True).all()
) # 筛选出有任何格式熔断的 Key
circuit_open_keys = []
for key in all_keys:
circuit_by_format = key.circuit_breaker_by_format or {}
for fmt, circuit_data in circuit_by_format.items():
if circuit_data.get("open"):
circuit_open_keys.append(key)
break
if not circuit_open_keys: if not circuit_open_keys:
return { return {
@@ -530,17 +569,15 @@ class AdminRecoverAllKeysHealthAdapter(AdminApiAdapter):
recovered_keys = [] recovered_keys = []
for key in circuit_open_keys: for key in circuit_open_keys:
key.health_score = 1.0 # 重置所有格式的健康度
key.consecutive_failures = 0 key.health_by_format = {} # type: ignore[assignment]
key.last_failure_at = None key.circuit_breaker_by_format = {} # type: ignore[assignment]
key.circuit_breaker_open = False
key.circuit_breaker_open_at = None
key.next_probe_at = None
recovered_keys.append( recovered_keys.append(
{ {
"key_id": key.id, "key_id": key.id,
"key_name": key.name, "key_name": key.name,
"endpoint_id": key.endpoint_id, "provider_id": key.provider_id,
"api_formats": key.api_formats,
} }
) )
@@ -552,7 +589,6 @@ class AdminRecoverAllKeysHealthAdapter(AdminApiAdapter):
HealthMonitor._open_circuit_keys = 0 HealthMonitor._open_circuit_keys = 0
health_open_circuits.set(0) health_open_circuits.set(0)
admin_name = context.user.username if context.user else "admin"
logger.info(f"管理员批量恢复 {len(recovered_keys)} 个 Key 的健康状态") logger.info(f"管理员批量恢复 {len(recovered_keys)} 个 Key 的健康状态")
return { return {

View File

@@ -1,5 +1,5 @@
""" """
Endpoint API Keys 管理 Provider API Keys 管理
""" """
import uuid import uuid
@@ -12,103 +12,23 @@ from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline from src.api.base.pipeline import ApiRequestPipeline
from src.config.constants import RPMDefaults
from src.core.crypto import crypto_service from src.core.crypto import crypto_service
from src.core.exceptions import InvalidRequestException, NotFoundException from src.core.exceptions import InvalidRequestException, NotFoundException
from src.core.key_capabilities import get_capability from src.core.key_capabilities import get_capability
from src.core.logger import logger from src.core.logger import logger
from src.database import get_db from src.database import get_db
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
from src.services.cache.provider_cache import ProviderCacheService
from src.models.endpoint_models import ( from src.models.endpoint_models import (
BatchUpdateKeyPriorityRequest,
EndpointAPIKeyCreate, EndpointAPIKeyCreate,
EndpointAPIKeyResponse, EndpointAPIKeyResponse,
EndpointAPIKeyUpdate, EndpointAPIKeyUpdate,
) )
router = APIRouter(tags=["Endpoint Keys"]) router = APIRouter(tags=["Provider Keys"])
pipeline = ApiRequestPipeline() pipeline = ApiRequestPipeline()
@router.get("/{endpoint_id}/keys", response_model=List[EndpointAPIKeyResponse])
async def list_endpoint_keys(
endpoint_id: str,
request: Request,
skip: int = Query(0, ge=0, description="跳过的记录数"),
limit: int = Query(100, ge=1, le=1000, description="返回的最大记录数"),
db: Session = Depends(get_db),
) -> List[EndpointAPIKeyResponse]:
"""
获取 Endpoint 的所有 Keys
获取指定 Endpoint 下的所有 API Key 列表,包括 Key 的配置、统计信息等。
结果按优先级和创建时间排序。
**路径参数**:
- `endpoint_id`: Endpoint ID
**查询参数**:
- `skip`: 跳过的记录数,用于分页(默认 0
- `limit`: 返回的最大记录数1-1000默认 100
**返回字段**:
- `id`: Key ID
- `name`: Key 名称
- `api_key_masked`: 脱敏后的 API Key
- `internal_priority`: 内部优先级
- `global_priority`: 全局优先级
- `rate_multiplier`: 速率倍数
- `max_concurrent`: 最大并发数null 表示自适应模式)
- `is_adaptive`: 是否为自适应并发模式
- `effective_limit`: 有效并发限制
- `success_rate`: 成功率
- `avg_response_time_ms`: 平均响应时间(毫秒)
- 其他配置和统计字段
"""
adapter = AdminListEndpointKeysAdapter(
endpoint_id=endpoint_id,
skip=skip,
limit=limit,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/{endpoint_id}/keys", response_model=EndpointAPIKeyResponse)
async def add_endpoint_key(
endpoint_id: str,
key_data: EndpointAPIKeyCreate,
request: Request,
db: Session = Depends(get_db),
) -> EndpointAPIKeyResponse:
"""
为 Endpoint 添加 Key
为指定 Endpoint 添加新的 API Key支持配置并发限制、速率倍数、
优先级、配额限制、能力限制等。
**路径参数**:
- `endpoint_id`: Endpoint ID
**请求体字段**:
- `endpoint_id`: Endpoint ID必须与路径参数一致
- `api_key`: API Key 原文(将被加密存储)
- `name`: Key 名称
- `note`: 备注(可选)
- `rate_multiplier`: 速率倍数(默认 1.0
- `internal_priority`: 内部优先级(默认 100
- `max_concurrent`: 最大并发数null 表示自适应模式)
- `rate_limit`: 每分钟请求限制(可选)
- `daily_limit`: 每日请求限制(可选)
- `monthly_limit`: 每月请求限制(可选)
- `allowed_models`: 允许的模型列表(可选)
- `capabilities`: 能力配置(可选)
**返回字段**:
- 包含完整的 Key 信息,其中 `api_key_plain` 为原文(仅在创建时返回)
"""
adapter = AdminCreateEndpointKeyAdapter(endpoint_id=endpoint_id, key_data=key_data)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.put("/keys/{key_id}", response_model=EndpointAPIKeyResponse) @router.put("/keys/{key_id}", response_model=EndpointAPIKeyResponse)
async def update_endpoint_key( async def update_endpoint_key(
key_id: str, key_id: str,
@@ -117,7 +37,7 @@ async def update_endpoint_key(
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> EndpointAPIKeyResponse: ) -> EndpointAPIKeyResponse:
""" """
更新 Endpoint Key 更新 Provider Key
更新指定 Key 的配置,支持修改并发限制、速率倍数、优先级、 更新指定 Key 的配置,支持修改并发限制、速率倍数、优先级、
配额限制、能力限制等。支持部分更新。 配额限制、能力限制等。支持部分更新。
@@ -131,10 +51,7 @@ async def update_endpoint_key(
- `note`: 备注 - `note`: 备注
- `rate_multiplier`: 速率倍数 - `rate_multiplier`: 速率倍数
- `internal_priority`: 内部优先级 - `internal_priority`: 内部优先级
- `max_concurrent`: 最大并发数(设置为 null 可切换到自适应模式) - `rpm_limit`: RPM 限制(设置为 null 可切换到自适应模式)
- `rate_limit`: 每分钟请求限制
- `daily_limit`: 每日请求限制
- `monthly_limit`: 每月请求限制
- `allowed_models`: 允许的模型列表 - `allowed_models`: 允许的模型列表
- `capabilities`: 能力配置 - `capabilities`: 能力配置
- `is_active`: 是否活跃 - `is_active`: 是否活跃
@@ -209,7 +126,7 @@ async def delete_endpoint_key(
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
""" """
删除 Endpoint Key 删除 Provider Key
删除指定的 API Key。此操作不可逆请谨慎使用。 删除指定的 API Key。此操作不可逆请谨慎使用。
@@ -223,163 +140,66 @@ async def delete_endpoint_key(
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.put("/{endpoint_id}/keys/batch-priority") # ========== Provider Keys API ==========
async def batch_update_key_priority(
endpoint_id: str,
request: Request,
priority_data: BatchUpdateKeyPriorityRequest,
db: Session = Depends(get_db),
) -> dict:
"""
批量更新 Endpoint 下 Keys 的优先级
批量更新指定 Endpoint 下多个 Key 的内部优先级,用于拖动排序。
所有 Key 必须属于指定的 Endpoint。 @router.get("/providers/{provider_id}/keys", response_model=List[EndpointAPIKeyResponse])
async def list_provider_keys(
provider_id: str,
request: Request,
skip: int = Query(0, ge=0, description="跳过的记录数"),
limit: int = Query(100, ge=1, le=1000, description="返回的最大记录数"),
db: Session = Depends(get_db),
) -> List[EndpointAPIKeyResponse]:
"""
获取 Provider 的所有 Keys
获取指定 Provider 下的所有 API Key 列表,支持多 API 格式。
结果按优先级和创建时间排序。
**路径参数**: **路径参数**:
- `endpoint_id`: Endpoint ID - `provider_id`: Provider ID
**查询参数**:
- `skip`: 跳过的记录数,用于分页(默认 0
- `limit`: 返回的最大记录数1-1000默认 100
"""
adapter = AdminListProviderKeysAdapter(
provider_id=provider_id,
skip=skip,
limit=limit,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/providers/{provider_id}/keys", response_model=EndpointAPIKeyResponse)
async def add_provider_key(
provider_id: str,
key_data: EndpointAPIKeyCreate,
request: Request,
db: Session = Depends(get_db),
) -> EndpointAPIKeyResponse:
"""
为 Provider 添加 Key
为指定 Provider 添加新的 API Key支持配置多个 API 格式。
**路径参数**:
- `provider_id`: Provider ID
**请求体字段**: **请求体字段**:
- `priorities`: 优先级列表 - `api_formats`: 支持的 API 格式列表(必填)
- `key_id`: Key ID - `api_key`: API Key 原文(将被加密存储)
- `internal_priority`: 新的内部优先级 - `name`: Key 名称
- 其他配置字段同 Key
**返回字段**:
- `message`: 操作结果消息
- `updated_count`: 实际更新的 Key 数量
""" """
adapter = AdminBatchUpdateKeyPriorityAdapter(endpoint_id=endpoint_id, priority_data=priority_data) adapter = AdminCreateProviderKeyAdapter(provider_id=provider_id, key_data=key_data)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# -------- Adapters -------- # -------- Adapters --------
@dataclass
class AdminListEndpointKeysAdapter(AdminApiAdapter):
endpoint_id: str
skip: int
limit: int
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
keys = (
db.query(ProviderAPIKey)
.filter(ProviderAPIKey.endpoint_id == self.endpoint_id)
.order_by(ProviderAPIKey.internal_priority.asc(), ProviderAPIKey.created_at.asc())
.offset(self.skip)
.limit(self.limit)
.all()
)
result: List[EndpointAPIKeyResponse] = []
for key in keys:
try:
decrypted_key = crypto_service.decrypt(key.api_key)
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
except Exception:
masked_key = "***ERROR***"
success_rate = key.success_count / key.request_count if key.request_count > 0 else 0.0
avg_response_time_ms = (
key.total_response_time_ms / key.success_count if key.success_count > 0 else 0.0
)
is_adaptive = key.max_concurrent is None
key_dict = key.__dict__.copy()
key_dict.pop("_sa_instance_state", None)
key_dict.update(
{
"api_key_masked": masked_key,
"api_key_plain": None,
"success_rate": success_rate,
"avg_response_time_ms": round(avg_response_time_ms, 2),
"is_adaptive": is_adaptive,
"effective_limit": (
key.learned_max_concurrent if is_adaptive else key.max_concurrent
),
}
)
result.append(EndpointAPIKeyResponse(**key_dict))
return result
@dataclass
class AdminCreateEndpointKeyAdapter(AdminApiAdapter):
endpoint_id: str
key_data: EndpointAPIKeyCreate
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
if self.key_data.endpoint_id != self.endpoint_id:
raise InvalidRequestException("endpoint_id 不匹配")
encrypted_key = crypto_service.encrypt(self.key_data.api_key)
now = datetime.now(timezone.utc)
# max_concurrent=NULL 表示自适应模式,数字表示固定限制
new_key = ProviderAPIKey(
id=str(uuid.uuid4()),
endpoint_id=self.endpoint_id,
api_key=encrypted_key,
name=self.key_data.name,
note=self.key_data.note,
rate_multiplier=self.key_data.rate_multiplier,
internal_priority=self.key_data.internal_priority,
max_concurrent=self.key_data.max_concurrent, # NULL=自适应模式
rate_limit=self.key_data.rate_limit,
daily_limit=self.key_data.daily_limit,
monthly_limit=self.key_data.monthly_limit,
allowed_models=self.key_data.allowed_models if self.key_data.allowed_models else None,
capabilities=self.key_data.capabilities if self.key_data.capabilities else None,
request_count=0,
success_count=0,
error_count=0,
total_response_time_ms=0,
is_active=True,
last_used_at=None,
created_at=now,
updated_at=now,
)
db.add(new_key)
db.commit()
db.refresh(new_key)
logger.info(f"[OK] 添加 Key: Endpoint={self.endpoint_id}, Key=***{self.key_data.api_key[-4:]}, ID={new_key.id}")
masked_key = f"{self.key_data.api_key[:8]}***{self.key_data.api_key[-4:]}"
is_adaptive = new_key.max_concurrent is None
response_dict = new_key.__dict__.copy()
response_dict.pop("_sa_instance_state", None)
response_dict.update(
{
"api_key_masked": masked_key,
"api_key_plain": self.key_data.api_key,
"success_rate": 0.0,
"avg_response_time_ms": 0.0,
"is_adaptive": is_adaptive,
"effective_limit": (
new_key.learned_max_concurrent if is_adaptive else new_key.max_concurrent
),
}
)
return EndpointAPIKeyResponse(**response_dict)
@dataclass @dataclass
class AdminUpdateEndpointKeyAdapter(AdminApiAdapter): class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
key_id: str key_id: str
@@ -395,14 +215,21 @@ class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
if "api_key" in update_data: if "api_key" in update_data:
update_data["api_key"] = crypto_service.encrypt(update_data["api_key"]) update_data["api_key"] = crypto_service.encrypt(update_data["api_key"])
# 特殊处理 max_concurrent需要区分"未提供"和"显式设置为 null" # 特殊处理 rpm_limit需要区分"未提供"和"显式设置为 null"
# 当 max_concurrent 被显式设置时(在 model_fields_set 中),即使值为 None 也应该更新 if "rpm_limit" in self.key_data.model_fields_set:
if "max_concurrent" in self.key_data.model_fields_set: update_data["rpm_limit"] = self.key_data.rpm_limit
update_data["max_concurrent"] = self.key_data.max_concurrent if self.key_data.rpm_limit is None:
# 切换到自适应模式时,清空学习到的并发限制,让系统重新学习 update_data["learned_rpm_limit"] = None
if self.key_data.max_concurrent is None: logger.info("Key %s 切换为自适应 RPM 模式", self.key_id)
update_data["learned_max_concurrent"] = None
logger.info("Key %s 切换为自适应并发模式", self.key_id) # 统一处理 allowed_models空列表/空字典 -> None表示不限制
if "allowed_models" in update_data:
am = update_data["allowed_models"]
if am is not None and (
(isinstance(am, list) and len(am) == 0)
or (isinstance(am, dict) and len(am) == 0)
):
update_data["allowed_models"] = None
for field, value in update_data.items(): for field, value in update_data.items():
setattr(key, field, value) setattr(key, field, value)
@@ -411,35 +238,13 @@ class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
db.commit() db.commit()
db.refresh(key) db.refresh(key)
# 任何字段更新都清除缓存,确保缓存一致性
# 包括 is_active、allowed_models、capabilities 等影响权限和行为的字段
await ProviderCacheService.invalidate_provider_api_key_cache(self.key_id)
logger.info("[OK] 更新 Key: ID=%s, Updates=%s", self.key_id, list(update_data.keys())) logger.info("[OK] 更新 Key: ID=%s, Updates=%s", self.key_id, list(update_data.keys()))
try: return _build_key_response(key)
decrypted_key = crypto_service.decrypt(key.api_key)
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
except Exception:
masked_key = "***ERROR***"
success_rate = key.success_count / key.request_count if key.request_count > 0 else 0.0
avg_response_time_ms = (
key.total_response_time_ms / key.success_count if key.success_count > 0 else 0.0
)
is_adaptive = key.max_concurrent is None
response_dict = key.__dict__.copy()
response_dict.pop("_sa_instance_state", None)
response_dict.update(
{
"api_key_masked": masked_key,
"api_key_plain": None,
"success_rate": success_rate,
"avg_response_time_ms": round(avg_response_time_ms, 2),
"is_adaptive": is_adaptive,
"effective_limit": (
key.learned_max_concurrent if is_adaptive else key.max_concurrent
),
}
)
return EndpointAPIKeyResponse(**response_dict)
@dataclass @dataclass
@@ -476,7 +281,7 @@ class AdminDeleteEndpointKeyAdapter(AdminApiAdapter):
if not key: if not key:
raise NotFoundException(f"Key {self.key_id} 不存在") raise NotFoundException(f"Key {self.key_id} 不存在")
endpoint_id = key.endpoint_id provider_id = key.provider_id
try: try:
db.delete(key) db.delete(key)
db.commit() db.commit()
@@ -485,7 +290,7 @@ class AdminDeleteEndpointKeyAdapter(AdminApiAdapter):
logger.error(f"删除 Key 失败: ID={self.key_id}, Error={exc}") logger.error(f"删除 Key 失败: ID={self.key_id}, Error={exc}")
raise raise
logger.warning(f"[DELETE] 删除 Key: ID={self.key_id}, Endpoint={endpoint_id}") logger.warning(f"[DELETE] 删除 Key: ID={self.key_id}, Provider={provider_id}")
return {"message": f"Key {self.key_id} 已删除"} return {"message": f"Key {self.key_id} 已删除"}
@@ -493,31 +298,51 @@ class AdminGetKeysGroupedByFormatAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override] async def handle(self, context): # type: ignore[override]
db = context.db db = context.db
# Key 属于 Provider按 key.api_formats 分组展示
keys = ( keys = (
db.query(ProviderAPIKey, ProviderEndpoint, Provider) db.query(ProviderAPIKey, Provider)
.join(ProviderEndpoint, ProviderAPIKey.endpoint_id == ProviderEndpoint.id) .join(Provider, ProviderAPIKey.provider_id == Provider.id)
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
.filter( .filter(
ProviderAPIKey.is_active.is_(True), ProviderAPIKey.is_active.is_(True),
ProviderEndpoint.is_active.is_(True),
Provider.is_active.is_(True), Provider.is_active.is_(True),
) )
.order_by( .order_by(
ProviderAPIKey.global_priority.asc().nullslast(), ProviderAPIKey.internal_priority.asc() ProviderAPIKey.global_priority.asc().nullslast(),
ProviderAPIKey.internal_priority.asc(),
) )
.all() .all()
) )
provider_ids = {str(provider.id) for _key, provider in keys}
endpoints = (
db.query(
ProviderEndpoint.provider_id,
ProviderEndpoint.api_format,
ProviderEndpoint.base_url,
)
.filter(
ProviderEndpoint.provider_id.in_(provider_ids),
ProviderEndpoint.is_active.is_(True),
)
.all()
)
endpoint_base_url_map: Dict[tuple[str, str], str] = {}
for provider_id, api_format, base_url in endpoints:
fmt = api_format.value if hasattr(api_format, "value") else str(api_format)
endpoint_base_url_map[(str(provider_id), fmt)] = base_url
grouped: Dict[str, List[dict]] = {} grouped: Dict[str, List[dict]] = {}
for key, endpoint, provider in keys: for key, provider in keys:
api_format = endpoint.api_format api_formats = key.api_formats or []
if api_format not in grouped:
grouped[api_format] = [] if not api_formats:
continue # 跳过没有 API 格式的 Key
try: try:
decrypted_key = crypto_service.decrypt(key.api_key) decrypted_key = crypto_service.decrypt(key.api_key)
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}" masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
except Exception: except Exception as e:
logger.error(f"解密 Key 失败: key_id={key.id}, error={e}")
masked_key = "***ERROR***" masked_key = "***ERROR***"
# 计算健康度指标 # 计算健康度指标
@@ -536,72 +361,209 @@ class AdminGetKeysGroupedByFormatAdapter(AdminApiAdapter):
cap_def = get_capability(cap_name) cap_def = get_capability(cap_name)
caps_list.append(cap_def.short_name if cap_def else cap_name) caps_list.append(cap_def.short_name if cap_def else cap_name)
grouped[api_format].append( # 构建 Key 信息(基础数据)
{ key_info = {
"id": key.id, "id": key.id,
"name": key.name, "name": key.name,
"api_key_masked": masked_key, "api_key_masked": masked_key,
"internal_priority": key.internal_priority, "internal_priority": key.internal_priority,
"global_priority": key.global_priority, "global_priority": key.global_priority,
"rate_multiplier": key.rate_multiplier, "rate_multiplier": key.rate_multiplier,
"is_active": key.is_active, "is_active": key.is_active,
"circuit_breaker_open": key.circuit_breaker_open, "provider_name": provider.name,
"provider_name": provider.display_name or provider.name, "api_formats": api_formats,
"endpoint_base_url": endpoint.base_url, "capabilities": caps_list,
"api_format": api_format, "success_rate": success_rate,
"capabilities": caps_list, "avg_response_time_ms": avg_response_time_ms,
"success_rate": success_rate, "request_count": key.request_count,
"avg_response_time_ms": avg_response_time_ms, }
"request_count": key.request_count,
} # 将 Key 添加到每个支持的格式分组中,并附加格式特定的健康度数据
) health_by_format = key.health_by_format or {}
circuit_by_format = key.circuit_breaker_by_format or {}
provider_id = str(provider.id)
for api_format in api_formats:
if api_format not in grouped:
grouped[api_format] = []
# 为每个格式创建副本,设置当前格式
format_key_info = key_info.copy()
format_key_info["api_format"] = api_format
format_key_info["endpoint_base_url"] = endpoint_base_url_map.get(
(provider_id, api_format)
)
# 添加格式特定的健康度数据
format_health = health_by_format.get(api_format, {})
format_circuit = circuit_by_format.get(api_format, {})
format_key_info["health_score"] = float(format_health.get("health_score") or 1.0)
format_key_info["circuit_breaker_open"] = bool(format_circuit.get("open", False))
grouped[api_format].append(format_key_info)
# 直接返回分组对象,供前端使用 # 直接返回分组对象,供前端使用
return grouped return grouped
# ========== Adapters ==========
def _build_key_response(
key: ProviderAPIKey, api_key_plain: str | None = None
) -> EndpointAPIKeyResponse:
"""构建 Key 响应对象的辅助函数"""
try:
decrypted_key = crypto_service.decrypt(key.api_key)
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
except Exception:
masked_key = "***ERROR***"
success_rate = key.success_count / key.request_count if key.request_count > 0 else 0.0
avg_response_time_ms = (
key.total_response_time_ms / key.success_count if key.success_count > 0 else 0.0
)
is_adaptive = key.rpm_limit is None
key_dict = key.__dict__.copy()
key_dict.pop("_sa_instance_state", None)
# 从 health_by_format 计算汇总字段(便于列表展示)
health_by_format = key.health_by_format or {}
circuit_by_format = key.circuit_breaker_by_format or {}
# 计算整体健康度(取所有格式中的最低值)
if health_by_format:
health_scores = [
float(h.get("health_score") or 1.0) for h in health_by_format.values()
]
min_health_score = min(health_scores) if health_scores else 1.0
# 取最大的连续失败次数
max_consecutive = max(
(int(h.get("consecutive_failures") or 0) for h in health_by_format.values()),
default=0,
)
# 取最近的失败时间
failure_times = [
h.get("last_failure_at")
for h in health_by_format.values()
if h.get("last_failure_at")
]
last_failure = max(failure_times) if failure_times else None
else:
min_health_score = 1.0
max_consecutive = 0
last_failure = None
# 检查是否有任何格式的熔断器打开
any_circuit_open = any(c.get("open", False) for c in circuit_by_format.values())
key_dict.update(
{
"api_key_masked": masked_key,
"api_key_plain": api_key_plain,
"success_rate": success_rate,
"avg_response_time_ms": round(avg_response_time_ms, 2),
"is_adaptive": is_adaptive,
"effective_limit": (
(key.learned_rpm_limit if key.learned_rpm_limit is not None else RPMDefaults.INITIAL_LIMIT)
if is_adaptive
else key.rpm_limit
),
# 汇总字段
"health_score": min_health_score,
"consecutive_failures": max_consecutive,
"last_failure_at": last_failure,
"circuit_breaker_open": any_circuit_open,
}
)
# 防御性:确保 api_formats 存在(历史数据可能为空/缺失)
if "api_formats" not in key_dict or key_dict["api_formats"] is None:
key_dict["api_formats"] = []
return EndpointAPIKeyResponse(**key_dict)
@dataclass @dataclass
class AdminBatchUpdateKeyPriorityAdapter(AdminApiAdapter): class AdminListProviderKeysAdapter(AdminApiAdapter):
endpoint_id: str """获取 Provider 的所有 Keys"""
priority_data: BatchUpdateKeyPriorityRequest
provider_id: str
skip: int
limit: int
async def handle(self, context): # type: ignore[override] async def handle(self, context): # type: ignore[override]
db = context.db db = context.db
endpoint = ( provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first() if not provider:
) raise NotFoundException(f"Provider {self.provider_id} 不存在")
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
# 获取所有需要更新的 Key ID
key_ids = [item.key_id for item in self.priority_data.priorities]
# 验证所有 Key 都属于该 Endpoint
keys = ( keys = (
db.query(ProviderAPIKey) db.query(ProviderAPIKey)
.filter( .filter(ProviderAPIKey.provider_id == self.provider_id)
ProviderAPIKey.id.in_(key_ids), .order_by(ProviderAPIKey.internal_priority.asc(), ProviderAPIKey.created_at.asc())
ProviderAPIKey.endpoint_id == self.endpoint_id, .offset(self.skip)
) .limit(self.limit)
.all() .all()
) )
if len(keys) != len(key_ids): return [_build_key_response(key) for key in keys]
found_ids = {k.id for k in keys}
missing_ids = set(key_ids) - found_ids
raise InvalidRequestException(f"Keys 不属于该 Endpoint 或不存在: {missing_ids}")
# 批量更新优先级
key_map = {k.id: k for k in keys}
updated_count = 0
for item in self.priority_data.priorities:
key = key_map.get(item.key_id)
if key and key.internal_priority != item.internal_priority:
key.internal_priority = item.internal_priority
key.updated_at = datetime.now(timezone.utc)
updated_count += 1
@dataclass
class AdminCreateProviderKeyAdapter(AdminApiAdapter):
"""为 Provider 添加 Key"""
provider_id: str
key_data: EndpointAPIKeyCreate
async def handle(self, context): # type: ignore[override]
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException(f"Provider {self.provider_id} 不存在")
# 验证 api_formats 必填
if not self.key_data.api_formats:
raise InvalidRequestException("api_formats 为必填字段")
# 允许同一个 API Key 在同一 Provider 下添加多次
# 用户可以为不同的 API 格式创建独立的配置记录,便于分开管理
encrypted_key = crypto_service.encrypt(self.key_data.api_key)
now = datetime.now(timezone.utc)
new_key = ProviderAPIKey(
id=str(uuid.uuid4()),
provider_id=self.provider_id,
api_formats=self.key_data.api_formats,
api_key=encrypted_key,
name=self.key_data.name,
note=self.key_data.note,
rate_multiplier=self.key_data.rate_multiplier,
rate_multipliers=self.key_data.rate_multipliers, # 按 API 格式的成本倍率
internal_priority=self.key_data.internal_priority,
rpm_limit=self.key_data.rpm_limit,
allowed_models=self.key_data.allowed_models if self.key_data.allowed_models else None,
capabilities=self.key_data.capabilities if self.key_data.capabilities else None,
cache_ttl_minutes=self.key_data.cache_ttl_minutes,
max_probe_interval_minutes=self.key_data.max_probe_interval_minutes,
request_count=0,
success_count=0,
error_count=0,
total_response_time_ms=0,
health_by_format={}, # 按格式存储健康度
circuit_breaker_by_format={}, # 按格式存储熔断器状态
is_active=True,
last_used_at=None,
created_at=now,
updated_at=now,
)
db.add(new_key)
db.commit() db.commit()
db.refresh(new_key)
logger.info(f"[OK] 批量更新 Key 优先级: Endpoint={self.endpoint_id}, Updated={updated_count}/{len(key_ids)}") logger.info(
return {"message": f"已更新 {updated_count} 个 Key 的优先级", "updated_count": updated_count} f"[OK] 添加 Key: Provider={self.provider_id}, "
f"Formats={self.key_data.api_formats}, Key=***{self.key_data.api_key[-4:]}, ID={new_key.id}"
)
return _build_key_response(new_key, api_key_plain=self.key_data.api_key)

View File

@@ -67,8 +67,6 @@ async def list_provider_endpoints(
- `custom_path`: 自定义路径 - `custom_path`: 自定义路径
- `timeout`: 超时时间(秒) - `timeout`: 超时时间(秒)
- `max_retries`: 最大重试次数 - `max_retries`: 最大重试次数
- `max_concurrent`: 最大并发数
- `rate_limit`: 速率限制
- `is_active`: 是否活跃 - `is_active`: 是否活跃
- `total_keys`: Key 总数 - `total_keys`: Key 总数
- `active_keys`: 活跃 Key 数量 - `active_keys`: 活跃 Key 数量
@@ -107,8 +105,6 @@ async def create_provider_endpoint(
- `headers`: 自定义请求头(可选) - `headers`: 自定义请求头(可选)
- `timeout`: 超时时间(秒,默认 300 - `timeout`: 超时时间(秒,默认 300
- `max_retries`: 最大重试次数(默认 2 - `max_retries`: 最大重试次数(默认 2
- `max_concurrent`: 最大并发数(可选)
- `rate_limit`: 速率限制(可选)
- `config`: 额外配置(可选) - `config`: 额外配置(可选)
- `proxy`: 代理配置(可选) - `proxy`: 代理配置(可选)
@@ -145,8 +141,6 @@ async def get_endpoint(
- `custom_path`: 自定义路径 - `custom_path`: 自定义路径
- `timeout`: 超时时间(秒) - `timeout`: 超时时间(秒)
- `max_retries`: 最大重试次数 - `max_retries`: 最大重试次数
- `max_concurrent`: 最大并发数
- `rate_limit`: 速率限制
- `is_active`: 是否活跃 - `is_active`: 是否活跃
- `total_keys`: Key 总数 - `total_keys`: Key 总数
- `active_keys`: 活跃 Key 数量 - `active_keys`: 活跃 Key 数量
@@ -178,8 +172,6 @@ async def update_endpoint(
- `headers`: 自定义请求头 - `headers`: 自定义请求头
- `timeout`: 超时时间(秒) - `timeout`: 超时时间(秒)
- `max_retries`: 最大重试次数 - `max_retries`: 最大重试次数
- `max_concurrent`: 最大并发数
- `rate_limit`: 速率限制
- `is_active`: 是否活跃 - `is_active`: 是否活跃
- `config`: 额外配置 - `config`: 额外配置
- `proxy`: 代理配置(设置为 null 可清除代理) - `proxy`: 代理配置(设置为 null 可清除代理)
@@ -203,15 +195,15 @@ async def delete_endpoint(
""" """
删除 Endpoint 删除 Endpoint
删除指定的 Endpoint同时级联删除所有关联的 API Keys 删除指定的 Endpoint会影响该 Provider 在该 API 格式下的路由能力
此操作不可逆,请谨慎使用 Key 不会被删除,但包含该 API 格式的 Key 将无法被调度使用(直到重新创建该格式的 Endpoint
**路径参数**: **路径参数**:
- `endpoint_id`: Endpoint ID - `endpoint_id`: Endpoint ID
**返回字段**: **返回字段**:
- `message`: 操作结果消息 - `message`: 操作结果消息
- `deleted_keys_count`: 同时删除的 Key 数量 - `affected_keys_count`: 受影响的 Key 数量(包含该 API 格式)
""" """
adapter = AdminDeleteProviderEndpointAdapter(endpoint_id=endpoint_id) adapter = AdminDeleteProviderEndpointAdapter(endpoint_id=endpoint_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode) return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@@ -241,39 +233,33 @@ class AdminListProviderEndpointsAdapter(AdminApiAdapter):
.all() .all()
) )
endpoint_ids = [ep.id for ep in endpoints] # Key 是 Provider 级别资源:按 key.api_formats 归类到各 Endpoint.api_format 下
total_keys_map = {} keys = (
active_keys_map = {} db.query(ProviderAPIKey.api_formats, ProviderAPIKey.is_active)
if endpoint_ids: .filter(ProviderAPIKey.provider_id == self.provider_id)
total_rows = ( .all()
db.query(ProviderAPIKey.endpoint_id, func.count(ProviderAPIKey.id).label("total")) )
.filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids)) total_keys_map: dict[str, int] = {}
.group_by(ProviderAPIKey.endpoint_id) active_keys_map: dict[str, int] = {}
.all() for api_formats, is_active in keys:
) for fmt in (api_formats or []):
total_keys_map = {row.endpoint_id: row.total for row in total_rows} total_keys_map[fmt] = total_keys_map.get(fmt, 0) + 1
if is_active:
active_rows = ( active_keys_map[fmt] = active_keys_map.get(fmt, 0) + 1
db.query(ProviderAPIKey.endpoint_id, func.count(ProviderAPIKey.id).label("active"))
.filter(
and_(
ProviderAPIKey.endpoint_id.in_(endpoint_ids),
ProviderAPIKey.is_active.is_(True),
)
)
.group_by(ProviderAPIKey.endpoint_id)
.all()
)
active_keys_map = {row.endpoint_id: row.active for row in active_rows}
result: List[ProviderEndpointResponse] = [] result: List[ProviderEndpointResponse] = []
for endpoint in endpoints: for endpoint in endpoints:
endpoint_format = (
endpoint.api_format
if isinstance(endpoint.api_format, str)
else endpoint.api_format.value
)
endpoint_dict = { endpoint_dict = {
**endpoint.__dict__, **endpoint.__dict__,
"provider_name": provider.name, "provider_name": provider.name,
"api_format": endpoint.api_format, "api_format": endpoint.api_format,
"total_keys": total_keys_map.get(endpoint.id, 0), "total_keys": total_keys_map.get(endpoint_format, 0),
"active_keys": active_keys_map.get(endpoint.id, 0), "active_keys": active_keys_map.get(endpoint_format, 0),
"proxy": mask_proxy_password(endpoint.proxy), "proxy": mask_proxy_password(endpoint.proxy),
} }
endpoint_dict.pop("_sa_instance_state", None) endpoint_dict.pop("_sa_instance_state", None)
@@ -321,8 +307,6 @@ class AdminCreateProviderEndpointAdapter(AdminApiAdapter):
headers=self.endpoint_data.headers, headers=self.endpoint_data.headers,
timeout=self.endpoint_data.timeout, timeout=self.endpoint_data.timeout,
max_retries=self.endpoint_data.max_retries, max_retries=self.endpoint_data.max_retries,
max_concurrent=self.endpoint_data.max_concurrent,
rate_limit=self.endpoint_data.rate_limit,
is_active=True, is_active=True,
config=self.endpoint_data.config, config=self.endpoint_data.config,
proxy=self.endpoint_data.proxy.model_dump() if self.endpoint_data.proxy else None, proxy=self.endpoint_data.proxy.model_dump() if self.endpoint_data.proxy else None,
@@ -367,19 +351,23 @@ class AdminGetProviderEndpointAdapter(AdminApiAdapter):
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在") raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
endpoint_obj, provider = endpoint endpoint_obj, provider = endpoint
total_keys = ( endpoint_format = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count() endpoint_obj.api_format
if isinstance(endpoint_obj.api_format, str)
else endpoint_obj.api_format.value
) )
active_keys = ( keys = (
db.query(ProviderAPIKey) db.query(ProviderAPIKey.api_formats, ProviderAPIKey.is_active)
.filter( .filter(ProviderAPIKey.provider_id == endpoint_obj.provider_id)
and_( .all()
ProviderAPIKey.endpoint_id == self.endpoint_id,
ProviderAPIKey.is_active.is_(True),
)
)
.count()
) )
total_keys = 0
active_keys = 0
for api_formats, is_active in keys:
if endpoint_format in (api_formats or []):
total_keys += 1
if is_active:
active_keys += 1
endpoint_dict = { endpoint_dict = {
k: v k: v
@@ -431,19 +419,21 @@ class AdminUpdateProviderEndpointAdapter(AdminApiAdapter):
provider = db.query(Provider).filter(Provider.id == endpoint.provider_id).first() provider = db.query(Provider).filter(Provider.id == endpoint.provider_id).first()
logger.info(f"[OK] 更新 Endpoint: ID={self.endpoint_id}, Updates={list(update_data.keys())}") logger.info(f"[OK] 更新 Endpoint: ID={self.endpoint_id}, Updates={list(update_data.keys())}")
total_keys = ( endpoint_format = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count() endpoint.api_format if isinstance(endpoint.api_format, str) else endpoint.api_format.value
) )
active_keys = ( keys = (
db.query(ProviderAPIKey) db.query(ProviderAPIKey.api_formats, ProviderAPIKey.is_active)
.filter( .filter(ProviderAPIKey.provider_id == endpoint.provider_id)
and_( .all()
ProviderAPIKey.endpoint_id == self.endpoint_id,
ProviderAPIKey.is_active.is_(True),
)
)
.count()
) )
total_keys = 0
active_keys = 0
for api_formats, is_active in keys:
if endpoint_format in (api_formats or []):
total_keys += 1
if is_active:
active_keys += 1
endpoint_dict = { endpoint_dict = {
k: v k: v
@@ -472,12 +462,26 @@ class AdminDeleteProviderEndpointAdapter(AdminApiAdapter):
if not endpoint: if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在") raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
keys_count = ( endpoint_format = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count() endpoint.api_format if isinstance(endpoint.api_format, str) else endpoint.api_format.value
)
keys = (
db.query(ProviderAPIKey.api_formats)
.filter(ProviderAPIKey.provider_id == endpoint.provider_id)
.all()
)
affected_keys_count = sum(
1 for (api_formats,) in keys if endpoint_format in (api_formats or [])
) )
db.delete(endpoint) db.delete(endpoint)
db.commit() db.commit()
logger.warning(f"[DELETE] 删除 Endpoint: ID={self.endpoint_id}, 同时删除了 {keys_count} 个 Keys") logger.warning(
f"[DELETE] 删除 Endpoint: ID={self.endpoint_id}, Format={endpoint_format}, "
f"AffectedKeys={affected_keys_count}"
)
return {"message": f"Endpoint {self.endpoint_id} 已删除", "deleted_keys_count": keys_count} return {
"message": f"Endpoint {self.endpoint_id} 已删除",
"affected_keys_count": affected_keys_count,
}

View File

@@ -125,7 +125,6 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
ModelCatalogProviderDetail( ModelCatalogProviderDetail(
provider_id=provider.id, provider_id=provider.id,
provider_name=provider.name, provider_name=provider.name,
provider_display_name=provider.display_name,
model_id=model.id, model_id=model.id,
target_model=model.provider_model_name, target_model=model.provider_model_name,
# 显示有效价格 # 显示有效价格

View File

@@ -452,7 +452,6 @@ class AdminGetGlobalModelProvidersAdapter(AdminApiAdapter):
ModelCatalogProviderDetail( ModelCatalogProviderDetail(
provider_id=provider.id, provider_id=provider.id,
provider_name=provider.name, provider_name=provider.name,
provider_display_name=provider.display_name,
model_id=model.id, model_id=model.id,
target_model=model.provider_model_name, target_model=model.provider_model_name,
input_price_per_1m=model.get_effective_input_price(), input_price_per_1m=model.get_effective_input_price(),

View File

@@ -819,7 +819,7 @@ class AdminListAffinitiesAdapter(AdminApiAdapter):
"username": user.username if user else None, "username": user.username if user else None,
"email": user.email if user else None, "email": user.email if user else None,
"provider_id": provider_id, "provider_id": provider_id,
"provider_name": provider.display_name if provider else None, "provider_name": provider.name if provider else None,
"endpoint_id": endpoint_id, "endpoint_id": endpoint_id,
"endpoint_api_format": ( "endpoint_api_format": (
endpoint.api_format if endpoint and endpoint.api_format else None endpoint.api_format if endpoint and endpoint.api_format else None
@@ -1369,9 +1369,7 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
for model, provider in models: for model, provider in models:
# 检查是否是主模型名称 # 检查是否是主模型名称
if model.provider_model_name == mapping_name: if model.provider_model_name == mapping_name:
provider_names.append( provider_names.append(provider.name)
provider.display_name or provider.name
)
continue continue
# 检查是否在映射列表中 # 检查是否在映射列表中
if model.provider_model_mappings: if model.provider_model_mappings:
@@ -1381,9 +1379,7 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
if isinstance(a, dict) if isinstance(a, dict)
] ]
if mapping_name in mapping_list: if mapping_name in mapping_list:
provider_names.append( provider_names.append(provider.name)
provider.display_name or provider.name
)
provider_names = sorted(list(set(provider_names))) provider_names = sorted(list(set(provider_names)))
mappings.append({ mappings.append({
@@ -1473,7 +1469,7 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
provider_model_mappings.append({ provider_model_mappings.append({
"provider_id": provider_id, "provider_id": provider_id,
"provider_name": provider.display_name or provider.name, "provider_name": provider.name,
"global_model_id": global_model_id, "global_model_id": global_model_id,
"global_model_name": global_model.name, "global_model_name": global_model.name,
"global_model_display_name": global_model.display_name, "global_model_display_name": global_model.display_name,

View File

@@ -13,10 +13,11 @@ from sqlalchemy.orm import Session, joinedload
from src.api.handlers.base.chat_adapter_base import get_adapter_class from src.api.handlers.base.chat_adapter_base import get_adapter_class
from src.api.handlers.base.cli_adapter_base import get_cli_adapter_class from src.api.handlers.base.cli_adapter_base import get_cli_adapter_class
from src.config.constants import TimeoutDefaults
from src.core.crypto import crypto_service from src.core.crypto import crypto_service
from src.core.logger import logger from src.core.logger import logger
from src.database.database import get_db from src.database.database import get_db
from src.models.database import Provider, ProviderEndpoint, User from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint, User
from src.utils.auth_utils import get_current_user from src.utils.auth_utils import get_current_user
router = APIRouter(prefix="/api/admin/provider-query", tags=["Provider Query"]) router = APIRouter(prefix="/api/admin/provider-query", tags=["Provider Query"])
@@ -81,10 +82,13 @@ async def query_available_models(
Returns: Returns:
所有端点的模型列表(合并) 所有端点的模型列表(合并)
""" """
# 获取提供商及其端点 # 获取提供商及其端点和 API Keys
provider = ( provider = (
db.query(Provider) db.query(Provider)
.options(joinedload(Provider.endpoints).joinedload(ProviderEndpoint.api_keys)) .options(
joinedload(Provider.endpoints),
joinedload(Provider.api_keys),
)
.filter(Provider.id == request.provider_id) .filter(Provider.id == request.provider_id)
.first() .first()
) )
@@ -95,49 +99,70 @@ async def query_available_models(
# 收集所有活跃端点的配置 # 收集所有活跃端点的配置
endpoint_configs: list[dict] = [] endpoint_configs: list[dict] = []
# 构建 api_format -> endpoint 映射
format_to_endpoint: dict[str, ProviderEndpoint] = {}
for endpoint in provider.endpoints:
if endpoint.is_active:
format_to_endpoint[endpoint.api_format] = endpoint
if request.api_key_id: if request.api_key_id:
# 指定了特定的 API Key,只使用该 Key 对应的端点 # 指定了特定的 API Key(从 provider.api_keys 查找)
for endpoint in provider.endpoints: api_key = next(
for api_key in endpoint.api_keys: (key for key in provider.api_keys if key.id == request.api_key_id),
if api_key.id == request.api_key_id: None
try: )
api_key_value = crypto_service.decrypt(api_key.api_key)
except Exception as e: if not api_key:
logger.error(f"Failed to decrypt API key: {e}") raise HTTPException(status_code=404, detail="API Key not found")
raise HTTPException(status_code=500, detail="Failed to decrypt API key")
endpoint_configs.append({ try:
"api_key": api_key_value, api_key_value = crypto_service.decrypt(api_key.api_key)
"base_url": endpoint.base_url, except Exception as e:
"api_format": endpoint.api_format, logger.error(f"Failed to decrypt API key: {e}")
"extra_headers": endpoint.headers, raise HTTPException(status_code=500, detail="Failed to decrypt API key")
})
break # 根据 Key 的 api_formats 找对应的 Endpoint
if endpoint_configs: key_formats = api_key.api_formats or []
break for fmt in key_formats:
endpoint = format_to_endpoint.get(fmt)
if endpoint:
endpoint_configs.append({
"api_key": api_key_value,
"base_url": endpoint.base_url,
"api_format": fmt,
"extra_headers": endpoint.headers,
})
if not endpoint_configs: if not endpoint_configs:
raise HTTPException(status_code=404, detail="API Key not found") raise HTTPException(
status_code=400,
detail="No matching endpoint found for this API Key's formats"
)
else: else:
# 遍历所有活跃端点,每个端点取第一个可用的 Key # 遍历所有活跃端点,每个端点找一个支持该格式的 Key
for endpoint in provider.endpoints: for endpoint in provider.endpoints:
if not endpoint.is_active or not endpoint.api_keys: if not endpoint.is_active:
continue continue
# 找第一个可用 Key # 找第一个支持该格式的可用 Key
for api_key in endpoint.api_keys: for api_key in provider.api_keys:
if api_key.is_active: if not api_key.is_active:
try: continue
api_key_value = crypto_service.decrypt(api_key.api_key) key_formats = api_key.api_formats or []
except Exception as e: if endpoint.api_format not in key_formats:
logger.error(f"Failed to decrypt API key: {e}") continue
continue # 尝试下一个 Key try:
endpoint_configs.append({ api_key_value = crypto_service.decrypt(api_key.api_key)
"api_key": api_key_value, except Exception as e:
"base_url": endpoint.base_url, logger.error(f"Failed to decrypt API key: {e}")
"api_format": endpoint.api_format, continue
"extra_headers": endpoint.headers, endpoint_configs.append({
}) "api_key": api_key_value,
break # 只取第一个可用的 Key "base_url": endpoint.base_url,
"api_format": endpoint.api_format,
"extra_headers": endpoint.headers,
})
break # 只取第一个可用的 Key
if not endpoint_configs: if not endpoint_configs:
raise HTTPException(status_code=400, detail="No active API Key found for this provider") raise HTTPException(status_code=400, detail="No active API Key found for this provider")
@@ -214,7 +239,6 @@ async def query_available_models(
"provider": { "provider": {
"id": provider.id, "id": provider.id,
"name": provider.name, "name": provider.name,
"display_name": provider.display_name,
}, },
} }
@@ -229,17 +253,14 @@ async def test_model(
测试模型连接性 测试模型连接性
向指定提供商的指定模型发送测试请求,验证模型是否可用 向指定提供商的指定模型发送测试请求,验证模型是否可用
Args:
request: 测试请求
Returns:
测试结果
""" """
# 获取提供商及其端点 # 获取提供商及其端点和 Keys
provider = ( provider = (
db.query(Provider) db.query(Provider)
.options(joinedload(Provider.endpoints).joinedload(ProviderEndpoint.api_keys)) .options(
joinedload(Provider.endpoints),
joinedload(Provider.api_keys),
)
.filter(Provider.id == request.provider_id) .filter(Provider.id == request.provider_id)
.first() .first()
) )
@@ -247,28 +268,38 @@ async def test_model(
if not provider: if not provider:
raise HTTPException(status_code=404, detail="Provider not found") raise HTTPException(status_code=404, detail="Provider not found")
# 找到合适的端点和API Key # 构建 api_format -> endpoint 映射
endpoint_config = None format_to_endpoint: dict[str, ProviderEndpoint] = {}
for ep in provider.endpoints:
if ep.is_active:
format_to_endpoint[ep.api_format] = ep
# 找到合适的端点和 API Key
endpoint = None endpoint = None
api_key = None api_key = None
if request.api_key_id: if request.api_key_id:
# 使用指定的API Key # 使用指定的 API Key
for ep in provider.endpoints: api_key = next(
for key in ep.api_keys: (key for key in provider.api_keys if key.id == request.api_key_id and key.is_active),
if key.id == request.api_key_id and key.is_active and ep.is_active: None
endpoint = ep )
api_key = key if api_key:
# 找到该 Key 支持的第一个活跃 Endpoint
for fmt in (api_key.api_formats or []):
if fmt in format_to_endpoint:
endpoint = format_to_endpoint[fmt]
break break
if endpoint:
break
else: else:
# 使用第一个可用的端点和密钥 # 使用第一个可用的端点和密钥
for ep in provider.endpoints: for ep in provider.endpoints:
if not ep.is_active or not ep.api_keys: if not ep.is_active:
continue continue
for key in ep.api_keys: # 找支持该格式的第一个可用 Key
if key.is_active: for key in provider.api_keys:
if not key.is_active:
continue
if ep.api_format in (key.api_formats or []):
endpoint = ep endpoint = ep
api_key = key api_key = key
break break
@@ -284,14 +315,14 @@ async def test_model(
logger.error(f"[test-model] Failed to decrypt API key: {e}") logger.error(f"[test-model] Failed to decrypt API key: {e}")
raise HTTPException(status_code=500, detail="Failed to decrypt API key") raise HTTPException(status_code=500, detail="Failed to decrypt API key")
# 构建请求配置 # 构建请求配置timeout 从 Provider 读取)
endpoint_config = { endpoint_config = {
"api_key": api_key_value, "api_key": api_key_value,
"api_key_id": api_key.id, # 添加API Key ID用于用量记录 "api_key_id": api_key.id, # 添加API Key ID用于用量记录
"base_url": endpoint.base_url, "base_url": endpoint.base_url,
"api_format": endpoint.api_format, "api_format": endpoint.api_format,
"extra_headers": endpoint.headers, "extra_headers": endpoint.headers,
"timeout": endpoint.timeout or 30.0, "timeout": provider.timeout or TimeoutDefaults.HTTP_REQUEST,
} }
try: try:
@@ -304,7 +335,6 @@ async def test_model(
"provider": { "provider": {
"id": provider.id, "id": provider.id,
"name": provider.name, "name": provider.name,
"display_name": provider.display_name,
}, },
"model": request.model_name, "model": request.model_name,
} }
@@ -325,7 +355,6 @@ async def test_model(
"provider": { "provider": {
"id": provider.id, "id": provider.id,
"name": provider.name, "name": provider.name,
"display_name": provider.display_name,
}, },
"model": request.model_name, "model": request.model_name,
} }
@@ -415,7 +444,6 @@ async def test_model(
"provider": { "provider": {
"id": provider.id, "id": provider.id,
"name": provider.name, "name": provider.name,
"display_name": provider.display_name,
}, },
"model": request.model_name, "model": request.model_name,
"endpoint": { "endpoint": {
@@ -433,7 +461,6 @@ async def test_model(
"provider": { "provider": {
"id": provider.id, "id": provider.id,
"name": provider.name, "name": provider.name,
"display_name": provider.display_name,
}, },
"model": request.model_name, "model": request.model_name,
"endpoint": { "endpoint": {

View File

@@ -78,7 +78,7 @@ async def get_provider_stats(
""" """
获取提供商统计数据 获取提供商统计数据
获取指定提供商的计费信息、RPM 使用情况和使用统计数据。 获取指定提供商的计费信息和使用统计数据。
**路径参数**: **路径参数**:
- `provider_id`: 提供商 ID - `provider_id`: 提供商 ID
@@ -96,10 +96,6 @@ async def get_provider_stats(
- `monthly_used_usd`: 月度已使用 - `monthly_used_usd`: 月度已使用
- `quota_remaining_usd`: 剩余配额 - `quota_remaining_usd`: 剩余配额
- `quota_expires_at`: 配额过期时间 - `quota_expires_at`: 配额过期时间
- `rpm_info`: RPM 信息
- `rpm_limit`: RPM 限制
- `rpm_used`: 已使用 RPM
- `rpm_reset_at`: RPM 重置时间
- `usage_stats`: 使用统计 - `usage_stats`: 使用统计
- `total_requests`: 总请求数 - `total_requests`: 总请求数
- `successful_requests`: 成功请求数 - `successful_requests`: 成功请求数
@@ -165,7 +161,6 @@ class AdminProviderBillingAdapter(AdminApiAdapter):
provider.billing_type = config.billing_type provider.billing_type = config.billing_type
provider.monthly_quota_usd = config.monthly_quota_usd provider.monthly_quota_usd = config.monthly_quota_usd
provider.quota_reset_day = config.quota_reset_day provider.quota_reset_day = config.quota_reset_day
provider.rpm_limit = config.rpm_limit
provider.provider_priority = config.provider_priority provider.provider_priority = config.provider_priority
from dateutil import parser from dateutil import parser
@@ -262,13 +257,6 @@ class AdminProviderStatsAdapter(AdminApiAdapter):
provider.quota_expires_at.isoformat() if provider.quota_expires_at else None provider.quota_expires_at.isoformat() if provider.quota_expires_at else None
), ),
}, },
"rpm_info": {
"rpm_limit": provider.rpm_limit,
"rpm_used": provider.rpm_used,
"rpm_reset_at": (
provider.rpm_reset_at.isoformat() if provider.rpm_reset_at else None
),
},
"usage_stats": { "usage_stats": {
"total_requests": total_requests, "total_requests": total_requests,
"successful_requests": total_success, "successful_requests": total_success,
@@ -296,8 +284,6 @@ class AdminProviderResetQuotaAdapter(AdminApiAdapter):
old_used = provider.monthly_used_usd old_used = provider.monthly_used_usd
provider.monthly_used_usd = 0.0 provider.monthly_used_usd = 0.0
provider.rpm_used = 0
provider.rpm_reset_at = None
db.commit() db.commit()
logger.info(f"Manually reset quota for provider {provider.name}") logger.info(f"Manually reset quota for provider {provider.name}")

View File

@@ -338,27 +338,29 @@ async def import_models_from_upstream(
""" """
从上游提供商导入模型 从上游提供商导入模型
从上游提供商导入模型列表。如果全局模型不存在,将自动创建。 从上游提供商导入模型列表。导入的模型作为独立的 ProviderModel 存储,
不会自动创建 GlobalModel。后续需要手动关联 GlobalModel 才能参与路由。
**流程说明**: **流程说明**:
1. 根据 model_ids 检查全局模型是否存在(按 name 匹配) 1. 检查模型是否存在于当前 Provider按 provider_model_name 匹配)
2. 如不存在,自动创建新的 GlobalModel(使用默认免费配置 2. 创建新的 ProviderModelglobal_model_id = NULL
3. 创建 Model 关联到当前 Provider 3. 支持设置价格覆盖tiered_pricing, price_per_request
4. 如模型已关联,则记录到成功列表中
**路径参数**: **路径参数**:
- `provider_id`: 提供商 ID - `provider_id`: 提供商 ID
**请求体字段**: **请求体字段**:
- `model_ids`: 模型 ID 数组(必填,每个 ID 长度 1-100 字符) - `model_ids`: 模型 ID 数组(必填,每个 ID 长度 1-100 字符)
- `tiered_pricing`: 可选的阶梯计费配置(应用于所有导入的模型)
- `price_per_request`: 可选的按次计费价格(应用于所有导入的模型)
**返回字段**: **返回字段**:
- `success`: 成功导入的模型数组,每项包含: - `success`: 成功导入的模型数组,每项包含:
- `model_id`: 模型 ID - `model_id`: 模型 ID
- `global_model_id`: 全局模型 ID
- `global_model_name`: 全局模型名称
- `provider_model_id`: 提供商模型 ID - `provider_model_id`: 提供商模型 ID
- `created_global_model`: 是否新创建了全局模型 - `global_model_id`: 全局模型 ID如果已关联
- `global_model_name`: 全局模型名称(如果已关联)
- `created_global_model`: 是否新创建了全局模型(始终为 false
- `errors`: 失败的模型数组,每项包含: - `errors`: 失败的模型数组,每项包含:
- `model_id`: 模型 ID - `model_id`: 模型 ID
- `error`: 错误信息 - `error`: 错误信息
@@ -638,7 +640,7 @@ class AdminBatchAssignModelsToProviderAdapter(AdminApiAdapter):
@dataclass @dataclass
class AdminImportFromUpstreamAdapter(AdminApiAdapter): class AdminImportFromUpstreamAdapter(AdminApiAdapter):
"""从上游提供商导入模型""" """从上游提供商导入模型(不创建 GlobalModel作为独立 ProviderModel"""
provider_id: str provider_id: str
payload: ImportFromUpstreamRequest payload: ImportFromUpstreamRequest
@@ -652,16 +654,13 @@ class AdminImportFromUpstreamAdapter(AdminApiAdapter):
success: list[ImportFromUpstreamSuccessItem] = [] success: list[ImportFromUpstreamSuccessItem] = []
errors: list[ImportFromUpstreamErrorItem] = [] errors: list[ImportFromUpstreamErrorItem] = []
# 默认阶梯计费配置(免费) # 获取价格覆盖配置
default_tiered_pricing = { tiered_pricing = None
"tiers": [ price_per_request = None
{ if hasattr(self.payload, 'tiered_pricing') and self.payload.tiered_pricing:
"up_to": None, tiered_pricing = self.payload.tiered_pricing
"input_price_per_1m": 0.0, if hasattr(self.payload, 'price_per_request') and self.payload.price_per_request is not None:
"output_price_per_1m": 0.0, price_per_request = self.payload.price_per_request
}
]
}
for model_id in self.payload.model_ids: for model_id in self.payload.model_ids:
# 输入验证:检查 model_id 长度 # 输入验证:检查 model_id 长度
@@ -678,56 +677,37 @@ class AdminImportFromUpstreamAdapter(AdminApiAdapter):
# 使用 savepoint 确保单个模型导入的原子性 # 使用 savepoint 确保单个模型导入的原子性
savepoint = db.begin_nested() savepoint = db.begin_nested()
try: try:
# 1. 检查是否已存在同名的 GlobalModel # 1. 检查是否已存在同名的 ProviderModel
global_model = (
db.query(GlobalModel).filter(GlobalModel.name == model_id).first()
)
created_global_model = False
if not global_model:
# 2. 创建新的 GlobalModel
global_model = GlobalModel(
name=model_id,
display_name=model_id,
default_tiered_pricing=default_tiered_pricing,
is_active=True,
)
db.add(global_model)
db.flush()
created_global_model = True
logger.info(
f"Created new GlobalModel: {model_id} during upstream import"
)
# 3. 检查是否已存在关联
existing = ( existing = (
db.query(Model) db.query(Model)
.filter( .filter(
Model.provider_id == self.provider_id, Model.provider_id == self.provider_id,
Model.global_model_id == global_model.id, Model.provider_model_name == model_id,
) )
.first() .first()
) )
if existing: if existing:
# 已存在关联,提交 savepoint 并记录成功 # 已存在,提交 savepoint 并记录成功
savepoint.commit() savepoint.commit()
success.append( success.append(
ImportFromUpstreamSuccessItem( ImportFromUpstreamSuccessItem(
model_id=model_id, model_id=model_id,
global_model_id=global_model.id, global_model_id=existing.global_model_id or "",
global_model_name=global_model.name, global_model_name=existing.global_model.name if existing.global_model else "",
provider_model_id=existing.id, provider_model_id=existing.id,
created_global_model=created_global_model, created_global_model=False,
) )
) )
continue continue
# 4. 创建新的 Model 记录 # 2. 创建新的 Model 记录(不关联 GlobalModel
new_model = Model( new_model = Model(
provider_id=self.provider_id, provider_id=self.provider_id,
global_model_id=global_model.id, global_model_id=None, # 独立模型,不关联 GlobalModel
provider_model_name=global_model.name, provider_model_name=model_id,
is_active=True, is_active=True,
tiered_pricing=tiered_pricing,
price_per_request=price_per_request,
) )
db.add(new_model) db.add(new_model)
db.flush() db.flush()
@@ -737,12 +717,15 @@ class AdminImportFromUpstreamAdapter(AdminApiAdapter):
success.append( success.append(
ImportFromUpstreamSuccessItem( ImportFromUpstreamSuccessItem(
model_id=model_id, model_id=model_id,
global_model_id=global_model.id, global_model_id="", # 未关联
global_model_name=global_model.name, global_model_name="", # 未关联
provider_model_id=new_model.id, provider_model_id=new_model.id,
created_global_model=created_global_model, created_global_model=False,
) )
) )
logger.info(
f"Created independent ProviderModel: {model_id} for provider {provider.name}"
)
except Exception as e: except Exception as e:
# 回滚到 savepoint # 回滚到 savepoint
savepoint.rollback() savepoint.rollback()
@@ -753,11 +736,9 @@ class AdminImportFromUpstreamAdapter(AdminApiAdapter):
db.commit() db.commit()
logger.info( logger.info(
f"Imported {len(success)} models from upstream to provider {provider.name} by {context.user.username}" f"Imported {len(success)} independent models to provider {provider.name} by {context.user.username}"
) )
# 清除 /v1/models 列表缓存 # 不需要清除 /v1/models 缓存,因为独立模型不参与路由
if success:
await invalidate_models_list_cache()
return ImportFromUpstreamResponse(success=success, errors=errors) return ImportFromUpstreamResponse(success=success, errors=errors)

View File

@@ -11,9 +11,11 @@ from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline from src.api.base.pipeline import ApiRequestPipeline
from src.core.enums import ProviderBillingType from src.core.enums import ProviderBillingType
from src.core.exceptions import InvalidRequestException, NotFoundException from src.core.exceptions import InvalidRequestException, NotFoundException
from src.core.logger import logger
from src.database import get_db from src.database import get_db
from src.models.admin_requests import CreateProviderRequest, UpdateProviderRequest from src.models.admin_requests import CreateProviderRequest, UpdateProviderRequest
from src.models.database import Provider from src.models.database import Provider
from src.services.cache.provider_cache import ProviderCacheService
router = APIRouter(tags=["Provider CRUD"]) router = APIRouter(tags=["Provider CRUD"])
pipeline = ApiRequestPipeline() pipeline = ApiRequestPipeline()
@@ -39,8 +41,7 @@ async def list_providers(
**返回字段**: **返回字段**:
- `id`: 提供商 ID - `id`: 提供商 ID
- `name`: 提供商名称(唯一标识 - `name`: 提供商名称(唯一)
- `display_name`: 显示名称
- `api_format`: API 格式(如 claude、openai、gemini 等) - `api_format`: API 格式(如 claude、openai、gemini 等)
- `base_url`: API 基础 URL - `base_url`: API 基础 URL
- `api_key`: API 密钥(脱敏显示) - `api_key`: API 密钥(脱敏显示)
@@ -61,8 +62,7 @@ async def create_provider(request: Request, db: Session = Depends(get_db)):
创建一个新的 AI 模型提供商配置。 创建一个新的 AI 模型提供商配置。
**请求体字段**: **请求体字段**:
- `name`: 提供商名称(必填,唯一,用于系统标识 - `name`: 提供商名称(必填,唯一)
- `display_name`: 显示名称(必填)
- `description`: 描述信息(可选) - `description`: 描述信息(可选)
- `website`: 官网地址(可选) - `website`: 官网地址(可选)
- `billing_type`: 计费类型可选pay_as_you_go/subscription/prepaid默认 pay_as_you_go - `billing_type`: 计费类型可选pay_as_you_go/subscription/prepaid默认 pay_as_you_go
@@ -70,16 +70,17 @@ async def create_provider(request: Request, db: Session = Depends(get_db)):
- `quota_reset_day`: 配额重置日期1-31可选 - `quota_reset_day`: 配额重置日期1-31可选
- `quota_last_reset_at`: 上次配额重置时间(可选) - `quota_last_reset_at`: 上次配额重置时间(可选)
- `quota_expires_at`: 配额过期时间(可选) - `quota_expires_at`: 配额过期时间(可选)
- `rpm_limit`: 每分钟请求数限制(可选)
- `provider_priority`: 提供商优先级(数字越小优先级越高,默认 100 - `provider_priority`: 提供商优先级(数字越小优先级越高,默认 100
- `is_active`: 是否启用(默认 true - `is_active`: 是否启用(默认 true
- `concurrent_limit`: 并发限制(可选) - `concurrent_limit`: 并发限制(可选)
- `timeout`: 请求超时(秒,可选)
- `max_retries`: 最大重试次数(可选)
- `proxy`: 代理配置(可选)
- `config`: 额外配置信息JSON可选 - `config`: 额外配置信息JSON可选
**返回字段**: **返回字段**:
- `id`: 新创建的提供商 ID - `id`: 新创建的提供商 ID
- `name`: 提供商名称 - `name`: 提供商名称
- `display_name`: 显示名称
- `message`: 成功提示信息 - `message`: 成功提示信息
""" """
adapter = AdminCreateProviderAdapter() adapter = AdminCreateProviderAdapter()
@@ -98,7 +99,6 @@ async def update_provider(provider_id: str, request: Request, db: Session = Depe
**请求体字段**(所有字段可选): **请求体字段**(所有字段可选):
- `name`: 提供商名称 - `name`: 提供商名称
- `display_name`: 显示名称
- `description`: 描述信息 - `description`: 描述信息
- `website`: 官网地址 - `website`: 官网地址
- `billing_type`: 计费类型pay_as_you_go/subscription/prepaid - `billing_type`: 计费类型pay_as_you_go/subscription/prepaid
@@ -106,10 +106,12 @@ async def update_provider(provider_id: str, request: Request, db: Session = Depe
- `quota_reset_day`: 配额重置日期1-31 - `quota_reset_day`: 配额重置日期1-31
- `quota_last_reset_at`: 上次配额重置时间 - `quota_last_reset_at`: 上次配额重置时间
- `quota_expires_at`: 配额过期时间 - `quota_expires_at`: 配额过期时间
- `rpm_limit`: 每分钟请求数限制
- `provider_priority`: 提供商优先级 - `provider_priority`: 提供商优先级
- `is_active`: 是否启用 - `is_active`: 是否启用
- `concurrent_limit`: 并发限制 - `concurrent_limit`: 并发限制
- `timeout`: 请求超时(秒)
- `max_retries`: 最大重试次数
- `proxy`: 代理配置
- `config`: 额外配置信息JSON - `config`: 额外配置信息JSON
**返回字段**: **返回字段**:
@@ -163,7 +165,6 @@ class AdminListProvidersAdapter(AdminApiAdapter):
{ {
"id": provider.id, "id": provider.id,
"name": provider.name, "name": provider.name,
"display_name": provider.display_name,
"api_format": api_format.value if api_format else None, "api_format": api_format.value if api_format else None,
"base_url": base_url, "base_url": base_url,
"api_key": "***" if api_key else None, "api_key": "***" if api_key else None,
@@ -215,7 +216,6 @@ class AdminCreateProviderAdapter(AdminApiAdapter):
# 创建 Provider 对象 # 创建 Provider 对象
provider = Provider( provider = Provider(
name=validated_data.name, name=validated_data.name,
display_name=validated_data.display_name,
description=validated_data.description, description=validated_data.description,
website=validated_data.website, website=validated_data.website,
billing_type=billing_type, billing_type=billing_type,
@@ -223,10 +223,12 @@ class AdminCreateProviderAdapter(AdminApiAdapter):
quota_reset_day=validated_data.quota_reset_day, quota_reset_day=validated_data.quota_reset_day,
quota_last_reset_at=validated_data.quota_last_reset_at, quota_last_reset_at=validated_data.quota_last_reset_at,
quota_expires_at=validated_data.quota_expires_at, quota_expires_at=validated_data.quota_expires_at,
rpm_limit=validated_data.rpm_limit,
provider_priority=validated_data.provider_priority, provider_priority=validated_data.provider_priority,
is_active=validated_data.is_active, is_active=validated_data.is_active,
concurrent_limit=validated_data.concurrent_limit, concurrent_limit=validated_data.concurrent_limit,
timeout=validated_data.timeout,
max_retries=validated_data.max_retries,
proxy=validated_data.proxy.model_dump() if validated_data.proxy else None,
config=validated_data.config, config=validated_data.config,
) )
@@ -246,7 +248,6 @@ class AdminCreateProviderAdapter(AdminApiAdapter):
return { return {
"id": provider.id, "id": provider.id,
"name": provider.name, "name": provider.name,
"display_name": provider.display_name,
"message": "提供商创建成功", "message": "提供商创建成功",
} }
except InvalidRequestException: except InvalidRequestException:
@@ -289,6 +290,9 @@ class AdminUpdateProviderAdapter(AdminApiAdapter):
if field == "billing_type" and value is not None: if field == "billing_type" and value is not None:
# billing_type 需要转换为枚举 # billing_type 需要转换为枚举
setattr(provider, field, ProviderBillingType(value)) setattr(provider, field, ProviderBillingType(value))
elif field == "proxy" and value is not None:
# proxy 需要转换为 dict如果是 Pydantic 模型)
setattr(provider, field, value if isinstance(value, dict) else value.model_dump())
else: else:
setattr(provider, field, value) setattr(provider, field, value)
@@ -296,6 +300,11 @@ class AdminUpdateProviderAdapter(AdminApiAdapter):
db.commit() db.commit()
db.refresh(provider) db.refresh(provider)
# 如果更新了 billing_type清除缓存
if "billing_type" in update_data:
await ProviderCacheService.invalidate_provider_cache(provider.id)
logger.debug(f"已清除 Provider 缓存: {provider.id}")
context.add_audit_metadata( context.add_audit_metadata(
action="update_provider", action="update_provider",
provider_id=provider.id, provider_id=provider.id,

View File

@@ -48,7 +48,6 @@ async def get_providers_summary(
**返回字段**(数组,每项包含): **返回字段**(数组,每项包含):
- `id`: 提供商 ID - `id`: 提供商 ID
- `name`: 提供商名称 - `name`: 提供商名称
- `display_name`: 显示名称
- `description`: 描述信息 - `description`: 描述信息
- `website`: 官网地址 - `website`: 官网地址
- `provider_priority`: 优先级 - `provider_priority`: 优先级
@@ -59,9 +58,9 @@ async def get_providers_summary(
- `quota_reset_day`: 配额重置日期 - `quota_reset_day`: 配额重置日期
- `quota_last_reset_at`: 上次配额重置时间 - `quota_last_reset_at`: 上次配额重置时间
- `quota_expires_at`: 配额过期时间 - `quota_expires_at`: 配额过期时间
- `rpm_limit`: RPM 限制 - `timeout`: 默认请求超时(秒)
- `rpm_used`: 已使用 RPM - `max_retries`: 默认最大重试次数
- `rpm_reset_at`: RPM 重置时间 - `proxy`: 默认代理配置
- `total_endpoints`: 端点总数 - `total_endpoints`: 端点总数
- `active_endpoints`: 活跃端点数 - `active_endpoints`: 活跃端点数
- `total_keys`: 密钥总数 - `total_keys`: 密钥总数
@@ -96,7 +95,6 @@ async def get_provider_summary(
**返回字段**: **返回字段**:
- `id`: 提供商 ID - `id`: 提供商 ID
- `name`: 提供商名称 - `name`: 提供商名称
- `display_name`: 显示名称
- `description`: 描述信息 - `description`: 描述信息
- `website`: 官网地址 - `website`: 官网地址
- `provider_priority`: 优先级 - `provider_priority`: 优先级
@@ -107,9 +105,9 @@ async def get_provider_summary(
- `quota_reset_day`: 配额重置日期 - `quota_reset_day`: 配额重置日期
- `quota_last_reset_at`: 上次配额重置时间 - `quota_last_reset_at`: 上次配额重置时间
- `quota_expires_at`: 配额过期时间 - `quota_expires_at`: 配额过期时间
- `rpm_limit`: RPM 限制 - `timeout`: 默认请求超时(秒)
- `rpm_used`: 已使用 RPM - `max_retries`: 默认最大重试次数
- `rpm_reset_at`: RPM 重置时间 - `proxy`: 默认代理配置
- `total_endpoints`: 端点总数 - `total_endpoints`: 端点总数
- `active_endpoints`: 活跃端点数 - `active_endpoints`: 活跃端点数
- `total_keys`: 密钥总数 - `total_keys`: 密钥总数
@@ -185,13 +183,13 @@ async def update_provider_settings(
""" """
更新提供商基础配置 更新提供商基础配置
更新提供商的基础配置信息,如显示名称、描述、优先级等。只需传入需要更新的字段。 更新提供商的基础配置信息,如名称、描述、优先级等。只需传入需要更新的字段。
**路径参数**: **路径参数**:
- `provider_id`: 提供商 ID - `provider_id`: 提供商 ID
**请求体字段**(所有字段可选): **请求体字段**(所有字段可选):
- `display_name`: 显示名称 - `name`: 提供商名称
- `description`: 描述信息 - `description`: 描述信息
- `website`: 官网地址 - `website`: 官网地址
- `provider_priority`: 优先级 - `provider_priority`: 优先级
@@ -199,9 +197,10 @@ async def update_provider_settings(
- `billing_type`: 计费类型 - `billing_type`: 计费类型
- `monthly_quota_usd`: 月度配额(美元) - `monthly_quota_usd`: 月度配额(美元)
- `quota_reset_day`: 配额重置日期 - `quota_reset_day`: 配额重置日期
- `quota_last_reset_at`: 上次配额重置时间
- `quota_expires_at`: 配额过期时间 - `quota_expires_at`: 配额过期时间
- `rpm_limit`: RPM 限制 - `timeout`: 默认请求超时(秒)
- `max_retries`: 默认最大重试次数
- `proxy`: 默认代理配置
**返回字段**: 返回更新后的提供商摘要信息(与 GET /summary 接口返回格式相同) **返回字段**: 返回更新后的提供商摘要信息(与 GET /summary 接口返回格式相同)
""" """
@@ -215,18 +214,18 @@ def _build_provider_summary(db: Session, provider: Provider) -> ProviderWithEndp
total_endpoints = len(endpoints) total_endpoints = len(endpoints)
active_endpoints = sum(1 for e in endpoints if e.is_active) active_endpoints = sum(1 for e in endpoints if e.is_active)
endpoint_ids = [e.id for e in endpoints]
# Key 统计(合并为单个查询) # Key 统计(合并为单个查询)
total_keys = 0 key_stats = (
active_keys = 0 db.query(
if endpoint_ids:
key_stats = db.query(
func.count(ProviderAPIKey.id).label("total"), func.count(ProviderAPIKey.id).label("total"),
func.sum(case((ProviderAPIKey.is_active == True, 1), else_=0)).label("active"), func.sum(case((ProviderAPIKey.is_active == True, 1), else_=0)).label("active"),
).filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids)).first() )
total_keys = key_stats.total or 0 .filter(ProviderAPIKey.provider_id == provider.id)
active_keys = int(key_stats.active or 0) .first()
)
total_keys = key_stats.total or 0
active_keys = int(key_stats.active or 0)
# Model 统计(合并为单个查询) # Model 统计(合并为单个查询)
model_stats = db.query( model_stats = db.query(
@@ -238,25 +237,34 @@ def _build_provider_summary(db: Session, provider: Provider) -> ProviderWithEndp
api_formats = [e.api_format for e in endpoints] api_formats = [e.api_format for e in endpoints]
# 优化: 一次性加载所有 endpoint 的 keys避免 N+1 查询 # 优化: 一次性加载 Provider 的 keys避免 N+1 查询
all_keys = [] all_keys = db.query(ProviderAPIKey).filter(ProviderAPIKey.provider_id == provider.id).all()
if endpoint_ids:
all_keys = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids)).all()
)
# 按 endpoint_id 分组 keys # 按 api_formats 分组 keys通过 api_formats 关联)
keys_by_endpoint: dict[str, list[ProviderAPIKey]] = {} format_to_endpoint_id: dict[str, str] = {e.api_format: e.id for e in endpoints}
keys_by_endpoint: dict[str, list[ProviderAPIKey]] = {e.id: [] for e in endpoints}
for key in all_keys: for key in all_keys:
if key.endpoint_id not in keys_by_endpoint: formats = key.api_formats or []
keys_by_endpoint[key.endpoint_id] = [] for fmt in formats:
keys_by_endpoint[key.endpoint_id].append(key) endpoint_id = format_to_endpoint_id.get(fmt)
if endpoint_id:
keys_by_endpoint[endpoint_id].append(key)
endpoint_health_map: dict[str, float] = {} endpoint_health_map: dict[str, float] = {}
for endpoint in endpoints: for endpoint in endpoints:
keys = keys_by_endpoint.get(endpoint.id, []) keys = keys_by_endpoint.get(endpoint.id, [])
if keys: if keys:
health_scores = [k.health_score for k in keys if k.health_score is not None] # 从 health_by_format 获取对应格式的健康度
api_fmt = endpoint.api_format
health_scores = []
for k in keys:
health_by_format = k.health_by_format or {}
if api_fmt in health_by_format:
score = health_by_format[api_fmt].get("health_score")
if score is not None:
health_scores.append(float(score))
else:
health_scores.append(1.0) # 默认健康度
avg_health = sum(health_scores) / len(health_scores) if health_scores else 1.0 avg_health = sum(health_scores) / len(health_scores) if health_scores else 1.0
endpoint_health_map[endpoint.id] = avg_health endpoint_health_map[endpoint.id] = avg_health
else: else:
@@ -284,7 +292,6 @@ def _build_provider_summary(db: Session, provider: Provider) -> ProviderWithEndp
return ProviderWithEndpointsSummary( return ProviderWithEndpointsSummary(
id=provider.id, id=provider.id,
name=provider.name, name=provider.name,
display_name=provider.display_name,
description=provider.description, description=provider.description,
website=provider.website, website=provider.website,
provider_priority=provider.provider_priority, provider_priority=provider.provider_priority,
@@ -295,9 +302,9 @@ def _build_provider_summary(db: Session, provider: Provider) -> ProviderWithEndp
quota_reset_day=provider.quota_reset_day, quota_reset_day=provider.quota_reset_day,
quota_last_reset_at=provider.quota_last_reset_at, quota_last_reset_at=provider.quota_last_reset_at,
quota_expires_at=provider.quota_expires_at, quota_expires_at=provider.quota_expires_at,
rpm_limit=provider.rpm_limit, timeout=provider.timeout,
rpm_used=provider.rpm_used, max_retries=provider.max_retries,
rpm_reset_at=provider.rpm_reset_at, proxy=provider.proxy,
total_endpoints=total_endpoints, total_endpoints=total_endpoints,
active_endpoints=active_endpoints, active_endpoints=active_endpoints,
total_keys=total_keys, total_keys=total_keys,
@@ -341,7 +348,7 @@ class AdminProviderHealthMonitorAdapter(AdminApiAdapter):
if not endpoint_ids: if not endpoint_ids:
response = ProviderEndpointHealthMonitorResponse( response = ProviderEndpointHealthMonitorResponse(
provider_id=provider.id, provider_id=provider.id,
provider_name=provider.display_name or provider.name, provider_name=provider.name,
generated_at=now, generated_at=now,
endpoints=[], endpoints=[],
) )
@@ -416,7 +423,7 @@ class AdminProviderHealthMonitorAdapter(AdminApiAdapter):
response = ProviderEndpointHealthMonitorResponse( response = ProviderEndpointHealthMonitorResponse(
provider_id=provider.id, provider_id=provider.id,
provider_name=provider.display_name or provider.name, provider_name=provider.name,
generated_at=now, generated_at=now,
endpoints=endpoint_monitors, endpoints=endpoint_monitors,
) )

View File

@@ -42,6 +42,42 @@ def _get_version_from_git() -> str | None:
return None return None
def _get_current_version() -> str:
"""获取当前版本号"""
version = _get_version_from_git()
if version:
return version
try:
from src._version import __version__
return __version__
except ImportError:
return "unknown"
def _parse_version(version_str: str) -> tuple:
"""解析版本号为可比较的元组,支持 3-4 段版本号
例如:
- '0.2.5' -> (0, 2, 5, 0)
- '0.2.5.1' -> (0, 2, 5, 1)
- 'v0.2.5-4-g1234567' -> (0, 2, 5, 0)
"""
import re
version_str = version_str.lstrip("v")
main_version = re.split(r"[-+]", version_str)[0]
try:
parts = main_version.split(".")
# 标准化为 4 段,便于比较
int_parts = [int(p) for p in parts]
while len(int_parts) < 4:
int_parts.append(0)
return tuple(int_parts[:4])
except ValueError:
return (0, 0, 0, 0)
@router.get("/version") @router.get("/version")
async def get_system_version(): async def get_system_version():
""" """
@@ -52,18 +88,111 @@ async def get_system_version():
**返回字段**: **返回字段**:
- `version`: 版本号字符串 - `version`: 版本号字符串
""" """
# 优先从 git 获取 return {"version": _get_current_version()}
version = _get_version_from_git()
if version:
return {"version": version} @router.get("/check-update")
async def check_update():
"""
检查系统更新
从 GitHub Tags 获取最新版本并与当前版本对比。
**返回字段**:
- `current_version`: 当前版本号
- `latest_version`: 最新版本号
- `has_update`: 是否有更新可用
- `release_url`: 最新版本的 GitHub 页面链接
"""
import httpx
from src.clients.http_client import HTTPClientPool
current_version = _get_current_version()
github_repo = "Aethersailor/Aether"
github_tags_url = f"https://api.github.com/repos/{github_repo}/tags"
# 回退到静态版本文件
try: try:
from src._version import __version__ async with HTTPClientPool.get_temp_client(
timeout=httpx.Timeout(connect=5.0, read=10.0, write=5.0, pool=5.0)
) as client:
response = await client.get(
github_tags_url,
headers={
"Accept": "application/vnd.github.v3+json",
"User-Agent": f"Aether/{current_version}",
},
params={"per_page": 10},
)
return {"version": __version__} if response.status_code != 200:
except ImportError: return {
return {"version": "unknown"} "current_version": current_version,
"latest_version": None,
"has_update": False,
"release_url": None,
"error": f"GitHub API 返回错误: {response.status_code}",
}
tags = response.json()
if not tags:
return {
"current_version": current_version,
"latest_version": None,
"has_update": False,
"release_url": None,
"error": None,
}
# 找到最新的版本 tag按版本号排序而非时间
version_tags = []
for tag in tags:
tag_name = tag.get("name", "")
if tag_name.startswith("v") or tag_name[0].isdigit():
version_tags.append((tag_name, _parse_version(tag_name)))
if not version_tags:
return {
"current_version": current_version,
"latest_version": None,
"has_update": False,
"release_url": None,
"error": None,
}
# 按版本号排序,取最大的
version_tags.sort(key=lambda x: x[1], reverse=True)
latest_tag = version_tags[0][0]
latest_version = latest_tag.lstrip("v")
current_tuple = _parse_version(current_version)
latest_tuple = _parse_version(latest_version)
has_update = latest_tuple > current_tuple
return {
"current_version": current_version,
"latest_version": latest_version,
"has_update": has_update,
"release_url": f"https://github.com/{github_repo}/releases/tag/{latest_tag}",
"error": None,
}
except httpx.TimeoutException:
return {
"current_version": current_version,
"latest_version": None,
"has_update": False,
"release_url": None,
"error": "检查更新超时",
}
except Exception as e:
return {
"current_version": current_version,
"latest_version": None,
"has_update": False,
"release_url": None,
"error": f"检查更新失败: {str(e)}",
}
pipeline = ApiRequestPipeline() pipeline = ApiRequestPipeline()
@@ -601,36 +730,6 @@ class AdminExportConfigAdapter(AdminApiAdapter):
) )
endpoints_data = [] endpoints_data = []
for ep in endpoints: for ep in endpoints:
# 导出 Endpoint Keys
keys = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == ep.id).all()
)
keys_data = []
for key in keys:
# 解密 API Key
try:
decrypted_key = crypto_service.decrypt(key.api_key)
except Exception:
decrypted_key = ""
keys_data.append(
{
"api_key": decrypted_key,
"name": key.name,
"note": key.note,
"rate_multiplier": key.rate_multiplier,
"internal_priority": key.internal_priority,
"global_priority": key.global_priority,
"max_concurrent": key.max_concurrent,
"rate_limit": key.rate_limit,
"daily_limit": key.daily_limit,
"monthly_limit": key.monthly_limit,
"allowed_models": key.allowed_models,
"capabilities": key.capabilities,
"is_active": key.is_active,
}
)
endpoints_data.append( endpoints_data.append(
{ {
"api_format": ep.api_format, "api_format": ep.api_format,
@@ -638,12 +737,44 @@ class AdminExportConfigAdapter(AdminApiAdapter):
"headers": ep.headers, "headers": ep.headers,
"timeout": ep.timeout, "timeout": ep.timeout,
"max_retries": ep.max_retries, "max_retries": ep.max_retries,
"max_concurrent": ep.max_concurrent,
"rate_limit": ep.rate_limit,
"is_active": ep.is_active, "is_active": ep.is_active,
"custom_path": ep.custom_path, "custom_path": ep.custom_path,
"config": ep.config, "config": ep.config,
"keys": keys_data, "proxy": ep.proxy,
}
)
# 导出 Provider Keys按 provider_id 归属,包含 api_formats
keys = (
db.query(ProviderAPIKey)
.filter(ProviderAPIKey.provider_id == provider.id)
.order_by(ProviderAPIKey.internal_priority.asc(), ProviderAPIKey.created_at.asc())
.all()
)
keys_data = []
for key in keys:
# 解密 API Key
try:
decrypted_key = crypto_service.decrypt(key.api_key)
except Exception:
decrypted_key = ""
keys_data.append(
{
"api_key": decrypted_key,
"name": key.name,
"note": key.note,
"api_formats": key.api_formats or [],
"rate_multiplier": key.rate_multiplier,
"rate_multipliers": key.rate_multipliers,
"internal_priority": key.internal_priority,
"global_priority": key.global_priority,
"rpm_limit": key.rpm_limit,
"allowed_models": key.allowed_models,
"capabilities": key.capabilities,
"cache_ttl_minutes": key.cache_ttl_minutes,
"max_probe_interval_minutes": key.max_probe_interval_minutes,
"is_active": key.is_active,
} }
) )
@@ -675,24 +806,26 @@ class AdminExportConfigAdapter(AdminApiAdapter):
providers_data.append( providers_data.append(
{ {
"name": provider.name, "name": provider.name,
"display_name": provider.display_name,
"description": provider.description, "description": provider.description,
"website": provider.website, "website": provider.website,
"billing_type": provider.billing_type.value if provider.billing_type else None, "billing_type": provider.billing_type.value if provider.billing_type else None,
"monthly_quota_usd": provider.monthly_quota_usd, "monthly_quota_usd": provider.monthly_quota_usd,
"quota_reset_day": provider.quota_reset_day, "quota_reset_day": provider.quota_reset_day,
"rpm_limit": provider.rpm_limit,
"provider_priority": provider.provider_priority, "provider_priority": provider.provider_priority,
"is_active": provider.is_active, "is_active": provider.is_active,
"concurrent_limit": provider.concurrent_limit, "concurrent_limit": provider.concurrent_limit,
"timeout": provider.timeout,
"max_retries": provider.max_retries,
"proxy": provider.proxy,
"config": provider.config, "config": provider.config,
"endpoints": endpoints_data, "endpoints": endpoints_data,
"api_keys": keys_data,
"models": models_data, "models": models_data,
} }
) )
return { return {
"version": "1.0", "version": "2.0",
"exported_at": datetime.now(timezone.utc).isoformat(), "exported_at": datetime.now(timezone.utc).isoformat(),
"global_models": global_models_data, "global_models": global_models_data,
"providers": providers_data, "providers": providers_data,
@@ -721,7 +854,7 @@ class AdminImportConfigAdapter(AdminApiAdapter):
# 验证配置版本 # 验证配置版本
version = payload.get("version") version = payload.get("version")
if version != "1.0": if version != "2.0":
raise InvalidRequestException(f"不支持的配置版本: {version}") raise InvalidRequestException(f"不支持的配置版本: {version}")
# 获取导入选项 # 获取导入选项
@@ -810,8 +943,8 @@ class AdminImportConfigAdapter(AdminApiAdapter):
) )
elif merge_mode == "overwrite": elif merge_mode == "overwrite":
# 更新现有记录 # 更新现有记录
existing_provider.display_name = prov_data.get( existing_provider.name = prov_data.get(
"display_name", existing_provider.display_name "name", existing_provider.name
) )
existing_provider.description = prov_data.get("description") existing_provider.description = prov_data.get("description")
existing_provider.website = prov_data.get("website") existing_provider.website = prov_data.get("website")
@@ -825,7 +958,6 @@ class AdminImportConfigAdapter(AdminApiAdapter):
existing_provider.quota_reset_day = prov_data.get( existing_provider.quota_reset_day = prov_data.get(
"quota_reset_day", 30 "quota_reset_day", 30
) )
existing_provider.rpm_limit = prov_data.get("rpm_limit")
existing_provider.provider_priority = prov_data.get( existing_provider.provider_priority = prov_data.get(
"provider_priority", 100 "provider_priority", 100
) )
@@ -833,6 +965,11 @@ class AdminImportConfigAdapter(AdminApiAdapter):
existing_provider.concurrent_limit = prov_data.get( existing_provider.concurrent_limit = prov_data.get(
"concurrent_limit" "concurrent_limit"
) )
existing_provider.timeout = prov_data.get("timeout", existing_provider.timeout)
existing_provider.max_retries = prov_data.get(
"max_retries", existing_provider.max_retries
)
existing_provider.proxy = prov_data.get("proxy", existing_provider.proxy)
existing_provider.config = prov_data.get("config") existing_provider.config = prov_data.get("config")
existing_provider.updated_at = datetime.now(timezone.utc) existing_provider.updated_at = datetime.now(timezone.utc)
stats["providers"]["updated"] += 1 stats["providers"]["updated"] += 1
@@ -845,16 +982,17 @@ class AdminImportConfigAdapter(AdminApiAdapter):
new_provider = Provider( new_provider = Provider(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
name=prov_data["name"], name=prov_data["name"],
display_name=prov_data.get("display_name", prov_data["name"]),
description=prov_data.get("description"), description=prov_data.get("description"),
website=prov_data.get("website"), website=prov_data.get("website"),
billing_type=billing_type, billing_type=billing_type,
monthly_quota_usd=prov_data.get("monthly_quota_usd"), monthly_quota_usd=prov_data.get("monthly_quota_usd"),
quota_reset_day=prov_data.get("quota_reset_day", 30), quota_reset_day=prov_data.get("quota_reset_day", 30),
rpm_limit=prov_data.get("rpm_limit"),
provider_priority=prov_data.get("provider_priority", 100), provider_priority=prov_data.get("provider_priority", 100),
is_active=prov_data.get("is_active", True), is_active=prov_data.get("is_active", True),
concurrent_limit=prov_data.get("concurrent_limit"), concurrent_limit=prov_data.get("concurrent_limit"),
timeout=prov_data.get("timeout"),
max_retries=prov_data.get("max_retries"),
proxy=prov_data.get("proxy"),
config=prov_data.get("config"), config=prov_data.get("config"),
) )
db.add(new_provider) db.add(new_provider)
@@ -874,7 +1012,6 @@ class AdminImportConfigAdapter(AdminApiAdapter):
) )
if existing_ep: if existing_ep:
endpoint_id = existing_ep.id
if merge_mode == "skip": if merge_mode == "skip":
stats["endpoints"]["skipped"] += 1 stats["endpoints"]["skipped"] += 1
elif merge_mode == "error": elif merge_mode == "error":
@@ -887,12 +1024,11 @@ class AdminImportConfigAdapter(AdminApiAdapter):
) )
existing_ep.headers = ep_data.get("headers") existing_ep.headers = ep_data.get("headers")
existing_ep.timeout = ep_data.get("timeout", 300) existing_ep.timeout = ep_data.get("timeout", 300)
existing_ep.max_retries = ep_data.get("max_retries", 3) existing_ep.max_retries = ep_data.get("max_retries", 2)
existing_ep.max_concurrent = ep_data.get("max_concurrent")
existing_ep.rate_limit = ep_data.get("rate_limit")
existing_ep.is_active = ep_data.get("is_active", True) existing_ep.is_active = ep_data.get("is_active", True)
existing_ep.custom_path = ep_data.get("custom_path") existing_ep.custom_path = ep_data.get("custom_path")
existing_ep.config = ep_data.get("config") existing_ep.config = ep_data.get("config")
existing_ep.proxy = ep_data.get("proxy")
existing_ep.updated_at = datetime.now(timezone.utc) existing_ep.updated_at = datetime.now(timezone.utc)
stats["endpoints"]["updated"] += 1 stats["endpoints"]["updated"] += 1
else: else:
@@ -903,69 +1039,107 @@ class AdminImportConfigAdapter(AdminApiAdapter):
base_url=ep_data["base_url"], base_url=ep_data["base_url"],
headers=ep_data.get("headers"), headers=ep_data.get("headers"),
timeout=ep_data.get("timeout", 300), timeout=ep_data.get("timeout", 300),
max_retries=ep_data.get("max_retries", 3), max_retries=ep_data.get("max_retries", 2),
max_concurrent=ep_data.get("max_concurrent"),
rate_limit=ep_data.get("rate_limit"),
is_active=ep_data.get("is_active", True), is_active=ep_data.get("is_active", True),
custom_path=ep_data.get("custom_path"), custom_path=ep_data.get("custom_path"),
config=ep_data.get("config"), config=ep_data.get("config"),
proxy=ep_data.get("proxy"),
) )
db.add(new_ep) db.add(new_ep)
db.flush() db.flush()
endpoint_id = new_ep.id
stats["endpoints"]["created"] += 1 stats["endpoints"]["created"] += 1
# 导入 Keys # 导入 Provider Keys按 provider_id 归属)
# 获取当前 endpoint 下所有已有的 keys用于去重 endpoint_format_rows = (
existing_keys = ( db.query(ProviderEndpoint.api_format)
db.query(ProviderAPIKey) .filter(ProviderEndpoint.provider_id == provider_id)
.filter(ProviderAPIKey.endpoint_id == endpoint_id) .all()
.all() )
) endpoint_formats: set[str] = set()
# 解密已有 keys 用于比对 for (api_format,) in endpoint_format_rows:
existing_key_values = set() fmt = api_format.value if hasattr(api_format, "value") else str(api_format)
for ek in existing_keys: endpoint_formats.add(fmt.strip().upper())
try: existing_keys = (
decrypted = crypto_service.decrypt(ek.api_key) db.query(ProviderAPIKey)
existing_key_values.add(decrypted) .filter(ProviderAPIKey.provider_id == provider_id)
except Exception: .all()
pass )
existing_key_values = set()
for ek in existing_keys:
try:
decrypted = crypto_service.decrypt(ek.api_key)
existing_key_values.add(decrypted)
except Exception:
pass
for key_data in ep_data.get("keys", []): for key_data in prov_data.get("api_keys", []):
if not key_data.get("api_key"): if not key_data.get("api_key"):
stats["errors"].append( stats["errors"].append(
f"跳过空 API Key (Endpoint: {ep_data['api_format']})" f"跳过空 API Key (Provider: {prov_data['name']})"
)
continue
# 检查是否已存在相同的 Key通过明文比对
if key_data["api_key"] in existing_key_values:
stats["keys"]["skipped"] += 1
continue
encrypted_key = crypto_service.encrypt(key_data["api_key"])
new_key = ProviderAPIKey(
id=str(uuid.uuid4()),
endpoint_id=endpoint_id,
api_key=encrypted_key,
name=key_data.get("name"),
note=key_data.get("note"),
rate_multiplier=key_data.get("rate_multiplier", 1.0),
internal_priority=key_data.get("internal_priority", 100),
global_priority=key_data.get("global_priority"),
max_concurrent=key_data.get("max_concurrent"),
rate_limit=key_data.get("rate_limit"),
daily_limit=key_data.get("daily_limit"),
monthly_limit=key_data.get("monthly_limit"),
allowed_models=key_data.get("allowed_models"),
capabilities=key_data.get("capabilities"),
is_active=key_data.get("is_active", True),
) )
db.add(new_key) continue
# 添加到已有集合,防止同一批导入中重复
existing_key_values.add(key_data["api_key"]) plaintext_key = key_data["api_key"]
stats["keys"]["created"] += 1 if plaintext_key in existing_key_values:
stats["keys"]["skipped"] += 1
continue
raw_formats = key_data.get("api_formats") or []
if not isinstance(raw_formats, list) or len(raw_formats) == 0:
stats["errors"].append(
f"跳过无 api_formats 的 Key (Provider: {prov_data['name']})"
)
continue
normalized_formats: list[str] = []
seen: set[str] = set()
missing_formats: list[str] = []
for fmt in raw_formats:
if not isinstance(fmt, str):
continue
fmt_upper = fmt.strip().upper()
if not fmt_upper or fmt_upper in seen:
continue
seen.add(fmt_upper)
if endpoint_formats and fmt_upper not in endpoint_formats:
missing_formats.append(fmt_upper)
continue
normalized_formats.append(fmt_upper)
if missing_formats:
stats["errors"].append(
f"Key (Provider: {prov_data['name']}) 的 api_formats 未配置对应 Endpoint已跳过: {missing_formats}"
)
if len(normalized_formats) == 0:
stats["keys"]["skipped"] += 1
continue
encrypted_key = crypto_service.encrypt(plaintext_key)
new_key = ProviderAPIKey(
id=str(uuid.uuid4()),
provider_id=provider_id,
api_formats=normalized_formats,
api_key=encrypted_key,
name=key_data.get("name") or "Imported Key",
note=key_data.get("note"),
rate_multiplier=key_data.get("rate_multiplier", 1.0),
rate_multipliers=key_data.get("rate_multipliers"),
internal_priority=key_data.get("internal_priority", 50),
global_priority=key_data.get("global_priority"),
rpm_limit=key_data.get("rpm_limit"),
allowed_models=key_data.get("allowed_models"),
capabilities=key_data.get("capabilities"),
cache_ttl_minutes=key_data.get("cache_ttl_minutes", 5),
max_probe_interval_minutes=key_data.get("max_probe_interval_minutes", 32),
is_active=key_data.get("is_active", True),
health_by_format={},
circuit_breaker_by_format={},
)
db.add(new_key)
existing_key_values.add(plaintext_key)
stats["keys"]["created"] += 1
# 导入 Models # 导入 Models
for model_data in prov_data.get("models", []): for model_data in prov_data.get("models", []):

View File

@@ -247,7 +247,8 @@ async def get_usage_detail(
- `request_headers`: 请求头 - `request_headers`: 请求头
- `request_body`: 请求体 - `request_body`: 请求体
- `provider_request_headers`: 提供商请求头 - `provider_request_headers`: 提供商请求头
- `response_headers`: 响应头 - `response_headers`: 提供商响应头
- `client_response_headers`: 返回给客户端的响应头
- `response_body`: 响应体 - `response_body`: 响应体
- `metadata`: 提供商响应元数据 - `metadata`: 提供商响应元数据
- `tiered_pricing`: 阶梯计费信息(如适用) - `tiered_pricing`: 阶梯计费信息(如适用)
@@ -916,6 +917,7 @@ class AdminUsageDetailAdapter(AdminApiAdapter):
"request_body": usage_record.get_request_body(), "request_body": usage_record.get_request_body(),
"provider_request_headers": usage_record.provider_request_headers, "provider_request_headers": usage_record.provider_request_headers,
"response_headers": usage_record.response_headers, "response_headers": usage_record.response_headers,
"client_response_headers": usage_record.client_response_headers,
"response_body": usage_record.get_response_body(), "response_body": usage_record.get_response_body(),
"metadata": usage_record.request_metadata, "metadata": usage_record.request_metadata,
"tiered_pricing": tiered_pricing_info, "tiered_pricing": tiered_pricing_info,

View File

@@ -202,20 +202,59 @@ def get_available_provider_ids(db: Session, api_formats: list[str]) -> set[str]:
条件: 条件:
- 端点 api_format 匹配 - 端点 api_format 匹配
- 端点是活跃的 - 端点是活跃的
- 端点下有活跃的 Key - Provider 下有活跃的 Key 且支持该 api_formatKey 直属 Provider通过 api_formats 过滤)
""" """
rows = ( target_formats = {f.upper() for f in api_formats}
db.query(ProviderEndpoint.provider_id)
.join(ProviderAPIKey, ProviderAPIKey.endpoint_id == ProviderEndpoint.id) # 1) 先找出有活跃端点的 Provider记录每个 Provider 支持的格式集合)
endpoint_rows = (
db.query(ProviderEndpoint.provider_id, ProviderEndpoint.api_format)
.filter( .filter(
ProviderEndpoint.api_format.in_(api_formats), ProviderEndpoint.api_format.in_(list(target_formats)),
ProviderEndpoint.is_active.is_(True), ProviderEndpoint.is_active.is_(True),
ProviderAPIKey.is_active.is_(True),
) )
.distinct()
.all() .all()
) )
return {row[0] for row in rows}
if not endpoint_rows:
return set()
provider_to_formats: dict[str, set[str]] = {}
for provider_id, fmt in endpoint_rows:
if not provider_id or not fmt:
continue
provider_to_formats.setdefault(provider_id, set()).add(str(fmt).upper())
provider_ids_with_endpoints = set(provider_to_formats.keys())
if not provider_ids_with_endpoints:
return set()
# 2) 再检查这些 Provider 是否至少有一个活跃 Key 支持对应格式
key_rows = (
db.query(ProviderAPIKey.provider_id, ProviderAPIKey.api_formats)
.filter(
ProviderAPIKey.provider_id.in_(provider_ids_with_endpoints),
ProviderAPIKey.is_active.is_(True),
)
.all()
)
available_provider_ids: set[str] = set()
for provider_id, key_formats in key_rows:
if not provider_id:
continue
endpoint_formats = provider_to_formats.get(provider_id)
if not endpoint_formats:
continue
formats_list = key_formats if isinstance(key_formats, list) else []
key_formats_upper = {str(f).upper() for f in formats_list}
# 只有同时满足:请求格式 ∩ Provider 端点格式 ∩ Key 支持格式 非空,才算可用
if key_formats_upper & endpoint_formats & target_formats:
available_provider_ids.add(provider_id)
return available_provider_ids
def _get_available_model_ids_for_format(db: Session, api_formats: list[str]) -> set[str]: def _get_available_model_ids_for_format(db: Session, api_formats: list[str]) -> set[str]:
@@ -228,36 +267,64 @@ def _get_available_model_ids_for_format(db: Session, api_formats: list[str]) ->
3. **该端点的 Provider 关联了该模型** 3. **该端点的 Provider 关联了该模型**
4. Key 的 allowed_models 允许该模型null = 允许该 Provider 关联的所有模型) 4. Key 的 allowed_models 允许该模型null = 允许该 Provider 关联的所有模型)
""" """
# 查询所有匹配格式的活跃端点及其活跃 Key同时获取 endpoint_id target_formats = {f.upper() for f in api_formats}
endpoint_keys = (
db.query( # 1) 找出有活跃端点的 Provider记录每个 Provider 支持的格式集合)
ProviderEndpoint.id.label("endpoint_id"), endpoint_rows = (
ProviderEndpoint.provider_id, db.query(ProviderEndpoint.provider_id, ProviderEndpoint.api_format)
ProviderAPIKey.allowed_models,
)
.join(ProviderAPIKey, ProviderAPIKey.endpoint_id == ProviderEndpoint.id)
.filter( .filter(
ProviderEndpoint.api_format.in_(api_formats), ProviderEndpoint.api_format.in_(list(target_formats)),
ProviderEndpoint.is_active.is_(True), ProviderEndpoint.is_active.is_(True),
)
.all()
)
if not endpoint_rows:
return set()
provider_to_formats: dict[str, set[str]] = {}
for provider_id, fmt in endpoint_rows:
if not provider_id or not fmt:
continue
provider_to_formats.setdefault(provider_id, set()).add(str(fmt).upper())
provider_ids_with_endpoints = set(provider_to_formats.keys())
if not provider_ids_with_endpoints:
return set()
# 2) 收集每个 Provider 下「支持对应格式」的活跃 Key 的 allowed_models
# Key 直属 Provider通过 key.api_formats 与 Provider 端点格式交集筛选
key_rows = (
db.query(ProviderAPIKey.provider_id, ProviderAPIKey.allowed_models, ProviderAPIKey.api_formats)
.filter(
ProviderAPIKey.provider_id.in_(provider_ids_with_endpoints),
ProviderAPIKey.is_active.is_(True), ProviderAPIKey.is_active.is_(True),
) )
.all() .all()
) )
if not endpoint_keys: # provider_id -> list[(allowed_models, usable_formats)]
provider_key_rules: dict[str, list[tuple[object, set[str]]]] = {}
for provider_id, allowed_models, key_formats in key_rows:
if not provider_id:
continue
endpoint_formats = provider_to_formats.get(provider_id)
if not endpoint_formats:
continue
formats_list = key_formats if isinstance(key_formats, list) else []
key_formats_upper = {str(f).upper() for f in formats_list}
usable_formats = key_formats_upper & endpoint_formats & target_formats
if not usable_formats:
continue
provider_key_rules.setdefault(provider_id, []).append((allowed_models, usable_formats))
provider_ids_with_format = set(provider_key_rules.keys())
if not provider_ids_with_format:
return set() return set()
# 收集每个 (provider_id, endpoint_id) 对应的 allowed_models
# 使用 provider_id 作为 key因为模型是关联到 Provider 的
provider_allowed_models: dict[str, list[Optional[list[str]]]] = {}
provider_ids_with_format: set[str] = set()
for endpoint_id, provider_id, allowed_models in endpoint_keys:
provider_ids_with_format.add(provider_id)
if provider_id not in provider_allowed_models:
provider_allowed_models[provider_id] = []
provider_allowed_models[provider_id].append(allowed_models)
# 只查询那些有匹配格式端点的 Provider 下的模型 # 只查询那些有匹配格式端点的 Provider 下的模型
models = ( models = (
db.query(Model) db.query(Model)
@@ -285,22 +352,30 @@ def _get_available_model_ids_for_format(db: Session, api_formats: list[str]) ->
if model_provider_id not in provider_ids_with_format: if model_provider_id not in provider_ids_with_format:
continue continue
# 检查该 provider 下是否有 Key 允许这个模型 # 检查该 provider 下是否有 Key 允许这个模型(支持 list/dict 两种 allowed_models
allowed_lists = provider_allowed_models.get(model_provider_id, []) from src.core.model_permissions import check_model_allowed
for allowed_models in allowed_lists:
rules = provider_key_rules.get(model_provider_id, [])
for allowed_models, usable_formats in rules:
# None = 不限制
if allowed_models is None: if allowed_models is None:
# null = 允许该 Provider 关联的所有模型(已通过上面的查询限制)
available_model_ids.add(model_id)
break
elif model_id in allowed_models:
# 明确在允许列表中
available_model_ids.add(model_id)
break
elif global_model and model.provider_model_name in allowed_models:
# 也检查 provider_model_name
available_model_ids.add(model_id) available_model_ids.add(model_id)
break break
# 对于支持多个格式的 Key任意一个可用格式允许即可
for fmt in usable_formats:
if check_model_allowed(
model_name=model_id,
allowed_models=allowed_models, # type: ignore[arg-type]
api_format=fmt,
resolved_model_name=(model.provider_model_name if global_model else None),
):
available_model_ids.add(model_id)
break
else:
continue
break
return available_model_ids return available_model_ids

View File

@@ -155,7 +155,7 @@ async def get_daily_stats(
class DashboardAdapter(ApiAdapter): class DashboardAdapter(ApiAdapter):
"""需要登录的仪表盘适配器基类。""" """需要登录的仪表盘适配器基类。"""
mode = ApiMode.ADMIN mode = ApiMode.USER # 普通用户也可访问仪表盘
def authorize(self, context): # type: ignore[override] def authorize(self, context): # type: ignore[override]
if not context.user: if not context.user:

View File

@@ -98,6 +98,7 @@ class MessageTelemetry:
request_headers: Dict[str, Any], request_headers: Dict[str, Any],
response_body: Any, response_body: Any,
response_headers: Dict[str, Any], response_headers: Dict[str, Any],
client_response_headers: Optional[Dict[str, Any]] = None,
cache_creation_tokens: int = 0, cache_creation_tokens: int = 0,
cache_read_tokens: int = 0, cache_read_tokens: int = 0,
is_stream: bool = False, is_stream: bool = False,
@@ -143,6 +144,7 @@ class MessageTelemetry:
request_body=request_body, request_body=request_body,
provider_request_headers=provider_request_headers or {}, provider_request_headers=provider_request_headers or {},
response_headers=response_headers, response_headers=response_headers,
client_response_headers=client_response_headers,
response_body=response_body, response_body=response_body,
request_id=self.request_id, request_id=self.request_id,
# Provider 侧追踪信息(用于记录真实成本) # Provider 侧追踪信息(用于记录真实成本)
@@ -192,6 +194,8 @@ class MessageTelemetry:
cache_creation_tokens: int = 0, cache_creation_tokens: int = 0,
cache_read_tokens: int = 0, cache_read_tokens: int = 0,
response_body: Optional[Dict[str, Any]] = None, response_body: Optional[Dict[str, Any]] = None,
response_headers: Optional[Dict[str, Any]] = None,
client_response_headers: Optional[Dict[str, Any]] = None,
# 模型映射信息 # 模型映射信息
target_model: Optional[str] = None, target_model: Optional[str] = None,
) -> None: ) -> None:
@@ -207,6 +211,8 @@ class MessageTelemetry:
cache_creation_tokens: 缓存创建 tokens cache_creation_tokens: 缓存创建 tokens
cache_read_tokens: 缓存读取 tokens cache_read_tokens: 缓存读取 tokens
response_body: 响应体(如果有部分响应) response_body: 响应体(如果有部分响应)
response_headers: 响应头Provider 返回的原始响应头)
client_response_headers: 返回给客户端的响应头
target_model: 映射后的目标模型名(如果发生了映射) target_model: 映射后的目标模型名(如果发生了映射)
""" """
provider_name = provider or "unknown" provider_name = provider or "unknown"
@@ -232,7 +238,8 @@ class MessageTelemetry:
request_headers=request_headers, request_headers=request_headers,
request_body=request_body, request_body=request_body,
provider_request_headers=provider_request_headers or {}, provider_request_headers=provider_request_headers or {},
response_headers={}, response_headers=response_headers or {},
client_response_headers=client_response_headers,
response_body=response_body or {"error": error_message}, response_body=response_body or {"error": error_message},
request_id=self.request_id, request_id=self.request_id,
# 模型映射信息 # 模型映射信息

View File

@@ -351,9 +351,9 @@ class ChatAdapterBase(ApiAdapter):
# 确定错误消息 # 确定错误消息
if isinstance(e, ProviderAuthException): if isinstance(e, ProviderAuthException):
error_message = ( error_message = (
f"提供商认证失败: {str(e)}" "上游服务认证失败"
if result.metadata.provider != "unknown" if result.metadata.provider != "unknown"
else "服务端错误: 无可用提供商" else "服务暂时不可用"
) )
result.error_message = error_message result.error_message = error_message

View File

@@ -37,7 +37,7 @@ from src.api.handlers.base.stream_processor import StreamProcessor
from src.api.handlers.base.stream_telemetry import StreamTelemetryRecorder from src.api.handlers.base.stream_telemetry import StreamTelemetryRecorder
from src.api.handlers.base.utils import build_sse_headers from src.api.handlers.base.utils import build_sse_headers
from src.config.settings import config from src.config.settings import config
from src.core.error_utils import extract_error_message from src.core.error_utils import extract_client_error_message
from src.core.exceptions import ( from src.core.exceptions import (
EmbeddedErrorException, EmbeddedErrorException,
ProviderAuthException, ProviderAuthException,
@@ -382,10 +382,17 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
http_request.is_disconnected, http_request.is_disconnected,
) )
# 透传提供商的响应头给客户端
# 同时添加必要的 SSE 头以确保流式传输正常工作
client_headers = dict(ctx.response_headers) if ctx.response_headers else {}
# 添加/覆盖 SSE 必需的头
client_headers.update(build_sse_headers())
client_headers["content-type"] = "text/event-stream"
return StreamingResponse( return StreamingResponse(
monitored_stream, monitored_stream,
media_type="text/event-stream", media_type="text/event-stream",
headers=build_sse_headers(), headers=client_headers,
background=background_tasks, background=background_tasks,
) )
@@ -463,7 +470,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
# 配置 HTTP 超时 # 配置 HTTP 超时
# 注意read timeout 用于检测连接断开,不是整体请求超时 # 注意read timeout 用于检测连接断开,不是整体请求超时
# 整体请求超时由 asyncio.wait_for 控制,使用 endpoint.timeout # 整体请求超时由 asyncio.wait_for 控制,使用 provider.timeout
timeout_config = httpx.Timeout( timeout_config = httpx.Timeout(
connect=config.http_connect_timeout, connect=config.http_connect_timeout,
read=config.http_read_timeout, # 使用全局配置,用于检测连接断开 read=config.http_read_timeout, # 使用全局配置,用于检测连接断开
@@ -471,14 +478,14 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
pool=config.http_pool_timeout, pool=config.http_pool_timeout,
) )
# endpoint.timeout 作为整体请求超时(建立连接 + 获取首字节) # provider.timeout 作为整体请求超时(建立连接 + 获取首字节)
request_timeout = float(endpoint.timeout or 300) request_timeout = float(provider.timeout or 300)
# 创建 HTTP 客户端(支持代理配置) # 创建 HTTP 客户端(支持代理配置,从 Provider 读取
from src.clients.http_client import HTTPClientPool from src.clients.http_client import HTTPClientPool
http_client = HTTPClientPool.create_client_with_proxy( http_client = HTTPClientPool.create_client_with_proxy(
proxy_config=endpoint.proxy, proxy_config=provider.proxy,
timeout=timeout_config, timeout=timeout_config,
) )
@@ -514,7 +521,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
try: try:
# 使用 asyncio.wait_for 包裹整个"建立连接 + 获取首字节"阶段 # 使用 asyncio.wait_for 包裹整个"建立连接 + 获取首字节"阶段
# endpoint.timeout 控制整体超时,避免上游长时间无响应 # provider.timeout 控制整体超时,避免上游长时间无响应
await asyncio.wait_for(_connect_and_prefetch(), timeout=request_timeout) await asyncio.wait_for(_connect_and_prefetch(), timeout=request_timeout)
except asyncio.TimeoutError: except asyncio.TimeoutError:
@@ -590,17 +597,22 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
actual_request_body = ctx.provider_request_body or original_request_body actual_request_body = ctx.provider_request_body or original_request_body
# 失败时返回给客户端的是 JSON 错误响应
client_response_headers = {"content-type": "application/json"}
await self.telemetry.record_failure( await self.telemetry.record_failure(
provider=ctx.provider_name or "unknown", provider=ctx.provider_name or "unknown",
model=ctx.model, model=ctx.model,
response_time_ms=response_time_ms, response_time_ms=response_time_ms,
status_code=status_code, status_code=status_code,
error_message=extract_error_message(error), error_message=extract_client_error_message(error),
request_headers=original_headers, request_headers=original_headers,
request_body=actual_request_body, request_body=actual_request_body,
is_stream=True, is_stream=True,
api_format=ctx.api_format, api_format=ctx.api_format,
provider_request_headers=ctx.provider_request_headers, provider_request_headers=ctx.provider_request_headers,
response_headers=ctx.response_headers,
client_response_headers=client_response_headers,
target_model=ctx.mapped_model, target_model=ctx.mapped_model,
) )
@@ -691,64 +703,98 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
f" [{self.request_id}] 请求体stream字段: {provider_payload.get('stream', 'N/A')}" f" [{self.request_id}] 请求体stream字段: {provider_payload.get('stream', 'N/A')}"
) )
# 创建 HTTP 客户端(支持代理配置) # 获取复用的 HTTP 客户端(支持代理配置,从 Provider 读取
# endpoint.timeout 作为整体请求超时 # 注意:使用 get_proxy_client 复用连接池,不再每次创建新客户端
from src.clients.http_client import HTTPClientPool from src.clients.http_client import HTTPClientPool
request_timeout = float(endpoint.timeout or 300) request_timeout = float(provider.timeout or 300)
http_client = HTTPClientPool.create_client_with_proxy( http_client = await HTTPClientPool.get_proxy_client(
proxy_config=endpoint.proxy, proxy_config=provider.proxy,
)
# 注意:不使用 async with因为复用的客户端不应该被关闭
# 超时通过 timeout 参数控制
resp = await http_client.post(
url,
json=provider_payload,
headers=provider_hdrs,
timeout=httpx.Timeout(request_timeout), timeout=httpx.Timeout(request_timeout),
) )
async with http_client:
resp = await http_client.post(url, json=provider_payload, headers=provider_hdrs)
status_code = resp.status_code status_code = resp.status_code
response_headers = dict(resp.headers) response_headers = dict(resp.headers)
if resp.status_code == 401: if resp.status_code == 401:
raise ProviderAuthException(f"提供商认证失败: {provider.name}") raise ProviderAuthException(str(provider.name))
elif resp.status_code == 429: elif resp.status_code == 429:
raise ProviderRateLimitException( raise ProviderRateLimitException(
f"提供商速率限制: {provider.name}", "请求过于频繁,请稍后重试",
provider_name=str(provider.name), provider_name=str(provider.name),
response_headers=response_headers, response_headers=response_headers,
)
elif resp.status_code >= 500:
# 记录响应体以便调试
error_body = ""
try:
error_body = resp.text[:1000]
logger.error(
f" [{self.request_id}] 上游返回5xx错误: status={resp.status_code}, body={error_body[:500]}"
) )
elif resp.status_code >= 500: except Exception:
# 记录响应体以便调试 pass
error_body = "" raise ProviderNotAvailableException(
try: f"上游服务暂时不可用 (HTTP {resp.status_code})",
error_body = resp.text[:1000] provider_name=str(provider.name),
logger.error( upstream_status=resp.status_code,
f" [{self.request_id}] 上游返回5xx错误: status={resp.status_code}, body={error_body[:500]}" upstream_response=error_body,
) )
except Exception: elif resp.status_code != 200:
pass # 记录非200响应以便调试
raise ProviderNotAvailableException( error_body = ""
f"提供商服务不可用: {provider.name}", try:
provider_name=str(provider.name), error_body = resp.text[:1000]
upstream_status=resp.status_code, logger.warning(
upstream_response=error_body, f" [{self.request_id}] 上游返回非200: status={resp.status_code}, body={error_body[:500]}"
)
elif resp.status_code != 200:
# 记录非200响应以便调试
error_body = ""
try:
error_body = resp.text[:1000]
logger.warning(
f" [{self.request_id}] 上游返回非200: status={resp.status_code}, body={error_body[:500]}"
)
except Exception:
pass
raise ProviderNotAvailableException(
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}",
provider_name=str(provider.name),
upstream_status=resp.status_code,
upstream_response=error_body,
) )
except Exception:
pass
raise ProviderNotAvailableException(
f"上游服务返回错误 (HTTP {resp.status_code})",
provider_name=str(provider.name),
upstream_status=resp.status_code,
upstream_response=error_body,
)
# 安全解析 JSON 响应,处理可能的编码错误
try:
response_json = resp.json() response_json = resp.json()
return response_json if isinstance(response_json, dict) else {} except (UnicodeDecodeError, json.JSONDecodeError) as e:
# 获取原始响应内容用于调试(存入 upstream_response
raw_content = ""
try:
raw_content = resp.text[:500] if resp.text else "(empty)"
except Exception:
try:
raw_content = repr(resp.content[:500]) if resp.content else "(empty)"
except Exception:
raw_content = "(unable to read)"
logger.error(
f"[{self.request_id}] 无法解析响应 JSON: {e}, 原始内容: {raw_content}"
)
# 判断错误类型,生成友好的客户端错误消息(不暴露提供商信息)
if raw_content == "(empty)" or not raw_content.strip():
client_message = "上游服务返回了空响应"
elif raw_content.strip().startswith(("<", "<!doctype", "<!DOCTYPE")):
client_message = "上游服务返回了非预期的响应格式"
else:
client_message = "上游服务返回了无效的响应"
raise ProviderNotAvailableException(
client_message,
provider_name=str(provider.name),
upstream_status=resp.status_code,
upstream_response=raw_content,
)
return response_json if isinstance(response_json, dict) else {}
try: try:
# 解析能力需求 # 解析能力需求
@@ -792,6 +838,11 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
actual_request_body = provider_request_body or original_request_body actual_request_body = provider_request_body or original_request_body
# 非流式成功时,返回给客户端的是提供商响应头(透传)
# JSONResponse 会自动设置 content-type但我们记录实际返回的完整头
client_response_headers = dict(response_headers) if response_headers else {}
client_response_headers["content-type"] = "application/json"
total_cost = await self.telemetry.record_success( total_cost = await self.telemetry.record_success(
provider=provider_name, provider=provider_name,
model=model, model=model,
@@ -802,6 +853,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
request_headers=original_headers, request_headers=original_headers,
request_body=actual_request_body, request_body=actual_request_body,
response_headers=response_headers, response_headers=response_headers,
client_response_headers=client_response_headers,
response_body=response_json, response_body=response_json,
cache_creation_tokens=cache_creation_tokens, cache_creation_tokens=cache_creation_tokens,
cache_read_tokens=cached_tokens, cache_read_tokens=cached_tokens,
@@ -823,7 +875,12 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
f"in:{input_tokens or 0} out:{output_tokens or 0}" f"in:{input_tokens or 0} out:{output_tokens or 0}"
) )
return JSONResponse(status_code=status_code, content=response_json) # 透传提供商的响应头
return JSONResponse(
status_code=status_code,
content=response_json,
headers=response_headers if response_headers else None,
)
except Exception as e: except Exception as e:
response_time_ms = self.elapsed_ms() response_time_ms = self.elapsed_ms()
@@ -838,17 +895,27 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
actual_request_body = provider_request_body or original_request_body actual_request_body = provider_request_body or original_request_body
# 尝试从异常中提取响应头
error_response_headers: Dict[str, str] = {}
if isinstance(e, ProviderRateLimitException) and e.response_headers:
error_response_headers = e.response_headers
elif isinstance(e, httpx.HTTPStatusError) and hasattr(e, "response"):
error_response_headers = dict(e.response.headers)
await self.telemetry.record_failure( await self.telemetry.record_failure(
provider=provider_name or "unknown", provider=provider_name or "unknown",
model=model, model=model,
response_time_ms=response_time_ms, response_time_ms=response_time_ms,
status_code=status_code, status_code=status_code,
error_message=extract_error_message(e), error_message=extract_client_error_message(e),
request_headers=original_headers, request_headers=original_headers,
request_body=actual_request_body, request_body=actual_request_body,
is_stream=False, is_stream=False,
api_format=api_format, api_format=api_format,
provider_request_headers=provider_request_headers, provider_request_headers=provider_request_headers,
response_headers=error_response_headers,
# 非流式失败返回给客户端的是 JSON 错误响应
client_response_headers={"content-type": "application/json"},
# 模型映射信息 # 模型映射信息
target_model=mapped_model_result, target_model=mapped_model_result,
) )

View File

@@ -306,9 +306,9 @@ class CliAdapterBase(ApiAdapter):
# 确定错误消息 # 确定错误消息
if isinstance(e, ProviderAuthException): if isinstance(e, ProviderAuthException):
error_message = ( error_message = (
f"提供商认证失败: {str(e)}" "上游服务认证失败"
if result.metadata.provider != "unknown" if result.metadata.provider != "unknown"
else "服务端错误: 无可用提供商" else "服务暂时不可用"
) )
result.error_message = error_message result.error_message = error_message

View File

@@ -47,7 +47,7 @@ from src.api.handlers.base.utils import (
) )
from src.config.constants import StreamDefaults from src.config.constants import StreamDefaults
from src.config.settings import config from src.config.settings import config
from src.core.error_utils import extract_error_message from src.core.error_utils import extract_client_error_message
from src.core.exceptions import ( from src.core.exceptions import (
EmbeddedErrorException, EmbeddedErrorException,
ProviderAuthException, ProviderAuthException,
@@ -376,10 +376,18 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 创建监控流 # 创建监控流
monitored_stream = self._create_monitored_stream(ctx, stream_generator) monitored_stream = self._create_monitored_stream(ctx, stream_generator)
# 透传提供商的响应头给客户端
# 同时添加必要的 SSE 头以确保流式传输正常工作
client_headers = dict(ctx.response_headers) if ctx.response_headers else {}
# 添加/覆盖 SSE 必需的头
client_headers.update(build_sse_headers())
client_headers["content-type"] = "text/event-stream"
ctx.client_response_headers = client_headers
return StreamingResponse( return StreamingResponse(
monitored_stream, monitored_stream,
media_type="text/event-stream", media_type="text/event-stream",
headers=build_sse_headers(), headers=client_headers,
background=background_tasks, background=background_tasks,
) )
@@ -475,8 +483,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
pool=config.http_pool_timeout, pool=config.http_pool_timeout,
) )
# endpoint.timeout 作为整体请求超时(建立连接 + 获取首字节) # provider.timeout 作为整体请求超时(建立连接 + 获取首字节)
request_timeout = float(endpoint.timeout or 300) request_timeout = float(provider.timeout or 300)
logger.debug( logger.debug(
f" └─ [{self.request_id}] 发送流式请求: " f" └─ [{self.request_id}] 发送流式请求: "
@@ -486,11 +494,11 @@ class CliMessageHandlerBase(BaseMessageHandler):
f"timeout={request_timeout}s" f"timeout={request_timeout}s"
) )
# 创建 HTTP 客户端(支持代理配置) # 创建 HTTP 客户端(支持代理配置,从 Provider 读取
from src.clients.http_client import HTTPClientPool from src.clients.http_client import HTTPClientPool
http_client = HTTPClientPool.create_client_with_proxy( http_client = HTTPClientPool.create_client_with_proxy(
proxy_config=endpoint.proxy, proxy_config=provider.proxy,
timeout=timeout_config, timeout=timeout_config,
) )
@@ -524,7 +532,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
try: try:
# 使用 asyncio.wait_for 包裹整个"建立连接 + 获取首字节"阶段 # 使用 asyncio.wait_for 包裹整个"建立连接 + 获取首字节"阶段
# endpoint.timeout 控制整体超时,避免上游长时间无响应 # provider.timeout 控制整体超时,避免上游长时间无响应
await asyncio.wait_for(_connect_and_prefetch(), timeout=request_timeout) await asyncio.wait_for(_connect_and_prefetch(), timeout=request_timeout)
except asyncio.TimeoutError: except asyncio.TimeoutError:
@@ -636,12 +644,16 @@ class CliMessageHandlerBase(BaseMessageHandler):
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0: if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
elapsed = time.time() - last_data_time elapsed = time.time() - last_data_time
if elapsed > self.DATA_TIMEOUT: if elapsed > self.DATA_TIMEOUT:
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据") logger.warning(f"Provider '{ctx.provider_name}' 流超时且无数据")
# 设置错误状态用于后续记录
ctx.status_code = 504
ctx.error_message = "流式响应超时,未收到有效数据"
ctx.upstream_response = f"流超时: Provider={ctx.provider_name}, elapsed={elapsed:.1f}s, chunk_count={ctx.chunk_count}, data_count=0"
error_event = { error_event = {
"type": "error", "type": "error",
"error": { "error": {
"type": "empty_stream_timeout", "type": "empty_stream_timeout",
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据", "message": ctx.error_message,
}, },
} }
self._mark_first_output(ctx, output_state) self._mark_first_output(ctx, output_state)
@@ -682,12 +694,16 @@ class CliMessageHandlerBase(BaseMessageHandler):
if ctx.data_count == 0: if ctx.data_count == 0:
# 流已开始,无法抛出异常进行故障转移 # 流已开始,无法抛出异常进行故障转移
# 发送错误事件并记录日志 # 发送错误事件并记录日志
logger.warning(f"提供商 '{ctx.provider_name}' 返回空流式响应") logger.warning(f"Provider '{ctx.provider_name}' 返回空流式响应")
# 设置错误状态用于后续记录
ctx.status_code = 503
ctx.error_message = "上游服务返回了空的流式响应"
ctx.upstream_response = f"空流式响应: Provider={ctx.provider_name}, chunk_count={ctx.chunk_count}, data_count=0"
error_event = { error_event = {
"type": "error", "type": "error",
"error": { "error": {
"type": "empty_response", "type": "empty_response",
"message": f"提供商 '{ctx.provider_name}' 返回了空的流式响应", "message": ctx.error_message,
}, },
} }
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8") yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
@@ -699,12 +715,16 @@ class CliMessageHandlerBase(BaseMessageHandler):
except httpx.StreamClosed: except httpx.StreamClosed:
if ctx.data_count == 0: if ctx.data_count == 0:
# 流已开始,发送错误事件而不是抛出异常 # 流已开始,发送错误事件而不是抛出异常
logger.warning(f"提供商 '{ctx.provider_name}' 流连接关闭且无数据") logger.warning(f"Provider '{ctx.provider_name}' 流连接关闭且无数据")
# 设置错误状态用于后续记录
ctx.status_code = 503
ctx.error_message = "上游服务连接关闭且未返回数据"
ctx.upstream_response = f"流连接关闭: Provider={ctx.provider_name}, chunk_count={ctx.chunk_count}, data_count=0"
error_event = { error_event = {
"type": "error", "type": "error",
"error": { "error": {
"type": "stream_closed", "type": "stream_closed",
"message": f"提供商 '{ctx.provider_name}' 连接关闭且未返回数据", "message": ctx.error_message,
}, },
} }
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8") yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
@@ -824,8 +844,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
f"base_url={endpoint.base_url}" f"base_url={endpoint.base_url}"
) )
raise ProviderNotAvailableException( raise ProviderNotAvailableException(
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应," "上游服务返回了非预期的响应格式",
f"请检查 endpoint 的 base_url 配置是否正确" provider_name=str(provider.name),
upstream_status=200,
upstream_response=normalized_line[:500] if normalized_line else "(empty)",
) )
if not normalized_line or normalized_line.startswith(":"): if not normalized_line or normalized_line.startswith(":"):
@@ -1024,12 +1046,16 @@ class CliMessageHandlerBase(BaseMessageHandler):
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0: if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
elapsed = time.time() - last_data_time elapsed = time.time() - last_data_time
if elapsed > self.DATA_TIMEOUT: if elapsed > self.DATA_TIMEOUT:
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据") logger.warning(f"Provider '{ctx.provider_name}' 流超时且无数据")
# 设置错误状态用于后续记录
ctx.status_code = 504
ctx.error_message = "流式响应超时,未收到有效数据"
ctx.upstream_response = f"流超时: Provider={ctx.provider_name}, elapsed={elapsed:.1f}s, chunk_count={ctx.chunk_count}, data_count=0"
error_event = { error_event = {
"type": "error", "type": "error",
"error": { "error": {
"type": "empty_stream_timeout", "type": "empty_stream_timeout",
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据", "message": ctx.error_message,
}, },
} }
self._mark_first_output(ctx, output_state) self._mark_first_output(ctx, output_state)
@@ -1071,14 +1097,18 @@ class CliMessageHandlerBase(BaseMessageHandler):
if ctx.data_count == 0: if ctx.data_count == 0:
# 空流通常意味着配置错误(如 base_url 指向了网页而非 API # 空流通常意味着配置错误(如 base_url 指向了网页而非 API
logger.error( logger.error(
f"提供商 '{ctx.provider_name}' 返回空流式响应 (收到 {ctx.chunk_count} 个非数据行), " f"Provider '{ctx.provider_name}' 返回空流式响应 (收到 {ctx.chunk_count} 个非数据行), "
f"可能是 endpoint base_url 配置错误" f"可能是 endpoint base_url 配置错误"
) )
# 设置错误状态用于后续记录
ctx.status_code = 503
ctx.error_message = "上游服务返回了空的流式响应"
ctx.upstream_response = f"空流式响应: Provider={ctx.provider_name}, chunk_count={ctx.chunk_count}, data_count=0, 可能是 base_url 配置错误"
error_event = { error_event = {
"type": "error", "type": "error",
"error": { "error": {
"type": "empty_response", "type": "empty_response",
"message": f"提供商 '{ctx.provider_name}' 返回了空的流式响应 (收到 {ctx.chunk_count} 行非 SSE 数据),请检查 endpoint 的 base_url 配置是否指向了正确的 API 地址", "message": ctx.error_message,
}, },
} }
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8") yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
@@ -1089,12 +1119,16 @@ class CliMessageHandlerBase(BaseMessageHandler):
raise raise
except httpx.StreamClosed: except httpx.StreamClosed:
if ctx.data_count == 0: if ctx.data_count == 0:
logger.warning(f"提供商 '{ctx.provider_name}' 流连接关闭且无数据") logger.warning(f"Provider '{ctx.provider_name}' 流连接关闭且无数据")
# 设置错误状态用于后续记录
ctx.status_code = 503
ctx.error_message = "上游服务连接关闭且未返回数据"
ctx.upstream_response = f"流连接关闭: Provider={ctx.provider_name}, chunk_count={ctx.chunk_count}, data_count=0"
error_event = { error_event = {
"type": "error", "type": "error",
"error": { "error": {
"type": "stream_closed", "type": "stream_closed",
"message": f"提供商 '{ctx.provider_name}' 连接关闭且未返回数据", "message": ctx.error_message,
}, },
} }
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8") yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
@@ -1289,6 +1323,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
if ctx.status_code and ctx.status_code >= 400: if ctx.status_code and ctx.status_code >= 400:
# 记录失败的 Usage但使用已收到的预估 token 信息(来自 message_start # 记录失败的 Usage但使用已收到的预估 token 信息(来自 message_start
# 这样即使请求中断,也能记录预估成本 # 这样即使请求中断,也能记录预估成本
# 失败时返回给客户端的是 JSON 错误响应,如果没有设置则使用默认值
client_response_headers = ctx.client_response_headers or {"content-type": "application/json"}
await bg_telemetry.record_failure( await bg_telemetry.record_failure(
provider=ctx.provider_name or "unknown", provider=ctx.provider_name or "unknown",
model=ctx.model, model=ctx.model,
@@ -1306,6 +1342,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
cache_creation_tokens=ctx.cache_creation_tokens, cache_creation_tokens=ctx.cache_creation_tokens,
cache_read_tokens=ctx.cached_tokens, cache_read_tokens=ctx.cached_tokens,
response_body=response_body, response_body=response_body,
response_headers=ctx.response_headers,
client_response_headers=client_response_headers,
# 模型映射信息 # 模型映射信息
target_model=ctx.mapped_model, target_model=ctx.mapped_model,
) )
@@ -1319,6 +1357,14 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 在记录统计前,允许子类从 parsed_chunks 中提取额外的元数据 # 在记录统计前,允许子类从 parsed_chunks 中提取额外的元数据
self._finalize_stream_metadata(ctx) self._finalize_stream_metadata(ctx)
# 流式成功时,返回给客户端的是提供商响应头 + SSE 必需头
client_response_headers = dict(ctx.response_headers) if ctx.response_headers else {}
client_response_headers.update({
"Cache-Control": "no-cache, no-transform",
"X-Accel-Buffering": "no",
"content-type": "text/event-stream",
})
total_cost = await bg_telemetry.record_success( total_cost = await bg_telemetry.record_success(
provider=ctx.provider_name, provider=ctx.provider_name,
model=ctx.model, model=ctx.model,
@@ -1330,6 +1376,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
request_headers=original_headers, request_headers=original_headers,
request_body=actual_request_body, request_body=actual_request_body,
response_headers=ctx.response_headers, response_headers=ctx.response_headers,
client_response_headers=client_response_headers,
response_body=response_body, response_body=response_body,
cache_creation_tokens=ctx.cache_creation_tokens, cache_creation_tokens=ctx.cache_creation_tokens,
cache_read_tokens=ctx.cached_tokens, cache_read_tokens=ctx.cached_tokens,
@@ -1367,13 +1414,15 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 499 = 客户端断开连接,应标记为失败 # 499 = 客户端断开连接,应标记为失败
# 503 = 服务不可用(如流中断),应标记为失败 # 503 = 服务不可用(如流中断),应标记为失败
if ctx.status_code and ctx.status_code >= 400: if ctx.status_code and ctx.status_code >= 400:
# 请求链路追踪使用 upstream_response原始响应回退到 error_message友好消息
trace_error_message = ctx.upstream_response or ctx.error_message or f"HTTP {ctx.status_code}"
RequestCandidateService.mark_candidate_failed( RequestCandidateService.mark_candidate_failed(
db=bg_db, db=bg_db,
candidate_id=ctx.attempt_id, candidate_id=ctx.attempt_id,
error_type=( error_type=(
"client_disconnected" if ctx.status_code == 499 else "stream_error" "client_disconnected" if ctx.status_code == 499 else "stream_error"
), ),
error_message=ctx.error_message or f"HTTP {ctx.status_code}", error_message=trace_error_message,
status_code=ctx.status_code, status_code=ctx.status_code,
latency_ms=response_time_ms, latency_ms=response_time_ms,
extra_data={ extra_data={
@@ -1426,17 +1475,22 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 使用实际发送给 Provider 的请求体(如果有),否则用原始请求体 # 使用实际发送给 Provider 的请求体(如果有),否则用原始请求体
actual_request_body = ctx.provider_request_body or original_request_body actual_request_body = ctx.provider_request_body or original_request_body
# 失败时返回给客户端的是 JSON 错误响应
client_response_headers = {"content-type": "application/json"}
await self.telemetry.record_failure( await self.telemetry.record_failure(
provider=ctx.provider_name or "unknown", provider=ctx.provider_name or "unknown",
model=ctx.model, model=ctx.model,
response_time_ms=response_time_ms, response_time_ms=response_time_ms,
status_code=status_code, status_code=status_code,
error_message=extract_error_message(error), error_message=extract_client_error_message(error),
request_headers=original_headers, request_headers=original_headers,
request_body=actual_request_body, request_body=actual_request_body,
is_stream=True, is_stream=True,
api_format=ctx.api_format, api_format=ctx.api_format,
provider_request_headers=ctx.provider_request_headers, provider_request_headers=ctx.provider_request_headers,
response_headers=ctx.response_headers,
client_response_headers=client_response_headers,
# 模型映射信息 # 模型映射信息
target_model=ctx.mapped_model, target_model=ctx.mapped_model,
) )
@@ -1534,72 +1588,99 @@ class CliMessageHandlerBase(BaseMessageHandler):
f"原始模型={model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}" f"原始模型={model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}"
) )
# 创建 HTTP 客户端(支持代理配置) # 获取复用的 HTTP 客户端(支持代理配置,从 Provider 读取
# endpoint.timeout 作为整体请求超时 # 注意:使用 get_proxy_client 复用连接池,不再每次创建新客户端
from src.clients.http_client import HTTPClientPool from src.clients.http_client import HTTPClientPool
request_timeout = float(endpoint.timeout or 300) request_timeout = float(provider.timeout or 300)
http_client = HTTPClientPool.create_client_with_proxy( http_client = await HTTPClientPool.get_proxy_client(
proxy_config=endpoint.proxy, proxy_config=provider.proxy,
)
# 注意:不使用 async with因为复用的客户端不应该被关闭
# 超时通过 timeout 参数控制
resp = await http_client.post(
url,
json=provider_payload,
headers=provider_headers,
timeout=httpx.Timeout(request_timeout), timeout=httpx.Timeout(request_timeout),
) )
async with http_client:
resp = await http_client.post(url, json=provider_payload, headers=provider_headers)
status_code = resp.status_code status_code = resp.status_code
response_headers = dict(resp.headers) response_headers = dict(resp.headers)
if resp.status_code == 401: if resp.status_code == 401:
raise ProviderAuthException(f"提供商认证失败: {provider.name}") raise ProviderAuthException(str(provider.name))
elif resp.status_code == 429: elif resp.status_code == 429:
raise ProviderRateLimitException( raise ProviderRateLimitException(
f"提供商速率限制: {provider.name}", "请求过于频繁,请稍后重试",
provider_name=str(provider.name), provider_name=str(provider.name),
response_headers=response_headers, response_headers=response_headers,
retry_after=int(resp.headers.get("retry-after", 0)) or None, retry_after=int(resp.headers.get("retry-after", 0)) or None,
) )
elif resp.status_code >= 500: elif resp.status_code >= 500:
error_text = resp.text error_text = resp.text
raise ProviderNotAvailableException( raise ProviderNotAvailableException(
f"提供商服务不可用: {provider.name}, 状态: {resp.status_code}", f"上游服务暂时不可用 (HTTP {resp.status_code})",
provider_name=str(provider.name), provider_name=str(provider.name),
upstream_status=resp.status_code, upstream_status=resp.status_code,
upstream_response=error_text, upstream_response=error_text,
) )
elif 300 <= resp.status_code < 400: elif 300 <= resp.status_code < 400:
redirect_url = resp.headers.get("location", "unknown") redirect_url = resp.headers.get("location", "unknown")
raise ProviderNotAvailableException( raise ProviderNotAvailableException(
f"提供商配置错误: {provider.name}, 返回重定向 {resp.status_code} -> {redirect_url}" "上游服务返回重定向响应",
) provider_name=str(provider.name),
elif resp.status_code != 200: upstream_status=resp.status_code,
error_text = resp.text upstream_response=f"重定向 {resp.status_code} -> {redirect_url}",
raise ProviderNotAvailableException( )
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}", elif resp.status_code != 200:
provider_name=str(provider.name), error_text = resp.text
upstream_status=resp.status_code, raise ProviderNotAvailableException(
upstream_response=error_text, f"上游服务返回错误 (HTTP {resp.status_code})",
) provider_name=str(provider.name),
upstream_status=resp.status_code,
upstream_response=error_text,
)
# 安全解析 JSON 响应,处理可能的编码错误 # 安全解析 JSON 响应,处理可能的编码错误
try:
response_json = resp.json()
except (UnicodeDecodeError, json.JSONDecodeError) as e:
# 获取原始响应内容用于调试(存入 upstream_response
content_type = resp.headers.get("content-type", "unknown")
content_encoding = resp.headers.get("content-encoding", "none")
raw_content = ""
try: try:
response_json = resp.json() raw_content = resp.text[:500] if resp.text else "(empty)"
except (UnicodeDecodeError, json.JSONDecodeError) as e: except Exception:
# 记录原始响应信息用于调试 try:
content_type = resp.headers.get("content-type", "unknown") raw_content = repr(resp.content[:500]) if resp.content else "(empty)"
content_encoding = resp.headers.get("content-encoding", "none") except Exception:
logger.error( raw_content = "(unable to read)"
f"[{self.request_id}] 无法解析响应 JSON: {e}, " logger.error(
f"Content-Type: {content_type}, Content-Encoding: {content_encoding}, " f"[{self.request_id}] 无法解析响应 JSON: {e}, "
f"响应长度: {len(resp.content)} bytes" f"Content-Type: {content_type}, Content-Encoding: {content_encoding}, "
) f"响应长度: {len(resp.content)} bytes, 原始内容: {raw_content}"
raise ProviderNotAvailableException( )
f"提供商返回无效响应: {provider.name}, 无法解析 JSON: {str(e)[:100]}" # 判断错误类型,生成友好的客户端错误消息(不暴露提供商信息)
) if raw_content == "(empty)" or not raw_content.strip():
client_message = "上游服务返回了空响应"
elif raw_content.strip().startswith(("<", "<!doctype", "<!DOCTYPE")):
client_message = "上游服务返回了非预期的响应格式"
else:
client_message = "上游服务返回了无效的响应"
raise ProviderNotAvailableException(
client_message,
provider_name=str(provider.name),
upstream_status=resp.status_code,
upstream_response=raw_content,
)
# 提取 Provider 响应元数据(子类可覆盖) # 提取 Provider 响应元数据(子类可覆盖)
response_metadata_result = self._extract_response_metadata(response_json) response_metadata_result = self._extract_response_metadata(response_json)
return response_json if isinstance(response_json, dict) else {} return response_json if isinstance(response_json, dict) else {}
try: try:
# 解析能力需求 # 解析能力需求
@@ -1663,6 +1744,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 使用实际发送给 Provider 的请求体(如果有),否则用原始请求体 # 使用实际发送给 Provider 的请求体(如果有),否则用原始请求体
actual_request_body = provider_request_body or original_request_body actual_request_body = provider_request_body or original_request_body
# 非流式成功时,返回给客户端的是提供商响应头(透传)
client_response_headers = dict(response_headers) if response_headers else {}
client_response_headers["content-type"] = "application/json"
total_cost = await self.telemetry.record_success( total_cost = await self.telemetry.record_success(
provider=provider_name, provider=provider_name,
model=model, model=model,
@@ -1673,6 +1758,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
request_headers=original_headers, request_headers=original_headers,
request_body=actual_request_body, request_body=actual_request_body,
response_headers=response_headers, response_headers=response_headers,
client_response_headers=client_response_headers,
response_body=response_json, response_body=response_json,
cache_creation_tokens=cache_creation_tokens, cache_creation_tokens=cache_creation_tokens,
cache_read_tokens=cached_tokens, cache_read_tokens=cached_tokens,
@@ -1691,7 +1777,12 @@ class CliMessageHandlerBase(BaseMessageHandler):
logger.info(f"{self.FORMAT_ID} 非流式响应处理完成") logger.info(f"{self.FORMAT_ID} 非流式响应处理完成")
return JSONResponse(status_code=status_code, content=response_json) # 透传提供商的响应头
return JSONResponse(
status_code=status_code,
content=response_json,
headers=response_headers if response_headers else None,
)
except Exception as e: except Exception as e:
response_time_ms = int((time.time() - sync_start_time) * 1000) response_time_ms = int((time.time() - sync_start_time) * 1000)
@@ -1707,17 +1798,27 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 使用实际发送给 Provider 的请求体(如果有),否则用原始请求体 # 使用实际发送给 Provider 的请求体(如果有),否则用原始请求体
actual_request_body = provider_request_body or original_request_body actual_request_body = provider_request_body or original_request_body
# 尝试从异常中提取响应头
error_response_headers: Dict[str, str] = {}
if isinstance(e, ProviderRateLimitException) and e.response_headers:
error_response_headers = e.response_headers
elif isinstance(e, httpx.HTTPStatusError) and hasattr(e, "response"):
error_response_headers = dict(e.response.headers)
await self.telemetry.record_failure( await self.telemetry.record_failure(
provider=provider_name or "unknown", provider=provider_name or "unknown",
model=model, model=model,
response_time_ms=response_time_ms, response_time_ms=response_time_ms,
status_code=status_code, status_code=status_code,
error_message=extract_error_message(e), error_message=extract_client_error_message(e),
request_headers=original_headers, request_headers=original_headers,
request_body=actual_request_body, request_body=actual_request_body,
is_stream=False, is_stream=False,
api_format=api_format, api_format=api_format,
provider_request_headers=provider_request_headers, provider_request_headers=provider_request_headers,
response_headers=error_response_headers,
# 非流式失败返回给客户端的是 JSON 错误响应
client_response_headers={"content-type": "application/json"},
# 模型映射信息 # 模型映射信息
target_model=mapped_model_result, target_model=mapped_model_result,
) )

View File

@@ -74,7 +74,6 @@ def build_safe_headers(
return headers return headers
# 保持向后兼容的run_endpoint_check函数使用新架构
async def run_endpoint_check( async def run_endpoint_check(
*, *,
client: httpx.AsyncClient, # 保持兼容性,但内部不使用 client: httpx.AsyncClient, # 保持兼容性,但内部不使用
@@ -176,10 +175,16 @@ async def _calculate_and_record_usage(
logger.warning(f"Provider API Key not found for usage calculation: {api_key_id}") logger.warning(f"Provider API Key not found for usage calculation: {api_key_id}")
return {"error": "Provider API Key not found"} return {"error": "Provider API Key not found"}
# 获取Provider Endpoint信息 # 获取Provider Endpoint信息(通过 api_format 查找)
provider_endpoint = None provider_endpoint = None
if provider_api_key.endpoint_id: if api_format and provider_api_key.provider_id:
provider_endpoint = db.query(ProviderEndpoint).filter(ProviderEndpoint.id == provider_api_key.endpoint_id).first() from src.models.database import Provider
provider = db.query(Provider).filter(Provider.id == provider_api_key.provider_id).first()
if provider:
for ep in provider.endpoints:
if ep.api_format == api_format and ep.is_active:
provider_endpoint = ep
break
# 获取用户的API Key用于记录关联即使实际使用的是Provider API Key # 获取用户的API Key用于记录关联即使实际使用的是Provider API Key
user_api_key = None user_api_key = None

View File

@@ -61,11 +61,13 @@ class StreamContext:
# 响应状态 # 响应状态
status_code: int = 200 status_code: int = 200
error_message: Optional[str] = None error_message: Optional[str] = None # 客户端友好的错误消息
upstream_response: Optional[str] = None # 原始 Provider 响应(用于请求链路追踪)
has_completion: bool = False has_completion: bool = False
# 请求/响应数据 # 请求/响应数据
response_headers: Dict[str, str] = field(default_factory=dict) response_headers: Dict[str, str] = field(default_factory=dict) # 提供商响应头
client_response_headers: Dict[str, str] = field(default_factory=dict) # 返回给客户端的响应头
provider_request_headers: Dict[str, str] = field(default_factory=dict) provider_request_headers: Dict[str, str] = field(default_factory=dict)
provider_request_body: Optional[Dict[str, Any]] = None provider_request_body: Optional[Dict[str, Any]] = None
@@ -97,9 +99,11 @@ class StreamContext:
self.cached_tokens = 0 self.cached_tokens = 0
self.cache_creation_tokens = 0 self.cache_creation_tokens = 0
self.error_message = None self.error_message = None
self.upstream_response = None
self.status_code = 200 self.status_code = 200
self.first_byte_time_ms = None self.first_byte_time_ms = None
self.response_headers = {} self.response_headers = {}
self.client_response_headers = {}
self.provider_request_headers = {} self.provider_request_headers = {}
self.provider_request_body = None self.provider_request_body = None
self.response_id = None self.response_id = None
@@ -174,10 +178,24 @@ class StreamContext:
): ):
self.cache_creation_tokens = cache_creation_tokens self.cache_creation_tokens = cache_creation_tokens
def mark_failed(self, status_code: int, error_message: str) -> None: def mark_failed(
"""标记请求失败""" self,
status_code: int,
error_message: str,
upstream_response: Optional[str] = None,
) -> None:
"""
标记请求失败
Args:
status_code: HTTP 状态码
error_message: 客户端友好的错误消息
upstream_response: 原始 Provider 响应(用于请求链路追踪)
"""
self.status_code = status_code self.status_code = status_code
self.error_message = error_message self.error_message = error_message
if upstream_response:
self.upstream_response = upstream_response
def record_first_byte_time(self, start_time: float) -> None: def record_first_byte_time(self, start_time: float) -> None:
""" """

View File

@@ -251,8 +251,10 @@ class StreamProcessor:
f"base_url={endpoint.base_url}" f"base_url={endpoint.base_url}"
) )
raise ProviderNotAvailableException( raise ProviderNotAvailableException(
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应," "上游服务返回了非预期的响应格式",
f"请检查 endpoint 的 base_url 配置是否正确" provider_name=str(provider.name),
upstream_status=200,
upstream_response=line[:500] if line else "(empty)",
) )
# 跳过空行和注释行 # 跳过空行和注释行

View File

@@ -154,6 +154,14 @@ class StreamTelemetryRecorder:
response_time_ms: int, response_time_ms: int,
) -> None: ) -> None:
"""记录成功的请求""" """记录成功的请求"""
# 流式成功时,返回给客户端的是提供商响应头 + SSE 必需头
client_response_headers = dict(ctx.response_headers) if ctx.response_headers else {}
client_response_headers.update({
"Cache-Control": "no-cache, no-transform",
"X-Accel-Buffering": "no",
"content-type": "text/event-stream",
})
await telemetry.record_success( await telemetry.record_success(
provider=ctx.provider_name or "unknown", provider=ctx.provider_name or "unknown",
model=ctx.model, model=ctx.model,
@@ -165,6 +173,7 @@ class StreamTelemetryRecorder:
request_headers=original_headers, request_headers=original_headers,
request_body=actual_request_body, request_body=actual_request_body,
response_headers=ctx.response_headers, response_headers=ctx.response_headers,
client_response_headers=client_response_headers,
response_body=response_body, response_body=response_body,
cache_creation_tokens=ctx.cache_creation_tokens, cache_creation_tokens=ctx.cache_creation_tokens,
cache_read_tokens=ctx.cached_tokens, cache_read_tokens=ctx.cached_tokens,
@@ -190,6 +199,9 @@ class StreamTelemetryRecorder:
response_time_ms: int, response_time_ms: int,
) -> None: ) -> None:
"""记录失败的请求""" """记录失败的请求"""
# 失败时返回给客户端的是 JSON 错误响应,如果没有设置则使用默认值
client_response_headers = ctx.client_response_headers or {"content-type": "application/json"}
await telemetry.record_failure( await telemetry.record_failure(
provider=ctx.provider_name or "unknown", provider=ctx.provider_name or "unknown",
model=ctx.model, model=ctx.model,
@@ -206,6 +218,8 @@ class StreamTelemetryRecorder:
cache_creation_tokens=ctx.cache_creation_tokens, cache_creation_tokens=ctx.cache_creation_tokens,
cache_read_tokens=ctx.cached_tokens, cache_read_tokens=ctx.cached_tokens,
response_body=response_body, response_body=response_body,
response_headers=ctx.response_headers,
client_response_headers=client_response_headers,
target_model=ctx.mapped_model, target_model=ctx.mapped_model,
) )
@@ -239,11 +253,13 @@ class StreamTelemetryRecorder:
) )
else: else:
error_type = "client_disconnected" if ctx.status_code == 499 else "stream_error" error_type = "client_disconnected" if ctx.status_code == 499 else "stream_error"
# 请求链路追踪使用 upstream_response原始响应回退到 error_message友好消息
trace_error_message = ctx.upstream_response or ctx.error_message or f"HTTP {ctx.status_code}"
RequestCandidateService.mark_candidate_failed( RequestCandidateService.mark_candidate_failed(
db=db, db=db,
candidate_id=ctx.attempt_id, candidate_id=ctx.attempt_id,
error_type=error_type, error_type=error_type,
error_message=ctx.error_message or f"HTTP {ctx.status_code}", error_message=trace_error_message,
status_code=ctx.status_code, status_code=ctx.status_code,
latency_ms=response_time_ms, latency_ms=response_time_ms,
extra_data={ extra_data={

View File

@@ -26,8 +26,10 @@ def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int:
3. **旧格式(优先级第三)** 3. **旧格式(优先级第三)**
usage.cache_creation_input_tokens usage.cache_creation_input_tokens
优先使用嵌套格式,如果嵌套格式字段存在但值为 0则智能 fallback 到旧格式。 说明:
扁平格式和嵌套格式互斥,按顺序检查 - 只要检测到新格式字段(嵌套/扁平),即视为权威来源:哪怕值为 0 也不回退到旧字段
- 仅当新格式字段完全不存在时,才回退到旧字段。
- 扁平格式和嵌套格式互斥,按顺序检查。
Args: Args:
usage: API 响应中的 usage 字典 usage: API 响应中的 usage 字典
@@ -37,27 +39,20 @@ def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int:
""" """
# 1. 检查嵌套格式(最新格式) # 1. 检查嵌套格式(最新格式)
cache_creation = usage.get("cache_creation") cache_creation = usage.get("cache_creation")
if isinstance(cache_creation, dict): has_nested_format = isinstance(cache_creation, dict) and (
"ephemeral_5m_input_tokens" in cache_creation
or "ephemeral_1h_input_tokens" in cache_creation
)
if has_nested_format:
cache_5m = int(cache_creation.get("ephemeral_5m_input_tokens", 0)) cache_5m = int(cache_creation.get("ephemeral_5m_input_tokens", 0))
cache_1h = int(cache_creation.get("ephemeral_1h_input_tokens", 0)) cache_1h = int(cache_creation.get("ephemeral_1h_input_tokens", 0))
total = cache_5m + cache_1h total = cache_5m + cache_1h
if total > 0: logger.debug(
logger.debug( f"Using nested cache_creation: 5m={cache_5m}, 1h={cache_1h}, total={total}"
f"Using nested cache_creation: 5m={cache_5m}, 1h={cache_1h}, total={total}" )
) return total
return total
# 嵌套格式存在但为 0fallback 到旧格式
old_format = int(usage.get("cache_creation_input_tokens", 0))
if old_format > 0:
logger.debug(
f"Nested cache_creation is 0, using old format: {old_format}"
)
return old_format
# 都是 0返回 0
return 0
# 2. 检查扁平新格式 # 2. 检查扁平新格式
has_flat_format = ( has_flat_format = (
@@ -70,22 +65,10 @@ def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int:
cache_1h = int(usage.get("claude_cache_creation_1_h_tokens", 0)) cache_1h = int(usage.get("claude_cache_creation_1_h_tokens", 0))
total = cache_5m + cache_1h total = cache_5m + cache_1h
if total > 0: logger.debug(
logger.debug( f"Using flat new format: 5m={cache_5m}, 1h={cache_1h}, total={total}"
f"Using flat new format: 5m={cache_5m}, 1h={cache_1h}, total={total}" )
) return total
return total
# 扁平格式存在但为 0fallback 到旧格式
old_format = int(usage.get("cache_creation_input_tokens", 0))
if old_format > 0:
logger.debug(
f"Flat cache_creation is 0, using old format: {old_format}"
)
return old_format
# 都是 0返回 0
return 0
# 3. 回退到旧格式 # 3. 回退到旧格式
old_format = int(usage.get("cache_creation_input_tokens", 0)) old_format = int(usage.get("cache_creation_input_tokens", 0))
@@ -173,8 +156,10 @@ def check_prefetched_response_error(
f"base_url={base_url}" f"base_url={base_url}"
) )
raise ProviderNotAvailableException( raise ProviderNotAvailableException(
f"提供商 '{provider_name}' 返回了 HTML 页面而非 API 响应," "上游服务返回了非预期的响应格式",
f"请检查 endpoint 的 base_url 配置是否正确" provider_name=provider_name,
upstream_status=200,
upstream_response=stripped.decode("utf-8", errors="replace")[:500],
) )
# 纯 JSON可能无换行/多行 JSON # 纯 JSON可能无换行/多行 JSON

View File

@@ -2,7 +2,6 @@
Claude CLI Message Handler - 基于通用 CLI Handler 基类的简化实现 Claude CLI Message Handler - 基于通用 CLI Handler 基类的简化实现
继承 CliMessageHandlerBase只需覆盖格式特定的配置和事件处理逻辑。 继承 CliMessageHandlerBase只需覆盖格式特定的配置和事件处理逻辑。
验证新架构的有效性:代码量从数百行减少到 ~80 行。
""" """
from typing import Any, Dict, Optional from typing import Any, Dict, Optional

View File

@@ -102,7 +102,6 @@ async def get_public_models(
- id: 模型唯一标识符 - id: 模型唯一标识符
- provider_id: 所属提供商 ID - provider_id: 所属提供商 ID
- provider_name: 提供商名称 - provider_name: 提供商名称
- provider_display_name: 提供商显示名称
- name: 模型统一名称(优先使用 GlobalModel 名称) - name: 模型统一名称(优先使用 GlobalModel 名称)
- display_name: 模型显示名称 - display_name: 模型显示名称
- description: 模型描述信息 - description: 模型描述信息
@@ -300,10 +299,20 @@ class PublicProvidersAdapter(PublicApiAdapter):
providers = query.offset(self.skip).limit(self.limit).all() providers = query.offset(self.skip).limit(self.limit).all()
result = [] result = []
for provider in providers: for provider in providers:
models_count = db.query(Model).filter(Model.provider_id == provider.id).count() models_count = (
db.query(Model)
.filter(Model.provider_id == provider.id, Model.global_model_id.isnot(None))
.count()
)
active_models_count = ( active_models_count = (
db.query(Model) db.query(Model)
.filter(and_(Model.provider_id == provider.id, Model.is_active.is_(True))) .filter(
and_(
Model.provider_id == provider.id,
Model.is_active.is_(True),
Model.global_model_id.isnot(None),
)
)
.count() .count()
) )
endpoints_count = len(provider.endpoints) if provider.endpoints else 0 endpoints_count = len(provider.endpoints) if provider.endpoints else 0
@@ -313,7 +322,6 @@ class PublicProvidersAdapter(PublicApiAdapter):
provider_data = PublicProviderResponse( provider_data = PublicProviderResponse(
id=provider.id, id=provider.id,
name=provider.name, name=provider.name,
display_name=provider.display_name,
description=provider.description, description=provider.description,
is_active=provider.is_active, is_active=provider.is_active,
provider_priority=provider.provider_priority, provider_priority=provider.provider_priority,
@@ -342,7 +350,13 @@ class PublicModelsAdapter(PublicApiAdapter):
db.query(Model, Provider) db.query(Model, Provider)
.options(joinedload(Model.global_model)) .options(joinedload(Model.global_model))
.join(Provider) .join(Provider)
.filter(and_(Model.is_active.is_(True), Provider.is_active.is_(True))) .filter(
and_(
Model.is_active.is_(True),
Provider.is_active.is_(True),
Model.global_model_id.isnot(None),
)
)
) )
if self.provider_id is not None: if self.provider_id is not None:
query = query.filter(Model.provider_id == self.provider_id) query = query.filter(Model.provider_id == self.provider_id)
@@ -357,7 +371,6 @@ class PublicModelsAdapter(PublicApiAdapter):
id=model.id, id=model.id,
provider_id=model.provider_id, provider_id=model.provider_id,
provider_name=provider.name, provider_name=provider.name,
provider_display_name=provider.display_name,
name=unified_name, name=unified_name,
display_name=display_name, display_name=display_name,
description=global_model.config.get("description") if global_model and global_model.config else None, description=global_model.config.get("description") if global_model and global_model.config else None,
@@ -386,7 +399,13 @@ class PublicStatsAdapter(PublicApiAdapter):
active_models = ( active_models = (
db.query(Model) db.query(Model)
.join(Provider) .join(Provider)
.filter(and_(Model.is_active.is_(True), Provider.is_active.is_(True))) .filter(
and_(
Model.is_active.is_(True),
Provider.is_active.is_(True),
Model.global_model_id.isnot(None),
)
)
.count() .count()
) )
formats = ( formats = (
@@ -418,7 +437,13 @@ class PublicSearchModelsAdapter(PublicApiAdapter):
.options(joinedload(Model.global_model)) .options(joinedload(Model.global_model))
.join(Provider) .join(Provider)
.outerjoin(GlobalModel, Model.global_model_id == GlobalModel.id) .outerjoin(GlobalModel, Model.global_model_id == GlobalModel.id)
.filter(and_(Model.is_active.is_(True), Provider.is_active.is_(True))) .filter(
and_(
Model.is_active.is_(True),
Provider.is_active.is_(True),
Model.global_model_id.isnot(None),
)
)
) )
search_filter = ( search_filter = (
Model.provider_model_name.ilike(f"%{self.query}%") Model.provider_model_name.ilike(f"%{self.query}%")
@@ -439,7 +464,6 @@ class PublicSearchModelsAdapter(PublicApiAdapter):
id=model.id, id=model.id,
provider_id=model.provider_id, provider_id=model.provider_id,
provider_name=provider.name, provider_name=provider.name,
provider_display_name=provider.display_name,
name=unified_name, name=unified_name,
display_name=display_name, display_name=display_name,
description=global_model.config.get("description") if global_model and global_model.config else None, description=global_model.config.get("description") if global_model and global_model.config else None,

View File

@@ -43,7 +43,6 @@ def _serialize_provider(
provider_data: Dict[str, Any] = { provider_data: Dict[str, Any] = {
"id": provider.id, "id": provider.id,
"name": provider.name, "name": provider.name,
"display_name": provider.display_name,
"is_active": provider.is_active, "is_active": provider.is_active,
"provider_priority": provider.provider_priority, "provider_priority": provider.provider_priority,
} }

View File

@@ -1023,7 +1023,6 @@ class ListAvailableProvidersAdapter(AuthenticatedApiAdapter):
{ {
"id": provider.id, "id": provider.id,
"name": provider.name, "name": provider.name,
"display_name": provider.display_name,
"description": provider.description, "description": provider.description,
"provider_priority": provider.provider_priority, "provider_priority": provider.provider_priority,
"endpoints": endpoints_data, "endpoints": endpoints_data,

View File

@@ -1,10 +1,18 @@
""" """
全局HTTP客户端池管理 全局HTTP客户端池管理
避免每次请求都创建新的AsyncClient,提高性能 避免每次请求都创建新的AsyncClient,提高性能
性能优化说明:
1. 默认客户端:无代理场景,全局复用单一客户端
2. 代理客户端缓存:相同代理配置复用同一客户端,避免重复创建
3. 连接池复用Keep-alive 连接减少 TCP 握手开销
""" """
import asyncio
import hashlib
import time
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Any, Dict, Optional from typing import Any, Dict, Optional, Tuple
from urllib.parse import quote, urlparse from urllib.parse import quote, urlparse
import httpx import httpx
@@ -12,6 +20,32 @@ import httpx
from src.config import config from src.config import config
from src.core.logger import logger from src.core.logger import logger
# 模块级锁,避免类属性延迟初始化的竞态条件
_proxy_clients_lock = asyncio.Lock()
_default_client_lock = asyncio.Lock()
def _compute_proxy_cache_key(proxy_config: Optional[Dict[str, Any]]) -> str:
"""
计算代理配置的缓存键
Args:
proxy_config: 代理配置字典
Returns:
缓存键字符串,无代理时返回 "__no_proxy__"
"""
if not proxy_config:
return "__no_proxy__"
# 构建代理 URL 作为缓存键的基础
proxy_url = build_proxy_url(proxy_config)
if not proxy_url:
return "__no_proxy__"
# 使用 MD5 哈希来避免过长的键名
return f"proxy:{hashlib.md5(proxy_url.encode()).hexdigest()[:16]}"
def build_proxy_url(proxy_config: Dict[str, Any]) -> Optional[str]: def build_proxy_url(proxy_config: Dict[str, Any]) -> Optional[str]:
""" """
@@ -61,11 +95,20 @@ class HTTPClientPool:
全局HTTP客户端池单例 全局HTTP客户端池单例
管理可重用的httpx.AsyncClient实例,避免频繁创建/销毁连接 管理可重用的httpx.AsyncClient实例,避免频繁创建/销毁连接
性能优化:
1. 默认客户端:无代理场景复用
2. 代理客户端缓存:相同代理配置复用同一客户端
3. LRU 淘汰:代理客户端超过上限时淘汰最久未使用的
""" """
_instance: Optional["HTTPClientPool"] = None _instance: Optional["HTTPClientPool"] = None
_default_client: Optional[httpx.AsyncClient] = None _default_client: Optional[httpx.AsyncClient] = None
_clients: Dict[str, httpx.AsyncClient] = {} _clients: Dict[str, httpx.AsyncClient] = {}
# 代理客户端缓存:{cache_key: (client, last_used_time)}
_proxy_clients: Dict[str, Tuple[httpx.AsyncClient, float]] = {}
# 代理客户端缓存上限(避免内存泄漏)
_max_proxy_clients: int = 50
def __new__(cls): def __new__(cls):
if cls._instance is None: if cls._instance is None:
@@ -73,12 +116,50 @@ class HTTPClientPool:
return cls._instance return cls._instance
@classmethod @classmethod
def get_default_client(cls) -> httpx.AsyncClient: async def get_default_client_async(cls) -> httpx.AsyncClient:
""" """
获取默认的HTTP客户端 获取默认的HTTP客户端(异步线程安全版本)
用于大多数HTTP请求,具有合理的默认配置 用于大多数HTTP请求,具有合理的默认配置
""" """
if cls._default_client is not None:
return cls._default_client
async with _default_client_lock:
# 双重检查,避免重复创建
if cls._default_client is None:
cls._default_client = httpx.AsyncClient(
http2=False, # 暂时禁用HTTP/2以提高兼容性
verify=True, # 启用SSL验证
timeout=httpx.Timeout(
connect=config.http_connect_timeout,
read=config.http_read_timeout,
write=config.http_write_timeout,
pool=config.http_pool_timeout,
),
limits=httpx.Limits(
max_connections=config.http_max_connections,
max_keepalive_connections=config.http_keepalive_connections,
keepalive_expiry=config.http_keepalive_expiry,
),
follow_redirects=True, # 跟随重定向
)
logger.info(
f"全局HTTP客户端池已初始化: "
f"max_connections={config.http_max_connections}, "
f"keepalive={config.http_keepalive_connections}, "
f"keepalive_expiry={config.http_keepalive_expiry}s"
)
return cls._default_client
@classmethod
def get_default_client(cls) -> httpx.AsyncClient:
"""
获取默认的HTTP客户端同步版本向后兼容
⚠️ 注意:此方法在高并发首次调用时可能存在竞态条件,
推荐使用 get_default_client_async() 异步版本。
"""
if cls._default_client is None: if cls._default_client is None:
cls._default_client = httpx.AsyncClient( cls._default_client = httpx.AsyncClient(
http2=False, # 暂时禁用HTTP/2以提高兼容性 http2=False, # 暂时禁用HTTP/2以提高兼容性
@@ -90,13 +171,18 @@ class HTTPClientPool:
pool=config.http_pool_timeout, pool=config.http_pool_timeout,
), ),
limits=httpx.Limits( limits=httpx.Limits(
max_connections=100, # 最大连接数 max_connections=config.http_max_connections,
max_keepalive_connections=20, # 最大保活连接数 max_keepalive_connections=config.http_keepalive_connections,
keepalive_expiry=30.0, # 保活过期时间(秒) keepalive_expiry=config.http_keepalive_expiry,
), ),
follow_redirects=True, # 跟随重定向 follow_redirects=True, # 跟随重定向
) )
logger.info("全局HTTP客户端池已初始化") logger.info(
f"全局HTTP客户端池已初始化: "
f"max_connections={config.http_max_connections}, "
f"keepalive={config.http_keepalive_connections}, "
f"keepalive_expiry={config.http_keepalive_expiry}s"
)
return cls._default_client return cls._default_client
@classmethod @classmethod
@@ -130,6 +216,101 @@ class HTTPClientPool:
return cls._clients[name] return cls._clients[name]
@classmethod
def _get_proxy_clients_lock(cls) -> asyncio.Lock:
"""获取代理客户端缓存锁(模块级单例,避免竞态条件)"""
return _proxy_clients_lock
@classmethod
async def _evict_lru_proxy_client(cls) -> None:
"""淘汰最久未使用的代理客户端"""
if len(cls._proxy_clients) < cls._max_proxy_clients:
return
# 找到最久未使用的客户端
oldest_key = min(cls._proxy_clients.keys(), key=lambda k: cls._proxy_clients[k][1])
old_client, _ = cls._proxy_clients.pop(oldest_key)
# 异步关闭旧客户端
try:
await old_client.aclose()
logger.debug(f"淘汰代理客户端: {oldest_key}")
except Exception as e:
logger.warning(f"关闭代理客户端失败: {e}")
@classmethod
async def get_proxy_client(
cls,
proxy_config: Optional[Dict[str, Any]] = None,
) -> httpx.AsyncClient:
"""
获取代理客户端(带缓存复用)
相同代理配置会复用同一个客户端,大幅减少连接建立开销。
注意:返回的客户端使用默认超时配置,如需自定义超时请在请求时传递 timeout 参数。
Args:
proxy_config: 代理配置字典,包含 url, username, password
Returns:
可复用的 httpx.AsyncClient 实例
"""
cache_key = _compute_proxy_cache_key(proxy_config)
# 无代理时返回默认客户端
if cache_key == "__no_proxy__":
return await cls.get_default_client_async()
lock = cls._get_proxy_clients_lock()
async with lock:
# 检查缓存
if cache_key in cls._proxy_clients:
client, _ = cls._proxy_clients[cache_key]
# 健康检查:如果客户端已关闭,移除并重新创建
if client.is_closed:
del cls._proxy_clients[cache_key]
logger.debug(f"代理客户端已关闭,将重新创建: {cache_key}")
else:
# 更新最后使用时间
cls._proxy_clients[cache_key] = (client, time.time())
return client
# 淘汰旧客户端(如果超过上限)
await cls._evict_lru_proxy_client()
# 创建新客户端(使用默认超时,请求时可覆盖)
client_config: Dict[str, Any] = {
"http2": False,
"verify": True,
"follow_redirects": True,
"limits": httpx.Limits(
max_connections=config.http_max_connections,
max_keepalive_connections=config.http_keepalive_connections,
keepalive_expiry=config.http_keepalive_expiry,
),
"timeout": httpx.Timeout(
connect=config.http_connect_timeout,
read=config.http_read_timeout,
write=config.http_write_timeout,
pool=config.http_pool_timeout,
),
}
# 添加代理配置
proxy_url = build_proxy_url(proxy_config) if proxy_config else None
if proxy_url:
client_config["proxy"] = proxy_url
client = httpx.AsyncClient(**client_config)
cls._proxy_clients[cache_key] = (client, time.time())
logger.debug(
f"创建代理客户端(缓存): {proxy_config.get('url', 'unknown') if proxy_config else 'none'}, "
f"缓存数量: {len(cls._proxy_clients)}"
)
return client
@classmethod @classmethod
async def close_all(cls): async def close_all(cls):
"""关闭所有HTTP客户端""" """关闭所有HTTP客户端"""
@@ -143,6 +324,16 @@ class HTTPClientPool:
logger.debug(f"命名HTTP客户端已关闭: {name}") logger.debug(f"命名HTTP客户端已关闭: {name}")
cls._clients.clear() cls._clients.clear()
# 关闭代理客户端缓存
for cache_key, (client, _) in cls._proxy_clients.items():
try:
await client.aclose()
logger.debug(f"代理客户端已关闭: {cache_key}")
except Exception as e:
logger.warning(f"关闭代理客户端失败: {e}")
cls._proxy_clients.clear()
logger.info("所有HTTP客户端已关闭") logger.info("所有HTTP客户端已关闭")
@classmethod @classmethod
@@ -185,13 +376,15 @@ class HTTPClientPool:
""" """
创建带代理配置的HTTP客户端 创建带代理配置的HTTP客户端
⚠️ 性能警告:此方法每次都创建新客户端,推荐使用 get_proxy_client() 复用连接。
Args: Args:
proxy_config: 代理配置字典,包含 url, username, password proxy_config: 代理配置字典,包含 url, username, password
timeout: 超时配置 timeout: 超时配置
**kwargs: 其他 httpx.AsyncClient 配置参数 **kwargs: 其他 httpx.AsyncClient 配置参数
Returns: Returns:
配置好的 httpx.AsyncClient 实例 配置好的 httpx.AsyncClient 实例(调用者需要负责关闭)
""" """
client_config: Dict[str, Any] = { client_config: Dict[str, Any] = {
"http2": False, "http2": False,
@@ -213,11 +406,21 @@ class HTTPClientPool:
proxy_url = build_proxy_url(proxy_config) if proxy_config else None proxy_url = build_proxy_url(proxy_config) if proxy_config else None
if proxy_url: if proxy_url:
client_config["proxy"] = proxy_url client_config["proxy"] = proxy_url
logger.debug(f"创建带代理的HTTP客户端: {proxy_config.get('url', 'unknown')}") logger.debug(f"创建带代理的HTTP客户端(一次性): {proxy_config.get('url', 'unknown')}")
client_config.update(kwargs) client_config.update(kwargs)
return httpx.AsyncClient(**client_config) return httpx.AsyncClient(**client_config)
@classmethod
def get_pool_stats(cls) -> Dict[str, Any]:
"""获取连接池统计信息"""
return {
"default_client_active": cls._default_client is not None,
"named_clients_count": len(cls._clients),
"proxy_clients_count": len(cls._proxy_clients),
"max_proxy_clients": cls._max_proxy_clients,
}
# 便捷访问函数 # 便捷访问函数
def get_http_client() -> httpx.AsyncClient: def get_http_client() -> httpx.AsyncClient:

View File

@@ -52,19 +52,33 @@ class StreamDefaults:
MAX_PREFETCH_BYTES = 64 * 1024 # 64KB MAX_PREFETCH_BYTES = 64 * 1024 # 64KB
class ConcurrencyDefaults: class RPMDefaults:
"""并发控制默认值 """RPM每分钟请求数制默认值
算法说明:边界记忆 + 渐进探测 算法说明:边界记忆 + 渐进探测
- 触发 429 时记录边界last_concurrent_peak新限制 = 边界 - 1 - 触发 429 时记录边界last_rpm_peak新限制 = 边界 - 1
- 扩容时不超过边界,除非是探测性扩容(长时间无 429 - 扩容时不超过边界,除非是探测性扩容(长时间无 429
- 这样可以快速收敛到真实限制附近,避免过度保守 - 这样可以快速收敛到真实限制附近,避免过度保守
初始值 50 RPM
- 系统会根据实际使用自动调整
""" """
# 自适应并发初始限制(宽松起步,遇到 429 再降低) # 自适应 RPM 初始限制
INITIAL_LIMIT = 50 INITIAL_LIMIT = 50 # 每分钟 50 次请求
# 429错误后的冷却时间分钟- 在此期间不会增加并发限制 # === 内存模式 RPM 计数器配置 ===
# 内存模式下的最大条目限制(防止内存泄漏)
# 每个条目约占 100 字节10000 条目 = ~1MB
# 计算依据1000 Key × 5 API 格式 × 2 (buffer) = 10000
# 可通过环境变量 RPM_MAX_MEMORY_ENTRIES 覆盖
MAX_MEMORY_RPM_ENTRIES = 10000
# 内存使用告警阈值(达到此比例时记录警告日志)
# 可通过环境变量 RPM_MEMORY_WARNING_THRESHOLD 覆盖
MEMORY_WARNING_THRESHOLD = 0.6 # 60%
# 429错误后的冷却时间分钟- 在此期间不会增加 RPM 限制
COOLDOWN_AFTER_429_MINUTES = 5 COOLDOWN_AFTER_429_MINUTES = 5
# 探测间隔上限(分钟)- 用于长期探测策略 # 探测间隔上限(分钟)- 用于长期探测策略
@@ -86,30 +100,30 @@ class ConcurrencyDefaults:
# 最小采样数 - 窗口内至少需要这么多采样才能做出扩容决策 # 最小采样数 - 窗口内至少需要这么多采样才能做出扩容决策
MIN_SAMPLES_FOR_DECISION = 5 MIN_SAMPLES_FOR_DECISION = 5
# 扩容步长 - 每次扩容增加的并发数 # 扩容步长 - 每次扩容增加的 RPM
INCREASE_STEP = 2 INCREASE_STEP = 5 # 每次增加 5 RPM
# 最大并发限制上限 # 最大 RPM 限制上限(不设上限,让系统自适应学习)
MAX_CONCURRENT_LIMIT = 200 MAX_RPM_LIMIT = 10000
# 最小并发限制下限 # 最小 RPM 限制下限
# 设置为 3 而不是 1因为预留机制10%预留给缓存用户)会导致 MIN_RPM_LIMIT = 5
# 当 learned_max_concurrent=1 时新用户实际可用槽位为 0永远无法命中
# 注意:当 limit < 10 时,预留机制实际不生效(预留槽位 = 0这是可接受的 # 缓存用户预留比例(默认 10%,新用户可用 90%
MIN_CONCURRENT_LIMIT = 3 # 已被动态预留机制 (AdaptiveReservationDefaults) 替代,保留用于向后兼容
CACHE_RESERVATION_RATIO = 0.1
# === 探测性扩容参数 === # === 探测性扩容参数 ===
# 探测性扩容间隔(分钟)- 长时间无 429 且有流量时尝试扩容 # 探测性扩容间隔(分钟)- 长时间无 429 且有流量时尝试扩容
# 探测性扩容可以突破已知边界,尝试更高的并发 # 探测性扩容可以突破已知边界,尝试更高的 RPM
PROBE_INCREASE_INTERVAL_MINUTES = 30 PROBE_INCREASE_INTERVAL_MINUTES = 30
# 探测性扩容最小请求数 - 在探测间隔内至少需要这么多请求 # 探测性扩容最小请求数 - 在探测间隔内至少需要这么多请求
PROBE_INCREASE_MIN_REQUESTS = 10 PROBE_INCREASE_MIN_REQUESTS = 10
# === 缓存用户预留比例 ===
# 缓存用户槽位预留比例(新用户可用 1 - 此值) # 向后兼容别名
# 0.1 表示缓存用户预留 10%,新用户可用 90% ConcurrencyDefaults = RPMDefaults
CACHE_RESERVATION_RATIO = 0.1
class CircuitBreakerDefaults: class CircuitBreakerDefaults:
@@ -193,10 +207,19 @@ class AdaptiveReservationDefaults:
class TimeoutDefaults: class TimeoutDefaults:
"""超时配置默认值(秒)""" """超时配置默认值(秒)
# HTTP 请求默认超时 超时配置说明:
HTTP_REQUEST = 300 # 5分钟 - 全局默认值和 Provider 默认值统一为 120 秒
- 120 秒是 LLM API 的合理默认值:
* 大多数请求在 30 秒内完成
* 复杂推理(如 Claude extended thinking可能需要 60-90 秒
* 120 秒足够覆盖大部分场景,同时避免线程池被长时间占用
- 如需更长超时,可在 Provider 级别单独配置
"""
# HTTP 请求默认超时(与 Provider 默认值保持一致)
HTTP_REQUEST = 120 # 2分钟
# 数据库连接池获取超时 # 数据库连接池获取超时
DB_POOL = 30 DB_POOL = 30

View File

@@ -145,6 +145,24 @@ class Config:
self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0")) self.http_write_timeout = float(os.getenv("HTTP_WRITE_TIMEOUT", "60.0"))
self.http_pool_timeout = float(os.getenv("HTTP_POOL_TIMEOUT", "10.0")) self.http_pool_timeout = float(os.getenv("HTTP_POOL_TIMEOUT", "10.0"))
# HTTP 连接池配置
# HTTP_MAX_CONNECTIONS: 最大连接数,影响并发能力
# - 每个连接占用一个 socket过多会耗尽系统资源
# - 默认根据 Worker 数量自动计算:单 Worker 200多 Worker 按比例分配
# HTTP_KEEPALIVE_CONNECTIONS: 保活连接数,影响连接复用效率
# - 高频请求场景应该增大此值
# - 默认为 max_connections 的 30%(长连接场景更高效)
# HTTP_KEEPALIVE_EXPIRY: 保活过期时间(秒)
# - 过短会频繁重建连接,过长会占用资源
# - 默认 30 秒,生图等长连接场景可适当增大
self.http_max_connections = int(
os.getenv("HTTP_MAX_CONNECTIONS") or self._auto_http_max_connections()
)
self.http_keepalive_connections = int(
os.getenv("HTTP_KEEPALIVE_CONNECTIONS") or self._auto_http_keepalive_connections()
)
self.http_keepalive_expiry = float(os.getenv("HTTP_KEEPALIVE_EXPIRY", "30.0"))
# 流式处理配置 # 流式处理配置
# STREAM_PREFETCH_LINES: 预读行数,用于检测嵌套错误 # STREAM_PREFETCH_LINES: 预读行数,用于检测嵌套错误
# STREAM_STATS_DELAY: 统计记录延迟(秒),等待流完全关闭 # STREAM_STATS_DELAY: 统计记录延迟(秒),等待流完全关闭
@@ -224,6 +242,53 @@ class Config:
"""智能计算最大溢出连接数 - 与 pool_size 相同""" """智能计算最大溢出连接数 - 与 pool_size 相同"""
return self.db_pool_size return self.db_pool_size
def _auto_http_max_connections(self) -> int:
"""
智能计算 HTTP 最大连接数
计算依据:
1. 系统 socket 资源有限Linux 默认 ulimit -n 通常为 1024
2. 多 Worker 部署时每个进程独立连接池
3. 需要为数据库连接、Redis 连接等预留资源
公式: base_connections / workers
- 单 Worker: 200 连接(适合开发/低负载)
- 多 Worker: 按比例分配,确保总数不超过系统限制
范围: 50 - 500
"""
# 基础连接数:假设系统可用 socket 约 800 个用于 HTTP
# (预留给 DB、Redis、内部服务等
base_connections = 800
workers = max(self.worker_processes, 1)
# 每个 Worker 分配的连接数
per_worker = base_connections // workers
# 限制范围:最小 50保证基本并发最大 500避免资源耗尽
return max(50, min(per_worker, 500))
def _auto_http_keepalive_connections(self) -> int:
"""
智能计算 HTTP 保活连接数
计算依据:
1. 保活连接用于复用,减少 TCP 握手开销
2. 对于 API 网关场景,上游请求频繁,保活比例应较高
3. 生图等长连接场景,连接会被长时间占用
公式: max_connections * 0.3
- 30% 的比例在复用效率和资源占用间取得平衡
- 长连接场景建议手动调高到 50-70%
范围: 10 - max_connections
"""
# 保活连接数为最大连接数的 30%
keepalive = int(self.http_max_connections * 0.3)
# 最小 10 个保活连接,最大不超过 max_connections
return max(10, min(keepalive, self.http_max_connections))
def _parse_ttfb_timeout(self) -> float: def _parse_ttfb_timeout(self) -> float:
""" """
解析 TTFB 超时配置,带错误处理和范围限制 解析 TTFB 超时配置,带错误处理和范围限制

View File

@@ -36,6 +36,12 @@ class CacheService:
try: try:
return json.loads(value) return json.loads(value)
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
# 非 JSON 值:统一返回字符串,避免上层出现 bytes/str 混用
if isinstance(value, (bytes, bytearray)):
try:
return value.decode("utf-8")
except Exception:
return value
return value return value
return None return None
@@ -98,6 +104,48 @@ class CacheService:
logger.warning(f"缓存删除失败: {key} - {e}") logger.warning(f"缓存删除失败: {key} - {e}")
return False return False
@staticmethod
async def delete_pattern(pattern: str, batch_size: int = 100) -> int:
"""
删除匹配模式的所有缓存
使用 SCAN 遍历并分批删除,避免阻塞 Redis
Args:
pattern: 缓存键模式(支持 * 通配符)
batch_size: 每批删除的最大键数量
Returns:
删除的键数量
"""
try:
redis = await get_redis_client(require_redis=False)
if not redis:
return 0
# 使用 SCAN 遍历匹配的键
deleted_count = 0
cursor = 0
while True:
cursor, keys = await redis.scan(cursor, match=pattern, count=batch_size)
if keys:
# 分批删除,避免单次删除过多键导致 Redis 阻塞
for i in range(0, len(keys), batch_size):
batch = keys[i : i + batch_size]
await redis.delete(*batch)
deleted_count += len(batch)
if cursor == 0:
break
if deleted_count > 0:
logger.debug(f"缓存模式删除成功: {pattern}, 删除 {deleted_count} 个键")
return deleted_count
except Exception as e:
logger.warning(f"缓存模式删除失败: {pattern} - {e}")
return 0
@staticmethod @staticmethod
async def exists(key: str) -> bool: async def exists(key: str) -> bool:
""" """

View File

@@ -7,16 +7,19 @@ from typing import Optional
def extract_error_message(error: Exception, status_code: Optional[int] = None) -> str: def extract_error_message(error: Exception, status_code: Optional[int] = None) -> str:
""" """
从异常中提取错误消息,优先使用上游响应内容 从异常中提取错误消息,优先使用上游原始响应(用于链路追踪/调试)
此函数用于 RequestCandidate 表的 error_message 字段,
用于请求链路追踪中显示原始 Provider 响应。
Args: Args:
error: 异常对象 error: 异常对象
status_code: 可选的 HTTP 状态码,用于构建更详细的错误消息 status_code: 可选的 HTTP 状态码,用于构建更详细的错误消息
Returns: Returns:
错误消息字符串 错误消息字符串(原始 Provider 响应)
""" """
# 优先使用 upstream_response 属性(包含上游 Provider 的原始错误) # 优先使用 upstream_response 属性(包含上游 Provider 的原始错误,用于调试
upstream_response = getattr(error, "upstream_response", None) upstream_response = getattr(error, "upstream_response", None)
if upstream_response and isinstance(upstream_response, str) and upstream_response.strip(): if upstream_response and isinstance(upstream_response, str) and upstream_response.strip():
return str(upstream_response) return str(upstream_response)
@@ -26,3 +29,25 @@ def extract_error_message(error: Exception, status_code: Optional[int] = None) -
if status_code is not None: if status_code is not None:
return f"HTTP {status_code}: {error_str}" return f"HTTP {status_code}: {error_str}"
return error_str return error_str
def extract_client_error_message(error: Exception) -> str:
"""
从异常中提取客户端友好的错误消息(用于返回给客户端/Usage 记录)
此函数用于 Usage 表的 error_message 字段,
用于显示给最终用户的友好错误消息。
Args:
error: 异常对象
Returns:
友好的错误消息字符串
"""
# 优先使用 message 属性(已经是友好处理过的消息)
message = getattr(error, "message", None)
if message and isinstance(message, str) and message.strip():
return message
# 回退到异常的字符串表示
return str(error) or repr(error)

View File

@@ -205,7 +205,7 @@ class ProviderTimeoutException(ProviderException):
def __init__(self, provider_name: str, timeout: int, request_metadata: Optional[Any] = None): def __init__(self, provider_name: str, timeout: int, request_metadata: Optional[Any] = None):
super().__init__( super().__init__(
message=f"提供商 '{provider_name}' 请求超时({timeout}秒)", message=f"请求超时({timeout}秒)",
provider_name=provider_name, provider_name=provider_name,
request_metadata=request_metadata, request_metadata=request_metadata,
timeout=timeout, timeout=timeout,
@@ -217,7 +217,7 @@ class ProviderAuthException(ProviderException):
def __init__(self, provider_name: str, request_metadata: Optional[Any] = None): def __init__(self, provider_name: str, request_metadata: Optional[Any] = None):
super().__init__( super().__init__(
message=f"提供商 '{provider_name}' 认证失败请检查API密钥", message="上游服务认证失败",
provider_name=provider_name, provider_name=provider_name,
request_metadata=request_metadata, request_metadata=request_metadata,
) )
@@ -292,9 +292,8 @@ class ModelNotSupportedException(ProxyException):
"""模型不支持""" """模型不支持"""
def __init__(self, model: str, provider_name: Optional[str] = None): def __init__(self, model: str, provider_name: Optional[str] = None):
# 客户端消息不暴露提供商信息
message = f"模型 '{model}' 不受支持" message = f"模型 '{model}' 不受支持"
if provider_name:
message = f"提供商 '{provider_name}' 不支持模型 '{model}'"
super().__init__( super().__init__(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
error_type="model_not_supported", error_type="model_not_supported",
@@ -307,10 +306,8 @@ class StreamingNotSupportedException(ProxyException):
"""流式请求不支持""" """流式请求不支持"""
def __init__(self, model: str, provider_name: Optional[str] = None): def __init__(self, model: str, provider_name: Optional[str] = None):
if provider_name: # 客户端消息不暴露提供商信息
message = f"模型 '{model}' 在提供商 '{provider_name}'不支持流式请求" message = f"模型 '{model}' 不支持流式请求"
else:
message = f"模型 '{model}' 不支持流式请求"
super().__init__( super().__init__(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
error_type="streaming_not_supported", error_type="streaming_not_supported",
@@ -389,7 +386,7 @@ class JSONParseException(ProviderException):
details["response_content"] = response_content details["response_content"] = response_content
super().__init__( super().__init__(
message=f"提供商 '{provider_name}' 返回了无效的JSON响应", message="上游服务返回了无效的响应",
provider_name=provider_name, provider_name=provider_name,
request_metadata=request_metadata, request_metadata=request_metadata,
**details, **details,
@@ -406,7 +403,7 @@ class EmptyStreamException(ProviderException):
request_metadata: Optional[Any] = None, request_metadata: Optional[Any] = None,
): ):
super().__init__( super().__init__(
message=f"提供商 '{provider_name}' 返回了空的流式响应status=200 但无数据)", message="上游服务返回了空的流式响应",
provider_name=provider_name, provider_name=provider_name,
request_metadata=request_metadata, request_metadata=request_metadata,
chunk_count=chunk_count, chunk_count=chunk_count,
@@ -428,11 +425,10 @@ class EmbeddedErrorException(ProviderException):
error_status: Optional[str] = None, error_status: Optional[str] = None,
request_metadata: Optional[Any] = None, request_metadata: Optional[Any] = None,
): ):
message = f"提供商 '{provider_name}' 返回了嵌套错误" # 客户端消息不暴露提供商信息
message = "上游服务返回了错误"
if error_code: if error_code:
message += f" (code={error_code})" message += f" (code={error_code})"
if error_message:
message += f": {error_message}"
super().__init__( super().__init__(
message=message, message=message,
@@ -549,12 +545,14 @@ class ErrorResponse:
if isinstance(e, ProxyException): if isinstance(e, ProxyException):
details = e.details.copy() if e.details else {} details = e.details.copy() if e.details else {}
status_code = e.status_code status_code = e.status_code
message = e.message message = e.message # 使用友好的错误消息
# 如果是 ProviderNotAvailableException 且有上游错误,直接透传上游信息 # 如果是 ProviderNotAvailableException 且有上游错误信息
if isinstance(e, ProviderNotAvailableException) and e.upstream_response: if isinstance(e, ProviderNotAvailableException):
if e.upstream_status: if e.upstream_status:
status_code = e.upstream_status status_code = e.upstream_status
message = e.upstream_response # upstream_response 存入 details 供请求链路追踪使用,不作为客户端消息
if e.upstream_response:
details["upstream_response"] = e.upstream_response
return ErrorResponse.create( return ErrorResponse.create(
error_type=e.error_type, error_type=e.error_type,
message=message, message=message,

View File

@@ -0,0 +1,286 @@
"""
模型权限工具
支持两种 allowed_models 格式:
1. 简单模式(列表): ["claude-sonnet-4", "gpt-4o"]
2. 按格式模式(字典): {"OPENAI": ["gpt-4o"], "CLAUDE": ["claude-sonnet-4"]}
使用 None/null 表示不限制(允许所有模型)
"""
from typing import Any, Dict, List, Optional, Set, Union
# 类型别名
AllowedModels = Optional[Union[List[str], Dict[str, List[str]]]]
def normalize_allowed_models(
allowed_models: AllowedModels,
api_format: Optional[str] = None,
) -> Optional[Set[str]]:
"""
将 allowed_models 规范化为模型名称集合
Args:
allowed_models: 允许的模型配置(列表或字典)
api_format: 当前请求的 API 格式(用于字典模式)
Returns:
- None: 不限制(允许所有模型)
- Set[str]: 允许的模型名称集合(可能为空集,表示拒绝所有)
"""
if allowed_models is None:
return None
# 简单模式:直接是列表
if isinstance(allowed_models, list):
return set(allowed_models)
# 按格式模式:字典
if isinstance(allowed_models, dict):
if api_format is None:
# 没有指定格式,合并所有格式的模型
all_models: Set[str] = set()
for models in allowed_models.values():
if isinstance(models, list):
all_models.update(models)
return all_models if all_models else None
# 查找指定格式的模型列表
api_format_upper = api_format.upper()
models = allowed_models.get(api_format_upper)
if models is None:
# 该格式未配置,检查是否有通配符 "*"
models = allowed_models.get("*")
if models is None:
# 字典模式下未配置的格式 = 不限制该格式
return None
return set(models) if isinstance(models, list) else None
# 未知类型,视为不限制
return None
def check_model_allowed(
model_name: str,
allowed_models: AllowedModels,
api_format: Optional[str] = None,
resolved_model_name: Optional[str] = None,
) -> bool:
"""
检查模型是否被允许
Args:
model_name: 请求的模型名称
allowed_models: 允许的模型配置
api_format: 当前请求的 API 格式
resolved_model_name: 解析后的 GlobalModel.name可选
Returns:
True: 允许使用该模型
False: 不允许使用该模型
"""
allowed_set = normalize_allowed_models(allowed_models, api_format)
if allowed_set is None:
# 不限制
return True
if len(allowed_set) == 0:
# 空集合 = 拒绝所有
return False
# 检查请求的模型名或解析后的名称是否在白名单中
if model_name in allowed_set:
return True
if resolved_model_name and resolved_model_name in allowed_set:
return True
return False
def merge_allowed_models(
allowed_models_1: AllowedModels,
allowed_models_2: AllowedModels,
) -> AllowedModels:
"""
合并两个 allowed_models 配置,取交集
规则:
- 如果任一为 None返回另一个
- 如果都有值,取交集
- 如果都是列表,取列表交集
- 如果有字典,按 API 格式分别取交集(保持字典语义,不丢失格式区分信息)
Args:
allowed_models_1: 第一个配置
allowed_models_2: 第二个配置
Returns:
合并后的配置
"""
if allowed_models_1 is None:
return allowed_models_2
if allowed_models_2 is None:
return allowed_models_1
# 两个都是简单列表:直接取交集(返回确定性顺序)
if isinstance(allowed_models_1, list) and isinstance(allowed_models_2, list):
intersection = set(allowed_models_1) & set(allowed_models_2)
return sorted(intersection) if intersection else []
# 任一为字典模式:按 API 格式分别取交集,避免把 dict 合并成 list 导致权限过宽
from src.core.enums import APIFormat
def merge_sets(a: Optional[Set[str]], b: Optional[Set[str]]) -> Optional[Set[str]]:
# None 表示不限制:交集规则下等价于“只受另一方限制”
if a is None:
return b
if b is None:
return a
return a & b
known_formats = [fmt.value for fmt in APIFormat]
per_format: Dict[str, Optional[Set[str]]] = {}
for fmt in known_formats:
s1 = normalize_allowed_models(allowed_models_1, api_format=fmt)
s2 = normalize_allowed_models(allowed_models_2, api_format=fmt)
per_format[fmt] = merge_sets(s1, s2)
# 计算默认(未知格式)的交集,用 "*" 作为默认值以覆盖未枚举的格式
default_s1 = normalize_allowed_models(allowed_models_1, api_format="__DEFAULT__")
default_s2 = normalize_allowed_models(allowed_models_2, api_format="__DEFAULT__")
default_set = merge_sets(default_s1, default_s2)
# 如果 default_set 非 None 且不存在“某些格式不限制”的情况,可用 "*" 作为默认规则并按需覆盖
can_use_wildcard = default_set is not None and all(v is not None for v in per_format.values())
merged_dict: Dict[str, List[str]] = {}
if can_use_wildcard and default_set is not None:
merged_dict["*"] = sorted(default_set)
for fmt, s in per_format.items():
# can_use_wildcard 保证 s 非 None
if s is not None and s != default_set:
merged_dict[fmt] = sorted(s)
else:
for fmt, s in per_format.items():
if s is None:
continue
merged_dict[fmt] = sorted(s)
if not merged_dict:
# 全部不限制
return None
return merged_dict
def get_allowed_models_preview(
allowed_models: AllowedModels,
max_items: int = 3,
) -> str:
"""
获取 allowed_models 的预览字符串(用于日志和错误消息)
Args:
allowed_models: 允许的模型配置
max_items: 最多显示的模型数
Returns:
预览字符串,如 "gpt-4o, claude-sonnet-4, ..."
"""
if allowed_models is None:
return "(不限制)"
all_models: Set[str] = set()
if isinstance(allowed_models, list):
all_models = set(allowed_models)
elif isinstance(allowed_models, dict):
for models in allowed_models.values():
if isinstance(models, list):
all_models.update(models)
if not all_models:
return "(无)"
sorted_models = sorted(all_models)
preview = ", ".join(sorted_models[:max_items])
if len(sorted_models) > max_items:
preview += f", ...共{len(sorted_models)}"
return preview
def is_format_mode(allowed_models: AllowedModels) -> bool:
"""
判断 allowed_models 是否为按格式模式
Args:
allowed_models: 允许的模型配置
Returns:
True: 按格式模式(字典)
False: 简单模式(列表或 None
"""
return isinstance(allowed_models, dict)
def convert_to_format_mode(
allowed_models: AllowedModels,
api_formats: Optional[List[str]] = None,
) -> Dict[str, List[str]]:
"""
将 allowed_models 转换为按格式模式
Args:
allowed_models: 原始配置
api_formats: 要应用的 API 格式列表
Returns:
按格式模式的配置
"""
if allowed_models is None:
return {}
if isinstance(allowed_models, dict):
return allowed_models
# 简单列表模式 -> 按格式模式
if isinstance(allowed_models, list):
if not api_formats:
return {"*": allowed_models}
return {fmt.upper(): list(allowed_models) for fmt in api_formats}
return {}
def convert_to_simple_mode(allowed_models: AllowedModels) -> Optional[List[str]]:
"""
将 allowed_models 转换为简单列表模式
Args:
allowed_models: 原始配置
Returns:
简单列表或 None
"""
if allowed_models is None:
return None
if isinstance(allowed_models, list):
return allowed_models
if isinstance(allowed_models, dict):
all_models: Set[str] = set()
for models in allowed_models.values():
if isinstance(models, list):
all_models.update(models)
return sorted(all_models) if all_models else None
return None

View File

@@ -369,7 +369,6 @@ def init_db():
_ensure_engine() _ensure_engine()
# 数据库表结构由 Alembic 迁移管理 # 数据库表结构由 Alembic 迁移管理
# 首次部署或更新后请运行: ./migrate.sh
db = _SessionLocal() db = _SessionLocal()
try: try:

View File

@@ -52,15 +52,31 @@ class ProxyConfig(BaseModel):
class CreateProviderRequest(BaseModel): class CreateProviderRequest(BaseModel):
"""创建 Provider 请求""" """创建 Provider 请求"""
name: str = Field( name: str = Field(..., min_length=1, max_length=100, description="提供商名称(唯一)")
...,
min_length=1,
max_length=100,
description="Provider 名称(英文字母、数字、下划线、连字符)",
)
display_name: str = Field(..., min_length=1, max_length=100, description="显示名称")
description: Optional[str] = Field(None, max_length=1000, description="描述") description: Optional[str] = Field(None, max_length=1000, description="描述")
website: Optional[str] = Field(None, max_length=500, description="官网地址") website: Optional[str] = Field(None, max_length=500, description="官网地址")
@field_validator("name")
@classmethod
def validate_name(cls, v: str) -> str:
"""验证名称格式,防止注入攻击"""
v = v.strip()
# 只允许安全的字符:字母、数字、下划线、连字符、空格、中文
if not re.match(r"^[\w\s\u4e00-\u9fff-]+$", v):
raise ValueError("名称只能包含字母、数字、下划线、连字符、空格和中文")
# 检查 SQL 注入关键字(不区分大小写)
sql_keywords = [
"SELECT", "INSERT", "UPDATE", "DELETE", "DROP", "CREATE",
"ALTER", "TRUNCATE", "UNION", "EXEC", "EXECUTE", "--", "/*", "*/"
]
v_upper = v.upper()
for keyword in sql_keywords:
if keyword in v_upper:
raise ValueError(f"名称包含非法关键字: {keyword}")
return v
billing_type: Optional[str] = Field( billing_type: Optional[str] = Field(
ProviderBillingType.PAY_AS_YOU_GO.value, description="计费类型" ProviderBillingType.PAY_AS_YOU_GO.value, description="计费类型"
) )
@@ -68,47 +84,16 @@ class CreateProviderRequest(BaseModel):
quota_reset_day: Optional[int] = Field(30, ge=1, le=365, description="配额重置周期(天数)") quota_reset_day: Optional[int] = Field(30, ge=1, le=365, description="配额重置周期(天数)")
quota_last_reset_at: Optional[datetime] = Field(None, description="当前周期开始时间") quota_last_reset_at: Optional[datetime] = Field(None, description="当前周期开始时间")
quota_expires_at: Optional[datetime] = Field(None, description="配额过期时间") quota_expires_at: Optional[datetime] = Field(None, description="配额过期时间")
rpm_limit: Optional[int] = Field(None, ge=0, description="RPM 限制")
provider_priority: Optional[int] = Field(100, ge=0, le=1000, description="提供商优先级(数字越小越优先)") provider_priority: Optional[int] = Field(100, ge=0, le=1000, description="提供商优先级(数字越小越优先)")
is_active: Optional[bool] = Field(True, description="是否启用") is_active: Optional[bool] = Field(True, description="是否启用")
concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制") concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制")
# 请求配置(从 Endpoint 迁移)
timeout: Optional[int] = Field(300, ge=1, le=600, description="请求超时(秒)")
max_retries: Optional[int] = Field(2, ge=0, le=10, description="最大重试次数")
proxy: Optional[ProxyConfig] = Field(None, description="代理配置")
config: Optional[Dict[str, Any]] = Field(None, description="其他配置") config: Optional[Dict[str, Any]] = Field(None, description="其他配置")
@field_validator("name") @field_validator("name", "description")
@classmethod
def validate_name(cls, v: str) -> str:
"""验证名称格式"""
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
raise ValueError("名称只能包含英文字母、数字、下划线和连字符")
# SQL 注入防护:检查危险关键字
dangerous_keywords = [
"SELECT",
"INSERT",
"UPDATE",
"DELETE",
"DROP",
"CREATE",
"ALTER",
"EXEC",
"UNION",
"OR",
"AND",
"--",
";",
"'",
'"',
"<",
">",
]
upper_name = v.upper()
for keyword in dangerous_keywords:
if keyword in upper_name:
raise ValueError(f"名称包含禁止的字符或关键字: {keyword}")
return v
@field_validator("display_name", "description")
@classmethod @classmethod
def sanitize_text(cls, v: Optional[str]) -> Optional[str]: def sanitize_text(cls, v: Optional[str]) -> Optional[str]:
"""清理文本输入,防止 XSS""" """清理文本输入,防止 XSS"""
@@ -162,7 +147,7 @@ class CreateProviderRequest(BaseModel):
class UpdateProviderRequest(BaseModel): class UpdateProviderRequest(BaseModel):
"""更新 Provider 请求""" """更新 Provider 请求"""
display_name: Optional[str] = Field(None, min_length=1, max_length=100) name: Optional[str] = Field(None, min_length=1, max_length=100)
description: Optional[str] = Field(None, max_length=1000) description: Optional[str] = Field(None, max_length=1000)
website: Optional[str] = Field(None, max_length=500) website: Optional[str] = Field(None, max_length=500)
billing_type: Optional[str] = None billing_type: Optional[str] = None
@@ -170,14 +155,17 @@ class UpdateProviderRequest(BaseModel):
quota_reset_day: Optional[int] = Field(None, ge=1, le=365) quota_reset_day: Optional[int] = Field(None, ge=1, le=365)
quota_last_reset_at: Optional[datetime] = None quota_last_reset_at: Optional[datetime] = None
quota_expires_at: Optional[datetime] = None quota_expires_at: Optional[datetime] = None
rpm_limit: Optional[int] = Field(None, ge=0)
provider_priority: Optional[int] = Field(None, ge=0, le=1000) provider_priority: Optional[int] = Field(None, ge=0, le=1000)
is_active: Optional[bool] = None is_active: Optional[bool] = None
concurrent_limit: Optional[int] = Field(None, ge=0) concurrent_limit: Optional[int] = Field(None, ge=0)
# 请求配置(从 Endpoint 迁移)
timeout: Optional[int] = Field(None, ge=1, le=600, description="请求超时(秒)")
max_retries: Optional[int] = Field(None, ge=0, le=10, description="最大重试次数")
proxy: Optional[ProxyConfig] = Field(None, description="代理配置")
config: Optional[Dict[str, Any]] = None config: Optional[Dict[str, Any]] = None
# 复用相同的验证器 # 复用相同的验证器
_sanitize_text = field_validator("display_name", "description")( _sanitize_text = field_validator("name", "description")(
CreateProviderRequest.sanitize_text.__func__ CreateProviderRequest.sanitize_text.__func__
) )
_validate_website = field_validator("website")(CreateProviderRequest.validate_website.__func__) _validate_website = field_validator("website")(CreateProviderRequest.validate_website.__func__)
@@ -196,7 +184,6 @@ class CreateEndpointRequest(BaseModel):
custom_path: Optional[str] = Field(None, max_length=200, description="自定义路径") custom_path: Optional[str] = Field(None, max_length=200, description="自定义路径")
priority: Optional[int] = Field(100, ge=0, le=1000, description="优先级") priority: Optional[int] = Field(100, ge=0, le=1000, description="优先级")
is_active: Optional[bool] = Field(True, description="是否启用") is_active: Optional[bool] = Field(True, description="是否启用")
rpm_limit: Optional[int] = Field(None, ge=0, description="RPM 限制")
concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制") concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制")
config: Optional[Dict[str, Any]] = Field(None, description="其他配置") config: Optional[Dict[str, Any]] = Field(None, description="其他配置")
proxy: Optional[ProxyConfig] = Field(None, description="代理配置") proxy: Optional[ProxyConfig] = Field(None, description="代理配置")
@@ -252,7 +239,6 @@ class UpdateEndpointRequest(BaseModel):
custom_path: Optional[str] = Field(None, max_length=200) custom_path: Optional[str] = Field(None, max_length=200)
priority: Optional[int] = Field(None, ge=0, le=1000) priority: Optional[int] = Field(None, ge=0, le=1000)
is_active: Optional[bool] = None is_active: Optional[bool] = None
rpm_limit: Optional[int] = Field(None, ge=0)
concurrent_limit: Optional[int] = Field(None, ge=0) concurrent_limit: Optional[int] = Field(None, ge=0)
config: Optional[Dict[str, Any]] = None config: Optional[Dict[str, Any]] = None
proxy: Optional[ProxyConfig] = Field(None, description="代理配置") proxy: Optional[ProxyConfig] = Field(None, description="代理配置")
@@ -277,8 +263,7 @@ class CreateAPIKeyRequest(BaseModel):
api_key: str = Field(..., min_length=1, max_length=500, description="API Key") api_key: str = Field(..., min_length=1, max_length=500, description="API Key")
priority: Optional[int] = Field(100, ge=0, le=1000, description="优先级") priority: Optional[int] = Field(100, ge=0, le=1000, description="优先级")
is_active: Optional[bool] = Field(True, description="是否启用") is_active: Optional[bool] = Field(True, description="是否启用")
max_rpm: Optional[int] = Field(None, ge=0, description="最大 RPM") rpm_limit: Optional[int] = Field(None, ge=0, description="RPM 限制NULL=自适应)")
max_concurrent: Optional[int] = Field(None, ge=0, description="最大并发")
notes: Optional[str] = Field(None, max_length=500, description="备注") notes: Optional[str] = Field(None, max_length=500, description="备注")
@field_validator("api_key") @field_validator("api_key")

View File

@@ -376,14 +376,13 @@ class ApiKeyResponse(BaseModel):
class ProviderCreate(BaseModel): class ProviderCreate(BaseModel):
"""创建提供商请求 """创建提供商请求
架构说明: 架构说明:
- Provider 仅包含提供商的元数据和计费配置 - Provider 仅包含提供商的元数据和计费配置
- API格式、URL、认证等配置应在 ProviderEndpoint 中设置 - API格式、URL、认证等配置应在 ProviderEndpoint 中设置
- API密钥应在 ProviderAPIKey 中设置 - API密钥应在 ProviderAPIKey 中设置
""" """
name: str = Field(..., min_length=1, max_length=100, description="提供商唯一标识") name: str = Field(..., min_length=1, max_length=100, description="提供商名称(唯一)")
display_name: str = Field(..., min_length=1, max_length=100, description="显示名称")
description: Optional[str] = Field(None, description="提供商描述") description: Optional[str] = Field(None, description="提供商描述")
website: Optional[str] = Field(None, max_length=500, description="主站网站") website: Optional[str] = Field(None, max_length=500, description="主站网站")
@@ -397,7 +396,7 @@ class ProviderCreate(BaseModel):
class ProviderUpdate(BaseModel): class ProviderUpdate(BaseModel):
"""更新提供商请求""" """更新提供商请求"""
display_name: Optional[str] = Field(None, min_length=1, max_length=100) name: Optional[str] = Field(None, min_length=1, max_length=100)
description: Optional[str] = None description: Optional[str] = None
website: Optional[str] = Field(None, max_length=500) website: Optional[str] = Field(None, max_length=500)
api_format: Optional[str] = None api_format: Optional[str] = None
@@ -418,7 +417,6 @@ class ProviderResponse(BaseModel):
id: str id: str
name: str name: str
display_name: str
description: Optional[str] description: Optional[str]
website: Optional[str] website: Optional[str]
api_format: str api_format: str
@@ -609,7 +607,6 @@ class PublicProviderResponse(BaseModel):
id: str id: str
name: str name: str
display_name: str
description: Optional[str] description: Optional[str]
website: Optional[str] website: Optional[str]
is_active: bool is_active: bool
@@ -627,7 +624,6 @@ class PublicModelResponse(BaseModel):
id: str id: str
provider_id: str provider_id: str
provider_name: str provider_name: str
provider_display_name: str
name: str name: str
display_name: str display_name: str
description: Optional[str] = None description: Optional[str] = None

View File

@@ -13,9 +13,7 @@ class ProviderAPIKeyBase(BaseModel):
name: Optional[str] = Field(None, description="密钥名称/备注") name: Optional[str] = Field(None, description="密钥名称/备注")
api_key: str = Field(..., description="API密钥") api_key: str = Field(..., description="API密钥")
rate_limit: Optional[int] = Field(None, description="速率限制(每分钟请求数)") rpm_limit: Optional[int] = Field(None, description="RPM限制(每分钟请求数)NULL=自适应模式")
daily_limit: Optional[int] = Field(None, description="每日请求限制")
monthly_limit: Optional[int] = Field(None, description="每月请求限制")
priority: int = Field(0, description="优先级(越高越优先使用)") priority: int = Field(0, description="优先级(越高越优先使用)")
is_active: bool = Field(True, description="是否启用") is_active: bool = Field(True, description="是否启用")
expires_at: Optional[datetime] = Field(None, description="过期时间") expires_at: Optional[datetime] = Field(None, description="过期时间")
@@ -32,9 +30,7 @@ class ProviderAPIKeyUpdate(BaseModel):
name: Optional[str] = None name: Optional[str] = None
api_key: Optional[str] = None api_key: Optional[str] = None
rate_limit: Optional[int] = None rpm_limit: Optional[int] = None
daily_limit: Optional[int] = None
monthly_limit: Optional[int] = None
priority: Optional[int] = None priority: Optional[int] = None
is_active: Optional[bool] = None is_active: Optional[bool] = None
expires_at: Optional[datetime] = None expires_at: Optional[datetime] = None
@@ -67,5 +63,3 @@ class ProviderAPIKeyStats(BaseModel):
last_used_at: Optional[datetime] last_used_at: Optional[datetime]
is_active: bool is_active: bool
is_expired: bool is_expired: bool
remaining_daily: Optional[int] = Field(None, description="今日剩余请求数")
remaining_monthly: Optional[int] = Field(None, description="本月剩余请求数")

View File

@@ -338,7 +338,8 @@ class Usage(Base):
request_headers = Column(JSON, nullable=True) # 客户端请求头 request_headers = Column(JSON, nullable=True) # 客户端请求头
request_body = Column(JSON, nullable=True) # 请求体7天内未压缩 request_body = Column(JSON, nullable=True) # 请求体7天内未压缩
provider_request_headers = Column(JSON, nullable=True) # 向提供商发送的请求头 provider_request_headers = Column(JSON, nullable=True) # 向提供商发送的请求头
response_headers = Column(JSON, nullable=True) # 响应头 response_headers = Column(JSON, nullable=True) # 提供商响应头
client_response_headers = Column(JSON, nullable=True) # 返回给客户端的响应头
response_body = Column(JSON, nullable=True) # 响应体7天内未压缩 response_body = Column(JSON, nullable=True) # 响应体7天内未压缩
# 压缩存储字段7天后自动压缩到这里 # 压缩存储字段7天后自动压缩到这里
@@ -513,8 +514,7 @@ class Provider(Base):
__tablename__ = "providers" __tablename__ = "providers"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True) id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True)
name = Column(String(100), unique=True, nullable=False, index=True) # 提供商唯一标识 name = Column(String(100), unique=True, nullable=False, index=True) # 提供商名称(唯一)
display_name = Column(String(100), nullable=False) # 显示名称
description = Column(Text, nullable=True) # 提供商描述 description = Column(Text, nullable=True) # 提供商描述
website = Column(String(500), nullable=True) # 主站网站 website = Column(String(500), nullable=True) # 主站网站
@@ -537,11 +537,6 @@ class Provider(Base):
quota_last_reset_at = Column(DateTime(timezone=True), nullable=True) # 上次额度重置时间 quota_last_reset_at = Column(DateTime(timezone=True), nullable=True) # 上次额度重置时间
quota_expires_at = Column(DateTime(timezone=True), nullable=True) # 月卡过期时间 quota_expires_at = Column(DateTime(timezone=True), nullable=True) # 月卡过期时间
# RPM限制NULL=无限制0=禁止请求
rpm_limit = Column(Integer, nullable=True) # 每分钟请求数限制NULL=无限制0=禁止请求)
rpm_used = Column(Integer, default=0) # 当前分钟已用请求数
rpm_reset_at = Column(DateTime(timezone=True), nullable=True) # RPM重置时间
# 提供商优先级 (数字越小越优先,用于提供商优先模式下的 Provider 排序) # 提供商优先级 (数字越小越优先,用于提供商优先模式下的 Provider 排序)
# 0-10: 急需消耗(如即将过期的月卡) # 0-10: 急需消耗(如即将过期的月卡)
# 11-50: 优先消耗(月卡) # 11-50: 优先消耗(月卡)
@@ -555,6 +550,15 @@ class Provider(Base):
# 限制 # 限制
concurrent_limit = Column(Integer, nullable=True) # 并发请求限制 concurrent_limit = Column(Integer, nullable=True) # 并发请求限制
# 请求配置(从 Endpoint 迁移,作为全局默认值)
# 超时 300 秒对于 LLM API 是合理的默认值:
# - 大多数请求在 30 秒内完成
# - 复杂推理(如 Claude thinking可能需要 60-120 秒
# - 300 秒足够覆盖极端场景(如超长上下文、复杂工具调用)
timeout = Column(Integer, default=300, nullable=True) # 请求超时(秒)
max_retries = Column(Integer, default=2, nullable=True) # 最大重试次数
proxy = Column(JSONB, nullable=True) # 代理配置: {url, username, password, enabled}
# 配置 # 配置
config = Column(JSON, nullable=True) # 额外配置如Azure deployment name等 config = Column(JSON, nullable=True) # 额外配置如Azure deployment name等
@@ -574,6 +578,9 @@ class Provider(Base):
endpoints = relationship( endpoints = relationship(
"ProviderEndpoint", back_populates="provider", cascade="all, delete-orphan" "ProviderEndpoint", back_populates="provider", cascade="all, delete-orphan"
) )
api_keys = relationship(
"ProviderAPIKey", back_populates="provider", cascade="all, delete-orphan"
)
api_key_mappings = relationship( api_key_mappings = relationship(
"ApiKeyProviderMapping", back_populates="provider", cascade="all, delete-orphan" "ApiKeyProviderMapping", back_populates="provider", cascade="all, delete-orphan"
) )
@@ -599,12 +606,6 @@ class ProviderEndpoint(Base):
timeout = Column(Integer, default=300) # 超时(秒) timeout = Column(Integer, default=300) # 超时(秒)
max_retries = Column(Integer, default=2) # 最大重试次数 max_retries = Column(Integer, default=2) # 最大重试次数
# 限制
max_concurrent = Column(
Integer, nullable=True, default=None
) # 该端点的最大并发数NULL=不限制)
rate_limit = Column(Integer, nullable=True) # 每分钟请求限制
# 状态 # 状态
is_active = Column(Boolean, default=True, nullable=False) is_active = Column(Boolean, default=True, nullable=False)
@@ -632,9 +633,6 @@ class ProviderEndpoint(Base):
# 关系 # 关系
provider = relationship("Provider", back_populates="endpoints") provider = relationship("Provider", back_populates="endpoints")
api_keys = relationship(
"ProviderAPIKey", back_populates="endpoint", cascade="all, delete-orphan"
)
# 唯一约束和索引在表定义后 # 唯一约束和索引在表定义后
__table_args__ = ( __table_args__ = (
@@ -734,9 +732,11 @@ class GlobalModel(Base):
class Model(Base): class Model(Base):
"""Provider 模型配置表 - Provider 如何使用某个 GlobalModel """Provider 模型配置表 - Provider 如何使用某个 GlobalModel
设计原则 (方案 A): 设计原则:
- 每个 Model 必须关联一个 GlobalModel (global_model_id 不可为空) - Model 表示 Provider 对某个模型的具体实现
- Model 表示 Provider 对某个 GlobalModel 的具体实现 - global_model_id 可为空:
- 为空时:模型尚未关联到 GlobalModel不参与路由
- 不为空时:模型已关联 GlobalModel参与路由
- provider_model_name 是 Provider 侧的实际模型名称 (可能与 GlobalModel.name 不同) - provider_model_name 是 Provider 侧的实际模型名称 (可能与 GlobalModel.name 不同)
- 价格和能力配置可为空,为空时使用 GlobalModel 的默认值 - 价格和能力配置可为空,为空时使用 GlobalModel 的默认值
""" """
@@ -745,7 +745,8 @@ class Model(Base):
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True) id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True)
provider_id = Column(String(36), ForeignKey("providers.id"), nullable=False) provider_id = Column(String(36), ForeignKey("providers.id"), nullable=False)
global_model_id = Column(String(36), ForeignKey("global_models.id"), nullable=False, index=True) # 可为空NULL 表示未关联,不参与路由;非 NULL 表示已关联,参与路由
global_model_id = Column(String(36), ForeignKey("global_models.id"), nullable=True, index=True)
# Provider 映射配置 # Provider 映射配置
provider_model_name = Column(String(200), nullable=False) # Provider 侧的主模型名称 provider_model_name = Column(String(200), nullable=False) # Provider 侧的主模型名称
@@ -983,17 +984,20 @@ class Model(Base):
class ProviderAPIKey(Base): class ProviderAPIKey(Base):
"""Provider API密钥表 - 归属于特定 ProviderEndpoint""" """Provider API密钥表 - 直接归属于 Provider,支持多种 API 格式"""
__tablename__ = "provider_api_keys" __tablename__ = "provider_api_keys"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True) id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True)
# 外键关系 # 外键关系 - 直接关联 Provider
endpoint_id = Column( provider_id = Column(
String(36), ForeignKey("provider_endpoints.id", ondelete="CASCADE"), nullable=False String(36), ForeignKey("providers.id", ondelete="CASCADE"), nullable=False, index=True
) )
# API 格式支持列表(核心字段)
api_formats = Column(JSON, nullable=False, default=list) # ["CLAUDE", "CLAUDE_CLI"]
# API密钥信息 # API密钥信息
api_key = Column(String(500), nullable=False) # API密钥加密存储 api_key = Column(String(500), nullable=False) # API密钥加密存储
name = Column(String(100), nullable=False) # 密钥名称(必填,用于识别) name = Column(String(100), nullable=False) # 密钥名称(必填,用于识别)
@@ -1002,7 +1006,10 @@ class ProviderAPIKey(Base):
# 成本计算 # 成本计算
rate_multiplier = Column( rate_multiplier = Column(
Float, default=1.0, nullable=False Float, default=1.0, nullable=False
) # 成本倍率(真实成本 = 表面成本 × 倍率) ) # 默认成本倍率(真实成本 = 表面成本 × 倍率)
rate_multipliers = Column(
JSON, nullable=True
) # 按 API 格式的成本倍率 {"CLAUDE": 1.0, "OPENAI": 0.8}
# 优先级配置 (数字越小越优先) # 优先级配置 (数字越小越优先)
internal_priority = Column( internal_priority = Column(
@@ -1012,14 +1019,11 @@ class ProviderAPIKey(Base):
Integer, nullable=True Integer, nullable=True
) # 全局 Key 优先级(用于全局 Key 优先模式,跨 Provider 的 Key 排序NULL=未配置使用默认排序) ) # 全局 Key 优先级(用于全局 Key 优先模式,跨 Provider 的 Key 排序NULL=未配置使用默认排序)
# 并发限制配置 # RPM 限制配置(自适应学习)
# max_concurrent 决定并发控制模式: # rpm_limit 决定 RPM 控制模式:
# - NULL: 自适应模式,系统自动学习并调整(使用 learned_max_concurrent # - NULL: 自适应模式,系统自动学习并调整(使用 learned_rpm_limit
# - 数字: 固定限制模式,使用用户指定的值 # - 数字: 固定限制模式,使用用户指定的值
max_concurrent = Column(Integer, nullable=True, default=None) rpm_limit = Column(Integer, nullable=True, default=None)
rate_limit = Column(Integer, nullable=True) # 速率限制(每分钟请求数)
daily_limit = Column(Integer, nullable=True) # 每日请求限制
monthly_limit = Column(Integer, nullable=True) # 每月请求限制
# 模型权限控制 # 模型权限控制
allowed_models = Column(JSON, nullable=True) # 允许使用的模型列表null = 支持所有模型) allowed_models = Column(JSON, nullable=True) # 允许使用的模型列表null = 支持所有模型)
@@ -1028,16 +1032,16 @@ class ProviderAPIKey(Base):
capabilities = Column(JSON, nullable=True) # Key 拥有的能力 capabilities = Column(JSON, nullable=True) # Key 拥有的能力
# 示例: {"cache_1h": true, "context_1m": true} # 示例: {"cache_1h": true, "context_1m": true}
# 自适应并发调整(仅当 max_concurrent = NULL 时生效) # 自适应 RPM 调整(仅当 rpm_limit = NULL 时生效)
learned_max_concurrent = Column( learned_rpm_limit = Column(
Integer, nullable=True Integer, nullable=True
) # 学习到的并发限制(自适应模式下的有效值) ) # 学习到的 RPM 限制(自适应模式下的有效值)
concurrent_429_count = Column(Integer, default=0, nullable=False) # 因并发导致的429次数 concurrent_429_count = Column(Integer, default=0, nullable=False) # 因并发导致的429次数
rpm_429_count = Column(Integer, default=0, nullable=False) # 因RPM导致的429次数 rpm_429_count = Column(Integer, default=0, nullable=False) # 因RPM导致的429次数
last_429_at = Column(DateTime(timezone=True), nullable=True) # 最后429时间 last_429_at = Column(DateTime(timezone=True), nullable=True) # 最后429时间
last_429_type = Column(String(50), nullable=True) # 最后429类型: concurrent/rpm/unknown last_429_type = Column(String(50), nullable=True) # 最后429类型: concurrent/rpm/unknown
last_concurrent_peak = Column(Integer, nullable=True) # 触发429时的并发数 last_rpm_peak = Column(Integer, nullable=True) # 触发429时的RPM峰值
adjustment_history = Column(JSON, nullable=True) # 并发调整历史 adjustment_history = Column(JSON, nullable=True) # RPM调整历史
# 基于滑动窗口的利用率追踪 # 基于滑动窗口的利用率追踪
utilization_samples = Column( utilization_samples = Column(
JSON, nullable=True JSON, nullable=True
@@ -1046,12 +1050,9 @@ class ProviderAPIKey(Base):
DateTime(timezone=True), nullable=True DateTime(timezone=True), nullable=True
) # 上次探测性扩容时间 ) # 上次探测性扩容时间
# 健康度追踪(基于滑动窗口 # 健康度追踪(按 API 格式存储
health_score = Column(Float, default=1.0) # 0.0-1.0(保留用于展示,实际熔断基于滑动窗口) # 结构: {"CLAUDE": {"health_score": 1.0, "consecutive_failures": 0, "last_failure_at": null, "request_results_window": []}, ...}
consecutive_failures = Column(Integer, default=0) health_by_format = Column(JSON, nullable=True, default=dict)
last_failure_at = Column(DateTime(timezone=True), nullable=True) # 最后失败时间
# 滑动窗口:记录最近 N 次请求的结果 [{"ts": timestamp, "ok": true/false}, ...]
request_results_window = Column(JSON, nullable=True)
# 缓存与熔断配置 # 缓存与熔断配置
cache_ttl_minutes = Column( cache_ttl_minutes = Column(
@@ -1061,14 +1062,9 @@ class ProviderAPIKey(Base):
Integer, default=32, nullable=False Integer, default=32, nullable=False
) # 最大探测间隔(分钟)默认32分钟硬上限 ) # 最大探测间隔(分钟)默认32分钟硬上限
# 熔断器字段(滑动窗口 + 半开状态模式 # 熔断器状态(按 API 格式存储
circuit_breaker_open = Column(Boolean, default=False, nullable=False) # 熔断器是否打开 # 结构: {"CLAUDE": {"open": false, "open_at": null, "next_probe_at": null, "half_open_until": null, "half_open_successes": 0, "half_open_failures": 0}, ...}
circuit_breaker_open_at = Column(DateTime(timezone=True), nullable=True) # 熔断器打开时间 circuit_breaker_by_format = Column(JSON, nullable=True, default=dict)
next_probe_at = Column(DateTime(timezone=True), nullable=True) # 下次探测时间
# 半开状态:允许少量请求通过验证服务是否恢复
half_open_until = Column(DateTime(timezone=True), nullable=True) # 半开状态结束时间
half_open_successes = Column(Integer, default=0) # 半开状态下的成功次数
half_open_failures = Column(Integer, default=0) # 半开状态下的失败次数
# 使用统计 # 使用统计
request_count = Column(Integer, default=0) # 请求次数 request_count = Column(Integer, default=0) # 请求次数
@@ -1095,7 +1091,7 @@ class ProviderAPIKey(Base):
) )
# 关系 # 关系
endpoint = relationship("ProviderEndpoint", back_populates="api_keys") provider = relationship("Provider", back_populates="api_keys")
class UserPreference(Base): class UserPreference(Base):

View File

@@ -4,7 +4,7 @@ ProviderEndpoint 相关的 API 模型定义
import re import re
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, field_validator from pydantic import BaseModel, ConfigDict, Field, field_validator
@@ -26,10 +26,6 @@ class ProviderEndpointCreate(BaseModel):
timeout: int = Field(default=300, ge=10, le=600, description="超时时间(秒)") timeout: int = Field(default=300, ge=10, le=600, description="超时时间(秒)")
max_retries: int = Field(default=2, ge=0, le=10, description="最大重试次数") max_retries: int = Field(default=2, ge=0, le=10, description="最大重试次数")
# 限制
max_concurrent: Optional[int] = Field(default=None, ge=1, description="最大并发数")
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制(请求/秒)")
# 额外配置 # 额外配置
config: Optional[Dict[str, Any]] = Field(default=None, description="额外配置JSON") config: Optional[Dict[str, Any]] = Field(default=None, description="额外配置JSON")
@@ -67,8 +63,6 @@ class ProviderEndpointUpdate(BaseModel):
headers: Optional[Dict[str, str]] = Field(default=None, description="自定义请求头") headers: Optional[Dict[str, str]] = Field(default=None, description="自定义请求头")
timeout: Optional[int] = Field(default=None, ge=10, le=600, description="超时时间(秒)") timeout: Optional[int] = Field(default=None, ge=10, le=600, description="超时时间(秒)")
max_retries: Optional[int] = Field(default=None, ge=0, le=10, description="最大重试次数") max_retries: Optional[int] = Field(default=None, ge=0, le=10, description="最大重试次数")
max_concurrent: Optional[int] = Field(default=None, ge=1, description="最大并发数")
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制")
is_active: Optional[bool] = Field(default=None, description="是否启用") is_active: Optional[bool] = Field(default=None, description="是否启用")
config: Optional[Dict[str, Any]] = Field(default=None, description="额外配置") config: Optional[Dict[str, Any]] = Field(default=None, description="额外配置")
proxy: Optional[ProxyConfig] = Field(default=None, description="代理配置") proxy: Optional[ProxyConfig] = Field(default=None, description="代理配置")
@@ -103,10 +97,6 @@ class ProviderEndpointResponse(BaseModel):
timeout: int timeout: int
max_retries: int max_retries: int
# 限制
max_concurrent: Optional[int] = None
rate_limit: Optional[int] = None
# 状态 # 状态
is_active: bool is_active: bool
@@ -127,32 +117,37 @@ class ProviderEndpointResponse(BaseModel):
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
# ========== ProviderAPIKey 相关(新架构) ========== # ========== ProviderAPIKey 相关 ==========
class EndpointAPIKeyCreate(BaseModel): class EndpointAPIKeyCreate(BaseModel):
"""Endpoint 添加 API Key""" """Provider 添加 API Key"""
endpoint_id: str = Field(..., description="Endpoint ID") provider_id: Optional[str] = Field(default=None, description="Provider ID从 URL 获取)")
api_key: str = Field(..., min_length=10, max_length=500, description="API Key将自动加密") api_formats: Optional[List[str]] = Field(
default=None, min_length=1, description="支持的 API 格式列表(必填,路由层校验)"
)
api_key: str = Field(..., min_length=3, max_length=500, description="API Key将自动加密")
name: str = Field(..., min_length=1, max_length=100, description="密钥名称(必填,用于识别)") name: str = Field(..., min_length=1, max_length=100, description="密钥名称(必填,用于识别)")
# 成本计算 # 成本计算
rate_multiplier: float = Field( rate_multiplier: float = Field(
default=1.0, ge=0.01, description="成本倍率(真实成本 = 表面成本 × 倍率)" default=1.0, ge=0.01, description="默认成本倍率(真实成本 = 表面成本 × 倍率)"
)
rate_multipliers: Optional[Dict[str, float]] = Field(
default=None, description="按 API 格式的成本倍率,如 {'CLAUDE': 1.0, 'OPENAI': 0.8}"
) )
# 优先级和限制(数字越小越优先) # 优先级和限制(数字越小越优先)
internal_priority: int = Field(default=50, description="Endpoint 内部优先级(提供商优先模式)") internal_priority: int = Field(default=50, description="Key 内部优先级(提供商优先模式)")
# max_concurrent: NULL=自适应模式(系统自动学习),数字=固定限制模式 # rpm_limit: NULL=自适应模式(系统自动学习),数字=固定限制模式
max_concurrent: Optional[int] = Field( rpm_limit: Optional[int] = Field(
default=None, ge=1, description="最大并发数NULL=自适应模式)" default=None, ge=1, le=10000, description="RPM 限制NULL=自适应模式)"
) )
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制") allowed_models: Optional[Union[List[str], Dict[str, List[str]]]] = Field(
daily_limit: Optional[int] = Field(default=None, ge=1, description="每日限制") default=None,
monthly_limit: Optional[int] = Field(default=None, ge=1, description="每月限制") description="允许使用的模型列表null=不限制,列表=简单白名单,字典=按API格式区分",
allowed_models: Optional[List[str]] = Field(
default=None, description="允许使用的模型列表null = 支持所有模型)"
) )
# 能力标签 # 能力标签
@@ -171,17 +166,99 @@ class EndpointAPIKeyCreate(BaseModel):
# 备注 # 备注
note: Optional[str] = Field(default=None, max_length=500, description="备注说明(可选)") note: Optional[str] = Field(default=None, max_length=500, description="备注说明(可选)")
@field_validator("api_formats")
@classmethod
def validate_api_formats(cls, v: Optional[List[str]]) -> Optional[List[str]]:
"""验证 API 格式列表"""
if v is None:
return v
from src.core.enums import APIFormat
allowed = [fmt.value for fmt in APIFormat]
validated = []
seen = set()
for fmt in v:
fmt_upper = fmt.upper()
if fmt_upper not in allowed:
raise ValueError(f"API 格式必须是 {allowed} 之一,当前值: {fmt}")
if fmt_upper in seen:
continue # 静默去重
seen.add(fmt_upper)
validated.append(fmt_upper)
return validated
@field_validator("allowed_models")
@classmethod
def validate_allowed_models(
cls, v: Optional[Union[List[str], Dict[str, List[str]]]]
) -> Optional[Union[List[str], Dict[str, List[str]]]]:
"""
规范化 allowed_models
- 列表模式:去空、去重、保留顺序
- 字典模式key 统一大写(支持 "*"value 去空、去重、保留顺序
"""
if v is None:
return v
if isinstance(v, list):
cleaned: List[str] = []
seen: set[str] = set()
for item in v:
if not isinstance(item, str):
raise ValueError("allowed_models 列表必须为字符串数组")
name = item.strip()
if not name or name in seen:
continue
seen.add(name)
cleaned.append(name)
return cleaned
if isinstance(v, dict):
from src.core.enums import APIFormat
allowed_formats = {fmt.value for fmt in APIFormat}
normalized: Dict[str, List[str]] = {}
for raw_key, models in v.items():
if not isinstance(raw_key, str):
raise ValueError("allowed_models 字典的 key 必须为字符串")
key = raw_key.upper()
if key != "*" and key not in allowed_formats:
raise ValueError(
f"allowed_models 字典的 key 必须是 {sorted(allowed_formats)}'*',当前值: {raw_key}"
)
if models is None:
# null 表示该格式不限制,跳过(不加入字典)
continue
if not isinstance(models, list):
raise ValueError("allowed_models 字典的 value 必须为字符串数组")
cleaned: List[str] = []
seen: set[str] = set()
for item in models:
if not isinstance(item, str):
raise ValueError("allowed_models 字典的 value 必须为字符串数组")
name = item.strip()
if not name or name in seen:
continue
seen.add(name)
cleaned.append(name)
normalized[key] = cleaned
return normalized
raise ValueError("allowed_models 必须是列表或字典")
@field_validator("api_key") @field_validator("api_key")
@classmethod @classmethod
def validate_api_key(cls, v: str) -> str: def validate_api_key(cls, v: str) -> str:
"""验证 API Key 安全性""" """验证 API Key 安全性"""
# 移除首尾空白 # 移除首尾空白(长度校验由 Field min_length 处理)
v = v.strip() v = v.strip()
# 检查最小长度
if len(v) < 10:
raise ValueError("API Key 长度不能少于 10 个字符")
# 检查危险字符SQL 注入防护) # 检查危险字符SQL 注入防护)
dangerous_chars = ["'", '"', ";", "--", "/*", "*/", "<", ">"] dangerous_chars = ["'", '"', ";", "--", "/*", "*/", "<", ">"]
for char in dangerous_chars: for char in dangerous_chars:
@@ -218,26 +295,35 @@ class EndpointAPIKeyCreate(BaseModel):
class EndpointAPIKeyUpdate(BaseModel): class EndpointAPIKeyUpdate(BaseModel):
"""更新 Endpoint API Key""" """更新 Endpoint API Key"""
api_formats: Optional[List[str]] = Field(
default=None, min_length=1, description="支持的 API 格式列表"
)
api_key: Optional[str] = Field( api_key: Optional[str] = Field(
default=None, min_length=10, max_length=500, description="API Key将自动加密" default=None, min_length=3, max_length=500, description="API Key将自动加密"
) )
name: Optional[str] = Field(default=None, min_length=1, max_length=100, description="密钥名称") name: Optional[str] = Field(default=None, min_length=1, max_length=100, description="密钥名称")
rate_multiplier: Optional[float] = Field(default=None, ge=0.01, description="成本倍率") rate_multiplier: Optional[float] = Field(default=None, ge=0.01, description="默认成本倍率")
rate_multipliers: Optional[Dict[str, float]] = Field(
default=None, description="按 API 格式的成本倍率,如 {'CLAUDE': 1.0, 'OPENAI': 0.8}"
)
internal_priority: Optional[int] = Field( internal_priority: Optional[int] = Field(
default=None, description="Endpoint 内部优先级(提供商优先模式,数字越小越优先)" default=None, description="Key 内部优先级(提供商优先模式,数字越小越优先)"
) )
global_priority: Optional[int] = Field( global_priority: Optional[int] = Field(
default=None, description="全局 Key 优先级(全局 Key 优先模式,数字越小越优先)" default=None, description="全局 Key 优先级(全局 Key 优先模式,数字越小越优先)"
) )
# max_concurrent: 使用特殊标记区分"未提供"和"设置为 null自适应模式" # rpm_limit: 使用特殊标记区分"未提供"和"设置为 null自适应模式"
# - 不提供字段:不更新 # - 不提供字段:不更新
# - 提供 null切换为自适应模式 # - 提供 null切换为自适应模式
# - 提供数字:设置固定并发限制 # - 提供数字:设置固定 RPM 限制
max_concurrent: Optional[int] = Field(default=None, ge=1, description="最大并发数null=自适应模式)") rpm_limit: Optional[int] = Field(
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制") default=None, ge=1, le=10000, description="RPM 限制null=自适应模式)"
daily_limit: Optional[int] = Field(default=None, ge=1, description="每日限制") )
monthly_limit: Optional[int] = Field(default=None, ge=1, description="每月限制") allowed_models: Optional[Union[List[str], Dict[str, List[str]]]] = Field(
allowed_models: Optional[List[str]] = Field(default=None, description="允许使用的模型列表") default=None,
description="允许使用的模型列表null=不限制,列表=简单白名单,字典=按API格式区分",
)
capabilities: Optional[Dict[str, bool]] = Field( capabilities: Optional[Dict[str, bool]] = Field(
default=None, description="Key 能力标签,如 {'cache_1h': true, 'context_1m': true}" default=None, description="Key 能力标签,如 {'cache_1h': true, 'context_1m': true}"
) )
@@ -250,6 +336,36 @@ class EndpointAPIKeyUpdate(BaseModel):
is_active: Optional[bool] = Field(default=None, description="是否启用") is_active: Optional[bool] = Field(default=None, description="是否启用")
note: Optional[str] = Field(default=None, max_length=500, description="备注说明") note: Optional[str] = Field(default=None, max_length=500, description="备注说明")
@field_validator("api_formats")
@classmethod
def validate_api_formats(cls, v: Optional[List[str]]) -> Optional[List[str]]:
"""验证 API 格式列表"""
if v is None:
return v
from src.core.enums import APIFormat
allowed = [fmt.value for fmt in APIFormat]
validated = []
seen = set()
for fmt in v:
fmt_upper = fmt.upper()
if fmt_upper not in allowed:
raise ValueError(f"API 格式必须是 {allowed} 之一,当前值: {fmt}")
if fmt_upper in seen:
continue # 静默去重
seen.add(fmt_upper)
validated.append(fmt_upper)
return validated
@field_validator("allowed_models")
@classmethod
def validate_allowed_models(
cls, v: Optional[Union[List[str], Dict[str, List[str]]]]
) -> Optional[Union[List[str], Dict[str, List[str]]]]:
# 与 EndpointAPIKeyCreate 保持一致
return EndpointAPIKeyCreate.validate_allowed_models(v)
@field_validator("api_key") @field_validator("api_key")
@classmethod @classmethod
def validate_api_key(cls, v: Optional[str]) -> Optional[str]: def validate_api_key(cls, v: Optional[str]) -> Optional[str]:
@@ -299,7 +415,9 @@ class EndpointAPIKeyResponse(BaseModel):
"""Endpoint API Key 响应""" """Endpoint API Key 响应"""
id: str id: str
endpoint_id: str
provider_id: str = Field(..., description="Provider ID")
api_formats: List[str] = Field(default=[], description="支持的 API 格式列表")
# Key 信息(脱敏) # Key 信息(脱敏)
api_key_masked: str = Field(..., description="脱敏后的 Key") api_key_masked: str = Field(..., description="脱敏后的 Key")
@@ -307,31 +425,37 @@ class EndpointAPIKeyResponse(BaseModel):
name: str = Field(..., description="密钥名称") name: str = Field(..., description="密钥名称")
# 成本计算 # 成本计算
rate_multiplier: float = Field(default=1.0, description="成本倍率") rate_multiplier: float = Field(default=1.0, description="默认成本倍率")
rate_multipliers: Optional[Dict[str, float]] = Field(
default=None, description="按 API 格式的成本倍率,如 {'CLAUDE': 1.0, 'OPENAI': 0.8}"
)
# 优先级和限制 # 优先级和限制
internal_priority: int = Field(default=50, description="Endpoint 内部优先级") internal_priority: int = Field(default=50, description="Endpoint 内部优先级")
global_priority: Optional[int] = Field(default=None, description="全局 Key 优先级") global_priority: Optional[int] = Field(default=None, description="全局 Key 优先级")
max_concurrent: Optional[int] = None rpm_limit: Optional[int] = None
rate_limit: Optional[int] = None allowed_models: Optional[Union[List[str], Dict[str, List[str]]]] = None
daily_limit: Optional[int] = None capabilities: Optional[Dict[str, bool]] = Field(default=None, description="Key 能力标签")
monthly_limit: Optional[int] = None
allowed_models: Optional[List[str]] = None
capabilities: Optional[Dict[str, bool]] = Field(
default=None, description="Key 能力标签"
)
# 缓存与熔断配置 # 缓存与熔断配置
cache_ttl_minutes: int = Field(default=5, description="缓存 TTL分钟0=禁用") cache_ttl_minutes: int = Field(default=5, description="缓存 TTL分钟0=禁用")
max_probe_interval_minutes: int = Field(default=32, description="熔断探测间隔(分钟)") max_probe_interval_minutes: int = Field(default=32, description="熔断探测间隔(分钟)")
# 健康度 # 按格式的健康度数据
health_score: float health_by_format: Optional[Dict[str, Any]] = Field(
consecutive_failures: int default=None, description="按 API 格式存储的健康度数据"
)
circuit_breaker_by_format: Optional[Dict[str, Any]] = Field(
default=None, description="按 API 格式存储的熔断器状态"
)
# 聚合字段(从 health_by_format 计算,用于列表显示)
health_score: float = Field(default=1.0, description="健康度(所有格式中的最低值)")
consecutive_failures: int = Field(default=0, description="连续失败次数")
last_failure_at: Optional[datetime] = None last_failure_at: Optional[datetime] = None
# 熔断器状态(滑动窗口 + 半开模式) # 聚合熔断器字段
circuit_breaker_open: bool = Field(default=False, description="熔断器是否打开") circuit_breaker_open: bool = Field(default=False, description="熔断器是否打开(任何格式)")
circuit_breaker_open_at: Optional[datetime] = Field(default=None, description="熔断器打开时间") circuit_breaker_open_at: Optional[datetime] = Field(default=None, description="熔断器打开时间")
next_probe_at: Optional[datetime] = Field(default=None, description="下次进入半开状态时间") next_probe_at: Optional[datetime] = Field(default=None, description="下次进入半开状态时间")
half_open_until: Optional[datetime] = Field(default=None, description="半开状态结束时间") half_open_until: Optional[datetime] = Field(default=None, description="半开状态结束时间")
@@ -349,9 +473,9 @@ class EndpointAPIKeyResponse(BaseModel):
# 状态 # 状态
is_active: bool is_active: bool
# 自适应并发信息 # 自适应 RPM 信息
is_adaptive: bool = Field(default=False, description="是否为自适应模式(max_concurrent=NULL") is_adaptive: bool = Field(default=False, description="是否为自适应模式(rpm_limit=NULL")
learned_max_concurrent: Optional[int] = Field(None, description="学习到的并发限制") learned_rpm_limit: Optional[int] = Field(None, description="学习到的 RPM 限制")
effective_limit: Optional[int] = Field(None, description="当前有效限制") effective_limit: Optional[int] = Field(None, description="当前有效限制")
# 滑动窗口利用率采样 # 滑动窗口利用率采样
utilization_samples: Optional[List[dict]] = Field(None, description="利用率采样窗口") utilization_samples: Optional[List[dict]] = Field(None, description="利用率采样窗口")
@@ -375,22 +499,42 @@ class EndpointAPIKeyResponse(BaseModel):
# ========== 健康监控相关 ========== # ========== 健康监控相关 ==========
class HealthStatusResponse(BaseModel): class FormatHealthData(BaseModel):
"""健康状态响应(仅 Key 级别)""" """单个 API 格式的健康度数据"""
# Key 健康状态 health_score: float = 1.0
error_rate: float = 0.0
window_size: int = 0
consecutive_failures: int = 0
last_failure_at: Optional[str] = None
circuit_breaker: Dict[str, Any] = Field(default_factory=dict)
class HealthStatusResponse(BaseModel):
"""健康状态响应(支持按格式查询)"""
# 基础信息
key_id: str key_id: str
key_health_score: float
key_consecutive_failures: int
key_last_failure_at: Optional[datetime] = None
key_is_active: bool key_is_active: bool
key_statistics: Optional[Dict[str, Any]] = None key_statistics: Optional[Dict[str, Any]] = None
# 熔断器状态(滑动窗口 + 半开模式 # 整体健康度(取所有格式中的最低值
key_health_score: float = 1.0
any_circuit_open: bool = False
# 按格式的健康度数据
health_by_format: Optional[Dict[str, FormatHealthData]] = None
# 单格式查询时的字段
api_format: Optional[str] = None
key_consecutive_failures: Optional[int] = None
key_last_failure_at: Optional[str] = None
# 单格式查询时的熔断器状态
circuit_breaker_open: bool = False circuit_breaker_open: bool = False
circuit_breaker_open_at: Optional[datetime] = None circuit_breaker_open_at: Optional[str] = None
next_probe_at: Optional[datetime] = None next_probe_at: Optional[str] = None
half_open_until: Optional[datetime] = None half_open_until: Optional[str] = None
half_open_successes: int = 0 half_open_successes: int = 0
half_open_failures: int = 0 half_open_failures: int = 0
@@ -402,33 +546,22 @@ class HealthSummaryResponse(BaseModel):
keys: Dict[str, int] = Field(..., description="Key 统计 (total, active, unhealthy)") keys: Dict[str, int] = Field(..., description="Key 统计 (total, active, unhealthy)")
# ========== 并发控制相关 ========== # ========== RPM 控制相关 ==========
class ConcurrencyStatusResponse(BaseModel): class KeyRpmStatusResponse(BaseModel):
"""并发状态响应""" """Key RPM 状态响应"""
endpoint_id: Optional[str] = None key_id: str = Field(..., description="Key ID")
endpoint_current_concurrency: int = Field(default=0, description="Endpoint 当前并发") current_rpm: int = Field(default=0, description="当前 RPM 计")
endpoint_max_concurrent: Optional[int] = Field(default=None, description="Endpoint 最大并发数") rpm_limit: Optional[int] = Field(default=None, description="RPM 限制")
key_id: Optional[str] = None
key_current_concurrency: int = Field(default=0, description="Key 当前并发数")
key_max_concurrent: Optional[int] = Field(default=None, description="Key 最大并发数")
class ResetConcurrencyRequest(BaseModel):
"""重置并发计数请求"""
endpoint_id: Optional[str] = Field(default=None, description="Endpoint ID可选")
key_id: Optional[str] = Field(default=None, description="Key ID可选")
class KeyPriorityItem(BaseModel): class KeyPriorityItem(BaseModel):
"""单个 Key 优先级项""" """单个 Key 优先级项"""
key_id: str = Field(..., description="Key ID") key_id: str = Field(..., description="Key ID")
internal_priority: int = Field(..., ge=0, description="Endpoint 内部优先级(数字越小越优先)") internal_priority: int = Field(..., ge=0, description="Key 内部优先级(数字越小越优先)")
class BatchUpdateKeyPriorityRequest(BaseModel): class BatchUpdateKeyPriorityRequest(BaseModel):
@@ -443,11 +576,9 @@ class BatchUpdateKeyPriorityRequest(BaseModel):
class ProviderUpdateRequest(BaseModel): class ProviderUpdateRequest(BaseModel):
"""Provider 基础配置更新请求""" """Provider 基础配置更新请求"""
display_name: Optional[str] = Field(None, min_length=1, max_length=100) name: Optional[str] = Field(None, min_length=1, max_length=100)
description: Optional[str] = None description: Optional[str] = None
website: Optional[str] = Field(None, max_length=500, description="主站网站") website: Optional[str] = Field(None, max_length=500, description="主站网站")
priority: Optional[int] = None
weight: Optional[float] = Field(None, gt=0)
provider_priority: Optional[int] = Field(None, description="提供商优先级(数字越小越优先)") provider_priority: Optional[int] = Field(None, description="提供商优先级(数字越小越优先)")
is_active: Optional[bool] = None is_active: Optional[bool] = None
billing_type: Optional[str] = Field( billing_type: Optional[str] = Field(
@@ -456,9 +587,10 @@ class ProviderUpdateRequest(BaseModel):
monthly_quota_usd: Optional[float] = Field(None, ge=0, description="订阅配额(美元)") monthly_quota_usd: Optional[float] = Field(None, ge=0, description="订阅配额(美元)")
quota_reset_day: Optional[int] = Field(None, ge=1, le=31, description="配额重置日1-31") quota_reset_day: Optional[int] = Field(None, ge=1, le=31, description="配额重置日1-31")
quota_expires_at: Optional[datetime] = Field(None, description="配额过期时间") quota_expires_at: Optional[datetime] = Field(None, description="配额过期时间")
rpm_limit: Optional[int] = Field( # 请求配置(从 Endpoint 迁移)
None, ge=0, description="每分钟请求数限制NULL=无限制0=禁止请求)" timeout: Optional[int] = Field(None, ge=1, le=600, description="请求超时(秒")
) max_retries: Optional[int] = Field(None, ge=0, le=10, description="最大重试次数")
proxy: Optional[Dict[str, Any]] = Field(None, description="代理配置")
class ProviderWithEndpointsSummary(BaseModel): class ProviderWithEndpointsSummary(BaseModel):
@@ -467,7 +599,6 @@ class ProviderWithEndpointsSummary(BaseModel):
# Provider 基本信息 # Provider 基本信息
id: str id: str
name: str name: str
display_name: str
description: Optional[str] = None description: Optional[str] = None
website: Optional[str] = None website: Optional[str] = None
provider_priority: int = Field(default=100, description="提供商优先级(数字越小越优先)") provider_priority: int = Field(default=100, description="提供商优先级(数字越小越优先)")
@@ -481,12 +612,10 @@ class ProviderWithEndpointsSummary(BaseModel):
quota_last_reset_at: Optional[datetime] = Field(default=None, description="当前周期开始时间") quota_last_reset_at: Optional[datetime] = Field(default=None, description="当前周期开始时间")
quota_expires_at: Optional[datetime] = Field(default=None, description="配额过期时间") quota_expires_at: Optional[datetime] = Field(default=None, description="配额过期时间")
# RPM 限制 # 请求配置(从 Endpoint 迁移)
rpm_limit: Optional[int] = Field( timeout: Optional[int] = Field(default=300, description="请求超时(秒)")
default=None, description="每分钟请求数限制NULL=无限制0=禁止请求)" max_retries: Optional[int] = Field(default=2, description="最大重试次数")
) proxy: Optional[Dict[str, Any]] = Field(default=None, description="代理配置")
rpm_used: Optional[int] = Field(default=None, description="当前分钟已用请求数")
rpm_reset_at: Optional[datetime] = Field(default=None, description="RPM 重置时间")
# Endpoint 统计 # Endpoint 统计
total_endpoints: int = Field(default=0, description="总 Endpoint 数量") total_endpoints: int = Field(default=0, description="总 Endpoint 数量")
@@ -621,12 +750,8 @@ class PublicApiFormatHealthMonitor(BaseModel):
default_factory=list, default_factory=list,
description="Usage 表生成的健康时间线healthy/warning/unhealthy/unknown", description="Usage 表生成的健康时间线healthy/warning/unhealthy/unknown",
) )
time_range_start: Optional[datetime] = Field( time_range_start: Optional[datetime] = Field(default=None, description="时间线覆盖区间开始时间")
default=None, description="时间线覆盖区间开始时间" time_range_end: Optional[datetime] = Field(default=None, description="时间线覆盖区间结束时间")
)
time_range_end: Optional[datetime] = Field(
default=None, description="时间线覆盖区间结束时间"
)
class PublicApiFormatHealthMonitorResponse(BaseModel): class PublicApiFormatHealthMonitorResponse(BaseModel):

View File

@@ -114,7 +114,6 @@ class ModelCatalogProviderDetail(BaseModel):
provider_id: str provider_id: str
provider_name: str provider_name: str
provider_display_name: Optional[str]
model_id: Optional[str] model_id: Optional[str]
target_model: str target_model: str
input_price_per_1m: Optional[float] input_price_per_1m: Optional[float]
@@ -312,16 +311,26 @@ class ImportFromUpstreamRequest(BaseModel):
"""从上游提供商导入模型请求""" """从上游提供商导入模型请求"""
model_ids: List[str] = Field(..., min_length=1, description="上游模型 ID 列表") model_ids: List[str] = Field(..., min_length=1, description="上游模型 ID 列表")
# 价格覆盖配置(应用于所有导入的模型)
tiered_pricing: Optional[Dict] = Field(
None,
description="阶梯计费配置(可选),格式: {tiers: [{up_to, input_price_per_1m, output_price_per_1m, ...}]}"
)
price_per_request: Optional[float] = Field(
None,
ge=0,
description="按次计费价格(可选,单位:美元)"
)
class ImportFromUpstreamSuccessItem(BaseModel): class ImportFromUpstreamSuccessItem(BaseModel):
"""导入成功的模型信息""" """导入成功的模型信息"""
model_id: str = Field(..., description="上游模型 ID") model_id: str = Field(..., description="上游模型 ID")
global_model_id: str = Field(..., description="GlobalModel ID")
global_model_name: str = Field(..., description="GlobalModel 名称")
provider_model_id: str = Field(..., description="Provider Model ID") provider_model_id: str = Field(..., description="Provider Model ID")
created_global_model: bool = Field(..., description="是否新创建了 GlobalModel") global_model_id: Optional[str] = Field("", description="GlobalModel ID如果已关联")
global_model_name: Optional[str] = Field("", description="GlobalModel 名称(如果已关联)")
created_global_model: bool = Field(False, description="是否新创建了 GlobalModel始终为 false")
class ImportFromUpstreamErrorItem(BaseModel): class ImportFromUpstreamErrorItem(BaseModel):

View File

@@ -8,7 +8,9 @@ import hashlib
import secrets import secrets
import time import time
import uuid import uuid
from collections import OrderedDict
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from threading import Lock
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import TYPE_CHECKING, Any, Dict, Optional
import jwt import jwt
@@ -30,6 +32,44 @@ from src.services.cache.user_cache import UserCacheService
from src.services.user.apikey import ApiKeyService from src.services.user.apikey import ApiKeyService
# API Key last_used_at 更新节流配置
# 同一个 API Key 在此时间间隔内只会更新一次 last_used_at
_LAST_USED_UPDATE_INTERVAL = 60 # 秒
_LAST_USED_CACHE_MAX_SIZE = 10000 # LRU 缓存最大条目数
# 进程内缓存:记录每个 API Key 最后一次更新 last_used_at 的时间
# 使用 OrderedDict 实现 LRU避免内存无限增长
_api_key_last_update_times: OrderedDict[str, float] = OrderedDict()
_last_update_lock = Lock()
def _should_update_last_used(api_key_id: str) -> bool:
"""判断是否应该更新 API Key 的 last_used_at
使用节流策略,同一个 Key 在指定间隔内只更新一次。
线程安全,使用 LRU 策略限制缓存大小。
Returns:
True 表示应该更新False 表示跳过
"""
now = time.time()
with _last_update_lock:
last_update = _api_key_last_update_times.get(api_key_id, 0)
if now - last_update >= _LAST_USED_UPDATE_INTERVAL:
_api_key_last_update_times[api_key_id] = now
# LRU: 移到末尾(最近使用)
_api_key_last_update_times.move_to_end(api_key_id)
# 超过最大容量时,移除最旧的条目
while len(_api_key_last_update_times) > _LAST_USED_CACHE_MAX_SIZE:
_api_key_last_update_times.popitem(last=False)
return True
return False
# JWT配置从config读取 # JWT配置从config读取
if not config.jwt_secret_key: if not config.jwt_secret_key:
# 如果没有配置,生成一个随机密钥并警告 # 如果没有配置,生成一个随机密钥并警告
@@ -367,9 +407,10 @@ class AuthService:
logger.warning(f"API认证失败 - 用户已禁用: {user.email}") logger.warning(f"API认证失败 - 用户已禁用: {user.email}")
return None return None
# 更新最后使用时间 # 更新最后使用时间(使用节流策略,减少数据库写入)
key_record.last_used_at = datetime.now(timezone.utc) if _should_update_last_used(key_record.id):
db.commit() # 立即提交事务,释放数据库锁,避免阻塞后续请求 key_record.last_used_at = datetime.now(timezone.utc)
db.commit() # 立即提交事务,释放数据库锁,避免阻塞后续请求
api_key_fp = hashlib.sha256(api_key.encode()).hexdigest()[:12] api_key_fp = hashlib.sha256(api_key.encode()).hexdigest()[:12]
logger.debug("API认证成功: 用户 {} (api_key_fp={})", user.email, api_key_fp) logger.debug("API认证成功: 用户 {} (api_key_fp={})", user.email, api_key_fp)

View File

@@ -34,7 +34,7 @@ import hashlib
import random import random
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from sqlalchemy.orm import Session, selectinload from sqlalchemy.orm import Session, selectinload
@@ -80,8 +80,6 @@ class ProviderCandidate:
@dataclass @dataclass
class ConcurrencySnapshot: class ConcurrencySnapshot:
endpoint_current: int
endpoint_limit: Optional[int]
key_current: int key_current: int
key_limit: Optional[int] key_limit: Optional[int]
is_cached_user: bool = False is_cached_user: bool = False
@@ -91,11 +89,9 @@ class ConcurrencySnapshot:
reservation_confidence: float = 0.0 reservation_confidence: float = 0.0
def describe(self) -> str: def describe(self) -> str:
endpoint_limit_text = str(self.endpoint_limit) if self.endpoint_limit is not None else "inf"
key_limit_text = str(self.key_limit) if self.key_limit is not None else "inf" key_limit_text = str(self.key_limit) if self.key_limit is not None else "inf"
reservation_text = f"{self.reservation_ratio:.0%}" if self.reservation_ratio > 0 else "N/A" reservation_text = f"{self.reservation_ratio:.0%}" if self.reservation_ratio > 0 else "N/A"
return ( return (
f"endpoint={self.endpoint_current}/{endpoint_limit_text}, "
f"key={self.key_current}/{key_limit_text}, " f"key={self.key_current}/{key_limit_text}, "
f"cached={self.is_cached_user}, " f"cached={self.is_cached_user}, "
f"reserve={reservation_text}({self.reservation_phase})" f"reserve={reservation_text}({self.reservation_phase})"
@@ -121,11 +117,13 @@ class CacheAwareScheduler:
PRIORITY_MODE_GLOBAL_KEY, PRIORITY_MODE_GLOBAL_KEY,
} }
# 调度模式常量 # 调度模式常量
SCHEDULING_MODE_FIXED_ORDER = "fixed_order" # 固定顺序模式 SCHEDULING_MODE_FIXED_ORDER = "fixed_order" # 固定顺序模式:严格按优先级,忽略缓存
SCHEDULING_MODE_CACHE_AFFINITY = "cache_affinity" # 缓存亲和模式 SCHEDULING_MODE_CACHE_AFFINITY = "cache_affinity" # 缓存亲和模式:优先缓存,同优先级哈希分散
SCHEDULING_MODE_LOAD_BALANCE = "load_balance" # 负载均衡模式:忽略缓存,同优先级随机轮换
ALLOWED_SCHEDULING_MODES = { ALLOWED_SCHEDULING_MODES = {
SCHEDULING_MODE_FIXED_ORDER, SCHEDULING_MODE_FIXED_ORDER,
SCHEDULING_MODE_CACHE_AFFINITY, SCHEDULING_MODE_CACHE_AFFINITY,
SCHEDULING_MODE_LOAD_BALANCE,
} }
def __init__( def __init__(
@@ -244,9 +242,8 @@ class CacheAwareScheduler:
if not candidates: if not candidates:
if provider_offset == 0: if provider_offset == 0:
# 没有找到任何候选,提供友好的错误提示 # 没有找到任何候选,提供友好的错误提示(不暴露内部信息)
error_msg = f"模型 '{model_name}' 不可用" raise ProviderNotAvailableException("请求的模型当前不可用")
raise ProviderNotAvailableException(error_msg)
break break
self._metrics["total_batches"] += 1 self._metrics["total_batches"] += 1
@@ -268,7 +265,6 @@ class CacheAwareScheduler:
is_cached_user = bool(candidate.is_cached) is_cached_user = bool(candidate.is_cached)
can_use, snapshot = await self._check_concurrent_available( can_use, snapshot = await self._check_concurrent_available(
endpoint,
key, key,
is_cached_user=is_cached_user, is_cached_user=is_cached_user,
) )
@@ -310,47 +306,51 @@ class CacheAwareScheduler:
provider_offset += provider_batch_size provider_offset += provider_batch_size
raise ProviderNotAvailableException(f"所有Provider的资源当前不可用 (model={model_name})") raise ProviderNotAvailableException("服务暂时繁忙,请稍后重试")
def _get_effective_concurrent_limit(self, key: ProviderAPIKey) -> Optional[int]: def _get_effective_rpm_limit(self, key: ProviderAPIKey) -> Optional[int]:
""" """
获取有效的并发限制 获取有效的 RPM 限制
新逻辑: 新逻辑:
- max_concurrent=NULL: 启用自适应,使用 learned_max_concurrent如无学习记录则为 None - rpm_limit=NULL: 启用自适应,使用 learned_rpm_limit如无学习记录则使用默认初始值
- max_concurrent=数字: 固定限制,直接使用该值 - rpm_limit=数字: 固定限制,直接使用该值
Args: Args:
key: API Key对象 key: API Key对象
Returns: Returns:
有效的并发限制None 表示不限制) 有效的 RPM 限制None 表示不限制)
""" """
if key.max_concurrent is None: if key.rpm_limit is None:
# 自适应模式:使用学习到的值 # 自适应模式:使用学习到的值
learned = key.learned_max_concurrent learned = key.learned_rpm_limit
return int(learned) if learned is not None else None if learned is not None:
return int(learned)
# 未学习到值时,使用默认初始限制,避免无限制打爆上游
from src.config.constants import RPMDefaults
return int(RPMDefaults.INITIAL_LIMIT)
else: else:
# 固定限制模式 # 固定限制模式
return int(key.max_concurrent) return int(key.rpm_limit)
async def _check_concurrent_available( async def _check_concurrent_available(
self, self,
endpoint: ProviderEndpoint,
key: ProviderAPIKey, key: ProviderAPIKey,
is_cached_user: bool = False, is_cached_user: bool = False,
) -> Tuple[bool, ConcurrencySnapshot]: ) -> Tuple[bool, ConcurrencySnapshot]:
""" """
检查并发是否可用(使用动态预留机制) 检查 RPM 限制是否可用(使用动态预留机制)
核心逻辑 - 动态缓存预留机制: 核心逻辑 - 动态缓存预留机制:
- 总槽位: 有效并发限制(固定值或学习到的值) - 总槽位: 有效 RPM 限制(固定值或学习到的值)
- 预留比例: 由 AdaptiveReservationManager 根据置信度和负载动态计算 - 预留比例: 由 AdaptiveReservationManager 根据置信度和负载动态计算
- 缓存用户可用: 全部槽位 - 缓存用户可用: 全部槽位
- 新用户可用: 总槽位 × (1 - 动态预留比例) - 新用户可用: 总槽位 × (1 - 动态预留比例)
Args: Args:
endpoint: ProviderEndpoint对象
key: ProviderAPIKey对象 key: ProviderAPIKey对象
is_cached_user: 是否是缓存用户 is_cached_user: 是否是缓存用户
@@ -358,7 +358,7 @@ class CacheAwareScheduler:
(是否可用, 并发快照) (是否可用, 并发快照)
""" """
# 获取有效的并发限制 # 获取有效的并发限制
effective_key_limit = self._get_effective_concurrent_limit(key) effective_key_limit = self._get_effective_rpm_limit(key)
logger.debug( logger.debug(
f" -> 并发检查: _concurrency_manager={self._concurrency_manager is not None}, " f" -> 并发检查: _concurrency_manager={self._concurrency_manager is not None}, "
@@ -369,33 +369,23 @@ class CacheAwareScheduler:
# 并发管理器不可用直接返回True # 并发管理器不可用直接返回True
logger.debug(f" -> 无并发管理器,直接通过") logger.debug(f" -> 无并发管理器,直接通过")
snapshot = ConcurrencySnapshot( snapshot = ConcurrencySnapshot(
endpoint_current=0,
endpoint_limit=(
int(endpoint.max_concurrent) if endpoint.max_concurrent is not None else None
),
key_current=0, key_current=0,
key_limit=effective_key_limit, key_limit=effective_key_limit,
is_cached_user=is_cached_user, is_cached_user=is_cached_user,
) )
return True, snapshot return True, snapshot
# 获取当前并发 # 获取当前 RPM 计
endpoint_count, key_count = await self._concurrency_manager.get_current_concurrency( key_count = await self._concurrency_manager.get_key_rpm_count(
endpoint_id=str(endpoint.id),
key_id=str(key.id), key_id=str(key.id),
) )
can_use = True can_use = True
# 检查Endpoint级别限制
if endpoint.max_concurrent is not None:
if endpoint_count >= endpoint.max_concurrent:
can_use = False
# 计算动态预留比例 # 计算动态预留比例
reservation_result = self._reservation_manager.calculate_reservation( reservation_result = self._reservation_manager.calculate_reservation(
key=key, key=key,
current_concurrent=key_count, current_usage=key_count,
effective_limit=effective_key_limit, effective_limit=effective_key_limit,
) )
@@ -438,7 +428,8 @@ class CacheAwareScheduler:
# 使用 max 确保至少有 1 个槽位可用 # 使用 max 确保至少有 1 个槽位可用
import math import math
available_for_new = max(1, math.ceil(effective_key_limit * (1 - reservation_ratio))) # 与 ConcurrencyManager 的 Lua 脚本保持一致:使用 floor 计算新用户可用槽位
available_for_new = max(1, math.floor(effective_key_limit * (1 - reservation_ratio)))
if key_count >= available_for_new: if key_count >= available_for_new:
logger.debug( logger.debug(
f"Key {key.id[:8]}... 新用户配额已满 " f"Key {key.id[:8]}... 新用户配额已满 "
@@ -458,8 +449,6 @@ class CacheAwareScheduler:
key_limit_for_snapshot = None key_limit_for_snapshot = None
snapshot = ConcurrencySnapshot( snapshot = ConcurrencySnapshot(
endpoint_current=endpoint_count,
endpoint_limit=endpoint.max_concurrent,
key_current=key_count, key_current=key_count,
key_limit=key_limit_for_snapshot, key_limit=key_limit_for_snapshot,
is_cached_user=is_cached_user, is_cached_user=is_cached_user,
@@ -473,7 +462,7 @@ class CacheAwareScheduler:
def _get_effective_restrictions( def _get_effective_restrictions(
self, self,
user_api_key: Optional[ApiKey], user_api_key: Optional[ApiKey],
) -> Dict[str, Optional[set]]: ) -> Dict[str, Any]:
""" """
获取有效的访问限制(合并 ApiKey 和 User 的限制) 获取有效的访问限制(合并 ApiKey 和 User 的限制)
@@ -534,7 +523,10 @@ class CacheAwareScheduler:
) )
# 合并 allowed_models # 合并 allowed_models
result["allowed_models"] = merge_restrictions( # allowed_models 支持 list/dict 两种结构,不能转成 set 否则会导致权限校验失效
from src.core.model_permissions import merge_allowed_models
result["allowed_models"] = merge_allowed_models(
user_api_key.allowed_models, user.allowed_models if user else None user_api_key.allowed_models, user.allowed_models if user else None
) )
@@ -615,22 +607,25 @@ class CacheAwareScheduler:
) )
return [], global_model_id return [], global_model_id
# 0.2 检查模型是否被允许 # 0.2 检查模型是否被允许(支持简单列表和按格式字典两种模式)
if allowed_models is not None: from src.core.model_permissions import check_model_allowed, get_allowed_models_preview
if (
requested_model_name not in allowed_models if not check_model_allowed(
and resolved_model_name not in allowed_models model_name=requested_model_name,
): allowed_models=allowed_models,
resolved_note = ( api_format=target_format.value,
f" (解析为 {resolved_model_name})" resolved_model_name=resolved_model_name,
if resolved_model_name != requested_model_name ):
else "" resolved_note = (
) f" (解析为 {resolved_model_name})"
logger.debug( if resolved_model_name != requested_model_name
f"用户/API Key 不允许使用模型 {requested_model_name}{resolved_note}, " else ""
f"允许的模型: {allowed_models}" )
) logger.debug(
return [], global_model_id f"用户/API Key 不允许使用模型 {requested_model_name}{resolved_note}, "
f"允许的模型: {get_allowed_models_preview(allowed_models)}"
)
return [], global_model_id
# 1. 查询 Providers # 1. 查询 Providers
providers = self._query_providers( providers = self._query_providers(
@@ -680,8 +675,9 @@ class CacheAwareScheduler:
f"(api_format={target_format.value}, model={model_name})" f"(api_format={target_format.value}, model={model_name})"
) )
# 4. 应用缓存亲和性排序(仅在缓存亲和模式下启用) # 4. 根据调度模式应用不同的排序策略
if self.scheduling_mode == self.SCHEDULING_MODE_CACHE_AFFINITY: if self.scheduling_mode == self.SCHEDULING_MODE_CACHE_AFFINITY:
# 缓存亲和模式:优先使用缓存的,同优先级内哈希分散
if affinity_key and candidates: if affinity_key and candidates:
candidates = await self._apply_cache_affinity( candidates = await self._apply_cache_affinity(
candidates=candidates, candidates=candidates,
@@ -689,8 +685,13 @@ class CacheAwareScheduler:
api_format=target_format, api_format=target_format,
global_model_id=global_model_id, global_model_id=global_model_id,
) )
elif self.scheduling_mode == self.SCHEDULING_MODE_LOAD_BALANCE:
# 负载均衡模式:忽略缓存,同优先级内随机轮换
candidates = self._apply_load_balance(candidates)
for candidate in candidates:
candidate.is_cached = False
else: else:
# 固定顺序模式:标记所有候选为非缓存 # 固定顺序模式:严格按优先级,忽略缓存
for candidate in candidates: for candidate in candidates:
candidate.is_cached = False candidate.is_cached = False
@@ -716,8 +717,11 @@ class CacheAwareScheduler:
provider_query = ( provider_query = (
db.query(Provider) db.query(Provider)
.options( .options(
selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys), # 预加载 Provider 级别的 api_keys
# 同时加载 models 和 global_model 关系,以便 get_effective_* 方法能正确继承默认值 selectinload(Provider.api_keys),
# 预加载 endpoints用于按 api_format 选择请求配置)
selectinload(Provider.endpoints),
# 同时加载 models 和 global_model 关系
selectinload(Provider.models).selectinload(Model.global_model), selectinload(Provider.models).selectinload(Model.global_model),
) )
.filter(Provider.is_active == True) .filter(Provider.is_active == True)
@@ -844,6 +848,7 @@ class CacheAwareScheduler:
def _check_key_availability( def _check_key_availability(
self, self,
key: ProviderAPIKey, key: ProviderAPIKey,
api_format: Optional[str],
model_name: str, model_name: str,
capability_requirements: Optional[Dict[str, bool]] = None, capability_requirements: Optional[Dict[str, bool]] = None,
resolved_model_name: Optional[str] = None, resolved_model_name: Optional[str] = None,
@@ -863,20 +868,24 @@ class CacheAwareScheduler:
Returns: Returns:
(is_available, skip_reason) (is_available, skip_reason)
""" """
# 检查熔断器状态(使用详细状态方法获取更丰富的跳过原因) # 检查熔断器状态(使用详细状态方法获取更丰富的跳过原因,按 API 格式
is_available, circuit_reason = health_monitor.get_circuit_breaker_status(key) is_available, circuit_reason = health_monitor.get_circuit_breaker_status(
key, api_format=api_format
)
if not is_available: if not is_available:
return False, circuit_reason or "熔断器已打开" return False, circuit_reason or "熔断器已打开"
# 模型权限检查:使用 allowed_models 白名单 # 模型权限检查:使用 allowed_models 白名单(支持简单列表和按格式字典两种模式)
# None = 允许所有模型,[] = 拒绝所有模型,["a","b"] = 只允许指定模型 # None = 允许所有模型,[] = 拒绝所有模型,["a","b"] = 只允许指定模型
if key.allowed_models is not None and ( from src.core.model_permissions import check_model_allowed, get_allowed_models_preview
model_name not in key.allowed_models
and (not resolved_model_name or resolved_model_name not in key.allowed_models) if not check_model_allowed(
model_name=model_name,
allowed_models=key.allowed_models,
api_format=api_format,
resolved_model_name=resolved_model_name,
): ):
allowed_preview = ", ".join(key.allowed_models[:3]) if key.allowed_models else "(无)" return False, f"模型权限不匹配(允许: {get_allowed_models_preview(key.allowed_models)})"
suffix = "..." if len(key.allowed_models) > 3 else ""
return False, f"模型权限不匹配(允许: {allowed_preview}{suffix})"
# Key 级别的能力匹配检查 # Key 级别的能力匹配检查
# 注意:模型级别的能力检查已在 _check_model_support 中完成 # 注意:模型级别的能力检查已在 _check_model_support 中完成
@@ -906,6 +915,8 @@ class CacheAwareScheduler:
""" """
构建候选列表 构建候选列表
Key 直属 Provider通过 api_formats 筛选符合目标格式的 Key。
Args: Args:
db: 数据库会话 db: 数据库会话
providers: Provider 列表 providers: Provider 列表
@@ -921,10 +932,10 @@ class CacheAwareScheduler:
候选列表 候选列表
""" """
candidates: List[ProviderCandidate] = [] candidates: List[ProviderCandidate] = []
target_format_str = target_format.value
for provider in providers: for provider in providers:
# 检查模型支持(同时检查流式支持和模型能力需求) # 检查模型支持(同时检查流式支持和模型能力需求)
# 模型能力检查在 Provider 级别进行,如果模型不支持所需能力,整个 Provider 被跳过
supports_model, skip_reason, _model_caps = await self._check_model_support( supports_model, skip_reason, _model_caps = await self._check_model_support(
db, provider, model_name, is_stream, capability_requirements db, provider, model_name, is_stream, capability_requirements
) )
@@ -932,49 +943,63 @@ class CacheAwareScheduler:
logger.debug(f"Provider {provider.name} 不支持模型 {model_name}: {skip_reason}") logger.debug(f"Provider {provider.name} 不支持模型 {model_name}: {skip_reason}")
continue continue
# 查找目标格式对应的 Endpoint获取请求配置
target_endpoint = None
for endpoint in provider.endpoints: for endpoint in provider.endpoints:
# endpoint.api_format 是字符串target_format 是枚举
endpoint_format_str = ( endpoint_format_str = (
endpoint.api_format endpoint.api_format
if isinstance(endpoint.api_format, str) if isinstance(endpoint.api_format, str)
else endpoint.api_format.value else endpoint.api_format.value
) )
if not endpoint.is_active or endpoint_format_str != target_format.value: if endpoint.is_active and endpoint_format_str == target_format_str:
continue target_endpoint = endpoint
break
# 获取活跃的 Key 并按 internal_priority + 负载均衡排序 if not target_endpoint:
active_keys = [key for key in endpoint.api_keys if key.is_active] logger.debug(f"Provider {provider.name} 没有活跃的 {target_format_str} 端点")
# 检查是否所有 Key 都是 TTL=0轮换模式 continue
# 如果所有 Key 的 cache_ttl_minutes 都是 0 或 None则使用随机排序
use_random = all(
(key.cache_ttl_minutes or 0) == 0 for key in active_keys
) if active_keys else False
if use_random and len(active_keys) > 1:
logger.debug(
f" Endpoint {endpoint.id[:8]}... 启用 Key 轮换模式 (TTL=0, {len(active_keys)} keys)"
)
keys = self._shuffle_keys_by_internal_priority(active_keys, affinity_key, use_random)
for key in keys: # Key 直属 Provider通过 api_formats 筛选
# Key 级别的能力检查(模型级别的能力检查已在上面完成) active_keys = [
is_available, skip_reason = self._check_key_availability( key for key in provider.api_keys
key, if key.is_active and target_format_str in (key.api_formats or [])
model_name, ]
capability_requirements,
resolved_model_name=resolved_model_name,
)
candidate = ProviderCandidate( if not active_keys:
provider=provider, logger.debug(f"Provider {provider.name} 没有支持 {target_format_str} 的活跃 Key")
endpoint=endpoint, continue
key=key,
is_skipped=not is_available,
skip_reason=skip_reason,
)
candidates.append(candidate)
if max_candidates and len(candidates) >= max_candidates: # 检查是否所有 Key 都是 TTL=0轮换模式
return candidates use_random = all(
(key.cache_ttl_minutes or 0) == 0 for key in active_keys
) if active_keys else False
if use_random and len(active_keys) > 1:
logger.debug(
f" Provider {provider.name} 启用 Key 轮换模式 (TTL=0, {len(active_keys)} keys)"
)
keys = self._shuffle_keys_by_internal_priority(active_keys, affinity_key, use_random)
for key in keys:
# Key 级别的能力检查
is_available, skip_reason = self._check_key_availability(
key,
target_format_str,
model_name,
capability_requirements,
resolved_model_name=resolved_model_name,
)
candidate = ProviderCandidate(
provider=provider,
endpoint=target_endpoint,
key=key,
is_skipped=not is_available,
skip_reason=skip_reason,
)
candidates.append(candidate)
if max_candidates and len(candidates) >= max_candidates:
return candidates
return candidates return candidates
@@ -1163,6 +1188,56 @@ class CacheAwareScheduler:
return result return result
def _apply_load_balance(
self, candidates: List[ProviderCandidate]
) -> List[ProviderCandidate]:
"""
负载均衡模式:同优先级内随机轮换
排序逻辑:
1. 按优先级分组provider_priority, internal_priority 或 global_priority
2. 同优先级组内随机打乱
3. 不考虑缓存亲和性
"""
if not candidates:
return candidates
from collections import defaultdict
priority_groups: Dict[tuple, List[ProviderCandidate]] = defaultdict(list)
# 根据优先级模式选择分组方式
if self.priority_mode == self.PRIORITY_MODE_GLOBAL_KEY:
# 全局 Key 优先模式:按 global_priority 分组
for candidate in candidates:
global_priority = (
candidate.key.global_priority
if candidate.key and candidate.key.global_priority is not None
else 999999
)
priority_groups[(global_priority,)].append(candidate)
else:
# 提供商优先模式:按 (provider_priority, internal_priority) 分组
for candidate in candidates:
key = (
candidate.provider.provider_priority or 999999,
candidate.key.internal_priority if candidate.key else 999999,
)
priority_groups[key].append(candidate)
result: List[ProviderCandidate] = []
for priority in sorted(priority_groups.keys()):
group = priority_groups[priority]
if len(group) > 1:
# 同优先级内随机打乱
shuffled = list(group)
random.shuffle(shuffled)
result.extend(shuffled)
else:
result.extend(group)
return result
def _shuffle_keys_by_internal_priority( def _shuffle_keys_by_internal_priority(
self, self,
keys: List[ProviderAPIKey], keys: List[ProviderAPIKey],

185
src/services/cache/provider_cache.py vendored Normal file
View File

@@ -0,0 +1,185 @@
"""
Provider 缓存服务 - 减少 Provider 和 ProviderAPIKey 查询
用于缓存 Provider 的 billing_type 和 ProviderAPIKey 的 rate_multiplier
这些数据在 UsageService.record_usage() 中被频繁查询但变化不频繁。
"""
from typing import Optional, Tuple
from sqlalchemy.orm import Session
from src.config.constants import CacheTTL
from src.core.cache_service import CacheService
from src.core.enums import ProviderBillingType
from src.core.logger import logger
from src.models.database import Provider, ProviderAPIKey
class ProviderCacheService:
"""Provider 缓存服务
提供 Provider 和 ProviderAPIKey 的缓存查询功能,减少数据库访问。
主要用于 UsageService 中获取费率倍数和计费类型。
"""
CACHE_TTL = CacheTTL.PROVIDER # 5 分钟
@staticmethod
async def get_provider_api_key_rate_multiplier(
db: Session, provider_api_key_id: str, api_format: Optional[str] = None
) -> Optional[float]:
"""
获取 ProviderAPIKey 的 rate_multiplier带缓存
优先返回指定 API 格式的倍率,如果没有则返回默认倍率。
Args:
db: 数据库会话
provider_api_key_id: ProviderAPIKey ID
api_format: API 格式(可选),如 "CLAUDE""OPENAI"
Returns:
rate_multiplier 或 None如果找不到
"""
# 缓存键包含 api_format
format_suffix = api_format.upper() if api_format else "default"
cache_key = f"provider_api_key:rate_multiplier:{provider_api_key_id}:{format_suffix}"
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data is not None:
logger.debug(f"ProviderAPIKey rate_multiplier 缓存命中: {provider_api_key_id[:8]}... format={format_suffix}")
# 缓存的 "NOT_FOUND" 表示数据库中不存在
if cached_data == "NOT_FOUND":
return None
return float(cached_data)
# 2. 缓存未命中,查询数据库
provider_key = (
db.query(ProviderAPIKey.rate_multiplier, ProviderAPIKey.rate_multipliers)
.filter(ProviderAPIKey.id == provider_api_key_id)
.first()
)
# 3. 计算倍率并写入缓存
if provider_key:
# 优先使用 rate_multipliers[api_format],回退到 rate_multiplier
rate_multiplier = provider_key.rate_multiplier or 1.0
if api_format and provider_key.rate_multipliers:
format_upper = api_format.upper()
if format_upper in provider_key.rate_multipliers:
rate_multiplier = provider_key.rate_multipliers[format_upper]
await CacheService.set(
cache_key, rate_multiplier, ttl_seconds=ProviderCacheService.CACHE_TTL
)
logger.debug(f"ProviderAPIKey rate_multiplier 已缓存: {provider_api_key_id[:8]}... format={format_suffix} value={rate_multiplier}")
return rate_multiplier
else:
# 缓存负结果
await CacheService.set(
cache_key, "NOT_FOUND", ttl_seconds=ProviderCacheService.CACHE_TTL
)
return None
@staticmethod
async def get_provider_billing_type(
db: Session, provider_id: str
) -> Optional[ProviderBillingType]:
"""
获取 Provider 的 billing_type带缓存
Args:
db: 数据库会话
provider_id: Provider ID
Returns:
billing_type 或 None如果找不到
"""
cache_key = f"provider:billing_type:{provider_id}"
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data is not None:
logger.debug(f"Provider billing_type 缓存命中: {provider_id[:8]}...")
if cached_data == "NOT_FOUND":
return None
try:
return ProviderBillingType(cached_data)
except ValueError:
# 缓存值无效,删除并重新查询
await CacheService.delete(cache_key)
# 2. 缓存未命中,查询数据库
provider = (
db.query(Provider.billing_type).filter(Provider.id == provider_id).first()
)
# 3. 写入缓存
if provider:
billing_type = provider.billing_type
await CacheService.set(
cache_key, billing_type.value, ttl_seconds=ProviderCacheService.CACHE_TTL
)
logger.debug(f"Provider billing_type 已缓存: {provider_id[:8]}...")
return billing_type
else:
# 缓存负结果
await CacheService.set(
cache_key, "NOT_FOUND", ttl_seconds=ProviderCacheService.CACHE_TTL
)
return None
@staticmethod
async def get_rate_multiplier_and_free_tier(
db: Session,
provider_api_key_id: Optional[str],
provider_id: Optional[str],
api_format: Optional[str] = None,
) -> Tuple[float, bool]:
"""
获取费率倍数和是否免费套餐(带缓存)
这是 UsageService._get_rate_multiplier_and_free_tier 的缓存版本。
Args:
db: 数据库会话
provider_api_key_id: ProviderAPIKey ID可选
provider_id: Provider ID可选
api_format: API 格式(可选),用于获取按格式配置的倍率
Returns:
(rate_multiplier, is_free_tier) 元组
"""
actual_rate_multiplier = 1.0
is_free_tier = False
# 获取费率倍数(支持按 API 格式查询)
if provider_api_key_id:
rate_multiplier = await ProviderCacheService.get_provider_api_key_rate_multiplier(
db, provider_api_key_id, api_format
)
if rate_multiplier is not None:
actual_rate_multiplier = rate_multiplier
# 获取计费类型
if provider_id:
billing_type = await ProviderCacheService.get_provider_billing_type(db, provider_id)
if billing_type == ProviderBillingType.FREE_TIER:
is_free_tier = True
return actual_rate_multiplier, is_free_tier
@staticmethod
async def invalidate_provider_api_key_cache(provider_api_key_id: str) -> None:
"""清除 ProviderAPIKey 缓存(包括所有 API 格式的缓存)"""
# 使用模式匹配删除所有格式的缓存
await CacheService.delete_pattern(f"provider_api_key:rate_multiplier:{provider_api_key_id}:*")
logger.debug(f"ProviderAPIKey 缓存已清除: {provider_api_key_id[:8]}...")
@staticmethod
async def invalidate_provider_cache(provider_id: str) -> None:
"""清除 Provider 缓存"""
await CacheService.delete(f"provider:billing_type:{provider_id}")
logger.debug(f"Provider 缓存已清除: {provider_id[:8]}...")

View File

@@ -70,20 +70,21 @@ class EndpointHealthService:
db.query(ProviderEndpoint).join(Provider).filter(Provider.is_active.is_(True)).all() db.query(ProviderEndpoint).join(Provider).filter(Provider.is_active.is_(True)).all()
) )
# 收集所有 endpoint_ids # 收集所有 provider_ids
all_endpoint_ids = [ep.id for ep in endpoints] all_provider_ids = list(set(ep.provider_id for ep in endpoints))
# 批量查询所有密钥 # 批量查询所有密钥(通过 provider_id 关联)
all_keys = ( all_keys = (
db.query(ProviderAPIKey) db.query(ProviderAPIKey)
.filter(ProviderAPIKey.endpoint_id.in_(all_endpoint_ids)) .filter(ProviderAPIKey.provider_id.in_(all_provider_ids))
.all() .all()
) if all_endpoint_ids else [] ) if all_provider_ids else []
# 按 endpoint_id 分组密钥 # 按 api_format 分组密钥(通过 api_formats 字段)
keys_by_endpoint: Dict[str, List[ProviderAPIKey]] = defaultdict(list) keys_by_format: Dict[str, List[ProviderAPIKey]] = defaultdict(list)
for key in all_keys: for key in all_keys:
keys_by_endpoint[key.endpoint_id].append(key) for fmt in (key.api_formats or []):
keys_by_format[fmt].append(key)
# 按 API 格式聚合 # 按 API 格式聚合
format_stats = defaultdict( format_stats = defaultdict(
@@ -106,18 +107,36 @@ class EndpointHealthService:
format_stats[api_format]["endpoint_ids"].append(ep.id) format_stats[api_format]["endpoint_ids"].append(ep.id)
format_stats[api_format]["provider_ids"].add(ep.provider_id) format_stats[api_format]["provider_ids"].add(ep.provider_id)
# 从预加载的密钥中获取 # 统计每个格式的密钥(直接从 keys_by_format 获取
keys = keys_by_endpoint.get(ep.id, []) for api_format, keys in keys_by_format.items():
format_stats[api_format]["total_keys"] += len(keys) if api_format not in format_stats:
# 如果有 Key 但没有对应的 Endpoint跳过
continue
# 统计活跃密钥和健康度 # 去重(同一个 Key 可能支持多个格式)
if ep.is_active: seen_key_ids = set()
for key in keys: unique_keys = []
format_stats[api_format]["key_ids"].append(key.id) for key in keys:
if key.is_active and not key.circuit_breaker_open: if key.id not in seen_key_ids:
format_stats[api_format]["active_keys"] += 1 seen_key_ids.add(key.id)
health_score = key.health_score if key.health_score is not None else 1.0 unique_keys.append(key)
format_stats[api_format]["health_scores"].append(health_score)
format_stats[api_format]["total_keys"] = len(unique_keys)
for key in unique_keys:
format_stats[api_format]["key_ids"].append(key.id)
# 检查该格式的熔断器状态
circuit_by_format = key.circuit_breaker_by_format or {}
format_circuit = circuit_by_format.get(api_format, {})
is_circuit_open = format_circuit.get("open", False)
if key.is_active and not is_circuit_open:
format_stats[api_format]["active_keys"] += 1
# 获取该格式的健康度
health_by_format = key.health_by_format or {}
format_health = health_by_format.get(api_format, {})
health_score = float(format_health.get("health_score") or 1.0)
format_stats[api_format]["health_scores"].append(health_score)
# 批量生成所有格式的时间线数据 # 批量生成所有格式的时间线数据
all_key_ids = [] all_key_ids = []
@@ -372,7 +391,7 @@ class EndpointHealthService:
segments: int = 100, segments: int = 100,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
从真实使用记录生成时间线数据(兼容旧接口,使用批量查询优化) 从真实使用记录生成时间线数据(使用批量查询优化)
Args: Args:
db: 数据库会话 db: 数据库会话
@@ -391,13 +410,34 @@ class EndpointHealthService:
"time_range_end": None, "time_range_end": None,
} }
# 先查询该 API 格式下的所有密钥 # 基于 endpoint_ids 反推 provider_ids 与 api_format再选出支持该格式的 keys
key_ids = [ endpoint_rows = (
k.id db.query(ProviderEndpoint.provider_id, ProviderEndpoint.api_format)
for k in db.query(ProviderAPIKey.id) .filter(ProviderEndpoint.id.in_(endpoint_ids))
.filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids))
.all() .all()
] )
if not endpoint_rows:
return {
"timeline": ["unknown"] * 100,
"time_range_start": None,
"time_range_end": None,
}
provider_ids = {str(pid) for pid, _fmt in endpoint_rows}
# 同一调用中 endpoint_ids 来自同一 api_format上层已按格式分组
api_format = (
endpoint_rows[0][1].value
if hasattr(endpoint_rows[0][1], "value")
else str(endpoint_rows[0][1])
)
keys = (
db.query(ProviderAPIKey.id, ProviderAPIKey.api_formats)
.filter(ProviderAPIKey.provider_id.in_(provider_ids))
.all()
)
key_ids = [str(key_id) for key_id, formats in keys if api_format in (formats or [])]
if not key_ids: if not key_ids:
return { return {

View File

@@ -1,11 +1,15 @@
""" """
健康监控器 - Endpoint 和 Key 的健康度追踪 健康监控器 - Endpoint 和 Key 的健康度追踪(按 API 格式区分)
功能: 功能:
1. 基于滑动窗口的错误率计算 1. 基于滑动窗口的错误率计算(按 API 格式独立)
2. 三态熔断器:关闭 -> 打开 -> 半开 -> 关闭 2. 三态熔断器:关闭 -> 打开 -> 半开 -> 关闭(按 API 格式独立)
3. 半开状态允许少量请求验证服务恢复 3. 半开状态允许少量请求验证服务恢复
4. 提供健康度查询和管理 API 4. 提供健康度查询和管理 API
数据结构:
- health_by_format: {"CLAUDE": {"health_score": 1.0, "consecutive_failures": 0, ...}, ...}
- circuit_breaker_by_format: {"CLAUDE": {"open": false, "open_at": null, ...}, ...}
""" """
import os import os
@@ -30,8 +34,30 @@ class CircuitState:
HALF_OPEN = "half_open" # 半开(验证恢复) HALF_OPEN = "half_open" # 半开(验证恢复)
# 默认健康度数据结构
def _default_health_data() -> Dict[str, Any]:
return {
"health_score": 1.0,
"consecutive_failures": 0,
"last_failure_at": None,
"request_results_window": [],
}
# 默认熔断器数据结构
def _default_circuit_data() -> Dict[str, Any]:
return {
"open": False,
"open_at": None,
"next_probe_at": None,
"half_open_until": None,
"half_open_successes": 0,
"half_open_failures": 0,
}
class HealthMonitor: class HealthMonitor:
"""健康监控器(滑动窗口 + 半开状态模式)""" """健康监控器(滑动窗口 + 半开状态模式,按 API 格式区分"""
# === 滑动窗口配置 === # === 滑动窗口配置 ===
WINDOW_SIZE = int(os.getenv("HEALTH_WINDOW_SIZE", str(CircuitBreakerDefaults.WINDOW_SIZE))) WINDOW_SIZE = int(os.getenv("HEALTH_WINDOW_SIZE", str(CircuitBreakerDefaults.WINDOW_SIZE)))
@@ -96,6 +122,38 @@ class HealthMonitor:
_circuit_history: List[Dict[str, Any]] = [] _circuit_history: List[Dict[str, Any]] = []
_open_circuit_keys: int = 0 _open_circuit_keys: int = 0
# ==================== 数据访问辅助方法 ====================
@classmethod
def _get_health_data(cls, key: ProviderAPIKey, api_format: str) -> Dict[str, Any]:
"""获取指定格式的健康度数据,不存在则返回默认值"""
health_by_format = key.health_by_format or {}
if api_format not in health_by_format:
return _default_health_data()
return health_by_format[api_format]
@classmethod
def _set_health_data(cls, key: ProviderAPIKey, api_format: str, data: Dict[str, Any]) -> None:
"""设置指定格式的健康度数据"""
health_by_format = dict(key.health_by_format or {})
health_by_format[api_format] = data
key.health_by_format = health_by_format # type: ignore[assignment]
@classmethod
def _get_circuit_data(cls, key: ProviderAPIKey, api_format: str) -> Dict[str, Any]:
"""获取指定格式的熔断器数据,不存在则返回默认值"""
circuit_by_format = key.circuit_breaker_by_format or {}
if api_format not in circuit_by_format:
return _default_circuit_data()
return circuit_by_format[api_format]
@classmethod
def _set_circuit_data(cls, key: ProviderAPIKey, api_format: str, data: Dict[str, Any]) -> None:
"""设置指定格式的熔断器数据"""
circuit_by_format = dict(key.circuit_breaker_by_format or {})
circuit_by_format[api_format] = data
key.circuit_breaker_by_format = circuit_by_format # type: ignore[assignment]
# ==================== 核心方法 ==================== # ==================== 核心方法 ====================
@classmethod @classmethod
@@ -103,9 +161,21 @@ class HealthMonitor:
cls, cls,
db: Session, db: Session,
key_id: Optional[str] = None, key_id: Optional[str] = None,
api_format: Optional[str] = None,
response_time_ms: Optional[int] = None, response_time_ms: Optional[int] = None,
) -> None: ) -> None:
"""记录成功请求""" """记录成功请求(按 API 格式)
Args:
db: 数据库会话
key_id: Key ID必需
api_format: API 格式(必需,用于区分不同格式的健康度)
response_time_ms: 响应时间(可选)
Note:
api_format 在逻辑上是必需的,但为了向后兼容保持 Optional 签名。
如果未提供,会尝试从 Key 的 api_formats 中获取第一个格式作为 fallback。
"""
try: try:
if not key_id: if not key_id:
return return
@@ -114,39 +184,96 @@ class HealthMonitor:
if not key: if not key:
return return
# api_format 兼容处理:如果未提供,尝试使用 Key 的第一个格式
effective_api_format = api_format
if not effective_api_format:
if key.api_formats and len(key.api_formats) > 0:
effective_api_format = key.api_formats[0]
logger.debug(
f"record_success: api_format 未提供,使用默认格式 {effective_api_format}"
)
else:
logger.warning(
f"record_success: api_format 未提供且 Key 无可用格式: key_id={key_id[:8]}..."
)
return
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
now_ts = now.timestamp() now_ts = now.timestamp()
# 获取当前格式的健康度数据
health_data = cls._get_health_data(key, effective_api_format)
circuit_data = cls._get_circuit_data(key, effective_api_format)
# 1. 更新滑动窗口 # 1. 更新滑动窗口
cls._add_to_window(key, now_ts, success=True) window = health_data.get("request_results_window") or []
window.append({"ts": now_ts, "ok": True})
cutoff_ts = now_ts - cls.WINDOW_SECONDS
window = [r for r in window if r["ts"] > cutoff_ts]
if len(window) > cls.WINDOW_SIZE:
window = window[-cls.WINDOW_SIZE :]
health_data["request_results_window"] = window
# 2. 更新健康度(用于展示) # 2. 更新健康度(用于展示)
new_score = min(float(key.health_score or 0) + cls.SUCCESS_INCREMENT, 1.0) current_score = float(health_data.get("health_score") or 0)
key.health_score = new_score # type: ignore[assignment] new_score = min(current_score + cls.SUCCESS_INCREMENT, 1.0)
health_data["health_score"] = new_score
# 3. 更新统计 # 3. 更新统计
key.consecutive_failures = 0 # type: ignore[assignment] health_data["consecutive_failures"] = 0
key.last_failure_at = None # type: ignore[assignment] health_data["last_failure_at"] = None
# 4. 处理熔断器状态
state = cls._get_circuit_state_from_data(circuit_data, now)
if state == CircuitState.HALF_OPEN:
# 半开状态:记录成功
circuit_data["half_open_successes"] = int(
circuit_data.get("half_open_successes") or 0
) + 1
if circuit_data["half_open_successes"] >= cls.HALF_OPEN_SUCCESS_THRESHOLD:
# 达到成功阈值,关闭熔断器
cls._close_circuit_data(circuit_data, health_data, reason="半开状态验证成功")
cls._push_circuit_event(
{
"event": "closed",
"key_id": key.id,
"api_format": effective_api_format,
"reason": "半开状态验证成功",
"timestamp": now.isoformat(),
}
)
logger.info(
f"[CLOSED] Key 熔断器关闭: {key.id[:8]}.../{effective_api_format} | 原因: 半开状态验证成功"
)
elif state == CircuitState.OPEN:
# 打开状态下的成功(探测成功),进入半开状态
cls._enter_half_open_data(circuit_data, now)
cls._push_circuit_event(
{
"event": "half_open",
"key_id": key.id,
"api_format": effective_api_format,
"timestamp": now.isoformat(),
}
)
logger.info(
f"[HALF-OPEN] Key 进入半开状态: {key.id[:8]}.../{effective_api_format} | "
f"需要 {cls.HALF_OPEN_SUCCESS_THRESHOLD} 次成功关闭熔断器"
)
# 保存数据
cls._set_health_data(key, effective_api_format, health_data)
cls._set_circuit_data(key, effective_api_format, circuit_data)
# 更新全局统计
key.success_count = int(key.success_count or 0) + 1 # type: ignore[assignment] key.success_count = int(key.success_count or 0) + 1 # type: ignore[assignment]
key.request_count = int(key.request_count or 0) + 1 # type: ignore[assignment] key.request_count = int(key.request_count or 0) + 1 # type: ignore[assignment]
if response_time_ms: if response_time_ms:
key.total_response_time_ms = int(key.total_response_time_ms or 0) + response_time_ms # type: ignore[assignment] key.total_response_time_ms = int(key.total_response_time_ms or 0) + response_time_ms # type: ignore[assignment]
# 4. 处理熔断器状态
state = cls._get_circuit_state(key, now)
if state == CircuitState.HALF_OPEN:
# 半开状态:记录成功
key.half_open_successes = int(key.half_open_successes or 0) + 1 # type: ignore[assignment]
if int(key.half_open_successes or 0) >= cls.HALF_OPEN_SUCCESS_THRESHOLD:
# 达到成功阈值,关闭熔断器
cls._close_circuit(key, now, reason="半开状态验证成功")
elif state == CircuitState.OPEN:
# 打开状态下的成功(探测成功),进入半开状态
cls._enter_half_open(key, now)
db.flush() db.flush()
get_batch_committer().mark_dirty(db) get_batch_committer().mark_dirty(db)
@@ -159,9 +286,21 @@ class HealthMonitor:
cls, cls,
db: Session, db: Session,
key_id: Optional[str] = None, key_id: Optional[str] = None,
api_format: Optional[str] = None,
error_type: Optional[str] = None, error_type: Optional[str] = None,
) -> None: ) -> None:
"""记录失败请求""" """记录失败请求(按 API 格式)
Args:
db: 数据库会话
key_id: Key ID必需
api_format: API 格式(必需,用于区分不同格式的健康度)
error_type: 错误类型(可选)
Note:
api_format 在逻辑上是必需的,但为了向后兼容保持 Optional 签名。
如果未提供,会尝试从 Key 的 api_formats 中获取第一个格式作为 fallback。
"""
try: try:
if not key_id: if not key_id:
return return
@@ -170,46 +309,117 @@ class HealthMonitor:
if not key: if not key:
return return
# api_format 兼容处理:如果未提供,尝试使用 Key 的第一个格式
effective_api_format = api_format
if not effective_api_format:
if key.api_formats and len(key.api_formats) > 0:
effective_api_format = key.api_formats[0]
logger.debug(
f"record_failure: api_format 未提供,使用默认格式 {effective_api_format}"
)
else:
logger.warning(
f"record_failure: api_format 未提供且 Key 无可用格式: key_id={key_id[:8]}..."
)
return
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
now_ts = now.timestamp() now_ts = now.timestamp()
# 获取当前格式的健康度数据
health_data = cls._get_health_data(key, effective_api_format)
circuit_data = cls._get_circuit_data(key, effective_api_format)
# 1. 更新滑动窗口 # 1. 更新滑动窗口
cls._add_to_window(key, now_ts, success=False) window = health_data.get("request_results_window") or []
window.append({"ts": now_ts, "ok": False})
cutoff_ts = now_ts - cls.WINDOW_SECONDS
window = [r for r in window if r["ts"] > cutoff_ts]
if len(window) > cls.WINDOW_SIZE:
window = window[-cls.WINDOW_SIZE :]
health_data["request_results_window"] = window
# 2. 更新健康度(用于展示) # 2. 更新健康度(用于展示)
new_score = max(float(key.health_score or 1) - cls.FAILURE_DECREMENT, 0.0) current_score = float(health_data.get("health_score") or 1)
key.health_score = new_score # type: ignore[assignment] new_score = max(current_score - cls.FAILURE_DECREMENT, 0.0)
health_data["health_score"] = new_score
# 3. 更新统计 # 3. 更新统计
key.consecutive_failures = int(key.consecutive_failures or 0) + 1 # type: ignore[assignment] health_data["consecutive_failures"] = (
key.last_failure_at = now # type: ignore[assignment] int(health_data.get("consecutive_failures") or 0) + 1
key.error_count = int(key.error_count or 0) + 1 # type: ignore[assignment] )
key.request_count = int(key.request_count or 0) + 1 # type: ignore[assignment] health_data["last_failure_at"] = now.isoformat()
# 4. 处理熔断器状态 # 4. 处理熔断器状态
state = cls._get_circuit_state(key, now) state = cls._get_circuit_state_from_data(circuit_data, now)
if state == CircuitState.HALF_OPEN: if state == CircuitState.HALF_OPEN:
# 半开状态:记录失败 # 半开状态:记录失败
key.half_open_failures = int(key.half_open_failures or 0) + 1 # type: ignore[assignment] circuit_data["half_open_failures"] = int(
circuit_data.get("half_open_failures") or 0
) + 1
if int(key.half_open_failures or 0) >= cls.HALF_OPEN_FAILURE_THRESHOLD: if circuit_data["half_open_failures"] >= cls.HALF_OPEN_FAILURE_THRESHOLD:
# 达到失败阈值,重新打开熔断器 # 达到失败阈值,重新打开熔断器
cls._open_circuit(key, now, reason="半开状态验证失败") # 注意:半开状态本身就是打开状态的子状态,不需要增加计数
consecutive = int(health_data.get("consecutive_failures") or 0)
recovery_seconds = cls._calculate_recovery_seconds(consecutive)
cls._open_circuit_data(
circuit_data, now, recovery_seconds, reason="半开状态验证失败"
)
cls._push_circuit_event(
{
"event": "opened",
"key_id": key.id,
"api_format": effective_api_format,
"reason": "半开状态验证失败",
"recovery_seconds": recovery_seconds,
"timestamp": now.isoformat(),
}
)
logger.warning(
f"[OPEN] Key 熔断器打开: {key.id[:8]}.../{effective_api_format} | 原因: 半开状态验证失败 | "
f"{recovery_seconds}秒后进入半开状态"
)
elif state == CircuitState.CLOSED: elif state == CircuitState.CLOSED:
# 关闭状态:检查是否需要打开熔断器 # 关闭状态:检查是否需要打开熔断器
error_rate = cls._calculate_error_rate(key, now_ts) error_rate = cls._calculate_error_rate_from_window(window, now_ts)
window = key.request_results_window or []
if len(window) >= cls.MIN_REQUESTS and error_rate >= cls.ERROR_RATE_THRESHOLD: if len(window) >= cls.MIN_REQUESTS and error_rate >= cls.ERROR_RATE_THRESHOLD:
cls._open_circuit( consecutive = int(health_data.get("consecutive_failures") or 0)
key, now, reason=f"错误率 {error_rate:.0%} 超过阈值 {cls.ERROR_RATE_THRESHOLD:.0%}" recovery_seconds = cls._calculate_recovery_seconds(consecutive)
reason = f"错误率 {error_rate:.0%} 超过阈值 {cls.ERROR_RATE_THRESHOLD:.0%}"
cls._open_circuit_data(circuit_data, now, recovery_seconds, reason=reason)
cls._open_circuit_keys += 1
health_open_circuits.set(cls._open_circuit_keys)
cls._push_circuit_event(
{
"event": "opened",
"key_id": key.id,
"api_format": effective_api_format,
"reason": reason,
"recovery_seconds": recovery_seconds,
"timestamp": now.isoformat(),
}
)
logger.warning(
f"[OPEN] Key 熔断器打开: {key.id[:8]}.../{effective_api_format} | 原因: {reason} | "
f"{recovery_seconds}秒后进入半开状态"
) )
# 保存数据
cls._set_health_data(key, effective_api_format, health_data)
cls._set_circuit_data(key, effective_api_format, circuit_data)
# 更新全局统计
key.error_count = int(key.error_count or 0) + 1 # type: ignore[assignment]
key.request_count = int(key.request_count or 0) + 1 # type: ignore[assignment]
key.last_error_at = now # type: ignore[assignment]
logger.debug( logger.debug(
f"[WARN] Key 健康度下降: {key_id[:8]}... -> {new_score:.2f} " f"[WARN] Key 健康度下降: {key_id[:8]}.../{effective_api_format} -> {new_score:.2f} "
f"(连续失败 {key.consecutive_failures} 次, error_type={error_type})" f"(连续失败 {health_data['consecutive_failures']} 次, error_type={error_type})"
) )
db.flush() db.flush()
@@ -222,31 +432,13 @@ class HealthMonitor:
# ==================== 滑动窗口方法 ==================== # ==================== 滑动窗口方法 ====================
@classmethod @classmethod
def _add_to_window(cls, key: ProviderAPIKey, now_ts: float, success: bool) -> None: def _calculate_error_rate_from_window(
"""添加请求结果到滑动窗口""" cls, window: List[Dict[str, Any]], now_ts: float
window: List[Dict[str, Any]] = key.request_results_window or [] ) -> float:
"""从窗口数据计算错误率"""
# 添加新记录
window.append({"ts": now_ts, "ok": success})
# 清理过期记录
cutoff_ts = now_ts - cls.WINDOW_SECONDS
window = [r for r in window if r["ts"] > cutoff_ts]
# 限制窗口大小
if len(window) > cls.WINDOW_SIZE:
window = window[-cls.WINDOW_SIZE :]
key.request_results_window = window # type: ignore[assignment]
@classmethod
def _calculate_error_rate(cls, key: ProviderAPIKey, now_ts: float) -> float:
"""计算滑动窗口内的错误率"""
window: List[Dict[str, Any]] = key.request_results_window or []
if not window: if not window:
return 0.0 return 0.0
# 过滤过期记录
cutoff_ts = now_ts - cls.WINDOW_SECONDS cutoff_ts = now_ts - cls.WINDOW_SECONDS
valid_records = [r for r in window if r["ts"] > cutoff_ts] valid_records = [r for r in window if r["ts"] > cutoff_ts]
@@ -256,157 +448,158 @@ class HealthMonitor:
failures = sum(1 for r in valid_records if not r["ok"]) failures = sum(1 for r in valid_records if not r["ok"])
return failures / len(valid_records) return failures / len(valid_records)
# ==================== 熔断器状态方法 ==================== # ==================== 熔断器状态方法(操作数据字典)====================
@classmethod @classmethod
def _get_circuit_state(cls, key: ProviderAPIKey, now: datetime) -> str: def _get_circuit_state_from_data(cls, circuit_data: Dict[str, Any], now: datetime) -> str:
"""获取当前熔断器状态""" """从数据字典获取当前熔断器状态"""
if not key.circuit_breaker_open: if not circuit_data.get("open"):
return CircuitState.CLOSED return CircuitState.CLOSED
# 检查是否在半开状态 # 检查是否在半开状态
if key.half_open_until and now < key.half_open_until: half_open_until_str = circuit_data.get("half_open_until")
return CircuitState.HALF_OPEN if half_open_until_str:
half_open_until = datetime.fromisoformat(half_open_until_str)
if now < half_open_until:
return CircuitState.HALF_OPEN
# 检查是否到了探测时间(进入半开) # 检查是否到了探测时间(进入半开)
if key.next_probe_at and now >= key.next_probe_at: next_probe_str = circuit_data.get("next_probe_at")
return CircuitState.HALF_OPEN if next_probe_str:
next_probe_at = datetime.fromisoformat(next_probe_str)
if now >= next_probe_at:
return CircuitState.HALF_OPEN
return CircuitState.OPEN return CircuitState.OPEN
@classmethod @classmethod
def _open_circuit(cls, key: ProviderAPIKey, now: datetime, reason: str) -> None: def _open_circuit_data(
"""打开熔断器""" cls,
was_open = key.circuit_breaker_open circuit_data: Dict[str, Any],
now: datetime,
key.circuit_breaker_open = True # type: ignore[assignment] recovery_seconds: int,
key.circuit_breaker_open_at = now # type: ignore[assignment] reason: str,
key.half_open_until = None # type: ignore[assignment] ) -> None:
key.half_open_successes = 0 # type: ignore[assignment] """打开熔断器(操作数据字典)"""
key.half_open_failures = 0 # type: ignore[assignment] circuit_data["open"] = True
circuit_data["open_at"] = now.isoformat()
# 计算下次探测时间(进入半开状态的时间) circuit_data["half_open_until"] = None
consecutive = int(key.consecutive_failures or 0) circuit_data["half_open_successes"] = 0
recovery_seconds = cls._calculate_recovery_seconds(consecutive) circuit_data["half_open_failures"] = 0
key.next_probe_at = now + timedelta(seconds=recovery_seconds) # type: ignore[assignment] circuit_data["next_probe_at"] = (now + timedelta(seconds=recovery_seconds)).isoformat()
if not was_open:
cls._open_circuit_keys += 1
health_open_circuits.set(cls._open_circuit_keys)
logger.warning(
f"[OPEN] Key 熔断器打开: {key.id[:8]}... | 原因: {reason} | "
f"{recovery_seconds}秒后进入半开状态"
)
cls._push_circuit_event(
{
"event": "opened",
"key_id": key.id,
"reason": reason,
"recovery_seconds": recovery_seconds,
"timestamp": now.isoformat(),
}
)
@classmethod @classmethod
def _enter_half_open(cls, key: ProviderAPIKey, now: datetime) -> None: def _enter_half_open_data(cls, circuit_data: Dict[str, Any], now: datetime) -> None:
"""进入半开状态""" """进入半开状态(操作数据字典)"""
key.half_open_until = now + timedelta(seconds=cls.HALF_OPEN_DURATION) # type: ignore[assignment] circuit_data["half_open_until"] = (
key.half_open_successes = 0 # type: ignore[assignment] now + timedelta(seconds=cls.HALF_OPEN_DURATION)
key.half_open_failures = 0 # type: ignore[assignment] ).isoformat()
circuit_data["half_open_successes"] = 0
logger.info( circuit_data["half_open_failures"] = 0
f"[HALF-OPEN] Key 进入半开状态: {key.id[:8]}... | "
f"需要 {cls.HALF_OPEN_SUCCESS_THRESHOLD} 次成功关闭熔断器"
)
cls._push_circuit_event(
{
"event": "half_open",
"key_id": key.id,
"timestamp": now.isoformat(),
}
)
@classmethod @classmethod
def _close_circuit(cls, key: ProviderAPIKey, now: datetime, reason: str) -> None: def _close_circuit_data(
"""关闭熔断器""" cls, circuit_data: Dict[str, Any], health_data: Dict[str, Any], reason: str
key.circuit_breaker_open = False # type: ignore[assignment] ) -> None:
key.circuit_breaker_open_at = None # type: ignore[assignment] """关闭熔断器(操作数据字典)"""
key.next_probe_at = None # type: ignore[assignment] circuit_data["open"] = False
key.half_open_until = None # type: ignore[assignment] circuit_data["open_at"] = None
key.half_open_successes = 0 # type: ignore[assignment] circuit_data["next_probe_at"] = None
key.half_open_failures = 0 # type: ignore[assignment] circuit_data["half_open_until"] = None
circuit_data["half_open_successes"] = 0
circuit_data["half_open_failures"] = 0
# 快速恢复健康度 # 快速恢复健康度
key.health_score = max(float(key.health_score or 0), cls.PROBE_RECOVERY_SCORE) # type: ignore[assignment] current_score = float(health_data.get("health_score") or 0)
health_data["health_score"] = max(current_score, cls.PROBE_RECOVERY_SCORE)
cls._open_circuit_keys = max(0, cls._open_circuit_keys - 1) cls._open_circuit_keys = max(0, cls._open_circuit_keys - 1)
health_open_circuits.set(cls._open_circuit_keys) health_open_circuits.set(cls._open_circuit_keys)
logger.info(f"[CLOSED] Key 熔断器关闭: {key.id[:8]}... | 原因: {reason}")
cls._push_circuit_event(
{
"event": "closed",
"key_id": key.id,
"reason": reason,
"timestamp": now.isoformat(),
}
)
@classmethod @classmethod
def _calculate_recovery_seconds(cls, consecutive_failures: int) -> int: def _calculate_recovery_seconds(cls, consecutive_failures: int) -> int:
"""计算恢复等待时间(指数退避)""" """计算恢复等待时间(指数退避)"""
# 指数退避30s -> 60s -> 120s -> 240s -> 300s上限 exponent = min(consecutive_failures // 5, 4)
exponent = min(consecutive_failures // 5, 4) # 每5次失败增加一级
seconds = cls.INITIAL_RECOVERY_SECONDS * (cls.RECOVERY_BACKOFF**exponent) seconds = cls.INITIAL_RECOVERY_SECONDS * (cls.RECOVERY_BACKOFF**exponent)
return min(int(seconds), cls.MAX_RECOVERY_SECONDS) return min(int(seconds), cls.MAX_RECOVERY_SECONDS)
# ==================== 状态查询方法 ==================== # ==================== 状态查询方法 ====================
@classmethod @classmethod
def is_circuit_breaker_closed(cls, resource: ProviderAPIKey) -> bool: def is_circuit_breaker_closed(
"""检查熔断器是否允许请求通过""" cls, resource: ProviderAPIKey, api_format: Optional[str] = None
if not resource.circuit_breaker_open: ) -> bool:
"""检查熔断器是否允许请求通过(按 API 格式)"""
if not api_format:
# 兼容旧调用:检查是否有任何格式的熔断器开启
circuit_by_format = resource.circuit_breaker_by_format or {}
for fmt, circuit_data in circuit_by_format.items():
if circuit_data.get("open"):
return False
return True
circuit_data = cls._get_circuit_data(resource, api_format)
if not circuit_data.get("open"):
return True return True
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
state = cls._get_circuit_state(resource, now) state = cls._get_circuit_state_from_data(circuit_data, now)
# 半开状态允许请求通过 # 半开状态允许请求通过
if state == CircuitState.HALF_OPEN: if state == CircuitState.HALF_OPEN:
return True return True
# 检查是否到了探测时间 # 检查是否到了探测时间
if resource.next_probe_at and now >= resource.next_probe_at: next_probe_str = circuit_data.get("next_probe_at")
# 自动进入半开状态 if next_probe_str:
cls._enter_half_open(resource, now) next_probe_at = datetime.fromisoformat(next_probe_str)
return True if now >= next_probe_at:
# 自动进入半开状态
cls._enter_half_open_data(circuit_data, now)
cls._set_circuit_data(resource, api_format, circuit_data)
return True
return False return False
@classmethod @classmethod
def get_circuit_breaker_status( def get_circuit_breaker_status(
cls, resource: ProviderAPIKey cls, resource: ProviderAPIKey, api_format: Optional[str] = None
) -> Tuple[bool, Optional[str]]: ) -> Tuple[bool, Optional[str]]:
"""获取熔断器详细状态""" """获取熔断器详细状态(按 API 格式)"""
if not resource.circuit_breaker_open: if not api_format:
# 兼容旧调用:返回第一个开启的熔断器状态
circuit_by_format = resource.circuit_breaker_by_format or {}
for fmt, circuit_data in circuit_by_format.items():
if circuit_data.get("open"):
return cls._get_status_from_circuit_data(circuit_data)
return True, None
circuit_data = cls._get_circuit_data(resource, api_format)
return cls._get_status_from_circuit_data(circuit_data)
@classmethod
def _get_status_from_circuit_data(
cls, circuit_data: Dict[str, Any]
) -> Tuple[bool, Optional[str]]:
"""从熔断器数据获取状态描述"""
if not circuit_data.get("open"):
return True, None return True, None
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
state = cls._get_circuit_state(resource, now) state = cls._get_circuit_state_from_data(circuit_data, now)
if state == CircuitState.HALF_OPEN: if state == CircuitState.HALF_OPEN:
successes = int(resource.half_open_successes or 0) successes = int(circuit_data.get("half_open_successes") or 0)
return True, f"半开状态({successes}/{cls.HALF_OPEN_SUCCESS_THRESHOLD}成功)" return True, f"半开状态({successes}/{cls.HALF_OPEN_SUCCESS_THRESHOLD}成功)"
if resource.next_probe_at: next_probe_str = circuit_data.get("next_probe_at")
if now >= resource.next_probe_at: if next_probe_str:
next_probe_at = datetime.fromisoformat(next_probe_str)
if now >= next_probe_at:
return True, None return True, None
remaining = resource.next_probe_at - now remaining = next_probe_at - now
remaining_seconds = int(remaining.total_seconds()) remaining_seconds = int(remaining.total_seconds())
if remaining_seconds >= 60: if remaining_seconds >= 60:
time_str = f"{remaining_seconds // 60}min{remaining_seconds % 60}s" time_str = f"{remaining_seconds // 60}min{remaining_seconds % 60}s"
@@ -417,8 +610,10 @@ class HealthMonitor:
return False, "熔断中" return False, "熔断中"
@classmethod @classmethod
def get_key_health(cls, db: Session, key_id: str) -> Optional[Dict[str, Any]]: def get_key_health(
"""获取 Key 健康状态""" cls, db: Session, key_id: str, api_format: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""获取 Key 健康状态(支持按格式查询)"""
try: try:
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == key_id).first() key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == key_id).first()
if not key: if not key:
@@ -427,24 +622,15 @@ class HealthMonitor:
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
now_ts = now.timestamp() now_ts = now.timestamp()
# 计算当前错误率
error_rate = cls._calculate_error_rate(key, now_ts)
window = key.request_results_window or []
valid_window = [r for r in window if r["ts"] > now_ts - cls.WINDOW_SECONDS]
avg_response_time_ms = ( avg_response_time_ms = (
int(key.total_response_time_ms or 0) / int(key.success_count or 1) int(key.total_response_time_ms or 0) / int(key.success_count or 1)
if key.success_count if key.success_count
else 0 else 0
) )
return { # 全局统计
result = {
"key_id": key.id, "key_id": key.id,
"health_score": float(key.health_score or 1.0),
"error_rate": error_rate,
"window_size": len(valid_window),
"consecutive_failures": int(key.consecutive_failures or 0),
"last_failure_at": key.last_failure_at.isoformat() if key.last_failure_at else None,
"is_active": key.is_active, "is_active": key.is_active,
"statistics": { "statistics": {
"request_count": int(key.request_count or 0), "request_count": int(key.request_count or 0),
@@ -457,25 +643,84 @@ class HealthMonitor:
), ),
"avg_response_time_ms": round(avg_response_time_ms, 2), "avg_response_time_ms": round(avg_response_time_ms, 2),
}, },
"circuit_breaker": {
"state": cls._get_circuit_state(key, now),
"open": key.circuit_breaker_open,
"open_at": (
key.circuit_breaker_open_at.isoformat()
if key.circuit_breaker_open_at
else None
),
"next_probe_at": (
key.next_probe_at.isoformat() if key.next_probe_at else None
),
"half_open_until": (
key.half_open_until.isoformat() if key.half_open_until else None
),
"half_open_successes": int(key.half_open_successes or 0),
"half_open_failures": int(key.half_open_failures or 0),
},
} }
# 按格式的健康度数据
health_by_format = key.health_by_format or {}
circuit_by_format = key.circuit_breaker_by_format or {}
if api_format:
# 查询单个格式
health_data = cls._get_health_data(key, api_format)
circuit_data = cls._get_circuit_data(key, api_format)
window = health_data.get("request_results_window") or []
valid_window = [r for r in window if r["ts"] > now_ts - cls.WINDOW_SECONDS]
result["api_format"] = api_format
result["health_score"] = float(health_data.get("health_score") or 1.0)
result["error_rate"] = cls._calculate_error_rate_from_window(window, now_ts)
result["window_size"] = len(valid_window)
result["consecutive_failures"] = int(
health_data.get("consecutive_failures") or 0
)
result["last_failure_at"] = health_data.get("last_failure_at")
result["circuit_breaker"] = {
"state": cls._get_circuit_state_from_data(circuit_data, now),
"open": circuit_data.get("open", False),
"open_at": circuit_data.get("open_at"),
"next_probe_at": circuit_data.get("next_probe_at"),
"half_open_until": circuit_data.get("half_open_until"),
"half_open_successes": int(circuit_data.get("half_open_successes") or 0),
"half_open_failures": int(circuit_data.get("half_open_failures") or 0),
}
else:
# 返回所有格式的健康度数据
formats_health = {}
for fmt in (key.api_formats or []):
health_data = health_by_format.get(fmt, _default_health_data())
circuit_data = circuit_by_format.get(fmt, _default_circuit_data())
window = health_data.get("request_results_window") or []
valid_window = [r for r in window if r["ts"] > now_ts - cls.WINDOW_SECONDS]
formats_health[fmt] = {
"health_score": float(health_data.get("health_score") or 1.0),
"error_rate": cls._calculate_error_rate_from_window(window, now_ts),
"window_size": len(valid_window),
"consecutive_failures": int(
health_data.get("consecutive_failures") or 0
),
"last_failure_at": health_data.get("last_failure_at"),
"circuit_breaker": {
"state": cls._get_circuit_state_from_data(circuit_data, now),
"open": circuit_data.get("open", False),
"open_at": circuit_data.get("open_at"),
"next_probe_at": circuit_data.get("next_probe_at"),
"half_open_until": circuit_data.get("half_open_until"),
"half_open_successes": int(
circuit_data.get("half_open_successes") or 0
),
"half_open_failures": int(
circuit_data.get("half_open_failures") or 0
),
},
}
result["health_by_format"] = formats_health
# 计算整体健康度(取最低值)
if formats_health:
result["health_score"] = min(
h["health_score"] for h in formats_health.values()
)
result["any_circuit_open"] = any(
h["circuit_breaker"]["open"] for h in formats_health.values()
)
else:
result["health_score"] = 1.0
result["any_circuit_open"] = False
return result
except Exception as e: except Exception as e:
logger.error(f"获取 Key 健康状态失败: {e}") logger.error(f"获取 Key 健康状态失败: {e}")
return None return None
@@ -507,23 +752,24 @@ class HealthMonitor:
# ==================== 管理方法 ==================== # ==================== 管理方法 ====================
@classmethod @classmethod
def reset_health(cls, db: Session, key_id: Optional[str] = None) -> bool: def reset_health(
"""重置健康度""" cls, db: Session, key_id: Optional[str] = None, api_format: Optional[str] = None
) -> bool:
"""重置健康度(支持按格式重置)"""
try: try:
if key_id: if key_id:
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == key_id).first() key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == key_id).first()
if key: if key:
key.health_score = 1.0 # type: ignore[assignment] if api_format:
key.consecutive_failures = 0 # type: ignore[assignment] # 重置单个格式
key.last_failure_at = None # type: ignore[assignment] cls._set_health_data(key, api_format, _default_health_data())
key.request_results_window = [] # type: ignore[assignment] cls._set_circuit_data(key, api_format, _default_circuit_data())
key.circuit_breaker_open = False # type: ignore[assignment] logger.info(f"[RESET] 重置 Key 健康度: {key_id}/{api_format}")
key.circuit_breaker_open_at = None # type: ignore[assignment] else:
key.next_probe_at = None # type: ignore[assignment] # 重置所有格式
key.half_open_until = None # type: ignore[assignment] key.health_by_format = {} # type: ignore[assignment]
key.half_open_successes = 0 # type: ignore[assignment] key.circuit_breaker_by_format = {} # type: ignore[assignment]
key.half_open_failures = 0 # type: ignore[assignment] logger.info(f"[RESET] 重置 Key 所有格式健康度: {key_id}")
logger.info(f"[RESET] 重置 Key 健康度: {key_id}")
db.flush() db.flush()
get_batch_committer().mark_dirty(db) get_batch_committer().mark_dirty(db)
@@ -542,7 +788,9 @@ class HealthMonitor:
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == key_id).first() key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == key_id).first()
if key and not key.is_active: if key and not key.is_active:
key.is_active = True # type: ignore[assignment] key.is_active = True # type: ignore[assignment]
key.consecutive_failures = 0 # type: ignore[assignment] # 重置所有格式的健康度
key.health_by_format = {} # type: ignore[assignment]
key.circuit_breaker_by_format = {} # type: ignore[assignment]
logger.info(f"[OK] 手动启用 Key: {key_id}") logger.info(f"[OK] 手动启用 Key: {key_id}")
db.flush() db.flush()
@@ -566,14 +814,28 @@ class HealthMonitor:
), ),
).first() ).first()
key_stats = db.query( # 统计 Key需要遍历 JSON 字段计算熔断状态)
func.count(ProviderAPIKey.id).label("total"), keys = db.query(ProviderAPIKey).all()
func.sum(case((ProviderAPIKey.is_active == True, 1), else_=0)).label("active"), total_keys = len(keys)
func.sum(case((ProviderAPIKey.health_score < 0.5, 1), else_=0)).label("unhealthy"), active_keys = sum(1 for k in keys if k.is_active)
func.sum(case((ProviderAPIKey.circuit_breaker_open == True, 1), else_=0)).label( unhealthy_keys = 0
"circuit_open" circuit_open_keys = 0
),
).first() for key in keys:
health_by_format = key.health_by_format or {}
circuit_by_format = key.circuit_breaker_by_format or {}
# 检查是否有任何格式健康度低于 0.5
for fmt, health_data in health_by_format.items():
if float(health_data.get("health_score") or 1.0) < 0.5:
unhealthy_keys += 1
break
# 检查是否有任何格式熔断器开启
for fmt, circuit_data in circuit_by_format.items():
if circuit_data.get("open"):
circuit_open_keys += 1
break
return { return {
"endpoints": { "endpoints": {
@@ -582,10 +844,10 @@ class HealthMonitor:
"unhealthy": int(endpoint_stats.unhealthy or 0) if endpoint_stats else 0, "unhealthy": int(endpoint_stats.unhealthy or 0) if endpoint_stats else 0,
}, },
"keys": { "keys": {
"total": key_stats.total or 0 if key_stats else 0, "total": total_keys,
"active": int(key_stats.active or 0) if key_stats else 0, "active": active_keys,
"unhealthy": int(key_stats.unhealthy or 0) if key_stats else 0, "unhealthy": unhealthy_keys,
"circuit_open": int(key_stats.circuit_open or 0) if key_stats else 0, "circuit_open": circuit_open_keys,
}, },
} }
@@ -618,8 +880,9 @@ class HealthMonitor:
db: Session, db: Session,
endpoint_id: Optional[str] = None, endpoint_id: Optional[str] = None,
key_id: Optional[str] = None, key_id: Optional[str] = None,
api_format: Optional[str] = None,
) -> bool: ) -> bool:
"""检查是否有资格进行探测(兼容旧接口""" """检查是否有资格进行探测(按 API 格式"""
if not cls.ALLOW_AUTO_RECOVER: if not cls.ALLOW_AUTO_RECOVER:
return False return False
@@ -628,13 +891,53 @@ class HealthMonitor:
if key_id: if key_id:
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == key_id).first() key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == key_id).first()
if key and key.circuit_breaker_open: if key:
now = datetime.now(timezone.utc) if api_format:
state = cls._get_circuit_state(key, now) circuit_data = cls._get_circuit_data(key, api_format)
return state == CircuitState.HALF_OPEN if circuit_data.get("open"):
now = datetime.now(timezone.utc)
state = cls._get_circuit_state_from_data(circuit_data, now)
return state == CircuitState.HALF_OPEN
else:
# 兼容旧调用:检查是否有任何格式处于半开状态
circuit_by_format = key.circuit_breaker_by_format or {}
now = datetime.now(timezone.utc)
for fmt, circuit_data in circuit_by_format.items():
if circuit_data.get("open"):
state = cls._get_circuit_state_from_data(circuit_data, now)
if state == CircuitState.HALF_OPEN:
return True
return False return False
# ==================== 便捷方法 ====================
@classmethod
def get_health_score(
cls, key: ProviderAPIKey, api_format: Optional[str] = None
) -> float:
"""获取指定格式的健康度分数"""
if not api_format:
# 返回所有格式中的最低健康度
health_by_format = key.health_by_format or {}
if not health_by_format:
return 1.0
return min(
float(h.get("health_score") or 1.0) for h in health_by_format.values()
)
health_data = cls._get_health_data(key, api_format)
return float(health_data.get("health_score") or 1.0)
@classmethod
def is_any_circuit_open(cls, key: ProviderAPIKey) -> bool:
"""检查是否有任何格式的熔断器开启"""
circuit_by_format = key.circuit_breaker_by_format or {}
for circuit_data in circuit_by_format.values():
if circuit_data.get("open"):
return True
return False
# 全局健康监控器实例 # 全局健康监控器实例
health_monitor = HealthMonitor() health_monitor = HealthMonitor()

View File

@@ -216,7 +216,7 @@ class ModelService:
def delete_model(db: Session, model_id: str): # UUID def delete_model(db: Session, model_id: str): # UUID
"""删除模型 """删除模型
新架构删除逻辑: 删除逻辑:
- Model 只是 Provider 对 GlobalModel 的实现,删除不影响 GlobalModel - Model 只是 Provider 对 GlobalModel 的实现,删除不影响 GlobalModel
- 检查是否是该 GlobalModel 的最后一个实现(如果是,警告但允许删除) - 检查是否是该 GlobalModel 的最后一个实现(如果是,警告但允许删除)
""" """
@@ -384,7 +384,7 @@ class ModelService:
@staticmethod @staticmethod
def convert_to_response(model: Model) -> ModelResponse: def convert_to_response(model: Model) -> ModelResponse:
"""转换为响应模型(新架构:从 GlobalModel 获取显示信息和默认值)""" """转换为响应模型(从 GlobalModel 获取显示信息和默认值)"""
return ModelResponse( return ModelResponse(
id=model.id, id=model.id,
provider_id=model.provider_id, provider_id=model.provider_id,

View File

@@ -171,7 +171,8 @@ class CandidateResolver:
) )
candidate_record_map[(candidate_index, 0)] = record_id candidate_record_map[(candidate_index, 0)] = record_id
else: else:
max_retries_for_candidate = endpoint.max_retries if candidate.is_cached else 1 # max_retries 已从 Endpoint 迁移到 ProviderEndpoint 仍可能保留旧字段用于兼容)
max_retries_for_candidate = int(provider.max_retries or 2) if candidate.is_cached else 1
for retry_index in range(max_retries_for_candidate): for retry_index in range(max_retries_for_candidate):
record_id = str(uuid.uuid4()) record_id = str(uuid.uuid4())
@@ -236,7 +237,7 @@ class CandidateResolver:
total = 0 total = 0
for candidate in all_candidates: for candidate in all_candidates:
if not candidate.is_skipped: if not candidate.is_skipped:
endpoint = candidate.endpoint provider = candidate.provider
max_retries = int(endpoint.max_retries) if candidate.is_cached else 1 max_retries = int(provider.max_retries or 2) if candidate.is_cached else 1
total += max_retries total += max_retries
return total return total

Some files were not shown because too many files have changed in this diff Show More