11 Commits

Author SHA1 Message Date
fawney19
7faca5512a feat(ui): 优化密钥添加和仪表盘空状态体验
- KeyFormDialog: 添加模式下保存后不关闭对话框,清除表单以便连续添加
- KeyFormDialog: 按钮文案根据编辑/添加模式动态显示
- Dashboard: 优化统计卡片加载状态和空数据占位显示
2026-01-10 19:32:36 +08:00
fawney19
ad84272084 fix: 修复普通用户无法访问仪表盘接口的权限问题
将 DashboardAdapter 的 mode 从 ApiMode.ADMIN 改为 ApiMode.USER,
允许普通用户访问 /api/dashboard/stats 和 /api/dashboard/daily-stats 接口。
2026-01-10 19:31:19 +08:00
fawney19
09e0f594ff refactor: 重构限流系统和健康监控,支持按 API 格式区分
- 将 adaptive_concurrency 重命名为 adaptive_rpm,从并发控制改为 RPM 控制
- 健康监控器支持按 API 格式独立管理健康度和熔断器状态
- 新增 model_permissions 模块,支持按格式配置允许的模型
- 重构前端提供商相关表单组件,新增 Collapsible UI 组件
- 新增数据库迁移脚本支持新的数据结构
2026-01-10 18:48:35 +08:00
fawney19
dd2fbf4424 style(ui): 调整模型详情抽屉关联提供商表格列宽 2026-01-08 13:37:41 +08:00
fawney19
99b12a49c6 Merge pull request #78 from fawney19/perf/optimize
perf: 优化 HTTP 客户端连接池复用
2026-01-08 13:37:13 +08:00
fawney19
ea35efe440 perf: 优化 HTTP 客户端连接池复用
- 新增 get_proxy_client() 方法,相同代理配置复用同一客户端
- 添加 LRU 淘汰策略,代理客户端上限 50 个防止内存泄漏
- 新增 get_default_client_async() 异步线程安全版本
- 使用模块级锁避免类属性初始化竞态条件
- 优化 ConcurrencyManager 使用 Redis MGET 批量获取减少往返
- 添加 get_pool_stats() 连接池统计信息接口
2026-01-08 13:34:59 +08:00
fawney19
bf09e740e9 fix(ui): 优化提供商详情页的交互体验
- 模型列表删除按钮仅在 hover 时显示红色
- 批量关联模型对话框:只有全局模型时展开,有多个分组时全部折叠
2026-01-08 11:25:52 +08:00
fawney19
60c77cec56 Merge pull request #77 from AAEE86/ui
style(ui): improve text visibility in dark mode for model badges
2026-01-08 10:52:54 +08:00
fawney19
0e4a1dddb5 refactor(ui): 优化批量端点创建的 UI 和性能
- 调整布局: API URL 移至顶部, API 格式选择移至下方
- 优化 checkbox 样式: 使用自定义勾选框替代原生样式
- API 格式按列排序: 基础格式和对应 CLI 格式上下对齐
- 请求配置改为 4 列布局, 更紧凑
- 使用 Promise.allSettled 并发创建端点, 提升性能
- 改进错误提示: 失败时直接展示具体错误信息给用户
- 清理未使用的 Select 组件导入和 selectOpen 变量
2026-01-08 10:50:25 +08:00
AAEE86
1cf18b6e12 feat(ui): support batch endpoint creation with multiple API formats (#76)
Replace single API format selector with multi-select checkbox interface in endpoint creation dialog. Users can now select multiple API formats to create multiple endpoints simultaneously with shared configuration (URL, path, timeout, etc.).

- Change API format selection from dropdown to checkbox grid layout
- Add selectedFormats array to track multiple format selections
- Implement batch creation logic with individual error handling
- Update submit button to show endpoint count being created
- Adjust form layout to improve visual hierarchy
- Display appropriate success/failure messages for batch operations
- Reset selectedFormats on form reset
2026-01-08 10:42:14 +08:00
AAEE86
f9a8be898a style(ui): improve text visibility in dark mode for model badges 2026-01-08 10:26:58 +08:00
102 changed files with 7077 additions and 4255 deletions

View File

@@ -0,0 +1,530 @@
"""consolidated schema updates
Revision ID: m4n5o6p7q8r9
Revises: 02a45b66b7c4
Create Date: 2026-01-10 20:00:00.000000
This migration consolidates all schema changes from 2026-01-08 to 2026-01-10:
1. provider_api_keys: Key 直接关联 Provider (provider_id, api_formats)
2. provider_api_keys: 添加 rate_multipliers JSON 字段(按格式费率)
3. models: global_model_id 改为可空(支持独立 ProviderModel
4. providers: 添加 timeout, max_retries, proxy从 endpoint 迁移)
5. providers: display_name 重命名为 name删除原 name
6. provider_api_keys: max_concurrent -> rpm_limit并发改 RPM
7. provider_api_keys: 健康度改为按格式存储health_by_format, circuit_breaker_by_format
8. provider_endpoints: 删除废弃的 rate_limit 列
9. usage: 添加 client_response_headers 字段
10. provider_api_keys: 删除 endpoint_idKey 不再与 Endpoint 绑定)
11. provider_endpoints: 删除废弃的 max_concurrent 列
12. providers: 删除废弃的 rpm_limit, rpm_used, rpm_reset_at 列
"""
import logging
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy import inspect
# 配置日志
alembic_logger = logging.getLogger("alembic.runtime.migration")
revision = "m4n5o6p7q8r9"
down_revision = "02a45b66b7c4"
branch_labels = None
depends_on = None
def _column_exists(table_name: str, column_name: str) -> bool:
"""Check if a column exists in the table"""
bind = op.get_bind()
inspector = inspect(bind)
columns = [col["name"] for col in inspector.get_columns(table_name)]
return column_name in columns
def _constraint_exists(table_name: str, constraint_name: str) -> bool:
"""Check if a constraint exists"""
bind = op.get_bind()
inspector = inspect(bind)
fks = inspector.get_foreign_keys(table_name)
return any(fk.get("name") == constraint_name for fk in fks)
def _index_exists(table_name: str, index_name: str) -> bool:
"""Check if an index exists"""
bind = op.get_bind()
inspector = inspect(bind)
indexes = inspector.get_indexes(table_name)
return any(idx.get("name") == index_name for idx in indexes)
def upgrade() -> None:
"""Apply all consolidated schema changes"""
bind = op.get_bind()
# ========== 1. provider_api_keys: 添加 provider_id 和 api_formats ==========
if not _column_exists("provider_api_keys", "provider_id"):
op.add_column("provider_api_keys", sa.Column("provider_id", sa.String(36), nullable=True))
# 数据迁移:从 endpoint 获取 provider_id
op.execute("""
UPDATE provider_api_keys k
SET provider_id = e.provider_id
FROM provider_endpoints e
WHERE k.endpoint_id = e.id AND k.provider_id IS NULL
""")
# 检查无法关联的孤儿 Key
result = bind.execute(sa.text(
"SELECT COUNT(*) FROM provider_api_keys WHERE provider_id IS NULL"
))
orphan_count = result.scalar() or 0
if orphan_count > 0:
# 使用 logger 记录更明显的告警
alembic_logger.warning("=" * 60)
alembic_logger.warning(f"[MIGRATION WARNING] 发现 {orphan_count} 个无法关联 Provider 的孤儿 Key")
alembic_logger.warning("=" * 60)
alembic_logger.info("正在备份孤儿 Key 到 _orphan_api_keys_backup 表...")
# 先备份孤儿数据到临时表,避免数据丢失
op.execute("""
CREATE TABLE IF NOT EXISTS _orphan_api_keys_backup AS
SELECT *, NOW() as backup_at
FROM provider_api_keys
WHERE provider_id IS NULL
""")
# 记录备份的 Key ID
orphan_ids = bind.execute(sa.text(
"SELECT id, name FROM provider_api_keys WHERE provider_id IS NULL"
)).fetchall()
alembic_logger.info("备份的孤儿 Key 列表:")
for key_id, key_name in orphan_ids:
alembic_logger.info(f" - Key: {key_name} (ID: {key_id})")
# 删除孤儿数据
op.execute("DELETE FROM provider_api_keys WHERE provider_id IS NULL")
alembic_logger.info(f"已备份并删除 {orphan_count} 个孤儿 Key")
# 提供恢复指南
alembic_logger.warning("-" * 60)
alembic_logger.warning("[恢复指南] 如需恢复孤儿 Key")
alembic_logger.warning(" 1. 查询备份表: SELECT * FROM _orphan_api_keys_backup;")
alembic_logger.warning(" 2. 确定正确的 provider_id")
alembic_logger.warning(" 3. 执行恢复:")
alembic_logger.warning(" INSERT INTO provider_api_keys (...)")
alembic_logger.warning(" SELECT ... FROM _orphan_api_keys_backup WHERE ...;")
alembic_logger.warning("-" * 60)
# 设置 NOT NULL 并创建外键
op.alter_column("provider_api_keys", "provider_id", nullable=False)
if not _constraint_exists("provider_api_keys", "fk_provider_api_keys_provider"):
op.create_foreign_key(
"fk_provider_api_keys_provider",
"provider_api_keys",
"providers",
["provider_id"],
["id"],
ondelete="CASCADE",
)
if not _index_exists("provider_api_keys", "idx_provider_api_keys_provider_id"):
op.create_index("idx_provider_api_keys_provider_id", "provider_api_keys", ["provider_id"])
if not _column_exists("provider_api_keys", "api_formats"):
op.add_column("provider_api_keys", sa.Column("api_formats", sa.JSON(), nullable=True))
# 数据迁移:从 endpoint 获取 api_format
op.execute("""
UPDATE provider_api_keys k
SET api_formats = json_build_array(e.api_format)
FROM provider_endpoints e
WHERE k.endpoint_id = e.id AND k.api_formats IS NULL
""")
op.alter_column("provider_api_keys", "api_formats", nullable=False, server_default="[]")
# 修改 endpoint_id 为可空,外键改为 SET NULL
if _constraint_exists("provider_api_keys", "provider_api_keys_endpoint_id_fkey"):
op.drop_constraint("provider_api_keys_endpoint_id_fkey", "provider_api_keys", type_="foreignkey")
op.alter_column("provider_api_keys", "endpoint_id", nullable=True)
# 不再重建外键,因为后面会删除这个字段
# ========== 2. provider_api_keys: 添加 rate_multipliers ==========
if not _column_exists("provider_api_keys", "rate_multipliers"):
op.add_column(
"provider_api_keys",
sa.Column("rate_multipliers", postgresql.JSON(astext_type=sa.Text()), nullable=True),
)
# 数据迁移:将 rate_multiplier 按 api_formats 转换
op.execute("""
UPDATE provider_api_keys
SET rate_multipliers = (
SELECT jsonb_object_agg(elem, rate_multiplier)
FROM jsonb_array_elements_text(api_formats::jsonb) AS elem
)
WHERE api_formats IS NOT NULL
AND api_formats::text != '[]'
AND api_formats::text != 'null'
AND rate_multipliers IS NULL
""")
# ========== 3. models: global_model_id 改为可空 ==========
op.alter_column("models", "global_model_id", existing_type=sa.String(36), nullable=True)
# ========== 4. providers: 添加 timeout, max_retries, proxy ==========
if not _column_exists("providers", "timeout"):
op.add_column(
"providers",
sa.Column("timeout", sa.Integer(), nullable=True, comment="请求超时(秒)"),
)
if not _column_exists("providers", "max_retries"):
op.add_column(
"providers",
sa.Column("max_retries", sa.Integer(), nullable=True, comment="最大重试次数"),
)
if not _column_exists("providers", "proxy"):
op.add_column(
"providers",
sa.Column("proxy", postgresql.JSONB(), nullable=True, comment="代理配置"),
)
# 从端点迁移数据到 provider
op.execute("""
UPDATE providers p
SET
timeout = COALESCE(
p.timeout,
(SELECT MAX(e.timeout) FROM provider_endpoints e WHERE e.provider_id = p.id AND e.timeout IS NOT NULL),
300
),
max_retries = COALESCE(
p.max_retries,
(SELECT MAX(e.max_retries) FROM provider_endpoints e WHERE e.provider_id = p.id AND e.max_retries IS NOT NULL),
2
),
proxy = COALESCE(
p.proxy,
(SELECT e.proxy FROM provider_endpoints e WHERE e.provider_id = p.id AND e.proxy IS NOT NULL ORDER BY e.created_at LIMIT 1)
)
WHERE p.timeout IS NULL OR p.max_retries IS NULL
""")
# ========== 5. providers: display_name -> name ==========
# 注意:这里假设 display_name 已经被重命名为 name
# 如果 display_name 仍然存在,则需要执行重命名
if _column_exists("providers", "display_name"):
# 删除旧的 name 索引
if _index_exists("providers", "ix_providers_name"):
op.drop_index("ix_providers_name", table_name="providers")
# 如果存在旧的 name 列,先删除
if _column_exists("providers", "name"):
op.drop_column("providers", "name")
# 重命名 display_name 为 name
op.alter_column("providers", "display_name", new_column_name="name")
# 创建新索引
op.create_index("ix_providers_name", "providers", ["name"], unique=True)
# ========== 6. provider_api_keys: max_concurrent -> rpm_limit ==========
if _column_exists("provider_api_keys", "max_concurrent"):
op.alter_column("provider_api_keys", "max_concurrent", new_column_name="rpm_limit")
if _column_exists("provider_api_keys", "learned_max_concurrent"):
op.alter_column("provider_api_keys", "learned_max_concurrent", new_column_name="learned_rpm_limit")
if _column_exists("provider_api_keys", "last_concurrent_peak"):
op.alter_column("provider_api_keys", "last_concurrent_peak", new_column_name="last_rpm_peak")
# 删除废弃字段
for col in ["rate_limit", "daily_limit", "monthly_limit"]:
if _column_exists("provider_api_keys", col):
op.drop_column("provider_api_keys", col)
# ========== 7. provider_api_keys: 健康度改为按格式存储 ==========
if not _column_exists("provider_api_keys", "health_by_format"):
op.add_column(
"provider_api_keys",
sa.Column(
"health_by_format",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
comment="按API格式存储的健康度数据",
),
)
if not _column_exists("provider_api_keys", "circuit_breaker_by_format"):
op.add_column(
"provider_api_keys",
sa.Column(
"circuit_breaker_by_format",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
comment="按API格式存储的熔断器状态",
),
)
# 数据迁移:如果存在旧字段,迁移数据到新结构
if _column_exists("provider_api_keys", "health_score"):
op.execute("""
UPDATE provider_api_keys
SET health_by_format = (
SELECT jsonb_object_agg(
elem,
jsonb_build_object(
'health_score', COALESCE(health_score, 1.0),
'consecutive_failures', COALESCE(consecutive_failures, 0),
'last_failure_at', last_failure_at,
'request_results_window', COALESCE(request_results_window::jsonb, '[]'::jsonb)
)
)
FROM jsonb_array_elements_text(api_formats::jsonb) AS elem
)
WHERE api_formats IS NOT NULL
AND api_formats::text != '[]'
AND health_by_format IS NULL
""")
# Circuit Breaker 迁移策略:
# 不复制旧的 circuit_breaker_open 状态到所有 format而是全部重置为 closed
# 原因:旧的单一 circuit breaker 状态可能因某一个 format 失败而打开,
# 如果复制到所有 format会导致其他正常工作的 format 被错误标记为不可用
if _column_exists("provider_api_keys", "circuit_breaker_open"):
op.execute("""
UPDATE provider_api_keys
SET circuit_breaker_by_format = (
SELECT jsonb_object_agg(
elem,
jsonb_build_object(
'open', false,
'open_at', NULL,
'next_probe_at', NULL,
'half_open_until', NULL,
'half_open_successes', 0,
'half_open_failures', 0
)
)
FROM jsonb_array_elements_text(api_formats::jsonb) AS elem
)
WHERE api_formats IS NOT NULL
AND api_formats::text != '[]'
AND circuit_breaker_by_format IS NULL
""")
# 设置默认空对象
op.execute("""
UPDATE provider_api_keys
SET health_by_format = '{}'::jsonb
WHERE health_by_format IS NULL
""")
op.execute("""
UPDATE provider_api_keys
SET circuit_breaker_by_format = '{}'::jsonb
WHERE circuit_breaker_by_format IS NULL
""")
# 创建 GIN 索引
if not _index_exists("provider_api_keys", "ix_provider_api_keys_health_by_format"):
op.create_index(
"ix_provider_api_keys_health_by_format",
"provider_api_keys",
["health_by_format"],
postgresql_using="gin",
)
if not _index_exists("provider_api_keys", "ix_provider_api_keys_circuit_breaker_by_format"):
op.create_index(
"ix_provider_api_keys_circuit_breaker_by_format",
"provider_api_keys",
["circuit_breaker_by_format"],
postgresql_using="gin",
)
# 删除旧字段
old_health_columns = [
"health_score",
"consecutive_failures",
"last_failure_at",
"request_results_window",
"circuit_breaker_open",
"circuit_breaker_open_at",
"next_probe_at",
"half_open_until",
"half_open_successes",
"half_open_failures",
]
for col in old_health_columns:
if _column_exists("provider_api_keys", col):
op.drop_column("provider_api_keys", col)
# ========== 8. provider_endpoints: 删除废弃的 rate_limit 列 ==========
if _column_exists("provider_endpoints", "rate_limit"):
op.drop_column("provider_endpoints", "rate_limit")
# ========== 9. usage: 添加 client_response_headers ==========
if not _column_exists("usage", "client_response_headers"):
op.add_column(
"usage",
sa.Column("client_response_headers", sa.JSON(), nullable=True),
)
# ========== 10. provider_api_keys: 删除 endpoint_id ==========
# Key 不再与 Endpoint 绑定,通过 provider_id + api_formats 关联
if _column_exists("provider_api_keys", "endpoint_id"):
# 确保外键已删除(前面可能已经删除)
try:
bind = op.get_bind()
inspector = inspect(bind)
for fk in inspector.get_foreign_keys("provider_api_keys"):
constrained = fk.get("constrained_columns") or []
if "endpoint_id" in constrained:
name = fk.get("name")
if name:
op.drop_constraint(name, "provider_api_keys", type_="foreignkey")
except Exception:
pass # 外键可能已经不存在
op.drop_column("provider_api_keys", "endpoint_id")
# ========== 11. provider_endpoints: 删除废弃的 max_concurrent 列 ==========
if _column_exists("provider_endpoints", "max_concurrent"):
op.drop_column("provider_endpoints", "max_concurrent")
# ========== 12. providers: 删除废弃的 RPM 相关字段 ==========
if _column_exists("providers", "rpm_limit"):
op.drop_column("providers", "rpm_limit")
if _column_exists("providers", "rpm_used"):
op.drop_column("providers", "rpm_used")
if _column_exists("providers", "rpm_reset_at"):
op.drop_column("providers", "rpm_reset_at")
alembic_logger.info("[OK] Consolidated migration completed successfully")
def downgrade() -> None:
"""
Downgrade is complex due to data migrations.
For safety, this only removes new columns without restoring old structure.
Manual intervention may be required for full rollback.
"""
bind = op.get_bind()
# 12. 恢复 providers RPM 相关字段
if not _column_exists("providers", "rpm_limit"):
op.add_column("providers", sa.Column("rpm_limit", sa.Integer(), nullable=True))
if not _column_exists("providers", "rpm_used"):
op.add_column(
"providers",
sa.Column("rpm_used", sa.Integer(), server_default="0", nullable=True),
)
if not _column_exists("providers", "rpm_reset_at"):
op.add_column(
"providers",
sa.Column("rpm_reset_at", sa.DateTime(timezone=True), nullable=True),
)
# 11. 恢复 provider_endpoints.max_concurrent
if not _column_exists("provider_endpoints", "max_concurrent"):
op.add_column("provider_endpoints", sa.Column("max_concurrent", sa.Integer(), nullable=True))
# 10. 恢复 endpoint_id
if not _column_exists("provider_api_keys", "endpoint_id"):
op.add_column("provider_api_keys", sa.Column("endpoint_id", sa.String(36), nullable=True))
# 9. 删除 client_response_headers
if _column_exists("usage", "client_response_headers"):
op.drop_column("usage", "client_response_headers")
# 8. 恢复 provider_endpoints.rate_limit如果需要
if not _column_exists("provider_endpoints", "rate_limit"):
op.add_column("provider_endpoints", sa.Column("rate_limit", sa.Integer(), nullable=True))
# 7. 删除健康度 JSON 字段
bind.execute(sa.text("DROP INDEX IF EXISTS ix_provider_api_keys_health_by_format"))
bind.execute(sa.text("DROP INDEX IF EXISTS ix_provider_api_keys_circuit_breaker_by_format"))
if _column_exists("provider_api_keys", "health_by_format"):
op.drop_column("provider_api_keys", "health_by_format")
if _column_exists("provider_api_keys", "circuit_breaker_by_format"):
op.drop_column("provider_api_keys", "circuit_breaker_by_format")
# 6. rpm_limit -> max_concurrent简化版仅重命名
if _column_exists("provider_api_keys", "rpm_limit"):
op.alter_column("provider_api_keys", "rpm_limit", new_column_name="max_concurrent")
if _column_exists("provider_api_keys", "learned_rpm_limit"):
op.alter_column("provider_api_keys", "learned_rpm_limit", new_column_name="learned_max_concurrent")
if _column_exists("provider_api_keys", "last_rpm_peak"):
op.alter_column("provider_api_keys", "last_rpm_peak", new_column_name="last_concurrent_peak")
# 恢复已删除的字段
if not _column_exists("provider_api_keys", "rate_limit"):
op.add_column("provider_api_keys", sa.Column("rate_limit", sa.Integer(), nullable=True))
if not _column_exists("provider_api_keys", "daily_limit"):
op.add_column("provider_api_keys", sa.Column("daily_limit", sa.Integer(), nullable=True))
if not _column_exists("provider_api_keys", "monthly_limit"):
op.add_column("provider_api_keys", sa.Column("monthly_limit", sa.Integer(), nullable=True))
# 5. name -> display_name (需要先删除索引)
if _index_exists("providers", "ix_providers_name"):
op.drop_index("ix_providers_name", table_name="providers")
op.alter_column("providers", "name", new_column_name="display_name")
# 重新添加原 name 字段
op.add_column("providers", sa.Column("name", sa.String(100), nullable=True))
op.execute("""
UPDATE providers
SET name = LOWER(REPLACE(REPLACE(display_name, ' ', '_'), '-', '_'))
""")
op.alter_column("providers", "name", nullable=False)
op.create_index("ix_providers_name", "providers", ["name"], unique=True)
# 4. 删除 providers 的 timeout, max_retries, proxy
if _column_exists("providers", "proxy"):
op.drop_column("providers", "proxy")
if _column_exists("providers", "max_retries"):
op.drop_column("providers", "max_retries")
if _column_exists("providers", "timeout"):
op.drop_column("providers", "timeout")
# 3. models: global_model_id 改回 NOT NULL
result = bind.execute(sa.text(
"SELECT COUNT(*) FROM models WHERE global_model_id IS NULL"
))
orphan_model_count = result.scalar() or 0
if orphan_model_count > 0:
alembic_logger.warning(f"[WARN] 发现 {orphan_model_count} 个无 global_model_id 的独立模型,将被删除")
op.execute("DELETE FROM models WHERE global_model_id IS NULL")
alembic_logger.info(f"已删除 {orphan_model_count} 个独立模型")
op.alter_column("models", "global_model_id", nullable=False)
# 2. 删除 rate_multipliers
if _column_exists("provider_api_keys", "rate_multipliers"):
op.drop_column("provider_api_keys", "rate_multipliers")
# 1. 删除 provider_id 和 api_formats
if _index_exists("provider_api_keys", "idx_provider_api_keys_provider_id"):
op.drop_index("idx_provider_api_keys_provider_id", table_name="provider_api_keys")
if _constraint_exists("provider_api_keys", "fk_provider_api_keys_provider"):
op.drop_constraint("fk_provider_api_keys_provider", "provider_api_keys", type_="foreignkey")
if _column_exists("provider_api_keys", "api_formats"):
op.drop_column("provider_api_keys", "api_formats")
if _column_exists("provider_api_keys", "provider_id"):
op.drop_column("provider_api_keys", "provider_id")
# 恢复 endpoint_id 外键(简化版:仅创建外键,不强制 NOT NULL
if _column_exists("provider_api_keys", "endpoint_id"):
if not _constraint_exists("provider_api_keys", "provider_api_keys_endpoint_id_fkey"):
op.create_foreign_key(
"provider_api_keys_endpoint_id_fkey",
"provider_api_keys",
"provider_endpoints",
["endpoint_id"],
["id"],
ondelete="SET NULL",
)
alembic_logger.info("[OK] Downgrade completed (simplified version)")

View File

@@ -67,7 +67,6 @@ export interface GlobalModelExport {
export interface ProviderExport {
name: string
display_name: string
description?: string | null
website?: string | null
billing_type?: string | null
@@ -76,10 +75,13 @@ export interface ProviderExport {
rpm_limit?: number | null
provider_priority?: number
is_active: boolean
rate_limit?: number | null
concurrent_limit?: number | null
timeout?: number | null
max_retries?: number | null
proxy?: any
config?: any
endpoints: EndpointExport[]
api_keys: ProviderKeyExport[]
models: ModelExport[]
}
@@ -89,27 +91,26 @@ export interface EndpointExport {
headers?: any
timeout?: number
max_retries?: number
max_concurrent?: number | null
rate_limit?: number | null
is_active: boolean
custom_path?: string | null
config?: any
keys: KeyExport[]
proxy?: any
}
export interface KeyExport {
export interface ProviderKeyExport {
api_key: string
name?: string | null
note?: string | null
api_formats: string[]
rate_multiplier?: number
rate_multipliers?: Record<string, number> | null
internal_priority?: number
global_priority?: number | null
max_concurrent?: number | null
rate_limit?: number | null
daily_limit?: number | null
monthly_limit?: number | null
allowed_models?: string[] | null
rpm_limit?: number | null
allowed_models?: any
capabilities?: any
cache_ttl_minutes?: number
max_probe_interval_minutes?: number
is_active: boolean
}

View File

@@ -155,6 +155,7 @@ export interface RequestDetail {
request_body?: Record<string, any>
provider_request_headers?: Record<string, any>
response_headers?: Record<string, any>
client_response_headers?: Record<string, any>
response_body?: Record<string, any>
metadata?: Record<string, any>
// 阶梯计费信息

View File

@@ -14,7 +14,7 @@ export async function toggleAdaptiveMode(
message: string
key_id: string
is_adaptive: boolean
max_concurrent: number | null
rpm_limit: number | null
effective_limit: number | null
}> {
const response = await client.patch(`/api/admin/adaptive/keys/${keyId}/mode`, data)
@@ -22,16 +22,16 @@ export async function toggleAdaptiveMode(
}
/**
* 设置 Key 的固定并发限制
* 设置 Key 的固定 RPM 限制
*/
export async function setConcurrentLimit(
export async function setRpmLimit(
keyId: string,
limit: number
): Promise<{
message: string
key_id: string
is_adaptive: boolean
max_concurrent: number
rpm_limit: number
previous_mode: string
}> {
const response = await client.patch(`/api/admin/adaptive/keys/${keyId}/limit`, null, {

View File

@@ -27,15 +27,9 @@ export async function createEndpoint(
api_format: string
base_url: string
custom_path?: string
auth_type?: string
auth_header?: string
headers?: Record<string, string>
timeout?: number
max_retries?: number
priority?: number
weight?: number
max_concurrent?: number
rate_limit?: number
is_active?: boolean
config?: Record<string, any>
proxy?: ProxyConfig | null
@@ -52,16 +46,10 @@ export async function updateEndpoint(
endpointId: string,
data: Partial<{
base_url: string
custom_path: string
auth_type: string
auth_header: string
custom_path: string | null
headers: Record<string, string>
timeout: number
max_retries: number
priority: number
weight: number
max_concurrent: number
rate_limit: number
is_active: boolean
config: Record<string, any>
proxy: ProxyConfig | null
@@ -74,7 +62,7 @@ export async function updateEndpoint(
/**
* 删除 Endpoint
*/
export async function deleteEndpoint(endpointId: string): Promise<{ message: string; deleted_keys_count: number }> {
export async function deleteEndpoint(endpointId: string): Promise<{ message: string; affected_keys_count: number }> {
const response = await client.delete(`/api/admin/endpoints/${endpointId}`)
return response.data
}

View File

@@ -32,16 +32,21 @@ export async function getKeyHealth(keyId: string): Promise<HealthStatus> {
/**
* 恢复Key健康状态一键恢复重置健康度 + 关闭熔断器 + 取消自动禁用)
* @param keyId Key ID
* @param apiFormat 可选,指定 API 格式(如 CLAUDE、OPENAI不指定则恢复所有格式
*/
export async function recoverKeyHealth(keyId: string): Promise<{
export async function recoverKeyHealth(keyId: string, apiFormat?: string): Promise<{
message: string
details: {
api_format?: string
health_score: number
circuit_breaker_open: boolean
is_active: boolean
}
}> {
const response = await client.patch(`/api/admin/endpoints/health/keys/${keyId}`)
const response = await client.patch(`/api/admin/endpoints/health/keys/${keyId}`, null, {
params: apiFormat ? { api_format: apiFormat } : undefined
})
return response.data
}

View File

@@ -1,5 +1,5 @@
import client from '../client'
import type { EndpointAPIKey } from './types'
import type { EndpointAPIKey, AllowedModels } from './types'
/**
* 能力定义类型
@@ -49,67 +49,6 @@ export async function getModelCapabilities(modelName: string): Promise<ModelCapa
return response.data
}
/**
* 获取 Endpoint 的所有 Keys
*/
export async function getEndpointKeys(endpointId: string): Promise<EndpointAPIKey[]> {
const response = await client.get(`/api/admin/endpoints/${endpointId}/keys`)
return response.data
}
/**
* 为 Endpoint 添加 Key
*/
export async function addEndpointKey(
endpointId: string,
data: {
endpoint_id: string
api_key: string
name: string // 密钥名称(必填)
rate_multiplier?: number // 成本倍率(默认 1.0
internal_priority?: number // Endpoint 内部优先级(数字越小越优先)
max_concurrent?: number // 最大并发数(留空=自适应模式)
rate_limit?: number
daily_limit?: number
monthly_limit?: number
cache_ttl_minutes?: number // 缓存 TTL分钟0=禁用
max_probe_interval_minutes?: number // 熔断探测间隔(分钟)
allowed_models?: string[] // 允许使用的模型列表
capabilities?: Record<string, boolean> // 能力标签配置
note?: string // 备注说明(可选)
}
): Promise<EndpointAPIKey> {
const response = await client.post(`/api/admin/endpoints/${endpointId}/keys`, data)
return response.data
}
/**
* 更新 Endpoint Key
*/
export async function updateEndpointKey(
keyId: string,
data: Partial<{
api_key: string
name: string // 密钥名称
rate_multiplier: number // 成本倍率
internal_priority: number // Endpoint 内部优先级(提供商优先模式,数字越小越优先)
global_priority: number // 全局 Key 优先级(全局 Key 优先模式,数字越小越优先)
max_concurrent: number // 最大并发数(留空=自适应模式)
rate_limit: number
daily_limit: number
monthly_limit: number
cache_ttl_minutes: number // 缓存 TTL分钟0=禁用
max_probe_interval_minutes: number // 熔断探测间隔(分钟)
allowed_models: string[] | null // 允许使用的模型列表null 表示允许所有
capabilities: Record<string, boolean> | null // 能力标签配置
is_active: boolean
note: string // 备注说明
}>
): Promise<EndpointAPIKey> {
const response = await client.put(`/api/admin/endpoints/keys/${keyId}`, data)
return response.data
}
/**
* 获取完整的 API Key用于查看和复制
*/
@@ -119,22 +58,71 @@ export async function revealEndpointKey(keyId: string): Promise<{ api_key: strin
}
/**
* 删除 Endpoint Key
* 删除 Key
*/
export async function deleteEndpointKey(keyId: string): Promise<{ message: string }> {
const response = await client.delete(`/api/admin/endpoints/keys/${keyId}`)
return response.data
}
// ========== Provider 级别的 Keys API ==========
/**
* 批量更新 Endpoint Keys 的优先级(用于拖动排序)
* 获取 Provider 的所有 Keys
*/
export async function batchUpdateKeyPriority(
endpointId: string,
priorities: Array<{ key_id: string; internal_priority: number }>
): Promise<{ message: string; updated_count: number }> {
const response = await client.put(`/api/admin/endpoints/${endpointId}/keys/batch-priority`, {
priorities
})
export async function getProviderKeys(providerId: string): Promise<EndpointAPIKey[]> {
const response = await client.get(`/api/admin/endpoints/providers/${providerId}/keys`)
return response.data
}
/**
* 为 Provider 添加 Key
*/
export async function addProviderKey(
providerId: string,
data: {
api_formats: string[] // 支持的 API 格式列表(必填)
api_key: string
name: string
rate_multiplier?: number // 默认成本倍率
rate_multipliers?: Record<string, number> | null // 按 API 格式的成本倍率
internal_priority?: number
rpm_limit?: number | null // RPM 限制(留空=自适应模式)
cache_ttl_minutes?: number
max_probe_interval_minutes?: number
allowed_models?: AllowedModels
capabilities?: Record<string, boolean>
note?: string
}
): Promise<EndpointAPIKey> {
const response = await client.post(`/api/admin/endpoints/providers/${providerId}/keys`, data)
return response.data
}
/**
* 更新 Key
*/
export async function updateProviderKey(
keyId: string,
data: Partial<{
api_formats: string[] // 支持的 API 格式列表
api_key: string
name: string
rate_multiplier: number // 默认成本倍率
rate_multipliers: Record<string, number> | null // 按 API 格式的成本倍率
internal_priority: number
global_priority: number | null
rpm_limit: number | null // RPM 限制(留空=自适应模式)
cache_ttl_minutes: number
max_probe_interval_minutes: number
allowed_models: AllowedModels
capabilities: Record<string, boolean> | null
is_active: boolean
note: string
}>
): Promise<EndpointAPIKey> {
const response = await client.put(`/api/admin/endpoints/keys/${keyId}`, data)
return response.data
}

View File

@@ -147,14 +147,26 @@ export async function queryProviderUpstreamModels(
/**
* 从上游提供商导入模型
* @param providerId 提供商 ID
* @param modelIds 模型 ID 列表
* @param options 可选配置
* @param options.tiered_pricing 阶梯计费配置
* @param options.price_per_request 按次计费价格
*/
export async function importModelsFromUpstream(
providerId: string,
modelIds: string[]
modelIds: string[],
options?: {
tiered_pricing?: object
price_per_request?: number
}
): Promise<ImportFromUpstreamResponse> {
const response = await client.post(
`/api/admin/providers/${providerId}/import-from-upstream`,
{ model_ids: modelIds }
{
model_ids: modelIds,
...options
}
)
return response.data
}

View File

@@ -1,5 +1,5 @@
import client from '../client'
import type { ProviderWithEndpointsSummary } from './types'
import type { ProviderWithEndpointsSummary, ProxyConfig } from './types'
/**
* 获取 Providers 摘要(包含 Endpoints 统计)
@@ -23,7 +23,7 @@ export async function getProvider(providerId: string): Promise<ProviderWithEndpo
export async function updateProvider(
providerId: string,
data: Partial<{
display_name: string
name: string
description: string
website: string
provider_priority: number
@@ -33,6 +33,10 @@ export async function updateProvider(
quota_last_reset_at: string // 周期开始时间
quota_expires_at: string
rpm_limit: number | null
// 请求配置(从 Endpoint 迁移)
timeout: number
max_retries: number
proxy: ProxyConfig | null
cache_ttl_minutes: number // 0表示不支持缓存>0表示支持缓存并设置TTL(分钟)
max_probe_interval_minutes: number
is_active: boolean
@@ -83,7 +87,6 @@ export interface TestModelResponse {
provider?: {
id: string
name: string
display_name: string
}
model?: string
}
@@ -92,4 +95,3 @@ export async function testModel(data: TestModelRequest): Promise<TestModelRespon
const response = await client.post('/api/admin/provider-query/test-model', data)
return response.data
}

View File

@@ -20,6 +20,38 @@ export const API_FORMAT_LABELS: Record<string, string> = {
[API_FORMATS.GEMINI_CLI]: 'Gemini CLI',
}
// API 格式缩写映射(用于空间紧凑的显示场景)
export const API_FORMAT_SHORT: Record<string, string> = {
[API_FORMATS.OPENAI]: 'O',
[API_FORMATS.OPENAI_CLI]: 'OC',
[API_FORMATS.CLAUDE]: 'C',
[API_FORMATS.CLAUDE_CLI]: 'CC',
[API_FORMATS.GEMINI]: 'G',
[API_FORMATS.GEMINI_CLI]: 'GC',
}
// API 格式排序顺序(统一的显示顺序)
export const API_FORMAT_ORDER: string[] = [
API_FORMATS.OPENAI,
API_FORMATS.OPENAI_CLI,
API_FORMATS.CLAUDE,
API_FORMATS.CLAUDE_CLI,
API_FORMATS.GEMINI,
API_FORMATS.GEMINI_CLI,
]
// 工具函数:按标准顺序排序 API 格式数组
export function sortApiFormats(formats: string[]): string[] {
return [...formats].sort((a, b) => {
const aIdx = API_FORMAT_ORDER.indexOf(a)
const bIdx = API_FORMAT_ORDER.indexOf(b)
if (aIdx === -1 && bIdx === -1) return 0
if (aIdx === -1) return 1
if (bIdx === -1) return -1
return aIdx - bIdx
})
}
/**
* 代理配置类型
*/
@@ -37,18 +69,9 @@ export interface ProviderEndpoint {
api_format: string
base_url: string
custom_path?: string // 自定义请求路径(可选,为空则使用 API 格式默认路径)
auth_type: string
auth_header?: string
headers?: Record<string, string>
timeout: number
max_retries: number
priority: number
weight: number
max_concurrent?: number
rate_limit?: number
health_score: number
consecutive_failures: number
last_failure_at?: string
is_active: boolean
config?: Record<string, any>
proxy?: ProxyConfig | null
@@ -58,25 +81,55 @@ export interface ProviderEndpoint {
updated_at: string
}
/**
* 模型权限配置类型(支持简单列表和按格式字典两种模式)
*
* 使用示例:
* 1. 不限制(允许所有模型): null
* 2. 简单列表模式(所有 API 格式共享同一个白名单): ["gpt-4", "claude-3-opus"]
* 3. 按格式字典模式(不同 API 格式使用不同的白名单):
* { "OPENAI": ["gpt-4"], "CLAUDE": ["claude-3-opus"] }
*/
export type AllowedModels = string[] | Record<string, string[]> | null
// AllowedModels 类型守卫函数
export function isAllowedModelsList(value: AllowedModels): value is string[] {
return Array.isArray(value)
}
export function isAllowedModelsDict(value: AllowedModels): value is Record<string, string[]> {
if (value === null || typeof value !== 'object' || Array.isArray(value)) {
return false
}
// 验证所有值都是字符串数组
return Object.values(value).every(
(v) => Array.isArray(v) && v.every((item) => typeof item === 'string')
)
}
export interface EndpointAPIKey {
id: string
endpoint_id: string
provider_id: string
api_formats: string[] // 支持的 API 格式列表
api_key_masked: string
api_key_plain?: string | null
name: string // 密钥名称(必填,用于识别)
rate_multiplier: number // 成本倍率(真实成本 = 表面成本 × 倍率)
internal_priority: number // Endpoint 内部优先级
rate_multiplier: number // 默认成本倍率(真实成本 = 表面成本 × 倍率)
rate_multipliers?: Record<string, number> | null // 按 API 格式的成本倍率,如 {"CLAUDE": 1.0, "OPENAI": 0.8}
internal_priority: number // Key 内部优先级
global_priority?: number | null // 全局 Key 优先级
max_concurrent?: number
rate_limit?: number
daily_limit?: number
monthly_limit?: number
allowed_models?: string[] | null // 允许使用的模型列表null = 支持所有模型)
rpm_limit?: number | null // RPM 速率限制 (1-10000)null 表示自适应模式
allowed_models?: AllowedModels // 允许使用的模型列表null=不限制,列表=简单白名单,字典=按格式区分)
capabilities?: Record<string, boolean> | null // 能力标签配置(如 cache_1h, context_1m
// 缓存与熔断配置
cache_ttl_minutes: number // 缓存 TTL分钟0=禁用
max_probe_interval_minutes: number // 熔断探测间隔(分钟)
// 按格式的健康度数据
health_by_format?: Record<string, FormatHealthData>
circuit_breaker_by_format?: Record<string, FormatCircuitBreakerData>
// 聚合字段(从 health_by_format 计算,用于列表显示)
health_score: number
circuit_breaker_open?: boolean
consecutive_failures: number
last_failure_at?: string
request_count: number
@@ -89,10 +142,10 @@ export interface EndpointAPIKey {
last_used_at?: string
created_at: string
updated_at: string
// 自适应并发字段
is_adaptive?: boolean // 是否为自适应模式(max_concurrent=NULL
effective_limit?: number // 当前有效限制(自适应使用学习值,固定使用配置值)
learned_max_concurrent?: number
// 自适应 RPM 字段
is_adaptive?: boolean // 是否为自适应模式(rpm_limit=NULL
effective_limit?: number // 当前有效 RPM 限制(自适应使用学习值,固定使用配置值)
learned_rpm_limit?: number // 学习到的 RPM 限制
// 滑动窗口利用率采样
utilization_samples?: Array<{ ts: number; util: number }> // 利用率采样窗口
last_probe_increase_at?: string // 上次探测性扩容时间
@@ -100,8 +153,7 @@ export interface EndpointAPIKey {
rpm_429_count?: number
last_429_at?: string
last_429_type?: string
// 熔断器字段(滑动窗口 + 半开模式)
circuit_breaker_open?: boolean
// 单格式场景的熔断器字段
circuit_breaker_open_at?: string
next_probe_at?: string
half_open_until?: string
@@ -110,17 +162,36 @@ export interface EndpointAPIKey {
request_results_window?: Array<{ ts: number; ok: boolean }> // 请求结果滑动窗口
}
// 按格式的健康度数据
export interface FormatHealthData {
health_score: number
error_rate: number
window_size: number
consecutive_failures: number
last_failure_at?: string | null
circuit_breaker: FormatCircuitBreakerData
}
// 按格式的熔断器数据
export interface FormatCircuitBreakerData {
open: boolean
open_at?: string | null
next_probe_at?: string | null
half_open_until?: string | null
half_open_successes: number
half_open_failures: number
}
export interface EndpointAPIKeyUpdate {
api_formats?: string[] // 支持的 API 格式列表
name?: string
api_key?: string // 仅在需要更新时提供
rate_multiplier?: number
rate_multiplier?: number // 默认成本倍率
rate_multipliers?: Record<string, number> | null // 按 API 格式的成本倍率
internal_priority?: number
global_priority?: number | null
max_concurrent?: number | null // null 表示切换为自适应模式
rate_limit?: number
daily_limit?: number
monthly_limit?: number
allowed_models?: string[] | null
rpm_limit?: number | null // RPM 速率限制 (1-10000)null 表示切换为自适应模式
allowed_models?: AllowedModels
capabilities?: Record<string, boolean> | null
cache_ttl_minutes?: number
max_probe_interval_minutes?: number
@@ -198,7 +269,6 @@ export interface PublicEndpointStatusMonitorResponse {
export interface ProviderWithEndpointsSummary {
id: string
name: string
display_name: string
description?: string
website?: string
provider_priority: number
@@ -208,9 +278,10 @@ export interface ProviderWithEndpointsSummary {
quota_reset_day?: number
quota_last_reset_at?: string // 当前周期开始时间
quota_expires_at?: string
rpm_limit?: number | null
rpm_used?: number
rpm_reset_at?: string
// 请求配置(从 Endpoint 迁移)
timeout?: number // 请求超时(秒)
max_retries?: number // 最大重试次数
proxy?: ProxyConfig | null // 代理配置
is_active: boolean
total_endpoints: number
active_endpoints: number
@@ -253,13 +324,10 @@ export interface HealthSummary {
}
}
export interface ConcurrencyStatus {
endpoint_id?: string
endpoint_current_concurrency: number
endpoint_max_concurrent?: number
key_id?: string
key_current_concurrency: number
key_max_concurrent?: number
export interface KeyRpmStatus {
key_id: string
current_rpm: number
rpm_limit?: number
}
export interface ProviderModelMapping {
@@ -361,7 +429,6 @@ export interface ModelPriceRange {
export interface ModelCatalogProviderDetail {
provider_id: string
provider_name: string
provider_display_name?: string | null
model_id?: string | null
target_model: string
input_price_per_1m?: number | null
@@ -534,10 +601,10 @@ export interface UpstreamModel {
*/
export interface ImportFromUpstreamSuccessItem {
model_id: string
global_model_id: string
global_model_name: string
provider_model_id: string
created_global_model: boolean
global_model_id?: string // 可选,未关联时为空字符串
global_model_name?: string // 可选,未关联时为空字符串
created_global_model: boolean // 始终为 false不再自动创建 GlobalModel
}
/**

View File

@@ -0,0 +1,15 @@
<script setup lang="ts">
import { CollapsibleContent, type CollapsibleContentProps } from 'radix-vue'
import { cn } from '@/lib/utils'
const props = defineProps<CollapsibleContentProps & { class?: string }>()
</script>
<template>
<CollapsibleContent
v-bind="props"
:class="cn('overflow-hidden data-[state=closed]:animate-collapsible-up data-[state=open]:animate-collapsible-down', props.class)"
>
<slot />
</CollapsibleContent>
</template>

View File

@@ -0,0 +1,11 @@
<script setup lang="ts">
import { CollapsibleTrigger, type CollapsibleTriggerProps } from 'radix-vue'
const props = defineProps<CollapsibleTriggerProps>()
</script>
<template>
<CollapsibleTrigger v-bind="props" as-child>
<slot />
</CollapsibleTrigger>
</template>

View File

@@ -0,0 +1,15 @@
<script setup lang="ts">
import { CollapsibleRoot, type CollapsibleRootEmits, type CollapsibleRootProps } from 'radix-vue'
import { useForwardPropsEmits } from 'radix-vue'
const props = defineProps<CollapsibleRootProps>()
const emits = defineEmits<CollapsibleRootEmits>()
const forwarded = useForwardPropsEmits(props, emits)
</script>
<template>
<CollapsibleRoot v-bind="forwarded">
<slot />
</CollapsibleRoot>
</template>

View File

@@ -65,3 +65,8 @@ export { default as RefreshButton } from './refresh-button.vue'
// Tooltip 提示系列
export { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from './tooltip'
// Collapsible 折叠系列
export { default as Collapsible } from './collapsible.vue'
export { default as CollapsibleTrigger } from './collapsible-trigger.vue'
export { default as CollapsibleContent } from './collapsible-content.vue'

View File

@@ -186,7 +186,7 @@
@click.stop
@change="toggleSelection('allowed_providers', provider.id)"
>
<span class="text-sm">{{ provider.display_name || provider.name }}</span>
<span class="text-sm">{{ provider.name }}</span>
</div>
<div
v-if="providers.length === 0"

View File

@@ -460,13 +460,13 @@
<TableHead class="h-10 font-semibold">
Provider
</TableHead>
<TableHead class="w-[120px] h-10 font-semibold">
<TableHead class="w-[100px] h-10 font-semibold">
能力
</TableHead>
<TableHead class="w-[180px] h-10 font-semibold">
<TableHead class="w-[200px] h-10 font-semibold">
价格 ($/M)
</TableHead>
<TableHead class="w-[80px] h-10 font-semibold text-center">
<TableHead class="w-[100px] h-10 font-semibold text-center">
操作
</TableHead>
</TableRow>
@@ -484,7 +484,7 @@
:class="provider.is_active ? 'bg-green-500' : 'bg-gray-300'"
:title="provider.is_active ? '活跃' : '停用'"
/>
<span class="font-medium truncate">{{ provider.display_name }}</span>
<span class="font-medium truncate">{{ provider.name }}</span>
</div>
</TableCell>
<TableCell class="py-3">
@@ -595,7 +595,7 @@
class="w-2 h-2 rounded-full shrink-0"
:class="provider.is_active ? 'bg-green-500' : 'bg-gray-300'"
/>
<span class="font-medium truncate">{{ provider.display_name }}</span>
<span class="font-medium truncate">{{ provider.name }}</span>
</div>
<div class="flex items-center gap-1 shrink-0">
<Button

View File

@@ -531,20 +531,23 @@ watch(() => props.open, async (isOpen) => {
// 加载数据
async function loadData() {
await Promise.all([loadGlobalModels(), loadExistingModels()])
// 默认折叠全局模型组
collapsedGroups.value = new Set(['global'])
// 检查缓存,如果有缓存数据则直接使用
const cachedModels = getCachedModels(props.providerId)
if (cachedModels) {
if (cachedModels && cachedModels.length > 0) {
upstreamModels.value = cachedModels
upstreamModelsLoaded.value = true
// 折叠所有上游模型组
// 有多个分组时全部折叠
const allGroups = new Set(['global'])
for (const model of cachedModels) {
if (model.api_format) {
collapsedGroups.value.add(model.api_format)
allGroups.add(model.api_format)
}
}
collapsedGroups.value = allGroups
} else {
// 只有全局模型时展开
collapsedGroups.value = new Set()
}
}
@@ -585,8 +588,8 @@ async function fetchUpstreamModels(forceRefresh = false) {
} else {
upstreamModels.value = result.models
upstreamModelsLoaded.value = true
// 折叠所有上游模型组
const allGroups = new Set(collapsedGroups.value)
// 有多个分组时全部折叠
const allGroups = new Set(['global'])
for (const model of result.models) {
if (model.api_format) {
allGroups.add(model.api_format)

View File

@@ -1,52 +1,142 @@
<template>
<Dialog
:model-value="internalOpen"
:title="isEditMode ? '编辑 API 端点' : '添加 API 端点'"
:description="isEditMode ? `修改 ${provider?.display_name} 的端点配置` : '为提供商添加新的 API 端点'"
:icon="isEditMode ? SquarePen : Link"
size="xl"
title="端点管理"
:description="`管理 ${provider?.name} 的 API 端点`"
:icon="Settings"
size="2xl"
@update:model-value="handleDialogUpdate"
>
<form
class="space-y-6"
@submit.prevent="handleSubmit()"
>
<!-- API 配置 -->
<div class="space-y-4">
<h3
v-if="isEditMode"
class="text-sm font-medium"
<!-- 已有端点列表 -->
<div
v-if="localEndpoints.length > 0"
class="space-y-2"
>
API 配置
</h3>
<div class="grid grid-cols-2 gap-4">
<!-- API 格式 -->
<Label class="text-muted-foreground">已配置的端点</Label>
<div class="space-y-2">
<Label for="api_format">API 格式 *</Label>
<template v-if="isEditMode">
<Input
id="api_format"
v-model="form.api_format"
disabled
class="bg-muted"
/>
<p class="text-xs text-muted-foreground">
API 格式创建后不可修改
</p>
</template>
<template v-else>
<Select
v-model="form.api_format"
v-model:open="selectOpen"
required
<div
v-for="endpoint in localEndpoints"
:key="endpoint.id"
class="rounded-md border px-3 py-2"
:class="{ 'opacity-50': !endpoint.is_active }"
>
<SelectTrigger>
<SelectValue placeholder="请选择 API 格式" />
<!-- 编辑模式 -->
<template v-if="editingEndpointId === endpoint.id">
<div class="space-y-2">
<div class="flex items-center gap-2">
<span class="text-sm font-medium w-24 shrink-0">{{ API_FORMAT_LABELS[endpoint.api_format] || endpoint.api_format }}</span>
<div class="flex items-center gap-1 ml-auto">
<Button
variant="ghost"
size="icon"
class="h-7 w-7"
title="保存"
:disabled="savingEndpointId === endpoint.id"
@click="saveEndpointUrl(endpoint)"
>
<Check class="w-3.5 h-3.5" />
</Button>
<Button
variant="ghost"
size="icon"
class="h-7 w-7"
title="取消"
@click="cancelEdit"
>
<X class="w-3.5 h-3.5" />
</Button>
</div>
</div>
<div class="grid grid-cols-2 gap-2">
<div class="space-y-1">
<Label class="text-xs text-muted-foreground">Base URL</Label>
<Input
v-model="editingUrl"
class="h-8 text-sm"
placeholder="https://api.example.com"
@keyup.escape="cancelEdit"
/>
</div>
<div class="space-y-1">
<Label class="text-xs text-muted-foreground">自定义路径 (可选)</Label>
<Input
v-model="editingPath"
class="h-8 text-sm"
:placeholder="editingDefaultPath || '留空使用默认路径'"
@keyup.escape="cancelEdit"
/>
</div>
</div>
</div>
</template>
<!-- 查看模式 -->
<template v-else>
<div class="flex items-center gap-3">
<div class="w-24 shrink-0">
<span class="text-sm font-medium">{{ API_FORMAT_LABELS[endpoint.api_format] || endpoint.api_format }}</span>
</div>
<div class="flex-1 min-w-0">
<span class="text-sm text-muted-foreground truncate block">
{{ endpoint.base_url }}{{ endpoint.custom_path ? endpoint.custom_path : '' }}
</span>
</div>
<div class="flex items-center gap-1 shrink-0">
<Button
variant="ghost"
size="icon"
class="h-7 w-7"
title="编辑"
@click="startEdit(endpoint)"
>
<Edit class="w-3.5 h-3.5" />
</Button>
<Button
variant="ghost"
size="icon"
class="h-7 w-7"
:title="endpoint.is_active ? '停用' : '启用'"
:disabled="togglingEndpointId === endpoint.id"
@click="handleToggleEndpoint(endpoint)"
>
<Power class="w-3.5 h-3.5" />
</Button>
<Button
variant="ghost"
size="icon"
class="h-7 w-7 text-destructive hover:text-destructive"
title="删除"
:disabled="deletingEndpointId === endpoint.id"
@click="handleDeleteEndpoint(endpoint)"
>
<Trash2 class="w-3.5 h-3.5" />
</Button>
</div>
</div>
</template>
</div>
</div>
</div>
<!-- 添加新端点 -->
<div
v-if="availableFormats.length > 0"
class="space-y-3 pt-3 border-t"
>
<Label class="text-muted-foreground">添加新端点</Label>
<div class="flex items-end gap-3">
<div class="w-32 shrink-0 space-y-1.5">
<Label class="text-xs">API 格式</Label>
<Select
v-model="newEndpoint.api_format"
v-model:open="formatSelectOpen"
>
<SelectTrigger class="h-9">
<SelectValue placeholder="选择格式" />
</SelectTrigger>
<SelectContent>
<SelectItem
v-for="format in apiFormats"
v-for="format in availableFormats"
:key="format.value"
:value="format.value"
>
@@ -54,192 +144,57 @@
</SelectItem>
</SelectContent>
</Select>
</template>
</div>
<!-- API URL -->
<div class="space-y-2">
<Label for="base_url">API URL *</Label>
<div class="flex-1 space-y-1.5">
<Label class="text-xs">Base URL</Label>
<Input
id="base_url"
v-model="form.base_url"
v-model="newEndpoint.base_url"
placeholder="https://api.example.com"
required
class="h-9"
/>
</div>
</div>
<!-- 自定义路径 -->
<div class="space-y-2">
<Label for="custom_path">自定义请求路径可选</Label>
<div class="w-40 shrink-0 space-y-1.5">
<Label class="text-xs">自定义路径</Label>
<Input
id="custom_path"
v-model="form.custom_path"
:placeholder="defaultPathPlaceholder"
v-model="newEndpoint.custom_path"
:placeholder="newEndpointDefaultPath || '可选'"
class="h-9"
/>
</div>
</div>
<!-- 请求配置 -->
<div class="space-y-4">
<h3 class="text-sm font-medium">
请求配置
</h3>
<div class="grid grid-cols-3 gap-4">
<div class="space-y-2">
<Label for="timeout">超时</Label>
<Input
id="timeout"
v-model.number="form.timeout"
type="number"
placeholder="300"
/>
</div>
<div class="space-y-2">
<Label for="max_retries">最大重试</Label>
<Input
id="max_retries"
v-model.number="form.max_retries"
type="number"
placeholder="3"
/>
</div>
<div class="space-y-2">
<Label for="max_concurrent">最大并发</Label>
<Input
id="max_concurrent"
:model-value="form.max_concurrent ?? ''"
type="number"
placeholder="无限制"
@update:model-value="(v) => form.max_concurrent = parseNumberInput(v)"
/>
</div>
</div>
<div class="grid grid-cols-2 gap-4">
<div class="space-y-2">
<Label for="rate_limit">速率限制(请求/分钟)</Label>
<Input
id="rate_limit"
:model-value="form.rate_limit ?? ''"
type="number"
placeholder="无限制"
@update:model-value="(v) => form.rate_limit = parseNumberInput(v)"
/>
</div>
</div>
</div>
<!-- 代理配置 -->
<div class="space-y-4">
<div class="flex items-center justify-between">
<h3 class="text-sm font-medium">
代理配置
</h3>
<div class="flex items-center gap-2">
<Switch v-model="proxyEnabled" />
<span class="text-sm text-muted-foreground">启用代理</span>
<Button
variant="outline"
size="sm"
class="h-9 shrink-0"
:disabled="!newEndpoint.api_format || !newEndpoint.base_url || addingEndpoint"
@click="handleAddEndpoint"
>
{{ addingEndpoint ? '添加中...' : '添加' }}
</Button>
</div>
</div>
<!-- 空状态 -->
<div
v-if="proxyEnabled"
class="space-y-4 rounded-lg border p-4"
v-if="localEndpoints.length === 0 && availableFormats.length === 0"
class="text-center py-8 text-muted-foreground"
>
<div class="space-y-2">
<Label for="proxy_url">代理 URL *</Label>
<Input
id="proxy_url"
v-model="form.proxy_url"
placeholder="http://host:port 或 socks5://host:port"
required
:class="proxyUrlError ? 'border-red-500' : ''"
/>
<p
v-if="proxyUrlError"
class="text-xs text-red-500"
>
{{ proxyUrlError }}
</p>
<p
v-else
class="text-xs text-muted-foreground"
>
支持 HTTPHTTPSSOCKS5 代理
</p>
</div>
<div class="grid grid-cols-2 gap-4">
<div class="space-y-2">
<Label for="proxy_user">用户名可选</Label>
<Input
:id="`proxy_user_${formId}`"
v-model="form.proxy_username"
:name="`proxy_user_${formId}`"
placeholder="代理认证用户名"
autocomplete="off"
data-form-type="other"
data-lpignore="true"
data-1p-ignore="true"
/>
</div>
<div class="space-y-2">
<Label :for="`proxy_pass_${formId}`">密码可选</Label>
<Input
:id="`proxy_pass_${formId}`"
v-model="form.proxy_password"
:name="`proxy_pass_${formId}`"
type="text"
:placeholder="passwordPlaceholder"
autocomplete="off"
data-form-type="other"
data-lpignore="true"
data-1p-ignore="true"
:style="{ '-webkit-text-security': 'disc', 'text-security': 'disc' }"
/>
<p>所有 API 格式都已配置</p>
</div>
</div>
</div>
</div>
</form>
<template #footer>
<Button
type="button"
variant="outline"
:disabled="loading"
@click="handleCancel"
@click="handleClose"
>
取消
</Button>
<Button
:disabled="loading || !form.base_url || (!isEditMode && !form.api_format)"
@click="handleSubmit()"
>
{{ loading ? (isEditMode ? '保存中...' : '创建中...') : (isEditMode ? '保存修改' : '创建') }}
关闭
</Button>
</template>
</Dialog>
<!-- 确认清空凭据对话框 -->
<AlertDialog
v-model="showClearCredentialsDialog"
title="清空代理凭据"
description="代理 URL 为空,但用户名和密码仍有值。是否清空这些凭据并继续保存?"
type="warning"
confirm-text="清空并保存"
cancel-text="返回编辑"
@confirm="confirmClearCredentials"
@cancel="showClearCredentialsDialog = false"
/>
</template>
<script setup lang="ts">
import { ref, computed, onMounted } from 'vue'
import { ref, computed, onMounted, watch } from 'vue'
import {
Dialog,
Button,
@@ -250,17 +205,15 @@ import {
SelectValue,
SelectContent,
SelectItem,
Switch,
} from '@/components/ui'
import AlertDialog from '@/components/common/AlertDialog.vue'
import { Link, SquarePen } from 'lucide-vue-next'
import { Settings, Edit, Trash2, Check, X, Power } from 'lucide-vue-next'
import { useToast } from '@/composables/useToast'
import { useFormDialog } from '@/composables/useFormDialog'
import { parseNumberInput } from '@/utils/form'
import { log } from '@/utils/logger'
import {
createEndpoint,
updateEndpoint,
deleteEndpoint,
API_FORMAT_LABELS,
type ProviderEndpoint,
type ProviderWithEndpointsSummary
} from '@/api/endpoints'
@@ -269,7 +222,7 @@ import { adminApi } from '@/api/admin'
const props = defineProps<{
modelValue: boolean
provider: ProviderWithEndpointsSummary | null
endpoint?: ProviderEndpoint | null // 编辑模式时传入
endpoints?: ProviderEndpoint[]
}>()
const emit = defineEmits<{
@@ -279,35 +232,55 @@ const emit = defineEmits<{
}>()
const { success, error: showError } = useToast()
const loading = ref(false)
const selectOpen = ref(false)
const proxyEnabled = ref(false)
const showClearCredentialsDialog = ref(false) // 确认清空凭据对话框
// 生成随机 ID 防止浏览器自动填充
const formId = Math.random().toString(36).substring(2, 10)
// 状态
const addingEndpoint = ref(false)
const editingEndpointId = ref<string | null>(null)
const editingUrl = ref('')
const editingPath = ref('')
const savingEndpointId = ref<string | null>(null)
const deletingEndpointId = ref<string | null>(null)
const togglingEndpointId = ref<string | null>(null)
const formatSelectOpen = ref(false)
// 内部状态
const internalOpen = computed(() => props.modelValue)
// 表单数据
const form = ref({
// 新端点表单
const newEndpoint = ref({
api_format: '',
base_url: '',
custom_path: '',
timeout: 300,
max_retries: 2,
max_concurrent: undefined as number | undefined,
rate_limit: undefined as number | undefined,
is_active: true,
// 代理配置
proxy_url: '',
proxy_username: '',
proxy_password: '',
})
// API 格式列表
const apiFormats = ref<Array<{ value: string; label: string; default_path: string; aliases: string[] }>>([])
const apiFormats = ref<Array<{ value: string; label: string; default_path: string }>>([])
// 本地端点列表
const localEndpoints = ref<ProviderEndpoint[]>([])
// 可用的格式(未添加的)
const availableFormats = computed(() => {
const existingFormats = localEndpoints.value.map(e => e.api_format)
return apiFormats.value.filter(f => !existingFormats.includes(f.value))
})
// 获取指定 API 格式的默认路径
function getDefaultPath(apiFormat: string): string {
const format = apiFormats.value.find(f => f.value === apiFormat)
return format?.default_path || ''
}
// 当前编辑端点的默认路径
const editingDefaultPath = computed(() => {
const endpoint = localEndpoints.value.find(e => e.id === editingEndpointId.value)
return endpoint ? getDefaultPath(endpoint.api_format) : ''
})
// 新端点选择的格式的默认路径
const newEndpointDefaultPath = computed(() => {
return getDefaultPath(newEndpoint.value.api_format)
})
// 加载 API 格式列表
const loadApiFormats = async () => {
@@ -316,221 +289,127 @@ const loadApiFormats = async () => {
apiFormats.value = response.formats
} catch (error) {
log.error('加载API格式失败:', error)
if (!isEditMode.value) {
showError('加载API格式失败', '错误')
}
}
}
// 根据选择的 API 格式计算默认路径
const defaultPath = computed(() => {
const format = apiFormats.value.find(f => f.value === form.value.api_format)
return format?.default_path || '/'
})
// 动态 placeholder
const defaultPathPlaceholder = computed(() => {
return `留空使用默认路径:${defaultPath.value}`
})
// 检查是否有已保存的密码(后端返回 *** 表示有密码)
const hasExistingPassword = computed(() => {
if (!props.endpoint?.proxy) return false
const proxy = props.endpoint.proxy as { password?: string }
return proxy?.password === MASKED_PASSWORD
})
// 密码输入框的 placeholder
const passwordPlaceholder = computed(() => {
if (hasExistingPassword.value) {
return '已保存密码,留空保持不变'
}
return '代理认证密码'
})
// 代理 URL 验证
const proxyUrlError = computed(() => {
// 只有启用代理且填写了 URL 时才验证
if (!proxyEnabled.value || !form.value.proxy_url) {
return ''
}
const url = form.value.proxy_url.trim()
// 检查禁止的特殊字符
if (/[\n\r]/.test(url)) {
return '代理 URL 包含非法字符'
}
// 验证协议(不支持 SOCKS4
if (!/^(http|https|socks5):\/\//i.test(url)) {
return '代理 URL 必须以 http://, https:// 或 socks5:// 开头'
}
try {
const parsed = new URL(url)
if (!parsed.host) {
return '代理 URL 必须包含有效的 host'
}
// 禁止 URL 中内嵌认证信息
if (parsed.username || parsed.password) {
return '请勿在 URL 中包含用户名和密码,请使用独立的认证字段'
}
} catch {
return '代理 URL 格式无效'
}
return ''
})
// 组件挂载时加载API格式
onMounted(() => {
loadApiFormats()
})
// 重置表单
function resetForm() {
form.value = {
api_format: '',
base_url: '',
custom_path: '',
timeout: 300,
max_retries: 2,
max_concurrent: undefined,
rate_limit: undefined,
is_active: true,
proxy_url: '',
proxy_username: '',
proxy_password: '',
// 监听 props 变化
watch(() => props.modelValue, (open) => {
if (open) {
localEndpoints.value = [...(props.endpoints || [])]
// 重置编辑状态
editingEndpointId.value = null
editingUrl.value = ''
editingPath.value = ''
} else {
// 关闭对话框时完全清空新端点表单
newEndpoint.value = { api_format: '', base_url: '', custom_path: '' }
}
proxyEnabled.value = false
}, { immediate: true })
watch(() => props.endpoints, (endpoints) => {
if (props.modelValue) {
localEndpoints.value = [...(endpoints || [])]
}
}, { deep: true })
// 开始编辑
function startEdit(endpoint: ProviderEndpoint) {
editingEndpointId.value = endpoint.id
editingUrl.value = endpoint.base_url
editingPath.value = endpoint.custom_path || ''
}
// 原始密码占位符(后端返回的脱敏标记)
const MASKED_PASSWORD = '***'
// 加载端点数据(编辑模式)
function loadEndpointData() {
if (!props.endpoint) return
const proxy = props.endpoint.proxy as { url?: string; username?: string; password?: string; enabled?: boolean } | null
form.value = {
api_format: props.endpoint.api_format,
base_url: props.endpoint.base_url,
custom_path: props.endpoint.custom_path || '',
timeout: props.endpoint.timeout,
max_retries: props.endpoint.max_retries,
max_concurrent: props.endpoint.max_concurrent || undefined,
rate_limit: props.endpoint.rate_limit || undefined,
is_active: props.endpoint.is_active,
proxy_url: proxy?.url || '',
proxy_username: proxy?.username || '',
// 如果密码是脱敏标记,显示为空(让用户知道有密码但看不到)
proxy_password: proxy?.password === MASKED_PASSWORD ? '' : (proxy?.password || ''),
// 取消编辑
function cancelEdit() {
editingEndpointId.value = null
editingUrl.value = ''
editingPath.value = ''
}
// 根据 enabled 字段或 url 存在判断是否启用代理
proxyEnabled.value = proxy?.enabled ?? !!proxy?.url
}
// 保存端点
async function saveEndpointUrl(endpoint: ProviderEndpoint) {
if (!editingUrl.value) return
// 使用 useFormDialog 统一处理对话框逻辑
const { isEditMode, handleDialogUpdate, handleCancel } = useFormDialog({
isOpen: () => props.modelValue,
entity: () => props.endpoint,
isLoading: loading,
onClose: () => emit('update:modelValue', false),
loadData: loadEndpointData,
resetForm,
})
// 构建代理配置
// - 有 URL 时始终保存配置,通过 enabled 字段控制是否启用
// - 无 URL 时返回 null
function buildProxyConfig(): { url: string; username?: string; password?: string; enabled: boolean } | null {
if (!form.value.proxy_url) {
// 没填 URL无代理配置
return null
}
return {
url: form.value.proxy_url,
username: form.value.proxy_username || undefined,
password: form.value.proxy_password || undefined,
enabled: proxyEnabled.value, // 开关状态决定是否启用
}
}
// 提交表单
const handleSubmit = async (skipCredentialCheck = false) => {
if (!props.provider && !props.endpoint) return
// 只在开关开启且填写了 URL 时验证
if (proxyEnabled.value && form.value.proxy_url && proxyUrlError.value) {
showError(proxyUrlError.value, '代理配置错误')
return
}
// 检查:开关开启但没有 URL却有用户名或密码
const hasOrphanedCredentials = proxyEnabled.value
&& !form.value.proxy_url
&& (form.value.proxy_username || form.value.proxy_password)
if (hasOrphanedCredentials && !skipCredentialCheck) {
// 弹出确认对话框
showClearCredentialsDialog.value = true
return
}
loading.value = true
savingEndpointId.value = endpoint.id
try {
const proxyConfig = buildProxyConfig()
if (isEditMode.value && props.endpoint) {
// 更新端点
await updateEndpoint(props.endpoint.id, {
base_url: form.value.base_url,
custom_path: form.value.custom_path || undefined,
timeout: form.value.timeout,
max_retries: form.value.max_retries,
max_concurrent: form.value.max_concurrent,
rate_limit: form.value.rate_limit,
is_active: form.value.is_active,
proxy: proxyConfig,
await updateEndpoint(endpoint.id, {
base_url: editingUrl.value,
custom_path: editingPath.value || null, // 空字符串时传 null 清空
})
success('端点已更新', '保存成功')
success('端点已更新')
emit('endpointUpdated')
} else if (props.provider) {
// 创建端点
cancelEdit()
} catch (error: any) {
showError(error.response?.data?.detail || '更新失败', '错误')
} finally {
savingEndpointId.value = null
}
}
// 添加端点
async function handleAddEndpoint() {
if (!props.provider || !newEndpoint.value.api_format || !newEndpoint.value.base_url) return
addingEndpoint.value = true
try {
await createEndpoint(props.provider.id, {
provider_id: props.provider.id,
api_format: form.value.api_format,
base_url: form.value.base_url,
custom_path: form.value.custom_path || undefined,
timeout: form.value.timeout,
max_retries: form.value.max_retries,
max_concurrent: form.value.max_concurrent,
rate_limit: form.value.rate_limit,
is_active: form.value.is_active,
proxy: proxyConfig,
api_format: newEndpoint.value.api_format,
base_url: newEndpoint.value.base_url,
custom_path: newEndpoint.value.custom_path || undefined,
is_active: true,
})
success('端点创建成功', '成功')
success(`已添加 ${API_FORMAT_LABELS[newEndpoint.value.api_format] || newEndpoint.value.api_format} 端点`)
// 重置表单,保留 URL
const url = newEndpoint.value.base_url
newEndpoint.value = { api_format: '', base_url: url, custom_path: '' }
emit('endpointCreated')
resetForm()
}
emit('update:modelValue', false)
} catch (error: any) {
const action = isEditMode.value ? '更新' : '创建'
showError(error.response?.data?.detail || `${action}端点失败`, '错误')
showError(error.response?.data?.detail || '添加失败', '错误')
} finally {
loading.value = false
addingEndpoint.value = false
}
}
// 确认清空凭据并继续保存
const confirmClearCredentials = () => {
form.value.proxy_username = ''
form.value.proxy_password = ''
showClearCredentialsDialog.value = false
handleSubmit(true) // 跳过凭据检查,直接提交
// 切换端点启用状态
async function handleToggleEndpoint(endpoint: ProviderEndpoint) {
togglingEndpointId.value = endpoint.id
try {
const newStatus = !endpoint.is_active
await updateEndpoint(endpoint.id, { is_active: newStatus })
success(newStatus ? '端点已启用' : '端点已停用')
emit('endpointUpdated')
} catch (error: any) {
showError(error.response?.data?.detail || '操作失败', '错误')
} finally {
togglingEndpointId.value = null
}
}
// 删除端点
async function handleDeleteEndpoint(endpoint: ProviderEndpoint) {
deletingEndpointId.value = endpoint.id
try {
await deleteEndpoint(endpoint.id)
success(`已删除 ${API_FORMAT_LABELS[endpoint.api_format] || endpoint.api_format} 端点`)
emit('endpointUpdated')
} catch (error: any) {
showError(error.response?.data?.detail || '删除失败', '错误')
} finally {
deletingEndpointId.value = null
}
}
// 关闭对话框
function handleDialogUpdate(value: boolean) {
emit('update:modelValue', value)
}
function handleClose() {
emit('update:modelValue', false)
}
</script>

View File

@@ -1,147 +1,160 @@
<template>
<Dialog
:model-value="isOpen"
title="配置允许的模型"
description="选择该 API Key 允许访问的模型,留空则允许访问所有模型"
:icon="Settings2"
title="获取上游模型"
:description="`使用密钥 ${props.apiKey?.name || props.apiKey?.api_key_masked || ''} 从上游获取模型列表。导入的模型需要关联全局模型后才能参与路由。`"
:icon="Layers"
size="2xl"
@update:model-value="handleDialogUpdate"
>
<div class="space-y-4 py-2">
<!-- 已选模型展示 -->
<div
v-if="selectedModels.length > 0"
class="space-y-2"
>
<div class="flex items-center justify-between px-1">
<div class="text-xs font-medium text-muted-foreground">
已选模型 ({{ selectedModels.length }})
<!-- 操作区域 -->
<div class="flex items-center justify-between">
<div class="text-sm text-muted-foreground">
<span v-if="!hasQueried">点击获取按钮查询上游可用模型</span>
<span v-else-if="upstreamModels.length > 0">
{{ upstreamModels.length }} 个模型已选 {{ selectedModels.length }}
</span>
<span v-else>未找到可用模型</span>
</div>
<Button
type="button"
variant="ghost"
variant="outline"
size="sm"
class="h-6 text-xs hover:text-destructive"
@click="clearModels"
:disabled="loading"
@click="fetchUpstreamModels"
>
清空
<RefreshCw
class="w-3.5 h-3.5 mr-1.5"
:class="{ 'animate-spin': loading }"
/>
{{ hasQueried ? '刷新' : '获取模型' }}
</Button>
</div>
<div class="flex flex-wrap gap-1.5 p-2 bg-muted/20 rounded-lg border border-border/40 min-h-[40px]">
<Badge
v-for="modelName in selectedModels"
:key="modelName"
variant="secondary"
class="text-[11px] px-2 py-0.5 bg-background border-border/60 shadow-sm"
>
{{ getModelLabel(modelName) }}
<button
class="ml-0.5 hover:text-destructive focus:outline-none"
@click.stop="toggleModel(modelName, false)"
>
&times;
</button>
</Badge>
</div>
</div>
<!-- 模型列表区域 -->
<div class="space-y-2">
<div class="flex items-center justify-between px-1">
<div class="text-xs font-medium text-muted-foreground">
可选模型列表
</div>
<div
v-if="!loadingModels && availableModels.length > 0"
class="text-[10px] text-muted-foreground/60"
>
{{ availableModels.length }} 个模型
</div>
</div>
<!-- 加载状态 -->
<div
v-if="loadingModels"
v-if="loading"
class="flex flex-col items-center justify-center py-12 space-y-3"
>
<div class="animate-spin rounded-full h-8 w-8 border-2 border-primary/20 border-t-primary" />
<span class="text-xs text-muted-foreground">正在加载模型列表...</span>
<span class="text-xs text-muted-foreground">正在从上游获取模型列表...</span>
</div>
<!-- 错误状态 -->
<div
v-else-if="errorMessage"
class="flex flex-col items-center justify-center py-12 text-destructive border border-dashed border-destructive/30 rounded-lg bg-destructive/5"
>
<AlertCircle class="w-10 h-10 mb-2 opacity-50" />
<span class="text-sm text-center px-4">{{ errorMessage }}</span>
<Button
variant="outline"
size="sm"
class="mt-3"
@click="fetchUpstreamModels"
>
重试
</Button>
</div>
<!-- 未查询状态 -->
<div
v-else-if="!hasQueried"
class="flex flex-col items-center justify-center py-12 text-muted-foreground border border-dashed rounded-lg bg-muted/10"
>
<Layers class="w-10 h-10 mb-2 opacity-20" />
<span class="text-sm">点击上方按钮获取模型列表</span>
</div>
<!-- 无模型 -->
<div
v-else-if="availableModels.length === 0"
v-else-if="upstreamModels.length === 0"
class="flex flex-col items-center justify-center py-12 text-muted-foreground border border-dashed rounded-lg bg-muted/10"
>
<Box class="w-10 h-10 mb-2 opacity-20" />
<span class="text-sm">暂无可选模型</span>
<span class="text-sm">上游 API 未返回可用模型</span>
</div>
<!-- 模型列表 -->
<div v-else class="space-y-2">
<!-- 全选/取消 -->
<div class="flex items-center justify-between px-1">
<div class="flex items-center gap-2">
<Checkbox
:checked="isAllSelected"
:indeterminate="isPartiallySelected"
@update:checked="toggleSelectAll"
/>
<span class="text-xs text-muted-foreground">
{{ isAllSelected ? '取消全选' : '全选' }}
</span>
</div>
<div class="text-xs text-muted-foreground">
{{ newModelsCount }} 个新模型不在本地
</div>
</div>
<div class="max-h-[320px] overflow-y-auto pr-1 space-y-1 custom-scrollbar">
<div
v-else
class="max-h-[320px] overflow-y-auto pr-1 space-y-1.5 custom-scrollbar"
>
<div
v-for="model in availableModels"
:key="model.global_model_name"
v-for="model in upstreamModels"
:key="`${model.id}:${model.api_format || ''}`"
class="group flex items-center gap-3 px-3 py-2.5 rounded-lg border transition-all duration-200 cursor-pointer select-none"
:class="[
selectedModels.includes(model.global_model_name)
selectedModels.includes(model.id)
? 'border-primary/40 bg-primary/5 shadow-sm'
: 'border-border/40 bg-background hover:border-primary/20 hover:bg-muted/30'
]"
@click="toggleModel(model.global_model_name, !selectedModels.includes(model.global_model_name))"
@click="toggleModel(model.id)"
>
<!-- Checkbox -->
<Checkbox
:checked="selectedModels.includes(model.global_model_name)"
:checked="selectedModels.includes(model.id)"
class="data-[state=checked]:bg-primary data-[state=checked]:border-primary"
@click.stop
@update:checked="checked => toggleModel(model.global_model_name, checked)"
@update:checked="checked => toggleModel(model.id, checked)"
/>
<!-- Info -->
<div class="flex-1 min-w-0">
<div class="flex items-center justify-between gap-2">
<span class="text-sm font-medium truncate text-foreground/90">{{ model.display_name }}</span>
<span
v-if="hasPricing(model)"
class="text-[10px] font-mono text-muted-foreground/80 bg-muted/30 px-1.5 py-0.5 rounded border border-border/30 shrink-0"
>
{{ formatPricingShort(model) }}
<div class="flex items-center gap-2">
<span class="text-sm font-medium truncate text-foreground/90">
{{ model.display_name || model.id }}
</span>
<Badge
v-if="model.api_format"
variant="outline"
class="text-[10px] px-1.5 py-0 shrink-0"
>
{{ API_FORMAT_LABELS[model.api_format] || model.api_format }}
</Badge>
<Badge
v-if="isModelExisting(model.id)"
variant="secondary"
class="text-[10px] px-1.5 py-0 shrink-0"
>
已存在
</Badge>
</div>
<div class="text-[11px] text-muted-foreground/60 font-mono truncate mt-0.5">
{{ model.global_model_name }}
{{ model.id }}
</div>
</div>
<!-- 测试按钮 -->
<Button
variant="ghost"
size="icon"
class="h-7 w-7 shrink-0"
title="测试模型连接"
:disabled="testingModelName === model.global_model_name"
@click.stop="testModelConnection(model)"
<div
v-if="model.owned_by"
class="text-[10px] text-muted-foreground/50 shrink-0"
>
<Loader2
v-if="testingModelName === model.global_model_name"
class="w-3.5 h-3.5 animate-spin"
/>
<Play
v-else
class="w-3.5 h-3.5"
/>
</Button>
{{ model.owned_by }}
</div>
</div>
</div>
</div>
</div>
<template #footer>
<div class="flex items-center justify-end gap-2 w-full pt-2">
<div class="flex items-center justify-between w-full pt-2">
<div class="text-xs text-muted-foreground">
<span v-if="selectedModels.length > 0 && newSelectedCount > 0">
将导入 {{ newSelectedCount }} 个新模型
</span>
</div>
<div class="flex items-center gap-2">
<Button
variant="outline"
class="h-9"
@@ -150,36 +163,37 @@
取消
</Button>
<Button
:disabled="saving"
class="h-9 min-w-[80px]"
@click="handleSave"
:disabled="importing || selectedModels.length === 0 || newSelectedCount === 0"
class="h-9 min-w-[100px]"
@click="handleImport"
>
<Loader2
v-if="saving"
v-if="importing"
class="w-3.5 h-3.5 mr-1.5 animate-spin"
/>
{{ saving ? '保存中' : '保存配置' }}
{{ importing ? '导入中' : `导入 ${newSelectedCount} 个模型` }}
</Button>
</div>
</div>
</template>
</Dialog>
</template>
<script setup lang="ts">
import { ref, computed, watch } from 'vue'
import { Box, Loader2, Settings2, Play } from 'lucide-vue-next'
import { Box, Layers, Loader2, RefreshCw, AlertCircle } from 'lucide-vue-next'
import { Dialog } from '@/components/ui'
import Button from '@/components/ui/button.vue'
import Badge from '@/components/ui/badge.vue'
import Checkbox from '@/components/ui/checkbox.vue'
import { useToast } from '@/composables/useToast'
import { parseApiError, parseTestModelError } from '@/utils/errorParser'
import { adminApi } from '@/api/admin'
import {
updateEndpointKey,
getProviderAvailableSourceModels,
testModel,
importModelsFromUpstream,
getProviderModels,
type EndpointAPIKey,
type ProviderAvailableSourceModel
type UpstreamModel,
API_FORMAT_LABELS,
} from '@/api/endpoints'
const props = defineProps<{
@@ -196,130 +210,116 @@ const emit = defineEmits<{
const { success, error: showError } = useToast()
const isOpen = computed(() => props.open)
const saving = ref(false)
const loadingModels = ref(false)
const availableModels = ref<ProviderAvailableSourceModel[]>([])
const loading = ref(false)
const importing = ref(false)
const hasQueried = ref(false)
const errorMessage = ref('')
const upstreamModels = ref<UpstreamModel[]>([])
const selectedModels = ref<string[]>([])
const initialModels = ref<string[]>([])
const testingModelName = ref<string | null>(null)
const existingModelIds = ref<Set<string>>(new Set())
// 计算属性
const isAllSelected = computed(() =>
upstreamModels.value.length > 0 &&
selectedModels.value.length === upstreamModels.value.length
)
const isPartiallySelected = computed(() =>
selectedModels.value.length > 0 &&
selectedModels.value.length < upstreamModels.value.length
)
const newModelsCount = computed(() =>
upstreamModels.value.filter(m => !existingModelIds.value.has(m.id)).length
)
const newSelectedCount = computed(() =>
selectedModels.value.filter(id => !existingModelIds.value.has(id)).length
)
// 检查模型是否已存在
function isModelExisting(modelId: string): boolean {
return existingModelIds.value.has(modelId)
}
// 监听对话框打开
watch(() => props.open, (open) => {
if (open) {
loadData()
resetState()
loadExistingModels()
}
})
async function loadData() {
// 初始化已选模型
if (props.apiKey?.allowed_models) {
selectedModels.value = [...props.apiKey.allowed_models]
initialModels.value = [...props.apiKey.allowed_models]
} else {
function resetState() {
hasQueried.value = false
errorMessage.value = ''
upstreamModels.value = []
selectedModels.value = []
initialModels.value = []
}
// 加载可选模型
if (props.providerId) {
await loadAvailableModels()
}
}
async function loadAvailableModels() {
// 加载已存在的模型列表
async function loadExistingModels() {
if (!props.providerId) return
try {
loadingModels.value = true
const response = await getProviderAvailableSourceModels(props.providerId)
availableModels.value = response.models
} catch (err: any) {
const errorMessage = parseApiError(err, '加载模型列表失败')
showError(errorMessage, '错误')
} finally {
loadingModels.value = false
const models = await getProviderModels(props.providerId)
existingModelIds.value = new Set(
models.map((m: { provider_model_name: string }) => m.provider_model_name)
)
} catch {
existingModelIds.value = new Set()
}
}
const modelLabelMap = computed(() => {
const map = new Map<string, string>()
availableModels.value.forEach(model => {
map.set(model.global_model_name, model.display_name || model.global_model_name)
})
return map
})
// 获取上游模型
async function fetchUpstreamModels() {
if (!props.providerId || !props.apiKey) return
function getModelLabel(modelName: string): string {
return modelLabelMap.value.get(modelName) ?? modelName
}
loading.value = true
errorMessage.value = ''
function hasPricing(model: ProviderAvailableSourceModel): boolean {
const input = model.price.input_price_per_1m ?? 0
const output = model.price.output_price_per_1m ?? 0
return input > 0 || output > 0
}
try {
const response = await adminApi.queryProviderModels(props.providerId, props.apiKey.id)
function formatPricingShort(model: ProviderAvailableSourceModel): string {
const input = model.price.input_price_per_1m ?? 0
const output = model.price.output_price_per_1m ?? 0
if (input > 0 || output > 0) {
return `$${formatPrice(input)}/$${formatPrice(output)}`
}
return ''
}
function formatPrice(value?: number | null): string {
if (value === undefined || value === null || value === 0) return '0'
if (value >= 1) {
return value.toFixed(2)
}
return value.toFixed(2)
}
function toggleModel(modelName: string, checked: boolean) {
if (checked) {
if (!selectedModels.value.includes(modelName)) {
selectedModels.value = [...selectedModels.value, modelName]
if (response.success && response.data?.models) {
upstreamModels.value = response.data.models
// 默认选中所有新模型
selectedModels.value = response.data.models
.filter((m: UpstreamModel) => !existingModelIds.value.has(m.id))
.map((m: UpstreamModel) => m.id)
hasQueried.value = true
// 如果有部分失败,显示警告提示
if (response.data.error) {
showError(`部分格式获取失败: ${response.data.error}`, '警告')
}
} else {
selectedModels.value = selectedModels.value.filter(name => name !== modelName)
errorMessage.value = response.data?.error || '获取上游模型失败'
}
} catch (err: any) {
errorMessage.value = err.response?.data?.detail || '获取上游模型失败'
} finally {
loading.value = false
}
}
function clearModels() {
// 切换模型选择
function toggleModel(modelId: string, checked?: boolean) {
const shouldSelect = checked !== undefined ? checked : !selectedModels.value.includes(modelId)
if (shouldSelect) {
if (!selectedModels.value.includes(modelId)) {
selectedModels.value = [...selectedModels.value, modelId]
}
} else {
selectedModels.value = selectedModels.value.filter(id => id !== modelId)
}
}
// 全选/取消全选
function toggleSelectAll(checked: boolean) {
if (checked) {
selectedModels.value = upstreamModels.value.map(m => m.id)
} else {
selectedModels.value = []
}
// 测试模型连接
async function testModelConnection(model: ProviderAvailableSourceModel) {
if (!props.providerId || !props.apiKey || testingModelName.value) return
testingModelName.value = model.global_model_name
try {
const result = await testModel({
provider_id: props.providerId,
model_name: model.provider_model_name,
api_key_id: props.apiKey.id,
message: "hello"
})
if (result.success) {
success(`模型 "${model.display_name}" 测试成功`)
} else {
showError(`模型测试失败: ${parseTestModelError(result)}`)
}
} catch (err: any) {
const errorMsg = err.response?.data?.detail || err.message || '测试请求失败'
showError(`模型测试失败: ${errorMsg}`)
} finally {
testingModelName.value = null
}
}
function areArraysEqual(a: string[], b: string[]): boolean {
if (a.length !== b.length) return false
const sortedA = [...a].sort()
const sortedB = [...b].sort()
return sortedA.every((value, index) => value === sortedB[index])
}
function handleDialogUpdate(value: boolean) {
@@ -332,30 +332,44 @@ function handleCancel() {
emit('close')
}
async function handleSave() {
if (!props.apiKey) return
// 导入选中的模型
async function handleImport() {
if (!props.providerId || selectedModels.value.length === 0) return
// 检查是否有变化
const hasChanged = !areArraysEqual(selectedModels.value, initialModels.value)
if (!hasChanged) {
emit('close')
// 过滤出新模型(不在已存在列表中的)
const modelsToImport = selectedModels.value.filter(id => !existingModelIds.value.has(id))
if (modelsToImport.length === 0) {
showError('所选模型都已存在', '提示')
return
}
saving.value = true
importing.value = true
try {
await updateEndpointKey(props.apiKey.id, {
// 空数组时发送 null表示允许所有模型
allowed_models: selectedModels.value.length > 0 ? [...selectedModels.value] : null
})
success('允许的模型已更新', '成功')
const response = await importModelsFromUpstream(props.providerId, modelsToImport)
const successCount = response.success?.length || 0
const errorCount = response.errors?.length || 0
if (successCount > 0 && errorCount === 0) {
success(`成功导入 ${successCount} 个模型`, '导入成功')
emit('saved')
emit('close')
} else if (successCount > 0 && errorCount > 0) {
success(`成功导入 ${successCount} 个模型,${errorCount} 个失败`, '部分成功')
emit('saved')
// 刷新列表以更新已存在状态
await loadExistingModels()
// 更新选中列表,移除已成功导入的
const successIds = new Set(response.success?.map((s: { model_id: string }) => s.model_id) || [])
selectedModels.value = selectedModels.value.filter(id => !successIds.has(id))
} else {
const errorMsg = response.errors?.[0]?.error || '导入失败'
showError(errorMsg, '导入失败')
}
} catch (err: any) {
const errorMessage = parseApiError(err, '保存失败')
showError(errorMessage, '错误')
showError(err.response?.data?.detail || '导入失败', '错误')
} finally {
saving.value = false
importing.value = false
}
}
</script>

View File

@@ -0,0 +1,696 @@
<template>
<Dialog
:model-value="isOpen"
title="模型权限"
:description="`管理密钥 ${props.apiKey?.name || ''} 可访问的模型,清空右侧列表表示允许全部`"
:icon="Shield"
size="4xl"
@update:model-value="handleDialogUpdate"
>
<template #default>
<div class="space-y-4">
<!-- 字典模式警告 -->
<div
v-if="isDictMode"
class="rounded-lg border border-amber-500/50 bg-amber-50 dark:bg-amber-950/30 p-3"
>
<p class="text-sm text-amber-700 dark:text-amber-400">
<strong>注意</strong>此密钥使用按 API 格式区分的模型权限配置
编辑后将转换为统一列表模式原有的格式区分信息将丢失
</p>
</div>
<!-- 密钥信息头部 -->
<div class="rounded-lg border bg-muted/30 p-4">
<div class="flex items-start justify-between">
<div>
<p class="font-semibold text-lg">{{ apiKey?.name }}</p>
<p class="text-sm text-muted-foreground font-mono">
{{ apiKey?.api_key_masked }}
</p>
</div>
<Badge
:variant="allowedModels.length === 0 ? 'default' : 'outline'"
class="text-xs"
>
{{ allowedModels.length === 0 ? '允许全部' : `限制 ${allowedModels.length} 个模型` }}
</Badge>
</div>
</div>
<!-- 左右对比布局 -->
<div class="flex gap-2 items-stretch">
<!-- 左侧可添加的模型 -->
<div class="flex-1 space-y-2">
<div class="flex items-center justify-between gap-2">
<p class="text-sm font-medium shrink-0">可添加</p>
<div class="flex-1 relative">
<Search class="absolute left-2 top-1/2 -translate-y-1/2 w-3.5 h-3.5 text-muted-foreground" />
<Input
v-model="searchQuery"
placeholder="搜索模型..."
class="pl-7 h-7 text-xs"
/>
</div>
<button
v-if="upstreamModelsLoaded"
type="button"
class="p-1.5 hover:bg-muted rounded-md transition-colors shrink-0"
title="刷新上游模型"
:disabled="fetchingUpstreamModels"
@click="fetchUpstreamModels()"
>
<RefreshCw
class="w-3.5 h-3.5"
:class="{ 'animate-spin': fetchingUpstreamModels }"
/>
</button>
<button
v-else-if="!fetchingUpstreamModels"
type="button"
class="p-1.5 hover:bg-muted rounded-md transition-colors shrink-0"
title="从提供商获取模型"
@click="fetchUpstreamModels()"
>
<Zap class="w-3.5 h-3.5" />
</button>
<Loader2
v-else
class="w-3.5 h-3.5 animate-spin text-muted-foreground shrink-0"
/>
</div>
<div class="border rounded-lg h-80 overflow-y-auto">
<div
v-if="loadingGlobalModels"
class="flex items-center justify-center h-full"
>
<Loader2 class="w-6 h-6 animate-spin text-primary" />
</div>
<div
v-else-if="totalAvailableCount === 0 && !upstreamModelsLoaded"
class="flex flex-col items-center justify-center h-full text-muted-foreground"
>
<Shield class="w-10 h-10 mb-2 opacity-30" />
<p class="text-sm">{{ searchQuery ? '无匹配结果' : '暂无可添加模型' }}</p>
</div>
<div v-else class="p-2 space-y-2">
<!-- 全局模型折叠组 -->
<div
v-if="availableGlobalModels.length > 0 || !upstreamModelsLoaded"
class="border rounded-lg overflow-hidden"
>
<div class="flex items-center gap-2 px-3 py-2 bg-muted/30">
<button
type="button"
class="flex items-center gap-2 flex-1 hover:bg-muted/50 -mx-1 px-1 rounded transition-colors"
@click="toggleGroupCollapse('global')"
>
<ChevronDown
class="w-4 h-4 transition-transform shrink-0"
:class="collapsedGroups.has('global') ? '-rotate-90' : ''"
/>
<span class="text-xs font-medium">全局模型</span>
<span class="text-xs text-muted-foreground">
({{ availableGlobalModels.length }})
</span>
</button>
<button
v-if="availableGlobalModels.length > 0"
type="button"
class="text-xs text-primary hover:underline shrink-0"
@click.stop="selectAllGlobalModels"
>
{{ isAllGlobalModelsSelected ? '取消' : '全选' }}
</button>
</div>
<div
v-show="!collapsedGroups.has('global')"
class="p-2 space-y-1 border-t"
>
<div
v-if="availableGlobalModels.length === 0"
class="py-4 text-center text-xs text-muted-foreground"
>
所有全局模型均已添加
</div>
<div
v-for="model in availableGlobalModels"
v-else
:key="model.name"
class="flex items-center gap-2 p-2 rounded-lg border transition-colors cursor-pointer"
:class="selectedLeftIds.includes(model.name)
? 'border-primary bg-primary/10'
: 'hover:bg-muted/50'"
@click="toggleLeftSelection(model.name)"
>
<Checkbox
:checked="selectedLeftIds.includes(model.name)"
@update:checked="toggleLeftSelection(model.name)"
@click.stop
/>
<div class="flex-1 min-w-0">
<p class="font-medium text-sm truncate">{{ model.display_name }}</p>
<p class="text-xs text-muted-foreground truncate font-mono">{{ model.name }}</p>
</div>
</div>
</div>
</div>
<!-- 从提供商获取的模型折叠组 -->
<div
v-for="group in upstreamModelGroups"
:key="group.api_format"
class="border rounded-lg overflow-hidden"
>
<div class="flex items-center gap-2 px-3 py-2 bg-muted/30">
<button
type="button"
class="flex items-center gap-2 flex-1 hover:bg-muted/50 -mx-1 px-1 rounded transition-colors"
@click="toggleGroupCollapse(group.api_format)"
>
<ChevronDown
class="w-4 h-4 transition-transform shrink-0"
:class="collapsedGroups.has(group.api_format) ? '-rotate-90' : ''"
/>
<span class="text-xs font-medium">
{{ API_FORMAT_LABELS[group.api_format] || group.api_format }}
</span>
<span class="text-xs text-muted-foreground">
({{ group.models.length }})
</span>
</button>
<button
type="button"
class="text-xs text-primary hover:underline shrink-0"
@click.stop="selectAllUpstreamModels(group.api_format)"
>
{{ isUpstreamGroupAllSelected(group.api_format) ? '取消' : '全选' }}
</button>
</div>
<div
v-show="!collapsedGroups.has(group.api_format)"
class="p-2 space-y-1 border-t"
>
<div
v-for="model in group.models"
:key="model.id"
class="flex items-center gap-2 p-2 rounded-lg border transition-colors cursor-pointer"
:class="selectedLeftIds.includes(model.id)
? 'border-primary bg-primary/10'
: 'hover:bg-muted/50'"
@click="toggleLeftSelection(model.id)"
>
<Checkbox
:checked="selectedLeftIds.includes(model.id)"
@update:checked="toggleLeftSelection(model.id)"
@click.stop
/>
<div class="flex-1 min-w-0">
<p class="font-medium text-sm truncate">{{ model.id }}</p>
<p class="text-xs text-muted-foreground truncate font-mono">
{{ model.owned_by || model.id }}
</p>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
<!-- 中间操作按钮 -->
<div class="flex flex-col items-center justify-center w-12 shrink-0 gap-2">
<Button
variant="outline"
size="sm"
class="w-9 h-8"
:class="selectedLeftIds.length > 0 ? 'border-primary' : ''"
:disabled="selectedLeftIds.length === 0"
title="添加选中"
@click="addSelected"
>
<ChevronRight
class="w-6 h-6 stroke-[3]"
:class="selectedLeftIds.length > 0 ? 'text-primary' : ''"
/>
</Button>
<Button
variant="outline"
size="sm"
class="w-9 h-8"
:class="selectedRightIds.length > 0 ? 'border-primary' : ''"
:disabled="selectedRightIds.length === 0"
title="移除选中"
@click="removeSelected"
>
<ChevronLeft
class="w-6 h-6 stroke-[3]"
:class="selectedRightIds.length > 0 ? 'text-primary' : ''"
/>
</Button>
</div>
<!-- 右侧已添加的允许模型 -->
<div class="flex-1 space-y-2">
<div class="flex items-center justify-between">
<p class="text-sm font-medium">已添加</p>
<Button
v-if="allowedModels.length > 0"
variant="ghost"
size="sm"
class="h-6 px-2 text-xs"
@click="toggleSelectAllRight"
>
{{ isAllRightSelected ? '取消' : '全选' }}
</Button>
</div>
<div class="border rounded-lg h-80 overflow-y-auto">
<div
v-if="allowedModels.length === 0"
class="flex flex-col items-center justify-center h-full text-muted-foreground"
>
<Shield class="w-10 h-10 mb-2 opacity-30" />
<p class="text-sm">允许访问全部模型</p>
<p class="text-xs mt-1">添加模型以限制访问范围</p>
</div>
<div v-else class="p-2 space-y-1">
<div
v-for="modelName in allowedModels"
:key="'allowed-' + modelName"
class="flex items-center gap-2 p-2 rounded-lg border transition-colors cursor-pointer"
:class="selectedRightIds.includes(modelName)
? 'border-primary bg-primary/10'
: 'hover:bg-muted/50'"
@click="toggleRightSelection(modelName)"
>
<Checkbox
:checked="selectedRightIds.includes(modelName)"
@update:checked="toggleRightSelection(modelName)"
@click.stop
/>
<div class="flex-1 min-w-0">
<p class="font-medium text-sm truncate">
{{ getModelDisplayName(modelName) }}
</p>
<p class="text-xs text-muted-foreground truncate font-mono">
{{ modelName }}
</p>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
</template>
<template #footer>
<div class="flex items-center justify-between w-full">
<p class="text-xs text-muted-foreground">
{{ hasChanges ? '有未保存的更改' : '' }}
</p>
<div class="flex items-center gap-2">
<Button variant="outline" @click="handleCancel">取消</Button>
<Button :disabled="saving || !hasChanges" @click="handleSave">
{{ saving ? '保存中...' : '保存' }}
</Button>
</div>
</div>
</template>
</Dialog>
</template>
<script setup lang="ts">
import { ref, computed, watch, onUnmounted } from 'vue'
import {
Shield,
Search,
RefreshCw,
Loader2,
Zap,
ChevronRight,
ChevronLeft,
ChevronDown
} from 'lucide-vue-next'
import { Dialog, Button, Input, Checkbox, Badge } from '@/components/ui'
import { useToast } from '@/composables/useToast'
import { parseApiError } from '@/utils/errorParser'
import {
updateProviderKey,
API_FORMAT_LABELS,
type EndpointAPIKey,
type AllowedModels,
} from '@/api/endpoints'
import { getGlobalModels, type GlobalModelResponse } from '@/api/global-models'
import { adminApi } from '@/api/admin'
import type { UpstreamModel } from '@/api/endpoints/types'
interface AvailableModel {
name: string
display_name: string
}
const props = defineProps<{
open: boolean
apiKey: EndpointAPIKey | null
providerId: string
}>()
const emit = defineEmits<{
close: []
saved: []
}>()
const { success, error: showError } = useToast()
const isOpen = computed(() => props.open)
const saving = ref(false)
const loadingGlobalModels = ref(false)
const fetchingUpstreamModels = ref(false)
const upstreamModelsLoaded = ref(false)
// 用于取消异步操作的标志
let loadingCancelled = false
// 搜索
const searchQuery = ref('')
// 折叠状态
const collapsedGroups = ref<Set<string>>(new Set())
// 可用模型列表(全局模型)
const allGlobalModels = ref<AvailableModel[]>([])
// 上游模型列表
const upstreamModels = ref<UpstreamModel[]>([])
// 已添加的允许模型(右侧)
const allowedModels = ref<string[]>([])
const initialAllowedModels = ref<string[]>([])
// 选中状态
const selectedLeftIds = ref<string[]>([])
const selectedRightIds = ref<string[]>([])
// 是否有更改
const hasChanges = computed(() => {
if (allowedModels.value.length !== initialAllowedModels.value.length) return true
const sorted1 = [...allowedModels.value].sort()
const sorted2 = [...initialAllowedModels.value].sort()
return sorted1.some((v, i) => v !== sorted2[i])
})
// 计算可添加的全局模型(排除已添加的)
const availableGlobalModelsBase = computed(() => {
const allowedSet = new Set(allowedModels.value)
return allGlobalModels.value.filter(m => !allowedSet.has(m.name))
})
// 搜索过滤后的全局模型
const availableGlobalModels = computed(() => {
if (!searchQuery.value.trim()) return availableGlobalModelsBase.value
const query = searchQuery.value.toLowerCase()
return availableGlobalModelsBase.value.filter(m =>
m.name.toLowerCase().includes(query) ||
m.display_name.toLowerCase().includes(query)
)
})
// 计算可添加的上游模型(排除已添加的)
const availableUpstreamModelsBase = computed(() => {
const allowedSet = new Set(allowedModels.value)
return upstreamModels.value.filter(m => !allowedSet.has(m.id))
})
// 搜索过滤后的上游模型
const availableUpstreamModels = computed(() => {
if (!searchQuery.value.trim()) return availableUpstreamModelsBase.value
const query = searchQuery.value.toLowerCase()
return availableUpstreamModelsBase.value.filter(m =>
m.id.toLowerCase().includes(query) ||
(m.owned_by && m.owned_by.toLowerCase().includes(query))
)
})
// 按 API 格式分组的上游模型
const upstreamModelGroups = computed(() => {
const groups: Record<string, UpstreamModel[]> = {}
for (const model of availableUpstreamModels.value) {
const format = model.api_format || 'unknown'
if (!groups[format]) groups[format] = []
groups[format].push(model)
}
const order = Object.keys(API_FORMAT_LABELS)
return Object.entries(groups)
.map(([api_format, models]) => ({ api_format, models }))
.sort((a, b) => {
const aIndex = order.indexOf(a.api_format)
const bIndex = order.indexOf(b.api_format)
if (aIndex === -1 && bIndex === -1) return a.api_format.localeCompare(b.api_format)
if (aIndex === -1) return 1
if (bIndex === -1) return -1
return aIndex - bIndex
})
})
// 总可添加数量
const totalAvailableCount = computed(() => {
return availableGlobalModels.value.length + availableUpstreamModels.value.length
})
// 右侧全选状态
const isAllRightSelected = computed(() =>
allowedModels.value.length > 0 &&
selectedRightIds.value.length === allowedModels.value.length
)
// 全局模型是否全选
const isAllGlobalModelsSelected = computed(() => {
if (availableGlobalModels.value.length === 0) return false
return availableGlobalModels.value.every(m => selectedLeftIds.value.includes(m.name))
})
// 检查某个上游组是否全选
function isUpstreamGroupAllSelected(apiFormat: string): boolean {
const group = upstreamModelGroups.value.find(g => g.api_format === apiFormat)
if (!group || group.models.length === 0) return false
return group.models.every(m => selectedLeftIds.value.includes(m.id))
}
// 获取模型显示名称
function getModelDisplayName(name: string): string {
const globalModel = allGlobalModels.value.find(m => m.name === name)
if (globalModel) return globalModel.display_name
const upstreamModel = upstreamModels.value.find(m => m.id === name)
if (upstreamModel) return upstreamModel.id
return name
}
// 加载全局模型
async function loadGlobalModels() {
loadingGlobalModels.value = true
try {
const response = await getGlobalModels({ limit: 1000 })
// 检查是否已取消dialog 已关闭)
if (loadingCancelled) return
allGlobalModels.value = response.models.map((m: GlobalModelResponse) => ({
name: m.name,
display_name: m.display_name
}))
} catch (err) {
if (loadingCancelled) return
showError('加载全局模型失败', '错误')
} finally {
loadingGlobalModels.value = false
}
}
// 从提供商获取模型(使用当前 key
async function fetchUpstreamModels() {
if (!props.providerId || !props.apiKey) return
try {
fetchingUpstreamModels.value = true
// 使用当前 key 的 ID 来查询上游模型
const response = await adminApi.queryProviderModels(props.providerId, props.apiKey.id)
// 检查是否已取消
if (loadingCancelled) return
if (response.success && response.data?.models) {
upstreamModels.value = response.data.models
upstreamModelsLoaded.value = true
const allGroups = new Set(['global'])
for (const model of response.data.models) {
if (model.api_format) allGroups.add(model.api_format)
}
collapsedGroups.value = allGroups
} else {
showError(response.data?.error || '获取上游模型失败', '错误')
}
} catch (err: any) {
if (loadingCancelled) return
showError(err.response?.data?.detail || '获取上游模型失败', '错误')
} finally {
fetchingUpstreamModels.value = false
}
}
// 切换折叠状态
function toggleGroupCollapse(group: string) {
if (collapsedGroups.value.has(group)) {
collapsedGroups.value.delete(group)
} else {
collapsedGroups.value.add(group)
}
collapsedGroups.value = new Set(collapsedGroups.value)
}
// 是否为字典模式(按 API 格式区分)
const isDictMode = ref(false)
// 解析 allowed_models
function parseAllowedModels(allowed: AllowedModels): string[] {
if (allowed === null || allowed === undefined) {
isDictMode.value = false
return []
}
if (Array.isArray(allowed)) {
isDictMode.value = false
return [...allowed]
}
// 字典模式:合并所有格式的模型,并设置警告标志
isDictMode.value = true
const all = new Set<string>()
for (const models of Object.values(allowed)) {
models.forEach(m => all.add(m))
}
return Array.from(all)
}
// 左侧选择
function toggleLeftSelection(name: string) {
const idx = selectedLeftIds.value.indexOf(name)
if (idx === -1) {
selectedLeftIds.value.push(name)
} else {
selectedLeftIds.value.splice(idx, 1)
}
}
// 右侧选择
function toggleRightSelection(name: string) {
const idx = selectedRightIds.value.indexOf(name)
if (idx === -1) {
selectedRightIds.value.push(name)
} else {
selectedRightIds.value.splice(idx, 1)
}
}
// 右侧全选切换
function toggleSelectAllRight() {
if (isAllRightSelected.value) {
selectedRightIds.value = []
} else {
selectedRightIds.value = [...allowedModels.value]
}
}
// 全选全局模型
function selectAllGlobalModels() {
const allNames = availableGlobalModels.value.map(m => m.name)
const allSelected = allNames.every(name => selectedLeftIds.value.includes(name))
if (allSelected) {
selectedLeftIds.value = selectedLeftIds.value.filter(id => !allNames.includes(id))
} else {
const newNames = allNames.filter(name => !selectedLeftIds.value.includes(name))
selectedLeftIds.value.push(...newNames)
}
}
// 全选某个 API 格式的上游模型
function selectAllUpstreamModels(apiFormat: string) {
const group = upstreamModelGroups.value.find(g => g.api_format === apiFormat)
if (!group) return
const allIds = group.models.map(m => m.id)
const allSelected = allIds.every(id => selectedLeftIds.value.includes(id))
if (allSelected) {
selectedLeftIds.value = selectedLeftIds.value.filter(id => !allIds.includes(id))
} else {
const newIds = allIds.filter(id => !selectedLeftIds.value.includes(id))
selectedLeftIds.value.push(...newIds)
}
}
// 添加选中的模型到右侧
function addSelected() {
for (const name of selectedLeftIds.value) {
if (!allowedModels.value.includes(name)) {
allowedModels.value.push(name)
}
}
selectedLeftIds.value = []
}
// 从右侧移除选中的模型
function removeSelected() {
allowedModels.value = allowedModels.value.filter(
name => !selectedRightIds.value.includes(name)
)
selectedRightIds.value = []
}
// 监听对话框打开
watch(() => props.open, async (open) => {
if (open && props.apiKey) {
// 重置取消标志
loadingCancelled = false
const parsed = parseAllowedModels(props.apiKey.allowed_models ?? null)
allowedModels.value = [...parsed]
initialAllowedModels.value = [...parsed]
selectedLeftIds.value = []
selectedRightIds.value = []
searchQuery.value = ''
upstreamModels.value = []
upstreamModelsLoaded.value = false
collapsedGroups.value = new Set()
await loadGlobalModels()
} else {
// dialog 关闭时设置取消标志
loadingCancelled = true
}
})
// 组件卸载时取消所有异步操作
onUnmounted(() => {
loadingCancelled = true
})
function handleDialogUpdate(value: boolean) {
if (!value) emit('close')
}
function handleCancel() {
emit('close')
}
async function handleSave() {
if (!props.apiKey) return
saving.value = true
try {
// 空列表 = null允许全部
const newAllowed: AllowedModels = allowedModels.value.length > 0
? [...allowedModels.value]
: null
await updateProviderKey(props.apiKey.id, { allowed_models: newAllowed })
success('模型权限已更新', '成功')
emit('saved')
emit('close')
} catch (err: any) {
showError(parseApiError(err, '保存失败'), '错误')
} finally {
saving.value = false
}
}
</script>

View File

@@ -2,22 +2,18 @@
<Dialog
:model-value="isOpen"
:title="isEditMode ? '编辑密钥' : '添加密钥'"
:description="isEditMode ? '修改 API 密钥配置' : '为端点添加新的 API 密钥'"
:description="isEditMode ? '修改 API 密钥配置' : '为提供商添加新的 API 密钥'"
:icon="isEditMode ? SquarePen : Key"
size="2xl"
@update:model-value="handleDialogUpdate"
>
<form
class="space-y-5"
class="space-y-4"
autocomplete="off"
@submit.prevent="handleSave"
>
<!-- 基本信息 -->
<div class="space-y-3">
<h3 class="text-sm font-medium border-b pb-2">
基本信息
</h3>
<div class="grid grid-cols-2 gap-4">
<div class="grid grid-cols-2 gap-3">
<div>
<Label :for="keyNameInputId">密钥名称 *</Label>
<Input
@@ -36,23 +32,6 @@
data-1p-ignore="true"
/>
</div>
<div>
<Label for="rate_multiplier">成本倍率 *</Label>
<Input
id="rate_multiplier"
v-model.number="form.rate_multiplier"
type="number"
step="0.01"
min="0.01"
required
placeholder="1.0"
/>
<p class="text-xs text-muted-foreground mt-1">
真实成本 = 表面成本 × 倍率
</p>
</div>
</div>
<div>
<Label :for="apiKeyInputId">API 密钥 {{ editingKey ? '' : '*' }}</Label>
<Input
@@ -83,10 +62,12 @@
v-else-if="editingKey"
class="text-xs text-muted-foreground mt-1"
>
留空表示不修改输入新值则覆盖
留空表示不修改
</p>
</div>
</div>
<!-- 备注 -->
<div>
<Label for="note">备注</Label>
<Input
@@ -95,98 +76,115 @@
placeholder="可选的备注信息"
/>
</div>
<!-- API 格式选择 -->
<div v-if="sortedApiFormats.length > 0">
<Label class="mb-1.5 block">支持的 API 格式 *</Label>
<div class="grid grid-cols-2 sm:grid-cols-3 gap-2">
<div
v-for="format in sortedApiFormats"
:key="format"
class="flex items-center justify-between rounded-md border px-2 py-1.5 transition-colors cursor-pointer"
:class="form.api_formats.includes(format)
? 'bg-primary/5 border-primary/30'
: 'bg-muted/30 border-border hover:border-muted-foreground/30'"
@click="toggleApiFormat(format)"
>
<div class="flex items-center gap-1.5 min-w-0">
<span
class="w-4 h-4 rounded border flex items-center justify-center text-xs shrink-0"
:class="form.api_formats.includes(format)
? 'bg-primary border-primary text-primary-foreground'
: 'border-muted-foreground/30'"
>
<span v-if="form.api_formats.includes(format)"></span>
</span>
<span
class="text-sm whitespace-nowrap"
:class="form.api_formats.includes(format) ? 'text-primary' : 'text-muted-foreground'"
>{{ API_FORMAT_LABELS[format] || format }}</span>
</div>
<div
class="flex items-center shrink-0 ml-2 text-xs text-muted-foreground gap-1"
@click.stop
>
<span>×</span>
<input
:value="form.rate_multipliers[format] ?? ''"
type="number"
step="0.01"
min="0.01"
placeholder="1"
class="w-9 bg-transparent text-right outline-none [appearance:textfield] [&::-webkit-outer-spin-button]:appearance-none [&::-webkit-inner-spin-button]:appearance-none"
:class="form.api_formats.includes(format) ? 'text-primary' : 'text-muted-foreground'"
title="成本倍率"
@input="(e) => updateRateMultiplier(format, (e.target as HTMLInputElement).value)"
>
</div>
</div>
</div>
</div>
<!-- 调度与限流 -->
<div class="space-y-3">
<h3 class="text-sm font-medium border-b pb-2">
调度与限流
</h3>
<div class="grid grid-cols-2 gap-4">
<!-- 配置项 -->
<div class="grid grid-cols-4 gap-3">
<div>
<Label for="internal_priority">内部优先级</Label>
<Label
for="internal_priority"
class="text-xs"
>优先级</Label>
<Input
id="internal_priority"
v-model.number="form.internal_priority"
type="number"
min="0"
class="h-8"
/>
<p class="text-xs text-muted-foreground mt-1">
数字越小越优先
<p class="text-xs text-muted-foreground mt-0.5">
越小越优先
</p>
</div>
<div>
<Label for="max_concurrent">最大并发</Label>
<Label
for="rpm_limit"
class="text-xs"
>RPM 限制</Label>
<Input
id="max_concurrent"
:model-value="form.max_concurrent ?? ''"
id="rpm_limit"
:model-value="form.rpm_limit ?? ''"
type="number"
min="1"
placeholder="留空启用自适应"
@update:model-value="(v) => form.max_concurrent = parseNumberInput(v)"
max="10000"
placeholder="自适应"
class="h-8"
@update:model-value="(v) => form.rpm_limit = parseNullableNumberInput(v, { min: 1, max: 10000 })"
/>
<p class="text-xs text-muted-foreground mt-1">
留空 = 自适应模式
<p class="text-xs text-muted-foreground mt-0.5">
留空自适应
</p>
</div>
</div>
<div class="grid grid-cols-3 gap-4">
<div>
<Label for="rate_limit">速率限制(/分钟)</Label>
<Input
id="rate_limit"
:model-value="form.rate_limit ?? ''"
type="number"
min="1"
@update:model-value="(v) => form.rate_limit = parseNumberInput(v)"
/>
</div>
<div>
<Label for="daily_limit">每日限制</Label>
<Input
id="daily_limit"
:model-value="form.daily_limit ?? ''"
type="number"
min="1"
@update:model-value="(v) => form.daily_limit = parseNumberInput(v)"
/>
</div>
<div>
<Label for="monthly_limit">每月限制</Label>
<Input
id="monthly_limit"
:model-value="form.monthly_limit ?? ''"
type="number"
min="1"
@update:model-value="(v) => form.monthly_limit = parseNumberInput(v)"
/>
</div>
</div>
</div>
<!-- 缓存与熔断 -->
<div class="space-y-3">
<h3 class="text-sm font-medium border-b pb-2">
缓存与熔断
</h3>
<div class="grid grid-cols-2 gap-4">
<div>
<Label for="cache_ttl_minutes">缓存 TTL (分钟)</Label>
<Label
for="cache_ttl_minutes"
class="text-xs"
>缓存 TTL</Label>
<Input
id="cache_ttl_minutes"
:model-value="form.cache_ttl_minutes ?? ''"
type="number"
min="0"
max="60"
class="h-8"
@update:model-value="(v) => form.cache_ttl_minutes = parseNumberInput(v, { min: 0, max: 60 }) ?? 5"
/>
<p class="text-xs text-muted-foreground mt-1">
0 = 禁用缓存亲和性
<p class="text-xs text-muted-foreground mt-0.5">
分钟0禁用
</p>
</div>
<div>
<Label for="max_probe_interval_minutes">熔断探测间隔 (分钟)</Label>
<Label
for="max_probe_interval_minutes"
class="text-xs"
>熔断探测</Label>
<Input
id="max_probe_interval_minutes"
:model-value="form.max_probe_interval_minutes ?? ''"
@@ -194,37 +192,31 @@
min="2"
max="32"
placeholder="32"
class="h-8"
@update:model-value="(v) => form.max_probe_interval_minutes = parseNumberInput(v, { min: 2, max: 32 }) ?? 32"
/>
<p class="text-xs text-muted-foreground mt-1">
范围 2-32 分钟
<p class="text-xs text-muted-foreground mt-0.5">
分钟,2-32
</p>
</div>
</div>
</div>
<!-- 能力标签配置 -->
<div
v-if="availableCapabilities.length > 0"
class="space-y-3"
>
<h3 class="text-sm font-medium border-b pb-2">
能力标签
</h3>
<div class="flex flex-wrap gap-2">
<label
<!-- 能力标签 -->
<div v-if="availableCapabilities.length > 0">
<Label class="text-xs mb-1.5 block">能力标签</Label>
<div class="flex flex-wrap gap-1.5">
<button
v-for="cap in availableCapabilities"
:key="cap.name"
class="flex items-center gap-2 px-3 py-1.5 rounded-md border border-border bg-muted/30 cursor-pointer text-sm"
type="button"
class="inline-flex items-center gap-1.5 px-2.5 py-1 rounded-md border text-sm transition-colors"
:class="form.capabilities[cap.name]
? 'bg-primary/10 border-primary/50 text-primary'
: 'bg-card border-border hover:bg-muted/50 text-muted-foreground'"
@click="form.capabilities[cap.name] = !form.capabilities[cap.name]"
>
<input
type="checkbox"
:checked="form.capabilities[cap.name] || false"
class="rounded"
@change="form.capabilities[cap.name] = !form.capabilities[cap.name]"
>
<span>{{ cap.display_name }}</span>
</label>
{{ cap.display_name }}
</button>
</div>
</div>
</form>
@@ -240,25 +232,27 @@
:disabled="saving"
@click="handleSave"
>
{{ saving ? '保存中...' : '保存' }}
{{ saving ? (isEditMode ? '保存中...' : '添加中...') : (isEditMode ? '保存' : '添加') }}
</Button>
</template>
</Dialog>
</template>
<script setup lang="ts">
import { ref, computed, onMounted } from 'vue'
import { ref, computed, onMounted, watch } from 'vue'
import { Dialog, Button, Input, Label } from '@/components/ui'
import { Key, SquarePen } from 'lucide-vue-next'
import { useToast } from '@/composables/useToast'
import { useFormDialog } from '@/composables/useFormDialog'
import { parseApiError } from '@/utils/errorParser'
import { parseNumberInput } from '@/utils/form'
import { parseNumberInput, parseNullableNumberInput } from '@/utils/form'
import { log } from '@/utils/logger'
import {
addEndpointKey,
updateEndpointKey,
addProviderKey,
updateProviderKey,
getAllCapabilities,
API_FORMAT_LABELS,
sortApiFormats,
type EndpointAPIKey,
type EndpointAPIKeyUpdate,
type ProviderEndpoint,
@@ -270,6 +264,7 @@ const props = defineProps<{
endpoint: ProviderEndpoint | null
editingKey: EndpointAPIKey | null
providerId: string | null
availableApiFormats: string[] // Provider 支持的所有 API 格式
}>()
const emit = defineEmits<{
@@ -279,6 +274,9 @@ const emit = defineEmits<{
const { success, error: showError } = useToast()
// 排序后的可用 API 格式列表
const sortedApiFormats = computed(() => sortApiFormats(props.availableApiFormats))
const isOpen = computed(() => props.open)
const saving = ref(false)
const formNonce = ref(createFieldNonce())
@@ -297,12 +295,10 @@ const availableCapabilities = ref<CapabilityDefinition[]>([])
const form = ref({
name: '',
api_key: '',
rate_multiplier: 1.0,
internal_priority: 50,
max_concurrent: undefined as number | undefined,
rate_limit: undefined as number | undefined,
daily_limit: undefined as number | undefined,
monthly_limit: undefined as number | undefined,
api_formats: [] as string[], // 支持的 API 格式列表
rate_multipliers: {} as Record<string, number>, // 按 API 格式的成本倍率
internal_priority: 10,
rpm_limit: undefined as number | null | undefined, // RPM 限制null=自适应undefined=保持原值)
cache_ttl_minutes: 5,
max_probe_interval_minutes: 32,
note: '',
@@ -323,6 +319,43 @@ onMounted(() => {
loadCapabilities()
})
// API 格式切换
function toggleApiFormat(format: string) {
const index = form.value.api_formats.indexOf(format)
if (index === -1) {
// 添加格式
form.value.api_formats.push(format)
} else {
// 移除格式前检查:至少保留一个格式
if (form.value.api_formats.length <= 1) {
showError('至少需要选择一个 API 格式', '验证失败')
return
}
// 移除格式,但保留倍率配置(用户可能只是临时取消)
form.value.api_formats.splice(index, 1)
}
}
// 更新指定格式的成本倍率
function updateRateMultiplier(format: string, value: string | number) {
// 使用对象替换以确保 Vue 3 响应性
const newMultipliers = { ...form.value.rate_multipliers }
if (value === '' || value === null || value === undefined) {
// 清空时删除该格式的配置(使用默认倍率)
delete newMultipliers[format]
} else {
const numValue = typeof value === 'string' ? parseFloat(value) : value
// 限制倍率范围0.01 - 100
if (!isNaN(numValue) && numValue >= 0.01 && numValue <= 100) {
newMultipliers[format] = numValue
}
}
// 替换整个对象以触发响应式更新
form.value.rate_multipliers = newMultipliers
}
// API 密钥输入框样式计算
function getApiKeyInputClass(): string {
const classes = []
@@ -363,12 +396,10 @@ function resetForm() {
form.value = {
name: '',
api_key: '',
rate_multiplier: 1.0,
internal_priority: 50,
max_concurrent: undefined,
rate_limit: undefined,
daily_limit: undefined,
monthly_limit: undefined,
api_formats: [], // 默认不选中任何格式
rate_multipliers: {},
internal_priority: 10,
rpm_limit: undefined,
cache_ttl_minutes: 5,
max_probe_interval_minutes: 32,
note: '',
@@ -377,6 +408,14 @@ function resetForm() {
}
}
// 添加成功后清除部分字段以便继续添加
function clearForNextAdd() {
formNonce.value = createFieldNonce()
apiKeyFocused.value = false
form.value.name = ''
form.value.api_key = ''
}
// 加载密钥数据(编辑模式)
function loadKeyData() {
if (!props.editingKey) return
@@ -385,13 +424,13 @@ function loadKeyData() {
form.value = {
name: props.editingKey.name,
api_key: '',
rate_multiplier: props.editingKey.rate_multiplier || 1.0,
internal_priority: props.editingKey.internal_priority ?? 50,
api_formats: props.editingKey.api_formats?.length > 0
? [...props.editingKey.api_formats]
: [], // 编辑模式下保持原有选择,不默认全选
rate_multipliers: { ...(props.editingKey.rate_multipliers || {}) },
internal_priority: props.editingKey.internal_priority ?? 10,
// 保留原始的 null/undefined 状态null 表示自适应模式
max_concurrent: props.editingKey.max_concurrent ?? undefined,
rate_limit: props.editingKey.rate_limit ?? undefined,
daily_limit: props.editingKey.daily_limit ?? undefined,
monthly_limit: props.editingKey.monthly_limit ?? undefined,
rpm_limit: props.editingKey.rpm_limit ?? undefined,
cache_ttl_minutes: props.editingKey.cache_ttl_minutes ?? 5,
max_probe_interval_minutes: props.editingKey.max_probe_interval_minutes ?? 32,
note: props.editingKey.note || '',
@@ -415,7 +454,11 @@ function createFieldNonce(): string {
}
async function handleSave() {
if (!props.endpoint) return
// 必须有 providerId
if (!props.providerId) {
showError('无法保存:缺少提供商信息', '错误')
return
}
// 提交前验证
if (apiKeyError.value) {
@@ -429,6 +472,12 @@ async function handleSave() {
return
}
// 验证至少选择一个 API 格式
if (form.value.api_formats.length === 0) {
showError('请至少选择一个 API 格式', '验证失败')
return
}
// 过滤出有效的能力配置(只包含值为 true 的)
const activeCapabilities: Record<string, boolean> = {}
for (const [key, value] of Object.entries(form.value.capabilities)) {
@@ -440,21 +489,27 @@ async function handleSave() {
saving.value = true
try {
// 准备 rate_multipliers 数据:只保留已选中格式的倍率配置
const filteredMultipliers: Record<string, number> = {}
for (const format of form.value.api_formats) {
if (form.value.rate_multipliers[format] !== undefined) {
filteredMultipliers[format] = form.value.rate_multipliers[format]
}
}
const rateMultipliersData = Object.keys(filteredMultipliers).length > 0
? filteredMultipliers
: null
if (props.editingKey) {
// 更新模式
// 注意:max_concurrent 需要显式发送 null 来切换到自适应模式
// undefined 会在 JSON 中被忽略,所以用 null 表示"清空/自适应"
// 注意:rpm_limit 使用 null 表示自适应模式
// undefined 表示"保持原值不变"会在 JSON 序列化时被忽略
const updateData: EndpointAPIKeyUpdate = {
api_formats: form.value.api_formats,
name: form.value.name,
rate_multiplier: form.value.rate_multiplier,
rate_multipliers: rateMultipliersData,
internal_priority: form.value.internal_priority,
// 显式使用 null 表示自适应模式,这样后端能区分"未提供"和"设置为 null"
// 注意:只有 max_concurrent 需要这种处理,因为它有"自适应模式"的概念
// 其他限制字段rate_limit 等)不支持"清空"操作undefined 会被 JSON 忽略即不更新
max_concurrent: form.value.max_concurrent === undefined ? null : form.value.max_concurrent,
rate_limit: form.value.rate_limit,
daily_limit: form.value.daily_limit,
monthly_limit: form.value.monthly_limit,
rpm_limit: form.value.rpm_limit,
cache_ttl_minutes: form.value.cache_ttl_minutes,
max_probe_interval_minutes: form.value.max_probe_interval_minutes,
note: form.value.note,
@@ -466,26 +521,27 @@ async function handleSave() {
updateData.api_key = form.value.api_key
}
await updateEndpointKey(props.editingKey.id, updateData)
await updateProviderKey(props.editingKey.id, updateData)
success('密钥已更新', '成功')
} else {
// 新增
await addEndpointKey(props.endpoint.id, {
endpoint_id: props.endpoint.id,
// 新增模式
await addProviderKey(props.providerId, {
api_formats: form.value.api_formats,
api_key: form.value.api_key,
name: form.value.name,
rate_multiplier: form.value.rate_multiplier,
rate_multipliers: rateMultipliersData,
internal_priority: form.value.internal_priority,
max_concurrent: form.value.max_concurrent,
rate_limit: form.value.rate_limit,
daily_limit: form.value.daily_limit,
monthly_limit: form.value.monthly_limit,
rpm_limit: form.value.rpm_limit,
cache_ttl_minutes: form.value.cache_ttl_minutes,
max_probe_interval_minutes: form.value.max_probe_interval_minutes,
note: form.value.note,
capabilities: capabilitiesData || undefined
})
success('密钥已添加', '成功')
// 添加模式:不关闭对话框,只清除名称和密钥以便继续添加
emit('saved')
clearForNextAdd()
return
}
emit('saved')

View File

@@ -95,7 +95,7 @@
<!-- 提供商信息 -->
<div class="flex-1 min-w-0 flex items-center gap-2">
<span class="font-medium text-sm truncate">{{ provider.display_name }}</span>
<span class="font-medium text-sm truncate">{{ provider.name }}</span>
<Badge
v-if="!provider.is_active"
variant="secondary"
@@ -395,7 +395,7 @@ import { Dialog } from '@/components/ui'
import Button from '@/components/ui/button.vue'
import Badge from '@/components/ui/badge.vue'
import { useToast } from '@/composables/useToast'
import { updateProvider, updateEndpointKey } from '@/api/endpoints'
import { updateProvider, updateProviderKey } from '@/api/endpoints'
import type { ProviderWithEndpointsSummary } from '@/api/endpoints'
import { adminApi } from '@/api/admin'
@@ -696,7 +696,7 @@ async function save() {
const keys = keysByFormat.value[format]
keys.forEach((key) => {
// 使用用户设置的 priority 值,相同 priority 会做负载均衡
keyUpdates.push(updateEndpointKey(key.id, { global_priority: key.priority }))
keyUpdates.push(updateProviderKey(key.id, { global_priority: key.priority }))
})
}

View File

@@ -25,12 +25,12 @@
<template v-else-if="provider">
<!-- 头部:名称 + 快捷操作 -->
<div class="sticky top-0 z-10 bg-background border-b p-4 sm:p-6">
<div class="sticky top-0 z-10 bg-background border-b px-4 sm:px-6 pt-4 sm:pt-6 pb-3 sm:pb-3">
<div class="flex items-start justify-between gap-3 sm:gap-4">
<div class="space-y-1 flex-1 min-w-0">
<div class="flex items-center gap-2">
<h2 class="text-lg sm:text-xl font-bold truncate">
{{ provider.display_name }}
{{ provider.name }}
</h2>
<Badge
:variant="provider.is_active ? 'default' : 'secondary'"
@@ -39,9 +39,11 @@
{{ provider.is_active ? '活跃' : '已停用' }}
</Badge>
</div>
<div class="flex items-center gap-2 flex-wrap">
<span class="text-sm text-muted-foreground font-mono">{{ provider.name }}</span>
<template v-if="provider.website">
<!-- 网站链接 -->
<div
v-if="provider.website"
class="flex items-center gap-2"
>
<span class="text-muted-foreground">·</span>
<a
:href="provider.website"
@@ -49,10 +51,7 @@
rel="noopener noreferrer"
class="text-xs text-primary hover:underline truncate"
title="访问官网"
>
{{ provider.website }}
</a>
</template>
>{{ provider.website }}</a>
</div>
</div>
<div class="flex items-center gap-1 shrink-0">
@@ -82,6 +81,22 @@
</Button>
</div>
</div>
<!-- 端点 API 格式 -->
<div class="flex items-center gap-1.5 flex-wrap mt-3">
<template v-for="endpoint in endpoints" :key="endpoint.id">
<span
class="text-xs px-2 py-0.5 rounded-md border border-border bg-background hover:bg-accent hover:border-accent-foreground/20 cursor-pointer transition-colors font-medium"
:class="{ 'opacity-40': !endpoint.is_active }"
:title="`编辑 ${API_FORMAT_LABELS[endpoint.api_format]} 端点`"
@click="handleEditEndpoint(endpoint)"
>{{ API_FORMAT_LABELS[endpoint.api_format] || endpoint.api_format }}</span>
</template>
<span
class="text-xs px-2 py-0.5 rounded-md border border-dashed border-border hover:bg-accent hover:border-accent-foreground/20 cursor-pointer transition-colors text-muted-foreground"
title="编辑端点"
@click="showAddEndpointDialog"
>编辑</span>
</div>
</div>
<div class="space-y-6 p-4 sm:p-6">
@@ -127,241 +142,43 @@
</div>
</Card>
<!-- 端点与密钥管理 -->
<!-- 密钥管理 -->
<Card class="overflow-hidden">
<div class="p-4 border-b border-border/60">
<div class="flex items-center justify-between">
<h3 class="text-sm font-semibold flex items-center gap-2">
<span>端点与密钥管理</span>
<h3 class="text-sm font-semibold">
密钥管理
</h3>
<Button
v-if="endpoints.length > 0"
variant="outline"
size="sm"
class="h-8"
@click="showAddEndpointDialog"
@click="handleAddKeyToFirstEndpoint"
>
<Plus class="w-3.5 h-3.5 mr-1.5" />
添加端点
添加密钥
</Button>
</div>
</div>
<!-- 端点列表 -->
<div
v-if="endpoints.length > 0"
class="divide-y divide-border/40"
>
<div
v-for="endpoint in endpoints"
:key="endpoint.id"
class="group"
>
<!-- 端点头部 - 可点击展开/收起 -->
<div
class="p-4 hover:bg-muted/30 transition-colors cursor-pointer"
@click="toggleEndpoint(endpoint.id)"
>
<div class="flex items-center justify-between">
<div class="flex items-center gap-3 flex-1 min-w-0">
<ChevronRight
class="w-4 h-4 text-muted-foreground transition-transform shrink-0"
:class="{ 'rotate-90': expandedEndpoints.has(endpoint.id) }"
/>
<div class="flex-1 min-w-0">
<div class="flex items-center gap-2">
<span class="text-sm font-medium">{{ endpoint.api_format }}</span>
<Badge
v-if="!endpoint.is_active"
variant="secondary"
class="text-[10px] px-1.5 py-0"
>
已停用
</Badge>
<span class="text-xs text-muted-foreground flex items-center gap-1">
<Key class="w-3 h-3" />
{{ endpoint.keys?.filter((k: EndpointAPIKey) => k.is_active).length || 0 }}
</span>
<span
v-if="endpoint.max_retries"
class="text-xs text-muted-foreground"
>
{{ endpoint.max_retries }}次重试
</span>
<span
v-if="endpoint.timeout"
class="text-xs text-muted-foreground"
>
{{ endpoint.timeout }}s
</span>
</div>
<div class="flex items-center gap-1.5 mt-0.5">
<span class="text-xs text-muted-foreground font-mono truncate">
{{ endpoint.base_url }}
</span>
<Button
variant="ghost"
size="icon"
class="h-5 w-5 shrink-0"
title="复制 Base URL"
@click.stop="copyToClipboard(endpoint.base_url)"
>
<Copy class="w-3 h-3" />
</Button>
</div>
</div>
</div>
<div
class="flex items-center gap-1"
@click.stop
>
<Button
v-if="hasUnhealthyKeys(endpoint)"
variant="ghost"
size="icon"
class="h-8 w-8 text-green-600"
title="恢复所有密钥健康状态"
:disabled="recoveringEndpointId === endpoint.id"
@click="handleRecoverAllKeys(endpoint)"
>
<Loader2
v-if="recoveringEndpointId === endpoint.id"
class="w-3.5 h-3.5 animate-spin"
/>
<RefreshCw
v-else
class="w-3.5 h-3.5"
/>
</Button>
<Button
variant="ghost"
size="icon"
class="h-8 w-8"
title="添加密钥"
@click="handleAddKey(endpoint)"
>
<Plus class="w-4 h-4" />
</Button>
<Button
variant="ghost"
size="icon"
class="h-8 w-8"
title="编辑端点"
@click="handleEditEndpoint(endpoint)"
>
<Edit class="w-3.5 h-3.5" />
</Button>
<Button
variant="ghost"
size="icon"
class="h-8 w-8"
:disabled="togglingEndpointId === endpoint.id"
:title="endpoint.is_active ? '点击停用' : '点击启用'"
@click="toggleEndpointActive(endpoint)"
>
<Power class="w-3.5 h-3.5" />
</Button>
<Button
variant="ghost"
size="icon"
class="h-8 w-8"
title="删除端点"
@click="handleDeleteEndpoint(endpoint)"
>
<Trash2 class="w-3.5 h-3.5" />
</Button>
</div>
</div>
</div>
<!-- 端点详情 - 可展开区域 -->
<div
v-if="expandedEndpoints.has(endpoint.id)"
class="px-4 pb-4 bg-muted/20 border-t border-border/40"
>
<div class="space-y-3 pt-3">
<!-- 端点配置信息 -->
<div
v-if="endpoint.custom_path || endpoint.rpm_limit"
class="flex flex-wrap gap-x-4 gap-y-1 text-xs"
>
<div v-if="endpoint.custom_path">
<span class="text-muted-foreground">自定义路径:</span>
<span class="ml-1 font-mono">{{ endpoint.custom_path }}</span>
</div>
<div v-if="endpoint.rpm_limit">
<span class="text-muted-foreground">RPM:</span>
<span class="ml-1 font-medium">{{ endpoint.rpm_limit }}</span>
</div>
</div>
<!-- 密钥列表 -->
<div class="space-y-2">
<div
v-if="endpoint.keys && endpoint.keys.length > 0"
class="space-y-2"
v-if="allKeys.length > 0"
class="divide-y divide-border/40"
>
<div
v-for="key in endpoint.keys"
v-for="{ key, endpoint } in allKeys"
:key="key.id"
draggable="true"
class="p-3 bg-background rounded-md border transition-all duration-150 group/key"
:class="{
'border-border/40 hover:border-border/80': dragState.targetKeyId !== key.id,
'border-primary border-2 bg-primary/5': dragState.targetKeyId === key.id,
'opacity-50': dragState.draggedKeyId === key.id,
'cursor-grabbing': dragState.isDragging
}"
@dragstart="handleDragStart($event, key, endpoint)"
@dragend="handleDragEnd"
@dragover="handleDragOver($event, key)"
@dragleave="handleDragLeave"
@drop="handleDrop($event, key, endpoint)"
class="px-4 py-2.5 hover:bg-muted/30 transition-colors"
>
<!-- 密钥主要信息行 -->
<div class="flex items-center justify-between mb-2">
<!-- 第一行名称 + 状态 + 操作按钮 -->
<div class="flex items-center justify-between gap-2">
<div class="flex items-center gap-2 flex-1 min-w-0">
<!-- 拖动手柄 -->
<div
class="cursor-grab active:cursor-grabbing text-muted-foreground/50 hover:text-muted-foreground"
title="拖动排序"
>
<GripVertical class="w-4 h-4" />
</div>
<div class="min-w-0">
<div class="flex items-center gap-1.5">
<span class="text-xs font-medium truncate">{{ key.name || '未命名密钥' }}</span>
<Badge
:variant="key.is_active ? 'default' : 'secondary'"
class="text-[10px] px-1.5 py-0 shrink-0"
>
{{ key.is_active ? '活跃' : '禁用' }}
</Badge>
</div>
<div class="flex items-center gap-1">
<span class="text-[10px] font-mono text-muted-foreground truncate max-w-[180px]">
{{ revealedKeys.has(key.id) ? revealedKeys.get(key.id) : key.api_key_masked }}
<span class="text-sm font-medium truncate">{{ key.name || '未命名密钥' }}</span>
<span class="text-xs font-mono text-muted-foreground">
{{ key.api_key_masked }}
</span>
<Button
variant="ghost"
size="icon"
class="h-5 w-5 shrink-0"
:title="revealedKeys.has(key.id) ? '隐藏密钥' : '显示密钥'"
:disabled="revealingKeyId === key.id"
@click.stop="toggleKeyReveal(key)"
>
<Loader2
v-if="revealingKeyId === key.id"
class="w-3 h-3 animate-spin"
/>
<EyeOff
v-else-if="revealedKeys.has(key.id)"
class="w-3 h-3"
/>
<Eye
v-else
class="w-3 h-3"
/>
</Button>
<Button
variant="ghost"
size="icon"
@@ -371,14 +188,36 @@
>
<Copy class="w-3 h-3" />
</Button>
<Badge
v-if="!key.is_active"
variant="secondary"
class="text-[10px] px-1.5 py-0 shrink-0"
>
禁用
</Badge>
<Badge
v-if="key.circuit_breaker_open"
variant="destructive"
class="text-[10px] px-1.5 py-0 shrink-0"
>
熔断
</Badge>
</div>
</div>
<div class="flex items-center gap-1.5 ml-auto shrink-0">
<!-- 并发 + 健康度 + 操作按钮 -->
<div class="flex items-center gap-1 shrink-0">
<!-- RPM 限制信息放在最前面 -->
<span
v-if="key.rpm_limit || key.is_adaptive"
class="text-[10px] text-muted-foreground mr-1"
>
{{ key.is_adaptive ? '自适应' : key.rpm_limit }} RPM
</span>
<!-- 健康度 -->
<div
v-if="key.health_score !== undefined"
class="flex items-center gap-1"
class="flex items-center gap-1 mr-1"
>
<div class="w-12 h-1 bg-muted/80 rounded-full overflow-hidden">
<div class="w-10 h-1.5 bg-muted/80 rounded-full overflow-hidden">
<div
class="h-full transition-all duration-300"
:class="getHealthScoreBarColor(key.health_score || 0)"
@@ -386,22 +225,12 @@
/>
</div>
<span
class="text-[10px] font-bold tabular-nums w-[30px] text-right"
class="text-[10px] font-medium tabular-nums"
:class="getHealthScoreColor(key.health_score || 0)"
>
{{ ((key.health_score || 0) * 100).toFixed(0) }}%
</span>
</div>
<Badge
v-if="key.circuit_breaker_open"
variant="destructive"
class="text-[10px] px-1.5 py-0"
>
熔断
</Badge>
</div>
</div>
<div class="flex items-center gap-1 ml-2">
<Button
v-if="key.circuit_breaker_open || (key.health_score !== undefined && key.health_score < 0.5)"
variant="ghost"
@@ -410,16 +239,16 @@
title="刷新健康状态"
@click="handleRecoverKey(key)"
>
<RefreshCw class="w-3 h-3" />
<RefreshCw class="w-3.5 h-3.5" />
</Button>
<Button
variant="ghost"
size="icon"
class="h-7 w-7"
title="配置允许的模型"
@click="handleConfigKeyModels(key)"
title="模型权限"
@click="handleKeyPermissions(key)"
>
<Layers class="w-3 h-3" />
<Shield class="w-3.5 h-3.5" />
</Button>
<Button
variant="ghost"
@@ -428,7 +257,7 @@
title="编辑密钥"
@click="handleEditKey(endpoint, key)"
>
<Edit class="w-3 h-3" />
<Edit class="w-3.5 h-3.5" />
</Button>
<Button
variant="ghost"
@@ -438,7 +267,7 @@
:title="key.is_active ? '点击停用' : '点击启用'"
@click="toggleKeyActive(key)"
>
<Power class="w-3 h-3" />
<Power class="w-3.5 h-3.5" />
</Button>
<Button
variant="ghost"
@@ -447,90 +276,46 @@
title="删除密钥"
@click="handleDeleteKey(key)"
>
<Trash2 class="w-3 h-3" />
<Trash2 class="w-3.5 h-3.5" />
</Button>
</div>
</div>
<!-- 密钥详细信息 -->
<div class="flex items-center text-[11px]">
<!-- 左侧固定信息 -->
<div class="flex items-center gap-2">
<!-- 可点击编辑的优先级 -->
<!-- 第二行优先级 + API 格式展开显示 + 统计信息 -->
<div class="flex items-center gap-1.5 mt-1 text-[11px] text-muted-foreground">
<!-- 优先级放最前面支持点击编辑 -->
<span
v-if="editingPriorityKey !== key.id"
class="text-muted-foreground cursor-pointer hover:text-foreground hover:bg-muted/50 px-1 rounded transition-colors"
title="点击编辑优先级,数字越小优先级越高"
title="点击编辑优先级"
class="font-medium text-foreground/80 cursor-pointer hover:text-primary hover:underline"
@click="startEditPriority(key)"
>
P {{ key.internal_priority }}
</span>
<!-- 编辑模式 -->
<span
v-else
class="flex items-center gap-1"
>
<span class="text-muted-foreground">P</span>
>P{{ key.internal_priority }}</span>
<input
ref="priorityInput"
v-model.number="editingPriorityValue"
type="number"
class="w-12 h-5 px-1 text-[11px] border rounded bg-background focus:outline-none focus:ring-1 focus:ring-primary"
min="0"
@keyup.enter="savePriority(key, endpoint)"
@keyup.escape="cancelEditPriority"
@blur="savePriority(key, endpoint)"
v-else
ref="priorityInputRef"
v-model="editingPriorityValue"
type="text"
inputmode="numeric"
pattern="[0-9]*"
class="w-8 h-5 px-1 text-[11px] text-center border rounded bg-background focus:outline-none focus:ring-1 focus:ring-primary font-medium text-foreground/80"
@keydown="(e) => handlePriorityKeydown(e, key)"
@blur="handlePriorityBlur(key)"
>
</span>
<span
class="text-muted-foreground"
title="成本倍率,实际成本 = 模型价格 × 倍率"
<span class="text-muted-foreground/40">|</span>
<!-- API 格式展开显示每个格式和倍率 -->
<template
v-for="(format, idx) in getKeyApiFormats(key, endpoint)"
:key="format"
>
{{ key.rate_multiplier }}x
</span>
<span
v-if="key.success_rate !== undefined"
class="text-muted-foreground"
title="成功率 = 成功次数 / 总请求数"
>
{{ (key.success_rate * 100).toFixed(1) }}% ({{ key.success_count }}/{{ key.request_count }})
</span>
</div>
<!-- 右侧动态信息 -->
<div class="flex items-center gap-2 ml-auto">
<span v-if="idx > 0" class="text-muted-foreground/40">/</span>
<span>{{ API_FORMAT_SHORT[format] || format }} {{ getKeyRateMultiplier(key, format) }}x</span>
</template>
<span v-if="key.rate_limit">| {{ key.rate_limit }}rpm</span>
<span
v-if="key.next_probe_at"
class="text-amber-600 dark:text-amber-400"
title="熔断器探测恢复时间"
>
{{ formatProbeTime(key.next_probe_at) }}探测
| {{ formatProbeTime(key.next_probe_at) }}探测
</span>
<span
v-if="key.rate_limit"
class="text-muted-foreground"
title="每分钟请求数限制"
>
{{ key.rate_limit }}rpm
</span>
<span
v-if="key.max_concurrent || key.is_adaptive"
class="text-muted-foreground"
:title="key.is_adaptive ? `自适应并发限制(学习值: ${key.learned_max_concurrent ?? '未学习'}` : `固定并发限制: ${key.max_concurrent}`"
>
{{ key.is_adaptive ? '自适应' : '固定' }}并发: {{ key.is_adaptive ? (key.learned_max_concurrent ?? '学习中') : key.max_concurrent }}
</span>
</div>
</div>
</div>
</div>
<div
v-else
class="text-xs text-muted-foreground text-center py-4"
>
暂无密钥
</div>
</div>
</div>
</div>
</div>
</div>
@@ -540,12 +325,12 @@
v-else
class="p-8 text-center text-muted-foreground"
>
<Server class="w-12 h-12 mx-auto mb-3 opacity-50" />
<Key class="w-12 h-12 mx-auto mb-3 opacity-50" />
<p class="text-sm">
暂无端点配置
暂无密钥配置
</p>
<p class="text-xs mt-1">
点击上方"添加端点"按钮创建第一个端点
{{ endpoints.length > 0 ? '点击上方"添加密钥"按钮创建第一个密钥' : '请先添加端点,然后再添加密钥' }}
</p>
</div>
</Card>
@@ -575,12 +360,12 @@
</Transition>
</Teleport>
<!-- 端点表单对话框添加/编辑 -->
<!-- 端点表单对话框管理/编辑 -->
<EndpointFormDialog
v-if="provider && open"
v-model="endpointDialogOpen"
:provider="provider"
:endpoint="endpointToEdit"
:endpoints="endpoints"
@endpoint-created="handleEndpointChanged"
@endpoint-updated="handleEndpointChanged"
/>
@@ -606,17 +391,18 @@
:endpoint="currentEndpoint"
:editing-key="editingKey"
:provider-id="provider ? provider.id : null"
:available-api-formats="provider?.api_formats || []"
@close="keyFormDialogOpen = false"
@saved="handleKeyChanged"
/>
<!-- 密钥允许模型配置对话框 -->
<KeyAllowedModelsDialog
<!-- 模型权限对话框 -->
<KeyAllowedModelsEditDialog
v-if="open"
:open="keyAllowedModelsDialogOpen"
:open="keyPermissionsDialogOpen"
:api-key="editingKey"
:provider-id="provider ? provider.id : null"
@close="keyAllowedModelsDialogOpen = false"
:provider-id="providerId || ''"
@close="keyPermissionsDialogOpen = false"
@saved="handleKeyChanged"
/>
@@ -639,7 +425,7 @@
v-if="open && provider"
:open="modelFormDialogOpen"
:provider-id="provider.id"
:provider-name="provider.display_name"
:provider-name="provider.name"
:editing-model="editingModel"
@update:open="modelFormDialogOpen = $event"
@saved="handleModelSaved"
@@ -650,7 +436,7 @@
v-if="open"
:model-value="deleteModelConfirmOpen"
title="移除模型支持"
:description="`确定要移除提供商 ${provider?.display_name} 对模型 ${modelToDelete?.global_model_display_name || modelToDelete?.provider_model_name} 的支持吗?这不会删除全局模型,只是该提供商将不再支持此模型。`"
:description="`确定要移除提供商 ${provider?.name} 对模型 ${modelToDelete?.global_model_display_name || modelToDelete?.provider_model_name} 的支持吗?这不会删除全局模型,只是该提供商将不再支持此模型。`"
confirm-text="移除"
cancel-text="取消"
type="danger"
@@ -664,7 +450,7 @@
v-if="open && provider"
:open="batchAssignDialogOpen"
:provider-id="provider.id"
:provider-name="provider.display_name"
:provider-name="provider.name"
:provider-identifier="provider.name"
@update:open="batchAssignDialogOpen = $event"
@changed="handleBatchAssignChanged"
@@ -672,7 +458,7 @@
</template>
<script setup lang="ts">
import { ref, watch, computed } from 'vue'
import { ref, watch, computed, nextTick } from 'vue'
import {
Server,
Plus,
@@ -684,11 +470,12 @@ import {
X,
Loader2,
Power,
Layers,
GripVertical,
Copy,
Eye,
EyeOff
EyeOff,
ExternalLink,
Shield
} from 'lucide-vue-next'
import { useEscapeKey } from '@/composables/useEscapeKey'
import Button from '@/components/ui/button.vue'
@@ -699,7 +486,7 @@ import { useClipboard } from '@/composables/useClipboard'
import { getProvider, getProviderEndpoints } from '@/api/endpoints'
import {
KeyFormDialog,
KeyAllowedModelsDialog,
KeyAllowedModelsEditDialog,
ModelsTab,
ModelAliasesTab,
BatchAssignModelsDialog
@@ -711,14 +498,17 @@ import {
deleteEndpoint as deleteEndpointAPI,
deleteEndpointKey,
recoverKeyHealth,
getEndpointKeys,
getProviderKeys,
updateEndpoint,
updateEndpointKey,
batchUpdateKeyPriority,
updateProviderKey,
revealEndpointKey,
type ProviderEndpoint,
type EndpointAPIKey,
type Model
type Model,
API_FORMAT_LABELS,
API_FORMAT_ORDER,
API_FORMAT_SHORT,
sortApiFormats,
} from '@/api/endpoints'
import { deleteModel as deleteModelAPI } from '@/api/endpoints/models'
@@ -747,17 +537,17 @@ const { copyToClipboard } = useClipboard()
const loading = ref(false)
const provider = ref<any>(null)
const endpoints = ref<ProviderEndpointWithKeys[]>([])
const providerKeys = ref<EndpointAPIKey[]>([]) // Provider 级别的 keys
const expandedEndpoints = ref<Set<string>>(new Set())
// 端点相关状态
const endpointDialogOpen = ref(false)
const endpointToEdit = ref<ProviderEndpoint | null>(null)
const deleteEndpointConfirmOpen = ref(false)
const endpointToDelete = ref<ProviderEndpoint | null>(null)
// 密钥相关状态
const keyFormDialogOpen = ref(false)
const keyAllowedModelsDialogOpen = ref(false)
const keyPermissionsDialogOpen = ref(false)
const currentEndpoint = ref<ProviderEndpoint | null>(null)
const editingKey = ref<EndpointAPIKey | null>(null)
const deleteKeyConfirmOpen = ref(false)
@@ -791,13 +581,15 @@ const dragState = ref({
// 点击编辑优先级相关状态
const editingPriorityKey = ref<string | null>(null)
const editingPriorityValue = ref<number>(0)
const priorityInputRef = ref<HTMLInputElement[] | null>(null)
const prioritySaving = ref(false)
// 任意模态窗口打开时,阻止抽屉被误关闭
const hasBlockingDialogOpen = computed(() =>
endpointDialogOpen.value ||
deleteEndpointConfirmOpen.value ||
keyFormDialogOpen.value ||
keyAllowedModelsDialogOpen.value ||
keyPermissionsDialogOpen.value ||
deleteKeyConfirmOpen.value ||
modelFormDialogOpen.value ||
deleteModelConfirmOpen.value ||
@@ -806,6 +598,36 @@ const hasBlockingDialogOpen = computed(() =>
modelAliasesTabRef.value?.dialogOpen
)
// 所有密钥的扁平列表(带端点信息)
// key 通过 api_formats 字段确定支持的格式endpoint 可能为 undefined
const allKeys = computed(() => {
const result: { key: EndpointAPIKey; endpoint?: ProviderEndpointWithKeys }[] = []
const seenKeyIds = new Set<string>()
// 1. 先添加 Provider 级别的 keys
for (const key of providerKeys.value) {
if (!seenKeyIds.has(key.id)) {
seenKeyIds.add(key.id)
// key 没有关联特定 endpoint
result.push({ key, endpoint: undefined })
}
}
// 2. 再遍历所有端点的 keys历史数据
for (const endpoint of endpoints.value) {
if (endpoint.keys) {
for (const key of endpoint.keys) {
if (!seenKeyIds.has(key.id)) {
seenKeyIds.add(key.id)
result.push({ key, endpoint })
}
}
}
}
return result
})
// 监听 providerId 变化
watch(() => props.providerId, (newId) => {
if (newId && props.open) {
@@ -823,18 +645,18 @@ watch(() => props.open, (newOpen) => {
// 重置所有状态
provider.value = null
endpoints.value = []
providerKeys.value = [] // 清空 Provider 级别的 keys
expandedEndpoints.value.clear()
// 重置所有对话框状态
endpointDialogOpen.value = false
deleteEndpointConfirmOpen.value = false
keyFormDialogOpen.value = false
keyAllowedModelsDialogOpen.value = false
keyPermissionsDialogOpen.value = false
deleteKeyConfirmOpen.value = false
batchAssignDialogOpen.value = false
// 重置临时数据
endpointToEdit.value = null
endpointToDelete.value = null
currentEndpoint.value = null
editingKey.value = null
@@ -873,15 +695,14 @@ async function handleRelatedDataRefresh() {
emit('refresh')
}
// 显示添加端点对话框
// 显示端点管理对话框
function showAddEndpointDialog() {
endpointToEdit.value = null // 添加模式
endpointDialogOpen.value = true
}
// ===== 端点事件处理 =====
function handleEditEndpoint(endpoint: ProviderEndpoint) {
endpointToEdit.value = endpoint // 编辑模式
function handleEditEndpoint(_endpoint: ProviderEndpoint) {
// 点击任何端点都打开管理对话框
endpointDialogOpen.value = true
}
@@ -907,9 +728,8 @@ async function confirmDeleteEndpoint() {
}
async function handleEndpointChanged() {
await loadEndpoints()
await Promise.all([loadProvider(), loadEndpoints()])
emit('refresh')
endpointToEdit.value = null
}
// ===== 密钥事件处理 =====
@@ -919,15 +739,22 @@ function handleAddKey(endpoint: ProviderEndpoint) {
keyFormDialogOpen.value = true
}
function handleEditKey(endpoint: ProviderEndpoint, key: EndpointAPIKey) {
currentEndpoint.value = endpoint
// 添加密钥(如果有多个端点则添加到第一个)
function handleAddKeyToFirstEndpoint() {
if (endpoints.value.length > 0) {
handleAddKey(endpoints.value[0])
}
}
function handleEditKey(endpoint: ProviderEndpoint | undefined, key: EndpointAPIKey) {
currentEndpoint.value = endpoint || null
editingKey.value = key
keyFormDialogOpen.value = true
}
function handleConfigKeyModels(key: EndpointAPIKey) {
function handleKeyPermissions(key: EndpointAPIKey) {
editingKey.value = key
keyAllowedModelsDialogOpen.value = true
keyPermissionsDialogOpen.value = true
}
// 切换密钥显示/隐藏
@@ -1080,7 +907,7 @@ async function toggleKeyActive(key: EndpointAPIKey) {
togglingKeyId.value = key.id
try {
const newStatus = !key.is_active
await updateEndpointKey(key.id, { is_active: newStatus })
await updateProviderKey(key.id, { is_active: newStatus })
key.is_active = newStatus
showSuccess(newStatus ? '密钥已启用' : '密钥已停用')
emit('refresh')
@@ -1240,9 +1067,11 @@ async function handleDrop(event: DragEvent, targetKey: EndpointAPIKey, endpoint:
handleDragEnd()
// 调用 API 批量更新
// 调用 API 批量更新(使用循环调用 updateProviderKey 替代已废弃的 batchUpdateKeyPriority
try {
await batchUpdateKeyPriority(endpoint.id, priorities)
await Promise.all(
priorities.map(p => updateProviderKey(p.key_id, { internal_priority: p.internal_priority }))
)
showSuccess('优先级已更新')
// 重新加载以获取更新后的数据
await loadEndpoints()
@@ -1258,15 +1087,43 @@ async function handleDrop(event: DragEvent, targetKey: EndpointAPIKey, endpoint:
function startEditPriority(key: EndpointAPIKey) {
editingPriorityKey.value = key.id
editingPriorityValue.value = key.internal_priority ?? 0
prioritySaving.value = false
nextTick(() => {
// v-for 中的 ref 是数组,取第一个元素
const input = Array.isArray(priorityInputRef.value) ? priorityInputRef.value[0] : priorityInputRef.value
input?.focus()
input?.select()
})
}
function cancelEditPriority() {
editingPriorityKey.value = null
prioritySaving.value = false
}
async function savePriority(key: EndpointAPIKey, endpoint: ProviderEndpointWithKeys) {
function handlePriorityKeydown(e: KeyboardEvent, key: EndpointAPIKey) {
if (e.key === 'Enter') {
e.preventDefault()
e.stopPropagation()
if (!prioritySaving.value) {
prioritySaving.value = true
savePriority(key)
}
} else if (e.key === 'Escape') {
e.preventDefault()
cancelEditPriority()
}
}
function handlePriorityBlur(key: EndpointAPIKey) {
// 如果已经在保存中Enter触发不重复保存
if (prioritySaving.value) return
savePriority(key)
}
async function savePriority(key: EndpointAPIKey) {
const keyId = editingPriorityKey.value
const newPriority = editingPriorityValue.value
const newPriority = parseInt(String(editingPriorityValue.value), 10) || 0
if (!keyId || newPriority < 0) {
cancelEditPriority()
@@ -1282,17 +1139,15 @@ async function savePriority(key: EndpointAPIKey, endpoint: ProviderEndpointWithK
cancelEditPriority()
try {
await updateEndpointKey(keyId, { internal_priority: newPriority })
await updateProviderKey(keyId, { internal_priority: newPriority })
showSuccess('优先级已更新')
// 更新本地数据
if (endpoint.keys) {
const keyToUpdate = endpoint.keys.find(k => k.id === keyId)
// 更新本地数据 - 更新 providerKeys 中的数据
const keyToUpdate = providerKeys.value.find(k => k.id === keyId)
if (keyToUpdate) {
keyToUpdate.internal_priority = newPriority
}
// 重新排序
endpoint.keys.sort((a, b) => (a.internal_priority ?? 0) - (b.internal_priority ?? 0))
}
providerKeys.value.sort((a, b) => (a.internal_priority ?? 0) - (b.internal_priority ?? 0))
emit('refresh')
} catch (err: any) {
showError(err.response?.data?.detail || '更新优先级失败', '错误')
@@ -1318,6 +1173,28 @@ function formatProbeTime(probeTime: string): string {
return '即将探测'
}
// 获取密钥的 API 格式列表(按指定顺序排序)
function getKeyApiFormats(key: EndpointAPIKey, endpoint?: ProviderEndpointWithKeys): string[] {
let formats: string[] = []
if (key.api_formats && key.api_formats.length > 0) {
formats = [...key.api_formats]
} else if (endpoint) {
formats = [endpoint.api_format]
}
// 使用统一的排序函数
return sortApiFormats(formats)
}
// 获取密钥在指定 API 格式下的成本倍率
function getKeyRateMultiplier(key: EndpointAPIKey, format: string): number {
// 优先使用 rate_multipliers 中指定格式的倍率
if (key.rate_multipliers && key.rate_multipliers[format] !== undefined) {
return key.rate_multipliers[format]
}
// 回退到默认倍率
return key.rate_multiplier || 1.0
}
// 健康度颜色
function getHealthScoreColor(score: number): string {
if (score >= 0.8) return 'text-green-600 dark:text-green-400'
@@ -1354,22 +1231,22 @@ async function loadEndpoints() {
if (!props.providerId) return
try {
const endpointsList = await getProviderEndpoints(props.providerId)
// 并行加载端点列表和 Provider 级别的 keys
const [endpointsList, providerKeysResult] = await Promise.all([
getProviderEndpoints(props.providerId),
getProviderKeys(props.providerId).catch(() => []),
])
// 为每个端点加载其密钥
const endpointsWithKeys = await Promise.all(
endpointsList.map(async (endpoint) => {
try {
const keys = await getEndpointKeys(endpoint.id)
return { ...endpoint, keys }
} catch {
// 如果获取密钥失败,返回空数组
return { ...endpoint, keys: [] }
}
providerKeys.value = providerKeysResult
// 按 API 格式排序
endpoints.value = endpointsList.sort((a, b) => {
const aIdx = API_FORMAT_ORDER.indexOf(a.api_format)
const bIdx = API_FORMAT_ORDER.indexOf(b.api_format)
if (aIdx === -1 && bIdx === -1) return 0
if (aIdx === -1) return 1
if (bIdx === -1) return -1
return aIdx - bIdx
})
)
endpoints.value = endpointsWithKeys
} catch (err: any) {
showError(err.response?.data?.detail || '加载端点失败', '错误')
}

View File

@@ -4,47 +4,29 @@
:title="isEditMode ? '编辑提供商' : '添加提供商'"
:description="isEditMode ? '更新提供商配置。API 端点和密钥需在详情页面单独管理。' : '创建新的提供商配置。创建后可以为其添加 API 端点和密钥。'"
:icon="isEditMode ? SquarePen : Server"
size="2xl"
size="xl"
@update:model-value="handleDialogUpdate"
>
<form
class="space-y-6"
class="space-y-5"
@submit.prevent="handleSubmit"
>
<!-- 基本信息 -->
<div class="space-y-4">
<div class="space-y-3">
<h3 class="text-sm font-medium border-b pb-2">
基本信息
</h3>
<!-- 添加模式显示提供商标识 -->
<div
v-if="!isEditMode"
class="space-y-2"
>
<Label for="name">提供商标识 *</Label>
<div class="grid grid-cols-2 gap-4">
<div class="space-y-1.5">
<Label for="name">名称 *</Label>
<Input
id="name"
v-model="form.name"
placeholder="例如: openai-primary"
required
/>
<p class="text-xs text-muted-foreground">
唯一ID创建后不可修改
</p>
</div>
<div class="grid grid-cols-2 gap-4">
<div class="space-y-2">
<Label for="display_name">显示名称 *</Label>
<Input
id="display_name"
v-model="form.display_name"
placeholder="例如: OpenAI 主账号"
required
/>
</div>
<div class="space-y-2">
<div class="space-y-1.5">
<Label for="website">主站链接</Label>
<Input
id="website"
@@ -55,24 +37,28 @@
</div>
</div>
<div class="space-y-2">
<div class="space-y-1.5">
<Label for="description">描述</Label>
<Textarea
<Input
id="description"
v-model="form.description"
placeholder="提供商描述(可选)"
rows="2"
/>
</div>
</div>
<!-- 计费与限流 -->
<div class="space-y-4">
<!-- 计费与限流 / 请求配置 -->
<div class="space-y-3">
<div class="grid grid-cols-2 gap-4">
<h3 class="text-sm font-medium border-b pb-2">
计费与限流
</h3>
<h3 class="text-sm font-medium border-b pb-2">
请求配置
</h3>
</div>
<div class="grid grid-cols-2 gap-4">
<div class="space-y-2">
<div class="space-y-1.5">
<Label>计费类型</Label>
<Select
v-model="form.billing_type"
@@ -82,81 +68,131 @@
<SelectValue />
</SelectTrigger>
<SelectContent>
<SelectItem value="monthly_quota">
月卡额度
</SelectItem>
<SelectItem value="pay_as_you_go">
按量付费
</SelectItem>
<SelectItem value="free_tier">
免费套餐
</SelectItem>
<SelectItem value="monthly_quota">月卡额度</SelectItem>
<SelectItem value="pay_as_you_go">按量付费</SelectItem>
<SelectItem value="free_tier">免费套餐</SelectItem>
</SelectContent>
</Select>
</div>
<div class="space-y-2">
<Label>RPM 限制</Label>
<div class="grid grid-cols-2 gap-4">
<div class="space-y-1.5">
<Label>超时时间 ()</Label>
<Input
:model-value="form.rpm_limit ?? ''"
:model-value="form.timeout ?? ''"
type="number"
min="1"
max="600"
placeholder="默认 300"
@update:model-value="(v) => form.timeout = parseNumberInput(v)"
/>
</div>
<div class="space-y-1.5">
<Label>最大重试次数</Label>
<Input
:model-value="form.max_retries ?? ''"
type="number"
min="0"
placeholder="不限制请留空"
@update:model-value="(v) => form.rpm_limit = parseNumberInput(v)"
max="10"
placeholder="默认 2"
@update:model-value="(v) => form.max_retries = parseNumberInput(v)"
/>
</div>
</div>
</div>
<!-- 月卡配置 -->
<div
v-if="form.billing_type === 'monthly_quota'"
class="grid grid-cols-2 gap-4 p-3 border rounded-lg bg-muted/50"
>
<div class="space-y-2">
<div class="space-y-1.5">
<Label class="text-xs">周期额度 (USD)</Label>
<Input
:model-value="form.monthly_quota_usd ?? ''"
type="number"
step="0.01"
min="0"
class="h-9"
@update:model-value="(v) => form.monthly_quota_usd = parseNumberInput(v, { allowFloat: true })"
/>
</div>
<div class="space-y-2">
<div class="space-y-1.5">
<Label class="text-xs">重置周期 (天)</Label>
<Input
:model-value="form.quota_reset_day ?? ''"
type="number"
min="1"
max="365"
class="h-9"
@update:model-value="(v) => form.quota_reset_day = parseNumberInput(v) ?? 30"
/>
</div>
<div class="space-y-2">
<div class="space-y-1.5">
<Label class="text-xs">
周期开始时间
<span class="text-red-500">*</span>
周期开始时间 <span class="text-red-500">*</span>
</Label>
<Input
v-model="form.quota_last_reset_at"
type="datetime-local"
class="h-9"
/>
<p class="text-xs text-muted-foreground">
系统会自动统计从该时间点开始的使用量
</p>
</div>
<div class="space-y-2">
<div class="space-y-1.5">
<Label class="text-xs">过期时间</Label>
<Input
v-model="form.quota_expires_at"
type="datetime-local"
class="h-9"
/>
<p class="text-xs text-muted-foreground">
留空表示永久有效
</p>
</div>
</div>
</div>
<!-- 代理配置 -->
<div class="space-y-3">
<div class="flex items-center justify-between">
<h3 class="text-sm font-medium">
代理配置
</h3>
<div class="flex items-center gap-2">
<Switch
:model-value="form.proxy_enabled"
@update:model-value="(v: boolean) => form.proxy_enabled = v"
/>
<span class="text-sm text-muted-foreground">启用代理</span>
</div>
</div>
<div
v-if="form.proxy_enabled"
class="grid grid-cols-2 gap-4 p-3 border rounded-lg bg-muted/50"
>
<div class="space-y-1.5">
<Label class="text-xs">代理地址 *</Label>
<Input
v-model="form.proxy_url"
placeholder="http://proxy:port 或 socks5://proxy:port"
/>
</div>
<div class="grid grid-cols-2 gap-3">
<div class="space-y-1.5">
<Label class="text-xs">用户名</Label>
<Input
v-model="form.proxy_username"
placeholder="可选"
autocomplete="off"
data-form-type="other"
data-lpignore="true"
data-1p-ignore="true"
/>
</div>
<div class="space-y-1.5">
<Label class="text-xs">密码</Label>
<Input
v-model="form.proxy_password"
type="password"
placeholder="可选"
autocomplete="new-password"
data-form-type="other"
data-lpignore="true"
data-1p-ignore="true"
/>
</div>
</div>
</div>
</div>
@@ -172,7 +208,7 @@
取消
</Button>
<Button
:disabled="loading || !form.display_name || (!isEditMode && !form.name)"
:disabled="loading || !form.name"
@click="handleSubmit"
>
{{ loading ? (isEditMode ? '保存中...' : '创建中...') : (isEditMode ? '保存' : '创建') }}
@@ -187,13 +223,13 @@ import {
Dialog,
Button,
Input,
Textarea,
Label,
Select,
SelectTrigger,
SelectValue,
SelectContent,
SelectItem,
Switch,
} from '@/components/ui'
import { Server, SquarePen } from 'lucide-vue-next'
import { useToast } from '@/composables/useToast'
@@ -223,7 +259,6 @@ const internalOpen = computed(() => props.modelValue)
// 表单数据
const form = ref({
name: '',
display_name: '',
description: '',
website: '',
// 计费配置
@@ -232,19 +267,25 @@ const form = ref({
quota_reset_day: 30,
quota_last_reset_at: '', // 周期开始时间
quota_expires_at: '',
rpm_limit: undefined as string | number | undefined,
provider_priority: 999,
// 状态配置
is_active: true,
rate_limit: undefined as number | undefined,
concurrent_limit: undefined as number | undefined,
// 请求配置
timeout: undefined as number | undefined,
max_retries: undefined as number | undefined,
// 代理配置(扁平化便于表单绑定)
proxy_enabled: false,
proxy_url: '',
proxy_username: '',
proxy_password: '',
})
// 重置表单
function resetForm() {
form.value = {
name: '',
display_name: '',
description: '',
website: '',
billing_type: 'pay_as_you_go',
@@ -252,11 +293,18 @@ function resetForm() {
quota_reset_day: 30,
quota_last_reset_at: '',
quota_expires_at: '',
rpm_limit: undefined,
provider_priority: 999,
is_active: true,
rate_limit: undefined,
concurrent_limit: undefined,
// 请求配置
timeout: undefined,
max_retries: undefined,
// 代理配置
proxy_enabled: false,
proxy_url: '',
proxy_username: '',
proxy_password: '',
}
}
@@ -264,9 +312,9 @@ function resetForm() {
function loadProviderData() {
if (!props.provider) return
const proxy = props.provider.proxy
form.value = {
name: props.provider.name,
display_name: props.provider.display_name,
description: props.provider.description || '',
website: props.provider.website || '',
billing_type: (props.provider.billing_type as 'monthly_quota' | 'pay_as_you_go' | 'free_tier') || 'pay_as_you_go',
@@ -276,11 +324,18 @@ function loadProviderData() {
new Date(props.provider.quota_last_reset_at).toISOString().slice(0, 16) : '',
quota_expires_at: props.provider.quota_expires_at ?
new Date(props.provider.quota_expires_at).toISOString().slice(0, 16) : '',
rpm_limit: props.provider.rpm_limit ?? undefined,
provider_priority: props.provider.provider_priority || 999,
is_active: props.provider.is_active,
rate_limit: undefined,
concurrent_limit: undefined,
// 请求配置
timeout: props.provider.timeout ?? undefined,
max_retries: props.provider.max_retries ?? undefined,
// 代理配置
proxy_enabled: proxy?.enabled ?? false,
proxy_url: proxy?.url || '',
proxy_username: proxy?.username || '',
proxy_password: proxy?.password || '',
}
}
@@ -302,17 +357,37 @@ const handleSubmit = async () => {
return
}
// 启用代理时必须填写代理地址
if (form.value.proxy_enabled && !form.value.proxy_url) {
showError('启用代理时必须填写代理地址', '验证失败')
return
}
loading.value = true
try {
// 构建代理配置
const proxy = form.value.proxy_enabled ? {
url: form.value.proxy_url,
username: form.value.proxy_username || undefined,
password: form.value.proxy_password || undefined,
enabled: true,
} : null
const payload = {
...form.value,
rpm_limit:
form.value.rpm_limit === undefined || form.value.rpm_limit === ''
? null
: Number(form.value.rpm_limit),
// 空字符串时不发送
name: form.value.name,
description: form.value.description || undefined,
website: form.value.website || undefined,
billing_type: form.value.billing_type,
monthly_quota_usd: form.value.monthly_quota_usd,
quota_reset_day: form.value.quota_reset_day,
quota_last_reset_at: form.value.quota_last_reset_at || undefined,
quota_expires_at: form.value.quota_expires_at || undefined,
provider_priority: form.value.provider_priority,
is_active: form.value.is_active,
// 请求配置
timeout: form.value.timeout ?? undefined,
max_retries: form.value.max_retries ?? undefined,
proxy,
}
if (isEditMode.value && props.provider) {

View File

@@ -2,6 +2,7 @@ export { default as ProviderFormDialog } from './ProviderFormDialog.vue'
export { default as EndpointFormDialog } from './EndpointFormDialog.vue'
export { default as KeyFormDialog } from './KeyFormDialog.vue'
export { default as KeyAllowedModelsDialog } from './KeyAllowedModelsDialog.vue'
export { default as KeyAllowedModelsEditDialog } from './KeyAllowedModelsEditDialog.vue'
export { default as PriorityManagementDialog } from './PriorityManagementDialog.vue'
export { default as ProviderModelFormDialog } from './ProviderModelFormDialog.vue'
export { default as ProviderDetailDrawer } from './ProviderDetailDrawer.vue'

View File

@@ -178,7 +178,7 @@
<Button
variant="ghost"
size="icon"
class="h-8 w-8 text-destructive hover:text-destructive"
class="h-8 w-8 hover:text-destructive"
title="删除"
@click="deleteModel(model)"
>

View File

@@ -289,14 +289,14 @@
/>
</div>
<!-- 错误信息卡片 -->
<!-- 响应客户端错误卡片 -->
<Card
v-if="detail.error_message"
class="border-red-200 dark:border-red-800"
>
<div class="p-4">
<h4 class="text-sm font-semibold text-red-600 dark:text-red-400 mb-2">
错误信息
响应客户端错误
</h4>
<div class="bg-red-50 dark:bg-red-900/20 rounded-lg p-3">
<p class="text-sm text-red-800 dark:text-red-300">
@@ -431,7 +431,7 @@
<TabsContent value="response-headers">
<JsonContent
:data="detail.response_headers"
:data="actualResponseHeaders"
:view-mode="viewMode"
:expand-depth="currentExpandDepth"
:is-dark="isDark"
@@ -614,6 +614,25 @@ const tabs = [
{ name: 'metadata', label: '元数据' },
]
// 判断数据是否有实际内容(非空对象/数组)
function hasContent(data: unknown): boolean {
if (data === null || data === undefined) return false
if (typeof data === 'object') {
return Object.keys(data as object).length > 0
}
return true
}
// 获取实际的响应头(优先 client_response_headers回退到 response_headers
const actualResponseHeaders = computed(() => {
if (!detail.value) return null
// 优先返回客户端响应头,如果没有则回退到提供商响应头
if (hasContent(detail.value.client_response_headers)) {
return detail.value.client_response_headers
}
return detail.value.response_headers
})
// 根据实际数据决定显示哪些 Tab
const visibleTabs = computed(() => {
if (!detail.value) return []
@@ -621,15 +640,15 @@ const visibleTabs = computed(() => {
return tabs.filter(tab => {
switch (tab.name) {
case 'request-headers':
return detail.value!.request_headers && Object.keys(detail.value!.request_headers).length > 0
return hasContent(detail.value!.request_headers)
case 'request-body':
return detail.value!.request_body !== null && detail.value!.request_body !== undefined
return hasContent(detail.value!.request_body)
case 'response-headers':
return detail.value!.response_headers && Object.keys(detail.value!.response_headers).length > 0
return hasContent(actualResponseHeaders.value)
case 'response-body':
return detail.value!.response_body !== null && detail.value!.response_body !== undefined
return hasContent(detail.value!.response_body)
case 'metadata':
return detail.value!.metadata && Object.keys(detail.value!.metadata).length > 0
return hasContent(detail.value!.metadata)
default:
return false
}
@@ -775,7 +794,7 @@ function copyJsonToClipboard(tabName: string) {
data = detail.value.request_body
break
case 'response-headers':
data = detail.value.response_headers
data = actualResponseHeaders.value
break
case 'response-body':
data = detail.value.response_body

View File

@@ -252,7 +252,7 @@
@click.stop
@change="toggleSelection('allowed_providers', provider.id)"
>
<span class="text-sm">{{ provider.display_name || provider.name }}</span>
<span class="text-sm">{{ provider.name }}</span>
</div>
<div
v-if="providers.length === 0"

View File

@@ -424,8 +424,7 @@ export const MOCK_ADMIN_API_KEYS: AdminApiKeysResponse = {
export const MOCK_PROVIDERS: ProviderWithEndpointsSummary[] = [
{
id: 'provider-001',
name: 'duck_coding_free',
display_name: 'DuckCodingFree',
name: 'DuckCodingFree',
description: '',
website: 'https://duckcoding.com',
provider_priority: 1,
@@ -451,8 +450,7 @@ export const MOCK_PROVIDERS: ProviderWithEndpointsSummary[] = [
},
{
id: 'provider-002',
name: 'open_claude_code',
display_name: 'OpenClaudeCode',
name: 'OpenClaudeCode',
description: '',
website: 'https://www.openclaudecode.cn',
provider_priority: 2,
@@ -477,8 +475,7 @@ export const MOCK_PROVIDERS: ProviderWithEndpointsSummary[] = [
},
{
id: 'provider-003',
name: '88_code',
display_name: '88Code',
name: '88Code',
description: '',
website: 'https://www.88code.org/',
provider_priority: 3,
@@ -503,8 +500,7 @@ export const MOCK_PROVIDERS: ProviderWithEndpointsSummary[] = [
},
{
id: 'provider-004',
name: 'ikun_code',
display_name: 'IKunCode',
name: 'IKunCode',
description: '',
website: 'https://api.ikuncode.cc',
provider_priority: 4,
@@ -531,8 +527,7 @@ export const MOCK_PROVIDERS: ProviderWithEndpointsSummary[] = [
},
{
id: 'provider-005',
name: 'duck_coding',
display_name: 'DuckCoding',
name: 'DuckCoding',
description: '',
website: 'https://duckcoding.com',
provider_priority: 5,
@@ -561,8 +556,7 @@ export const MOCK_PROVIDERS: ProviderWithEndpointsSummary[] = [
},
{
id: 'provider-006',
name: 'privnode',
display_name: 'Privnode',
name: 'Privnode',
description: '',
website: 'https://privnode.com',
provider_priority: 6,
@@ -584,8 +578,7 @@ export const MOCK_PROVIDERS: ProviderWithEndpointsSummary[] = [
},
{
id: 'provider-007',
name: 'undying_api',
display_name: 'UndyingAPI',
name: 'UndyingAPI',
description: '',
website: 'https://vip.undyingapi.com',
provider_priority: 7,

View File

@@ -418,16 +418,16 @@ const MOCK_ALIASES = [
// Mock Endpoint Keys
const MOCK_ENDPOINT_KEYS = [
{ id: 'ekey-001', endpoint_id: 'ep-001', api_key_masked: 'sk-ant...abc1', name: 'Primary Key', rate_multiplier: 1.0, internal_priority: 1, health_score: 98, consecutive_failures: 0, request_count: 5000, success_count: 4950, error_count: 50, success_rate: 99, avg_response_time_ms: 1200, is_active: true, created_at: '2024-01-01T00:00:00Z', updated_at: new Date().toISOString() },
{ id: 'ekey-002', endpoint_id: 'ep-001', api_key_masked: 'sk-ant...def2', name: 'Backup Key', rate_multiplier: 1.0, internal_priority: 2, health_score: 95, consecutive_failures: 1, request_count: 2000, success_count: 1950, error_count: 50, success_rate: 97.5, avg_response_time_ms: 1350, is_active: true, created_at: '2024-02-01T00:00:00Z', updated_at: new Date().toISOString() },
{ id: 'ekey-003', endpoint_id: 'ep-002', api_key_masked: 'sk-oai...ghi3', name: 'OpenAI Main', rate_multiplier: 1.0, internal_priority: 1, health_score: 97, consecutive_failures: 0, request_count: 3500, success_count: 3450, error_count: 50, success_rate: 98.6, avg_response_time_ms: 900, is_active: true, created_at: '2024-01-15T00:00:00Z', updated_at: new Date().toISOString() }
{ id: 'ekey-001', provider_id: 'provider-001', api_formats: ['CLAUDE'], api_key_masked: 'sk-ant...abc1', name: 'Primary Key', rate_multiplier: 1.0, internal_priority: 1, health_score: 0.98, consecutive_failures: 0, request_count: 5000, success_count: 4950, error_count: 50, success_rate: 0.99, avg_response_time_ms: 1200, cache_ttl_minutes: 5, max_probe_interval_minutes: 32, is_active: true, created_at: '2024-01-01T00:00:00Z', updated_at: new Date().toISOString() },
{ id: 'ekey-002', provider_id: 'provider-001', api_formats: ['CLAUDE'], api_key_masked: 'sk-ant...def2', name: 'Backup Key', rate_multiplier: 1.0, internal_priority: 2, health_score: 0.95, consecutive_failures: 1, request_count: 2000, success_count: 1950, error_count: 50, success_rate: 0.975, avg_response_time_ms: 1350, cache_ttl_minutes: 5, max_probe_interval_minutes: 32, is_active: true, created_at: '2024-02-01T00:00:00Z', updated_at: new Date().toISOString() },
{ id: 'ekey-003', provider_id: 'provider-002', api_formats: ['OPENAI'], api_key_masked: 'sk-oai...ghi3', name: 'OpenAI Main', rate_multiplier: 1.0, internal_priority: 1, health_score: 0.97, consecutive_failures: 0, request_count: 3500, success_count: 3450, error_count: 50, success_rate: 0.986, avg_response_time_ms: 900, cache_ttl_minutes: 5, max_probe_interval_minutes: 32, is_active: true, created_at: '2024-01-15T00:00:00Z', updated_at: new Date().toISOString() }
]
// Mock Endpoints
const MOCK_ENDPOINTS = [
{ id: 'ep-001', provider_id: 'provider-001', provider_name: 'anthropic', api_format: 'claude', base_url: 'https://api.anthropic.com', auth_type: 'bearer', timeout: 120, max_retries: 2, priority: 100, weight: 100, health_score: 98, consecutive_failures: 0, is_active: true, total_keys: 2, active_keys: 2, created_at: '2024-01-01T00:00:00Z', updated_at: new Date().toISOString() },
{ id: 'ep-002', provider_id: 'provider-002', provider_name: 'openai', api_format: 'openai', base_url: 'https://api.openai.com', auth_type: 'bearer', timeout: 60, max_retries: 2, priority: 90, weight: 100, health_score: 97, consecutive_failures: 0, is_active: true, total_keys: 1, active_keys: 1, created_at: '2024-01-01T00:00:00Z', updated_at: new Date().toISOString() },
{ id: 'ep-003', provider_id: 'provider-003', provider_name: 'google', api_format: 'gemini', base_url: 'https://generativelanguage.googleapis.com', auth_type: 'api_key', timeout: 60, max_retries: 2, priority: 80, weight: 100, health_score: 96, consecutive_failures: 0, is_active: true, total_keys: 1, active_keys: 1, created_at: '2024-01-15T00:00:00Z', updated_at: new Date().toISOString() }
{ id: 'ep-001', provider_id: 'provider-001', provider_name: 'anthropic', api_format: 'CLAUDE', base_url: 'https://api.anthropic.com', timeout: 300, max_retries: 2, is_active: true, total_keys: 2, active_keys: 2, created_at: '2024-01-01T00:00:00Z', updated_at: new Date().toISOString() },
{ id: 'ep-002', provider_id: 'provider-002', provider_name: 'openai', api_format: 'OPENAI', base_url: 'https://api.openai.com', timeout: 60, max_retries: 2, is_active: true, total_keys: 1, active_keys: 1, created_at: '2024-01-01T00:00:00Z', updated_at: new Date().toISOString() },
{ id: 'ep-003', provider_id: 'provider-003', provider_name: 'google', api_format: 'GEMINI', base_url: 'https://generativelanguage.googleapis.com', timeout: 60, max_retries: 2, is_active: true, total_keys: 1, active_keys: 1, created_at: '2024-01-15T00:00:00Z', updated_at: new Date().toISOString() }
]
// Mock 能力定义
@@ -581,7 +581,6 @@ const mockHandlers: Record<string, (config: AxiosRequestConfig) => Promise<Axios
return createMockResponse(MOCK_PROVIDERS.map(p => ({
id: p.id,
name: p.name,
display_name: p.display_name,
is_active: p.is_active
})))
},
@@ -1222,13 +1221,8 @@ function generateMockEndpointsForProvider(providerId: string) {
base_url: format.includes('CLAUDE') ? 'https://api.anthropic.com' :
format.includes('OPENAI') ? 'https://api.openai.com' :
'https://generativelanguage.googleapis.com',
auth_type: format.includes('GEMINI') ? 'api_key' : 'bearer',
timeout: 120,
timeout: 300,
max_retries: 2,
priority: 100 - index * 10,
weight: 100,
health_score: healthDetail?.health_score ?? 1.0,
consecutive_failures: healthDetail?.health_score && healthDetail.health_score < 0.7 ? 2 : 0,
is_active: healthDetail?.is_active ?? true,
total_keys: Math.ceil(Math.random() * 3) + 1,
active_keys: Math.ceil(Math.random() * 2) + 1,
@@ -1238,11 +1232,16 @@ function generateMockEndpointsForProvider(providerId: string) {
})
}
// 为 endpoint 生成 keys
function generateMockKeysForEndpoint(endpointId: string, count: number = 2) {
// 为 provider 生成 keysKey 归属 Provider通过 api_formats 关联)
const PROVIDER_KEYS_CACHE: Record<string, any[]> = {}
function generateMockKeysForProvider(providerId: string, count: number = 2) {
const provider = MOCK_PROVIDERS.find(p => p.id === providerId)
const formats = provider?.api_formats || []
return Array.from({ length: count }, (_, i) => ({
id: `key-${endpointId}-${i + 1}`,
endpoint_id: endpointId,
id: `key-${providerId}-${i + 1}`,
provider_id: providerId,
api_formats: i === 0 ? formats : formats.slice(0, 1),
api_key_masked: `sk-***...${Math.random().toString(36).substring(2, 6)}`,
name: i === 0 ? 'Primary Key' : `Backup Key ${i}`,
rate_multiplier: 1.0,
@@ -1254,6 +1253,8 @@ function generateMockKeysForEndpoint(endpointId: string, count: number = 2) {
error_count: Math.floor(Math.random() * 100),
success_rate: 0.95 + Math.random() * 0.04, // 0.95-0.99
avg_response_time_ms: 800 + Math.floor(Math.random() * 600),
cache_ttl_minutes: 5,
max_probe_interval_minutes: 32,
is_active: true,
created_at: '2024-01-01T00:00:00Z',
updated_at: new Date().toISOString()
@@ -1463,29 +1464,63 @@ registerDynamicRoute('PUT', '/api/admin/endpoints/:endpointId', async (config, p
registerDynamicRoute('DELETE', '/api/admin/endpoints/:endpointId', async (_config, _params) => {
await delay()
requireAdmin()
return createMockResponse({ message: '删除成功(演示模式)' })
return createMockResponse({ message: '删除成功(演示模式)', affected_keys_count: 0 })
})
// Endpoint Keys 列表
registerDynamicRoute('GET', '/api/admin/endpoints/:endpointId/keys', async (_config, params) => {
// Provider Keys 列表
registerDynamicRoute('GET', '/api/admin/endpoints/providers/:providerId/keys', async (_config, params) => {
await delay()
requireAdmin()
const keys = generateMockKeysForEndpoint(params.endpointId, 2)
return createMockResponse(keys)
if (!PROVIDER_KEYS_CACHE[params.providerId]) {
PROVIDER_KEYS_CACHE[params.providerId] = generateMockKeysForProvider(params.providerId, 2)
}
return createMockResponse(PROVIDER_KEYS_CACHE[params.providerId])
})
// 创建 Key
registerDynamicRoute('POST', '/api/admin/endpoints/:endpointId/keys', async (config, params) => {
// 为 Provider 创建 Key
registerDynamicRoute('POST', '/api/admin/endpoints/providers/:providerId/keys', async (config, params) => {
await delay()
requireAdmin()
const body = JSON.parse(config.data || '{}')
return createMockResponse({
const apiKeyPlain = body.api_key || 'sk-demo'
const masked = apiKeyPlain.length >= 12
? `${apiKeyPlain.slice(0, 8)}***${apiKeyPlain.slice(-4)}`
: 'sk-***...demo'
const newKey = {
id: `key-demo-${Date.now()}`,
endpoint_id: params.endpointId,
api_key_masked: 'sk-***...demo',
...body,
created_at: new Date().toISOString()
})
provider_id: params.providerId,
api_formats: body.api_formats || [],
api_key_masked: masked,
api_key_plain: null,
name: body.name || 'New Key',
note: body.note,
rate_multiplier: body.rate_multiplier ?? 1.0,
rate_multipliers: body.rate_multipliers ?? null,
internal_priority: body.internal_priority ?? 50,
global_priority: body.global_priority ?? null,
rpm_limit: body.rpm_limit ?? null,
allowed_models: body.allowed_models ?? null,
capabilities: body.capabilities ?? null,
cache_ttl_minutes: body.cache_ttl_minutes ?? 5,
max_probe_interval_minutes: body.max_probe_interval_minutes ?? 32,
health_score: 1.0,
consecutive_failures: 0,
request_count: 0,
success_count: 0,
error_count: 0,
success_rate: 0.0,
avg_response_time_ms: 0.0,
is_active: true,
created_at: new Date().toISOString(),
updated_at: new Date().toISOString(),
}
if (!PROVIDER_KEYS_CACHE[params.providerId]) {
PROVIDER_KEYS_CACHE[params.providerId] = []
}
PROVIDER_KEYS_CACHE[params.providerId].push(newKey)
return createMockResponse(newKey)
})
// Key 更新
@@ -1503,6 +1538,50 @@ registerDynamicRoute('DELETE', '/api/admin/endpoints/keys/:keyId', async (_confi
return createMockResponse({ message: '删除成功(演示模式)' })
})
// Key Reveal
registerDynamicRoute('GET', '/api/admin/endpoints/keys/:keyId/reveal', async (_config, _params) => {
await delay()
requireAdmin()
return createMockResponse({ api_key: 'sk-demo-reveal' })
})
// Keys grouped by format
mockHandlers['GET /api/admin/endpoints/keys/grouped-by-format'] = async () => {
await delay()
requireAdmin()
// 确保每个 provider 都有 key 数据
for (const provider of MOCK_PROVIDERS) {
if (!PROVIDER_KEYS_CACHE[provider.id]) {
PROVIDER_KEYS_CACHE[provider.id] = generateMockKeysForProvider(provider.id, 2)
}
}
const grouped: Record<string, any[]> = {}
for (const provider of MOCK_PROVIDERS) {
const endpoints = generateMockEndpointsForProvider(provider.id)
const baseUrlByFormat = Object.fromEntries(endpoints.map(e => [e.api_format, e.base_url]))
const keys = PROVIDER_KEYS_CACHE[provider.id] || []
for (const key of keys) {
const formats: string[] = key.api_formats || []
for (const fmt of formats) {
if (!grouped[fmt]) grouped[fmt] = []
grouped[fmt].push({
...key,
api_format: fmt,
provider_name: provider.name,
endpoint_base_url: baseUrlByFormat[fmt],
global_priority: key.global_priority ?? null,
circuit_breaker_open: false,
capabilities: [],
})
}
}
}
return createMockResponse(grouped)
}
// Provider Models 列表
registerDynamicRoute('GET', '/api/admin/providers/:providerId/models', async (_config, params) => {
await delay()

View File

@@ -20,7 +20,7 @@ interface ValidationError {
const fieldNameMap: Record<string, string> = {
'api_key': 'API 密钥',
'priority': '优先级',
'max_concurrent': '最大并发',
'rpm_limit': 'RPM 限制',
'rate_limit': '速率限制',
'daily_limit': '每日限制',
'monthly_limit': '每月限制',
@@ -44,7 +44,6 @@ const fieldNameMap: Record<string, string> = {
'monthly_quota_usd': '月度配额',
'quota_reset_day': '配额重置日',
'quota_expires_at': '配额过期时间',
'rpm_limit': 'RPM 限制',
'cache_ttl_minutes': '缓存 TTL',
'max_probe_interval_minutes': '最大探测间隔',
}
@@ -151,11 +150,18 @@ export function parseApiError(err: unknown, defaultMessage: string = '操作失
return '无法连接到服务器,请检查网络连接'
}
const detail = err.response?.data?.detail
const data = err.response?.data
// 1. 处理 {error: {type, message}} 格式ProxyException 返回格式)
if (data?.error?.message) {
return data.error.message
}
const detail = data?.detail
// 如果没有 detail 字段
if (!detail) {
return err.response?.data?.message || err.message || defaultMessage
return data?.message || err.message || defaultMessage
}
// 1. 处理 Pydantic 验证错误(数组格式)

View File

@@ -54,6 +54,57 @@ export function parseNumberInput(
return result
}
/**
* Parse number input value for nullable fields (like rpm_limit)
* Returns `null` when empty (to signal "use adaptive/default mode")
* Returns `undefined` when not provided (to signal "keep original value")
*
* @param value - Input value (string or number)
* @param options - Parse options
* @returns Parsed number, null (for empty/adaptive), or undefined
*/
export function parseNullableNumberInput(
value: string | number | null | undefined,
options: {
allowFloat?: boolean
min?: number
max?: number
} = {}
): number | null | undefined {
const { allowFloat = false, min, max } = options
// Empty string means "null" (adaptive mode)
if (value === '') {
return null
}
// null/undefined means "keep original value"
if (value === null || value === undefined) {
return undefined
}
// Parse the value
const num = typeof value === 'string'
? (allowFloat ? parseFloat(value) : parseInt(value, 10))
: value
// Handle NaN - treat as null (adaptive mode)
if (isNaN(num)) {
return null
}
// Apply min/max constraints
let result = num
if (min !== undefined && result < min) {
result = min
}
if (max !== undefined && result > max) {
result = max
}
return result
}
/**
* Create a handler function for number input with specific field
* Useful for creating inline handlers in templates

View File

@@ -530,9 +530,6 @@
/>
<div class="flex-1 min-w-0">
<p class="font-medium text-sm truncate">
{{ provider.display_name || provider.name }}
</p>
<p class="text-xs text-muted-foreground truncate">
{{ provider.name }}
</p>
</div>
@@ -645,10 +642,7 @@
/>
<div class="flex-1 min-w-0">
<p class="font-medium text-sm truncate">
{{ provider.display_name }}
</p>
<p class="text-xs text-muted-foreground truncate">
{{ provider.identifier }}
{{ provider.name }}
</p>
</div>
<Badge
@@ -679,7 +673,7 @@
<ProviderModelFormDialog
:open="editProviderDialogOpen"
:provider-id="editingProvider?.id || ''"
:provider-name="editingProvider?.display_name || ''"
:provider-name="editingProvider?.name || ''"
:editing-model="editingProviderModel"
@update:open="handleEditProviderDialogUpdate"
@saved="handleEditProviderSaved"
@@ -939,7 +933,7 @@ async function batchAddSelectedProviders() {
const errorMessages = result.errors
.map(e => {
const provider = providerOptions.value.find(p => p.id === e.provider_id)
const providerName = provider?.display_name || provider?.name || e.provider_id
const providerName = provider?.name || e.provider_id
return `${providerName}: ${e.error}`
})
.join('\n')
@@ -977,7 +971,7 @@ async function batchRemoveSelectedProviders() {
await deleteModel(providerId, provider.model_id)
successCount++
} catch (err: any) {
errors.push(`${provider.display_name}: ${parseApiError(err, '删除失败')}`)
errors.push(`${provider.name}: ${parseApiError(err, '删除失败')}`)
}
}
@@ -1088,8 +1082,7 @@ async function loadModelProviders(_globalModelId: string) {
selectedModelProviders.value = response.providers.map(p => ({
id: p.provider_id,
model_id: p.model_id,
display_name: p.provider_display_name || p.provider_name,
identifier: p.provider_name,
name: p.provider_name,
provider_type: 'API',
target_model: p.target_model,
is_active: p.is_active,
@@ -1219,7 +1212,7 @@ async function confirmDeleteProviderImplementation(provider: any) {
}
const confirmed = await confirmDanger(
`确定要删除 ${provider.display_name} 的模型关联吗?\n\n模型: ${provider.target_model}\n\n此操作不可恢复`,
`确定要删除 ${provider.name} 的模型关联吗?\n\n模型: ${provider.target_model}\n\n此操作不可恢复`,
'删除关联提供商'
)
if (!confirmed) return
@@ -1227,7 +1220,7 @@ async function confirmDeleteProviderImplementation(provider: any) {
try {
const { deleteModel } = await import('@/api/endpoints')
await deleteModel(provider.id, provider.model_id)
success(`已删除 ${provider.display_name} 的模型实现`)
success(`已删除 ${provider.name} 的模型实现`)
// 重新加载 Provider 列表
if (selectedModel.value) {
await loadModelProviders(selectedModel.value.id)

View File

@@ -134,10 +134,7 @@
@click="handleRowClick($event, provider.id)"
>
<TableCell class="py-3.5">
<div class="flex flex-col gap-0.5">
<span class="text-sm font-medium text-foreground">{{ provider.display_name }}</span>
<span class="text-xs text-muted-foreground/70 font-mono">{{ provider.name }}</span>
</div>
<span class="text-sm font-medium text-foreground">{{ provider.name }}</span>
</TableCell>
<TableCell class="py-3.5">
<Badge
@@ -219,17 +216,10 @@
>${{ (provider.monthly_used_usd ?? 0).toFixed(2) }}</span> / <span class="font-medium">${{ (provider.monthly_quota_usd ?? 0).toFixed(2) }}</span>
</div>
<div
v-if="rpmUsage(provider)"
class="flex items-center gap-1"
>
<span class="text-muted-foreground/70">RPM:</span>
<span class="font-medium text-foreground/80">{{ rpmUsage(provider) }}</span>
</div>
<div
v-if="provider.billing_type !== 'monthly_quota' && !rpmUsage(provider)"
v-else
class="text-muted-foreground/50"
>
无限制
按量付费
</div>
</div>
</TableCell>
@@ -304,7 +294,7 @@
<div class="flex items-start justify-between gap-3">
<div class="flex-1 min-w-0">
<div class="flex items-center gap-2">
<span class="font-medium text-foreground truncate">{{ provider.display_name }}</span>
<span class="font-medium text-foreground truncate">{{ provider.name }}</span>
<Badge
:variant="provider.is_active ? 'success' : 'secondary'"
class="text-xs shrink-0"
@@ -312,7 +302,6 @@
{{ provider.is_active ? '活跃' : '停用' }}
</Badge>
</div>
<span class="text-xs text-muted-foreground/70 font-mono">{{ provider.name }}</span>
</div>
<div
class="flex items-center gap-0.5 shrink-0"
@@ -383,20 +372,17 @@
</span>
</div>
<!-- 第四行配额/限流 -->
<!-- 第四行配额 -->
<div
v-if="provider.billing_type === 'monthly_quota' || rpmUsage(provider)"
v-if="provider.billing_type === 'monthly_quota'"
class="flex items-center gap-3 text-xs text-muted-foreground"
>
<span v-if="provider.billing_type === 'monthly_quota'">
<span>
配额: <span
class="font-semibold"
:class="getQuotaUsedColorClass(provider)"
>${{ (provider.monthly_used_usd ?? 0).toFixed(2) }}</span> / ${{ (provider.monthly_quota_usd ?? 0).toFixed(2) }}
</span>
<span v-if="rpmUsage(provider)">
RPM: {{ rpmUsage(provider) }}
</span>
</div>
</div>
</div>
@@ -509,7 +495,7 @@ const filteredProviders = computed(() => {
if (searchQuery.value.trim()) {
const keywords = searchQuery.value.toLowerCase().split(/\s+/).filter(k => k.length > 0)
result = result.filter(p => {
const searchableText = `${p.display_name} ${p.name}`.toLowerCase()
const searchableText = `${p.name}`.toLowerCase()
return keywords.every(keyword => searchableText.includes(keyword))
})
}
@@ -525,7 +511,7 @@ const filteredProviders = computed(() => {
return a.provider_priority - b.provider_priority
}
// 3. 按名称排序
return a.display_name.localeCompare(b.display_name)
return a.name.localeCompare(b.name)
})
})
@@ -586,7 +572,10 @@ function sortEndpoints(endpoints: any[]) {
// 判断端点是否可用(有 key
function isEndpointAvailable(endpoint: any, _provider: ProviderWithEndpointsSummary): boolean {
// 检查端点是否有活跃的密钥
// 检查端点是否启用,以及是否有活跃的密钥
if (endpoint.is_active === false) {
return false
}
return (endpoint.active_keys ?? 0) > 0
}
@@ -639,21 +628,6 @@ function getQuotaUsedColorClass(provider: ProviderWithEndpointsSummary): string
return 'text-foreground'
}
function rpmUsage(provider: ProviderWithEndpointsSummary): string | null {
const rpmLimit = provider.rpm_limit
const rpmUsed = provider.rpm_used ?? 0
if (rpmLimit === null || rpmLimit === undefined) {
return rpmUsed > 0 ? `${rpmUsed}` : null
}
if (rpmLimit === 0) {
return '已完全禁止'
}
return `${rpmUsed} / ${rpmLimit}`
}
// 使用复用的行点击逻辑
const { handleMouseDown, shouldTriggerRowClick } = useRowClick()
@@ -706,7 +680,7 @@ function handleProviderAdded() {
async function handleDeleteProvider(provider: ProviderWithEndpointsSummary) {
const confirmed = await confirmDanger(
'删除提供商',
`确定要删除提供商 "${provider.display_name}" 吗?\n\n这将同时删除其所有端点、密钥和配置。此操作不可恢复`
`确定要删除提供商 "${provider.name}" 吗?\n\n这将同时删除其所有端点、密钥和配置。此操作不可恢复`
)
if (!confirmed) return

View File

@@ -511,7 +511,7 @@
端点: {{ importPreview.providers?.reduce((sum: number, p: any) => sum + (p.endpoints?.length || 0), 0) }}
</li>
<li>
API Keys: {{ importPreview.providers?.reduce((sum: number, p: any) => sum + p.endpoints?.reduce((s: number, e: any) => s + (e.keys?.length || 0), 0), 0) }}
API Keys: {{ importPreview.providers?.reduce((sum: number, p: any) => sum + (p.api_keys?.length || 0), 0) }}
</li>
</ul>
</div>
@@ -1144,7 +1144,7 @@ function handleConfigFileSelect(event: Event) {
const data = JSON.parse(content) as ConfigExportData
// 验证版本
if (data.version !== '1.0') {
if (data.version !== '2.0') {
error(`不支持的配置版本: ${data.version}`)
return
}

View File

@@ -16,7 +16,8 @@
<!-- 主要统计卡片 -->
<div class="grid grid-cols-2 gap-3 sm:gap-4 xl:grid-cols-4">
<template v-if="loading && stats.length === 0">
<!-- 加载中骨架屏 -->
<template v-if="loading">
<Card
v-for="i in 4"
:key="'skeleton-' + i"
@@ -27,9 +28,10 @@
<Skeleton class="h-4 w-16" />
</Card>
</template>
<!-- 有数据时显示统计卡片 -->
<template v-else-if="stats.length > 0">
<Card
v-for="(stat, index) in stats"
v-else
:key="stat.name"
class="relative overflow-hidden p-3 sm:p-5"
:class="statCardBorders[index % statCardBorders.length]"
@@ -83,6 +85,41 @@
</div>
</div>
</Card>
</template>
<!-- 无数据时显示占位卡片 -->
<template v-else>
<Card
v-for="(placeholder, index) in emptyStatPlaceholders"
:key="'empty-' + index"
class="relative overflow-hidden p-3 sm:p-5"
:class="statCardBorders[index % statCardBorders.length]"
>
<div
class="pointer-events-none absolute -right-4 -top-6 h-28 w-28 rounded-full blur-3xl opacity-20"
:class="statCardGlows[index % statCardGlows.length]"
/>
<div
class="absolute top-3 right-3 sm:top-5 sm:right-5 rounded-xl sm:rounded-2xl border border-border bg-card/50 p-2 sm:p-3 shadow-inner backdrop-blur-sm"
:class="getStatIconColor(index)"
>
<component
:is="placeholder.icon"
class="h-4 w-4 sm:h-5 sm:w-5"
/>
</div>
<div>
<p class="text-[9px] sm:text-[11px] font-semibold uppercase tracking-[0.2em] sm:tracking-[0.4em] text-muted-foreground pr-10 sm:pr-14">
{{ placeholder.name }}
</p>
<p class="mt-2 sm:mt-4 text-xl sm:text-3xl font-semibold text-muted-foreground/50">
--
</p>
<p class="mt-0.5 sm:mt-1 text-[10px] sm:text-sm text-muted-foreground/50">
暂无数据
</p>
</div>
</Card>
</template>
</div>
<!-- 管理员系统健康摘要 -->
@@ -872,6 +909,24 @@ const iconMap: Record<string, any> = {
Users, Activity, TrendingUp, DollarSign, Key, Hash, Database
}
// 空状态占位卡片
const emptyStatPlaceholders = computed(() => {
if (isAdmin.value) {
return [
{ name: '今日请求', icon: Activity },
{ name: '今日 Tokens', icon: Hash },
{ name: '活跃用户', icon: Users },
{ name: '今日费用', icon: DollarSign }
]
}
return [
{ name: '今日请求', icon: Activity },
{ name: '今日 Tokens', icon: Hash },
{ name: 'API Keys', icon: Key },
{ name: '今日费用', icon: DollarSign }
]
})
const totalStats = computed(() => {
if (dailyStats.value.length === 0) {
return { requests: 0, tokens: 0, cost: 0, avgResponseTime: 0 }

View File

@@ -78,6 +78,20 @@ export default {
md: "calc(var(--radius) - 2px)",
sm: "calc(var(--radius) - 4px)",
},
keyframes: {
"collapsible-down": {
from: { height: "0" },
to: { height: "var(--radix-collapsible-content-height)" },
},
"collapsible-up": {
from: { height: "var(--radix-collapsible-content-height)" },
to: { height: "0" },
},
},
animation: {
"collapsible-down": "collapsible-down 0.2s ease-out",
"collapsible-up": "collapsible-up 0.2s ease-out",
},
},
},
plugins: [require("tailwindcss-animate")],

View File

@@ -1,12 +1,12 @@
"""
自适应并发管理 API 端点
自适应 RPM 管理 API 端点
设计原则:
- 自适应模式由 max_concurrent 字段决定:
- max_concurrent = NULL启用自适应模式系统自动学习并调整并发限制
- max_concurrent = 数字:固定限制模式,使用用户指定的并发限制
- learned_max_concurrent自适应模式下学习到的并发限制值
- adaptive_mode 是计算字段,基于 max_concurrent 是否为 NULL
- 自适应模式由 rpm_limit 字段决定:
- rpm_limit = NULL启用自适应模式系统自动学习并调整 RPM 限制
- rpm_limit = 数字:固定限制模式,使用用户指定的 RPM 限制
- learned_rpm_limit自适应模式下学习到的 RPM 限制值
- adaptive_mode 是计算字段,基于 rpm_limit 是否为 NULL
"""
from dataclasses import dataclass
@@ -18,12 +18,13 @@ from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.config.constants import RPMDefaults
from src.core.exceptions import InvalidRequestException, translate_pydantic_error
from src.database import get_db
from src.models.database import ProviderAPIKey
from src.services.rate_limit.adaptive_concurrency import get_adaptive_manager
from src.services.rate_limit.adaptive_rpm import get_adaptive_rpm_manager
router = APIRouter(prefix="/api/admin/adaptive", tags=["Adaptive Concurrency"])
router = APIRouter(prefix="/api/admin/adaptive", tags=["Adaptive RPM"])
pipeline = ApiRequestPipeline()
@@ -35,19 +36,19 @@ class EnableAdaptiveRequest(BaseModel):
enabled: bool = Field(..., description="是否启用自适应模式true=自适应false=固定限制)")
fixed_limit: Optional[int] = Field(
None, ge=1, le=100, description="固定并发限制(仅当 enabled=false 时生效)"
None, ge=1, le=100, description="固定 RPM 限制(仅当 enabled=false 时生效1-100"
)
class AdaptiveStatsResponse(BaseModel):
"""自适应统计响应"""
adaptive_mode: bool = Field(..., description="是否为自适应模式(max_concurrent=NULL")
max_concurrent: Optional[int] = Field(None, description="用户配置的固定限制NULL=自适应)")
adaptive_mode: bool = Field(..., description="是否为自适应模式(rpm_limit=NULL")
rpm_limit: Optional[int] = Field(None, description="用户配置的固定限制NULL=自适应)")
effective_limit: Optional[int] = Field(
None, description="当前有效限制(自适应使用学习值,固定使用配置值)"
)
learned_limit: Optional[int] = Field(None, description="学习到的并发限制")
learned_limit: Optional[int] = Field(None, description="学习到的 RPM 限制")
concurrent_429_count: int
rpm_429_count: int
last_429_at: Optional[str]
@@ -61,11 +62,12 @@ class KeyListItem(BaseModel):
id: str
name: Optional[str]
endpoint_id: str
is_adaptive: bool = Field(..., description="是否为自适应模式max_concurrent=NULL")
max_concurrent: Optional[int] = Field(None, description="固定并发限制NULL=自适应")
provider_id: str
api_formats: List[str] = Field(default_factory=list)
is_adaptive: bool = Field(..., description="是否为自适应模式rpm_limit=NULL")
rpm_limit: Optional[int] = Field(None, description="固定 RPM 限制NULL=自适应)")
effective_limit: Optional[int] = Field(None, description="当前有效限制")
learned_max_concurrent: Optional[int] = Field(None, description="学习到的并发限制")
learned_rpm_limit: Optional[int] = Field(None, description="学习到的 RPM 限制")
concurrent_429_count: int
rpm_429_count: int
@@ -80,22 +82,22 @@ class KeyListItem(BaseModel):
)
async def list_adaptive_keys(
request: Request,
endpoint_id: Optional[str] = Query(None, description="Endpoint 过滤"),
provider_id: Optional[str] = Query(None, description="Provider 过滤"),
db: Session = Depends(get_db),
):
"""
获取所有启用自适应模式的Key列表
可选参数:
- endpoint_id: 按 Endpoint 过滤
- provider_id: 按 Provider 过滤
"""
adapter = ListAdaptiveKeysAdapter(endpoint_id=endpoint_id)
adapter = ListAdaptiveKeysAdapter(provider_id=provider_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.patch(
"/keys/{key_id}/mode",
summary="Toggle key's concurrency control mode",
summary="Toggle key's RPM control mode",
)
async def toggle_adaptive_mode(
key_id: str,
@@ -103,10 +105,10 @@ async def toggle_adaptive_mode(
db: Session = Depends(get_db),
):
"""
Toggle the concurrency control mode for a specific key
Toggle the RPM control mode for a specific key
Parameters:
- enabled: true=adaptive mode (max_concurrent=NULL), false=fixed limit mode
- enabled: true=adaptive mode (rpm_limit=NULL), false=fixed limit mode
- fixed_limit: fixed limit value (required when enabled=false)
"""
adapter = ToggleAdaptiveModeAdapter(key_id=key_id)
@@ -124,7 +126,7 @@ async def get_adaptive_stats(
db: Session = Depends(get_db),
):
"""
获取指定Key的自适应并发统计信息
获取指定Key的自适应 RPM 统计信息
包括:
- 当前配置
@@ -149,12 +151,12 @@ async def reset_adaptive_learning(
Reset the adaptive learning state for a specific key
Clears:
- Learned concurrency limit (learned_max_concurrent)
- Learned RPM limit (learned_rpm_limit)
- 429 error counts
- Adjustment history
Does not change:
- max_concurrent config (determines adaptive mode)
- rpm_limit config (determines adaptive mode)
"""
adapter = ResetAdaptiveLearningAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@@ -162,40 +164,40 @@ async def reset_adaptive_learning(
@router.patch(
"/keys/{key_id}/limit",
summary="Set key to fixed concurrency limit mode",
summary="Set key to fixed RPM limit mode",
)
async def set_concurrent_limit(
async def set_rpm_limit(
key_id: str,
request: Request,
limit: int = Query(..., ge=1, le=100, description="Concurrency limit value"),
limit: int = Query(..., ge=1, le=100, description="RPM limit value (1-100)"),
db: Session = Depends(get_db),
):
"""
Set key to fixed concurrency limit mode
Set key to fixed RPM limit mode
Note:
- After setting this value, key switches to fixed limit mode and won't auto-adjust
- To restore adaptive mode, use PATCH /keys/{key_id}/mode
"""
adapter = SetConcurrentLimitAdapter(key_id=key_id, limit=limit)
adapter = SetRPMLimitAdapter(key_id=key_id, limit=limit)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get(
"/summary",
summary="获取自适应并发的全局统计",
summary="获取自适应 RPM 的全局统计",
)
async def get_adaptive_summary(
request: Request,
db: Session = Depends(get_db),
):
"""
获取自适应并发的全局统计摘要
获取自适应 RPM 的全局统计摘要
包括:
- 启用自适应模式的Key数量
- 总429错误数
- 并发限制调整次数
- RPM 限制调整次数
"""
adapter = AdaptiveSummaryAdapter()
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@@ -206,26 +208,29 @@ async def get_adaptive_summary(
@dataclass
class ListAdaptiveKeysAdapter(AdminApiAdapter):
endpoint_id: Optional[str] = None
provider_id: Optional[str] = None
async def handle(self, context): # type: ignore[override]
# 自适应模式:max_concurrent = NULL
query = context.db.query(ProviderAPIKey).filter(ProviderAPIKey.max_concurrent.is_(None))
if self.endpoint_id:
query = query.filter(ProviderAPIKey.endpoint_id == self.endpoint_id)
# 自适应模式:rpm_limit = NULL
query = context.db.query(ProviderAPIKey).filter(ProviderAPIKey.rpm_limit.is_(None))
if self.provider_id:
query = query.filter(ProviderAPIKey.provider_id == self.provider_id)
keys = query.all()
return [
KeyListItem(
id=key.id,
name=key.name,
endpoint_id=key.endpoint_id,
is_adaptive=key.max_concurrent is None,
max_concurrent=key.max_concurrent,
provider_id=key.provider_id,
api_formats=key.api_formats or [],
is_adaptive=key.rpm_limit is None,
rpm_limit=key.rpm_limit,
effective_limit=(
key.learned_max_concurrent if key.max_concurrent is None else key.max_concurrent
(key.learned_rpm_limit if key.learned_rpm_limit is not None else RPMDefaults.INITIAL_LIMIT)
if key.rpm_limit is None
else key.rpm_limit
),
learned_max_concurrent=key.learned_max_concurrent,
learned_rpm_limit=key.learned_rpm_limit,
concurrent_429_count=key.concurrent_429_count or 0,
rpm_429_count=key.rpm_429_count or 0,
)
@@ -252,28 +257,32 @@ class ToggleAdaptiveModeAdapter(AdminApiAdapter):
raise InvalidRequestException("请求数据验证失败")
if body.enabled:
# 启用自适应模式:将 max_concurrent 设为 NULL
key.max_concurrent = None
message = "已切换为自适应模式,系统将自动学习并调整并发限制"
# 启用自适应模式:将 rpm_limit 设为 NULL
key.rpm_limit = None
message = "已切换为自适应模式,系统将自动学习并调整 RPM 限制"
else:
# 禁用自适应模式:设置固定限制
if body.fixed_limit is None:
raise HTTPException(
status_code=400, detail="禁用自适应模式时必须提供 fixed_limit 参数"
)
key.max_concurrent = body.fixed_limit
message = f"已切换为固定限制模式,并发限制设为 {body.fixed_limit}"
key.rpm_limit = body.fixed_limit
message = f"已切换为固定限制模式,RPM 限制设为 {body.fixed_limit}"
context.db.commit()
context.db.refresh(key)
is_adaptive = key.max_concurrent is None
is_adaptive = key.rpm_limit is None
return {
"message": message,
"key_id": key.id,
"is_adaptive": is_adaptive,
"max_concurrent": key.max_concurrent,
"effective_limit": key.learned_max_concurrent if is_adaptive else key.max_concurrent,
"rpm_limit": key.rpm_limit,
"effective_limit": (
(key.learned_rpm_limit if key.learned_rpm_limit is not None else RPMDefaults.INITIAL_LIMIT)
if is_adaptive
else key.rpm_limit
),
}
@@ -286,13 +295,13 @@ class GetAdaptiveStatsAdapter(AdminApiAdapter):
if not key:
raise HTTPException(status_code=404, detail="Key not found")
adaptive_manager = get_adaptive_manager()
adaptive_manager = get_adaptive_rpm_manager()
stats = adaptive_manager.get_adjustment_stats(key)
# 转换字段名以匹配响应模型
return AdaptiveStatsResponse(
adaptive_mode=stats["adaptive_mode"],
max_concurrent=stats["max_concurrent"],
rpm_limit=stats["rpm_limit"],
effective_limit=stats["effective_limit"],
learned_limit=stats["learned_limit"],
concurrent_429_count=stats["concurrent_429_count"],
@@ -313,13 +322,13 @@ class ResetAdaptiveLearningAdapter(AdminApiAdapter):
if not key:
raise HTTPException(status_code=404, detail="Key not found")
adaptive_manager = get_adaptive_manager()
adaptive_manager = get_adaptive_rpm_manager()
adaptive_manager.reset_learning(context.db, key)
return {"message": "学习状态已重置", "key_id": key.id}
@dataclass
class SetConcurrentLimitAdapter(AdminApiAdapter):
class SetRPMLimitAdapter(AdminApiAdapter):
key_id: str
limit: int
@@ -328,25 +337,25 @@ class SetConcurrentLimitAdapter(AdminApiAdapter):
if not key:
raise HTTPException(status_code=404, detail="Key not found")
was_adaptive = key.max_concurrent is None
key.max_concurrent = self.limit
was_adaptive = key.rpm_limit is None
key.rpm_limit = self.limit
context.db.commit()
context.db.refresh(key)
return {
"message": f"已设置为固定限制模式,并发限制为 {self.limit}",
"message": f"已设置为固定限制模式,RPM 限制为 {self.limit}",
"key_id": key.id,
"is_adaptive": False,
"max_concurrent": key.max_concurrent,
"rpm_limit": key.rpm_limit,
"previous_mode": "adaptive" if was_adaptive else "fixed",
}
class AdaptiveSummaryAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
# 自适应模式:max_concurrent = NULL
# 自适应模式:rpm_limit = NULL
adaptive_keys = (
context.db.query(ProviderAPIKey).filter(ProviderAPIKey.max_concurrent.is_(None)).all()
context.db.query(ProviderAPIKey).filter(ProviderAPIKey.rpm_limit.is_(None)).all()
)
total_keys = len(adaptive_keys)

View File

@@ -1,9 +1,8 @@
"""
Endpoint 并发控制管理 API
Key RPM 限制管理 API
"""
from dataclasses import dataclass
from typing import Optional
from fastapi import APIRouter, Depends, Request
from sqlalchemy.orm import Session
@@ -12,83 +11,56 @@ from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.core.exceptions import NotFoundException
from src.database import get_db
from src.models.database import ProviderAPIKey, ProviderEndpoint
from src.models.endpoint_models import (
ConcurrencyStatusResponse,
ResetConcurrencyRequest,
)
from src.models.database import ProviderAPIKey
from src.models.endpoint_models import KeyRpmStatusResponse
from src.services.rate_limit.concurrency_manager import get_concurrency_manager
router = APIRouter(tags=["Concurrency Control"])
router = APIRouter(tags=["RPM Control"])
pipeline = ApiRequestPipeline()
@router.get("/concurrency/endpoint/{endpoint_id}", response_model=ConcurrencyStatusResponse)
async def get_endpoint_concurrency(
endpoint_id: str,
request: Request,
db: Session = Depends(get_db),
) -> ConcurrencyStatusResponse:
"""
获取 Endpoint 当前并发状态
查询指定 Endpoint 的实时并发使用情况,包括当前并发数和最大并发限制。
**路径参数**:
- `endpoint_id`: Endpoint ID
**返回字段**:
- `endpoint_id`: Endpoint ID
- `endpoint_current_concurrency`: 当前并发数
- `endpoint_max_concurrent`: 最大并发限制
"""
adapter = AdminEndpointConcurrencyAdapter(endpoint_id=endpoint_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.get("/concurrency/key/{key_id}", response_model=ConcurrencyStatusResponse)
async def get_key_concurrency(
@router.get("/rpm/key/{key_id}", response_model=KeyRpmStatusResponse)
async def get_key_rpm(
key_id: str,
request: Request,
db: Session = Depends(get_db),
) -> ConcurrencyStatusResponse:
) -> KeyRpmStatusResponse:
"""
获取 Key 当前并发状态
获取 Key 当前 RPM 状态
查询指定 API Key 的实时并发使用情况,包括当前并发数和最大并发限制。
查询指定 API Key 的实时 RPM 使用情况,包括当前 RPM 计数和最大 RPM 限制。
**路径参数**:
- `key_id`: API Key ID
**返回字段**:
- `key_id`: API Key ID
- `key_current_concurrency`: 当前并发
- `key_max_concurrent`: 最大并发限制
- `current_rpm`: 当前 RPM 计
- `rpm_limit`: RPM 限制
"""
adapter = AdminKeyConcurrencyAdapter(key_id=key_id)
adapter = AdminKeyRpmAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.delete("/concurrency")
async def reset_concurrency(
request: ResetConcurrencyRequest,
@router.delete("/rpm/key/{key_id}")
async def reset_key_rpm(
key_id: str,
http_request: Request,
db: Session = Depends(get_db),
) -> dict:
"""
重置并发计数器
重置 Key RPM 计数器
重置指定 Endpoint 或 Key 的并发计数器,用于解决计数不准确的问题。
重置指定 API Key 的 RPM 计数器,用于解决计数不准确的问题。
管理员功能,请谨慎使用。
**请求体字段**:
- `endpoint_id`: Endpoint ID可选
- `key_id`: API Key ID可选
**路径参数**:
- `key_id`: API Key ID
**返回字段**:
- `message`: 操作结果消息
"""
adapter = AdminResetConcurrencyAdapter(endpoint_id=request.endpoint_id, key_id=request.key_id)
adapter = AdminResetKeyRpmAdapter(key_id=key_id)
return await pipeline.run(adapter=adapter, http_request=http_request, db=db, mode=adapter.mode)
@@ -96,31 +68,7 @@ async def reset_concurrency(
@dataclass
class AdminEndpointConcurrencyAdapter(AdminApiAdapter):
endpoint_id: str
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
concurrency_manager = await get_concurrency_manager()
endpoint_count, _ = await concurrency_manager.get_current_concurrency(
endpoint_id=self.endpoint_id
)
return ConcurrencyStatusResponse(
endpoint_id=self.endpoint_id,
endpoint_current_concurrency=endpoint_count,
endpoint_max_concurrent=endpoint.max_concurrent,
)
@dataclass
class AdminKeyConcurrencyAdapter(AdminApiAdapter):
class AdminKeyRpmAdapter(AdminApiAdapter):
key_id: str
async def handle(self, context): # type: ignore[override]
@@ -130,23 +78,20 @@ class AdminKeyConcurrencyAdapter(AdminApiAdapter):
raise NotFoundException(f"Key {self.key_id} 不存在")
concurrency_manager = await get_concurrency_manager()
_, key_count = await concurrency_manager.get_current_concurrency(key_id=self.key_id)
key_count = await concurrency_manager.get_key_rpm_count(key_id=self.key_id)
return ConcurrencyStatusResponse(
return KeyRpmStatusResponse(
key_id=self.key_id,
key_current_concurrency=key_count,
key_max_concurrent=key.max_concurrent,
current_rpm=key_count,
rpm_limit=key.rpm_limit,
)
@dataclass
class AdminResetConcurrencyAdapter(AdminApiAdapter):
endpoint_id: Optional[str]
key_id: Optional[str]
class AdminResetKeyRpmAdapter(AdminApiAdapter):
key_id: str
async def handle(self, context): # type: ignore[override]
concurrency_manager = await get_concurrency_manager()
await concurrency_manager.reset_concurrency(
endpoint_id=self.endpoint_id, key_id=self.key_id
)
return {"message": "并发计数已重置"}
await concurrency_manager.reset_key_rpm(key_id=self.key_id)
return {"message": "RPM 计数已重置"}

View File

@@ -5,7 +5,7 @@ Endpoint 健康监控 API
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Dict, List
from typing import Dict, List, Optional
from fastapi import APIRouter, Depends, Query, Request
from sqlalchemy import func
@@ -128,29 +128,32 @@ async def get_api_format_health_monitor(
async def get_key_health(
key_id: str,
request: Request,
api_format: Optional[str] = Query(None, description="API 格式(可选,如 CLAUDE、OPENAI"),
db: Session = Depends(get_db),
) -> HealthStatusResponse:
"""
获取 Key 健康状态
获取指定 API Key 的健康状态详情,包括健康分数、连续失败次数、
熔断器状态等信息。
熔断器状态等信息。支持按 API 格式查询。
**路径参数**:
- `key_id`: API Key ID
**查询参数**:
- `api_format`: 可选,指定 API 格式(如 CLAUDE、OPENAI
- 指定时返回该格式的健康度详情
- 不指定时返回所有格式的健康度摘要
**返回字段**:
- `key_id`: API Key ID
- `key_health_score`: 健康分数0.0-1.0
- `key_consecutive_failures`: 连续失败次数
- `key_last_failure_at`: 最后失败时间
- `key_is_active`: 是否活跃
- `key_statistics`: 统计信息
- `circuit_breaker_open`: 熔断器是否打开
- `circuit_breaker_open_at`: 熔断器打开时间
- `next_probe_at`: 下次探测时间
- `health_by_format`: 按格式的健康度数据(无 api_format 参数时)
- `circuit_breaker_open`: 熔断器是否打开(有 api_format 参数时)
"""
adapter = AdminKeyHealthAdapter(key_id=key_id)
adapter = AdminKeyHealthAdapter(key_id=key_id, api_format=api_format)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@@ -158,17 +161,23 @@ async def get_key_health(
async def recover_key_health(
key_id: str,
request: Request,
api_format: Optional[str] = Query(None, description="API 格式(可选,不指定则恢复所有格式)"),
db: Session = Depends(get_db),
) -> dict:
"""
恢复 Key 健康状态
手动恢复指定 Key 的健康状态,将健康分数重置为 1.0,关闭熔断器,
取消自动禁用,并重置所有失败计数。
取消自动禁用,并重置所有失败计数。支持按 API 格式恢复。
**路径参数**:
- `key_id`: API Key ID
**查询参数**:
- `api_format`: 可选,指定 API 格式(如 CLAUDE、OPENAI
- 指定时仅恢复该格式的健康度
- 不指定时恢复所有格式
**返回字段**:
- `message`: 操作结果消息
- `details`: 详细信息
@@ -176,7 +185,7 @@ async def recover_key_health(
- `circuit_breaker_open`: 熔断器状态
- `is_active`: 是否活跃
"""
adapter = AdminRecoverKeyHealthAdapter(key_id=key_id)
adapter = AdminRecoverKeyHealthAdapter(key_id=key_id, api_format=api_format)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@@ -276,34 +285,9 @@ class AdminApiFormatHealthMonitorAdapter(AdminApiAdapter):
)
all_formats[api_format] = provider_count
# 1.1 获取所有活跃的 API 格式及其 API Key 数量
active_keys = (
db.query(
ProviderEndpoint.api_format,
func.count(ProviderAPIKey.id).label("key_count"),
)
.join(ProviderAPIKey, ProviderEndpoint.id == ProviderAPIKey.endpoint_id)
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
.filter(
ProviderEndpoint.is_active.is_(True),
Provider.is_active.is_(True),
ProviderAPIKey.is_active.is_(True),
)
.group_by(ProviderEndpoint.api_format)
.all()
)
# 构建所有格式的 key_count 映射
key_counts: Dict[str, int] = {}
for api_format_enum, key_count in active_keys:
api_format = (
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
)
key_counts[api_format] = key_count
# 1.2 建立每个 API 格式对应的 Endpoint ID 列表,供 Usage 时间线生成使用
# 1.1 建立每个 API 格式对应的 Endpoint ID 列表(用于时间线生成),并收集活跃的 provider+format 组合
endpoint_rows = (
db.query(ProviderEndpoint.api_format, ProviderEndpoint.id)
db.query(ProviderEndpoint.api_format, ProviderEndpoint.id, ProviderEndpoint.provider_id)
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
.filter(
ProviderEndpoint.is_active.is_(True),
@@ -312,11 +296,32 @@ class AdminApiFormatHealthMonitorAdapter(AdminApiAdapter):
.all()
)
endpoint_map: Dict[str, List[str]] = defaultdict(list)
for api_format_enum, endpoint_id in endpoint_rows:
active_provider_formats: set[tuple[str, str]] = set()
for api_format_enum, endpoint_id, provider_id in endpoint_rows:
api_format = (
api_format_enum.value if hasattr(api_format_enum, "value") else str(api_format_enum)
)
endpoint_map[api_format].append(endpoint_id)
active_provider_formats.add((str(provider_id), api_format))
# 1.2 统计每个 API 格式可用的活跃 Key 数量Key 属于 Provider通过 api_formats 关联格式)
key_counts: Dict[str, int] = {}
if active_provider_formats:
active_provider_keys = (
db.query(ProviderAPIKey.provider_id, ProviderAPIKey.api_formats)
.join(Provider, ProviderAPIKey.provider_id == Provider.id)
.filter(
Provider.is_active.is_(True),
ProviderAPIKey.is_active.is_(True),
)
.all()
)
for provider_id, api_formats in active_provider_keys:
pid = str(provider_id)
for fmt in (api_formats or []):
if (pid, fmt) not in active_provider_formats:
continue
key_counts[fmt] = key_counts.get(fmt, 0) + 1
# 2. 统计窗口内每个 API 格式的请求状态分布(真实统计)
# 只统计最终状态success, failed, skipped
@@ -457,28 +462,45 @@ class AdminApiFormatHealthMonitorAdapter(AdminApiAdapter):
@dataclass
class AdminKeyHealthAdapter(AdminApiAdapter):
key_id: str
api_format: Optional[str] = None
async def handle(self, context): # type: ignore[override]
health_data = health_monitor.get_key_health(context.db, self.key_id)
health_data = health_monitor.get_key_health(context.db, self.key_id, self.api_format)
if not health_data:
raise NotFoundException(f"Key {self.key_id} 不存在")
return HealthStatusResponse(
key_id=health_data["key_id"],
key_health_score=health_data["health_score"],
key_consecutive_failures=health_data["consecutive_failures"],
key_last_failure_at=health_data["last_failure_at"],
key_is_active=health_data["is_active"],
key_statistics=health_data["statistics"],
circuit_breaker_open=health_data["circuit_breaker_open"],
circuit_breaker_open_at=health_data["circuit_breaker_open_at"],
next_probe_at=health_data["next_probe_at"],
)
# 构建响应
response_data = {
"key_id": health_data["key_id"],
"key_is_active": health_data["is_active"],
"key_statistics": health_data.get("statistics"),
"key_health_score": health_data.get("health_score", 1.0),
}
if self.api_format:
# 单格式查询
response_data["api_format"] = self.api_format
response_data["key_consecutive_failures"] = health_data.get("consecutive_failures")
response_data["key_last_failure_at"] = health_data.get("last_failure_at")
circuit = health_data.get("circuit_breaker", {})
response_data["circuit_breaker_open"] = circuit.get("open", False)
response_data["circuit_breaker_open_at"] = circuit.get("open_at")
response_data["next_probe_at"] = circuit.get("next_probe_at")
response_data["half_open_until"] = circuit.get("half_open_until")
response_data["half_open_successes"] = circuit.get("half_open_successes", 0)
response_data["half_open_failures"] = circuit.get("half_open_failures", 0)
else:
# 全格式查询
response_data["any_circuit_open"] = health_data.get("any_circuit_open", False)
response_data["health_by_format"] = health_data.get("health_by_format")
return HealthStatusResponse(**response_data)
@dataclass
class AdminRecoverKeyHealthAdapter(AdminApiAdapter):
key_id: str
api_format: Optional[str] = None
async def handle(self, context): # type: ignore[override]
db = context.db
@@ -486,22 +508,32 @@ class AdminRecoverKeyHealthAdapter(AdminApiAdapter):
if not key:
raise NotFoundException(f"Key {self.key_id} 不存在")
key.health_score = 1.0
key.consecutive_failures = 0
key.last_failure_at = None
key.circuit_breaker_open = False
key.circuit_breaker_open_at = None
key.next_probe_at = None
# 使用 health_monitor.reset_health 重置健康度
success = health_monitor.reset_health(db, key_id=self.key_id, api_format=self.api_format)
if not success:
raise Exception("重置健康度失败")
# 如果 Key 被禁用,重新启用
if not key.is_active:
key.is_active = True
key.is_active = True # type: ignore[assignment]
db.commit()
admin_name = context.user.username if context.user else "admin"
logger.info(f"管理员恢复Key健康状态: {self.key_id} (health_score: 1.0, circuit_breaker: closed)")
if self.api_format:
logger.info(f"管理员恢复Key健康状态: {self.key_id}/{self.api_format}")
return {
"message": "Key已完全恢复",
"message": f"Key{self.api_format} 格式已恢复",
"details": {
"api_format": self.api_format,
"health_score": 1.0,
"circuit_breaker_open": False,
"is_active": True,
},
}
else:
logger.info(f"管理员恢复Key健康状态: {self.key_id} (所有格式)")
return {
"message": "Key 所有格式已恢复",
"details": {
"health_score": 1.0,
"circuit_breaker_open": False,
@@ -516,10 +548,17 @@ class AdminRecoverAllKeysHealthAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
# 查找所有熔断的 Key
circuit_open_keys = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.circuit_breaker_open == True).all()
)
# 查找所有熔断格式的 Key(检查 circuit_breaker_by_format JSON 字段)
all_keys = db.query(ProviderAPIKey).all()
# 筛选出有任何格式熔断的 Key
circuit_open_keys = []
for key in all_keys:
circuit_by_format = key.circuit_breaker_by_format or {}
for fmt, circuit_data in circuit_by_format.items():
if circuit_data.get("open"):
circuit_open_keys.append(key)
break
if not circuit_open_keys:
return {
@@ -530,17 +569,15 @@ class AdminRecoverAllKeysHealthAdapter(AdminApiAdapter):
recovered_keys = []
for key in circuit_open_keys:
key.health_score = 1.0
key.consecutive_failures = 0
key.last_failure_at = None
key.circuit_breaker_open = False
key.circuit_breaker_open_at = None
key.next_probe_at = None
# 重置所有格式的健康度
key.health_by_format = {} # type: ignore[assignment]
key.circuit_breaker_by_format = {} # type: ignore[assignment]
recovered_keys.append(
{
"key_id": key.id,
"key_name": key.name,
"endpoint_id": key.endpoint_id,
"provider_id": key.provider_id,
"api_formats": key.api_formats,
}
)
@@ -552,7 +589,6 @@ class AdminRecoverAllKeysHealthAdapter(AdminApiAdapter):
HealthMonitor._open_circuit_keys = 0
health_open_circuits.set(0)
admin_name = context.user.username if context.user else "admin"
logger.info(f"管理员批量恢复 {len(recovered_keys)} 个 Key 的健康状态")
return {

View File

@@ -1,5 +1,5 @@
"""
Endpoint API Keys 管理
Provider API Keys 管理
"""
import uuid
@@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
from src.api.base.admin_adapter import AdminApiAdapter
from src.api.base.pipeline import ApiRequestPipeline
from src.config.constants import RPMDefaults
from src.core.crypto import crypto_service
from src.core.exceptions import InvalidRequestException, NotFoundException
from src.core.key_capabilities import get_capability
@@ -20,96 +21,14 @@ from src.database import get_db
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
from src.services.cache.provider_cache import ProviderCacheService
from src.models.endpoint_models import (
BatchUpdateKeyPriorityRequest,
EndpointAPIKeyCreate,
EndpointAPIKeyResponse,
EndpointAPIKeyUpdate,
)
router = APIRouter(tags=["Endpoint Keys"])
router = APIRouter(tags=["Provider Keys"])
pipeline = ApiRequestPipeline()
@router.get("/{endpoint_id}/keys", response_model=List[EndpointAPIKeyResponse])
async def list_endpoint_keys(
endpoint_id: str,
request: Request,
skip: int = Query(0, ge=0, description="跳过的记录数"),
limit: int = Query(100, ge=1, le=1000, description="返回的最大记录数"),
db: Session = Depends(get_db),
) -> List[EndpointAPIKeyResponse]:
"""
获取 Endpoint 的所有 Keys
获取指定 Endpoint 下的所有 API Key 列表,包括 Key 的配置、统计信息等。
结果按优先级和创建时间排序。
**路径参数**:
- `endpoint_id`: Endpoint ID
**查询参数**:
- `skip`: 跳过的记录数,用于分页(默认 0
- `limit`: 返回的最大记录数1-1000默认 100
**返回字段**:
- `id`: Key ID
- `name`: Key 名称
- `api_key_masked`: 脱敏后的 API Key
- `internal_priority`: 内部优先级
- `global_priority`: 全局优先级
- `rate_multiplier`: 速率倍数
- `max_concurrent`: 最大并发数null 表示自适应模式)
- `is_adaptive`: 是否为自适应并发模式
- `effective_limit`: 有效并发限制
- `success_rate`: 成功率
- `avg_response_time_ms`: 平均响应时间(毫秒)
- 其他配置和统计字段
"""
adapter = AdminListEndpointKeysAdapter(
endpoint_id=endpoint_id,
skip=skip,
limit=limit,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/{endpoint_id}/keys", response_model=EndpointAPIKeyResponse)
async def add_endpoint_key(
endpoint_id: str,
key_data: EndpointAPIKeyCreate,
request: Request,
db: Session = Depends(get_db),
) -> EndpointAPIKeyResponse:
"""
为 Endpoint 添加 Key
为指定 Endpoint 添加新的 API Key支持配置并发限制、速率倍数、
优先级、配额限制、能力限制等。
**路径参数**:
- `endpoint_id`: Endpoint ID
**请求体字段**:
- `endpoint_id`: Endpoint ID必须与路径参数一致
- `api_key`: API Key 原文(将被加密存储)
- `name`: Key 名称
- `note`: 备注(可选)
- `rate_multiplier`: 速率倍数(默认 1.0
- `internal_priority`: 内部优先级(默认 100
- `max_concurrent`: 最大并发数null 表示自适应模式)
- `rate_limit`: 每分钟请求限制(可选)
- `daily_limit`: 每日请求限制(可选)
- `monthly_limit`: 每月请求限制(可选)
- `allowed_models`: 允许的模型列表(可选)
- `capabilities`: 能力配置(可选)
**返回字段**:
- 包含完整的 Key 信息,其中 `api_key_plain` 为原文(仅在创建时返回)
"""
adapter = AdminCreateEndpointKeyAdapter(endpoint_id=endpoint_id, key_data=key_data)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.put("/keys/{key_id}", response_model=EndpointAPIKeyResponse)
async def update_endpoint_key(
key_id: str,
@@ -118,7 +37,7 @@ async def update_endpoint_key(
db: Session = Depends(get_db),
) -> EndpointAPIKeyResponse:
"""
更新 Endpoint Key
更新 Provider Key
更新指定 Key 的配置,支持修改并发限制、速率倍数、优先级、
配额限制、能力限制等。支持部分更新。
@@ -132,10 +51,7 @@ async def update_endpoint_key(
- `note`: 备注
- `rate_multiplier`: 速率倍数
- `internal_priority`: 内部优先级
- `max_concurrent`: 最大并发数(设置为 null 可切换到自适应模式)
- `rate_limit`: 每分钟请求限制
- `daily_limit`: 每日请求限制
- `monthly_limit`: 每月请求限制
- `rpm_limit`: RPM 限制(设置为 null 可切换到自适应模式)
- `allowed_models`: 允许的模型列表
- `capabilities`: 能力配置
- `is_active`: 是否活跃
@@ -210,7 +126,7 @@ async def delete_endpoint_key(
db: Session = Depends(get_db),
) -> dict:
"""
删除 Endpoint Key
删除 Provider Key
删除指定的 API Key。此操作不可逆请谨慎使用。
@@ -224,163 +140,66 @@ async def delete_endpoint_key(
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.put("/{endpoint_id}/keys/batch-priority")
async def batch_update_key_priority(
endpoint_id: str,
request: Request,
priority_data: BatchUpdateKeyPriorityRequest,
db: Session = Depends(get_db),
) -> dict:
"""
批量更新 Endpoint 下 Keys 的优先级
# ========== Provider Keys API ==========
批量更新指定 Endpoint 下多个 Key 的内部优先级,用于拖动排序。
所有 Key 必须属于指定的 Endpoint。
@router.get("/providers/{provider_id}/keys", response_model=List[EndpointAPIKeyResponse])
async def list_provider_keys(
provider_id: str,
request: Request,
skip: int = Query(0, ge=0, description="跳过的记录数"),
limit: int = Query(100, ge=1, le=1000, description="返回的最大记录数"),
db: Session = Depends(get_db),
) -> List[EndpointAPIKeyResponse]:
"""
获取 Provider 的所有 Keys
获取指定 Provider 下的所有 API Key 列表,支持多 API 格式。
结果按优先级和创建时间排序。
**路径参数**:
- `endpoint_id`: Endpoint ID
- `provider_id`: Provider ID
**查询参数**:
- `skip`: 跳过的记录数,用于分页(默认 0
- `limit`: 返回的最大记录数1-1000默认 100
"""
adapter = AdminListProviderKeysAdapter(
provider_id=provider_id,
skip=skip,
limit=limit,
)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@router.post("/providers/{provider_id}/keys", response_model=EndpointAPIKeyResponse)
async def add_provider_key(
provider_id: str,
key_data: EndpointAPIKeyCreate,
request: Request,
db: Session = Depends(get_db),
) -> EndpointAPIKeyResponse:
"""
为 Provider 添加 Key
为指定 Provider 添加新的 API Key支持配置多个 API 格式。
**路径参数**:
- `provider_id`: Provider ID
**请求体字段**:
- `priorities`: 优先级列表
- `key_id`: Key ID
- `internal_priority`: 新的内部优先级
**返回字段**:
- `message`: 操作结果消息
- `updated_count`: 实际更新的 Key 数量
- `api_formats`: 支持的 API 格式列表(必填)
- `api_key`: API Key 原文(将被加密存储)
- `name`: Key 名称
- 其他配置字段同 Key
"""
adapter = AdminBatchUpdateKeyPriorityAdapter(endpoint_id=endpoint_id, priority_data=priority_data)
adapter = AdminCreateProviderKeyAdapter(provider_id=provider_id, key_data=key_data)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
# -------- Adapters --------
@dataclass
class AdminListEndpointKeysAdapter(AdminApiAdapter):
endpoint_id: str
skip: int
limit: int
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
keys = (
db.query(ProviderAPIKey)
.filter(ProviderAPIKey.endpoint_id == self.endpoint_id)
.order_by(ProviderAPIKey.internal_priority.asc(), ProviderAPIKey.created_at.asc())
.offset(self.skip)
.limit(self.limit)
.all()
)
result: List[EndpointAPIKeyResponse] = []
for key in keys:
try:
decrypted_key = crypto_service.decrypt(key.api_key)
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
except Exception:
masked_key = "***ERROR***"
success_rate = key.success_count / key.request_count if key.request_count > 0 else 0.0
avg_response_time_ms = (
key.total_response_time_ms / key.success_count if key.success_count > 0 else 0.0
)
is_adaptive = key.max_concurrent is None
key_dict = key.__dict__.copy()
key_dict.pop("_sa_instance_state", None)
key_dict.update(
{
"api_key_masked": masked_key,
"api_key_plain": None,
"success_rate": success_rate,
"avg_response_time_ms": round(avg_response_time_ms, 2),
"is_adaptive": is_adaptive,
"effective_limit": (
key.learned_max_concurrent if is_adaptive else key.max_concurrent
),
}
)
result.append(EndpointAPIKeyResponse(**key_dict))
return result
@dataclass
class AdminCreateEndpointKeyAdapter(AdminApiAdapter):
endpoint_id: str
key_data: EndpointAPIKeyCreate
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
if self.key_data.endpoint_id != self.endpoint_id:
raise InvalidRequestException("endpoint_id 不匹配")
encrypted_key = crypto_service.encrypt(self.key_data.api_key)
now = datetime.now(timezone.utc)
# max_concurrent=NULL 表示自适应模式,数字表示固定限制
new_key = ProviderAPIKey(
id=str(uuid.uuid4()),
endpoint_id=self.endpoint_id,
api_key=encrypted_key,
name=self.key_data.name,
note=self.key_data.note,
rate_multiplier=self.key_data.rate_multiplier,
internal_priority=self.key_data.internal_priority,
max_concurrent=self.key_data.max_concurrent, # NULL=自适应模式
rate_limit=self.key_data.rate_limit,
daily_limit=self.key_data.daily_limit,
monthly_limit=self.key_data.monthly_limit,
allowed_models=self.key_data.allowed_models if self.key_data.allowed_models else None,
capabilities=self.key_data.capabilities if self.key_data.capabilities else None,
request_count=0,
success_count=0,
error_count=0,
total_response_time_ms=0,
is_active=True,
last_used_at=None,
created_at=now,
updated_at=now,
)
db.add(new_key)
db.commit()
db.refresh(new_key)
logger.info(f"[OK] 添加 Key: Endpoint={self.endpoint_id}, Key=***{self.key_data.api_key[-4:]}, ID={new_key.id}")
masked_key = f"{self.key_data.api_key[:8]}***{self.key_data.api_key[-4:]}"
is_adaptive = new_key.max_concurrent is None
response_dict = new_key.__dict__.copy()
response_dict.pop("_sa_instance_state", None)
response_dict.update(
{
"api_key_masked": masked_key,
"api_key_plain": self.key_data.api_key,
"success_rate": 0.0,
"avg_response_time_ms": 0.0,
"is_adaptive": is_adaptive,
"effective_limit": (
new_key.learned_max_concurrent if is_adaptive else new_key.max_concurrent
),
}
)
return EndpointAPIKeyResponse(**response_dict)
@dataclass
class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
key_id: str
@@ -396,14 +215,21 @@ class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
if "api_key" in update_data:
update_data["api_key"] = crypto_service.encrypt(update_data["api_key"])
# 特殊处理 max_concurrent需要区分"未提供"和"显式设置为 null"
# 当 max_concurrent 被显式设置时(在 model_fields_set 中),即使值为 None 也应该更新
if "max_concurrent" in self.key_data.model_fields_set:
update_data["max_concurrent"] = self.key_data.max_concurrent
# 切换到自适应模式时,清空学习到的并发限制,让系统重新学习
if self.key_data.max_concurrent is None:
update_data["learned_max_concurrent"] = None
logger.info("Key %s 切换为自适应并发模式", self.key_id)
# 特殊处理 rpm_limit需要区分"未提供"和"显式设置为 null"
if "rpm_limit" in self.key_data.model_fields_set:
update_data["rpm_limit"] = self.key_data.rpm_limit
if self.key_data.rpm_limit is None:
update_data["learned_rpm_limit"] = None
logger.info("Key %s 切换为自适应 RPM 模式", self.key_id)
# 统一处理 allowed_models空列表/空字典 -> None表示不限制
if "allowed_models" in update_data:
am = update_data["allowed_models"]
if am is not None and (
(isinstance(am, list) and len(am) == 0)
or (isinstance(am, dict) and len(am) == 0)
):
update_data["allowed_models"] = None
for field, value in update_data.items():
setattr(key, field, value)
@@ -412,39 +238,13 @@ class AdminUpdateEndpointKeyAdapter(AdminApiAdapter):
db.commit()
db.refresh(key)
# 如果更新了 rate_multiplier清除缓存
if "rate_multiplier" in update_data:
# 任何字段更新都清除缓存,确保缓存一致性
# 包括 is_active、allowed_models、capabilities 等影响权限和行为的字段
await ProviderCacheService.invalidate_provider_api_key_cache(self.key_id)
logger.info("[OK] 更新 Key: ID=%s, Updates=%s", self.key_id, list(update_data.keys()))
try:
decrypted_key = crypto_service.decrypt(key.api_key)
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
except Exception:
masked_key = "***ERROR***"
success_rate = key.success_count / key.request_count if key.request_count > 0 else 0.0
avg_response_time_ms = (
key.total_response_time_ms / key.success_count if key.success_count > 0 else 0.0
)
is_adaptive = key.max_concurrent is None
response_dict = key.__dict__.copy()
response_dict.pop("_sa_instance_state", None)
response_dict.update(
{
"api_key_masked": masked_key,
"api_key_plain": None,
"success_rate": success_rate,
"avg_response_time_ms": round(avg_response_time_ms, 2),
"is_adaptive": is_adaptive,
"effective_limit": (
key.learned_max_concurrent if is_adaptive else key.max_concurrent
),
}
)
return EndpointAPIKeyResponse(**response_dict)
return _build_key_response(key)
@dataclass
@@ -481,7 +281,7 @@ class AdminDeleteEndpointKeyAdapter(AdminApiAdapter):
if not key:
raise NotFoundException(f"Key {self.key_id} 不存在")
endpoint_id = key.endpoint_id
provider_id = key.provider_id
try:
db.delete(key)
db.commit()
@@ -490,7 +290,7 @@ class AdminDeleteEndpointKeyAdapter(AdminApiAdapter):
logger.error(f"删除 Key 失败: ID={self.key_id}, Error={exc}")
raise
logger.warning(f"[DELETE] 删除 Key: ID={self.key_id}, Endpoint={endpoint_id}")
logger.warning(f"[DELETE] 删除 Key: ID={self.key_id}, Provider={provider_id}")
return {"message": f"Key {self.key_id} 已删除"}
@@ -498,31 +298,51 @@ class AdminGetKeysGroupedByFormatAdapter(AdminApiAdapter):
async def handle(self, context): # type: ignore[override]
db = context.db
# Key 属于 Provider按 key.api_formats 分组展示
keys = (
db.query(ProviderAPIKey, ProviderEndpoint, Provider)
.join(ProviderEndpoint, ProviderAPIKey.endpoint_id == ProviderEndpoint.id)
.join(Provider, ProviderEndpoint.provider_id == Provider.id)
db.query(ProviderAPIKey, Provider)
.join(Provider, ProviderAPIKey.provider_id == Provider.id)
.filter(
ProviderAPIKey.is_active.is_(True),
ProviderEndpoint.is_active.is_(True),
Provider.is_active.is_(True),
)
.order_by(
ProviderAPIKey.global_priority.asc().nullslast(), ProviderAPIKey.internal_priority.asc()
ProviderAPIKey.global_priority.asc().nullslast(),
ProviderAPIKey.internal_priority.asc(),
)
.all()
)
provider_ids = {str(provider.id) for _key, provider in keys}
endpoints = (
db.query(
ProviderEndpoint.provider_id,
ProviderEndpoint.api_format,
ProviderEndpoint.base_url,
)
.filter(
ProviderEndpoint.provider_id.in_(provider_ids),
ProviderEndpoint.is_active.is_(True),
)
.all()
)
endpoint_base_url_map: Dict[tuple[str, str], str] = {}
for provider_id, api_format, base_url in endpoints:
fmt = api_format.value if hasattr(api_format, "value") else str(api_format)
endpoint_base_url_map[(str(provider_id), fmt)] = base_url
grouped: Dict[str, List[dict]] = {}
for key, endpoint, provider in keys:
api_format = endpoint.api_format
if api_format not in grouped:
grouped[api_format] = []
for key, provider in keys:
api_formats = key.api_formats or []
if not api_formats:
continue # 跳过没有 API 格式的 Key
try:
decrypted_key = crypto_service.decrypt(key.api_key)
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
except Exception:
except Exception as e:
logger.error(f"解密 Key 失败: key_id={key.id}, error={e}")
masked_key = "***ERROR***"
# 计算健康度指标
@@ -541,8 +361,8 @@ class AdminGetKeysGroupedByFormatAdapter(AdminApiAdapter):
cap_def = get_capability(cap_name)
caps_list.append(cap_def.short_name if cap_def else cap_name)
grouped[api_format].append(
{
# 构建 Key 信息(基础数据)
key_info = {
"id": key.id,
"name": key.name,
"api_key_masked": masked_key,
@@ -550,64 +370,200 @@ class AdminGetKeysGroupedByFormatAdapter(AdminApiAdapter):
"global_priority": key.global_priority,
"rate_multiplier": key.rate_multiplier,
"is_active": key.is_active,
"circuit_breaker_open": key.circuit_breaker_open,
"provider_name": provider.display_name or provider.name,
"endpoint_base_url": endpoint.base_url,
"api_format": api_format,
"provider_name": provider.name,
"api_formats": api_formats,
"capabilities": caps_list,
"health_score": key.health_score,
"success_rate": success_rate,
"avg_response_time_ms": avg_response_time_ms,
"request_count": key.request_count,
}
# 将 Key 添加到每个支持的格式分组中,并附加格式特定的健康度数据
health_by_format = key.health_by_format or {}
circuit_by_format = key.circuit_breaker_by_format or {}
provider_id = str(provider.id)
for api_format in api_formats:
if api_format not in grouped:
grouped[api_format] = []
# 为每个格式创建副本,设置当前格式
format_key_info = key_info.copy()
format_key_info["api_format"] = api_format
format_key_info["endpoint_base_url"] = endpoint_base_url_map.get(
(provider_id, api_format)
)
# 添加格式特定的健康度数据
format_health = health_by_format.get(api_format, {})
format_circuit = circuit_by_format.get(api_format, {})
format_key_info["health_score"] = float(format_health.get("health_score") or 1.0)
format_key_info["circuit_breaker_open"] = bool(format_circuit.get("open", False))
grouped[api_format].append(format_key_info)
# 直接返回分组对象,供前端使用
return grouped
# ========== Adapters ==========
def _build_key_response(
key: ProviderAPIKey, api_key_plain: str | None = None
) -> EndpointAPIKeyResponse:
"""构建 Key 响应对象的辅助函数"""
try:
decrypted_key = crypto_service.decrypt(key.api_key)
masked_key = f"{decrypted_key[:8]}***{decrypted_key[-4:]}"
except Exception:
masked_key = "***ERROR***"
success_rate = key.success_count / key.request_count if key.request_count > 0 else 0.0
avg_response_time_ms = (
key.total_response_time_ms / key.success_count if key.success_count > 0 else 0.0
)
is_adaptive = key.rpm_limit is None
key_dict = key.__dict__.copy()
key_dict.pop("_sa_instance_state", None)
# 从 health_by_format 计算汇总字段(便于列表展示)
health_by_format = key.health_by_format or {}
circuit_by_format = key.circuit_breaker_by_format or {}
# 计算整体健康度(取所有格式中的最低值)
if health_by_format:
health_scores = [
float(h.get("health_score") or 1.0) for h in health_by_format.values()
]
min_health_score = min(health_scores) if health_scores else 1.0
# 取最大的连续失败次数
max_consecutive = max(
(int(h.get("consecutive_failures") or 0) for h in health_by_format.values()),
default=0,
)
# 取最近的失败时间
failure_times = [
h.get("last_failure_at")
for h in health_by_format.values()
if h.get("last_failure_at")
]
last_failure = max(failure_times) if failure_times else None
else:
min_health_score = 1.0
max_consecutive = 0
last_failure = None
# 检查是否有任何格式的熔断器打开
any_circuit_open = any(c.get("open", False) for c in circuit_by_format.values())
key_dict.update(
{
"api_key_masked": masked_key,
"api_key_plain": api_key_plain,
"success_rate": success_rate,
"avg_response_time_ms": round(avg_response_time_ms, 2),
"is_adaptive": is_adaptive,
"effective_limit": (
(key.learned_rpm_limit if key.learned_rpm_limit is not None else RPMDefaults.INITIAL_LIMIT)
if is_adaptive
else key.rpm_limit
),
# 汇总字段
"health_score": min_health_score,
"consecutive_failures": max_consecutive,
"last_failure_at": last_failure,
"circuit_breaker_open": any_circuit_open,
}
)
# 防御性:确保 api_formats 存在(历史数据可能为空/缺失)
if "api_formats" not in key_dict or key_dict["api_formats"] is None:
key_dict["api_formats"] = []
return EndpointAPIKeyResponse(**key_dict)
@dataclass
class AdminBatchUpdateKeyPriorityAdapter(AdminApiAdapter):
endpoint_id: str
priority_data: BatchUpdateKeyPriorityRequest
class AdminListProviderKeysAdapter(AdminApiAdapter):
"""获取 Provider 的所有 Keys"""
provider_id: str
skip: int
limit: int
async def handle(self, context): # type: ignore[override]
db = context.db
endpoint = (
db.query(ProviderEndpoint).filter(ProviderEndpoint.id == self.endpoint_id).first()
)
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException(f"Provider {self.provider_id} 不存在")
# 获取所有需要更新的 Key ID
key_ids = [item.key_id for item in self.priority_data.priorities]
# 验证所有 Key 都属于该 Endpoint
keys = (
db.query(ProviderAPIKey)
.filter(
ProviderAPIKey.id.in_(key_ids),
ProviderAPIKey.endpoint_id == self.endpoint_id,
)
.filter(ProviderAPIKey.provider_id == self.provider_id)
.order_by(ProviderAPIKey.internal_priority.asc(), ProviderAPIKey.created_at.asc())
.offset(self.skip)
.limit(self.limit)
.all()
)
if len(keys) != len(key_ids):
found_ids = {k.id for k in keys}
missing_ids = set(key_ids) - found_ids
raise InvalidRequestException(f"Keys 不属于该 Endpoint 或不存在: {missing_ids}")
return [_build_key_response(key) for key in keys]
# 批量更新优先级
key_map = {k.id: k for k in keys}
updated_count = 0
for item in self.priority_data.priorities:
key = key_map.get(item.key_id)
if key and key.internal_priority != item.internal_priority:
key.internal_priority = item.internal_priority
key.updated_at = datetime.now(timezone.utc)
updated_count += 1
@dataclass
class AdminCreateProviderKeyAdapter(AdminApiAdapter):
"""为 Provider 添加 Key"""
provider_id: str
key_data: EndpointAPIKeyCreate
async def handle(self, context): # type: ignore[override]
db = context.db
provider = db.query(Provider).filter(Provider.id == self.provider_id).first()
if not provider:
raise NotFoundException(f"Provider {self.provider_id} 不存在")
# 验证 api_formats 必填
if not self.key_data.api_formats:
raise InvalidRequestException("api_formats 为必填字段")
# 允许同一个 API Key 在同一 Provider 下添加多次
# 用户可以为不同的 API 格式创建独立的配置记录,便于分开管理
encrypted_key = crypto_service.encrypt(self.key_data.api_key)
now = datetime.now(timezone.utc)
new_key = ProviderAPIKey(
id=str(uuid.uuid4()),
provider_id=self.provider_id,
api_formats=self.key_data.api_formats,
api_key=encrypted_key,
name=self.key_data.name,
note=self.key_data.note,
rate_multiplier=self.key_data.rate_multiplier,
rate_multipliers=self.key_data.rate_multipliers, # 按 API 格式的成本倍率
internal_priority=self.key_data.internal_priority,
rpm_limit=self.key_data.rpm_limit,
allowed_models=self.key_data.allowed_models if self.key_data.allowed_models else None,
capabilities=self.key_data.capabilities if self.key_data.capabilities else None,
cache_ttl_minutes=self.key_data.cache_ttl_minutes,
max_probe_interval_minutes=self.key_data.max_probe_interval_minutes,
request_count=0,
success_count=0,
error_count=0,
total_response_time_ms=0,
health_by_format={}, # 按格式存储健康度
circuit_breaker_by_format={}, # 按格式存储熔断器状态
is_active=True,
last_used_at=None,
created_at=now,
updated_at=now,
)
db.add(new_key)
db.commit()
db.refresh(new_key)
logger.info(f"[OK] 批量更新 Key 优先级: Endpoint={self.endpoint_id}, Updated={updated_count}/{len(key_ids)}")
return {"message": f"已更新 {updated_count} 个 Key 的优先级", "updated_count": updated_count}
logger.info(
f"[OK] 添加 Key: Provider={self.provider_id}, "
f"Formats={self.key_data.api_formats}, Key=***{self.key_data.api_key[-4:]}, ID={new_key.id}"
)
return _build_key_response(new_key, api_key_plain=self.key_data.api_key)

View File

@@ -67,8 +67,6 @@ async def list_provider_endpoints(
- `custom_path`: 自定义路径
- `timeout`: 超时时间(秒)
- `max_retries`: 最大重试次数
- `max_concurrent`: 最大并发数
- `rate_limit`: 速率限制
- `is_active`: 是否活跃
- `total_keys`: Key 总数
- `active_keys`: 活跃 Key 数量
@@ -107,8 +105,6 @@ async def create_provider_endpoint(
- `headers`: 自定义请求头(可选)
- `timeout`: 超时时间(秒,默认 300
- `max_retries`: 最大重试次数(默认 2
- `max_concurrent`: 最大并发数(可选)
- `rate_limit`: 速率限制(可选)
- `config`: 额外配置(可选)
- `proxy`: 代理配置(可选)
@@ -145,8 +141,6 @@ async def get_endpoint(
- `custom_path`: 自定义路径
- `timeout`: 超时时间(秒)
- `max_retries`: 最大重试次数
- `max_concurrent`: 最大并发数
- `rate_limit`: 速率限制
- `is_active`: 是否活跃
- `total_keys`: Key 总数
- `active_keys`: 活跃 Key 数量
@@ -178,8 +172,6 @@ async def update_endpoint(
- `headers`: 自定义请求头
- `timeout`: 超时时间(秒)
- `max_retries`: 最大重试次数
- `max_concurrent`: 最大并发数
- `rate_limit`: 速率限制
- `is_active`: 是否活跃
- `config`: 额外配置
- `proxy`: 代理配置(设置为 null 可清除代理)
@@ -203,15 +195,15 @@ async def delete_endpoint(
"""
删除 Endpoint
删除指定的 Endpoint同时级联删除所有关联的 API Keys
此操作不可逆,请谨慎使用
删除指定的 Endpoint会影响该 Provider 在该 API 格式下的路由能力
Key 不会被删除,但包含该 API 格式的 Key 将无法被调度使用(直到重新创建该格式的 Endpoint
**路径参数**:
- `endpoint_id`: Endpoint ID
**返回字段**:
- `message`: 操作结果消息
- `deleted_keys_count`: 同时删除的 Key 数量
- `affected_keys_count`: 受影响的 Key 数量(包含该 API 格式)
"""
adapter = AdminDeleteProviderEndpointAdapter(endpoint_id=endpoint_id)
return await pipeline.run(adapter=adapter, http_request=request, db=db, mode=adapter.mode)
@@ -241,39 +233,33 @@ class AdminListProviderEndpointsAdapter(AdminApiAdapter):
.all()
)
endpoint_ids = [ep.id for ep in endpoints]
total_keys_map = {}
active_keys_map = {}
if endpoint_ids:
total_rows = (
db.query(ProviderAPIKey.endpoint_id, func.count(ProviderAPIKey.id).label("total"))
.filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids))
.group_by(ProviderAPIKey.endpoint_id)
# Key 是 Provider 级别资源:按 key.api_formats 归类到各 Endpoint.api_format 下
keys = (
db.query(ProviderAPIKey.api_formats, ProviderAPIKey.is_active)
.filter(ProviderAPIKey.provider_id == self.provider_id)
.all()
)
total_keys_map = {row.endpoint_id: row.total for row in total_rows}
active_rows = (
db.query(ProviderAPIKey.endpoint_id, func.count(ProviderAPIKey.id).label("active"))
.filter(
and_(
ProviderAPIKey.endpoint_id.in_(endpoint_ids),
ProviderAPIKey.is_active.is_(True),
)
)
.group_by(ProviderAPIKey.endpoint_id)
.all()
)
active_keys_map = {row.endpoint_id: row.active for row in active_rows}
total_keys_map: dict[str, int] = {}
active_keys_map: dict[str, int] = {}
for api_formats, is_active in keys:
for fmt in (api_formats or []):
total_keys_map[fmt] = total_keys_map.get(fmt, 0) + 1
if is_active:
active_keys_map[fmt] = active_keys_map.get(fmt, 0) + 1
result: List[ProviderEndpointResponse] = []
for endpoint in endpoints:
endpoint_format = (
endpoint.api_format
if isinstance(endpoint.api_format, str)
else endpoint.api_format.value
)
endpoint_dict = {
**endpoint.__dict__,
"provider_name": provider.name,
"api_format": endpoint.api_format,
"total_keys": total_keys_map.get(endpoint.id, 0),
"active_keys": active_keys_map.get(endpoint.id, 0),
"total_keys": total_keys_map.get(endpoint_format, 0),
"active_keys": active_keys_map.get(endpoint_format, 0),
"proxy": mask_proxy_password(endpoint.proxy),
}
endpoint_dict.pop("_sa_instance_state", None)
@@ -321,8 +307,6 @@ class AdminCreateProviderEndpointAdapter(AdminApiAdapter):
headers=self.endpoint_data.headers,
timeout=self.endpoint_data.timeout,
max_retries=self.endpoint_data.max_retries,
max_concurrent=self.endpoint_data.max_concurrent,
rate_limit=self.endpoint_data.rate_limit,
is_active=True,
config=self.endpoint_data.config,
proxy=self.endpoint_data.proxy.model_dump() if self.endpoint_data.proxy else None,
@@ -367,19 +351,23 @@ class AdminGetProviderEndpointAdapter(AdminApiAdapter):
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
endpoint_obj, provider = endpoint
total_keys = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
endpoint_format = (
endpoint_obj.api_format
if isinstance(endpoint_obj.api_format, str)
else endpoint_obj.api_format.value
)
active_keys = (
db.query(ProviderAPIKey)
.filter(
and_(
ProviderAPIKey.endpoint_id == self.endpoint_id,
ProviderAPIKey.is_active.is_(True),
)
)
.count()
keys = (
db.query(ProviderAPIKey.api_formats, ProviderAPIKey.is_active)
.filter(ProviderAPIKey.provider_id == endpoint_obj.provider_id)
.all()
)
total_keys = 0
active_keys = 0
for api_formats, is_active in keys:
if endpoint_format in (api_formats or []):
total_keys += 1
if is_active:
active_keys += 1
endpoint_dict = {
k: v
@@ -431,19 +419,21 @@ class AdminUpdateProviderEndpointAdapter(AdminApiAdapter):
provider = db.query(Provider).filter(Provider.id == endpoint.provider_id).first()
logger.info(f"[OK] 更新 Endpoint: ID={self.endpoint_id}, Updates={list(update_data.keys())}")
total_keys = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
endpoint_format = (
endpoint.api_format if isinstance(endpoint.api_format, str) else endpoint.api_format.value
)
active_keys = (
db.query(ProviderAPIKey)
.filter(
and_(
ProviderAPIKey.endpoint_id == self.endpoint_id,
ProviderAPIKey.is_active.is_(True),
)
)
.count()
keys = (
db.query(ProviderAPIKey.api_formats, ProviderAPIKey.is_active)
.filter(ProviderAPIKey.provider_id == endpoint.provider_id)
.all()
)
total_keys = 0
active_keys = 0
for api_formats, is_active in keys:
if endpoint_format in (api_formats or []):
total_keys += 1
if is_active:
active_keys += 1
endpoint_dict = {
k: v
@@ -472,12 +462,26 @@ class AdminDeleteProviderEndpointAdapter(AdminApiAdapter):
if not endpoint:
raise NotFoundException(f"Endpoint {self.endpoint_id} 不存在")
keys_count = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == self.endpoint_id).count()
endpoint_format = (
endpoint.api_format if isinstance(endpoint.api_format, str) else endpoint.api_format.value
)
keys = (
db.query(ProviderAPIKey.api_formats)
.filter(ProviderAPIKey.provider_id == endpoint.provider_id)
.all()
)
affected_keys_count = sum(
1 for (api_formats,) in keys if endpoint_format in (api_formats or [])
)
db.delete(endpoint)
db.commit()
logger.warning(f"[DELETE] 删除 Endpoint: ID={self.endpoint_id}, 同时删除了 {keys_count} 个 Keys")
logger.warning(
f"[DELETE] 删除 Endpoint: ID={self.endpoint_id}, Format={endpoint_format}, "
f"AffectedKeys={affected_keys_count}"
)
return {"message": f"Endpoint {self.endpoint_id} 已删除", "deleted_keys_count": keys_count}
return {
"message": f"Endpoint {self.endpoint_id} 已删除",
"affected_keys_count": affected_keys_count,
}

View File

@@ -125,7 +125,6 @@ class AdminGetModelCatalogAdapter(AdminApiAdapter):
ModelCatalogProviderDetail(
provider_id=provider.id,
provider_name=provider.name,
provider_display_name=provider.display_name,
model_id=model.id,
target_model=model.provider_model_name,
# 显示有效价格

View File

@@ -452,7 +452,6 @@ class AdminGetGlobalModelProvidersAdapter(AdminApiAdapter):
ModelCatalogProviderDetail(
provider_id=provider.id,
provider_name=provider.name,
provider_display_name=provider.display_name,
model_id=model.id,
target_model=model.provider_model_name,
input_price_per_1m=model.get_effective_input_price(),

View File

@@ -819,7 +819,7 @@ class AdminListAffinitiesAdapter(AdminApiAdapter):
"username": user.username if user else None,
"email": user.email if user else None,
"provider_id": provider_id,
"provider_name": provider.display_name if provider else None,
"provider_name": provider.name if provider else None,
"endpoint_id": endpoint_id,
"endpoint_api_format": (
endpoint.api_format if endpoint and endpoint.api_format else None
@@ -1369,9 +1369,7 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
for model, provider in models:
# 检查是否是主模型名称
if model.provider_model_name == mapping_name:
provider_names.append(
provider.display_name or provider.name
)
provider_names.append(provider.name)
continue
# 检查是否在映射列表中
if model.provider_model_mappings:
@@ -1381,9 +1379,7 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
if isinstance(a, dict)
]
if mapping_name in mapping_list:
provider_names.append(
provider.display_name or provider.name
)
provider_names.append(provider.name)
provider_names = sorted(list(set(provider_names)))
mappings.append({
@@ -1473,7 +1469,7 @@ class AdminModelMappingCacheStatsAdapter(AdminApiAdapter):
provider_model_mappings.append({
"provider_id": provider_id,
"provider_name": provider.display_name or provider.name,
"provider_name": provider.name,
"global_model_id": global_model_id,
"global_model_name": global_model.name,
"global_model_display_name": global_model.display_name,

View File

@@ -13,10 +13,11 @@ from sqlalchemy.orm import Session, joinedload
from src.api.handlers.base.chat_adapter_base import get_adapter_class
from src.api.handlers.base.cli_adapter_base import get_cli_adapter_class
from src.config.constants import TimeoutDefaults
from src.core.crypto import crypto_service
from src.core.logger import logger
from src.database.database import get_db
from src.models.database import Provider, ProviderEndpoint, User
from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint, User
from src.utils.auth_utils import get_current_user
router = APIRouter(prefix="/api/admin/provider-query", tags=["Provider Query"])
@@ -81,10 +82,13 @@ async def query_available_models(
Returns:
所有端点的模型列表(合并)
"""
# 获取提供商及其端点
# 获取提供商及其端点和 API Keys
provider = (
db.query(Provider)
.options(joinedload(Provider.endpoints).joinedload(ProviderEndpoint.api_keys))
.options(
joinedload(Provider.endpoints),
joinedload(Provider.api_keys),
)
.filter(Provider.id == request.provider_id)
.first()
)
@@ -95,42 +99,63 @@ async def query_available_models(
# 收集所有活跃端点的配置
endpoint_configs: list[dict] = []
if request.api_key_id:
# 指定了特定的 API Key只使用该 Key 对应的端点
# 构建 api_format -> endpoint 映射
format_to_endpoint: dict[str, ProviderEndpoint] = {}
for endpoint in provider.endpoints:
for api_key in endpoint.api_keys:
if api_key.id == request.api_key_id:
if endpoint.is_active:
format_to_endpoint[endpoint.api_format] = endpoint
if request.api_key_id:
# 指定了特定的 API Key从 provider.api_keys 查找)
api_key = next(
(key for key in provider.api_keys if key.id == request.api_key_id),
None
)
if not api_key:
raise HTTPException(status_code=404, detail="API Key not found")
try:
api_key_value = crypto_service.decrypt(api_key.api_key)
except Exception as e:
logger.error(f"Failed to decrypt API key: {e}")
raise HTTPException(status_code=500, detail="Failed to decrypt API key")
# 根据 Key 的 api_formats 找对应的 Endpoint
key_formats = api_key.api_formats or []
for fmt in key_formats:
endpoint = format_to_endpoint.get(fmt)
if endpoint:
endpoint_configs.append({
"api_key": api_key_value,
"base_url": endpoint.base_url,
"api_format": endpoint.api_format,
"api_format": fmt,
"extra_headers": endpoint.headers,
})
break
if endpoint_configs:
break
if not endpoint_configs:
raise HTTPException(status_code=404, detail="API Key not found")
raise HTTPException(
status_code=400,
detail="No matching endpoint found for this API Key's formats"
)
else:
# 遍历所有活跃端点,每个端点取第一个可用的 Key
# 遍历所有活跃端点,每个端点找一个支持该格式的 Key
for endpoint in provider.endpoints:
if not endpoint.is_active or not endpoint.api_keys:
if not endpoint.is_active:
continue
# 找第一个可用 Key
for api_key in endpoint.api_keys:
if api_key.is_active:
# 找第一个支持该格式的可用 Key
for api_key in provider.api_keys:
if not api_key.is_active:
continue
key_formats = api_key.api_formats or []
if endpoint.api_format not in key_formats:
continue
try:
api_key_value = crypto_service.decrypt(api_key.api_key)
except Exception as e:
logger.error(f"Failed to decrypt API key: {e}")
continue # 尝试下一个 Key
continue
endpoint_configs.append({
"api_key": api_key_value,
"base_url": endpoint.base_url,
@@ -214,7 +239,6 @@ async def query_available_models(
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
}
@@ -229,17 +253,14 @@ async def test_model(
测试模型连接性
向指定提供商的指定模型发送测试请求,验证模型是否可用
Args:
request: 测试请求
Returns:
测试结果
"""
# 获取提供商及其端点
# 获取提供商及其端点和 Keys
provider = (
db.query(Provider)
.options(joinedload(Provider.endpoints).joinedload(ProviderEndpoint.api_keys))
.options(
joinedload(Provider.endpoints),
joinedload(Provider.api_keys),
)
.filter(Provider.id == request.provider_id)
.first()
)
@@ -247,28 +268,38 @@ async def test_model(
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
# 构建 api_format -> endpoint 映射
format_to_endpoint: dict[str, ProviderEndpoint] = {}
for ep in provider.endpoints:
if ep.is_active:
format_to_endpoint[ep.api_format] = ep
# 找到合适的端点和 API Key
endpoint_config = None
endpoint = None
api_key = None
if request.api_key_id:
# 使用指定的 API Key
for ep in provider.endpoints:
for key in ep.api_keys:
if key.id == request.api_key_id and key.is_active and ep.is_active:
endpoint = ep
api_key = key
break
if endpoint:
api_key = next(
(key for key in provider.api_keys if key.id == request.api_key_id and key.is_active),
None
)
if api_key:
# 找到该 Key 支持的第一个活跃 Endpoint
for fmt in (api_key.api_formats or []):
if fmt in format_to_endpoint:
endpoint = format_to_endpoint[fmt]
break
else:
# 使用第一个可用的端点和密钥
for ep in provider.endpoints:
if not ep.is_active or not ep.api_keys:
if not ep.is_active:
continue
for key in ep.api_keys:
if key.is_active:
# 找支持该格式的第一个可用 Key
for key in provider.api_keys:
if not key.is_active:
continue
if ep.api_format in (key.api_formats or []):
endpoint = ep
api_key = key
break
@@ -284,14 +315,14 @@ async def test_model(
logger.error(f"[test-model] Failed to decrypt API key: {e}")
raise HTTPException(status_code=500, detail="Failed to decrypt API key")
# 构建请求配置
# 构建请求配置timeout 从 Provider 读取)
endpoint_config = {
"api_key": api_key_value,
"api_key_id": api_key.id, # 添加API Key ID用于用量记录
"base_url": endpoint.base_url,
"api_format": endpoint.api_format,
"extra_headers": endpoint.headers,
"timeout": endpoint.timeout or 30.0,
"timeout": provider.timeout or TimeoutDefaults.HTTP_REQUEST,
}
try:
@@ -304,7 +335,6 @@ async def test_model(
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
"model": request.model_name,
}
@@ -325,7 +355,6 @@ async def test_model(
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
"model": request.model_name,
}
@@ -415,7 +444,6 @@ async def test_model(
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
"model": request.model_name,
"endpoint": {
@@ -433,7 +461,6 @@ async def test_model(
"provider": {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
},
"model": request.model_name,
"endpoint": {

View File

@@ -78,7 +78,7 @@ async def get_provider_stats(
"""
获取提供商统计数据
获取指定提供商的计费信息、RPM 使用情况和使用统计数据。
获取指定提供商的计费信息和使用统计数据。
**路径参数**:
- `provider_id`: 提供商 ID
@@ -96,10 +96,6 @@ async def get_provider_stats(
- `monthly_used_usd`: 月度已使用
- `quota_remaining_usd`: 剩余配额
- `quota_expires_at`: 配额过期时间
- `rpm_info`: RPM 信息
- `rpm_limit`: RPM 限制
- `rpm_used`: 已使用 RPM
- `rpm_reset_at`: RPM 重置时间
- `usage_stats`: 使用统计
- `total_requests`: 总请求数
- `successful_requests`: 成功请求数
@@ -165,7 +161,6 @@ class AdminProviderBillingAdapter(AdminApiAdapter):
provider.billing_type = config.billing_type
provider.monthly_quota_usd = config.monthly_quota_usd
provider.quota_reset_day = config.quota_reset_day
provider.rpm_limit = config.rpm_limit
provider.provider_priority = config.provider_priority
from dateutil import parser
@@ -262,13 +257,6 @@ class AdminProviderStatsAdapter(AdminApiAdapter):
provider.quota_expires_at.isoformat() if provider.quota_expires_at else None
),
},
"rpm_info": {
"rpm_limit": provider.rpm_limit,
"rpm_used": provider.rpm_used,
"rpm_reset_at": (
provider.rpm_reset_at.isoformat() if provider.rpm_reset_at else None
),
},
"usage_stats": {
"total_requests": total_requests,
"successful_requests": total_success,
@@ -296,8 +284,6 @@ class AdminProviderResetQuotaAdapter(AdminApiAdapter):
old_used = provider.monthly_used_usd
provider.monthly_used_usd = 0.0
provider.rpm_used = 0
provider.rpm_reset_at = None
db.commit()
logger.info(f"Manually reset quota for provider {provider.name}")

View File

@@ -338,27 +338,29 @@ async def import_models_from_upstream(
"""
从上游提供商导入模型
从上游提供商导入模型列表。如果全局模型不存在,将自动创建。
从上游提供商导入模型列表。导入的模型作为独立的 ProviderModel 存储,
不会自动创建 GlobalModel。后续需要手动关联 GlobalModel 才能参与路由。
**流程说明**:
1. 根据 model_ids 检查全局模型是否存在(按 name 匹配)
2. 如不存在,自动创建新的 GlobalModel(使用默认免费配置
3. 创建 Model 关联到当前 Provider
4. 如模型已关联,则记录到成功列表中
1. 检查模型是否存在于当前 Provider按 provider_model_name 匹配)
2. 创建新的 ProviderModelglobal_model_id = NULL
3. 支持设置价格覆盖tiered_pricing, price_per_request
**路径参数**:
- `provider_id`: 提供商 ID
**请求体字段**:
- `model_ids`: 模型 ID 数组(必填,每个 ID 长度 1-100 字符)
- `tiered_pricing`: 可选的阶梯计费配置(应用于所有导入的模型)
- `price_per_request`: 可选的按次计费价格(应用于所有导入的模型)
**返回字段**:
- `success`: 成功导入的模型数组,每项包含:
- `model_id`: 模型 ID
- `global_model_id`: 全局模型 ID
- `global_model_name`: 全局模型名称
- `provider_model_id`: 提供商模型 ID
- `created_global_model`: 是否新创建了全局模型
- `global_model_id`: 全局模型 ID如果已关联
- `global_model_name`: 全局模型名称(如果已关联)
- `created_global_model`: 是否新创建了全局模型(始终为 false
- `errors`: 失败的模型数组,每项包含:
- `model_id`: 模型 ID
- `error`: 错误信息
@@ -638,7 +640,7 @@ class AdminBatchAssignModelsToProviderAdapter(AdminApiAdapter):
@dataclass
class AdminImportFromUpstreamAdapter(AdminApiAdapter):
"""从上游提供商导入模型"""
"""从上游提供商导入模型(不创建 GlobalModel作为独立 ProviderModel"""
provider_id: str
payload: ImportFromUpstreamRequest
@@ -652,16 +654,13 @@ class AdminImportFromUpstreamAdapter(AdminApiAdapter):
success: list[ImportFromUpstreamSuccessItem] = []
errors: list[ImportFromUpstreamErrorItem] = []
# 默认阶梯计费配置(免费)
default_tiered_pricing = {
"tiers": [
{
"up_to": None,
"input_price_per_1m": 0.0,
"output_price_per_1m": 0.0,
}
]
}
# 获取价格覆盖配置
tiered_pricing = None
price_per_request = None
if hasattr(self.payload, 'tiered_pricing') and self.payload.tiered_pricing:
tiered_pricing = self.payload.tiered_pricing
if hasattr(self.payload, 'price_per_request') and self.payload.price_per_request is not None:
price_per_request = self.payload.price_per_request
for model_id in self.payload.model_ids:
# 输入验证:检查 model_id 长度
@@ -678,56 +677,37 @@ class AdminImportFromUpstreamAdapter(AdminApiAdapter):
# 使用 savepoint 确保单个模型导入的原子性
savepoint = db.begin_nested()
try:
# 1. 检查是否已存在同名的 GlobalModel
global_model = (
db.query(GlobalModel).filter(GlobalModel.name == model_id).first()
)
created_global_model = False
if not global_model:
# 2. 创建新的 GlobalModel
global_model = GlobalModel(
name=model_id,
display_name=model_id,
default_tiered_pricing=default_tiered_pricing,
is_active=True,
)
db.add(global_model)
db.flush()
created_global_model = True
logger.info(
f"Created new GlobalModel: {model_id} during upstream import"
)
# 3. 检查是否已存在关联
# 1. 检查是否已存在同名的 ProviderModel
existing = (
db.query(Model)
.filter(
Model.provider_id == self.provider_id,
Model.global_model_id == global_model.id,
Model.provider_model_name == model_id,
)
.first()
)
if existing:
# 已存在关联,提交 savepoint 并记录成功
# 已存在,提交 savepoint 并记录成功
savepoint.commit()
success.append(
ImportFromUpstreamSuccessItem(
model_id=model_id,
global_model_id=global_model.id,
global_model_name=global_model.name,
global_model_id=existing.global_model_id or "",
global_model_name=existing.global_model.name if existing.global_model else "",
provider_model_id=existing.id,
created_global_model=created_global_model,
created_global_model=False,
)
)
continue
# 4. 创建新的 Model 记录
# 2. 创建新的 Model 记录(不关联 GlobalModel
new_model = Model(
provider_id=self.provider_id,
global_model_id=global_model.id,
provider_model_name=global_model.name,
global_model_id=None, # 独立模型,不关联 GlobalModel
provider_model_name=model_id,
is_active=True,
tiered_pricing=tiered_pricing,
price_per_request=price_per_request,
)
db.add(new_model)
db.flush()
@@ -737,12 +717,15 @@ class AdminImportFromUpstreamAdapter(AdminApiAdapter):
success.append(
ImportFromUpstreamSuccessItem(
model_id=model_id,
global_model_id=global_model.id,
global_model_name=global_model.name,
global_model_id="", # 未关联
global_model_name="", # 未关联
provider_model_id=new_model.id,
created_global_model=created_global_model,
created_global_model=False,
)
)
logger.info(
f"Created independent ProviderModel: {model_id} for provider {provider.name}"
)
except Exception as e:
# 回滚到 savepoint
savepoint.rollback()
@@ -753,11 +736,9 @@ class AdminImportFromUpstreamAdapter(AdminApiAdapter):
db.commit()
logger.info(
f"Imported {len(success)} models from upstream to provider {provider.name} by {context.user.username}"
f"Imported {len(success)} independent models to provider {provider.name} by {context.user.username}"
)
# 清除 /v1/models 列表缓存
if success:
await invalidate_models_list_cache()
# 不需要清除 /v1/models 缓存,因为独立模型不参与路由
return ImportFromUpstreamResponse(success=success, errors=errors)

View File

@@ -41,8 +41,7 @@ async def list_providers(
**返回字段**:
- `id`: 提供商 ID
- `name`: 提供商名称(唯一标识
- `display_name`: 显示名称
- `name`: 提供商名称(唯一)
- `api_format`: API 格式(如 claude、openai、gemini 等)
- `base_url`: API 基础 URL
- `api_key`: API 密钥(脱敏显示)
@@ -63,8 +62,7 @@ async def create_provider(request: Request, db: Session = Depends(get_db)):
创建一个新的 AI 模型提供商配置。
**请求体字段**:
- `name`: 提供商名称(必填,唯一,用于系统标识
- `display_name`: 显示名称(必填)
- `name`: 提供商名称(必填,唯一)
- `description`: 描述信息(可选)
- `website`: 官网地址(可选)
- `billing_type`: 计费类型可选pay_as_you_go/subscription/prepaid默认 pay_as_you_go
@@ -72,16 +70,17 @@ async def create_provider(request: Request, db: Session = Depends(get_db)):
- `quota_reset_day`: 配额重置日期1-31可选
- `quota_last_reset_at`: 上次配额重置时间(可选)
- `quota_expires_at`: 配额过期时间(可选)
- `rpm_limit`: 每分钟请求数限制(可选)
- `provider_priority`: 提供商优先级(数字越小优先级越高,默认 100
- `is_active`: 是否启用(默认 true
- `concurrent_limit`: 并发限制(可选)
- `timeout`: 请求超时(秒,可选)
- `max_retries`: 最大重试次数(可选)
- `proxy`: 代理配置(可选)
- `config`: 额外配置信息JSON可选
**返回字段**:
- `id`: 新创建的提供商 ID
- `name`: 提供商名称
- `display_name`: 显示名称
- `message`: 成功提示信息
"""
adapter = AdminCreateProviderAdapter()
@@ -100,7 +99,6 @@ async def update_provider(provider_id: str, request: Request, db: Session = Depe
**请求体字段**(所有字段可选):
- `name`: 提供商名称
- `display_name`: 显示名称
- `description`: 描述信息
- `website`: 官网地址
- `billing_type`: 计费类型pay_as_you_go/subscription/prepaid
@@ -108,10 +106,12 @@ async def update_provider(provider_id: str, request: Request, db: Session = Depe
- `quota_reset_day`: 配额重置日期1-31
- `quota_last_reset_at`: 上次配额重置时间
- `quota_expires_at`: 配额过期时间
- `rpm_limit`: 每分钟请求数限制
- `provider_priority`: 提供商优先级
- `is_active`: 是否启用
- `concurrent_limit`: 并发限制
- `timeout`: 请求超时(秒)
- `max_retries`: 最大重试次数
- `proxy`: 代理配置
- `config`: 额外配置信息JSON
**返回字段**:
@@ -165,7 +165,6 @@ class AdminListProvidersAdapter(AdminApiAdapter):
{
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
"api_format": api_format.value if api_format else None,
"base_url": base_url,
"api_key": "***" if api_key else None,
@@ -217,7 +216,6 @@ class AdminCreateProviderAdapter(AdminApiAdapter):
# 创建 Provider 对象
provider = Provider(
name=validated_data.name,
display_name=validated_data.display_name,
description=validated_data.description,
website=validated_data.website,
billing_type=billing_type,
@@ -225,10 +223,12 @@ class AdminCreateProviderAdapter(AdminApiAdapter):
quota_reset_day=validated_data.quota_reset_day,
quota_last_reset_at=validated_data.quota_last_reset_at,
quota_expires_at=validated_data.quota_expires_at,
rpm_limit=validated_data.rpm_limit,
provider_priority=validated_data.provider_priority,
is_active=validated_data.is_active,
concurrent_limit=validated_data.concurrent_limit,
timeout=validated_data.timeout,
max_retries=validated_data.max_retries,
proxy=validated_data.proxy.model_dump() if validated_data.proxy else None,
config=validated_data.config,
)
@@ -248,7 +248,6 @@ class AdminCreateProviderAdapter(AdminApiAdapter):
return {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
"message": "提供商创建成功",
}
except InvalidRequestException:
@@ -291,6 +290,9 @@ class AdminUpdateProviderAdapter(AdminApiAdapter):
if field == "billing_type" and value is not None:
# billing_type 需要转换为枚举
setattr(provider, field, ProviderBillingType(value))
elif field == "proxy" and value is not None:
# proxy 需要转换为 dict如果是 Pydantic 模型)
setattr(provider, field, value if isinstance(value, dict) else value.model_dump())
else:
setattr(provider, field, value)

View File

@@ -48,7 +48,6 @@ async def get_providers_summary(
**返回字段**(数组,每项包含):
- `id`: 提供商 ID
- `name`: 提供商名称
- `display_name`: 显示名称
- `description`: 描述信息
- `website`: 官网地址
- `provider_priority`: 优先级
@@ -59,9 +58,9 @@ async def get_providers_summary(
- `quota_reset_day`: 配额重置日期
- `quota_last_reset_at`: 上次配额重置时间
- `quota_expires_at`: 配额过期时间
- `rpm_limit`: RPM 限制
- `rpm_used`: 已使用 RPM
- `rpm_reset_at`: RPM 重置时间
- `timeout`: 默认请求超时(秒)
- `max_retries`: 默认最大重试次数
- `proxy`: 默认代理配置
- `total_endpoints`: 端点总数
- `active_endpoints`: 活跃端点数
- `total_keys`: 密钥总数
@@ -96,7 +95,6 @@ async def get_provider_summary(
**返回字段**:
- `id`: 提供商 ID
- `name`: 提供商名称
- `display_name`: 显示名称
- `description`: 描述信息
- `website`: 官网地址
- `provider_priority`: 优先级
@@ -107,9 +105,9 @@ async def get_provider_summary(
- `quota_reset_day`: 配额重置日期
- `quota_last_reset_at`: 上次配额重置时间
- `quota_expires_at`: 配额过期时间
- `rpm_limit`: RPM 限制
- `rpm_used`: 已使用 RPM
- `rpm_reset_at`: RPM 重置时间
- `timeout`: 默认请求超时(秒)
- `max_retries`: 默认最大重试次数
- `proxy`: 默认代理配置
- `total_endpoints`: 端点总数
- `active_endpoints`: 活跃端点数
- `total_keys`: 密钥总数
@@ -185,13 +183,13 @@ async def update_provider_settings(
"""
更新提供商基础配置
更新提供商的基础配置信息,如显示名称、描述、优先级等。只需传入需要更新的字段。
更新提供商的基础配置信息,如名称、描述、优先级等。只需传入需要更新的字段。
**路径参数**:
- `provider_id`: 提供商 ID
**请求体字段**(所有字段可选):
- `display_name`: 显示名称
- `name`: 提供商名称
- `description`: 描述信息
- `website`: 官网地址
- `provider_priority`: 优先级
@@ -199,9 +197,10 @@ async def update_provider_settings(
- `billing_type`: 计费类型
- `monthly_quota_usd`: 月度配额(美元)
- `quota_reset_day`: 配额重置日期
- `quota_last_reset_at`: 上次配额重置时间
- `quota_expires_at`: 配额过期时间
- `rpm_limit`: RPM 限制
- `timeout`: 默认请求超时(秒)
- `max_retries`: 默认最大重试次数
- `proxy`: 默认代理配置
**返回字段**: 返回更新后的提供商摘要信息(与 GET /summary 接口返回格式相同)
"""
@@ -215,16 +214,16 @@ def _build_provider_summary(db: Session, provider: Provider) -> ProviderWithEndp
total_endpoints = len(endpoints)
active_endpoints = sum(1 for e in endpoints if e.is_active)
endpoint_ids = [e.id for e in endpoints]
# Key 统计(合并为单个查询)
total_keys = 0
active_keys = 0
if endpoint_ids:
key_stats = db.query(
key_stats = (
db.query(
func.count(ProviderAPIKey.id).label("total"),
func.sum(case((ProviderAPIKey.is_active == True, 1), else_=0)).label("active"),
).filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids)).first()
)
.filter(ProviderAPIKey.provider_id == provider.id)
.first()
)
total_keys = key_stats.total or 0
active_keys = int(key_stats.active or 0)
@@ -238,25 +237,34 @@ def _build_provider_summary(db: Session, provider: Provider) -> ProviderWithEndp
api_formats = [e.api_format for e in endpoints]
# 优化: 一次性加载所有 endpoint 的 keys避免 N+1 查询
all_keys = []
if endpoint_ids:
all_keys = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids)).all()
)
# 优化: 一次性加载 Provider 的 keys避免 N+1 查询
all_keys = db.query(ProviderAPIKey).filter(ProviderAPIKey.provider_id == provider.id).all()
# 按 endpoint_id 分组 keys
keys_by_endpoint: dict[str, list[ProviderAPIKey]] = {}
# 按 api_formats 分组 keys通过 api_formats 关联)
format_to_endpoint_id: dict[str, str] = {e.api_format: e.id for e in endpoints}
keys_by_endpoint: dict[str, list[ProviderAPIKey]] = {e.id: [] for e in endpoints}
for key in all_keys:
if key.endpoint_id not in keys_by_endpoint:
keys_by_endpoint[key.endpoint_id] = []
keys_by_endpoint[key.endpoint_id].append(key)
formats = key.api_formats or []
for fmt in formats:
endpoint_id = format_to_endpoint_id.get(fmt)
if endpoint_id:
keys_by_endpoint[endpoint_id].append(key)
endpoint_health_map: dict[str, float] = {}
for endpoint in endpoints:
keys = keys_by_endpoint.get(endpoint.id, [])
if keys:
health_scores = [k.health_score for k in keys if k.health_score is not None]
# 从 health_by_format 获取对应格式的健康度
api_fmt = endpoint.api_format
health_scores = []
for k in keys:
health_by_format = k.health_by_format or {}
if api_fmt in health_by_format:
score = health_by_format[api_fmt].get("health_score")
if score is not None:
health_scores.append(float(score))
else:
health_scores.append(1.0) # 默认健康度
avg_health = sum(health_scores) / len(health_scores) if health_scores else 1.0
endpoint_health_map[endpoint.id] = avg_health
else:
@@ -284,7 +292,6 @@ def _build_provider_summary(db: Session, provider: Provider) -> ProviderWithEndp
return ProviderWithEndpointsSummary(
id=provider.id,
name=provider.name,
display_name=provider.display_name,
description=provider.description,
website=provider.website,
provider_priority=provider.provider_priority,
@@ -295,9 +302,9 @@ def _build_provider_summary(db: Session, provider: Provider) -> ProviderWithEndp
quota_reset_day=provider.quota_reset_day,
quota_last_reset_at=provider.quota_last_reset_at,
quota_expires_at=provider.quota_expires_at,
rpm_limit=provider.rpm_limit,
rpm_used=provider.rpm_used,
rpm_reset_at=provider.rpm_reset_at,
timeout=provider.timeout,
max_retries=provider.max_retries,
proxy=provider.proxy,
total_endpoints=total_endpoints,
active_endpoints=active_endpoints,
total_keys=total_keys,
@@ -341,7 +348,7 @@ class AdminProviderHealthMonitorAdapter(AdminApiAdapter):
if not endpoint_ids:
response = ProviderEndpointHealthMonitorResponse(
provider_id=provider.id,
provider_name=provider.display_name or provider.name,
provider_name=provider.name,
generated_at=now,
endpoints=[],
)
@@ -416,7 +423,7 @@ class AdminProviderHealthMonitorAdapter(AdminApiAdapter):
response = ProviderEndpointHealthMonitorResponse(
provider_id=provider.id,
provider_name=provider.display_name or provider.name,
provider_name=provider.name,
generated_at=now,
endpoints=endpoint_monitors,
)

View File

@@ -730,9 +730,26 @@ class AdminExportConfigAdapter(AdminApiAdapter):
)
endpoints_data = []
for ep in endpoints:
# 导出 Endpoint Keys
endpoints_data.append(
{
"api_format": ep.api_format,
"base_url": ep.base_url,
"headers": ep.headers,
"timeout": ep.timeout,
"max_retries": ep.max_retries,
"is_active": ep.is_active,
"custom_path": ep.custom_path,
"config": ep.config,
"proxy": ep.proxy,
}
)
# 导出 Provider Keys按 provider_id 归属,包含 api_formats
keys = (
db.query(ProviderAPIKey).filter(ProviderAPIKey.endpoint_id == ep.id).all()
db.query(ProviderAPIKey)
.filter(ProviderAPIKey.provider_id == provider.id)
.order_by(ProviderAPIKey.internal_priority.asc(), ProviderAPIKey.created_at.asc())
.all()
)
keys_data = []
for key in keys:
@@ -747,35 +764,20 @@ class AdminExportConfigAdapter(AdminApiAdapter):
"api_key": decrypted_key,
"name": key.name,
"note": key.note,
"api_formats": key.api_formats or [],
"rate_multiplier": key.rate_multiplier,
"rate_multipliers": key.rate_multipliers,
"internal_priority": key.internal_priority,
"global_priority": key.global_priority,
"max_concurrent": key.max_concurrent,
"rate_limit": key.rate_limit,
"daily_limit": key.daily_limit,
"monthly_limit": key.monthly_limit,
"rpm_limit": key.rpm_limit,
"allowed_models": key.allowed_models,
"capabilities": key.capabilities,
"cache_ttl_minutes": key.cache_ttl_minutes,
"max_probe_interval_minutes": key.max_probe_interval_minutes,
"is_active": key.is_active,
}
)
endpoints_data.append(
{
"api_format": ep.api_format,
"base_url": ep.base_url,
"headers": ep.headers,
"timeout": ep.timeout,
"max_retries": ep.max_retries,
"max_concurrent": ep.max_concurrent,
"rate_limit": ep.rate_limit,
"is_active": ep.is_active,
"custom_path": ep.custom_path,
"config": ep.config,
"keys": keys_data,
}
)
# 导出 Provider Models
models = db.query(Model).filter(Model.provider_id == provider.id).all()
models_data = []
@@ -804,24 +806,26 @@ class AdminExportConfigAdapter(AdminApiAdapter):
providers_data.append(
{
"name": provider.name,
"display_name": provider.display_name,
"description": provider.description,
"website": provider.website,
"billing_type": provider.billing_type.value if provider.billing_type else None,
"monthly_quota_usd": provider.monthly_quota_usd,
"quota_reset_day": provider.quota_reset_day,
"rpm_limit": provider.rpm_limit,
"provider_priority": provider.provider_priority,
"is_active": provider.is_active,
"concurrent_limit": provider.concurrent_limit,
"timeout": provider.timeout,
"max_retries": provider.max_retries,
"proxy": provider.proxy,
"config": provider.config,
"endpoints": endpoints_data,
"api_keys": keys_data,
"models": models_data,
}
)
return {
"version": "1.0",
"version": "2.0",
"exported_at": datetime.now(timezone.utc).isoformat(),
"global_models": global_models_data,
"providers": providers_data,
@@ -850,7 +854,7 @@ class AdminImportConfigAdapter(AdminApiAdapter):
# 验证配置版本
version = payload.get("version")
if version != "1.0":
if version != "2.0":
raise InvalidRequestException(f"不支持的配置版本: {version}")
# 获取导入选项
@@ -939,8 +943,8 @@ class AdminImportConfigAdapter(AdminApiAdapter):
)
elif merge_mode == "overwrite":
# 更新现有记录
existing_provider.display_name = prov_data.get(
"display_name", existing_provider.display_name
existing_provider.name = prov_data.get(
"name", existing_provider.name
)
existing_provider.description = prov_data.get("description")
existing_provider.website = prov_data.get("website")
@@ -954,7 +958,6 @@ class AdminImportConfigAdapter(AdminApiAdapter):
existing_provider.quota_reset_day = prov_data.get(
"quota_reset_day", 30
)
existing_provider.rpm_limit = prov_data.get("rpm_limit")
existing_provider.provider_priority = prov_data.get(
"provider_priority", 100
)
@@ -962,6 +965,11 @@ class AdminImportConfigAdapter(AdminApiAdapter):
existing_provider.concurrent_limit = prov_data.get(
"concurrent_limit"
)
existing_provider.timeout = prov_data.get("timeout", existing_provider.timeout)
existing_provider.max_retries = prov_data.get(
"max_retries", existing_provider.max_retries
)
existing_provider.proxy = prov_data.get("proxy", existing_provider.proxy)
existing_provider.config = prov_data.get("config")
existing_provider.updated_at = datetime.now(timezone.utc)
stats["providers"]["updated"] += 1
@@ -974,16 +982,17 @@ class AdminImportConfigAdapter(AdminApiAdapter):
new_provider = Provider(
id=str(uuid.uuid4()),
name=prov_data["name"],
display_name=prov_data.get("display_name", prov_data["name"]),
description=prov_data.get("description"),
website=prov_data.get("website"),
billing_type=billing_type,
monthly_quota_usd=prov_data.get("monthly_quota_usd"),
quota_reset_day=prov_data.get("quota_reset_day", 30),
rpm_limit=prov_data.get("rpm_limit"),
provider_priority=prov_data.get("provider_priority", 100),
is_active=prov_data.get("is_active", True),
concurrent_limit=prov_data.get("concurrent_limit"),
timeout=prov_data.get("timeout"),
max_retries=prov_data.get("max_retries"),
proxy=prov_data.get("proxy"),
config=prov_data.get("config"),
)
db.add(new_provider)
@@ -1003,7 +1012,6 @@ class AdminImportConfigAdapter(AdminApiAdapter):
)
if existing_ep:
endpoint_id = existing_ep.id
if merge_mode == "skip":
stats["endpoints"]["skipped"] += 1
elif merge_mode == "error":
@@ -1017,11 +1025,10 @@ class AdminImportConfigAdapter(AdminApiAdapter):
existing_ep.headers = ep_data.get("headers")
existing_ep.timeout = ep_data.get("timeout", 300)
existing_ep.max_retries = ep_data.get("max_retries", 2)
existing_ep.max_concurrent = ep_data.get("max_concurrent")
existing_ep.rate_limit = ep_data.get("rate_limit")
existing_ep.is_active = ep_data.get("is_active", True)
existing_ep.custom_path = ep_data.get("custom_path")
existing_ep.config = ep_data.get("config")
existing_ep.proxy = ep_data.get("proxy")
existing_ep.updated_at = datetime.now(timezone.utc)
stats["endpoints"]["updated"] += 1
else:
@@ -1033,25 +1040,30 @@ class AdminImportConfigAdapter(AdminApiAdapter):
headers=ep_data.get("headers"),
timeout=ep_data.get("timeout", 300),
max_retries=ep_data.get("max_retries", 2),
max_concurrent=ep_data.get("max_concurrent"),
rate_limit=ep_data.get("rate_limit"),
is_active=ep_data.get("is_active", True),
custom_path=ep_data.get("custom_path"),
config=ep_data.get("config"),
proxy=ep_data.get("proxy"),
)
db.add(new_ep)
db.flush()
endpoint_id = new_ep.id
stats["endpoints"]["created"] += 1
# 导入 Keys
# 获取当前 endpoint 下所有已有的 keys用于去重
existing_keys = (
db.query(ProviderAPIKey)
.filter(ProviderAPIKey.endpoint_id == endpoint_id)
# 导入 Provider Keys按 provider_id 归属)
endpoint_format_rows = (
db.query(ProviderEndpoint.api_format)
.filter(ProviderEndpoint.provider_id == provider_id)
.all()
)
endpoint_formats: set[str] = set()
for (api_format,) in endpoint_format_rows:
fmt = api_format.value if hasattr(api_format, "value") else str(api_format)
endpoint_formats.add(fmt.strip().upper())
existing_keys = (
db.query(ProviderAPIKey)
.filter(ProviderAPIKey.provider_id == provider_id)
.all()
)
# 解密已有 keys 用于比对
existing_key_values = set()
for ek in existing_keys:
try:
@@ -1060,40 +1072,73 @@ class AdminImportConfigAdapter(AdminApiAdapter):
except Exception:
pass
for key_data in ep_data.get("keys", []):
for key_data in prov_data.get("api_keys", []):
if not key_data.get("api_key"):
stats["errors"].append(
f"跳过空 API Key (Endpoint: {ep_data['api_format']})"
f"跳过空 API Key (Provider: {prov_data['name']})"
)
continue
# 检查是否已存在相同的 Key通过明文比对
if key_data["api_key"] in existing_key_values:
plaintext_key = key_data["api_key"]
if plaintext_key in existing_key_values:
stats["keys"]["skipped"] += 1
continue
encrypted_key = crypto_service.encrypt(key_data["api_key"])
raw_formats = key_data.get("api_formats") or []
if not isinstance(raw_formats, list) or len(raw_formats) == 0:
stats["errors"].append(
f"跳过无 api_formats 的 Key (Provider: {prov_data['name']})"
)
continue
normalized_formats: list[str] = []
seen: set[str] = set()
missing_formats: list[str] = []
for fmt in raw_formats:
if not isinstance(fmt, str):
continue
fmt_upper = fmt.strip().upper()
if not fmt_upper or fmt_upper in seen:
continue
seen.add(fmt_upper)
if endpoint_formats and fmt_upper not in endpoint_formats:
missing_formats.append(fmt_upper)
continue
normalized_formats.append(fmt_upper)
if missing_formats:
stats["errors"].append(
f"Key (Provider: {prov_data['name']}) 的 api_formats 未配置对应 Endpoint已跳过: {missing_formats}"
)
if len(normalized_formats) == 0:
stats["keys"]["skipped"] += 1
continue
encrypted_key = crypto_service.encrypt(plaintext_key)
new_key = ProviderAPIKey(
id=str(uuid.uuid4()),
endpoint_id=endpoint_id,
provider_id=provider_id,
api_formats=normalized_formats,
api_key=encrypted_key,
name=key_data.get("name"),
name=key_data.get("name") or "Imported Key",
note=key_data.get("note"),
rate_multiplier=key_data.get("rate_multiplier", 1.0),
internal_priority=key_data.get("internal_priority", 100),
rate_multipliers=key_data.get("rate_multipliers"),
internal_priority=key_data.get("internal_priority", 50),
global_priority=key_data.get("global_priority"),
max_concurrent=key_data.get("max_concurrent"),
rate_limit=key_data.get("rate_limit"),
daily_limit=key_data.get("daily_limit"),
monthly_limit=key_data.get("monthly_limit"),
rpm_limit=key_data.get("rpm_limit"),
allowed_models=key_data.get("allowed_models"),
capabilities=key_data.get("capabilities"),
cache_ttl_minutes=key_data.get("cache_ttl_minutes", 5),
max_probe_interval_minutes=key_data.get("max_probe_interval_minutes", 32),
is_active=key_data.get("is_active", True),
health_by_format={},
circuit_breaker_by_format={},
)
db.add(new_key)
# 添加到已有集合,防止同一批导入中重复
existing_key_values.add(key_data["api_key"])
existing_key_values.add(plaintext_key)
stats["keys"]["created"] += 1
# 导入 Models

View File

@@ -247,7 +247,8 @@ async def get_usage_detail(
- `request_headers`: 请求头
- `request_body`: 请求体
- `provider_request_headers`: 提供商请求头
- `response_headers`: 响应头
- `response_headers`: 提供商响应头
- `client_response_headers`: 返回给客户端的响应头
- `response_body`: 响应体
- `metadata`: 提供商响应元数据
- `tiered_pricing`: 阶梯计费信息(如适用)
@@ -916,6 +917,7 @@ class AdminUsageDetailAdapter(AdminApiAdapter):
"request_body": usage_record.get_request_body(),
"provider_request_headers": usage_record.provider_request_headers,
"response_headers": usage_record.response_headers,
"client_response_headers": usage_record.client_response_headers,
"response_body": usage_record.get_response_body(),
"metadata": usage_record.request_metadata,
"tiered_pricing": tiered_pricing_info,

View File

@@ -202,20 +202,59 @@ def get_available_provider_ids(db: Session, api_formats: list[str]) -> set[str]:
条件:
- 端点 api_format 匹配
- 端点是活跃的
- 端点下有活跃的 Key
- Provider 下有活跃的 Key 且支持该 api_formatKey 直属 Provider通过 api_formats 过滤)
"""
rows = (
db.query(ProviderEndpoint.provider_id)
.join(ProviderAPIKey, ProviderAPIKey.endpoint_id == ProviderEndpoint.id)
target_formats = {f.upper() for f in api_formats}
# 1) 先找出有活跃端点的 Provider记录每个 Provider 支持的格式集合)
endpoint_rows = (
db.query(ProviderEndpoint.provider_id, ProviderEndpoint.api_format)
.filter(
ProviderEndpoint.api_format.in_(api_formats),
ProviderEndpoint.api_format.in_(list(target_formats)),
ProviderEndpoint.is_active.is_(True),
ProviderAPIKey.is_active.is_(True),
)
.distinct()
.all()
)
return {row[0] for row in rows}
if not endpoint_rows:
return set()
provider_to_formats: dict[str, set[str]] = {}
for provider_id, fmt in endpoint_rows:
if not provider_id or not fmt:
continue
provider_to_formats.setdefault(provider_id, set()).add(str(fmt).upper())
provider_ids_with_endpoints = set(provider_to_formats.keys())
if not provider_ids_with_endpoints:
return set()
# 2) 再检查这些 Provider 是否至少有一个活跃 Key 支持对应格式
key_rows = (
db.query(ProviderAPIKey.provider_id, ProviderAPIKey.api_formats)
.filter(
ProviderAPIKey.provider_id.in_(provider_ids_with_endpoints),
ProviderAPIKey.is_active.is_(True),
)
.all()
)
available_provider_ids: set[str] = set()
for provider_id, key_formats in key_rows:
if not provider_id:
continue
endpoint_formats = provider_to_formats.get(provider_id)
if not endpoint_formats:
continue
formats_list = key_formats if isinstance(key_formats, list) else []
key_formats_upper = {str(f).upper() for f in formats_list}
# 只有同时满足:请求格式 ∩ Provider 端点格式 ∩ Key 支持格式 非空,才算可用
if key_formats_upper & endpoint_formats & target_formats:
available_provider_ids.add(provider_id)
return available_provider_ids
def _get_available_model_ids_for_format(db: Session, api_formats: list[str]) -> set[str]:
@@ -228,36 +267,64 @@ def _get_available_model_ids_for_format(db: Session, api_formats: list[str]) ->
3. **该端点的 Provider 关联了该模型**
4. Key 的 allowed_models 允许该模型null = 允许该 Provider 关联的所有模型)
"""
# 查询所有匹配格式的活跃端点及其活跃 Key同时获取 endpoint_id
endpoint_keys = (
db.query(
ProviderEndpoint.id.label("endpoint_id"),
ProviderEndpoint.provider_id,
ProviderAPIKey.allowed_models,
)
.join(ProviderAPIKey, ProviderAPIKey.endpoint_id == ProviderEndpoint.id)
target_formats = {f.upper() for f in api_formats}
# 1) 找出有活跃端点的 Provider记录每个 Provider 支持的格式集合)
endpoint_rows = (
db.query(ProviderEndpoint.provider_id, ProviderEndpoint.api_format)
.filter(
ProviderEndpoint.api_format.in_(api_formats),
ProviderEndpoint.api_format.in_(list(target_formats)),
ProviderEndpoint.is_active.is_(True),
)
.all()
)
if not endpoint_rows:
return set()
provider_to_formats: dict[str, set[str]] = {}
for provider_id, fmt in endpoint_rows:
if not provider_id or not fmt:
continue
provider_to_formats.setdefault(provider_id, set()).add(str(fmt).upper())
provider_ids_with_endpoints = set(provider_to_formats.keys())
if not provider_ids_with_endpoints:
return set()
# 2) 收集每个 Provider 下「支持对应格式」的活跃 Key 的 allowed_models
# Key 直属 Provider通过 key.api_formats 与 Provider 端点格式交集筛选
key_rows = (
db.query(ProviderAPIKey.provider_id, ProviderAPIKey.allowed_models, ProviderAPIKey.api_formats)
.filter(
ProviderAPIKey.provider_id.in_(provider_ids_with_endpoints),
ProviderAPIKey.is_active.is_(True),
)
.all()
)
if not endpoint_keys:
# provider_id -> list[(allowed_models, usable_formats)]
provider_key_rules: dict[str, list[tuple[object, set[str]]]] = {}
for provider_id, allowed_models, key_formats in key_rows:
if not provider_id:
continue
endpoint_formats = provider_to_formats.get(provider_id)
if not endpoint_formats:
continue
formats_list = key_formats if isinstance(key_formats, list) else []
key_formats_upper = {str(f).upper() for f in formats_list}
usable_formats = key_formats_upper & endpoint_formats & target_formats
if not usable_formats:
continue
provider_key_rules.setdefault(provider_id, []).append((allowed_models, usable_formats))
provider_ids_with_format = set(provider_key_rules.keys())
if not provider_ids_with_format:
return set()
# 收集每个 (provider_id, endpoint_id) 对应的 allowed_models
# 使用 provider_id 作为 key因为模型是关联到 Provider 的
provider_allowed_models: dict[str, list[Optional[list[str]]]] = {}
provider_ids_with_format: set[str] = set()
for endpoint_id, provider_id, allowed_models in endpoint_keys:
provider_ids_with_format.add(provider_id)
if provider_id not in provider_allowed_models:
provider_allowed_models[provider_id] = []
provider_allowed_models[provider_id].append(allowed_models)
# 只查询那些有匹配格式端点的 Provider 下的模型
models = (
db.query(Model)
@@ -285,20 +352,28 @@ def _get_available_model_ids_for_format(db: Session, api_formats: list[str]) ->
if model_provider_id not in provider_ids_with_format:
continue
# 检查该 provider 下是否有 Key 允许这个模型
allowed_lists = provider_allowed_models.get(model_provider_id, [])
for allowed_models in allowed_lists:
# 检查该 provider 下是否有 Key 允许这个模型(支持 list/dict 两种 allowed_models
from src.core.model_permissions import check_model_allowed
rules = provider_key_rules.get(model_provider_id, [])
for allowed_models, usable_formats in rules:
# None = 不限制
if allowed_models is None:
# null = 允许该 Provider 关联的所有模型(已通过上面的查询限制)
available_model_ids.add(model_id)
break
elif model_id in allowed_models:
# 明确在允许列表中
# 对于支持多个格式的 Key任意一个可用格式允许即可
for fmt in usable_formats:
if check_model_allowed(
model_name=model_id,
allowed_models=allowed_models, # type: ignore[arg-type]
api_format=fmt,
resolved_model_name=(model.provider_model_name if global_model else None),
):
available_model_ids.add(model_id)
break
elif global_model and model.provider_model_name in allowed_models:
# 也检查 provider_model_name
available_model_ids.add(model_id)
else:
continue
break
return available_model_ids

View File

@@ -155,7 +155,7 @@ async def get_daily_stats(
class DashboardAdapter(ApiAdapter):
"""需要登录的仪表盘适配器基类。"""
mode = ApiMode.ADMIN
mode = ApiMode.USER # 普通用户也可访问仪表盘
def authorize(self, context): # type: ignore[override]
if not context.user:

View File

@@ -98,6 +98,7 @@ class MessageTelemetry:
request_headers: Dict[str, Any],
response_body: Any,
response_headers: Dict[str, Any],
client_response_headers: Optional[Dict[str, Any]] = None,
cache_creation_tokens: int = 0,
cache_read_tokens: int = 0,
is_stream: bool = False,
@@ -143,6 +144,7 @@ class MessageTelemetry:
request_body=request_body,
provider_request_headers=provider_request_headers or {},
response_headers=response_headers,
client_response_headers=client_response_headers,
response_body=response_body,
request_id=self.request_id,
# Provider 侧追踪信息(用于记录真实成本)
@@ -192,6 +194,8 @@ class MessageTelemetry:
cache_creation_tokens: int = 0,
cache_read_tokens: int = 0,
response_body: Optional[Dict[str, Any]] = None,
response_headers: Optional[Dict[str, Any]] = None,
client_response_headers: Optional[Dict[str, Any]] = None,
# 模型映射信息
target_model: Optional[str] = None,
) -> None:
@@ -207,6 +211,8 @@ class MessageTelemetry:
cache_creation_tokens: 缓存创建 tokens
cache_read_tokens: 缓存读取 tokens
response_body: 响应体(如果有部分响应)
response_headers: 响应头Provider 返回的原始响应头)
client_response_headers: 返回给客户端的响应头
target_model: 映射后的目标模型名(如果发生了映射)
"""
provider_name = provider or "unknown"
@@ -232,7 +238,8 @@ class MessageTelemetry:
request_headers=request_headers,
request_body=request_body,
provider_request_headers=provider_request_headers or {},
response_headers={},
response_headers=response_headers or {},
client_response_headers=client_response_headers,
response_body=response_body or {"error": error_message},
request_id=self.request_id,
# 模型映射信息

View File

@@ -351,9 +351,9 @@ class ChatAdapterBase(ApiAdapter):
# 确定错误消息
if isinstance(e, ProviderAuthException):
error_message = (
f"提供商认证失败: {str(e)}"
"上游服务认证失败"
if result.metadata.provider != "unknown"
else "服务端错误: 无可用提供商"
else "服务暂时不可用"
)
result.error_message = error_message

View File

@@ -37,7 +37,7 @@ from src.api.handlers.base.stream_processor import StreamProcessor
from src.api.handlers.base.stream_telemetry import StreamTelemetryRecorder
from src.api.handlers.base.utils import build_sse_headers
from src.config.settings import config
from src.core.error_utils import extract_error_message
from src.core.error_utils import extract_client_error_message
from src.core.exceptions import (
EmbeddedErrorException,
ProviderAuthException,
@@ -382,10 +382,17 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
http_request.is_disconnected,
)
# 透传提供商的响应头给客户端
# 同时添加必要的 SSE 头以确保流式传输正常工作
client_headers = dict(ctx.response_headers) if ctx.response_headers else {}
# 添加/覆盖 SSE 必需的头
client_headers.update(build_sse_headers())
client_headers["content-type"] = "text/event-stream"
return StreamingResponse(
monitored_stream,
media_type="text/event-stream",
headers=build_sse_headers(),
headers=client_headers,
background=background_tasks,
)
@@ -463,7 +470,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
# 配置 HTTP 超时
# 注意read timeout 用于检测连接断开,不是整体请求超时
# 整体请求超时由 asyncio.wait_for 控制,使用 endpoint.timeout
# 整体请求超时由 asyncio.wait_for 控制,使用 provider.timeout
timeout_config = httpx.Timeout(
connect=config.http_connect_timeout,
read=config.http_read_timeout, # 使用全局配置,用于检测连接断开
@@ -471,14 +478,14 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
pool=config.http_pool_timeout,
)
# endpoint.timeout 作为整体请求超时(建立连接 + 获取首字节)
request_timeout = float(endpoint.timeout or 300)
# provider.timeout 作为整体请求超时(建立连接 + 获取首字节)
request_timeout = float(provider.timeout or 300)
# 创建 HTTP 客户端(支持代理配置)
# 创建 HTTP 客户端(支持代理配置,从 Provider 读取
from src.clients.http_client import HTTPClientPool
http_client = HTTPClientPool.create_client_with_proxy(
proxy_config=endpoint.proxy,
proxy_config=provider.proxy,
timeout=timeout_config,
)
@@ -514,7 +521,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
try:
# 使用 asyncio.wait_for 包裹整个"建立连接 + 获取首字节"阶段
# endpoint.timeout 控制整体超时,避免上游长时间无响应
# provider.timeout 控制整体超时,避免上游长时间无响应
await asyncio.wait_for(_connect_and_prefetch(), timeout=request_timeout)
except asyncio.TimeoutError:
@@ -590,17 +597,22 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
actual_request_body = ctx.provider_request_body or original_request_body
# 失败时返回给客户端的是 JSON 错误响应
client_response_headers = {"content-type": "application/json"}
await self.telemetry.record_failure(
provider=ctx.provider_name or "unknown",
model=ctx.model,
response_time_ms=response_time_ms,
status_code=status_code,
error_message=extract_error_message(error),
error_message=extract_client_error_message(error),
request_headers=original_headers,
request_body=actual_request_body,
is_stream=True,
api_format=ctx.api_format,
provider_request_headers=ctx.provider_request_headers,
response_headers=ctx.response_headers,
client_response_headers=client_response_headers,
target_model=ctx.mapped_model,
)
@@ -691,26 +703,32 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
f" [{self.request_id}] 请求体stream字段: {provider_payload.get('stream', 'N/A')}"
)
# 创建 HTTP 客户端(支持代理配置)
# endpoint.timeout 作为整体请求超时
# 获取复用的 HTTP 客户端(支持代理配置,从 Provider 读取
# 注意:使用 get_proxy_client 复用连接池,不再每次创建新客户端
from src.clients.http_client import HTTPClientPool
request_timeout = float(endpoint.timeout or 300)
http_client = HTTPClientPool.create_client_with_proxy(
proxy_config=endpoint.proxy,
request_timeout = float(provider.timeout or 300)
http_client = await HTTPClientPool.get_proxy_client(
proxy_config=provider.proxy,
)
# 注意:不使用 async with因为复用的客户端不应该被关闭
# 超时通过 timeout 参数控制
resp = await http_client.post(
url,
json=provider_payload,
headers=provider_hdrs,
timeout=httpx.Timeout(request_timeout),
)
async with http_client:
resp = await http_client.post(url, json=provider_payload, headers=provider_hdrs)
status_code = resp.status_code
response_headers = dict(resp.headers)
if resp.status_code == 401:
raise ProviderAuthException(f"提供商认证失败: {provider.name}")
raise ProviderAuthException(str(provider.name))
elif resp.status_code == 429:
raise ProviderRateLimitException(
f"提供商速率限制: {provider.name}",
"请求过于频繁,请稍后重试",
provider_name=str(provider.name),
response_headers=response_headers,
)
@@ -725,7 +743,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
except Exception:
pass
raise ProviderNotAvailableException(
f"提供商服务不可用: {provider.name}",
f"上游服务暂时不可用 (HTTP {resp.status_code})",
provider_name=str(provider.name),
upstream_status=resp.status_code,
upstream_response=error_body,
@@ -741,13 +759,41 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
except Exception:
pass
raise ProviderNotAvailableException(
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}",
f"上游服务返回错误 (HTTP {resp.status_code})",
provider_name=str(provider.name),
upstream_status=resp.status_code,
upstream_response=error_body,
)
# 安全解析 JSON 响应,处理可能的编码错误
try:
response_json = resp.json()
except (UnicodeDecodeError, json.JSONDecodeError) as e:
# 获取原始响应内容用于调试(存入 upstream_response
raw_content = ""
try:
raw_content = resp.text[:500] if resp.text else "(empty)"
except Exception:
try:
raw_content = repr(resp.content[:500]) if resp.content else "(empty)"
except Exception:
raw_content = "(unable to read)"
logger.error(
f"[{self.request_id}] 无法解析响应 JSON: {e}, 原始内容: {raw_content}"
)
# 判断错误类型,生成友好的客户端错误消息(不暴露提供商信息)
if raw_content == "(empty)" or not raw_content.strip():
client_message = "上游服务返回了空响应"
elif raw_content.strip().startswith(("<", "<!doctype", "<!DOCTYPE")):
client_message = "上游服务返回了非预期的响应格式"
else:
client_message = "上游服务返回了无效的响应"
raise ProviderNotAvailableException(
client_message,
provider_name=str(provider.name),
upstream_status=resp.status_code,
upstream_response=raw_content,
)
return response_json if isinstance(response_json, dict) else {}
try:
@@ -792,6 +838,11 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
actual_request_body = provider_request_body or original_request_body
# 非流式成功时,返回给客户端的是提供商响应头(透传)
# JSONResponse 会自动设置 content-type但我们记录实际返回的完整头
client_response_headers = dict(response_headers) if response_headers else {}
client_response_headers["content-type"] = "application/json"
total_cost = await self.telemetry.record_success(
provider=provider_name,
model=model,
@@ -802,6 +853,7 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
request_headers=original_headers,
request_body=actual_request_body,
response_headers=response_headers,
client_response_headers=client_response_headers,
response_body=response_json,
cache_creation_tokens=cache_creation_tokens,
cache_read_tokens=cached_tokens,
@@ -823,7 +875,12 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
f"in:{input_tokens or 0} out:{output_tokens or 0}"
)
return JSONResponse(status_code=status_code, content=response_json)
# 透传提供商的响应头
return JSONResponse(
status_code=status_code,
content=response_json,
headers=response_headers if response_headers else None,
)
except Exception as e:
response_time_ms = self.elapsed_ms()
@@ -838,17 +895,27 @@ class ChatHandlerBase(BaseMessageHandler, ABC):
actual_request_body = provider_request_body or original_request_body
# 尝试从异常中提取响应头
error_response_headers: Dict[str, str] = {}
if isinstance(e, ProviderRateLimitException) and e.response_headers:
error_response_headers = e.response_headers
elif isinstance(e, httpx.HTTPStatusError) and hasattr(e, "response"):
error_response_headers = dict(e.response.headers)
await self.telemetry.record_failure(
provider=provider_name or "unknown",
model=model,
response_time_ms=response_time_ms,
status_code=status_code,
error_message=extract_error_message(e),
error_message=extract_client_error_message(e),
request_headers=original_headers,
request_body=actual_request_body,
is_stream=False,
api_format=api_format,
provider_request_headers=provider_request_headers,
response_headers=error_response_headers,
# 非流式失败返回给客户端的是 JSON 错误响应
client_response_headers={"content-type": "application/json"},
# 模型映射信息
target_model=mapped_model_result,
)

View File

@@ -306,9 +306,9 @@ class CliAdapterBase(ApiAdapter):
# 确定错误消息
if isinstance(e, ProviderAuthException):
error_message = (
f"提供商认证失败: {str(e)}"
"上游服务认证失败"
if result.metadata.provider != "unknown"
else "服务端错误: 无可用提供商"
else "服务暂时不可用"
)
result.error_message = error_message

View File

@@ -47,7 +47,7 @@ from src.api.handlers.base.utils import (
)
from src.config.constants import StreamDefaults
from src.config.settings import config
from src.core.error_utils import extract_error_message
from src.core.error_utils import extract_client_error_message
from src.core.exceptions import (
EmbeddedErrorException,
ProviderAuthException,
@@ -376,10 +376,18 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 创建监控流
monitored_stream = self._create_monitored_stream(ctx, stream_generator)
# 透传提供商的响应头给客户端
# 同时添加必要的 SSE 头以确保流式传输正常工作
client_headers = dict(ctx.response_headers) if ctx.response_headers else {}
# 添加/覆盖 SSE 必需的头
client_headers.update(build_sse_headers())
client_headers["content-type"] = "text/event-stream"
ctx.client_response_headers = client_headers
return StreamingResponse(
monitored_stream,
media_type="text/event-stream",
headers=build_sse_headers(),
headers=client_headers,
background=background_tasks,
)
@@ -475,8 +483,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
pool=config.http_pool_timeout,
)
# endpoint.timeout 作为整体请求超时(建立连接 + 获取首字节)
request_timeout = float(endpoint.timeout or 300)
# provider.timeout 作为整体请求超时(建立连接 + 获取首字节)
request_timeout = float(provider.timeout or 300)
logger.debug(
f" └─ [{self.request_id}] 发送流式请求: "
@@ -486,11 +494,11 @@ class CliMessageHandlerBase(BaseMessageHandler):
f"timeout={request_timeout}s"
)
# 创建 HTTP 客户端(支持代理配置)
# 创建 HTTP 客户端(支持代理配置,从 Provider 读取
from src.clients.http_client import HTTPClientPool
http_client = HTTPClientPool.create_client_with_proxy(
proxy_config=endpoint.proxy,
proxy_config=provider.proxy,
timeout=timeout_config,
)
@@ -524,7 +532,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
try:
# 使用 asyncio.wait_for 包裹整个"建立连接 + 获取首字节"阶段
# endpoint.timeout 控制整体超时,避免上游长时间无响应
# provider.timeout 控制整体超时,避免上游长时间无响应
await asyncio.wait_for(_connect_and_prefetch(), timeout=request_timeout)
except asyncio.TimeoutError:
@@ -636,12 +644,16 @@ class CliMessageHandlerBase(BaseMessageHandler):
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
elapsed = time.time() - last_data_time
if elapsed > self.DATA_TIMEOUT:
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
logger.warning(f"Provider '{ctx.provider_name}' 流超时且无数据")
# 设置错误状态用于后续记录
ctx.status_code = 504
ctx.error_message = "流式响应超时,未收到有效数据"
ctx.upstream_response = f"流超时: Provider={ctx.provider_name}, elapsed={elapsed:.1f}s, chunk_count={ctx.chunk_count}, data_count=0"
error_event = {
"type": "error",
"error": {
"type": "empty_stream_timeout",
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
"message": ctx.error_message,
},
}
self._mark_first_output(ctx, output_state)
@@ -682,12 +694,16 @@ class CliMessageHandlerBase(BaseMessageHandler):
if ctx.data_count == 0:
# 流已开始,无法抛出异常进行故障转移
# 发送错误事件并记录日志
logger.warning(f"提供商 '{ctx.provider_name}' 返回空流式响应")
logger.warning(f"Provider '{ctx.provider_name}' 返回空流式响应")
# 设置错误状态用于后续记录
ctx.status_code = 503
ctx.error_message = "上游服务返回了空的流式响应"
ctx.upstream_response = f"空流式响应: Provider={ctx.provider_name}, chunk_count={ctx.chunk_count}, data_count=0"
error_event = {
"type": "error",
"error": {
"type": "empty_response",
"message": f"提供商 '{ctx.provider_name}' 返回了空的流式响应",
"message": ctx.error_message,
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
@@ -699,12 +715,16 @@ class CliMessageHandlerBase(BaseMessageHandler):
except httpx.StreamClosed:
if ctx.data_count == 0:
# 流已开始,发送错误事件而不是抛出异常
logger.warning(f"提供商 '{ctx.provider_name}' 流连接关闭且无数据")
logger.warning(f"Provider '{ctx.provider_name}' 流连接关闭且无数据")
# 设置错误状态用于后续记录
ctx.status_code = 503
ctx.error_message = "上游服务连接关闭且未返回数据"
ctx.upstream_response = f"流连接关闭: Provider={ctx.provider_name}, chunk_count={ctx.chunk_count}, data_count=0"
error_event = {
"type": "error",
"error": {
"type": "stream_closed",
"message": f"提供商 '{ctx.provider_name}' 连接关闭且未返回数据",
"message": ctx.error_message,
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
@@ -824,8 +844,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
f"base_url={endpoint.base_url}"
)
raise ProviderNotAvailableException(
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,"
f"请检查 endpoint 的 base_url 配置是否正确"
"上游服务返回了非预期的响应格式",
provider_name=str(provider.name),
upstream_status=200,
upstream_response=normalized_line[:500] if normalized_line else "(empty)",
)
if not normalized_line or normalized_line.startswith(":"):
@@ -1024,12 +1046,16 @@ class CliMessageHandlerBase(BaseMessageHandler):
if ctx.chunk_count > self.EMPTY_CHUNK_THRESHOLD and ctx.data_count == 0:
elapsed = time.time() - last_data_time
if elapsed > self.DATA_TIMEOUT:
logger.warning(f"提供商 '{ctx.provider_name}' 流超时且无数据")
logger.warning(f"Provider '{ctx.provider_name}' 流超时且无数据")
# 设置错误状态用于后续记录
ctx.status_code = 504
ctx.error_message = "流式响应超时,未收到有效数据"
ctx.upstream_response = f"流超时: Provider={ctx.provider_name}, elapsed={elapsed:.1f}s, chunk_count={ctx.chunk_count}, data_count=0"
error_event = {
"type": "error",
"error": {
"type": "empty_stream_timeout",
"message": f"提供商 '{ctx.provider_name}' 流超时且未返回有效数据",
"message": ctx.error_message,
},
}
self._mark_first_output(ctx, output_state)
@@ -1071,14 +1097,18 @@ class CliMessageHandlerBase(BaseMessageHandler):
if ctx.data_count == 0:
# 空流通常意味着配置错误(如 base_url 指向了网页而非 API
logger.error(
f"提供商 '{ctx.provider_name}' 返回空流式响应 (收到 {ctx.chunk_count} 个非数据行), "
f"Provider '{ctx.provider_name}' 返回空流式响应 (收到 {ctx.chunk_count} 个非数据行), "
f"可能是 endpoint base_url 配置错误"
)
# 设置错误状态用于后续记录
ctx.status_code = 503
ctx.error_message = "上游服务返回了空的流式响应"
ctx.upstream_response = f"空流式响应: Provider={ctx.provider_name}, chunk_count={ctx.chunk_count}, data_count=0, 可能是 base_url 配置错误"
error_event = {
"type": "error",
"error": {
"type": "empty_response",
"message": f"提供商 '{ctx.provider_name}' 返回了空的流式响应 (收到 {ctx.chunk_count} 行非 SSE 数据),请检查 endpoint 的 base_url 配置是否指向了正确的 API 地址",
"message": ctx.error_message,
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
@@ -1089,12 +1119,16 @@ class CliMessageHandlerBase(BaseMessageHandler):
raise
except httpx.StreamClosed:
if ctx.data_count == 0:
logger.warning(f"提供商 '{ctx.provider_name}' 流连接关闭且无数据")
logger.warning(f"Provider '{ctx.provider_name}' 流连接关闭且无数据")
# 设置错误状态用于后续记录
ctx.status_code = 503
ctx.error_message = "上游服务连接关闭且未返回数据"
ctx.upstream_response = f"流连接关闭: Provider={ctx.provider_name}, chunk_count={ctx.chunk_count}, data_count=0"
error_event = {
"type": "error",
"error": {
"type": "stream_closed",
"message": f"提供商 '{ctx.provider_name}' 连接关闭且未返回数据",
"message": ctx.error_message,
},
}
yield f"event: error\ndata: {json.dumps(error_event)}\n\n".encode("utf-8")
@@ -1289,6 +1323,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
if ctx.status_code and ctx.status_code >= 400:
# 记录失败的 Usage但使用已收到的预估 token 信息(来自 message_start
# 这样即使请求中断,也能记录预估成本
# 失败时返回给客户端的是 JSON 错误响应,如果没有设置则使用默认值
client_response_headers = ctx.client_response_headers or {"content-type": "application/json"}
await bg_telemetry.record_failure(
provider=ctx.provider_name or "unknown",
model=ctx.model,
@@ -1306,6 +1342,8 @@ class CliMessageHandlerBase(BaseMessageHandler):
cache_creation_tokens=ctx.cache_creation_tokens,
cache_read_tokens=ctx.cached_tokens,
response_body=response_body,
response_headers=ctx.response_headers,
client_response_headers=client_response_headers,
# 模型映射信息
target_model=ctx.mapped_model,
)
@@ -1319,6 +1357,14 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 在记录统计前,允许子类从 parsed_chunks 中提取额外的元数据
self._finalize_stream_metadata(ctx)
# 流式成功时,返回给客户端的是提供商响应头 + SSE 必需头
client_response_headers = dict(ctx.response_headers) if ctx.response_headers else {}
client_response_headers.update({
"Cache-Control": "no-cache, no-transform",
"X-Accel-Buffering": "no",
"content-type": "text/event-stream",
})
total_cost = await bg_telemetry.record_success(
provider=ctx.provider_name,
model=ctx.model,
@@ -1330,6 +1376,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
request_headers=original_headers,
request_body=actual_request_body,
response_headers=ctx.response_headers,
client_response_headers=client_response_headers,
response_body=response_body,
cache_creation_tokens=ctx.cache_creation_tokens,
cache_read_tokens=ctx.cached_tokens,
@@ -1367,13 +1414,15 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 499 = 客户端断开连接,应标记为失败
# 503 = 服务不可用(如流中断),应标记为失败
if ctx.status_code and ctx.status_code >= 400:
# 请求链路追踪使用 upstream_response原始响应回退到 error_message友好消息
trace_error_message = ctx.upstream_response or ctx.error_message or f"HTTP {ctx.status_code}"
RequestCandidateService.mark_candidate_failed(
db=bg_db,
candidate_id=ctx.attempt_id,
error_type=(
"client_disconnected" if ctx.status_code == 499 else "stream_error"
),
error_message=ctx.error_message or f"HTTP {ctx.status_code}",
error_message=trace_error_message,
status_code=ctx.status_code,
latency_ms=response_time_ms,
extra_data={
@@ -1426,17 +1475,22 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 使用实际发送给 Provider 的请求体(如果有),否则用原始请求体
actual_request_body = ctx.provider_request_body or original_request_body
# 失败时返回给客户端的是 JSON 错误响应
client_response_headers = {"content-type": "application/json"}
await self.telemetry.record_failure(
provider=ctx.provider_name or "unknown",
model=ctx.model,
response_time_ms=response_time_ms,
status_code=status_code,
error_message=extract_error_message(error),
error_message=extract_client_error_message(error),
request_headers=original_headers,
request_body=actual_request_body,
is_stream=True,
api_format=ctx.api_format,
provider_request_headers=ctx.provider_request_headers,
response_headers=ctx.response_headers,
client_response_headers=client_response_headers,
# 模型映射信息
target_model=ctx.mapped_model,
)
@@ -1534,26 +1588,32 @@ class CliMessageHandlerBase(BaseMessageHandler):
f"原始模型={model}, 映射后={mapped_model or '无映射'}, URL模型={url_model}"
)
# 创建 HTTP 客户端(支持代理配置)
# endpoint.timeout 作为整体请求超时
# 获取复用的 HTTP 客户端(支持代理配置,从 Provider 读取
# 注意:使用 get_proxy_client 复用连接池,不再每次创建新客户端
from src.clients.http_client import HTTPClientPool
request_timeout = float(endpoint.timeout or 300)
http_client = HTTPClientPool.create_client_with_proxy(
proxy_config=endpoint.proxy,
request_timeout = float(provider.timeout or 300)
http_client = await HTTPClientPool.get_proxy_client(
proxy_config=provider.proxy,
)
# 注意:不使用 async with因为复用的客户端不应该被关闭
# 超时通过 timeout 参数控制
resp = await http_client.post(
url,
json=provider_payload,
headers=provider_headers,
timeout=httpx.Timeout(request_timeout),
)
async with http_client:
resp = await http_client.post(url, json=provider_payload, headers=provider_headers)
status_code = resp.status_code
response_headers = dict(resp.headers)
if resp.status_code == 401:
raise ProviderAuthException(f"提供商认证失败: {provider.name}")
raise ProviderAuthException(str(provider.name))
elif resp.status_code == 429:
raise ProviderRateLimitException(
f"提供商速率限制: {provider.name}",
"请求过于频繁,请稍后重试",
provider_name=str(provider.name),
response_headers=response_headers,
retry_after=int(resp.headers.get("retry-after", 0)) or None,
@@ -1561,7 +1621,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
elif resp.status_code >= 500:
error_text = resp.text
raise ProviderNotAvailableException(
f"提供商服务不可用: {provider.name}, 状态: {resp.status_code}",
f"上游服务暂时不可用 (HTTP {resp.status_code})",
provider_name=str(provider.name),
upstream_status=resp.status_code,
upstream_response=error_text,
@@ -1569,12 +1629,15 @@ class CliMessageHandlerBase(BaseMessageHandler):
elif 300 <= resp.status_code < 400:
redirect_url = resp.headers.get("location", "unknown")
raise ProviderNotAvailableException(
f"提供商配置错误: {provider.name}, 返回重定向 {resp.status_code} -> {redirect_url}"
"上游服务返回重定向响应",
provider_name=str(provider.name),
upstream_status=resp.status_code,
upstream_response=f"重定向 {resp.status_code} -> {redirect_url}",
)
elif resp.status_code != 200:
error_text = resp.text
raise ProviderNotAvailableException(
f"提供商返回错误: {provider.name}, 状态: {resp.status_code}",
f"上游服务返回错误 (HTTP {resp.status_code})",
provider_name=str(provider.name),
upstream_status=resp.status_code,
upstream_response=error_text,
@@ -1584,16 +1647,34 @@ class CliMessageHandlerBase(BaseMessageHandler):
try:
response_json = resp.json()
except (UnicodeDecodeError, json.JSONDecodeError) as e:
# 记录原始响应信息用于调试
# 获取原始响应内容用于调试(存入 upstream_response
content_type = resp.headers.get("content-type", "unknown")
content_encoding = resp.headers.get("content-encoding", "none")
raw_content = ""
try:
raw_content = resp.text[:500] if resp.text else "(empty)"
except Exception:
try:
raw_content = repr(resp.content[:500]) if resp.content else "(empty)"
except Exception:
raw_content = "(unable to read)"
logger.error(
f"[{self.request_id}] 无法解析响应 JSON: {e}, "
f"Content-Type: {content_type}, Content-Encoding: {content_encoding}, "
f"响应长度: {len(resp.content)} bytes"
f"响应长度: {len(resp.content)} bytes, 原始内容: {raw_content}"
)
# 判断错误类型,生成友好的客户端错误消息(不暴露提供商信息)
if raw_content == "(empty)" or not raw_content.strip():
client_message = "上游服务返回了空响应"
elif raw_content.strip().startswith(("<", "<!doctype", "<!DOCTYPE")):
client_message = "上游服务返回了非预期的响应格式"
else:
client_message = "上游服务返回了无效的响应"
raise ProviderNotAvailableException(
f"提供商返回无效响应: {provider.name}, 无法解析 JSON: {str(e)[:100]}"
client_message,
provider_name=str(provider.name),
upstream_status=resp.status_code,
upstream_response=raw_content,
)
# 提取 Provider 响应元数据(子类可覆盖)
@@ -1663,6 +1744,10 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 使用实际发送给 Provider 的请求体(如果有),否则用原始请求体
actual_request_body = provider_request_body or original_request_body
# 非流式成功时,返回给客户端的是提供商响应头(透传)
client_response_headers = dict(response_headers) if response_headers else {}
client_response_headers["content-type"] = "application/json"
total_cost = await self.telemetry.record_success(
provider=provider_name,
model=model,
@@ -1673,6 +1758,7 @@ class CliMessageHandlerBase(BaseMessageHandler):
request_headers=original_headers,
request_body=actual_request_body,
response_headers=response_headers,
client_response_headers=client_response_headers,
response_body=response_json,
cache_creation_tokens=cache_creation_tokens,
cache_read_tokens=cached_tokens,
@@ -1691,7 +1777,12 @@ class CliMessageHandlerBase(BaseMessageHandler):
logger.info(f"{self.FORMAT_ID} 非流式响应处理完成")
return JSONResponse(status_code=status_code, content=response_json)
# 透传提供商的响应头
return JSONResponse(
status_code=status_code,
content=response_json,
headers=response_headers if response_headers else None,
)
except Exception as e:
response_time_ms = int((time.time() - sync_start_time) * 1000)
@@ -1707,17 +1798,27 @@ class CliMessageHandlerBase(BaseMessageHandler):
# 使用实际发送给 Provider 的请求体(如果有),否则用原始请求体
actual_request_body = provider_request_body or original_request_body
# 尝试从异常中提取响应头
error_response_headers: Dict[str, str] = {}
if isinstance(e, ProviderRateLimitException) and e.response_headers:
error_response_headers = e.response_headers
elif isinstance(e, httpx.HTTPStatusError) and hasattr(e, "response"):
error_response_headers = dict(e.response.headers)
await self.telemetry.record_failure(
provider=provider_name or "unknown",
model=model,
response_time_ms=response_time_ms,
status_code=status_code,
error_message=extract_error_message(e),
error_message=extract_client_error_message(e),
request_headers=original_headers,
request_body=actual_request_body,
is_stream=False,
api_format=api_format,
provider_request_headers=provider_request_headers,
response_headers=error_response_headers,
# 非流式失败返回给客户端的是 JSON 错误响应
client_response_headers={"content-type": "application/json"},
# 模型映射信息
target_model=mapped_model_result,
)

View File

@@ -74,7 +74,6 @@ def build_safe_headers(
return headers
# 保持向后兼容的run_endpoint_check函数使用新架构
async def run_endpoint_check(
*,
client: httpx.AsyncClient, # 保持兼容性,但内部不使用
@@ -176,10 +175,16 @@ async def _calculate_and_record_usage(
logger.warning(f"Provider API Key not found for usage calculation: {api_key_id}")
return {"error": "Provider API Key not found"}
# 获取Provider Endpoint信息
# 获取Provider Endpoint信息(通过 api_format 查找)
provider_endpoint = None
if provider_api_key.endpoint_id:
provider_endpoint = db.query(ProviderEndpoint).filter(ProviderEndpoint.id == provider_api_key.endpoint_id).first()
if api_format and provider_api_key.provider_id:
from src.models.database import Provider
provider = db.query(Provider).filter(Provider.id == provider_api_key.provider_id).first()
if provider:
for ep in provider.endpoints:
if ep.api_format == api_format and ep.is_active:
provider_endpoint = ep
break
# 获取用户的API Key用于记录关联即使实际使用的是Provider API Key
user_api_key = None

View File

@@ -61,11 +61,13 @@ class StreamContext:
# 响应状态
status_code: int = 200
error_message: Optional[str] = None
error_message: Optional[str] = None # 客户端友好的错误消息
upstream_response: Optional[str] = None # 原始 Provider 响应(用于请求链路追踪)
has_completion: bool = False
# 请求/响应数据
response_headers: Dict[str, str] = field(default_factory=dict)
response_headers: Dict[str, str] = field(default_factory=dict) # 提供商响应头
client_response_headers: Dict[str, str] = field(default_factory=dict) # 返回给客户端的响应头
provider_request_headers: Dict[str, str] = field(default_factory=dict)
provider_request_body: Optional[Dict[str, Any]] = None
@@ -97,9 +99,11 @@ class StreamContext:
self.cached_tokens = 0
self.cache_creation_tokens = 0
self.error_message = None
self.upstream_response = None
self.status_code = 200
self.first_byte_time_ms = None
self.response_headers = {}
self.client_response_headers = {}
self.provider_request_headers = {}
self.provider_request_body = None
self.response_id = None
@@ -174,10 +178,24 @@ class StreamContext:
):
self.cache_creation_tokens = cache_creation_tokens
def mark_failed(self, status_code: int, error_message: str) -> None:
"""标记请求失败"""
def mark_failed(
self,
status_code: int,
error_message: str,
upstream_response: Optional[str] = None,
) -> None:
"""
标记请求失败
Args:
status_code: HTTP 状态码
error_message: 客户端友好的错误消息
upstream_response: 原始 Provider 响应(用于请求链路追踪)
"""
self.status_code = status_code
self.error_message = error_message
if upstream_response:
self.upstream_response = upstream_response
def record_first_byte_time(self, start_time: float) -> None:
"""

View File

@@ -251,8 +251,10 @@ class StreamProcessor:
f"base_url={endpoint.base_url}"
)
raise ProviderNotAvailableException(
f"提供商 '{provider.name}' 返回了 HTML 页面而非 API 响应,"
f"请检查 endpoint 的 base_url 配置是否正确"
"上游服务返回了非预期的响应格式",
provider_name=str(provider.name),
upstream_status=200,
upstream_response=line[:500] if line else "(empty)",
)
# 跳过空行和注释行

View File

@@ -154,6 +154,14 @@ class StreamTelemetryRecorder:
response_time_ms: int,
) -> None:
"""记录成功的请求"""
# 流式成功时,返回给客户端的是提供商响应头 + SSE 必需头
client_response_headers = dict(ctx.response_headers) if ctx.response_headers else {}
client_response_headers.update({
"Cache-Control": "no-cache, no-transform",
"X-Accel-Buffering": "no",
"content-type": "text/event-stream",
})
await telemetry.record_success(
provider=ctx.provider_name or "unknown",
model=ctx.model,
@@ -165,6 +173,7 @@ class StreamTelemetryRecorder:
request_headers=original_headers,
request_body=actual_request_body,
response_headers=ctx.response_headers,
client_response_headers=client_response_headers,
response_body=response_body,
cache_creation_tokens=ctx.cache_creation_tokens,
cache_read_tokens=ctx.cached_tokens,
@@ -190,6 +199,9 @@ class StreamTelemetryRecorder:
response_time_ms: int,
) -> None:
"""记录失败的请求"""
# 失败时返回给客户端的是 JSON 错误响应,如果没有设置则使用默认值
client_response_headers = ctx.client_response_headers or {"content-type": "application/json"}
await telemetry.record_failure(
provider=ctx.provider_name or "unknown",
model=ctx.model,
@@ -206,6 +218,8 @@ class StreamTelemetryRecorder:
cache_creation_tokens=ctx.cache_creation_tokens,
cache_read_tokens=ctx.cached_tokens,
response_body=response_body,
response_headers=ctx.response_headers,
client_response_headers=client_response_headers,
target_model=ctx.mapped_model,
)
@@ -239,11 +253,13 @@ class StreamTelemetryRecorder:
)
else:
error_type = "client_disconnected" if ctx.status_code == 499 else "stream_error"
# 请求链路追踪使用 upstream_response原始响应回退到 error_message友好消息
trace_error_message = ctx.upstream_response or ctx.error_message or f"HTTP {ctx.status_code}"
RequestCandidateService.mark_candidate_failed(
db=db,
candidate_id=ctx.attempt_id,
error_type=error_type,
error_message=ctx.error_message or f"HTTP {ctx.status_code}",
error_message=trace_error_message,
status_code=ctx.status_code,
latency_ms=response_time_ms,
extra_data={

View File

@@ -26,8 +26,10 @@ def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int:
3. **旧格式(优先级第三)**
usage.cache_creation_input_tokens
优先使用嵌套格式,如果嵌套格式字段存在但值为 0则智能 fallback 到旧格式。
扁平格式和嵌套格式互斥,按顺序检查
说明:
- 只要检测到新格式字段(嵌套/扁平),即视为权威来源:哪怕值为 0 也不回退到旧字段
- 仅当新格式字段完全不存在时,才回退到旧字段。
- 扁平格式和嵌套格式互斥,按顺序检查。
Args:
usage: API 响应中的 usage 字典
@@ -37,28 +39,21 @@ def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int:
"""
# 1. 检查嵌套格式(最新格式)
cache_creation = usage.get("cache_creation")
if isinstance(cache_creation, dict):
has_nested_format = isinstance(cache_creation, dict) and (
"ephemeral_5m_input_tokens" in cache_creation
or "ephemeral_1h_input_tokens" in cache_creation
)
if has_nested_format:
cache_5m = int(cache_creation.get("ephemeral_5m_input_tokens", 0))
cache_1h = int(cache_creation.get("ephemeral_1h_input_tokens", 0))
total = cache_5m + cache_1h
if total > 0:
logger.debug(
f"Using nested cache_creation: 5m={cache_5m}, 1h={cache_1h}, total={total}"
)
return total
# 嵌套格式存在但为 0fallback 到旧格式
old_format = int(usage.get("cache_creation_input_tokens", 0))
if old_format > 0:
logger.debug(
f"Nested cache_creation is 0, using old format: {old_format}"
)
return old_format
# 都是 0返回 0
return 0
# 2. 检查扁平新格式
has_flat_format = (
"claude_cache_creation_5_m_tokens" in usage
@@ -70,23 +65,11 @@ def extract_cache_creation_tokens(usage: Dict[str, Any]) -> int:
cache_1h = int(usage.get("claude_cache_creation_1_h_tokens", 0))
total = cache_5m + cache_1h
if total > 0:
logger.debug(
f"Using flat new format: 5m={cache_5m}, 1h={cache_1h}, total={total}"
)
return total
# 扁平格式存在但为 0fallback 到旧格式
old_format = int(usage.get("cache_creation_input_tokens", 0))
if old_format > 0:
logger.debug(
f"Flat cache_creation is 0, using old format: {old_format}"
)
return old_format
# 都是 0返回 0
return 0
# 3. 回退到旧格式
old_format = int(usage.get("cache_creation_input_tokens", 0))
if old_format > 0:
@@ -173,8 +156,10 @@ def check_prefetched_response_error(
f"base_url={base_url}"
)
raise ProviderNotAvailableException(
f"提供商 '{provider_name}' 返回了 HTML 页面而非 API 响应,"
f"请检查 endpoint 的 base_url 配置是否正确"
"上游服务返回了非预期的响应格式",
provider_name=provider_name,
upstream_status=200,
upstream_response=stripped.decode("utf-8", errors="replace")[:500],
)
# 纯 JSON可能无换行/多行 JSON

View File

@@ -2,7 +2,6 @@
Claude CLI Message Handler - 基于通用 CLI Handler 基类的简化实现
继承 CliMessageHandlerBase只需覆盖格式特定的配置和事件处理逻辑。
验证新架构的有效性:代码量从数百行减少到 ~80 行。
"""
from typing import Any, Dict, Optional

View File

@@ -102,7 +102,6 @@ async def get_public_models(
- id: 模型唯一标识符
- provider_id: 所属提供商 ID
- provider_name: 提供商名称
- provider_display_name: 提供商显示名称
- name: 模型统一名称(优先使用 GlobalModel 名称)
- display_name: 模型显示名称
- description: 模型描述信息
@@ -300,10 +299,20 @@ class PublicProvidersAdapter(PublicApiAdapter):
providers = query.offset(self.skip).limit(self.limit).all()
result = []
for provider in providers:
models_count = db.query(Model).filter(Model.provider_id == provider.id).count()
models_count = (
db.query(Model)
.filter(Model.provider_id == provider.id, Model.global_model_id.isnot(None))
.count()
)
active_models_count = (
db.query(Model)
.filter(and_(Model.provider_id == provider.id, Model.is_active.is_(True)))
.filter(
and_(
Model.provider_id == provider.id,
Model.is_active.is_(True),
Model.global_model_id.isnot(None),
)
)
.count()
)
endpoints_count = len(provider.endpoints) if provider.endpoints else 0
@@ -313,7 +322,6 @@ class PublicProvidersAdapter(PublicApiAdapter):
provider_data = PublicProviderResponse(
id=provider.id,
name=provider.name,
display_name=provider.display_name,
description=provider.description,
is_active=provider.is_active,
provider_priority=provider.provider_priority,
@@ -342,7 +350,13 @@ class PublicModelsAdapter(PublicApiAdapter):
db.query(Model, Provider)
.options(joinedload(Model.global_model))
.join(Provider)
.filter(and_(Model.is_active.is_(True), Provider.is_active.is_(True)))
.filter(
and_(
Model.is_active.is_(True),
Provider.is_active.is_(True),
Model.global_model_id.isnot(None),
)
)
)
if self.provider_id is not None:
query = query.filter(Model.provider_id == self.provider_id)
@@ -357,7 +371,6 @@ class PublicModelsAdapter(PublicApiAdapter):
id=model.id,
provider_id=model.provider_id,
provider_name=provider.name,
provider_display_name=provider.display_name,
name=unified_name,
display_name=display_name,
description=global_model.config.get("description") if global_model and global_model.config else None,
@@ -386,7 +399,13 @@ class PublicStatsAdapter(PublicApiAdapter):
active_models = (
db.query(Model)
.join(Provider)
.filter(and_(Model.is_active.is_(True), Provider.is_active.is_(True)))
.filter(
and_(
Model.is_active.is_(True),
Provider.is_active.is_(True),
Model.global_model_id.isnot(None),
)
)
.count()
)
formats = (
@@ -418,7 +437,13 @@ class PublicSearchModelsAdapter(PublicApiAdapter):
.options(joinedload(Model.global_model))
.join(Provider)
.outerjoin(GlobalModel, Model.global_model_id == GlobalModel.id)
.filter(and_(Model.is_active.is_(True), Provider.is_active.is_(True)))
.filter(
and_(
Model.is_active.is_(True),
Provider.is_active.is_(True),
Model.global_model_id.isnot(None),
)
)
)
search_filter = (
Model.provider_model_name.ilike(f"%{self.query}%")
@@ -439,7 +464,6 @@ class PublicSearchModelsAdapter(PublicApiAdapter):
id=model.id,
provider_id=model.provider_id,
provider_name=provider.name,
provider_display_name=provider.display_name,
name=unified_name,
display_name=display_name,
description=global_model.config.get("description") if global_model and global_model.config else None,

View File

@@ -43,7 +43,6 @@ def _serialize_provider(
provider_data: Dict[str, Any] = {
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
"is_active": provider.is_active,
"provider_priority": provider.provider_priority,
}

View File

@@ -1023,7 +1023,6 @@ class ListAvailableProvidersAdapter(AuthenticatedApiAdapter):
{
"id": provider.id,
"name": provider.name,
"display_name": provider.display_name,
"description": provider.description,
"provider_priority": provider.provider_priority,
"endpoints": endpoints_data,

View File

@@ -1,10 +1,18 @@
"""
全局HTTP客户端池管理
避免每次请求都创建新的AsyncClient,提高性能
性能优化说明:
1. 默认客户端:无代理场景,全局复用单一客户端
2. 代理客户端缓存:相同代理配置复用同一客户端,避免重复创建
3. 连接池复用Keep-alive 连接减少 TCP 握手开销
"""
import asyncio
import hashlib
import time
from contextlib import asynccontextmanager
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Tuple
from urllib.parse import quote, urlparse
import httpx
@@ -12,6 +20,32 @@ import httpx
from src.config import config
from src.core.logger import logger
# 模块级锁,避免类属性延迟初始化的竞态条件
_proxy_clients_lock = asyncio.Lock()
_default_client_lock = asyncio.Lock()
def _compute_proxy_cache_key(proxy_config: Optional[Dict[str, Any]]) -> str:
"""
计算代理配置的缓存键
Args:
proxy_config: 代理配置字典
Returns:
缓存键字符串,无代理时返回 "__no_proxy__"
"""
if not proxy_config:
return "__no_proxy__"
# 构建代理 URL 作为缓存键的基础
proxy_url = build_proxy_url(proxy_config)
if not proxy_url:
return "__no_proxy__"
# 使用 MD5 哈希来避免过长的键名
return f"proxy:{hashlib.md5(proxy_url.encode()).hexdigest()[:16]}"
def build_proxy_url(proxy_config: Dict[str, Any]) -> Optional[str]:
"""
@@ -61,11 +95,20 @@ class HTTPClientPool:
全局HTTP客户端池单例
管理可重用的httpx.AsyncClient实例,避免频繁创建/销毁连接
性能优化:
1. 默认客户端:无代理场景复用
2. 代理客户端缓存:相同代理配置复用同一客户端
3. LRU 淘汰:代理客户端超过上限时淘汰最久未使用的
"""
_instance: Optional["HTTPClientPool"] = None
_default_client: Optional[httpx.AsyncClient] = None
_clients: Dict[str, httpx.AsyncClient] = {}
# 代理客户端缓存:{cache_key: (client, last_used_time)}
_proxy_clients: Dict[str, Tuple[httpx.AsyncClient, float]] = {}
# 代理客户端缓存上限(避免内存泄漏)
_max_proxy_clients: int = 50
def __new__(cls):
if cls._instance is None:
@@ -73,12 +116,50 @@ class HTTPClientPool:
return cls._instance
@classmethod
def get_default_client(cls) -> httpx.AsyncClient:
async def get_default_client_async(cls) -> httpx.AsyncClient:
"""
获取默认的HTTP客户端
获取默认的HTTP客户端(异步线程安全版本)
用于大多数HTTP请求,具有合理的默认配置
"""
if cls._default_client is not None:
return cls._default_client
async with _default_client_lock:
# 双重检查,避免重复创建
if cls._default_client is None:
cls._default_client = httpx.AsyncClient(
http2=False, # 暂时禁用HTTP/2以提高兼容性
verify=True, # 启用SSL验证
timeout=httpx.Timeout(
connect=config.http_connect_timeout,
read=config.http_read_timeout,
write=config.http_write_timeout,
pool=config.http_pool_timeout,
),
limits=httpx.Limits(
max_connections=config.http_max_connections,
max_keepalive_connections=config.http_keepalive_connections,
keepalive_expiry=config.http_keepalive_expiry,
),
follow_redirects=True, # 跟随重定向
)
logger.info(
f"全局HTTP客户端池已初始化: "
f"max_connections={config.http_max_connections}, "
f"keepalive={config.http_keepalive_connections}, "
f"keepalive_expiry={config.http_keepalive_expiry}s"
)
return cls._default_client
@classmethod
def get_default_client(cls) -> httpx.AsyncClient:
"""
获取默认的HTTP客户端同步版本向后兼容
⚠️ 注意:此方法在高并发首次调用时可能存在竞态条件,
推荐使用 get_default_client_async() 异步版本。
"""
if cls._default_client is None:
cls._default_client = httpx.AsyncClient(
http2=False, # 暂时禁用HTTP/2以提高兼容性
@@ -135,6 +216,101 @@ class HTTPClientPool:
return cls._clients[name]
@classmethod
def _get_proxy_clients_lock(cls) -> asyncio.Lock:
"""获取代理客户端缓存锁(模块级单例,避免竞态条件)"""
return _proxy_clients_lock
@classmethod
async def _evict_lru_proxy_client(cls) -> None:
"""淘汰最久未使用的代理客户端"""
if len(cls._proxy_clients) < cls._max_proxy_clients:
return
# 找到最久未使用的客户端
oldest_key = min(cls._proxy_clients.keys(), key=lambda k: cls._proxy_clients[k][1])
old_client, _ = cls._proxy_clients.pop(oldest_key)
# 异步关闭旧客户端
try:
await old_client.aclose()
logger.debug(f"淘汰代理客户端: {oldest_key}")
except Exception as e:
logger.warning(f"关闭代理客户端失败: {e}")
@classmethod
async def get_proxy_client(
cls,
proxy_config: Optional[Dict[str, Any]] = None,
) -> httpx.AsyncClient:
"""
获取代理客户端(带缓存复用)
相同代理配置会复用同一个客户端,大幅减少连接建立开销。
注意:返回的客户端使用默认超时配置,如需自定义超时请在请求时传递 timeout 参数。
Args:
proxy_config: 代理配置字典,包含 url, username, password
Returns:
可复用的 httpx.AsyncClient 实例
"""
cache_key = _compute_proxy_cache_key(proxy_config)
# 无代理时返回默认客户端
if cache_key == "__no_proxy__":
return await cls.get_default_client_async()
lock = cls._get_proxy_clients_lock()
async with lock:
# 检查缓存
if cache_key in cls._proxy_clients:
client, _ = cls._proxy_clients[cache_key]
# 健康检查:如果客户端已关闭,移除并重新创建
if client.is_closed:
del cls._proxy_clients[cache_key]
logger.debug(f"代理客户端已关闭,将重新创建: {cache_key}")
else:
# 更新最后使用时间
cls._proxy_clients[cache_key] = (client, time.time())
return client
# 淘汰旧客户端(如果超过上限)
await cls._evict_lru_proxy_client()
# 创建新客户端(使用默认超时,请求时可覆盖)
client_config: Dict[str, Any] = {
"http2": False,
"verify": True,
"follow_redirects": True,
"limits": httpx.Limits(
max_connections=config.http_max_connections,
max_keepalive_connections=config.http_keepalive_connections,
keepalive_expiry=config.http_keepalive_expiry,
),
"timeout": httpx.Timeout(
connect=config.http_connect_timeout,
read=config.http_read_timeout,
write=config.http_write_timeout,
pool=config.http_pool_timeout,
),
}
# 添加代理配置
proxy_url = build_proxy_url(proxy_config) if proxy_config else None
if proxy_url:
client_config["proxy"] = proxy_url
client = httpx.AsyncClient(**client_config)
cls._proxy_clients[cache_key] = (client, time.time())
logger.debug(
f"创建代理客户端(缓存): {proxy_config.get('url', 'unknown') if proxy_config else 'none'}, "
f"缓存数量: {len(cls._proxy_clients)}"
)
return client
@classmethod
async def close_all(cls):
"""关闭所有HTTP客户端"""
@@ -148,6 +324,16 @@ class HTTPClientPool:
logger.debug(f"命名HTTP客户端已关闭: {name}")
cls._clients.clear()
# 关闭代理客户端缓存
for cache_key, (client, _) in cls._proxy_clients.items():
try:
await client.aclose()
logger.debug(f"代理客户端已关闭: {cache_key}")
except Exception as e:
logger.warning(f"关闭代理客户端失败: {e}")
cls._proxy_clients.clear()
logger.info("所有HTTP客户端已关闭")
@classmethod
@@ -190,13 +376,15 @@ class HTTPClientPool:
"""
创建带代理配置的HTTP客户端
⚠️ 性能警告:此方法每次都创建新客户端,推荐使用 get_proxy_client() 复用连接。
Args:
proxy_config: 代理配置字典,包含 url, username, password
timeout: 超时配置
**kwargs: 其他 httpx.AsyncClient 配置参数
Returns:
配置好的 httpx.AsyncClient 实例
配置好的 httpx.AsyncClient 实例(调用者需要负责关闭)
"""
client_config: Dict[str, Any] = {
"http2": False,
@@ -218,11 +406,21 @@ class HTTPClientPool:
proxy_url = build_proxy_url(proxy_config) if proxy_config else None
if proxy_url:
client_config["proxy"] = proxy_url
logger.debug(f"创建带代理的HTTP客户端: {proxy_config.get('url', 'unknown')}")
logger.debug(f"创建带代理的HTTP客户端(一次性): {proxy_config.get('url', 'unknown')}")
client_config.update(kwargs)
return httpx.AsyncClient(**client_config)
@classmethod
def get_pool_stats(cls) -> Dict[str, Any]:
"""获取连接池统计信息"""
return {
"default_client_active": cls._default_client is not None,
"named_clients_count": len(cls._clients),
"proxy_clients_count": len(cls._proxy_clients),
"max_proxy_clients": cls._max_proxy_clients,
}
# 便捷访问函数
def get_http_client() -> httpx.AsyncClient:

View File

@@ -52,19 +52,33 @@ class StreamDefaults:
MAX_PREFETCH_BYTES = 64 * 1024 # 64KB
class ConcurrencyDefaults:
"""并发控制默认值
class RPMDefaults:
"""RPM每分钟请求数制默认值
算法说明:边界记忆 + 渐进探测
- 触发 429 时记录边界last_concurrent_peak新限制 = 边界 - 1
- 触发 429 时记录边界last_rpm_peak新限制 = 边界 - 1
- 扩容时不超过边界,除非是探测性扩容(长时间无 429
- 这样可以快速收敛到真实限制附近,避免过度保守
初始值 50 RPM
- 系统会根据实际使用自动调整
"""
# 自适应并发初始限制(宽松起步,遇到 429 再降低)
INITIAL_LIMIT = 50
# 自适应 RPM 初始限制
INITIAL_LIMIT = 50 # 每分钟 50 次请求
# 429错误后的冷却时间分钟- 在此期间不会增加并发限制
# === 内存模式 RPM 计数器配置 ===
# 内存模式下的最大条目限制(防止内存泄漏)
# 每个条目约占 100 字节10000 条目 = ~1MB
# 计算依据1000 Key × 5 API 格式 × 2 (buffer) = 10000
# 可通过环境变量 RPM_MAX_MEMORY_ENTRIES 覆盖
MAX_MEMORY_RPM_ENTRIES = 10000
# 内存使用告警阈值(达到此比例时记录警告日志)
# 可通过环境变量 RPM_MEMORY_WARNING_THRESHOLD 覆盖
MEMORY_WARNING_THRESHOLD = 0.6 # 60%
# 429错误后的冷却时间分钟- 在此期间不会增加 RPM 限制
COOLDOWN_AFTER_429_MINUTES = 5
# 探测间隔上限(分钟)- 用于长期探测策略
@@ -86,30 +100,30 @@ class ConcurrencyDefaults:
# 最小采样数 - 窗口内至少需要这么多采样才能做出扩容决策
MIN_SAMPLES_FOR_DECISION = 5
# 扩容步长 - 每次扩容增加的并发数
INCREASE_STEP = 2
# 扩容步长 - 每次扩容增加的 RPM
INCREASE_STEP = 5 # 每次增加 5 RPM
# 最大并发限制上限
MAX_CONCURRENT_LIMIT = 200
# 最大 RPM 限制上限(不设上限,让系统自适应学习)
MAX_RPM_LIMIT = 10000
# 最小并发限制下限
# 设置为 3 而不是 1因为预留机制10%预留给缓存用户)会导致
# 当 learned_max_concurrent=1 时新用户实际可用槽位为 0永远无法命中
# 注意:当 limit < 10 时,预留机制实际不生效(预留槽位 = 0这是可接受的
MIN_CONCURRENT_LIMIT = 3
# 最小 RPM 限制下限
MIN_RPM_LIMIT = 5
# 缓存用户预留比例(默认 10%,新用户可用 90%
# 已被动态预留机制 (AdaptiveReservationDefaults) 替代,保留用于向后兼容
CACHE_RESERVATION_RATIO = 0.1
# === 探测性扩容参数 ===
# 探测性扩容间隔(分钟)- 长时间无 429 且有流量时尝试扩容
# 探测性扩容可以突破已知边界,尝试更高的并发
# 探测性扩容可以突破已知边界,尝试更高的 RPM
PROBE_INCREASE_INTERVAL_MINUTES = 30
# 探测性扩容最小请求数 - 在探测间隔内至少需要这么多请求
PROBE_INCREASE_MIN_REQUESTS = 10
# === 缓存用户预留比例 ===
# 缓存用户槽位预留比例(新用户可用 1 - 此值)
# 0.1 表示缓存用户预留 10%,新用户可用 90%
CACHE_RESERVATION_RATIO = 0.1
# 向后兼容别名
ConcurrencyDefaults = RPMDefaults
class CircuitBreakerDefaults:
@@ -193,10 +207,19 @@ class AdaptiveReservationDefaults:
class TimeoutDefaults:
"""超时配置默认值(秒)"""
"""超时配置默认值(秒)
# HTTP 请求默认超时
HTTP_REQUEST = 300 # 5分钟
超时配置说明:
- 全局默认值和 Provider 默认值统一为 120 秒
- 120 秒是 LLM API 的合理默认值:
* 大多数请求在 30 秒内完成
* 复杂推理(如 Claude extended thinking可能需要 60-90 秒
* 120 秒足够覆盖大部分场景,同时避免线程池被长时间占用
- 如需更长超时,可在 Provider 级别单独配置
"""
# HTTP 请求默认超时(与 Provider 默认值保持一致)
HTTP_REQUEST = 120 # 2分钟
# 数据库连接池获取超时
DB_POOL = 30

View File

@@ -36,6 +36,12 @@ class CacheService:
try:
return json.loads(value)
except (json.JSONDecodeError, TypeError):
# 非 JSON 值:统一返回字符串,避免上层出现 bytes/str 混用
if isinstance(value, (bytes, bytearray)):
try:
return value.decode("utf-8")
except Exception:
return value
return value
return None
@@ -98,6 +104,48 @@ class CacheService:
logger.warning(f"缓存删除失败: {key} - {e}")
return False
@staticmethod
async def delete_pattern(pattern: str, batch_size: int = 100) -> int:
"""
删除匹配模式的所有缓存
使用 SCAN 遍历并分批删除,避免阻塞 Redis
Args:
pattern: 缓存键模式(支持 * 通配符)
batch_size: 每批删除的最大键数量
Returns:
删除的键数量
"""
try:
redis = await get_redis_client(require_redis=False)
if not redis:
return 0
# 使用 SCAN 遍历匹配的键
deleted_count = 0
cursor = 0
while True:
cursor, keys = await redis.scan(cursor, match=pattern, count=batch_size)
if keys:
# 分批删除,避免单次删除过多键导致 Redis 阻塞
for i in range(0, len(keys), batch_size):
batch = keys[i : i + batch_size]
await redis.delete(*batch)
deleted_count += len(batch)
if cursor == 0:
break
if deleted_count > 0:
logger.debug(f"缓存模式删除成功: {pattern}, 删除 {deleted_count} 个键")
return deleted_count
except Exception as e:
logger.warning(f"缓存模式删除失败: {pattern} - {e}")
return 0
@staticmethod
async def exists(key: str) -> bool:
"""

View File

@@ -7,16 +7,19 @@ from typing import Optional
def extract_error_message(error: Exception, status_code: Optional[int] = None) -> str:
"""
从异常中提取错误消息,优先使用上游响应内容
从异常中提取错误消息,优先使用上游原始响应(用于链路追踪/调试)
此函数用于 RequestCandidate 表的 error_message 字段,
用于请求链路追踪中显示原始 Provider 响应。
Args:
error: 异常对象
status_code: 可选的 HTTP 状态码,用于构建更详细的错误消息
Returns:
错误消息字符串
错误消息字符串(原始 Provider 响应)
"""
# 优先使用 upstream_response 属性(包含上游 Provider 的原始错误)
# 优先使用 upstream_response 属性(包含上游 Provider 的原始错误,用于调试
upstream_response = getattr(error, "upstream_response", None)
if upstream_response and isinstance(upstream_response, str) and upstream_response.strip():
return str(upstream_response)
@@ -26,3 +29,25 @@ def extract_error_message(error: Exception, status_code: Optional[int] = None) -
if status_code is not None:
return f"HTTP {status_code}: {error_str}"
return error_str
def extract_client_error_message(error: Exception) -> str:
"""
从异常中提取客户端友好的错误消息(用于返回给客户端/Usage 记录)
此函数用于 Usage 表的 error_message 字段,
用于显示给最终用户的友好错误消息。
Args:
error: 异常对象
Returns:
友好的错误消息字符串
"""
# 优先使用 message 属性(已经是友好处理过的消息)
message = getattr(error, "message", None)
if message and isinstance(message, str) and message.strip():
return message
# 回退到异常的字符串表示
return str(error) or repr(error)

View File

@@ -205,7 +205,7 @@ class ProviderTimeoutException(ProviderException):
def __init__(self, provider_name: str, timeout: int, request_metadata: Optional[Any] = None):
super().__init__(
message=f"提供商 '{provider_name}' 请求超时({timeout}秒)",
message=f"请求超时({timeout}秒)",
provider_name=provider_name,
request_metadata=request_metadata,
timeout=timeout,
@@ -217,7 +217,7 @@ class ProviderAuthException(ProviderException):
def __init__(self, provider_name: str, request_metadata: Optional[Any] = None):
super().__init__(
message=f"提供商 '{provider_name}' 认证失败请检查API密钥",
message="上游服务认证失败",
provider_name=provider_name,
request_metadata=request_metadata,
)
@@ -292,9 +292,8 @@ class ModelNotSupportedException(ProxyException):
"""模型不支持"""
def __init__(self, model: str, provider_name: Optional[str] = None):
# 客户端消息不暴露提供商信息
message = f"模型 '{model}' 不受支持"
if provider_name:
message = f"提供商 '{provider_name}' 不支持模型 '{model}'"
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
error_type="model_not_supported",
@@ -307,9 +306,7 @@ class StreamingNotSupportedException(ProxyException):
"""流式请求不支持"""
def __init__(self, model: str, provider_name: Optional[str] = None):
if provider_name:
message = f"模型 '{model}' 在提供商 '{provider_name}' 上不支持流式请求"
else:
# 客户端消息不暴露提供商信息
message = f"模型 '{model}' 不支持流式请求"
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
@@ -389,7 +386,7 @@ class JSONParseException(ProviderException):
details["response_content"] = response_content
super().__init__(
message=f"提供商 '{provider_name}' 返回了无效的JSON响应",
message="上游服务返回了无效的响应",
provider_name=provider_name,
request_metadata=request_metadata,
**details,
@@ -406,7 +403,7 @@ class EmptyStreamException(ProviderException):
request_metadata: Optional[Any] = None,
):
super().__init__(
message=f"提供商 '{provider_name}' 返回了空的流式响应status=200 但无数据)",
message="上游服务返回了空的流式响应",
provider_name=provider_name,
request_metadata=request_metadata,
chunk_count=chunk_count,
@@ -428,11 +425,10 @@ class EmbeddedErrorException(ProviderException):
error_status: Optional[str] = None,
request_metadata: Optional[Any] = None,
):
message = f"提供商 '{provider_name}' 返回了嵌套错误"
# 客户端消息不暴露提供商信息
message = "上游服务返回了错误"
if error_code:
message += f" (code={error_code})"
if error_message:
message += f": {error_message}"
super().__init__(
message=message,
@@ -549,12 +545,14 @@ class ErrorResponse:
if isinstance(e, ProxyException):
details = e.details.copy() if e.details else {}
status_code = e.status_code
message = e.message
# 如果是 ProviderNotAvailableException 且有上游错误,直接透传上游信息
if isinstance(e, ProviderNotAvailableException) and e.upstream_response:
message = e.message # 使用友好的错误消息
# 如果是 ProviderNotAvailableException 且有上游错误信息
if isinstance(e, ProviderNotAvailableException):
if e.upstream_status:
status_code = e.upstream_status
message = e.upstream_response
# upstream_response 存入 details 供请求链路追踪使用,不作为客户端消息
if e.upstream_response:
details["upstream_response"] = e.upstream_response
return ErrorResponse.create(
error_type=e.error_type,
message=message,

View File

@@ -0,0 +1,286 @@
"""
模型权限工具
支持两种 allowed_models 格式:
1. 简单模式(列表): ["claude-sonnet-4", "gpt-4o"]
2. 按格式模式(字典): {"OPENAI": ["gpt-4o"], "CLAUDE": ["claude-sonnet-4"]}
使用 None/null 表示不限制(允许所有模型)
"""
from typing import Any, Dict, List, Optional, Set, Union
# 类型别名
AllowedModels = Optional[Union[List[str], Dict[str, List[str]]]]
def normalize_allowed_models(
allowed_models: AllowedModels,
api_format: Optional[str] = None,
) -> Optional[Set[str]]:
"""
将 allowed_models 规范化为模型名称集合
Args:
allowed_models: 允许的模型配置(列表或字典)
api_format: 当前请求的 API 格式(用于字典模式)
Returns:
- None: 不限制(允许所有模型)
- Set[str]: 允许的模型名称集合(可能为空集,表示拒绝所有)
"""
if allowed_models is None:
return None
# 简单模式:直接是列表
if isinstance(allowed_models, list):
return set(allowed_models)
# 按格式模式:字典
if isinstance(allowed_models, dict):
if api_format is None:
# 没有指定格式,合并所有格式的模型
all_models: Set[str] = set()
for models in allowed_models.values():
if isinstance(models, list):
all_models.update(models)
return all_models if all_models else None
# 查找指定格式的模型列表
api_format_upper = api_format.upper()
models = allowed_models.get(api_format_upper)
if models is None:
# 该格式未配置,检查是否有通配符 "*"
models = allowed_models.get("*")
if models is None:
# 字典模式下未配置的格式 = 不限制该格式
return None
return set(models) if isinstance(models, list) else None
# 未知类型,视为不限制
return None
def check_model_allowed(
model_name: str,
allowed_models: AllowedModels,
api_format: Optional[str] = None,
resolved_model_name: Optional[str] = None,
) -> bool:
"""
检查模型是否被允许
Args:
model_name: 请求的模型名称
allowed_models: 允许的模型配置
api_format: 当前请求的 API 格式
resolved_model_name: 解析后的 GlobalModel.name可选
Returns:
True: 允许使用该模型
False: 不允许使用该模型
"""
allowed_set = normalize_allowed_models(allowed_models, api_format)
if allowed_set is None:
# 不限制
return True
if len(allowed_set) == 0:
# 空集合 = 拒绝所有
return False
# 检查请求的模型名或解析后的名称是否在白名单中
if model_name in allowed_set:
return True
if resolved_model_name and resolved_model_name in allowed_set:
return True
return False
def merge_allowed_models(
allowed_models_1: AllowedModels,
allowed_models_2: AllowedModels,
) -> AllowedModels:
"""
合并两个 allowed_models 配置,取交集
规则:
- 如果任一为 None返回另一个
- 如果都有值,取交集
- 如果都是列表,取列表交集
- 如果有字典,按 API 格式分别取交集(保持字典语义,不丢失格式区分信息)
Args:
allowed_models_1: 第一个配置
allowed_models_2: 第二个配置
Returns:
合并后的配置
"""
if allowed_models_1 is None:
return allowed_models_2
if allowed_models_2 is None:
return allowed_models_1
# 两个都是简单列表:直接取交集(返回确定性顺序)
if isinstance(allowed_models_1, list) and isinstance(allowed_models_2, list):
intersection = set(allowed_models_1) & set(allowed_models_2)
return sorted(intersection) if intersection else []
# 任一为字典模式:按 API 格式分别取交集,避免把 dict 合并成 list 导致权限过宽
from src.core.enums import APIFormat
def merge_sets(a: Optional[Set[str]], b: Optional[Set[str]]) -> Optional[Set[str]]:
# None 表示不限制:交集规则下等价于“只受另一方限制”
if a is None:
return b
if b is None:
return a
return a & b
known_formats = [fmt.value for fmt in APIFormat]
per_format: Dict[str, Optional[Set[str]]] = {}
for fmt in known_formats:
s1 = normalize_allowed_models(allowed_models_1, api_format=fmt)
s2 = normalize_allowed_models(allowed_models_2, api_format=fmt)
per_format[fmt] = merge_sets(s1, s2)
# 计算默认(未知格式)的交集,用 "*" 作为默认值以覆盖未枚举的格式
default_s1 = normalize_allowed_models(allowed_models_1, api_format="__DEFAULT__")
default_s2 = normalize_allowed_models(allowed_models_2, api_format="__DEFAULT__")
default_set = merge_sets(default_s1, default_s2)
# 如果 default_set 非 None 且不存在“某些格式不限制”的情况,可用 "*" 作为默认规则并按需覆盖
can_use_wildcard = default_set is not None and all(v is not None for v in per_format.values())
merged_dict: Dict[str, List[str]] = {}
if can_use_wildcard and default_set is not None:
merged_dict["*"] = sorted(default_set)
for fmt, s in per_format.items():
# can_use_wildcard 保证 s 非 None
if s is not None and s != default_set:
merged_dict[fmt] = sorted(s)
else:
for fmt, s in per_format.items():
if s is None:
continue
merged_dict[fmt] = sorted(s)
if not merged_dict:
# 全部不限制
return None
return merged_dict
def get_allowed_models_preview(
allowed_models: AllowedModels,
max_items: int = 3,
) -> str:
"""
获取 allowed_models 的预览字符串(用于日志和错误消息)
Args:
allowed_models: 允许的模型配置
max_items: 最多显示的模型数
Returns:
预览字符串,如 "gpt-4o, claude-sonnet-4, ..."
"""
if allowed_models is None:
return "(不限制)"
all_models: Set[str] = set()
if isinstance(allowed_models, list):
all_models = set(allowed_models)
elif isinstance(allowed_models, dict):
for models in allowed_models.values():
if isinstance(models, list):
all_models.update(models)
if not all_models:
return "(无)"
sorted_models = sorted(all_models)
preview = ", ".join(sorted_models[:max_items])
if len(sorted_models) > max_items:
preview += f", ...共{len(sorted_models)}"
return preview
def is_format_mode(allowed_models: AllowedModels) -> bool:
"""
判断 allowed_models 是否为按格式模式
Args:
allowed_models: 允许的模型配置
Returns:
True: 按格式模式(字典)
False: 简单模式(列表或 None
"""
return isinstance(allowed_models, dict)
def convert_to_format_mode(
allowed_models: AllowedModels,
api_formats: Optional[List[str]] = None,
) -> Dict[str, List[str]]:
"""
将 allowed_models 转换为按格式模式
Args:
allowed_models: 原始配置
api_formats: 要应用的 API 格式列表
Returns:
按格式模式的配置
"""
if allowed_models is None:
return {}
if isinstance(allowed_models, dict):
return allowed_models
# 简单列表模式 -> 按格式模式
if isinstance(allowed_models, list):
if not api_formats:
return {"*": allowed_models}
return {fmt.upper(): list(allowed_models) for fmt in api_formats}
return {}
def convert_to_simple_mode(allowed_models: AllowedModels) -> Optional[List[str]]:
"""
将 allowed_models 转换为简单列表模式
Args:
allowed_models: 原始配置
Returns:
简单列表或 None
"""
if allowed_models is None:
return None
if isinstance(allowed_models, list):
return allowed_models
if isinstance(allowed_models, dict):
all_models: Set[str] = set()
for models in allowed_models.values():
if isinstance(models, list):
all_models.update(models)
return sorted(all_models) if all_models else None
return None

View File

@@ -369,7 +369,6 @@ def init_db():
_ensure_engine()
# 数据库表结构由 Alembic 迁移管理
# 首次部署或更新后请运行: ./migrate.sh
db = _SessionLocal()
try:

View File

@@ -52,15 +52,31 @@ class ProxyConfig(BaseModel):
class CreateProviderRequest(BaseModel):
"""创建 Provider 请求"""
name: str = Field(
...,
min_length=1,
max_length=100,
description="Provider 名称(英文字母、数字、下划线、连字符)",
)
display_name: str = Field(..., min_length=1, max_length=100, description="显示名称")
name: str = Field(..., min_length=1, max_length=100, description="提供商名称(唯一)")
description: Optional[str] = Field(None, max_length=1000, description="描述")
website: Optional[str] = Field(None, max_length=500, description="官网地址")
@field_validator("name")
@classmethod
def validate_name(cls, v: str) -> str:
"""验证名称格式,防止注入攻击"""
v = v.strip()
# 只允许安全的字符:字母、数字、下划线、连字符、空格、中文
if not re.match(r"^[\w\s\u4e00-\u9fff-]+$", v):
raise ValueError("名称只能包含字母、数字、下划线、连字符、空格和中文")
# 检查 SQL 注入关键字(不区分大小写)
sql_keywords = [
"SELECT", "INSERT", "UPDATE", "DELETE", "DROP", "CREATE",
"ALTER", "TRUNCATE", "UNION", "EXEC", "EXECUTE", "--", "/*", "*/"
]
v_upper = v.upper()
for keyword in sql_keywords:
if keyword in v_upper:
raise ValueError(f"名称包含非法关键字: {keyword}")
return v
billing_type: Optional[str] = Field(
ProviderBillingType.PAY_AS_YOU_GO.value, description="计费类型"
)
@@ -68,47 +84,16 @@ class CreateProviderRequest(BaseModel):
quota_reset_day: Optional[int] = Field(30, ge=1, le=365, description="配额重置周期(天数)")
quota_last_reset_at: Optional[datetime] = Field(None, description="当前周期开始时间")
quota_expires_at: Optional[datetime] = Field(None, description="配额过期时间")
rpm_limit: Optional[int] = Field(None, ge=0, description="RPM 限制")
provider_priority: Optional[int] = Field(100, ge=0, le=1000, description="提供商优先级(数字越小越优先)")
is_active: Optional[bool] = Field(True, description="是否启用")
concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制")
# 请求配置(从 Endpoint 迁移)
timeout: Optional[int] = Field(300, ge=1, le=600, description="请求超时(秒)")
max_retries: Optional[int] = Field(2, ge=0, le=10, description="最大重试次数")
proxy: Optional[ProxyConfig] = Field(None, description="代理配置")
config: Optional[Dict[str, Any]] = Field(None, description="其他配置")
@field_validator("name")
@classmethod
def validate_name(cls, v: str) -> str:
"""验证名称格式"""
if not re.match(r"^[a-zA-Z0-9_-]+$", v):
raise ValueError("名称只能包含英文字母、数字、下划线和连字符")
# SQL 注入防护:检查危险关键字
dangerous_keywords = [
"SELECT",
"INSERT",
"UPDATE",
"DELETE",
"DROP",
"CREATE",
"ALTER",
"EXEC",
"UNION",
"OR",
"AND",
"--",
";",
"'",
'"',
"<",
">",
]
upper_name = v.upper()
for keyword in dangerous_keywords:
if keyword in upper_name:
raise ValueError(f"名称包含禁止的字符或关键字: {keyword}")
return v
@field_validator("display_name", "description")
@field_validator("name", "description")
@classmethod
def sanitize_text(cls, v: Optional[str]) -> Optional[str]:
"""清理文本输入,防止 XSS"""
@@ -162,7 +147,7 @@ class CreateProviderRequest(BaseModel):
class UpdateProviderRequest(BaseModel):
"""更新 Provider 请求"""
display_name: Optional[str] = Field(None, min_length=1, max_length=100)
name: Optional[str] = Field(None, min_length=1, max_length=100)
description: Optional[str] = Field(None, max_length=1000)
website: Optional[str] = Field(None, max_length=500)
billing_type: Optional[str] = None
@@ -170,14 +155,17 @@ class UpdateProviderRequest(BaseModel):
quota_reset_day: Optional[int] = Field(None, ge=1, le=365)
quota_last_reset_at: Optional[datetime] = None
quota_expires_at: Optional[datetime] = None
rpm_limit: Optional[int] = Field(None, ge=0)
provider_priority: Optional[int] = Field(None, ge=0, le=1000)
is_active: Optional[bool] = None
concurrent_limit: Optional[int] = Field(None, ge=0)
# 请求配置(从 Endpoint 迁移)
timeout: Optional[int] = Field(None, ge=1, le=600, description="请求超时(秒)")
max_retries: Optional[int] = Field(None, ge=0, le=10, description="最大重试次数")
proxy: Optional[ProxyConfig] = Field(None, description="代理配置")
config: Optional[Dict[str, Any]] = None
# 复用相同的验证器
_sanitize_text = field_validator("display_name", "description")(
_sanitize_text = field_validator("name", "description")(
CreateProviderRequest.sanitize_text.__func__
)
_validate_website = field_validator("website")(CreateProviderRequest.validate_website.__func__)
@@ -196,7 +184,6 @@ class CreateEndpointRequest(BaseModel):
custom_path: Optional[str] = Field(None, max_length=200, description="自定义路径")
priority: Optional[int] = Field(100, ge=0, le=1000, description="优先级")
is_active: Optional[bool] = Field(True, description="是否启用")
rpm_limit: Optional[int] = Field(None, ge=0, description="RPM 限制")
concurrent_limit: Optional[int] = Field(None, ge=0, description="并发限制")
config: Optional[Dict[str, Any]] = Field(None, description="其他配置")
proxy: Optional[ProxyConfig] = Field(None, description="代理配置")
@@ -252,7 +239,6 @@ class UpdateEndpointRequest(BaseModel):
custom_path: Optional[str] = Field(None, max_length=200)
priority: Optional[int] = Field(None, ge=0, le=1000)
is_active: Optional[bool] = None
rpm_limit: Optional[int] = Field(None, ge=0)
concurrent_limit: Optional[int] = Field(None, ge=0)
config: Optional[Dict[str, Any]] = None
proxy: Optional[ProxyConfig] = Field(None, description="代理配置")
@@ -277,8 +263,7 @@ class CreateAPIKeyRequest(BaseModel):
api_key: str = Field(..., min_length=1, max_length=500, description="API Key")
priority: Optional[int] = Field(100, ge=0, le=1000, description="优先级")
is_active: Optional[bool] = Field(True, description="是否启用")
max_rpm: Optional[int] = Field(None, ge=0, description="最大 RPM")
max_concurrent: Optional[int] = Field(None, ge=0, description="最大并发")
rpm_limit: Optional[int] = Field(None, ge=0, description="RPM 限制NULL=自适应)")
notes: Optional[str] = Field(None, max_length=500, description="备注")
@field_validator("api_key")

View File

@@ -376,14 +376,13 @@ class ApiKeyResponse(BaseModel):
class ProviderCreate(BaseModel):
"""创建提供商请求
架构说明:
架构说明:
- Provider 仅包含提供商的元数据和计费配置
- API格式、URL、认证等配置应在 ProviderEndpoint 中设置
- API密钥应在 ProviderAPIKey 中设置
"""
name: str = Field(..., min_length=1, max_length=100, description="提供商唯一标识")
display_name: str = Field(..., min_length=1, max_length=100, description="显示名称")
name: str = Field(..., min_length=1, max_length=100, description="提供商名称(唯一)")
description: Optional[str] = Field(None, description="提供商描述")
website: Optional[str] = Field(None, max_length=500, description="主站网站")
@@ -397,7 +396,7 @@ class ProviderCreate(BaseModel):
class ProviderUpdate(BaseModel):
"""更新提供商请求"""
display_name: Optional[str] = Field(None, min_length=1, max_length=100)
name: Optional[str] = Field(None, min_length=1, max_length=100)
description: Optional[str] = None
website: Optional[str] = Field(None, max_length=500)
api_format: Optional[str] = None
@@ -418,7 +417,6 @@ class ProviderResponse(BaseModel):
id: str
name: str
display_name: str
description: Optional[str]
website: Optional[str]
api_format: str
@@ -609,7 +607,6 @@ class PublicProviderResponse(BaseModel):
id: str
name: str
display_name: str
description: Optional[str]
website: Optional[str]
is_active: bool
@@ -627,7 +624,6 @@ class PublicModelResponse(BaseModel):
id: str
provider_id: str
provider_name: str
provider_display_name: str
name: str
display_name: str
description: Optional[str] = None

View File

@@ -13,9 +13,7 @@ class ProviderAPIKeyBase(BaseModel):
name: Optional[str] = Field(None, description="密钥名称/备注")
api_key: str = Field(..., description="API密钥")
rate_limit: Optional[int] = Field(None, description="速率限制(每分钟请求数)")
daily_limit: Optional[int] = Field(None, description="每日请求限制")
monthly_limit: Optional[int] = Field(None, description="每月请求限制")
rpm_limit: Optional[int] = Field(None, description="RPM限制(每分钟请求数)NULL=自适应模式")
priority: int = Field(0, description="优先级(越高越优先使用)")
is_active: bool = Field(True, description="是否启用")
expires_at: Optional[datetime] = Field(None, description="过期时间")
@@ -32,9 +30,7 @@ class ProviderAPIKeyUpdate(BaseModel):
name: Optional[str] = None
api_key: Optional[str] = None
rate_limit: Optional[int] = None
daily_limit: Optional[int] = None
monthly_limit: Optional[int] = None
rpm_limit: Optional[int] = None
priority: Optional[int] = None
is_active: Optional[bool] = None
expires_at: Optional[datetime] = None
@@ -67,5 +63,3 @@ class ProviderAPIKeyStats(BaseModel):
last_used_at: Optional[datetime]
is_active: bool
is_expired: bool
remaining_daily: Optional[int] = Field(None, description="今日剩余请求数")
remaining_monthly: Optional[int] = Field(None, description="本月剩余请求数")

View File

@@ -338,7 +338,8 @@ class Usage(Base):
request_headers = Column(JSON, nullable=True) # 客户端请求头
request_body = Column(JSON, nullable=True) # 请求体7天内未压缩
provider_request_headers = Column(JSON, nullable=True) # 向提供商发送的请求头
response_headers = Column(JSON, nullable=True) # 响应头
response_headers = Column(JSON, nullable=True) # 提供商响应头
client_response_headers = Column(JSON, nullable=True) # 返回给客户端的响应头
response_body = Column(JSON, nullable=True) # 响应体7天内未压缩
# 压缩存储字段7天后自动压缩到这里
@@ -513,8 +514,7 @@ class Provider(Base):
__tablename__ = "providers"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True)
name = Column(String(100), unique=True, nullable=False, index=True) # 提供商唯一标识
display_name = Column(String(100), nullable=False) # 显示名称
name = Column(String(100), unique=True, nullable=False, index=True) # 提供商名称(唯一)
description = Column(Text, nullable=True) # 提供商描述
website = Column(String(500), nullable=True) # 主站网站
@@ -537,11 +537,6 @@ class Provider(Base):
quota_last_reset_at = Column(DateTime(timezone=True), nullable=True) # 上次额度重置时间
quota_expires_at = Column(DateTime(timezone=True), nullable=True) # 月卡过期时间
# RPM限制NULL=无限制0=禁止请求
rpm_limit = Column(Integer, nullable=True) # 每分钟请求数限制NULL=无限制0=禁止请求)
rpm_used = Column(Integer, default=0) # 当前分钟已用请求数
rpm_reset_at = Column(DateTime(timezone=True), nullable=True) # RPM重置时间
# 提供商优先级 (数字越小越优先,用于提供商优先模式下的 Provider 排序)
# 0-10: 急需消耗(如即将过期的月卡)
# 11-50: 优先消耗(月卡)
@@ -555,6 +550,15 @@ class Provider(Base):
# 限制
concurrent_limit = Column(Integer, nullable=True) # 并发请求限制
# 请求配置(从 Endpoint 迁移,作为全局默认值)
# 超时 300 秒对于 LLM API 是合理的默认值:
# - 大多数请求在 30 秒内完成
# - 复杂推理(如 Claude thinking可能需要 60-120 秒
# - 300 秒足够覆盖极端场景(如超长上下文、复杂工具调用)
timeout = Column(Integer, default=300, nullable=True) # 请求超时(秒)
max_retries = Column(Integer, default=2, nullable=True) # 最大重试次数
proxy = Column(JSONB, nullable=True) # 代理配置: {url, username, password, enabled}
# 配置
config = Column(JSON, nullable=True) # 额外配置如Azure deployment name等
@@ -574,6 +578,9 @@ class Provider(Base):
endpoints = relationship(
"ProviderEndpoint", back_populates="provider", cascade="all, delete-orphan"
)
api_keys = relationship(
"ProviderAPIKey", back_populates="provider", cascade="all, delete-orphan"
)
api_key_mappings = relationship(
"ApiKeyProviderMapping", back_populates="provider", cascade="all, delete-orphan"
)
@@ -599,12 +606,6 @@ class ProviderEndpoint(Base):
timeout = Column(Integer, default=300) # 超时(秒)
max_retries = Column(Integer, default=2) # 最大重试次数
# 限制
max_concurrent = Column(
Integer, nullable=True, default=None
) # 该端点的最大并发数NULL=不限制)
rate_limit = Column(Integer, nullable=True) # 每分钟请求限制
# 状态
is_active = Column(Boolean, default=True, nullable=False)
@@ -632,9 +633,6 @@ class ProviderEndpoint(Base):
# 关系
provider = relationship("Provider", back_populates="endpoints")
api_keys = relationship(
"ProviderAPIKey", back_populates="endpoint", cascade="all, delete-orphan"
)
# 唯一约束和索引在表定义后
__table_args__ = (
@@ -734,9 +732,11 @@ class GlobalModel(Base):
class Model(Base):
"""Provider 模型配置表 - Provider 如何使用某个 GlobalModel
设计原则 (方案 A):
- 每个 Model 必须关联一个 GlobalModel (global_model_id 不可为空)
- Model 表示 Provider 对某个 GlobalModel 的具体实现
设计原则:
- Model 表示 Provider 对某个模型的具体实现
- global_model_id 可为空:
- 为空时:模型尚未关联到 GlobalModel不参与路由
- 不为空时:模型已关联 GlobalModel参与路由
- provider_model_name 是 Provider 侧的实际模型名称 (可能与 GlobalModel.name 不同)
- 价格和能力配置可为空,为空时使用 GlobalModel 的默认值
"""
@@ -745,7 +745,8 @@ class Model(Base):
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True)
provider_id = Column(String(36), ForeignKey("providers.id"), nullable=False)
global_model_id = Column(String(36), ForeignKey("global_models.id"), nullable=False, index=True)
# 可为空NULL 表示未关联,不参与路由;非 NULL 表示已关联,参与路由
global_model_id = Column(String(36), ForeignKey("global_models.id"), nullable=True, index=True)
# Provider 映射配置
provider_model_name = Column(String(200), nullable=False) # Provider 侧的主模型名称
@@ -983,17 +984,20 @@ class Model(Base):
class ProviderAPIKey(Base):
"""Provider API密钥表 - 归属于特定 ProviderEndpoint"""
"""Provider API密钥表 - 直接归属于 Provider,支持多种 API 格式"""
__tablename__ = "provider_api_keys"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()), index=True)
# 外键关系
endpoint_id = Column(
String(36), ForeignKey("provider_endpoints.id", ondelete="CASCADE"), nullable=False
# 外键关系 - 直接关联 Provider
provider_id = Column(
String(36), ForeignKey("providers.id", ondelete="CASCADE"), nullable=False, index=True
)
# API 格式支持列表(核心字段)
api_formats = Column(JSON, nullable=False, default=list) # ["CLAUDE", "CLAUDE_CLI"]
# API密钥信息
api_key = Column(String(500), nullable=False) # API密钥加密存储
name = Column(String(100), nullable=False) # 密钥名称(必填,用于识别)
@@ -1002,7 +1006,10 @@ class ProviderAPIKey(Base):
# 成本计算
rate_multiplier = Column(
Float, default=1.0, nullable=False
) # 成本倍率(真实成本 = 表面成本 × 倍率)
) # 默认成本倍率(真实成本 = 表面成本 × 倍率)
rate_multipliers = Column(
JSON, nullable=True
) # 按 API 格式的成本倍率 {"CLAUDE": 1.0, "OPENAI": 0.8}
# 优先级配置 (数字越小越优先)
internal_priority = Column(
@@ -1012,14 +1019,11 @@ class ProviderAPIKey(Base):
Integer, nullable=True
) # 全局 Key 优先级(用于全局 Key 优先模式,跨 Provider 的 Key 排序NULL=未配置使用默认排序)
# 并发限制配置
# max_concurrent 决定并发控制模式:
# - NULL: 自适应模式,系统自动学习并调整(使用 learned_max_concurrent
# RPM 限制配置(自适应学习)
# rpm_limit 决定 RPM 控制模式:
# - NULL: 自适应模式,系统自动学习并调整(使用 learned_rpm_limit
# - 数字: 固定限制模式,使用用户指定的值
max_concurrent = Column(Integer, nullable=True, default=None)
rate_limit = Column(Integer, nullable=True) # 速率限制(每分钟请求数)
daily_limit = Column(Integer, nullable=True) # 每日请求限制
monthly_limit = Column(Integer, nullable=True) # 每月请求限制
rpm_limit = Column(Integer, nullable=True, default=None)
# 模型权限控制
allowed_models = Column(JSON, nullable=True) # 允许使用的模型列表null = 支持所有模型)
@@ -1028,16 +1032,16 @@ class ProviderAPIKey(Base):
capabilities = Column(JSON, nullable=True) # Key 拥有的能力
# 示例: {"cache_1h": true, "context_1m": true}
# 自适应并发调整(仅当 max_concurrent = NULL 时生效)
learned_max_concurrent = Column(
# 自适应 RPM 调整(仅当 rpm_limit = NULL 时生效)
learned_rpm_limit = Column(
Integer, nullable=True
) # 学习到的并发限制(自适应模式下的有效值)
) # 学习到的 RPM 限制(自适应模式下的有效值)
concurrent_429_count = Column(Integer, default=0, nullable=False) # 因并发导致的429次数
rpm_429_count = Column(Integer, default=0, nullable=False) # 因RPM导致的429次数
last_429_at = Column(DateTime(timezone=True), nullable=True) # 最后429时间
last_429_type = Column(String(50), nullable=True) # 最后429类型: concurrent/rpm/unknown
last_concurrent_peak = Column(Integer, nullable=True) # 触发429时的并发数
adjustment_history = Column(JSON, nullable=True) # 并发调整历史
last_rpm_peak = Column(Integer, nullable=True) # 触发429时的RPM峰值
adjustment_history = Column(JSON, nullable=True) # RPM调整历史
# 基于滑动窗口的利用率追踪
utilization_samples = Column(
JSON, nullable=True
@@ -1046,12 +1050,9 @@ class ProviderAPIKey(Base):
DateTime(timezone=True), nullable=True
) # 上次探测性扩容时间
# 健康度追踪(基于滑动窗口
health_score = Column(Float, default=1.0) # 0.0-1.0(保留用于展示,实际熔断基于滑动窗口)
consecutive_failures = Column(Integer, default=0)
last_failure_at = Column(DateTime(timezone=True), nullable=True) # 最后失败时间
# 滑动窗口:记录最近 N 次请求的结果 [{"ts": timestamp, "ok": true/false}, ...]
request_results_window = Column(JSON, nullable=True)
# 健康度追踪(按 API 格式存储
# 结构: {"CLAUDE": {"health_score": 1.0, "consecutive_failures": 0, "last_failure_at": null, "request_results_window": []}, ...}
health_by_format = Column(JSON, nullable=True, default=dict)
# 缓存与熔断配置
cache_ttl_minutes = Column(
@@ -1061,14 +1062,9 @@ class ProviderAPIKey(Base):
Integer, default=32, nullable=False
) # 最大探测间隔(分钟)默认32分钟硬上限
# 熔断器字段(滑动窗口 + 半开状态模式
circuit_breaker_open = Column(Boolean, default=False, nullable=False) # 熔断器是否打开
circuit_breaker_open_at = Column(DateTime(timezone=True), nullable=True) # 熔断器打开时间
next_probe_at = Column(DateTime(timezone=True), nullable=True) # 下次探测时间
# 半开状态:允许少量请求通过验证服务是否恢复
half_open_until = Column(DateTime(timezone=True), nullable=True) # 半开状态结束时间
half_open_successes = Column(Integer, default=0) # 半开状态下的成功次数
half_open_failures = Column(Integer, default=0) # 半开状态下的失败次数
# 熔断器状态(按 API 格式存储
# 结构: {"CLAUDE": {"open": false, "open_at": null, "next_probe_at": null, "half_open_until": null, "half_open_successes": 0, "half_open_failures": 0}, ...}
circuit_breaker_by_format = Column(JSON, nullable=True, default=dict)
# 使用统计
request_count = Column(Integer, default=0) # 请求次数
@@ -1095,7 +1091,7 @@ class ProviderAPIKey(Base):
)
# 关系
endpoint = relationship("ProviderEndpoint", back_populates="api_keys")
provider = relationship("Provider", back_populates="api_keys")
class UserPreference(Base):

View File

@@ -4,7 +4,7 @@ ProviderEndpoint 相关的 API 模型定义
import re
from datetime import datetime
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, field_validator
@@ -26,10 +26,6 @@ class ProviderEndpointCreate(BaseModel):
timeout: int = Field(default=300, ge=10, le=600, description="超时时间(秒)")
max_retries: int = Field(default=2, ge=0, le=10, description="最大重试次数")
# 限制
max_concurrent: Optional[int] = Field(default=None, ge=1, description="最大并发数")
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制(请求/秒)")
# 额外配置
config: Optional[Dict[str, Any]] = Field(default=None, description="额外配置JSON")
@@ -67,8 +63,6 @@ class ProviderEndpointUpdate(BaseModel):
headers: Optional[Dict[str, str]] = Field(default=None, description="自定义请求头")
timeout: Optional[int] = Field(default=None, ge=10, le=600, description="超时时间(秒)")
max_retries: Optional[int] = Field(default=None, ge=0, le=10, description="最大重试次数")
max_concurrent: Optional[int] = Field(default=None, ge=1, description="最大并发数")
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制")
is_active: Optional[bool] = Field(default=None, description="是否启用")
config: Optional[Dict[str, Any]] = Field(default=None, description="额外配置")
proxy: Optional[ProxyConfig] = Field(default=None, description="代理配置")
@@ -103,10 +97,6 @@ class ProviderEndpointResponse(BaseModel):
timeout: int
max_retries: int
# 限制
max_concurrent: Optional[int] = None
rate_limit: Optional[int] = None
# 状态
is_active: bool
@@ -127,32 +117,37 @@ class ProviderEndpointResponse(BaseModel):
model_config = ConfigDict(from_attributes=True)
# ========== ProviderAPIKey 相关(新架构) ==========
# ========== ProviderAPIKey 相关 ==========
class EndpointAPIKeyCreate(BaseModel):
"""Endpoint 添加 API Key"""
"""Provider 添加 API Key"""
provider_id: Optional[str] = Field(default=None, description="Provider ID从 URL 获取)")
api_formats: Optional[List[str]] = Field(
default=None, min_length=1, description="支持的 API 格式列表(必填,路由层校验)"
)
endpoint_id: str = Field(..., description="Endpoint ID")
api_key: str = Field(..., min_length=3, max_length=500, description="API Key将自动加密")
name: str = Field(..., min_length=1, max_length=100, description="密钥名称(必填,用于识别)")
# 成本计算
rate_multiplier: float = Field(
default=1.0, ge=0.01, description="成本倍率(真实成本 = 表面成本 × 倍率)"
default=1.0, ge=0.01, description="默认成本倍率(真实成本 = 表面成本 × 倍率)"
)
rate_multipliers: Optional[Dict[str, float]] = Field(
default=None, description="按 API 格式的成本倍率,如 {'CLAUDE': 1.0, 'OPENAI': 0.8}"
)
# 优先级和限制(数字越小越优先)
internal_priority: int = Field(default=50, description="Endpoint 内部优先级(提供商优先模式)")
# max_concurrent: NULL=自适应模式(系统自动学习),数字=固定限制模式
max_concurrent: Optional[int] = Field(
default=None, ge=1, description="最大并发数NULL=自适应模式)"
internal_priority: int = Field(default=50, description="Key 内部优先级(提供商优先模式)")
# rpm_limit: NULL=自适应模式(系统自动学习),数字=固定限制模式
rpm_limit: Optional[int] = Field(
default=None, ge=1, le=10000, description="RPM 限制NULL=自适应模式)"
)
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制")
daily_limit: Optional[int] = Field(default=None, ge=1, description="每日限制")
monthly_limit: Optional[int] = Field(default=None, ge=1, description="每月限制")
allowed_models: Optional[List[str]] = Field(
default=None, description="允许使用的模型列表null = 支持所有模型)"
allowed_models: Optional[Union[List[str], Dict[str, List[str]]]] = Field(
default=None,
description="允许使用的模型列表null=不限制,列表=简单白名单,字典=按API格式区分",
)
# 能力标签
@@ -171,6 +166,92 @@ class EndpointAPIKeyCreate(BaseModel):
# 备注
note: Optional[str] = Field(default=None, max_length=500, description="备注说明(可选)")
@field_validator("api_formats")
@classmethod
def validate_api_formats(cls, v: Optional[List[str]]) -> Optional[List[str]]:
"""验证 API 格式列表"""
if v is None:
return v
from src.core.enums import APIFormat
allowed = [fmt.value for fmt in APIFormat]
validated = []
seen = set()
for fmt in v:
fmt_upper = fmt.upper()
if fmt_upper not in allowed:
raise ValueError(f"API 格式必须是 {allowed} 之一,当前值: {fmt}")
if fmt_upper in seen:
continue # 静默去重
seen.add(fmt_upper)
validated.append(fmt_upper)
return validated
@field_validator("allowed_models")
@classmethod
def validate_allowed_models(
cls, v: Optional[Union[List[str], Dict[str, List[str]]]]
) -> Optional[Union[List[str], Dict[str, List[str]]]]:
"""
规范化 allowed_models
- 列表模式:去空、去重、保留顺序
- 字典模式key 统一大写(支持 "*"value 去空、去重、保留顺序
"""
if v is None:
return v
if isinstance(v, list):
cleaned: List[str] = []
seen: set[str] = set()
for item in v:
if not isinstance(item, str):
raise ValueError("allowed_models 列表必须为字符串数组")
name = item.strip()
if not name or name in seen:
continue
seen.add(name)
cleaned.append(name)
return cleaned
if isinstance(v, dict):
from src.core.enums import APIFormat
allowed_formats = {fmt.value for fmt in APIFormat}
normalized: Dict[str, List[str]] = {}
for raw_key, models in v.items():
if not isinstance(raw_key, str):
raise ValueError("allowed_models 字典的 key 必须为字符串")
key = raw_key.upper()
if key != "*" and key not in allowed_formats:
raise ValueError(
f"allowed_models 字典的 key 必须是 {sorted(allowed_formats)}'*',当前值: {raw_key}"
)
if models is None:
# null 表示该格式不限制,跳过(不加入字典)
continue
if not isinstance(models, list):
raise ValueError("allowed_models 字典的 value 必须为字符串数组")
cleaned: List[str] = []
seen: set[str] = set()
for item in models:
if not isinstance(item, str):
raise ValueError("allowed_models 字典的 value 必须为字符串数组")
name = item.strip()
if not name or name in seen:
continue
seen.add(name)
cleaned.append(name)
normalized[key] = cleaned
return normalized
raise ValueError("allowed_models 必须是列表或字典")
@field_validator("api_key")
@classmethod
def validate_api_key(cls, v: str) -> str:
@@ -214,26 +295,35 @@ class EndpointAPIKeyCreate(BaseModel):
class EndpointAPIKeyUpdate(BaseModel):
"""更新 Endpoint API Key"""
api_formats: Optional[List[str]] = Field(
default=None, min_length=1, description="支持的 API 格式列表"
)
api_key: Optional[str] = Field(
default=None, min_length=3, max_length=500, description="API Key将自动加密"
)
name: Optional[str] = Field(default=None, min_length=1, max_length=100, description="密钥名称")
rate_multiplier: Optional[float] = Field(default=None, ge=0.01, description="成本倍率")
rate_multiplier: Optional[float] = Field(default=None, ge=0.01, description="默认成本倍率")
rate_multipliers: Optional[Dict[str, float]] = Field(
default=None, description="按 API 格式的成本倍率,如 {'CLAUDE': 1.0, 'OPENAI': 0.8}"
)
internal_priority: Optional[int] = Field(
default=None, description="Endpoint 内部优先级(提供商优先模式,数字越小越优先)"
default=None, description="Key 内部优先级(提供商优先模式,数字越小越优先)"
)
global_priority: Optional[int] = Field(
default=None, description="全局 Key 优先级(全局 Key 优先模式,数字越小越优先)"
)
# max_concurrent: 使用特殊标记区分"未提供"和"设置为 null自适应模式"
# rpm_limit: 使用特殊标记区分"未提供"和"设置为 null自适应模式"
# - 不提供字段:不更新
# - 提供 null切换为自适应模式
# - 提供数字:设置固定并发限制
max_concurrent: Optional[int] = Field(default=None, ge=1, description="最大并发数null=自适应模式)")
rate_limit: Optional[int] = Field(default=None, ge=1, description="速率限制")
daily_limit: Optional[int] = Field(default=None, ge=1, description="每日限制")
monthly_limit: Optional[int] = Field(default=None, ge=1, description="每月限制")
allowed_models: Optional[List[str]] = Field(default=None, description="允许使用的模型列表")
# - 提供数字:设置固定 RPM 限制
rpm_limit: Optional[int] = Field(
default=None, ge=1, le=10000, description="RPM 限制null=自适应模式)"
)
allowed_models: Optional[Union[List[str], Dict[str, List[str]]]] = Field(
default=None,
description="允许使用的模型列表null=不限制,列表=简单白名单,字典=按API格式区分",
)
capabilities: Optional[Dict[str, bool]] = Field(
default=None, description="Key 能力标签,如 {'cache_1h': true, 'context_1m': true}"
)
@@ -246,6 +336,36 @@ class EndpointAPIKeyUpdate(BaseModel):
is_active: Optional[bool] = Field(default=None, description="是否启用")
note: Optional[str] = Field(default=None, max_length=500, description="备注说明")
@field_validator("api_formats")
@classmethod
def validate_api_formats(cls, v: Optional[List[str]]) -> Optional[List[str]]:
"""验证 API 格式列表"""
if v is None:
return v
from src.core.enums import APIFormat
allowed = [fmt.value for fmt in APIFormat]
validated = []
seen = set()
for fmt in v:
fmt_upper = fmt.upper()
if fmt_upper not in allowed:
raise ValueError(f"API 格式必须是 {allowed} 之一,当前值: {fmt}")
if fmt_upper in seen:
continue # 静默去重
seen.add(fmt_upper)
validated.append(fmt_upper)
return validated
@field_validator("allowed_models")
@classmethod
def validate_allowed_models(
cls, v: Optional[Union[List[str], Dict[str, List[str]]]]
) -> Optional[Union[List[str], Dict[str, List[str]]]]:
# 与 EndpointAPIKeyCreate 保持一致
return EndpointAPIKeyCreate.validate_allowed_models(v)
@field_validator("api_key")
@classmethod
def validate_api_key(cls, v: Optional[str]) -> Optional[str]:
@@ -295,7 +415,9 @@ class EndpointAPIKeyResponse(BaseModel):
"""Endpoint API Key 响应"""
id: str
endpoint_id: str
provider_id: str = Field(..., description="Provider ID")
api_formats: List[str] = Field(default=[], description="支持的 API 格式列表")
# Key 信息(脱敏)
api_key_masked: str = Field(..., description="脱敏后的 Key")
@@ -303,31 +425,37 @@ class EndpointAPIKeyResponse(BaseModel):
name: str = Field(..., description="密钥名称")
# 成本计算
rate_multiplier: float = Field(default=1.0, description="成本倍率")
rate_multiplier: float = Field(default=1.0, description="默认成本倍率")
rate_multipliers: Optional[Dict[str, float]] = Field(
default=None, description="按 API 格式的成本倍率,如 {'CLAUDE': 1.0, 'OPENAI': 0.8}"
)
# 优先级和限制
internal_priority: int = Field(default=50, description="Endpoint 内部优先级")
global_priority: Optional[int] = Field(default=None, description="全局 Key 优先级")
max_concurrent: Optional[int] = None
rate_limit: Optional[int] = None
daily_limit: Optional[int] = None
monthly_limit: Optional[int] = None
allowed_models: Optional[List[str]] = None
capabilities: Optional[Dict[str, bool]] = Field(
default=None, description="Key 能力标签"
)
rpm_limit: Optional[int] = None
allowed_models: Optional[Union[List[str], Dict[str, List[str]]]] = None
capabilities: Optional[Dict[str, bool]] = Field(default=None, description="Key 能力标签")
# 缓存与熔断配置
cache_ttl_minutes: int = Field(default=5, description="缓存 TTL分钟0=禁用")
max_probe_interval_minutes: int = Field(default=32, description="熔断探测间隔(分钟)")
# 健康度
health_score: float
consecutive_failures: int
# 按格式的健康度数据
health_by_format: Optional[Dict[str, Any]] = Field(
default=None, description="按 API 格式存储的健康度数据"
)
circuit_breaker_by_format: Optional[Dict[str, Any]] = Field(
default=None, description="按 API 格式存储的熔断器状态"
)
# 聚合字段(从 health_by_format 计算,用于列表显示)
health_score: float = Field(default=1.0, description="健康度(所有格式中的最低值)")
consecutive_failures: int = Field(default=0, description="连续失败次数")
last_failure_at: Optional[datetime] = None
# 熔断器状态(滑动窗口 + 半开模式)
circuit_breaker_open: bool = Field(default=False, description="熔断器是否打开")
# 聚合熔断器字段
circuit_breaker_open: bool = Field(default=False, description="熔断器是否打开(任何格式)")
circuit_breaker_open_at: Optional[datetime] = Field(default=None, description="熔断器打开时间")
next_probe_at: Optional[datetime] = Field(default=None, description="下次进入半开状态时间")
half_open_until: Optional[datetime] = Field(default=None, description="半开状态结束时间")
@@ -345,9 +473,9 @@ class EndpointAPIKeyResponse(BaseModel):
# 状态
is_active: bool
# 自适应并发信息
is_adaptive: bool = Field(default=False, description="是否为自适应模式(max_concurrent=NULL")
learned_max_concurrent: Optional[int] = Field(None, description="学习到的并发限制")
# 自适应 RPM 信息
is_adaptive: bool = Field(default=False, description="是否为自适应模式(rpm_limit=NULL")
learned_rpm_limit: Optional[int] = Field(None, description="学习到的 RPM 限制")
effective_limit: Optional[int] = Field(None, description="当前有效限制")
# 滑动窗口利用率采样
utilization_samples: Optional[List[dict]] = Field(None, description="利用率采样窗口")
@@ -371,22 +499,42 @@ class EndpointAPIKeyResponse(BaseModel):
# ========== 健康监控相关 ==========
class HealthStatusResponse(BaseModel):
"""健康状态响应(仅 Key 级别)"""
class FormatHealthData(BaseModel):
"""单个 API 格式的健康度数据"""
# Key 健康状态
health_score: float = 1.0
error_rate: float = 0.0
window_size: int = 0
consecutive_failures: int = 0
last_failure_at: Optional[str] = None
circuit_breaker: Dict[str, Any] = Field(default_factory=dict)
class HealthStatusResponse(BaseModel):
"""健康状态响应(支持按格式查询)"""
# 基础信息
key_id: str
key_health_score: float
key_consecutive_failures: int
key_last_failure_at: Optional[datetime] = None
key_is_active: bool
key_statistics: Optional[Dict[str, Any]] = None
# 熔断器状态(滑动窗口 + 半开模式
# 整体健康度(取所有格式中的最低值
key_health_score: float = 1.0
any_circuit_open: bool = False
# 按格式的健康度数据
health_by_format: Optional[Dict[str, FormatHealthData]] = None
# 单格式查询时的字段
api_format: Optional[str] = None
key_consecutive_failures: Optional[int] = None
key_last_failure_at: Optional[str] = None
# 单格式查询时的熔断器状态
circuit_breaker_open: bool = False
circuit_breaker_open_at: Optional[datetime] = None
next_probe_at: Optional[datetime] = None
half_open_until: Optional[datetime] = None
circuit_breaker_open_at: Optional[str] = None
next_probe_at: Optional[str] = None
half_open_until: Optional[str] = None
half_open_successes: int = 0
half_open_failures: int = 0
@@ -398,33 +546,22 @@ class HealthSummaryResponse(BaseModel):
keys: Dict[str, int] = Field(..., description="Key 统计 (total, active, unhealthy)")
# ========== 并发控制相关 ==========
# ========== RPM 控制相关 ==========
class ConcurrencyStatusResponse(BaseModel):
"""并发状态响应"""
class KeyRpmStatusResponse(BaseModel):
"""Key RPM 状态响应"""
endpoint_id: Optional[str] = None
endpoint_current_concurrency: int = Field(default=0, description="Endpoint 当前并发")
endpoint_max_concurrent: Optional[int] = Field(default=None, description="Endpoint 最大并发数")
key_id: Optional[str] = None
key_current_concurrency: int = Field(default=0, description="Key 当前并发数")
key_max_concurrent: Optional[int] = Field(default=None, description="Key 最大并发数")
class ResetConcurrencyRequest(BaseModel):
"""重置并发计数请求"""
endpoint_id: Optional[str] = Field(default=None, description="Endpoint ID可选")
key_id: Optional[str] = Field(default=None, description="Key ID可选")
key_id: str = Field(..., description="Key ID")
current_rpm: int = Field(default=0, description="当前 RPM 计")
rpm_limit: Optional[int] = Field(default=None, description="RPM 限制")
class KeyPriorityItem(BaseModel):
"""单个 Key 优先级项"""
key_id: str = Field(..., description="Key ID")
internal_priority: int = Field(..., ge=0, description="Endpoint 内部优先级(数字越小越优先)")
internal_priority: int = Field(..., ge=0, description="Key 内部优先级(数字越小越优先)")
class BatchUpdateKeyPriorityRequest(BaseModel):
@@ -439,11 +576,9 @@ class BatchUpdateKeyPriorityRequest(BaseModel):
class ProviderUpdateRequest(BaseModel):
"""Provider 基础配置更新请求"""
display_name: Optional[str] = Field(None, min_length=1, max_length=100)
name: Optional[str] = Field(None, min_length=1, max_length=100)
description: Optional[str] = None
website: Optional[str] = Field(None, max_length=500, description="主站网站")
priority: Optional[int] = None
weight: Optional[float] = Field(None, gt=0)
provider_priority: Optional[int] = Field(None, description="提供商优先级(数字越小越优先)")
is_active: Optional[bool] = None
billing_type: Optional[str] = Field(
@@ -452,9 +587,10 @@ class ProviderUpdateRequest(BaseModel):
monthly_quota_usd: Optional[float] = Field(None, ge=0, description="订阅配额(美元)")
quota_reset_day: Optional[int] = Field(None, ge=1, le=31, description="配额重置日1-31")
quota_expires_at: Optional[datetime] = Field(None, description="配额过期时间")
rpm_limit: Optional[int] = Field(
None, ge=0, description="每分钟请求数限制NULL=无限制0=禁止请求)"
)
# 请求配置(从 Endpoint 迁移)
timeout: Optional[int] = Field(None, ge=1, le=600, description="请求超时(秒")
max_retries: Optional[int] = Field(None, ge=0, le=10, description="最大重试次数")
proxy: Optional[Dict[str, Any]] = Field(None, description="代理配置")
class ProviderWithEndpointsSummary(BaseModel):
@@ -463,7 +599,6 @@ class ProviderWithEndpointsSummary(BaseModel):
# Provider 基本信息
id: str
name: str
display_name: str
description: Optional[str] = None
website: Optional[str] = None
provider_priority: int = Field(default=100, description="提供商优先级(数字越小越优先)")
@@ -477,12 +612,10 @@ class ProviderWithEndpointsSummary(BaseModel):
quota_last_reset_at: Optional[datetime] = Field(default=None, description="当前周期开始时间")
quota_expires_at: Optional[datetime] = Field(default=None, description="配额过期时间")
# RPM 限制
rpm_limit: Optional[int] = Field(
default=None, description="每分钟请求数限制NULL=无限制0=禁止请求)"
)
rpm_used: Optional[int] = Field(default=None, description="当前分钟已用请求数")
rpm_reset_at: Optional[datetime] = Field(default=None, description="RPM 重置时间")
# 请求配置(从 Endpoint 迁移)
timeout: Optional[int] = Field(default=300, description="请求超时(秒)")
max_retries: Optional[int] = Field(default=2, description="最大重试次数")
proxy: Optional[Dict[str, Any]] = Field(default=None, description="代理配置")
# Endpoint 统计
total_endpoints: int = Field(default=0, description="总 Endpoint 数量")
@@ -617,12 +750,8 @@ class PublicApiFormatHealthMonitor(BaseModel):
default_factory=list,
description="Usage 表生成的健康时间线healthy/warning/unhealthy/unknown",
)
time_range_start: Optional[datetime] = Field(
default=None, description="时间线覆盖区间开始时间"
)
time_range_end: Optional[datetime] = Field(
default=None, description="时间线覆盖区间结束时间"
)
time_range_start: Optional[datetime] = Field(default=None, description="时间线覆盖区间开始时间")
time_range_end: Optional[datetime] = Field(default=None, description="时间线覆盖区间结束时间")
class PublicApiFormatHealthMonitorResponse(BaseModel):

View File

@@ -114,7 +114,6 @@ class ModelCatalogProviderDetail(BaseModel):
provider_id: str
provider_name: str
provider_display_name: Optional[str]
model_id: Optional[str]
target_model: str
input_price_per_1m: Optional[float]
@@ -312,16 +311,26 @@ class ImportFromUpstreamRequest(BaseModel):
"""从上游提供商导入模型请求"""
model_ids: List[str] = Field(..., min_length=1, description="上游模型 ID 列表")
# 价格覆盖配置(应用于所有导入的模型)
tiered_pricing: Optional[Dict] = Field(
None,
description="阶梯计费配置(可选),格式: {tiers: [{up_to, input_price_per_1m, output_price_per_1m, ...}]}"
)
price_per_request: Optional[float] = Field(
None,
ge=0,
description="按次计费价格(可选,单位:美元)"
)
class ImportFromUpstreamSuccessItem(BaseModel):
"""导入成功的模型信息"""
model_id: str = Field(..., description="上游模型 ID")
global_model_id: str = Field(..., description="GlobalModel ID")
global_model_name: str = Field(..., description="GlobalModel 名称")
provider_model_id: str = Field(..., description="Provider Model ID")
created_global_model: bool = Field(..., description="是否新创建了 GlobalModel")
global_model_id: Optional[str] = Field("", description="GlobalModel ID如果已关联")
global_model_name: Optional[str] = Field("", description="GlobalModel 名称(如果已关联)")
created_global_model: bool = Field(False, description="是否新创建了 GlobalModel始终为 false")
class ImportFromUpstreamErrorItem(BaseModel):

View File

@@ -34,7 +34,7 @@ import hashlib
import random
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from sqlalchemy.orm import Session, selectinload
@@ -80,8 +80,6 @@ class ProviderCandidate:
@dataclass
class ConcurrencySnapshot:
endpoint_current: int
endpoint_limit: Optional[int]
key_current: int
key_limit: Optional[int]
is_cached_user: bool = False
@@ -91,11 +89,9 @@ class ConcurrencySnapshot:
reservation_confidence: float = 0.0
def describe(self) -> str:
endpoint_limit_text = str(self.endpoint_limit) if self.endpoint_limit is not None else "inf"
key_limit_text = str(self.key_limit) if self.key_limit is not None else "inf"
reservation_text = f"{self.reservation_ratio:.0%}" if self.reservation_ratio > 0 else "N/A"
return (
f"endpoint={self.endpoint_current}/{endpoint_limit_text}, "
f"key={self.key_current}/{key_limit_text}, "
f"cached={self.is_cached_user}, "
f"reserve={reservation_text}({self.reservation_phase})"
@@ -246,9 +242,8 @@ class CacheAwareScheduler:
if not candidates:
if provider_offset == 0:
# 没有找到任何候选,提供友好的错误提示
error_msg = f"模型 '{model_name}' 不可用"
raise ProviderNotAvailableException(error_msg)
# 没有找到任何候选,提供友好的错误提示(不暴露内部信息)
raise ProviderNotAvailableException("请求的模型当前不可用")
break
self._metrics["total_batches"] += 1
@@ -270,7 +265,6 @@ class CacheAwareScheduler:
is_cached_user = bool(candidate.is_cached)
can_use, snapshot = await self._check_concurrent_available(
endpoint,
key,
is_cached_user=is_cached_user,
)
@@ -312,47 +306,51 @@ class CacheAwareScheduler:
provider_offset += provider_batch_size
raise ProviderNotAvailableException(f"所有Provider的资源当前不可用 (model={model_name})")
raise ProviderNotAvailableException("服务暂时繁忙,请稍后重试")
def _get_effective_concurrent_limit(self, key: ProviderAPIKey) -> Optional[int]:
def _get_effective_rpm_limit(self, key: ProviderAPIKey) -> Optional[int]:
"""
获取有效的并发限制
获取有效的 RPM 限制
新逻辑:
- max_concurrent=NULL: 启用自适应,使用 learned_max_concurrent如无学习记录则为 None
- max_concurrent=数字: 固定限制,直接使用该值
- rpm_limit=NULL: 启用自适应,使用 learned_rpm_limit如无学习记录则使用默认初始值
- rpm_limit=数字: 固定限制,直接使用该值
Args:
key: API Key对象
Returns:
有效的并发限制None 表示不限制)
有效的 RPM 限制None 表示不限制)
"""
if key.max_concurrent is None:
if key.rpm_limit is None:
# 自适应模式:使用学习到的值
learned = key.learned_max_concurrent
return int(learned) if learned is not None else None
learned = key.learned_rpm_limit
if learned is not None:
return int(learned)
# 未学习到值时,使用默认初始限制,避免无限制打爆上游
from src.config.constants import RPMDefaults
return int(RPMDefaults.INITIAL_LIMIT)
else:
# 固定限制模式
return int(key.max_concurrent)
return int(key.rpm_limit)
async def _check_concurrent_available(
self,
endpoint: ProviderEndpoint,
key: ProviderAPIKey,
is_cached_user: bool = False,
) -> Tuple[bool, ConcurrencySnapshot]:
"""
检查并发是否可用(使用动态预留机制)
检查 RPM 限制是否可用(使用动态预留机制)
核心逻辑 - 动态缓存预留机制:
- 总槽位: 有效并发限制(固定值或学习到的值)
- 总槽位: 有效 RPM 限制(固定值或学习到的值)
- 预留比例: 由 AdaptiveReservationManager 根据置信度和负载动态计算
- 缓存用户可用: 全部槽位
- 新用户可用: 总槽位 × (1 - 动态预留比例)
Args:
endpoint: ProviderEndpoint对象
key: ProviderAPIKey对象
is_cached_user: 是否是缓存用户
@@ -360,7 +358,7 @@ class CacheAwareScheduler:
(是否可用, 并发快照)
"""
# 获取有效的并发限制
effective_key_limit = self._get_effective_concurrent_limit(key)
effective_key_limit = self._get_effective_rpm_limit(key)
logger.debug(
f" -> 并发检查: _concurrency_manager={self._concurrency_manager is not None}, "
@@ -371,33 +369,23 @@ class CacheAwareScheduler:
# 并发管理器不可用直接返回True
logger.debug(f" -> 无并发管理器,直接通过")
snapshot = ConcurrencySnapshot(
endpoint_current=0,
endpoint_limit=(
int(endpoint.max_concurrent) if endpoint.max_concurrent is not None else None
),
key_current=0,
key_limit=effective_key_limit,
is_cached_user=is_cached_user,
)
return True, snapshot
# 获取当前并发
endpoint_count, key_count = await self._concurrency_manager.get_current_concurrency(
endpoint_id=str(endpoint.id),
# 获取当前 RPM 计
key_count = await self._concurrency_manager.get_key_rpm_count(
key_id=str(key.id),
)
can_use = True
# 检查Endpoint级别限制
if endpoint.max_concurrent is not None:
if endpoint_count >= endpoint.max_concurrent:
can_use = False
# 计算动态预留比例
reservation_result = self._reservation_manager.calculate_reservation(
key=key,
current_concurrent=key_count,
current_usage=key_count,
effective_limit=effective_key_limit,
)
@@ -440,7 +428,8 @@ class CacheAwareScheduler:
# 使用 max 确保至少有 1 个槽位可用
import math
available_for_new = max(1, math.ceil(effective_key_limit * (1 - reservation_ratio)))
# 与 ConcurrencyManager 的 Lua 脚本保持一致:使用 floor 计算新用户可用槽位
available_for_new = max(1, math.floor(effective_key_limit * (1 - reservation_ratio)))
if key_count >= available_for_new:
logger.debug(
f"Key {key.id[:8]}... 新用户配额已满 "
@@ -460,8 +449,6 @@ class CacheAwareScheduler:
key_limit_for_snapshot = None
snapshot = ConcurrencySnapshot(
endpoint_current=endpoint_count,
endpoint_limit=endpoint.max_concurrent,
key_current=key_count,
key_limit=key_limit_for_snapshot,
is_cached_user=is_cached_user,
@@ -475,7 +462,7 @@ class CacheAwareScheduler:
def _get_effective_restrictions(
self,
user_api_key: Optional[ApiKey],
) -> Dict[str, Optional[set]]:
) -> Dict[str, Any]:
"""
获取有效的访问限制(合并 ApiKey 和 User 的限制)
@@ -536,7 +523,10 @@ class CacheAwareScheduler:
)
# 合并 allowed_models
result["allowed_models"] = merge_restrictions(
# allowed_models 支持 list/dict 两种结构,不能转成 set 否则会导致权限校验失效
from src.core.model_permissions import merge_allowed_models
result["allowed_models"] = merge_allowed_models(
user_api_key.allowed_models, user.allowed_models if user else None
)
@@ -617,11 +607,14 @@ class CacheAwareScheduler:
)
return [], global_model_id
# 0.2 检查模型是否被允许
if allowed_models is not None:
if (
requested_model_name not in allowed_models
and resolved_model_name not in allowed_models
# 0.2 检查模型是否被允许(支持简单列表和按格式字典两种模式)
from src.core.model_permissions import check_model_allowed, get_allowed_models_preview
if not check_model_allowed(
model_name=requested_model_name,
allowed_models=allowed_models,
api_format=target_format.value,
resolved_model_name=resolved_model_name,
):
resolved_note = (
f" (解析为 {resolved_model_name})"
@@ -630,7 +623,7 @@ class CacheAwareScheduler:
)
logger.debug(
f"用户/API Key 不允许使用模型 {requested_model_name}{resolved_note}, "
f"允许的模型: {allowed_models}"
f"允许的模型: {get_allowed_models_preview(allowed_models)}"
)
return [], global_model_id
@@ -724,8 +717,11 @@ class CacheAwareScheduler:
provider_query = (
db.query(Provider)
.options(
selectinload(Provider.endpoints).selectinload(ProviderEndpoint.api_keys),
# 同时加载 models 和 global_model 关系,以便 get_effective_* 方法能正确继承默认值
# 预加载 Provider 级别的 api_keys
selectinload(Provider.api_keys),
# 预加载 endpoints用于按 api_format 选择请求配置)
selectinload(Provider.endpoints),
# 同时加载 models 和 global_model 关系
selectinload(Provider.models).selectinload(Model.global_model),
)
.filter(Provider.is_active == True)
@@ -852,6 +848,7 @@ class CacheAwareScheduler:
def _check_key_availability(
self,
key: ProviderAPIKey,
api_format: Optional[str],
model_name: str,
capability_requirements: Optional[Dict[str, bool]] = None,
resolved_model_name: Optional[str] = None,
@@ -871,20 +868,24 @@ class CacheAwareScheduler:
Returns:
(is_available, skip_reason)
"""
# 检查熔断器状态(使用详细状态方法获取更丰富的跳过原因)
is_available, circuit_reason = health_monitor.get_circuit_breaker_status(key)
# 检查熔断器状态(使用详细状态方法获取更丰富的跳过原因,按 API 格式
is_available, circuit_reason = health_monitor.get_circuit_breaker_status(
key, api_format=api_format
)
if not is_available:
return False, circuit_reason or "熔断器已打开"
# 模型权限检查:使用 allowed_models 白名单
# 模型权限检查:使用 allowed_models 白名单(支持简单列表和按格式字典两种模式)
# None = 允许所有模型,[] = 拒绝所有模型,["a","b"] = 只允许指定模型
if key.allowed_models is not None and (
model_name not in key.allowed_models
and (not resolved_model_name or resolved_model_name not in key.allowed_models)
from src.core.model_permissions import check_model_allowed, get_allowed_models_preview
if not check_model_allowed(
model_name=model_name,
allowed_models=key.allowed_models,
api_format=api_format,
resolved_model_name=resolved_model_name,
):
allowed_preview = ", ".join(key.allowed_models[:3]) if key.allowed_models else "(无)"
suffix = "..." if len(key.allowed_models) > 3 else ""
return False, f"模型权限不匹配(允许: {allowed_preview}{suffix})"
return False, f"模型权限不匹配(允许: {get_allowed_models_preview(key.allowed_models)})"
# Key 级别的能力匹配检查
# 注意:模型级别的能力检查已在 _check_model_support 中完成
@@ -914,6 +915,8 @@ class CacheAwareScheduler:
"""
构建候选列表
Key 直属 Provider通过 api_formats 筛选符合目标格式的 Key。
Args:
db: 数据库会话
providers: Provider 列表
@@ -929,10 +932,10 @@ class CacheAwareScheduler:
候选列表
"""
candidates: List[ProviderCandidate] = []
target_format_str = target_format.value
for provider in providers:
# 检查模型支持(同时检查流式支持和模型能力需求)
# 模型能力检查在 Provider 级别进行,如果模型不支持所需能力,整个 Provider 被跳过
supports_model, skip_reason, _model_caps = await self._check_model_support(
db, provider, model_name, is_stream, capability_requirements
)
@@ -940,33 +943,47 @@ class CacheAwareScheduler:
logger.debug(f"Provider {provider.name} 不支持模型 {model_name}: {skip_reason}")
continue
# 查找目标格式对应的 Endpoint获取请求配置
target_endpoint = None
for endpoint in provider.endpoints:
# endpoint.api_format 是字符串target_format 是枚举
endpoint_format_str = (
endpoint.api_format
if isinstance(endpoint.api_format, str)
else endpoint.api_format.value
)
if not endpoint.is_active or endpoint_format_str != target_format.value:
if endpoint.is_active and endpoint_format_str == target_format_str:
target_endpoint = endpoint
break
if not target_endpoint:
logger.debug(f"Provider {provider.name} 没有活跃的 {target_format_str} 端点")
continue
# Key 直属 Provider通过 api_formats 筛选
active_keys = [
key for key in provider.api_keys
if key.is_active and target_format_str in (key.api_formats or [])
]
if not active_keys:
logger.debug(f"Provider {provider.name} 没有支持 {target_format_str} 的活跃 Key")
continue
# 获取活跃的 Key 并按 internal_priority + 负载均衡排序
active_keys = [key for key in endpoint.api_keys if key.is_active]
# 检查是否所有 Key 都是 TTL=0轮换模式
# 如果所有 Key 的 cache_ttl_minutes 都是 0 或 None则使用随机排序
use_random = all(
(key.cache_ttl_minutes or 0) == 0 for key in active_keys
) if active_keys else False
if use_random and len(active_keys) > 1:
logger.debug(
f" Endpoint {endpoint.id[:8]}... 启用 Key 轮换模式 (TTL=0, {len(active_keys)} keys)"
f" Provider {provider.name} 启用 Key 轮换模式 (TTL=0, {len(active_keys)} keys)"
)
keys = self._shuffle_keys_by_internal_priority(active_keys, affinity_key, use_random)
for key in keys:
# Key 级别的能力检查(模型级别的能力检查已在上面完成)
# Key 级别的能力检查
is_available, skip_reason = self._check_key_availability(
key,
target_format_str,
model_name,
capability_requirements,
resolved_model_name=resolved_model_name,
@@ -974,7 +991,7 @@ class CacheAwareScheduler:
candidate = ProviderCandidate(
provider=provider,
endpoint=endpoint,
endpoint=target_endpoint,
key=key,
is_skipped=not is_available,
skip_reason=skip_reason,
@@ -1187,7 +1204,6 @@ class CacheAwareScheduler:
from collections import defaultdict
# 使用 tuple 作为统一的 key 类型,兼容两种模式
priority_groups: Dict[tuple, List[ProviderCandidate]] = defaultdict(list)
# 根据优先级模式选择分组方式

View File

@@ -27,24 +27,29 @@ class ProviderCacheService:
@staticmethod
async def get_provider_api_key_rate_multiplier(
db: Session, provider_api_key_id: str
db: Session, provider_api_key_id: str, api_format: Optional[str] = None
) -> Optional[float]:
"""
获取 ProviderAPIKey 的 rate_multiplier带缓存
优先返回指定 API 格式的倍率,如果没有则返回默认倍率。
Args:
db: 数据库会话
provider_api_key_id: ProviderAPIKey ID
api_format: API 格式(可选),如 "CLAUDE""OPENAI"
Returns:
rate_multiplier 或 None如果找不到
"""
cache_key = f"provider_api_key:rate_multiplier:{provider_api_key_id}"
# 缓存键包含 api_format
format_suffix = api_format.upper() if api_format else "default"
cache_key = f"provider_api_key:rate_multiplier:{provider_api_key_id}:{format_suffix}"
# 1. 尝试从缓存获取
cached_data = await CacheService.get(cache_key)
if cached_data is not None:
logger.debug(f"ProviderAPIKey rate_multiplier 缓存命中: {provider_api_key_id[:8]}...")
logger.debug(f"ProviderAPIKey rate_multiplier 缓存命中: {provider_api_key_id[:8]}... format={format_suffix}")
# 缓存的 "NOT_FOUND" 表示数据库中不存在
if cached_data == "NOT_FOUND":
return None
@@ -52,18 +57,24 @@ class ProviderCacheService:
# 2. 缓存未命中,查询数据库
provider_key = (
db.query(ProviderAPIKey.rate_multiplier)
db.query(ProviderAPIKey.rate_multiplier, ProviderAPIKey.rate_multipliers)
.filter(ProviderAPIKey.id == provider_api_key_id)
.first()
)
# 3. 写入缓存
# 3. 计算倍率并写入缓存
if provider_key:
# 优先使用 rate_multipliers[api_format],回退到 rate_multiplier
rate_multiplier = provider_key.rate_multiplier or 1.0
if api_format and provider_key.rate_multipliers:
format_upper = api_format.upper()
if format_upper in provider_key.rate_multipliers:
rate_multiplier = provider_key.rate_multipliers[format_upper]
await CacheService.set(
cache_key, rate_multiplier, ttl_seconds=ProviderCacheService.CACHE_TTL
)
logger.debug(f"ProviderAPIKey rate_multiplier 已缓存: {provider_api_key_id[:8]}...")
logger.debug(f"ProviderAPIKey rate_multiplier 已缓存: {provider_api_key_id[:8]}... format={format_suffix} value={rate_multiplier}")
return rate_multiplier
else:
# 缓存负结果
@@ -125,6 +136,7 @@ class ProviderCacheService:
db: Session,
provider_api_key_id: Optional[str],
provider_id: Optional[str],
api_format: Optional[str] = None,
) -> Tuple[float, bool]:
"""
获取费率倍数和是否免费套餐(带缓存)
@@ -135,6 +147,7 @@ class ProviderCacheService:
db: 数据库会话
provider_api_key_id: ProviderAPIKey ID可选
provider_id: Provider ID可选
api_format: API 格式(可选),用于获取按格式配置的倍率
Returns:
(rate_multiplier, is_free_tier) 元组
@@ -142,10 +155,10 @@ class ProviderCacheService:
actual_rate_multiplier = 1.0
is_free_tier = False
# 获取费率倍数
# 获取费率倍数(支持按 API 格式查询)
if provider_api_key_id:
rate_multiplier = await ProviderCacheService.get_provider_api_key_rate_multiplier(
db, provider_api_key_id
db, provider_api_key_id, api_format
)
if rate_multiplier is not None:
actual_rate_multiplier = rate_multiplier
@@ -160,8 +173,9 @@ class ProviderCacheService:
@staticmethod
async def invalidate_provider_api_key_cache(provider_api_key_id: str) -> None:
"""清除 ProviderAPIKey 缓存"""
await CacheService.delete(f"provider_api_key:rate_multiplier:{provider_api_key_id}")
"""清除 ProviderAPIKey 缓存(包括所有 API 格式的缓存)"""
# 使用模式匹配删除所有格式的缓存
await CacheService.delete_pattern(f"provider_api_key:rate_multiplier:{provider_api_key_id}:*")
logger.debug(f"ProviderAPIKey 缓存已清除: {provider_api_key_id[:8]}...")
@staticmethod

View File

@@ -70,20 +70,21 @@ class EndpointHealthService:
db.query(ProviderEndpoint).join(Provider).filter(Provider.is_active.is_(True)).all()
)
# 收集所有 endpoint_ids
all_endpoint_ids = [ep.id for ep in endpoints]
# 收集所有 provider_ids
all_provider_ids = list(set(ep.provider_id for ep in endpoints))
# 批量查询所有密钥
# 批量查询所有密钥(通过 provider_id 关联)
all_keys = (
db.query(ProviderAPIKey)
.filter(ProviderAPIKey.endpoint_id.in_(all_endpoint_ids))
.filter(ProviderAPIKey.provider_id.in_(all_provider_ids))
.all()
) if all_endpoint_ids else []
) if all_provider_ids else []
# 按 endpoint_id 分组密钥
keys_by_endpoint: Dict[str, List[ProviderAPIKey]] = defaultdict(list)
# 按 api_format 分组密钥(通过 api_formats 字段)
keys_by_format: Dict[str, List[ProviderAPIKey]] = defaultdict(list)
for key in all_keys:
keys_by_endpoint[key.endpoint_id].append(key)
for fmt in (key.api_formats or []):
keys_by_format[fmt].append(key)
# 按 API 格式聚合
format_stats = defaultdict(
@@ -106,17 +107,35 @@ class EndpointHealthService:
format_stats[api_format]["endpoint_ids"].append(ep.id)
format_stats[api_format]["provider_ids"].add(ep.provider_id)
# 从预加载的密钥中获取
keys = keys_by_endpoint.get(ep.id, [])
format_stats[api_format]["total_keys"] += len(keys)
# 统计每个格式的密钥(直接从 keys_by_format 获取
for api_format, keys in keys_by_format.items():
if api_format not in format_stats:
# 如果有 Key 但没有对应的 Endpoint跳过
continue
# 统计活跃密钥和健康度
if ep.is_active:
# 去重(同一个 Key 可能支持多个格式)
seen_key_ids = set()
unique_keys = []
for key in keys:
if key.id not in seen_key_ids:
seen_key_ids.add(key.id)
unique_keys.append(key)
format_stats[api_format]["total_keys"] = len(unique_keys)
for key in unique_keys:
format_stats[api_format]["key_ids"].append(key.id)
if key.is_active and not key.circuit_breaker_open:
# 检查该格式的熔断器状态
circuit_by_format = key.circuit_breaker_by_format or {}
format_circuit = circuit_by_format.get(api_format, {})
is_circuit_open = format_circuit.get("open", False)
if key.is_active and not is_circuit_open:
format_stats[api_format]["active_keys"] += 1
health_score = key.health_score if key.health_score is not None else 1.0
# 获取该格式的健康度
health_by_format = key.health_by_format or {}
format_health = health_by_format.get(api_format, {})
health_score = float(format_health.get("health_score") or 1.0)
format_stats[api_format]["health_scores"].append(health_score)
# 批量生成所有格式的时间线数据
@@ -372,7 +391,7 @@ class EndpointHealthService:
segments: int = 100,
) -> Dict[str, Any]:
"""
从真实使用记录生成时间线数据(兼容旧接口,使用批量查询优化)
从真实使用记录生成时间线数据(使用批量查询优化)
Args:
db: 数据库会话
@@ -391,13 +410,34 @@ class EndpointHealthService:
"time_range_end": None,
}
# 先查询该 API 格式下的所有密钥
key_ids = [
k.id
for k in db.query(ProviderAPIKey.id)
.filter(ProviderAPIKey.endpoint_id.in_(endpoint_ids))
# 基于 endpoint_ids 反推 provider_ids 与 api_format再选出支持该格式的 keys
endpoint_rows = (
db.query(ProviderEndpoint.provider_id, ProviderEndpoint.api_format)
.filter(ProviderEndpoint.id.in_(endpoint_ids))
.all()
]
)
if not endpoint_rows:
return {
"timeline": ["unknown"] * 100,
"time_range_start": None,
"time_range_end": None,
}
provider_ids = {str(pid) for pid, _fmt in endpoint_rows}
# 同一调用中 endpoint_ids 来自同一 api_format上层已按格式分组
api_format = (
endpoint_rows[0][1].value
if hasattr(endpoint_rows[0][1], "value")
else str(endpoint_rows[0][1])
)
keys = (
db.query(ProviderAPIKey.id, ProviderAPIKey.api_formats)
.filter(ProviderAPIKey.provider_id.in_(provider_ids))
.all()
)
key_ids = [str(key_id) for key_id, formats in keys if api_format in (formats or [])]
if not key_ids:
return {

View File

@@ -1,11 +1,15 @@
"""
健康监控器 - Endpoint 和 Key 的健康度追踪
健康监控器 - Endpoint 和 Key 的健康度追踪(按 API 格式区分)
功能:
1. 基于滑动窗口的错误率计算
2. 三态熔断器:关闭 -> 打开 -> 半开 -> 关闭
1. 基于滑动窗口的错误率计算(按 API 格式独立)
2. 三态熔断器:关闭 -> 打开 -> 半开 -> 关闭(按 API 格式独立)
3. 半开状态允许少量请求验证服务恢复
4. 提供健康度查询和管理 API
数据结构:
- health_by_format: {"CLAUDE": {"health_score": 1.0, "consecutive_failures": 0, ...}, ...}
- circuit_breaker_by_format: {"CLAUDE": {"open": false, "open_at": null, ...}, ...}
"""
import os
@@ -30,8 +34,30 @@ class CircuitState:
HALF_OPEN = "half_open" # 半开(验证恢复)
# 默认健康度数据结构
def _default_health_data() -> Dict[str, Any]:
return {
"health_score": 1.0,
"consecutive_failures": 0,
"last_failure_at": None,
"request_results_window": [],
}
# 默认熔断器数据结构
def _default_circuit_data() -> Dict[str, Any]:
return {
"open": False,
"open_at": None,
"next_probe_at": None,
"half_open_until": None,
"half_open_successes": 0,
"half_open_failures": 0,
}
class HealthMonitor:
"""健康监控器(滑动窗口 + 半开状态模式)"""
"""健康监控器(滑动窗口 + 半开状态模式,按 API 格式区分"""
# === 滑动窗口配置 ===
WINDOW_SIZE = int(os.getenv("HEALTH_WINDOW_SIZE", str(CircuitBreakerDefaults.WINDOW_SIZE)))
@@ -96,6 +122,38 @@ class HealthMonitor:
_circuit_history: List[Dict[str, Any]] = []
_open_circuit_keys: int = 0
# ==================== 数据访问辅助方法 ====================
@classmethod
def _get_health_data(cls, key: ProviderAPIKey, api_format: str) -> Dict[str, Any]:
"""获取指定格式的健康度数据,不存在则返回默认值"""
health_by_format = key.health_by_format or {}
if api_format not in health_by_format:
return _default_health_data()
return health_by_format[api_format]
@classmethod
def _set_health_data(cls, key: ProviderAPIKey, api_format: str, data: Dict[str, Any]) -> None:
"""设置指定格式的健康度数据"""
health_by_format = dict(key.health_by_format or {})
health_by_format[api_format] = data
key.health_by_format = health_by_format # type: ignore[assignment]
@classmethod
def _get_circuit_data(cls, key: ProviderAPIKey, api_format: str) -> Dict[str, Any]:
"""获取指定格式的熔断器数据,不存在则返回默认值"""
circuit_by_format = key.circuit_breaker_by_format or {}
if api_format not in circuit_by_format:
return _default_circuit_data()
return circuit_by_format[api_format]
@classmethod
def _set_circuit_data(cls, key: ProviderAPIKey, api_format: str, data: Dict[str, Any]) -> None:
"""设置指定格式的熔断器数据"""
circuit_by_format = dict(key.circuit_breaker_by_format or {})
circuit_by_format[api_format] = data
key.circuit_breaker_by_format = circuit_by_format # type: ignore[assignment]
# ==================== 核心方法 ====================
@classmethod
@@ -103,9 +161,21 @@ class HealthMonitor:
cls,
db: Session,
key_id: Optional[str] = None,
api_format: Optional[str] = None,
response_time_ms: Optional[int] = None,
) -> None:
"""记录成功请求"""
"""记录成功请求(按 API 格式)
Args:
db: 数据库会话
key_id: Key ID必需
api_format: API 格式(必需,用于区分不同格式的健康度)
response_time_ms: 响应时间(可选)
Note:
api_format 在逻辑上是必需的,但为了向后兼容保持 Optional 签名。
如果未提供,会尝试从 Key 的 api_formats 中获取第一个格式作为 fallback。
"""
try:
if not key_id:
return
@@ -114,39 +184,96 @@ class HealthMonitor:
if not key:
return
# api_format 兼容处理:如果未提供,尝试使用 Key 的第一个格式
effective_api_format = api_format
if not effective_api_format:
if key.api_formats and len(key.api_formats) > 0:
effective_api_format = key.api_formats[0]
logger.debug(
f"record_success: api_format 未提供,使用默认格式 {effective_api_format}"
)
else:
logger.warning(
f"record_success: api_format 未提供且 Key 无可用格式: key_id={key_id[:8]}..."
)
return
now = datetime.now(timezone.utc)
now_ts = now.timestamp()
# 获取当前格式的健康度数据
health_data = cls._get_health_data(key, effective_api_format)
circuit_data = cls._get_circuit_data(key, effective_api_format)
# 1. 更新滑动窗口
cls._add_to_window(key, now_ts, success=True)
window = health_data.get("request_results_window") or []
window.append({"ts": now_ts, "ok": True})
cutoff_ts = now_ts - cls.WINDOW_SECONDS
window = [r for r in window if r["ts"] > cutoff_ts]
if len(window) > cls.WINDOW_SIZE:
window = window[-cls.WINDOW_SIZE :]
health_data["request_results_window"] = window
# 2. 更新健康度(用于展示)
new_score = min(float(key.health_score or 0) + cls.SUCCESS_INCREMENT, 1.0)
key.health_score = new_score # type: ignore[assignment]
current_score = float(health_data.get("health_score") or 0)
new_score = min(current_score + cls.SUCCESS_INCREMENT, 1.0)
health_data["health_score"] = new_score
# 3. 更新统计
key.consecutive_failures = 0 # type: ignore[assignment]
key.last_failure_at = None # type: ignore[assignment]
health_data["consecutive_failures"] = 0
health_data["last_failure_at"] = None
# 4. 处理熔断器状态
state = cls._get_circuit_state_from_data(circuit_data, now)
if state == CircuitState.HALF_OPEN:
# 半开状态:记录成功
circuit_data["half_open_successes"] = int(
circuit_data.get("half_open_successes") or 0
) + 1
if circuit_data["half_open_successes"] >= cls.HALF_OPEN_SUCCESS_THRESHOLD:
# 达到成功阈值,关闭熔断器
cls._close_circuit_data(circuit_data, health_data, reason="半开状态验证成功")
cls._push_circuit_event(
{
"event": "closed",
"key_id": key.id,
"api_format": effective_api_format,
"reason": "半开状态验证成功",
"timestamp": now.isoformat(),
}
)
logger.info(
f"[CLOSED] Key 熔断器关闭: {key.id[:8]}.../{effective_api_format} | 原因: 半开状态验证成功"
)
elif state == CircuitState.OPEN:
# 打开状态下的成功(探测成功),进入半开状态
cls._enter_half_open_data(circuit_data, now)
cls._push_circuit_event(
{
"event": "half_open",
"key_id": key.id,
"api_format": effective_api_format,
"timestamp": now.isoformat(),
}
)
logger.info(
f"[HALF-OPEN] Key 进入半开状态: {key.id[:8]}.../{effective_api_format} | "
f"需要 {cls.HALF_OPEN_SUCCESS_THRESHOLD} 次成功关闭熔断器"
)
# 保存数据
cls._set_health_data(key, effective_api_format, health_data)
cls._set_circuit_data(key, effective_api_format, circuit_data)
# 更新全局统计
key.success_count = int(key.success_count or 0) + 1 # type: ignore[assignment]
key.request_count = int(key.request_count or 0) + 1 # type: ignore[assignment]
if response_time_ms:
key.total_response_time_ms = int(key.total_response_time_ms or 0) + response_time_ms # type: ignore[assignment]
# 4. 处理熔断器状态
state = cls._get_circuit_state(key, now)
if state == CircuitState.HALF_OPEN:
# 半开状态:记录成功
key.half_open_successes = int(key.half_open_successes or 0) + 1 # type: ignore[assignment]
if int(key.half_open_successes or 0) >= cls.HALF_OPEN_SUCCESS_THRESHOLD:
# 达到成功阈值,关闭熔断器
cls._close_circuit(key, now, reason="半开状态验证成功")
elif state == CircuitState.OPEN:
# 打开状态下的成功(探测成功),进入半开状态
cls._enter_half_open(key, now)
db.flush()
get_batch_committer().mark_dirty(db)
@@ -159,9 +286,21 @@ class HealthMonitor:
cls,
db: Session,
key_id: Optional[str] = None,
api_format: Optional[str] = None,
error_type: Optional[str] = None,
) -> None:
"""记录失败请求"""
"""记录失败请求(按 API 格式)
Args:
db: 数据库会话
key_id: Key ID必需
api_format: API 格式(必需,用于区分不同格式的健康度)
error_type: 错误类型(可选)
Note:
api_format 在逻辑上是必需的,但为了向后兼容保持 Optional 签名。
如果未提供,会尝试从 Key 的 api_formats 中获取第一个格式作为 fallback。
"""
try:
if not key_id:
return
@@ -170,46 +309,117 @@ class HealthMonitor:
if not key:
return
# api_format 兼容处理:如果未提供,尝试使用 Key 的第一个格式
effective_api_format = api_format
if not effective_api_format:
if key.api_formats and len(key.api_formats) > 0:
effective_api_format = key.api_formats[0]
logger.debug(
f"record_failure: api_format 未提供,使用默认格式 {effective_api_format}"
)
else:
logger.warning(
f"record_failure: api_format 未提供且 Key 无可用格式: key_id={key_id[:8]}..."
)
return
now = datetime.now(timezone.utc)
now_ts = now.timestamp()
# 获取当前格式的健康度数据
health_data = cls._get_health_data(key, effective_api_format)
circuit_data = cls._get_circuit_data(key, effective_api_format)
# 1. 更新滑动窗口
cls._add_to_window(key, now_ts, success=False)
window = health_data.get("request_results_window") or []
window.append({"ts": now_ts, "ok": False})
cutoff_ts = now_ts - cls.WINDOW_SECONDS
window = [r for r in window if r["ts"] > cutoff_ts]
if len(window) > cls.WINDOW_SIZE:
window = window[-cls.WINDOW_SIZE :]
health_data["request_results_window"] = window
# 2. 更新健康度(用于展示)
new_score = max(float(key.health_score or 1) - cls.FAILURE_DECREMENT, 0.0)
key.health_score = new_score # type: ignore[assignment]
current_score = float(health_data.get("health_score") or 1)
new_score = max(current_score - cls.FAILURE_DECREMENT, 0.0)
health_data["health_score"] = new_score
# 3. 更新统计
key.consecutive_failures = int(key.consecutive_failures or 0) + 1 # type: ignore[assignment]
key.last_failure_at = now # type: ignore[assignment]
key.error_count = int(key.error_count or 0) + 1 # type: ignore[assignment]
key.request_count = int(key.request_count or 0) + 1 # type: ignore[assignment]
health_data["consecutive_failures"] = (
int(health_data.get("consecutive_failures") or 0) + 1
)
health_data["last_failure_at"] = now.isoformat()
# 4. 处理熔断器状态
state = cls._get_circuit_state(key, now)
state = cls._get_circuit_state_from_data(circuit_data, now)
if state == CircuitState.HALF_OPEN:
# 半开状态:记录失败
key.half_open_failures = int(key.half_open_failures or 0) + 1 # type: ignore[assignment]
circuit_data["half_open_failures"] = int(
circuit_data.get("half_open_failures") or 0
) + 1
if int(key.half_open_failures or 0) >= cls.HALF_OPEN_FAILURE_THRESHOLD:
if circuit_data["half_open_failures"] >= cls.HALF_OPEN_FAILURE_THRESHOLD:
# 达到失败阈值,重新打开熔断器
cls._open_circuit(key, now, reason="半开状态验证失败")
# 注意:半开状态本身就是打开状态的子状态,不需要增加计数
consecutive = int(health_data.get("consecutive_failures") or 0)
recovery_seconds = cls._calculate_recovery_seconds(consecutive)
cls._open_circuit_data(
circuit_data, now, recovery_seconds, reason="半开状态验证失败"
)
cls._push_circuit_event(
{
"event": "opened",
"key_id": key.id,
"api_format": effective_api_format,
"reason": "半开状态验证失败",
"recovery_seconds": recovery_seconds,
"timestamp": now.isoformat(),
}
)
logger.warning(
f"[OPEN] Key 熔断器打开: {key.id[:8]}.../{effective_api_format} | 原因: 半开状态验证失败 | "
f"{recovery_seconds}秒后进入半开状态"
)
elif state == CircuitState.CLOSED:
# 关闭状态:检查是否需要打开熔断器
error_rate = cls._calculate_error_rate(key, now_ts)
window = key.request_results_window or []
error_rate = cls._calculate_error_rate_from_window(window, now_ts)
if len(window) >= cls.MIN_REQUESTS and error_rate >= cls.ERROR_RATE_THRESHOLD:
cls._open_circuit(
key, now, reason=f"错误率 {error_rate:.0%} 超过阈值 {cls.ERROR_RATE_THRESHOLD:.0%}"
consecutive = int(health_data.get("consecutive_failures") or 0)
recovery_seconds = cls._calculate_recovery_seconds(consecutive)
reason = f"错误率 {error_rate:.0%} 超过阈值 {cls.ERROR_RATE_THRESHOLD:.0%}"
cls._open_circuit_data(circuit_data, now, recovery_seconds, reason=reason)
cls._open_circuit_keys += 1
health_open_circuits.set(cls._open_circuit_keys)
cls._push_circuit_event(
{
"event": "opened",
"key_id": key.id,
"api_format": effective_api_format,
"reason": reason,
"recovery_seconds": recovery_seconds,
"timestamp": now.isoformat(),
}
)
logger.warning(
f"[OPEN] Key 熔断器打开: {key.id[:8]}.../{effective_api_format} | 原因: {reason} | "
f"{recovery_seconds}秒后进入半开状态"
)
# 保存数据
cls._set_health_data(key, effective_api_format, health_data)
cls._set_circuit_data(key, effective_api_format, circuit_data)
# 更新全局统计
key.error_count = int(key.error_count or 0) + 1 # type: ignore[assignment]
key.request_count = int(key.request_count or 0) + 1 # type: ignore[assignment]
key.last_error_at = now # type: ignore[assignment]
logger.debug(
f"[WARN] Key 健康度下降: {key_id[:8]}... -> {new_score:.2f} "
f"(连续失败 {key.consecutive_failures} 次, error_type={error_type})"
f"[WARN] Key 健康度下降: {key_id[:8]}.../{effective_api_format} -> {new_score:.2f} "
f"(连续失败 {health_data['consecutive_failures']} 次, error_type={error_type})"
)
db.flush()
@@ -222,31 +432,13 @@ class HealthMonitor:
# ==================== 滑动窗口方法 ====================
@classmethod
def _add_to_window(cls, key: ProviderAPIKey, now_ts: float, success: bool) -> None:
"""添加请求结果到滑动窗口"""
window: List[Dict[str, Any]] = key.request_results_window or []
# 添加新记录
window.append({"ts": now_ts, "ok": success})
# 清理过期记录
cutoff_ts = now_ts - cls.WINDOW_SECONDS
window = [r for r in window if r["ts"] > cutoff_ts]
# 限制窗口大小
if len(window) > cls.WINDOW_SIZE:
window = window[-cls.WINDOW_SIZE :]
key.request_results_window = window # type: ignore[assignment]
@classmethod
def _calculate_error_rate(cls, key: ProviderAPIKey, now_ts: float) -> float:
"""计算滑动窗口内的错误率"""
window: List[Dict[str, Any]] = key.request_results_window or []
def _calculate_error_rate_from_window(
cls, window: List[Dict[str, Any]], now_ts: float
) -> float:
"""从窗口数据计算错误率"""
if not window:
return 0.0
# 过滤过期记录
cutoff_ts = now_ts - cls.WINDOW_SECONDS
valid_records = [r for r in window if r["ts"] > cutoff_ts]
@@ -256,157 +448,158 @@ class HealthMonitor:
failures = sum(1 for r in valid_records if not r["ok"])
return failures / len(valid_records)
# ==================== 熔断器状态方法 ====================
# ==================== 熔断器状态方法(操作数据字典)====================
@classmethod
def _get_circuit_state(cls, key: ProviderAPIKey, now: datetime) -> str:
"""获取当前熔断器状态"""
if not key.circuit_breaker_open:
def _get_circuit_state_from_data(cls, circuit_data: Dict[str, Any], now: datetime) -> str:
"""从数据字典获取当前熔断器状态"""
if not circuit_data.get("open"):
return CircuitState.CLOSED
# 检查是否在半开状态
if key.half_open_until and now < key.half_open_until:
half_open_until_str = circuit_data.get("half_open_until")
if half_open_until_str:
half_open_until = datetime.fromisoformat(half_open_until_str)
if now < half_open_until:
return CircuitState.HALF_OPEN
# 检查是否到了探测时间(进入半开)
if key.next_probe_at and now >= key.next_probe_at:
next_probe_str = circuit_data.get("next_probe_at")
if next_probe_str:
next_probe_at = datetime.fromisoformat(next_probe_str)
if now >= next_probe_at:
return CircuitState.HALF_OPEN
return CircuitState.OPEN
@classmethod
def _open_circuit(cls, key: ProviderAPIKey, now: datetime, reason: str) -> None:
"""打开熔断器"""
was_open = key.circuit_breaker_open
key.circuit_breaker_open = True # type: ignore[assignment]
key.circuit_breaker_open_at = now # type: ignore[assignment]
key.half_open_until = None # type: ignore[assignment]
key.half_open_successes = 0 # type: ignore[assignment]
key.half_open_failures = 0 # type: ignore[assignment]
# 计算下次探测时间(进入半开状态的时间)
consecutive = int(key.consecutive_failures or 0)
recovery_seconds = cls._calculate_recovery_seconds(consecutive)
key.next_probe_at = now + timedelta(seconds=recovery_seconds) # type: ignore[assignment]
if not was_open:
cls._open_circuit_keys += 1
health_open_circuits.set(cls._open_circuit_keys)
logger.warning(
f"[OPEN] Key 熔断器打开: {key.id[:8]}... | 原因: {reason} | "
f"{recovery_seconds}秒后进入半开状态"
)
cls._push_circuit_event(
{
"event": "opened",
"key_id": key.id,
"reason": reason,
"recovery_seconds": recovery_seconds,
"timestamp": now.isoformat(),
}
)
def _open_circuit_data(
cls,
circuit_data: Dict[str, Any],
now: datetime,
recovery_seconds: int,
reason: str,
) -> None:
"""打开熔断器(操作数据字典)"""
circuit_data["open"] = True
circuit_data["open_at"] = now.isoformat()
circuit_data["half_open_until"] = None
circuit_data["half_open_successes"] = 0
circuit_data["half_open_failures"] = 0
circuit_data["next_probe_at"] = (now + timedelta(seconds=recovery_seconds)).isoformat()
@classmethod
def _enter_half_open(cls, key: ProviderAPIKey, now: datetime) -> None:
"""进入半开状态"""
key.half_open_until = now + timedelta(seconds=cls.HALF_OPEN_DURATION) # type: ignore[assignment]
key.half_open_successes = 0 # type: ignore[assignment]
key.half_open_failures = 0 # type: ignore[assignment]
logger.info(
f"[HALF-OPEN] Key 进入半开状态: {key.id[:8]}... | "
f"需要 {cls.HALF_OPEN_SUCCESS_THRESHOLD} 次成功关闭熔断器"
)
cls._push_circuit_event(
{
"event": "half_open",
"key_id": key.id,
"timestamp": now.isoformat(),
}
)
def _enter_half_open_data(cls, circuit_data: Dict[str, Any], now: datetime) -> None:
"""进入半开状态(操作数据字典)"""
circuit_data["half_open_until"] = (
now + timedelta(seconds=cls.HALF_OPEN_DURATION)
).isoformat()
circuit_data["half_open_successes"] = 0
circuit_data["half_open_failures"] = 0
@classmethod
def _close_circuit(cls, key: ProviderAPIKey, now: datetime, reason: str) -> None:
"""关闭熔断器"""
key.circuit_breaker_open = False # type: ignore[assignment]
key.circuit_breaker_open_at = None # type: ignore[assignment]
key.next_probe_at = None # type: ignore[assignment]
key.half_open_until = None # type: ignore[assignment]
key.half_open_successes = 0 # type: ignore[assignment]
key.half_open_failures = 0 # type: ignore[assignment]
def _close_circuit_data(
cls, circuit_data: Dict[str, Any], health_data: Dict[str, Any], reason: str
) -> None:
"""关闭熔断器(操作数据字典)"""
circuit_data["open"] = False
circuit_data["open_at"] = None
circuit_data["next_probe_at"] = None
circuit_data["half_open_until"] = None
circuit_data["half_open_successes"] = 0
circuit_data["half_open_failures"] = 0
# 快速恢复健康度
key.health_score = max(float(key.health_score or 0), cls.PROBE_RECOVERY_SCORE) # type: ignore[assignment]
current_score = float(health_data.get("health_score") or 0)
health_data["health_score"] = max(current_score, cls.PROBE_RECOVERY_SCORE)
cls._open_circuit_keys = max(0, cls._open_circuit_keys - 1)
health_open_circuits.set(cls._open_circuit_keys)
logger.info(f"[CLOSED] Key 熔断器关闭: {key.id[:8]}... | 原因: {reason}")
cls._push_circuit_event(
{
"event": "closed",
"key_id": key.id,
"reason": reason,
"timestamp": now.isoformat(),
}
)
@classmethod
def _calculate_recovery_seconds(cls, consecutive_failures: int) -> int:
"""计算恢复等待时间(指数退避)"""
# 指数退避30s -> 60s -> 120s -> 240s -> 300s上限
exponent = min(consecutive_failures // 5, 4) # 每5次失败增加一级
exponent = min(consecutive_failures // 5, 4)
seconds = cls.INITIAL_RECOVERY_SECONDS * (cls.RECOVERY_BACKOFF**exponent)
return min(int(seconds), cls.MAX_RECOVERY_SECONDS)
# ==================== 状态查询方法 ====================
@classmethod
def is_circuit_breaker_closed(cls, resource: ProviderAPIKey) -> bool:
"""检查熔断器是否允许请求通过"""
if not resource.circuit_breaker_open:
def is_circuit_breaker_closed(
cls, resource: ProviderAPIKey, api_format: Optional[str] = None
) -> bool:
"""检查熔断器是否允许请求通过(按 API 格式)"""
if not api_format:
# 兼容旧调用:检查是否有任何格式的熔断器开启
circuit_by_format = resource.circuit_breaker_by_format or {}
for fmt, circuit_data in circuit_by_format.items():
if circuit_data.get("open"):
return False
return True
circuit_data = cls._get_circuit_data(resource, api_format)
if not circuit_data.get("open"):
return True
now = datetime.now(timezone.utc)
state = cls._get_circuit_state(resource, now)
state = cls._get_circuit_state_from_data(circuit_data, now)
# 半开状态允许请求通过
if state == CircuitState.HALF_OPEN:
return True
# 检查是否到了探测时间
if resource.next_probe_at and now >= resource.next_probe_at:
next_probe_str = circuit_data.get("next_probe_at")
if next_probe_str:
next_probe_at = datetime.fromisoformat(next_probe_str)
if now >= next_probe_at:
# 自动进入半开状态
cls._enter_half_open(resource, now)
cls._enter_half_open_data(circuit_data, now)
cls._set_circuit_data(resource, api_format, circuit_data)
return True
return False
@classmethod
def get_circuit_breaker_status(
cls, resource: ProviderAPIKey
cls, resource: ProviderAPIKey, api_format: Optional[str] = None
) -> Tuple[bool, Optional[str]]:
"""获取熔断器详细状态"""
if not resource.circuit_breaker_open:
"""获取熔断器详细状态(按 API 格式)"""
if not api_format:
# 兼容旧调用:返回第一个开启的熔断器状态
circuit_by_format = resource.circuit_breaker_by_format or {}
for fmt, circuit_data in circuit_by_format.items():
if circuit_data.get("open"):
return cls._get_status_from_circuit_data(circuit_data)
return True, None
circuit_data = cls._get_circuit_data(resource, api_format)
return cls._get_status_from_circuit_data(circuit_data)
@classmethod
def _get_status_from_circuit_data(
cls, circuit_data: Dict[str, Any]
) -> Tuple[bool, Optional[str]]:
"""从熔断器数据获取状态描述"""
if not circuit_data.get("open"):
return True, None
now = datetime.now(timezone.utc)
state = cls._get_circuit_state(resource, now)
state = cls._get_circuit_state_from_data(circuit_data, now)
if state == CircuitState.HALF_OPEN:
successes = int(resource.half_open_successes or 0)
successes = int(circuit_data.get("half_open_successes") or 0)
return True, f"半开状态({successes}/{cls.HALF_OPEN_SUCCESS_THRESHOLD}成功)"
if resource.next_probe_at:
if now >= resource.next_probe_at:
next_probe_str = circuit_data.get("next_probe_at")
if next_probe_str:
next_probe_at = datetime.fromisoformat(next_probe_str)
if now >= next_probe_at:
return True, None
remaining = resource.next_probe_at - now
remaining = next_probe_at - now
remaining_seconds = int(remaining.total_seconds())
if remaining_seconds >= 60:
time_str = f"{remaining_seconds // 60}min{remaining_seconds % 60}s"
@@ -417,8 +610,10 @@ class HealthMonitor:
return False, "熔断中"
@classmethod
def get_key_health(cls, db: Session, key_id: str) -> Optional[Dict[str, Any]]:
"""获取 Key 健康状态"""
def get_key_health(
cls, db: Session, key_id: str, api_format: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""获取 Key 健康状态(支持按格式查询)"""
try:
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == key_id).first()
if not key:
@@ -427,24 +622,15 @@ class HealthMonitor:
now = datetime.now(timezone.utc)
now_ts = now.timestamp()
# 计算当前错误率
error_rate = cls._calculate_error_rate(key, now_ts)
window = key.request_results_window or []
valid_window = [r for r in window if r["ts"] > now_ts - cls.WINDOW_SECONDS]
avg_response_time_ms = (
int(key.total_response_time_ms or 0) / int(key.success_count or 1)
if key.success_count
else 0
)
return {
# 全局统计
result = {
"key_id": key.id,
"health_score": float(key.health_score or 1.0),
"error_rate": error_rate,
"window_size": len(valid_window),
"consecutive_failures": int(key.consecutive_failures or 0),
"last_failure_at": key.last_failure_at.isoformat() if key.last_failure_at else None,
"is_active": key.is_active,
"statistics": {
"request_count": int(key.request_count or 0),
@@ -457,25 +643,84 @@ class HealthMonitor:
),
"avg_response_time_ms": round(avg_response_time_ms, 2),
},
}
# 按格式的健康度数据
health_by_format = key.health_by_format or {}
circuit_by_format = key.circuit_breaker_by_format or {}
if api_format:
# 查询单个格式
health_data = cls._get_health_data(key, api_format)
circuit_data = cls._get_circuit_data(key, api_format)
window = health_data.get("request_results_window") or []
valid_window = [r for r in window if r["ts"] > now_ts - cls.WINDOW_SECONDS]
result["api_format"] = api_format
result["health_score"] = float(health_data.get("health_score") or 1.0)
result["error_rate"] = cls._calculate_error_rate_from_window(window, now_ts)
result["window_size"] = len(valid_window)
result["consecutive_failures"] = int(
health_data.get("consecutive_failures") or 0
)
result["last_failure_at"] = health_data.get("last_failure_at")
result["circuit_breaker"] = {
"state": cls._get_circuit_state_from_data(circuit_data, now),
"open": circuit_data.get("open", False),
"open_at": circuit_data.get("open_at"),
"next_probe_at": circuit_data.get("next_probe_at"),
"half_open_until": circuit_data.get("half_open_until"),
"half_open_successes": int(circuit_data.get("half_open_successes") or 0),
"half_open_failures": int(circuit_data.get("half_open_failures") or 0),
}
else:
# 返回所有格式的健康度数据
formats_health = {}
for fmt in (key.api_formats or []):
health_data = health_by_format.get(fmt, _default_health_data())
circuit_data = circuit_by_format.get(fmt, _default_circuit_data())
window = health_data.get("request_results_window") or []
valid_window = [r for r in window if r["ts"] > now_ts - cls.WINDOW_SECONDS]
formats_health[fmt] = {
"health_score": float(health_data.get("health_score") or 1.0),
"error_rate": cls._calculate_error_rate_from_window(window, now_ts),
"window_size": len(valid_window),
"consecutive_failures": int(
health_data.get("consecutive_failures") or 0
),
"last_failure_at": health_data.get("last_failure_at"),
"circuit_breaker": {
"state": cls._get_circuit_state(key, now),
"open": key.circuit_breaker_open,
"open_at": (
key.circuit_breaker_open_at.isoformat()
if key.circuit_breaker_open_at
else None
"state": cls._get_circuit_state_from_data(circuit_data, now),
"open": circuit_data.get("open", False),
"open_at": circuit_data.get("open_at"),
"next_probe_at": circuit_data.get("next_probe_at"),
"half_open_until": circuit_data.get("half_open_until"),
"half_open_successes": int(
circuit_data.get("half_open_successes") or 0
),
"next_probe_at": (
key.next_probe_at.isoformat() if key.next_probe_at else None
"half_open_failures": int(
circuit_data.get("half_open_failures") or 0
),
"half_open_until": (
key.half_open_until.isoformat() if key.half_open_until else None
),
"half_open_successes": int(key.half_open_successes or 0),
"half_open_failures": int(key.half_open_failures or 0),
},
}
result["health_by_format"] = formats_health
# 计算整体健康度(取最低值)
if formats_health:
result["health_score"] = min(
h["health_score"] for h in formats_health.values()
)
result["any_circuit_open"] = any(
h["circuit_breaker"]["open"] for h in formats_health.values()
)
else:
result["health_score"] = 1.0
result["any_circuit_open"] = False
return result
except Exception as e:
logger.error(f"获取 Key 健康状态失败: {e}")
return None
@@ -507,23 +752,24 @@ class HealthMonitor:
# ==================== 管理方法 ====================
@classmethod
def reset_health(cls, db: Session, key_id: Optional[str] = None) -> bool:
"""重置健康度"""
def reset_health(
cls, db: Session, key_id: Optional[str] = None, api_format: Optional[str] = None
) -> bool:
"""重置健康度(支持按格式重置)"""
try:
if key_id:
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == key_id).first()
if key:
key.health_score = 1.0 # type: ignore[assignment]
key.consecutive_failures = 0 # type: ignore[assignment]
key.last_failure_at = None # type: ignore[assignment]
key.request_results_window = [] # type: ignore[assignment]
key.circuit_breaker_open = False # type: ignore[assignment]
key.circuit_breaker_open_at = None # type: ignore[assignment]
key.next_probe_at = None # type: ignore[assignment]
key.half_open_until = None # type: ignore[assignment]
key.half_open_successes = 0 # type: ignore[assignment]
key.half_open_failures = 0 # type: ignore[assignment]
logger.info(f"[RESET] 重置 Key 健康度: {key_id}")
if api_format:
# 重置单个格式
cls._set_health_data(key, api_format, _default_health_data())
cls._set_circuit_data(key, api_format, _default_circuit_data())
logger.info(f"[RESET] 重置 Key 健康度: {key_id}/{api_format}")
else:
# 重置所有格式
key.health_by_format = {} # type: ignore[assignment]
key.circuit_breaker_by_format = {} # type: ignore[assignment]
logger.info(f"[RESET] 重置 Key 所有格式健康度: {key_id}")
db.flush()
get_batch_committer().mark_dirty(db)
@@ -542,7 +788,9 @@ class HealthMonitor:
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == key_id).first()
if key and not key.is_active:
key.is_active = True # type: ignore[assignment]
key.consecutive_failures = 0 # type: ignore[assignment]
# 重置所有格式的健康度
key.health_by_format = {} # type: ignore[assignment]
key.circuit_breaker_by_format = {} # type: ignore[assignment]
logger.info(f"[OK] 手动启用 Key: {key_id}")
db.flush()
@@ -566,14 +814,28 @@ class HealthMonitor:
),
).first()
key_stats = db.query(
func.count(ProviderAPIKey.id).label("total"),
func.sum(case((ProviderAPIKey.is_active == True, 1), else_=0)).label("active"),
func.sum(case((ProviderAPIKey.health_score < 0.5, 1), else_=0)).label("unhealthy"),
func.sum(case((ProviderAPIKey.circuit_breaker_open == True, 1), else_=0)).label(
"circuit_open"
),
).first()
# 统计 Key需要遍历 JSON 字段计算熔断状态)
keys = db.query(ProviderAPIKey).all()
total_keys = len(keys)
active_keys = sum(1 for k in keys if k.is_active)
unhealthy_keys = 0
circuit_open_keys = 0
for key in keys:
health_by_format = key.health_by_format or {}
circuit_by_format = key.circuit_breaker_by_format or {}
# 检查是否有任何格式健康度低于 0.5
for fmt, health_data in health_by_format.items():
if float(health_data.get("health_score") or 1.0) < 0.5:
unhealthy_keys += 1
break
# 检查是否有任何格式熔断器开启
for fmt, circuit_data in circuit_by_format.items():
if circuit_data.get("open"):
circuit_open_keys += 1
break
return {
"endpoints": {
@@ -582,10 +844,10 @@ class HealthMonitor:
"unhealthy": int(endpoint_stats.unhealthy or 0) if endpoint_stats else 0,
},
"keys": {
"total": key_stats.total or 0 if key_stats else 0,
"active": int(key_stats.active or 0) if key_stats else 0,
"unhealthy": int(key_stats.unhealthy or 0) if key_stats else 0,
"circuit_open": int(key_stats.circuit_open or 0) if key_stats else 0,
"total": total_keys,
"active": active_keys,
"unhealthy": unhealthy_keys,
"circuit_open": circuit_open_keys,
},
}
@@ -618,8 +880,9 @@ class HealthMonitor:
db: Session,
endpoint_id: Optional[str] = None,
key_id: Optional[str] = None,
api_format: Optional[str] = None,
) -> bool:
"""检查是否有资格进行探测(兼容旧接口"""
"""检查是否有资格进行探测(按 API 格式"""
if not cls.ALLOW_AUTO_RECOVER:
return False
@@ -628,13 +891,53 @@ class HealthMonitor:
if key_id:
key = db.query(ProviderAPIKey).filter(ProviderAPIKey.id == key_id).first()
if key and key.circuit_breaker_open:
if key:
if api_format:
circuit_data = cls._get_circuit_data(key, api_format)
if circuit_data.get("open"):
now = datetime.now(timezone.utc)
state = cls._get_circuit_state(key, now)
state = cls._get_circuit_state_from_data(circuit_data, now)
return state == CircuitState.HALF_OPEN
else:
# 兼容旧调用:检查是否有任何格式处于半开状态
circuit_by_format = key.circuit_breaker_by_format or {}
now = datetime.now(timezone.utc)
for fmt, circuit_data in circuit_by_format.items():
if circuit_data.get("open"):
state = cls._get_circuit_state_from_data(circuit_data, now)
if state == CircuitState.HALF_OPEN:
return True
return False
# ==================== 便捷方法 ====================
@classmethod
def get_health_score(
cls, key: ProviderAPIKey, api_format: Optional[str] = None
) -> float:
"""获取指定格式的健康度分数"""
if not api_format:
# 返回所有格式中的最低健康度
health_by_format = key.health_by_format or {}
if not health_by_format:
return 1.0
return min(
float(h.get("health_score") or 1.0) for h in health_by_format.values()
)
health_data = cls._get_health_data(key, api_format)
return float(health_data.get("health_score") or 1.0)
@classmethod
def is_any_circuit_open(cls, key: ProviderAPIKey) -> bool:
"""检查是否有任何格式的熔断器开启"""
circuit_by_format = key.circuit_breaker_by_format or {}
for circuit_data in circuit_by_format.values():
if circuit_data.get("open"):
return True
return False
# 全局健康监控器实例
health_monitor = HealthMonitor()

View File

@@ -216,7 +216,7 @@ class ModelService:
def delete_model(db: Session, model_id: str): # UUID
"""删除模型
新架构删除逻辑:
删除逻辑:
- Model 只是 Provider 对 GlobalModel 的实现,删除不影响 GlobalModel
- 检查是否是该 GlobalModel 的最后一个实现(如果是,警告但允许删除)
"""
@@ -384,7 +384,7 @@ class ModelService:
@staticmethod
def convert_to_response(model: Model) -> ModelResponse:
"""转换为响应模型(新架构:从 GlobalModel 获取显示信息和默认值)"""
"""转换为响应模型(从 GlobalModel 获取显示信息和默认值)"""
return ModelResponse(
id=model.id,
provider_id=model.provider_id,

View File

@@ -171,7 +171,8 @@ class CandidateResolver:
)
candidate_record_map[(candidate_index, 0)] = record_id
else:
max_retries_for_candidate = endpoint.max_retries if candidate.is_cached else 1
# max_retries 已从 Endpoint 迁移到 ProviderEndpoint 仍可能保留旧字段用于兼容)
max_retries_for_candidate = int(provider.max_retries or 2) if candidate.is_cached else 1
for retry_index in range(max_retries_for_candidate):
record_id = str(uuid.uuid4())
@@ -236,7 +237,7 @@ class CandidateResolver:
total = 0
for candidate in all_candidates:
if not candidate.is_skipped:
endpoint = candidate.endpoint
max_retries = int(endpoint.max_retries) if candidate.is_cached else 1
provider = candidate.provider
max_retries = int(provider.max_retries or 2) if candidate.is_cached else 1
total += max_retries
return total

View File

@@ -26,7 +26,7 @@ from src.models.database import Provider, ProviderAPIKey, ProviderEndpoint
from src.services.cache.aware_scheduler import CacheAwareScheduler
from src.services.health.monitor import health_monitor
from src.services.provider.format import normalize_api_format
from src.services.rate_limit.adaptive_concurrency import get_adaptive_manager
from src.services.rate_limit.adaptive_rpm import get_adaptive_rpm_manager
from src.services.rate_limit.detector import RateLimitType, detect_rate_limit_type
@@ -112,7 +112,7 @@ class ErrorClassifier:
cache_scheduler: 缓存调度器(可选)
"""
self.db = db
self.adaptive_manager = adaptive_manager or get_adaptive_manager()
self.adaptive_manager = adaptive_manager or get_adaptive_rpm_manager()
self.cache_scheduler = cache_scheduler
# 表示客户端错误的 error type不区分大小写
@@ -361,7 +361,7 @@ class ErrorClassifier:
self,
key: ProviderAPIKey,
provider_name: str,
current_concurrent: Optional[int],
current_rpm: Optional[int],
exception: ProviderRateLimitException,
request_id: Optional[str] = None,
) -> str:
@@ -371,7 +371,7 @@ class ErrorClassifier:
Args:
key: API Key 对象
provider_name: 提供商名称
current_concurrent: 当前并发
current_rpm: 当前分钟内的请求
exception: 速率限制异常
request_id: 请求 ID用于日志
@@ -388,27 +388,27 @@ class ErrorClassifier:
rate_limit_info = detect_rate_limit_type(
headers=response_headers,
provider_name=provider_name,
current_concurrent=current_concurrent,
current_usage=current_rpm,
)
logger.info(f" [{request_id}] 429错误分析: "
f"类型={rate_limit_info.limit_type}, "
f"retry_after={rate_limit_info.retry_after}s, "
f"当前并发={current_concurrent}")
f"当前RPM={current_rpm}")
# 调用自适应管理器处理
new_limit = self.adaptive_manager.handle_429_error(
db=self.db,
key=key,
rate_limit_info=rate_limit_info,
current_concurrent=current_concurrent,
current_rpm=current_rpm,
)
if rate_limit_info.limit_type == RateLimitType.CONCURRENT:
logger.warning(f" [{request_id}] 自适应调整: " f"Key {key.id[:8]}... 并发限制 -> {new_limit}")
logger.warning(f" [{request_id}] 并发限制触发不调整RPM")
return "concurrent"
elif rate_limit_info.limit_type == RateLimitType.RPM:
logger.info(f" [{request_id}] [RPM] RPM限制需要切换Provider")
logger.warning(f" [{request_id}] 自适应调整: Key {key.id[:8]}... RPM限制 -> {new_limit}")
return "rpm"
else:
return "unknown"
@@ -439,18 +439,18 @@ class ErrorClassifier:
# 提取可读的错误消息
extracted_message = self._extract_error_message(error_response_text)
# 构建详细错误信息
# 构建详细错误信息(仅用于日志,不暴露给客户端)
if extracted_message:
detailed_message = f"提供商 '{provider_name}' 返回错误 {status}: {extracted_message}"
detailed_message = f"上游服务返回错误 {status}: {extracted_message}"
else:
detailed_message = f"提供商 '{provider_name}' 返回错误: {status}"
detailed_message = f"上游服务返回错误: {status}"
if status == 401:
return ProviderAuthException(provider_name=provider_name)
if status == 429:
return ProviderRateLimitException(
message=error_response_text or f"提供商 '{provider_name}' 速率限制",
message="请求过于频繁,请稍后重试",
provider_name=provider_name,
response_headers=dict(error.response.headers) if error.response else None,
retry_after=(
@@ -583,6 +583,7 @@ class ErrorClassifier:
health_monitor.record_failure(
db=self.db,
key_id=str(key.id),
api_format=api_format_str,
error_type="ProviderAuthException",
)
return extra_data
@@ -592,7 +593,7 @@ class ErrorClassifier:
await self.handle_rate_limit(
key=key,
provider_name=provider_name,
current_concurrent=captured_key_concurrent,
current_rpm=captured_key_concurrent,
exception=converted_error,
request_id=request_id,
)
@@ -620,6 +621,7 @@ class ErrorClassifier:
health_monitor.record_failure(
db=self.db,
key_id=str(key.id),
api_format=api_format_str,
error_type=type(converted_error).__name__,
)
@@ -675,7 +677,7 @@ class ErrorClassifier:
await self.handle_rate_limit(
key=key,
provider_name=provider_name,
current_concurrent=captured_key_concurrent,
current_rpm=captured_key_concurrent,
exception=error,
request_id=request_id,
)
@@ -702,5 +704,6 @@ class ErrorClassifier:
health_monitor.record_failure(
db=self.db,
key_id=str(key.id),
api_format=api_format_str,
error_type=type(error).__name__,
)

View File

@@ -44,7 +44,7 @@ from src.services.cache.aware_scheduler import (
get_cache_aware_scheduler,
)
from src.services.provider.format import normalize_api_format
from src.services.rate_limit.adaptive_concurrency import get_adaptive_manager
from src.services.rate_limit.adaptive_rpm import get_adaptive_rpm_manager
from src.services.rate_limit.concurrency_manager import get_concurrency_manager
from src.services.request.candidate import RequestCandidateService
from src.services.request.executor import ExecutionError, RequestExecutor
@@ -87,7 +87,7 @@ class FallbackOrchestrator:
self.redis = redis_client
self.cache_scheduler: Optional[CacheAwareScheduler] = None
self.concurrency_manager: Any = None
self.adaptive_manager = get_adaptive_manager() # 自适应并发管理器
self.adaptive_manager = get_adaptive_rpm_manager() # 自适应 RPM 管理器
self.request_executor: Optional[RequestExecutor] = None
# 拆分后的组件(延迟初始化)
@@ -558,7 +558,8 @@ class FallbackOrchestrator:
"""尝试单个候选(含重试逻辑),返回执行结果"""
provider = candidate.provider
endpoint = candidate.endpoint
max_retries_for_candidate = int(endpoint.max_retries) if candidate.is_cached else 1
# 从 Provider 读取 max_retries已从 Endpoint 迁移)
max_retries_for_candidate = int(provider.max_retries or 2) if candidate.is_cached else 1
last_error: Optional[Exception] = None
for retry_index in range(max_retries_for_candidate):
@@ -710,7 +711,7 @@ class FallbackOrchestrator:
upstream_status = getattr(last_error, "upstream_status", None)
upstream_response = getattr(last_error, "upstream_response", None)
# 如果响应为空或无效,使用异常的字符串表示
# 如果响应为空或无效,使用异常的字符串表示作为 upstream_response
if (
not upstream_response
or not upstream_response.strip()
@@ -718,8 +719,17 @@ class FallbackOrchestrator:
):
upstream_response = str(last_error)
# 构建友好的错误消息(用于返回给客户端,不暴露内部信息)
# 如果 last_error 有 message 属性,优先使用(已经是友好提示)
# 否则使用通用提示
friendly_message = "服务暂时不可用,请稍后重试"
if last_error:
last_error_message = getattr(last_error, "message", None)
if last_error_message and isinstance(last_error_message, str):
friendly_message = last_error_message
raise ProviderNotAvailableException(
f"所有Provider均不可用已尝试{max_attempts}个组合",
friendly_message,
request_metadata=request_metadata,
upstream_status=upstream_status,
upstream_response=upstream_response,

View File

@@ -1,19 +1,23 @@
"""
限流服务模块
包含自适应并发控制、RPM限流、IP限流等功能。
包含自适应 RPM 控制、并发管理、IP限流等功能。
"""
from src.services.rate_limit.adaptive_concurrency import AdaptiveConcurrencyManager
from src.services.rate_limit.adaptive_rpm import (
AdaptiveConcurrencyManager, # 向后兼容别名
AdaptiveRPMManager,
get_adaptive_rpm_manager,
)
from src.services.rate_limit.concurrency_manager import ConcurrencyManager
from src.services.rate_limit.detector import RateLimitDetector
from src.services.rate_limit.ip_limiter import IPRateLimiter
from src.services.rate_limit.rpm_limiter import RPMLimiter
__all__ = [
"AdaptiveConcurrencyManager",
"AdaptiveConcurrencyManager", # 向后兼容
"AdaptiveRPMManager",
"ConcurrencyManager",
"IPRateLimiter",
"RPMLimiter",
"RateLimitDetector",
"get_adaptive_rpm_manager",
]

View File

@@ -98,7 +98,7 @@ class AdaptiveReservationManager:
def calculate_reservation(
self,
key: "ProviderAPIKey",
current_concurrent: int = 0,
current_usage: int = 0,
effective_limit: Optional[int] = None,
) -> ReservationResult:
"""
@@ -106,8 +106,8 @@ class AdaptiveReservationManager:
Args:
key: ProviderAPIKey 对象
current_concurrent: 当前并发数
effective_limit: 有效并发限制(学习值或配置值)
current_usage: 当前使用量RPM 计数)
effective_limit: 有效限制(学习值或配置值)
Returns:
ReservationResult 包含预留比例和详细信息
@@ -116,7 +116,7 @@ class AdaptiveReservationManager:
total_requests = self._get_total_requests(key)
# 计算负载率
load_ratio = self._calculate_load_ratio(current_concurrent, effective_limit)
load_ratio = self._calculate_load_ratio(current_usage, effective_limit)
# 阶段1: 探测阶段
if total_requests < self.config.probe_phase_requests:
@@ -165,12 +165,12 @@ class AdaptiveReservationManager:
return request_count
def _calculate_load_ratio(
self, current_concurrent: int, effective_limit: Optional[int]
self, current_usage: int, effective_limit: Optional[int]
) -> float:
"""计算当前负载率"""
if not effective_limit or effective_limit <= 0:
return 0.0
return min(current_concurrent / effective_limit, 1.0)
return min(current_usage / effective_limit, 1.0)
def _calculate_confidence(self, key: "ProviderAPIKey") -> float:
"""

View File

@@ -1,16 +1,16 @@
"""
自适应并发调整器 - 基于边界记忆的并发限制调整
自适应 RPM 调整器 - 基于边界记忆的 RPM 限制调整
核心算法边界记忆 + 渐进探测
- 触发 429 时记录边界last_concurrent_peak这就是真实上限
- 缩容策略新限制 = 边界 - 1而非乘性减少
- 触发 429 时记录边界last_rpm_peak这就是真实上限
- 缩容策略新限制 = 边界 - 步长而非乘性减少
- 扩容策略不超过已知边界除非是探测性扩容
- 探测性扩容长时间无 429 时尝试突破边界
设计原则
1. 快速收敛一次 429 就能找到接近真实的限制
2. 避免过度保守不会因为多次 429 而无限下降
3. 安全探测允许在稳定后尝试更高并发
3. 安全探测允许在稳定后尝试更高 RPM
"""
from datetime import datetime, timezone
@@ -18,7 +18,7 @@ from typing import Any, Dict, List, Optional, cast
from sqlalchemy.orm import Session
from src.config.constants import ConcurrencyDefaults
from src.config.constants import RPMDefaults
from src.core.batch_committer import get_batch_committer
from src.core.logger import logger
from src.models.database import ProviderAPIKey
@@ -33,14 +33,14 @@ class AdaptiveStrategy:
AGGRESSIVE = "aggressive" # 激进策略(快速探测)
class AdaptiveConcurrencyManager:
class AdaptiveRPMManager:
"""
自适应并发管理器
自适应 RPM 管理器
核心算法边界记忆 + 渐进探测
- 触发 429 时记录边界last_concurrent_peak = 触发时的并发数
- 缩容新限制 = 边界 - 1快速收敛到真实限制附近
- 扩容不超过边界 last_concurrent_peak允许回到边界值尝试
- 触发 429 时记录边界last_rpm_peak = 触发时的 RPM
- 缩容新限制 = 边界 - 步长快速收敛到真实限制附近
- 扩容不超过边界 last_rpm_peak允许回到边界值尝试
- 探测性扩容长时间30分钟 429 可以尝试 +1 突破边界
扩容条件满足任一即可
@@ -50,35 +50,35 @@ class AdaptiveConcurrencyManager:
关键特性
1. 快速收敛一次 429 就能学到接近真实的限制值
2. 边界保护普通扩容不会超过已知边界
3. 安全探测长时间稳定后允许尝试更高并发
3. 安全探测长时间稳定后允许尝试更高 RPM
4. 区分并发限制和 RPM 限制
"""
# 默认配置 - 使用统一常量
DEFAULT_INITIAL_LIMIT = ConcurrencyDefaults.INITIAL_LIMIT
MIN_CONCURRENT_LIMIT = ConcurrencyDefaults.MIN_CONCURRENT_LIMIT
MAX_CONCURRENT_LIMIT = ConcurrencyDefaults.MAX_CONCURRENT_LIMIT
DEFAULT_INITIAL_LIMIT = RPMDefaults.INITIAL_LIMIT
MIN_RPM_LIMIT = RPMDefaults.MIN_RPM_LIMIT
MAX_RPM_LIMIT = RPMDefaults.MAX_RPM_LIMIT
# AIMD 参数
INCREASE_STEP = ConcurrencyDefaults.INCREASE_STEP
INCREASE_STEP = RPMDefaults.INCREASE_STEP
# 滑动窗口参数
UTILIZATION_WINDOW_SIZE = ConcurrencyDefaults.UTILIZATION_WINDOW_SIZE
UTILIZATION_WINDOW_SECONDS = ConcurrencyDefaults.UTILIZATION_WINDOW_SECONDS
UTILIZATION_THRESHOLD = ConcurrencyDefaults.UTILIZATION_THRESHOLD
HIGH_UTILIZATION_RATIO = ConcurrencyDefaults.HIGH_UTILIZATION_RATIO
MIN_SAMPLES_FOR_DECISION = ConcurrencyDefaults.MIN_SAMPLES_FOR_DECISION
UTILIZATION_WINDOW_SIZE = RPMDefaults.UTILIZATION_WINDOW_SIZE
UTILIZATION_WINDOW_SECONDS = RPMDefaults.UTILIZATION_WINDOW_SECONDS
UTILIZATION_THRESHOLD = RPMDefaults.UTILIZATION_THRESHOLD
HIGH_UTILIZATION_RATIO = RPMDefaults.HIGH_UTILIZATION_RATIO
MIN_SAMPLES_FOR_DECISION = RPMDefaults.MIN_SAMPLES_FOR_DECISION
# 探测性扩容参数
PROBE_INCREASE_INTERVAL_MINUTES = ConcurrencyDefaults.PROBE_INCREASE_INTERVAL_MINUTES
PROBE_INCREASE_MIN_REQUESTS = ConcurrencyDefaults.PROBE_INCREASE_MIN_REQUESTS
PROBE_INCREASE_INTERVAL_MINUTES = RPMDefaults.PROBE_INCREASE_INTERVAL_MINUTES
PROBE_INCREASE_MIN_REQUESTS = RPMDefaults.PROBE_INCREASE_MIN_REQUESTS
# 记录历史数量
MAX_HISTORY_RECORDS = 20
def __init__(self, strategy: str = AdaptiveStrategy.AIMD):
"""
初始化自适应并发管理器
初始化自适应 RPM 管理器
Args:
strategy: 调整策略
@@ -90,54 +90,54 @@ class AdaptiveConcurrencyManager:
db: Session,
key: ProviderAPIKey,
rate_limit_info: RateLimitInfo,
current_concurrent: Optional[int] = None,
current_rpm: Optional[int] = None,
) -> int:
"""
处理429错误调整并发限制
处理429错误调整 RPM 限制
Args:
db: 数据库会话
key: API Key对象
rate_limit_info: 速率限制信息
current_concurrent: 当前并发
current_rpm: 当前分钟内的请求
Returns:
调整后的并发限制
调整后的 RPM 限制
"""
# max_concurrent=NULL 表示启用自适应,max_concurrent=数字 表示固定限制
is_adaptive = key.max_concurrent is None
# rpm_limit=NULL 表示启用自适应,rpm_limit=数字 表示固定限制
is_adaptive = key.rpm_limit is None
if not is_adaptive:
logger.debug(
f"Key {key.id} 设置了固定并发限制 ({key.max_concurrent}),跳过自适应调整"
f"Key {key.id} 设置了固定 RPM 限制 ({key.rpm_limit}),跳过自适应调整"
)
return int(key.max_concurrent) # type: ignore[arg-type]
return int(key.rpm_limit) # type: ignore[arg-type]
# 更新429统计
key.last_429_at = datetime.now(timezone.utc) # type: ignore[assignment]
key.last_429_type = rate_limit_info.limit_type # type: ignore[assignment]
# 仅在并发限制且拿到并发数时记录边界RPM/UNKNOWN 不应覆盖并发边界记忆)
# 仅在 RPM 限制且拿到 RPM 数时记录边界
if (
rate_limit_info.limit_type == RateLimitType.CONCURRENT
and current_concurrent is not None
and current_concurrent > 0
rate_limit_info.limit_type == RateLimitType.RPM
and current_rpm is not None
and current_rpm > 0
):
key.last_concurrent_peak = current_concurrent # type: ignore[assignment]
key.last_rpm_peak = current_rpm # type: ignore[assignment]
# 遇到 429 错误,清空利用率采样窗口(重新开始收集)
key.utilization_samples = [] # type: ignore[assignment]
if rate_limit_info.limit_type == RateLimitType.CONCURRENT:
# 并发限制:减少并发数
key.concurrent_429_count = int(key.concurrent_429_count or 0) + 1 # type: ignore[assignment]
if rate_limit_info.limit_type == RateLimitType.RPM:
# RPM 限制:减少 RPM 限制
key.rpm_429_count = int(key.rpm_429_count or 0) + 1 # type: ignore[assignment]
# 获取当前有效限制(自适应模式使用 learned_max_concurrent
old_limit = int(key.learned_max_concurrent or self.DEFAULT_INITIAL_LIMIT)
new_limit = self._decrease_limit(old_limit, current_concurrent)
# 获取当前有效限制(自适应模式使用 learned_rpm_limit
old_limit = int(key.learned_rpm_limit or self.DEFAULT_INITIAL_LIMIT)
new_limit = self._decrease_limit(old_limit, current_rpm)
logger.warning(
f"[CONCURRENT] 并发限制触发: Key {key.id[:8]}... | "
f"当前并发: {current_concurrent} | "
f"[RPM] RPM 限制触发: Key {key.id[:8]}... | "
f"当前 RPM: {current_rpm} | "
f"调整: {old_limit} -> {new_limit}"
)
@@ -146,79 +146,78 @@ class AdaptiveConcurrencyManager:
key,
old_limit=old_limit,
new_limit=new_limit,
reason="concurrent_429",
current_concurrent=current_concurrent,
reason="rpm_429",
current_rpm=current_rpm,
)
# 更新学习到的并发限制
key.learned_max_concurrent = new_limit # type: ignore[assignment]
# 更新学习到的 RPM 限制
key.learned_rpm_limit = new_limit # type: ignore[assignment]
elif rate_limit_info.limit_type == RateLimitType.RPM:
# RPM限制:不调整并发,只记录
key.rpm_429_count = int(key.rpm_429_count or 0) + 1 # type: ignore[assignment]
elif rate_limit_info.limit_type == RateLimitType.CONCURRENT:
# 并发限制:不调整 RPM,只记录
key.concurrent_429_count = int(key.concurrent_429_count or 0) + 1 # type: ignore[assignment]
logger.info(
f"[RPM] RPM限制触发: Key {key.id[:8]}... | "
f"retry_after: {rate_limit_info.retry_after}s | "
f"不调整并发限制"
f"[CONCURRENT] 并发限制触发: Key {key.id[:8]}... | "
f"不调整 RPM 限制(这是并发问题,非 RPM 问题)"
)
else:
# 未知类型:保守处理,轻微减少
logger.warning(
f"[UNKNOWN] 未知429类型: Key {key.id[:8]}... | "
f"当前并发: {current_concurrent} | "
f"保守减少并发"
f"当前 RPM: {current_rpm} | "
f"保守减少 RPM"
)
old_limit = int(key.learned_max_concurrent or self.DEFAULT_INITIAL_LIMIT)
new_limit = max(int(old_limit * 0.9), self.MIN_CONCURRENT_LIMIT) # 减少10%
old_limit = int(key.learned_rpm_limit or self.DEFAULT_INITIAL_LIMIT)
new_limit = max(int(old_limit * 0.9), self.MIN_RPM_LIMIT) # 减少10%
self._record_adjustment(
key,
old_limit=old_limit,
new_limit=new_limit,
reason="unknown_429",
current_concurrent=current_concurrent,
current_rpm=current_rpm,
)
key.learned_max_concurrent = new_limit # type: ignore[assignment]
key.learned_rpm_limit = new_limit # type: ignore[assignment]
db.flush()
get_batch_committer().mark_dirty(db)
return int(key.learned_max_concurrent or self.DEFAULT_INITIAL_LIMIT)
return int(key.learned_rpm_limit or self.DEFAULT_INITIAL_LIMIT)
def handle_success(
self,
db: Session,
key: ProviderAPIKey,
current_concurrent: int,
current_rpm: int,
) -> Optional[int]:
"""
处理成功请求基于滑动窗口利用率考虑增加并发限制
处理成功请求基于滑动窗口利用率考虑增加 RPM 限制
Args:
db: 数据库会话
key: API Key对象
current_concurrent: 当前并发必需用于计算利用率
current_rpm: 当前分钟内的请求必需用于计算利用率
Returns:
调整后的并发限制如果有调整否则返回 None
调整后的 RPM 限制如果有调整否则返回 None
"""
# max_concurrent=NULL 表示启用自适应
is_adaptive = key.max_concurrent is None
# rpm_limit=NULL 表示启用自适应
is_adaptive = key.rpm_limit is None
if not is_adaptive:
return None
current_limit = int(key.learned_max_concurrent or self.DEFAULT_INITIAL_LIMIT)
current_limit = int(key.learned_rpm_limit or self.DEFAULT_INITIAL_LIMIT)
# 获取已知边界(上次触发 429 时的并发数
known_boundary = key.last_concurrent_peak
# 获取已知边界(上次触发 429 时的 RPM
known_boundary = key.last_rpm_peak
# 计算当前利用率
utilization = float(current_concurrent / current_limit) if current_limit > 0 else 0.0
utilization = float(current_rpm / current_limit) if current_limit > 0 else 0.0
now = datetime.now(timezone.utc)
now_ts = now.timestamp()
@@ -229,7 +228,7 @@ class AdaptiveConcurrencyManager:
# 检查是否满足扩容条件
increase_reason = self._check_increase_conditions(key, samples, now, known_boundary)
if increase_reason and current_limit < self.MAX_CONCURRENT_LIMIT:
if increase_reason and current_limit < self.MAX_RPM_LIMIT:
old_limit = current_limit
is_probe = increase_reason == "probe_increase"
new_limit = self._increase_limit(current_limit, known_boundary, is_probe)
@@ -262,12 +261,12 @@ class AdaptiveConcurrencyManager:
avg_utilization=round(avg_util, 2),
high_util_ratio=round(high_util_ratio, 2),
sample_count=len(samples),
current_concurrent=current_concurrent,
current_rpm=current_rpm,
known_boundary=known_boundary,
)
# 更新限制
key.learned_max_concurrent = new_limit # type: ignore[assignment]
key.learned_rpm_limit = new_limit # type: ignore[assignment]
# 如果是探测性扩容,更新探测时间
if is_probe:
@@ -334,7 +333,7 @@ class AdaptiveConcurrencyManager:
key: API Key对象
samples: 利用率采样列表
now: 当前时间
known_boundary: 已知边界触发 429 时的并发数
known_boundary: 已知边界触发 429 时的 RPM
Returns:
扩容原因如果满足条件否则返回 None
@@ -343,7 +342,7 @@ class AdaptiveConcurrencyManager:
if self._is_in_cooldown(key):
return None
current_limit = int(key.learned_max_concurrent or self.DEFAULT_INITIAL_LIMIT)
current_limit = int(key.learned_rpm_limit or self.DEFAULT_INITIAL_LIMIT)
# 条件1滑动窗口扩容不超过边界
if len(samples) >= self.MIN_SAMPLES_FOR_DECISION:
@@ -353,7 +352,7 @@ class AdaptiveConcurrencyManager:
if high_util_ratio >= self.HIGH_UTILIZATION_RATIO:
# 检查是否还有扩容空间(边界保护)
if known_boundary:
# 允许扩容到边界值(而非 boundary - 1因为缩容时已经 -1
# 允许扩容到边界值(而非 boundary - 1因为缩容时已经 -步长
if current_limit < known_boundary:
return "high_utilization"
# 已达边界,不触发普通扩容
@@ -429,34 +428,37 @@ class AdaptiveConcurrencyManager:
last_429_at = cast(datetime, key.last_429_at)
time_since_429 = (datetime.now(timezone.utc) - last_429_at).total_seconds()
cooldown_seconds = ConcurrencyDefaults.COOLDOWN_AFTER_429_MINUTES * 60
cooldown_seconds = RPMDefaults.COOLDOWN_AFTER_429_MINUTES * 60
return bool(time_since_429 < cooldown_seconds)
def _decrease_limit(
self,
current_limit: int,
current_concurrent: Optional[int] = None,
current_rpm: Optional[int] = None,
) -> int:
"""
减少并发限制基于边界记忆策略
减少 RPM 限制基于边界记忆策略
策略
- 如果知道触发 429 时的并发数新限制 = 并发数 - 1
- 这样可以快速收敛到真实限制附近而不会过度保守
- 例如真实限制 8触发时并发 8 -> 新限制 7而非 8*0.85=6
- 如果知道触发 429 时的 RPM新限制 = RPM * 0.90保留 10% 安全边际
- 10% 的安全边际更保守考虑到
1. RPM 报告可能存在延迟实际触发时 RPM 可能略高于报告值
2. 上游 API 的限制可能有波动
3. 避免频繁在边界附近触发 429
- 相比固定步长百分比方式更适应不同量级的限制值
"""
if current_concurrent is not None and current_concurrent > 0:
# 边界记忆策略:新限制 = 触发边界 - 1
candidate = current_concurrent - 1
if current_rpm is not None and current_rpm > 0:
# 边界记忆策略:新限制 = 触发边界 * 0.9010% 安全边际)
candidate = int(current_rpm * 0.90)
else:
# 没有并发信息时,保守减少 1
candidate = current_limit - 1
# 没有 RPM 信息时,减少 10%
candidate = int(current_limit * 0.9)
# 保证不会缩容变扩容”(例如 current_concurrent > current_limit 的异常场景)
# 保证不会"缩容变扩容"
candidate = min(candidate, current_limit - 1)
new_limit = max(candidate, self.MIN_CONCURRENT_LIMIT)
new_limit = max(candidate, self.MIN_RPM_LIMIT)
return new_limit
@@ -467,16 +469,15 @@ class AdaptiveConcurrencyManager:
is_probe: bool = False,
) -> int:
"""
增加并发限制考虑边界保护
增加 RPM 限制考虑边界保护
策略
- 普通扩容每次 +INCREASE_STEP但不超过 known_boundary
因为缩容时已经 -1 这里允许回到边界值尝试
- 探测性扩容每次只 +1可以突破边界但要谨慎
Args:
current_limit: 当前限制
known_boundary: 已知边界last_concurrent_peak即触发 429 时的并发数
known_boundary: 已知边界last_rpm_peak即触发 429 时的 RPM
is_probe: 是否是探测性扩容可以突破边界
"""
if is_probe:
@@ -486,13 +487,13 @@ class AdaptiveConcurrencyManager:
# 普通模式:每次 +INCREASE_STEP
new_limit = current_limit + self.INCREASE_STEP
# 边界保护:普通扩容不超过 known_boundary(允许回到边界值尝试)
# 边界保护:普通扩容不超过 known_boundary
if known_boundary:
if new_limit > known_boundary:
new_limit = known_boundary
# 全局上限保护
new_limit = min(new_limit, self.MAX_CONCURRENT_LIMIT)
new_limit = min(new_limit, self.MAX_RPM_LIMIT)
# 确保有增长(否则返回原值表示不扩容)
if new_limit <= current_limit:
@@ -509,7 +510,7 @@ class AdaptiveConcurrencyManager:
**extra_data: Any,
) -> None:
"""
记录并发调整历史
记录 RPM 调整历史
Args:
key: API Key对象
@@ -548,10 +549,10 @@ class AdaptiveConcurrencyManager:
history: List[Dict[str, Any]] = list(key.adjustment_history or [])
samples: List[Dict[str, Any]] = list(key.utilization_samples or [])
# max_concurrent=NULL 表示自适应,否则为固定限制
is_adaptive = key.max_concurrent is None
current_limit = int(key.learned_max_concurrent or self.DEFAULT_INITIAL_LIMIT)
effective_limit = current_limit if is_adaptive else int(key.max_concurrent) # type: ignore
# rpm_limit=NULL 表示自适应,否则为固定限制
is_adaptive = key.rpm_limit is None
current_limit = int(key.learned_rpm_limit or self.DEFAULT_INITIAL_LIMIT)
effective_limit = current_limit if is_adaptive else int(key.rpm_limit) # type: ignore
# 计算窗口统计
avg_utilization: Optional[float] = None
@@ -570,15 +571,15 @@ class AdaptiveConcurrencyManager:
last_probe_at_str = cast(datetime, key.last_probe_increase_at).isoformat()
# 边界信息
known_boundary = key.last_concurrent_peak
known_boundary = key.last_rpm_peak
return {
"adaptive_mode": is_adaptive,
"max_concurrent": key.max_concurrent, # NULL=自适应,数字=固定限制
"rpm_limit": key.rpm_limit, # NULL=自适应,数字=固定限制
"effective_limit": effective_limit, # 当前有效限制
"learned_limit": key.learned_max_concurrent, # 学习到的限制
"learned_limit": key.learned_rpm_limit, # 学习到的限制
# 边界记忆相关
"known_boundary": known_boundary, # 触发 429 时的并发数(已知上限)
"known_boundary": known_boundary, # 触发 429 时的 RPM(已知上限)
"concurrent_429_count": int(key.concurrent_429_count or 0),
"rpm_429_count": int(key.rpm_429_count or 0),
"last_429_at": last_429_at_str,
@@ -607,12 +608,12 @@ class AdaptiveConcurrencyManager:
"""
logger.info(f"[RESET] 重置学习状态: Key {key.id[:8]}...")
key.learned_max_concurrent = None # type: ignore[assignment]
key.learned_rpm_limit = None # type: ignore[assignment]
key.concurrent_429_count = 0 # type: ignore[assignment]
key.rpm_429_count = 0 # type: ignore[assignment]
key.last_429_at = None # type: ignore[assignment]
key.last_429_type = None # type: ignore[assignment]
key.last_concurrent_peak = None # type: ignore[assignment]
key.last_rpm_peak = None # type: ignore[assignment]
key.adjustment_history = [] # type: ignore[assignment]
key.utilization_samples = [] # type: ignore[assignment]
key.last_probe_increase_at = None # type: ignore[assignment]
@@ -622,12 +623,17 @@ class AdaptiveConcurrencyManager:
# 全局单例
_adaptive_manager: Optional[AdaptiveConcurrencyManager] = None
_adaptive_rpm_manager: Optional[AdaptiveRPMManager] = None
def get_adaptive_manager() -> AdaptiveConcurrencyManager:
"""获取全局自适应管理器单例"""
global _adaptive_manager
if _adaptive_manager is None:
_adaptive_manager = AdaptiveConcurrencyManager()
return _adaptive_manager
def get_adaptive_rpm_manager() -> AdaptiveRPMManager:
"""获取全局自适应 RPM 管理器单例"""
global _adaptive_rpm_manager
if _adaptive_rpm_manager is None:
_adaptive_rpm_manager = AdaptiveRPMManager()
return _adaptive_rpm_manager
# 向后兼容别名
AdaptiveConcurrencyManager = AdaptiveRPMManager
get_adaptive_manager = get_adaptive_rpm_manager

View File

@@ -1,29 +1,33 @@
"""
并发管理器 - 支持 Redis 或内存的并发控
RPM 限制管理器 - 支持 Redis 或内存的 Key 级别 RPM 限
功能:
1. Endpoint 级别的并发限制
2. ProviderAPIKey 级别的并发限制
3. 分布式环境下优先使用 Redis多实例共享
4. 在开发/单实例场景下自动降级为内存计数
5. 自动释放和异常处理Redis 提供 TTL内存模式请确保手动释放
1. ProviderAPIKey 级别的 RPM 限制(按分钟窗口计数)
2. 分布式环境下优先使用 Redis多实例共享
3. 在开发/单实例场景下自动降级为内存计数
4. 支持缓存用户优先级(预留槽位机制)
"""
import asyncio
import math
import os
import time
from contextlib import asynccontextmanager
from datetime import timedelta # noqa: F401 - kept for potential future use
from typing import Optional, Tuple
from typing import Optional
import redis.asyncio as aioredis
from src.config.constants import RPMDefaults
from src.core.logger import logger
class ConcurrencyManager:
"""分布式并发管理器"""
"""Key RPM 限制管理器"""
_instance: Optional["ConcurrencyManager"] = None
_redis: Optional[aioredis.Redis] = None
_key_rpm_bucket_seconds: int = 60
_key_rpm_key_ttl_seconds: int = 120 # 2 分钟,足够覆盖当前分钟与边界
def __new__(cls):
"""单例模式"""
@@ -37,9 +41,22 @@ class ConcurrencyManager:
return
self._memory_lock: asyncio.Lock = asyncio.Lock()
self._memory_endpoint_counts: dict[str, int] = {}
self._memory_key_counts: dict[str, int] = {}
# Key RPM 计数器:{key_id: (bucket, count)}bucket = floor(now / 60)
self._memory_key_rpm_counts: dict[str, tuple[int, int]] = {}
self._owns_redis: bool = False
self._last_cleanup_bucket: int = 0 # 上次清理时的 bucket用于定期清理过期数据
self._last_cleanup_time: float = 0 # 上次清理的时间戳,用于强制定期清理
self._cleanup_interval_seconds: int = 300 # 强制清理间隔5 分钟)
self._cleanup_task: Optional[asyncio.Task] = None # 后台清理任务
# 内存模式下的最大条目限制,防止内存泄漏(支持环境变量覆盖)
self._max_memory_rpm_entries: int = int(
os.getenv("RPM_MAX_MEMORY_ENTRIES", str(RPMDefaults.MAX_MEMORY_RPM_ENTRIES))
)
# 早期告警阈值(达到此比例时记录警告)
self._memory_warning_threshold: float = float(
os.getenv("RPM_MEMORY_WARNING_THRESHOLD", str(RPMDefaults.MEMORY_WARNING_THRESHOLD))
)
self._memory_initialized = True
async def initialize(self) -> None:
@@ -56,212 +73,304 @@ class ConcurrencyManager:
if self._redis:
logger.info("[OK] ConcurrencyManager 已复用全局 Redis 客户端")
else:
logger.warning("[WARN] Redis 不可用,并发控制降级为内存模式(仅在单实例环境下安全)")
logger.warning("[WARN] Redis 不可用,RPM 限制降级为内存模式(仅在单实例环境下安全)")
# 内存模式下启动后台清理任务
self._start_background_cleanup()
except Exception as e:
logger.error(f"[ERROR] 获取全局 Redis 客户端失败: {e}")
logger.warning("[WARN] 并发控制将降级为内存模式(仅在单实例环境下安全)")
logger.warning("[WARN] RPM 限制将降级为内存模式(仅在单实例环境下安全)")
self._redis = None
self._owns_redis = False
# 内存模式下启动后台清理任务
self._start_background_cleanup()
def _start_background_cleanup(self) -> None:
"""启动后台定期清理任务(仅内存模式需要)"""
if self._cleanup_task is not None:
return # 已经启动
async def cleanup_loop():
"""后台清理循环"""
while True:
try:
await asyncio.sleep(60) # 每分钟检查一次
async with self._memory_lock:
current_bucket = self._get_rpm_bucket()
self._cleanup_expired_memory_rpm_counts(current_bucket, force=False)
except asyncio.CancelledError:
break
except Exception as e:
logger.debug(f"后台清理任务异常: {e}")
try:
self._cleanup_task = asyncio.create_task(cleanup_loop())
logger.debug("[OK] 内存模式后台清理任务已启动")
except RuntimeError:
# 没有事件循环时忽略
pass
async def close(self) -> None:
"""关闭 Redis 连接"""
# 停止后台清理任务
if self._cleanup_task is not None:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
self._cleanup_task = None
if self._redis and self._owns_redis:
await self._redis.close()
logger.info("ConcurrencyManager Redis 连接已关闭")
self._redis = None
self._owns_redis = False
def _get_endpoint_key(self, endpoint_id: str) -> str:
"""获取 Endpoint 并发计数的 Redis Key"""
return f"concurrency:endpoint:{endpoint_id}"
@classmethod
def _get_rpm_bucket(cls, now_ts: Optional[float] = None) -> int:
"""获取当前 RPM 计数桶(按分钟)"""
ts = now_ts if now_ts is not None else time.time()
return int(ts // cls._key_rpm_bucket_seconds)
def _get_key_key(self, key_id: str) -> str:
"""获取 ProviderAPIKey 并发计数的 Redis Key"""
return f"concurrency:key:{key_id}"
@classmethod
def _get_key_key(cls, key_id: str, bucket: Optional[int] = None) -> str:
"""获取 ProviderAPIKey RPM 计数的 Redis Key按分钟桶"""
b = bucket if bucket is not None else cls._get_rpm_bucket()
return f"rpm:key:{key_id}:{b}"
async def get_current_concurrency(
self, endpoint_id: Optional[str] = None, key_id: Optional[str] = None
) -> Tuple[int, int]:
def _get_memory_key_rpm_count(self, key_id: str, bucket: int) -> int:
"""获取内存模式下 Key 在指定 bucket 的 RPM 计数"""
stored = self._memory_key_rpm_counts.get(key_id)
if not stored:
return 0
stored_bucket, count = stored
if stored_bucket != bucket:
# 旧桶数据已过期,删除以防止内存泄漏
del self._memory_key_rpm_counts[key_id]
return 0
return count
def _set_memory_key_rpm_count(self, key_id: str, bucket: int, count: int) -> None:
"""设置内存模式下 Key 在指定 bucket 的 RPM 计数"""
current_size = len(self._memory_key_rpm_counts)
warning_threshold = int(self._max_memory_rpm_entries * self._memory_warning_threshold)
high_threshold = int(self._max_memory_rpm_entries * 0.8)
critical_threshold = int(self._max_memory_rpm_entries * 0.95)
# 分级告警:根据使用率记录不同级别的日志
if current_size >= critical_threshold and key_id not in self._memory_key_rpm_counts:
logger.critical(
f"[CRITICAL] 内存 RPM 计数器接近上限 ({current_size}/{self._max_memory_rpm_entries})"
f"强烈建议启用 Redis继续增长可能导致 RPM 限制失效"
)
elif current_size >= high_threshold and key_id not in self._memory_key_rpm_counts:
# 每 100 个条目告警一次,避免日志过多
if current_size % 100 == 0:
logger.error(
f"[ERROR] 内存 RPM 计数器使用率过高 ({current_size}/{self._max_memory_rpm_entries})"
f"建议启用 Redis"
)
elif current_size >= warning_threshold and key_id not in self._memory_key_rpm_counts:
if current_size == warning_threshold:
logger.warning(
f"[WARN] 内存 RPM 计数器达到 {self._memory_warning_threshold:.0%} 阈值 "
f"({current_size}/{self._max_memory_rpm_entries}),建议启用 Redis"
)
# 检查是否超过最大条目限制
if (
key_id not in self._memory_key_rpm_counts
and current_size >= self._max_memory_rpm_entries
):
# 触发强制清理
self._cleanup_expired_memory_rpm_counts(bucket, force=True)
# 如果清理后仍然超过限制,执行 LRU 淘汰(删除最旧的 20%
if len(self._memory_key_rpm_counts) >= self._max_memory_rpm_entries:
evict_count = max(1, self._max_memory_rpm_entries // 5)
# 按 bucket时间排序删除最旧的
sorted_keys = sorted(
self._memory_key_rpm_counts.items(),
key=lambda x: x[1][0] # 按 bucket 排序
)
for k, _ in sorted_keys[:evict_count]:
del self._memory_key_rpm_counts[k]
logger.warning(
f"[WARN] 内存 RPM 计数器达到上限,已淘汰 {evict_count} 个最旧条目"
)
self._memory_key_rpm_counts[key_id] = (bucket, count)
def _cleanup_expired_memory_rpm_counts(self, current_bucket: int, force: bool = False) -> None:
"""
获取当前并发数
清理内存中过期的 RPM 计数(必须在持有 _memory_lock 时调用)
清理策略:
- 常规清理:每分钟最多执行一次(当 bucket 变化时)
- 强制清理:每 5 分钟执行一次(防止长时间无请求导致内存泄漏)
"""
now = time.time()
# 检查是否需要清理
should_cleanup = (
current_bucket != self._last_cleanup_bucket # 分钟切换
or force # 强制清理
or (now - self._last_cleanup_time > self._cleanup_interval_seconds) # 超时清理
)
if not should_cleanup:
return
self._last_cleanup_bucket = current_bucket
self._last_cleanup_time = now
expired_keys = []
for key_id, (stored_bucket, _count) in self._memory_key_rpm_counts.items():
if stored_bucket < current_bucket:
expired_keys.append(key_id)
for key_id in expired_keys:
del self._memory_key_rpm_counts[key_id]
if expired_keys:
logger.debug(f"[CLEANUP] 清理了 {len(expired_keys)} 个过期的内存 RPM 计数")
async def get_key_rpm_count(self, key_id: str) -> int:
"""
获取 Key 当前 RPM 计数
Args:
endpoint_id: Endpoint ID可选
key_id: ProviderAPIKey ID可选
key_id: ProviderAPIKey ID
Returns:
(endpoint_concurrency, key_concurrency)
当前分钟窗口内的请求数
"""
if self._redis is None:
async with self._memory_lock:
endpoint_count = (
self._memory_endpoint_counts.get(endpoint_id, 0) if endpoint_id else 0
)
key_count = self._memory_key_counts.get(key_id, 0) if key_id else 0
return endpoint_count, key_count
endpoint_count = 0
key_count = 0
bucket = self._get_rpm_bucket()
# 定期清理过期数据,避免内存泄漏
self._cleanup_expired_memory_rpm_counts(bucket)
return self._get_memory_key_rpm_count(key_id, bucket)
try:
if endpoint_id:
endpoint_key = self._get_endpoint_key(endpoint_id)
result = await self._redis.get(endpoint_key)
endpoint_count = int(result) if result else 0
if key_id:
key_key = self._get_key_key(key_id)
result = await self._redis.get(key_key)
key_count = int(result) if result else 0
return int(result) if result else 0
except Exception as e:
logger.error(f"获取并发数失败: {e}")
logger.error(f"获取 RPM 计数失败: {e}")
return 0
return endpoint_count, key_count
async def check_available(
async def check_rpm_available(
self,
endpoint_id: str,
endpoint_max_concurrent: Optional[int],
key_id: str,
key_max_concurrent: Optional[int],
key_rpm_limit: Optional[int],
is_cached_user: bool = False,
cache_reservation_ratio: Optional[float] = None,
) -> bool:
"""
检查是否可以获取并发槽位(不实际获取
检查是否可以通过 RPM 限制(不实际增加计数
Args:
endpoint_id: Endpoint ID
endpoint_max_concurrent: Endpoint 最大并发数None 表示不限制)
key_id: ProviderAPIKey ID
key_max_concurrent: Key 最大并发数(None 表示不限制)
key_rpm_limit: Key RPM 限制(每分钟最大请求数,None 表示不限制)
is_cached_user: 是否是缓存用户
cache_reservation_ratio: 缓存预留比例
Returns:
是否可用True/False
"""
if self._redis is None:
async with self._memory_lock:
endpoint_count = self._memory_endpoint_counts.get(endpoint_id, 0)
key_count = self._memory_key_counts.get(key_id, 0)
if (
endpoint_max_concurrent is not None
and endpoint_count >= endpoint_max_concurrent
):
return False
if key_max_concurrent is not None and key_count >= key_max_concurrent:
return False
if key_rpm_limit is None:
return True
endpoint_count, key_count = await self.get_current_concurrency(endpoint_id, key_id)
# 从配置读取默认值
from src.config.settings import config
# 检查 Endpoint 级别限制
if endpoint_max_concurrent is not None and endpoint_count >= endpoint_max_concurrent:
return False
if cache_reservation_ratio is None:
cache_reservation_ratio = config.cache_reservation_ratio
# 检查 Key 级别限制
if key_max_concurrent is not None and key_count >= key_max_concurrent:
return False
key_count = await self.get_key_rpm_count(key_id)
return True
if is_cached_user:
return key_count < key_rpm_limit
else:
# 新用户只能使用 (1 - cache_reservation_ratio) 的槽位
available_for_new = max(1, math.floor(key_rpm_limit * (1 - cache_reservation_ratio)))
return key_count < available_for_new
async def acquire_slot(
async def acquire_rpm_slot(
self,
endpoint_id: str,
endpoint_max_concurrent: Optional[int],
key_id: str,
key_max_concurrent: Optional[int],
is_cached_user: bool = False, # 新增:是否是缓存用户
cache_reservation_ratio: Optional[float] = None, # 缓存预留比例None 时从配置读取
ttl_seconds: Optional[int] = None, # TTL 秒数None 时从配置读取
key_rpm_limit: Optional[int],
is_cached_user: bool = False,
cache_reservation_ratio: Optional[float] = None,
) -> bool:
"""
尝试获取并发槽位(支持缓存用户优先级)
尝试获取 RPM 槽位(支持缓存用户优先级)
Args:
endpoint_id: Endpoint ID
endpoint_max_concurrent: Endpoint 最大并发数None 表示不限制)
key_id: ProviderAPIKey ID
key_max_concurrent: Key 最大并发数(None 表示不限制)
key_rpm_limit: Key RPM 限制(每分钟最大请求数,None 表示不限制)
is_cached_user: 是否是缓存用户(缓存用户可使用全部槽位)
cache_reservation_ratio: 缓存预留比例None 时从配置读取
ttl_seconds: TTL 秒数None 时从配置读取
Returns:
是否成功获取True/False
缓存预留机制说明:
- 假设 key_max_concurrent = 10, cache_reservation_ratio = 0.3
- 新用户最多使用: 7个槽位 (10 * (1 - 0.3))
- 缓存用户最多使用: 10个槽位(全部)
- 预留的3个槽位专门给缓存用户,保证他们的请求优先
- 假设 key_rpm_limit = 100, cache_reservation_ratio = 0.3
- 新用户最多使用: 70 RPM (100 * (1 - 0.3))
- 缓存用户最多使用: 100 RPM(全部)
- 预留的 30 RPM 专门给缓存用户,保证他们的请求优先
"""
# 从配置读取默认值
from src.config.settings import config
if cache_reservation_ratio is None:
cache_reservation_ratio = config.cache_reservation_ratio
if ttl_seconds is None:
ttl_seconds = config.concurrency_slot_ttl
if self._redis is None:
async with self._memory_lock:
endpoint_count = self._memory_endpoint_counts.get(endpoint_id, 0)
key_count = self._memory_key_counts.get(key_id, 0)
bucket = self._get_rpm_bucket()
# 定期清理过期数据,避免内存泄漏
self._cleanup_expired_memory_rpm_counts(bucket)
# Endpoint 限制
if (
endpoint_max_concurrent is not None
and endpoint_count >= endpoint_max_concurrent
):
return False
key_count = self._get_memory_key_rpm_count(key_id, bucket)
# Key 限制,包含缓存预留
if key_max_concurrent is not None:
# Key RPM 限制,包含缓存预留
if key_rpm_limit is not None:
if is_cached_user:
if key_count >= key_max_concurrent:
if key_count >= key_rpm_limit:
return False
else:
# 新用户只能使用 (1 - cache_reservation_ratio) 的槽位
available_for_new = max(
1, math.ceil(key_max_concurrent * (1 - cache_reservation_ratio))
1, math.floor(key_rpm_limit * (1 - cache_reservation_ratio))
)
if key_count >= available_for_new:
return False
# 通过限制,更新计数
self._memory_endpoint_counts[endpoint_id] = endpoint_count + 1
self._memory_key_counts[key_id] = key_count + 1
self._set_memory_key_rpm_count(key_id, bucket, key_count + 1)
return True
endpoint_key = self._get_endpoint_key(endpoint_id)
key_key = self._get_key_key(key_id)
bucket = self._get_rpm_bucket()
key_key = self._get_key_key(key_id, bucket=bucket)
try:
# 使用 Lua 脚本保证原子性(新增缓存预留逻辑)
# 使用 Lua 脚本保证原子性(支持缓存预留逻辑)
lua_script = """
local endpoint_key = KEYS[1]
local key_key = KEYS[2]
local endpoint_max = tonumber(ARGV[1])
local key_max = tonumber(ARGV[2])
local ttl = tonumber(ARGV[3])
local is_cached = tonumber(ARGV[4]) -- 0=新用户, 1=缓存用户
local cache_ratio = tonumber(ARGV[5]) -- 缓存预留比例
local key_key = KEYS[1]
local key_max = tonumber(ARGV[1])
local key_ttl = tonumber(ARGV[2])
local is_cached = tonumber(ARGV[3]) -- 0=新用户, 1=缓存用户
local cache_ratio = tonumber(ARGV[4]) -- 缓存预留比例
-- 获取当前值
local endpoint_count = tonumber(redis.call('GET', endpoint_key) or '0')
local key_count = tonumber(redis.call('GET', key_key) or '0')
-- 检查 endpoint 限制(-1 表示不限制)
if endpoint_max >= 0 and endpoint_count >= endpoint_max then
return 0 -- 失败endpoint 已满
end
-- 检查 key 限制(支持缓存预留)
if key_max >= 0 then
if is_cached == 0 then
-- 新用户:只能使用 (1 - cache_ratio) 的槽位
local available_for_new = math.floor(key_max * (1 - cache_ratio))
local available_for_new = math.max(1, math.floor(key_max * (1 - cache_ratio)))
if key_count >= available_for_new then
return 0 -- 失败:新用户配额已满
end
@@ -274,10 +383,8 @@ class ConcurrencyManager:
end
-- 增加计数
redis.call('INCR', endpoint_key)
redis.call('EXPIRE', endpoint_key, ttl)
redis.call('INCR', key_key)
redis.call('EXPIRE', key_key, ttl)
redis.call('EXPIRE', key_key, key_ttl)
return 1 -- 成功
"""
@@ -285,12 +392,10 @@ class ConcurrencyManager:
# 执行脚本
result = await self._redis.eval(
lua_script,
2, # 2 个 KEYS
endpoint_key,
1, # 1 个 KEY
key_key,
endpoint_max_concurrent if endpoint_max_concurrent is not None else -1,
key_max_concurrent if key_max_concurrent is not None else -1,
ttl_seconds,
key_rpm_limit if key_rpm_limit is not None else -1,
self._key_rpm_key_ttl_seconds,
1 if is_cached_user else 0, # 缓存用户标志
cache_reservation_ratio, # 预留比例
)
@@ -299,143 +404,68 @@ class ConcurrencyManager:
if success:
user_type = "缓存用户" if is_cached_user else "新用户"
logger.debug(
f"[OK] 获取并发槽位成功: endpoint={endpoint_id}, key={key_id}, "
f"类型={user_type}"
)
logger.debug(f"[OK] 获取 RPM 槽位成功: key={key_id}, 类型={user_type}")
else:
endpoint_count, key_count = await self.get_current_concurrency(endpoint_id, key_id)
key_count = await self.get_key_rpm_count(key_id)
# 计算新用户可用槽位
if key_max_concurrent and not is_cached_user:
available_for_new = int(key_max_concurrent * (1 - cache_reservation_ratio))
# 计算新用户可用 RPM
if key_rpm_limit and not is_cached_user:
available_for_new = int(key_rpm_limit * (1 - cache_reservation_ratio))
user_info = f"新用户配额={available_for_new}, 当前={key_count}"
else:
user_info = f"缓存用户, 当前={key_count}/{key_max_concurrent}"
user_info = f"缓存用户, 当前={key_count}/{key_rpm_limit}"
logger.warning(
f"[WARN] 并发槽位已满: endpoint={endpoint_id}({endpoint_count}/{endpoint_max_concurrent}), "
f"key={key_id}({user_info})"
)
logger.warning(f"[WARN] RPM 限制已达上限: key={key_id}({user_info})")
return success
except Exception as e:
logger.error(f"获取并发槽位失败,降级到内存模式: {e}")
logger.error(f"获取 RPM 槽位失败,降级到内存模式: {e}")
# Redis 异常时降级到内存模式进行保守限流
# 使用较低的限制值(原限制的 50%)避免上游 API 被打爆
async with self._memory_lock:
endpoint_count = self._memory_endpoint_counts.get(endpoint_id, 0)
key_count = self._memory_key_counts.get(key_id, 0)
bucket = self._get_rpm_bucket()
self._cleanup_expired_memory_rpm_counts(bucket)
key_count = self._get_memory_key_rpm_count(key_id, bucket)
# 降级模式下使用更保守的限制50%
fallback_endpoint_limit = (
max(1, endpoint_max_concurrent // 2)
if endpoint_max_concurrent is not None
else None
)
fallback_key_limit = (
max(1, key_max_concurrent // 2) if key_max_concurrent is not None else None
fallback_rpm_limit = (
max(1, key_rpm_limit // 2) if key_rpm_limit is not None else None
)
if (
fallback_endpoint_limit is not None
and endpoint_count >= fallback_endpoint_limit
):
if fallback_rpm_limit is not None and key_count >= fallback_rpm_limit:
logger.warning(
f"[FALLBACK] Endpoint 并发达到降级限制: {endpoint_count}/{fallback_endpoint_limit}"
)
return False
if fallback_key_limit is not None and key_count >= fallback_key_limit:
logger.warning(
f"[FALLBACK] Key 并发达到降级限制: {key_count}/{fallback_key_limit}"
f"[FALLBACK] Key RPM 达到降级限制: {key_count}/{fallback_rpm_limit}"
)
return False
# 更新内存计数
self._memory_endpoint_counts[endpoint_id] = endpoint_count + 1
self._memory_key_counts[key_id] = key_count + 1
logger.debug(
f"[FALLBACK] 使用内存模式获取槽位: endpoint={endpoint_id}, key={key_id}"
)
self._set_memory_key_rpm_count(key_id, bucket, key_count + 1)
logger.debug(f"[FALLBACK] 使用内存模式获取 RPM 槽位: key={key_id}")
return True
async def release_slot(self, endpoint_id: str, key_id: str) -> None:
"""
释放并发槽位
Args:
endpoint_id: Endpoint ID
key_id: ProviderAPIKey ID
"""
if self._redis is None:
async with self._memory_lock:
if endpoint_id in self._memory_endpoint_counts:
self._memory_endpoint_counts[endpoint_id] = max(
0, self._memory_endpoint_counts[endpoint_id] - 1
)
if self._memory_endpoint_counts[endpoint_id] == 0:
self._memory_endpoint_counts.pop(endpoint_id, None)
if key_id in self._memory_key_counts:
self._memory_key_counts[key_id] = max(0, self._memory_key_counts[key_id] - 1)
if self._memory_key_counts[key_id] == 0:
self._memory_key_counts.pop(key_id, None)
return
endpoint_key = self._get_endpoint_key(endpoint_id)
key_key = self._get_key_key(key_id)
try:
# 使用 Lua 脚本保证原子性(不会减到负数)
lua_script = """
local endpoint_key = KEYS[1]
local key_key = KEYS[2]
local endpoint_count = tonumber(redis.call('GET', endpoint_key) or '0')
local key_count = tonumber(redis.call('GET', key_key) or '0')
if endpoint_count > 0 then
redis.call('DECR', endpoint_key)
end
if key_count > 0 then
redis.call('DECR', key_key)
end
return 1
"""
await self._redis.eval(lua_script, 2, endpoint_key, key_key)
logger.debug(f"[OK] 释放并发槽位: endpoint={endpoint_id}, key={key_id}")
except Exception as e:
logger.error(f"释放并发槽位失败: {e}")
@asynccontextmanager
async def concurrency_guard(
async def rpm_guard(
self,
endpoint_id: str,
endpoint_max_concurrent: Optional[int],
key_id: str,
key_max_concurrent: Optional[int],
is_cached_user: bool = False, # 新增:是否是缓存用户
cache_reservation_ratio: Optional[float] = None, # 缓存预留比例None 时从配置读取
key_rpm_limit: Optional[int],
is_cached_user: bool = False,
cache_reservation_ratio: Optional[float] = None,
):
"""
并发控制上下文管理器(支持缓存用户优先级)
RPM 限制上下文管理器(支持缓存用户优先级)
用法:
async with manager.concurrency_guard(
endpoint_id, endpoint_max, key_id, key_max,
async with manager.rpm_guard(
key_id, key_rpm_limit,
is_cached_user=True # 缓存用户
):
# 执行请求
response = await send_request(...)
如果获取失败,会抛出 ConcurrencyLimitError 异常
注意RPM 是按分钟窗口计数,不需要在请求结束后释放
"""
# 从配置读取默认值
from src.config.settings import config
@@ -444,11 +474,9 @@ class ConcurrencyManager:
cache_reservation_ratio = config.cache_reservation_ratio
# 尝试获取槽位(传递缓存用户参数)
acquired = await self.acquire_slot(
endpoint_id,
endpoint_max_concurrent,
acquired = await self.acquire_rpm_slot(
key_id,
key_max_concurrent,
key_rpm_limit,
is_cached_user,
cache_reservation_ratio,
)
@@ -458,7 +486,7 @@ class ConcurrencyManager:
user_type = "缓存用户" if is_cached_user else "新用户"
raise ConcurrencyLimitError(
f"并发限制已达上限: endpoint={endpoint_id}, key={key_id}, 类型={user_type}"
f"RPM 限制已达上限: key={key_id}, 类型={user_type}"
)
# 记录开始时间和状态
@@ -469,7 +497,7 @@ class ConcurrencyManager:
try:
yield # 执行请求
except Exception as e:
except Exception:
# 记录异常
exception_occurred = True
raise
@@ -498,7 +526,7 @@ class ConcurrencyManager:
# 告警:槽位占用时间过长(超过 60 秒)
if slot_duration > 60:
logger.warning(
f"[WARN] 并发槽位占用时间过长: "
f"[WARN] 请求耗时过长: "
f"key_id={key_id[:8] if key_id else 'unknown'}..., "
f"duration={slot_duration:.1f}s, "
f"exception={exception_occurred}"
@@ -506,67 +534,64 @@ class ConcurrencyManager:
except Exception as metric_error:
# 指标记录失败不应影响业务逻辑
logger.debug(f"记录并发指标失败: {metric_error}")
logger.debug(f"记录指标失败: {metric_error}")
# 自动释放槽位(即使发生异常)
await self.release_slot(endpoint_id, key_id)
# 注意RPM 计数不需要在请求结束后释放,它会在分钟窗口过期后自动重置
async def reset_concurrency(
self, endpoint_id: Optional[str] = None, key_id: Optional[str] = None
) -> None:
async def reset_key_rpm(self, key_id: str) -> None:
"""
重置并发计数(管理功能,慎用)
重置 Key RPM 计数(管理功能,慎用)
Args:
endpoint_id: Endpoint ID可选None 表示重置所有 endpoint
key_id: ProviderAPIKey ID可选None 表示重置所有 key
key_id: ProviderAPIKey ID
"""
if self._redis is None:
async with self._memory_lock:
if endpoint_id:
self._memory_endpoint_counts.pop(endpoint_id, None)
logger.info(f"[RESET] 重置 Endpoint 并发计数(内存): {endpoint_id}")
else:
count = len(self._memory_endpoint_counts)
self._memory_endpoint_counts.clear()
if count:
logger.info(f"[RESET] 重置所有 Endpoint 并发计数(内存): {count}")
if key_id:
self._memory_key_counts.pop(key_id, None)
logger.info(f"[RESET] 重置 Key 并发计数(内存): {key_id}")
else:
count = len(self._memory_key_counts)
self._memory_key_counts.clear()
if count:
logger.info(f"[RESET] 重置所有 Key 并发计数(内存): {count}")
self._memory_key_rpm_counts.pop(key_id, None)
logger.info(f"[RESET] 重置 Key RPM 计数(内存): {key_id}")
return
try:
if endpoint_id:
endpoint_key = self._get_endpoint_key(endpoint_id)
await self._redis.delete(endpoint_key)
logger.info(f"[RESET] 重置 Endpoint 并发计数: {endpoint_id}")
else:
# 重置所有 endpoint
keys = await self._redis.keys("concurrency:endpoint:*")
if keys:
await self._redis.delete(*keys)
logger.info(f"[RESET] 重置所有 Endpoint 并发计数: {len(keys)}")
if key_id:
key_key = self._get_key_key(key_id)
await self._redis.delete(key_key)
logger.info(f"[RESET] 重置 Key 并发计数: {key_id}")
else:
# 重置所有 key
keys = await self._redis.keys("concurrency:key:*")
if keys:
await self._redis.delete(*keys)
logger.info(f"[RESET] 重置所有 Key 并发计数: {len(keys)}")
deleted_count = await self._scan_and_delete(f"rpm:key:{key_id}:*")
logger.info(f"[RESET] 重置 Key RPM 计数: {key_id}, 删除 {deleted_count} 个键")
except Exception as e:
logger.error(f"重置并发计数失败: {e}")
logger.error(f"重置 Key RPM 计数失败: {e}")
async def reset_all_rpm(self) -> None:
"""重置所有 Key RPM 计数(管理功能,慎用)"""
if self._redis is None:
async with self._memory_lock:
count = len(self._memory_key_rpm_counts)
self._memory_key_rpm_counts.clear()
if count:
logger.info(f"[RESET] 重置所有 Key RPM 计数(内存): {count}")
return
try:
deleted_count = await self._scan_and_delete("rpm:key:*")
if deleted_count:
logger.info(f"[RESET] 重置所有 Key RPM 计数: {deleted_count}")
except Exception as e:
logger.error(f"重置所有 Key RPM 计数失败: {e}")
async def _scan_and_delete(self, pattern: str, batch_size: int = 100) -> int:
"""使用 SCAN 遍历并分批删除匹配的键,避免阻塞 Redis"""
if self._redis is None:
return 0
deleted_count = 0
cursor = 0
while True:
cursor, keys = await self._redis.scan(cursor, match=pattern, count=batch_size)
if keys:
# 分批删除,每批最多 batch_size 个
for i in range(0, len(keys), batch_size):
batch = keys[i : i + batch_size]
await self._redis.delete(*batch)
deleted_count += len(batch)
if cursor == 0:
break
return deleted_count
# 全局单例

View File

@@ -3,7 +3,7 @@
"""
from datetime import datetime, timezone
from typing import Any, Dict, Optional, Tuple
from typing import Dict, Optional
from src.core.logger import logger
@@ -62,7 +62,7 @@ class RateLimitDetector:
def detect_from_headers(
headers: Dict[str, str],
provider_name: str = "unknown",
current_concurrent: Optional[int] = None,
current_usage: Optional[int] = None,
) -> RateLimitInfo:
"""
从响应头中检测速率限制类型
@@ -70,7 +70,7 @@ class RateLimitDetector:
Args:
headers: 429响应的HTTP头
provider_name: 提供商名称(用于选择解析策略)
current_concurrent: 当前并发数(用于判断是否为并发限制)
current_usage: 当前使用量RPM 计数,用于启发式判断是否为并发限制)
Returns:
RateLimitInfo对象
@@ -80,16 +80,16 @@ class RateLimitDetector:
# 根据提供商选择解析策略
if "anthropic" in provider_name.lower() or "claude" in provider_name.lower():
return RateLimitDetector._parse_anthropic_headers(headers_lower, current_concurrent)
return RateLimitDetector._parse_anthropic_headers(headers_lower, current_usage)
elif "openai" in provider_name.lower():
return RateLimitDetector._parse_openai_headers(headers_lower, current_concurrent)
return RateLimitDetector._parse_openai_headers(headers_lower, current_usage)
else:
return RateLimitDetector._parse_generic_headers(headers_lower, current_concurrent)
return RateLimitDetector._parse_generic_headers(headers_lower, current_usage)
@staticmethod
def _parse_anthropic_headers(
headers: Dict[str, str],
current_concurrent: Optional[int] = None,
current_usage: Optional[int] = None,
) -> RateLimitInfo:
"""
解析 Anthropic Claude API 的速率限制头
@@ -127,29 +127,66 @@ class RateLimitDetector:
raw_headers=headers,
)
# 2. 可能的并发限制判断(多条件综合
# 条件:当前并发数存在,且 remaining > 0说明不是 RPM 耗尽)
# 同时 retry_after 较短(并发限制通常 retry_after 较短,如 1-10 秒)
is_likely_concurrent = (
current_concurrent is not None
and current_concurrent >= 2 # 至少有 2 个并发
and (requests_remaining is None or requests_remaining > 0) # RPM 未耗尽
and (retry_after is None or retry_after <= 30) # 短暂等待
)
# 2. 并发限制判断(多条件策略
# 注意current_usage 是 RPM 计数(当前分钟请求数),不是真正的并发数
#
# 判断条件(满足任一即可):
# A. 强判断remaining > 0 且 retry_after <= 30Provider 明确告知还有配额但需要等待)
# B. 弱判断:只有 retry_after <= 5 且缺少 remaining 头(短等待时间是并发限制的典型特征)
#
# 选择保守的 retry_after 阈值:
# - 强判断用 30 秒(有 remaining 头时)
# - 弱判断用 5 秒(无 remaining 头时,更保守)
is_likely_concurrent = False
concurrent_reason = ""
# 条件 Aremaining > 0 且 retry_after <= 30
if (
requests_remaining is not None
and requests_remaining > 0
and retry_after is not None
and retry_after <= 30
):
is_likely_concurrent = True
concurrent_reason = f"remaining={requests_remaining} > 0, retry_after={retry_after}s <= 30s"
# 条件 B无 remaining 头但 retry_after 很短(<= 5 秒)
elif (
requests_remaining is None
and retry_after is not None
and retry_after <= 5
):
is_likely_concurrent = True
concurrent_reason = f"no remaining header, retry_after={retry_after}s <= 5s"
if is_likely_concurrent:
logger.info(
f"检测到可能的并发限制: current_concurrent={current_concurrent}, "
f"remaining={requests_remaining}, retry_after={retry_after}"
)
logger.info(f"检测到并发限制: {concurrent_reason}")
return RateLimitInfo(
limit_type=RateLimitType.CONCURRENT,
retry_after=retry_after,
current_usage=current_concurrent,
current_usage=current_usage,
raw_headers=headers,
)
# 3. 未知类型
# 3. 默认视为 RPM 限制(更保守的处理)
# 无法明确区分时,视为 RPM 限制让系统降低 RPM
# 这比误判为并发限制(不降 RPM更安全
if retry_after is not None or requests_limit is not None:
logger.info(
f"无法明确区分限制类型,保守视为 RPM 限制: "
f"remaining={requests_remaining}, retry_after={retry_after}"
)
return RateLimitInfo(
limit_type=RateLimitType.RPM,
retry_after=retry_after,
limit_value=requests_limit,
remaining=requests_remaining,
reset_at=requests_reset,
current_usage=current_usage,
raw_headers=headers,
)
# 4. 完全没有信息,标记为未知
return RateLimitInfo(
limit_type=RateLimitType.UNKNOWN,
retry_after=retry_after,
@@ -159,7 +196,7 @@ class RateLimitDetector:
@staticmethod
def _parse_openai_headers(
headers: Dict[str, str],
current_concurrent: Optional[int] = None,
current_usage: Optional[int] = None,
) -> RateLimitInfo:
"""
解析 OpenAI API 的速率限制头
@@ -195,23 +232,55 @@ class RateLimitDetector:
raw_headers=headers,
)
# 2. 可能的并发限制(多条件综合判断
is_likely_concurrent = (
current_concurrent is not None
and current_concurrent >= 2
and (requests_remaining is None or requests_remaining > 0)
and (retry_after is None or retry_after <= 30)
)
# 2. 并发限制判断(多条件策略
# 判断条件(满足任一即可):
# A. 强判断remaining > 0 且 retry_after <= 30
# B. 弱判断:只有 retry_after <= 5 且缺少 remaining 头
is_likely_concurrent = False
concurrent_reason = ""
if (
requests_remaining is not None
and requests_remaining > 0
and retry_after is not None
and retry_after <= 30
):
is_likely_concurrent = True
concurrent_reason = f"remaining={requests_remaining} > 0, retry_after={retry_after}s <= 30s"
elif (
requests_remaining is None
and retry_after is not None
and retry_after <= 5
):
is_likely_concurrent = True
concurrent_reason = f"no remaining header, retry_after={retry_after}s <= 5s"
if is_likely_concurrent:
logger.info(f"检测到并发限制: {concurrent_reason}")
return RateLimitInfo(
limit_type=RateLimitType.CONCURRENT,
retry_after=retry_after,
current_usage=current_concurrent,
current_usage=current_usage,
raw_headers=headers,
)
# 3. 未知类型
# 3. 默认视为 RPM 限制(更保守的处理)
if retry_after is not None or requests_limit is not None:
logger.info(
f"无法明确区分限制类型,保守视为 RPM 限制: "
f"remaining={requests_remaining}, retry_after={retry_after}"
)
return RateLimitInfo(
limit_type=RateLimitType.RPM,
retry_after=retry_after,
limit_value=requests_limit,
remaining=requests_remaining,
reset_at=requests_reset,
current_usage=current_usage,
raw_headers=headers,
)
# 4. 完全没有信息,标记为未知
return RateLimitInfo(
limit_type=RateLimitType.UNKNOWN,
retry_after=retry_after,
@@ -221,7 +290,7 @@ class RateLimitDetector:
@staticmethod
def _parse_generic_headers(
headers: Dict[str, str],
current_concurrent: Optional[int] = None,
current_usage: Optional[int] = None,
) -> RateLimitInfo:
"""
解析通用的速率限制头
@@ -247,23 +316,54 @@ class RateLimitDetector:
raw_headers=headers,
)
# 2. 可能的并发限制
is_likely_concurrent = (
current_concurrent is not None
and current_concurrent >= 2
and (remaining is None or remaining > 0)
and (retry_after is None or retry_after <= 30)
)
# 2. 并发限制判断(多条件策略)
# 判断条件(满足任一即可):
# A. 强判断remaining > 0 且 retry_after <= 30
# B. 弱判断:只有 retry_after <= 5 且缺少 remaining 头
is_likely_concurrent = False
concurrent_reason = ""
if (
remaining is not None
and remaining > 0
and retry_after is not None
and retry_after <= 30
):
is_likely_concurrent = True
concurrent_reason = f"remaining={remaining} > 0, retry_after={retry_after}s <= 30s"
elif (
remaining is None
and retry_after is not None
and retry_after <= 5
):
is_likely_concurrent = True
concurrent_reason = f"no remaining header, retry_after={retry_after}s <= 5s"
if is_likely_concurrent:
logger.info(f"检测到并发限制: {concurrent_reason}")
return RateLimitInfo(
limit_type=RateLimitType.CONCURRENT,
retry_after=retry_after,
current_usage=current_concurrent,
current_usage=current_usage,
raw_headers=headers,
)
# 3. 未知类型
# 3. 默认视为 RPM 限制(更保守的处理)
if retry_after is not None or limit_value is not None:
logger.info(
f"无法明确区分限制类型,保守视为 RPM 限制: "
f"remaining={remaining}, retry_after={retry_after}"
)
return RateLimitInfo(
limit_type=RateLimitType.RPM,
retry_after=retry_after,
limit_value=limit_value,
remaining=remaining,
current_usage=current_usage,
raw_headers=headers,
)
# 4. 完全没有信息,标记为未知
return RateLimitInfo(
limit_type=RateLimitType.UNKNOWN,
retry_after=retry_after,
@@ -317,7 +417,7 @@ class RateLimitDetector:
def detect_rate_limit_type(
headers: Dict[str, str],
provider_name: str = "unknown",
current_concurrent: Optional[int] = None,
current_usage: Optional[int] = None,
) -> RateLimitInfo:
"""
检测速率限制类型(便捷函数)
@@ -325,9 +425,9 @@ def detect_rate_limit_type(
Args:
headers: 429响应头
provider_name: 提供商名称
current_concurrent: 当前并发数
current_usage: 当前使用量RPM 计数)
Returns:
RateLimitInfo对象
"""
return RateLimitDetector.detect_from_headers(headers, provider_name, current_concurrent)
return RateLimitDetector.detect_from_headers(headers, provider_name, current_usage)

View File

@@ -1,135 +0,0 @@
"""
RPM (Requests Per Minute) 限流服务
"""
import time
from datetime import datetime, timedelta, timezone
from typing import Dict, Tuple
from sqlalchemy.orm import Session
from src.core.batch_committer import get_batch_committer
from src.core.logger import logger
from src.models.database import Provider
from src.models.database_extensions import ProviderUsageTracking
class RPMLimiter:
"""RPM限流器"""
def __init__(self, db: Session):
self.db = db
# 内存中的RPM计数器 {provider_id: (count, window_start)}
self._rpm_counters: Dict[str, Tuple[int, float]] = {}
def check_and_increment(self, provider_id: str) -> bool:
"""
检查并递增RPM计数
Returns:
True if allowed, False if rate limited
"""
provider = self.db.query(Provider).filter(Provider.id == provider_id).first()
if not provider:
return True
rpm_limit = provider.rpm_limit
if rpm_limit is None:
# 未设置限制
return True
if rpm_limit == 0:
logger.warning(f"Provider {provider.name} is fully restricted by RPM limit=0")
return False
current_time = time.time()
# 检查是否需要重置
if provider.rpm_reset_at and provider.rpm_reset_at < datetime.now(timezone.utc):
provider.rpm_used = 0
provider.rpm_reset_at = datetime.fromtimestamp(current_time + 60, tz=timezone.utc)
self.db.commit() # 立即提交事务,释放数据库锁
# 检查是否超限
if provider.rpm_used >= rpm_limit:
logger.warning(f"Provider {provider.name} RPM limit exceeded")
return False
# 递增计数
provider.rpm_used += 1
if not provider.rpm_reset_at:
provider.rpm_reset_at = datetime.fromtimestamp(current_time + 60, tz=timezone.utc)
self.db.commit() # 立即提交事务,释放数据库锁
return True
def record_usage(
self, provider_id: str, success: bool, response_time_ms: float, cost_usd: float
):
"""记录使用情况到追踪表"""
# 获取当前分钟窗口
now = datetime.now(timezone.utc)
window_start = now.replace(second=0, microsecond=0)
window_end = window_start + timedelta(minutes=1)
# 查找或创建追踪记录
tracking = (
self.db.query(ProviderUsageTracking)
.filter(
ProviderUsageTracking.provider_id == provider_id,
ProviderUsageTracking.window_start == window_start,
)
.first()
)
if not tracking:
tracking = ProviderUsageTracking(
provider_id=provider_id, window_start=window_start, window_end=window_end
)
self.db.add(tracking)
# 更新统计
tracking.total_requests += 1
if success:
tracking.successful_requests += 1
else:
tracking.failed_requests += 1
tracking.total_response_time_ms += response_time_ms
tracking.avg_response_time_ms = tracking.total_response_time_ms / tracking.total_requests
tracking.total_cost_usd += cost_usd
self.db.flush() # 只 flush不立即 commit
# RPM 使用统计是非关键数据,使用批量提交
get_batch_committer().mark_dirty(self.db)
logger.debug(f"Recorded usage for provider {provider_id}")
def get_rpm_status(self, provider_id: str) -> Dict:
"""获取提供商的RPM状态"""
provider = self.db.query(Provider).filter(Provider.id == provider_id).first()
if not provider:
return {"error": "Provider not found"}
return {
"provider_id": provider_id,
"provider_name": provider.name,
"rpm_limit": provider.rpm_limit,
"rpm_used": provider.rpm_used,
"rpm_reset_at": provider.rpm_reset_at.isoformat() if provider.rpm_reset_at else None,
"available": (
provider.rpm_limit - provider.rpm_used if provider.rpm_limit is not None else None
),
}
def reset_rpm_counter(self, provider_id: str):
"""手动重置RPM计数器"""
provider = self.db.query(Provider).filter(Provider.id == provider_id).first()
if provider:
provider.rpm_used = 0
provider.rpm_reset_at = None
self.db.commit() # 立即提交事务,释放数据库锁
logger.info(f"Reset RPM counter for provider {provider.name}")

View File

@@ -89,24 +89,29 @@ class RequestExecutor:
try:
# 计算动态预留比例
reservation_manager = get_adaptive_reservation_manager()
# 获取当前并发数用于计算负载
# 获取当前 RPM 计数用于计算负载
# 注意key 侧返回的是 RPM 计数(不会在请求结束时减少,靠 TTL 过期)
try:
_, current_key_concurrent = await self.concurrency_manager.get_current_concurrency(
_, current_key_rpm = await self.concurrency_manager.get_current_concurrency(
endpoint_id=endpoint.id,
key_id=key.id,
)
except Exception as e:
logger.debug(f"获取并发数失败(用于预留计算): {e}")
current_key_concurrent = 0
logger.debug(f"获取 RPM 计数失败(用于预留计算): {e}")
current_key_rpm = 0
# 获取有效的并发限制(自适应或固定)
effective_key_limit = (
key.learned_max_concurrent if key.max_concurrent is None else key.max_concurrent
)
# 获取有效的 RPM 限制(自适应或固定)
if key.rpm_limit is None:
# 自适应模式:优先使用学习值,否则使用默认初始限制,避免无限制打爆上游
from src.config.constants import RPMDefaults
effective_key_limit = int(key.learned_rpm_limit or RPMDefaults.INITIAL_LIMIT)
else:
effective_key_limit = int(key.rpm_limit)
reservation_result = reservation_manager.calculate_reservation(
key=key,
current_concurrent=current_key_concurrent,
current_usage=current_key_rpm,
effective_limit=effective_key_limit,
)
dynamic_reservation_ratio = reservation_result.ratio
@@ -115,24 +120,22 @@ class RequestExecutor:
f"ratio={dynamic_reservation_ratio:.0%}, phase={reservation_result.phase}, "
f"confidence={reservation_result.confidence:.0%}")
async with self.concurrency_manager.concurrency_guard(
endpoint_id=endpoint.id,
endpoint_max_concurrent=endpoint.max_concurrent,
async with self.concurrency_manager.rpm_guard(
key_id=key.id,
key_max_concurrent=effective_key_limit,
key_rpm_limit=effective_key_limit,
is_cached_user=is_cached_user,
cache_reservation_ratio=dynamic_reservation_ratio,
):
# 获取当前 RPM 计数guard 内再次获取以获得最新值)
try:
_, key_concurrent = await self.concurrency_manager.get_current_concurrency(
endpoint_id=endpoint.id,
key_rpm_count = await self.concurrency_manager.get_key_rpm_count(
key_id=key.id,
)
except Exception as e:
logger.debug(f"获取并发数失败guard 内): {e}")
key_concurrent = None
logger.debug(f"获取 RPM 计数失败guard 内): {e}")
key_rpm_count = None
context.concurrent_requests = key_concurrent
context.concurrent_requests = key_rpm_count # 用于记录,实际是 RPM 计数
context.start_time = time.time()
response = await request_func(provider, endpoint, key)
@@ -142,15 +145,18 @@ class RequestExecutor:
health_monitor.record_success(
db=self.db,
key_id=key.id,
api_format=(
api_format.value if isinstance(api_format, APIFormat) else api_format
),
response_time_ms=context.elapsed_ms,
)
# 自适应模式:max_concurrent = NULL
if key.max_concurrent is None and key_concurrent is not None:
# 自适应模式:rpm_limit = NULL
if key.rpm_limit is None and key_rpm_count is not None:
self.adaptive_manager.handle_success(
db=self.db,
key=key,
current_concurrent=key_concurrent,
current_rpm=key_rpm_count,
)
# 根据是否为流式请求,标记不同状态
@@ -162,7 +168,7 @@ class RequestExecutor:
db=self.db,
candidate_id=candidate_id,
status_code=200,
concurrent_requests=key_concurrent,
concurrent_requests=key_rpm_count,
)
else:
# 非流式请求:标记为 success 状态
@@ -171,7 +177,7 @@ class RequestExecutor:
candidate_id=candidate_id,
status_code=200,
latency_ms=context.elapsed_ms,
concurrent_requests=key_concurrent,
concurrent_requests=key_rpm_count,
extra_data={
"is_cached_user": is_cached_user,
"model_name": model_name,

View File

@@ -289,10 +289,10 @@ class RequestResult:
status_code = 500
error_type = "internal_error"
# 构建错误消息:优先使用上游响应作为主要错误信息
if isinstance(exception, ProviderNotAvailableException) and exception.upstream_response:
error_message = exception.upstream_response
else:
# 构建错误消息:优先使用友好的 message 属性
# upstream_response 仅用于调试/链路追踪,不作为客户端错误消息
error_message = getattr(exception, "message", None)
if not error_message or not isinstance(error_message, str):
error_message = str(exception)
return cls(

View File

@@ -96,8 +96,6 @@ class QuotaScheduler:
logger.info(f"Resetting quota for provider {provider.name}")
provider.monthly_used_usd = 0.0
provider.rpm_used = 0 # 同时重置RPM计数
provider.rpm_reset_at = None
provider.quota_last_reset_at = now
reset_count += 1
@@ -126,8 +124,6 @@ class QuotaScheduler:
provider = db.query(Provider).filter(Provider.id == provider_id).first()
if provider and provider.billing_type == ProviderBillingType.MONTHLY_QUOTA:
provider.monthly_used_usd = 0.0
provider.rpm_used = 0
provider.rpm_reset_at = None
provider.quota_last_reset_at = now
db.commit()
logger.info(f"Force reset quota for provider {provider.name}")
@@ -140,8 +136,6 @@ class QuotaScheduler:
)
for provider in providers:
provider.monthly_used_usd = 0.0
provider.rpm_used = 0
provider.rpm_reset_at = None
provider.quota_last_reset_at = now
db.commit()
logger.info(f"Force reset quotas for {len(providers)} providers")

View File

@@ -93,6 +93,10 @@ class UsageRecorder:
if metadata.original_model and metadata.original_model != metadata.model:
target_model = metadata.model
# 非流式成功时,返回给客户端的是提供商响应头(透传)+ content-type
client_response_headers = dict(metadata.provider_response_headers) if metadata.provider_response_headers else {}
client_response_headers["content-type"] = "application/json"
await UsageService.record_usage(
db=self.db,
user=self.user,
@@ -115,6 +119,7 @@ class UsageRecorder:
request_body=request_body or result.request_body,
provider_request_headers=metadata.provider_request_headers,
response_headers=metadata.provider_response_headers,
client_response_headers=client_response_headers,
response_body=result.response_data if isinstance(result.response_data, dict) else {},
request_id=self.request_id,
provider_id=metadata.provider_id,
@@ -181,6 +186,8 @@ class UsageRecorder:
request_body=request_body or result.request_body,
provider_request_headers=metadata.provider_request_headers,
response_headers={},
# 失败请求返回给客户端的是 JSON 错误响应
client_response_headers={"content-type": "application/json"},
response_body={"error": result.error_message} if result.error_message else {},
request_id=self.request_id,
provider_id=metadata.provider_id,

View File

@@ -40,6 +40,7 @@ class UsageRecordParams:
request_body: Optional[Any]
provider_request_headers: Optional[Dict[str, Any]]
response_headers: Optional[Dict[str, Any]]
client_response_headers: Optional[Dict[str, Any]]
response_body: Optional[Any]
request_id: str
provider_id: Optional[str]
@@ -223,6 +224,7 @@ class UsageService:
request_body: Optional[Any],
provider_request_headers: Optional[Dict[str, Any]],
response_headers: Optional[Dict[str, Any]],
client_response_headers: Optional[Dict[str, Any]],
response_body: Optional[Any],
request_id: str,
provider_id: Optional[str],
@@ -288,6 +290,13 @@ class UsageService:
db, response_headers
)
# 处理返回给客户端的响应头
processed_client_response_headers = None
if should_log_headers and client_response_headers:
processed_client_response_headers = SystemConfigService.mask_sensitive_headers(
db, client_response_headers
)
# 计算真实成本(表面成本 * 倍率),免费套餐实际费用为 0
if is_free_tier:
actual_input_cost = 0.0
@@ -351,6 +360,7 @@ class UsageService:
"request_body": processed_request_body,
"provider_request_headers": processed_provider_request_headers,
"response_headers": processed_response_headers,
"client_response_headers": processed_client_response_headers,
"response_body": processed_response_body,
}
@@ -360,12 +370,13 @@ class UsageService:
db: Session,
provider_api_key_id: Optional[str],
provider_id: Optional[str],
api_format: Optional[str] = None,
) -> Tuple[float, bool]:
"""获取费率倍数和是否免费套餐(使用缓存)"""
from src.services.cache.provider_cache import ProviderCacheService
return await ProviderCacheService.get_rate_multiplier_and_free_tier(
db, provider_api_key_id, provider_id
db, provider_api_key_id, provider_id, api_format
)
@classmethod
@@ -484,6 +495,7 @@ class UsageService:
existing_usage.provider_request_headers = usage_params["provider_request_headers"]
existing_usage.response_body = usage_params["response_body"]
existing_usage.response_headers = usage_params["response_headers"]
existing_usage.client_response_headers = usage_params["client_response_headers"]
# 更新 token 和费用信息
existing_usage.input_tokens = usage_params["input_tokens"]
@@ -656,9 +668,9 @@ class UsageService:
Returns:
(usage_params 字典, total_cost 总成本)
"""
# 获取费率倍数和是否免费套餐
# 获取费率倍数和是否免费套餐(传递 api_format 支持按格式配置的倍率)
actual_rate_multiplier, is_free_tier = await cls._get_rate_multiplier_and_free_tier(
params.db, params.provider_api_key_id, params.provider_id
params.db, params.provider_api_key_id, params.provider_id, params.api_format
)
# 计算成本
@@ -704,6 +716,7 @@ class UsageService:
request_body=params.request_body,
provider_request_headers=params.provider_request_headers,
response_headers=params.response_headers,
client_response_headers=params.client_response_headers,
response_body=params.response_body,
request_id=params.request_id,
provider_id=params.provider_id,
@@ -753,6 +766,7 @@ class UsageService:
request_body: Optional[Any] = None,
provider_request_headers: Optional[Dict[str, Any]] = None,
response_headers: Optional[Dict[str, Any]] = None,
client_response_headers: Optional[Dict[str, Any]] = None,
response_body: Optional[Any] = None,
request_id: Optional[str] = None,
provider_id: Optional[str] = None,
@@ -785,7 +799,8 @@ class UsageService:
status_code=status_code, error_message=error_message, metadata=metadata,
request_headers=request_headers, request_body=request_body,
provider_request_headers=provider_request_headers,
response_headers=response_headers, response_body=response_body,
response_headers=response_headers, client_response_headers=client_response_headers,
response_body=response_body,
request_id=request_id, provider_id=provider_id,
provider_endpoint_id=provider_endpoint_id,
provider_api_key_id=provider_api_key_id, status=status,
@@ -844,6 +859,7 @@ class UsageService:
request_body: Optional[Any] = None,
provider_request_headers: Optional[Dict[str, Any]] = None,
response_headers: Optional[Dict[str, Any]] = None,
client_response_headers: Optional[Dict[str, Any]] = None,
response_body: Optional[Any] = None,
request_id: Optional[str] = None,
provider_id: Optional[str] = None,
@@ -878,7 +894,8 @@ class UsageService:
status_code=status_code, error_message=error_message, metadata=metadata,
request_headers=request_headers, request_body=request_body,
provider_request_headers=provider_request_headers,
response_headers=response_headers, response_body=response_body,
response_headers=response_headers, client_response_headers=client_response_headers,
response_body=response_body,
request_id=request_id, provider_id=provider_id,
provider_endpoint_id=provider_endpoint_id,
provider_api_key_id=provider_api_key_id, status=status,

Some files were not shown because too many files have changed in this diff Show More